瀏覽代碼

Embeddings implementation

Benjamin Harris 1 月之前
父節點
當前提交
7f07089bb6
共有 2 個文件被更改,包括 451 次插入235 次删除
  1. 64 4
      bridge/bridge.py
  2. 387 231
      bridge/embeddings.py

+ 64 - 4
bridge/bridge.py

@@ -17,6 +17,7 @@ import textwrap
 import threading
 import time
 from collections import Counter
+from datetime import datetime, timezone
 from pathlib import Path
 
 import numpy as np
@@ -25,6 +26,11 @@ import sounddevice as sd
 
 from fastapi import FastAPI, Request
 from whisperlivekit import AudioProcessor, TranscriptionEngine
+from embeddings import (
+    EmbeddingRegistry,
+    log_transcript_segment,
+)
+
 import uvicorn
 
 # ── Configuration ─────────────────────────────────────────────────────────────
@@ -123,6 +129,13 @@ class BridgeState:
         self._text_buffer                  = ""
         self._display: list[str]           = [""] * DISPLAY_LINES
         self._last_final_time              = time.monotonic()
+        # Raw diarization ID of current speaker (SPEAKER_XX)
+        self._raw_speaker_id: str | None   = None
+        # Voiceprint matching
+        self._embedding_registry           = EmbeddingRegistry()
+        self._accumulators: dict[str, object] = {}
+        self._confirmed_ids: set[str]      = set()
+        self._session_id: str              = datetime.now(timezone.utc).strftime("%Y-%m-%d")
 
     def set_speaker_name(self, speaker_id: str, name: str) -> None:
         with self._lock:
@@ -154,11 +167,24 @@ class BridgeState:
             return entry.get("name") or speaker_id
         return str(entry)
 
-    def push_final(self, text: str, speaker_id: str | None, mqtt_client: mqtt.Client) -> None:
+    def push_final(self, text: str, speaker_id: str | None, mqtt_client: mqtt.Client, seg_start: float = 0.0, seg_end: float = 0.0) -> None:
+        """Accept a finalised segment; flush on sentence boundary or speaker change."""
         with self._lock:
+            # Track raw diarization ID for PCM accumulator
+            self._raw_speaker_id = speaker_id
+
             if speaker_id:
                 self._seen.add(speaker_id)
 
+            # Log segment to transcript_segments.jsonl for later enrolment
+            log_transcript_segment(
+                speaker_id = speaker_id or "UNKNOWN",
+                text       = text,
+                start_sec  = seg_start,
+                end_sec    = seg_end,
+                session_id = self._session_id,
+            )
+
             resolved = self._resolve(speaker_id)
             if resolved != self._current_speaker:
                 if self._text_buffer:
@@ -338,13 +364,47 @@ async def audio_processor_loop(state: BridgeState, mqtt_client: mqtt.Client, eng
             dtype="int16", blocksize=BLOCKSIZE, callback=audio_callback,
         ):
             while True:
-                # Injected test audio takes priority over live microphone
+                # Drain test audio injection first if available
                 try:
-                    chunk = _inject_queue.get_nowait()
-                except _stdlib_queue.Empty:
+                    chunk = test_audio_queue.get_nowait()
+                except asyncio.QueueEmpty:
                     chunk = await audio_queue.get()
                 await audio_processor.process_audio(chunk)
 
+                # Accumulate PCM for live speaker matching
+                current_spk = state._raw_speaker_id
+                if current_spk and current_spk not in state._confirmed_ids:
+                    if current_spk not in state._accumulators:
+                        state._accumulators[current_spk] = \
+                            state._embedding_registry.make_accumulator(min_seconds=5.0)
+                    
+                    # Convert float32 → int16 for the embedding accumulator
+                    chunk_i16 = (np.frombuffer(chunk, dtype=np.float32) * 32767).astype(np.int16).tobytes()
+                    state._accumulators[current_spk].push(chunk_i16)
+                    # state._accumulators[current_spk].push(chunk)
+                    
+                    if state._accumulators[current_spk].ready():
+                        try:
+                            live_emb = state._accumulators[current_spk].extract_embedding()
+                            match    = state._embedding_registry.find_match(live_emb)
+                            if match:
+                                matched_id, score = match
+                                resolved = state.speaker_names.get(matched_id, matched_id)
+                                print(
+                                    f"[Embeddings] Auto-matched {current_spk} → "
+                                    f"{resolved} (score={score:.3f})"
+                                )
+                                with state._lock:
+                                    state.speaker_names[current_spk] = resolved
+                                    state._confirmed_ids.add(current_spk)
+                                    _write_speakers(state.speaker_names)
+                            else:
+                                # No match yet — reset and try again with more audio
+                                state._accumulators[current_spk].reset()
+                        except Exception as exc:
+                            print(f"[Embeddings] Accumulator error: {exc}")
+                            state._accumulators[current_spk].reset()
+
     flusher  = asyncio.create_task(_flusher(state, mqtt_client))
     reloader = asyncio.create_task(_speaker_reloader(state))
     try:

文件差異過大導致無法顯示
+ 387 - 231
bridge/embeddings.py


部分文件因文件數量過多而無法顯示