25 minute read

How to choose between batch and real-time inference, the architectural decision that shapes your entire ML serving infrastructure.

TL;DR

Batch inference precomputes predictions periodically and stores them for O(1) lookup — it’s 10-100x cheaper but trades feature freshness. Real-time inference computes on-demand with sub-100ms latency but requires always-on GPU capacity. Most production systems at Netflix, Uber, and LinkedIn use a hybrid: batch for bulk predictions cached in Redis, with real-time fallback for cache misses. The deciding factors are latency requirements, feature freshness needs, and cost constraints.

Two parallel water channels — one a wide slow-moving river and one a narrow high-pressure jet — both feeding into the...

How Do You Serve ML Predictions in Production?

After training a model, you need to serve predictions. Two fundamental approaches:

  1. Batch Inference: Precompute predictions for all users/items periodically
  2. Real-Time Inference: Compute predictions on-demand when requested

Why this matters:

  • Different latency requirements → Different architectures
  • Cost implications → Batch can be 10-100x cheaper
  • System complexity → Real-time requires more infrastructure
  • Feature freshness → Real-time uses latest data

What you’ll learn:

  • When to use batch vs real-time
  • Architecture for each approach
  • Hybrid systems combining both
  • Trade-offs and decision framework
  • Production implementation patterns

What Are the Requirements for an Inference System?

Design an ML inference system that serves predictions efficiently.

Functional Requirements

  1. Prediction Serving
    • Batch: Generate predictions for all entities periodically
    • Real-time: Serve predictions on-demand with low latency
    • Hybrid: Combine both approaches
  2. Data Freshness
    • Access to latest features
    • Handle feature staleness
    • Feature computation strategy
  3. Scalability
    • Handle millions of predictions
    • Scale horizontally
    • Handle traffic spikes

Non-Functional Requirements

  1. Latency
    • Batch: Minutes to hours acceptable
    • Real-time: < 100ms for most applications
  2. Throughput
    • Batch: Process millions of predictions in one run
    • Real-time: 1000s of requests/second
  3. Cost
    • Optimize compute resources
    • Minimize infrastructure costs
  4. Reliability
    • 99.9%+ uptime for real-time
    • Graceful degradation
    • Fallback mechanisms

How Does Batch Inference Work?

Batch inference precomputes predictions periodically (daily, hourly) and stores them for fast O(1) lookup — typically 10-100x cheaper than real-time because you batch GPU utilization and avoid idle capacity.

Architecture

┌─────────────────────────────────────────────────────────┐
│ Batch Inference Pipeline │
├─────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ ┌──────────────┐ │
│ │ Data Lake │ │ Feature │ │
│ │ (HDFS/S3) │─────▶│ Engineering │ │
│ └──────────────┘ └──────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────┐ │
│ │ Batch Job │ │
│ │ (Spark/Ray) │ │
│ │ - Load model│ │
│ │ - Predict │ │
│ └──────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────┐ │
│ │ Write to │ │
│ │ Cache/DB │ │
│ │ (Redis/DDB) │ │
│ └──────────────┘ │
│ │ │
│ ┌──────────────┐ │ │
│ │ Application │◀───────────┘ │
│ │ Server │ Lookup predictions │
│ └──────────────┘ │
│ │
└─────────────────────────────────────────────────────────┘

Flow:
1. Extract features from data warehouse
2. Run batch prediction job
3. Store predictions in fast lookup store
4. Application does simple lookup

Implementation

from typing import List, Dict
import numpy as np
import redis
import json
import time

