26 minute read

Build lightweight models that detect specific keywords in audio streams with minimal latency and power consumption for voice interfaces.

TL;DR

Keyword spotting enables always-on wake word detection (“Hey Siri”, “Alexa”) with sub-50ms latency and sub-1MB models on edge devices. Production systems use a two-stage pipeline: a tiny CNN (~50KB) screens continuously, then a larger model confirms detections, saving up to 90% battery. Key techniques include mel-spectrogram features, model quantization (INT8) for 4x speedup, binary search for optimal detection thresholds, and on-device personalization. Related: streaming speech pipelines for the audio transport layer and voice activity detection for silence detection fundamentals.

A tiny MEMS microphone chip on a green PCB board next to a coin for scale

Introduction

Keyword spotting (KWS) detects specific words or phrases in continuous audio streams, enabling voice-activated interfaces.

Common applications:

  • Wake word detection (“Hey Siri”, “Alexa”, “OK Google”)
  • Voice commands (“Play”, “Stop”, “Next”)
  • Accessibility features (voice navigation)
  • Security (speaker verification)

Key requirements:

  • Ultra-low latency: < 50ms detection time
  • Low power: Run continuously on battery
  • Small model: Fit on edge devices (< 1MB)
  • High accuracy: < 1% false acceptance rate
  • Noise robust: Work in real-world conditions

Problem Formulation

Task Definition

Given audio input, classify whether a target keyword is present:

Input: Audio waveform (e.g., 1 second, 16kHz = 16,000 samples)
Output: {keyword, no_keyword}

Example:
 Audio: "Hey Siri, what's the weather?"
 Output: keyword="hey_siri", timestamp=0.0s

Challenges

  1. Always-on constraint: Must run 24/7 without draining battery
  2. False positives: Accidental activations frustrate users
  3. False negatives: Missed detections break user experience
  4. Noise robustness: Background noise, music, TV
  5. Speaker variability: Different accents, ages, genders

System Architecture

┌─────────────────────────────────────────────────────────┐
│ Microphone Input │
└────────────────────┬────────────────────────────────────┘
 │ Continuous audio stream
 ▼
 ┌────────────────────────┐
 │ Audio Preprocessing │
 │ - Noise reduction │
 │ - Normalization │
 └───────────┬────────────┘
 │
 ▼
 ┌────────────────────────┐
 │ Feature Extraction │
 │ - MFCC / Mel-spec │
 │ - Sliding window │
 └───────────┬────────────┘
 │
 ▼
 ┌────────────────────────┐
 │ KWS Model (Tiny NN) │
 │ - CNN / RNN / TCN │
 │ - < 1MB, < 10ms │
 └───────────┬────────────┘
 │
 ▼
 ┌────────────────────────┐
 │ Post-processing │
 │ - Threshold / Smooth │
 │ - Reject false pos │
 └───────────┬────────────┘
 │
 ▼
 ┌────────────────┐
 │ Trigger Event │
 │ (Wake system) │
 └────────────────┘

Model Architectures

Approach 1: CNN-based KWS

Small convolutional network on spectrograms

import torch
import torch.nn as nn

class KeywordSpottingCNN(nn.Module):
    """
    Lightweight CNN for keyword spotting

    Input: Mel-spectrogram (n_mels, time_steps)
    Output: Keyword confidence score

    Model size: ~100KB
    Inference time: ~5ms on CPU
    """

    def __init__(self, n_mels=40, n_classes=2):
        super().__init__()

        # Convolutional layers
        self.conv1 = nn.Sequential(
        nn.Conv2d(1, 64, kernel_size=3, padding=1),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.MaxPool2d(2)
        )

        self.conv2 = nn.Sequential(
        nn.Conv2d(64, 64, kernel_size=3, padding=1),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.MaxPool2d(2)
        )

        # Global average pooling
        self.gap = nn.AdaptiveAvgPool2d((1, 1))

        # Classifier
        self.fc = nn.Linear(64, n_classes)

    def forward(self, x):
        """
        Args:
            x: [batch, 1, n_mels, time_steps]

            Returns:
                [batch, n_classes]
                """
                x = self.conv1(x)
                x = self.conv2(x)
                x = self.gap(x)
                x = x.view(x.size(0), -1)
                x = self.fc(x)
                return x

                # Create model
                model = KeywordSpottingCNN(n_mels=40, n_classes=2)

                # Count parameters
                n_params = sum(p.numel() for p in model.parameters())
                print(f"Model parameters: {n_params:,}") # ~30K parameters

                # Estimate model size
                model_size_mb = n_params * 4 / (1024 ** 2) # 4 bytes per float32
                print(f"Model size: {model_size_mb:.2f} MB")

Approach 2: RNN-based KWS

Temporal modeling with GRU

class KeywordSpottingGRU(nn.Module):
    """
    GRU-based keyword spotting

    Better for temporal patterns, slightly larger
    """

    def __init__(self, n_mels=40, hidden_size=64, n_layers=2, n_classes=2):
        super().__init__()

        self.gru = nn.GRU(
        input_size=n_mels,
        hidden_size=hidden_size,
        num_layers=n_layers,
        batch_first=True,
        bidirectional=False # Unidirectional for streaming
        )

        self.fc = nn.Linear(hidden_size, n_classes)

    def forward(self, x):
        """
        Args:
            x: [batch, time_steps, n_mels]

            Returns:
                [batch, n_classes]
                """
                # GRU forward pass
                out, hidden = self.gru(x)

                # Use last hidden state
                x = out[:, -1, :]

                # Classifier
                x = self.fc(x)
                return x

                model_gru = KeywordSpottingGRU(n_mels=40, hidden_size=64)

