28 minute read

How voice assistants recognize “turn on the lights” from raw audio in under 100ms without full ASR transcription.

Introduction

When you say “Alexa, turn off the lights” or “Hey Google, set a timer,” your voice assistant doesn’t actually transcribe your speech to text first. Instead, it uses a direct audio-to-intent classification system that’s:

  • Faster than ASR + NLU (50-100ms vs 200-500ms)
  • Smaller models (< 10MB vs 100MB+)
  • Works offline (on-device inference)
  • More privacy-preserving (no text sent to cloud)

This approach is perfect for a limited vocabulary of commands (30-100 commands) where you care more about speed and privacy than open-ended understanding.

What you’ll learn:

  • Why direct audio→intent beats ASR→NLU for commands
  • Audio feature extraction (MFCCs, mel-spectrograms)
  • Model architectures (CNN, RNN, Attention)
  • Training strategies and data augmentation
  • On-device deployment and optimization
  • Unknown command handling (OOD detection)
  • Real-world examples from Google, Amazon, Apple

Problem Definition

Design a speech command classification system for a voice assistant that:

Functional Requirements

  1. Multi-class Classification
    • 30-50 predefined commands
    • Examples: “lights on”, “volume up”, “play music”, “stop timer”
    • Support synonyms and variations
  2. Unknown Detection
    • Detect and reject out-of-vocabulary audio
    • Handle background conversation
    • Distinguish commands from non-commands
  3. Multi-language Support
    • 5+ languages initially
    • Shared model or separate models per language
  4. Context Awareness
    • Optional: Use device state as context
    • Example: “turn it off” depends on what’s currently on

Non-Functional Requirements

  1. Latency
    • End-to-end < 100ms
    • Includes audio buffering, processing, inference
  2. Model Constraints
    • Model size < 10MB (on-device)
    • RAM usage < 50MB during inference
    • CPU-only (no GPU on most devices)
  3. Accuracy
    • 95% on target commands (clean audio)

    • 90% on noisy audio

    • < 5% false positive rate
  4. Throughput
    • 1000 QPS per server (cloud)
    • Single inference on device

Why Not ASR + NLU?

Traditional Pipeline

Audio → ASR → Text → NLU → Intent
"lights on" → ASR (200ms) → "lights on" → NLU (50ms) → {action: "lights", state: "on"}
Total latency: 250ms

Direct Classification

Audio → Audio Features → CNN → Intent
"lights on" → Mel-spec (5ms) → CNN (40ms) → {action: "lights", state: "on"}
Total latency: 45ms

Advantages:

  • ✅ 5x faster (45ms vs 250ms)
  • ✅ 10x smaller model (5MB vs 50MB)
  • ✅ Works offline
  • ✅ More private (no text)
  • ✅ Fewer points of failure

Disadvantages:

  • ❌ Limited vocabulary (30-50 commands vs unlimited)
  • ❌ Less flexible (new commands need retraining)
  • ❌ Can’t handle complex queries (“turn on the lights in the living room at 8pm”)

When to use each:

  • Direct classification: Simple commands, latency-critical, on-device
  • ASR + NLU: Complex queries, unlimited vocabulary, cloud-based

Architecture

Audio Input (1-2 seconds @ 16kHz)
    ↓
Audio Preprocessing
    ├─ Resampling (if needed)
    ├─ Padding/Trimming to fixed length
    └─ Normalization
    ↓
Feature Extraction
    ├─ MFCCs (40 coefficients)
    or
    ├─ Mel-Spectrogram (40 bins)
    ↓
Neural Network
    ├─ CNN (fastest, on-device)
    or
    ├─ RNN (better temporal modeling)
    or
    ├─ Attention (best accuracy, slower)
    ↓
Softmax Layer (31 classes)
    ├─ 30 command classes
    └─ 1 unknown class
    ↓
Post-processing
    ├─ Confidence thresholding
    ├─ Unknown detection
    └─ Output filtering
    ↓
Prediction: {command: "lights_on", confidence: 0.94}

Component 1: Audio Preprocessing

Fixed-Length Input

Problem: Audio clips have variable duration (0.5s - 3s)

Solution: Standardize to fixed length (e.g., 1 second)

def preprocess_audio(audio: np.ndarray, sr=16000, target_duration=1.0):
    """
    Ensure all audio clips are same length
    
    Args:
        audio: Audio waveform
        sr: Sample rate
        target_duration: Target duration in seconds
    
    Returns:
        Processed audio of length sr * target_duration
    """
    target_length = int(sr * target_duration)
    
    # Pad if too short
    if len(audio) < target_length:
        pad_length = target_length - len(audio)
        audio = np.pad(audio, (0, pad_length), mode='constant')
    
    # Trim if too long
    elif len(audio) > target_length:
        # Take central portion
        start = (len(audio) - target_length) // 2
        audio = audio[start:start + target_length]
    
    return audio

Why fixed length?

  • Neural networks expect fixed-size inputs
  • Enables batching during training
  • Simplifies model architecture

Alternative: Variable-length with padding