class BatchInferenceSystem:
    """
    Batch inference system

    Precomputes predictions for all users/items
    """

    def __init__(self, model, redis_client):
        self.model = model
        self.redis = redis_client
        self.batch_size = 1000

    def run_batch_prediction(self, entity_ids: List[str], features_df):
        """
        Run batch prediction for all entities

        Args:
            entity_ids: List of user/item IDs
            features_df: DataFrame with features for all entities

            Returns:
                Number of predictions generated
                """
                num_predictions = 0

                # Process in batches for memory efficiency
                for i in range(0, len(entity_ids), self.batch_size):
                    batch_ids = entity_ids[i:i+self.batch_size]
                    batch_features = features_df.iloc[i:i+self.batch_size]

                    # Predict
                    predictions = self.model.predict(batch_features.values)

                    # Store in Redis
                    self._store_predictions(batch_ids, predictions)

                    num_predictions += len(batch_ids)

                    if num_predictions % 10000 == 0:
                        print(f"Processed {num_predictions} predictions...")

                        return num_predictions

    def _store_predictions(self, entity_ids: List[str], predictions: np.ndarray):
        """Store predictions in Redis with TTL"""
        pipeline = self.redis.pipeline()

        ttl_seconds = 24 * 3600 # 24 hours

        for entity_id, prediction in zip(entity_ids, predictions):
            # Store as JSON
            key = f"pred:{entity_id}"
            value = json.dumps({
            'prediction': float(prediction),
            'timestamp': time.time()
            })

            pipeline.setex(key, ttl_seconds, value)

            pipeline.execute()

    def get_prediction(self, entity_id: str) -> float:
        """
        Lookup precomputed prediction

        Fast O(1) lookup
        """
        key = f"pred:{entity_id}"
        value = self.redis.get(key)

        if value is None:
            # Prediction not found or expired
            return None

            data = json.loads(value)
            return data['prediction']

            # Usage
            import pandas as pd
            import time

            # Initialize
            redis_client = redis.Redis(host='localhost', port=6379, db=0)
            model = load_trained_model() # Your trained model

            batch_system = BatchInferenceSystem(model, redis_client)

            # Run batch prediction (e.g., daily cron job)
            user_ids = fetch_all_user_ids() # Get all users
            features_df = fetch_user_features(user_ids) # Get features

            num_preds = batch_system.run_batch_prediction(user_ids, features_df)
            print(f"Generated {num_preds} predictions")

            # Later, application looks up prediction
            prediction = batch_system.get_prediction("user_12345")
            print(f"Prediction: {prediction}")

Spark-based Batch Inference

For large-scale batch processing:

from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf, PandasUDFType
import pandas as pd

class SparkBatchInference:
    """
    Distributed batch inference using PySpark

    Scales to billions of predictions
    """

    def __init__(self, model_path):
        self.spark = SparkSession.builder \
        .appName("BatchInference") \
        .getOrCreate()

        self.model_path = model_path

    def predict_spark(self, features_df_spark):
        """
        Distribute prediction across cluster

        Args:
            features_df_spark: Spark DataFrame with features

            Returns:
                Spark DataFrame with predictions
                """
                model_path = self.model_path

                # Define pandas UDF for prediction
                @pandas_udf("double", PandasUDFType.SCALAR)
    def predict_udf(*features):
        # Load model once per executor
        import joblib
        model = joblib.load(model_path)

        # Create feature matrix
        X = pd.DataFrame({
        f'feature_{i}': features[i]
        for i in range(len(features))
        })

        # Predict
        predictions = model.predict(X.values)
        return pd.Series(predictions)

        # Apply UDF
        feature_cols = [col for col in features_df_spark.columns if col.startswith('feature_')]

        result_df = features_df_spark.withColumn(
        'prediction',
        predict_udf(*feature_cols)
        )

        return result_df

    def run_batch_job(self, input_path, output_path):
        """
        Full batch inference pipeline

        Args:
            input_path: S3/HDFS path to input data
            output_path: S3/HDFS path to save predictions
            """
            # Read input
            df = self.spark.read.parquet(input_path)

            # Predict
            predictions_df = self.predict_spark(df)

            # Write output
            predictions_df.write.parquet(output_path, mode='overwrite')

            print(f"Batch prediction complete. Output: {output_path}")

            # Usage
            spark_batch = SparkBatchInference(model_path='s3://models/my_model.pkl')

            spark_batch.run_batch_job(
            input_path='s3://data/user_features/',
            output_path='s3://predictions/daily/2025-01-15/'
            )

When Should You Use Real-Time Inference?

Real-time inference computes predictions on-demand when requested. Use it when sub-100ms latency is required, features must be fresh (current context), or predictions serve only a small active subset of entities.

Architecture

┌─────────────────────────────────────────────────────────┐
│ Real-Time Inference System │
├─────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ │
│ │ Load │◀─── Model Registry │
│ │ Balancer │ │
│ └──────┬───────┘ │
│ │ │
│ ┌────▼─────────────────────────────┐ │
│ │ Model Serving Instances │ │
│ │ ┌─────────┐ ┌─────────┐ │ │
│ │ │ Model 1 │ │ Model 2 │ ... │ │
│ │ │ (GPU) │ │ (GPU) │ │ │
│ │ └─────────┘ └─────────┘ │ │
│ └────┬─────────────────────────────┘ │
│ │ │
│ ┌────▼──────────┐ ┌──────────────┐ │
│ │ Feature │────▶│ Feature │ │
│ │ Service │ │ Store │ │
│ │ - Online │ │ (Redis) │ │
│ │ features │ └──────────────┘ │
│ └───────────────┘ │
│ │
└─────────────────────────────────────────────────────────┘

Flow:
1. Request arrives with user/item ID
2. Fetch features from feature store
3. Compute additional online features
4. Model predicts
5. Return prediction

Implementation

from fastapi import FastAPI
import numpy as np
from typing import Dict
import torch

app = FastAPI()

