Content Delivery Networks (CDN)
Design a global CDN for ML systems: Edge caching reduces latency from 500ms to 50ms. Critical for real-time predictions worldwide.
Problem Statement
Design a Content Delivery Network (CDN) for serving:
- ML model inference (predictions at the edge)
- Static assets (model weights, configs, embeddings)
- API responses (cached predictions, feature data)
Why Do We Need a CDN?
The Core Problem: Distance Creates Latency
Imagine you’re a user in Tokyo trying to access a website hosted in Virginia, USA:
User (Tokyo) ──────── 10,000 km ──────── Server (Virginia)
~150ms round-trip time
The physics problem:
- Light travels at 300,000 km/s
- Signal in fiber: ~200,000 km/s
- Tokyo ↔ Virginia: ~10,000 km
- Theoretical minimum: 50ms
- Reality with routing: 150-200ms
What if we could serve from Tokyo instead?
User (Tokyo) ── 10 km ── Edge Server (Tokyo)
~1-2ms!
That’s a 75-100x improvement just from being geographically closer!
Real-World Impact on ML Systems
Scenario: Real-time recommendation system
Architecture | Latency | User Experience |
---|---|---|
Without CDN: Request → US → Model Inference → Response | 200ms+ | Noticeable delay, users leave |
With CDN: Request → Local Edge → Cached/Local Inference → Response | 20-50ms | Feels instant ✓ |
The business impact:
- Every 100ms of latency = 1% drop in sales (Amazon study)
- For ML systems: Users won’t wait for slow predictions
- CDN makes your ML system feel instant globally
What CDN Does for You
1. Geographic Distribution Cache content at multiple locations worldwide (edge servers)
2. Intelligent Caching Store frequently accessed content close to users
3. Smart Routing Direct users to the best edge server (closest + healthy + low load)
4. Fault Tolerance If one edge fails, route to another
5. Bandwidth Savings Serve from edge → Less traffic to origin → Lower costs
Requirements
Functional:
- Serve content from geographically distributed edge locations
- Cache popular content close to users
- Route requests to nearest/best edge server
- Handle cache invalidation and updates
- Support both static and dynamic content
Non-Functional:
- Latency: < 50ms p99 for edge hits (vs 200-500ms from origin)
- Availability: 99.99% uptime (4 minutes downtime/month)
- Scalability: Handle 1M+ requests/second globally
- Cache hit rate: > 80% for static content (fewer origin requests)
- Global coverage: Presence in 50+ regions
High-Level Architecture
┌─────────────────────────────────────────────────────────────┐
│ USER REQUESTS │
│ 🌍 Asia 🌍 Europe 🌍 Americas 🌍 Africa │
└───────┬────────────┬────────────┬──────────────┬────────────┘
│ │ │ │
↓ ↓ ↓ ↓
┌────────────────────────────────────────────────────────────┐
│ DNS / GLOBAL LOAD BALANCER │
│ • GeoDNS routing │
│ • Health checks │
│ • Latency-based routing │
└────────┬────────────┬────────────┬──────────────┬──────────┘
│ │ │ │
┌────▼────┐ ┌───▼─────┐ ┌───▼─────┐ ┌───▼─────┐
│ Edge │ │ Edge │ │ Edge │ │ Edge │
│ Tokyo │ │ London │ │ N.Virginia│ │ Mumbai │
│ │ │ │ │ │ │ │
│ L1 Cache│ │ L1 Cache│ │ L1 Cache│ │ L1 Cache│
│ (Redis) │ │ (Redis) │ │ (Redis) │ │ (Redis) │
│ │ │ │ │ │ │ │
│ ML Model│ │ ML Model│ │ ML Model│ │ ML Model│
└────┬────┘ └────┬────┘ └────┬────┘ └────┬────┘
│ │ │ │
└────────────┴────────────┴──────────────┘
│
┌───────▼────────┐
│ ORIGIN SERVERS │
│ │
│ • Master models│
│ • Databases │
│ • Feature store│
│ • Object store │
└────────────────┘
Core Components
1. Edge Servers
Purpose: Serve content from locations close to users
Before we dive into code, let’s understand the concept:
What is an Edge Server?
Think of edge servers like local convenience stores:
- Origin Server = Central warehouse (far away, has everything)
- Edge Server = Local store (nearby, has popular items)
When you need milk:
- Without edge: Drive to warehouse (30 min)
- With edge: Walk to local store (2 min)
Multi-Level Cache Strategy
Edge servers use multiple cache layers:
Request → L1 Cache (Redis, in-memory) ← Fastest, smallest
↓ Miss
L2 Cache (Disk, local SSD) ← Fast, medium
↓ Miss
Origin Server (Database) ← Slow, largest
Why multiple levels?
- L1 (Redis): Hot data, 50-100ms access, expensive ($100/GB/month)
- L2 (Disk): Warm data, 5-10ms access, cheap ($10/GB/month)
- Origin: Cold data, 100-500ms access, cheapest ($0.02/GB/month)
Trade-off: Speed vs Cost vs Capacity
Cache | Speed | Cost | Capacity | Use Case |
---|---|---|---|---|
L1 (Redis) | 1ms | High | Small (10GB) | Prediction results, hot features |
L2 (Disk) | 10ms | Medium | Medium (100GB) | Model weights, embeddings |
Origin | 200ms | Low | Large (TB+) | Full dataset, historical data |
class EdgeServer:
"""
CDN edge server
Components:
- L1 cache (Redis): Hot content
- L2 cache (local disk): Warm content
- ML model: For edge inference
- Origin client: Fetch misses from origin
"""
def __init__(self, region, origin_url):
self.region = region
self.origin_url = origin_url
# Multi-level cache
import redis
import pickle
self.l1_cache = redis.Redis(host='localhost', port=6379)
# Minimal DiskCache stub for illustration
class DiskCache:
def __init__(self, size_gb=100):
self.store = {}
def get(self, key):
return self.store.get(key)
def set(self, key, value):
self.store[key] = value
def delete(self, key):
self.store.pop(key, None)
def delete_pattern(self, pattern):
# naive pattern matcher
import fnmatch
keys = [k for k in self.store.keys() if fnmatch.fnmatch(k, pattern)]
for k in keys:
self.store.pop(k, None)
self.l2_cache = DiskCache(size_gb=100)
# ML model for edge inference
def load_model(path):
return object()
self.model = load_model('model.onnx')
# Metrics
self.metrics = EdgeMetrics()
async def handle_request(self, request):
"""
Handle incoming request
Flow:
1. Check L1 cache (Redis)
2. Check L2 cache (disk)
3. Fetch from origin
4. Update caches
"""
import time, json, pickle
start_time = time.time()
# Generate cache key
cache_key = self._generate_cache_key(request)
# Try L1 cache
response = await self._check_l1_cache(cache_key)
if response:
self.metrics.record_hit('l1', time.time() - start_time)
return response
# Try L2 cache
response = await self._check_l2_cache(cache_key)
if response:
# Promote to L1
await self._store_l1_cache(cache_key, response)
self.metrics.record_hit('l2', time.time() - start_time)
return response
# Cache miss: fetch from origin
response = await self._fetch_from_origin(request)
# Update caches
await self._store_l1_cache(cache_key, response)
await self._store_l2_cache(cache_key, response)
self.metrics.record_miss(time.time() - start_time)
return response
async def _check_l1_cache(self, key):
"""Check L1 (Redis) cache"""
try:
data = self.l1_cache.get(key)
if data:
return pickle.loads(data)
except Exception as e:
print(f"L1 cache error: {e}")
return None
async def _store_l1_cache(self, key, value, ttl=300):
"""Store in L1 cache with TTL"""
try:
self.l1_cache.setex(
key,
ttl,
pickle.dumps(value)
)
except Exception as e:
print(f"L1 cache store error: {e}")
async def _check_l2_cache(self, key):
"""Check L2 (disk) cache"""
return self.l2_cache.get(key)
async def _store_l2_cache(self, key, value):
"""Store in L2 cache"""
self.l2_cache.set(key, value)
async def _fetch_from_origin(self, request):
"""Fetch from origin server"""
import aiohttp
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.origin_url}{request.path}",
json=request.data
) as response:
return await response.json()
def _generate_cache_key(self, request):
"""Generate cache key from request"""
import hashlib
# Include path and normalized data
key_data = f"{request.path}:{json.dumps(request.data, sort_keys=True)}"
return hashlib.md5(key_data.encode()).hexdigest()
class EdgeMetrics:
"""Track edge server metrics"""
def __init__(self):
self.l1_hits = 0
self.l2_hits = 0
self.misses = 0
self.l1_latencies = []
self.l2_latencies = []
self.miss_latencies = []
def record_hit(self, level, latency):
if level == 'l1':
self.l1_hits += 1
self.l1_latencies.append(latency)
elif level == 'l2':
self.l2_hits += 1
self.l2_latencies.append(latency)
def record_miss(self, latency):
self.misses += 1
self.miss_latencies.append(latency)
def get_stats(self):
total = self.l1_hits + self.l2_hits + self.misses
return {
'l1_hit_rate': self.l1_hits / total if total > 0 else 0,
'l2_hit_rate': self.l2_hits / total if total > 0 else 0,
'miss_rate': self.misses / total if total > 0 else 0,
'avg_l1_latency_ms': np.mean(self.l1_latencies) * 1000 if self.l1_latencies else 0,
'avg_l2_latency_ms': np.mean(self.l2_latencies) * 1000 if self.l2_latencies else 0,
'avg_miss_latency_ms': np.mean(self.miss_latencies) * 1000 if self.miss_latencies else 0,
}
# Example usage
edge = EdgeServer(region='us-east-1', origin_url='https://api.example.com')
# Simulate requests
async def simulate_requests():
for i in range(100):
request = Request(
path='/predict',
data={'features': [1, 2, 3, 4, 5]}
)
response = await edge.handle_request(request)
print(f"Request {i}: {response}")
# Print metrics
stats = edge.metrics.get_stats()
print("\nEdge Server Metrics:")
for key, value in stats.items():
if 'rate' in key:
print(f" {key}: {value:.2%}")
else:
print(f" {key}: {value:.2f}")
# Run
import asyncio
asyncio.run(simulate_requests())
2. Global Load Balancer / GeoDNS
Purpose: Route requests to optimal edge server
class GlobalLoadBalancer:
"""
Route requests to best edge server
Routing strategies:
1. Geographic proximity
2. Server load
3. Health status
4. Network latency
"""
def __init__(self):
self.edge_servers = self._discover_edge_servers()
self.health_checker = HealthChecker(self.edge_servers)
# Start health checking
self.health_checker.start()
def route_request(self, client_ip, request):
"""
Route request to best edge server
Args:
client_ip: Client IP address
request: Request object
Returns:
Best edge server
"""
# Get client location
client_location = self._geolocate_ip(client_ip)
# Get healthy edge servers
healthy_servers = self.health_checker.get_healthy_servers()
if not healthy_servers:
raise Exception("No healthy edge servers available")
# Score each server
scores = []
for server in healthy_servers:
score = self._score_server(
server,
client_location,
request
)
scores.append((server, score))
# Sort by score (higher is better)
scores.sort(key=lambda x: x[1], reverse=True)
# Return best server
return scores[0][0]
def _score_server(self, server, client_location, request):
"""
Score server for given request
Factors:
- Geographic distance (weight: 0.5)
- Server load (weight: 0.3)
- Cache hit rate (weight: 0.2)
"""
# Geographic proximity
distance = self._calculate_distance(
client_location,
server.location
)
distance_score = 1.0 / (1.0 + distance / 1000) # Normalize
# Server load
load = server.get_current_load()
load_score = 1.0 - min(load, 1.0)
# Cache hit rate
hit_rate = server.metrics.get_stats()['l1_hit_rate']
# Weighted score
score = (
0.5 * distance_score +
0.3 * load_score +
0.2 * hit_rate
)
return score
def _geolocate_ip(self, ip):
"""
Get geographic location from IP
Uses MaxMind GeoIP or similar
"""
import geoip2.database
reader = geoip2.database.Reader('GeoLite2-City.mmdb')
response = reader.city(ip)
return {
'lat': response.location.latitude,
'lon': response.location.longitude,
'city': response.city.name,
'country': response.country.name
}
def _calculate_distance(self, loc1, loc2):
"""
Calculate distance between two locations (km)
Uses Haversine formula
"""
from math import radians, sin, cos, sqrt, atan2
R = 6371 # Earth radius in km
lat1, lon1 = radians(loc1['lat']), radians(loc1['lon'])
lat2, lon2 = radians(loc2['lat']), radians(loc2['lon'])
dlat = lat2 - lat1
dlon = lon2 - lon1
a = sin(dlat/2)**2 + cos(lat1) * cos(lat2) * sin(dlon/2)**2
c = 2 * atan2(sqrt(a), sqrt(1-a))
distance = R * c
return distance
def _discover_edge_servers(self):
"""Discover available edge servers"""
# In production, this would query service registry
return [
EdgeServerInfo('us-east-1', 'https://edge-us-east-1.example.com', {'lat': 39.0, 'lon': -77.5}),
EdgeServerInfo('eu-west-1', 'https://edge-eu-west-1.example.com', {'lat': 53.3, 'lon': -6.3}),
EdgeServerInfo('ap-northeast-1', 'https://edge-ap-northeast-1.example.com', {'lat': 35.7, 'lon': 139.7}),
]
class HealthChecker:
"""
Monitor health of edge servers
Checks:
- HTTP health endpoint
- Response time
- Error rate
"""
def __init__(self, servers, check_interval=10):
self.servers = servers
self.check_interval = check_interval
self.health_status = {server.region: True for server in servers}
self.last_check = {server.region: 0 for server in servers}
self.running = False
def start(self):
"""Start health checking in background"""
self.running = True
import threading
self.thread = threading.Thread(target=self._health_check_loop, daemon=True)
self.thread.start()
def stop(self):
"""Stop health checking"""
self.running = False
def _health_check_loop(self):
"""Health check loop"""
while self.running:
for server in self.servers:
healthy = self._check_server_health(server)
self.health_status[server.region] = healthy
self.last_check[server.region] = time.time()
import time
time.sleep(self.check_interval)
def _check_server_health(self, server):
"""Check if server is healthy"""
try:
import requests
response = requests.get(
f"{server.url}/health",
timeout=5
)
if response.status_code == 200:
# Check response time
if response.elapsed.total_seconds() < 1.0:
return True
return False
except Exception as e:
print(f"Health check failed for {server.region}: {e}")
return False
def get_healthy_servers(self):
"""Get list of healthy servers"""
return [
server for server in self.servers
if self.health_status[server.region]
]
class EdgeServerInfo:
"""Edge server information"""
def __init__(self, region, url, location):
self.region = region
self.url = url
self.location = location
self.metrics = EdgeMetrics()
def get_current_load(self):
"""Get current server load (0-1)"""
# In production, query server metrics
return 0.5 # Placeholder
# Example usage
glb = GlobalLoadBalancer()
# Route request
client_ip = '8.8.8.8' # Google DNS (US)
request = Request(path='/predict', data={})
best_server = glb.route_request(client_ip, request)
print(f"Routing to: {best_server.region}")
3. Cache Invalidation System
Purpose: Propagate updates across edge servers
class CacheInvalidationSystem:
"""
Propagate cache invalidations to edge servers
Methods:
1. Push-based: Immediate invalidation
2. Pull-based: Periodic refresh
3. TTL-based: Automatic expiration
"""
def __init__(self, edge_servers):
self.edge_servers = edge_servers
# Message queue for invalidations
self.invalidation_queue = redis.Redis(host='localhost', port=6379)
# Pub/sub for real-time propagation
self.pubsub = self.invalidation_queue.pubsub()
self.pubsub.subscribe('cache:invalidate')
def invalidate(self, keys, pattern=False):
"""
Invalidate cache keys across all edge servers
Args:
keys: List of keys to invalidate
pattern: If True, treat keys as patterns
"""
message = {
'keys': keys,
'pattern': pattern,
'timestamp': time.time()
}
# Publish to all edge servers
self.invalidation_queue.publish(
'cache:invalidate',
json.dumps(message)
)
print(f"Invalidated {len(keys)} keys across edge network")
def invalidate_prefix(self, prefix):
"""
Invalidate all keys with given prefix
Example: invalidate_prefix('user:123:*')
"""
self.invalidate([prefix], pattern=True)
def invalidate_model_update(self, model_id):
"""
Invalidate caches after model update
Invalidates:
- Model predictions
- Model metadata
- Related embeddings
"""
patterns = [
f"model:{model_id}:*",
f"prediction:{model_id}:*",
f"embedding:{model_id}:*"
]
self.invalidate(patterns, pattern=True)
print(f"Invalidated caches for model {model_id}")
class EdgeInvalidationListener:
"""
Listen for invalidation messages on edge server
"""
def __init__(self, edge_server):
self.edge_server = edge_server
# Subscribe to invalidations
self.pubsub = redis.Redis(host='localhost', port=6379).pubsub()
self.pubsub.subscribe('cache:invalidate')
self.running = False
def start(self):
"""Start listening for invalidations"""
self.running = True
import threading
self.thread = threading.Thread(target=self._listen_loop, daemon=True)
self.thread.start()
def stop(self):
"""Stop listening"""
self.running = False
def _listen_loop(self):
"""Listen for invalidation messages"""
for message in self.pubsub.listen():
if message['type'] == 'message':
data = json.loads(message['data'])
self._handle_invalidation(data)
def _handle_invalidation(self, data):
"""Handle invalidation message"""
keys = data['keys']
pattern = data['pattern']
if pattern:
# Invalidate by pattern
for key_pattern in keys:
self._invalidate_pattern(key_pattern)
else:
# Invalidate specific keys
for key in keys:
self._invalidate_key(key)
def _invalidate_key(self, key):
"""Invalidate specific key"""
# Remove from L1 cache
self.edge_server.l1_cache.delete(key)
# Remove from L2 cache
self.edge_server.l2_cache.delete(key)
print(f"Invalidated key: {key}")
def _invalidate_pattern(self, pattern):
"""Invalidate keys matching pattern"""
# Scan L1 cache
for key in self.edge_server.l1_cache.scan_iter(match=pattern):
self.edge_server.l1_cache.delete(key)
# Scan L2 cache
self.edge_server.l2_cache.delete_pattern(pattern)
print(f"Invalidated pattern: {pattern}")
# Example usage
edge_servers = [
EdgeServer('us-east-1', 'https://origin.example.com'),
EdgeServer('eu-west-1', 'https://origin.example.com'),
]
invalidation_system = CacheInvalidationSystem(edge_servers)
# Start listeners on each edge
for edge in edge_servers:
listener = EdgeInvalidationListener(edge)
listener.start()
# Trigger invalidation
invalidation_system.invalidate_model_update('model_v2')
ML Model Serving at Edge
Edge Inference
class EdgeMLServer:
"""
Serve ML models at edge for low-latency inference
Benefits:
- Reduced latency (no round trip to origin)
- Reduced bandwidth
- Better privacy (data doesn't leave region)
"""
def __init__(self, model_path):
# Load ONNX model for edge inference
import onnxruntime as ort
self.session = ort.InferenceSession(
model_path,
providers=['CPUExecutionProvider']
)
self.input_name = self.session.get_inputs()[0].name
self.output_name = self.session.get_outputs()[0].name
# Cache for predictions
self.prediction_cache = LRUCache(capacity=10000)
def predict(self, features):
"""
Predict with caching
Args:
features: Input features (must be hashable)
Returns:
Prediction
"""
# Generate cache key
cache_key = hash(features)
# Check cache
cached_prediction = self.prediction_cache.get(cache_key)
if cached_prediction != -1:
return cached_prediction
# Run inference
features_array = np.array([features], dtype=np.float32)
prediction = self.session.run(
[self.output_name],
{self.input_name: features_array}
)[0][0]
# Cache result
self.prediction_cache.put(cache_key, prediction)
return prediction
async def batch_predict(self, features_list):
"""
Batch prediction for efficiency
Separates cache hits from misses
"""
predictions = {}
cache_misses = []
cache_miss_indices = []
# Check cache
for i, features in enumerate(features_list):
cache_key = hash(features)
cached = self.prediction_cache.get(cache_key)
if cached != -1:
predictions[i] = cached
else:
cache_misses.append(features)
cache_miss_indices.append(i)
# Batch inference for cache misses
if cache_misses:
features_array = np.array(cache_misses, dtype=np.float32)
batch_predictions = self.session.run(
[self.output_name],
{self.input_name: features_array}
)[0]
# Store in cache and results
for i, pred in zip(cache_miss_indices, batch_predictions):
cache_key = hash(features_list[i])
self.prediction_cache.put(cache_key, pred)
predictions[i] = pred
# Return in original order
return [predictions[i] for i in range(len(features_list))]
# Example: Edge API server with ML inference
from fastapi import FastAPI
import uvicorn
app = FastAPI()
# Load model at startup
edge_ml_server = EdgeMLServer('model.onnx')
@app.post("/predict")
async def predict(request: dict):
"""
Edge prediction endpoint
Returns cached or computed prediction
"""
features = tuple(request['features'])
try:
prediction = edge_ml_server.predict(features)
return {
'prediction': float(prediction),
'cached': edge_ml_server.prediction_cache.get(hash(features)) != -1,
'edge_region': 'us-east-1'
}
except Exception as e:
return {'error': str(e)}, 500
@app.post("/batch_predict")
async def batch_predict(request: dict):
"""
Batch prediction endpoint
"""
features_list = [tuple(f) for f in request['features']]
predictions = await edge_ml_server.batch_predict(features_list)
return {
'predictions': [float(p) for p in predictions],
'count': len(predictions),
'edge_region': 'us-east-1'
}
# Run edge server
# uvicorn.run(app, host='0.0.0.0', port=8000)
Model Distribution to Edge
class ModelDistributionSystem:
"""
Distribute ML models to edge servers
Challenges:
- Large model sizes (GB)
- Many edge locations
- Version management
- Atomic updates
"""
def __init__(self, s3_bucket, edge_servers):
self.s3_bucket = s3_bucket
self.edge_servers = edge_servers
# Track model versions at each edge
self.edge_versions = {
server.region: None
for server in edge_servers
}
def distribute_model(self, model_path, version):
"""
Distribute model to all edge servers
Steps:
1. Upload to S3
2. Notify edge servers
3. Edge servers download
4. Edge servers validate
5. Edge servers activate
"""
print(f"Distributing model {version} to {len(self.edge_servers)} edge servers...")
# Step 1: Upload to S3
s3_key = f"models/{version}/model.onnx"
self._upload_to_s3(model_path, s3_key)
# Step 2: Notify edge servers
results = []
for server in self.edge_servers:
result = self._distribute_to_edge(server, s3_key, version)
results.append((server.region, result))
# Check results
successful = [r for r in results if r[1]]
failed = [r for r in results if not r[1]]
print(f"\nDistribution complete:")
print(f" Successful: {len(successful)}/{len(self.edge_servers)}")
print(f" Failed: {len(failed)}")
if failed:
print(f" Failed regions: {[r[0] for r in failed]}")
return len(failed) == 0
def _upload_to_s3(self, local_path, s3_key):
"""Upload model to S3"""
import boto3
s3 = boto3.client('s3')
print(f"Uploading {local_path} to s3://{self.s3_bucket}/{s3_key}")
s3.upload_file(
local_path,
self.s3_bucket,
s3_key,
ExtraArgs={'ServerSideEncryption': 'AES256'}
)
def _distribute_to_edge(self, server, s3_key, version):
"""
Notify edge server to download model
Edge server will:
1. Download from S3
2. Validate checksum
3. Load model
4. Run health checks
5. Activate (atomic swap)
"""
try:
import requests
response = requests.post(
f"{server.url}/admin/update_model",
json={
's3_bucket': self.s3_bucket,
's3_key': s3_key,
'version': version
},
timeout=300 # 5 minutes for large models
)
if response.status_code == 200:
self.edge_versions[server.region] = version
print(f" ✓ {server.region}: Updated to {version}")
return True
else:
print(f" ✗ {server.region}: Failed - {response.text}")
return False
except Exception as e:
print(f" ✗ {server.region}: Error - {e}")
return False
def rollback_model(self, target_version):
"""
Rollback to previous model version
Useful if new model has issues
"""
print(f"Rolling back to version {target_version}...")
s3_key = f"models/{target_version}/model.onnx"
return self.distribute_model(f"/tmp/model_{target_version}.onnx", target_version)
def get_version_status(self):
"""Get model versions deployed at each edge"""
return self.edge_versions
# Example usage
edge_servers = [
EdgeServerInfo('us-east-1', 'https://edge-us-east-1.example.com', {}),
EdgeServerInfo('eu-west-1', 'https://edge-eu-west-1.example.com', {}),
EdgeServerInfo('ap-northeast-1', 'https://edge-ap-northeast-1.example.com', {}),
]
distributor = ModelDistributionSystem(
s3_bucket='my-ml-models',
edge_servers=edge_servers
)
# Distribute new model
success = distributor.distribute_model('model_v3.onnx', 'v3')
if success:
print("\nModel distribution successful!")
print("Current versions:")
for region, version in distributor.get_version_status().items():
print(f" {region}: {version}")
else:
print("\nModel distribution failed, rolling back...")
distributor.rollback_model('v2')
Monitoring & Observability
CDN Metrics Dashboard
class CDNMetricsDashboard:
"""
Aggregate and visualize CDN metrics
Key metrics:
- Cache hit rate
- Latency (p50, p95, p99)
- Bandwidth usage
- Error rate
- Request rate
"""
def __init__(self, edge_servers):
self.edge_servers = edge_servers
# Time series database for metrics
from prometheus_client import Counter, Histogram, Gauge
self.request_count = Counter(
'cdn_requests_total',
'Total CDN requests',
['region', 'status']
)
self.latency = Histogram(
'cdn_request_latency_seconds',
'CDN request latency',
['region', 'cache_level']
)
self.cache_hit_rate = Gauge(
'cdn_cache_hit_rate',
'Cache hit rate',
['region', 'cache_level']
)
def collect_metrics(self):
"""
Collect metrics from all edge servers
Returns aggregated view
"""
global_metrics = {
'total_requests': 0,
'total_cache_hits': 0,
'regions': {}
}
for server in self.edge_servers:
stats = server.metrics.get_stats()
total_requests = (
server.metrics.l1_hits +
server.metrics.l2_hits +
server.metrics.misses
)
total_cache_hits = server.metrics.l1_hits + server.metrics.l2_hits
global_metrics['total_requests'] += total_requests
global_metrics['total_cache_hits'] += total_cache_hits
global_metrics['regions'][server.region] = {
'requests': total_requests,
'cache_hits': total_cache_hits,
'stats': stats
}
# Calculate global hit rate
if global_metrics['total_requests'] > 0:
global_metrics['cache_hit_rate'] = (
global_metrics['total_cache_hits'] / global_metrics['total_requests']
)
else:
global_metrics['cache_hit_rate'] = 0
return global_metrics
def print_dashboard(self):
"""Print metrics dashboard"""
metrics = self.collect_metrics()
print("\n" + "="*70)
print("CDN METRICS DASHBOARD")
print("="*70)
print(f"\nGlobal Metrics:")
print(f" Total Requests: {metrics['total_requests']:,}")
print(f" Cache Hit Rate: {metrics['cache_hit_rate']:.2%}")
print(f"\nRegional Breakdown:")
for region, data in metrics['regions'].items():
print(f"\n {region}:")
print(f" Requests: {data['requests']:,}")
print(f" L1 Hit Rate: {data['stats']['l1_hit_rate']:.2%}")
print(f" L2 Hit Rate: {data['stats']['l2_hit_rate']:.2%}")
print(f" Miss Rate: {data['stats']['miss_rate']:.2%}")
print(f" Avg L1 Latency: {data['stats']['avg_l1_latency_ms']:.2f}ms")
print(f" Avg Miss Latency: {data['stats']['avg_miss_latency_ms']:.2f}ms")
print("="*70)
def plot_latency_distribution(self):
"""Plot latency distribution by region"""
import matplotlib.pyplot as plt
fig, axes = plt.subplots(len(self.edge_servers), 1, figsize=(12, 4 * len(self.edge_servers)))
for i, server in enumerate(self.edge_servers):
ax = axes[i] if len(self.edge_servers) > 1 else axes
# Get latencies
l1_latencies = np.array(server.metrics.l1_latencies) * 1000 # ms
l2_latencies = np.array(server.metrics.l2_latencies) * 1000
miss_latencies = np.array(server.metrics.miss_latencies) * 1000
# Plot histograms
ax.hist(l1_latencies, bins=50, alpha=0.5, label='L1 Cache', color='green')
ax.hist(l2_latencies, bins=50, alpha=0.5, label='L2 Cache', color='blue')
ax.hist(miss_latencies, bins=50, alpha=0.5, label='Origin', color='red')
ax.set_xlabel('Latency (ms)')
ax.set_ylabel('Frequency')
ax.set_title(f'Latency Distribution - {server.region}')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('cdn_latency_distribution.png')
plt.close()
print("Latency distribution plot saved to cdn_latency_distribution.png")
# Example usage
edge_servers = [
# ... initialize edge servers
]
dashboard = CDNMetricsDashboard(edge_servers)
# Collect and display metrics
dashboard.print_dashboard()
# Plot latency distribution
dashboard.plot_latency_distribution()
Cost Optimization
Tiered Caching Strategy
class TieredCachingStrategy:
"""
Optimize costs with tiered caching
Tiers:
1. Hot (L1 - Redis): Most accessed, expensive, fast
2. Warm (L2 - Local disk): Frequently accessed, cheap, medium speed
3. Cold (S3): Rarely accessed, cheapest, slow
Move items between tiers based on access patterns
"""
def __init__(self):
self.l1_cost_per_gb_per_month = 100 # Redis
self.l2_cost_per_gb_per_month = 10 # SSD
self.l3_cost_per_gb_per_month = 0.02 # S3
self.l1_size_gb = 10
self.l2_size_gb = 100
self.l3_size_gb = 1000
def calculate_monthly_cost(self):
"""Calculate monthly storage cost"""
l1_cost = self.l1_size_gb * self.l1_cost_per_gb_per_month
l2_cost = self.l2_size_gb * self.l2_cost_per_gb_per_month
l3_cost = self.l3_size_gb * self.l3_cost_per_gb_per_month
total_cost = l1_cost + l2_cost + l3_cost
return {
'l1_cost': l1_cost,
'l2_cost': l2_cost,
'l3_cost': l3_cost,
'total_cost': total_cost
}
def optimize_tier_sizes(self, access_patterns):
"""
Optimize tier sizes based on access patterns
Goal: Minimize cost while maintaining hit rate
"""
# Analyze access frequency
access_freq = {}
for item_id, accesses in access_patterns.items():
access_freq[item_id] = len(accesses)
# Sort by frequency
sorted_items = sorted(
access_freq.items(),
key=lambda x: x[1],
reverse=True
)
# Allocate to tiers
l1_items = sorted_items[:100] # Top 100
l2_items = sorted_items[100:1000] # Next 900
l3_items = sorted_items[1000:] # Rest
print(f"Tier allocation:")
print(f" L1 (Hot): {len(l1_items)} items")
print(f" L2 (Warm): {len(l2_items)} items")
print(f" L3 (Cold): {len(l3_items)} items")
# Calculate expected hit rate
total_accesses = sum(access_freq.values())
l1_accesses = sum(freq for _, freq in l1_items)
l2_accesses = sum(freq for _, freq in l2_items)
l1_hit_rate = l1_accesses / total_accesses
l2_hit_rate = l2_accesses / total_accesses
print(f"\nExpected hit rates:")
print(f" L1: {l1_hit_rate:.2%}")
print(f" L2: {l2_hit_rate:.2%}")
print(f" Combined (L1+L2): {(l1_hit_rate + l2_hit_rate):.2%}")
# Example
strategy = TieredCachingStrategy()
# Calculate costs
costs = strategy.calculate_monthly_cost()
print("Monthly CDN storage costs:")
for key, value in costs.items():
print(f" {key}: ${value:.2f}")
# Simulate access patterns
access_patterns = {
f"item_{i}": [time.time() - random.random() * 86400 for _ in range(random.randint(1, 100))]
for i in range(10000)
}
# Optimize
strategy.optimize_tier_sizes(access_patterns)
Key Takeaways
✅ Edge caching - Serve content close to users for low latency
✅ Multi-level cache - L1 (Redis), L2 (disk), origin (database)
✅ Smart routing - GeoDNS + latency-based + load-based
✅ Cache invalidation - Pub/sub for real-time propagation
✅ Edge ML serving - Deploy models to edge for fast inference
✅ Cost optimization - Tiered storage based on access patterns
Key Metrics:
- Cache hit rate: > 80%
- P99 latency: < 50ms for cache hits
- Origin latency: 200-500ms
- Bandwidth savings: 70-90%
Originally published at: arunbaby.com/ml-system-design/0011-content-delivery-network
If you found this helpful, consider sharing it with others who might benefit.