Approach 3: Temporal Convolutional Network

Efficient temporal modeling

class TemporalBlock(nn.Module):
    """Single temporal convolutional block"""

    def __init__(self, n_inputs, n_outputs, kernel_size, dilation):
        super().__init__()

        self.conv1 = nn.Conv1d(
        n_inputs, n_outputs, kernel_size,
        padding=(kernel_size-1) * dilation // 2,
        dilation=dilation
        )
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)

        # Residual connection
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) \
        if n_inputs != n_outputs else None

    def forward(self, x):
        out = self.conv1(x)
        out = self.relu(out)
        out = self.dropout(out)

        res = x if self.downsample is None else self.downsample(x)
        return out + res

    class KeywordSpottingTCN(nn.Module):
        """
        Temporal Convolutional Network for KWS

        Combines efficiency of CNNs with temporal modeling
        """

    def __init__(self, n_mels=40, n_classes=2):
        super().__init__()

        self.blocks = nn.Sequential(
        TemporalBlock(n_mels, 64, kernel_size=3, dilation=1),
        TemporalBlock(64, 64, kernel_size=3, dilation=2),
        TemporalBlock(64, 64, kernel_size=3, dilation=4),
        )

        self.gap = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(64, n_classes)

    def forward(self, x):
        """
        Args:
            x: [batch, time_steps, n_mels]

            Returns:
                [batch, n_classes]
                """
                # Transpose for conv1d: [batch, n_mels, time_steps]
                x = x.transpose(1, 2)

                # Temporal blocks
                x = self.blocks(x)

                # Global average pooling
                x = self.gap(x).squeeze(-1)

                # Classifier
                x = self.fc(x)
                return x

                model_tcn = KeywordSpottingTCN(n_mels=40)

Feature Extraction Pipeline

import librosa
import numpy as np

class KeywordSpottingFeatureExtractor:
    """
    Extract features for keyword spotting

    Optimized for real-time processing
    """

    def __init__(self, sample_rate=16000, window_size_ms=30,
    hop_size_ms=10, n_mels=40):
        self.sample_rate = sample_rate
        self.n_fft = int(sample_rate * window_size_ms / 1000)
        self.hop_length = int(sample_rate * hop_size_ms / 1000)
        self.n_mels = n_mels

        # Precompute mel filterbank
        self.mel_basis = librosa.filters.mel(
        sr=sample_rate,
        n_fft=self.n_fft,
        n_mels=n_mels,
        fmin=0,
        fmax=sample_rate / 2
        )

    def extract(self, audio):
        """
        Extract mel-spectrogram features

        Args:
            audio: Audio samples [n_samples]

            Returns:
                Mel-spectrogram [n_mels, time_steps]
                """
                # Compute STFT
                stft = librosa.stft(
                audio,
                n_fft=self.n_fft,
                hop_length=self.hop_length,
                window='hann'
                )

                # Power spectrogram
                power = np.abs(stft) ** 2

                # Apply mel filterbank on power
                mel_power = np.dot(self.mel_basis, power)

                # Log compression (power → dB)
                mel_db = librosa.power_to_db(mel_power, ref=np.max)

                return mel_db

    def extract_from_stream(self, audio_chunk):
        """
        Extract features from streaming audio

        Optimized for low latency
        """
        return self.extract(audio_chunk)

        # Usage
        extractor = KeywordSpottingFeatureExtractor(sample_rate=16000)

        # Extract features from 1-second audio
        audio = np.random.randn(16000)
        features = extractor.extract(audio)
        print(f"Feature shape: {features.shape}") # (40, 101)

Training Pipeline

Data Preparation

import torch
from torch.utils.data import Dataset, DataLoader
import librosa
import numpy as np

class KeywordSpottingDataset(Dataset):
    """
    Dataset for keyword spotting training

    Handles positive (keyword) and negative (non-keyword) examples
    """

    def __init__(self, audio_files, labels, feature_extractor,
    augment=True):
        self.audio_files = audio_files
        self.labels = labels
        self.feature_extractor = feature_extractor
        self.augment = augment

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        # Load audio
        audio, sr = librosa.load(
        self.audio_files[idx],
        sr=self.feature_extractor.sample_rate
        )

        # Pad or trim to 1 second
        target_length = self.feature_extractor.sample_rate
        if len(audio) < target_length:
            audio = np.pad(audio, (0, target_length - len(audio)))
        else:
            audio = audio[:target_length]

            # Data augmentation
            if self.augment:
                audio = self._augment(audio)

                # Extract features
                features = self.feature_extractor.extract(audio)

                # Convert to tensor
                features = torch.FloatTensor(features).unsqueeze(0) # Add channel dim
                label = torch.LongTensor([self.labels[idx]])

                return features, label

    def _augment(self, audio):
        """
        Data augmentation

        - Add noise
        - Time shift
        - Speed perturbation
        """
        # Add background noise
        noise_level = np.random.uniform(0, 0.005)
        noise = np.random.randn(len(audio)) * noise_level
        audio = audio + noise

        # Time shift
        shift = np.random.randint(-1600, 1600) # ±100ms at 16kHz
        audio = np.roll(audio, shift)

        # Speed perturbation (simplified)
        speed_factor = np.random.uniform(0.9, 1.1)
        # In practice, use librosa.effects.time_stretch

        return audio

        # Create dataset
        dataset = KeywordSpottingDataset(
        audio_files=['audio1.wav', 'audio2.wav', ...],
        labels=[1, 0, ...], # 1=keyword, 0=no keyword
        feature_extractor=extractor,
        augment=True
        )

        dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

