24 minute read

Separate overlapping speakers with 99%+ accuracy: Deep learning solves the cocktail party problem for meeting transcription and voice assistants.

TL;DR

Speech separation isolates individual speakers from mixed audio using deep learning. Conv-TasNet, the current gold standard, achieves SI-SDR improvements above 15 dB by learning time-domain representations and estimating speaker masks through Temporal Convolutional Networks. With chunk-based streaming, separation runs in real-time at sub-50ms latency — enabling production meeting transcription, voice assistants in multi-speaker environments, and hearing aid technology.

Three intertwined fiber optic cables in different colors being carefully separated by precision tweezers into individ...

What is Speech Separation?

Speech Separation (also called Source Separation or Speaker Separation) is the task of isolating individual speech sources from a mixture of overlapping speakers.

The Cocktail Party Problem

Humans can focus on a single speaker in a noisy, multi-speaker environment (like a cocktail party). Teaching machines to do the same is a fundamental challenge in speech processing.

Applications:

  • Meeting transcription with overlapping speech
  • Voice assistants in multi-speaker environments
  • Hearing aids for selective attention
  • Call center audio analysis
  • Video conferencing quality improvement

Problem Formulation

Input: Mixed audio with N speakers Output: N separated audio streams, one per speaker

Mixed Audio:
 Speaker 1 + Speaker 2 + ... + Speaker N + Noise

Goal:
 → Separated Speaker 1
 → Separated Speaker 2
 → ...
 → Separated Speaker N

Why is Speech Separation So Hard?

What Makes This Problem Fundamentally Difficult?

Let’s understand the fundamental challenge with a simple analogy:

The Paint Mixing Problem

Imagine you have:

  • Red paint (Speaker 1)
  • Blue paint (Speaker 2)
  • You mix them → Purple paint (Mixed audio)

Challenge: Given purple paint, separate back into red and blue!

This seems impossible because mixing is information-destructive. But speech separation works because:

  1. Speech has structure: Not random noise, but patterns (phonemes, pitch, timing)
  2. Speakers differ: Different voice characteristics (pitch, timbre, accent)
  3. Deep learning: Can learn these patterns from thousands of examples

The Human Cocktail Party Effect

At a party with multiple conversations, you can focus on one person and “tune out” others. How?

Human brain uses:

  • Spatial cues: Sound comes from different directions
  • Voice characteristics: Pitch, timbre, speaking style
  • Linguistic context: Grammar, meaning help predict words
  • Visual cues: Lip reading, body language

ML models use:

  • ❌ No spatial cues (single microphone input)
  • ✅ Voice characteristics (learned from data)
  • ✅ Temporal patterns (speaking rhythm)
  • ✅ Spectral patterns (frequency differences)

The Core Mathematical Challenge

Input: Mixed waveform M(t) = S1(t) + S2(t) + ... + Sn(t)

  • M(t): What we hear (mixture)
  • S1(t), S2(t), ...: Individual speakers (what we want)

Goal: Find a function f such that:

  • f(M)[S1, S2, ..., Sn]

Why this is hard:

  1. Underdetermined problem: One equation (mixture), N unknowns (sources)
  2. Non-linear mixing: In reality, it’s not just addition (room acoustics, etc.)
  3. Unknown N: We often don’t know how many speakers there are
  4. Permutation ambiguity: Output order doesn’t matter (Speaker 1 could be output 2)

Challenges

1. Permutation Problem - The Hardest Part

When you train a model:

Attempt 1:
Ground truth: [Speaker A, Speaker B]
Model output: [Speaker A, Speaker B] ✓ Matches!

Attempt 2:
Ground truth: [Speaker A, Speaker B] 
Model output: [Speaker B, Speaker A] ✓ Also correct! Just different order!

The problem: Standard loss (MSE) would say Attempt 2 is wrong!

# This would incorrectly penalize Attempt 2
loss = mse(output[0], speaker_A) + mse(output[1], speaker_B)

Solution: Try all permutations, use best one (Permutation Invariant Training)

# Try both orderings, pick better one
loss1 = mse(output[0], speaker_A) + mse(output[1], speaker_B)
loss2 = mse(output[0], speaker_B) + mse(output[1], speaker_A)
loss = min(loss1, loss2) # Use better permutation

2. Number of Speakers

Scenario Difficulty Solution
Fixed N (always 2 speakers) Easy Train model for N=2
Variable N (2-5 speakers) Hard Separate approaches: 1) Train multiple models, 2) Train one model + speaker counting
Unknown N Very Hard Need speaker counting + adaptive separation

3. Overlapping Speech

Scenario 1: Sequential (Easy)
Time: 0s 1s 2s 3s 4s
Speaker A: "Hello" 
Speaker B: "Hi"
 ↑ No overlap, trivial!