class RealTimeInferenceService:
    """
    Real-time inference service

    Serves predictions with low latency
    """

    def __init__(self, model, feature_store):
        self.model = model
        self.feature_store = feature_store

        # Warm up model
        self._warmup()

    def _warmup(self):
        """Warm up model with dummy prediction"""
        dummy_features = np.random.randn(1, self.model.input_dim)
        _ = self.model.predict(dummy_features)

    def get_features(self, entity_id: str) -> Dict:
        """
        Fetch features for entity

        Combines precomputed + real-time features
        """
        # Fetch precomputed features from Redis
        precomputed_raw = self.feature_store.get(f"features:{entity_id}")
        precomputed = {}
        if precomputed_raw:
            try:
                precomputed = json.loads(precomputed_raw)
            except Exception:
                precomputed = {}

                if precomputed is None:
                    # Fallback: compute features on-the-fly
                    precomputed = self._compute_features_fallback(entity_id)

                    # Add real-time features
                    realtime_features = self._compute_realtime_features(entity_id)

                    # Combine
                    features = {**precomputed, **realtime_features}

                    return features

    def _compute_realtime_features(self, entity_id: str) -> Dict:
        """
        Compute features that must be fresh

        E.g., time of day, user's current session, etc.
        """
        import datetime

        now = datetime.datetime.now()

        return {
        'hour_of_day': now.hour,
        'day_of_week': now.weekday(),
        'is_weekend': 1 if now.weekday() >= 5 else 0
        }

    def _compute_features_fallback(self, entity_id: str) -> Dict:
        """Fallback feature computation"""
        # Query database, compute on-the-fly
        # This is slower but ensures we can always serve
        return {}

    def predict(self, entity_id: str) -> float:
        """
        Real-time prediction

        Returns:
            Prediction score
            """
            # Get features
            features = self.get_features(entity_id)

            # Convert to numpy array (assuming fixed feature order)
            feature_vector = np.array([
            features.get(f'feature_{i}', 0.0)
            for i in range(self.model.input_dim)
            ]).reshape(1, -1)

            # Predict
            prediction = self.model.predict(feature_vector)[0]

            return float(prediction)

            # FastAPI endpoints
            realtime_service = RealTimeInferenceService(model, redis_client)

            @app.get("/predict/{entity_id}")
            async def predict_endpoint(entity_id: str):
                """
                Real-time prediction endpoint

                GET /predict/user_12345
                """
                try:
                    prediction = realtime_service.predict(entity_id)

                    return {
                    'entity_id': entity_id,
                    'prediction': prediction,
                    'timestamp': time.time()
                    }

                except Exception as e:
                    from fastapi import HTTPException
                    raise HTTPException(status_code=500, detail={'error': str(e), 'entity_id': entity_id})

                    # Run with: uvicorn app:app --host 0.0.0.0 --port 8000

TensorFlow Serving

Production-grade model serving:

import requests
import json

class TensorFlowServingClient:
    """
    Client for TensorFlow Serving

    High-performance model serving
    """

    def __init__(self, server_url, model_name, model_version=None):
        self.server_url = server_url
        self.model_name = model_name
        self.model_version = model_version or 'latest'

        # Endpoint
        if self.model_version == 'latest':
            self.endpoint = f"{server_url}/v1/models/{model_name}:predict"
        else:
            self.endpoint = f"{server_url}/v1/models/{model_name}/versions/{model_version}:predict"

    def predict(self, instances: List[List[float]]) -> List[float]:
        """
        Send prediction request to TF Serving

        Args:
            instances: List of feature vectors

            Returns:
                List of predictions
                """
                # Prepare request
                payload = {
                "signature_name": "serving_default",
                "instances": instances
                }

                # Send request
                response = requests.post(
                self.endpoint,
                data=json.dumps(payload),
                headers={'Content-Type': 'application/json'}
                )

                if response.status_code != 200:
                    raise Exception(f"Prediction failed: {response.text}")

                    # Parse response
                    result = response.json()
                    predictions = result['predictions']

                    return predictions

                    # Usage
                    tf_client = TensorFlowServingClient(
                    server_url='http://localhost:8501',
                    model_name='recommendation_model',
                    model_version='3'
                    )

                    # Predict
                    features = [[0.1, 0.5, 0.3, 0.9]]
                    predictions = tf_client.predict(features)
                    print(f"Prediction: {predictions[0]}")

How Do Hybrid Systems Combine Both Approaches?

Hybrid systems combine batch and real-time for optimal cost-performance. This is what most production systems at scale actually use.

Architecture

