| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916 |
- #!/usr/bin/env python3
- """
- embeddings.py — Speaker Voiceprint Management
- Handles extraction, storage, and matching of speaker voice embeddings
- using pyannote.audio's SpeakerEmbedding pipeline.
- Embeddings are stored as numpy .npy files in bridge/embeddings/.
- speakers.json is updated with has_embedding / embedding_updated metadata.
- Transcript segments are logged to bridge/transcript_segments.jsonl —
- a line-delimited JSON log of every finalised speech segment with speaker ID,
- timestamps, and text. This log is used for automatic best-segment enrolment.
- Public API
- ----------
- registry = EmbeddingRegistry()
- # Extract from an audio file segment
- embedding = registry.extract(audio_path, start_sec=45.0, end_sec=55.0)
- # Save a voiceprint for a known speaker
- registry.save(speaker_id="SPEAKER_00", embedding=embedding)
- # Auto-enrol from a transcript log (picks the cleanest segment)
- registry.enrol_from_transcript("SPEAKER_00", audio_path, segments)
- # Find the best matching speaker for a live embedding
- match = registry.find_match(live_embedding, threshold=0.82)
- # -> ("SPEAKER_00", 0.91) or None
- # Load a stored embedding
- emb = registry.load("SPEAKER_00") # -> np.ndarray | None
- # Delete a voiceprint
- registry.delete("SPEAKER_00")
- # List all speakers with voiceprints
- registry.list_enrolled()
- # Accumulate live audio and extract when ready
- acc = registry.make_accumulator()
- acc.push(pcm_bytes)
- if acc.ready():
- emb = acc.extract_embedding()
- match = registry.find_match(emb)
- """
- from __future__ import annotations
- import json
- import logging
- import os
- from datetime import datetime, timezone
- from pathlib import Path
- from typing import Optional
- import numpy as np
- logger = logging.getLogger(__name__)
- # ── Paths ──────────────────────────────────────────────────────────────────────
- _HERE = Path(__file__).parent
- EMBEDDINGS_DIR = _HERE / "embeddings"
- SPEAKERS_FILE = _HERE / "speakers.json"
- TRANSCRIPT_SEGMENTS_LOG = _HERE / "transcript_segments.jsonl"
- EMBEDDING_DIM = 512
- DEFAULT_THRESHOLD = 0.82
- MIN_AUDIO_SEC = 5.0
- # ── Lazy model loader ──────────────────────────────────────────────────────────
- _pipeline = None
- def _get_pipeline():
- """
- Load the pyannote SpeakerEmbedding pipeline on first call.
- Requires HF_TOKEN env var or huggingface-cli login.
- """
- global _pipeline
- if _pipeline is not None:
- return _pipeline
- try:
- from pyannote.audio import Model, Inference
- except ImportError as exc:
- raise RuntimeError(
- "pyannote.audio is required for voiceprint extraction. "
- "Install with: pip install pyannote.audio"
- ) from exc
- hf_token = os.environ.get("HF_TOKEN")
- if not hf_token:
- raise RuntimeError(
- "HF_TOKEN environment variable is not set. "
- "Set it in start.bat or run: huggingface-cli login"
- )
- logger.info("[Embeddings] Loading pyannote speaker embedding model...")
- try:
- model = Model.from_pretrained("pyannote/embedding", use_auth_token=hf_token)
- _pipeline = Inference(model, window="whole")
- logger.info("[Embeddings] Speaker embedding model loaded.")
- except Exception as exc:
- raise RuntimeError(
- f"Failed to load pyannote/embedding model: {exc}\n"
- "Ensure you have accepted model conditions at "
- "https://huggingface.co/pyannote/embedding"
- ) from exc
- return _pipeline
- # ── Helpers ────────────────────────────────────────────────────────────────────
- def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
- a = a.flatten().astype(np.float32)
- b = b.flatten().astype(np.float32)
- norm_a = np.linalg.norm(a)
- norm_b = np.linalg.norm(b)
- if norm_a == 0 or norm_b == 0:
- return 0.0
- return float(np.dot(a, b) / (norm_a * norm_b))
- def _load_speakers_json() -> dict:
- if SPEAKERS_FILE.exists():
- try:
- raw = json.loads(SPEAKERS_FILE.read_text(encoding="utf-8"))
- return raw if isinstance(raw, dict) else {}
- except (json.JSONDecodeError, OSError):
- pass
- return {}
- def _save_speakers_json(data: dict) -> None:
- try:
- SPEAKERS_FILE.write_text(
- json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8"
- )
- except OSError as exc:
- logger.error(f"[Embeddings] Could not save speakers.json: {exc}")
- def _normalise_speaker_entry(entry) -> dict:
- """Migrate flat string entries to rich dict format."""
- if isinstance(entry, str):
- return {"name": entry, "has_embedding": False,
- "embedding_updated": None, "colour": None, "notes": ""}
- if isinstance(entry, dict):
- entry.setdefault("name", "")
- entry.setdefault("has_embedding", False)
- entry.setdefault("embedding_updated", None)
- entry.setdefault("colour", None)
- entry.setdefault("notes", "")
- return entry
- return {"name": str(entry), "has_embedding": False,
- "embedding_updated": None, "colour": None, "notes": ""}
- def _update_speaker_meta(speaker_id: str, has_embedding: bool) -> None:
- data = _load_speakers_json()
- entry = _normalise_speaker_entry(data.get(speaker_id, speaker_id))
- entry["has_embedding"] = has_embedding
- entry["embedding_updated"] = (
- datetime.now(timezone.utc).isoformat() if has_embedding else None
- )
- data[speaker_id] = entry
- _save_speakers_json(data)
- # ── Transcript segment log ─────────────────────────────────────────────────────
- def log_transcript_segment(
- speaker_id: str,
- text: str,
- start_sec: float,
- end_sec: float,
- session_id: Optional[str] = None,
- ) -> None:
- """
- Append a finalised transcript segment to transcript_segments.jsonl.
- Called by bridge.py each time a is_final segment is received.
- This log is later used by enrol_from_transcript() to find the best
- clean segment for voiceprint extraction.
- Args:
- speaker_id: Diarization label, e.g. "SPEAKER_00"
- text: Transcribed text for this segment
- start_sec: Segment start time in seconds (from audio start)
- end_sec: Segment end time in seconds
- session_id: Optional session identifier (e.g. date string)
- """
- record = {
- "ts": datetime.now(timezone.utc).isoformat(),
- "session": session_id or datetime.now(timezone.utc).strftime("%Y-%m-%d"),
- "speaker": speaker_id,
- "start": round(start_sec, 3),
- "end": round(end_sec, 3),
- "duration": round(end_sec - start_sec, 3),
- "text": text.strip(),
- }
- try:
- with TRANSCRIPT_SEGMENTS_LOG.open("a", encoding="utf-8") as f:
- f.write(json.dumps(record, ensure_ascii=False) + "\n")
- except OSError as exc:
- logger.warning(f"[Embeddings] Could not write segment log: {exc}")
- def load_transcript_segments(
- session_id: Optional[str] = None,
- speaker_id: Optional[str] = None,
- ) -> list[dict]:
- """
- Load transcript segments from the log file.
- Args:
- session_id: If set, only return segments from this session
- speaker_id: If set, only return segments for this speaker
- Returns:
- List of segment dicts sorted by start time
- """
- if not TRANSCRIPT_SEGMENTS_LOG.exists():
- return []
- segments = []
- try:
- with TRANSCRIPT_SEGMENTS_LOG.open(encoding="utf-8") as f:
- for line in f:
- line = line.strip()
- if not line:
- continue
- try:
- rec = json.loads(line)
- if session_id and rec.get("session") != session_id:
- continue
- if speaker_id and rec.get("speaker") != speaker_id:
- continue
- segments.append(rec)
- except json.JSONDecodeError:
- continue
- except OSError as exc:
- logger.warning(f"[Embeddings] Could not read segment log: {exc}")
- return sorted(segments, key=lambda x: x.get("start", 0))
- def get_best_enrolment_segments(
- speaker_id: str,
- session_id: Optional[str] = None,
- min_duration: float = 8.0,
- top_n: int = 5,
- ) -> list[dict]:
- """
- Find the best candidate segments for enrolment from the transcript log.
- Ranks segments by:
- 1. Duration (longer = more reliable embedding)
- 2. Isolation (gap before/after to avoid speaker overlap bleed)
- 3. Text length (more words = more acoustic variety = better embedding)
- Args:
- speaker_id: Speaker to find segments for
- session_id: Limit to a specific session (None = all sessions)
- min_duration: Minimum segment duration in seconds
- top_n: How many candidate segments to return
- Returns:
- List of segment dicts with added "score" and "recommendation" fields,
- sorted by score descending
- """
- all_segs = load_transcript_segments(session_id=session_id)
- spk_segs = [s for s in all_segs
- if s.get("speaker") == speaker_id
- and s.get("duration", 0) >= min_duration]
- if not spk_segs:
- return []
- # Build a set of all segment boundaries for isolation scoring
- all_starts = {s["start"] for s in all_segs}
- all_ends = {s["end"] for s in all_segs}
- scored = []
- for seg in spk_segs:
- duration = seg.get("duration", 0)
- text = seg.get("text", "")
- word_count = len(text.split())
- # Gap before this segment (silence / different speaker)
- gap_before = min(
- (seg["start"] - e for e in all_ends if e <= seg["start"]),
- default=seg["start"],
- )
- # Gap after this segment
- gap_after = min(
- (s - seg["end"] for s in all_starts if s >= seg["end"]),
- default=999.0,
- )
- # Score: weight duration heavily, reward isolation, penalise short text
- isolation = min(gap_before, gap_after)
- score = (
- duration * 2.0 # main driver
- + isolation * 1.5 # isolation bonus
- + word_count * 0.1 # text richness
- )
- # Human-readable recommendation
- if duration >= 15 and isolation >= 1.0:
- rec = "Excellent"
- elif duration >= 10 and isolation >= 0.5:
- rec = "Good"
- elif duration >= 8:
- rec = "Acceptable"
- else:
- rec = "Short"
- scored.append({**seg, "score": round(score, 2), "recommendation": rec})
- return sorted(scored, key=lambda x: x["score"], reverse=True)[:top_n]
- # ── Audio extraction ───────────────────────────────────────────────────────────
- def extract_audio_segment(
- audio_path: Path | str,
- start_sec: float,
- end_sec: float,
- out_path: Optional[Path] = None,
- ) -> Path:
- """
- Extract a segment of audio using miniaudio and save as a WAV file.
- Returns path to the extracted WAV (caller deletes if temp).
- """
- import tempfile
- import wave
- try:
- import miniaudio
- except ImportError as exc:
- raise RuntimeError("miniaudio is required: pip install miniaudio") from exc
- duration = end_sec - start_sec
- if duration < MIN_AUDIO_SEC:
- logger.warning(
- f"[Embeddings] Short segment ({duration:.1f}s < {MIN_AUDIO_SEC}s) "
- "— embedding may be less reliable"
- )
- sample_rate = 16000
- frames_needed = int(duration * sample_rate)
- skip_frames = int(start_sec * sample_rate)
- stream = miniaudio.stream_file(
- str(audio_path),
- output_format=miniaudio.SampleFormat.SIGNED16,
- nchannels=1,
- sample_rate=sample_rate,
- frames_to_read=4096,
- )
- collected: list[bytes] = []
- collected_frames = 0
- skipped_frames = 0
- for chunk in stream:
- chunk_bytes = bytes(chunk)
- chunk_frames = len(chunk_bytes) // 2
- if skipped_frames < skip_frames:
- remaining_skip = skip_frames - skipped_frames
- if chunk_frames <= remaining_skip:
- skipped_frames += chunk_frames
- continue
- else:
- offset = remaining_skip * 2
- chunk_bytes = chunk_bytes[offset:]
- chunk_frames = len(chunk_bytes) // 2
- skipped_frames = skip_frames
- remaining_needed = frames_needed - collected_frames
- if chunk_frames >= remaining_needed:
- collected.append(chunk_bytes[: remaining_needed * 2])
- collected_frames += remaining_needed
- break
- else:
- collected.append(chunk_bytes)
- collected_frames += chunk_frames
- if not collected:
- raise ValueError(
- f"No audio extracted from {audio_path} at {start_sec}–{end_sec}s."
- )
- pcm_data = b"".join(collected)
- if out_path is None:
- tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False, prefix="embedding_")
- out_path = Path(tmp.name)
- tmp.close()
- with wave.open(str(out_path), "wb") as wf:
- wf.setnchannels(1)
- wf.setsampwidth(2)
- wf.setframerate(sample_rate)
- wf.writeframes(pcm_data)
- return out_path
- # ── Main registry class ────────────────────────────────────────────────────────
- class EmbeddingRegistry:
- def __init__(
- self,
- embeddings_dir: Path | str = EMBEDDINGS_DIR,
- speakers_file: Path | str = SPEAKERS_FILE,
- threshold: float = DEFAULT_THRESHOLD,
- ):
- self.embeddings_dir = Path(embeddings_dir)
- self.speakers_file = Path(speakers_file)
- self.threshold = threshold
- self.embeddings_dir.mkdir(parents=True, exist_ok=True)
- def _path(self, speaker_id: str) -> Path:
- return self.embeddings_dir / f"{speaker_id}.npy"
- # ── Storage ────────────────────────────────────────────────────────────────
- def save(self, speaker_id: str, embedding: np.ndarray) -> None:
- emb = embedding.flatten().astype(np.float32)
- np.save(str(self._path(speaker_id)), emb)
- _update_speaker_meta(speaker_id, has_embedding=True)
- logger.info(f"[Embeddings] Saved voiceprint for {speaker_id}")
- def load(self, speaker_id: str) -> Optional[np.ndarray]:
- p = self._path(speaker_id)
- if not p.exists():
- return None
- try:
- return np.load(str(p)).astype(np.float32)
- except Exception as exc:
- logger.error(f"[Embeddings] Failed to load {p}: {exc}")
- return None
- def has(self, speaker_id: str) -> bool:
- return self._path(speaker_id).exists()
- def delete(self, speaker_id: str) -> bool:
- p = self._path(speaker_id)
- if p.exists():
- p.unlink()
- _update_speaker_meta(speaker_id, has_embedding=False)
- logger.info(f"[Embeddings] Deleted voiceprint for {speaker_id}")
- return True
- return False
- def list_enrolled(self) -> list[dict]:
- enrolled = []
- data = _load_speakers_json()
- for p in sorted(self.embeddings_dir.glob("*.npy")):
- sid = p.stem
- entry = _normalise_speaker_entry(data.get(sid, sid))
- enrolled.append({
- "id": sid,
- "name": entry["name"],
- "updated": entry.get("embedding_updated"),
- "size_kb": round(p.stat().st_size / 1024, 1),
- })
- return enrolled
- # ── Extraction ─────────────────────────────────────────────────────────────
- def extract(
- self,
- audio_path: Path | str,
- start_sec: float = 0.0,
- end_sec: Optional[float] = None,
- ) -> np.ndarray:
- """Extract a speaker embedding from an audio file segment."""
- audio_path = Path(audio_path)
- if not audio_path.exists():
- raise FileNotFoundError(f"Audio file not found: {audio_path}")
- if end_sec is None:
- try:
- import miniaudio
- info = miniaudio.get_file_info(str(audio_path))
- end_sec = info.duration
- except Exception:
- end_sec = start_sec + 30.0
- if end_sec - start_sec <= 0:
- raise ValueError(f"Invalid segment: start={start_sec}, end={end_sec}")
- tmp_wav = None
- try:
- tmp_wav = extract_audio_segment(audio_path, start_sec, end_sec)
- pipeline = _get_pipeline()
- result = pipeline(str(tmp_wav))
- if hasattr(result, "data"):
- arr = np.array(result.data).flatten()
- elif isinstance(result, np.ndarray):
- arr = result.flatten()
- else:
- arr = np.array(result).flatten()
- logger.info(
- f"[Embeddings] Extracted from {audio_path.name} "
- f"[{start_sec:.1f}s–{end_sec:.1f}s] shape={arr.shape}"
- )
- return arr.astype(np.float32)
- finally:
- if tmp_wav and tmp_wav.exists():
- try:
- tmp_wav.unlink()
- except OSError:
- pass
- def extract_and_save(
- self,
- speaker_id: str,
- audio_path: Path | str,
- start_sec: float = 0.0,
- end_sec: Optional[float] = None,
- ) -> np.ndarray:
- """Extract embedding from segment and save it immediately."""
- embedding = self.extract(audio_path, start_sec, end_sec)
- self.save(speaker_id, embedding)
- return embedding
- def enrol_from_transcript(
- self,
- speaker_id: str,
- audio_path: Path | str,
- session_id: Optional[str] = None,
- min_duration: float = 8.0,
- ) -> np.ndarray:
- """
- Automatically find the best segment from the transcript log
- and use it to enrol the speaker.
- The best segment is the highest-scoring candidate from
- get_best_enrolment_segments() — longest, most isolated, most words.
- If multiple "Excellent" or "Good" segments exist, their embeddings
- are averaged for a more robust voiceprint.
- Args:
- speaker_id: Speaker to enrol
- audio_path: Source audio file (should be the test recording
- that was used to generate the transcript log)
- session_id: Limit to segments from a specific session
- min_duration: Minimum segment duration to consider
- Returns:
- The saved embedding array
- Raises:
- ValueError: if no suitable segments are found in the log
- """
- candidates = get_best_enrolment_segments(
- speaker_id = speaker_id,
- session_id = session_id,
- min_duration = min_duration,
- top_n = 3,
- )
- if not candidates:
- raise ValueError(
- f"No transcript segments found for {speaker_id} "
- f"with duration >= {min_duration}s. "
- "Run a test recording first to generate the segment log."
- )
- # Use top candidate, or average top-2 if both are Excellent/Good
- top = candidates[0]
- good_candidates = [
- c for c in candidates
- if c["recommendation"] in ("Excellent", "Good")
- ]
- if len(good_candidates) >= 2:
- logger.info(
- f"[Embeddings] Averaging {len(good_candidates)} segments "
- f"for {speaker_id} (more robust voiceprint)"
- )
- embeddings = []
- for c in good_candidates:
- try:
- emb = self.extract(audio_path, c["start"], c["end"])
- embeddings.append(emb)
- logger.info(
- f"[Embeddings] Used segment [{c['start']:.1f}s–{c['end']:.1f}s] "
- f"'{c['text'][:50]}...'"
- )
- except Exception as exc:
- logger.warning(f"[Embeddings] Skipping segment: {exc}")
- if not embeddings:
- raise ValueError("All candidate segments failed extraction")
- # Average and re-normalise
- avg = np.mean(np.stack(embeddings), axis=0)
- norm = np.linalg.norm(avg)
- if norm > 0:
- avg = avg / norm
- self.save(speaker_id, avg)
- return avg
- else:
- logger.info(
- f"[Embeddings] Using single best segment for {speaker_id}: "
- f"[{top['start']:.1f}s–{top['end']:.1f}s] "
- f"({top['recommendation']}, {top['duration']:.1f}s)"
- )
- return self.extract_and_save(
- speaker_id, audio_path, top["start"], top["end"]
- )
- # ── Matching ───────────────────────────────────────────────────────────────
- def find_match(
- self,
- live_embedding: np.ndarray,
- threshold: Optional[float] = None,
- exclude_ids: Optional[set[str]] = None,
- ) -> Optional[tuple[str, float]]:
- """
- Find the best matching enrolled speaker for a live embedding.
- Returns (speaker_id, similarity_score) or None.
- """
- if threshold is None:
- threshold = self.threshold
- exclude_ids = exclude_ids or set()
- best_id = None
- best_score = -1.0
- for p in self.embeddings_dir.glob("*.npy"):
- sid = p.stem
- if sid in exclude_ids:
- continue
- stored = self.load(sid)
- if stored is None:
- continue
- score = _cosine_similarity(live_embedding, stored)
- logger.debug(f"[Embeddings] {sid}: similarity={score:.4f}")
- if score > best_score:
- best_score = score
- best_id = sid
- if best_id and best_score >= threshold:
- logger.info(f"[Embeddings] Match: {best_id} (score={best_score:.4f})")
- return (best_id, best_score)
- logger.debug(
- f"[Embeddings] No match above threshold={threshold:.2f} "
- f"(best={best_score:.4f})"
- )
- return None
- def similarity_scores(self, live_embedding: np.ndarray) -> list[dict]:
- """Return similarity scores against all enrolled speakers, sorted descending."""
- results = []
- data = _load_speakers_json()
- for p in self.embeddings_dir.glob("*.npy"):
- sid = p.stem
- stored = self.load(sid)
- if stored is None:
- continue
- score = _cosine_similarity(live_embedding, stored)
- entry = _normalise_speaker_entry(data.get(sid, sid))
- results.append({
- "id": sid,
- "name": entry["name"],
- "similarity": round(score, 4),
- "match": score >= self.threshold,
- })
- return sorted(results, key=lambda x: x["similarity"], reverse=True)
- # ── Live accumulator ───────────────────────────────────────────────────────
- def make_accumulator(
- self,
- min_seconds: float = 5.0,
- sample_rate: int = 16000,
- ) -> "LiveEmbeddingAccumulator":
- return LiveEmbeddingAccumulator(self, min_seconds, sample_rate)
- # ── Live audio accumulator ─────────────────────────────────────────────────────
- class LiveEmbeddingAccumulator:
- """
- Buffers raw int16 PCM chunks from the live microphone stream.
- Once min_seconds of audio is accumulated, extracts an embedding
- for speaker matching or enrolment.
- """
- def __init__(
- self,
- registry: EmbeddingRegistry,
- min_seconds: float = 5.0,
- sample_rate: int = 16000,
- ):
- self.registry = registry
- self.min_frames = int(min_seconds * sample_rate)
- self.sample_rate = sample_rate
- self._frames: list[bytes] = []
- self._n_frames = 0
- def push(self, pcm_bytes: bytes) -> None:
- self._frames.append(pcm_bytes)
- self._n_frames += len(pcm_bytes) // 2
- def ready(self) -> bool:
- return self._n_frames >= self.min_frames
- def seconds_accumulated(self) -> float:
- return self._n_frames / self.sample_rate
- def reset(self) -> None:
- self._frames = []
- self._n_frames = 0
- def extract_embedding(self) -> np.ndarray:
- if not self.ready():
- raise RuntimeError(
- f"Not enough audio: {self.seconds_accumulated():.1f}s accumulated, "
- f"need {self.min_frames / self.sample_rate:.1f}s"
- )
- import tempfile
- import wave
- pcm_data = b"".join(self._frames)
- tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False, prefix="live_emb_")
- tmp_path = Path(tmp.name)
- tmp.close()
- try:
- with wave.open(str(tmp_path), "wb") as wf:
- wf.setnchannels(1)
- wf.setsampwidth(2)
- wf.setframerate(self.sample_rate)
- wf.writeframes(pcm_data)
- return self.registry.extract(tmp_path)
- finally:
- if tmp_path.exists():
- try:
- tmp_path.unlink()
- except OSError:
- pass
- # ── Standalone CLI ─────────────────────────────────────────────────────────────
- if __name__ == "__main__":
- import argparse
- import sys
- logging.basicConfig(level=logging.INFO, format="%(message)s")
- parser = argparse.ArgumentParser(
- description="Speaker voiceprint utility",
- formatter_class=argparse.RawDescriptionHelpFormatter,
- epilog="""
- Examples:
- # Enrol from a specific time range
- python embeddings.py enrol SPEAKER_00 recording.mp3 --start 45 --end 55
- # Auto-enrol using the best segment from the transcript log
- python embeddings.py auto-enrol SPEAKER_00 recording.mp3
- # Show best candidate segments for a speaker (before enrolling)
- python embeddings.py candidates SPEAKER_00
- # List all enrolled speakers
- python embeddings.py list
- # Test match against a clip
- python embeddings.py match recording.mp3 --start 120 --end 130
- # Show similarity scores for a clip
- python embeddings.py scores recording.mp3 --start 45 --end 55
- # Delete a voiceprint
- python embeddings.py delete SPEAKER_00
- """,
- )
- sub = parser.add_subparsers(dest="cmd")
- p_enrol = sub.add_parser("enrol", help="Extract and save voiceprint from time range")
- p_enrol.add_argument("speaker_id")
- p_enrol.add_argument("audio_file")
- p_enrol.add_argument("--start", type=float, default=0.0)
- p_enrol.add_argument("--end", type=float, default=None)
- p_auto = sub.add_parser("auto-enrol", help="Enrol using best segment from transcript log")
- p_auto.add_argument("speaker_id")
- p_auto.add_argument("audio_file")
- p_auto.add_argument("--session", default=None)
- p_auto.add_argument("--min-dur", type=float, default=8.0)
- p_cand = sub.add_parser("candidates", help="Show best enrolment segments from transcript log")
- p_cand.add_argument("speaker_id")
- p_cand.add_argument("--session", default=None)
- p_cand.add_argument("--min-dur", type=float, default=8.0)
- sub.add_parser("list", help="List enrolled speakers")
- p_match = sub.add_parser("match", help="Find matching speaker for audio clip")
- p_match.add_argument("audio_file")
- p_match.add_argument("--start", type=float, default=0.0)
- p_match.add_argument("--end", type=float, default=None)
- p_match.add_argument("--threshold", type=float, default=DEFAULT_THRESHOLD)
- p_scores = sub.add_parser("scores", help="Similarity against all enrolled speakers")
- p_scores.add_argument("audio_file")
- p_scores.add_argument("--start", type=float, default=0.0)
- p_scores.add_argument("--end", type=float, default=None)
- p_del = sub.add_parser("delete", help="Delete a stored voiceprint")
- p_del.add_argument("speaker_id")
- args = parser.parse_args()
- registry = EmbeddingRegistry()
- if not args.cmd:
- parser.print_help()
- sys.exit(0)
- elif args.cmd == "enrol":
- emb = registry.extract_and_save(
- args.speaker_id, args.audio_file, args.start, args.end
- )
- print(f"✓ Enrolled {args.speaker_id} — shape: {emb.shape}")
- elif args.cmd == "auto-enrol":
- emb = registry.enrol_from_transcript(
- args.speaker_id, args.audio_file,
- session_id=args.session, min_duration=args.min_dur,
- )
- print(f"✓ Auto-enrolled {args.speaker_id} — shape: {emb.shape}")
- elif args.cmd == "candidates":
- cands = get_best_enrolment_segments(
- args.speaker_id, session_id=args.session, min_duration=args.min_dur
- )
- if not cands:
- print(f"No segments found for {args.speaker_id} (min {args.min_dur}s)")
- else:
- print(f"\nBest enrolment candidates for {args.speaker_id}:\n")
- print(f" {'Start':>7} {'End':>7} {'Dur':>6} {'Rating':<12} {'Score':>6} Text")
- print(" " + "-" * 75)
- for c in cands:
- preview = c["text"][:45] + "..." if len(c["text"]) > 45 else c["text"]
- print(
- f" {c['start']:>7.1f} {c['end']:>7.1f} "
- f"{c['duration']:>5.1f}s {c['recommendation']:<12} "
- f"{c['score']:>6.1f} {preview}"
- )
- elif args.cmd == "list":
- enrolled = registry.list_enrolled()
- if not enrolled:
- print("No speakers enrolled yet.")
- else:
- print(f"\n{'ID':<15} {'Name':<25} {'Updated':<30} {'Size'}")
- print("-" * 75)
- for e in enrolled:
- print(
- f"{e['id']:<15} {e['name']:<25} "
- f"{(e['updated'] or 'unknown'):<30} {e['size_kb']} KB"
- )
- elif args.cmd == "match":
- emb = registry.extract(args.audio_file, args.start, args.end)
- match = registry.find_match(emb, threshold=args.threshold)
- if match:
- sid, score = match
- data = _load_speakers_json()
- entry = _normalise_speaker_entry(data.get(sid, sid))
- print(f"✓ Match: {sid} ({entry['name']}) — similarity {score:.4f}")
- else:
- print(f"✗ No match above threshold {args.threshold:.2f}")
- elif args.cmd == "scores":
- emb = registry.extract(args.audio_file, args.start, args.end)
- scores = registry.similarity_scores(emb)
- if not scores:
- print("No enrolled speakers to compare against.")
- else:
- print(f"\n{'ID':<15} {'Name':<25} {'Similarity':<12} {'Match?'}")
- print("-" * 60)
- for s in scores:
- flag = "✓" if s["match"] else " "
- print(
- f"{s['id']:<15} {s['name']:<25} "
- f"{s['similarity']:<12.4f} {flag}"
- )
- elif args.cmd == "delete":
- if registry.delete(args.speaker_id):
- print(f"✓ Deleted voiceprint for {args.speaker_id}")
- else:
- print(f"No voiceprint found for {args.speaker_id}")
|