Scenario 2: Partial Overlap (Medium)
Time: 0s 1s 2s 3s 4s
Speaker A: "Hello there"
Speaker B: "Hi how are you"
 ↑ Some overlap

Scenario 3: Complete Overlap (Hard)
Time: 0s 1s 2s 3s 4s
Speaker A: "Hello there"
Speaker B: "Hi how are you"
 ↑ Both speaking simultaneously!

Why complete overlap is hard:

  • Maximum information loss
  • Voices blend in frequency domain
  • Harder to find distinguishing features

4. Quality Metrics

How do we measure separation quality?

Metric What it Measures Good Value
SDR (Signal-to-Distortion Ratio) Overall quality > 10 dB
SIR (Signal-to-Interference) How well other speakers removed > 15 dB
SAR (Signal-to-Artifacts) Artificial noise introduced > 10 dB
SI-SDR (Scale-Invariant SDR) Quality regardless of volume > 15 dB

Intuition: Higher dB = Better separation

SI-SDR = 0 dB → No separation (output = input)
SI-SDR = 10 dB → Good separation (10x better signal)
SI-SDR = 20 dB → Excellent (100x better signal!)

Traditional Approaches

Independent Component Analysis (ICA):

  • Assumes statistical independence
  • Works for determined/overdetermined cases
  • Limited by linear mixing assumption

Beamforming:

  • Uses spatial information from microphone array
  • Requires known speaker locations
  • Hardware-dependent

Non-Negative Matrix Factorization (NMF):

  • Factorizes spectrogram into basis and activation
  • Interpretable but limited capacity

Deep Learning Revolution

Modern approaches use end-to-end deep learning:

  • TasNet (Time-domain Audio Separation Network)
  • Conv-TasNet (Convolutional TasNet)
  • Dual-Path RNN
  • SepFormer (Transformer-based)

How Does Conv-TasNet Separate Speakers?

What is the Conv-TasNet Architecture?

Conv-TasNet is the gold standard for speech separation:

