|
|
@@ -0,0 +1,760 @@
|
|
|
+#!/usr/bin/env python3
|
|
|
+"""
|
|
|
+embeddings.py — Speaker Voiceprint Management
|
|
|
+
|
|
|
+Handles extraction, storage, and matching of speaker voice embeddings
|
|
|
+using pyannote.audio's SpeakerEmbedding pipeline.
|
|
|
+
|
|
|
+Embeddings are stored as numpy .npy files in bridge/embeddings/.
|
|
|
+speakers.json is updated with has_embedding / embedding_updated metadata.
|
|
|
+
|
|
|
+Designed for future migration to a remote vector DB (pgvector, Qdrant, etc.)
|
|
|
+— each embedding is a 512-dim float32 array, stored as a .npy binary file.
|
|
|
+
|
|
|
+Public API
|
|
|
+----------
|
|
|
+ registry = EmbeddingRegistry()
|
|
|
+
|
|
|
+ # Extract from an audio file segment (e.g. 10s snip from a recording)
|
|
|
+ embedding = registry.extract(audio_path, start_sec=45.0, end_sec=55.0)
|
|
|
+
|
|
|
+ # Save a voiceprint for a known speaker
|
|
|
+ registry.save(speaker_id="SPEAKER_00", embedding=embedding)
|
|
|
+
|
|
|
+ # Check if a speaker has a stored voiceprint
|
|
|
+ registry.has(speaker_id="SPEAKER_00") # -> bool
|
|
|
+
|
|
|
+ # Find the best matching speaker for a live embedding
|
|
|
+ match = registry.find_match(live_embedding, threshold=0.82)
|
|
|
+ # -> ("SPEAKER_00", 0.91) or None
|
|
|
+
|
|
|
+ # Load a stored embedding directly
|
|
|
+ emb = registry.load("SPEAKER_00") # -> np.ndarray | None
|
|
|
+
|
|
|
+ # Delete a voiceprint
|
|
|
+ registry.delete("SPEAKER_00")
|
|
|
+
|
|
|
+ # List all speakers with voiceprints
|
|
|
+ registry.list_enrolled() # -> [{"id": ..., "updated": ...}, ...]
|
|
|
+"""
|
|
|
+
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
+import json
|
|
|
+import logging
|
|
|
+import os
|
|
|
+from datetime import datetime, timezone
|
|
|
+from pathlib import Path
|
|
|
+from typing import Optional
|
|
|
+
|
|
|
+import numpy as np
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+# ── Paths ──────────────────────────────────────────────────────────────────────
|
|
|
+
|
|
|
+# Default locations relative to this file (bridge/)
|
|
|
+_HERE = Path(__file__).parent
|
|
|
+EMBEDDINGS_DIR = _HERE / "embeddings"
|
|
|
+SPEAKERS_FILE = _HERE / "speakers.json"
|
|
|
+
|
|
|
+# Embedding dimensions produced by pyannote SpeakerEmbedding
|
|
|
+EMBEDDING_DIM = 512
|
|
|
+
|
|
|
+# Cosine similarity threshold for automatic name assignment
|
|
|
+# 0.82 = conservative (fewer false positives), 0.75 = more permissive
|
|
|
+DEFAULT_THRESHOLD = 0.82
|
|
|
+
|
|
|
+# Minimum audio duration in seconds for reliable embedding extraction
|
|
|
+MIN_AUDIO_SEC = 5.0
|
|
|
+
|
|
|
+
|
|
|
+# ── Lazy model loader ──────────────────────────────────────────────────────────
|
|
|
+
|
|
|
+_pipeline = None # loaded on first use
|
|
|
+
|
|
|
+def _get_pipeline():
|
|
|
+ """
|
|
|
+ Load the pyannote SpeakerEmbedding pipeline on first call.
|
|
|
+ Requires HF_TOKEN env var or huggingface-cli login.
|
|
|
+ Model: pyannote/embedding (3M params, fast, accurate).
|
|
|
+ """
|
|
|
+ global _pipeline
|
|
|
+ if _pipeline is not None:
|
|
|
+ return _pipeline
|
|
|
+
|
|
|
+ try:
|
|
|
+ from pyannote.audio import Model, Inference
|
|
|
+ except ImportError as exc:
|
|
|
+ raise RuntimeError(
|
|
|
+ "pyannote.audio is required for voiceprint extraction. "
|
|
|
+ "Install with: pip install pyannote.audio"
|
|
|
+ ) from exc
|
|
|
+
|
|
|
+ hf_token = os.environ.get("HF_TOKEN")
|
|
|
+ if not hf_token:
|
|
|
+ raise RuntimeError(
|
|
|
+ "HF_TOKEN environment variable is not set. "
|
|
|
+ "Set it in start.bat or run: huggingface-cli login"
|
|
|
+ )
|
|
|
+
|
|
|
+ logger.info("[Embeddings] Loading pyannote speaker embedding model...")
|
|
|
+ try:
|
|
|
+ model = Model.from_pretrained(
|
|
|
+ "pyannote/embedding",
|
|
|
+ use_auth_token=hf_token,
|
|
|
+ )
|
|
|
+ _pipeline = Inference(model, window="whole")
|
|
|
+ logger.info("[Embeddings] Speaker embedding model loaded.")
|
|
|
+ except Exception as exc:
|
|
|
+ raise RuntimeError(
|
|
|
+ f"Failed to load pyannote/embedding model: {exc}\n"
|
|
|
+ "Ensure you have accepted the model conditions at "
|
|
|
+ "https://huggingface.co/pyannote/embedding"
|
|
|
+ ) from exc
|
|
|
+
|
|
|
+ return _pipeline
|
|
|
+
|
|
|
+
|
|
|
+# ── Helpers ────────────────────────────────────────────────────────────────────
|
|
|
+
|
|
|
+def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
|
|
+ """Cosine similarity between two 1-D vectors. Returns value in [-1, 1]."""
|
|
|
+ a = a.flatten().astype(np.float32)
|
|
|
+ b = b.flatten().astype(np.float32)
|
|
|
+ norm_a = np.linalg.norm(a)
|
|
|
+ norm_b = np.linalg.norm(b)
|
|
|
+ if norm_a == 0 or norm_b == 0:
|
|
|
+ return 0.0
|
|
|
+ return float(np.dot(a, b) / (norm_a * norm_b))
|
|
|
+
|
|
|
+
|
|
|
+def _load_speakers_json() -> dict:
|
|
|
+ """Load speakers.json, returning empty dict on any failure."""
|
|
|
+ if SPEAKERS_FILE.exists():
|
|
|
+ try:
|
|
|
+ raw = json.loads(SPEAKERS_FILE.read_text(encoding="utf-8"))
|
|
|
+ # Support both legacy flat format {"SPEAKER_00": "Name"}
|
|
|
+ # and new rich format {"SPEAKER_00": {"name": "...", ...}}
|
|
|
+ return raw if isinstance(raw, dict) else {}
|
|
|
+ except (json.JSONDecodeError, OSError):
|
|
|
+ pass
|
|
|
+ return {}
|
|
|
+
|
|
|
+
|
|
|
+def _save_speakers_json(data: dict) -> None:
|
|
|
+ try:
|
|
|
+ SPEAKERS_FILE.write_text(
|
|
|
+ json.dumps(data, indent=2, ensure_ascii=False),
|
|
|
+ encoding="utf-8",
|
|
|
+ )
|
|
|
+ except OSError as exc:
|
|
|
+ logger.error(f"[Embeddings] Could not save speakers.json: {exc}")
|
|
|
+
|
|
|
+
|
|
|
+def _normalise_speaker_entry(entry) -> dict:
|
|
|
+ """
|
|
|
+ Ensure a speakers.json entry is in rich-dict format.
|
|
|
+ Handles legacy flat strings: "Pastor John" -> {"name": "Pastor John", ...}
|
|
|
+ """
|
|
|
+ if isinstance(entry, str):
|
|
|
+ return {
|
|
|
+ "name": entry,
|
|
|
+ "has_embedding": False,
|
|
|
+ "embedding_updated": None,
|
|
|
+ "colour": None,
|
|
|
+ "notes": "",
|
|
|
+ }
|
|
|
+ if isinstance(entry, dict):
|
|
|
+ entry.setdefault("name", "")
|
|
|
+ entry.setdefault("has_embedding", False)
|
|
|
+ entry.setdefault("embedding_updated", None)
|
|
|
+ entry.setdefault("colour", None)
|
|
|
+ entry.setdefault("notes", "")
|
|
|
+ return entry
|
|
|
+ return {"name": str(entry), "has_embedding": False,
|
|
|
+ "embedding_updated": None, "colour": None, "notes": ""}
|
|
|
+
|
|
|
+
|
|
|
+def _update_speaker_meta(speaker_id: str, has_embedding: bool) -> None:
|
|
|
+ """Update speakers.json metadata for a speaker after enrolment/deletion."""
|
|
|
+ data = _load_speakers_json()
|
|
|
+ entry = _normalise_speaker_entry(data.get(speaker_id, speaker_id))
|
|
|
+ entry["has_embedding"] = has_embedding
|
|
|
+ entry["embedding_updated"] = (
|
|
|
+ datetime.now(timezone.utc).isoformat() if has_embedding else None
|
|
|
+ )
|
|
|
+ data[speaker_id] = entry
|
|
|
+ _save_speakers_json(data)
|
|
|
+
|
|
|
+
|
|
|
+# ── Audio extraction helper ────────────────────────────────────────────────────
|
|
|
+
|
|
|
+def extract_audio_segment(
|
|
|
+ audio_path: Path | str,
|
|
|
+ start_sec: float,
|
|
|
+ end_sec: float,
|
|
|
+ out_path: Optional[Path] = None,
|
|
|
+) -> Path:
|
|
|
+ """
|
|
|
+ Extract a segment of audio using miniaudio and save as a temporary WAV file.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ audio_path: Source audio file (MP3, WAV, FLAC, OGG, M4A, etc.)
|
|
|
+ start_sec: Start time in seconds
|
|
|
+ end_sec: End time in seconds
|
|
|
+ out_path: Optional explicit output path; if None, a temp file is used
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Path to the extracted WAV file (caller should delete when done if temp)
|
|
|
+ """
|
|
|
+ import tempfile
|
|
|
+ import struct
|
|
|
+ import wave
|
|
|
+
|
|
|
+ try:
|
|
|
+ import miniaudio
|
|
|
+ except ImportError as exc:
|
|
|
+ raise RuntimeError("miniaudio is required: pip install miniaudio") from exc
|
|
|
+
|
|
|
+ duration = end_sec - start_sec
|
|
|
+ if duration < MIN_AUDIO_SEC:
|
|
|
+ logger.warning(
|
|
|
+ f"[Embeddings] Short segment ({duration:.1f}s < {MIN_AUDIO_SEC}s) "
|
|
|
+ "— embedding may be less reliable"
|
|
|
+ )
|
|
|
+
|
|
|
+ sample_rate = 16000
|
|
|
+ n_channels = 1
|
|
|
+ frames_needed = int(duration * sample_rate)
|
|
|
+ skip_frames = int(start_sec * sample_rate)
|
|
|
+
|
|
|
+ # Stream file and skip to start position
|
|
|
+ stream = miniaudio.stream_file(
|
|
|
+ str(audio_path),
|
|
|
+ output_format=miniaudio.SampleFormat.SIGNED16,
|
|
|
+ nchannels=n_channels,
|
|
|
+ sample_rate=sample_rate,
|
|
|
+ frames_to_read=4096,
|
|
|
+ )
|
|
|
+
|
|
|
+ collected: list[bytes] = []
|
|
|
+ collected_frames = 0
|
|
|
+ skipped_frames = 0
|
|
|
+
|
|
|
+ for chunk in stream:
|
|
|
+ chunk_bytes = bytes(chunk)
|
|
|
+ chunk_frames = len(chunk_bytes) // 2 # int16 = 2 bytes per sample
|
|
|
+
|
|
|
+ # Skip frames before start_sec
|
|
|
+ if skipped_frames < skip_frames:
|
|
|
+ remaining_skip = skip_frames - skipped_frames
|
|
|
+ if chunk_frames <= remaining_skip:
|
|
|
+ skipped_frames += chunk_frames
|
|
|
+ continue
|
|
|
+ else:
|
|
|
+ # Partial skip
|
|
|
+ offset = remaining_skip * 2
|
|
|
+ chunk_bytes = chunk_bytes[offset:]
|
|
|
+ chunk_frames = len(chunk_bytes) // 2
|
|
|
+ skipped_frames = skip_frames
|
|
|
+
|
|
|
+ # Collect up to frames_needed
|
|
|
+ remaining_needed = frames_needed - collected_frames
|
|
|
+ if chunk_frames >= remaining_needed:
|
|
|
+ collected.append(chunk_bytes[: remaining_needed * 2])
|
|
|
+ collected_frames += remaining_needed
|
|
|
+ break
|
|
|
+ else:
|
|
|
+ collected.append(chunk_bytes)
|
|
|
+ collected_frames += chunk_frames
|
|
|
+
|
|
|
+ if not collected:
|
|
|
+ raise ValueError(
|
|
|
+ f"No audio extracted from {audio_path} at {start_sec}–{end_sec}s. "
|
|
|
+ "Check that the file is valid and the time range is within its duration."
|
|
|
+ )
|
|
|
+
|
|
|
+ pcm_data = b"".join(collected)
|
|
|
+
|
|
|
+ # Write to WAV
|
|
|
+ if out_path is None:
|
|
|
+ tmp = tempfile.NamedTemporaryFile(
|
|
|
+ suffix=".wav", delete=False, prefix="embedding_"
|
|
|
+ )
|
|
|
+ out_path = Path(tmp.name)
|
|
|
+ tmp.close()
|
|
|
+
|
|
|
+ with wave.open(str(out_path), "wb") as wf:
|
|
|
+ wf.setnchannels(n_channels)
|
|
|
+ wf.setsampwidth(2) # int16
|
|
|
+ wf.setframerate(sample_rate)
|
|
|
+ wf.writeframes(pcm_data)
|
|
|
+
|
|
|
+ return out_path
|
|
|
+
|
|
|
+
|
|
|
+# ── Main registry class ────────────────────────────────────────────────────────
|
|
|
+
|
|
|
+class EmbeddingRegistry:
|
|
|
+ """
|
|
|
+ Manages speaker voice embeddings on disk.
|
|
|
+
|
|
|
+ Thread-safe for reading; write operations should be called from
|
|
|
+ a single thread (the admin server) in normal use.
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ embeddings_dir: Path | str = EMBEDDINGS_DIR,
|
|
|
+ speakers_file: Path | str = SPEAKERS_FILE,
|
|
|
+ threshold: float = DEFAULT_THRESHOLD,
|
|
|
+ ):
|
|
|
+ self.embeddings_dir = Path(embeddings_dir)
|
|
|
+ self.speakers_file = Path(speakers_file)
|
|
|
+ self.threshold = threshold
|
|
|
+ self.embeddings_dir.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ def _path(self, speaker_id: str) -> Path:
|
|
|
+ return self.embeddings_dir / f"{speaker_id}.npy"
|
|
|
+
|
|
|
+ # ── Core storage ───────────────────────────────────────────────────────────
|
|
|
+
|
|
|
+ def save(self, speaker_id: str, embedding: np.ndarray) -> None:
|
|
|
+ """Persist a voiceprint embedding for a speaker."""
|
|
|
+ emb = embedding.flatten().astype(np.float32)
|
|
|
+ if emb.shape[0] != EMBEDDING_DIM:
|
|
|
+ logger.warning(
|
|
|
+ f"[Embeddings] Unexpected embedding dim {emb.shape[0]} "
|
|
|
+ f"(expected {EMBEDDING_DIM}) for {speaker_id}"
|
|
|
+ )
|
|
|
+ np.save(str(self._path(speaker_id)), emb)
|
|
|
+ _update_speaker_meta(speaker_id, has_embedding=True)
|
|
|
+ logger.info(f"[Embeddings] Saved voiceprint for {speaker_id}")
|
|
|
+
|
|
|
+ def load(self, speaker_id: str) -> Optional[np.ndarray]:
|
|
|
+ """Load a stored embedding, or return None if not found."""
|
|
|
+ p = self._path(speaker_id)
|
|
|
+ if not p.exists():
|
|
|
+ return None
|
|
|
+ try:
|
|
|
+ return np.load(str(p)).astype(np.float32)
|
|
|
+ except Exception as exc:
|
|
|
+ logger.error(f"[Embeddings] Failed to load {p}: {exc}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ def has(self, speaker_id: str) -> bool:
|
|
|
+ """Return True if a voiceprint exists for this speaker."""
|
|
|
+ return self._path(speaker_id).exists()
|
|
|
+
|
|
|
+ def delete(self, speaker_id: str) -> bool:
|
|
|
+ """Delete a stored voiceprint. Returns True if it existed."""
|
|
|
+ p = self._path(speaker_id)
|
|
|
+ if p.exists():
|
|
|
+ p.unlink()
|
|
|
+ _update_speaker_meta(speaker_id, has_embedding=False)
|
|
|
+ logger.info(f"[Embeddings] Deleted voiceprint for {speaker_id}")
|
|
|
+ return True
|
|
|
+ return False
|
|
|
+
|
|
|
+ def list_enrolled(self) -> list[dict]:
|
|
|
+ """Return list of speakers with stored voiceprints."""
|
|
|
+ enrolled = []
|
|
|
+ for p in sorted(self.embeddings_dir.glob("*.npy")):
|
|
|
+ speaker_id = p.stem
|
|
|
+ data = _load_speakers_json()
|
|
|
+ entry = _normalise_speaker_entry(data.get(speaker_id, speaker_id))
|
|
|
+ enrolled.append({
|
|
|
+ "id": speaker_id,
|
|
|
+ "name": entry["name"],
|
|
|
+ "updated": entry.get("embedding_updated"),
|
|
|
+ "size_kb": round(p.stat().st_size / 1024, 1),
|
|
|
+ })
|
|
|
+ return enrolled
|
|
|
+
|
|
|
+ # ── Extraction ─────────────────────────────────────────────────────────────
|
|
|
+
|
|
|
+ def extract(
|
|
|
+ self,
|
|
|
+ audio_path: Path | str,
|
|
|
+ start_sec: float = 0.0,
|
|
|
+ end_sec: Optional[float] = None,
|
|
|
+ ) -> np.ndarray:
|
|
|
+ """
|
|
|
+ Extract a speaker embedding from an audio file segment.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ audio_path: Path to audio file (MP3, WAV, FLAC, M4A, OGG, etc.)
|
|
|
+ start_sec: Start of segment in seconds (default: 0)
|
|
|
+ end_sec: End of segment in seconds (default: use whole file)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ numpy array of shape (512,) — the speaker embedding
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ RuntimeError: if pyannote model fails to load
|
|
|
+ ValueError: if audio segment is empty or too short
|
|
|
+ """
|
|
|
+ import tempfile
|
|
|
+
|
|
|
+ audio_path = Path(audio_path)
|
|
|
+ if not audio_path.exists():
|
|
|
+ raise FileNotFoundError(f"Audio file not found: {audio_path}")
|
|
|
+
|
|
|
+ # If no end time, get file duration and use whole file
|
|
|
+ if end_sec is None:
|
|
|
+ try:
|
|
|
+ import miniaudio
|
|
|
+ info = miniaudio.get_file_info(str(audio_path))
|
|
|
+ end_sec = info.duration
|
|
|
+ except Exception:
|
|
|
+ end_sec = start_sec + 30.0 # fallback
|
|
|
+
|
|
|
+ duration = end_sec - start_sec
|
|
|
+ if duration <= 0:
|
|
|
+ raise ValueError(f"Invalid segment: start={start_sec}, end={end_sec}")
|
|
|
+
|
|
|
+ # Extract segment to a temp WAV
|
|
|
+ tmp_wav = None
|
|
|
+ try:
|
|
|
+ tmp_wav = extract_audio_segment(audio_path, start_sec, end_sec)
|
|
|
+ pipeline = _get_pipeline()
|
|
|
+
|
|
|
+ # pyannote Inference accepts a file path directly
|
|
|
+ embedding = pipeline(str(tmp_wav))
|
|
|
+
|
|
|
+ # embedding may be a pyannote Annotation or ndarray depending on version
|
|
|
+ if hasattr(embedding, "data"):
|
|
|
+ arr = np.array(embedding.data).flatten()
|
|
|
+ elif isinstance(embedding, np.ndarray):
|
|
|
+ arr = embedding.flatten()
|
|
|
+ else:
|
|
|
+ arr = np.array(embedding).flatten()
|
|
|
+
|
|
|
+ logger.info(
|
|
|
+ f"[Embeddings] Extracted embedding from {audio_path.name} "
|
|
|
+ f"[{start_sec:.1f}s – {end_sec:.1f}s] → shape {arr.shape}"
|
|
|
+ )
|
|
|
+ return arr.astype(np.float32)
|
|
|
+
|
|
|
+ finally:
|
|
|
+ if tmp_wav and tmp_wav.exists():
|
|
|
+ try:
|
|
|
+ tmp_wav.unlink()
|
|
|
+ except OSError:
|
|
|
+ pass
|
|
|
+
|
|
|
+ def extract_and_save(
|
|
|
+ self,
|
|
|
+ speaker_id: str,
|
|
|
+ audio_path: Path | str,
|
|
|
+ start_sec: float = 0.0,
|
|
|
+ end_sec: Optional[float] = None,
|
|
|
+ ) -> np.ndarray:
|
|
|
+ """
|
|
|
+ Extract embedding from audio segment and immediately save it.
|
|
|
+ Convenience wrapper around extract() + save().
|
|
|
+ """
|
|
|
+ embedding = self.extract(audio_path, start_sec, end_sec)
|
|
|
+ self.save(speaker_id, embedding)
|
|
|
+ return embedding
|
|
|
+
|
|
|
+ # ── Matching ───────────────────────────────────────────────────────────────
|
|
|
+
|
|
|
+ def find_match(
|
|
|
+ self,
|
|
|
+ live_embedding: np.ndarray,
|
|
|
+ threshold: Optional[float] = None,
|
|
|
+ exclude_ids: Optional[set[str]] = None,
|
|
|
+ ) -> Optional[tuple[str, float]]:
|
|
|
+ """
|
|
|
+ Find the best matching enrolled speaker for a live embedding.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ live_embedding: Embedding from live audio (shape: (512,))
|
|
|
+ threshold: Cosine similarity threshold (default: self.threshold)
|
|
|
+ exclude_ids: Speaker IDs to skip (already confirmed this session)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ (speaker_id, similarity_score) if match found above threshold,
|
|
|
+ or None if no match.
|
|
|
+
|
|
|
+ Notes:
|
|
|
+ - Similarity is cosine similarity in [-1, 1]; higher = more similar
|
|
|
+ - Typical same-speaker similarity: 0.85–0.98
|
|
|
+ - Typical different-speaker similarity: 0.2–0.65
|
|
|
+ - Threshold 0.82 is conservative; lower to 0.75 for more permissive
|
|
|
+ """
|
|
|
+ if threshold is None:
|
|
|
+ threshold = self.threshold
|
|
|
+
|
|
|
+ exclude_ids = exclude_ids or set()
|
|
|
+ best_id = None
|
|
|
+ best_score = -1.0
|
|
|
+
|
|
|
+ for p in self.embeddings_dir.glob("*.npy"):
|
|
|
+ speaker_id = p.stem
|
|
|
+ if speaker_id in exclude_ids:
|
|
|
+ continue
|
|
|
+ stored = self.load(speaker_id)
|
|
|
+ if stored is None:
|
|
|
+ continue
|
|
|
+ score = _cosine_similarity(live_embedding, stored)
|
|
|
+ logger.debug(f"[Embeddings] {speaker_id}: similarity={score:.4f}")
|
|
|
+ if score > best_score:
|
|
|
+ best_score = score
|
|
|
+ best_id = speaker_id
|
|
|
+
|
|
|
+ if best_id and best_score >= threshold:
|
|
|
+ logger.info(
|
|
|
+ f"[Embeddings] Match: {best_id} (similarity={best_score:.4f})"
|
|
|
+ )
|
|
|
+ return (best_id, best_score)
|
|
|
+
|
|
|
+ logger.debug(
|
|
|
+ f"[Embeddings] No match above threshold={threshold:.2f} "
|
|
|
+ f"(best={best_score:.4f})"
|
|
|
+ )
|
|
|
+ return None
|
|
|
+
|
|
|
+ def similarity_scores(
|
|
|
+ self,
|
|
|
+ live_embedding: np.ndarray,
|
|
|
+ ) -> list[dict]:
|
|
|
+ """
|
|
|
+ Return similarity scores against all enrolled speakers, sorted descending.
|
|
|
+ Useful for admin UI diagnostics / confidence display.
|
|
|
+ """
|
|
|
+ results = []
|
|
|
+ data = _load_speakers_json()
|
|
|
+ for p in self.embeddings_dir.glob("*.npy"):
|
|
|
+ speaker_id = p.stem
|
|
|
+ stored = self.load(speaker_id)
|
|
|
+ if stored is None:
|
|
|
+ continue
|
|
|
+ score = _cosine_similarity(live_embedding, stored)
|
|
|
+ entry = _normalise_speaker_entry(data.get(speaker_id, speaker_id))
|
|
|
+ results.append({
|
|
|
+ "id": speaker_id,
|
|
|
+ "name": entry["name"],
|
|
|
+ "similarity": round(score, 4),
|
|
|
+ "match": score >= self.threshold,
|
|
|
+ })
|
|
|
+ return sorted(results, key=lambda x: x["similarity"], reverse=True)
|
|
|
+
|
|
|
+ # ── Accumulator for live audio ─────────────────────────────────────────────
|
|
|
+
|
|
|
+ def make_accumulator(
|
|
|
+ self,
|
|
|
+ min_seconds: float = 5.0,
|
|
|
+ sample_rate: int = 16000,
|
|
|
+ ) -> "LiveEmbeddingAccumulator":
|
|
|
+ """
|
|
|
+ Create an accumulator that buffers live PCM audio and
|
|
|
+ extracts an embedding once enough audio has been collected.
|
|
|
+ """
|
|
|
+ return LiveEmbeddingAccumulator(
|
|
|
+ registry = self,
|
|
|
+ min_seconds = min_seconds,
|
|
|
+ sample_rate = sample_rate,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+# ── Live audio accumulator ─────────────────────────────────────────────────────
|
|
|
+
|
|
|
+class LiveEmbeddingAccumulator:
|
|
|
+ """
|
|
|
+ Buffers raw int16 PCM chunks from the live microphone stream.
|
|
|
+ Once MIN_SECONDS of audio is accumulated, extracts an embedding
|
|
|
+ that can be used for speaker matching or enrolment.
|
|
|
+
|
|
|
+ Usage:
|
|
|
+ acc = registry.make_accumulator()
|
|
|
+ acc.push(pcm_chunk_bytes)
|
|
|
+ ...
|
|
|
+ if acc.ready():
|
|
|
+ embedding = acc.extract_embedding()
|
|
|
+ match = registry.find_match(embedding)
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ registry: EmbeddingRegistry,
|
|
|
+ min_seconds: float = 5.0,
|
|
|
+ sample_rate: int = 16000,
|
|
|
+ ):
|
|
|
+ self.registry = registry
|
|
|
+ self.min_frames = int(min_seconds * sample_rate)
|
|
|
+ self.sample_rate = sample_rate
|
|
|
+ self._frames: list[bytes] = []
|
|
|
+ self._n_frames = 0
|
|
|
+
|
|
|
+ def push(self, pcm_bytes: bytes) -> None:
|
|
|
+ """Add a chunk of raw int16 PCM audio."""
|
|
|
+ self._frames.append(pcm_bytes)
|
|
|
+ self._n_frames += len(pcm_bytes) // 2 # int16 = 2 bytes
|
|
|
+
|
|
|
+ def ready(self) -> bool:
|
|
|
+ """Return True if enough audio has been accumulated."""
|
|
|
+ return self._n_frames >= self.min_frames
|
|
|
+
|
|
|
+ def seconds_accumulated(self) -> float:
|
|
|
+ return self._n_frames / self.sample_rate
|
|
|
+
|
|
|
+ def reset(self) -> None:
|
|
|
+ self._frames = []
|
|
|
+ self._n_frames = 0
|
|
|
+
|
|
|
+ def extract_embedding(self) -> np.ndarray:
|
|
|
+ """
|
|
|
+ Extract an embedding from the buffered audio.
|
|
|
+ Raises RuntimeError if not enough audio has been collected.
|
|
|
+ """
|
|
|
+ if not self.ready():
|
|
|
+ raise RuntimeError(
|
|
|
+ f"Not enough audio: {self.seconds_accumulated():.1f}s "
|
|
|
+ f"accumulated, need {self.min_frames / self.sample_rate:.1f}s"
|
|
|
+ )
|
|
|
+
|
|
|
+ import tempfile
|
|
|
+ import wave
|
|
|
+
|
|
|
+ # Write buffered PCM to a temp WAV
|
|
|
+ pcm_data = b"".join(self._frames)
|
|
|
+ tmp = tempfile.NamedTemporaryFile(
|
|
|
+ suffix=".wav", delete=False, prefix="live_emb_"
|
|
|
+ )
|
|
|
+ tmp_path = Path(tmp.name)
|
|
|
+ tmp.close()
|
|
|
+
|
|
|
+ try:
|
|
|
+ with wave.open(str(tmp_path), "wb") as wf:
|
|
|
+ wf.setnchannels(1)
|
|
|
+ wf.setsampwidth(2)
|
|
|
+ wf.setframerate(self.sample_rate)
|
|
|
+ wf.writeframes(pcm_data)
|
|
|
+
|
|
|
+ return self.registry.extract(tmp_path)
|
|
|
+ finally:
|
|
|
+ if tmp_path.exists():
|
|
|
+ try:
|
|
|
+ tmp_path.unlink()
|
|
|
+ except OSError:
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
+# ── Standalone test / CLI ──────────────────────────────────────────────────────
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ import argparse
|
|
|
+ import sys
|
|
|
+
|
|
|
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
|
|
|
+
|
|
|
+ parser = argparse.ArgumentParser(
|
|
|
+ description="Speaker voiceprint utility",
|
|
|
+ formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
|
+ epilog="""
|
|
|
+Examples:
|
|
|
+ # Extract and save a voiceprint from seconds 45-55 of a recording
|
|
|
+ python embeddings.py enrol SPEAKER_00 recording.mp3 --start 45 --end 55
|
|
|
+
|
|
|
+ # List all enrolled speakers
|
|
|
+ python embeddings.py list
|
|
|
+
|
|
|
+ # Test match: compare a 10s clip against enrolled speakers
|
|
|
+ python embeddings.py match recording.mp3 --start 120 --end 130
|
|
|
+
|
|
|
+ # Delete a voiceprint
|
|
|
+ python embeddings.py delete SPEAKER_00
|
|
|
+
|
|
|
+ # Show similarity scores for a clip against all enrolled speakers
|
|
|
+ python embeddings.py scores recording.mp3 --start 45 --end 55
|
|
|
+ """,
|
|
|
+ )
|
|
|
+ sub = parser.add_subparsers(dest="cmd")
|
|
|
+
|
|
|
+ # enrol
|
|
|
+ p_enrol = sub.add_parser("enrol", help="Extract and save a voiceprint")
|
|
|
+ p_enrol.add_argument("speaker_id")
|
|
|
+ p_enrol.add_argument("audio_file")
|
|
|
+ p_enrol.add_argument("--start", type=float, default=0.0)
|
|
|
+ p_enrol.add_argument("--end", type=float, default=None)
|
|
|
+
|
|
|
+ # list
|
|
|
+ sub.add_parser("list", help="List enrolled speakers")
|
|
|
+
|
|
|
+ # match
|
|
|
+ p_match = sub.add_parser("match", help="Find matching speaker for an audio clip")
|
|
|
+ p_match.add_argument("audio_file")
|
|
|
+ p_match.add_argument("--start", type=float, default=0.0)
|
|
|
+ p_match.add_argument("--end", type=float, default=None)
|
|
|
+ p_match.add_argument("--threshold", type=float, default=DEFAULT_THRESHOLD)
|
|
|
+
|
|
|
+ # scores
|
|
|
+ p_scores = sub.add_parser("scores", help="Show similarity against all enrolled speakers")
|
|
|
+ p_scores.add_argument("audio_file")
|
|
|
+ p_scores.add_argument("--start", type=float, default=0.0)
|
|
|
+ p_scores.add_argument("--end", type=float, default=None)
|
|
|
+
|
|
|
+ # delete
|
|
|
+ p_del = sub.add_parser("delete", help="Delete a stored voiceprint")
|
|
|
+ p_del.add_argument("speaker_id")
|
|
|
+
|
|
|
+ args = parser.parse_args()
|
|
|
+ if not args.cmd:
|
|
|
+ parser.print_help()
|
|
|
+ sys.exit(0)
|
|
|
+
|
|
|
+ registry = EmbeddingRegistry()
|
|
|
+
|
|
|
+ if args.cmd == "enrol":
|
|
|
+ emb = registry.extract_and_save(
|
|
|
+ args.speaker_id, args.audio_file, args.start, args.end
|
|
|
+ )
|
|
|
+ print(f"✓ Enrolled {args.speaker_id} — embedding shape: {emb.shape}")
|
|
|
+
|
|
|
+ elif args.cmd == "list":
|
|
|
+ enrolled = registry.list_enrolled()
|
|
|
+ if not enrolled:
|
|
|
+ print("No speakers enrolled yet.")
|
|
|
+ else:
|
|
|
+ print(f"{'ID':<15} {'Name':<25} {'Updated':<28} {'Size'}")
|
|
|
+ print("-" * 75)
|
|
|
+ for e in enrolled:
|
|
|
+ print(
|
|
|
+ f"{e['id']:<15} {e['name']:<25} "
|
|
|
+ f"{(e['updated'] or 'unknown'):<28} {e['size_kb']} KB"
|
|
|
+ )
|
|
|
+
|
|
|
+ elif args.cmd == "match":
|
|
|
+ emb = registry.extract(args.audio_file, args.start, args.end)
|
|
|
+ match = registry.find_match(emb, threshold=args.threshold)
|
|
|
+ if match:
|
|
|
+ sid, score = match
|
|
|
+ data = _load_speakers_json()
|
|
|
+ entry = _normalise_speaker_entry(data.get(sid, sid))
|
|
|
+ print(f"✓ Match: {sid} ({entry['name']}) — similarity {score:.4f}")
|
|
|
+ else:
|
|
|
+ print(f"✗ No match above threshold {args.threshold:.2f}")
|
|
|
+
|
|
|
+ elif args.cmd == "scores":
|
|
|
+ emb = registry.extract(args.audio_file, args.start, args.end)
|
|
|
+ scores = registry.similarity_scores(emb)
|
|
|
+ if not scores:
|
|
|
+ print("No enrolled speakers to compare against.")
|
|
|
+ else:
|
|
|
+ print(f"{'ID':<15} {'Name':<25} {'Similarity':<12} {'Match?'}")
|
|
|
+ print("-" * 60)
|
|
|
+ for s in scores:
|
|
|
+ flag = "✓" if s["match"] else " "
|
|
|
+ print(
|
|
|
+ f"{s['id']:<15} {s['name']:<25} "
|
|
|
+ f"{s['similarity']:<12.4f} {flag}"
|
|
|
+ )
|
|
|
+
|
|
|
+ elif args.cmd == "delete":
|
|
|
+ if registry.delete(args.speaker_id):
|
|
|
+ print(f"✓ Deleted voiceprint for {args.speaker_id}")
|
|
|
+ else:
|
|
|
+ print(f"No voiceprint found for {args.speaker_id}")
|