28 minute read

Design efficient caching layers for ML systems to reduce latency, save compute costs, and improve user experience at scale.

TL;DR

Caching is essential for achieving sub-100ms ML serving latency. This post covers cache eviction policies (LRU, LFU, TTL), distributed caching with Redis, multi-level cache architectures combining local and shared caches, cache warming strategies, invalidation via pub/sub, feature store caching, stampede prevention with locking and probabilistic early expiration, write-through vs write-back patterns, and cache performance analysis with size optimization. For the broader ML system design series, caching sits at the intersection of infrastructure and model performance. See also how model serving architecture integrates caching at every layer.

A three-tier shelf of transparent containers — top shelf with small fast-access jars

Introduction

Caching temporarily stores computed results to serve future requests faster. In ML systems, caching is critical for:

Why caching matters:

  • Latency reduction: ms instead of seconds for predictions
  • Cost savings: Avoid expensive model inference
  • Scalability: Handle more requests with same resources
  • Availability: Serve cached results if model service is down

Common caching scenarios in ML:

  • Model predictions (feature → prediction)
  • Feature computations (raw data → engineered features)
  • Embeddings (entity → vector representation)
  • Model artifacts (model weights, config)
  • Training data (preprocessed datasets)

Cache Hierarchy

┌────────────────────────────────────────────────────┐
│ Client/Browser │
│ (Local Storage, Cookies) │
└──────────────────────┬─────────────────────────────┘
 │
 ▼
┌────────────────────────────────────────────────────┐
│ CDN Cache │
│ (CloudFlare, Akamai, CloudFront) │
└──────────────────────┬─────────────────────────────┘
 │
 ▼
┌────────────────────────────────────────────────────┐
│ Application Cache │
│ (Redis, Memcached, Local) │
└──────────────────────┬─────────────────────────────┘
 │
 ▼
┌────────────────────────────────────────────────────┐
│ ML Model Service │
│ (TensorFlow Serving, etc.) │
└──────────────────────┬─────────────────────────────┘
 │
 ▼
┌────────────────────────────────────────────────────┐
│ Database │
│ (PostgreSQL, MongoDB, etc.) │
└────────────────────────────────────────────────────┘

Cache Eviction Policies

LRU (Least Recently Used)

Most common for ML systems

from collections import OrderedDict

class LRUCache:
    """
    LRU Cache implementation

    Evicts least recently used items when capacity is reached
    """

    def __init__(self, capacity: int):
        self.cache = OrderedDict()
        self.capacity = capacity

    def get(self, key):
        """
        Get value and mark as recently used

        Time: O(1)
        """
        if key not in self.cache:
            return None

            # Move to end (most recent)
            self.cache.move_to_end(key)
            return self.cache[key]

    def put(self, key, value):
        """
        Put key-value pair

        Time: O(1)
        """
        if key in self.cache:
            # Update and move to end
            self.cache.move_to_end(key)

            self.cache[key] = value

            # Evict if over capacity
            if len(self.cache) > self.capacity:
                # Remove first item (least recently used)
                self.cache.popitem(last=False)

    def stats(self):
        """Get cache statistics"""
        return {
        'size': len(self.cache),
        'capacity': self.capacity,
        'utilization': len(self.cache) / self.capacity
        }

        # Usage
        cache = LRUCache(capacity=1000)

        # Cache predictions
    def get_prediction_cached(features, model):
        cache_key = hash(tuple(features))

        # Check cache
        cached_result = cache.get(cache_key)
        if cached_result is not None:
            return cached_result

            # Compute prediction
            prediction = model.predict([features])[0]

            # Cache result
            cache.put(cache_key, prediction)

            return prediction

LFU (Least Frequently Used)

Good for skewed access patterns

from collections import defaultdict
import heapq

class LFUCache:
    """
    LFU Cache - evicts least frequently used items

    Better for "hot" items that are accessed repeatedly
    """

    def __init__(self, capacity: int):
        self.capacity = capacity
        self.cache = {} # key -> (value, frequency)
        self.freq_map = defaultdict(set) # frequency -> set of keys
        self.min_freq = 0
        self.access_count = 0

    def get(self, key):
        """Get value and increment frequency"""
        if key not in self.cache:
            return None

            value, freq = self.cache[key]

            # Update frequency
            self.freq_map[freq].remove(key)
            if not self.freq_map[freq] and freq == self.min_freq:
                self.min_freq += 1

                new_freq = freq + 1
                self.freq_map[new_freq].add(key)
                self.cache[key] = (value, new_freq)

                return value

    def put(self, key, value):
        """Put key-value pair"""
        if self.capacity == 0:
            return

            if key in self.cache:
                # Update existing key
                _, freq = self.cache[key]
                self.cache[key] = (value, freq)
                self.get(key) # Update frequency
                return

                # Evict if at capacity
                if len(self.cache) >= self.capacity:
                    # Remove item with minimum frequency
                    evict_key = next(iter(self.freq_map[self.min_freq]))
                    self.freq_map[self.min_freq].remove(evict_key)
                    del self.cache[evict_key]

                    # Add new key
                    self.cache[key] = (value, 1)
                    self.freq_map[1].add(key)
                    self.min_freq = 1

    def get_top_k(self, k: int):
        """Get top k most frequently accessed items"""
        items = [(freq, key) for key, (val, freq) in self.cache.items()]
        return heapq.nlargest(k, items)

        # Usage for embeddings (frequently accessed)
        embedding_cache = LFUCache(capacity=10000)

    def get_embedding_cached(entity_id, embedding_model):
        cached_emb = embedding_cache.get(entity_id)
        if cached_emb is not None:
            return cached_emb

            embedding = embedding_model.encode(entity_id)
            embedding_cache.put(entity_id, embedding)

            return embedding