def pad_sequence(audios: list, sr=16000):
    """
    Pad multiple audio clips to longest length
    Used during batched inference
    """
    max_length = max(len(a) for a in audios)
    
    padded = []
    masks = []
    
    for audio in audios:
        pad_length = max_length - len(audio)
        padded_audio = np.pad(audio, (0, pad_length))
        mask = np.ones(len(audio)).tolist() + [0] * pad_length
        
        padded.append(padded_audio)
        masks.append(mask)
    
    return np.array(padded), np.array(masks)

Normalization

def normalize_audio(audio: np.ndarray) -> np.ndarray:
    """
    Normalize audio to [-1, 1] range
    
    Improves model convergence and generalization
    """
    # Peak normalization
    max_val = np.max(np.abs(audio))
    if max_val > 0:
        audio = audio / max_val
    
    return audio


def normalize_rms(audio: np.ndarray, target_rms=0.1) -> np.ndarray:
    """
    Normalize by RMS (root mean square) energy
    
    Better for handling volume variations
    """
    current_rms = np.sqrt(np.mean(audio ** 2))
    if current_rms > 0:
        audio = audio * (target_rms / current_rms)
    
    return audio

Component 2: Feature Extraction

Option 1: MFCCs (Mel-Frequency Cepstral Coefficients)

MFCCs capture the spectral envelope of speech, which is important for phonetic content.

import librosa

def extract_mfcc(audio, sr=16000, n_mfcc=40, n_fft=512, hop_length=160):
    """
    Extract MFCC features
    
    Args:
        audio: Waveform
        sr: Sample rate (Hz)
        n_mfcc: Number of MFCC coefficients
        n_fft: FFT window size
        hop_length: Hop length between frames (10ms at 16kHz)
    
    Returns:
        MFCCs: (n_mfcc, time_steps)
    """
    # Compute MFCCs
    mfccs = librosa.feature.mfcc(
        y=audio,
        sr=sr,
        n_mfcc=n_mfcc,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=40,          # Number of mel bands
        fmin=20,            # Minimum frequency
        fmax=sr//2          # Maximum frequency (Nyquist)
    )
    
    # Add delta (velocity) and delta-delta (acceleration)
    delta = librosa.feature.delta(mfccs)
    delta2 = librosa.feature.delta(mfccs, order=2)
    
    # Stack all features
    features = np.vstack([mfccs, delta, delta2])  # (120, time)
    
    return features.T  # (time, 120)

Why delta features?

  • MFCCs: Spectral shape (what phonemes)
  • Delta: How spectral shape is changing (dynamics)
  • Delta-delta: Rate of change (acceleration)

Together they capture both static and dynamic characteristics of speech.

Option 2: Mel-Spectrogram

Mel-spectrograms preserve more temporal resolution than MFCCs.

def extract_mel_spectrogram(audio, sr=16000, n_mels=40, n_fft=512, hop_length=160):
    """
    Extract log mel-spectrogram
    
    Returns:
        Log mel-spectrogram: (time, n_mels)
    """
    # Compute mel spectrogram
    mel_spec = librosa.feature.melspectrogram(
        y=audio,
        sr=sr,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels,
        fmin=20,
        fmax=sr//2
    )
    
    # Convert to log scale (dB)
    log_mel = librosa.power_to_db(mel_spec, ref=np.max)
    
    return log_mel.T  # (time, n_mels)

MFCCs vs Mel-Spectrogram:

Feature MFCCs Mel-Spectrogram
Size (time, 13-40) (time, 40-80)
Information Spectral envelope Full spectrum
Works better with Small models CNNs (image-like)
Training time Faster Slower
Accuracy Slightly lower Slightly higher

Recommendation: Use mel-spectrograms with CNNs for best accuracy.


Component 3: Model Architectures

Architecture 1: CNN (Fastest for On-Device)

import torch
import torch.nn as nn

class CommandCNN(nn.Module):
    """
    CNN for audio command classification
    
    Treats mel-spectrogram as 2D image
    """
    def __init__(self, num_classes=31, input_channels=1):
        super().__init__()
        
        # Convolutional layers
        self.conv1 = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        
        # Global average pooling (instead of fully-connected)
        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, x):
        # x: (batch, 1, time, freq)
        
        x = self.conv1(x)   # → (batch, 32, time/2, freq/2)
        x = self.conv2(x)   # → (batch, 64, time/4, freq/4)
        x = self.conv3(x)   # → (batch, 128, time/8, freq/8)
        
        x = self.gap(x)     # → (batch, 128, 1, 1)
        x = x.view(x.size(0), -1)  # → (batch, 128)
        
        x = self.classifier(x)  # → (batch, num_classes)
        
        return x

# Model size: ~2MB
# Inference time (CPU): 15ms
# Accuracy: ~93%

Why CNNs work for audio:

  • Local patterns: Phonemes have localized frequency patterns
  • Translation invariance: Command can start at different times
  • Parameter sharing: Same filters across time/frequency
  • Efficient: Mostly matrix operations, highly optimized

Architecture 2: RNN (Better Temporal Modeling)

