24 minute read

Design a global CDN for ML systems: Edge caching reduces latency from 500ms to 50ms. Critical for real-time predictions worldwide.

TL;DR

A CDN for ML systems places edge servers with multi-level caches (Redis for hot predictions, local SSD for model weights, origin for everything else) close to users worldwide. GeoDNS and health-aware load balancing route each request to the optimal edge server. When models update, a pub/sub invalidation system propagates changes across the network. Edge inference using ONNX models means many predictions never leave the user’s region. This architecture is essential for any ML system serving global users – see also model serving architecture and caching strategies for related patterns.

A spider web covered in morning dew drops

Problem Statement

Design a Content Delivery Network (CDN) for serving:

  1. ML model inference (predictions at the edge)
  2. Static assets (model weights, configs, embeddings)
  3. API responses (cached predictions, feature data)

Why Do We Need a CDN?

The Core Problem: Distance Creates Latency

Imagine you’re a user in Tokyo trying to access a website hosted in Virginia, USA:

User (Tokyo) ──────── 10,000 km ──────── Server (Virginia)
 ~150ms round-trip time

The physics problem:

  • Light travels at 300,000 km/s
  • Signal in fiber: ~200,000 km/s
  • Tokyo ↔ Virginia: ~10,000 km
  • Theoretical minimum: 50ms
  • Reality with routing: 150-200ms

What if we could serve from Tokyo instead?

User (Tokyo) ── 10 km ── Edge Server (Tokyo)
 ~1-2ms!

That’s a 75-100x improvement just from being geographically closer!

Real-World Impact on ML Systems

Scenario: Real-time recommendation system

Architecture Latency User Experience
Without CDN: Request → US → Model Inference → Response 200ms+ Noticeable delay, users leave
With CDN: Request → Local Edge → Cached/Local Inference → Response 20-50ms Feels instant ✓

The business impact:

  • Every 100ms of latency = 1% drop in sales (Amazon study)
  • For ML systems: Users won’t wait for slow predictions
  • CDN makes your ML system feel instant globally

What CDN Does for You

1. Geographic Distribution Cache content at multiple locations worldwide (edge servers)

2. Intelligent Caching Store frequently accessed content close to users

3. Smart Routing Direct users to the best edge server (closest + healthy + low load)

4. Fault Tolerance If one edge fails, route to another

5. Bandwidth Savings Serve from edge → Less traffic to origin → Lower costs

Requirements

Functional:

  • Serve content from geographically distributed edge locations
  • Cache popular content close to users
  • Route requests to nearest/best edge server
  • Handle cache invalidation and updates
  • Support both static and dynamic content

Non-Functional:

  • Latency: < 50ms p99 for edge hits (vs 200-500ms from origin)
  • Availability: 99.99% uptime (4 minutes downtime/month)
  • Scalability: Handle 1M+ requests/second globally
  • Cache hit rate: > 80% for static content (fewer origin requests)
  • Global coverage: Presence in 50+ regions

High-Level Architecture

┌─────────────────────────────────────────────────────────────┐
│ USER REQUESTS │
│ 🌍 Asia 🌍 Europe 🌍 Americas 🌍 Africa │
└───────┬────────────┬────────────┬──────────────┬────────────┘
 │ │ │ │
 ↓ ↓ ↓ ↓
┌────────────────────────────────────────────────────────────┐
│ DNS / GLOBAL LOAD BALANCER │
│ • GeoDNS routing │
│ • Health checks │
│ • Latency-based routing │
└────────┬────────────┬────────────┬──────────────┬──────────┘
 │ │ │ │
 ┌────▼────┐ ┌───▼─────┐ ┌───▼─────┐ ┌───▼─────┐
 │ Edge │ │ Edge │ │ Edge │ │ Edge │
 │ Tokyo │ │ London │ │ N.Virginia│ │ Mumbai │
 │ │ │ │ │ │ │ │
 │ L1 Cache│ │ L1 Cache│ │ L1 Cache│ │ L1 Cache│
 │ (Redis) │ │ (Redis) │ │ (Redis) │ │ (Redis) │
 │ │ │ │ │ │ │ │
 │ ML Model│ │ ML Model│ │ ML Model│ │ ML Model│
 └────┬────┘ └────┬────┘ └────┬────┘ └────┬────┘
 │ │ │ │
 └────────────┴────────────┴──────────────┘
 │
 ┌───────▼────────┐
 │ ORIGIN SERVERS │
 │ │
 │ • Master models│
 │ • Databases │
 │ • Feature store│
 │ • Object store │
 └────────────────┘

Core Components

1. Edge Servers

Purpose: Serve content from locations close to users

Before we dive into code, let’s understand the concept:

What is an Edge Server?

Think of edge servers like local convenience stores:

  • Origin Server = Central warehouse (far away, has everything)
  • Edge Server = Local store (nearby, has popular items)