Training Loop

import torch
import torch.nn as nn

def train_keyword_spotting_model(model, train_loader, val_loader,
n_epochs=50, device='cuda'):
    """
    Train keyword spotting model

    Args:
        model: PyTorch model
        train_loader: Training data loader
        val_loader: Validation data loader
        n_epochs: Number of epochs
        device: Device to train on
        """
        model = model.to(device)

        # Loss and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', patience=5
        )

        best_val_acc = 0

        for epoch in range(n_epochs):
            # Training
            model.train()
            train_loss = 0
            train_correct = 0
            train_total = 0

            for features, labels in train_loader:
                features = features.to(device)
                labels = labels.squeeze().to(device)

                # Forward pass
                outputs = model(features)
                loss = criterion(outputs, labels)

                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # Track metrics
                train_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                train_correct += (predicted == labels).sum().item()
                train_total += labels.size(0)

                train_acc = train_correct / train_total

                # Validation
                model.eval()
                val_correct = 0
                val_total = 0

                with torch.no_grad():
                    for features, labels in val_loader:
                        features = features.to(device)
                        labels = labels.squeeze().to(device)

                        outputs = model(features)
                        _, predicted = torch.max(outputs, 1)

                        val_correct += (predicted == labels).sum().item()
                        val_total += labels.size(0)

                        val_acc = val_correct / val_total

                        # Learning rate scheduling
                        scheduler.step(val_acc)

                        # Save best model
                        if val_acc > best_val_acc:
                            best_val_acc = val_acc
                            torch.save(model.state_dict(), 'best_kws_model.pth')

                            print(f"Epoch {epoch+1}/{n_epochs}: "
                            f"Train Acc={train_acc:.4f}, Val Acc={val_acc:.4f}")

                            return model

                            # Train
                            model = KeywordSpottingCNN()
                            trained_model = train_keyword_spotting_model(
                            model, train_loader, val_loader, n_epochs=50
                            )

Deployment Optimization

Model Quantization

def quantize_kws_model(model):
    """
    Apply dynamic quantization to linear layers for edge deployment
    """
    import torch
    import torch.nn as nn

    model.eval()
    qmodel = torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)
    return qmodel

    # Quantize
    model_quantized = quantize_kws_model(model)

    # Compare serialized sizes for accurate measurement
    import io
    import torch

def get_model_size_mb(m):
    buffer = io.BytesIO()
    torch.save(m.state_dict(), buffer)
    return len(buffer.getvalue()) / (1024 ** 2)

    original_size = get_model_size_mb(model)
    quantized_size = get_model_size_mb(model_quantized)

    print(f"Original: {original_size:.2f} MB")
    print(f"Quantized: {quantized_size:.2f} MB")
    print(f"Compression: {original_size / max(quantized_size, 1e-6):.1f}x")

TensorFlow Lite Conversion

def convert_to_tflite(model, sample_input):
    """
    Convert PyTorch model to TensorFlow Lite

    For deployment on mobile/edge devices
    """
    import torch
    import onnx
    import tensorflow as tf
    from onnx_tf.backend import prepare

    # Step 1: PyTorch → ONNX
    torch.onnx.export(
    model,
    sample_input,
    'kws_model.onnx',
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}}
    )

    # Step 2: ONNX → TensorFlow
    onnx_model = onnx.load('kws_model.onnx')
    tf_rep = prepare(onnx_model)
    tf_rep.export_graph('kws_model_tf')

    # Step 3: TensorFlow → TFLite
    converter = tf.lite.TFLiteConverter.from_saved_model('kws_model_tf')

    # Optimization
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_spec.supported_types = [tf.float16]

    tflite_model = converter.convert()

    # Save
    with open('kws_model.tflite', 'wb') as f:
        f.write(tflite_model)

        print(f"TFLite model size: {len(tflite_model) / 1024:.1f} KB")

        # Convert
        convert_to_tflite(model, sample_input)

Real-time Inference

Streaming KWS System

import sounddevice as sd
import numpy as np
import torch
from collections import deque