class CommandRNN(nn.Module):
    """
    RNN for command classification
    
    Better at capturing temporal dependencies
    """
    def __init__(self, input_dim=40, hidden_dim=128, num_layers=2, num_classes=31):
        super().__init__()
        
        # LSTM layers
        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=0.2
        )
        
        # Attention mechanism (optional)
        self.attention = nn.Linear(hidden_dim * 2, 1)
        
        # Classification head
        self.classifier = nn.Linear(hidden_dim * 2, num_classes)
    
    def forward(self, x):
        # x: (batch, time, features)
        
        # LSTM
        lstm_out, _ = self.lstm(x)  # → (batch, time, hidden*2)
        
        # Attention pooling (instead of taking last time step)
        attention_weights = torch.softmax(
            self.attention(lstm_out),  # → (batch, time, 1)
            dim=1
        )
        
        # Weighted sum
        context = torch.sum(attention_weights * lstm_out, dim=1)  # → (batch, hidden*2)
        
        # Classify
        logits = self.classifier(context)  # → (batch, num_classes)
        
        return logits

# Model size: ~5MB
# Inference time (CPU): 30ms
# Accuracy: ~95%

Architecture 3: Attention-Based (Best Accuracy)

class CommandTransformer(nn.Module):
    """
    Transformer for command classification
    
    Best accuracy but slower inference
    """
    def __init__(self, input_dim=40, d_model=128, nhead=4, num_layers=2, num_classes=31):
        super().__init__()
        
        # Input projection
        self.embedding = nn.Linear(input_dim, d_model)
        
        # Positional encoding
        self.pos_encoder = PositionalEncoding(d_model)
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=d_model * 4,
            dropout=0.1
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Classification head
        self.classifier = nn.Linear(d_model, num_classes)
    
    def forward(self, x):
        # x: (batch, time, features)
        
        # Project to d_model
        x = self.embedding(x)  # → (batch, time, d_model)
        
        # Add positional encoding
        x = self.pos_encoder(x)
        
        # Transformer expects (time, batch, d_model)
        x = x.transpose(0, 1)
        x = self.transformer(x)
        x = x.transpose(0, 1)
        
        # Average pool over time
        x = x.mean(dim=1)  # → (batch, d_model)
        
        # Classify
        logits = self.classifier(x)  # → (batch, num_classes)
        
        return logits

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
    
    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

# Model size: ~8MB
# Inference time (CPU): 50ms
# Accuracy: ~97%

Model Comparison

Model Params Size CPU Latency GPU Latency Accuracy Best For
CNN 500K 2MB 15ms 3ms 93% Mobile devices
RNN 1.2M 5MB 30ms 5ms 95% Balanced
Transformer 2M 8MB 50ms 8ms 97% Cloud/high-end

Production choice: CNN for on-device, RNN for cloud


Training Strategy

Data Collection

Per command, need:

  • 1000-5000 examples
  • 100+ speakers (diversity)
  • Both genders, various ages
  • Different accents
  • Background noise variations
  • Different recording devices

Example dataset structure:

data/
├── lights_on/
│   ├── speaker001_01.wav
│   ├── speaker001_02.wav
│   ├── speaker002_01.wav
│   └── ...
├── lights_off/
│   └── ...
├── volume_up/
│   └── ...
└── unknown/
    ├── random_speech/
    ├── music/
    ├── noise/
    └── silence/

Data Augmentation

Critical for robustness! Augment during training:

import random

def augment_audio(audio, sr=16000):
    """
    Apply random augmentation
    
    Each training example augmented differently
    """
    augmentations = [
        add_noise,
        time_shift,
        time_stretch,
        pitch_shift,
        add_reverb
    ]
    
    # Apply 1-3 random augmentations
    num_augs = random.randint(1, 3)
    selected = random.sample(augmentations, num_augs)
    
    for aug_fn in selected:
        audio = aug_fn(audio, sr)
    
    return audio


def add_noise(audio, sr, snr_db=random.uniform(5, 20)):
    """Add background noise at specific SNR"""
    # Load random noise sample
    noise = load_random_noise_sample(len(audio))
    
    # Calculate noise power for target SNR
    audio_power = np.mean(audio ** 2)
    noise_power = audio_power / (10 ** (snr_db / 10))
    noise_scaled = noise * np.sqrt(noise_power / np.mean(noise ** 2))
    
    return audio + noise_scaled


def time_shift(audio, sr, shift_max=0.1):
    """Shift audio in time (simulates different reaction times)"""
    shift = int(sr * shift_max * (random.random() - 0.5))
    return np.roll(audio, shift)


def time_stretch(audio, sr, rate=random.uniform(0.9, 1.1)):
    """Change speed without changing pitch"""
    return librosa.effects.time_stretch(audio, rate=rate)


def pitch_shift(audio, sr, n_steps=random.randint(-2, 2)):
    """Shift pitch (simulates different speakers)"""
    return librosa.effects.pitch_shift(audio, sr=sr, n_steps=n_steps)


def add_reverb(audio, sr):
    """Add room reverb (simulates different environments)"""
    # Simple reverb using convolution with impulse response
    impulse_response = generate_simple_reverb(sr)
    return np.convolve(audio, impulse_response, mode='same')

Impact: 2-3x effective dataset size, 10-20% accuracy improvement

Training Loop