When you need milk:

  • Without edge: Drive to warehouse (30 min)
  • With edge: Walk to local store (2 min)

Multi-Level Cache Strategy

Edge servers use multiple cache layers:

Request → L1 Cache (Redis, in-memory) ← Fastest, smallest
 ↓ Miss
 L2 Cache (Disk, local SSD) ← Fast, medium
 ↓ Miss
 Origin Server (Database) ← Slow, largest

Why multiple levels?

  1. L1 (Redis): Hot data, 50-100ms access, expensive ($100/GB/month)
  2. L2 (Disk): Warm data, 5-10ms access, cheap ($10/GB/month)
  3. Origin: Cold data, 100-500ms access, cheapest ($0.02/GB/month)

Trade-off: Speed vs Cost vs Capacity

Cache Speed Cost Capacity Use Case
L1 (Redis) 1ms High Small (10GB) Prediction results, hot features
L2 (Disk) 10ms Medium Medium (100GB) Model weights, embeddings
Origin 200ms Low Large (TB+) Full dataset, historical data
class EdgeServer:
    """
    CDN edge server

    Components:
        - L1 cache (Redis): Hot content
        - L2 cache (local disk): Warm content
        - ML model: For edge inference
        - Origin client: Fetch misses from origin
        """

    def __init__(self, region, origin_url):
        self.region = region
        self.origin_url = origin_url

        # Multi-level cache
        import redis
        import pickle
        self.l1_cache = redis.Redis(host='localhost', port=6379)

        # Minimal DiskCache stub for illustration
    class DiskCache:
    def __init__(self, size_gb=100):
        self.store = {}
    def get(self, key):
        return self.store.get(key)
    def set(self, key, value):
        self.store[key] = value
    def delete(self, key):
        self.store.pop(key, None)
    def delete_pattern(self, pattern):
        # naive pattern matcher
        import fnmatch
        keys = [k for k in self.store.keys() if fnmatch.fnmatch(k, pattern)]
        for k in keys:
            self.store.pop(k, None)
            self.l2_cache = DiskCache(size_gb=100)

            # ML model for edge inference
    def load_model(path):
        return object()
        self.model = load_model('model.onnx')

        # Metrics
        self.metrics = EdgeMetrics()

        async def handle_request(self, request):
            """
            Handle incoming request

            Flow:
                1. Check L1 cache (Redis)
                2. Check L2 cache (disk)
                3. Fetch from origin
                4. Update caches
                """
                import time, json, pickle
                start_time = time.time()

                # Generate cache key
                cache_key = self._generate_cache_key(request)

                # Try L1 cache
                response = await self._check_l1_cache(cache_key)
                if response:
                    self.metrics.record_hit('l1', time.time() - start_time)
                    return response

                    # Try L2 cache
                    response = await self._check_l2_cache(cache_key)
                    if response:
                        # Promote to L1
                        await self._store_l1_cache(cache_key, response)
                        self.metrics.record_hit('l2', time.time() - start_time)
                        return response

                        # Cache miss: fetch from origin
                        response = await self._fetch_from_origin(request)

                        # Update caches
                        await self._store_l1_cache(cache_key, response)
                        await self._store_l2_cache(cache_key, response)

                        self.metrics.record_miss(time.time() - start_time)

                        return response

                        async def _check_l1_cache(self, key):
                            """Check L1 (Redis) cache"""
                            try:
                                data = self.l1_cache.get(key)
                                if data:
                                    return pickle.loads(data)
                                except Exception as e:
                                    print(f"L1 cache error: {e}")

                                    return None

                                    async def _store_l1_cache(self, key, value, ttl=300):
                                        """Store in L1 cache with TTL"""
                                        try:
                                            self.l1_cache.setex(
                                            key,
                                            ttl,
                                            pickle.dumps(value)
                                            )
                                        except Exception as e:
                                            print(f"L1 cache store error: {e}")

                                            async def _check_l2_cache(self, key):
                                                """Check L2 (disk) cache"""
                                                return self.l2_cache.get(key)

                                                async def _store_l2_cache(self, key, value):
                                                    """Store in L2 cache"""
                                                    self.l2_cache.set(key, value)

                                                    async def _fetch_from_origin(self, request):
                                                        """Fetch from origin server"""
                                                        import aiohttp

                                                        async with aiohttp.ClientSession() as session:
                                                            async with session.post(
                                                            f"{self.origin_url}{request.path}",
                                                            json=request.data
                                                            ) as response:
                                                                return await response.json()

    def _generate_cache_key(self, request):
        """Generate cache key from request"""
        import hashlib

        # Include path and normalized data
        key_data = f"{request.path}:{json.dumps(request.data, sort_keys=True)}"
        return hashlib.md5(key_data.encode()).hexdigest()

    class EdgeMetrics:
        """Track edge server metrics"""

    def __init__(self):
        self.l1_hits = 0
        self.l2_hits = 0
        self.misses = 0

        self.l1_latencies = []
        self.l2_latencies = []
        self.miss_latencies = []

    def record_hit(self, level, latency):
        if level == 'l1':
            self.l1_hits += 1
            self.l1_latencies.append(latency)
        elif level == 'l2':
            self.l2_hits += 1
            self.l2_latencies.append(latency)

    def record_miss(self, latency):
        self.misses += 1
        self.miss_latencies.append(latency)

    def get_stats(self):
        total = self.l1_hits + self.l2_hits + self.misses

        return {
        'l1_hit_rate': self.l1_hits / total if total > 0 else 0,
        'l2_hit_rate': self.l2_hits / total if total > 0 else 0,
        'miss_rate': self.misses / total if total > 0 else 0,
        'avg_l1_latency_ms': np.mean(self.l1_latencies) * 1000 if self.l1_latencies else 0,
        'avg_l2_latency_ms': np.mean(self.l2_latencies) * 1000 if self.l2_latencies else 0,
        'avg_miss_latency_ms': np.mean(self.miss_latencies) * 1000 if self.miss_latencies else 0,
        }

        # Example usage
        edge = EdgeServer(region='us-east-1', origin_url='https://api.example.com')

        # Simulate requests
        async def simulate_requests():
            for i in range(100):
                request = Request(
                path='/predict',
                data={'features': [1, 2, 3, 4, 5]}
                )

                response = await edge.handle_request(request)
                print(f"Request {i}: {response}")

                # Print metrics
                stats = edge.metrics.get_stats()
                print("\nEdge Server Metrics:")
                for key, value in stats.items():
                    if 'rate' in key:
                        print(f" {key}: {value:.2%}")
                    else:
                        print(f" {key}: {value:.2f}")

                        # Run
                        import asyncio
                        asyncio.run(simulate_requests())

