Speech Separation
Separate overlapping speakers with 99%+ accuracy: Deep learning solves the cocktail party problem for meeting transcription and voice assistants.
Problem Statement
Speech Separation (also called Source Separation or Speaker Separation) is the task of isolating individual speech sources from a mixture of overlapping speakers.
The Cocktail Party Problem
Humans can focus on a single speaker in a noisy, multi-speaker environment (like a cocktail party). Teaching machines to do the same is a fundamental challenge in speech processing.
Applications:
- Meeting transcription with overlapping speech
- Voice assistants in multi-speaker environments
- Hearing aids for selective attention
- Call center audio analysis
- Video conferencing quality improvement
Problem Formulation
Input: Mixed audio with N speakers
Output: N separated audio streams, one per speaker
Mixed Audio:
Speaker 1 + Speaker 2 + ... + Speaker N + Noise
Goal:
→ Separated Speaker 1
→ Separated Speaker 2
→ ...
→ Separated Speaker N
Understanding Speech Separation
Why is Speech Separation Hard?
Let’s understand the fundamental challenge with a simple analogy:
The Paint Mixing Problem
Imagine you have:
- Red paint (Speaker 1)
- Blue paint (Speaker 2)
- You mix them → Purple paint (Mixed audio)
Challenge: Given purple paint, separate back into red and blue!
This seems impossible because mixing is information-destructive. But speech separation works because:
- Speech has structure: Not random noise, but patterns (phonemes, pitch, timing)
- Speakers differ: Different voice characteristics (pitch, timbre, accent)
- Deep learning: Can learn these patterns from thousands of examples
The Human Cocktail Party Effect
At a party with multiple conversations, you can focus on one person and “tune out” others. How?
Human brain uses:
- Spatial cues: Sound comes from different directions
- Voice characteristics: Pitch, timbre, speaking style
- Linguistic context: Grammar, meaning help predict words
- Visual cues: Lip reading, body language
ML models use:
- ❌ No spatial cues (single microphone input)
- ✅ Voice characteristics (learned from data)
- ✅ Temporal patterns (speaking rhythm)
- ✅ Spectral patterns (frequency differences)
The Core Mathematical Challenge
Input: Mixed waveform M(t) = S1(t) + S2(t) + ... + Sn(t)
M(t)
: What we hear (mixture)S1(t), S2(t), ...
: Individual speakers (what we want)
Goal: Find a function f
such that:
f(M)
→[S1, S2, ..., Sn]
Why this is hard:
- Underdetermined problem: One equation (mixture), N unknowns (sources)
- Non-linear mixing: In reality, it’s not just addition (room acoustics, etc.)
- Unknown N: We often don’t know how many speakers there are
- Permutation ambiguity: Output order doesn’t matter (Speaker 1 could be output 2)
Challenges
1. Permutation Problem - The Hardest Part
When you train a model:
Attempt 1:
Ground truth: [Speaker A, Speaker B]
Model output: [Speaker A, Speaker B] ✓ Matches!
Attempt 2:
Ground truth: [Speaker A, Speaker B]
Model output: [Speaker B, Speaker A] ✓ Also correct! Just different order!
The problem: Standard loss (MSE) would say Attempt 2 is wrong!
# This would incorrectly penalize Attempt 2
loss = mse(output[0], speaker_A) + mse(output[1], speaker_B)
Solution: Try all permutations, use best one (Permutation Invariant Training)
# Try both orderings, pick better one
loss1 = mse(output[0], speaker_A) + mse(output[1], speaker_B)
loss2 = mse(output[0], speaker_B) + mse(output[1], speaker_A)
loss = min(loss1, loss2) # Use better permutation
2. Number of Speakers
Scenario | Difficulty | Solution |
---|---|---|
Fixed N (always 2 speakers) | Easy | Train model for N=2 |
Variable N (2-5 speakers) | Hard | Separate approaches: 1) Train multiple models, 2) Train one model + speaker counting |
Unknown N | Very Hard | Need speaker counting + adaptive separation |
3. Overlapping Speech
Scenario 1: Sequential (Easy)
Time: 0s 1s 2s 3s 4s
Speaker A: "Hello"
Speaker B: "Hi"
↑ No overlap, trivial!
Scenario 2: Partial Overlap (Medium)
Time: 0s 1s 2s 3s 4s
Speaker A: "Hello there"
Speaker B: "Hi how are you"
↑ Some overlap
Scenario 3: Complete Overlap (Hard)
Time: 0s 1s 2s 3s 4s
Speaker A: "Hello there"
Speaker B: "Hi how are you"
↑ Both speaking simultaneously!
Why complete overlap is hard:
- Maximum information loss
- Voices blend in frequency domain
- Harder to find distinguishing features
4. Quality Metrics
How do we measure separation quality?
Metric | What it Measures | Good Value |
---|---|---|
SDR (Signal-to-Distortion Ratio) | Overall quality | > 10 dB |
SIR (Signal-to-Interference) | How well other speakers removed | > 15 dB |
SAR (Signal-to-Artifacts) | Artificial noise introduced | > 10 dB |
SI-SDR (Scale-Invariant SDR) | Quality regardless of volume | > 15 dB |
Intuition: Higher dB = Better separation
SI-SDR = 0 dB → No separation (output = input)
SI-SDR = 10 dB → Good separation (10x better signal)
SI-SDR = 20 dB → Excellent (100x better signal!)
Traditional Approaches
Independent Component Analysis (ICA):
- Assumes statistical independence
- Works for determined/overdetermined cases
- Limited by linear mixing assumption
Beamforming:
- Uses spatial information from microphone array
- Requires known speaker locations
- Hardware-dependent
Non-Negative Matrix Factorization (NMF):
- Factorizes spectrogram into basis and activation
- Interpretable but limited capacity
Deep Learning Revolution
Modern approaches use end-to-end deep learning:
- TasNet (Time-domain Audio Separation Network)
- Conv-TasNet (Convolutional TasNet)
- Dual-Path RNN
- SepFormer (Transformer-based)
Solution 1: Conv-TasNet Architecture
Architecture Overview
Conv-TasNet is the gold standard for speech separation:
┌──────────────────────────────────────────────────────────┐
│ CONV-TASNET ARCHITECTURE │
├──────────────────────────────────────────────────────────┤
│ │
│ Input Waveform │
│ [batch, time] │
│ │ │
│ ▼ │
│ ┌─────────────┐ │
│ │ Encoder │ (1D Conv) │
│ │ 512 filters│ Learns time-domain basis functions │
│ └──────┬──────┘ │
│ │ │
│ ▼ │
│ [batch, 512, time'] │
│ │ │
│ ▼ │
│ ┌─────────────┐ │
│ │ Separator │ (TCN blocks) │
│ │ Temporal │ Estimates masks for each speaker │
│ │ Convolution│ │
│ │ Network │ │
│ └──────┬──────┘ │
│ │ │
│ ▼ │
│ [batch, n_speakers, 512, time'] │
│ (Masks for each speaker) │
│ │ │
│ ▼ │
│ ┌─────────────┐ │
│ │ Apply Mask │ (Element-wise multiply) │
│ └──────┬──────┘ │
│ │ │
│ ▼ │
│ [batch, n_speakers, 512, time'] │
│ (Masked representations) │
│ │ │
│ ▼ │
│ ┌─────────────┐ │
│ │ Decoder │ (1D Transposed Conv) │
│ │ n_speakers │ Reconstructs waveforms │
│ │ outputs │ │
│ └──────┬──────┘ │
│ │ │
│ ▼ │
│ Separated Waveforms │
│ [batch, n_speakers, time] │
│ │
└──────────────────────────────────────────────────────────┘
Implementation
import torch
import torch.nn as nn
import numpy as np
class ConvTasNet(nn.Module):
"""
Conv-TasNet for speech separation
Paper: "Conv-TasNet: Surpassing Ideal Time-Frequency Magnitude Masking
for Speech Separation" (Luo & Mesgarani, 2019)
Architecture:
1. Encoder: Waveform → learned representation
2. Separator: Mask estimation with TCN
3. Decoder: Masked representation → waveforms
"""
def __init__(
self,
n_src=2,
n_filters=512,
kernel_size=16,
stride=8,
n_blocks=8,
n_repeats=3,
bn_chan=128,
hid_chan=512,
skip_chan=128
):
"""
Args:
n_src: Number of sources (speakers)
n_filters: Number of filters in encoder
kernel_size: Encoder/decoder kernel size
stride: Encoder/decoder stride
n_blocks: Number of TCN blocks per repeat
n_repeats: Number of times to repeat TCN blocks
bn_chan: Bottleneck channels
hid_chan: Hidden channels in TCN
skip_chan: Skip connection channels
"""
super().__init__()
self.n_src = n_src
# Encoder: waveform → representation
self.encoder = nn.Conv1d(
1,
n_filters,
kernel_size=kernel_size,
stride=stride,
padding=kernel_size // 2,
bias=False
)
# Separator: Temporal Convolutional Network
self.separator = TemporalConvNet(
n_filters,
n_src,
n_blocks=n_blocks,
n_repeats=n_repeats,
bn_chan=bn_chan,
hid_chan=hid_chan,
skip_chan=skip_chan
)
# Decoder: representation → waveform
self.decoder = nn.ConvTranspose1d(
n_filters,
1,
kernel_size=kernel_size,
stride=stride,
padding=kernel_size // 2,
bias=False
)
def forward(self, mixture):
"""
Separate mixture into sources
Args:
mixture: [batch, time] mixed waveform
Returns:
separated: [batch, n_src, time] separated waveforms
"""
batch_size = mixture.size(0)
# Add channel dimension
mixture = mixture.unsqueeze(1) # [batch, 1, time]
# Encode
encoded = self.encoder(mixture) # [batch, n_filters, time']
# Estimate masks
masks = self.separator(encoded) # [batch, n_src, n_filters, time']
# Apply masks
masked = encoded.unsqueeze(1) * masks # [batch, n_src, n_filters, time']
# Decode each source
separated = []
for src_idx in range(self.n_src):
src_masked = masked[:, src_idx, :, :] # [batch, n_filters, time']
src_waveform = self.decoder(src_masked) # [batch, 1, time]
separated.append(src_waveform.squeeze(1)) # [batch, time]
# Stack sources
separated = torch.stack(separated, dim=1) # [batch, n_src, time]
# Trim to original length
if separated.size(-1) != mixture.size(-1):
separated = separated[..., :mixture.size(-1)]
return separated
class TemporalConvNet(nn.Module):
"""
Temporal Convolutional Network for mask estimation
Stack of dilated 1D conv blocks with skip connections
"""
def __init__(
self,
n_filters,
n_src,
n_blocks=8,
n_repeats=3,
bn_chan=128,
hid_chan=512,
skip_chan=128
):
super().__init__()
# Layer normalization
self.layer_norm = nn.GroupNorm(1, n_filters)
# Bottleneck (reduce dimensionality)
self.bottleneck = nn.Conv1d(n_filters, bn_chan, 1)
# TCN blocks
self.tcn_blocks = nn.ModuleList()
for r in range(n_repeats):
for b in range(n_blocks):
dilation = 2 ** b
self.tcn_blocks.append(
TCNBlock(
bn_chan,
hid_chan,
skip_chan,
kernel_size=3,
dilation=dilation
)
)
# Output projection
self.output = nn.Sequential(
nn.PReLU(),
nn.Conv1d(skip_chan, n_filters * n_src, 1),
)
self.n_filters = n_filters
self.n_src = n_src
def forward(self, x):
"""
Estimate masks for each source
Args:
x: [batch, n_filters, time']
Returns:
masks: [batch, n_src, n_filters, time']
"""
batch_size, n_filters, time = x.size()
# Normalize
x = self.layer_norm(x)
# Bottleneck
x = self.bottleneck(x) # [batch, bn_chan, time']
# Accumulate skip connections
skip_sum = 0
for block in self.tcn_blocks:
x, skip = block(x)
skip_sum = skip_sum + skip
# Output masks
masks = self.output(skip_sum) # [batch, n_filters * n_src, time']
# Reshape to [batch, n_src, n_filters, time']
masks = masks.view(batch_size, self.n_src, self.n_filters, time)
# Apply non-linearity (ReLU for masking)
masks = torch.relu(masks)
return masks
class TCNBlock(nn.Module):
"""
Single TCN block with dilated depthwise-separable convolution
"""
def __init__(self, in_chan, hid_chan, skip_chan, kernel_size=3, dilation=1):
super().__init__()
# 1x1 conv
self.conv1x1_1 = nn.Conv1d(in_chan, hid_chan, 1)
self.prelu1 = nn.PReLU()
self.norm1 = nn.GroupNorm(1, hid_chan)
# Depthwise conv with dilation
self.depthwise_conv = nn.Conv1d(
hid_chan,
hid_chan,
kernel_size,
padding=(kernel_size - 1) * dilation // 2,
dilation=dilation,
groups=hid_chan # Depthwise
)
self.prelu2 = nn.PReLU()
self.norm2 = nn.GroupNorm(1, hid_chan)
# 1x1 conv
self.conv1x1_2 = nn.Conv1d(hid_chan, in_chan, 1)
# Skip connection
self.skip_conv = nn.Conv1d(hid_chan, skip_chan, 1)
def forward(self, x):
"""
Args:
x: [batch, in_chan, time]
Returns:
output: [batch, in_chan, time]
skip: [batch, skip_chan, time]
"""
residual = x
# 1x1 conv
x = self.conv1x1_1(x)
x = self.prelu1(x)
x = self.norm1(x)
# Depthwise conv
x = self.depthwise_conv(x)
x = self.prelu2(x)
x = self.norm2(x)
# Skip connection
skip = self.skip_conv(x)
# 1x1 conv
x = self.conv1x1_2(x)
# Residual connection
output = x + residual
return output, skip
# Example usage
model = ConvTasNet(n_src=2, n_filters=512)
# Mixed waveform (2 speakers)
mixture = torch.randn(4, 16000) # [batch=4, time=16000 (1 second at 16kHz)]
# Separate
separated = model(mixture) # [4, 2, 16000]
print(f"Input shape: {mixture.shape}")
print(f"Output shape: {separated.shape}")
print(f"Separated speaker 1: {separated[:, 0, :].shape}")
print(f"Separated speaker 2: {separated[:, 1, :].shape}")
Training with Permutation Invariant Loss
import torch
import torch.nn as nn
import torch.nn.functional as F
class PermutationInvariantLoss(nn.Module):
"""
Permutation Invariant Training (PIT) loss
Problem: Model outputs are in arbitrary order
Solution: Try all permutations, use best one
For 2 speakers:
- Try (output1→target1, output2→target2)
- Try (output1→target2, output2→target1)
- Use permutation with lower loss
"""
def __init__(self, loss_fn='si_sdr'):
super().__init__()
self.loss_fn = loss_fn
def forward(self, estimated, target):
"""
Compute PIT loss
Args:
estimated: [batch, n_src, time]
target: [batch, n_src, time]
Returns:
loss: scalar
"""
batch_size, n_src, time = estimated.size()
# Generate all permutations
import itertools
perms = list(itertools.permutations(range(n_src)))
# Compute loss for each permutation
perm_losses = []
for perm in perms:
# Reorder estimated according to permutation
estimated_perm = estimated[:, perm, :]
# Compute loss
if self.loss_fn == 'si_sdr':
loss = self._si_sdr_loss(estimated_perm, target)
elif self.loss_fn == 'mse':
loss = F.mse_loss(estimated_perm, target)
else:
raise ValueError(f"Unknown loss function: {self.loss_fn}")
perm_losses.append(loss)
# Stack losses
# [n_perms], take minimum (best permutation)
perm_losses = torch.stack(perm_losses)
return perm_losses.min()
def _si_sdr_loss(self, estimated, target):
"""
Scale-Invariant Signal-to-Distortion Ratio loss
Better than MSE for speech separation
"""
# Zero-mean
estimated = estimated - estimated.mean(dim=-1, keepdim=True)
target = target - target.mean(dim=-1, keepdim=True)
# Project estimated onto target
dot = (estimated * target).sum(dim=-1, keepdim=True)
target_energy = (target ** 2).sum(dim=-1, keepdim=True) + 1e-8
projection = dot * target / target_energy
# Noise (estimation error)
noise = estimated - projection
# SI-SDR
si_sdr = 10 * torch.log10(
(projection ** 2).sum(dim=-1) / ((noise ** 2).sum(dim=-1) + 1e-8)
)
# Negative for loss (we want to maximize SI-SDR)
return -si_sdr.mean()
# Training loop
model = ConvTasNet(n_src=2)
criterion = PermutationInvariantLoss(loss_fn='si_sdr')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
def train_epoch(model, train_loader, criterion, optimizer):
"""Train one epoch"""
model.train()
total_loss = 0
for batch_idx, (mixture, target) in enumerate(train_loader):
# mixture: [batch, time]
# target: [batch, n_src, time]
# Forward
estimated = model(mixture)
# Loss with PIT
loss = criterion(estimated, target)
# Backward
optimizer.zero_grad()
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
optimizer.step()
total_loss += loss.item()
if batch_idx % 10 == 0:
print(f"Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")
return total_loss / len(train_loader)
# Train
# for epoch in range(num_epochs):
# train_loss = train_epoch(model, train_loader, criterion, optimizer)
# print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}")
Evaluation Metrics
Signal-to-Distortion Ratio (SDR)
def compute_sdr(estimated, target):
"""
Compute Signal-to-Distortion Ratio
SDR = 10 * log10(||target||^2 / ||target - estimated||^2)
Higher is better. Good: > 10 dB, Great: > 15 dB
"""
target = target - target.mean()
estimated = estimated - estimated.mean()
signal_power = np.sum(target ** 2)
distortion = target - estimated
distortion_power = np.sum(distortion ** 2) + 1e-10
sdr = 10 * np.log10(signal_power / distortion_power)
return sdr
def compute_si_sdr(estimated, target):
"""
Compute Scale-Invariant SDR
Invariant to scaling of the signal
"""
# Zero-mean
estimated = estimated - estimated.mean()
target = target - target.mean()
# Project estimated onto target
alpha = np.dot(estimated, target) / (np.dot(target, target) + 1e-10)
projection = alpha * target
# Noise
noise = estimated - projection
# SI-SDR
si_sdr = 10 * np.log10(
np.sum(projection ** 2) / (np.sum(noise ** 2) + 1e-10)
)
return si_sdr
def compute_sir(estimated, target, interference):
"""
Compute Signal-to-Interference Ratio
Measures how well interfering speakers are suppressed
"""
target = target - target.mean()
estimated = estimated - estimated.mean()
# Project estimated onto target
s_target = np.dot(estimated, target) / (np.dot(target, target) + 1e-10) * target
# Interference
e_interf = 0
for interf in interference:
interf = interf - interf.mean()
e_interf += np.dot(estimated, interf) / (np.dot(interf, interf) + 1e-10) * interf
# SIR
sir = 10 * np.log10(
np.sum(s_target ** 2) / (np.sum(e_interf ** 2) + 1e-10)
)
return sir
# Comprehensive evaluation
def evaluate_separation(model, test_loader):
"""
Evaluate separation quality
Returns metrics for each source
"""
model.eval()
all_sdr = []
all_si_sdr = []
with torch.no_grad():
for mixture, targets in test_loader:
# Separate
estimated = model(mixture)
# Convert to numpy
estimated_np = estimated.cpu().numpy()
targets_np = targets.cpu().numpy()
batch_size, n_src, time = estimated_np.shape
# Compute metrics for each sample and source
for b in range(batch_size):
for s in range(n_src):
est = estimated_np[b, s, :]
tgt = targets_np[b, s, :]
sdr = compute_sdr(est, tgt)
si_sdr = compute_si_sdr(est, tgt)
all_sdr.append(sdr)
all_si_sdr.append(si_sdr)
results = {
'sdr_mean': np.mean(all_sdr),
'sdr_std': np.std(all_sdr),
'si_sdr_mean': np.mean(all_si_sdr),
'si_sdr_std': np.std(all_si_sdr)
}
print("="*60)
print("SEPARATION EVALUATION RESULTS")
print("="*60)
print(f"SDR: {results['sdr_mean']:.2f} ± {results['sdr_std']:.2f} dB")
print(f"SI-SDR: {results['si_sdr_mean']:.2f} ± {results['si_sdr_std']:.2f} dB")
print("="*60)
return results
# Example
# results = evaluate_separation(model, test_loader)
Real-Time Separation Pipeline
Streaming Speech Separation
class StreamingSpeechSeparator:
"""
Real-time speech separation for streaming audio
Challenges:
- Causal processing (no future context)
- Low latency (< 50ms)
- State management across chunks
"""
def __init__(self, model, chunk_size=4800, overlap=1200):
"""
Args:
model: Trained separation model
chunk_size: Samples per chunk (300ms at 16kHz)
overlap: Overlap between chunks (75ms at 16kHz)
"""
self.model = model
self.model.eval()
self.chunk_size = chunk_size
self.overlap = overlap
self.hop_size = chunk_size - overlap
# Buffer for overlapping
self.input_buffer = np.zeros(overlap)
self.output_buffers = [np.zeros(overlap) for _ in range(model.n_src)]
def process_chunk(self, audio_chunk):
"""
Process single audio chunk
Args:
audio_chunk: [chunk_size] numpy array
Returns:
separated_chunks: list of [hop_size] arrays, one per speaker
"""
# Concatenate with buffer
full_chunk = np.concatenate([self.input_buffer, audio_chunk])
# Ensure correct size
if len(full_chunk) < self.chunk_size:
full_chunk = np.pad(
full_chunk,
(0, self.chunk_size - len(full_chunk)),
mode='constant'
)
# Convert to tensor
with torch.no_grad():
chunk_tensor = torch.from_numpy(full_chunk).float().unsqueeze(0)
# Separate
separated = self.model(chunk_tensor) # [1, n_src, chunk_size]
# Convert back to numpy
separated_np = separated[0].cpu().numpy() # [n_src, chunk_size]
# Overlap-add
result_chunks = []
for src_idx in range(self.model.n_src):
src_audio = separated_np[src_idx]
# Add overlap from previous chunk
src_audio[:self.overlap] += self.output_buffers[src_idx]
# Extract output (without overlap)
output_chunk = src_audio[:self.hop_size]
result_chunks.append(output_chunk)
# Save overlap for next chunk
self.output_buffers[src_idx] = src_audio[-self.overlap:]
# Update input buffer
self.input_buffer = audio_chunk[-self.overlap:]
return result_chunks
def reset(self):
"""Reset state for new stream"""
self.input_buffer = np.zeros(self.overlap)
self.output_buffers = [np.zeros(self.overlap) for _ in range(self.model.n_src)]
# Example: Real-time separation server
from fastapi import FastAPI, WebSocket
import asyncio
app = FastAPI()
# Load model
model = ConvTasNet(n_src=2)
model.load_state_dict(torch.load('convtasnet_separation.pth'))
separator = StreamingSpeechSeparator(model, chunk_size=4800, overlap=1200)
@app.websocket("/separate")
async def websocket_separation(websocket: WebSocket):
"""
WebSocket endpoint for real-time separation
Client sends audio chunks, receives separated streams
"""
await websocket.accept()
try:
while True:
# Receive audio chunk
data = await websocket.receive_bytes()
# Decode audio (assuming 16-bit PCM)
audio_chunk = np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0
# Separate
separated_chunks = separator.process_chunk(audio_chunk)
# Send separated streams
for src_idx, src_chunk in enumerate(separated_chunks):
# Encode back to 16-bit PCM
src_bytes = (src_chunk * 32768).astype(np.int16).tobytes()
await websocket.send_json({
'speaker_id': src_idx,
'audio': src_bytes.hex()
})
except Exception as e:
print(f"WebSocket error: {e}")
finally:
separator.reset()
await websocket.close()
# Run server
# uvicorn.run(app, host='0.0.0.0', port=8000)
Integration with Downstream Tasks
Speech Separation + ASR Pipeline
class SeparationASRPipeline:
"""
Combined pipeline: Separate speakers → Transcribe each
Use case: Meeting transcription with overlapping speech
"""
def __init__(self, separation_model, asr_model):
self.separator = separation_model
self.asr = asr_model
def transcribe_multi_speaker(self, audio):
"""
Transcribe audio with multiple speakers
Args:
audio: Mixed audio
Returns:
List of (speaker_id, transcript) tuples
"""
# Separate speakers
with torch.no_grad():
audio_tensor = torch.from_numpy(audio).float().unsqueeze(0)
separated = self.separator(audio_tensor)[0] # [n_src, time]
# Transcribe each speaker
transcripts = []
for speaker_id in range(separated.size(0)):
speaker_audio = separated[speaker_id].cpu().numpy()
# Transcribe
transcript = self.asr.transcribe(speaker_audio)
transcripts.append({
'speaker_id': speaker_id,
'transcript': transcript,
'audio_length_sec': len(speaker_audio) / 16000
})
return transcripts
def transcribe_with_diarization(self, audio):
"""
Transcribe with speaker diarization
Diarization: Who spoke when?
Separation: Isolate each speaker's audio
ASR: Transcribe each speaker
"""
# Separate speakers
with torch.no_grad():
audio_tensor = torch.from_numpy(audio).float().unsqueeze(0)
separated = self.separator(audio_tensor)[0] # [n_src, time]
# Speaker diarization on each separated stream
diarization_results = []
for speaker_id in range(separated.size(0)):
speaker_audio = separated[speaker_id].cpu().numpy()
# Voice Activity Detection
vad_segments = self._detect_voice_activity(speaker_audio)
# Transcribe active segments
for segment in vad_segments:
start_idx = int(segment['start'] * 16000)
end_idx = int(segment['end'] * 16000)
segment_audio = speaker_audio[start_idx:end_idx]
transcript = self.asr.transcribe(segment_audio)
diarization_results.append({
'speaker_id': speaker_id,
'start_time': segment['start'],
'end_time': segment['end'],
'transcript': transcript
})
# Sort by start time
diarization_results.sort(key=lambda x: x['start_time'])
return diarization_results
def _detect_voice_activity(self, audio, frame_duration=0.03):
"""
Simple energy-based VAD
Returns list of (start, end) segments with voice activity
"""
import librosa
# Compute energy
frame_length = int(frame_duration * 16000)
energy = librosa.feature.rms(
y=audio,
frame_length=frame_length,
hop_length=frame_length // 2
)[0]
# Threshold
threshold = np.mean(energy) * 0.5
# Find voice segments
is_voice = energy > threshold
segments = []
in_segment = False
start = 0
for i, voice in enumerate(is_voice):
if voice and not in_segment:
start = i * frame_duration / 2
in_segment = True
elif not voice and in_segment:
end = i * frame_duration / 2
segments.append({'start': start, 'end': end})
in_segment = False
return segments
# Example usage
separation_model = ConvTasNet(n_src=2)
separation_model.load_state_dict(torch.load('separation_model.pth'))
# Mock ASR model
class MockASR:
def transcribe(self, audio):
return f"Transcribed {len(audio)} samples"
asr_model = MockASR()
pipeline = SeparationASRPipeline(separation_model, asr_model)
# Transcribe multi-speaker audio
audio = np.random.randn(16000 * 10) # 10 seconds
results = pipeline.transcribe_multi_speaker(audio)
print("Transcription results:")
for result in results:
print(f"Speaker {result['speaker_id']}: {result['transcript']}")
Advanced Topics
Unknown Number of Speakers
class AdaptiveSeparationModel(nn.Module):
"""
Separate audio with unknown number of speakers
Approach:
1. Estimate number of speakers
2. Separate into estimated number of sources
3. Filter empty sources
"""
def __init__(self, max_speakers=10):
super().__init__()
self.max_speakers = max_speakers
# Speaker counting network
self.speaker_counter = nn.Sequential(
nn.Conv1d(1, 128, kernel_size=3, stride=2),
nn.ReLU(),
nn.Conv1d(128, 256, kernel_size=3, stride=2),
nn.ReLU(),
nn.AdaptiveAvgPool1d(1),
nn.Flatten(),
nn.Linear(256, max_speakers + 1), # 0 to max_speakers
nn.Softmax(dim=-1)
)
# Separation models for different numbers of speakers
self.separators = nn.ModuleList([
ConvTasNet(n_src=n) for n in range(1, max_speakers + 1)
])
def forward(self, mixture):
"""
Separate with adaptive number of sources
Args:
mixture: [batch, time]
Returns:
separated: list of [batch, time] tensors (one per active speaker)
"""
# Estimate number of speakers
mixture_1d = mixture.unsqueeze(1) # [batch, 1, time]
speaker_probs = self.speaker_counter(mixture_1d) # [batch, max_speakers + 1]
n_speakers = speaker_probs.argmax(dim=-1) # [batch]
# For simplicity, use max in batch (in practice, process per sample)
max_n_speakers = n_speakers.max().item()
if max_n_speakers == 0:
return []
# Separate using appropriate model
separator = self.separators[max_n_speakers - 1]
separated = separator(mixture) # [batch, n_src, time]
return separated
# Example
model = AdaptiveSeparationModel(max_speakers=5)
# Test with 2 speakers
mixture = torch.randn(1, 16000)
separated = model(mixture)
print(f"Estimated sources: {separated.size(1)}")
Multi-Channel Separation
class MultiChannelSeparator(nn.Module):
"""
Use multiple microphones for better separation
Microphone array provides spatial information
"""
def __init__(self, n_channels, n_src):
super().__init__()
self.n_channels = n_channels
self.n_src = n_src
# Encoder for each channel
self.encoders = nn.ModuleList([
nn.Conv1d(1, 256, kernel_size=16, stride=8)
for _ in range(n_channels)
])
# Cross-channel attention
self.cross_channel_attention = nn.MultiheadAttention(
embed_dim=256 * n_channels,
num_heads=8
)
# Separator
self.separator = TemporalConvNet(
256 * n_channels,
n_src,
n_blocks=8,
n_repeats=3,
bn_chan=128,
hid_chan=512,
skip_chan=128
)
# Decoder
self.decoder = nn.ConvTranspose1d(256, 1, kernel_size=16, stride=8)
def forward(self, multi_channel_mixture):
"""
Separate using multi-channel input
Args:
multi_channel_mixture: [batch, n_channels, time]
Returns:
separated: [batch, n_src, time]
"""
batch_size, n_channels, time = multi_channel_mixture.size()
# Encode each channel
encoded_channels = []
for ch in range(n_channels):
ch_audio = multi_channel_mixture[:, ch:ch+1, :] # [batch, 1, time]
ch_encoded = self.encoders[ch](ch_audio) # [batch, 256, time']
encoded_channels.append(ch_encoded)
# Concatenate channels
encoded = torch.cat(encoded_channels, dim=1) # [batch, 256 * n_channels, time']
# Cross-channel attention
# Reshape for attention: [time', batch, 256 * n_channels]
encoded_t = encoded.permute(2, 0, 1)
attended, _ = self.cross_channel_attention(encoded_t, encoded_t, encoded_t)
attended = attended.permute(1, 2, 0) # [batch, 256 * n_channels, time']
# Separate
masks = self.separator(attended) # [batch, n_src, 256 * n_channels, time']
# Apply masks and decode
separated = []
for src_idx in range(self.n_src):
masked = attended * masks[:, src_idx, :, :]
# Take first 256 channels for decoding
masked_single = masked[:, :256, :]
src_waveform = self.decoder(masked_single).squeeze(1)
separated.append(src_waveform)
separated = torch.stack(separated, dim=1)
return separated
# Example: 4-microphone array
model = MultiChannelSeparator(n_channels=4, n_src=2)
# 4-channel input
multi_channel_audio = torch.randn(1, 4, 16000)
separated = model(multi_channel_audio)
print(f"Separated shape: {separated.shape}") # [1, 2, 16000]
Key Takeaways
✅ Conv-TasNet - State-of-the-art time-domain separation
✅ PIT loss - Handle output permutation problem
✅ SI-SDR metric - Scale-invariant quality measure
✅ Real-time streaming - Chunk-based processing with overlap-add
✅ Integration with ASR - End-to-end meeting transcription
Performance Targets:
- SI-SDR improvement: > 15 dB
- Real-time factor: < 0.1 (10x faster than real-time)
- Latency: < 50ms for streaming
- Works with 2-5 overlapping speakers
Originally published at: arunbaby.com/speech-tech/0011-speech-separation
If you found this helpful, consider sharing it with others who might benefit.