def train_command_classifier(
    model, 
    train_loader, 
    val_loader, 
    epochs=100, 
    lr=0.001
):
    """
    Train speech command classifier
    """
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='max',
        factor=0.5,
        patience=5,
        verbose=True
    )
    
    best_val_acc = 0.0
    
    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0
        
        for batch_idx, (audio, labels) in enumerate(train_loader):
            # Extract features
            features = extract_features_batch(audio, sr=16000)
            features = torch.tensor(features, dtype=torch.float32)
            
            # Add channel dimension for CNN
            if len(features.shape) == 3:
                features = features.unsqueeze(1)  # (batch, 1, time, freq)
            
            labels = torch.tensor(labels, dtype=torch.long)
            
            # Forward
            outputs = model(features)
            loss = criterion(outputs, labels)
            
            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Track accuracy
            _, predicted = torch.max(outputs, 1)
            train_correct += (predicted == labels).sum().item()
            train_total += labels.size(0)
            train_loss += loss.item()
        
        train_acc = train_correct / train_total
        avg_loss = train_loss / len(train_loader)
        
        # Validation
        val_acc = validate(model, val_loader)
        
        # Learning rate scheduling
        scheduler.step(val_acc)
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"✓ New best model: {val_acc:.4f}")
        
        print(f"Epoch {epoch+1}/{epochs}: "
              f"Loss={avg_loss:.4f}, "
              f"Train Acc={train_acc:.4f}, "
              f"Val Acc={val_acc:.4f}")
    
    return model


def validate(model, val_loader):
    """Evaluate on validation set"""
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for audio, labels in val_loader:
            features = extract_features_batch(audio)
            features = torch.tensor(features).unsqueeze(1)
            labels = torch.tensor(labels)
            
            outputs = model(features)
            _, predicted = torch.max(outputs, 1)
            
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    
    return correct / total

Component 4: Handling Unknown Commands

Strategy 1: Add “Unknown” Class

# Training data
command_classes = [
    "lights_on", "lights_off", "volume_up", "volume_down",
    "play_music", "stop", "pause", "next", "previous",
    # ... 30 total commands
]

# Collect negative examples
unknown_class = [
    "random_speech",  # Conversations
    "music",          # Background music
    "noise",          # Environmental sounds
    "silence"         # No speech
]

# Labels: 0-29 for commands, 30 for unknown
all_classes = command_classes + ["unknown"]

Collecting unknown data:

# Record actual user interactions
# Label anything that's NOT a command as "unknown"

unknown_samples = []

for audio in production_audio_stream:
    if not is_valid_command(audio):
        unknown_samples.append(audio)
        
        if len(unknown_samples) >= 10000:
            # Add to training set
            augment_and_save(unknown_samples, label="unknown")

Strategy 2: Confidence Thresholding

def predict_with_threshold(model, audio, threshold=0.7):
    """
    Reject low-confidence predictions as unknown
    """
    # Extract features
    features = extract_mel_spectrogram(audio)
    features = torch.tensor(features).unsqueeze(0).unsqueeze(0)
    
    # Predict
    with torch.no_grad():
        logits = model(features)
        probs = torch.softmax(logits, dim=1)[0]
    
    # Get top prediction
    max_prob, predicted_class = torch.max(probs, 0)
    
    # Threshold check
    if max_prob < threshold:
        return "unknown", float(max_prob)
    
    return command_classes[predicted_class], float(max_prob)

Strategy 3: Out-of-Distribution (OOD) Detection

def detect_ood_with_entropy(probs):
    """
    High entropy = model is uncertain = likely OOD
    """
    entropy = -torch.sum(probs * torch.log(probs + 1e-10))
    
    # Calibrate threshold on validation set
    # In-distribution: entropy ~0.5
    # Out-of-distribution: entropy > 2.0
    
    if entropy > 2.0:
        return True  # OOD
    return False


def detect_ood_with_mahalanobis(features, class_means, class_covariances):
    """
    Mahalanobis distance to class centroids
    
    Far from all classes = likely OOD
    """
    min_distance = float('inf')
    
    for class_idx in range(len(class_means)):
        mean = class_means[class_idx]
        cov = class_covariances[class_idx]
        
        # Mahalanobis distance
        diff = features - mean
        distance = np.sqrt(diff.T @ np.linalg.inv(cov) @ diff)
        
        min_distance = min(min_distance, distance)
    
    # Threshold: 3-sigma rule
    if min_distance > 3.0:
        return True  # OOD
    return False

Model Optimization for Edge Deployment

Quantization

# Post-training quantization (dynamic quantization targets Linear; Conv2d not supported)
model_fp32 = CommandCNN(num_classes=31)
model_fp32.load_state_dict(torch.load('model.pth'))
model_fp32.eval()

# Dynamic quantization (Linear layers)
model_int8 = torch.quantization.quantize_dynamic(
    model_fp32,
    {torch.nn.Linear},
    dtype=torch.qint8
)

# Save
torch.save(model_int8.state_dict(), 'model_int8.pth')

# Results (typical on CPU with CNN head including Linear):
# - Model size: 2MB → ~1.2MB (1.6x smaller)
# - Inference: 15ms → ~10-12ms (1.3-1.5x faster)
# - Accuracy: ~93.2% → ~93.0% (≤0.2% drop)

Pruning

import torch.nn.utils.prune as prune

def prune_model(model, amount=0.3):
    """
    Remove 30% of weights with lowest magnitude
    """
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            prune.l1_unstructured(module, name='weight', amount=amount)
    
    return model

# Results with 30% pruning:
# - Model size: 2MB → 1.4MB
# - Inference: 15ms → 12ms
# - Accuracy: 93.2% → 92.7%