class StreamingKeywordSpotter:
    """
    Real-time keyword spotting system

    Continuously monitors audio and detects keywords
    """

    def __init__(self, model, feature_extractor,
    threshold=0.8, cooldown_ms=1000):
        self.model = model
        self.model.eval()

        self.feature_extractor = feature_extractor
        self.threshold = threshold
        self.cooldown_samples = int(cooldown_ms * 16000 / 1000)

        # Audio buffer (1 second)
        self.buffer_size = 16000
        self.audio_buffer = deque(maxlen=self.buffer_size)

        # Detection cooldown
        self.last_detection = -self.cooldown_samples
        self.sample_count = 0

    def process_audio_chunk(self, audio_chunk):
        """
        Process incoming audio chunk

        Args:
            audio_chunk: Audio samples [n_samples]

            Returns:
                (detected, confidence) tuple
                """
                # Add to buffer
                self.audio_buffer.extend(audio_chunk)
                self.sample_count += len(audio_chunk)

                # Wait until buffer is full
                if len(self.audio_buffer) < self.buffer_size:
                    return False, 0.0

                    # Check cooldown
                    if self.sample_count - self.last_detection < self.cooldown_samples:
                        return False, 0.0

                        # Extract features
                        audio = np.array(self.audio_buffer)
                        features = self.feature_extractor.extract(audio)

                        # Add batch and channel dimensions
                        features_tensor = torch.FloatTensor(features).unsqueeze(0).unsqueeze(0)

                        # Run inference
                        with torch.no_grad():
                            output = self.model(features_tensor)
                            probs = torch.softmax(output, dim=1)
                            confidence = probs[0][1].item() # Probability of keyword

                            # Check threshold
                            if confidence >= self.threshold:
                                self.last_detection = self.sample_count
                                return True, confidence

                                return False, confidence

    def start_listening(self, callback=None):
        """
        Start continuous listening

        Args:
            callback: Function called when keyword detected
            """
            print("🎤 Listening for keywords...")

    def audio_callback(indata, frames, time_info, status):
        """Process audio in callback"""
        if status:
            print(f"Audio status: {status}")

            # Process chunk
            detected, confidence = self.process_audio_chunk(indata[:, 0])

            if detected:
                print(f"✓ Keyword detected! (confidence={confidence:.3f})")
                if callback:
                    callback(confidence)

                    # Start audio stream
                    with sd.InputStream(
                    samplerate=16000,
                    channels=1,
                    blocksize=1600, # 100ms chunks
                    callback=audio_callback
                    ):
                        print("Press Ctrl+C to stop")
                        sd.sleep(1000000) # Sleep indefinitely

                        # Usage
                        model = KeywordSpottingCNN()
                        model.load_state_dict(torch.load('best_kws_model.pth'))

                        spotter = StreamingKeywordSpotter(
                        model=model,
                        feature_extractor=extractor,
                        threshold=0.8,
                        cooldown_ms=1000
                        )

    def on_keyword_detected(confidence):
        """Callback when keyword detected"""
        print(f"🔔 Activating voice assistant... (conf={confidence:.2f})")
        # Trigger downstream processing

        spotter.start_listening(callback=on_keyword_detected)

Connection to Binary Search (DSA)

Keyword spotting uses binary search for threshold optimization:

class KeywordThresholdOptimizer:
    """
    Find optimal detection threshold using binary search

    Balances false accepts vs false rejects
    """

    def __init__(self, model, feature_extractor):
        self.model = model
        self.feature_extractor = feature_extractor
        self.model.eval()

    def find_optimal_threshold(self, positive_samples, negative_samples,
    target_far=0.01):
        """
        Binary search for threshold that achieves target FAR

        FAR = False Acceptance Rate

        Args:
            positive_samples: List of keyword audio samples
            negative_samples: List of non-keyword audio samples
            target_far: Target false acceptance rate (e.g., 0.01 = 1%)

            Returns:
                Optimal threshold
                """
                # Get confidence scores for all samples
                pos_scores = self._get_scores(positive_samples)
                neg_scores = self._get_scores(negative_samples)

                # Binary search on threshold space [0, 1]
                left, right = 0.0, 1.0
                best_threshold = 0.5

                for iteration in range(20): # 20 iterations for precision
                    mid = (left + right) / 2

                    # Calculate FAR at this threshold
                    false_accepts = sum(1 for score in neg_scores if score >= mid)
                    far = false_accepts / len(neg_scores)

                    # Calculate FRR at this threshold
                    false_rejects = sum(1 for score in pos_scores if score < mid)
                    frr = false_rejects / len(pos_scores)

                    print(f"Iteration {iteration}: threshold={mid:.4f}, "
                    f"FAR={far:.4f}, FRR={frr:.4f}")

                    # Adjust search space
                    if far > target_far:
                        # Too many false accepts, increase threshold
                        left = mid
                    else:
                        # FAR is good, try lowering threshold to reduce FRR
                        right = mid
                        best_threshold = mid

                        return best_threshold

    def _get_scores(self, audio_samples):
        """Get confidence scores for audio samples"""
        scores = []

        for audio in audio_samples:
            # Extract features
            features = self.feature_extractor.extract(audio)
            features_tensor = torch.FloatTensor(features).unsqueeze(0).unsqueeze(0)

            # Inference
            with torch.no_grad():
                output = self.model(features_tensor)
                probs = torch.softmax(output, dim=1)
                confidence = probs[0][1].item()
                scores.append(confidence)

                return scores

                # Usage
                optimizer = KeywordThresholdOptimizer(model, extractor)

                optimal_threshold = optimizer.find_optimal_threshold(
                positive_samples=keyword_audios,
                negative_samples=background_audios,
                target_far=0.01 # 1% false accept rate
                )

                print(f"Optimal threshold: {optimal_threshold:.4f}")

Advanced Model Architectures

1. Attention-Based KWS

import torch
import torch.nn as nn

