embeddings.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760
  1. #!/usr/bin/env python3
  2. """
  3. embeddings.py — Speaker Voiceprint Management
  4. Handles extraction, storage, and matching of speaker voice embeddings
  5. using pyannote.audio's SpeakerEmbedding pipeline.
  6. Embeddings are stored as numpy .npy files in bridge/embeddings/.
  7. speakers.json is updated with has_embedding / embedding_updated metadata.
  8. Designed for future migration to a remote vector DB (pgvector, Qdrant, etc.)
  9. — each embedding is a 512-dim float32 array, stored as a .npy binary file.
  10. Public API
  11. ----------
  12. registry = EmbeddingRegistry()
  13. # Extract from an audio file segment (e.g. 10s snip from a recording)
  14. embedding = registry.extract(audio_path, start_sec=45.0, end_sec=55.0)
  15. # Save a voiceprint for a known speaker
  16. registry.save(speaker_id="SPEAKER_00", embedding=embedding)
  17. # Check if a speaker has a stored voiceprint
  18. registry.has(speaker_id="SPEAKER_00") # -> bool
  19. # Find the best matching speaker for a live embedding
  20. match = registry.find_match(live_embedding, threshold=0.82)
  21. # -> ("SPEAKER_00", 0.91) or None
  22. # Load a stored embedding directly
  23. emb = registry.load("SPEAKER_00") # -> np.ndarray | None
  24. # Delete a voiceprint
  25. registry.delete("SPEAKER_00")
  26. # List all speakers with voiceprints
  27. registry.list_enrolled() # -> [{"id": ..., "updated": ...}, ...]
  28. """
  29. from __future__ import annotations
  30. import json
  31. import logging
  32. import os
  33. from datetime import datetime, timezone
  34. from pathlib import Path
  35. from typing import Optional
  36. import numpy as np
  37. logger = logging.getLogger(__name__)
  38. # ── Paths ──────────────────────────────────────────────────────────────────────
  39. # Default locations relative to this file (bridge/)
  40. _HERE = Path(__file__).parent
  41. EMBEDDINGS_DIR = _HERE / "embeddings"
  42. SPEAKERS_FILE = _HERE / "speakers.json"
  43. # Embedding dimensions produced by pyannote SpeakerEmbedding
  44. EMBEDDING_DIM = 512
  45. # Cosine similarity threshold for automatic name assignment
  46. # 0.82 = conservative (fewer false positives), 0.75 = more permissive
  47. DEFAULT_THRESHOLD = 0.82
  48. # Minimum audio duration in seconds for reliable embedding extraction
  49. MIN_AUDIO_SEC = 5.0
  50. # ── Lazy model loader ──────────────────────────────────────────────────────────
  51. _pipeline = None # loaded on first use
  52. def _get_pipeline():
  53. """
  54. Load the pyannote SpeakerEmbedding pipeline on first call.
  55. Requires HF_TOKEN env var or huggingface-cli login.
  56. Model: pyannote/embedding (3M params, fast, accurate).
  57. """
  58. global _pipeline
  59. if _pipeline is not None:
  60. return _pipeline
  61. try:
  62. from pyannote.audio import Model, Inference
  63. except ImportError as exc:
  64. raise RuntimeError(
  65. "pyannote.audio is required for voiceprint extraction. "
  66. "Install with: pip install pyannote.audio"
  67. ) from exc
  68. hf_token = os.environ.get("HF_TOKEN")
  69. if not hf_token:
  70. raise RuntimeError(
  71. "HF_TOKEN environment variable is not set. "
  72. "Set it in start.bat or run: huggingface-cli login"
  73. )
  74. logger.info("[Embeddings] Loading pyannote speaker embedding model...")
  75. try:
  76. model = Model.from_pretrained(
  77. "pyannote/embedding",
  78. use_auth_token=hf_token,
  79. )
  80. _pipeline = Inference(model, window="whole")
  81. logger.info("[Embeddings] Speaker embedding model loaded.")
  82. except Exception as exc:
  83. raise RuntimeError(
  84. f"Failed to load pyannote/embedding model: {exc}\n"
  85. "Ensure you have accepted the model conditions at "
  86. "https://huggingface.co/pyannote/embedding"
  87. ) from exc
  88. return _pipeline
  89. # ── Helpers ────────────────────────────────────────────────────────────────────
  90. def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
  91. """Cosine similarity between two 1-D vectors. Returns value in [-1, 1]."""
  92. a = a.flatten().astype(np.float32)
  93. b = b.flatten().astype(np.float32)
  94. norm_a = np.linalg.norm(a)
  95. norm_b = np.linalg.norm(b)
  96. if norm_a == 0 or norm_b == 0:
  97. return 0.0
  98. return float(np.dot(a, b) / (norm_a * norm_b))
  99. def _load_speakers_json() -> dict:
  100. """Load speakers.json, returning empty dict on any failure."""
  101. if SPEAKERS_FILE.exists():
  102. try:
  103. raw = json.loads(SPEAKERS_FILE.read_text(encoding="utf-8"))
  104. # Support both legacy flat format {"SPEAKER_00": "Name"}
  105. # and new rich format {"SPEAKER_00": {"name": "...", ...}}
  106. return raw if isinstance(raw, dict) else {}
  107. except (json.JSONDecodeError, OSError):
  108. pass
  109. return {}
  110. def _save_speakers_json(data: dict) -> None:
  111. try:
  112. SPEAKERS_FILE.write_text(
  113. json.dumps(data, indent=2, ensure_ascii=False),
  114. encoding="utf-8",
  115. )
  116. except OSError as exc:
  117. logger.error(f"[Embeddings] Could not save speakers.json: {exc}")
  118. def _normalise_speaker_entry(entry) -> dict:
  119. """
  120. Ensure a speakers.json entry is in rich-dict format.
  121. Handles legacy flat strings: "Pastor John" -> {"name": "Pastor John", ...}
  122. """
  123. if isinstance(entry, str):
  124. return {
  125. "name": entry,
  126. "has_embedding": False,
  127. "embedding_updated": None,
  128. "colour": None,
  129. "notes": "",
  130. }
  131. if isinstance(entry, dict):
  132. entry.setdefault("name", "")
  133. entry.setdefault("has_embedding", False)
  134. entry.setdefault("embedding_updated", None)
  135. entry.setdefault("colour", None)
  136. entry.setdefault("notes", "")
  137. return entry
  138. return {"name": str(entry), "has_embedding": False,
  139. "embedding_updated": None, "colour": None, "notes": ""}
  140. def _update_speaker_meta(speaker_id: str, has_embedding: bool) -> None:
  141. """Update speakers.json metadata for a speaker after enrolment/deletion."""
  142. data = _load_speakers_json()
  143. entry = _normalise_speaker_entry(data.get(speaker_id, speaker_id))
  144. entry["has_embedding"] = has_embedding
  145. entry["embedding_updated"] = (
  146. datetime.now(timezone.utc).isoformat() if has_embedding else None
  147. )
  148. data[speaker_id] = entry
  149. _save_speakers_json(data)
  150. # ── Audio extraction helper ────────────────────────────────────────────────────
  151. def extract_audio_segment(
  152. audio_path: Path | str,
  153. start_sec: float,
  154. end_sec: float,
  155. out_path: Optional[Path] = None,
  156. ) -> Path:
  157. """
  158. Extract a segment of audio using miniaudio and save as a temporary WAV file.
  159. Args:
  160. audio_path: Source audio file (MP3, WAV, FLAC, OGG, M4A, etc.)
  161. start_sec: Start time in seconds
  162. end_sec: End time in seconds
  163. out_path: Optional explicit output path; if None, a temp file is used
  164. Returns:
  165. Path to the extracted WAV file (caller should delete when done if temp)
  166. """
  167. import tempfile
  168. import struct
  169. import wave
  170. try:
  171. import miniaudio
  172. except ImportError as exc:
  173. raise RuntimeError("miniaudio is required: pip install miniaudio") from exc
  174. duration = end_sec - start_sec
  175. if duration < MIN_AUDIO_SEC:
  176. logger.warning(
  177. f"[Embeddings] Short segment ({duration:.1f}s < {MIN_AUDIO_SEC}s) "
  178. "— embedding may be less reliable"
  179. )
  180. sample_rate = 16000
  181. n_channels = 1
  182. frames_needed = int(duration * sample_rate)
  183. skip_frames = int(start_sec * sample_rate)
  184. # Stream file and skip to start position
  185. stream = miniaudio.stream_file(
  186. str(audio_path),
  187. output_format=miniaudio.SampleFormat.SIGNED16,
  188. nchannels=n_channels,
  189. sample_rate=sample_rate,
  190. frames_to_read=4096,
  191. )
  192. collected: list[bytes] = []
  193. collected_frames = 0
  194. skipped_frames = 0
  195. for chunk in stream:
  196. chunk_bytes = bytes(chunk)
  197. chunk_frames = len(chunk_bytes) // 2 # int16 = 2 bytes per sample
  198. # Skip frames before start_sec
  199. if skipped_frames < skip_frames:
  200. remaining_skip = skip_frames - skipped_frames
  201. if chunk_frames <= remaining_skip:
  202. skipped_frames += chunk_frames
  203. continue
  204. else:
  205. # Partial skip
  206. offset = remaining_skip * 2
  207. chunk_bytes = chunk_bytes[offset:]
  208. chunk_frames = len(chunk_bytes) // 2
  209. skipped_frames = skip_frames
  210. # Collect up to frames_needed
  211. remaining_needed = frames_needed - collected_frames
  212. if chunk_frames >= remaining_needed:
  213. collected.append(chunk_bytes[: remaining_needed * 2])
  214. collected_frames += remaining_needed
  215. break
  216. else:
  217. collected.append(chunk_bytes)
  218. collected_frames += chunk_frames
  219. if not collected:
  220. raise ValueError(
  221. f"No audio extracted from {audio_path} at {start_sec}–{end_sec}s. "
  222. "Check that the file is valid and the time range is within its duration."
  223. )
  224. pcm_data = b"".join(collected)
  225. # Write to WAV
  226. if out_path is None:
  227. tmp = tempfile.NamedTemporaryFile(
  228. suffix=".wav", delete=False, prefix="embedding_"
  229. )
  230. out_path = Path(tmp.name)
  231. tmp.close()
  232. with wave.open(str(out_path), "wb") as wf:
  233. wf.setnchannels(n_channels)
  234. wf.setsampwidth(2) # int16
  235. wf.setframerate(sample_rate)
  236. wf.writeframes(pcm_data)
  237. return out_path
  238. # ── Main registry class ────────────────────────────────────────────────────────
  239. class EmbeddingRegistry:
  240. """
  241. Manages speaker voice embeddings on disk.
  242. Thread-safe for reading; write operations should be called from
  243. a single thread (the admin server) in normal use.
  244. """
  245. def __init__(
  246. self,
  247. embeddings_dir: Path | str = EMBEDDINGS_DIR,
  248. speakers_file: Path | str = SPEAKERS_FILE,
  249. threshold: float = DEFAULT_THRESHOLD,
  250. ):
  251. self.embeddings_dir = Path(embeddings_dir)
  252. self.speakers_file = Path(speakers_file)
  253. self.threshold = threshold
  254. self.embeddings_dir.mkdir(parents=True, exist_ok=True)
  255. def _path(self, speaker_id: str) -> Path:
  256. return self.embeddings_dir / f"{speaker_id}.npy"
  257. # ── Core storage ───────────────────────────────────────────────────────────
  258. def save(self, speaker_id: str, embedding: np.ndarray) -> None:
  259. """Persist a voiceprint embedding for a speaker."""
  260. emb = embedding.flatten().astype(np.float32)
  261. if emb.shape[0] != EMBEDDING_DIM:
  262. logger.warning(
  263. f"[Embeddings] Unexpected embedding dim {emb.shape[0]} "
  264. f"(expected {EMBEDDING_DIM}) for {speaker_id}"
  265. )
  266. np.save(str(self._path(speaker_id)), emb)
  267. _update_speaker_meta(speaker_id, has_embedding=True)
  268. logger.info(f"[Embeddings] Saved voiceprint for {speaker_id}")
  269. def load(self, speaker_id: str) -> Optional[np.ndarray]:
  270. """Load a stored embedding, or return None if not found."""
  271. p = self._path(speaker_id)
  272. if not p.exists():
  273. return None
  274. try:
  275. return np.load(str(p)).astype(np.float32)
  276. except Exception as exc:
  277. logger.error(f"[Embeddings] Failed to load {p}: {exc}")
  278. return None
  279. def has(self, speaker_id: str) -> bool:
  280. """Return True if a voiceprint exists for this speaker."""
  281. return self._path(speaker_id).exists()
  282. def delete(self, speaker_id: str) -> bool:
  283. """Delete a stored voiceprint. Returns True if it existed."""
  284. p = self._path(speaker_id)
  285. if p.exists():
  286. p.unlink()
  287. _update_speaker_meta(speaker_id, has_embedding=False)
  288. logger.info(f"[Embeddings] Deleted voiceprint for {speaker_id}")
  289. return True
  290. return False
  291. def list_enrolled(self) -> list[dict]:
  292. """Return list of speakers with stored voiceprints."""
  293. enrolled = []
  294. for p in sorted(self.embeddings_dir.glob("*.npy")):
  295. speaker_id = p.stem
  296. data = _load_speakers_json()
  297. entry = _normalise_speaker_entry(data.get(speaker_id, speaker_id))
  298. enrolled.append({
  299. "id": speaker_id,
  300. "name": entry["name"],
  301. "updated": entry.get("embedding_updated"),
  302. "size_kb": round(p.stat().st_size / 1024, 1),
  303. })
  304. return enrolled
  305. # ── Extraction ─────────────────────────────────────────────────────────────
  306. def extract(
  307. self,
  308. audio_path: Path | str,
  309. start_sec: float = 0.0,
  310. end_sec: Optional[float] = None,
  311. ) -> np.ndarray:
  312. """
  313. Extract a speaker embedding from an audio file segment.
  314. Args:
  315. audio_path: Path to audio file (MP3, WAV, FLAC, M4A, OGG, etc.)
  316. start_sec: Start of segment in seconds (default: 0)
  317. end_sec: End of segment in seconds (default: use whole file)
  318. Returns:
  319. numpy array of shape (512,) — the speaker embedding
  320. Raises:
  321. RuntimeError: if pyannote model fails to load
  322. ValueError: if audio segment is empty or too short
  323. """
  324. import tempfile
  325. audio_path = Path(audio_path)
  326. if not audio_path.exists():
  327. raise FileNotFoundError(f"Audio file not found: {audio_path}")
  328. # If no end time, get file duration and use whole file
  329. if end_sec is None:
  330. try:
  331. import miniaudio
  332. info = miniaudio.get_file_info(str(audio_path))
  333. end_sec = info.duration
  334. except Exception:
  335. end_sec = start_sec + 30.0 # fallback
  336. duration = end_sec - start_sec
  337. if duration <= 0:
  338. raise ValueError(f"Invalid segment: start={start_sec}, end={end_sec}")
  339. # Extract segment to a temp WAV
  340. tmp_wav = None
  341. try:
  342. tmp_wav = extract_audio_segment(audio_path, start_sec, end_sec)
  343. pipeline = _get_pipeline()
  344. # pyannote Inference accepts a file path directly
  345. embedding = pipeline(str(tmp_wav))
  346. # embedding may be a pyannote Annotation or ndarray depending on version
  347. if hasattr(embedding, "data"):
  348. arr = np.array(embedding.data).flatten()
  349. elif isinstance(embedding, np.ndarray):
  350. arr = embedding.flatten()
  351. else:
  352. arr = np.array(embedding).flatten()
  353. logger.info(
  354. f"[Embeddings] Extracted embedding from {audio_path.name} "
  355. f"[{start_sec:.1f}s – {end_sec:.1f}s] → shape {arr.shape}"
  356. )
  357. return arr.astype(np.float32)
  358. finally:
  359. if tmp_wav and tmp_wav.exists():
  360. try:
  361. tmp_wav.unlink()
  362. except OSError:
  363. pass
  364. def extract_and_save(
  365. self,
  366. speaker_id: str,
  367. audio_path: Path | str,
  368. start_sec: float = 0.0,
  369. end_sec: Optional[float] = None,
  370. ) -> np.ndarray:
  371. """
  372. Extract embedding from audio segment and immediately save it.
  373. Convenience wrapper around extract() + save().
  374. """
  375. embedding = self.extract(audio_path, start_sec, end_sec)
  376. self.save(speaker_id, embedding)
  377. return embedding
  378. # ── Matching ───────────────────────────────────────────────────────────────
  379. def find_match(
  380. self,
  381. live_embedding: np.ndarray,
  382. threshold: Optional[float] = None,
  383. exclude_ids: Optional[set[str]] = None,
  384. ) -> Optional[tuple[str, float]]:
  385. """
  386. Find the best matching enrolled speaker for a live embedding.
  387. Args:
  388. live_embedding: Embedding from live audio (shape: (512,))
  389. threshold: Cosine similarity threshold (default: self.threshold)
  390. exclude_ids: Speaker IDs to skip (already confirmed this session)
  391. Returns:
  392. (speaker_id, similarity_score) if match found above threshold,
  393. or None if no match.
  394. Notes:
  395. - Similarity is cosine similarity in [-1, 1]; higher = more similar
  396. - Typical same-speaker similarity: 0.85–0.98
  397. - Typical different-speaker similarity: 0.2–0.65
  398. - Threshold 0.82 is conservative; lower to 0.75 for more permissive
  399. """
  400. if threshold is None:
  401. threshold = self.threshold
  402. exclude_ids = exclude_ids or set()
  403. best_id = None
  404. best_score = -1.0
  405. for p in self.embeddings_dir.glob("*.npy"):
  406. speaker_id = p.stem
  407. if speaker_id in exclude_ids:
  408. continue
  409. stored = self.load(speaker_id)
  410. if stored is None:
  411. continue
  412. score = _cosine_similarity(live_embedding, stored)
  413. logger.debug(f"[Embeddings] {speaker_id}: similarity={score:.4f}")
  414. if score > best_score:
  415. best_score = score
  416. best_id = speaker_id
  417. if best_id and best_score >= threshold:
  418. logger.info(
  419. f"[Embeddings] Match: {best_id} (similarity={best_score:.4f})"
  420. )
  421. return (best_id, best_score)
  422. logger.debug(
  423. f"[Embeddings] No match above threshold={threshold:.2f} "
  424. f"(best={best_score:.4f})"
  425. )
  426. return None
  427. def similarity_scores(
  428. self,
  429. live_embedding: np.ndarray,
  430. ) -> list[dict]:
  431. """
  432. Return similarity scores against all enrolled speakers, sorted descending.
  433. Useful for admin UI diagnostics / confidence display.
  434. """
  435. results = []
  436. data = _load_speakers_json()
  437. for p in self.embeddings_dir.glob("*.npy"):
  438. speaker_id = p.stem
  439. stored = self.load(speaker_id)
  440. if stored is None:
  441. continue
  442. score = _cosine_similarity(live_embedding, stored)
  443. entry = _normalise_speaker_entry(data.get(speaker_id, speaker_id))
  444. results.append({
  445. "id": speaker_id,
  446. "name": entry["name"],
  447. "similarity": round(score, 4),
  448. "match": score >= self.threshold,
  449. })
  450. return sorted(results, key=lambda x: x["similarity"], reverse=True)
  451. # ── Accumulator for live audio ─────────────────────────────────────────────
  452. def make_accumulator(
  453. self,
  454. min_seconds: float = 5.0,
  455. sample_rate: int = 16000,
  456. ) -> "LiveEmbeddingAccumulator":
  457. """
  458. Create an accumulator that buffers live PCM audio and
  459. extracts an embedding once enough audio has been collected.
  460. """
  461. return LiveEmbeddingAccumulator(
  462. registry = self,
  463. min_seconds = min_seconds,
  464. sample_rate = sample_rate,
  465. )
  466. # ── Live audio accumulator ─────────────────────────────────────────────────────
  467. class LiveEmbeddingAccumulator:
  468. """
  469. Buffers raw int16 PCM chunks from the live microphone stream.
  470. Once MIN_SECONDS of audio is accumulated, extracts an embedding
  471. that can be used for speaker matching or enrolment.
  472. Usage:
  473. acc = registry.make_accumulator()
  474. acc.push(pcm_chunk_bytes)
  475. ...
  476. if acc.ready():
  477. embedding = acc.extract_embedding()
  478. match = registry.find_match(embedding)
  479. """
  480. def __init__(
  481. self,
  482. registry: EmbeddingRegistry,
  483. min_seconds: float = 5.0,
  484. sample_rate: int = 16000,
  485. ):
  486. self.registry = registry
  487. self.min_frames = int(min_seconds * sample_rate)
  488. self.sample_rate = sample_rate
  489. self._frames: list[bytes] = []
  490. self._n_frames = 0
  491. def push(self, pcm_bytes: bytes) -> None:
  492. """Add a chunk of raw int16 PCM audio."""
  493. self._frames.append(pcm_bytes)
  494. self._n_frames += len(pcm_bytes) // 2 # int16 = 2 bytes
  495. def ready(self) -> bool:
  496. """Return True if enough audio has been accumulated."""
  497. return self._n_frames >= self.min_frames
  498. def seconds_accumulated(self) -> float:
  499. return self._n_frames / self.sample_rate
  500. def reset(self) -> None:
  501. self._frames = []
  502. self._n_frames = 0
  503. def extract_embedding(self) -> np.ndarray:
  504. """
  505. Extract an embedding from the buffered audio.
  506. Raises RuntimeError if not enough audio has been collected.
  507. """
  508. if not self.ready():
  509. raise RuntimeError(
  510. f"Not enough audio: {self.seconds_accumulated():.1f}s "
  511. f"accumulated, need {self.min_frames / self.sample_rate:.1f}s"
  512. )
  513. import tempfile
  514. import wave
  515. # Write buffered PCM to a temp WAV
  516. pcm_data = b"".join(self._frames)
  517. tmp = tempfile.NamedTemporaryFile(
  518. suffix=".wav", delete=False, prefix="live_emb_"
  519. )
  520. tmp_path = Path(tmp.name)
  521. tmp.close()
  522. try:
  523. with wave.open(str(tmp_path), "wb") as wf:
  524. wf.setnchannels(1)
  525. wf.setsampwidth(2)
  526. wf.setframerate(self.sample_rate)
  527. wf.writeframes(pcm_data)
  528. return self.registry.extract(tmp_path)
  529. finally:
  530. if tmp_path.exists():
  531. try:
  532. tmp_path.unlink()
  533. except OSError:
  534. pass
  535. # ── Standalone test / CLI ──────────────────────────────────────────────────────
  536. if __name__ == "__main__":
  537. import argparse
  538. import sys
  539. logging.basicConfig(level=logging.INFO, format="%(message)s")
  540. parser = argparse.ArgumentParser(
  541. description="Speaker voiceprint utility",
  542. formatter_class=argparse.RawDescriptionHelpFormatter,
  543. epilog="""
  544. Examples:
  545. # Extract and save a voiceprint from seconds 45-55 of a recording
  546. python embeddings.py enrol SPEAKER_00 recording.mp3 --start 45 --end 55
  547. # List all enrolled speakers
  548. python embeddings.py list
  549. # Test match: compare a 10s clip against enrolled speakers
  550. python embeddings.py match recording.mp3 --start 120 --end 130
  551. # Delete a voiceprint
  552. python embeddings.py delete SPEAKER_00
  553. # Show similarity scores for a clip against all enrolled speakers
  554. python embeddings.py scores recording.mp3 --start 45 --end 55
  555. """,
  556. )
  557. sub = parser.add_subparsers(dest="cmd")
  558. # enrol
  559. p_enrol = sub.add_parser("enrol", help="Extract and save a voiceprint")
  560. p_enrol.add_argument("speaker_id")
  561. p_enrol.add_argument("audio_file")
  562. p_enrol.add_argument("--start", type=float, default=0.0)
  563. p_enrol.add_argument("--end", type=float, default=None)
  564. # list
  565. sub.add_parser("list", help="List enrolled speakers")
  566. # match
  567. p_match = sub.add_parser("match", help="Find matching speaker for an audio clip")
  568. p_match.add_argument("audio_file")
  569. p_match.add_argument("--start", type=float, default=0.0)
  570. p_match.add_argument("--end", type=float, default=None)
  571. p_match.add_argument("--threshold", type=float, default=DEFAULT_THRESHOLD)
  572. # scores
  573. p_scores = sub.add_parser("scores", help="Show similarity against all enrolled speakers")
  574. p_scores.add_argument("audio_file")
  575. p_scores.add_argument("--start", type=float, default=0.0)
  576. p_scores.add_argument("--end", type=float, default=None)
  577. # delete
  578. p_del = sub.add_parser("delete", help="Delete a stored voiceprint")
  579. p_del.add_argument("speaker_id")
  580. args = parser.parse_args()
  581. if not args.cmd:
  582. parser.print_help()
  583. sys.exit(0)
  584. registry = EmbeddingRegistry()
  585. if args.cmd == "enrol":
  586. emb = registry.extract_and_save(
  587. args.speaker_id, args.audio_file, args.start, args.end
  588. )
  589. print(f"✓ Enrolled {args.speaker_id} — embedding shape: {emb.shape}")
  590. elif args.cmd == "list":
  591. enrolled = registry.list_enrolled()
  592. if not enrolled:
  593. print("No speakers enrolled yet.")
  594. else:
  595. print(f"{'ID':<15} {'Name':<25} {'Updated':<28} {'Size'}")
  596. print("-" * 75)
  597. for e in enrolled:
  598. print(
  599. f"{e['id']:<15} {e['name']:<25} "
  600. f"{(e['updated'] or 'unknown'):<28} {e['size_kb']} KB"
  601. )
  602. elif args.cmd == "match":
  603. emb = registry.extract(args.audio_file, args.start, args.end)
  604. match = registry.find_match(emb, threshold=args.threshold)
  605. if match:
  606. sid, score = match
  607. data = _load_speakers_json()
  608. entry = _normalise_speaker_entry(data.get(sid, sid))
  609. print(f"✓ Match: {sid} ({entry['name']}) — similarity {score:.4f}")
  610. else:
  611. print(f"✗ No match above threshold {args.threshold:.2f}")
  612. elif args.cmd == "scores":
  613. emb = registry.extract(args.audio_file, args.start, args.end)
  614. scores = registry.similarity_scores(emb)
  615. if not scores:
  616. print("No enrolled speakers to compare against.")
  617. else:
  618. print(f"{'ID':<15} {'Name':<25} {'Similarity':<12} {'Match?'}")
  619. print("-" * 60)
  620. for s in scores:
  621. flag = "✓" if s["match"] else " "
  622. print(
  623. f"{s['id']:<15} {s['name']:<25} "
  624. f"{s['similarity']:<12.4f} {flag}"
  625. )
  626. elif args.cmd == "delete":
  627. if registry.delete(args.speaker_id):
  628. print(f"✓ Deleted voiceprint for {args.speaker_id}")
  629. else:
  630. print(f"No voiceprint found for {args.speaker_id}")