Knowledge Distillation

def distillation_loss(student_logits, teacher_logits, labels, temperature=3.0, alpha=0.7):
    """
    Train small student to mimic large teacher
    
    Args:
        temperature: Soften probability distributions
        alpha: Weight between soft and hard targets
    """
    # Soft targets from teacher
    soft_targets = torch.softmax(teacher_logits / temperature, dim=1)
    soft_prob = torch.log_softmax(student_logits / temperature, dim=1)
    soft_loss = -torch.sum(soft_targets * soft_prob) / soft_prob.size()[0]
    soft_loss = soft_loss * (temperature ** 2)
    
    # Hard targets (ground truth)
    hard_loss = nn.CrossEntropyLoss()(student_logits, labels)
    
    # Combine
    return alpha * soft_loss + (1 - alpha) * hard_loss


# Train student
teacher = CommandTransformer(num_classes=31)  # 8MB, 97% accuracy
student = CommandCNN(num_classes=31)          # 2MB, 93% accuracy

for audio, labels in train_loader:
    # Teacher predictions (frozen)
    with torch.no_grad():
        teacher_logits = teacher(audio)
    
    # Student predictions
    student_logits = student(audio)
    
    # Distillation loss
    loss = distillation_loss(student_logits, teacher_logits, labels)
    
    # Optimize student
    loss.backward()
    optimizer.step()

# Result: Student achieves 95% (vs 93% without distillation)

On-Device Deployment

Export to Mobile Formats

TensorFlow Lite (Android):

import tensorflow as tf

# Convert PyTorch to TensorFlow (via ONNX)
# 1. Export PyTorch to ONNX
torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    input_names=['input'],
    output_names=['output']
)

# 2. Convert ONNX to TF
import onnx
from onnx_tf.backend import prepare

onnx_model = onnx.load("model.onnx")
tf_model = prepare(onnx_model)
tf_model.export_graph("model_tf")

# 3. Convert TF to TFLite
converter = tf.lite.TFLiteConverter.from_saved_model("model_tf")
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

with open('command_classifier.tflite', 'wb') as f:
    f.write(tflite_model)

Core ML (iOS):

import coremltools as ct

# Trace PyTorch model
example_input = torch.randn(1, 1, 100, 40)
traced_model = torch.jit.trace(model, example_input)

# Convert to Core ML
coreml_model = ct.convert(
    traced_model,
    inputs=[ct.TensorType(name="audio", shape=(1, 1, 100, 40))],
    outputs=[ct.TensorType(name="logits")]
)

# Add metadata
coreml_model.author = "Arun Baby"
coreml_model.short_description = "Speech command classifier"
coreml_model.version = "1.0"

# Save
coreml_model.save("CommandClassifier.mlmodel")

Mobile Inference Code

Android (Kotlin):

import org.tensorflow.lite.Interpreter
import java.nio.ByteBuffer

class CommandClassifier(private val context: Context) {
    private lateinit var interpreter: Interpreter
    
    init {
        // Load model
        val model = loadModelFile("command_classifier.tflite")
        interpreter = Interpreter(model)
    }
    
    fun classify(audio: FloatArray): Pair<String, Float> {
        // Extract features
        val features = extractMelSpectrogram(audio)
        
        // Prepare input
        val inputBuffer = ByteBuffer.allocateDirect(4 * features.size)
        inputBuffer.order(ByteOrder.nativeOrder())
        features.forEach { inputBuffer.putFloat(it) }
        
        // Prepare output
        val output = Array(1) { FloatArray(31) }
        
        // Run inference
        interpreter.run(inputBuffer, output)
        
        // Get top prediction
        val probabilities = output[0]
        val maxIndex = probabilities.indices.maxByOrNull { probabilities[it] } ?: 0
        val confidence = probabilities[maxIndex]
        
        return Pair(commandNames[maxIndex], confidence)
    }
}

iOS (Swift):

import CoreML

class CommandClassifier {
    private var model: CommandClassifierModel!
    
    init() {
        model = try! CommandClassifierModel(configuration: MLModelConfiguration())
    }
    
    func classify(audio: [Float]) -> (command: String, confidence: Double) {
        // Extract features
        let features = extractMelSpectrogram(audio)
        
        // Create MLMultiArray
        let input = try! MLMultiArray(shape: [1, 1, 100, 40], dataType: .float32)
        for i in 0..<features.count {
            input[i] = NSNumber(value: features[i])
        }
        
        // Run inference
        let output = try! model.prediction(audio: input)
        
        // Get top prediction
        let probabilities = output.logits
        let maxIndex = probabilities.argmax()
        let confidence = probabilities[maxIndex]
        
        return (commandNames[maxIndex], Double(confidence))
    }
}

Monitoring & Evaluation

Metrics Dashboard

from dataclasses import dataclass
from typing import List

@dataclass
class ClassificationMetrics:
    """Per-class metrics"""
    precision: float
    recall: float
    f1_score: float
    support: int  # Number of samples
    