TTL (Time-To-Live) Cache

Good for time-sensitive data

import time

class TTLCache:
    """
    TTL Cache - items expire after specified time

    Perfect for:
        - User sessions
        - Real-time features (stock prices, weather)
        - Model predictions that become stale
        """

    def __init__(self, default_ttl_seconds=3600):
        self.cache = {} # key -> (value, expiration_time)
        self.default_ttl = default_ttl_seconds

    def get(self, key):
        """Get value if not expired"""
        if key not in self.cache:
            return None

            value, expiration = self.cache[key]

            # Check expiration
            if time.time() > expiration:
                del self.cache[key]
                return None

                return value

    def put(self, key, value, ttl=None):
        """Put key-value pair with TTL"""
        if ttl is None:
            ttl = self.default_ttl

            expiration = time.time() + ttl
            self.cache[key] = (value, expiration)

    def cleanup(self):
        """Remove expired entries"""
        current_time = time.time()
        expired_keys = [
        k for k, (v, exp) in self.cache.items()
        if current_time > exp
        ]

        for key in expired_keys:
            del self.cache[key]

            return len(expired_keys)

            # Usage for time-sensitive predictions
            prediction_cache = TTLCache(default_ttl_seconds=300) # 5 minutes

    def predict_stock_price(symbol, model):
        """Predictions expire quickly for real-time data"""
        cached = prediction_cache.get(symbol)
        if cached is not None:
            return cached

            prediction = model.predict(symbol)
            prediction_cache.put(symbol, prediction, ttl=60) # 1 minute TTL

            return prediction

Distributed Caching

Redis-Based Cache

import redis
import json
import pickle
import hashlib

class RedisMLCache:
    """
    Redis-based cache for ML predictions

    Features:
        - Distributed across multiple servers
        - Persistence
        - TTL support
        - Pub/sub for cache invalidation
        """

    def __init__(self, host='localhost', port=6379, db=0):
        self.redis_client = redis.Redis(
        host=host,
        port=port,
        db=db,
        decode_responses=False
        )

        self.hits = 0
        self.misses = 0

    def _serialize(self, obj):
        """Serialize Python object"""
        return pickle.dumps(obj)

    def _deserialize(self, data):
        """Deserialize to Python object"""
        if data is None:
            return None
            return pickle.loads(data)

    def _make_key(self, prefix, *args):
        """Generate cache key"""
        # Hash arguments for consistent key
        key_str = f"{prefix}:{':'.join(map(str, args))}"
        return key_str

    def get_prediction(self, model_id, features):
        """
        Get cached prediction

        Args:
            model_id: Model identifier
            features: Feature vector (hashable)

            Returns:
                Cached prediction or None
                """
                # Create cache key
                feature_hash = hashlib.md5(
                str(features).encode()
                ).hexdigest()
                key = self._make_key('prediction', model_id, feature_hash)

                # Get from Redis
                cached = self.redis_client.get(key)

                if cached is not None:
                    self.hits += 1
                    return self._deserialize(cached)

                    self.misses += 1
                    return None

    def set_prediction(self, model_id, features, prediction, ttl=3600):
        """Cache prediction with TTL"""
        feature_hash = hashlib.md5(
        str(features).encode()
        ).hexdigest()
        key = self._make_key('prediction', model_id, feature_hash)

        # Serialize and store
        value = self._serialize(prediction)
        self.redis_client.setex(key, ttl, value)

    def get_embedding(self, entity_id):
        """Get cached embedding"""
        key = self._make_key('embedding', entity_id)
        cached = self.redis_client.get(key)

        if cached:
            self.hits += 1
            # Embeddings stored as JSON arrays
            return json.loads(cached)

            self.misses += 1
            return None

    def set_embedding(self, entity_id, embedding, ttl=None):
        """Cache embedding"""
        key = self._make_key('embedding', entity_id)
        value = json.dumps(embedding.tolist() if hasattr(embedding, 'tolist') else embedding)

        if ttl:
            self.redis_client.setex(key, ttl, value)
        else:
            self.redis_client.set(key, value)

    def invalidate_model(self, model_id):
        """Invalidate all predictions for a model (SCAN + DEL)"""
        pattern = self._make_key('prediction', model_id, '*')
        cursor = 0
        total_deleted = 0

        while True:
            cursor, keys = self.redis_client.scan(cursor=cursor, match=pattern, count=1000)
            if keys:
                total_deleted += self.redis_client.delete(*keys)
                if cursor == 0:
                    break

                    return total_deleted

    def get_stats(self):
        """Get cache statistics"""
        total_requests = self.hits + self.misses
        hit_rate = self.hits / total_requests if total_requests > 0 else 0

        return {
        'hits': self.hits,
        'misses': self.misses,
        'hit_rate': hit_rate,
        'total_keys': self.redis_client.dbsize()
        }

        # Usage
        cache = RedisMLCache(host='localhost', port=6379)

    def predict_with_cache(features, model, model_id):
        """Predict with Redis caching"""
        # Check cache
        cached = cache.get_prediction(model_id, features)
        if cached is not None:
            return cached

            # Compute prediction
            prediction = model.predict([features])[0]

            # Cache result
            cache.set_prediction(model_id, features, prediction, ttl=3600)

            return prediction

            # Check cache performance
            stats = cache.get_stats()
            print(f"Cache hit rate: {stats['hit_rate']:.2%}")