2. Global Load Balancer / GeoDNS

Purpose: Route requests to optimal edge server

class GlobalLoadBalancer:
    """
    Route requests to best edge server

    Routing strategies:
        1. Geographic proximity
        2. Server load
        3. Health status
        4. Network latency
        """

    def __init__(self):
        self.edge_servers = self._discover_edge_servers()
        self.health_checker = HealthChecker(self.edge_servers)

        # Start health checking
        self.health_checker.start()

    def route_request(self, client_ip, request):
        """
        Route request to best edge server

        Args:
            client_ip: Client IP address
            request: Request object

            Returns:
                Best edge server
                """
                # Get client location
                client_location = self._geolocate_ip(client_ip)

                # Get healthy edge servers
                healthy_servers = self.health_checker.get_healthy_servers()

                if not healthy_servers:
                    raise Exception("No healthy edge servers available")

                    # Score each server
                    scores = []

                    for server in healthy_servers:
                        score = self._score_server(
                        server,
                        client_location,
                        request
                        )
                        scores.append((server, score))

                        # Sort by score (higher is better)
                        scores.sort(key=lambda x: x[1], reverse=True)

                        # Return best server
                        return scores[0][0]

    def _score_server(self, server, client_location, request):
        """
        Score server for given request

        Factors:
            - Geographic distance (weight: 0.5)
            - Server load (weight: 0.3)
            - Cache hit rate (weight: 0.2)
            """
            # Geographic proximity
            distance = self._calculate_distance(
            client_location,
            server.location
            )
            distance_score = 1.0 / (1.0 + distance / 1000) # Normalize

            # Server load
            load = server.get_current_load()
            load_score = 1.0 - min(load, 1.0)

            # Cache hit rate
            hit_rate = server.metrics.get_stats()['l1_hit_rate']

            # Weighted score
            score = (
            0.5 * distance_score +
            0.3 * load_score +
            0.2 * hit_rate
            )

            return score

    def _geolocate_ip(self, ip):
        """
        Get geographic location from IP

        Uses MaxMind GeoIP or similar
        """
        import geoip2.database

        reader = geoip2.database.Reader('GeoLite2-City.mmdb')
        response = reader.city(ip)

        return {
        'lat': response.location.latitude,
        'lon': response.location.longitude,
        'city': response.city.name,
        'country': response.country.name
        }

    def _calculate_distance(self, loc1, loc2):
        """
        Calculate distance between two locations (km)

        Uses Haversine formula
        """
        from math import radians, sin, cos, sqrt, atan2

        R = 6371 # Earth radius in km

        lat1, lon1 = radians(loc1['lat']), radians(loc1['lon'])
        lat2, lon2 = radians(loc2['lat']), radians(loc2['lon'])

        dlat = lat2 - lat1
        dlon = lon2 - lon1

        a = sin(dlat/2)**2 + cos(lat1) * cos(lat2) * sin(dlon/2)**2
        c = 2 * atan2(sqrt(a), sqrt(1-a))

        distance = R * c

        return distance

    def _discover_edge_servers(self):
        """Discover available edge servers"""
        # In production, this would query service registry
        return [
        EdgeServerInfo('us-east-1', 'https://edge-us-east-1.example.com', {'lat': 39.0, 'lon': -77.5}),
        EdgeServerInfo('eu-west-1', 'https://edge-eu-west-1.example.com', {'lat': 53.3, 'lon': -6.3}),
        EdgeServerInfo('ap-northeast-1', 'https://edge-ap-northeast-1.example.com', {'lat': 35.7, 'lon': 139.7}),
        ]

    class HealthChecker:
        """
        Monitor health of edge servers

        Checks:
            - HTTP health endpoint
            - Response time
            - Error rate
            """

    def __init__(self, servers, check_interval=10):
        self.servers = servers
        self.check_interval = check_interval

        self.health_status = {server.region: True for server in servers}
        self.last_check = {server.region: 0 for server in servers}

        self.running = False

    def start(self):
        """Start health checking in background"""
        self.running = True

        import threading
        self.thread = threading.Thread(target=self._health_check_loop, daemon=True)
        self.thread.start()

    def stop(self):
        """Stop health checking"""
        self.running = False

    def _health_check_loop(self):
        """Health check loop"""
        while self.running:
            for server in self.servers:
                healthy = self._check_server_health(server)
                self.health_status[server.region] = healthy
                self.last_check[server.region] = time.time()

                import time
                time.sleep(self.check_interval)

    def _check_server_health(self, server):
        """Check if server is healthy"""
        try:
            import requests

            response = requests.get(
            f"{server.url}/health",
            timeout=5
            )

            if response.status_code == 200:
                # Check response time
                if response.elapsed.total_seconds() < 1.0:
                    return True

                    return False

                except Exception as e:
                    print(f"Health check failed for {server.region}: {e}")
                    return False

    def get_healthy_servers(self):
        """Get list of healthy servers"""
        return [
        server for server in self.servers
        if self.health_status[server.region]
        ]

    class EdgeServerInfo:
        """Edge server information"""
    def __init__(self, region, url, location):
        self.region = region
        self.url = url
        self.location = location
        self.metrics = EdgeMetrics()

    def get_current_load(self):
        """Get current server load (0-1)"""
        # In production, query server metrics
        return 0.5 # Placeholder

        # Example usage
        glb = GlobalLoadBalancer()

        # Route request
        client_ip = '8.8.8.8' # Google DNS (US)
        request = Request(path='/predict', data={})

        best_server = glb.route_request(client_ip, request)
        print(f"Routing to: {best_server.region}")

