Batch vs Real-Time Inference
How to choose between batch and real-time inference, the architectural decision that shapes your entire ML serving infrastructure.
Introduction
After training a model, you need to serve predictions. Two fundamental approaches:
- Batch Inference: Precompute predictions for all users/items periodically
- Real-Time Inference: Compute predictions on-demand when requested
Why this matters:
- Different latency requirements → Different architectures
- Cost implications → Batch can be 10-100x cheaper
- System complexity → Real-time requires more infrastructure
- Feature freshness → Real-time uses latest data
What you’ll learn:
- When to use batch vs real-time
- Architecture for each approach
- Hybrid systems combining both
- Trade-offs and decision framework
- Production implementation patterns
Problem Definition
Design an ML inference system that serves predictions efficiently.
Functional Requirements
- Prediction Serving
- Batch: Generate predictions for all entities periodically
- Real-time: Serve predictions on-demand with low latency
- Hybrid: Combine both approaches
- Data Freshness
- Access to latest features
- Handle feature staleness
- Feature computation strategy
- Scalability
- Handle millions of predictions
- Scale horizontally
- Handle traffic spikes
Non-Functional Requirements
- Latency
- Batch: Minutes to hours acceptable
- Real-time: < 100ms for most applications
- Throughput
- Batch: Process millions of predictions in one run
- Real-time: 1000s of requests/second
- Cost
- Optimize compute resources
- Minimize infrastructure costs
- Reliability
- 99.9%+ uptime for real-time
- Graceful degradation
- Fallback mechanisms
Batch Inference
Precompute predictions periodically (daily, hourly, etc.).
Architecture
┌─────────────────────────────────────────────────────────┐
│ Batch Inference Pipeline │
├─────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ ┌──────────────┐ │
│ │ Data Lake │ │ Feature │ │
│ │ (HDFS/S3) │─────▶│ Engineering │ │
│ └──────────────┘ └──────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────┐ │
│ │ Batch Job │ │
│ │ (Spark/Ray) │ │
│ │ - Load model│ │
│ │ - Predict │ │
│ └──────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────┐ │
│ │ Write to │ │
│ │ Cache/DB │ │
│ │ (Redis/DDB) │ │
│ └──────────────┘ │
│ │ │
│ ┌──────────────┐ │ │
│ │ Application │◀───────────┘ │
│ │ Server │ Lookup predictions │
│ └──────────────┘ │
│ │
└─────────────────────────────────────────────────────────┘
Flow:
1. Extract features from data warehouse
2. Run batch prediction job
3. Store predictions in fast lookup store
4. Application does simple lookup
Implementation
from typing import List, Dict
import numpy as np
import redis
import json
import time
class BatchInferenceSystem:
"""
Batch inference system
Precomputes predictions for all users/items
"""
def __init__(self, model, redis_client):
self.model = model
self.redis = redis_client
self.batch_size = 1000
def run_batch_prediction(self, entity_ids: List[str], features_df):
"""
Run batch prediction for all entities
Args:
entity_ids: List of user/item IDs
features_df: DataFrame with features for all entities
Returns:
Number of predictions generated
"""
num_predictions = 0
# Process in batches for memory efficiency
for i in range(0, len(entity_ids), self.batch_size):
batch_ids = entity_ids[i:i+self.batch_size]
batch_features = features_df.iloc[i:i+self.batch_size]
# Predict
predictions = self.model.predict(batch_features.values)
# Store in Redis
self._store_predictions(batch_ids, predictions)
num_predictions += len(batch_ids)
if num_predictions % 10000 == 0:
print(f"Processed {num_predictions} predictions...")
return num_predictions
def _store_predictions(self, entity_ids: List[str], predictions: np.ndarray):
"""Store predictions in Redis with TTL"""
pipeline = self.redis.pipeline()
ttl_seconds = 24 * 3600 # 24 hours
for entity_id, prediction in zip(entity_ids, predictions):
# Store as JSON
key = f"pred:{entity_id}"
value = json.dumps({
'prediction': float(prediction),
'timestamp': time.time()
})
pipeline.setex(key, ttl_seconds, value)
pipeline.execute()
def get_prediction(self, entity_id: str) -> float:
"""
Lookup precomputed prediction
Fast O(1) lookup
"""
key = f"pred:{entity_id}"
value = self.redis.get(key)
if value is None:
# Prediction not found or expired
return None
data = json.loads(value)
return data['prediction']
# Usage
import pandas as pd
import time
# Initialize
redis_client = redis.Redis(host='localhost', port=6379, db=0)
model = load_trained_model() # Your trained model
batch_system = BatchInferenceSystem(model, redis_client)
# Run batch prediction (e.g., daily cron job)
user_ids = fetch_all_user_ids() # Get all users
features_df = fetch_user_features(user_ids) # Get features
num_preds = batch_system.run_batch_prediction(user_ids, features_df)
print(f"Generated {num_preds} predictions")
# Later, application looks up prediction
prediction = batch_system.get_prediction("user_12345")
print(f"Prediction: {prediction}")
Spark-based Batch Inference
For large-scale batch processing:
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf, PandasUDFType
import pandas as pd
class SparkBatchInference:
"""
Distributed batch inference using PySpark
Scales to billions of predictions
"""
def __init__(self, model_path):
self.spark = SparkSession.builder \
.appName("BatchInference") \
.getOrCreate()
self.model_path = model_path
def predict_spark(self, features_df_spark):
"""
Distribute prediction across cluster
Args:
features_df_spark: Spark DataFrame with features
Returns:
Spark DataFrame with predictions
"""
model_path = self.model_path
# Define pandas UDF for prediction
@pandas_udf("double", PandasUDFType.SCALAR)
def predict_udf(*features):
# Load model once per executor
import joblib
model = joblib.load(model_path)
# Create feature matrix
X = pd.DataFrame({
f'feature_{i}': features[i]
for i in range(len(features))
})
# Predict
predictions = model.predict(X.values)
return pd.Series(predictions)
# Apply UDF
feature_cols = [col for col in features_df_spark.columns if col.startswith('feature_')]
result_df = features_df_spark.withColumn(
'prediction',
predict_udf(*feature_cols)
)
return result_df
def run_batch_job(self, input_path, output_path):
"""
Full batch inference pipeline
Args:
input_path: S3/HDFS path to input data
output_path: S3/HDFS path to save predictions
"""
# Read input
df = self.spark.read.parquet(input_path)
# Predict
predictions_df = self.predict_spark(df)
# Write output
predictions_df.write.parquet(output_path, mode='overwrite')
print(f"Batch prediction complete. Output: {output_path}")
# Usage
spark_batch = SparkBatchInference(model_path='s3://models/my_model.pkl')
spark_batch.run_batch_job(
input_path='s3://data/user_features/',
output_path='s3://predictions/daily/2025-01-15/'
)
Real-Time Inference
Compute predictions on-demand when requested.
Architecture
┌─────────────────────────────────────────────────────────┐
│ Real-Time Inference System │
├─────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ │
│ │ Load │◀─── Model Registry │
│ │ Balancer │ │
│ └──────┬───────┘ │
│ │ │
│ ┌────▼─────────────────────────────┐ │
│ │ Model Serving Instances │ │
│ │ ┌─────────┐ ┌─────────┐ │ │
│ │ │ Model 1 │ │ Model 2 │ ... │ │
│ │ │ (GPU) │ │ (GPU) │ │ │
│ │ └─────────┘ └─────────┘ │ │
│ └────┬─────────────────────────────┘ │
│ │ │
│ ┌────▼──────────┐ ┌──────────────┐ │
│ │ Feature │────▶│ Feature │ │
│ │ Service │ │ Store │ │
│ │ - Online │ │ (Redis) │ │
│ │ features │ └──────────────┘ │
│ └───────────────┘ │
│ │
└─────────────────────────────────────────────────────────┘
Flow:
1. Request arrives with user/item ID
2. Fetch features from feature store
3. Compute additional online features
4. Model predicts
5. Return prediction
Implementation
from fastapi import FastAPI
import numpy as np
from typing import Dict
import torch
app = FastAPI()
class RealTimeInferenceService:
"""
Real-time inference service
Serves predictions with low latency
"""
def __init__(self, model, feature_store):
self.model = model
self.feature_store = feature_store
# Warm up model
self._warmup()
def _warmup(self):
"""Warm up model with dummy prediction"""
dummy_features = np.random.randn(1, self.model.input_dim)
_ = self.model.predict(dummy_features)
def get_features(self, entity_id: str) -> Dict:
"""
Fetch features for entity
Combines precomputed + real-time features
"""
# Fetch precomputed features from Redis
precomputed_raw = self.feature_store.get(f"features:{entity_id}")
precomputed = {}
if precomputed_raw:
try:
precomputed = json.loads(precomputed_raw)
except Exception:
precomputed = {}
if precomputed is None:
# Fallback: compute features on-the-fly
precomputed = self._compute_features_fallback(entity_id)
# Add real-time features
realtime_features = self._compute_realtime_features(entity_id)
# Combine
features = {**precomputed, **realtime_features}
return features
def _compute_realtime_features(self, entity_id: str) -> Dict:
"""
Compute features that must be fresh
E.g., time of day, user's current session, etc.
"""
import datetime
now = datetime.datetime.now()
return {
'hour_of_day': now.hour,
'day_of_week': now.weekday(),
'is_weekend': 1 if now.weekday() >= 5 else 0
}
def _compute_features_fallback(self, entity_id: str) -> Dict:
"""Fallback feature computation"""
# Query database, compute on-the-fly
# This is slower but ensures we can always serve
return {}
def predict(self, entity_id: str) -> float:
"""
Real-time prediction
Returns:
Prediction score
"""
# Get features
features = self.get_features(entity_id)
# Convert to numpy array (assuming fixed feature order)
feature_vector = np.array([
features.get(f'feature_{i}', 0.0)
for i in range(self.model.input_dim)
]).reshape(1, -1)
# Predict
prediction = self.model.predict(feature_vector)[0]
return float(prediction)
# FastAPI endpoints
realtime_service = RealTimeInferenceService(model, redis_client)
@app.get("/predict/{entity_id}")
async def predict_endpoint(entity_id: str):
"""
Real-time prediction endpoint
GET /predict/user_12345
"""
try:
prediction = realtime_service.predict(entity_id)
return {
'entity_id': entity_id,
'prediction': prediction,
'timestamp': time.time()
}
except Exception as e:
from fastapi import HTTPException
raise HTTPException(status_code=500, detail={'error': str(e), 'entity_id': entity_id})
# Run with: uvicorn app:app --host 0.0.0.0 --port 8000
TensorFlow Serving
Production-grade model serving:
import requests
import json
class TensorFlowServingClient:
"""
Client for TensorFlow Serving
High-performance model serving
"""
def __init__(self, server_url, model_name, model_version=None):
self.server_url = server_url
self.model_name = model_name
self.model_version = model_version or 'latest'
# Endpoint
if self.model_version == 'latest':
self.endpoint = f"{server_url}/v1/models/{model_name}:predict"
else:
self.endpoint = f"{server_url}/v1/models/{model_name}/versions/{model_version}:predict"
def predict(self, instances: List[List[float]]) -> List[float]:
"""
Send prediction request to TF Serving
Args:
instances: List of feature vectors
Returns:
List of predictions
"""
# Prepare request
payload = {
"signature_name": "serving_default",
"instances": instances
}
# Send request
response = requests.post(
self.endpoint,
data=json.dumps(payload),
headers={'Content-Type': 'application/json'}
)
if response.status_code != 200:
raise Exception(f"Prediction failed: {response.text}")
# Parse response
result = response.json()
predictions = result['predictions']
return predictions
# Usage
tf_client = TensorFlowServingClient(
server_url='http://localhost:8501',
model_name='recommendation_model',
model_version='3'
)
# Predict
features = [[0.1, 0.5, 0.3, 0.9]]
predictions = tf_client.predict(features)
print(f"Prediction: {predictions[0]}")
Hybrid Approach
Combine batch and real-time for optimal performance.
Architecture
┌────────────────────────────────────────────────────────┐
│ Hybrid Inference System │
├────────────────────────────────────────────────────────┤
│ │
│ ┌────────────────┐ ┌────────────────┐ │
│ │ Batch Pipeline │ │ Real-Time API │ │
│ │ (Daily) │ │ │ │
│ └───────┬────────┘ └───────┬────────┘ │
│ │ │ │
│ ▼ ▼ │
│ ┌──────────────────────────────────────────┐ │
│ │ Prediction Cache (Redis) │ │
│ │ ┌────────────┐ ┌────────────┐ │ │
│ │ │ Batch │ │ Real-time │ │ │
│ │ │ Predictions│ │ Predictions│ │ │
│ │ │ (TTL: 24h) │ │ (TTL: 1h) │ │ │
│ │ └────────────┘ └────────────┘ │ │
│ └──────────────────────────────────────────┘ │
│ ▲ │
│ │ │
│ ┌──────┴────────┐ │
│ │ Application │ │
│ │ 1. Check cache│ │
│ │ 2. Fallback to│ │
│ │ real-time │ │
│ └───────────────┘ │
│ │
└────────────────────────────────────────────────────────┘
Implementation
class HybridInferenceSystem:
"""
Hybrid system: batch + real-time
- Fast path: Use batch predictions if available
- Slow path: Compute real-time if needed
"""
def __init__(self, batch_system, realtime_system):
self.batch = batch_system
self.realtime = realtime_system
self.cache_hit_counter = 0
self.cache_miss_counter = 0
def predict(self, entity_id: str, max_staleness_hours: int = 24) -> Dict:
"""
Get prediction with automatic fallback
Args:
entity_id: Entity to predict for
max_staleness_hours: Maximum age of batch prediction
Returns:
{
'prediction': float,
'source': 'batch' | 'realtime',
'timestamp': float
}
"""
# Try batch prediction first
batch_pred_value = self.batch.get_prediction(entity_id)
if batch_pred_value is not None:
# If batch system returns only a float, treat as fresh within TTL of Redis
self.cache_hit_counter += 1
return {
'prediction': batch_pred_value,
'source': 'batch',
'timestamp': time.time(),
'cache_hit': True
}
# Fallback to real-time
self.cache_miss_counter += 1
realtime_pred = self.realtime.predict(entity_id)
return {
'prediction': realtime_pred,
'source': 'realtime',
'timestamp': time.time(),
'cache_hit': False
}
def get_cache_hit_rate(self) -> float:
"""Calculate cache hit rate"""
total = self.cache_hit_counter + self.cache_miss_counter
if total == 0:
return 0.0
return self.cache_hit_counter / total
# Usage
hybrid = HybridInferenceSystem(batch_system, realtime_service)
# Predict for user
result = hybrid.predict('user_12345', max_staleness_hours=12)
print(f"Prediction: {result['prediction']}")
print(f"Source: {result['source']}")
print(f"Cache hit rate: {hybrid.get_cache_hit_rate():.2%}")
Decision Framework
When to use which approach:
Use Batch Inference When:
✅ Latency is not critical (recommendations, email campaigns)
✅ Predictions needed for all entities (e.g., all users)
✅ Features are expensive to compute
✅ Model is large/slow
✅ Cost optimization is priority
✅ Predictions don’t change frequently
Examples:
- Daily email recommendations
- Product catalog rankings
- Weekly personalized content
- Batch fraud scoring
Use Real-Time Inference When:
✅ Low latency required (< 100ms)
✅ Fresh features critical (current context)
✅ Predictions for small subset (active users)
✅ Immediate user feedback (search, ads)
✅ High-value decisions (fraud detection)
Examples:
- Search ranking
- Ad serving
- Real-time fraud detection
- Live recommendation widgets
Use Hybrid When:
✅ Mix of latency requirements
✅ Want cost + performance
✅ Can tolerate some staleness
✅ Variable traffic patterns
✅ Graceful degradation needed
Examples:
- Homepage recommendations (batch) + search (real-time)
- Social feed (batch) + stories (real-time)
- Product pages (batch) + checkout (real-time)
Cost Comparison
class CostAnalyzer:
"""
Estimate costs for batch vs real-time
"""
def estimate_batch_cost(
self,
num_entities: int,
predictions_per_day: int,
cost_per_compute_hour: float = 3.0
) -> Dict:
"""Estimate daily batch inference cost"""
# Assume 10K predictions/second throughput
throughput = 10_000
# Total predictions
total_preds = num_entities * predictions_per_day
# Compute time needed
compute_seconds = total_preds / throughput
compute_hours = compute_seconds / 3600
# Cost
compute_cost = compute_hours * cost_per_compute_hour
# Storage cost (Redis/DDB)
storage_gb = total_preds * 100 / 1e9 # 100 bytes per prediction
storage_cost = storage_gb * 0.25 # $0.25/GB/month
total_cost = compute_cost + storage_cost
return {
'compute_hours': compute_hours,
'compute_cost': compute_cost,
'storage_cost': storage_cost,
'total_daily_cost': total_cost,
'cost_per_prediction': total_cost / total_preds
}
def estimate_realtime_cost(
self,
requests_per_second: int,
cost_per_instance_hour: float = 5.0,
requests_per_instance: int = 100
) -> Dict:
"""Estimate real-time serving cost"""
# Number of instances needed
num_instances = requests_per_second / requests_per_instance
num_instances = int(np.ceil(num_instances * 1.5)) # 50% headroom
# Daily cost
daily_hours = 24
daily_cost = num_instances * cost_per_instance_hour * daily_hours
# Predictions per day
daily_requests = requests_per_second * 86400
return {
'num_instances': num_instances,
'daily_cost': daily_cost,
'cost_per_prediction': daily_cost / daily_requests
}
# Compare costs
analyzer = CostAnalyzer()
# Batch: 1M users, predict once/day
batch_cost = analyzer.estimate_batch_cost(
num_entities=1_000_000,
predictions_per_day=1
)
print("Batch Inference:")
print(f" Daily cost: ${batch_cost['total_daily_cost']:.2f}")
print(f" Cost per prediction: ${batch_cost['cost_per_prediction']:.6f}")
# Real-time: 100 QPS average
realtime_cost = analyzer.estimate_realtime_cost(
requests_per_second=100
)
print("\nReal-Time Inference:")
print(f" Daily cost: ${realtime_cost['daily_cost']:.2f}")
print(f" Cost per prediction: ${realtime_cost['cost_per_prediction']:.6f}")
# Compare
savings = (realtime_cost['daily_cost'] - batch_cost['total_daily_cost']) / realtime_cost['daily_cost'] * 100
print(f"\nBatch is {savings:.1f}% cheaper!")
Advanced Patterns
Multi-Tier Caching
Layer multiple caches for optimal performance.
class MultiTierInferenceSystem:
"""
Multi-tier caching: Memory → Redis → Compute
Optimizes for different latency/cost profiles
"""
def __init__(self, model, redis_client):
self.model = model
self.redis = redis_client
# In-memory cache (fastest)
self.memory_cache = {}
self.memory_cache_size = 10000
# Statistics
self.stats = {
'memory_hits': 0,
'redis_hits': 0,
'compute': 0,
'total_requests': 0
}
def predict(self, entity_id: str) -> float:
"""
Predict with multi-tier caching
Tier 1: In-memory cache (~1ms)
Tier 2: Redis cache (~5ms)
Tier 3: Compute prediction (~50ms)
"""
self.stats['total_requests'] += 1
# Tier 1: Memory cache
if entity_id in self.memory_cache:
self.stats['memory_hits'] += 1
return self.memory_cache[entity_id]
# Tier 2: Redis cache
redis_key = f"pred:{entity_id}"
cached = self.redis.get(redis_key)
if cached is not None:
self.stats['redis_hits'] += 1
prediction = float(cached)
# Promote to memory cache
self._add_to_memory_cache(entity_id, prediction)
return prediction
# Tier 3: Compute
self.stats['compute'] += 1
prediction = self._compute_prediction(entity_id)
# Write to both caches
self.redis.setex(redis_key, 3600, str(prediction)) # 1 hour TTL
self._add_to_memory_cache(entity_id, prediction)
return prediction
def _add_to_memory_cache(self, entity_id: str, prediction: float):
"""Add to memory cache with LRU eviction"""
if len(self.memory_cache) >= self.memory_cache_size:
# Simple eviction: remove first item
# In production, use LRU cache
self.memory_cache.pop(next(iter(self.memory_cache)))
self.memory_cache[entity_id] = prediction
def _compute_prediction(self, entity_id: str) -> float:
"""Compute prediction from model"""
# Fetch features
features = self._get_features(entity_id)
# Predict
prediction = self.model.predict([features])[0]
return float(prediction)
def _get_features(self, entity_id: str):
"""Fetch features for entity"""
# Placeholder
return [0.1, 0.2, 0.3, 0.4, 0.5]
def get_cache_stats(self) -> dict:
"""Get cache performance statistics"""
total = self.stats['total_requests']
if total == 0:
return self.stats
return {
**self.stats,
'memory_hit_rate': self.stats['memory_hits'] / total * 100,
'redis_hit_rate': self.stats['redis_hits'] / total * 100,
'compute_rate': self.stats['compute'] / total * 100,
'overall_cache_hit_rate':
(self.stats['memory_hits'] + self.stats['redis_hits']) / total * 100
}
# Usage
system = MultiTierInferenceSystem(model, redis_client)
# Make predictions
for entity_id in ['user_1', 'user_2', 'user_1', 'user_3', 'user_1']:
prediction = system.predict(entity_id)
print(f"{entity_id}: {prediction:.4f}")
stats = system.get_cache_stats()
print(f"\nCache hit rate: {stats['overall_cache_hit_rate']:.1f}%")
print(f"Memory: {stats['memory_hit_rate']:.1f}%, Redis: {stats['redis_hit_rate']:.1f}%, Compute: {stats['compute_rate']:.1f}%")
Prediction Warming
Precompute predictions for likely requests.
class PredictionWarmer:
"""
Warm cache with predictions for likely-to-be-requested entities
Use case: Preload predictions for active users
"""
def __init__(self, model, cache):
self.model = model
self.cache = cache
def warm_predictions(
self,
entity_ids: List[str],
batch_size: int = 100
):
"""
Warm cache for list of entities
Args:
entity_ids: Entities to warm
batch_size: Batch size for efficient computation
"""
num_warmed = 0
for i in range(0, len(entity_ids), batch_size):
batch_ids = entity_ids[i:i+batch_size]
# Batch feature fetching
features = self._batch_get_features(batch_ids)
# Batch prediction
predictions = self.model.predict(features)
# Write to cache
for entity_id, prediction in zip(batch_ids, predictions):
self.cache.set(f"pred:{entity_id}", float(prediction), ex=3600)
num_warmed += 1
return num_warmed
def _batch_get_features(self, entity_ids: List[str]):
"""Fetch features for multiple entities"""
# In production: Batch query to feature store
return [[0.1] * 5 for _ in entity_ids]
def warm_by_activity(
self,
lookback_hours: int = 24,
top_k: int = 10000
):
"""
Warm cache for most active entities
Args:
lookback_hours: Look back this many hours for activity
top_k: Warm top K most active entities
"""
# Query activity logs
active_entities = self._get_active_entities(lookback_hours, top_k)
# Warm predictions
num_warmed = self.warm_predictions(active_entities)
return {
'num_warmed': num_warmed,
'lookback_hours': lookback_hours,
'timestamp': time.time()
}
def _get_active_entities(self, lookback_hours: int, top_k: int) -> List[str]:
"""Get most active entities from activity logs"""
# Placeholder: Query activity database
return [f'user_{i}' for i in range(top_k)]
# Usage: Warm cache every hour for active users
warmer = PredictionWarmer(model, redis_client)
# Warm cache for top 10K active users
result = warmer.warm_by_activity(lookback_hours=1, top_k=10000)
print(f"Warmed {result['num_warmed']} predictions")
Conditional Batch Updates
Update batch predictions conditionally based on staleness/changes.
class ConditionalBatchUpdater:
"""
Update batch predictions only when necessary
Strategies:
- Update only if features changed significantly
- Update only if prediction is stale
- Update only for active entities
"""
def __init__(self, model, cache, feature_store):
self.model = model
self.cache = cache
self.feature_store = feature_store
def update_if_changed(
self,
entity_ids: List[str],
change_threshold: float = 0.1
) -> dict:
"""
Update predictions only if features changed significantly
Args:
entity_ids: Entities to check
change_threshold: Update if features changed by this much
Returns:
Statistics on updates
"""
num_checked = 0
num_updated = 0
for entity_id in entity_ids:
num_checked += 1
# Get current features
current_features = self.feature_store.get(f"features:{entity_id}")
# Get cached features (when prediction was made)
cached_features = self.feature_store.get(f"cached_features:{entity_id}")
# Check if features changed significantly
if self._features_changed(cached_features, current_features, change_threshold):
# Recompute prediction
prediction = self.model.predict([current_features])[0]
# Update cache
self.cache.set(f"pred:{entity_id}", float(prediction), ex=3600)
self.feature_store.set(f"cached_features:{entity_id}", current_features)
num_updated += 1
return {
'num_checked': num_checked,
'num_updated': num_updated,
'update_rate': num_updated / num_checked * 100 if num_checked > 0 else 0
}
def _features_changed(
self,
old_features,
new_features,
threshold: float
) -> bool:
"""Check if features changed significantly"""
if old_features is None or new_features is None:
return True
# Compute L2 distance
diff = np.linalg.norm(np.array(new_features) - np.array(old_features))
return diff > threshold
Graceful Degradation
Handle failures gracefully with fallback strategies.
class GracefulDegradationSystem:
"""
Inference system with graceful degradation
Fallback chain:
1. Try real-time prediction
2. Fallback to batch prediction (if available)
3. Fallback to default/fallback prediction
"""
def __init__(
self,
realtime_service,
batch_cache,
default_prediction: float = 0.5
):
self.realtime = realtime_service
self.batch_cache = batch_cache
self.default_prediction = default_prediction
# Monitoring
self.degradation_stats = {
'realtime': 0,
'batch_fallback': 0,
'default_fallback': 0
}
def predict_with_fallback(
self,
entity_id: str,
max_latency_ms: int = 100
) -> dict:
"""
Predict with fallback strategies
Args:
entity_id: Entity to predict for
max_latency_ms: Maximum acceptable latency
Returns:
{
'prediction': float,
'source': str,
'latency_ms': float
}
"""
start = time.perf_counter()
# Try real-time prediction
try:
prediction = self.realtime.predict(entity_id)
elapsed_ms = (time.perf_counter() - start) * 1000
if elapsed_ms <= max_latency_ms:
self.degradation_stats['realtime'] += 1
return {
'prediction': prediction,
'source': 'realtime',
'latency_ms': elapsed_ms
}
except Exception as e:
print(f"Real-time prediction failed: {e}")
# Fallback 1: Batch cache
try:
batch_pred = self.batch_cache.get(f"pred:{entity_id}")
if batch_pred is not None:
elapsed_ms = (time.perf_counter() - start) * 1000
self.degradation_stats['batch_fallback'] += 1
return {
'prediction': float(batch_pred),
'source': 'batch_fallback',
'latency_ms': elapsed_ms,
'warning': 'Using stale batch prediction'
}
except Exception as e:
print(f"Batch fallback failed: {e}")
# Fallback 2: Default prediction
elapsed_ms = (time.perf_counter() - start) * 1000
self.degradation_stats['default_fallback'] += 1
return {
'prediction': self.default_prediction,
'source': 'default_fallback',
'latency_ms': elapsed_ms,
'warning': 'Using default prediction - service degraded'
}
def get_health_status(self) -> dict:
"""Get system health metrics"""
total = sum(self.degradation_stats.values())
if total == 0:
return {'status': 'no_traffic'}
realtime_rate = self.degradation_stats['realtime'] / total * 100
if realtime_rate > 95:
status = 'healthy'
elif realtime_rate > 80:
status = 'degraded'
else:
status = 'critical'
return {
'status': status,
'realtime_rate': realtime_rate,
'batch_fallback_rate': self.degradation_stats['batch_fallback'] / total * 100,
'default_fallback_rate': self.degradation_stats['default_fallback'] / total * 100,
'total_requests': total
}
Real-World Case Studies
Netflix: Hybrid Recommendations
Challenge: Personalized recommendations for 200M+ users
Solution:
- Batch: Precompute top-N recommendations for all users daily
- Real-time: Rerank based on current session context
- Result: < 100ms latency with personalized results
Architecture:
Daily Batch Job (Spark)
↓
Precompute Top 1000 movies per user
↓
Store in Cassandra
↓
Real-time API fetches top 1000 + reranks based on:
- Current time of day
- Device type
- Recent viewing history
↓
Return Top 20 to UI
Uber: Real-Time ETA Prediction
Challenge: Predict arrival time for millions of rides
Solution:
- Real-time only: ETA must reflect current traffic
- Strategy: Fast model (< 50ms inference)
- Features: Current location, traffic data, historical patterns
Why not batch:
- Traffic changes rapidly
- Each ride is unique
- Requires current GPS coordinates
LinkedIn: People You May Know
Challenge: Suggest connections for 800M+ users
Solution:
- Batch: Graph algorithms compute connection candidates (weekly)
- Real-time: Scoring based on user activity
- Result: Balance compute cost with personalization
Hybrid Strategy:
Weekly Batch:
- Graph traversal (2nd, 3rd degree connections)
- Identify ~1000 candidates per user
- Store in candidate DB
Real-time (on page load):
- Fetch candidates from DB
- Score based on:
* Recent profile views
* Shared groups/companies
* Mutual connections
- Return top 10
Monitoring & Observability
Key Metrics to Track
class InferenceMetrics:
"""
Track comprehensive inference metrics
"""
def __init__(self):
self.metrics = {
'latency_p50': [],
'latency_p95': [],
'latency_p99': [],
'cache_hit_rate': [],
'error_rate': [],
'throughput': [],
'cost_per_prediction': []
}
def record_prediction(
self,
latency_ms: float,
cache_hit: bool,
error: bool,
cost: float
):
"""Record single prediction metrics"""
pass # Implementation details
def get_dashboard_metrics(self) -> dict:
"""
Get metrics for monitoring dashboard
Returns:
Key metrics for alerting
"""
return {
'latency_p50_ms': np.median(self.metrics['latency_p50']),
'latency_p99_ms': np.percentile(self.metrics['latency_p99'], 99),
'cache_hit_rate': np.mean(self.metrics['cache_hit_rate']) * 100,
'error_rate': np.mean(self.metrics['error_rate']) * 100,
'qps': np.mean(self.metrics['throughput']),
'cost_per_1k_predictions': np.mean(self.metrics['cost_per_prediction']) * 1000
}
SLA Definition
class InferenceSLA:
"""
Define and monitor SLA for inference service
"""
def __init__(self):
self.sla_targets = {
'p99_latency_ms': 100,
'availability': 99.9,
'error_rate': 0.1 # 0.1%
}
def check_sla_compliance(self, metrics: dict) -> dict:
"""
Check if current metrics meet SLA
Returns:
SLA compliance report
"""
compliance = {}
for metric, target in self.sla_targets.items():
actual = metrics.get(metric, 0)
if metric == 'error_rate':
# Lower is better
meets_sla = actual <= target
else:
# Check if within range (e.g., latency or availability)
meets_sla = actual <= target if 'latency' in metric else actual >= target
compliance[metric] = {
'target': target,
'actual': actual,
'meets_sla': meets_sla,
'margin': target - actual if 'latency' in metric or 'error' in metric else actual - target
}
return compliance
Key Takeaways
✅ Batch inference precomputes predictions, cheaper, higher latency
✅ Real-time inference computes on-demand, expensive, lower latency
✅ Hybrid approach combines both for optimal cost/performance
✅ Multi-tier caching (memory → Redis → compute) optimizes latency
✅ Prediction warming preloads cache for likely requests
✅ Conditional updates reduce unnecessary recomputation
✅ Graceful degradation ensures reliability via fallback strategies
✅ Latency vs cost is the fundamental trade-off
✅ Feature freshness often determines the choice
✅ Most systems use hybrid: batch for bulk, real-time for edge cases
✅ Cache hit rate critical metric for hybrid systems
✅ SLA monitoring ensures service quality
Originally published at: arunbaby.com/ml-system-design/0005-batch-realtime-inference
If you found this helpful, consider sharing it with others who might benefit.