Multi-Level Cache

class MultiLevelCache:
    """
    Multi-level caching with L1 (local) and L2 (Redis)

    Pattern:
        1. Check L1 (in-memory, fastest)
        2. If miss, check L2 (Redis, shared)
        3. If miss, compute and populate both levels
        """

    def __init__(self, l1_capacity=1000, redis_host='localhost'):
        # L1: Local LRU cache
        self.l1 = LRUCache(capacity=l1_capacity)

        # L2: Redis cache
        self.l2 = RedisMLCache(host=redis_host)

        self.l1_hits = 0
        self.l2_hits = 0
        self.misses = 0

    def get(self, key):
        """Get value from multi-level cache"""
        # Try L1
        value = self.l1.get(key)
        if value is not None:
            self.l1_hits += 1
            return value

            # Try L2
            value = self.l2.redis_client.get(key)
            if value is not None:
                self.l2_hits += 1

                # Populate L1
                value = self.l2._deserialize(value)
                self.l1.put(key, value)

                return value

                # Miss
                self.misses += 1
                return None

    def put(self, key, value, ttl=3600):
        """Put value in both cache levels"""
        # Store in L1
        self.l1.put(key, value)

        # Store in L2
        self.l2.redis_client.setex(
        key,
        ttl,
        self.l2._serialize(value)
        )

    def get_stats(self):
        """Get multi-level cache statistics"""
        total = self.l1_hits + self.l2_hits + self.misses

        return {
        'l1_hits': self.l1_hits,
        'l2_hits': self.l2_hits,
        'misses': self.misses,
        'total_requests': total,
        'l1_hit_rate': self.l1_hits / total if total > 0 else 0,
        'l2_hit_rate': self.l2_hits / total if total > 0 else 0,
        'overall_hit_rate': (self.l1_hits + self.l2_hits) / total if total > 0 else 0
        }

        # Usage
        ml_cache = MultiLevelCache(l1_capacity=1000, redis_host='localhost')

    def get_user_embedding(user_id, embedding_model):
        """Get user embedding with multi-level caching"""
        key = f"user_emb:{user_id}"

        # Try cache
        embedding = ml_cache.get(key)
        if embedding is not None:
            return embedding

            # Compute
            embedding = embedding_model.encode(user_id)

            # Cache
            ml_cache.put(key, embedding, ttl=3600)

            return embedding

Cache Warming Strategies

Proactive Cache Warming

import threading
import time
from queue import Queue

class CacheWarmer:
    """
    Proactively warm cache before requests arrive

    Strategies:
        1. Popular items (based on historical data)
        2. Scheduled warmup (daily, hourly)
        3. Predictive warmup (ML-based)
        """

    def __init__(self, cache, compute_fn):
        self.cache = cache
        self.compute_fn = compute_fn

        self.warmup_queue = Queue()
        self.is_running = False

    def warm_popular_items(self, items, priority='high'):
        """Warm cache with popular items"""
        print(f"Warming {len(items)} popular items...")

        for item in items:
            key, args = item

            # Check if already cached
            if self.cache.get(key) is not None:
                continue

                # Compute and cache
                try:
                    result = self.compute_fn(*args)
                    self.cache.put(key, result)
                except Exception as e:
                    print(f"Error warming {key}: {e}")

    def warm_on_schedule(self, items, interval_seconds=3600):
        """Periodically warm cache"""
    def warmup_worker():
        while self.is_running:
            self.warm_popular_items(items)
            time.sleep(interval_seconds)

            self.is_running = True
            worker = threading.Thread(target=warmup_worker, daemon=True)
            worker.start()

    def stop(self):
        """Stop scheduled warmup"""
        self.is_running = False

        # Usage
    def compute_recommendation(user_id, model):
        """Expensive recommendation computation"""
        return model.recommend(user_id, n=10)

        cache = LRUCache(capacity=10000)
        warmer = CacheWarmer(cache, compute_recommendation)

        # Warm cache with top 1000 users
        popular_users = get_top_1000_active_users()
        items = [
        (f"rec:{user_id}", (user_id, recommendation_model))
        for user_id in popular_users
        ]

        warmer.warm_popular_items(items)

        # Or schedule periodic warmup
        warmer.warm_on_schedule(items, interval_seconds=3600)

Cache Invalidation

Push-Based Invalidation

import redis