3. Cache Invalidation System

Purpose: Propagate updates across edge servers

class CacheInvalidationSystem:
    """
    Propagate cache invalidations to edge servers

    Methods:
        1. Push-based: Immediate invalidation
        2. Pull-based: Periodic refresh
        3. TTL-based: Automatic expiration
        """

    def __init__(self, edge_servers):
        self.edge_servers = edge_servers

        # Message queue for invalidations
        self.invalidation_queue = redis.Redis(host='localhost', port=6379)

        # Pub/sub for real-time propagation
        self.pubsub = self.invalidation_queue.pubsub()
        self.pubsub.subscribe('cache:invalidate')

    def invalidate(self, keys, pattern=False):
        """
        Invalidate cache keys across all edge servers

        Args:
            keys: List of keys to invalidate
            pattern: If True, treat keys as patterns
            """
            message = {
            'keys': keys,
            'pattern': pattern,
            'timestamp': time.time()
            }

            # Publish to all edge servers
            self.invalidation_queue.publish(
            'cache:invalidate',
            json.dumps(message)
            )

            print(f"Invalidated {len(keys)} keys across edge network")

    def invalidate_prefix(self, prefix):
        """
        Invalidate all keys with given prefix

        Example: invalidate_prefix('user:123:*')
        """
        self.invalidate([prefix], pattern=True)

    def invalidate_model_update(self, model_id):
        """
        Invalidate caches after model update

        Invalidates:
            - Model predictions
            - Model metadata
            - Related embeddings
            """
            patterns = [
            f"model:{model_id}:*",
            f"prediction:{model_id}:*",
            f"embedding:{model_id}:*"
            ]

            self.invalidate(patterns, pattern=True)

            print(f"Invalidated caches for model {model_id}")

    class EdgeInvalidationListener:
        """
        Listen for invalidation messages on edge server
        """

    def __init__(self, edge_server):
        self.edge_server = edge_server

        # Subscribe to invalidations
        self.pubsub = redis.Redis(host='localhost', port=6379).pubsub()
        self.pubsub.subscribe('cache:invalidate')

        self.running = False

    def start(self):
        """Start listening for invalidations"""
        self.running = True

        import threading
        self.thread = threading.Thread(target=self._listen_loop, daemon=True)
        self.thread.start()

    def stop(self):
        """Stop listening"""
        self.running = False

    def _listen_loop(self):
        """Listen for invalidation messages"""
        for message in self.pubsub.listen():
            if message['type'] == 'message':
                data = json.loads(message['data'])
                self._handle_invalidation(data)

    def _handle_invalidation(self, data):
        """Handle invalidation message"""
        keys = data['keys']
        pattern = data['pattern']

        if pattern:
            # Invalidate by pattern
            for key_pattern in keys:
                self._invalidate_pattern(key_pattern)
            else:
                # Invalidate specific keys
                for key in keys:
                    self._invalidate_key(key)

    def _invalidate_key(self, key):
        """Invalidate specific key"""
        # Remove from L1 cache
        self.edge_server.l1_cache.delete(key)

        # Remove from L2 cache
        self.edge_server.l2_cache.delete(key)

        print(f"Invalidated key: {key}")

    def _invalidate_pattern(self, pattern):
        """Invalidate keys matching pattern"""
        # Scan L1 cache
        for key in self.edge_server.l1_cache.scan_iter(match=pattern):
            self.edge_server.l1_cache.delete(key)

            # Scan L2 cache
            self.edge_server.l2_cache.delete_pattern(pattern)

            print(f"Invalidated pattern: {pattern}")

            # Example usage
            edge_servers = [
            EdgeServer('us-east-1', 'https://origin.example.com'),
            EdgeServer('eu-west-1', 'https://origin.example.com'),
            ]

            invalidation_system = CacheInvalidationSystem(edge_servers)

            # Start listeners on each edge
            for edge in edge_servers:
                listener = EdgeInvalidationListener(edge)
                listener.start()

                # Trigger invalidation
                invalidation_system.invalidate_model_update('model_v2')