┌────────────────────────────────────────────────────────┐
│ Hybrid Inference System │
├────────────────────────────────────────────────────────┤
│ │
│ ┌────────────────┐ ┌────────────────┐ │
│ │ Batch Pipeline │ │ Real-Time API │ │
│ │ (Daily) │ │ │ │
│ └───────┬────────┘ └───────┬────────┘ │
│ │ │ │
│ ▼ ▼ │
│ ┌──────────────────────────────────────────┐ │
│ │ Prediction Cache (Redis) │ │
│ │ ┌────────────┐ ┌────────────┐ │ │
│ │ │ Batch │ │ Real-time │ │ │
│ │ │ Predictions│ │ Predictions│ │ │
│ │ │ (TTL: 24h) │ │ (TTL: 1h) │ │ │
│ │ └────────────┘ └────────────┘ │ │
│ └──────────────────────────────────────────┘ │
│ ▲ │
│ │ │
│ ┌──────┴────────┐ │
│ │ Application │ │
│ │ 1. Check cache│ │
│ │ 2. Fallback to│ │
│ │ real-time │ │
│ └───────────────┘ │
│ │
└────────────────────────────────────────────────────────┘

Implementation

class HybridInferenceSystem:
    """
    Hybrid system: batch + real-time

    - Fast path: Use batch predictions if available
    - Slow path: Compute real-time if needed
    """

    def __init__(self, batch_system, realtime_system):
        self.batch = batch_system
        self.realtime = realtime_system
        self.cache_hit_counter = 0
        self.cache_miss_counter = 0

    def predict(self, entity_id: str, max_staleness_hours: int = 24) -> Dict:
        """
        Get prediction with automatic fallback

        Args:
            entity_id: Entity to predict for
            max_staleness_hours: Maximum age of batch prediction

            Returns:
                {
                'prediction': float,
                'source': 'batch' | 'realtime',
                'timestamp': float
                }
                """
                # Try batch prediction first
                batch_pred_value = self.batch.get_prediction(entity_id)

                if batch_pred_value is not None:
                    # If batch system returns only a float, treat as fresh within TTL of Redis
                    self.cache_hit_counter += 1
                    return {
                    'prediction': batch_pred_value,
                    'source': 'batch',
                    'timestamp': time.time(),
                    'cache_hit': True
                    }

                    # Fallback to real-time
                    self.cache_miss_counter += 1

                    realtime_pred = self.realtime.predict(entity_id)

                    return {
                    'prediction': realtime_pred,
                    'source': 'realtime',
                    'timestamp': time.time(),
                    'cache_hit': False
                    }

    def get_cache_hit_rate(self) -> float:
        """Calculate cache hit rate"""
        total = self.cache_hit_counter + self.cache_miss_counter
        if total == 0:
            return 0.0
            return self.cache_hit_counter / total

            # Usage
            hybrid = HybridInferenceSystem(batch_system, realtime_service)

            # Predict for user
            result = hybrid.predict('user_12345', max_staleness_hours=12)

            print(f"Prediction: {result['prediction']}")
            print(f"Source: {result['source']}")
            print(f"Cache hit rate: {hybrid.get_cache_hit_rate():.2%}")

How Do You Decide Between Batch, Real-Time, or Hybrid?

Use Batch Inference When:

Latency is not critical (recommendations, email campaigns) ✅ Predictions needed for all entities (e.g., all users) ✅ Features are expensive to computeModel is large/slowCost optimization is priorityPredictions don’t change frequently

Examples:

  • Daily email recommendations
  • Product catalog rankings
  • Weekly personalized content
  • Batch fraud scoring

Use Real-Time Inference When:

Low latency required (< 100ms) ✅ Fresh features critical (current context) ✅ Predictions for small subset (active users) ✅ Immediate user feedback (search, ads) ✅ High-value decisions (fraud detection)

Examples:

  • Search ranking
  • Ad serving
  • Real-time fraud detection
  • Live recommendation widgets

Use Hybrid When:

Mix of latency requirementsWant cost + performanceCan tolerate some stalenessVariable traffic patternsGraceful degradation needed

Examples:

  • Homepage recommendations (batch) + search (real-time)
  • Social feed (batch) + stories (real-time)
  • Product pages (batch) + checkout (real-time)

How Much Does Each Approach Cost?