def compute_metrics(y_true: List[int], y_pred: List[int], num_classes: int):
    """
    Compute detailed metrics per class
    """
    from sklearn.metrics import classification_report, confusion_matrix
    
    # Per-class metrics
    report = classification_report(y_true, y_pred, output_dict=True)
    
    # Confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    
    # Identify problematic classes
    for i in range(num_classes):
        if report[str(i)]['f1-score'] < 0.85:
            print(f"⚠️  Class {i} ({command_names[i]}) has low F1: {report[str(i)]['f1-score']:.3f}")
            
            # Find most confused class
            confused_with = cm[i].argmax()
            if confused_with != i:
                print(f"   Most confused with class {confused_with} ({command_names[confused_with]})")
    
    return report, cm

Online Monitoring

class OnlineMetricsTracker:
    """
    Track metrics in production
    """
    def __init__(self):
        self.predictions = []
        self.confidences = []
        self.latencies = []
    
    def record(self, prediction: int, confidence: float, latency_ms: float):
        """Record single prediction"""
        self.predictions.append(prediction)
        self.confidences.append(confidence)
        self.latencies.append(latency_ms)
    
    def get_stats(self, last_n=1000):
        """Get recent statistics"""
        recent_preds = self.predictions[-last_n:]
        recent_confs = self.confidences[-last_n:]
        recent_lats = self.latencies[-last_n:]
        
        # Class distribution
        from collections import Counter
        class_dist = Counter(recent_preds)
        
        return {
            'total_predictions': len(recent_preds),
            'class_distribution': dict(class_dist),
            'avg_confidence': np.mean(recent_confs),
            'low_confidence_rate': sum(c < 0.7 for c in recent_confs) / len(recent_confs),
            'p50_latency': np.percentile(recent_lats, 50),
            'p95_latency': np.percentile(recent_lats, 95),
            'p99_latency': np.percentile(recent_lats, 99)
        }

Multi-Language Support

Approach 1: Separate Models per Language

Pros:

  • Best accuracy per language
  • Language-specific optimizations
  • Easier to add new languages

Cons:

  • Multiple models to maintain
  • Higher storage footprint
  • Language detection needed first
class MultilingualClassifier:
    """
    Separate model per language
    """
    def __init__(self):
        self.models = {
            'en': load_model('command_en.pth'),
            'es': load_model('command_es.pth'),
            'fr': load_model('command_fr.pth'),
            'de': load_model('command_de.pth'),
            'ja': load_model('command_ja.pth')
        }
        self.language_detector = load_model('lang_detect.pth')
    
    def predict(self, audio):
        # Detect language first
        language = self.language_detector.predict(audio)
        
        # Use language-specific model
        model = self.models[language]
        prediction = model.predict(audio)
        
        return prediction, language

Storage requirement: 5 languages × 2MB = 10MB

Approach 2: Multilingual Shared Model

Training strategy:

def train_multilingual_model():
    """
    Single model trained on all languages
    
    Add language ID as auxiliary input
    """
    model = MultilingualCommandCNN(
        num_classes=30,
        num_languages=5
    )
    
    # Training data from all languages
    for audio, command_label, lang_id in train_loader:
        features = extract_features(audio)
        
        # Forward pass with language embedding
        command_pred = model(features, lang_id)
        
        # Loss
        loss = criterion(command_pred, command_label)
        
        loss.backward()
        optimizer.step()
    
    return model

Model architecture:

class MultilingualCommandCNN(nn.Module):
    """
    Shared model with language embeddings
    """
    def __init__(self, num_classes=30, num_languages=5, embedding_dim=16):
        super().__init__()
        
        # Language embedding
        self.lang_embedding = nn.Embedding(num_languages, embedding_dim)
        
        # Shared CNN backbone
        self.cnn = CommandCNN(num_classes=128)  # Feature extractor
        
        # Language-conditioned classifier
        self.classifier = nn.Linear(128 + embedding_dim, num_classes)
    
    def forward(self, audio_features, language_id):
        # CNN features
        cnn_features = self.cnn(audio_features)  # (batch, 128)
        
        # Language embedding
        lang_emb = self.lang_embedding(language_id)  # (batch, 16)
        
        # Concatenate
        combined = torch.cat([cnn_features, lang_emb], dim=1)  # (batch, 144)
        
        # Classify
        logits = self.classifier(combined)  # (batch, num_classes)
        
        return logits

Pros:

  • Single model (2-3MB)
  • Shared representations across languages
  • Transfer learning for low-resource languages

Cons:

  • Slightly lower accuracy per language
  • All languages must use same command set

Failure Cases & Mitigation

Common Failure Modes

1. Background Speech/TV

Problem: Model activates on TV dialogue or background conversation

Mitigation:

def detect_background_speech(audio, sr=16000):
    """
    Detect if audio is from TV/background vs direct user speech
    
    Features:
    - Energy envelope variation (TV more consistent)
    - Reverb characteristics (TV more reverberant)
    - Spectral rolloff (TV often compressed)
    """
    # Energy variation
    frame_energy = librosa.feature.rms(y=audio)[0]
    energy_std = np.std(frame_energy)
    
    # TV has lower energy variation
    if energy_std < 0.01:
        return True  # Likely background
    
    # Spectral centroid (TV often band-limited)
    spectral_centroid = librosa.feature.spectral_centroid(y=audio, sr=sr)[0]
    avg_centroid = np.mean(spectral_centroid)
    
    if avg_centroid < 1000:  # Hz
        return True  # Likely background
    
    return False

Additional strategy: Use speaker verification to check if it’s the registered user