ML Model Serving at Edge

Edge Inference

class EdgeMLServer:
    """
    Serve ML models at edge for low-latency inference

    Benefits:
        - Reduced latency (no round trip to origin)
        - Reduced bandwidth
        - Better privacy (data doesn't leave region)
        """

    def __init__(self, model_path):
        # Load ONNX model for edge inference
        import onnxruntime as ort

        self.session = ort.InferenceSession(
        model_path,
        providers=['CPUExecutionProvider']
        )

        self.input_name = self.session.get_inputs()[0].name
        self.output_name = self.session.get_outputs()[0].name

        # Cache for predictions
        self.prediction_cache = LRUCache(capacity=10000)

    def predict(self, features):
        """
        Predict with caching

        Args:
            features: Input features (must be hashable)

            Returns:
                Prediction
                """
                # Generate cache key
                cache_key = hash(features)

                # Check cache
                cached_prediction = self.prediction_cache.get(cache_key)
                if cached_prediction != -1:
                    return cached_prediction

                    # Run inference
                    features_array = np.array([features], dtype=np.float32)

                    prediction = self.session.run(
                    [self.output_name],
                    {self.input_name: features_array}
                    )[0][0]

                    # Cache result
                    self.prediction_cache.put(cache_key, prediction)

                    return prediction

                    async def batch_predict(self, features_list):
                        """
                        Batch prediction for efficiency

                        Separates cache hits from misses
                        """
                        predictions = {}
                        cache_misses = []
                        cache_miss_indices = []

                        # Check cache
                        for i, features in enumerate(features_list):
                            cache_key = hash(features)
                            cached = self.prediction_cache.get(cache_key)

                            if cached != -1:
                                predictions[i] = cached
                            else:
                                cache_misses.append(features)
                                cache_miss_indices.append(i)

                                # Batch inference for cache misses
                                if cache_misses:
                                    features_array = np.array(cache_misses, dtype=np.float32)

                                    batch_predictions = self.session.run(
                                    [self.output_name],
                                    {self.input_name: features_array}
                                    )[0]

                                    # Store in cache and results
                                    for i, pred in zip(cache_miss_indices, batch_predictions):
                                        cache_key = hash(features_list[i])
                                        self.prediction_cache.put(cache_key, pred)
                                        predictions[i] = pred

                                        # Return in original order
                                        return [predictions[i] for i in range(len(features_list))]

                                        # Example: Edge API server with ML inference
                                        from fastapi import FastAPI
                                        import uvicorn

                                        app = FastAPI()

                                        # Load model at startup
                                        edge_ml_server = EdgeMLServer('model.onnx')

                                        @app.post("/predict")
                                        async def predict(request: dict):
                                            """
                                            Edge prediction endpoint

                                            Returns cached or computed prediction
                                            """
                                            features = tuple(request['features'])

                                            try:
                                                prediction = edge_ml_server.predict(features)

                                                return {
                                                'prediction': float(prediction),
                                                'cached': edge_ml_server.prediction_cache.get(hash(features)) != -1,
                                                'edge_region': 'us-east-1'
                                                }
                                            except Exception as e:
                                                return {'error': str(e)}, 500

                                                @app.post("/batch_predict")
                                                async def batch_predict(request: dict):
                                                    """
                                                    Batch prediction endpoint
                                                    """
                                                    features_list = [tuple(f) for f in request['features']]

                                                    predictions = await edge_ml_server.batch_predict(features_list)

                                                    return {
                                                    'predictions': [float(p) for p in predictions],
                                                    'count': len(predictions),
                                                    'edge_region': 'us-east-1'
                                                    }

                                                    # Run edge server
                                                    # uvicorn.run(app, host='0.0.0.0', port=8000)