class CostAnalyzer:
    """
    Estimate costs for batch vs real-time
    """

    def estimate_batch_cost(
    self,
    num_entities: int,
    predictions_per_day: int,
    cost_per_compute_hour: float = 3.0
    ) -> Dict:
        """Estimate daily batch inference cost"""

        # Assume 10K predictions/second throughput
        throughput = 10_000

        # Total predictions
        total_preds = num_entities * predictions_per_day

        # Compute time needed
        compute_seconds = total_preds / throughput
        compute_hours = compute_seconds / 3600

        # Cost
        compute_cost = compute_hours * cost_per_compute_hour

        # Storage cost (Redis/DDB)
        storage_gb = total_preds * 100 / 1e9 # 100 bytes per prediction
        storage_cost = storage_gb * 0.25 # $0.25/GB/month

        total_cost = compute_cost + storage_cost

        return {
        'compute_hours': compute_hours,
        'compute_cost': compute_cost,
        'storage_cost': storage_cost,
        'total_daily_cost': total_cost,
        'cost_per_prediction': total_cost / total_preds
        }

    def estimate_realtime_cost(
    self,
    requests_per_second: int,
    cost_per_instance_hour: float = 5.0,
    requests_per_instance: int = 100
    ) -> Dict:
        """Estimate real-time serving cost"""

        # Number of instances needed
        num_instances = requests_per_second / requests_per_instance
        num_instances = int(np.ceil(num_instances * 1.5)) # 50% headroom

        # Daily cost
        daily_hours = 24
        daily_cost = num_instances * cost_per_instance_hour * daily_hours

        # Predictions per day
        daily_requests = requests_per_second * 86400

        return {
        'num_instances': num_instances,
        'daily_cost': daily_cost,
        'cost_per_prediction': daily_cost / daily_requests
        }

        # Compare costs
        analyzer = CostAnalyzer()

        # Batch: 1M users, predict once/day
        batch_cost = analyzer.estimate_batch_cost(
        num_entities=1_000_000,
        predictions_per_day=1
        )

        print("Batch Inference:")
        print(f" Daily cost: ${batch_cost['total_daily_cost']:.2f}")
        print(f" Cost per prediction: ${batch_cost['cost_per_prediction']:.6f}")

        # Real-time: 100 QPS average
        realtime_cost = analyzer.estimate_realtime_cost(
        requests_per_second=100
        )

        print("\nReal-Time Inference:")
        print(f" Daily cost: ${realtime_cost['daily_cost']:.2f}")
        print(f" Cost per prediction: ${realtime_cost['cost_per_prediction']:.6f}")

        # Compare
        savings = (realtime_cost['daily_cost'] - batch_cost['total_daily_cost']) / realtime_cost['daily_cost'] * 100
        print(f"\nBatch is {savings:.1f}% cheaper!")

What Are the Advanced Patterns for Production Inference?

How Does Multi-Tier Caching Optimize Latency?

Layer multiple caches for optimal performance.

class MultiTierInferenceSystem:
    """
    Multi-tier caching: Memory → Redis → Compute

    Optimizes for different latency/cost profiles
    """

    def __init__(self, model, redis_client):
        self.model = model
        self.redis = redis_client

        # In-memory cache (fastest)
        self.memory_cache = {}
        self.memory_cache_size = 10000

        # Statistics
        self.stats = {
        'memory_hits': 0,
        'redis_hits': 0,
        'compute': 0,
        'total_requests': 0
        }

    def predict(self, entity_id: str) -> float:
        """
        Predict with multi-tier caching

        Tier 1: In-memory cache (~1ms)
        Tier 2: Redis cache (~5ms)
        Tier 3: Compute prediction (~50ms)
        """
        self.stats['total_requests'] += 1

        # Tier 1: Memory cache
        if entity_id in self.memory_cache:
            self.stats['memory_hits'] += 1
            return self.memory_cache[entity_id]

            # Tier 2: Redis cache
            redis_key = f"pred:{entity_id}"
            cached = self.redis.get(redis_key)

            if cached is not None:
                self.stats['redis_hits'] += 1
                prediction = float(cached)

                # Promote to memory cache
                self._add_to_memory_cache(entity_id, prediction)

                return prediction

                # Tier 3: Compute
                self.stats['compute'] += 1
                prediction = self._compute_prediction(entity_id)

                # Write to both caches
                self.redis.setex(redis_key, 3600, str(prediction)) # 1 hour TTL
                self._add_to_memory_cache(entity_id, prediction)

                return prediction

    def _add_to_memory_cache(self, entity_id: str, prediction: float):
        """Add to memory cache with LRU eviction"""
        if len(self.memory_cache) >= self.memory_cache_size:
            # Simple eviction: remove first item
            # In production, use LRU cache
            self.memory_cache.pop(next(iter(self.memory_cache)))

            self.memory_cache[entity_id] = prediction

    def _compute_prediction(self, entity_id: str) -> float:
        """Compute prediction from model"""
        # Fetch features
        features = self._get_features(entity_id)

        # Predict
        prediction = self.model.predict([features])[0]

        return float(prediction)

    def _get_features(self, entity_id: str):
        """Fetch features for entity"""
        # Placeholder
        return [0.1, 0.2, 0.3, 0.4, 0.5]

    def get_cache_stats(self) -> dict:
        """Get cache performance statistics"""
        total = self.stats['total_requests']

        if total == 0:
            return self.stats

            return {
            **self.stats,
            'memory_hit_rate': self.stats['memory_hits'] / total * 100,
            'redis_hit_rate': self.stats['redis_hits'] / total * 100,
            'compute_rate': self.stats['compute'] / total * 100,
            'overall_cache_hit_rate':
                (self.stats['memory_hits'] + self.stats['redis_hits']) / total * 100
                }

                # Usage
                system = MultiTierInferenceSystem(model, redis_client)

                # Make predictions
                for entity_id in ['user_1', 'user_2', 'user_1', 'user_3', 'user_1']:
                    prediction = system.predict(entity_id)
                    print(f"{entity_id}: {prediction:.4f}")

                    stats = system.get_cache_stats()
                    print(f"\nCache hit rate: {stats['overall_cache_hit_rate']:.1f}%")
                    print(f"Memory: {stats['memory_hit_rate']:.1f}%, Redis: {stats['redis_hit_rate']:.1f}%, Compute: {stats['compute_rate']:.1f}%")