class AttentionKWS(nn.Module):
    """
    Keyword spotting with attention mechanism

    Learns to focus on important parts of audio
    """

    def __init__(self, n_mels=40, hidden_dim=128, n_classes=2):
        super().__init__()

        # Bidirectional LSTM
        self.lstm = nn.LSTM(
        input_size=n_mels,
        hidden_size=hidden_dim,
        num_layers=2,
        batch_first=True,
        bidirectional=True
        )

        # Attention layer
        self.attention = nn.Sequential(
        nn.Linear(hidden_dim * 2, 64),
        nn.Tanh(),
        nn.Linear(64, 1)
        )

        # Classifier
        self.fc = nn.Linear(hidden_dim * 2, n_classes)

    def forward(self, x):
        """
        Args:
            x: [batch, time_steps, n_mels]

            Returns:
                [batch, n_classes]
                """
                # LSTM
                lstm_out, _ = self.lstm(x) # [batch, time, hidden*2]

                # Attention scores
                attention_scores = self.attention(lstm_out) # [batch, time, 1]
                attention_weights = torch.softmax(attention_scores, dim=1)

                # Weighted sum
                context = torch.sum(lstm_out * attention_weights, dim=1) # [batch, hidden*2]

                # Classify
                output = self.fc(context)

                return output, attention_weights

                # Usage
                model = AttentionKWS(n_mels=40, hidden_dim=128)

                # Train and visualize attention
                x = torch.randn(1, 100, 40) # 1 sample
                output, attention = model(x)

                # Visualize which parts of audio model focuses on
                import matplotlib.pyplot as plt
                plt.figure(figsize=(12, 4))
                plt.plot(attention[0].detach().numpy())
                plt.title('Attention Weights Over Time')
                plt.xlabel('Time Step')
                plt.ylabel('Attention Weight')
                plt.savefig('attention_visualization.png')

2. Res-Net Based KWS

import torch
import torch.nn as nn

class ResNetBlock(nn.Module):
    """Residual block for audio"""

    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        residual = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        out = self.relu(out)
        return out

    class ResNetKWS(nn.Module):
        """
        ResNet-based keyword spotting

        Deeper network for better accuracy
        """

    def __init__(self, n_mels=40, n_classes=2):
        super().__init__()

        # Initial conv
        self.conv1 = nn.Sequential(
        nn.Conv2d(1, 32, kernel_size=3, padding=1),
        nn.BatchNorm2d(32),
        nn.ReLU()
        )

        # Residual blocks
        self.res_blocks = nn.Sequential(
        ResNetBlock(32),
        ResNetBlock(32),
        ResNetBlock(32)
        )

        # Pooling
        self.pool = nn.AdaptiveAvgPool2d((1, 1))

        # Classifier
        self.fc = nn.Linear(32, n_classes)

    def forward(self, x):
        """
        Args:
            x: [batch, 1, n_mels, time_steps]
            """
            x = self.conv1(x)
            x = self.res_blocks(x)
            x = self.pool(x)
            x = x.view(x.size(0), -1)
            x = self.fc(x)
            return x

            model_resnet = ResNetKWS(n_mels=40)

3. Transformer-Based KWS

import torch
import torch.nn as nn
import numpy as np

class TransformerKWS(nn.Module):
    """
    Transformer for keyword spotting

    State-of-the-art performance but larger model
    """

    def __init__(self, n_mels=40, d_model=128, nhead=4,
    num_layers=2, n_classes=2):
        super().__init__()

        # Input projection
        self.input_proj = nn.Linear(n_mels, d_model)

        # Positional encoding
        self.pos_encoder = PositionalEncoding(d_model)

        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
        d_model=d_model,
        nhead=nhead,
        dim_feedforward=d_model * 4,
        dropout=0.1,
        batch_first=True
        )
        self.transformer = nn.TransformerEncoder(
        encoder_layer,
        num_layers=num_layers
        )

        # Classifier
        self.fc = nn.Linear(d_model, n_classes)

    def forward(self, x):
        """
        Args:
            x: [batch, time_steps, n_mels]
            """
            # Project input
            x = self.input_proj(x)

            # Add positional encoding
            x = self.pos_encoder(x)

            # Transformer
            x = self.transformer(x)

            # Global average pooling
            x = x.mean(dim=1)

            # Classify
            x = self.fc(x)
            return x

    class PositionalEncoding(nn.Module):
        """Positional encoding for transformer"""

    def __init__(self, d_model, max_len=5000):
        super().__init__()

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
        torch.arange(0, d_model, 2).float() *
        (-np.log(10000.0) / d_model)
        )

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

        model_transformer = TransformerKWS(n_mels=40)

Data Augmentation Strategies

For foundational augmentation techniques (time stretching, pitch shifting, noise mixing, SpecAugment), see Audio Preprocessing & Signal Processing.

Advanced Audio Augmentation

import librosa
import numpy as np