Model Distribution to Edge

class ModelDistributionSystem:
    """
    Distribute ML models to edge servers

    Challenges:
        - Large model sizes (GB)
        - Many edge locations
        - Version management
        - Atomic updates
        """

    def __init__(self, s3_bucket, edge_servers):
        self.s3_bucket = s3_bucket
        self.edge_servers = edge_servers

        # Track model versions at each edge
        self.edge_versions = {
        server.region: None
        for server in edge_servers
        }

    def distribute_model(self, model_path, version):
        """
        Distribute model to all edge servers

        Steps:
            1. Upload to S3
            2. Notify edge servers
            3. Edge servers download
            4. Edge servers validate
            5. Edge servers activate
            """
            print(f"Distributing model {version} to {len(self.edge_servers)} edge servers...")

            # Step 1: Upload to S3
            s3_key = f"models/{version}/model.onnx"
            self._upload_to_s3(model_path, s3_key)

            # Step 2: Notify edge servers
            results = []

            for server in self.edge_servers:
                result = self._distribute_to_edge(server, s3_key, version)
                results.append((server.region, result))

                # Check results
                successful = [r for r in results if r[1]]
                failed = [r for r in results if not r[1]]

                print(f"\nDistribution complete:")
                print(f" Successful: {len(successful)}/{len(self.edge_servers)}")
                print(f" Failed: {len(failed)}")

                if failed:
                    print(f" Failed regions: {[r[0] for r in failed]}")

                    return len(failed) == 0

    def _upload_to_s3(self, local_path, s3_key):
        """Upload model to S3"""
        import boto3

        s3 = boto3.client('s3')

        print(f"Uploading {local_path} to s3://{self.s3_bucket}/{s3_key}")

        s3.upload_file(
        local_path,
        self.s3_bucket,
        s3_key,
        ExtraArgs={'ServerSideEncryption': 'AES256'}
        )

    def _distribute_to_edge(self, server, s3_key, version):
        """
        Notify edge server to download model

        Edge server will:
            1. Download from S3
            2. Validate checksum
            3. Load model
            4. Run health checks
            5. Activate (atomic swap)
            """
            try:
                import requests

                response = requests.post(
                f"{server.url}/admin/update_model",
                json={
                's3_bucket': self.s3_bucket,
                's3_key': s3_key,
                'version': version
                },
                timeout=300 # 5 minutes for large models
                )

                if response.status_code == 200:
                    self.edge_versions[server.region] = version
                    print(f"{server.region}: Updated to {version}")
                    return True
                else:
                    print(f"{server.region}: Failed - {response.text}")
                    return False

                except Exception as e:
                    print(f"{server.region}: Error - {e}")
                    return False

    def rollback_model(self, target_version):
        """
        Rollback to previous model version

        Useful if new model has issues
        """
        print(f"Rolling back to version {target_version}...")

        s3_key = f"models/{target_version}/model.onnx"

        return self.distribute_model(f"/tmp/model_{target_version}.onnx", target_version)

    def get_version_status(self):
        """Get model versions deployed at each edge"""
        return self.edge_versions

        # Example usage
        edge_servers = [
        EdgeServerInfo('us-east-1', 'https://edge-us-east-1.example.com', {}),
        EdgeServerInfo('eu-west-1', 'https://edge-eu-west-1.example.com', {}),
        EdgeServerInfo('ap-northeast-1', 'https://edge-ap-northeast-1.example.com', {}),
        ]

        distributor = ModelDistributionSystem(
        s3_bucket='my-ml-models',
        edge_servers=edge_servers
        )

        # Distribute new model
        success = distributor.distribute_model('model_v3.onnx', 'v3')

        if success:
            print("\nModel distribution successful!")
            print("Current versions:")
            for region, version in distributor.get_version_status().items():
                print(f" {region}: {version}")
            else:
                print("\nModel distribution failed, rolling back...")
                distributor.rollback_model('v2')

Monitoring & Observability

CDN Metrics Dashboard