2. Accented Speech

Problem: Model trained on standard accent performs poorly on regional accents

Mitigation:

# Data collection strategy
accent_distribution = {
    'general_american': 0.3,
    'british': 0.15,
    'australian': 0.1,
    'indian': 0.15,
    'southern_us': 0.1,
    'canadian': 0.1,
    'other': 0.1
}

# Ensure balanced training data
for accent, proportion in accent_distribution.items():
    required_samples = total_samples * proportion
    collect_samples(accent, required_samples)

# Use accent-aware data augmentation
def accent_aware_augmentation(audio, accent_type):
    """Apply accent-specific augmentations"""
    if accent_type == 'indian':
        # Indian English: Stronger pitch variation
        audio = pitch_shift(audio, n_steps=random.randint(-3, 3))
    elif accent_type == 'southern_us':
        # Southern US: Slower speech
        audio = time_stretch(audio, rate=random.uniform(0.85, 1.0))
    
    return audio

3. Noisy Environments

Problem: Model degrades in cafes, cars, streets

Mitigation:

def enhance_audio_for_inference(audio, sr=16000):
    """
    Lightweight denoising for inference
    
    Must be < 5ms to maintain latency budget
    """
    # Spectral gating (simple but effective)
    stft = librosa.stft(audio)
    magnitude = np.abs(stft)
    
    # Estimate noise floor (first 100ms)
    noise_frames = magnitude[:, :10]
    noise_threshold = np.mean(noise_frames, axis=1, keepdims=True) * 1.5
    
    # Gate
    mask = magnitude > noise_threshold
    stft_denoised = stft * mask
    
    # Inverse STFT
    audio_denoised = librosa.istft(stft_denoised)
    
    return audio_denoised

Better approach: Train with noisy data

# Use diverse noise types during training
noise_types = [
    'cafe_ambiance',
    'car_interior',
    'street_traffic',
    'office_chatter',
    'home_appliances',
    'rain',
    'wind'
]

for audio, label in train_loader:
    # Add random noise
    noise_type = random.choice(noise_types)
    noisy_audio = add_noise(audio, noise_type, snr_db=random.uniform(5, 20))

4. Similar Sounding Commands

Problem: “lights on” vs “lights off”, “volume up” vs “volume down”

Mitigation:

# Use contrastive learning during training
def contrastive_loss(anchor, positive, negative, margin=1.0):
    """
    Pull together similar commands, push apart confusable ones
    """
    pos_distance = torch.norm(anchor - positive, dim=1)
    neg_distance = torch.norm(anchor - negative, dim=1)
    
    loss = torch.relu(pos_distance - neg_distance + margin)
    
    return loss.mean()

# Identify confusable pairs
confusable_pairs = [
    ('lights_on', 'lights_off'),
    ('volume_up', 'volume_down'),
    ('next', 'previous'),
    ('play', 'pause')
]

# During training
for audio, label in train_loader:
    features = model.extract_features(audio)
    
    # For confusable commands, add contrastive loss
    if label in confusable_commands:
        opposite_label = get_opposite_command(label)
        opposite_audio = sample_from_class(opposite_label)
        opposite_features = model.extract_features(opposite_audio)
        
        total_loss = classification_loss + 0.2 * contrastive_loss(
            features, 
            features,  # Anchor to itself
            opposite_features
        )

Production Deployment Architecture

Edge Deployment (Smart Speaker)

┌─────────────────────────────────────────┐
│         Smart Speaker Device            │
├─────────────────────────────────────────┤
│                                         │
│  Microphone Array                       │
│       ↓                                 │
│  Beamforming (5ms)                      │
│       ↓                                 │
│  Wake Word Detection (10ms)             │
│       ↓                                 │
│  [If wake word detected]                │
│       ↓                                 │
│  Audio Buffer (1 second)                │
│       ↓                                 │
│  Feature Extraction (5ms)               │
│       ↓                                 │
│  Command CNN Inference (15ms)           │
│       ↓                                 │
│  ┌──────────────┐                       │
│  │ Confidence   │                       │
│  │   > 0.85?    │                       │
│  └──────┬───────┘                       │
│         │                               │
│    Yes  │  No                           │
│         ↓                               │
│  Execute Command    Send to Cloud ASR   │
│                                         │
└─────────────────────────────────────────┘

Total latency (on-device): < 40ms
Power consumption: < 100mW during inference

Hybrid Edge-Cloud Architecture

class HybridCommandClassifier:
    """
    Intelligent routing between edge and cloud
    """
    def __init__(self):
        self.edge_model = load_edge_model()  # Small CNN
        self.cloud_client = CloudASRClient()
        
        # Common commands handled on-device
        self.edge_commands = {
            'lights_on', 'lights_off', 
            'volume_up', 'volume_down',
            'play', 'pause', 'stop',
            'next', 'previous'
        }
    
    async def classify(self, audio):
        # Try edge first
        edge_pred, edge_conf = self.edge_model.predict(audio)
        
        # High confidence + known command → use edge
        if edge_conf > 0.85 and edge_pred in self.edge_commands:
            return {
                'command': edge_pred,
                'confidence': edge_conf,
                'source': 'edge',
                'latency_ms': 35
            }
        
        # Otherwise → cloud ASR
        cloud_result = await self.cloud_client.recognize(audio)
        
        return {
            'command': cloud_result['text'],
            'confidence': cloud_result['confidence'],
            'source': 'cloud',
            'latency_ms': 250
        }

