diff --git a/zaban_backend/app/api/v1/voiceprint.py b/zaban_backend/app/api/v1/voiceprint.py index 4c3b3ab..f7df806 100644 --- a/zaban_backend/app/api/v1/voiceprint.py +++ b/zaban_backend/app/api/v1/voiceprint.py @@ -34,6 +34,11 @@ def get_verifier(request: Request): return verifier +def get_diarization_service(request: Request): + """Dependency to get the diarization service from app state.""" + return getattr(request.app.state, "diarization_service", None) + + def _service_unavailable_response() -> JSONResponse: """Return a graceful response when voiceprint service is unavailable.""" if not settings.VOICEPRINT_ENABLED: @@ -121,7 +126,8 @@ async def verify_voiceprint( device_id: Optional[str] = Form(None), file: UploadFile = File(None), db: Session = Depends(get_db), - verifier=Depends(get_verifier) + verifier=Depends(get_verifier), + diarization_service=Depends(get_diarization_service) ): """Verify a user's voice against their enrolled voiceprint.""" if verifier is None: @@ -147,6 +153,23 @@ async def verify_voiceprint( async with await open_file(temp_path, "wb") as f: await f.write(audio_content) + # Diarization check (ensure single speaker) + if diarization_service: + # Run diarization in executor as it's CPU/GPU heavy + import asyncio + loop = asyncio.get_running_loop() + diarization_result = await loop.run_in_executor( + None, diarization_service.diarize, temp_path + ) + + unique_speakers = len(set(segment["speaker"] for segment in diarization_result)) + if unique_speakers > 1: + return VerificationResponse( + verified=False, + threshold=settings.VERIFICATION_THRESHOLD, + error=f"Multiple speakers detected ({unique_speakers}). Please provide audio with only one speaker." + ) + # Verify result = await verifier.verify_speaker(temp_path, str(customer_id)) diff --git a/zaban_backend/app/main.py b/zaban_backend/app/main.py index 32ea699..336a095 100644 --- a/zaban_backend/app/main.py +++ b/zaban_backend/app/main.py @@ -64,6 +64,17 @@ async def startup_event(): else: print("ℹ️ Voiceprint service disabled (VOICEPRINT_ENABLED=false)") + # Initialize Diarization Service + try: + from .services.diarization.service import DiarizationService + app.state.diarization_service = DiarizationService() + print("✅ Diarization service initialized.") + except Exception as e: + import traceback + print(f"⚠️ Diarization service initialization failed: {e}") + traceback.print_exc() + app.state.diarization_service = None + @app.get("/up") async def up(): diff --git a/zaban_backend/app/services/diarization/clustering.py b/zaban_backend/app/services/diarization/clustering.py new file mode 100644 index 0000000..bead967 --- /dev/null +++ b/zaban_backend/app/services/diarization/clustering.py @@ -0,0 +1,69 @@ +# app/services/diarization/clustering.py + +from sklearn.cluster import AgglomerativeClustering +import numpy as np +from sklearn.metrics.pairwise import cosine_similarity + +def is_single_speaker(embeddings, threshold=0.75): + """ + Check if all embeddings are similar → single speaker + """ + if len(embeddings) < 2: + return True + + sims = cosine_similarity(embeddings) + + # ignore diagonal + avg_sim = (np.sum(sims) - len(sims)) / (len(sims)**2 - len(sims)) + + return avg_sim > threshold + +class SpeakerClustering: + def __init__(self, distance_threshold=0.6, single_speaker_threshold=0.7): + """ + distance_threshold: For AgglomerativeClustering (1 - cosine_similarity) + single_speaker_threshold: Cosine similarity threshold for Pre-Diarization Decision Layer + """ + self.distance_threshold = distance_threshold + self.single_speaker_threshold = single_speaker_threshold + + def cluster(self, segments): + """ + segments: [{start, end, embedding}] + """ + if not segments: + return [] + + embeddings = np.array([s["embedding"] for s in segments]) + + # Normalize embeddings for cosine distance + norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + norms[norms == 0] = 1.0 # avoid division by zero + embeddings = embeddings / norms + + # Pre-Diarization Decision Layer + if is_single_speaker(embeddings, threshold=self.single_speaker_threshold): + print(f"Single speaker detected (sim > {self.single_speaker_threshold}). Skipping clustering.") + for seg in segments: + seg["speaker"] = "speaker_0" + return segments + + # Clustering for multiple speakers + try: + clustering = AgglomerativeClustering( + n_clusters=None, + distance_threshold=self.distance_threshold, + metric="cosine", + linkage="average" + ) + labels = clustering.fit_predict(embeddings) + except Exception as e: + print(f"Clustering failed: {e}. Falling back to single speaker.") + for seg in segments: + seg["speaker"] = "speaker_0" + return segments + + for seg, label in zip(segments, labels): + seg["speaker"] = f"speaker_{label}" + + return segments \ No newline at end of file diff --git a/zaban_backend/app/services/diarization/config.py b/zaban_backend/app/services/diarization/config.py new file mode 100644 index 0000000..e69de29 diff --git a/zaban_backend/app/services/diarization/pipeline.py b/zaban_backend/app/services/diarization/pipeline.py new file mode 100644 index 0000000..3ff986d --- /dev/null +++ b/zaban_backend/app/services/diarization/pipeline.py @@ -0,0 +1,91 @@ +# app/services/diarization/pipeline.py + +import torch +import librosa +from speechbrain.inference import EncoderClassifier +from speechbrain.inference.VAD import VAD + + +class DiarizationPipeline: + def __init__(self): + # Speaker embedding model (ECAPA) + self.embedding_model = EncoderClassifier.from_hparams( + source="speechbrain/spkrec-ecapa-voxceleb", + run_opts={"device": "cpu"} + ) + + # Voice Activity Detection + self.vad = VAD.from_hparams( + source="speechbrain/vad-crdnn-libriparty", + run_opts={"device": "cpu"} + ) + + def load_audio(self, path): + waveform, sr = librosa.load(path, sr=16000) + return torch.tensor(waveform).unsqueeze(0), sr + + def get_speech_segments(self, audio_path): + """ + Returns speech segments [(start, end)] + """ + boundaries = self.vad.get_speech_segments(audio_path) + segments = [] + + for seg in boundaries: + start = float(seg[0]) + end = float(seg[1]) + segments.append((start, end)) + print("VAD output:", boundaries) + + return segments + + def get_embedding(self, waveform): + """ + Extract speaker embedding + """ + with torch.no_grad(): + emb = self.embedding_model.encode_batch(waveform) + return emb.squeeze().numpy() + + def extract_embeddings(self, audio_path, segments): + waveform, sr = self.load_audio(audio_path) + + embeddings = [] + + CHUNK_DURATION = 1.5 # seconds + STRIDE = 0.75 # overlap + + chunk_size = int(CHUNK_DURATION * sr) + stride_size = int(STRIDE * sr) + + for start, end in segments: + duration = end - start + start_sample = int(start * sr) + end_sample = int(end * sr) + segment_wave = waveform[:, start_sample:end_sample] + + if segment_wave.shape[1] == 0: + continue + + # If segment is shorter than chunk_size, take the whole segment + if duration < CHUNK_DURATION: + emb = self.get_embedding(segment_wave) + embeddings.append({ + "start": start, + "end": end, + "embedding": emb + }) + continue + + # For longer segments, use sliding window + for i in range(0, segment_wave.shape[1] - chunk_size + 1, stride_size): + chunk = segment_wave[:, i:i + chunk_size] + + emb = self.get_embedding(chunk) + embeddings.append({ + "start": start + (i / sr), + "end": start + (min(i + chunk_size, segment_wave.shape[1]) / sr), + "embedding": emb + }) + + return embeddings \ No newline at end of file diff --git a/zaban_backend/app/services/diarization/segmenter.py b/zaban_backend/app/services/diarization/segmenter.py new file mode 100644 index 0000000..e0057b4 --- /dev/null +++ b/zaban_backend/app/services/diarization/segmenter.py @@ -0,0 +1,77 @@ +# app/services/diarization/segmenter.py + +class Segmenter: + @staticmethod + def resolve_overlaps(segments): + """ + Convert overlapping segments into timeline using majority voting + """ + timeline = [] + + for seg in segments: + timeline.append((seg["start"], "start", seg["speaker"])) + timeline.append((seg["end"], "end", seg["speaker"])) + + timeline.sort() + + active = [] + result = [] + + last_time = None + + for time, typ, speaker in timeline: + if last_time is not None and active: + # pick most frequent speaker + speaker_counts = {} + for s in active: + speaker_counts[s] = speaker_counts.get(s, 0) + 1 + + dominant = max(speaker_counts, key=speaker_counts.get) + + result.append({ + "start": last_time, + "end": time, + "speaker": dominant + }) + + if typ == "start": + active.append(speaker) + else: + if speaker in active: + active.remove(speaker) + + last_time = time + + return result + + @staticmethod + def merge_segments(segments, gap_threshold=0.5): + if not segments: + return [] + + segments = sorted(segments, key=lambda x: x["start"]) + merged = [segments[0]] + + for current in segments[1:]: + last = merged[-1] + + if ( + current["speaker"] == last["speaker"] + and (current["start"] - last["end"]) <= gap_threshold + ): + last["end"] = current["end"] + else: + merged.append(current) + + return merged + + @staticmethod + def format_output(segments): + return [ + { + "start": s["start"], + "end": s["end"], + "speaker": s["speaker"] + } + for s in segments + ] \ No newline at end of file diff --git a/zaban_backend/app/services/diarization/service.py b/zaban_backend/app/services/diarization/service.py new file mode 100644 index 0000000..ba59c8e --- /dev/null +++ b/zaban_backend/app/services/diarization/service.py @@ -0,0 +1,37 @@ +# app/services/diarization/service.py + +from .pipeline import DiarizationPipeline +from .clustering import SpeakerClustering +from .segmenter import Segmenter + + +class DiarizationService: + def __init__(self, distance_threshold=0.6, single_speaker_threshold=0.7): + self.pipeline = DiarizationPipeline() + self.clustering = SpeakerClustering( + distance_threshold=distance_threshold, + single_speaker_threshold=single_speaker_threshold + ) + + def diarize(self, audio_path: str): + """ + Full diarization pipeline + """ + + # Step 1: VAD → segments + segments = self.pipeline.get_speech_segments(audio_path) + + # Step 2: embeddings + segments_with_embeddings = self.pipeline.extract_embeddings( + audio_path, segments + ) + + # Step 3: clustering → assign speakers + clustered = self.clustering.cluster(segments_with_embeddings) + + # Step 4: merge segments + resolved = Segmenter.resolve_overlaps(clustered) + merged = Segmenter.merge_segments(resolved) + + # Step 5: format output + return Segmenter.format_output(merged) \ No newline at end of file diff --git a/zaban_backend/app/services/diarization/utils/audio.py b/zaban_backend/app/services/diarization/utils/audio.py new file mode 100644 index 0000000..e69de29 diff --git a/zaban_backend/app/services/diarization/utils/embeddings.py b/zaban_backend/app/services/diarization/utils/embeddings.py new file mode 100644 index 0000000..e69de29