class CDNMetricsDashboard:
    """
    Aggregate and visualize CDN metrics

    Key metrics:
        - Cache hit rate
        - Latency (p50, p95, p99)
        - Bandwidth usage
        - Error rate
        - Request rate
        """

    def __init__(self, edge_servers):
        self.edge_servers = edge_servers

        # Time series database for metrics
        from prometheus_client import Counter, Histogram, Gauge

        self.request_count = Counter(
        'cdn_requests_total',
        'Total CDN requests',
        ['region', 'status']
        )

        self.latency = Histogram(
        'cdn_request_latency_seconds',
        'CDN request latency',
        ['region', 'cache_level']
        )

        self.cache_hit_rate = Gauge(
        'cdn_cache_hit_rate',
        'Cache hit rate',
        ['region', 'cache_level']
        )

    def collect_metrics(self):
        """
        Collect metrics from all edge servers

        Returns aggregated view
        """
        global_metrics = {
        'total_requests': 0,
        'total_cache_hits': 0,
        'regions': {}
        }

        for server in self.edge_servers:
            stats = server.metrics.get_stats()

            total_requests = (
            server.metrics.l1_hits +
            server.metrics.l2_hits +
            server.metrics.misses
            )

            total_cache_hits = server.metrics.l1_hits + server.metrics.l2_hits

            global_metrics['total_requests'] += total_requests
            global_metrics['total_cache_hits'] += total_cache_hits

            global_metrics['regions'][server.region] = {
            'requests': total_requests,
            'cache_hits': total_cache_hits,
            'stats': stats
            }

            # Calculate global hit rate
            if global_metrics['total_requests'] > 0:
                global_metrics['cache_hit_rate'] = (
                global_metrics['total_cache_hits'] / global_metrics['total_requests']
                )
            else:
                global_metrics['cache_hit_rate'] = 0

                return global_metrics

    def print_dashboard(self):
        """Print metrics dashboard"""
        metrics = self.collect_metrics()

        print("\n" + "="*70)
        print("CDN METRICS DASHBOARD")
        print("="*70)

        print(f"\nGlobal Metrics:")
        print(f" Total Requests: {metrics['total_requests']:,}")
        print(f" Cache Hit Rate: {metrics['cache_hit_rate']:.2%}")

        print(f"\nRegional Breakdown:")

        for region, data in metrics['regions'].items():
            print(f"\n {region}:")
            print(f" Requests: {data['requests']:,}")
            print(f" L1 Hit Rate: {data['stats']['l1_hit_rate']:.2%}")
            print(f" L2 Hit Rate: {data['stats']['l2_hit_rate']:.2%}")
            print(f" Miss Rate: {data['stats']['miss_rate']:.2%}")
            print(f" Avg L1 Latency: {data['stats']['avg_l1_latency_ms']:.2f}ms")
            print(f" Avg Miss Latency: {data['stats']['avg_miss_latency_ms']:.2f}ms")

            print("="*70)

    def plot_latency_distribution(self):
        """Plot latency distribution by region"""
        import matplotlib.pyplot as plt

        fig, axes = plt.subplots(len(self.edge_servers), 1, figsize=(12, 4 * len(self.edge_servers)))

        for i, server in enumerate(self.edge_servers):
            ax = axes[i] if len(self.edge_servers) > 1 else axes

            # Get latencies
            l1_latencies = np.array(server.metrics.l1_latencies) * 1000 # ms
            l2_latencies = np.array(server.metrics.l2_latencies) * 1000
            miss_latencies = np.array(server.metrics.miss_latencies) * 1000

            # Plot histograms
            ax.hist(l1_latencies, bins=50, alpha=0.5, label='L1 Cache', color='green')
            ax.hist(l2_latencies, bins=50, alpha=0.5, label='L2 Cache', color='blue')
            ax.hist(miss_latencies, bins=50, alpha=0.5, label='Origin', color='red')

            ax.set_xlabel('Latency (ms)')
            ax.set_ylabel('Frequency')
            ax.set_title(f'Latency Distribution - {server.region}')
            ax.legend()
            ax.grid(True, alpha=0.3)

            plt.tight_layout()
            plt.savefig('cdn_latency_distribution.png')
            plt.close()

            print("Latency distribution plot saved to cdn_latency_distribution.png")

            # Example usage
            edge_servers = [
            # ... initialize edge servers
            ]

            dashboard = CDNMetricsDashboard(edge_servers)

            # Collect and display metrics
            dashboard.print_dashboard()

            # Plot latency distribution
            dashboard.plot_latency_distribution()

Cost Optimization

Tiered Caching Strategy