Prediction Warming

Precompute predictions for likely requests.

class PredictionWarmer:
    """
    Warm cache with predictions for likely-to-be-requested entities

    Use case: Preload predictions for active users
    """

    def __init__(self, model, cache):
        self.model = model
        self.cache = cache

    def warm_predictions(
    self,
    entity_ids: List[str],
    batch_size: int = 100
    ):
        """
        Warm cache for list of entities

        Args:
            entity_ids: Entities to warm
            batch_size: Batch size for efficient computation
            """
            num_warmed = 0

            for i in range(0, len(entity_ids), batch_size):
                batch_ids = entity_ids[i:i+batch_size]

                # Batch feature fetching
                features = self._batch_get_features(batch_ids)

                # Batch prediction
                predictions = self.model.predict(features)

                # Write to cache
                for entity_id, prediction in zip(batch_ids, predictions):
                    self.cache.set(f"pred:{entity_id}", float(prediction), ex=3600)
                    num_warmed += 1

                    return num_warmed

    def _batch_get_features(self, entity_ids: List[str]):
        """Fetch features for multiple entities"""
        # In production: Batch query to feature store
        return [[0.1] * 5 for _ in entity_ids]

    def warm_by_activity(
    self,
    lookback_hours: int = 24,
    top_k: int = 10000
    ):
        """
        Warm cache for most active entities

        Args:
            lookback_hours: Look back this many hours for activity
            top_k: Warm top K most active entities
            """
            # Query activity logs
            active_entities = self._get_active_entities(lookback_hours, top_k)

            # Warm predictions
            num_warmed = self.warm_predictions(active_entities)

            return {
            'num_warmed': num_warmed,
            'lookback_hours': lookback_hours,
            'timestamp': time.time()
            }

    def _get_active_entities(self, lookback_hours: int, top_k: int) -> List[str]:
        """Get most active entities from activity logs"""
        # Placeholder: Query activity database
        return [f'user_{i}' for i in range(top_k)]

        # Usage: Warm cache every hour for active users
        warmer = PredictionWarmer(model, redis_client)

        # Warm cache for top 10K active users
        result = warmer.warm_by_activity(lookback_hours=1, top_k=10000)
        print(f"Warmed {result['num_warmed']} predictions")

Conditional Batch Updates

Update batch predictions conditionally based on staleness/changes.

class ConditionalBatchUpdater:
    """
    Update batch predictions only when necessary

    Strategies:
        - Update only if features changed significantly
        - Update only if prediction is stale
        - Update only for active entities
        """

    def __init__(self, model, cache, feature_store):
        self.model = model
        self.cache = cache
        self.feature_store = feature_store

    def update_if_changed(
    self,
    entity_ids: List[str],
    change_threshold: float = 0.1
    ) -> dict:
        """
        Update predictions only if features changed significantly

        Args:
            entity_ids: Entities to check
            change_threshold: Update if features changed by this much

            Returns:
                Statistics on updates
                """
                num_checked = 0
                num_updated = 0

                for entity_id in entity_ids:
                    num_checked += 1

                    # Get current features
                    current_features = self.feature_store.get(f"features:{entity_id}")

                    # Get cached features (when prediction was made)
                    cached_features = self.feature_store.get(f"cached_features:{entity_id}")

                    # Check if features changed significantly
                    if self._features_changed(cached_features, current_features, change_threshold):
                        # Recompute prediction
                        prediction = self.model.predict([current_features])[0]

                        # Update cache
                        self.cache.set(f"pred:{entity_id}", float(prediction), ex=3600)
                        self.feature_store.set(f"cached_features:{entity_id}", current_features)

                        num_updated += 1

                        return {
                        'num_checked': num_checked,
                        'num_updated': num_updated,
                        'update_rate': num_updated / num_checked * 100 if num_checked > 0 else 0
                        }

    def _features_changed(
    self,
    old_features,
    new_features,
    threshold: float
    ) -> bool:
        """Check if features changed significantly"""
        if old_features is None or new_features is None:
            return True

            # Compute L2 distance
            diff = np.linalg.norm(np.array(new_features) - np.array(old_features))

            return diff > threshold