class CacheInvalidator:
    """
    Cache invalidation using Redis Pub/Sub

    Pattern:
        - When model updates, publish invalidation message
        - All cache instances subscribe and clear relevant entries
        """

    def __init__(self, redis_host='localhost'):
        self.redis_pub = redis.Redis(host=redis_host)
        self.redis_sub = redis.Redis(host=redis_host)

        self.cache = {}
        self.invalidation_count = 0

    def subscribe_to_invalidations(self, channel='cache:invalidate'):
        """Subscribe to invalidation messages"""
        pubsub = self.redis_sub.pubsub()
        pubsub.subscribe(channel)

    def listen():
        for message in pubsub.listen():
            if message['type'] == 'message':
                self._handle_invalidation(message['data'])

                # Start listener thread
                listener = threading.Thread(target=listen, daemon=True)
                listener.start()

    def _handle_invalidation(self, message):
        """Handle invalidation message"""
        # Message format: "model_id:v2"
        invalidation_key = message.decode('utf-8')

        # Remove matching cache entries
        keys_to_remove = [
        k for k in self.cache.keys()
        if k.startswith(invalidation_key)
        ]

        for key in keys_to_remove:
            del self.cache[key]

            self.invalidation_count += len(keys_to_remove)
            print(f"Invalidated {len(keys_to_remove)} cache entries")

    def invalidate_model(self, model_id):
        """Publish invalidation message"""
        message = f"{model_id}:v"
        self.redis_pub.publish('cache:invalidate', message)

        # Usage
        invalidator = CacheInvalidator()
        invalidator.subscribe_to_invalidations()

        # When model is updated
    def update_model(model_id, new_model):
        """Update model and invalidate cache"""
        # Deploy new model
        deploy_model(new_model)

        # Invalidate all predictions for this model
        invalidator.invalidate_model(model_id)

Feature Store Caching

class FeatureStoreCache:
    """
    Caching layer for feature store

    Features:
        - Cache precomputed features
        - Batch feature retrieval
        - Freshness guarantees
        """

    def __init__(self, redis_client, ttl=3600):
        self.redis = redis_client
        self.ttl = ttl

    def get_features(self, entity_ids, feature_names):
        """
        Get features for multiple entities (batch)

        Args:
            entity_ids: List of entity IDs
            feature_names: List of feature names

            Returns:
                Dict of entity_id -> feature_dict
                """
                results = {}
                cache_misses = []

                # Try cache first
                for entity_id in entity_ids:
                    cache_key = f"features:{entity_id}"
                    cached = self.redis.get(cache_key)

                    if cached:
                        # Parse cached features
                        features = json.loads(cached)

                        # Filter to requested features
                        filtered = {
                        fname: features[fname]
                        for fname in feature_names
                        if fname in features
                        }

                        if len(filtered) == len(feature_names):
                            results[entity_id] = filtered
                        else:
                            cache_misses.append(entity_id)
                        else:
                            cache_misses.append(entity_id)

                            # Compute missing features
                            if cache_misses:
                                computed = self._compute_features(cache_misses, feature_names)

                                # Cache computed features
                                for entity_id, features in computed.items():
                                    self._cache_features(entity_id, features)
                                    results[entity_id] = features

                                    return results

    def _compute_features(self, entity_ids, feature_names):
        """Compute features from feature store"""
        # Call actual feature store
        return compute_features_batch(entity_ids, feature_names)

    def _cache_features(self, entity_id, features):
        """Cache features for entity"""
        cache_key = f"features:{entity_id}"
        self.redis.setex(
        cache_key,
        self.ttl,
        json.dumps(features)
        )

    def invalidate_entity(self, entity_id):
        """Invalidate features for entity"""
        cache_key = f"features:{entity_id}"
        self.redis.delete(cache_key)

        # Usage
        feature_cache = FeatureStoreCache(redis_client, ttl=300)

        # Get features for batch of users
        user_ids = [123, 456, 789]
        feature_names = ['age', 'location', 'purchase_count']

        features = feature_cache.get_features(user_ids, feature_names)

Connection to Linked Lists (DSA)

Cache implementations heavily use linked list concepts:

class DoublyLinkedNode:
    """Node for doubly-linked list (used in LRU)"""
    def __init__(self, key, value):
        self.key = key
        self.value = value
        self.prev = None
        self.next = None

    class ProductionLRUCache:
        """
        Production LRU cache using doubly-linked list

        Connection to DSA:
            - Uses linked list for maintaining order
            - Pointer manipulation similar to reversal
            - O(1) operations through careful pointer management
            """

    def __init__(self, capacity: int):
        self.capacity = capacity
        self.cache = {}

        # Dummy head and tail
        self.head = DoublyLinkedNode(0, 0)
        self.tail = DoublyLinkedNode(0, 0)
        self.head.next = self.tail
        self.tail.prev = self.head

    def _add_node(self, node):
        """Add node right after head"""
        node.prev = self.head
        node.next = self.head.next

        self.head.next.prev = node
        self.head.next = node

    def _remove_node(self, node):
        """Remove node from list"""
        prev_node = node.prev
        next_node = node.next

        prev_node.next = next_node
        next_node.prev = prev_node

    def _move_to_head(self, node):
        """Move node to head (most recently used)"""
        self._remove_node(node)
        self._add_node(node)

    def _pop_tail(self):
        """Remove least recently used (tail.prev)"""
        res = self.tail.prev
        self._remove_node(res)
        return res

    def get(self, key):
        """Get value"""
        node = self.cache.get(key)
        if not node:
            return -1

            self._move_to_head(node)
            return node.value

    def put(self, key, value):
        """Put key-value"""
        node = self.cache.get(key)

        if node:
            node.value = value
            self._move_to_head(node)
        else:
            new_node = DoublyLinkedNode(key, value)
            self.cache[key] = new_node
            self._add_node(new_node)

            if len(self.cache) > self.capacity:
                tail = self._pop_tail()
                del self.cache[tail.key]

Understanding Cache Performance

Cache Hit Rate Analysis