class TieredCachingStrategy:
    """
    Optimize costs with tiered caching

    Tiers:
        1. Hot (L1 - Redis): Most accessed, expensive, fast
        2. Warm (L2 - Local disk): Frequently accessed, cheap, medium speed
        3. Cold (S3): Rarely accessed, cheapest, slow

        Move items between tiers based on access patterns
        """

    def __init__(self):
        self.l1_cost_per_gb_per_month = 100 # Redis
        self.l2_cost_per_gb_per_month = 10 # SSD
        self.l3_cost_per_gb_per_month = 0.02 # S3

        self.l1_size_gb = 10
        self.l2_size_gb = 100
        self.l3_size_gb = 1000

    def calculate_monthly_cost(self):
        """Calculate monthly storage cost"""
        l1_cost = self.l1_size_gb * self.l1_cost_per_gb_per_month
        l2_cost = self.l2_size_gb * self.l2_cost_per_gb_per_month
        l3_cost = self.l3_size_gb * self.l3_cost_per_gb_per_month

        total_cost = l1_cost + l2_cost + l3_cost

        return {
        'l1_cost': l1_cost,
        'l2_cost': l2_cost,
        'l3_cost': l3_cost,
        'total_cost': total_cost
        }

    def optimize_tier_sizes(self, access_patterns):
        """
        Optimize tier sizes based on access patterns

        Goal: Minimize cost while maintaining hit rate
        """
        # Analyze access frequency
        access_freq = {}

        for item_id, accesses in access_patterns.items():
            access_freq[item_id] = len(accesses)

            # Sort by frequency
            sorted_items = sorted(
            access_freq.items(),
            key=lambda x: x[1],
            reverse=True
            )

            # Allocate to tiers
            l1_items = sorted_items[:100] # Top 100
            l2_items = sorted_items[100:1000] # Next 900
            l3_items = sorted_items[1000:] # Rest

            print(f"Tier allocation:")
            print(f" L1 (Hot): {len(l1_items)} items")
            print(f" L2 (Warm): {len(l2_items)} items")
            print(f" L3 (Cold): {len(l3_items)} items")

            # Calculate expected hit rate
            total_accesses = sum(access_freq.values())
            l1_accesses = sum(freq for _, freq in l1_items)
            l2_accesses = sum(freq for _, freq in l2_items)

            l1_hit_rate = l1_accesses / total_accesses
            l2_hit_rate = l2_accesses / total_accesses

            print(f"\nExpected hit rates:")
            print(f" L1: {l1_hit_rate:.2%}")
            print(f" L2: {l2_hit_rate:.2%}")
            print(f" Combined (L1+L2): {(l1_hit_rate + l2_hit_rate):.2%}")

            # Example
            strategy = TieredCachingStrategy()

            # Calculate costs
            costs = strategy.calculate_monthly_cost()
            print("Monthly CDN storage costs:")
            for key, value in costs.items():
                print(f" {key}: ${value:.2f}")

                # Simulate access patterns
                access_patterns = {
                f"item_{i}": [time.time() - random.random() * 86400 for _ in range(random.randint(1, 100))]
                for i in range(10000)
                }

                # Optimize
                strategy.optimize_tier_sizes(access_patterns)

Key Takeaways

Edge caching - Serve content close to users for low latency ✅ Multi-level cache - L1 (Redis), L2 (disk), origin (database) ✅ Smart routing - GeoDNS + latency-based + load-based ✅ Cache invalidation - Pub/sub for real-time propagation ✅ Edge ML serving - Deploy models to edge for fast inference ✅ Cost optimization - Tiered storage based on access patterns

Key Metrics:

  • Cache hit rate: > 80%
  • P99 latency: < 50ms for cache hits
  • Origin latency: 200-500ms
  • Bandwidth savings: 70-90%

FAQ

How does a CDN reduce ML inference latency?

A CDN places edge servers with cached predictions and lightweight models close to users, eliminating the round-trip to a central origin server. Instead of a user in Tokyo hitting a server in Virginia (150-200ms round trip), the request goes to a local edge server (1-2ms). This cuts inference latency from 200-500ms to under 50ms for cached or edge-computed predictions.

What caching strategy works best for ML predictions at the edge?

A multi-level cache provides the best balance of speed and cost. Redis (L1) stores hot prediction results with sub-millisecond access at roughly $100/GB/month. Local SSD (L2) holds model weights and embeddings at ~$10/GB/month. The origin server handles cold data and full dataset access. Items promote from L2 to L1 based on access frequency, and TTL-based expiration keeps the cache fresh.

How do you handle model updates across a global CDN?

Use a pub/sub system like Redis to broadcast cache invalidation messages to all edge servers when a model is updated. Each edge server subscribes to invalidation channels, clears affected cache entries by key or pattern, then downloads the new model from S3. Atomic version swaps ensure the edge never serves a partially loaded model.

When should you deploy ML models at the edge versus the origin?

Deploy at the edge when latency requirements are strict (under 50ms), when data privacy regulations require regional processing, or when predictions are highly cacheable. Keep complex models at the origin when accuracy demands full-size models that exceed edge memory, when predictions require real-time access to a central feature store, or when the model changes too frequently for efficient distribution.


Cross-links: Model Serving Architecture Caching Strategies for ML Distributed ML Systems

Want to work together?

I take on projects, advisory roles, and fractional CTO engagements in AI/ML. I also help businesses go AI-native with agentic workflows and agent orchestration.

Get in touch