Voice Enhancement & Noise Reduction
Build systems that enhance voice quality by removing noise, improving intelligibility, and optimizing audio for speech applications.
Introduction
Voice enhancement improves speech quality by:
- Removing background noise (traffic, wind, keyboard)
- Suppressing reverberation
- Normalizing volume levels
- Enhancing speech intelligibility
- Removing artifacts and distortion
Critical for:
- Video conferencing (Zoom, Teams, Meet)
- Voice assistants (Alexa, Siri, Google Assistant)
- Podcast/content creation
- Hearing aids
- Telecommunication
- Speech recognition systems
Key challenges:
- Real-time processing (< 50ms latency)
- Preserving speech quality
- Handling diverse noise types
- Low computational cost
- Avoiding artifacts
Problem Formulation
Input/Output
Input: Noisy speech signal
y(t) = s(t) + n(t)
where:
s(t) = clean speech
n(t) = noise
Output: Enhanced speech signal
ŝ(t) ≈ s(t)
Goal: Minimize ‖ŝ(t) - s(t)‖ while maintaining naturalness
Quality Metrics
import numpy as np
from scipy import signal
def calculate_snr(clean_speech, noisy_speech):
"""
Calculate Signal-to-Noise Ratio
SNR = 10 * log10(P_signal / P_noise)
Higher is better (typically 10-30 dB)
"""
signal_power = np.mean(clean_speech ** 2)
noise = noisy_speech - clean_speech
noise_power = np.mean(noise ** 2)
if noise_power == 0:
return float('inf')
snr = 10 * np.log10(signal_power / noise_power)
return snr
def calculate_pesq(reference, degraded, sr=16000):
"""
Calculate PESQ (Perceptual Evaluation of Speech Quality)
Range: -0.5 to 4.5 (higher is better)
Industry standard for speech quality
"""
from pesq import pesq
# PESQ requires 8kHz or 16kHz
if sr not in [8000, 16000]:
raise ValueError("PESQ requires sr=8000 or sr=16000")
mode = 'nb' if sr == 8000 else 'wb'
score = pesq(sr, reference, degraded, mode)
return score
def calculate_stoi(clean, enhanced, sr=16000):
"""
Calculate STOI (Short-Time Objective Intelligibility)
Range: 0 to 1 (higher is better)
Correlates well with speech intelligibility
"""
from pystoi import stoi
score = stoi(clean, enhanced, sr, extended=False)
return score
# Usage
clean = np.random.randn(16000) # 1 second at 16kHz
noisy = clean + 0.1 * np.random.randn(16000)
snr = calculate_snr(clean, noisy)
print(f"SNR: {snr:.2f} dB")
# pesq_score = calculate_pesq(clean, noisy, sr=16000)
# print(f"PESQ: {pesq_score:.2f}")
Classical Methods
1. Spectral Subtraction
Subtract noise spectrum from noisy spectrum
import librosa
import numpy as np
class SpectralSubtraction:
"""
Classic spectral subtraction for noise reduction
Steps:
1. Estimate noise spectrum (from silence periods)
2. Subtract from noisy spectrum
3. Half-wave rectification
4. Reconstruct signal
"""
def __init__(self, n_fft=512, hop_length=128):
self.n_fft = n_fft
self.hop_length = hop_length
self.noise_profile = None
def estimate_noise(self, noise_audio, sr=16000):
"""
Estimate noise spectrum from noise-only segment
Args:
noise_audio: Audio containing only noise
"""
# STFT of noise
noise_stft = librosa.stft(
noise_audio,
n_fft=self.n_fft,
hop_length=self.hop_length
)
# Average magnitude spectrum
self.noise_profile = np.mean(np.abs(noise_stft), axis=1, keepdims=True)
def enhance(self, noisy_audio, alpha=2.0, beta=0.002):
"""
Apply spectral subtraction
Args:
noisy_audio: Noisy speech signal
alpha: Over-subtraction factor (higher = more aggressive)
beta: Spectral floor (prevents negative values)
Returns:
Enhanced audio
"""
if self.noise_profile is None:
raise ValueError("Must estimate noise first")
# STFT of noisy signal
noisy_stft = librosa.stft(
noisy_audio,
n_fft=self.n_fft,
hop_length=self.hop_length
)
# Magnitude and phase
mag = np.abs(noisy_stft)
phase = np.angle(noisy_stft)
# Spectral subtraction
enhanced_mag = mag - alpha * self.noise_profile
# Half-wave rectification with spectral floor
enhanced_mag = np.maximum(enhanced_mag, beta * mag)
# Reconstruct with original phase
enhanced_stft = enhanced_mag * np.exp(1j * phase)
# Inverse STFT
enhanced_audio = librosa.istft(
enhanced_stft,
hop_length=self.hop_length
)
return enhanced_audio
# Usage
sr = 16000
# Load noisy speech
noisy_speech, _ = librosa.load('noisy_speech.wav', sr=sr)
# Estimate noise from first 0.5 seconds (assumed to be silence)
noise_segment = noisy_speech[:int(0.5 * sr)]
enhancer = SpectralSubtraction()
enhancer.estimate_noise(noise_segment)
# Enhance full audio
enhanced = enhancer.enhance(noisy_speech, alpha=2.0)
# Save result
import soundfile as sf
sf.write('enhanced_speech.wav', enhanced, sr)
2. Wiener Filtering
Optimal filter in MMSE sense
class WienerFilter:
"""
Wiener filtering for speech enhancement
Minimizes mean squared error between clean and enhanced speech
"""
def __init__(self, n_fft=512, hop_length=128):
self.n_fft = n_fft
self.hop_length = hop_length
self.noise_psd = None
def estimate_noise_psd(self, noise_audio):
"""Estimate noise power spectral density"""
noise_stft = librosa.stft(
noise_audio,
n_fft=self.n_fft,
hop_length=self.hop_length
)
# Power spectral density
self.noise_psd = np.mean(np.abs(noise_stft) ** 2, axis=1, keepdims=True)
def enhance(self, noisy_audio, a_priori_snr=None):
"""
Apply Wiener filtering
Wiener gain: H = S / (S + N)
where S = signal PSD, N = noise PSD
"""
if self.noise_psd is None:
raise ValueError("Must estimate noise PSD first")
# STFT
noisy_stft = librosa.stft(
noisy_audio,
n_fft=self.n_fft,
hop_length=self.hop_length
)
# Noisy PSD
noisy_psd = np.abs(noisy_stft) ** 2
# Estimate clean speech PSD
speech_psd = np.maximum(noisy_psd - self.noise_psd, 0)
# Wiener gain
wiener_gain = speech_psd / (speech_psd + self.noise_psd + 1e-10)
# Apply gain
enhanced_stft = wiener_gain * noisy_stft
# Inverse STFT
enhanced_audio = librosa.istft(
enhanced_stft,
hop_length=self.hop_length
)
return enhanced_audio
# Usage
wiener = WienerFilter()
wiener.estimate_noise_psd(noise_segment)
enhanced = wiener.enhance(noisy_speech)
Deep Learning Approaches
1. Mask-Based Enhancement
Learn ideal ratio mask (IRM) or ideal binary mask (IBM)
import torch
import torch.nn as nn
class MaskEstimationNet(nn.Module):
"""
Neural network for mask estimation
Predicts time-frequency mask to apply to noisy spectrogram
"""
def __init__(self, n_fft=512, hidden_dim=128):
super().__init__()
self.n_freq = n_fft // 2 + 1
# Bidirectional LSTM
self.lstm = nn.LSTM(
input_size=self.n_freq,
hidden_size=hidden_dim,
num_layers=2,
batch_first=True,
bidirectional=True
)
# Mask prediction
self.mask_fc = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim, self.n_freq),
nn.Sigmoid() # Mask values in [0, 1]
)
def forward(self, noisy_mag):
"""
Args:
noisy_mag: Noisy magnitude spectrogram [batch, time, freq]
Returns:
mask: Predicted mask [batch, time, freq]
"""
# LSTM
lstm_out, _ = self.lstm(noisy_mag)
# Predict mask
mask = self.mask_fc(lstm_out)
return mask
class MaskBasedEnhancer:
"""
Speech enhancement using learned mask
"""
def __init__(self, model, n_fft=512, hop_length=128):
self.model = model
self.model.eval()
self.n_fft = n_fft
self.hop_length = hop_length
def enhance(self, noisy_audio):
"""
Enhance audio using learned mask
Steps:
1. Compute noisy spectrogram
2. Predict mask with neural network
3. Apply mask
4. Reconstruct audio
"""
# STFT
noisy_stft = librosa.stft(
noisy_audio,
n_fft=self.n_fft,
hop_length=self.hop_length
)
# Magnitude and phase
noisy_mag = np.abs(noisy_stft)
phase = np.angle(noisy_stft)
# Normalize magnitude
mag_mean = np.mean(noisy_mag)
mag_std = np.std(noisy_mag)
noisy_mag_norm = (noisy_mag - mag_mean) / (mag_std + 1e-8)
# Predict mask
with torch.no_grad():
# Transpose to [1, time, freq]
mag_tensor = torch.FloatTensor(noisy_mag_norm.T).unsqueeze(0)
mask = self.model(mag_tensor)
# Back to numpy
mask = mask.squeeze(0).numpy().T
# Apply mask
enhanced_mag = noisy_mag * mask
# Reconstruct
enhanced_stft = enhanced_mag * np.exp(1j * phase)
enhanced_audio = librosa.istft(
enhanced_stft,
hop_length=self.hop_length
)
return enhanced_audio
# Usage
model = MaskEstimationNet(n_fft=512)
enhancer = MaskBasedEnhancer(model)
# Enhance
enhanced = enhancer.enhance(noisy_speech)
2. End-to-End Waveform Enhancement
Direct waveform→waveform mapping
class ConvTasNet(nn.Module):
"""
Conv-TasNet for speech enhancement
End-to-end time-domain speech separation
Based on: "Conv-TasNet: Surpassing Ideal Time-Frequency Masking"
"""
def __init__(self, N=256, L=20, B=256, H=512, P=3, X=8, R=3):
"""
Args:
N: Number of filters in autoencoder
L: Length of filters (ms)
B: Number of channels in bottleneck
H: Number of channels in conv blocks
P: Kernel size in conv blocks
X: Number of conv blocks in each repeat
R: Number of repeats
"""
super().__init__()
# Encoder (waveform → features)
self.encoder = nn.Conv1d(1, N, L, stride=L//2, padding=L//2)
# Separator (TCN blocks)
self.separator = self._build_separator(N, B, H, P, X, R)
# Decoder (features → waveform)
self.decoder = nn.ConvTranspose1d(N, 1, L, stride=L//2, padding=L//2)
def _build_separator(self, N, B, H, P, X, R):
"""Build temporal convolutional network"""
layers = []
# Layer normalization
layers.append(nn.LayerNorm(N))
# Bottleneck
layers.append(nn.Conv1d(N, B, 1))
# TCN blocks
for r in range(R):
for x in range(X):
dilation = 2 ** x
layers.append(
TemporalConvBlock(B, H, P, dilation)
)
# Output projection
layers.append(nn.PReLU())
layers.append(nn.Conv1d(B, N, 1))
return nn.Sequential(*layers)
def forward(self, mixture):
"""
Args:
mixture: Noisy waveform [batch, 1, samples]
Returns:
estimated_clean: Enhanced waveform [batch, 1, samples]
"""
# Encode
encoded = self.encoder(mixture) # [batch, N, T]
# Separate
mask = self.separator(encoded) # [batch, N, T]
# Apply mask
separated = encoded * mask
# Decode
estimated = self.decoder(separated) # [batch, 1, samples]
return estimated
class TemporalConvBlock(nn.Module):
"""
Temporal convolutional block with dilated convolutions
"""
def __init__(self, in_channels, hidden_channels, kernel_size, dilation):
super().__init__()
self.conv1 = nn.Conv1d(
in_channels, hidden_channels,
kernel_size, dilation=dilation,
padding=dilation * (kernel_size - 1) // 2
)
self.prelu1 = nn.PReLU()
self.norm1 = nn.GroupNorm(1, hidden_channels)
self.conv2 = nn.Conv1d(
hidden_channels, in_channels,
1
)
self.prelu2 = nn.PReLU()
self.norm2 = nn.GroupNorm(1, in_channels)
def forward(self, x):
"""
Args:
x: [batch, channels, time]
"""
residual = x
out = self.conv1(x)
out = self.prelu1(out)
out = self.norm1(out)
out = self.conv2(out)
out = self.prelu2(out)
out = self.norm2(out)
return out + residual
# Usage
model = ConvTasNet()
noisy_tensor = torch.randn(1, 1, 16000) # 1 second
enhanced_tensor = model(noisy_tensor)
Real-Time Enhancement
Streaming Enhancement System
import numpy as np
from collections import deque
class StreamingEnhancer:
"""
Real-time streaming speech enhancement
Requirements:
- Low latency (< 50ms)
- Causal processing
- Minimal buffering
"""
def __init__(self, model, chunk_size=512, overlap=256, sr=16000):
"""
Args:
chunk_size: Samples per chunk
overlap: Overlap between chunks (for smooth transitions)
"""
self.model = model
self.chunk_size = chunk_size
self.overlap = overlap
self.sr = sr
# Circular buffer for overlap-add
self.buffer = deque(maxlen=overlap)
self.output_buffer = deque(maxlen=overlap)
self.processed_chunks = 0
def process_chunk(self, audio_chunk):
"""
Process single audio chunk
Args:
audio_chunk: Audio samples [chunk_size]
Returns:
Enhanced audio chunk
"""
# Add previous overlap
if len(self.buffer) > 0:
input_chunk = np.concatenate([
np.array(self.buffer),
audio_chunk
])
else:
input_chunk = audio_chunk
# Enhance
enhanced = self._enhance_chunk(input_chunk)
# Overlap-add with linear cross-fade
if len(self.output_buffer) > 0:
# Smooth transition
overlap_region = min(len(self.output_buffer), self.overlap)
for i in range(overlap_region):
weight = i / overlap_region
enhanced[i] = (1 - weight) * self.output_buffer[i] + weight * enhanced[i]
# Save overlap for next chunk
self.buffer.clear()
self.buffer.extend(audio_chunk[-self.overlap:])
self.output_buffer.clear()
self.output_buffer.extend(enhanced[-self.overlap:])
self.processed_chunks += 1
# Return non-overlap part
return enhanced[:-self.overlap] if len(enhanced) > self.overlap else enhanced
def _enhance_chunk(self, audio_chunk):
"""Enhance using model"""
# Convert to tensor
audio_tensor = torch.FloatTensor(audio_chunk).unsqueeze(0).unsqueeze(0)
# Enhance
with torch.no_grad():
enhanced_tensor = self.model(audio_tensor)
# Back to numpy
enhanced = enhanced_tensor.squeeze().numpy()
return enhanced
def get_latency_ms(self):
"""Calculate processing latency"""
return (self.chunk_size / self.sr) * 1000
# Usage for real-time processing
model = ConvTasNet()
enhancer = StreamingEnhancer(model, chunk_size=512, overlap=256, sr=16000)
print(f"Latency: {enhancer.get_latency_ms():.2f} ms")
# Process audio stream
import sounddevice as sd
def audio_callback(indata, outdata, frames, time, status):
"""Real-time audio callback"""
# Get input chunk
input_chunk = indata[:, 0]
# Enhance
enhanced_chunk = enhancer.process_chunk(input_chunk)
# Output
outdata[:len(enhanced_chunk), 0] = enhanced_chunk
if status:
print(f"Status: {status}")
# Start real-time processing
with sd.Stream(
samplerate=16000,
channels=1,
callback=audio_callback,
blocksize=512
):
print("Processing audio in real-time... Press Ctrl+C to stop")
sd.sleep(10000)
Multi-Channel Enhancement
Beamforming
class BeamformerEnhancer:
"""
Beamforming for multi-microphone enhancement
Uses spatial information to enhance target speech
"""
def __init__(self, n_mics=4, sr=16000):
self.n_mics = n_mics
self.sr = sr
def delay_and_sum(self, multi_channel_audio, target_direction=0):
"""
Delay-and-sum beamforming
Args:
multi_channel_audio: [n_mics, n_samples]
target_direction: Target angle in degrees (0 = front)
Returns:
Enhanced single-channel audio
"""
n_samples = multi_channel_audio.shape[1]
# Calculate delays for each microphone
# (Simplified: assumes linear array)
mic_spacing = 0.05 # 5cm between mics
speed_of_sound = 343 # m/s
delays = []
for i in range(self.n_mics):
distance_diff = i * mic_spacing * np.sin(np.deg2rad(target_direction))
delay_samples = int(distance_diff / speed_of_sound * self.sr)
delays.append(delay_samples)
# Align and sum
aligned_signals = []
for i, delay in enumerate(delays):
sig = multi_channel_audio[i]
if delay > 0:
# Delay by pre-pending zeros
padded = np.concatenate([np.zeros(delay, dtype=sig.dtype), sig])
aligned = padded[:n_samples]
elif delay < 0:
# Advance by removing first samples
aligned = sig[-delay:]
if aligned.shape[0] < n_samples:
aligned = np.pad(aligned, (0, n_samples - aligned.shape[0]), mode='constant')
else:
aligned = sig
aligned_signals.append(aligned)
# Sum aligned signals
enhanced = np.mean(aligned_signals, axis=0)
return enhanced
def mvdr_beamformer(self, multi_channel_audio, noise_segment):
"""
MVDR (Minimum Variance Distortionless Response) beamformer
Optimal beamformer for known noise covariance
"""
# Compute noise covariance matrix
noise_cov = self._compute_covariance(noise_segment)
# Compute signal+noise covariance
signal_noise_cov = self._compute_covariance(multi_channel_audio)
# MVDR weights
# w = R_n^{-1} * a / (a^H * R_n^{-1} * a)
# where a is steering vector
# Simplified: assume steering vector points to channel 0
steering_vector = np.zeros((self.n_mics, 1))
steering_vector[0] = 1
# Compute weights
inv_noise_cov = np.linalg.pinv(noise_cov)
numerator = inv_noise_cov @ steering_vector
denominator = steering_vector.T @ inv_noise_cov @ steering_vector
weights = numerator / (denominator + 1e-10)
# Apply weights
enhanced = weights.T @ multi_channel_audio
return enhanced.squeeze()
def _compute_covariance(self, signal):
"""Compute covariance matrix"""
# [n_mics, n_samples] → [n_mics, n_mics]
cov = signal @ signal.T / signal.shape[1]
return cov
# Usage
beamformer = BeamformerEnhancer(n_mics=4, sr=16000)
# Multi-channel recording
multi_ch_audio = np.random.randn(4, 16000) # 4 mics, 1 second
# Enhance using delay-and-sum
enhanced_ds = beamformer.delay_and_sum(multi_ch_audio, target_direction=0)
# Or using MVDR
noise_segment = multi_ch_audio[:, :8000] # First 0.5 seconds
enhanced_mvdr = beamformer.mvdr_beamformer(multi_ch_audio, noise_segment)
Connection to Caching (Day 10 ML)
Voice enhancement benefits from caching strategies:
class EnhancementCache:
"""
Cache enhanced audio segments
Connection to Day 10 ML:
- Cache expensive enhancement operations
- LRU for frequently accessed segments
- TTL for time-sensitive applications
"""
def __init__(self, capacity=1000):
from collections import OrderedDict
self.cache = OrderedDict()
self.capacity = capacity
self.hits = 0
self.misses = 0
def get_enhanced(self, audio_segment, model):
"""
Get enhanced audio with caching
Args:
audio_segment: Raw audio
model: Enhancement model
Returns:
Enhanced audio
"""
# Create cache key (hash of audio)
cache_key = hash(audio_segment.tobytes())
# Check cache
if cache_key in self.cache:
self.hits += 1
self.cache.move_to_end(cache_key) # Mark as recently used
return self.cache[cache_key]
# Compute enhancement
self.misses += 1
enhanced = model.enhance(audio_segment)
# Cache result
self.cache[cache_key] = enhanced
# Evict if over capacity
if len(self.cache) > self.capacity:
self.cache.popitem(last=False)
return enhanced
def get_hit_rate(self):
"""Calculate cache hit rate"""
total = self.hits + self.misses
return self.hits / total if total > 0 else 0
# Usage
cache = EnhancementCache(capacity=1000)
model = ConvTasNet()
# Process with caching
for audio_segment in audio_stream:
enhanced = cache.get_enhanced(audio_segment, model)
print(f"Cache hit rate: {cache.get_hit_rate():.2%}")
Understanding Audio Enhancement Fundamentals
Why Enhancement is Critical
Voice enhancement is the foundation of any production speech system. Poor audio quality cascades through the entire pipeline:
class AudioQualityImpactAnalyzer:
"""
Analyze impact of audio quality on downstream tasks
Demonstrates how SNR affects ASR accuracy, speaker recognition, etc.
"""
def __init__(self, asr_model, speaker_model):
self.asr_model = asr_model
self.speaker_model = speaker_model
def evaluate_quality_impact(self, clean_audio, noisy_audio, transcript):
"""
Compare performance on clean vs noisy audio
Returns:
Dictionary with metrics for both conditions
"""
# ASR on clean audio
clean_prediction = self.asr_model.transcribe(clean_audio)
clean_wer = self._calculate_wer(transcript, clean_prediction)
# ASR on noisy audio
noisy_prediction = self.asr_model.transcribe(noisy_audio)
noisy_wer = self._calculate_wer(transcript, noisy_prediction)
# Speaker embedding quality
clean_embedding = self.speaker_model.extract_embedding(clean_audio)
noisy_embedding = self.speaker_model.extract_embedding(noisy_audio)
# Embedding similarity (should be close for same speaker)
similarity = np.dot(clean_embedding, noisy_embedding) / (
np.linalg.norm(clean_embedding) * np.linalg.norm(noisy_embedding)
)
# Calculate SNR
snr_db = self._calculate_snr(clean_audio, noisy_audio)
return {
'snr_db': snr_db,
'clean_wer': clean_wer,
'noisy_wer': noisy_wer,
'wer_degradation': noisy_wer - clean_wer,
'embedding_similarity': similarity,
'relative_performance': clean_wer / noisy_wer if noisy_wer > 0 else 1.0
}
def _calculate_wer(self, reference, hypothesis):
"""Calculate Word Error Rate"""
import editdistance
ref_words = reference.lower().split()
hyp_words = hypothesis.lower().split()
distance = editdistance.eval(ref_words, hyp_words)
wer = distance / len(ref_words) if len(ref_words) > 0 else 0
return wer
def _calculate_snr(self, clean, noisy):
"""Calculate Signal-to-Noise Ratio"""
noise = noisy - clean
signal_power = np.mean(clean ** 2)
noise_power = np.mean(noise ** 2)
if noise_power == 0:
return float('inf')
snr = 10 * np.log10(signal_power / noise_power)
return snr
# Demo impact analysis
print("="*60)
print("AUDIO QUALITY IMPACT ANALYSIS")
print("="*60)
# Simulate different SNR levels
snr_levels = [-5, 0, 5, 10, 15, 20]
for snr_target in snr_levels:
# Add noise at specific SNR
noisy = add_noise_at_snr(clean_audio, noise, snr_target)
# Evaluate
results = analyzer.evaluate_quality_impact(clean_audio, noisy, transcript)
print(f"\nSNR: {snr_target} dB")
print(f" WER (clean): {results['clean_wer']:.2%}")
print(f" WER (noisy): {results['noisy_wer']:.2%}")
print(f" Degradation: {results['wer_degradation']:.2%}")
print(f" Speaker Sim: {results['embedding_similarity']:.3f}")
Frequency Domain Analysis
Understanding audio in frequency domain is crucial for enhancement:
class FrequencyDomainAnalyzer:
"""
Analyze and visualize audio in frequency domain
Essential for understanding what noise reduction does
"""
def __init__(self, sr=16000):
self.sr = sr
def analyze_spectrum(self, audio):
"""
Compute and visualize spectrum
Returns:
frequencies, magnitudes, phases
"""
# Compute FFT
n_fft = 2048
fft = np.fft.rfft(audio, n=n_fft)
# Magnitude and phase
magnitude = np.abs(fft)
phase = np.angle(fft)
# Frequency bins
frequencies = np.fft.rfftfreq(n_fft, 1/self.sr)
return frequencies, magnitude, phase
def compare_spectra(self, clean, noisy, enhanced):
"""
Compare spectra before and after enhancement
"""
import matplotlib.pyplot as plt
# Compute spectra
freq_clean, mag_clean, _ = self.analyze_spectrum(clean)
freq_noisy, mag_noisy, _ = self.analyze_spectrum(noisy)
freq_enhanced, mag_enhanced, _ = self.analyze_spectrum(enhanced)
# Plot
fig, axes = plt.subplots(3, 1, figsize=(12, 10))
# Clean
axes[0].plot(freq_clean, 20 * np.log10(mag_clean + 1e-10))
axes[0].set_title('Clean Audio Spectrum')
axes[0].set_ylabel('Magnitude (dB)')
axes[0].grid(True)
# Noisy
axes[1].plot(freq_noisy, 20 * np.log10(mag_noisy + 1e-10), color='red')
axes[1].set_title('Noisy Audio Spectrum')
axes[1].set_ylabel('Magnitude (dB)')
axes[1].grid(True)
# Enhanced
axes[2].plot(freq_enhanced, 20 * np.log10(mag_enhanced + 1e-10), color='green')
axes[2].set_title('Enhanced Audio Spectrum')
axes[2].set_xlabel('Frequency (Hz)')
axes[2].set_ylabel('Magnitude (dB)')
axes[2].grid(True)
plt.tight_layout()
plt.savefig('spectrum_comparison.png')
plt.close()
def compute_spectral_features(self, audio):
"""
Compute spectral features for quality assessment
"""
freq, mag, _ = self.analyze_spectrum(audio)
# Spectral centroid
centroid = np.sum(freq * mag) / np.sum(mag)
# Spectral bandwidth
bandwidth = np.sqrt(np.sum(((freq - centroid) ** 2) * mag) / np.sum(mag))
# Spectral flatness (Wiener entropy)
geometric_mean = np.exp(np.mean(np.log(mag + 1e-10)))
arithmetic_mean = np.mean(mag)
flatness = geometric_mean / arithmetic_mean
# Spectral rolloff (95% of energy)
cumsum = np.cumsum(mag)
rolloff_idx = np.where(cumsum >= 0.95 * cumsum[-1])[0][0]
rolloff = freq[rolloff_idx]
return {
'centroid_hz': centroid,
'bandwidth_hz': bandwidth,
'flatness': flatness,
'rolloff_hz': rolloff
}
# Usage
analyzer = FrequencyDomainAnalyzer(sr=16000)
# Analyze audio
features = analyzer.compute_spectral_features(audio)
print("Spectral Features:")
print(f" Centroid: {features['centroid_hz']:.1f} Hz")
print(f" Bandwidth: {features['bandwidth_hz']:.1f} Hz")
print(f" Flatness: {features['flatness']:.3f}")
print(f" Rolloff: {features['rolloff_hz']:.1f} Hz")
# Compare before/after
analyzer.compare_spectra(clean_audio, noisy_audio, enhanced_audio)
Advanced Deep Learning Enhancement
State-of-the-Art Architectures
class ConvTasNetEnhancer(nn.Module):
"""
Conv-TasNet for speech enhancement
Architecture:
1. Encoder: Waveform → Feature representation
2. Separator: Mask estimation using temporal convolutions
3. Decoder: Masked features → Enhanced waveform
Advantages over STFT-based methods:
- Operates on raw waveform
- Learnable basis functions
- Better phase reconstruction
"""
def __init__(
self,
n_src=1,
n_filters=512,
kernel_size=16,
stride=8,
n_blocks=8,
n_repeats=3,
bn_chan=128,
hid_chan=512,
skip_chan=128
):
super().__init__()
# Encoder: 1D conv
self.encoder = nn.Conv1d(
1,
n_filters,
kernel_size=kernel_size,
stride=stride,
padding=kernel_size // 2
)
# Separator: TCN blocks
self.separator = TemporalConvNet(
n_filters,
n_src,
n_blocks=n_blocks,
n_repeats=n_repeats,
bn_chan=bn_chan,
hid_chan=hid_chan,
skip_chan=skip_chan
)
# Decoder: 1D transposed conv
self.decoder = nn.ConvTranspose1d(
n_filters,
1,
kernel_size=kernel_size,
stride=stride,
padding=kernel_size // 2
)
def forward(self, waveform):
"""
Enhance waveform
Args:
waveform: [batch, time]
Returns:
enhanced: [batch, time]
"""
# Add channel dimension
x = waveform.unsqueeze(1) # [batch, 1, time]
# Encode
encoded = self.encoder(x) # [batch, n_filters, time']
# Separate (estimate mask)
masks = self.separator(encoded) # [batch, n_src, n_filters, time']
# Apply mask
masked = encoded.unsqueeze(1) * masks # [batch, n_src, n_filters, time']
# Decode
enhanced = self.decoder(masked.squeeze(1)) # [batch, 1, time]
# Remove channel dimension
enhanced = enhanced.squeeze(1) # [batch, time]
# Trim to original length
if enhanced.shape[-1] != waveform.shape[-1]:
enhanced = enhanced[..., :waveform.shape[-1]]
return enhanced
class TemporalConvNet(nn.Module):
"""
Temporal Convolutional Network for Conv-TasNet
Stack of dilated conv blocks with skip connections
"""
def __init__(
self,
n_filters,
n_src,
n_blocks=8,
n_repeats=3,
bn_chan=128,
hid_chan=512,
skip_chan=128
):
super().__init__()
# Layer norm
self.layer_norm = nn.GroupNorm(1, n_filters)
# Bottleneck
self.bottleneck = nn.Conv1d(n_filters, bn_chan, 1)
# TCN blocks
self.blocks = nn.ModuleList()
for r in range(n_repeats):
for b in range(n_blocks):
dilation = 2 ** b
self.blocks.append(
TCNBlock(
bn_chan,
hid_chan,
skip_chan,
kernel_size=3,
dilation=dilation
)
)
# Output
self.output = nn.Sequential(
nn.PReLU(),
nn.Conv1d(skip_chan, n_filters, 1),
nn.Sigmoid() # Mask should be [0, 1]
)
def forward(self, x):
"""
Args:
x: [batch, n_filters, time]
Returns:
masks: [batch, n_src, n_filters, time]
"""
# Normalize
x = self.layer_norm(x)
# Bottleneck
x = self.bottleneck(x) # [batch, bn_chan, time]
# Accumulate skip connections
skip_sum = 0
for block in self.blocks:
x, skip = block(x)
skip_sum = skip_sum + skip
# Output mask
masks = self.output(skip_sum)
# Unsqueeze for n_src dimension
masks = masks.unsqueeze(1) # [batch, 1, n_filters, time]
return masks
class TCNBlock(nn.Module):
"""Single TCN block with dilated convolution"""
def __init__(self, in_chan, hid_chan, skip_chan, kernel_size=3, dilation=1):
super().__init__()
self.conv1 = nn.Conv1d(
in_chan,
hid_chan,
1
)
self.prelu1 = nn.PReLU()
self.norm1 = nn.GroupNorm(1, hid_chan)
self.depthwise_conv = nn.Conv1d(
hid_chan,
hid_chan,
kernel_size,
padding=(kernel_size - 1) * dilation // 2,
dilation=dilation,
groups=hid_chan
)
self.prelu2 = nn.PReLU()
self.norm2 = nn.GroupNorm(1, hid_chan)
self.conv2 = nn.Conv1d(hid_chan, in_chan, 1)
self.skip_conv = nn.Conv1d(hid_chan, skip_chan, 1)
def forward(self, x):
"""
Args:
x: [batch, in_chan, time]
Returns:
output: [batch, in_chan, time]
skip: [batch, skip_chan, time]
"""
residual = x
# 1x1 conv
x = self.conv1(x)
x = self.prelu1(x)
x = self.norm1(x)
# Depthwise conv
x = self.depthwise_conv(x)
x = self.prelu2(x)
x = self.norm2(x)
# Skip connection
skip = self.skip_conv(x)
# Output
x = self.conv2(x)
# Residual
output = x + residual
return output, skip
# Training Conv-TasNet
class ConvTasNetTrainer:
"""
Train Conv-TasNet for speech enhancement
"""
def __init__(self, model, device='cuda'):
self.model = model.to(device)
self.device = device
# Optimizer
self.optimizer = torch.optim.Adam(
self.model.parameters(),
lr=1e-3
)
# Learning rate scheduler
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer,
mode='min',
factor=0.5,
patience=3
)
def train_epoch(self, train_loader):
"""Train one epoch"""
self.model.train()
total_loss = 0
for batch_idx, (noisy, clean) in enumerate(train_loader):
noisy = noisy.to(self.device)
clean = clean.to(self.device)
# Forward
enhanced = self.model(noisy)
# Loss: SI-SNR (Scale-Invariant SNR)
loss = self._si_snr_loss(enhanced, clean)
# Backward
self.optimizer.zero_grad()
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=5.0)
self.optimizer.step()
total_loss += loss.item()
if batch_idx % 100 == 0:
print(f"Batch {batch_idx}, Loss: {loss.item():.4f}")
return total_loss / len(train_loader)
def _si_snr_loss(self, estimate, target):
"""
Scale-Invariant Signal-to-Noise Ratio loss
Better than MSE for speech enhancement
"""
# Zero-mean
estimate_zm = estimate - estimate.mean(dim=-1, keepdim=True)
target_zm = target - target.mean(dim=-1, keepdim=True)
# <s', s>s / ||s||^2
dot = (estimate_zm * target_zm).sum(dim=-1, keepdim=True)
target_energy = (target_zm ** 2).sum(dim=-1, keepdim=True)
projection = dot * target_zm / (target_energy + 1e-8)
# Noise
noise = estimate_zm - projection
# SI-SNR
si_snr = 10 * torch.log10(
(projection ** 2).sum(dim=-1) / (noise ** 2).sum(dim=-1) + 1e-8
)
# Negative for loss (we want to maximize SI-SNR)
return -si_snr.mean()
# Usage
model = ConvTasNetEnhancer()
trainer = ConvTasNetTrainer(model, device='cuda')
# Train
for epoch in range(num_epochs):
train_loss = trainer.train_epoch(train_loader)
val_loss = trainer.validate(val_loader)
print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
trainer.scheduler.step(val_loss)
Real-Time Enhancement with ONNX
class RealTimeONNXEnhancer:
"""
Real-time enhancement using ONNX Runtime
Optimized for production deployment
"""
def __init__(self, onnx_model_path, chunk_size=4800):
"""
Args:
onnx_model_path: Path to exported ONNX model
chunk_size: Audio chunk size (samples)
"""
import onnxruntime as ort
self.chunk_size = chunk_size
# Load ONNX model
self.session = ort.InferenceSession(
onnx_model_path,
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
# Get input/output names
self.input_name = self.session.get_inputs()[0].name
self.output_name = self.session.get_outputs()[0].name
# State for streaming
self.reset_state()
def reset_state(self):
"""Reset streaming state"""
self.overlap_buffer = np.zeros(self.chunk_size // 2, dtype=np.float32)
def enhance_chunk(self, audio_chunk):
"""
Enhance single audio chunk with overlap-add
Args:
audio_chunk: [chunk_size] numpy array
Returns:
enhanced_chunk: [chunk_size] numpy array
"""
# Prepare input (batch dimension)
input_data = audio_chunk.astype(np.float32)[np.newaxis, :]
# Run inference
enhanced = self.session.run(
[self.output_name],
{self.input_name: input_data}
)[0][0]
# Overlap-add
overlap_size = len(self.overlap_buffer)
enhanced[:overlap_size] += self.overlap_buffer
# Save overlap for next chunk
self.overlap_buffer = enhanced[-overlap_size:].copy()
# Return without overlap region
return enhanced[:-overlap_size]
def enhance_stream(self, audio_stream):
"""
Enhance audio stream in real-time
Generator that yields enhanced chunks
"""
for chunk in audio_stream:
# Ensure correct size
if len(chunk) != self.chunk_size:
# Pad or skip
continue
# Enhance
enhanced = self.enhance_chunk(chunk)
yield enhanced
# Export PyTorch model to ONNX
def export_to_onnx(pytorch_model, onnx_path, chunk_size=4800):
"""
Export trained PyTorch model to ONNX
"""
pytorch_model.eval()
# Dummy input
dummy_input = torch.randn(1, chunk_size)
# Export
torch.onnx.export(
pytorch_model,
dummy_input,
onnx_path,
input_names=['audio_input'],
output_names=['audio_output'],
dynamic_axes={
'audio_input': {1: 'time'},
'audio_output': {1: 'time'}
},
opset_version=14
)
print(f"Model exported to {onnx_path}")
# Usage
# Export model
export_to_onnx(trained_model, 'convtasnet_enhancer.onnx')
# Create real-time enhancer
enhancer = RealTimeONNXEnhancer('convtasnet_enhancer.onnx', chunk_size=4800)
# Stream audio
def audio_stream_generator():
"""Generate audio chunks from microphone/file"""
# Implementation depends on audio source
pass
# Enhance stream
for enhanced_chunk in enhancer.enhance_stream(audio_stream_generator()):
# Play or save enhanced audio
pass
Production Quality Assurance
Automated Quality Metrics
class EnhancementQualityAssurance:
"""
Automated quality assurance for enhancement pipeline
Monitors:
- SNR improvement
- Speech intelligibility
- Artifacts
- Latency
"""
def __init__(self):
self.metrics_history = []
def assess_quality(self, original, enhanced, reference=None):
"""
Comprehensive quality assessment
Args:
original: Noisy input
enhanced: Enhanced output
reference: Clean reference (if available)
Returns:
Quality metrics dictionary
"""
metrics = {}
# SNR improvement (requires reference)
if reference is not None:
original_snr = self._compute_snr(original, reference)
enhanced_snr = self._compute_snr(enhanced, reference)
metrics['snr_improvement_db'] = enhanced_snr - original_snr
# PESQ (Perceptual Evaluation of Speech Quality)
from pesq import pesq
metrics['pesq_original'] = pesq(16000, reference, original, 'wb')
metrics['pesq_enhanced'] = pesq(16000, reference, enhanced, 'wb')
metrics['pesq_improvement'] = (
metrics['pesq_enhanced'] - metrics['pesq_original']
)
# STOI (Short-Time Objective Intelligibility)
from pystoi import stoi
metrics['stoi_original'] = stoi(reference, original, 16000)
metrics['stoi_enhanced'] = stoi(reference, enhanced, 16000)
metrics['stoi_improvement'] = (
metrics['stoi_enhanced'] - metrics['stoi_original']
)
# Artifact detection (no reference needed)
metrics['artifact_score'] = self._detect_artifacts(enhanced)
# Spectral distortion
metrics['spectral_distortion'] = self._compute_spectral_distortion(
original, enhanced
)
# Dynamic range
metrics['dynamic_range_db'] = 20 * np.log10(
np.max(np.abs(enhanced)) / (np.mean(np.abs(enhanced)) + 1e-8)
)
# Clipping detection
metrics['clipping_ratio'] = np.mean(np.abs(enhanced) > 0.99)
# Overall quality score
metrics['quality_score'] = self._compute_overall_score(metrics)
self.metrics_history.append(metrics)
return metrics
def _compute_snr(self, signal, reference):
"""Compute SNR"""
noise = signal - reference
signal_power = np.mean(reference ** 2)
noise_power = np.mean(noise ** 2)
if noise_power == 0:
return float('inf')
snr_db = 10 * np.log10(signal_power / noise_power)
return snr_db
def _detect_artifacts(self, audio):
"""
Detect musical noise and other artifacts
Returns:
Artifact score (0-1, lower is better)
"""
# Compute spectrogram
S = librosa.stft(audio)
magnitude = np.abs(S)
# Temporal variation
temporal_diff = np.diff(magnitude, axis=1)
temporal_variance = np.var(temporal_diff)
# Spectral variation
spectral_diff = np.diff(magnitude, axis=0)
spectral_variance = np.var(spectral_diff)
# High variance indicates artifacts
artifact_score = (temporal_variance + spectral_variance) / 2
# Normalize to [0, 1]
artifact_score = np.clip(artifact_score / 100, 0, 1)
return artifact_score
def _compute_spectral_distortion(self, original, enhanced):
"""
Compute spectral distortion
Measures how much the spectrum changed
"""
# Compute spectrograms
S_orig = np.abs(librosa.stft(original))
S_enh = np.abs(librosa.stft(enhanced))
# Log magnitude
S_orig_db = librosa.amplitude_to_db(S_orig + 1e-10)
S_enh_db = librosa.amplitude_to_db(S_enh + 1e-10)
# MSE in log domain
distortion = np.mean((S_orig_db - S_enh_db) ** 2)
return distortion
def _compute_overall_score(self, metrics):
"""
Compute overall quality score
Weighted combination of metrics
"""
score = 0.0
# PESQ improvement (if available)
if 'pesq_improvement' in metrics:
score += 0.4 * np.clip(metrics['pesq_improvement'] / 2, 0, 1)
# STOI improvement (if available)
if 'stoi_improvement' in metrics:
score += 0.4 * np.clip(metrics['stoi_improvement'], 0, 1)
# Artifact penalty
score -= 0.2 * metrics['artifact_score']
# Normalize to [0, 1]
score = np.clip(score, 0, 1)
return score
def generate_report(self):
"""Generate quality assurance report"""
if not self.metrics_history:
print("No metrics recorded")
return
# Aggregate metrics
avg_metrics = {}
for key in self.metrics_history[0].keys():
values = [m[key] for m in self.metrics_history if key in m]
avg_metrics[key] = np.mean(values)
print("\n" + "="*60)
print("ENHANCEMENT QUALITY ASSURANCE REPORT")
print("="*60)
print(f"Samples Evaluated: {len(self.metrics_history)}")
print(f"\nAverage Metrics:")
for key, value in avg_metrics.items():
print(f" {key:30s}: {value:.4f}")
# Pass/fail criteria
print(f"\n{'Criterion':<30s} {'Status':>10s}")
print("-" * 42)
checks = [
('SNR Improvement', avg_metrics.get('snr_improvement_db', 0) > 3, '>3 dB'),
('PESQ Improvement', avg_metrics.get('pesq_improvement', 0) > 0.5, '>0.5'),
('STOI Improvement', avg_metrics.get('stoi_improvement', 0) > 0.1, '>0.1'),
('Artifact Score', avg_metrics.get('artifact_score', 1) < 0.3, '<0.3'),
('Clipping Ratio', avg_metrics.get('clipping_ratio', 1) < 0.01, '<1%'),
]
all_passed = True
for name, passed, threshold in checks:
status = "✓ PASS" if passed else "✗ FAIL"
all_passed = all_passed and passed
print(f" {name:<30s} {status:>10s} ({threshold})")
print("-" * 42)
print(f" {'Overall Result':<30s} {'✓ PASS' if all_passed else '✗ FAIL':>10s}")
print("="*60)
# Usage
qa = EnhancementQualityAssurance()
# Evaluate multiple files
for noisy_file, clean_file in test_pairs:
noisy_audio, _ = librosa.load(noisy_file, sr=16000)
clean_audio, _ = librosa.load(clean_file, sr=16000)
# Enhance
enhanced_audio = enhancer.enhance(noisy_audio)
# Assess quality
metrics = qa.assess_quality(noisy_audio, enhanced_audio, clean_audio)
# Generate report
qa.generate_report()
Key Takeaways
✅ Multiple approaches - Classical (spectral subtraction, Wiener) and deep learning
✅ Quality metrics - PESQ, STOI, SNR for evaluation
✅ Real-time processing - Streaming with low latency < 50ms
✅ Multi-channel - Beamforming for spatial enhancement
✅ Caching benefits - Reduce computational cost for repeated segments
✅ Trade-offs - Quality vs latency vs computational cost
✅ Production considerations - Monitoring, fallback, quality control
Originally published at: arunbaby.com/speech-tech/0010-voice-enhancement
If you found this helpful, consider sharing it with others who might benefit.