class CachePerformanceAnalyzer:
    """
    Analyze and optimize cache performance

    Key metrics:
        - Hit rate: % of requests served from cache
        - Miss rate: % of requests requiring computation
        - Latency reduction: Time saved by caching
        - Memory efficiency: Cache size vs hit rate
        """

    def __init__(self):
        self.total_requests = 0
        self.cache_hits = 0
        self.cache_misses = 0

        self.hit_latencies = []
        self.miss_latencies = []

    def record_hit(self, latency_ms):
        """Record cache hit"""
        self.cache_hits += 1
        self.total_requests += 1
        self.hit_latencies.append(latency_ms)

    def record_miss(self, latency_ms):
        """Record cache miss"""
        self.cache_misses += 1
        self.total_requests += 1
        self.miss_latencies.append(latency_ms)

    def get_metrics(self):
        """Calculate performance metrics"""
        if self.total_requests == 0:
            return {}

            hit_rate = self.cache_hits / self.total_requests
            miss_rate = self.cache_misses / self.total_requests

            avg_hit_latency = (
            sum(self.hit_latencies) / len(self.hit_latencies)
            if self.hit_latencies else 0
            )

            avg_miss_latency = (
            sum(self.miss_latencies) / len(self.miss_latencies)
            if self.miss_latencies else 0
            )

            # Calculate latency reduction
            avg_latency_with_cache = (
            hit_rate * avg_hit_latency + miss_rate * avg_miss_latency
            )

            latency_reduction = (
            (avg_miss_latency - avg_latency_with_cache) / avg_miss_latency
            if avg_miss_latency > 0 else 0
            )

            return {
            'total_requests': self.total_requests,
            'cache_hits': self.cache_hits,
            'cache_misses': self.cache_misses,
            'hit_rate': hit_rate,
            'miss_rate': miss_rate,
            'avg_hit_latency_ms': avg_hit_latency,
            'avg_miss_latency_ms': avg_miss_latency,
            'avg_overall_latency_ms': avg_latency_with_cache,
            'latency_reduction_pct': latency_reduction * 100
            }

    def print_report(self):
        """Print performance report"""
        metrics = self.get_metrics()

        print("\n" + "="*60)
        print("CACHE PERFORMANCE REPORT")
        print("="*60)
        print(f"Total Requests: {metrics['total_requests']:,}")
        print(f"Cache Hits: {metrics['cache_hits']:,}")
        print(f"Cache Misses: {metrics['cache_misses']:,}")
        print(f"Hit Rate: {metrics['hit_rate']:.2%}")
        print(f"Miss Rate: {metrics['miss_rate']:.2%}")
        print(f"\nLatency Analysis:")
        print(f" Cache Hit: {metrics['avg_hit_latency_ms']:.2f} ms")
        print(f" Cache Miss: {metrics['avg_miss_latency_ms']:.2f} ms")
        print(f" Overall Average: {metrics['avg_overall_latency_ms']:.2f} ms")
        print(f" Latency Reduction: {metrics['latency_reduction_pct']:.1f}%")
        print("="*60)

        # Usage example
        analyzer = CachePerformanceAnalyzer()

        # Simulate requests
        import random
        import time

        cache = LRUCache(capacity=100)

        for i in range(1000):
            key = f"key_{random.randint(1, 150)}"

            # Check cache
            start = time.perf_counter()
            value = cache.get(key)

            if value is not None:
                # Cache hit (fast)
                latency = (time.perf_counter() - start) * 1000
                analyzer.record_hit(latency)
            else:
                # Cache miss (slow - simulate computation)
                time.sleep(0.001) # 1ms computation
                latency = (time.perf_counter() - start) * 1000
                analyzer.record_miss(latency)

                # Store in cache
                cache.put(key, f"value_{key}")

                analyzer.print_report()

Cache Size Optimization

class CacheSizeOptimizer:
    """
    Find optimal cache size for given workload

    Trade-off: Larger cache = higher hit rate but more memory
    """

    def __init__(self, workload):
        """
        Args:
            workload: List of access patterns (keys)
            """
            self.workload = workload

    def find_optimal_size(self, max_size=10000, step=100):
        """
        Test different cache sizes

        Returns optimal size based on diminishing returns
        """
        results = []

        print("Testing cache sizes...")
        print(f"{'Size':<10} {'Hit Rate':<12} {'Marginal Gain':<15}")
        print("-" * 40)

        prev_hit_rate = 0

        for size in range(step, max_size + 1, step):
            hit_rate = self._simulate_cache(size)
            marginal_gain = hit_rate - prev_hit_rate

            results.append({
            'size': size,
            'hit_rate': hit_rate,
            'marginal_gain': marginal_gain
            })

            print(f"{size:<10} {hit_rate:<12.2%} {marginal_gain:<15.4%}")

            prev_hit_rate = hit_rate

            # Stop if marginal gain is too small
            if marginal_gain < 0.001: # 0.1% gain
                print(f"\nDiminishing returns detected at size {size}")
                break

                return results

    def _simulate_cache(self, size):
        """Simulate cache with given size"""
        cache = LRUCache(capacity=size)
        hits = 0

        for key in self.workload:
            if cache.get(key) is not None:
                hits += 1
            else:
                cache.put(key, True)

                return hits / len(self.workload)

                # Generate workload (Zipf distribution - realistic for many applications)
                import numpy as np

    def generate_zipf_workload(n_items=1000, n_requests=10000, alpha=1.5):
        """
        Generate Zipf-distributed workload

        Zipf law: Some items are accessed much more frequently
        (80/20 rule, power law distribution)
        """
        # Zipf distribution
        probabilities = np.array([1.0 / (i ** alpha) for i in range(1, n_items + 1)])
        probabilities /= probabilities.sum()

        # Generate requests
        workload = np.random.choice(
        [f"key_{i}" for i in range(n_items)],
        size=n_requests,
        p=probabilities
        )

        return workload.tolist()

        # Find optimal cache size
        workload = generate_zipf_workload(n_items=1000, n_requests=10000)
        optimizer = CacheSizeOptimizer(workload)
        results = optimizer.find_optimal_size(max_size=500, step=50)

        # Plot results
        import matplotlib.pyplot as plt

        sizes = [r['size'] for r in results]
        hit_rates = [r['hit_rate'] for r in results]

        plt.figure(figsize=(10, 6))
        plt.plot(sizes, hit_rates, marker='o')
        plt.xlabel('Cache Size')
        plt.ylabel('Hit Rate')
        plt.title('Cache Size vs Hit Rate')
        plt.grid(True)
        plt.savefig('cache_size_optimization.png')

