A/B Testing Systems for ML
How to design experimentation platforms that enable rapid iteration while maintaining statistical rigor at scale.
TL;DR
A/B testing is the backbone of data-driven ML iteration. This post covers the full experimentation platform design: deterministic user assignment via consistent hashing, metrics tracking and aggregation, statistical significance testing with t-tests and chi-square, power analysis for sample sizing, guardrail metrics that block harmful rollouts, sequential testing for early stopping, multi-armed bandits for adaptive allocation, CUPED for variance reduction, and layered experiments that let hundreds of tests run concurrently. For the broader ML system design series, experimentation is what turns model improvements into validated business outcomes.

Introduction
A/B testing is the backbone of data-driven decision making in ML systems. Every major tech company runs thousands of experiments simultaneously to:
- Test new model versions
- Validate product changes
- Optimize user experience
- Measure feature impact
Why it matters:
- Validate improvements: Ensure new models actually perform better
- Reduce risk: Test changes on small cohorts before full rollout
- Quantify impact: Measure precise effect size, not just gut feeling
- Enable velocity: Run multiple experiments in parallel
What you’ll learn:
- A/B testing architecture for ML systems
- Statistical foundations (hypothesis testing, power analysis)
- Experiment assignment and randomization
- Metrics tracking and analysis
- Guardrail metrics and quality assurance
- Real-world examples from tech giants
Problem Definition
Design an A/B testing platform for ML systems.
Functional Requirements
- Experiment Setup
- Create experiments with control/treatment variants
- Define success metrics and guardrails
- Set experiment parameters (duration, traffic allocation)
- Support multi-variant testing (A/B/C/D)
- User Assignment
- Randomly assign users to variants
- Ensure consistency (same user always sees same variant)
- Support layered experiments
- Handle new vs returning users
- Metrics Tracking
- Log user actions and outcomes
- Compute experiment metrics in real-time
- Track both primary and secondary metrics
- Monitor guardrail metrics
- Statistical Analysis
- Calculate statistical significance
- Compute confidence intervals
- Detect early wins/losses
- Generate experiment reports
Non-Functional Requirements
- Scale
- Handle 10M+ users
- Support 100+ concurrent experiments
- Process billions of events/day
- Latency
- Assignment: < 10ms
- Metrics updates: Near real-time (< 1 minute lag)
- Reliability
- 99.9% uptime
- No data loss
- Audit trail for all experiments
- Statistical Rigor
- Type I error (false positive) < 5%
- Sufficient statistical power (80%+)
- Multiple testing corrections
High-Level Architecture
┌─────────────────────────────────────────────────────────────┐
│ Experimentation Platform │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ Experiment Configuration Service │ │
│ │ - Create experiments │ │
│ │ - Define metrics │ │
│ │ - Set parameters │ │
│ └──────────────────────────────────────────────────────┘ │
│ ↓ │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ Assignment Service │ │
│ │ - Hash user_id → variant │ │
│ │ - Consistent assignment │ │
│ │ - Cache assignments │ │
│ └──────────────────────────────────────────────────────┘ │
│ ↓ │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ Event Logging Service │ │
│ │ - Log user actions │ │
│ │ - Track outcomes │ │
│ │ - Stream to analytics │ │
│ └──────────────────────────────────────────────────────┘ │
│ ↓ │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ Metrics Aggregation Service │ │
│ │ - Compute experiment metrics │ │
│ │ - Real-time dashboards │ │
│ │ - Statistical tests │ │
│ └──────────────────────────────────────────────────────┘ │
│ ↓ │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ Analysis & Reporting Service │ │
│ │ - Statistical significance │ │
│ │ - Confidence intervals │ │
│ │ - Decision recommendations │ │
│ └──────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
Data Flow:
User Request → Assignment → Show Variant → Log Events → Aggregate Metrics → Analyze
Component 1: Experiment Assignment
Assign users to experiment variants consistently and randomly.
Deterministic Assignment via Hashing
import hashlib
from typing import List, Dict
class ExperimentAssigner:
"""
Assign users to experiment variants
Requirements:
- Deterministic: Same user_id → same variant
- Random: Uniform distribution across variants
- Independent: Different experiments use different hash seeds
"""
def __init__(self):
self.experiments = {} # experiment_id → config
def create_experiment(
self,
experiment_id: str,
variants: List[str],
traffic_allocation: float = 1.0
):
"""
Create new experiment
Args:
experiment_id: Unique experiment identifier
variants: List of variant names (e.g., ['control', 'treatment'])
traffic_allocation: Fraction of users to include (0.0 to 1.0)
"""
self.experiments[experiment_id] = {
'variants': variants,
'traffic_allocation': traffic_allocation,
'num_variants': len(variants)
}
def assign_variant(self, user_id: str, experiment_id: str) -> str:
"""
Assign user to variant
Uses consistent hashing for deterministic assignment
Returns:
Variant name or None if user not in experiment
"""
if experiment_id not in self.experiments:
raise ValueError(f"Experiment {experiment_id} not found")
config = self.experiments[experiment_id]
# Hash user_id + experiment_id
hash_input = f"{user_id}:{experiment_id}".encode('utf-8')
hash_value = int(hashlib.md5(hash_input).hexdigest(), 16)
# Map to [0, 1]
normalized = (hash_value % 10000) / 10000.0
# Check if user is in experiment (traffic allocation)
if normalized >= config['traffic_allocation']:
return None # User not in experiment
# Assign to variant
# Re-normalize to [0, 1] within allocated traffic
variant_hash = normalized / config['traffic_allocation']
variant_idx = int(variant_hash * config['num_variants'])
return config['variants'][variant_idx]
# Usage
assigner = ExperimentAssigner()
# Create experiment: 50% control, 50% treatment, 100% of users
assigner.create_experiment(
experiment_id='model_v2_test',
variants=['control', 'treatment'],
traffic_allocation=1.0
)
# Assign users
user_1_variant = assigner.assign_variant('user_123', 'model_v2_test')
print(f"User 123 assigned to: {user_1_variant}")
# Same user always gets same variant
assert assigner.assign_variant('user_123', 'model_v2_test') == user_1_variant
Why Hashing Works
Properties of MD5/SHA hashing:
- Deterministic: Same input → same output
- Uniform: Output uniformly distributed
- Independent: Different inputs → uncorrelated outputs
Key insight: Hash(user_id + experiment_id) acts as a random number generator with a fixed seed per user-experiment pair.
Handling Traffic Allocation
def assign_with_traffic_split(self, user_id: str, experiment_id: str) -> str:
"""
Assign with partial traffic allocation
Example: 10% of users in experiment
- Hash to [0, 1]
- If < 0.10 → assign to variant
- Else → not in experiment
"""
config = self.experiments[experiment_id]
# Hash
hash_input = f"{user_id}:{experiment_id}".encode('utf-8')
hash_value = int(hashlib.md5(hash_input).hexdigest(), 16)
normalized = (hash_value % 10000) / 10000.0
# Traffic allocation check
if normalized >= config['traffic_allocation']:
return None
# Within traffic, assign to variant
# Scale normalized to [0, traffic_allocation] → [0, 1]
variant_hash = normalized / config['traffic_allocation']
variant_idx = int(variant_hash * config['num_variants'])
return config['variants'][variant_idx]
Component 2: Metrics Tracking
Track user actions and compute experiment metrics.
Event Logging
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
import json
@dataclass
class ExperimentEvent:
"""Single experiment event"""
user_id: str
experiment_id: str
variant: str
event_type: str # e.g., 'impression', 'click', 'purchase'
timestamp: datetime
metadata: dict = None
def to_dict(self):
return {
'user_id': self.user_id,
'experiment_id': self.experiment_id,
'variant': self.variant,
'event_type': self.event_type,
'timestamp': self.timestamp.isoformat(),
'metadata': self.metadata or {}
}
class EventLogger:
"""
Log experiment events
In production: Stream to Kafka/Kinesis → Data warehouse
"""
def __init__(self, output_file='experiment_events.jsonl'):
self.output_file = output_file
def log_event(self, event: ExperimentEvent):
"""
Log single event
In production: Send to message queue
"""
with open(self.output_file, 'a') as f:
f.write(json.dumps(event.to_dict()) + '\n')
def log_assignment(self, user_id: str, experiment_id: str, variant: str):
"""Log when user is assigned to variant"""
event = ExperimentEvent(
user_id=user_id,
experiment_id=experiment_id,
variant=variant,
event_type='assignment',
timestamp=datetime.now()
)
self.log_event(event)
def log_metric_event(
self,
user_id: str,
experiment_id: str,
variant: str,
metric_name: str,
metric_value: float
):
"""Log metric event (e.g., click, purchase)"""
event = ExperimentEvent(
user_id=user_id,
experiment_id=experiment_id,
variant=variant,
event_type=metric_name,
timestamp=datetime.now(),
metadata={'value': metric_value}
)
self.log_event(event)
# Usage
logger = EventLogger()
# Log assignment
logger.log_assignment('user_123', 'model_v2_test', 'treatment')
# Log click
logger.log_metric_event('user_123', 'model_v2_test', 'treatment', 'click', 1.0)
# Log purchase
logger.log_metric_event('user_123', 'model_v2_test', 'treatment', 'purchase', 49.99)
Metrics Aggregation
from collections import defaultdict
import pandas as pd
class MetricsAggregator:
"""
Aggregate experiment metrics from events
Computes per-variant statistics
"""
def __init__(self):
self.variant_stats = defaultdict(lambda: defaultdict(list))
def add_event(self, variant: str, metric_name: str, value: float):
"""Add metric value for variant"""
self.variant_stats[variant][metric_name].append(value)
def compute_metrics(self, experiment_id: str) -> pd.DataFrame:
"""
Compute aggregated metrics per variant
Returns DataFrame with columns:
- variant
- metric
- count
- mean
- std
- sum
"""
results = []
for variant, metrics in self.variant_stats.items():
for metric_name, values in metrics.items():
import numpy as np
results.append({
'variant': variant,
'metric': metric_name,
'count': len(values),
'mean': np.mean(values),
'std': np.std(values),
'sum': np.sum(values),
'min': np.min(values),
'max': np.max(values)
})
return pd.DataFrame(results)
# Usage
aggregator = MetricsAggregator()
# Simulate events
aggregator.add_event('control', 'ctr', 0.05)
aggregator.add_event('control', 'ctr', 0.04)
aggregator.add_event('treatment', 'ctr', 0.06)
aggregator.add_event('treatment', 'ctr', 0.07)
metrics_df = aggregator.compute_metrics('model_v2_test')
print(metrics_df)
Component 3: Statistical Analysis
Determine if observed differences are statistically significant.
T-Test for Continuous Metrics
from scipy import stats
import numpy as np
class StatisticalAnalyzer:
"""
Perform statistical tests on experiment data
"""
def t_test(
self,
control_values: List[float],
treatment_values: List[float],
alpha: float = 0.05
) -> dict:
"""
Two-sample t-test
H0: mean(treatment) = mean(control)
H1: mean(treatment) ≠ mean(control)
Args:
control_values: Metric values from control group
treatment_values: Metric values from treatment group
alpha: Significance level (default 0.05)
Returns:
Dictionary with test results
"""
control = np.array(control_values)
treatment = np.array(treatment_values)
# Perform t-test
t_statistic, p_value = stats.ttest_ind(treatment, control)
# Compute effect size (Cohen's d)
pooled_std = np.sqrt(
((len(control) - 1) * np.var(control, ddof=1) +
(len(treatment) - 1) * np.var(treatment, ddof=1)) /
(len(control) + len(treatment) - 2)
)
cohens_d = (np.mean(treatment) - np.mean(control)) / pooled_std if pooled_std > 0 else 0
# Confidence interval for difference
se = pooled_std * np.sqrt(1/len(control) + 1/len(treatment))
df = len(control) + len(treatment) - 2
t_critical = stats.t.ppf(1 - alpha/2, df)
mean_diff = np.mean(treatment) - np.mean(control)
ci_lower = mean_diff - t_critical * se
ci_upper = mean_diff + t_critical * se
# Relative lift
relative_lift = (np.mean(treatment) / np.mean(control) - 1) * 100 if np.mean(control) > 0 else 0
return {
'control_mean': np.mean(control),
'treatment_mean': np.mean(treatment),
'absolute_diff': mean_diff,
'relative_lift_pct': relative_lift,
't_statistic': t_statistic,
'p_value': p_value,
'is_significant': p_value < alpha,
'confidence_interval': (ci_lower, ci_upper),
'cohens_d': cohens_d,
'sample_size_control': len(control),
'sample_size_treatment': len(treatment)
}
def chi_square_test(
self,
control_successes: int,
control_total: int,
treatment_successes: int,
treatment_total: int,
alpha: float = 0.05
) -> dict:
"""
Chi-square test for proportions (e.g., CTR, conversion rate)
"""
# Construct contingency table
contingency = np.array([
[treatment_successes, treatment_total - treatment_successes],
[control_successes, control_total - control_successes]
])
# Chi-square test
chi2, p_value, dof, expected = stats.chi2_contingency(contingency)
# Rates
control_rate = control_successes / control_total if control_total > 0 else 0
treatment_rate = treatment_successes / treatment_total if treatment_total > 0 else 0
# Relative lift
relative_lift = (treatment_rate / control_rate - 1) * 100 if control_rate > 0 else 0
# CI for difference in proportions (Wald)
p1 = treatment_rate
p2 = control_rate
se = np.sqrt(
(p1 * (1 - p1) / max(treatment_total, 1)) +
(p2 * (1 - p2) / max(control_total, 1))
)
z_critical = stats.norm.ppf(1 - alpha / 2)
diff = p1 - p2
ci_lower = diff - z_critical * se
ci_upper = diff + z_critical * se
return {
'control_rate': control_rate,
'treatment_rate': treatment_rate,
'absolute_diff': diff,
'relative_lift_pct': relative_lift,
'chi2_statistic': chi2,
'p_value': p_value,
'is_significant': p_value < alpha,
'confidence_interval': (ci_lower, ci_upper),
'sample_size_control': control_total,
'sample_size_treatment': treatment_total
}
# Usage
analyzer = StatisticalAnalyzer()
# Simulate metric data (e.g., session duration in seconds)
control_sessions = np.random.normal(120, 30, size=1000) # mean=120s, std=30s
treatment_sessions = np.random.normal(125, 30, size=1000) # mean=125s (5s improvement)
result = analyzer.t_test(control_sessions, treatment_sessions)
print(f"Control mean: {result['control_mean']:.2f}")
print(f"Treatment mean: {result['treatment_mean']:.2f}")
print(f"Relative lift: {result['relative_lift_pct']:.2f}%")
print(f"P-value: {result['p_value']:.4f}")
print(f"Significant: {result['is_significant']}")
print(f"95% CI: [{result['confidence_interval'][0]:.2f}, {result['confidence_interval'][1]:.2f}]")
Chi-Square Test for Binary Metrics
def chi_square_test(
self,
control_successes: int,
control_total: int,
treatment_successes: int,
treatment_total: int,
alpha: float = 0.05
) -> dict:
"""
Chi-square test for proportions (e.g., CTR, conversion rate)
H0: p(treatment) = p(control)
H1: p(treatment) ≠ p(control)
"""
# Construct contingency table
contingency = np.array([
[treatment_successes, treatment_total - treatment_successes],
[control_successes, control_total - control_successes]
])
# Chi-square test
chi2, p_value, dof, expected = stats.chi2_contingency(contingency)
# Compute rates
control_rate = control_successes / control_total if control_total > 0 else 0
treatment_rate = treatment_successes / treatment_total if treatment_total > 0 else 0
# Relative lift
relative_lift = (treatment_rate / control_rate - 1) * 100 if control_rate > 0 else 0
# Confidence interval for difference in proportions
p1 = treatment_rate
p2 = control_rate
se = np.sqrt(p1*(1-p1)/treatment_total + p2*(1-p2)/control_total)
z_critical = stats.norm.ppf(1 - alpha/2)
diff = p1 - p2
ci_lower = diff - z_critical * se
ci_upper = diff + z_critical * se
return {
'control_rate': control_rate,
'treatment_rate': treatment_rate,
'absolute_diff': diff,
'relative_lift_pct': relative_lift,
'chi2_statistic': chi2,
'p_value': p_value,
'is_significant': p_value < alpha,
'confidence_interval': (ci_lower, ci_upper),
'sample_size_control': control_total,
'sample_size_treatment': treatment_total
}
# Add to StatisticalAnalyzer class
# Usage
# Example: Click-through rate test
control_clicks = 450
control_impressions = 10000
treatment_clicks = 520
treatment_impressions = 10000
result = analyzer.chi_square_test(
control_clicks, control_impressions,
treatment_clicks, treatment_impressions
)
print(f"Control CTR: {result['control_rate']*100:.2f}%")
print(f"Treatment CTR: {result['treatment_rate']*100:.2f}%")
print(f"Relative lift: {result['relative_lift_pct']:.2f}%")
print(f"P-value: {result['p_value']:.4f}")
print(f"Significant: {result['is_significant']}")
Sample Size Calculation & Power Analysis
Determine required sample size before running experiment.
Power Analysis
from scipy.stats import norm
class PowerAnalysis:
"""
Calculate required sample size for experiments
"""
def sample_size_for_proportions(
self,
baseline_rate: float,
mde: float, # Minimum Detectable Effect
alpha: float = 0.05,
power: float = 0.80
) -> int:
"""
Calculate sample size needed to detect effect on proportion
Args:
baseline_rate: Current conversion rate (e.g., 0.05 for 5%)
mde: Minimum relative effect to detect (e.g., 0.10 for 10% improvement)
alpha: Significance level (Type I error rate)
power: Statistical power (1 - Type II error rate)
Returns:
Required sample size per variant
"""
# Target rate after improvement
target_rate = baseline_rate * (1 + mde)
# Z-scores
z_alpha = norm.ppf(1 - alpha/2) # Two-tailed
z_beta = norm.ppf(power)
# Pooled proportion under H0
p_avg = (baseline_rate + target_rate) / 2
# Sample size formula
numerator = (z_alpha * np.sqrt(2 * p_avg * (1 - p_avg)) +
z_beta * np.sqrt(baseline_rate * (1 - baseline_rate) +
target_rate * (1 - target_rate))) ** 2
denominator = (target_rate - baseline_rate) ** 2
n = numerator / denominator
return int(np.ceil(n))
def sample_size_for_means(
self,
baseline_mean: float,
baseline_std: float,
mde: float,
alpha: float = 0.05,
power: float = 0.80
) -> int:
"""
Calculate sample size for continuous metric
Args:
baseline_mean: Current mean value
baseline_std: Standard deviation
mde: Minimum relative effect (e.g., 0.05 for 5% improvement)
alpha: Significance level
power: Statistical power
Returns:
Required sample size per variant
"""
target_mean = baseline_mean * (1 + mde)
effect_size = abs(target_mean - baseline_mean) / baseline_std
z_alpha = norm.ppf(1 - alpha/2)
z_beta = norm.ppf(power)
n = 2 * ((z_alpha + z_beta) / effect_size) ** 2
return int(np.ceil(n))
def experiment_duration(
self,
required_sample_size: int,
daily_users: int,
traffic_allocation: float = 0.5
) -> int:
"""
Calculate experiment duration in days
Args:
required_sample_size: Sample size per variant
daily_users: Daily active users
traffic_allocation: Fraction of users in experiment
Returns:
Duration in days
"""
users_per_day = daily_users * traffic_allocation
days = required_sample_size / users_per_day
return int(np.ceil(days))
# Usage
power = PowerAnalysis()
# Example: CTR improvement test
current_ctr = 0.05 # 5% baseline
mde = 0.10 # Want to detect 10% relative improvement (5% → 5.5%)
sample_size = power.sample_size_for_proportions(
baseline_rate=current_ctr,
mde=mde,
alpha=0.05,
power=0.80
)
print(f"Required sample size per variant: {sample_size:,}")
# If we have 100K daily users and allocate 50% to experiment
duration = power.experiment_duration(
required_sample_size=sample_size,
daily_users=100000,
traffic_allocation=0.5
)
print(f"Experiment duration: {duration} days")
Guardrail Metrics
Ensure experiments don’t harm key business metrics.
Implementing Guardrails
class GuardrailChecker:
"""
Monitor guardrail metrics during experiments
Guardrails: Metrics that must not degrade
"""
def __init__(self):
self.guardrails = {}
def define_guardrail(
self,
metric_name: str,
threshold_type: str, # 'relative' or 'absolute'
threshold_value: float,
direction: str # 'decrease' or 'increase'
):
"""
Define a guardrail metric
Example:
- Revenue must not decrease by more than 2%
- Error rate must not increase by more than 0.5 percentage points
"""
self.guardrails[metric_name] = {
'threshold_type': threshold_type,
'threshold_value': threshold_value,
'direction': direction
}
def check_guardrails(
self,
control_metrics: dict,
treatment_metrics: dict
) -> dict:
"""
Check if treatment violates guardrails
Returns:
Dictionary of guardrail violations
"""
violations = {}
for metric_name, guardrail in self.guardrails.items():
control_value = control_metrics.get(metric_name)
treatment_value = treatment_metrics.get(metric_name)
if control_value is None or treatment_value is None:
continue
# Calculate change
if guardrail['threshold_type'] == 'relative':
change = (treatment_value / control_value - 1) * 100
else: # absolute
change = treatment_value - control_value
# Check violation
violated = False
if guardrail['direction'] == 'decrease':
# Metric should not decrease beyond threshold
if change < -guardrail['threshold_value']:
violated = True
else: # increase
# Metric should not increase beyond threshold
if change > guardrail['threshold_value']:
violated = True
if violated:
violations[metric_name] = {
'control': control_value,
'treatment': treatment_value,
'change': change,
'threshold': guardrail['threshold_value'],
'type': guardrail['threshold_type']
}
return violations
# Usage
guardrails = GuardrailChecker()
# Define guardrails
guardrails.define_guardrail(
metric_name='revenue_per_user',
threshold_type='relative',
threshold_value=2.0, # Cannot decrease by more than 2%
direction='decrease'
)
guardrails.define_guardrail(
metric_name='error_rate',
threshold_type='absolute',
threshold_value=0.5, # Cannot increase by more than 0.5 percentage points
direction='increase'
)
# Check guardrails
control_metrics = {
'revenue_per_user': 10.0,
'error_rate': 1.0
}
treatment_metrics = {
'revenue_per_user': 9.5, # 5% decrease - violates guardrail!
'error_rate': 1.2 # 0.2pp increase - OK
}
violations = guardrails.check_guardrails(control_metrics, treatment_metrics)
if violations:
print("⚠️ Guardrail violations detected:")
for metric, details in violations.items():
print(f" {metric}: {details['change']:.2f}% change (threshold: {details['threshold']}%)")
else:
print("✅ All guardrails passed")
Real-World Examples
Netflix: Experimentation at Scale
Scale:
- 1000+ experiments running concurrently
- 200M+ users worldwide
- Multiple metrics per experiment
Key innovations:
- Quasi-experimentation: Use observational data when randomization not possible
- Interleaving: Test ranking algorithms by mixing results
- Heterogeneous treatment effects: Analyze impact per user segment
Example metric:
- Stream starts per member: How many shows/movies a user starts watching
- Effective catalog size: Number of unique titles watched (diversity metric)
Google: Large-scale Testing
Scale:
- 10,000+ experiments per year
- 1B+ users
- Experiments across Search, Ads, YouTube, etc.
Methodology:
- Layered experiments: Run multiple experiments on same users (orthogonal layers)
- Ramping: Gradually increase traffic allocation
- Long-running holdouts: Keep small % in old version to measure long-term effects
Example: Testing new ranking algorithm in Google Search:
- Primary metric: Click-through rate on top results
- Guardrails: Ad revenue, latency, user satisfaction
- Duration: 2-4 weeks
- Traffic: Start at 1%, ramp to 50%
Advanced Topics
Sequential Testing & Early Stopping
Stop experiments early when results are conclusive.
import math
from scipy.stats import norm
class SequentialTesting:
"""
Sequential probability ratio test (SPRT)
Allows stopping experiment early while controlling error rates
"""
def __init__(
self,
alpha=0.05,
beta=0.20,
mde=0.05 # Minimum detectable effect
):
self.alpha = alpha # Type I error rate
self.beta = beta # Type II error rate (1 - power)
self.mde = mde
# Calculate log-likelihood ratio bounds
self.upper_bound = math.log((1 - beta) / alpha)
self.lower_bound = math.log(beta / (1 - alpha))
def should_stop(
self,
control_successes: int,
control_total: int,
treatment_successes: int,
treatment_total: int
) -> dict:
"""
Check if experiment can be stopped
Returns:
{
'decision': 'continue' | 'stop_treatment_wins' | 'stop_control_wins',
'log_likelihood_ratio': float
}
"""
# Compute rates
p_control = control_successes / control_total if control_total > 0 else 0
p_treatment = treatment_successes / treatment_total if treatment_total > 0 else 0
# Avoid edge cases
p_control = max(min(p_control, 0.9999), 0.0001)
p_treatment = max(min(p_treatment, 0.9999), 0.0001)
# Log-likelihood ratio
# H1: treatment is better by mde
# H0: treatment = control
p_h1 = p_control * (1 + self.mde)
llr = 0
# Contribution from treatment group
llr += treatment_successes * math.log(p_h1 / p_control)
llr += (treatment_total - treatment_successes) * math.log((1 - p_h1) / (1 - p_control))
# Decision
if llr >= self.upper_bound:
return {'decision': 'stop_treatment_wins', 'log_likelihood_ratio': llr}
elif llr <= self.lower_bound:
return {'decision': 'stop_control_wins', 'log_likelihood_ratio': llr}
else:
return {'decision': 'continue', 'log_likelihood_ratio': llr}
# Usage
sequential = SequentialTesting(alpha=0.05, beta=0.20, mde=0.05)
# Check daily
for day in range(1, 15):
control_clicks = day * 450
control_impressions = day * 10000
treatment_clicks = day * 500
treatment_impressions = day * 10000
result = sequential.should_stop(
control_clicks, control_impressions,
treatment_clicks, treatment_impressions
)
print(f" {day}: {result['decision']}")
if result['decision'] != 'continue':
print(f"🎉 Experiment can stop on day {day}!")
break
Multi-Armed Bandits
Allocate traffic dynamically to better-performing variants.
import numpy as np
class ThompsonSampling:
"""
Thompson Sampling for multi-armed bandit
Dynamically allocate traffic to maximize reward
while exploring alternatives
"""
def __init__(self, num_variants):
self.num_variants = num_variants
# Beta distribution parameters for each variant
# Beta(alpha, beta) represents posterior belief
self.alpha = np.ones(num_variants) # Success count + 1
self.beta = np.ones(num_variants) # Failure count + 1
def select_variant(self) -> int:
"""
Select variant to show to next user
Sample from each variant's posterior and pick the best
"""
# Sample from each variant's posterior distribution
sampled_values = [
np.random.beta(self.alpha[i], self.beta[i])
for i in range(self.num_variants)
]
# Select variant with highest sample
return np.argmax(sampled_values)
def update(self, variant: int, reward: float):
"""
Update beliefs after observing reward
Args:
variant: Which variant was shown
reward: 0 or 1 (failure or success)
"""
if reward > 0:
self.alpha[variant] += 1
else:
self.beta[variant] += 1
def get_statistics(self):
"""Get current statistics for each variant"""
stats = []
for i in range(self.num_variants):
# Mean of Beta(alpha, beta) = alpha / (alpha + beta)
mean = self.alpha[i] / (self.alpha[i] + self.beta[i])
# 95% credible interval
samples = np.random.beta(self.alpha[i], self.beta[i], size=10000)
ci_lower, ci_upper = np.percentile(samples, [2.5, 97.5])
stats.append({
'variant': i,
'estimated_mean': mean,
'total_samples': self.alpha[i] + self.beta[i] - 2,
'successes': self.alpha[i] - 1,
'credible_interval': (ci_lower, ci_upper)
})
return stats
# Usage
bandit = ThompsonSampling(num_variants=3)
# Simulate 10,000 users
for user in range(10000):
# Select variant to show
variant = bandit.select_variant()
# Simulate user interaction (variant 2 is best: 6% CTR)
true_ctrs = [0.04, 0.05, 0.06]
clicked = np.random.random() < true_ctrs[variant]
# Update beliefs
bandit.update(variant, 1.0 if clicked else 0.0)
# Check statistics
stats = bandit.get_statistics()
for s in stats:
print(f"Variant {s['variant']}: "
f"Estimated CTR = {s['estimated_mean']:.3f}, "
f"Samples = {s['total_samples']}, "
f"95% CI = [{s['credible_interval'][0]:.3f}, {s['credible_interval'][1]:.3f}]")
Variance Reduction: CUPED
Reduce variance by using pre-experiment covariates.
class CUPED:
"""
Controlled-experiment Using Pre-Experiment Data
Reduces variance by adjusting for pre-experiment metrics
"""
def __init__(self):
pass
def adjust_metric(
self,
y: np.ndarray, # Post-experiment metric
x: np.ndarray, # Pre-experiment metric (covariate)
) -> np.ndarray:
"""
Adjust post-experiment metric using pre-experiment data
Adjusted metric: y_adj = y - theta * (x - E[x])
Where theta is chosen to minimize variance of y_adj
"""
# Compute optimal theta
# theta = Cov(y, x) / Var(x)
mean_x = np.mean(x)
mean_y = np.mean(y)
cov_yx = np.mean((y - mean_y) * (x - mean_x))
var_x = np.var(x, ddof=1)
if var_x == 0:
return y # No adjustment possible
theta = cov_yx / var_x
# Adjust y
y_adjusted = y - theta * (x - mean_x)
return y_adjusted
def compare_variants_with_cuped(
self,
control_post: np.ndarray,
control_pre: np.ndarray,
treatment_post: np.ndarray,
treatment_pre: np.ndarray
) -> dict:
"""
Compare variants using CUPED
Returns improvement in statistical power
"""
# Original comparison (without CUPED)
from scipy import stats
original_t, original_p = stats.ttest_ind(treatment_post, control_post)
original_var = np.var(treatment_post) + np.var(control_post)
# Adjust metrics
all_pre = np.concatenate([control_pre, treatment_pre])
all_post = np.concatenate([control_post, treatment_post])
adjusted_post = self.adjust_metric(all_post, all_pre)
# Split back
n_control = len(control_post)
control_post_adj = adjusted_post[:n_control]
treatment_post_adj = adjusted_post[n_control:]
# Adjusted comparison
adjusted_t, adjusted_p = stats.ttest_ind(treatment_post_adj, control_post_adj)
adjusted_var = np.var(treatment_post_adj) + np.var(control_post_adj)
# Variance reduction
variance_reduction = (original_var - adjusted_var) / original_var * 100
return {
'original_p_value': original_p,
'adjusted_p_value': adjusted_p,
'variance_reduction_pct': variance_reduction,
'power_improvement': (original_var / adjusted_var) ** 0.5
}
# Example: Using pre-experiment purchase history to reduce variance
control_pre = np.random.normal(100, 30, size=500) # Past purchases
control_post = control_pre + np.random.normal(5, 20, size=500) # Correlated
treatment_pre = np.random.normal(100, 30, size=500)
treatment_post = treatment_pre + np.random.normal(8, 20, size=500) # Slightly better
cuped = CUPED()
result = cuped.compare_variants_with_cuped(
control_post, control_pre,
treatment_post, treatment_pre
)
print(f"Original p-value: {result['original_p_value']:.4f}")
print(f"Adjusted p-value: {result['adjusted_p_value']:.4f}")
print(f"Variance reduction: {result['variance_reduction_pct']:.1f}%")
print(f"Power improvement: {result['power_improvement']:.2f}x")
Stratified Sampling
Ensure balance across important user segments.
class StratifiedAssignment:
"""
Assign users to experiments with stratification
Ensures balanced assignment within strata (e.g., country, platform)
"""
def __init__(self, num_variants=2):
self.num_variants = num_variants
self.strata_counters = {} # stratum → variant counts
def assign_variant(self, user_id: str, stratum: str) -> int:
"""
Assign user to variant, ensuring balance within stratum
Args:
user_id: User identifier
stratum: Stratum key (e.g., "US_iOS", "UK_Android")
Returns:
Variant index
"""
# Initialize stratum if new
if stratum not in self.strata_counters:
self.strata_counters[stratum] = [0] * self.num_variants
# Hash-based assignment (deterministic)
import hashlib
hash_input = f"{user_id}:{stratum}".encode('utf-8')
hash_value = int(hashlib.md5(hash_input).hexdigest(), 16)
variant = hash_value % self.num_variants
# Update counter
self.strata_counters[stratum][variant] += 1
return variant
def get_balance_report(self) -> dict:
"""Check balance within each stratum"""
report = {}
for stratum, counts in self.strata_counters.items():
total = sum(counts)
proportions = [c / total for c in counts]
# Check if balanced (each variant should have ~1/num_variants)
expected = 1 / self.num_variants
max_deviation = max(abs(p - expected) for p in proportions)
report[stratum] = {
'counts': counts,
'proportions': proportions,
'max_deviation': max_deviation,
'balanced': max_deviation < 0.05 # Within 5% of expected
}
return report
# Usage
stratified = StratifiedAssignment(num_variants=2)
# Simulate user assignments
for i in range(10000):
user_id = f"user_{i}"
# Assign stratum based on user
if i % 3 == 0:
stratum = "US_iOS"
elif i % 3 == 1:
stratum = "US_Android"
else:
stratum = "UK_iOS"
variant = stratified.assign_variant(user_id, stratum)
# Check balance
balance = stratified.get_balance_report()
for stratum, stats in balance.items():
print(f"{stratum}: {stats['counts']}, balanced={stats['balanced']}")
Multiple Testing Correction
When running many experiments, control family-wise error rate.
Bonferroni Correction
def bonferroni_correction(p_values: List[float], alpha: float = 0.05) -> List[bool]:
"""
Bonferroni correction for multiple comparisons
Adjusted alpha = alpha / num_tests
Args:
p_values: List of p-values from multiple tests
alpha: Family-wise error rate
Returns:
List of booleans (True = significant after correction)
"""
num_tests = len(p_values)
adjusted_alpha = alpha / num_tests
return [p < adjusted_alpha for p in p_values]
# Example: Testing 10 variants
p_values = [0.04, 0.06, 0.03, 0.08, 0.02, 0.09, 0.07, 0.05, 0.01, 0.10]
significant_uncorrected = [p < 0.05 for p in p_values]
significant_corrected = bonferroni_correction(p_values, alpha=0.05)
print(f"Significant (uncorrected): {sum(significant_uncorrected)} / {len(p_values)}")
print(f"Significant (Bonferroni): {sum(significant_corrected)} / {len(p_values)}")
False Discovery Rate (FDR) - Benjamini-Hochberg
def benjamini_hochberg(p_values: List[float], alpha: float = 0.05) -> List[bool]:
"""
Benjamini-Hochberg procedure for FDR control
Less conservative than Bonferroni
Args:
p_values: List of p-values
alpha: Desired FDR level
Returns:
List of booleans (True = significant)
"""
num_tests = len(p_values)
# Sort p-values with original indices
indexed_p_values = [(p, i) for i, p in enumerate(p_values)]
indexed_p_values.sort()
# Find largest k such that p[k] <= (k+1)/m * alpha
significant_indices = set()
for k in range(num_tests - 1, -1, -1):
p_value, original_idx = indexed_p_values[k]
threshold = (k + 1) / num_tests * alpha
if p_value <= threshold:
# Mark this and all smaller p-values as significant
for j in range(k + 1):
significant_indices.add(indexed_p_values[j][1])
break
# Create result list
return [i in significant_indices for i in range(num_tests)]
# Compare to Bonferroni
fdr_significant = benjamini_hochberg(p_values, alpha=0.05)
print(f"Significant (FDR): {sum(fdr_significant)} / {len(p_values)}")
Layered Experiments
Run multiple experiments simultaneously on orthogonal layers.
class ExperimentLayer:
"""
Single experiment layer
"""
def __init__(self, layer_id: str, experiments: List[str]):
self.layer_id = layer_id
self.experiments = experiments
self.num_experiments = len(experiments)
def assign_experiment(self, user_id: str) -> str:
"""Assign user to one experiment in this layer"""
import hashlib
hash_input = f"{user_id}:{self.layer_id}".encode('utf-8')
hash_value = int(hashlib.md5(hash_input).hexdigest(), 16)
experiment_idx = hash_value % self.num_experiments
return self.experiments[experiment_idx]
class LayeredExperimentPlatform:
"""
Platform supporting layered experiments
Layers should be independent (orthogonal)
"""
def __init__(self):
self.layers = {}
def add_layer(self, layer_id: str, experiments: List[str]):
"""Add experiment layer"""
self.layers[layer_id] = ExperimentLayer(layer_id, experiments)
def assign_user(self, user_id: str) -> dict:
"""
Assign user to experiments across all layers
Returns:
Dict mapping layer_id → experiment_id
"""
assignments = {}
for layer_id, layer in self.layers.items():
experiment = layer.assign_experiment(user_id)
assignments[layer_id] = experiment
return assignments
# Usage
platform = LayeredExperimentPlatform()
# Layer 1: Ranking algorithm tests
platform.add_layer(
'ranking',
['ranking_baseline', 'ranking_ml_v1', 'ranking_ml_v2']
)
# Layer 2: UI tests (independent of ranking)
platform.add_layer(
'ui',
['ui_old', 'ui_new_blue', 'ui_new_green']
)
# Layer 3: Recommendation tests
platform.add_layer(
'recommendations',
['recs_baseline', 'recs_personalized']
)
# Assign user to experiments
user_experiments = platform.assign_user('user_12345')
print(f"User assigned to:")
for layer, experiment in user_experiments.items():
print(f" {layer}: {experiment}")
# User gets combination like:
# ranking: ranking_ml_v2
# ui: ui_new_blue
# recommendations: recs_personalized
Airbnb’s Experiment Framework
Real-world example of production experimentation.
Key Components:
- ERF (Experiment Reporting Framework)
- Centralized metric definitions
- Automated metric computation
- Standardized reporting
- CUPED for Variance Reduction
- Uses pre-experiment booking history
- 50%+ variance reduction on key metrics
- Dramatically reduces required sample size
- Quasi-experiments
- When randomization not possible (e.g., pricing tests)
- Difference-in-differences analysis
- Synthetic control methods
- Interference Handling
- Network effects (one user’s treatment affects others)
- Cluster randomization (randomize at city/market level)
- Ego-cluster randomization
Metrics Hierarchy:
Primary Metrics (move-the-needle)
├── Bookings
├── Revenue
└── Guest Satisfaction
Secondary Metrics (understand mechanism)
├── Search engagement
├── Listing views
└── Message rate
Guardrail Metrics (protect)
├── Host satisfaction
├── Cancellation rate
└── Customer support tickets
Key Takeaways
✅ Randomization via consistent hashing ensures unbiased assignment ✅ Statistical rigor prevents false positives, require p < 0.05 and sufficient power ✅ Sample size calculation upfront prevents underpowered experiments ✅ Guardrail metrics protect against shipping harmful changes ✅ Real-time monitoring enables early stopping for clear wins/losses ✅ Sequential testing allows stopping early while controlling error rates ✅ Multi-armed bandits dynamically optimize traffic allocation ✅ CUPED reduces variance using pre-experiment data → smaller samples needed ✅ Stratified sampling ensures balance across key user segments ✅ Multiple testing corrections control error rates when running many experiments ✅ Layered experiments increase experimentation velocity without conflicts ✅ Long-term holdouts measure sustained impact vs novelty effects
FAQ
How does consistent hashing ensure reliable A/B test assignment?
Hashing the concatenation of user ID and experiment ID produces a deterministic, uniformly distributed value. The same user always gets the same variant, and different experiments produce independent assignments.
What is CUPED and how does it help A/B testing?
CUPED (Controlled-experiment Using Pre-Experiment Data) adjusts post-experiment metrics by subtracting the correlation with pre-experiment behavior. This can reduce variance by 50 percent or more, allowing experiments to reach significance with smaller samples or shorter durations.
When should you use sequential testing instead of fixed-horizon tests?
Sequential testing allows checking results continuously and stopping early when results are conclusive, while still controlling error rates. Use it when you want faster decisions, but be aware it requires more total samples if the effect is borderline.
What are guardrail metrics in experimentation?
Guardrail metrics are key business metrics that must not degrade during an experiment, such as revenue, error rate, or latency. If a treatment improves the primary metric but violates a guardrail, the experiment should not ship.
Originally published at: arunbaby.com/ml-system-design/0004-ab-testing-systems
Want to work together?
I take on projects, advisory roles, and fractional CTO engagements in AI/ML. I also help businesses go AI-native with agentic workflows and agent orchestration.
Get in touch