Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion zaban_backend/app/api/v1/voiceprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ def get_verifier(request: Request):
return verifier


def get_diarization_service(request: Request):
"""Dependency to get the diarization service from app state."""
return getattr(request.app.state, "diarization_service", None)


def _service_unavailable_response() -> JSONResponse:
"""Return a graceful response when voiceprint service is unavailable."""
if not settings.VOICEPRINT_ENABLED:
Expand Down Expand Up @@ -121,7 +126,8 @@ async def verify_voiceprint(
device_id: Optional[str] = Form(None),
file: UploadFile = File(None),
db: Session = Depends(get_db),
verifier=Depends(get_verifier)
verifier=Depends(get_verifier),
diarization_service=Depends(get_diarization_service)
):
"""Verify a user's voice against their enrolled voiceprint."""
if verifier is None:
Expand All @@ -147,6 +153,23 @@ async def verify_voiceprint(
async with await open_file(temp_path, "wb") as f:
await f.write(audio_content)

# Diarization check (ensure single speaker)
if diarization_service:
# Run diarization in executor as it's CPU/GPU heavy
import asyncio
loop = asyncio.get_running_loop()
diarization_result = await loop.run_in_executor(
None, diarization_service.diarize, temp_path
)

unique_speakers = len(set(segment["speaker"] for segment in diarization_result))
if unique_speakers > 1:
return VerificationResponse(
verified=False,
threshold=settings.VERIFICATION_THRESHOLD,
error=f"Multiple speakers detected ({unique_speakers}). Please provide audio with only one speaker."
)

# Verify
result = await verifier.verify_speaker(temp_path, str(customer_id))

Expand Down
11 changes: 11 additions & 0 deletions zaban_backend/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,17 @@ async def startup_event():
else:
print("ℹ️ Voiceprint service disabled (VOICEPRINT_ENABLED=false)")

# Initialize Diarization Service
try:
from .services.diarization.service import DiarizationService
app.state.diarization_service = DiarizationService()
print("✅ Diarization service initialized.")
except Exception as e:
import traceback
print(f"⚠️ Diarization service initialization failed: {e}")
traceback.print_exc()
app.state.diarization_service = None


@app.get("/up")
async def up():
Expand Down
69 changes: 69 additions & 0 deletions zaban_backend/app/services/diarization/clustering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# app/services/diarization/clustering.py

from sklearn.cluster import AgglomerativeClustering
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

def is_single_speaker(embeddings, threshold=0.75):
"""
Check if all embeddings are similar → single speaker
"""
if len(embeddings) < 2:
return True

sims = cosine_similarity(embeddings)

# ignore diagonal
avg_sim = (np.sum(sims) - len(sims)) / (len(sims)**2 - len(sims))

return avg_sim > threshold

class SpeakerClustering:
def __init__(self, distance_threshold=0.6, single_speaker_threshold=0.7):
"""
distance_threshold: For AgglomerativeClustering (1 - cosine_similarity)
single_speaker_threshold: Cosine similarity threshold for Pre-Diarization Decision Layer
"""
self.distance_threshold = distance_threshold
self.single_speaker_threshold = single_speaker_threshold

def cluster(self, segments):
"""
segments: [{start, end, embedding}]
"""
if not segments:
return []

embeddings = np.array([s["embedding"] for s in segments])

# Normalize embeddings for cosine distance
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
norms[norms == 0] = 1.0 # avoid division by zero
embeddings = embeddings / norms

# Pre-Diarization Decision Layer
if is_single_speaker(embeddings, threshold=self.single_speaker_threshold):
print(f"Single speaker detected (sim > {self.single_speaker_threshold}). Skipping clustering.")
for seg in segments:
seg["speaker"] = "speaker_0"
return segments

# Clustering for multiple speakers
try:
clustering = AgglomerativeClustering(
n_clusters=None,
distance_threshold=self.distance_threshold,
metric="cosine",
linkage="average"
)
labels = clustering.fit_predict(embeddings)
except Exception as e:
print(f"Clustering failed: {e}. Falling back to single speaker.")
for seg in segments:
seg["speaker"] = "speaker_0"
return segments

for seg, label in zip(segments, labels):
seg["speaker"] = f"speaker_{label}"

return segments
Empty file.
91 changes: 91 additions & 0 deletions zaban_backend/app/services/diarization/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# app/services/diarization/pipeline.py

import torch
import librosa
from speechbrain.inference import EncoderClassifier
from speechbrain.inference.VAD import VAD


class DiarizationPipeline:
def __init__(self):
# Speaker embedding model (ECAPA)
self.embedding_model = EncoderClassifier.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb",
run_opts={"device": "cpu"}
)