Benefits:

  • ✅ 90% of commands handled on-device (< 50ms)
  • ✅ 10% fall back to cloud for complex queries
  • ✅ Privacy for common commands
  • ✅ Graceful degradation if network unavailable

A/B Testing & Gradual Rollout

Experiment Framework

class ModelExperiment:
    """
    A/B test new model versions
    """
    def __init__(self, control_model, treatment_model, treatment_percentage=10):
        self.control = control_model
        self.treatment = treatment_model
        self.treatment_pct = treatment_percentage
    
    def predict(self, audio, user_id):
        # Deterministic assignment based on user_id
        bucket = hash(user_id) % 100
        
        if bucket < self.treatment_pct:
            # Treatment group
            pred, conf = self.treatment.predict(audio)
            variant = 'treatment'
        else:
            # Control group
            pred, conf = self.control.predict(audio)
            variant = 'control'
        
        # Log for analysis
        self.log_prediction(user_id, variant, pred, conf)
        
        return pred, conf
    
    def log_prediction(self, user_id, variant, prediction, confidence):
        """Log to analytics system"""
        event = {
            'user_id': user_id,
            'timestamp': time.time(),
            'variant': variant,
            'prediction': prediction,
            'confidence': confidence
        }
        
        analytics_logger.log(event)

Metrics to Track

def compute_experiment_metrics(control_group, treatment_group):
    """
    Compare model versions
    """
    metrics = {}
    
    # Accuracy (if ground truth available)
    if has_ground_truth:
        metrics['accuracy_control'] = compute_accuracy(control_group)
        metrics['accuracy_treatment'] = compute_accuracy(treatment_group)
    
    # Confidence distribution
    metrics['avg_confidence_control'] = np.mean([x['confidence'] for x in control_group])
    metrics['avg_confidence_treatment'] = np.mean([x['confidence'] for x in treatment_group])
    
    # Latency
    metrics['p95_latency_control'] = np.percentile([x['latency'] for x in control_group], 95)
    metrics['p95_latency_treatment'] = np.percentile([x['latency'] for x in treatment_group], 95)
    
    # User engagement (proxy for accuracy)
    metrics['retry_rate_control'] = compute_retry_rate(control_group)
    metrics['retry_rate_treatment'] = compute_retry_rate(treatment_group)
    
    # Statistical significance
    from scipy.stats import ttest_ind
    
    control_success = [x['success'] for x in control_group]
    treatment_success = [x['success'] for x in treatment_group]
    
    t_stat, p_value = ttest_ind(control_success, treatment_success)
    metrics['p_value'] = p_value
    metrics['is_significant'] = p_value < 0.05
    
    return metrics

Real-World Examples

Google Assistant

“Hey Google” Wake Word:

  • Always-on detection using tiny model (< 1MB)
  • Runs on low-power co-processor (DSP)
  • < 10ms latency, ~0.5mW power
  • ~ 99.5% accuracy on target phrase
  • Personalized over time with on-device learning

Command Classification:

  • Separate model for common commands (~30 commands)
  • Fallback to full ASR for complex queries
  • On-device for privacy (no audio sent to cloud)
  • Multi-language support (40+ languages)

Architecture:

Microphone → Beamformer → Wake Word → Command CNN → Execute
                                              ↓
                                         (if low conf)
                                              ↓
                                         Cloud ASR

Amazon Alexa

“Alexa” Wake Word:

  • Multi-stage cascade:
    • Stage 1: Energy detector (< 1ms, filters silence)
    • Stage 2: Keyword spotter (< 10ms, CNN)
    • Stage 3: Full verification (< 50ms, larger model)
  • Reduces false positives by 10x
  • Power-efficient (only stage 3 uses main CPU)

Custom Skills:

  • Slot-filling approach for structured commands
  • Template: “play {song} by {artist}”
  • Combined classification + entity extraction
  • ~100K custom skills available

Deployment:

  • Edge: Wake word + simple commands
  • Cloud: Everything else (200ms latency acceptable)

Apple Siri

“Hey Siri” Detection:

  • Neural network on Neural Engine (dedicated ML chip)
  • Personalized to user’s voice during setup
  • Continuously adapts to voice changes
  • < 50ms latency
  • Works offline (completely on-device)
  • Power: < 1mW in always-listening mode

Privacy Design:

  • Audio never sent to cloud without explicit activation
  • Voice profile stored locally (encrypted)
  • Random identifier (not tied to Apple ID)

Technical Details:

  • Uses LSTM for temporal modeling
  • Trained on millions of “Hey Siri” variations
  • Negative examples: TV shows, movies, other voices

Key Takeaways

Direct audio→intent faster than ASR→NLU for limited commands
CNNs on mel-spectrograms work excellently for on-device
Data augmentation critical for robustness (noise, time shift, pitch)
Unknown class handling prevents false activations
Quantization achieves 4x compression with < 1% accuracy loss
Threshold tuning balances precision/recall for business needs


Further Reading

Papers:

Datasets:

Tools:


Originally published at: arunbaby.com/speech-tech/0002-speech-classification

If you found this helpful, consider sharing it with others who might benefit.