class AudioAugmenter:
    """
    Comprehensive audio augmentation for KWS training

    Improves robustness to real-world conditions
    """

    def __init__(self):
        self.sample_rate = 16000

    def time_stretch(self, audio, rate=None):
        """
        Stretch/compress audio in time

        Args:
            audio: Audio samples
            rate: Stretch factor (0.8-1.2 typical)
            """
            if rate is None:
                rate = np.random.uniform(0.9, 1.1)

                return librosa.effects.time_stretch(audio, rate=rate)

    def pitch_shift(self, audio, n_steps=None):
        """
        Shift pitch without changing speed

        Args:
            n_steps: Semitones to shift (-3 to +3 typical)
            """
            if n_steps is None:
                n_steps = np.random.randint(-2, 3)

                return librosa.effects.pitch_shift(
                audio,
                sr=self.sample_rate,
                n_steps=n_steps
                )

    def add_background_noise(self, audio, noise_audio, snr_db=None):
        """
        Add background noise at specified SNR

        Args:
            noise_audio: Background noise samples
            snr_db: Signal-to-noise ratio in dB (10-30 typical)
            """
            if snr_db is None:
                snr_db = np.random.uniform(10, 30)

                # Calculate noise scaling factor
                audio_power = np.mean(audio ** 2)
                noise_power = np.mean(noise_audio ** 2)

                snr_linear = 10 ** (snr_db / 10)
                noise_scale = np.sqrt(audio_power / (snr_linear * noise_power))

                # Mix audio and noise
                return audio + noise_scale * noise_audio

    def room_simulation(self, audio, room_size='medium'):
        """
        Simulate room acoustics (reverb)

        Args:
            room_size: 'small', 'medium', or 'large'
            """
            # Room impulse response parameters
            params = {
            'small': {'delay': 0.05, 'decay': 0.3},
            'medium': {'delay': 0.1, 'decay': 0.5},
            'large': {'delay': 0.2, 'decay': 0.7}
            }

            delay_samples = int(params[room_size]['delay'] * self.sample_rate)
            decay = params[room_size]['decay']

            # Simple reverb simulation
            reverb = np.zeros_like(audio)
            reverb[delay_samples:] = audio[:-delay_samples] * decay

            return audio + reverb

    def apply_compression(self, audio, threshold_db=-20):
        """
        Dynamic range compression

        Makes quiet sounds louder, loud sounds quieter
        """
        threshold = 10 ** (threshold_db / 20)
        compressed = np.copy(audio)

        # Compress samples above threshold
        mask = np.abs(audio) > threshold
        compressed[mask] = threshold + (audio[mask] - threshold) * 0.5

        return compressed

    def augment(self, audio):
        """
        Apply random augmentation pipeline

        Returns augmented audio
        """
        # Random selection of augmentations
        aug_functions = [
        lambda x: self.time_stretch(x),
        lambda x: self.pitch_shift(x),
        lambda x: self.apply_compression(x),
        ]

        # Apply 1-2 random augmentations
        n_augs = np.random.randint(1, 3)
        for _ in range(n_augs):
            aug_fn = np.random.choice(aug_functions)
            audio = aug_fn(audio)

            # Add background noise (always)
            noise = np.random.randn(len(audio)) * 0.005
            audio = self.add_background_noise(audio, noise)

            return audio

            # Usage in training
            augmenter = AudioAugmenter()

            # Augment training data
            augmented_audio = augmenter.augment(original_audio)

Production Deployment Patterns

Multi-Stage Detection Pipeline

import torch

class MultiStageKWSPipeline:
    """
    Multi-stage KWS for production

    Stage 1: Lightweight detector (always-on)
    Stage 2: Accurate model (triggered by stage 1)

    Optimizes power consumption vs accuracy
    """

    def __init__(self, stage1_model, stage2_model,
    stage1_threshold=0.7, stage2_threshold=0.9):
        self.stage1_model = stage1_model # Tiny model (~50KB)
        self.stage2_model = stage2_model # Accurate model (~500KB)

        self.stage1_threshold = stage1_threshold
        self.stage2_threshold = stage2_threshold

        self.stats = {
        'stage1_triggers': 0,
        'stage2_confirms': 0,
        'false_positives': 0
        }
        self.total_chunks = 0

    def detect(self, audio_chunk):
        """
        Two-stage detection

        Returns: (detected, confidence, stage)
        """
        # Increment processed chunks counter
        self.total_chunks += 1

        # Stage 1: Lightweight screening
        stage1_conf = self._run_stage1(audio_chunk)

        if stage1_conf < self.stage1_threshold:
            # Not a keyword, skip stage 2
            return False, stage1_conf, 1

            self.stats['stage1_triggers'] += 1

            # Stage 2: Accurate verification
            stage2_conf = self._run_stage2(audio_chunk)

            if stage2_conf >= self.stage2_threshold:
                self.stats['stage2_confirms'] += 1
                return True, stage2_conf, 2
            else:
                self.stats['false_positives'] += 1
                return False, stage2_conf, 2

    def _run_stage1(self, audio_chunk):
        """Run lightweight model"""
        features = extract_features_fast(audio_chunk)

        with torch.no_grad():
            output = self.stage1_model(features)
            confidence = torch.softmax(output, dim=1)[0][1].item()

            return confidence

    def _run_stage2(self, audio_chunk):
        """Run accurate model"""
        features = extract_features_high_quality(audio_chunk)

        with torch.no_grad():
            output = self.stage2_model(features)
            confidence = torch.softmax(output, dim=1)[0][1].item()

            return confidence

    def get_precision(self):
        """Calculate precision of two-stage system"""
        total_detections = self.stats['stage2_confirms'] + self.stats['false_positives']
        if total_detections == 0:
            return 0.0

            return self.stats['stage2_confirms'] / total_detections

    def get_power_savings(self):
        """Estimate power savings from two-stage approach"""
        # Stage 2 ~10x power of stage 1 (normalized units)
        stage2_invocations = self.stats['stage1_triggers']
        total = max(self.total_chunks, 1)
        cost_stage1 = 1.0
        cost_stage2 = 10.0

        energy_two_stage = total * cost_stage1 + stage2_invocations * cost_stage2
        energy_single_stage = total * cost_stage2

        savings = 1.0 - (energy_two_stage / energy_single_stage)
        return max(0.0, min(1.0, savings))

        # Usage
        pipeline = MultiStageKWSPipeline(
        stage1_model=lightweight_model,
        stage2_model=accurate_model
        )

        # Continuous monitoring
        for chunk in audio_stream:
            detected, confidence, stage = pipeline.detect(chunk)

            if detected:
                print(f"Keyword detected! (stage={stage}, conf={confidence:.3f})")

                print(f"Precision: {pipeline.get_precision():.2%}")
                print(f"Power savings: {pipeline.get_power_savings():.2%}")