# Voice Activity Detection
self.vad = VAD.from_hparams(
source="speechbrain/vad-crdnn-libriparty",
run_opts={"device": "cpu"}
)

def load_audio(self, path):
waveform, sr = librosa.load(path, sr=16000)
return torch.tensor(waveform).unsqueeze(0), sr

def get_speech_segments(self, audio_path):
"""
Returns speech segments [(start, end)]
"""
boundaries = self.vad.get_speech_segments(audio_path)
segments = []

for seg in boundaries:
start = float(seg[0])
end = float(seg[1])
segments.append((start, end))
print("VAD output:", boundaries)

return segments

def get_embedding(self, waveform):
"""
Extract speaker embedding
"""
with torch.no_grad():
emb = self.embedding_model.encode_batch(waveform)
return emb.squeeze().numpy()

def extract_embeddings(self, audio_path, segments):
waveform, sr = self.load_audio(audio_path)

embeddings = []

CHUNK_DURATION = 1.5 # seconds
STRIDE = 0.75 # overlap

chunk_size = int(CHUNK_DURATION * sr)
stride_size = int(STRIDE * sr)

for start, end in segments:
duration = end - start
start_sample = int(start * sr)
end_sample = int(end * sr)
segment_wave = waveform[:, start_sample:end_sample]

if segment_wave.shape[1] == 0:
continue

# If segment is shorter than chunk_size, take the whole segment
if duration < CHUNK_DURATION:
emb = self.get_embedding(segment_wave)
embeddings.append({
"start": start,
"end": end,
"embedding": emb
})
continue

# For longer segments, use sliding window
for i in range(0, segment_wave.shape[1] - chunk_size + 1, stride_size):
chunk = segment_wave[:, i:i + chunk_size]

emb = self.get_embedding(chunk)
embeddings.append({
"start": start + (i / sr),
"end": start + (min(i + chunk_size, segment_wave.shape[1]) / sr),
"embedding": emb
})

return embeddings
77 changes: 77 additions & 0 deletions zaban_backend/app/services/diarization/segmenter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# app/services/diarization/segmenter.py

class Segmenter:
@staticmethod
def resolve_overlaps(segments):
"""
Convert overlapping segments into timeline using majority voting
"""
timeline = []

for seg in segments:
timeline.append((seg["start"], "start", seg["speaker"]))
timeline.append((seg["end"], "end", seg["speaker"]))

timeline.sort()

active = []
result = []

last_time = None

for time, typ, speaker in timeline:
if last_time is not None and active:
# pick most frequent speaker
speaker_counts = {}
for s in active:
speaker_counts[s] = speaker_counts.get(s, 0) + 1

dominant = max(speaker_counts, key=speaker_counts.get)

result.append({
"start": last_time,
"end": time,
"speaker": dominant
})

if typ == "start":
active.append(speaker)
else:
if speaker in active:
active.remove(speaker)

last_time = time

return result

@staticmethod
def merge_segments(segments, gap_threshold=0.5):
if not segments:
return []

segments = sorted(segments, key=lambda x: x["start"])
merged = [segments[0]]

for current in segments[1:]:
last = merged[-1]

if (
current["speaker"] == last["speaker"]
and (current["start"] - last["end"]) <= gap_threshold
):
last["end"] = current["end"]
else:
merged.append(current)

return merged

@staticmethod
def format_output(segments):
return [
{
"start": s["start"],
"end": s["end"],
"speaker": s["speaker"]
}
for s in segments
]
37 changes: 37 additions & 0 deletions zaban_backend/app/services/diarization/service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# app/services/diarization/service.py

from .pipeline import DiarizationPipeline
from .clustering import SpeakerClustering
from .segmenter import Segmenter


class DiarizationService:
def __init__(self, distance_threshold=0.6, single_speaker_threshold=0.7):
self.pipeline = DiarizationPipeline()
self.clustering = SpeakerClustering(
distance_threshold=distance_threshold,
single_speaker_threshold=single_speaker_threshold
)

def diarize(self, audio_path: str):
"""
Full diarization pipeline
"""

# Step 1: VAD → segments
segments = self.pipeline.get_speech_segments(audio_path)

# Step 2: embeddings
segments_with_embeddings = self.pipeline.extract_embeddings(
audio_path, segments
)

# Step 3: clustering → assign speakers
clustered = self.clustering.cluster(segments_with_embeddings)

# Step 4: merge segments
resolved = Segmenter.resolve_overlaps(clustered)
merged = Segmenter.merge_segments(resolved)

# Step 5: format output
return Segmenter.format_output(merged)
Empty file.
Empty file.