Advanced Caching Patterns

Write-Through vs Write-Back Cache

class WriteThroughCache:
    """
    Write-through cache: Write to cache and database simultaneously

    Pros:
        - Data consistency
        - Simple to implement

        Cons:
            - Slower writes
            - Every write hits database
            """

    def __init__(self, cache, database):
        self.cache = cache
        self.database = database

    def get(self, key):
        """Read with cache"""
        # Try cache first
        value = self.cache.get(key)
        if value is not None:
            return value

            # Cache miss: read from database
            value = self.database.get(key)
            if value is not None:
                self.cache.put(key, value)

                return value

    def put(self, key, value):
        """Write to both cache and database"""
        # Write to database first
        self.database.put(key, value)

        # Then update cache
        self.cache.put(key, value)

    class WriteBackCache:
        """
        Write-back cache: Write to cache only, flush to database later

        Pros:
            - Fast writes
            - Batching possible

            Cons:
                - Risk of data loss
                - More complex
                """

    def __init__(self, cache, database, flush_interval=5):
        self.cache = cache
        self.database = database
        self.flush_interval = flush_interval

        self.dirty_keys = set()
        self.last_flush = time.time()

    def get(self, key):
        """Read with cache"""
        value = self.cache.get(key)
        if value is not None:
            return value

            value = self.database.get(key)
            if value is not None:
                self.cache.put(key, value)

                return value

    def put(self, key, value):
        """Write to cache only"""
        self.cache.put(key, value)
        self.dirty_keys.add(key)

        # Check if we need to flush
        if time.time() - self.last_flush > self.flush_interval:
            self.flush()

    def flush(self):
        """Flush dirty keys to database"""
        if not self.dirty_keys:
            return

            print(f"Flushing {len(self.dirty_keys)} dirty keys...")

            for key in self.dirty_keys:
                value = self.cache.get(key)
                if value is not None:
                    self.database.put(key, value)

                    self.dirty_keys.clear()
                    self.last_flush = time.time()

                    # Example database simulation
    class SimpleDatabase:
    def __init__(self):
        self.data = {}
        self.read_count = 0
        self.write_count = 0

    def get(self, key):
        self.read_count += 1
        time.sleep(0.001) # Simulate latency
        return self.data.get(key)

    def put(self, key, value):
        self.write_count += 1
        time.sleep(0.001) # Simulate latency
        self.data[key] = value

        # Compare write-through vs write-back
        db1 = SimpleDatabase()
        cache1 = LRUCache(capacity=100)
        write_through = WriteThroughCache(cache1, db1)

        db2 = SimpleDatabase()
        cache2 = LRUCache(capacity=100)
        write_back = WriteBackCache(cache2, db2)

        # Benchmark writes
        import time

        # Write-through
        start = time.time()
        for i in range(100):
            write_through.put(f"key_{i}", f"value_{i}")
            wt_time = time.time() - start

            # Write-back
            start = time.time()
            for i in range(100):
                write_back.put(f"key_{i}", f"value_{i}")
                write_back.flush() # Final flush
                wb_time = time.time() - start

                print(f"Write-through: {wt_time:.3f}s, DB writes: {db1.write_count}")
                print(f"Write-back: {wb_time:.3f}s, DB writes: {db2.write_count}")

Cache Aside Pattern