On-Device Learning

class OnDeviceKWSLearner:
    """
    Personalized KWS with on-device learning

    Adapts to user's voice without sending data to cloud
    """

    def __init__(self, base_model):
        self.base_model = base_model

        # Freeze base model
        for param in self.base_model.parameters():
            param.requires_grad = False

            # Add personalization layer
            self.personalization_layer = nn.Linear(
            self.base_model.output_dim,
            2
            )

            self.optimizer = torch.optim.SGD(
            self.personalization_layer.parameters(),
            lr=0.01
            )

            self.user_examples = []
            self.max_examples = 50 # Limited on-device storage

    def collect_user_example(self, audio, label):
        """
        Collect user-specific training example

        Args:
            audio: User's audio sample
            label: 1 for keyword, 0 for non-keyword
            """
            features = extract_features(audio)

            self.user_examples.append((features, label))

            # Keep only recent examples
            if len(self.user_examples) > self.max_examples:
                self.user_examples.pop(0)

    def personalize(self, n_epochs=10):
        """
        Personalize model to user

        Quick fine-tuning on device
        """
        if len(self.user_examples) < 5:
            print("Not enough user examples yet")
            return

            print(f"Personalizing with {len(self.user_examples)} examples...")

            for epoch in range(n_epochs):
                total_loss = 0

                for features, label in self.user_examples:
                    # Extract base features
                    with torch.no_grad():
                        base_output = self.base_model(features)

                        # Personalization layer
                        output = self.personalization_layer(base_output)

                        # Loss
                        loss = nn.CrossEntropyLoss()(
                        output.unsqueeze(0),
                        torch.tensor([label])
                        )

                        # Update
                        self.optimizer.zero_grad()
                        loss.backward()
                        self.optimizer.step()

                        total_loss += loss.item()

                        if epoch % 5 == 0:
                            print(f"Epoch {epoch}: Loss = {total_loss / len(self.user_examples):.4f}")

                            print("Personalization complete!")

    def predict(self, audio):
        """Predict with personalized model"""
        features = extract_features(audio)

        with torch.no_grad():
            base_output = self.base_model(features)
            output = self.personalization_layer(base_output)
            confidence = torch.softmax(output, dim=1)[0][1].item()

            return confidence

            # Usage
            learner = OnDeviceKWSLearner(base_model)

            # User trains their custom wake word
            print("Please say your wake word 5 times...")
            for i in range(5):
                audio = record_audio()
                learner.collect_user_example(audio, label=1)

                print("Please say 5 non-wake-word phrases...")
                for i in range(5):
                    audio = record_audio()
                    learner.collect_user_example(audio, label=0)

                    # Personalize on-device
                    learner.personalize(n_epochs=20)

                    # Use personalized model
                    confidence = learner.predict(test_audio)

Real-World Integration Examples

Smart Speaker Integration

This pattern shows how KWS feeds into streaming speech pipelines for full voice assistant workflows.

import time

class SmartSpeakerKWS:
    """
    KWS integrated with smart speaker

    Handles wake word → command processing pipeline
    """

    def __init__(self, wake_word_model, command_asr_model):
        self.wake_word_model = wake_word_model
        self.command_asr_model = command_asr_model

        self.state = 'listening' # 'listening' or 'processing'
        self.wake_word_detected = False
        self.command_timeout = 5.0 # seconds

        async def process_audio_stream(self, audio_stream):
            """
            Main processing loop

            Always listening for wake word, then processes command
            """
            wake_word_detector = StreamingKeywordSpotter(
            model=self.wake_word_model,
            feature_extractor=KeywordSpottingFeatureExtractor(sample_rate=16000)
            )

            async for chunk in audio_stream:
                if self.state == 'listening':
                    # Check for wake word
                    detected, confidence = wake_word_detector.process_audio_chunk(chunk)

                    if detected:
                        print("🔊 Wake word detected!")
                        await self.play_sound('ding.wav') # Audio feedback

                        # Switch to command processing
                        self.state = 'processing'
                        self.wake_word_detected = True

                        # Start command capture
                        command_audio = await self.capture_command(audio_stream)

                        # Process command
                        await self.process_command(command_audio)

                        # Return to listening
                        self.state = 'listening'

                        async def capture_command(self, audio_stream, timeout=5.0):
                            """Capture user command after wake word"""
                            command_chunks = []
                            start_time = time.time()

                            async for chunk in audio_stream:
                                command_chunks.append(chunk)

                                # Check timeout
                                if time.time() - start_time > timeout:
                                    break

                                    # Check for silence (end of command)
                                    if self.is_silence(chunk):
                                        break

                                        return np.concatenate(command_chunks)

                                        async def process_command(self, command_audio):
                                            """Process voice command"""
                                            # Transcribe command
                                            transcription = self.command_asr_model.transcribe(command_audio)
                                            print(f"Command: {transcription}")

                                            # Execute command
                                            response = await self.execute_command(transcription)

                                            # Speak response
                                            await self.speak(response)

                                            async def execute_command(self, command):
                                                """Execute voice command"""
                                                # Command routing
                                                if 'weather' in command.lower():
                                                    return await self.get_weather()
                                                elif 'music' in command.lower():
                                                    return await self.play_music()
                                                elif 'timer' in command.lower():
                                                    return await self.set_timer(command)
                                                else:
                                                    return "Sorry, I didn't understand that."

                                                    # Usage
                                                    speaker = SmartSpeakerKWS(wake_word_model, command_asr_model)
                                                    await speaker.process_audio_stream(microphone_stream)