Graceful Degradation

Handle failures gracefully with fallback strategies.

class GracefulDegradationSystem:
    """
    Inference system with graceful degradation

    Fallback chain:
        1. Try real-time prediction
        2. Fallback to batch prediction (if available)
        3. Fallback to default/fallback prediction
        """

    def __init__(
    self,
    realtime_service,
    batch_cache,
    default_prediction: float = 0.5
    ):
        self.realtime = realtime_service
        self.batch_cache = batch_cache
        self.default_prediction = default_prediction

        # Monitoring
        self.degradation_stats = {
        'realtime': 0,
        'batch_fallback': 0,
        'default_fallback': 0
        }

    def predict_with_fallback(
    self,
    entity_id: str,
    max_latency_ms: int = 100
    ) -> dict:
        """
        Predict with fallback strategies

        Args:
            entity_id: Entity to predict for
            max_latency_ms: Maximum acceptable latency

            Returns:
                {
                'prediction': float,
                'source': str,
                'latency_ms': float
                }
                """
                start = time.perf_counter()

                # Try real-time prediction
                try:
                    prediction = self.realtime.predict(entity_id)
                    elapsed_ms = (time.perf_counter() - start) * 1000

                    if elapsed_ms <= max_latency_ms:
                        self.degradation_stats['realtime'] += 1
                        return {
                        'prediction': prediction,
                        'source': 'realtime',
                        'latency_ms': elapsed_ms
                        }
                    except Exception as e:
                        print(f"Real-time prediction failed: {e}")

                        # Fallback 1: Batch cache
                        try:
                            batch_pred = self.batch_cache.get(f"pred:{entity_id}")

                            if batch_pred is not None:
                                elapsed_ms = (time.perf_counter() - start) * 1000
                                self.degradation_stats['batch_fallback'] += 1

                                return {
                                'prediction': float(batch_pred),
                                'source': 'batch_fallback',
                                'latency_ms': elapsed_ms,
                                'warning': 'Using stale batch prediction'
                                }
                            except Exception as e:
                                print(f"Batch fallback failed: {e}")

                                # Fallback 2: Default prediction
                                elapsed_ms = (time.perf_counter() - start) * 1000
                                self.degradation_stats['default_fallback'] += 1

                                return {
                                'prediction': self.default_prediction,
                                'source': 'default_fallback',
                                'latency_ms': elapsed_ms,
                                'warning': 'Using default prediction - service degraded'
                                }

    def get_health_status(self) -> dict:
        """Get system health metrics"""
        total = sum(self.degradation_stats.values())

        if total == 0:
            return {'status': 'no_traffic'}

            realtime_rate = self.degradation_stats['realtime'] / total * 100

            if realtime_rate > 95:
                status = 'healthy'
            elif realtime_rate > 80:
                status = 'degraded'
            else:
                status = 'critical'

                return {
                'status': status,
                'realtime_rate': realtime_rate,
                'batch_fallback_rate': self.degradation_stats['batch_fallback'] / total * 100,
                'default_fallback_rate': self.degradation_stats['default_fallback'] / total * 100,
                'total_requests': total
                }

How Do Companies Like Netflix, Uber, and LinkedIn Handle This?

Netflix: Hybrid Recommendations

Challenge: Personalized recommendations for 200M+ users

Solution:

  • Batch: Precompute top-N recommendations for all users daily
  • Real-time: Rerank based on current session context
  • Result: < 100ms latency with personalized results

Architecture:

Daily Batch Job (Spark)
 ↓
Precompute Top 1000 movies per user
 ↓
Store in Cassandra
 ↓
Real-time API fetches top 1000 + reranks based on:
 - Current time of day
 - Device type
 - Recent viewing history
 ↓
Return Top 20 to UI

Uber: Real-Time ETA Prediction

Challenge: Predict arrival time for millions of rides

Solution:

  • Real-time only: ETA must reflect current traffic
  • Strategy: Fast model (< 50ms inference)
  • Features: Current location, traffic data, historical patterns

Why not batch:

  • Traffic changes rapidly
  • Each ride is unique
  • Requires current GPS coordinates

LinkedIn: People You May Know

Challenge: Suggest connections for 800M+ users

Solution:

  • Batch: Graph algorithms compute connection candidates (weekly)
  • Real-time: Scoring based on user activity
  • Result: Balance compute cost with personalization