class CacheAsidePattern:
    """
    Cache-aside (lazy loading): Application manages cache

    Most common pattern for ML systems

    Flow:
        1. Check cache
        2. If miss, query database
        3. Store in cache
        4. Return result
        """

    def __init__(self, cache, database):
        self.cache = cache
        self.database = database

        self.stats = {
        'reads': 0,
        'cache_hits': 0,
        'cache_misses': 0,
        'writes': 0
        }

    def get(self, key):
        """
        Get with cache-aside pattern

        Application is responsible for loading cache
        """
        self.stats['reads'] += 1

        # Try cache first
        value = self.cache.get(key)
        if value is not None:
            self.stats['cache_hits'] += 1
            return value

            # Cache miss: load from database
            self.stats['cache_misses'] += 1
            value = self.database.get(key)

            if value is not None:
                # Populate cache for next time
                self.cache.put(key, value)

                return value

    def put(self, key, value):
        """
        Write to database, invalidate cache

        Simple approach: Just write to DB and remove from cache
        Next read will repopulate
        """
        self.stats['writes'] += 1

        # Write to database
        self.database.put(key, value)

        # Invalidate cache entry
        # (Could also update cache here - depends on use case)
        if key in self.cache.cache:
            del self.cache.cache[key]

    def get_stats(self):
        """Get cache statistics"""
        hit_rate = (
        self.stats['cache_hits'] / self.stats['reads']
        if self.stats['reads'] > 0 else 0
        )

        return {
        **self.stats,
        'hit_rate': hit_rate
        }

        # Usage for ML predictions
    class MLPredictionService:
        """
        ML prediction service with cache-aside pattern
        """

    def __init__(self, model, cache_capacity=1000):
        self.model = model
        self.cache = LRUCache(capacity=cache_capacity)

        # Fake database for persisted predictions
        self.prediction_db = {}

        self.pattern = CacheAsidePattern(
        self.cache,
        self.prediction_db
        )

    def predict(self, features):
        """
        Predict with caching

        Args:
            features: Feature vector (tuple for hashability)

            Returns:
                Prediction
                """
                # Create cache key from features
                cache_key = hash(features)

                # Try cache-aside pattern
                cached_prediction = self.pattern.get(cache_key)
                if cached_prediction is not None:
                    return cached_prediction

                    # Compute prediction (expensive)
                    prediction = self.model.predict([features])[0]

                    # Store in database and cache
                    self.pattern.put(cache_key, prediction)

                    return prediction

    def get_cache_stats(self):
        """Get caching statistics"""
        return self.pattern.get_stats()

        # Example usage
        from sklearn.ensemble import RandomForestClassifier
        import numpy as np

        # Train simple model
        X_train = np.random.randn(100, 5)
        y_train = (X_train.sum(axis=1) > 0).astype(int)
        model = RandomForestClassifier(n_estimators=10)
        model.fit(X_train, y_train)

        # Create prediction service
        service = MLPredictionService(model, cache_capacity=100)

        # Make predictions (some repeated)
        for _ in range(1000):
            # Generate features (with some repetition)
            features = tuple(np.random.randint(0, 10, size=5))
            prediction = service.predict(features)

            print("Cache statistics:")
            print(service.get_cache_stats())

Cache Stampede Prevention

Problem: Thundering Herd

class CacheStampedeProtection:
    """
    Prevent cache stampede (thundering herd)

    Problem:
        - Cache entry expires
        - Many requests try to regenerate simultaneously
        - Database/model gets overwhelmed

        Solution:
            - Use locking to ensure only one request regenerates
            - Others wait for that request to complete
            """

    def __init__(self, cache, compute_fn):
        self.cache = cache
        self.compute_fn = compute_fn

        # Lock for each key
        self.locks = {}
        self.master_lock = threading.Lock()

    def get(self, key):
        """
        Get with stampede protection

        Uses double-check locking pattern
        """
        # First check: Try cache (no lock)
        value = self.cache.get(key)
        if value is not None:
            return value

            # Get or create lock for this key
            with self.master_lock:
                if key not in self.locks:
                    self.locks[key] = threading.Lock()
                    key_lock = self.locks[key]

                    # Acquire key-specific lock
                    with key_lock:
                        # Second check: Try cache again (another thread might have filled it)
                        value = self.cache.get(key)
                        if value is not None:
                            return value

                            # Compute value (only one thread does this)
                            print(f"Computing value for {key} (thread: {threading.current_thread().name})")
                            value = self.compute_fn(key)

                            # Store in cache
                            self.cache.put(key, value)

                            return value

                            # Demo: Simulate stampede
                            import threading
                            import time

    def expensive_computation(key):
        """Simulate expensive computation"""
        time.sleep(0.1) # 100ms
        return f"computed_value_for_{key}"

        cache = LRUCache(capacity=100)
        protector = CacheStampedeProtection(cache, expensive_computation)

        # Simulate stampede: 10 threads requesting same key
    def make_request(key, results, index):
        start = time.time()
        result = protector.get(key)
        duration = time.time() - start
        results[index] = duration

        results = [0] * 10
        threads = []

        # Clear cache to force computation
        cache = LRUCache(capacity=100)
        protector.cache = cache

        print("Simulating cache stampede for key 'popular_item'...")
        start_time = time.time()

        for i in range(10):
            t = threading.Thread(
            target=make_request,
            args=('popular_item', results, i),
            name=f"Thread-{i}"
            )
            threads.append(t)
            t.start()

            for t in threads:
                t.join()

                total_time = time.time() - start_time

                print(f"\nTotal time: {total_time:.3f}s")
                print(f"Average request time: {sum(results)/len(results):.3f}s")
                print(f"Max request time: {max(results):.3f}s")
                print(f"Min request time: {min(results):.3f}s")
                print("\nWith protection, only one thread computed (others waited)")

Probabilistic Early Expiration