┌──────────────────────────────────────────────────────────┐
│ CONV-TASNET ARCHITECTURE │
├──────────────────────────────────────────────────────────┤
│ │
│ Input Waveform │
│ [batch, time] │
│ │ │
│ ▼ │
│ ┌─────────────┐ │
│ │ Encoder │ (1D Conv) │
│ │ 512 filters│ Learns time-domain basis functions │
│ └──────┬──────┘ │
│ │ │
│ ▼ │
│ [batch, 512, time'] │
│ │ │
│ ▼ │
│ ┌─────────────┐ │
│ │ Separator │ (TCN blocks) │
│ │ Temporal │ Estimates masks for each speaker │
│ │ Convolution│ │
│ │ Network │ │
│ └──────┬──────┘ │
│ │ │
│ ▼ │
│ [batch, n_speakers, 512, time'] │
│ (Masks for each speaker) │
│ │ │
│ ▼ │
│ ┌─────────────┐ │
│ │ Apply Mask │ (Element-wise multiply) │
│ └──────┬──────┘ │
│ │ │
│ ▼ │
│ [batch, n_speakers, 512, time'] │
│ (Masked representations) │
│ │ │
│ ▼ │
│ ┌─────────────┐ │
│ │ Decoder │ (1D Transposed Conv) │
│ │ n_speakers │ Reconstructs waveforms │
│ │ outputs │ │
│ └──────┬──────┘ │
│ │ │
│ ▼ │
│ Separated Waveforms │
│ [batch, n_speakers, time] │
│ │
└──────────────────────────────────────────────────────────┘

Implementation

import torch
import torch.nn as nn
import numpy as np

class ConvTasNet(nn.Module):
    """
    Conv-TasNet for speech separation

    Paper: "Conv-TasNet: Surpassing Ideal Time-Frequency Magnitude Masking
    for Speech Separation" (Luo & Mesgarani, 2019)

    Architecture:
        1. Encoder: Waveform → learned representation
        2. Separator: Mask estimation with TCN
        3. Decoder: Masked representation → waveforms
        """

    def __init__(
    self,
    n_src=2,
    n_filters=512,
    kernel_size=16,
    stride=8,
    n_blocks=8,
    n_repeats=3,
    bn_chan=128,
    hid_chan=512,
    skip_chan=128
    ):
        """
        Args:
            n_src: Number of sources (speakers)
            n_filters: Number of filters in encoder
            kernel_size: Encoder/decoder kernel size
            stride: Encoder/decoder stride
            n_blocks: Number of TCN blocks per repeat
            n_repeats: Number of times to repeat TCN blocks
            bn_chan: Bottleneck channels
            hid_chan: Hidden channels in TCN
            skip_chan: Skip connection channels
            """
            super().__init__()

            self.n_src = n_src

            # Encoder: waveform → representation
            self.encoder = nn.Conv1d(
            1,
            n_filters,
            kernel_size=kernel_size,
            stride=stride,
            padding=kernel_size // 2,
            bias=False
            )

            # Separator: Temporal Convolutional Network
            self.separator = TemporalConvNet(
            n_filters,
            n_src,
            n_blocks=n_blocks,
            n_repeats=n_repeats,
            bn_chan=bn_chan,
            hid_chan=hid_chan,
            skip_chan=skip_chan
            )

            # Decoder: representation → waveform
            self.decoder = nn.ConvTranspose1d(
            n_filters,
            1,
            kernel_size=kernel_size,
            stride=stride,
            padding=kernel_size // 2,
            bias=False
            )

    def forward(self, mixture):
        """
        Separate mixture into sources

        Args:
            mixture: [batch, time] mixed waveform

            Returns:
                separated: [batch, n_src, time] separated waveforms
                """
                batch_size = mixture.size(0)

                # Add channel dimension
                mixture = mixture.unsqueeze(1) # [batch, 1, time]

                # Encode
                encoded = self.encoder(mixture) # [batch, n_filters, time']

                # Estimate masks
                masks = self.separator(encoded) # [batch, n_src, n_filters, time']

                # Apply masks
                masked = encoded.unsqueeze(1) * masks # [batch, n_src, n_filters, time']

                # Decode each source
                separated = []

                for src_idx in range(self.n_src):
                    src_masked = masked[:, src_idx, :, :] # [batch, n_filters, time']
                    src_waveform = self.decoder(src_masked) # [batch, 1, time]
                    separated.append(src_waveform.squeeze(1)) # [batch, time]

                    # Stack sources
                    separated = torch.stack(separated, dim=1) # [batch, n_src, time]

                    # Trim to original length
                    if separated.size(-1) != mixture.size(-1):
                        separated = separated[..., :mixture.size(-1)]

                        return separated

    class TemporalConvNet(nn.Module):
        """
        Temporal Convolutional Network for mask estimation

        Stack of dilated 1D conv blocks with skip connections
        """

    def __init__(
    self,
    n_filters,
    n_src,
    n_blocks=8,
    n_repeats=3,
    bn_chan=128,
    hid_chan=512,
    skip_chan=128
    ):
        super().__init__()

        # Layer normalization
        self.layer_norm = nn.GroupNorm(1, n_filters)

        # Bottleneck (reduce dimensionality)
        self.bottleneck = nn.Conv1d(n_filters, bn_chan, 1)

        # TCN blocks
        self.tcn_blocks = nn.ModuleList()

        for r in range(n_repeats):
            for b in range(n_blocks):
                dilation = 2 ** b
                self.tcn_blocks.append(
                TCNBlock(
                bn_chan,
                hid_chan,
                skip_chan,
                kernel_size=3,
                dilation=dilation
                )
                )

                # Output projection
                self.output = nn.Sequential(
                nn.PReLU(),
                nn.Conv1d(skip_chan, n_filters * n_src, 1),
                )

                self.n_filters = n_filters
                self.n_src = n_src

    def forward(self, x):
        """
        Estimate masks for each source

        Args:
            x: [batch, n_filters, time']

            Returns:
                masks: [batch, n_src, n_filters, time']
                """
                batch_size, n_filters, time = x.size()

                # Normalize
                x = self.layer_norm(x)

                # Bottleneck
                x = self.bottleneck(x) # [batch, bn_chan, time']

                # Accumulate skip connections
                skip_sum = 0

                for block in self.tcn_blocks:
                    x, skip = block(x)
                    skip_sum = skip_sum + skip

                    # Output masks
                    masks = self.output(skip_sum) # [batch, n_filters * n_src, time']

                    # Reshape to [batch, n_src, n_filters, time']
                    masks = masks.view(batch_size, self.n_src, self.n_filters, time)

                    # Apply non-linearity (ReLU for masking)
                    masks = torch.relu(masks)

                    return masks

    class TCNBlock(nn.Module):
        """
        Single TCN block with dilated depthwise-separable convolution
        """

    def __init__(self, in_chan, hid_chan, skip_chan, kernel_size=3, dilation=1):
        super().__init__()

        # 1x1 conv
        self.conv1x1_1 = nn.Conv1d(in_chan, hid_chan, 1)
        self.prelu1 = nn.PReLU()
        self.norm1 = nn.GroupNorm(1, hid_chan)

        # Depthwise conv with dilation
        self.depthwise_conv = nn.Conv1d(
        hid_chan,
        hid_chan,
        kernel_size,
        padding=(kernel_size - 1) * dilation // 2,
        dilation=dilation,
        groups=hid_chan # Depthwise
        )
        self.prelu2 = nn.PReLU()
        self.norm2 = nn.GroupNorm(1, hid_chan)

        # 1x1 conv
        self.conv1x1_2 = nn.Conv1d(hid_chan, in_chan, 1)

        # Skip connection
        self.skip_conv = nn.Conv1d(hid_chan, skip_chan, 1)

    def forward(self, x):
        """
        Args:
            x: [batch, in_chan, time]

            Returns:
                output: [batch, in_chan, time]
                skip: [batch, skip_chan, time]
                """
                residual = x

                # 1x1 conv
                x = self.conv1x1_1(x)
                x = self.prelu1(x)
                x = self.norm1(x)

                # Depthwise conv
                x = self.depthwise_conv(x)
                x = self.prelu2(x)
                x = self.norm2(x)

                # Skip connection
                skip = self.skip_conv(x)

                # 1x1 conv
                x = self.conv1x1_2(x)

                # Residual connection
                output = x + residual

                return output, skip

                # Example usage
                model = ConvTasNet(n_src=2, n_filters=512)

                # Mixed waveform (2 speakers)
                mixture = torch.randn(4, 16000) # [batch=4, time=16000 (1 second at 16kHz)]

                # Separate
                separated = model(mixture) # [4, 2, 16000]

                print(f"Input shape: {mixture.shape}")
                print(f"Output shape: {separated.shape}")
                print(f"Separated speaker 1: {separated[:, 0, :].shape}")
                print(f"Separated speaker 2: {separated[:, 1, :].shape}")

Training with Permutation Invariant Loss

import torch
import torch.nn as nn
import torch.nn.functional as F

class PermutationInvariantLoss(nn.Module):
    """
    Permutation Invariant Training (PIT) loss

    Problem: Model outputs are in arbitrary order
    Solution: Try all permutations, use best one

    For 2 speakers:
        - Try (output1→target1, output2→target2)
        - Try (output1→target2, output2→target1)
        - Use permutation with lower loss
        """

    def __init__(self, loss_fn='si_sdr'):
        super().__init__()
        self.loss_fn = loss_fn

    def forward(self, estimated, target):
        """
        Compute PIT loss

        Args:
            estimated: [batch, n_src, time]
            target: [batch, n_src, time]

            Returns:
                loss: scalar
                """
                batch_size, n_src, time = estimated.size()

                # Generate all permutations
                import itertools
                perms = list(itertools.permutations(range(n_src)))

                # Compute loss for each permutation
                perm_losses = []

                for perm in perms:
                    # Reorder estimated according to permutation
                    estimated_perm = estimated[:, perm, :]

                    # Compute loss
                    if self.loss_fn == 'si_sdr':
                        loss = self._si_sdr_loss(estimated_perm, target)
                    elif self.loss_fn == 'mse':
                        loss = F.mse_loss(estimated_perm, target)
                    else:
                        raise ValueError(f"Unknown loss function: {self.loss_fn}")

                        perm_losses.append(loss)

                        # Stack losses
                        # [n_perms], take minimum (best permutation)
                        perm_losses = torch.stack(perm_losses)
                        return perm_losses.min()

    def _si_sdr_loss(self, estimated, target):
        """
        Scale-Invariant Signal-to-Distortion Ratio loss

        Better than MSE for speech separation
        """
        # Zero-mean
        estimated = estimated - estimated.mean(dim=-1, keepdim=True)
        target = target - target.mean(dim=-1, keepdim=True)

        # Project estimated onto target
        dot = (estimated * target).sum(dim=-1, keepdim=True)
        target_energy = (target ** 2).sum(dim=-1, keepdim=True) + 1e-8
        projection = dot * target / target_energy

        # Noise (estimation error)
        noise = estimated - projection

        # SI-SDR
        si_sdr = 10 * torch.log10(
        (projection ** 2).sum(dim=-1) / ((noise ** 2).sum(dim=-1) + 1e-8)
        )

        # Negative for loss (we want to maximize SI-SDR)
        return -si_sdr.mean()

        # Training loop
        model = ConvTasNet(n_src=2)
        criterion = PermutationInvariantLoss(loss_fn='si_sdr')
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    def train_epoch(model, train_loader, criterion, optimizer):
        """Train one epoch"""
        model.train()
        total_loss = 0

        for batch_idx, (mixture, target) in enumerate(train_loader):
            # mixture: [batch, time]
            # target: [batch, n_src, time]

            # Forward
            estimated = model(mixture)

            # Loss with PIT
            loss = criterion(estimated, target)

            # Backward
            optimizer.zero_grad()
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)

            optimizer.step()

            total_loss += loss.item()

            if batch_idx % 10 == 0:
                print(f"Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")

                return total_loss / len(train_loader)

                # Train
                # for epoch in range(num_epochs):
                # train_loss = train_epoch(model, train_loader, criterion, optimizer)
                # print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}")

How Do You Measure Separation Quality?

What is Signal-to-Distortion Ratio (SDR)?

def compute_sdr(estimated, target):
    """
    Compute Signal-to-Distortion Ratio

    SDR = 10 * log10(||target||^2 / ||target - estimated||^2)

    Higher is better. Good: > 10 dB, Great: > 15 dB
    """
    target = target - target.mean()
    estimated = estimated - estimated.mean()

    signal_power = np.sum(target ** 2)
    distortion = target - estimated
    distortion_power = np.sum(distortion ** 2) + 1e-10

    sdr = 10 * np.log10(signal_power / distortion_power)

    return sdr

def compute_si_sdr(estimated, target):
    """
    Compute Scale-Invariant SDR

    Invariant to scaling of the signal
    """
    # Zero-mean
    estimated = estimated - estimated.mean()
    target = target - target.mean()

    # Project estimated onto target
    alpha = np.dot(estimated, target) / (np.dot(target, target) + 1e-10)
    projection = alpha * target

    # Noise
    noise = estimated - projection

    # SI-SDR
    si_sdr = 10 * np.log10(
    np.sum(projection ** 2) / (np.sum(noise ** 2) + 1e-10)
    )

    return si_sdr

def compute_sir(estimated, target, interference):
    """
    Compute Signal-to-Interference Ratio

    Measures how well interfering speakers are suppressed
    """
    target = target - target.mean()
    estimated = estimated - estimated.mean()

    # Project estimated onto target
    s_target = np.dot(estimated, target) / (np.dot(target, target) + 1e-10) * target

    # Interference
    e_interf = 0
    for interf in interference:
        interf = interf - interf.mean()
        e_interf += np.dot(estimated, interf) / (np.dot(interf, interf) + 1e-10) * interf

        # SIR
        sir = 10 * np.log10(
        np.sum(s_target ** 2) / (np.sum(e_interf ** 2) + 1e-10)
        )

        return sir

        # Comprehensive evaluation
    def evaluate_separation(model, test_loader):
        """
        Evaluate separation quality

        Returns metrics for each source
        """
        model.eval()

        all_sdr = []
        all_si_sdr = []

        with torch.no_grad():
            for mixture, targets in test_loader:
                # Separate
                estimated = model(mixture)

                # Convert to numpy
                estimated_np = estimated.cpu().numpy()
                targets_np = targets.cpu().numpy()

                batch_size, n_src, time = estimated_np.shape

                # Compute metrics for each sample and source
                for b in range(batch_size):
                    for s in range(n_src):
                        est = estimated_np[b, s, :]
                        tgt = targets_np[b, s, :]

                        sdr = compute_sdr(est, tgt)
                        si_sdr = compute_si_sdr(est, tgt)

                        all_sdr.append(sdr)
                        all_si_sdr.append(si_sdr)

                        results = {
                        'sdr_mean': np.mean(all_sdr),
                        'sdr_std': np.std(all_sdr),
                        'si_sdr_mean': np.mean(all_si_sdr),
                        'si_sdr_std': np.std(all_si_sdr)
                        }

                        print("="*60)
                        print("SEPARATION EVALUATION RESULTS")
                        print("="*60)
                        print(f"SDR: {results['sdr_mean']:.2f} ± {results['sdr_std']:.2f} dB")
                        print(f"SI-SDR: {results['si_sdr_mean']:.2f} ± {results['si_sdr_std']:.2f} dB")
                        print("="*60)

                        return results

                        # Example
                        # results = evaluate_separation(model, test_loader)

Can Speech Separation Work in Real-Time?

How Does Streaming Separation Work?

class StreamingSpeechSeparator:
    """
    Real-time speech separation for streaming audio

    Challenges:
        - Causal processing (no future context)
        - Low latency (< 50ms)
        - State management across chunks
        """

    def __init__(self, model, chunk_size=4800, overlap=1200):
        """
        Args:
            model: Trained separation model
            chunk_size: Samples per chunk (300ms at 16kHz)
            overlap: Overlap between chunks (75ms at 16kHz)
            """
            self.model = model
            self.model.eval()

            self.chunk_size = chunk_size
            self.overlap = overlap
            self.hop_size = chunk_size - overlap

            # Buffer for overlapping
            self.input_buffer = np.zeros(overlap)
            self.output_buffers = [np.zeros(overlap) for _ in range(model.n_src)]

    def process_chunk(self, audio_chunk):
        """
        Process single audio chunk

        Args:
            audio_chunk: [chunk_size] numpy array

            Returns:
                separated_chunks: list of [hop_size] arrays, one per speaker
                """
                # Concatenate with buffer
                full_chunk = np.concatenate([self.input_buffer, audio_chunk])

                # Ensure correct size
                if len(full_chunk) < self.chunk_size:
                    full_chunk = np.pad(
                    full_chunk,
                    (0, self.chunk_size - len(full_chunk)),
                    mode='constant'
                    )

                    # Convert to tensor
                    with torch.no_grad():
                        chunk_tensor = torch.from_numpy(full_chunk).float().unsqueeze(0)

                        # Separate
                        separated = self.model(chunk_tensor) # [1, n_src, chunk_size]

                        # Convert back to numpy
                        separated_np = separated[0].cpu().numpy() # [n_src, chunk_size]

                        # Overlap-add
                        result_chunks = []

                        for src_idx in range(self.model.n_src):
                            src_audio = separated_np[src_idx]

                            # Add overlap from previous chunk
                            src_audio[:self.overlap] += self.output_buffers[src_idx]

                            # Extract output (without overlap)
                            output_chunk = src_audio[:self.hop_size]
                            result_chunks.append(output_chunk)

                            # Save overlap for next chunk
                            self.output_buffers[src_idx] = src_audio[-self.overlap:]

                            # Update input buffer
                            self.input_buffer = audio_chunk[-self.overlap:]

                            return result_chunks

    def reset(self):
        """Reset state for new stream"""
        self.input_buffer = np.zeros(self.overlap)
        self.output_buffers = [np.zeros(self.overlap) for _ in range(self.model.n_src)]

        # Example: Real-time separation server
        from fastapi import FastAPI, WebSocket
        import asyncio

        app = FastAPI()

        # Load model
        model = ConvTasNet(n_src=2)
        model.load_state_dict(torch.load('convtasnet_separation.pth'))
        separator = StreamingSpeechSeparator(model, chunk_size=4800, overlap=1200)

        @app.websocket("/separate")
        async def websocket_separation(websocket: WebSocket):
            """
            WebSocket endpoint for real-time separation

            Client sends audio chunks, receives separated streams
            """
            await websocket.accept()

            try:
                while True:
                    # Receive audio chunk
                    data = await websocket.receive_bytes()

                    # Decode audio (assuming 16-bit PCM)
                    audio_chunk = np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0

                    # Separate
                    separated_chunks = separator.process_chunk(audio_chunk)

                    # Send separated streams
                    for src_idx, src_chunk in enumerate(separated_chunks):
                        # Encode back to 16-bit PCM
                        src_bytes = (src_chunk * 32768).astype(np.int16).tobytes()

                        await websocket.send_json({
                        'speaker_id': src_idx,
                        'audio': src_bytes.hex()
                        })

                    except Exception as e:
                        print(f"WebSocket error: {e}")
                    finally:
                        separator.reset()
                        await websocket.close()

                        # Run server
                        # uvicorn.run(app, host='0.0.0.0', port=8000)

How Do You Integrate Separation with ASR?

What Does a Separation + Transcription Pipeline Look Like?

class SeparationASRPipeline:
    """
    Combined pipeline: Separate speakers → Transcribe each

    Use case: Meeting transcription with overlapping speech
    """

    def __init__(self, separation_model, asr_model):
        self.separator = separation_model
        self.asr = asr_model

    def transcribe_multi_speaker(self, audio):
        """
        Transcribe audio with multiple speakers

        Args:
            audio: Mixed audio

            Returns:
                List of (speaker_id, transcript) tuples
                """
                # Separate speakers
                with torch.no_grad():
                    audio_tensor = torch.from_numpy(audio).float().unsqueeze(0)
                    separated = self.separator(audio_tensor)[0] # [n_src, time]

                    # Transcribe each speaker
                    transcripts = []

                    for speaker_id in range(separated.size(0)):
                        speaker_audio = separated[speaker_id].cpu().numpy()

                        # Transcribe
                        transcript = self.asr.transcribe(speaker_audio)

                        transcripts.append({
                        'speaker_id': speaker_id,
                        'transcript': transcript,
                        'audio_length_sec': len(speaker_audio) / 16000
                        })

                        return transcripts

    def transcribe_with_diarization(self, audio):
        """
        Transcribe with speaker diarization

        Diarization: Who spoke when?
        Separation: Isolate each speaker's audio
        ASR: Transcribe each speaker
        """
        # Separate speakers
        with torch.no_grad():
            audio_tensor = torch.from_numpy(audio).float().unsqueeze(0)
            separated = self.separator(audio_tensor)[0] # [n_src, time]

            # Speaker diarization on each separated stream
            diarization_results = []

            for speaker_id in range(separated.size(0)):
                speaker_audio = separated[speaker_id].cpu().numpy()
                # Voice Activity Detection
                vad_segments = self._detect_voice_activity(speaker_audio)

                # Transcribe active segments
                for segment in vad_segments:
                    start_idx = int(segment['start'] * 16000)
                    end_idx = int(segment['end'] * 16000)

                    segment_audio = speaker_audio[start_idx:end_idx]
                    transcript = self.asr.transcribe(segment_audio)

                    diarization_results.append({
                    'speaker_id': speaker_id,
                    'start_time': segment['start'],
                    'end_time': segment['end'],
                    'transcript': transcript
                    })

                    # Sort by start time
                    diarization_results.sort(key=lambda x: x['start_time'])

                    return diarization_results

    def _detect_voice_activity(self, audio, frame_duration=0.03):
        """
        Simple energy-based VAD

        Returns list of (start, end) segments with voice activity
        """
        import librosa

        # Compute energy
        frame_length = int(frame_duration * 16000)
        energy = librosa.feature.rms(
        y=audio,
        frame_length=frame_length,
        hop_length=frame_length // 2
        )[0]

        # Threshold
        threshold = np.mean(energy) * 0.5

        # Find voice segments
        is_voice = energy > threshold

        segments = []
        in_segment = False
        start = 0

        for i, voice in enumerate(is_voice):
            if voice and not in_segment:
                start = i * frame_duration / 2
                in_segment = True
            elif not voice and in_segment:
                end = i * frame_duration / 2
                segments.append({'start': start, 'end': end})
                in_segment = False

                return segments

                # Example usage
                separation_model = ConvTasNet(n_src=2)
                separation_model.load_state_dict(torch.load('separation_model.pth'))

                # Mock ASR model
    class MockASR:
    def transcribe(self, audio):
        return f"Transcribed {len(audio)} samples"

        asr_model = MockASR()

        pipeline = SeparationASRPipeline(separation_model, asr_model)

        # Transcribe multi-speaker audio
        audio = np.random.randn(16000 * 10) # 10 seconds
        results = pipeline.transcribe_multi_speaker(audio)

        print("Transcription results:")
        for result in results:
            print(f"Speaker {result['speaker_id']}: {result['transcript']}")

What About Unknown Numbers of Speakers?

How Do You Handle a Variable Number of Sources?

class AdaptiveSeparationModel(nn.Module):
    """
    Separate audio with unknown number of speakers

    Approach:
        1. Estimate number of speakers
        2. Separate into estimated number of sources
        3. Filter empty sources
        """

    def __init__(self, max_speakers=10):
        super().__init__()

        self.max_speakers = max_speakers

        # Speaker counting network
        self.speaker_counter = nn.Sequential(
        nn.Conv1d(1, 128, kernel_size=3, stride=2),
        nn.ReLU(),
        nn.Conv1d(128, 256, kernel_size=3, stride=2),
        nn.ReLU(),
        nn.AdaptiveAvgPool1d(1),
        nn.Flatten(),
        nn.Linear(256, max_speakers + 1), # 0 to max_speakers
        nn.Softmax(dim=-1)
        )

        # Separation models for different numbers of speakers
        self.separators = nn.ModuleList([
        ConvTasNet(n_src=n) for n in range(1, max_speakers + 1)
        ])

    def forward(self, mixture):
        """
        Separate with adaptive number of sources

        Args:
            mixture: [batch, time]

            Returns:
                separated: list of [batch, time] tensors (one per active speaker)
                """
                # Estimate number of speakers
                mixture_1d = mixture.unsqueeze(1) # [batch, 1, time]
                speaker_probs = self.speaker_counter(mixture_1d) # [batch, max_speakers + 1]

                n_speakers = speaker_probs.argmax(dim=-1) # [batch]

                # For simplicity, use max in batch (in practice, process per sample)
                max_n_speakers = n_speakers.max().item()

                if max_n_speakers == 0:
                    return []

                    # Separate using appropriate model
                    separator = self.separators[max_n_speakers - 1]
                    separated = separator(mixture) # [batch, n_src, time]

                    return separated

                    # Example
                    model = AdaptiveSeparationModel(max_speakers=5)

                    # Test with 2 speakers
                    mixture = torch.randn(1, 16000)
                    separated = model(mixture)

                    print(f"Estimated sources: {separated.size(1)}")

Multi-Channel Separation

class MultiChannelSeparator(nn.Module):
    """
    Use multiple microphones for better separation

    Microphone array provides spatial information
    """

    def __init__(self, n_channels, n_src):
        super().__init__()

        self.n_channels = n_channels
        self.n_src = n_src

        # Encoder for each channel
        self.encoders = nn.ModuleList([
        nn.Conv1d(1, 256, kernel_size=16, stride=8)
        for _ in range(n_channels)
        ])

        # Cross-channel attention
        self.cross_channel_attention = nn.MultiheadAttention(
        embed_dim=256 * n_channels,
        num_heads=8
        )

        # Separator
        self.separator = TemporalConvNet(
        256 * n_channels,
        n_src,
        n_blocks=8,
        n_repeats=3,
        bn_chan=128,
        hid_chan=512,
        skip_chan=128
        )

        # Decoder
        self.decoder = nn.ConvTranspose1d(256, 1, kernel_size=16, stride=8)

    def forward(self, multi_channel_mixture):
        """
        Separate using multi-channel input

        Args:
            multi_channel_mixture: [batch, n_channels, time]

            Returns:
                separated: [batch, n_src, time]
                """
                batch_size, n_channels, time = multi_channel_mixture.size()

                # Encode each channel
                encoded_channels = []

                for ch in range(n_channels):
                    ch_audio = multi_channel_mixture[:, ch:ch+1, :] # [batch, 1, time]
                    ch_encoded = self.encoders[ch](ch_audio) # [batch, 256, time']
                    encoded_channels.append(ch_encoded)

                    # Concatenate channels
                    encoded = torch.cat(encoded_channels, dim=1) # [batch, 256 * n_channels, time']

                    # Cross-channel attention
                    # Reshape for attention: [time', batch, 256 * n_channels]
                    encoded_t = encoded.permute(2, 0, 1)
                    attended, _ = self.cross_channel_attention(encoded_t, encoded_t, encoded_t)
                    attended = attended.permute(1, 2, 0) # [batch, 256 * n_channels, time']

                    # Separate
                    masks = self.separator(attended) # [batch, n_src, 256 * n_channels, time']

                    # Apply masks and decode
                    separated = []

                    for src_idx in range(self.n_src):
                        masked = attended * masks[:, src_idx, :, :]

                        # Take first 256 channels for decoding
                        masked_single = masked[:, :256, :]

                        src_waveform = self.decoder(masked_single).squeeze(1)
                        separated.append(src_waveform)

                        separated = torch.stack(separated, dim=1)

                        return separated

                        # Example: 4-microphone array
                        model = MultiChannelSeparator(n_channels=4, n_src=2)

                        # 4-channel input
                        multi_channel_audio = torch.randn(1, 4, 16000)

                        separated = model(multi_channel_audio)
                        print(f"Separated shape: {separated.shape}") # [1, 2, 16000]

FAQ

Q: What is speech separation in machine learning? A: Speech separation is the task of isolating individual speech sources from a mixture of overlapping speakers. It uses deep learning models like Conv-TasNet to estimate masks that separate each speaker’s audio from the mix, achieving SI-SDR improvements above 15 dB.

Q: How does Conv-TasNet work for speech separation? A: Conv-TasNet uses a three-stage architecture: an encoder converts waveforms to learned representations, a Temporal Convolutional Network estimates masks for each speaker, and a decoder reconstructs separated waveforms. It operates in the time domain with 512 filters and achieves state-of-the-art separation quality.

Q: What is Permutation Invariant Training (PIT)? A: Permutation Invariant Training solves the problem that separated outputs can be in any order. It computes the loss for all possible output-to-target permutations and uses the permutation with the lowest loss for backpropagation.

Q: Can speech separation work in real-time? A: Yes. Streaming speech separation processes audio in chunks (typically 300ms with 75ms overlap) using overlap-add reconstruction, achieving latency under 50ms and a real-time factor below 0.1.

Q: What metrics measure speech separation quality? A: Key metrics include SDR (Signal-to-Distortion Ratio, good above 10 dB), SI-SDR (Scale-Invariant SDR, good above 15 dB), SIR (Signal-to-Interference Ratio), and SAR (Signal-to-Artifacts Ratio). Higher values indicate better separation.


Key Takeaways

Conv-TasNet - State-of-the-art time-domain separation ✅ PIT loss - Handle output permutation problem ✅ SI-SDR metric - Scale-invariant quality measure ✅ Real-time streaming - Chunk-based processing with overlap-add ✅ Integration with ASR - End-to-end meeting transcription

Performance Targets:

  • SI-SDR improvement: > 15 dB
  • Real-time factor: < 0.1 (10x faster than real-time)
  • Latency: < 50ms for streaming
  • Works with 2-5 overlapping speakers

Originally published at: arunbaby.com/speech-tech/0011-speech-separation

If you found this helpful, consider sharing it with others who might benefit.

Want to work together?

I take on projects, advisory roles, and fractional CTO engagements in AI/ML. I also help businesses go AI-native with agentic workflows and agent orchestration.

Get in touch