Hybrid Strategy:

Weekly Batch:
 - Graph traversal (2nd, 3rd degree connections)
 - Identify ~1000 candidates per user
 - Store in candidate DB

Real-time (on page load):
 - Fetch candidates from DB
 - Score based on:
 * Recent profile views
 * Shared groups/companies
 * Mutual connections
 - Return top 10

What Metrics Should You Monitor?

What Key Metrics Should You Track?

class InferenceMetrics:
    """
    Track comprehensive inference metrics
    """

    def __init__(self):
        self.metrics = {
        'latency_p50': [],
        'latency_p95': [],
        'latency_p99': [],
        'cache_hit_rate': [],
        'error_rate': [],
        'throughput': [],
        'cost_per_prediction': []
        }

    def record_prediction(
    self,
    latency_ms: float,
    cache_hit: bool,
    error: bool,
    cost: float
    ):
        """Record single prediction metrics"""
        pass # Implementation details

    def get_dashboard_metrics(self) -> dict:
        """
        Get metrics for monitoring dashboard

        Returns:
            Key metrics for alerting
            """
            return {
            'latency_p50_ms': np.median(self.metrics['latency_p50']),
            'latency_p99_ms': np.percentile(self.metrics['latency_p99'], 99),
            'cache_hit_rate': np.mean(self.metrics['cache_hit_rate']) * 100,
            'error_rate': np.mean(self.metrics['error_rate']) * 100,
            'qps': np.mean(self.metrics['throughput']),
            'cost_per_1k_predictions': np.mean(self.metrics['cost_per_prediction']) * 1000
            }

SLA Definition

class InferenceSLA:
    """
    Define and monitor SLA for inference service
    """

    def __init__(self):
        self.sla_targets = {
        'p99_latency_ms': 100,
        'availability': 99.9,
        'error_rate': 0.1 # 0.1%
        }

    def check_sla_compliance(self, metrics: dict) -> dict:
        """
        Check if current metrics meet SLA

        Returns:
            SLA compliance report
            """
            compliance = {}

            for metric, target in self.sla_targets.items():
                actual = metrics.get(metric, 0)

                if metric == 'error_rate':
                    # Lower is better
                    meets_sla = actual <= target
                else:
                    # Check if within range (e.g., latency or availability)
                    meets_sla = actual <= target if 'latency' in metric else actual >= target

                    compliance[metric] = {
                    'target': target,
                    'actual': actual,
                    'meets_sla': meets_sla,
                    'margin': target - actual if 'latency' in metric or 'error' in metric else actual - target
                    }

                    return compliance

FAQ

Q: When should I use batch inference vs real-time inference? A: Use batch when latency is not critical (email recommendations, catalog rankings), predictions are needed for all entities, and cost matters. Use real-time when sub-100ms latency is required (search, ads, fraud detection), features must be fresh, and predictions serve a small active subset. Most systems use hybrid.

Q: How much cheaper is batch inference than real-time? A: Batch inference is typically 10-100x cheaper than real-time. A batch job processing 1M predictions costs around $0.30/day (compute + storage), while real-time serving at 100 QPS costs roughly $180/day in always-on GPU instances. The savings come from efficient batching and no idle capacity.

Q: How does Netflix use hybrid inference? A: Netflix precomputes top 1000 movie candidates per user daily via Spark batch jobs stored in Cassandra. When a user opens the app, a real-time API fetches these candidates and reranks based on current context (time of day, device, recent views), returning the top 20 in under 100ms.

Q: What is prediction warming? A: Prediction warming preloads the cache with predictions for entities likely to be requested soon (e.g., recently active users). It combines the cost efficiency of batch prediction with the freshness advantage of real-time, reducing cache misses by proactively computing for the most active users.


Key Takeaways

Batch inference precomputes predictions, cheaper, higher latency ✅ Real-time inference computes on-demand, expensive, lower latency ✅ Hybrid approach combines both for optimal cost/performance ✅ Multi-tier caching (memory → Redis → compute) optimizes latency ✅ Prediction warming preloads cache for likely requests ✅ Conditional updates reduce unnecessary recomputation ✅ Graceful degradation ensures reliability via fallback strategies ✅ Latency vs cost is the fundamental trade-off ✅ Feature freshness often determines the choice ✅ Most systems use hybrid: batch for bulk, real-time for edge cases ✅ Cache hit rate critical metric for hybrid systems ✅ SLA monitoring ensures service quality


Originally published at: arunbaby.com/ml-system-design/0005-batch-realtime-inference

If you found this helpful, consider sharing it with others who might benefit.

Want to work together?

I take on projects, advisory roles, and fractional CTO engagements in AI/ML. I also help businesses go AI-native with agentic workflows and agent orchestration.

Get in touch