bridge.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. #!/usr/bin/env python3
  2. """
  3. bridge.py — Live Transcription Bridge
  4. Uses WhisperLiveKit's AudioProcessor API directly (no WebSocket).
  5. Publishes rolling 3-line JSON to Mosquitto MQTT.
  6. Run this script:
  7. python bridge.py
  8. """
  9. import asyncio
  10. import json
  11. import queue as _stdlib_queue
  12. import re
  13. import textwrap
  14. import threading
  15. import time
  16. from collections import Counter
  17. from datetime import datetime, timezone
  18. from pathlib import Path
  19. import numpy as np
  20. import paho.mqtt.client as mqtt
  21. import sounddevice as sd
  22. from fastapi import FastAPI, Request
  23. from whisperlivekit import AudioProcessor, TranscriptionEngine
  24. from embeddings import (
  25. EmbeddingRegistry,
  26. log_transcript_segment,
  27. )
  28. import uvicorn
  29. # ── Configuration ─────────────────────────────────────────────────────────────
  30. MQTT_HOST = "localhost"
  31. MQTT_PORT = 1883
  32. MQTT_TOPIC_TEXT = "display/text"
  33. MQTT_TOPIC_CLEAR = "display/clear"
  34. SAMPLE_RATE = 16000
  35. CHANNELS = 1
  36. BLOCKSIZE = 4096 # ~256 ms per chunk at 16 kHz
  37. SENTENCE_TIMEOUT = 4.0 # seconds of silence before forcing a flush
  38. MAX_LINE_CHARS = 80 # characters per line
  39. DISPLAY_LINES = 3
  40. # Set to a device index (integer) to force a specific input device.
  41. # Leave as None to use the Windows default input device.
  42. AUDIO_DEVICE: int | None = None
  43. # Set True to capture system audio (everything playing on the PC speakers).
  44. # Requires "Stereo Mix" to be enabled in Windows Sound settings → Recording tab.
  45. # When True, AUDIO_DEVICE is ignored and the loopback device is auto-detected.
  46. USE_SYSTEM_AUDIO: bool = False
  47. _LOOPBACK_NAMES = ("stereo mix", "what u hear", "wave out mix", "loopback")
  48. SPEAKERS_FILE = Path(__file__).parent / "speakers.json"
  49. DEFAULT_SPEAKERS: dict[str, dict] = {
  50. "SPEAKER_00": {"name": "A.A.A", "role": "Serving Brother", "location": "Sydney", "has_embedding": False, "embedding_updated": None, "colour": "#16a34a", "notes": ""},
  51. "SPEAKER_01": {"name": "A.A.A", "role": "Contributor", "location": "London", "has_embedding": False, "embedding_updated": None, "colour": "#16a34a", "notes": ""},
  52. "SPEAKER_02": {"name": "A.A.A", "role": "Contributor", "location": "Hobart", "has_embedding": False, "embedding_updated": None, "colour": "#16a34a", "notes": ""},
  53. "SPEAKER_03": {"name": "A.A.A", "role": "Contributor", "location": "Perth", "has_embedding": False, "embedding_updated": None, "colour": "#16a34a", "notes": ""},
  54. }
  55. # ── Audio injection queue ─────────────────────────────────────────────────────
  56. # stdlib queue.Queue is thread-safe across event loops; asyncio.Queue is not.
  57. # admin.py POSTs test audio chunks to /inject (port 8002) which puts them here.
  58. # _send_audio() drains this queue in preference to the live microphone.
  59. _inject_queue: _stdlib_queue.Queue = _stdlib_queue.Queue(maxsize=240)
  60. # ── Audio injection API ───────────────────────────────────────────────────────
  61. _bridge_app = FastAPI()
  62. @_bridge_app.post("/inject")
  63. async def inject_audio(request: Request):
  64. chunk = await request.body()
  65. if chunk:
  66. try:
  67. _inject_queue.put_nowait(chunk)
  68. except _stdlib_queue.Full:
  69. pass
  70. return {"ok": True}
  71. @_bridge_app.post("/inject/clear")
  72. async def inject_clear():
  73. while True:
  74. try:
  75. _inject_queue.get_nowait()
  76. except _stdlib_queue.Empty:
  77. break
  78. return {"ok": True}
  79. # ── Speaker persistence ───────────────────────────────────────────────────────
  80. def _load_speakers() -> dict:
  81. if SPEAKERS_FILE.exists():
  82. try:
  83. data = json.loads(SPEAKERS_FILE.read_text(encoding="utf-8"))
  84. if isinstance(data, dict):
  85. return data
  86. except (json.JSONDecodeError, OSError):
  87. pass
  88. _write_speakers(DEFAULT_SPEAKERS)
  89. return dict(DEFAULT_SPEAKERS)
  90. def _write_speakers(names: dict) -> None:
  91. try:
  92. SPEAKERS_FILE.write_text(
  93. json.dumps(names, indent=2, ensure_ascii=False),
  94. encoding="utf-8",
  95. )
  96. except OSError as exc:
  97. print(f"[Speakers] Save failed: {exc}")
  98. # ── State ─────────────────────────────────────────────────────────────────────
  99. class BridgeState:
  100. """All mutable state, protected by a single lock."""
  101. def __init__(self):
  102. self._lock = threading.Lock()
  103. self.speaker_names: dict = _load_speakers()
  104. self._seen: set[str] = set(self.speaker_names)
  105. self._current_speaker: str | None = None
  106. self._speaker_changed = False
  107. self._text_buffer = ""
  108. self._display: list[str] = [""] * DISPLAY_LINES
  109. self._last_final_time = time.monotonic()
  110. # Raw diarization ID of current speaker (SPEAKER_XX)
  111. self._raw_speaker_id: str | None = None
  112. # Voiceprint matching
  113. self._embedding_registry = EmbeddingRegistry()
  114. self._accumulators: dict[str, object] = {}
  115. self._confirmed_ids: set[str] = set()
  116. self._session_id: str = datetime.now(timezone.utc).strftime("%Y-%m-%d")
  117. def set_speaker_name(self, speaker_id: str, name: str) -> None:
  118. with self._lock:
  119. entry = self.speaker_names.get(speaker_id, {})
  120. if isinstance(entry, dict):
  121. self.speaker_names[speaker_id] = {**entry, "name": name.strip()}
  122. else:
  123. self.speaker_names[speaker_id] = {"name": name.strip()}
  124. self._seen.add(speaker_id)
  125. _write_speakers(self.speaker_names)
  126. def delete_speaker(self, speaker_id: str) -> None:
  127. with self._lock:
  128. self.speaker_names.pop(speaker_id, None)
  129. self._seen.discard(speaker_id)
  130. _write_speakers(self.speaker_names)
  131. def seen_speakers_snapshot(self) -> set[str]:
  132. with self._lock:
  133. return set(self._seen)
  134. def _resolve(self, speaker_id: str | None) -> str | None:
  135. if not speaker_id:
  136. return None
  137. entry = self.speaker_names.get(speaker_id)
  138. if entry is None:
  139. return speaker_id
  140. if isinstance(entry, dict):
  141. return entry.get("name") or speaker_id
  142. return str(entry)
  143. def push_final(self, text: str, speaker_id: str | None, mqtt_client: mqtt.Client, seg_start: float = 0.0, seg_end: float = 0.0) -> None:
  144. """Accept a finalised segment; flush on sentence boundary or speaker change."""
  145. with self._lock:
  146. # Track raw diarization ID for PCM accumulator
  147. self._raw_speaker_id = speaker_id
  148. if speaker_id:
  149. self._seen.add(speaker_id)
  150. # Log segment to transcript_segments.jsonl for later enrolment
  151. log_transcript_segment(
  152. speaker_id = speaker_id or "UNKNOWN",
  153. text = text,
  154. start_sec = seg_start,
  155. end_sec = seg_end,
  156. session_id = self._session_id,
  157. )
  158. resolved = self._resolve(speaker_id)
  159. if resolved != self._current_speaker:
  160. if self._text_buffer:
  161. self._flush(mqtt_client)
  162. self._current_speaker = resolved
  163. self._speaker_changed = True
  164. sep = " " if self._text_buffer else ""
  165. self._text_buffer += sep + text.strip()
  166. self._last_final_time = time.monotonic()
  167. if _is_sentence_end(text):
  168. self._flush(mqtt_client)
  169. def maybe_timeout_flush(self, mqtt_client: mqtt.Client) -> None:
  170. with self._lock:
  171. if self._text_buffer and (time.monotonic() - self._last_final_time) > SENTENCE_TIMEOUT:
  172. self._flush(mqtt_client)
  173. def _flush(self, mqtt_client: mqtt.Client) -> None:
  174. text = self._text_buffer.strip()
  175. self._text_buffer = ""
  176. if not text:
  177. return
  178. new_lines: list[str] = []
  179. if self._speaker_changed and self._current_speaker:
  180. new_lines.append(f"[{self._current_speaker.upper()}]")
  181. self._speaker_changed = False
  182. new_lines.extend(textwrap.wrap(text, MAX_LINE_CHARS) or [""])
  183. for line in new_lines:
  184. self._display.append(line)
  185. self._display = self._display[-DISPLAY_LINES:]
  186. while len(self._display) < DISPLAY_LINES:
  187. self._display.insert(0, "")
  188. payload = json.dumps({"lines": list(self._display)})
  189. mqtt_client.publish(MQTT_TOPIC_TEXT, payload)
  190. print(f"[Display] {self._display}")
  191. def clear(self, mqtt_client: mqtt.Client) -> None:
  192. with self._lock:
  193. self._display = [""] * DISPLAY_LINES
  194. self._text_buffer = ""
  195. self._current_speaker = None
  196. self._speaker_changed = False
  197. mqtt_client.publish(MQTT_TOPIC_CLEAR, "")
  198. print("[Display] Cleared")
  199. # ── Helpers ───────────────────────────────────────────────────────────────────
  200. def _is_sentence_end(text: str) -> bool:
  201. return bool(re.search(r'[.!?…]\s*$', text.strip()))
  202. # ── MQTT ──────────────────────────────────────────────────────────────────────
  203. def build_mqtt_client() -> mqtt.Client:
  204. client = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2)
  205. def on_connect(client, userdata, flags, rc, props):
  206. print("[MQTT] Connected" if rc == 0 else f"[MQTT] Failed: {rc}")
  207. def on_disconnect(client, userdata, flags, rc, props):
  208. print(f"[MQTT] Disconnected ({rc}), will reconnect...")
  209. client.on_connect = on_connect
  210. client.on_disconnect = on_disconnect
  211. client.reconnect_delay_set(min_delay=1, max_delay=30)
  212. client.connect_async(MQTT_HOST, MQTT_PORT)
  213. client.loop_start()
  214. return client
  215. # ── Audio pipeline ────────────────────────────────────────────────────────────
  216. def _choose_audio_device() -> int | None:
  217. try:
  218. devices = sd.query_devices()
  219. default_in = sd.default.device[0]
  220. except Exception as exc:
  221. print(f"[Audio] Cannot query devices: {exc}")
  222. return None
  223. print("[Audio] Available input devices:")
  224. input_devices: list[tuple[int, str]] = []
  225. for i, dev in enumerate(devices):
  226. if dev["max_input_channels"] > 0:
  227. marker = " ← default" if i == default_in else ""
  228. print(f" [{i}] {dev['name']}{marker}")
  229. input_devices.append((i, dev["name"]))
  230. if not input_devices:
  231. print("[Audio] ERROR: No input devices found.")
  232. return None
  233. if USE_SYSTEM_AUDIO:
  234. for idx, name in input_devices:
  235. if any(k in name.lower() for k in _LOOPBACK_NAMES):
  236. print(f"[Audio] System audio (loopback) device: [{idx}] {name}")
  237. return idx
  238. print("[Audio] WARNING: No loopback device found. Enable 'Stereo Mix' in")
  239. print("[Audio] Windows Sound Settings → Recording tab → right-click → Show Disabled Devices")
  240. print("[Audio] Falling back to default input device.")
  241. if AUDIO_DEVICE is not None:
  242. print(f"[Audio] Using configured device [{AUDIO_DEVICE}]")
  243. return AUDIO_DEVICE
  244. if default_in >= 0:
  245. print(f"[Audio] Using default input device [{default_in}]")
  246. return default_in
  247. idx, name = input_devices[0]
  248. print(f"[Audio] No system default — using [{idx}] {name}")
  249. return idx
  250. async def audio_processor_loop(state: BridgeState, mqtt_client: mqtt.Client, engine: TranscriptionEngine) -> None:
  251. audio_queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=120)
  252. loop = asyncio.get_running_loop()
  253. device = _choose_audio_device()
  254. if device is None:
  255. print("[Audio] No input device — cannot start.")
  256. return
  257. # Query the device's native channel count (loopback devices are often stereo)
  258. try:
  259. dev_info = sd.query_devices(device)
  260. in_ch = max(1, min(int(dev_info["max_input_channels"]), 2))
  261. except Exception:
  262. in_ch = CHANNELS
  263. if in_ch > 1:
  264. print(f"[Audio] Device has {in_ch} channels — will mix down to mono")
  265. def audio_callback(indata: np.ndarray, frames: int, time_info, status) -> None:
  266. if status:
  267. print(f"[Audio] {status}")
  268. # indata is float32; mix stereo → mono then convert to int16
  269. mono = indata.mean(axis=1) if indata.shape[1] > 1 else indata[:, 0]
  270. chunk = (mono * 32767).clip(-32768, 32767).astype(np.int16).tobytes()
  271. loop.call_soon_threadsafe(
  272. lambda: audio_queue.put_nowait(chunk) if not audio_queue.full() else None
  273. )
  274. audio_processor = AudioProcessor(transcription_engine=engine)
  275. results_generator = await audio_processor.create_tasks()
  276. async def _receive_results():
  277. # FrontData.lines is validated_segments + a growing current-segment snapshot.
  278. # The last element's text GROWS silently between calls, so index-counting
  279. # misses incremental content. Instead, track the full concatenated
  280. # transcript and push only the delta each time it grows.
  281. prev_full_text = ""
  282. async for response in results_generator:
  283. lines = response.lines or []
  284. current_full_text = " ".join(
  285. (seg.text or "").strip()
  286. for seg in lines
  287. if not seg.is_silence() and (seg.text or "").strip()
  288. )
  289. if current_full_text == prev_full_text:
  290. continue
  291. if prev_full_text and current_full_text.startswith(prev_full_text):
  292. new_text = current_full_text[len(prev_full_text):].strip()
  293. # Drop leading punctuation that belongs to the previous sentence
  294. while new_text and new_text[0] in ".,;:!?":
  295. new_text = new_text[1:].strip()
  296. else:
  297. # First segment or context reset after a long silence
  298. new_text = current_full_text
  299. prev_full_text = current_full_text
  300. if not new_text or len(new_text) < 2:
  301. continue
  302. last_seg = next(
  303. (s for s in reversed(lines) if not s.is_silence() and (s.text or "").strip()),
  304. None,
  305. )
  306. spk = getattr(last_seg, "speaker", None) if last_seg else None
  307. speaker_id = f"SPEAKER_{spk:02d}" if isinstance(spk, int) and spk >= 0 else None
  308. seg_start = float(getattr(last_seg, "start", 0.0) or 0.0) if last_seg else 0.0
  309. seg_end = float(getattr(last_seg, "end", 0.0) or 0.0) if last_seg else 0.0
  310. print(f"[Whisper] ({speaker_id or '?'}) {new_text}")
  311. state.push_final(new_text, speaker_id, mqtt_client, seg_start, seg_end)
  312. async def _send_audio():
  313. with sd.InputStream(
  314. device=device, samplerate=SAMPLE_RATE, channels=in_ch,
  315. dtype="float32", blocksize=BLOCKSIZE, callback=audio_callback,
  316. ):
  317. while True:
  318. # Injected test audio takes priority over live microphone
  319. try:
  320. chunk = _inject_queue.get_nowait()
  321. except _stdlib_queue.Empty:
  322. chunk = await audio_queue.get()
  323. await audio_processor.process_audio(chunk)
  324. # Accumulate int16 PCM for live speaker matching
  325. current_spk = state._raw_speaker_id
  326. if current_spk and current_spk not in state._confirmed_ids:
  327. if current_spk not in state._accumulators:
  328. state._accumulators[current_spk] = \
  329. state._embedding_registry.make_accumulator(min_seconds=5.0)
  330. state._accumulators[current_spk].push(chunk)
  331. if state._accumulators[current_spk].ready():
  332. try:
  333. live_emb = state._accumulators[current_spk].extract_embedding()
  334. match = state._embedding_registry.find_match(live_emb)
  335. if match:
  336. matched_id, score = match
  337. matched_entry = state.speaker_names.get(matched_id, {})
  338. matched_name = (
  339. matched_entry.get("name", matched_id)
  340. if isinstance(matched_entry, dict)
  341. else str(matched_entry)
  342. )
  343. print(
  344. f"[Embeddings] Auto-matched {current_spk} → "
  345. f"{matched_name} (score={score:.3f})"
  346. )
  347. with state._lock:
  348. entry = state.speaker_names.get(current_spk, {})
  349. if not isinstance(entry, dict):
  350. entry = {}
  351. state.speaker_names[current_spk] = {**entry, "name": matched_name}
  352. state._confirmed_ids.add(current_spk)
  353. _write_speakers(state.speaker_names)
  354. else:
  355. # No match yet — reset and try again with more audio
  356. state._accumulators[current_spk].reset()
  357. except Exception as exc:
  358. print(f"[Embeddings] Accumulator error: {exc}")
  359. state._accumulators[current_spk].reset()
  360. flusher = asyncio.create_task(_flusher(state, mqtt_client))
  361. reloader = asyncio.create_task(_speaker_reloader(state))
  362. try:
  363. await asyncio.gather(_send_audio(), _receive_results())
  364. finally:
  365. flusher.cancel()
  366. reloader.cancel()
  367. async def _flusher(state: BridgeState, mqtt_client: mqtt.Client) -> None:
  368. while True:
  369. await asyncio.sleep(1.0)
  370. state.maybe_timeout_flush(mqtt_client)
  371. async def _speaker_reloader(state: BridgeState) -> None:
  372. last_mtime = 0.0
  373. while True:
  374. await asyncio.sleep(5.0)
  375. try:
  376. mtime = SPEAKERS_FILE.stat().st_mtime
  377. if mtime != last_mtime:
  378. fresh = _load_speakers()
  379. with state._lock:
  380. state.speaker_names = fresh
  381. last_mtime = mtime
  382. print("[Bridge] Speaker names reloaded from disk")
  383. except OSError:
  384. pass
  385. # ── Entry point ───────────────────────────────────────────────────────────────
  386. def main() -> None:
  387. state = BridgeState()
  388. mqtt_client = build_mqtt_client()
  389. engine = TranscriptionEngine(
  390. model_size="large-v3",
  391. lan="en",
  392. diarization=False,
  393. pcm_input=True,
  394. backend_policy="localagreement",
  395. confidence_validation=True,
  396. min_chunk_size=3,
  397. vac=False,
  398. )
  399. # Inject API must start before the audio loop so test playback works immediately
  400. def _run_inject_api():
  401. uvicorn.run(_bridge_app, host="127.0.0.1", port=8002, log_level="warning")
  402. inject_thread = threading.Thread(target=_run_inject_api, daemon=True)
  403. inject_thread.start()
  404. print("[Bridge] Test audio inject API at http://127.0.0.1:8002")
  405. def _run():
  406. asyncio.run(audio_processor_loop(state, mqtt_client, engine))
  407. ws_thread = threading.Thread(target=_run, daemon=True)
  408. ws_thread.start()
  409. print(f"[Bridge] Speaker names loaded from {SPEAKERS_FILE}")
  410. print("[Bridge] Audio pipeline running — speaker admin at http://localhost:8001")
  411. print("[Bridge] Close this window to quit")
  412. try:
  413. ws_thread.join()
  414. except KeyboardInterrupt:
  415. pass
  416. if __name__ == "__main__":
  417. main()