Speech Pipeline Orchestration
“Orchestrating complex speech processing pipelines from audio ingestion to final output.”
1. Speech Pipeline Architecture
A production speech system has multiple stages with complex dependencies:
Audio Input → VAD → Segmentation → ASR → NLU → Action
↓ ↓ ↓ ↓ ↓ ↓
Quality Speaker Speaker Language Post- Error
Check Diarization ID Detection process Handling
Key Differences from General ML Pipelines:
- Real-time Constraints: Audio must be processed with < 500ms latency.
- Streaming Data: Continuous audio stream, not batches.
- State Management: Need to maintain conversational context.
- Multi-modal: Combine audio with text, user profile, location.
2. Real-time vs. Batch Speech Pipelines
Batch Pipeline (Offline Transcription)
Use Case: Transcribe uploaded podcast episodes, meeting recordings.
Architecture:
from airflow import DAG
from airflow.operators.python import PythonOperator
def audio_preprocessing(**context):
audio_path = context['params']['audio_path']
# 1. Format conversion
wav_audio = convert_to_wav(audio_path, sample_rate=16000)
# 2. Noise reduction
clean_audio = noise_reduction(wav_audio)
# 3. Volume normalization
normalized = normalize_volume(clean_audio)
save_to_gcs(normalized, f"preprocessed_{context['ds']}.wav")
def voice_activity_detection(**context):
audio = load_from_gcs(f"preprocessed_{context['ds']}.wav")
# Detect speech segments
segments = vad_model.predict(audio) # [(start_time, end_time), ...]
context['ti'].xcom_push(key='segments', value=segments)
def speaker_diarization(**context):
audio = load_from_gcs(f"preprocessed_{context['ds']}.wav")
segments = context['ti'].xcom_pull(task_ids='vad', key='segments')
# Extract speaker embeddings (x-vectors)
embeddings = [extract_xvector(audio[start:end]) for start, end in segments]
# Cluster speakers
from sklearn.cluster import AgglomerativeClustering
clustering = AgglomerativeClustering(n_clusters=None, distance_threshold=0.5)
speaker_labels = clustering.fit_predict(embeddings)
context['ti'].xcom_push(key='speaker_labels', value=speaker_labels)
def asr_transcription(**context):
audio = load_from_gcs(f"preprocessed_{context['ds']}.wav")
segments = context['ti'].xcom_pull(task_ids='vad', key='segments')
speaker_labels = context['ti'].xcom_pull(task_ids='diarization', key='speaker_labels')
transcripts = []
for (start, end), speaker in zip(segments, speaker_labels):
segment_audio = audio[start:end]
text = asr_model.transcribe(segment_audio)
transcripts.append({
'speaker': f'Speaker_{speaker}',
'start': start,
'end': end,
'text': text
})
save_transcripts(transcripts, f"transcript_{context['ds']}.json")
# DAG definition
dag = DAG('speech_transcription_pipeline', ...)
preprocess = PythonOperator(task_id='audio_preprocessing', ...)
vad = PythonOperator(task_id='voice_activity_detection', ...)
diarization = PythonOperator(task_id='speaker_diarization', ...)
transcription = PythonOperator(task_id='asr_transcription', ...)
preprocess >> vad >> diarization >> transcription
Streaming Pipeline (Real-time Transcription)
Use Case: Live captioning, voice assistants.
Architecture (using Kafka + Kubernetes):
Audio Stream (Kafka) → VAD Service → ASR Service → NLU Service → Output
↓ ↓ ↓
K8s Pod K8s Pod K8s Pod
(auto-scale) (GPU) (CPU)
Implementation:
from kafka import KafkaConsumer, KafkaProducer
import torch
class StreamingASRService:
def __init__(self):
self.consumer = KafkaConsumer(
'audio_stream',
bootstrap_servers=['kafka:9092'],
value_deserializer=lambda m: pickle.loads(m)
)
self.producer = KafkaProducer(
bootstrap_servers=['kafka:9092'],
value_serializer=lambda m: pickle.dumps(m)
)
self.asr_model = load_streaming_asr_model()
self.state = {} # conversation state per session
def process_audio_chunk(self, chunk):
session_id = chunk['session_id']
audio = chunk['audio'] # 80ms chunk
# Get or initialize state
if session_id not in self.state:
self.state[session_id] = {
'hidden': None,
'partial_transcript': ''
}
# Streaming inference
logits, hidden = self.asr_model.forward_chunk(
audio,
prev_hidden=self.state[session_id]['hidden']
)
# Update state
self.state[session_id]['hidden'] = hidden
# Decode
tokens = self.asr_model.decode(logits)
text = self.asr_model.tokens_to_text(tokens)
# Emit partial result
return {
'session_id': session_id,
'text': text,
'is_final': False,
'timestamp': chunk['timestamp']
}
def run(self):
for message in self.consumer:
chunk = message.value
result = self.process_audio_chunk(chunk)
# Send to next stage
self.producer.send('asr_output', result)
# Kubernetes Deployment
"""
apiVersion: apps/v1
kind: Deployment
metadata:
name: asr-service
spec:
replicas: 10 # Auto-scale based on load
template:
spec:
containers:
- name: asr
image: asr-streaming:v1
resources:
requests:
nvidia.com/gpu: 1
limits:
nvidia.com/gpu: 1
"""
3. Dependency Management in Speech Pipelines
Sequential Dependencies
Audio → Preprocessing → VAD → ASR
Each stage must complete before the next.
Parallel Dependencies (Fan-Out)
→ Speaker Diarization →
Audio → VAD → → Merge → Output
→ Language Detection →
Airflow Implementation:
vad_task >> [diarization_task, language_detection_task] >> merge_task
Conditional Dependencies
Audio → ASR → (if confidence > 0.8) → Output
↓ (else)
Human Review
Airflow BranchPythonOperator:
from airflow.operators.python import BranchPythonOperator
def check_confidence(**context):
confidence = context['ti'].xcom_pull(task_ids='asr', key='confidence')
if confidence > 0.8:
return 'output_task'
else:
return 'human_review_task'
branch = BranchPythonOperator(
task_id='check_confidence',
python_callable=check_confidence,
dag=dag
)
asr_task >> branch >> [output_task, human_review_task]
4. State Management in Streaming Speech
Challenge: Maintain context across audio chunks.
Example: User says “Play… [pause]… Taylor Swift.”
- Chunk 1: “Play”
- Chunk 2: “” (pause)
- Chunk 3: “Taylor Swift”
Need to remember: “Play” from Chunk 1 when processing Chunk 3.
Solution: Stateful Stream Processing
class StatefulASRProcessor:
def __init__(self):
self.sessions = {} # {session_id: {hidden_state, partial_text, ...}}
def process_chunk(self, session_id, audio_chunk):
if session_id not in self.sessions:
self.sessions[session_id] = {
'hidden': None,
'partial': '',
'last_update': time.time()
}
session = self.sessions[session_id]
# Streaming RNN-T forward pass
logits, new_hidden = self.model.forward(
audio_chunk,
prev_hidden=session['hidden']
)
# Update state
session['hidden'] = new_hidden
session['last_update'] = time.time()
# Decode
new_tokens = self.decode(logits)
session['partial'] += new_tokens
return session['partial']
def cleanup_old_sessions(self, timeout=300):
"""Remove sessions inactive for > 5 minutes"""
now = time.time()
to_remove = [sid for sid, s in self.sessions.items() if now - s['last_update'] > timeout]
for sid in to_remove:
del self.sessions[sid]
Deep Dive: Google Speech-to-Text Pipeline
Architecture:
Client Audio → Load Balancer → ASR Frontend → ASR Backend (BPod)
↓ ↓
Language Model Acoustic Model
Pipeline Stages:
- Audio Ingestion: Client streams audio via gRPC.
- Load Balancing: Route to nearest datacenter.
- Feature Extraction (Frontend): Compute mel-spectrograms.
- ASR Backend (BPod - Brain Pod):
- Acoustic Model: Conformer-based RNN-T.
- Language Model: Neural LM for rescoring.
- NLU (Optional): Intent classification, slot filling.
- Response: Return transcript to client.
Dependency Chain:
- Feature extraction must complete before acoustic model can run.
- Acoustic model emits tokens in real-time.
- Language model rescores N-best hypotheses.
Orchestration:
- Kubernetes: Auto-scale ASR pods based on request load.
- Airflow: Manage batch model training/deployment pipelines.
Deep Dive: Amazon Transcribe Architecture
Batch Transcription Pipeline:
User uploads audio to S3 → S3 Event triggers Lambda → Lambda starts Transcribe Job
↓
Job Queue (SQS)
↓
Worker Pools (EC2/Fargate)
↓
1. VAD → 2. Diarization → 3. ASR
↓
Results saved to S3
↓
Notification (SNS)
Orchestration:
- Step Functions: Coordinate multi-stage processing.
- SQS: Queue jobs to handle bursts.
- Auto-Scaling: Scale worker pools based on queue depth.
Dependency Graph:
{
"StartAt": "AudioPreprocessing",
"States": {
"AudioPreprocessing": {
"Type": "Task",
"Resource": "arn:aws:lambda:...:function:preprocess-audio",
"Next": "VAD"
},
"VAD": {
"Type": "Task",
"Resource": "arn:aws:lambda:...:function:voice-activity-detection",
"Next": "ParallelProcessing"
},
"ParallelProcessing": {
"Type": "Parallel",
"Branches": [
{
"StartAt": "SpeakerDiarization",
"States": {
"SpeakerDiarization": {
"Type": "Task",
"Resource": "arn:aws:lambda:...:function:diarization",
"End": true
}
}
},
{
"StartAt": "LanguageDetection",
"States": {
"LanguageDetection": {
"Type": "Task",
"Resource": "arn:aws:lambda:...:function:language-detection",
"End": true
}
}
}
],
"Next": "ASRTranscription"
},
"ASRTranscription": {
"Type": "Task",
"Resource": "arn:aws:lambda:...:function:asr",
"End": true
}
}
}
Deep Dive: Handling Audio Format Diversity
Challenge: Input audio comes in 100+ formats (MP3, AAC, FLAC, OGG, …).
Solution: Format Normalization Pipeline
def normalize_audio(input_path, output_path):
"""
Convert any audio format to WAV with:
- Sample rate: 16kHz
- Channels: Mono
- Bit depth: 16-bit PCM
"""
import ffmpeg
(
ffmpeg
.input(input_path)
.output(output_path, ar=16000, ac=1, f='wav')
.run(overwrite_output=True, quiet=True)
)
# Airflow Task
normalize = PythonOperator(
task_id='normalize_audio',
python_callable=normalize_audio,
op_kwargs={'input_path': '{{ params.input }}', 'output_path': '/tmp/normalized.wav'},
dag=dag
)
Deep Dive: Quality Gates in Speech Pipelines
Problem: Don’t want to deploy a model with 50% WER.
Solution: Quality Gates
def evaluate_model(**context):
model_path = context['params']['model_path']
test_set = load_test_set()
wer = compute_wer(model_path, test_set)
cer = compute_cer(model_path, test_set)
# Quality gates
if wer > 0.15: # 15% WER threshold
raise ValueError(f"WER too high: {wer:.2%}")
if cer > 0.08: # 8% CER threshold
raise ValueError(f"CER too high: {cer:.2%}")
print(f"Model passed quality gates. WER: {wer:.2%}, CER: {cer:.2%}")
return model_path
evaluate_task = PythonOperator(
task_id='evaluate_model',
python_callable=evaluate_model,
dag=dag
)
# If evaluation fails, pipeline stops (model not deployed)
train_task >> evaluate_task >> deploy_task
Deep Dive: Backfilling Speech Model Training
Scenario: You have 3 years of call center recordings. Want to train ASR models for each quarter.
Backfill Strategy:
# Airflow DAG with catchup=True
dag = DAG(
'train_asr_quarterly',
schedule_interval='0 0 1 */3 *', # Every 3 months
start_date=datetime(2021, 1, 1),
catchup=True, # Run for all past quarters
max_active_runs=2 # Limit parallelism (training is expensive)
)
def train_for_quarter(**context):
quarter_start = context['data_interval_start']
quarter_end = context['data_interval_end']
# Fetch audio from this quarter
audio_files = fetch_audio_between(quarter_start, quarter_end)
# Train model
model = train_asr(audio_files)
# Save with version = quarter
save_model(model, f"asr_model_Q{quarter_start.quarter}_{quarter_start.year}.pt")
train_task = PythonOperator(task_id='train', python_callable=train_for_quarter, dag=dag)
Deep Dive: Monitoring Speech Pipeline Health
Key Metrics:
- WER (Word Error Rate): Track per language, per domain.
- Latency: p50, p95, p99 latency for each pipeline stage.
- Throughput: Audio hours processed per second.
- Error Rate: % of jobs that fail (network errors, bad audio, etc.).
Prometheus Metrics:
from prometheus_client import Counter, Histogram
asr_requests = Counter('asr_requests_total', 'Total ASR requests')
asr_latency = Histogram('asr_latency_seconds', 'ASR latency')
asr_wer = Histogram('asr_wer', 'Word Error Rate')
@asr_latency.time()
def transcribe(audio):
asr_requests.inc()
transcript = model.transcribe(audio)
# If we have ground truth, compute WER
if ground_truth:
wer = compute_wer(transcript, ground_truth)
asr_wer.observe(wer)
return transcript
Grafana Dashboard:
- Panel 1: ASR latency (p95) over time.
- Panel 2: WER by language.
- Panel 3: Throughput (audio hours/second).
- Panel 4: Error rate (%).
Implementation: Complete Speech Orchestration Pipeline
from airflow import DAG
from airflow.operators.python import PythonOperator, BranchPythonOperator
from datetime import datetime, timedelta
default_args = {
'owner': 'speech-team',
'retries': 2,
'retry_delay': timedelta(minutes=1),
'execution_timeout': timedelta(hours=2),
}
dag = DAG(
'speech_processing_pipeline',
default_args=default_args,
schedule_interval='@hourly',
start_date=datetime(2024, 1, 1),
catchup=False
)
def fetch_audio(**context):
"""Fetch new audio files from S3"""
# List files uploaded in the last hour
files = s3_client.list_objects(Bucket='audio-uploads', Prefix=f"{context['ds']}/")
audio_paths = [f['Key'] for f in files]
context['ti'].xcom_push(key='audio_paths', value=audio_paths)
if not audio_paths:
return 'skip_processing' # Branch to end if no files
return 'preprocess_audio'
def preprocess_audio(**context):
paths = context['ti'].xcom_pull(task_ids='fetch', key='audio_paths')
for path in paths:
# Download, normalize, denoise
audio = download_and_normalize(path)
save_for_processing(audio, path)
def voice_activity_detection(**context):
paths = context['ti'].xcom_pull(task_ids='fetch', key='audio_paths')
segments = []
for path in paths:
audio = load(path)
file_segments = vad_model.predict(audio)
segments.append({'file': path, 'segments': file_segments})
context['ti'].xcom_push(key='segments', value=segments)
def asr_transcription(**context):
segments_data = context['ti'].xcom_pull(task_ids='vad', key='segments')
for file_data in segments_data:
audio = load(file_data['file'])
transcripts = []
for start, end in file_data['segments']:
segment = audio[start:end]
text, confidence = asr_model.transcribe(segment, return_confidence=True)
transcripts.append({
'start': start,
'end': end,
'text': text,
'confidence': confidence
})
save_transcripts(file_data['file'], transcripts)
def quality_check(**context):
"""Check if transcriptions meet quality threshold"""
# Load transcripts for this run
transcripts = load_all_transcripts(context['ds'])
avg_confidence = sum(t['confidence'] for t in transcripts) / len(transcripts)
if avg_confidence < 0.7:
# Send alert
send_slack_alert(f"Low confidence transcriptions: {avg_confidence:.2%}")
print(f"Average confidence: {avg_confidence:.2%}")
# Define tasks
fetch = BranchPythonOperator(task_id='fetch', python_callable=fetch_audio, dag=dag)
preprocess = PythonOperator(task_id='preprocess_audio', python_callable=preprocess_audio, dag=dag)
vad = PythonOperator(task_id='voice_activity_detection', python_callable=voice_activity_detection, dag=dag)
transcribe = PythonOperator(task_id='asr_transcription', python_callable=asr_transcription, dag=dag)
quality = PythonOperator(task_id='quality_check', python_callable=quality_check, dag=dag)
skip = EmptyOperator(task_id='skip_processing', dag=dag)
# Dependencies
fetch >> [preprocess, skip]
preprocess >> vad >> transcribe >> quality
Top Interview Questions
Q1: How do you handle real-time speech processing latency requirements? Answer: Use streaming models (RNN-T, Conformer), process audio in small chunks (80ms), deploy on GPUs for fast inference, use edge computing to reduce network latency, and implement speculative execution.
Q2: What’s the difference between online and offline speech pipeline orchestration? Answer:
- Online (Real-time): Low latency (<500ms), stateful processing, use Kafka/gRPC streaming, K8s for auto-scaling.
- Offline (Batch): Process large audio files, can use expensive models, orchestrated with Airflow/Step Functions.
Q3: How do you handle pipeline failures in production? Answer: Idempotent tasks, automatic retries with exponential backoff, dead-letter queues for permanent failures, monitoring/alerting (PagerDuty), graceful degradation (fallback to simpler model).
Q4: How do you orchestrate multi-language speech pipelines? Answer: Use language detection as first step, branch to language-specific ASR models, share common preprocessing (VAD, denoising), use multilingual models where possible (reduces pipeline complexity).
Key Takeaways
- DAG Structure: Speech pipelines are DAGs with stages: preprocessing, VAD, diarization, ASR, NLU.
- Real-time vs. Batch: Real-time uses Kafka/K8s streaming, batch uses Airflow orchestration.
- State Management: Essential for streaming speech (maintain context across chunks).
- Quality Gates: Check WER/CER before deploying models.
- Monitoring: Track latency, WER, throughput, error rate.
Summary
| Aspect | Insight |
|---|---|
| Core Challenge | Orchestrate multi-stage speech processing with dependencies |
| Real-time Tools | Kafka, Kubernetes, gRPC streaming |
| Batch Tools | Airflow, AWS Step Functions |
| Key Patterns | Sequential, Parallel (fan-out), Conditional branching |
Originally published at: arunbaby.com/speech-tech/0031-speech-pipeline-orchestration
If you found this helpful, consider sharing it with others who might benefit.