embeddings.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916
  1. #!/usr/bin/env python3
  2. """
  3. embeddings.py — Speaker Voiceprint Management
  4. Handles extraction, storage, and matching of speaker voice embeddings
  5. using pyannote.audio's SpeakerEmbedding pipeline.
  6. Embeddings are stored as numpy .npy files in bridge/embeddings/.
  7. speakers.json is updated with has_embedding / embedding_updated metadata.
  8. Transcript segments are logged to bridge/transcript_segments.jsonl —
  9. a line-delimited JSON log of every finalised speech segment with speaker ID,
  10. timestamps, and text. This log is used for automatic best-segment enrolment.
  11. Public API
  12. ----------
  13. registry = EmbeddingRegistry()
  14. # Extract from an audio file segment
  15. embedding = registry.extract(audio_path, start_sec=45.0, end_sec=55.0)
  16. # Save a voiceprint for a known speaker
  17. registry.save(speaker_id="SPEAKER_00", embedding=embedding)
  18. # Auto-enrol from a transcript log (picks the cleanest segment)
  19. registry.enrol_from_transcript("SPEAKER_00", audio_path, segments)
  20. # Find the best matching speaker for a live embedding
  21. match = registry.find_match(live_embedding, threshold=0.82)
  22. # -> ("SPEAKER_00", 0.91) or None
  23. # Load a stored embedding
  24. emb = registry.load("SPEAKER_00") # -> np.ndarray | None
  25. # Delete a voiceprint
  26. registry.delete("SPEAKER_00")
  27. # List all speakers with voiceprints
  28. registry.list_enrolled()
  29. # Accumulate live audio and extract when ready
  30. acc = registry.make_accumulator()
  31. acc.push(pcm_bytes)
  32. if acc.ready():
  33. emb = acc.extract_embedding()
  34. match = registry.find_match(emb)
  35. """
  36. from __future__ import annotations
  37. import json
  38. import logging
  39. import os
  40. from datetime import datetime, timezone
  41. from pathlib import Path
  42. from typing import Optional
  43. import numpy as np
  44. logger = logging.getLogger(__name__)
  45. # ── Paths ──────────────────────────────────────────────────────────────────────
  46. _HERE = Path(__file__).parent
  47. EMBEDDINGS_DIR = _HERE / "embeddings"
  48. SPEAKERS_FILE = _HERE / "speakers.json"
  49. TRANSCRIPT_SEGMENTS_LOG = _HERE / "transcript_segments.jsonl"
  50. EMBEDDING_DIM = 512
  51. DEFAULT_THRESHOLD = 0.82
  52. MIN_AUDIO_SEC = 5.0
  53. # ── Lazy model loader ──────────────────────────────────────────────────────────
  54. _pipeline = None
  55. def _get_pipeline():
  56. """
  57. Load the pyannote SpeakerEmbedding pipeline on first call.
  58. Requires HF_TOKEN env var or huggingface-cli login.
  59. """
  60. global _pipeline
  61. if _pipeline is not None:
  62. return _pipeline
  63. try:
  64. from pyannote.audio import Model, Inference
  65. except ImportError as exc:
  66. raise RuntimeError(
  67. "pyannote.audio is required for voiceprint extraction. "
  68. "Install with: pip install pyannote.audio"
  69. ) from exc
  70. hf_token = os.environ.get("HF_TOKEN")
  71. if not hf_token:
  72. raise RuntimeError(
  73. "HF_TOKEN environment variable is not set. "
  74. "Set it in start.bat or run: huggingface-cli login"
  75. )
  76. logger.info("[Embeddings] Loading pyannote speaker embedding model...")
  77. try:
  78. model = Model.from_pretrained("pyannote/embedding", use_auth_token=hf_token)
  79. _pipeline = Inference(model, window="whole")
  80. logger.info("[Embeddings] Speaker embedding model loaded.")
  81. except Exception as exc:
  82. raise RuntimeError(
  83. f"Failed to load pyannote/embedding model: {exc}\n"
  84. "Ensure you have accepted model conditions at "
  85. "https://huggingface.co/pyannote/embedding"
  86. ) from exc
  87. return _pipeline
  88. # ── Helpers ────────────────────────────────────────────────────────────────────
  89. def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
  90. a = a.flatten().astype(np.float32)
  91. b = b.flatten().astype(np.float32)
  92. norm_a = np.linalg.norm(a)
  93. norm_b = np.linalg.norm(b)
  94. if norm_a == 0 or norm_b == 0:
  95. return 0.0
  96. return float(np.dot(a, b) / (norm_a * norm_b))
  97. def _load_speakers_json() -> dict:
  98. if SPEAKERS_FILE.exists():
  99. try:
  100. raw = json.loads(SPEAKERS_FILE.read_text(encoding="utf-8"))
  101. return raw if isinstance(raw, dict) else {}
  102. except (json.JSONDecodeError, OSError):
  103. pass
  104. return {}
  105. def _save_speakers_json(data: dict) -> None:
  106. try:
  107. SPEAKERS_FILE.write_text(
  108. json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8"
  109. )
  110. except OSError as exc:
  111. logger.error(f"[Embeddings] Could not save speakers.json: {exc}")
  112. def _normalise_speaker_entry(entry) -> dict:
  113. """Migrate flat string entries to rich dict format."""
  114. if isinstance(entry, str):
  115. return {"name": entry, "has_embedding": False,
  116. "embedding_updated": None, "colour": None, "notes": ""}
  117. if isinstance(entry, dict):
  118. entry.setdefault("name", "")
  119. entry.setdefault("has_embedding", False)
  120. entry.setdefault("embedding_updated", None)
  121. entry.setdefault("colour", None)
  122. entry.setdefault("notes", "")
  123. return entry
  124. return {"name": str(entry), "has_embedding": False,
  125. "embedding_updated": None, "colour": None, "notes": ""}
  126. def _update_speaker_meta(speaker_id: str, has_embedding: bool) -> None:
  127. data = _load_speakers_json()
  128. entry = _normalise_speaker_entry(data.get(speaker_id, speaker_id))
  129. entry["has_embedding"] = has_embedding
  130. entry["embedding_updated"] = (
  131. datetime.now(timezone.utc).isoformat() if has_embedding else None
  132. )
  133. data[speaker_id] = entry
  134. _save_speakers_json(data)
  135. # ── Transcript segment log ─────────────────────────────────────────────────────
  136. def log_transcript_segment(
  137. speaker_id: str,
  138. text: str,
  139. start_sec: float,
  140. end_sec: float,
  141. session_id: Optional[str] = None,
  142. ) -> None:
  143. """
  144. Append a finalised transcript segment to transcript_segments.jsonl.
  145. Called by bridge.py each time a is_final segment is received.
  146. This log is later used by enrol_from_transcript() to find the best
  147. clean segment for voiceprint extraction.
  148. Args:
  149. speaker_id: Diarization label, e.g. "SPEAKER_00"
  150. text: Transcribed text for this segment
  151. start_sec: Segment start time in seconds (from audio start)
  152. end_sec: Segment end time in seconds
  153. session_id: Optional session identifier (e.g. date string)
  154. """
  155. record = {
  156. "ts": datetime.now(timezone.utc).isoformat(),
  157. "session": session_id or datetime.now(timezone.utc).strftime("%Y-%m-%d"),
  158. "speaker": speaker_id,
  159. "start": round(start_sec, 3),
  160. "end": round(end_sec, 3),
  161. "duration": round(end_sec - start_sec, 3),
  162. "text": text.strip(),
  163. }
  164. try:
  165. with TRANSCRIPT_SEGMENTS_LOG.open("a", encoding="utf-8") as f:
  166. f.write(json.dumps(record, ensure_ascii=False) + "\n")
  167. except OSError as exc:
  168. logger.warning(f"[Embeddings] Could not write segment log: {exc}")
  169. def load_transcript_segments(
  170. session_id: Optional[str] = None,
  171. speaker_id: Optional[str] = None,
  172. ) -> list[dict]:
  173. """
  174. Load transcript segments from the log file.
  175. Args:
  176. session_id: If set, only return segments from this session
  177. speaker_id: If set, only return segments for this speaker
  178. Returns:
  179. List of segment dicts sorted by start time
  180. """
  181. if not TRANSCRIPT_SEGMENTS_LOG.exists():
  182. return []
  183. segments = []
  184. try:
  185. with TRANSCRIPT_SEGMENTS_LOG.open(encoding="utf-8") as f:
  186. for line in f:
  187. line = line.strip()
  188. if not line:
  189. continue
  190. try:
  191. rec = json.loads(line)
  192. if session_id and rec.get("session") != session_id:
  193. continue
  194. if speaker_id and rec.get("speaker") != speaker_id:
  195. continue
  196. segments.append(rec)
  197. except json.JSONDecodeError:
  198. continue
  199. except OSError as exc:
  200. logger.warning(f"[Embeddings] Could not read segment log: {exc}")
  201. return sorted(segments, key=lambda x: x.get("start", 0))
  202. def get_best_enrolment_segments(
  203. speaker_id: str,
  204. session_id: Optional[str] = None,
  205. min_duration: float = 8.0,
  206. top_n: int = 5,
  207. ) -> list[dict]:
  208. """
  209. Find the best candidate segments for enrolment from the transcript log.
  210. Ranks segments by:
  211. 1. Duration (longer = more reliable embedding)
  212. 2. Isolation (gap before/after to avoid speaker overlap bleed)
  213. 3. Text length (more words = more acoustic variety = better embedding)
  214. Args:
  215. speaker_id: Speaker to find segments for
  216. session_id: Limit to a specific session (None = all sessions)
  217. min_duration: Minimum segment duration in seconds
  218. top_n: How many candidate segments to return
  219. Returns:
  220. List of segment dicts with added "score" and "recommendation" fields,
  221. sorted by score descending
  222. """
  223. all_segs = load_transcript_segments(session_id=session_id)
  224. spk_segs = [s for s in all_segs
  225. if s.get("speaker") == speaker_id
  226. and s.get("duration", 0) >= min_duration]
  227. if not spk_segs:
  228. return []
  229. # Build a set of all segment boundaries for isolation scoring
  230. all_starts = {s["start"] for s in all_segs}
  231. all_ends = {s["end"] for s in all_segs}
  232. scored = []
  233. for seg in spk_segs:
  234. duration = seg.get("duration", 0)
  235. text = seg.get("text", "")
  236. word_count = len(text.split())
  237. # Gap before this segment (silence / different speaker)
  238. gap_before = min(
  239. (seg["start"] - e for e in all_ends if e <= seg["start"]),
  240. default=seg["start"],
  241. )
  242. # Gap after this segment
  243. gap_after = min(
  244. (s - seg["end"] for s in all_starts if s >= seg["end"]),
  245. default=999.0,
  246. )
  247. # Score: weight duration heavily, reward isolation, penalise short text
  248. isolation = min(gap_before, gap_after)
  249. score = (
  250. duration * 2.0 # main driver
  251. + isolation * 1.5 # isolation bonus
  252. + word_count * 0.1 # text richness
  253. )
  254. # Human-readable recommendation
  255. if duration >= 15 and isolation >= 1.0:
  256. rec = "Excellent"
  257. elif duration >= 10 and isolation >= 0.5:
  258. rec = "Good"
  259. elif duration >= 8:
  260. rec = "Acceptable"
  261. else:
  262. rec = "Short"
  263. scored.append({**seg, "score": round(score, 2), "recommendation": rec})
  264. return sorted(scored, key=lambda x: x["score"], reverse=True)[:top_n]
  265. # ── Audio extraction ───────────────────────────────────────────────────────────
  266. def extract_audio_segment(
  267. audio_path: Path | str,
  268. start_sec: float,
  269. end_sec: float,
  270. out_path: Optional[Path] = None,
  271. ) -> Path:
  272. """
  273. Extract a segment of audio using miniaudio and save as a WAV file.
  274. Returns path to the extracted WAV (caller deletes if temp).
  275. """
  276. import tempfile
  277. import wave
  278. try:
  279. import miniaudio
  280. except ImportError as exc:
  281. raise RuntimeError("miniaudio is required: pip install miniaudio") from exc
  282. duration = end_sec - start_sec
  283. if duration < MIN_AUDIO_SEC:
  284. logger.warning(
  285. f"[Embeddings] Short segment ({duration:.1f}s < {MIN_AUDIO_SEC}s) "
  286. "— embedding may be less reliable"
  287. )
  288. sample_rate = 16000
  289. frames_needed = int(duration * sample_rate)
  290. skip_frames = int(start_sec * sample_rate)
  291. stream = miniaudio.stream_file(
  292. str(audio_path),
  293. output_format=miniaudio.SampleFormat.SIGNED16,
  294. nchannels=1,
  295. sample_rate=sample_rate,
  296. frames_to_read=4096,
  297. )
  298. collected: list[bytes] = []
  299. collected_frames = 0
  300. skipped_frames = 0
  301. for chunk in stream:
  302. chunk_bytes = bytes(chunk)
  303. chunk_frames = len(chunk_bytes) // 2
  304. if skipped_frames < skip_frames:
  305. remaining_skip = skip_frames - skipped_frames
  306. if chunk_frames <= remaining_skip:
  307. skipped_frames += chunk_frames
  308. continue
  309. else:
  310. offset = remaining_skip * 2
  311. chunk_bytes = chunk_bytes[offset:]
  312. chunk_frames = len(chunk_bytes) // 2
  313. skipped_frames = skip_frames
  314. remaining_needed = frames_needed - collected_frames
  315. if chunk_frames >= remaining_needed:
  316. collected.append(chunk_bytes[: remaining_needed * 2])
  317. collected_frames += remaining_needed
  318. break
  319. else:
  320. collected.append(chunk_bytes)
  321. collected_frames += chunk_frames
  322. if not collected:
  323. raise ValueError(
  324. f"No audio extracted from {audio_path} at {start_sec}–{end_sec}s."
  325. )
  326. pcm_data = b"".join(collected)
  327. if out_path is None:
  328. tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False, prefix="embedding_")
  329. out_path = Path(tmp.name)
  330. tmp.close()
  331. with wave.open(str(out_path), "wb") as wf:
  332. wf.setnchannels(1)
  333. wf.setsampwidth(2)
  334. wf.setframerate(sample_rate)
  335. wf.writeframes(pcm_data)
  336. return out_path
  337. # ── Main registry class ────────────────────────────────────────────────────────
  338. class EmbeddingRegistry:
  339. def __init__(
  340. self,
  341. embeddings_dir: Path | str = EMBEDDINGS_DIR,
  342. speakers_file: Path | str = SPEAKERS_FILE,
  343. threshold: float = DEFAULT_THRESHOLD,
  344. ):
  345. self.embeddings_dir = Path(embeddings_dir)
  346. self.speakers_file = Path(speakers_file)
  347. self.threshold = threshold
  348. self.embeddings_dir.mkdir(parents=True, exist_ok=True)
  349. def _path(self, speaker_id: str) -> Path:
  350. return self.embeddings_dir / f"{speaker_id}.npy"
  351. # ── Storage ────────────────────────────────────────────────────────────────
  352. def save(self, speaker_id: str, embedding: np.ndarray) -> None:
  353. emb = embedding.flatten().astype(np.float32)
  354. np.save(str(self._path(speaker_id)), emb)
  355. _update_speaker_meta(speaker_id, has_embedding=True)
  356. logger.info(f"[Embeddings] Saved voiceprint for {speaker_id}")
  357. def load(self, speaker_id: str) -> Optional[np.ndarray]:
  358. p = self._path(speaker_id)
  359. if not p.exists():
  360. return None
  361. try:
  362. return np.load(str(p)).astype(np.float32)
  363. except Exception as exc:
  364. logger.error(f"[Embeddings] Failed to load {p}: {exc}")
  365. return None
  366. def has(self, speaker_id: str) -> bool:
  367. return self._path(speaker_id).exists()
  368. def delete(self, speaker_id: str) -> bool:
  369. p = self._path(speaker_id)
  370. if p.exists():
  371. p.unlink()
  372. _update_speaker_meta(speaker_id, has_embedding=False)
  373. logger.info(f"[Embeddings] Deleted voiceprint for {speaker_id}")
  374. return True
  375. return False
  376. def list_enrolled(self) -> list[dict]:
  377. enrolled = []
  378. data = _load_speakers_json()
  379. for p in sorted(self.embeddings_dir.glob("*.npy")):
  380. sid = p.stem
  381. entry = _normalise_speaker_entry(data.get(sid, sid))
  382. enrolled.append({
  383. "id": sid,
  384. "name": entry["name"],
  385. "updated": entry.get("embedding_updated"),
  386. "size_kb": round(p.stat().st_size / 1024, 1),
  387. })
  388. return enrolled
  389. # ── Extraction ─────────────────────────────────────────────────────────────
  390. def extract(
  391. self,
  392. audio_path: Path | str,
  393. start_sec: float = 0.0,
  394. end_sec: Optional[float] = None,
  395. ) -> np.ndarray:
  396. """Extract a speaker embedding from an audio file segment."""
  397. audio_path = Path(audio_path)
  398. if not audio_path.exists():
  399. raise FileNotFoundError(f"Audio file not found: {audio_path}")
  400. if end_sec is None:
  401. try:
  402. import miniaudio
  403. info = miniaudio.get_file_info(str(audio_path))
  404. end_sec = info.duration
  405. except Exception:
  406. end_sec = start_sec + 30.0
  407. if end_sec - start_sec <= 0:
  408. raise ValueError(f"Invalid segment: start={start_sec}, end={end_sec}")
  409. tmp_wav = None
  410. try:
  411. tmp_wav = extract_audio_segment(audio_path, start_sec, end_sec)
  412. pipeline = _get_pipeline()
  413. result = pipeline(str(tmp_wav))
  414. if hasattr(result, "data"):
  415. arr = np.array(result.data).flatten()
  416. elif isinstance(result, np.ndarray):
  417. arr = result.flatten()
  418. else:
  419. arr = np.array(result).flatten()
  420. logger.info(
  421. f"[Embeddings] Extracted from {audio_path.name} "
  422. f"[{start_sec:.1f}s–{end_sec:.1f}s] shape={arr.shape}"
  423. )
  424. return arr.astype(np.float32)
  425. finally:
  426. if tmp_wav and tmp_wav.exists():
  427. try:
  428. tmp_wav.unlink()
  429. except OSError:
  430. pass
  431. def extract_and_save(
  432. self,
  433. speaker_id: str,
  434. audio_path: Path | str,
  435. start_sec: float = 0.0,
  436. end_sec: Optional[float] = None,
  437. ) -> np.ndarray:
  438. """Extract embedding from segment and save it immediately."""
  439. embedding = self.extract(audio_path, start_sec, end_sec)
  440. self.save(speaker_id, embedding)
  441. return embedding
  442. def enrol_from_transcript(
  443. self,
  444. speaker_id: str,
  445. audio_path: Path | str,
  446. session_id: Optional[str] = None,
  447. min_duration: float = 8.0,
  448. ) -> np.ndarray:
  449. """
  450. Automatically find the best segment from the transcript log
  451. and use it to enrol the speaker.
  452. The best segment is the highest-scoring candidate from
  453. get_best_enrolment_segments() — longest, most isolated, most words.
  454. If multiple "Excellent" or "Good" segments exist, their embeddings
  455. are averaged for a more robust voiceprint.
  456. Args:
  457. speaker_id: Speaker to enrol
  458. audio_path: Source audio file (should be the test recording
  459. that was used to generate the transcript log)
  460. session_id: Limit to segments from a specific session
  461. min_duration: Minimum segment duration to consider
  462. Returns:
  463. The saved embedding array
  464. Raises:
  465. ValueError: if no suitable segments are found in the log
  466. """
  467. candidates = get_best_enrolment_segments(
  468. speaker_id = speaker_id,
  469. session_id = session_id,
  470. min_duration = min_duration,
  471. top_n = 3,
  472. )
  473. if not candidates:
  474. raise ValueError(
  475. f"No transcript segments found for {speaker_id} "
  476. f"with duration >= {min_duration}s. "
  477. "Run a test recording first to generate the segment log."
  478. )
  479. # Use top candidate, or average top-2 if both are Excellent/Good
  480. top = candidates[0]
  481. good_candidates = [
  482. c for c in candidates
  483. if c["recommendation"] in ("Excellent", "Good")
  484. ]
  485. if len(good_candidates) >= 2:
  486. logger.info(
  487. f"[Embeddings] Averaging {len(good_candidates)} segments "
  488. f"for {speaker_id} (more robust voiceprint)"
  489. )
  490. embeddings = []
  491. for c in good_candidates:
  492. try:
  493. emb = self.extract(audio_path, c["start"], c["end"])
  494. embeddings.append(emb)
  495. logger.info(
  496. f"[Embeddings] Used segment [{c['start']:.1f}s–{c['end']:.1f}s] "
  497. f"'{c['text'][:50]}...'"
  498. )
  499. except Exception as exc:
  500. logger.warning(f"[Embeddings] Skipping segment: {exc}")
  501. if not embeddings:
  502. raise ValueError("All candidate segments failed extraction")
  503. # Average and re-normalise
  504. avg = np.mean(np.stack(embeddings), axis=0)
  505. norm = np.linalg.norm(avg)
  506. if norm > 0:
  507. avg = avg / norm
  508. self.save(speaker_id, avg)
  509. return avg
  510. else:
  511. logger.info(
  512. f"[Embeddings] Using single best segment for {speaker_id}: "
  513. f"[{top['start']:.1f}s–{top['end']:.1f}s] "
  514. f"({top['recommendation']}, {top['duration']:.1f}s)"
  515. )
  516. return self.extract_and_save(
  517. speaker_id, audio_path, top["start"], top["end"]
  518. )
  519. # ── Matching ───────────────────────────────────────────────────────────────
  520. def find_match(
  521. self,
  522. live_embedding: np.ndarray,
  523. threshold: Optional[float] = None,
  524. exclude_ids: Optional[set[str]] = None,
  525. ) -> Optional[tuple[str, float]]:
  526. """
  527. Find the best matching enrolled speaker for a live embedding.
  528. Returns (speaker_id, similarity_score) or None.
  529. """
  530. if threshold is None:
  531. threshold = self.threshold
  532. exclude_ids = exclude_ids or set()
  533. best_id = None
  534. best_score = -1.0
  535. for p in self.embeddings_dir.glob("*.npy"):
  536. sid = p.stem
  537. if sid in exclude_ids:
  538. continue
  539. stored = self.load(sid)
  540. if stored is None:
  541. continue
  542. score = _cosine_similarity(live_embedding, stored)
  543. logger.debug(f"[Embeddings] {sid}: similarity={score:.4f}")
  544. if score > best_score:
  545. best_score = score
  546. best_id = sid
  547. if best_id and best_score >= threshold:
  548. logger.info(f"[Embeddings] Match: {best_id} (score={best_score:.4f})")
  549. return (best_id, best_score)
  550. logger.debug(
  551. f"[Embeddings] No match above threshold={threshold:.2f} "
  552. f"(best={best_score:.4f})"
  553. )
  554. return None
  555. def similarity_scores(self, live_embedding: np.ndarray) -> list[dict]:
  556. """Return similarity scores against all enrolled speakers, sorted descending."""
  557. results = []
  558. data = _load_speakers_json()
  559. for p in self.embeddings_dir.glob("*.npy"):
  560. sid = p.stem
  561. stored = self.load(sid)
  562. if stored is None:
  563. continue
  564. score = _cosine_similarity(live_embedding, stored)
  565. entry = _normalise_speaker_entry(data.get(sid, sid))
  566. results.append({
  567. "id": sid,
  568. "name": entry["name"],
  569. "similarity": round(score, 4),
  570. "match": score >= self.threshold,
  571. })
  572. return sorted(results, key=lambda x: x["similarity"], reverse=True)
  573. # ── Live accumulator ───────────────────────────────────────────────────────
  574. def make_accumulator(
  575. self,
  576. min_seconds: float = 5.0,
  577. sample_rate: int = 16000,
  578. ) -> "LiveEmbeddingAccumulator":
  579. return LiveEmbeddingAccumulator(self, min_seconds, sample_rate)
  580. # ── Live audio accumulator ─────────────────────────────────────────────────────
  581. class LiveEmbeddingAccumulator:
  582. """
  583. Buffers raw int16 PCM chunks from the live microphone stream.
  584. Once min_seconds of audio is accumulated, extracts an embedding
  585. for speaker matching or enrolment.
  586. """
  587. def __init__(
  588. self,
  589. registry: EmbeddingRegistry,
  590. min_seconds: float = 5.0,
  591. sample_rate: int = 16000,
  592. ):
  593. self.registry = registry
  594. self.min_frames = int(min_seconds * sample_rate)
  595. self.sample_rate = sample_rate
  596. self._frames: list[bytes] = []
  597. self._n_frames = 0
  598. def push(self, pcm_bytes: bytes) -> None:
  599. self._frames.append(pcm_bytes)
  600. self._n_frames += len(pcm_bytes) // 2
  601. def ready(self) -> bool:
  602. return self._n_frames >= self.min_frames
  603. def seconds_accumulated(self) -> float:
  604. return self._n_frames / self.sample_rate
  605. def reset(self) -> None:
  606. self._frames = []
  607. self._n_frames = 0
  608. def extract_embedding(self) -> np.ndarray:
  609. if not self.ready():
  610. raise RuntimeError(
  611. f"Not enough audio: {self.seconds_accumulated():.1f}s accumulated, "
  612. f"need {self.min_frames / self.sample_rate:.1f}s"
  613. )
  614. import tempfile
  615. import wave
  616. pcm_data = b"".join(self._frames)
  617. tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False, prefix="live_emb_")
  618. tmp_path = Path(tmp.name)
  619. tmp.close()
  620. try:
  621. with wave.open(str(tmp_path), "wb") as wf:
  622. wf.setnchannels(1)
  623. wf.setsampwidth(2)
  624. wf.setframerate(self.sample_rate)
  625. wf.writeframes(pcm_data)
  626. return self.registry.extract(tmp_path)
  627. finally:
  628. if tmp_path.exists():
  629. try:
  630. tmp_path.unlink()
  631. except OSError:
  632. pass
  633. # ── Standalone CLI ─────────────────────────────────────────────────────────────
  634. if __name__ == "__main__":
  635. import argparse
  636. import sys
  637. logging.basicConfig(level=logging.INFO, format="%(message)s")
  638. parser = argparse.ArgumentParser(
  639. description="Speaker voiceprint utility",
  640. formatter_class=argparse.RawDescriptionHelpFormatter,
  641. epilog="""
  642. Examples:
  643. # Enrol from a specific time range
  644. python embeddings.py enrol SPEAKER_00 recording.mp3 --start 45 --end 55
  645. # Auto-enrol using the best segment from the transcript log
  646. python embeddings.py auto-enrol SPEAKER_00 recording.mp3
  647. # Show best candidate segments for a speaker (before enrolling)
  648. python embeddings.py candidates SPEAKER_00
  649. # List all enrolled speakers
  650. python embeddings.py list
  651. # Test match against a clip
  652. python embeddings.py match recording.mp3 --start 120 --end 130
  653. # Show similarity scores for a clip
  654. python embeddings.py scores recording.mp3 --start 45 --end 55
  655. # Delete a voiceprint
  656. python embeddings.py delete SPEAKER_00
  657. """,
  658. )
  659. sub = parser.add_subparsers(dest="cmd")
  660. p_enrol = sub.add_parser("enrol", help="Extract and save voiceprint from time range")
  661. p_enrol.add_argument("speaker_id")
  662. p_enrol.add_argument("audio_file")
  663. p_enrol.add_argument("--start", type=float, default=0.0)
  664. p_enrol.add_argument("--end", type=float, default=None)
  665. p_auto = sub.add_parser("auto-enrol", help="Enrol using best segment from transcript log")
  666. p_auto.add_argument("speaker_id")
  667. p_auto.add_argument("audio_file")
  668. p_auto.add_argument("--session", default=None)
  669. p_auto.add_argument("--min-dur", type=float, default=8.0)
  670. p_cand = sub.add_parser("candidates", help="Show best enrolment segments from transcript log")
  671. p_cand.add_argument("speaker_id")
  672. p_cand.add_argument("--session", default=None)
  673. p_cand.add_argument("--min-dur", type=float, default=8.0)
  674. sub.add_parser("list", help="List enrolled speakers")
  675. p_match = sub.add_parser("match", help="Find matching speaker for audio clip")
  676. p_match.add_argument("audio_file")
  677. p_match.add_argument("--start", type=float, default=0.0)
  678. p_match.add_argument("--end", type=float, default=None)
  679. p_match.add_argument("--threshold", type=float, default=DEFAULT_THRESHOLD)
  680. p_scores = sub.add_parser("scores", help="Similarity against all enrolled speakers")
  681. p_scores.add_argument("audio_file")
  682. p_scores.add_argument("--start", type=float, default=0.0)
  683. p_scores.add_argument("--end", type=float, default=None)
  684. p_del = sub.add_parser("delete", help="Delete a stored voiceprint")
  685. p_del.add_argument("speaker_id")
  686. args = parser.parse_args()
  687. registry = EmbeddingRegistry()
  688. if not args.cmd:
  689. parser.print_help()
  690. sys.exit(0)
  691. elif args.cmd == "enrol":
  692. emb = registry.extract_and_save(
  693. args.speaker_id, args.audio_file, args.start, args.end
  694. )
  695. print(f"✓ Enrolled {args.speaker_id} — shape: {emb.shape}")
  696. elif args.cmd == "auto-enrol":
  697. emb = registry.enrol_from_transcript(
  698. args.speaker_id, args.audio_file,
  699. session_id=args.session, min_duration=args.min_dur,
  700. )
  701. print(f"✓ Auto-enrolled {args.speaker_id} — shape: {emb.shape}")
  702. elif args.cmd == "candidates":
  703. cands = get_best_enrolment_segments(
  704. args.speaker_id, session_id=args.session, min_duration=args.min_dur
  705. )
  706. if not cands:
  707. print(f"No segments found for {args.speaker_id} (min {args.min_dur}s)")
  708. else:
  709. print(f"\nBest enrolment candidates for {args.speaker_id}:\n")
  710. print(f" {'Start':>7} {'End':>7} {'Dur':>6} {'Rating':<12} {'Score':>6} Text")
  711. print(" " + "-" * 75)
  712. for c in cands:
  713. preview = c["text"][:45] + "..." if len(c["text"]) > 45 else c["text"]
  714. print(
  715. f" {c['start']:>7.1f} {c['end']:>7.1f} "
  716. f"{c['duration']:>5.1f}s {c['recommendation']:<12} "
  717. f"{c['score']:>6.1f} {preview}"
  718. )
  719. elif args.cmd == "list":
  720. enrolled = registry.list_enrolled()
  721. if not enrolled:
  722. print("No speakers enrolled yet.")
  723. else:
  724. print(f"\n{'ID':<15} {'Name':<25} {'Updated':<30} {'Size'}")
  725. print("-" * 75)
  726. for e in enrolled:
  727. print(
  728. f"{e['id']:<15} {e['name']:<25} "
  729. f"{(e['updated'] or 'unknown'):<30} {e['size_kb']} KB"
  730. )
  731. elif args.cmd == "match":
  732. emb = registry.extract(args.audio_file, args.start, args.end)
  733. match = registry.find_match(emb, threshold=args.threshold)
  734. if match:
  735. sid, score = match
  736. data = _load_speakers_json()
  737. entry = _normalise_speaker_entry(data.get(sid, sid))
  738. print(f"✓ Match: {sid} ({entry['name']}) — similarity {score:.4f}")
  739. else:
  740. print(f"✗ No match above threshold {args.threshold:.2f}")
  741. elif args.cmd == "scores":
  742. emb = registry.extract(args.audio_file, args.start, args.end)
  743. scores = registry.similarity_scores(emb)
  744. if not scores:
  745. print("No enrolled speakers to compare against.")
  746. else:
  747. print(f"\n{'ID':<15} {'Name':<25} {'Similarity':<12} {'Match?'}")
  748. print("-" * 60)
  749. for s in scores:
  750. flag = "✓" if s["match"] else " "
  751. print(
  752. f"{s['id']:<15} {s['name']:<25} "
  753. f"{s['similarity']:<12.4f} {flag}"
  754. )
  755. elif args.cmd == "delete":
  756. if registry.delete(args.speaker_id):
  757. print(f"✓ Deleted voiceprint for {args.speaker_id}")
  758. else:
  759. print(f"No voiceprint found for {args.speaker_id}")