Mobile App Integration

class MobileKWSManager:
    """
    KWS manager for mobile apps

    Handles battery optimization and background processing
    """

    def __init__(self, model_path):
        self.model = self.load_optimized_model(model_path)
        self.is_active = False
        self.battery_saver_mode = False

        # Performance tracking
        self.battery_usage = 0
        self.detections = 0

    def load_optimized_model(self, model_path):
        """Load quantized model for mobile"""
        # Load TFLite model
        import tensorflow as tf
        interpreter = tf.lite.Interpreter(model_path=model_path)
        interpreter.allocate_tensors()

        return interpreter

    def start_listening(self, battery_level=100):
        """Start KWS with battery-aware mode"""
        self.is_active = True

        # Enable battery saver if low battery
        if battery_level < 20:
            self.enable_battery_saver()

            # Start audio capture thread
            self.audio_thread = threading.Thread(target=self._audio_processing_loop)
            self.audio_thread.start()

    def enable_battery_saver(self):
        """Enable battery saving mode"""
        self.battery_saver_mode = True

        # Reduce processing frequency
        self.chunk_duration_ms = 200 # Longer chunks = less processing

        # Lower threshold for stage 1
        self.stage1_threshold = 0.8 # Higher threshold = fewer stage 2 triggers

        print("⚡ Battery saver mode enabled")

    def _audio_processing_loop(self):
        """Background audio processing"""
        while self.is_active:
            # Capture audio
            audio_chunk = self.capture_audio_chunk()

            # Process
            detected, confidence = self.detect_keyword(audio_chunk)

            if detected:
                self.detections += 1
                self.trigger_callback(confidence)

                # Track battery usage (simplified)
                self.battery_usage += 0.001 # mAh per iteration

                # Sleep to save battery
                if self.battery_saver_mode:
                    time.sleep(0.1)

    def detect_keyword(self, audio_chunk):
        """Run inference on mobile"""
        # Extract features
        features = extract_features(audio_chunk)

        # TFLite inference
        input_details = self.model.get_input_details()
        output_details = self.model.get_output_details()

        self.model.set_tensor(input_details[0]['index'], features)
        self.model.invoke()

        output = self.model.get_tensor(output_details[0]['index'])
        confidence = output[0][1]

        return confidence > 0.8, confidence

    def get_battery_impact(self):
        """Estimate battery impact"""
        return {
        'total_usage_mah': self.battery_usage,
        'detections': self.detections,
        'usage_per_hour': self.battery_usage * 3600 # Extrapolate
        }

        # Usage in mobile app
        kws_manager = MobileKWSManager('kws_model.tflite')
        kws_manager.start_listening(battery_level=get_battery_level())

        # Check battery impact
        impact = kws_manager.get_battery_impact()
        print(f"Battery usage: {impact['usage_per_hour']:.2f} mAh/hour")

Key Takeaways

Ultra-lightweight models - < 1MB for edge deployment ✅ Real-time processing - < 50ms latency requirement ✅ Always-on capability - Low power consumption ✅ Noise robustness - Data augmentation and preprocessing critical ✅ Binary search optimization - Find optimal detection thresholds ✅ Model compression - Quantization, pruning for deployment ✅ Streaming architecture - Process continuous audio efficiently


FAQ

Q: What model architecture is best for keyword spotting on edge devices? A: A lightweight CNN operating on mel-spectrogram input is the best starting point for edge keyword spotting. With roughly 30K parameters (~100KB), it achieves sub-5ms inference on CPU. For higher accuracy, use a two-stage pipeline where the CNN screens candidates and a larger GRU or attention-based model confirms detections.

Q: How do I reduce false activations in wake word detection? A: Use a multi-stage detection pipeline: a lightweight Stage 1 model with a lower threshold screens candidates, then a more accurate Stage 2 model with a higher threshold confirms. Combine this with a cooldown period (1-2 seconds) between detections and binary search to find the optimal threshold that achieves less than 1% false acceptance rate.

Q: How much training data do I need for keyword spotting? A: You need at least a few thousand positive examples of the target keyword from diverse speakers, plus 5-10x more negative examples including similar-sounding words. The Google Speech Commands dataset provides a good baseline. Data augmentation (noise, pitch shift, time stretch, room simulation) can effectively multiply your dataset size by 5-10x.


Originally published at: arunbaby.com/speech-tech/0009-keyword-spotting

If you found this helpful, consider sharing it with others who might benefit.

Want to work together?

I take on projects, advisory roles, and fractional CTO engagements in AI/ML. I also help businesses go AI-native with agentic workflows and agent orchestration.

Get in touch