class ProbabilisticCache:
    """
    Probabilistic early expiration to prevent stampede

    Idea: Refresh cache before expiration with increasing probability
    This spreads out refresh operations
    """

    def __init__(self, cache, compute_fn, ttl=60, beta=1.0):
        """
        Args:
            ttl: Time to live in seconds
            beta: Controls early expiration probability
            """
            self.cache = cache
            self.compute_fn = compute_fn
            self.ttl = ttl
            self.beta = beta

            # Track insertion times
            self.insertion_times = {}

    def get(self, key):
        """
        Get with probabilistic early expiration

        Formula: Should refresh if:
            current_time - stored_time * beta * log(random) >= ttl
            """
            # Check cache
            value = self.cache.get(key)

            if value is not None and key in self.insertion_times:
                # Calculate age
                age = time.time() - self.insertion_times[key]

                # Probabilistic early expiration
                import random
                import math

                # XFetch algorithm
                delta = self.ttl - age
                if delta * self.beta * math.log(random.random()) < 0:
                    # Refresh early
                    print(f"Early refresh for {key} (age: {age:.1f}s)")
                    value = self._refresh(key)

                    return value

                    # Cache miss or expired
                    return self._refresh(key)

    def _refresh(self, key):
        """Refresh cache entry"""
        value = self.compute_fn(key)
        self.cache.put(key, value)
        self.insertion_times[key] = time.time()
        return value

        # Demo
    def compute_value(key):
        time.sleep(0.01)
        return f"value_{key}_{time.time()}"

        pcache = ProbabilisticCache(
        LRUCache(capacity=100),
        compute_value,
        ttl=5, # 5 second TTL
        beta=1.0
        )

        # Access same key multiple times
        for i in range(20):
            value = pcache.get('test_key')
            time.sleep(0.3) # 300ms between requests

Distributed Cache Challenges

Cache Consistency

class DistributedCacheCoordinator:
    """
    Coordinate cache across multiple instances

    Challenges:
        1. Keeping caches in sync
        2. Handling partial failures
        3. Eventual consistency
        """

    def __init__(self, redis_client, instance_id):
        self.redis = redis_client
        self.instance_id = instance_id

        # Local L1 cache
        self.local_cache = LRUCache(capacity=1000)

        # Subscribe to invalidation messages
        self.pubsub = self.redis.pubsub()
        self.pubsub.subscribe('cache:invalidate')

        # Start listener thread
        self.listener_thread = threading.Thread(
        target=self._listen_for_invalidations,
        daemon=True
        )
        self.listener_thread.start()

    def get(self, key):
        """
        Get from multi-level cache

        L1 (local) -> L2 (Redis) -> Compute
        """
        # Try local cache
        value = self.local_cache.get(key)
        if value is not None:
            return value

            # Try Redis
            value = self.redis.get(key)
            if value is not None:
                value = pickle.loads(value)
                # Populate local cache
                self.local_cache.put(key, value)
                return value

                return None

    def put(self, key, value, ttl=3600):
        """
        Put in both levels and notify others
        """
        # Store in local cache
        self.local_cache.put(key, value)

        # Store in Redis
        self.redis.setex(key, ttl, pickle.dumps(value))

        # Notify other instances to invalidate their L1
        self.redis.publish(
        'cache:invalidate',
        json.dumps({
        'key': key,
        'source_instance': self.instance_id
        })
        )

    def _listen_for_invalidations(self):
        """Listen for invalidation messages"""
        for message in self.pubsub.listen():
            if message['type'] == 'message':
                data = json.loads(message['data'])

                # Don't invalidate if we sent the message
                if data['source_instance'] != self.instance_id:
                    key = data['key']

                    # Invalidate local cache
                    if key in self.local_cache.cache:
                        del self.local_cache.cache[key]
                        print(f"Invalidated {key} from local cache")

                        # Usage across multiple instances
                        # Instance 1
                        coordinator1 = DistributedCacheCoordinator(redis_client, instance_id='instance1')

                        # Instance 2
                        coordinator2 = DistributedCacheCoordinator(redis_client, instance_id='instance2')

                        # Instance 1 writes
                        coordinator1.put('shared_key', 'value_from_instance1')

                        # Instance 2 reads (will get from Redis)
                        value = coordinator2.get('shared_key')

Key Takeaways

Multiple eviction policies - LRU, LFU, TTL for different use cases ✅ Distributed caching - Redis for shared cache across services ✅ Multi-level caching - L1 (local) + L2 (distributed) for optimal performance ✅ Cache warming - Proactive population of hot items ✅ Invalidation strategies - Push-based and pull-based ✅ Linked list connection - Understanding pointer manipulation helps with cache implementation ✅ Monitor cache metrics - Hit rate, latency, memory usage

FAQ

Which cache eviction policy should I use for ML predictions?

Use LRU for general prediction caching where recent requests are likely to recur. Use LFU for embeddings and features with skewed access patterns where some entities are consistently popular. Use TTL for time-sensitive data like real-time features that become stale.

What is a cache stampede and how do you prevent it?

A cache stampede occurs when a popular cache entry expires and hundreds of concurrent requests all try to recompute it simultaneously, overwhelming the model service. Prevent it with locking so only one request recomputes, or use probabilistic early expiration to spread refreshes over time.

How does multi-level caching work for ML serving?

L1 is an in-memory local cache on each serving instance for sub-millisecond access. L2 is a shared Redis cache for cross-instance hits. On L1 miss, check L2 and populate L1. On L2 miss, compute the prediction and populate both levels.

How do you invalidate cached predictions when a model is updated?

Use Redis pub/sub to broadcast invalidation messages when a model is deployed. Each serving instance subscribes and clears its local cache entries matching the invalidated model. Redis entries can be cleared using SCAN with pattern matching.


Originally published at: arunbaby.com/ml-system-design/0010-caching-strategies

Want to work together?

I take on projects, advisory roles, and fractional CTO engagements in AI/ML. I also help businesses go AI-native with agentic workflows and agent orchestration.

Get in touch