bridge.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. #!/usr/bin/env python3
  2. """
  3. bridge.py — Live Transcription Bridge
  4. Uses WhisperLiveKit's AudioProcessor API directly (no WebSocket).
  5. Publishes rolling 3-line JSON to Mosquitto MQTT.
  6. Run this script:
  7. python bridge.py
  8. """
  9. import asyncio
  10. import json
  11. import queue as _stdlib_queue
  12. import re
  13. import textwrap
  14. import threading
  15. import time
  16. from collections import Counter
  17. from datetime import datetime, timezone
  18. from pathlib import Path
  19. import numpy as np
  20. import paho.mqtt.client as mqtt
  21. import sounddevice as sd
  22. from fastapi import FastAPI, Request
  23. from whisperlivekit import AudioProcessor, TranscriptionEngine
  24. from embeddings import (
  25. EmbeddingRegistry,
  26. log_transcript_segment,
  27. )
  28. import uvicorn
  29. # ── Configuration ─────────────────────────────────────────────────────────────
  30. MQTT_HOST = "localhost"
  31. MQTT_PORT = 1883
  32. MQTT_TOPIC_TEXT = "display/text"
  33. MQTT_TOPIC_CLEAR = "display/clear"
  34. SAMPLE_RATE = 16000
  35. CHANNELS = 1
  36. BLOCKSIZE = 4096 # ~256 ms per chunk at 16 kHz
  37. SENTENCE_TIMEOUT = 4.0 # seconds of silence before forcing a flush
  38. MAX_LINE_CHARS = 80 # characters per line
  39. DISPLAY_LINES = 3
  40. # Set to a device index (integer) to force a specific microphone.
  41. # Leave as None to use the Windows default input device.
  42. AUDIO_DEVICE: int | None = 12
  43. SPEAKERS_FILE = Path(__file__).parent / "speakers.json"
  44. DEFAULT_SPEAKERS: dict[str, dict] = {
  45. "SPEAKER_00": {"name": "A.A.A", "role": "Serving Brother", "location": "Sydney", "has_embedding": False, "embedding_updated": None, "colour": "#16a34a", "notes": ""},
  46. "SPEAKER_01": {"name": "A.A.A", "role": "Contributor", "location": "London", "has_embedding": False, "embedding_updated": None, "colour": "#16a34a", "notes": ""},
  47. "SPEAKER_02": {"name": "A.A.A", "role": "Contributor", "location": "Hobart", "has_embedding": False, "embedding_updated": None, "colour": "#16a34a", "notes": ""},
  48. "SPEAKER_03": {"name": "A.A.A", "role": "Contributor", "location": "Perth", "has_embedding": False, "embedding_updated": None, "colour": "#16a34a", "notes": ""},
  49. }
  50. # ── Audio injection queue ─────────────────────────────────────────────────────
  51. # stdlib queue.Queue is thread-safe across event loops; asyncio.Queue is not.
  52. # admin.py POSTs test audio chunks to /inject (port 8002) which puts them here.
  53. # _send_audio() drains this queue in preference to the live microphone.
  54. _inject_queue: _stdlib_queue.Queue = _stdlib_queue.Queue(maxsize=240)
  55. # ── Audio injection API ───────────────────────────────────────────────────────
  56. _bridge_app = FastAPI()
  57. @_bridge_app.post("/inject")
  58. async def inject_audio(request: Request):
  59. chunk = await request.body()
  60. if chunk:
  61. try:
  62. _inject_queue.put_nowait(chunk)
  63. except _stdlib_queue.Full:
  64. pass
  65. return {"ok": True}
  66. @_bridge_app.post("/inject/clear")
  67. async def inject_clear():
  68. while True:
  69. try:
  70. _inject_queue.get_nowait()
  71. except _stdlib_queue.Empty:
  72. break
  73. return {"ok": True}
  74. # ── Speaker persistence ───────────────────────────────────────────────────────
  75. def _load_speakers() -> dict:
  76. if SPEAKERS_FILE.exists():
  77. try:
  78. data = json.loads(SPEAKERS_FILE.read_text(encoding="utf-8"))
  79. if isinstance(data, dict):
  80. return data
  81. except (json.JSONDecodeError, OSError):
  82. pass
  83. _write_speakers(DEFAULT_SPEAKERS)
  84. return dict(DEFAULT_SPEAKERS)
  85. def _write_speakers(names: dict) -> None:
  86. try:
  87. SPEAKERS_FILE.write_text(
  88. json.dumps(names, indent=2, ensure_ascii=False),
  89. encoding="utf-8",
  90. )
  91. except OSError as exc:
  92. print(f"[Speakers] Save failed: {exc}")
  93. # ── State ─────────────────────────────────────────────────────────────────────
  94. class BridgeState:
  95. """All mutable state, protected by a single lock."""
  96. def __init__(self):
  97. self._lock = threading.Lock()
  98. self.speaker_names: dict = _load_speakers()
  99. self._seen: set[str] = set(self.speaker_names)
  100. self._current_speaker: str | None = None
  101. self._speaker_changed = False
  102. self._text_buffer = ""
  103. self._display: list[str] = [""] * DISPLAY_LINES
  104. self._last_final_time = time.monotonic()
  105. # Raw diarization ID of current speaker (SPEAKER_XX)
  106. self._raw_speaker_id: str | None = None
  107. # Voiceprint matching
  108. self._embedding_registry = EmbeddingRegistry()
  109. self._accumulators: dict[str, object] = {}
  110. self._confirmed_ids: set[str] = set()
  111. self._session_id: str = datetime.now(timezone.utc).strftime("%Y-%m-%d")
  112. def set_speaker_name(self, speaker_id: str, name: str) -> None:
  113. with self._lock:
  114. entry = self.speaker_names.get(speaker_id, {})
  115. if isinstance(entry, dict):
  116. self.speaker_names[speaker_id] = {**entry, "name": name.strip()}
  117. else:
  118. self.speaker_names[speaker_id] = {"name": name.strip()}
  119. self._seen.add(speaker_id)
  120. _write_speakers(self.speaker_names)
  121. def delete_speaker(self, speaker_id: str) -> None:
  122. with self._lock:
  123. self.speaker_names.pop(speaker_id, None)
  124. self._seen.discard(speaker_id)
  125. _write_speakers(self.speaker_names)
  126. def seen_speakers_snapshot(self) -> set[str]:
  127. with self._lock:
  128. return set(self._seen)
  129. def _resolve(self, speaker_id: str | None) -> str | None:
  130. if not speaker_id:
  131. return None
  132. entry = self.speaker_names.get(speaker_id)
  133. if entry is None:
  134. return speaker_id
  135. if isinstance(entry, dict):
  136. return entry.get("name") or speaker_id
  137. return str(entry)
  138. def push_final(self, text: str, speaker_id: str | None, mqtt_client: mqtt.Client, seg_start: float = 0.0, seg_end: float = 0.0) -> None:
  139. """Accept a finalised segment; flush on sentence boundary or speaker change."""
  140. with self._lock:
  141. # Track raw diarization ID for PCM accumulator
  142. self._raw_speaker_id = speaker_id
  143. if speaker_id:
  144. self._seen.add(speaker_id)
  145. # Log segment to transcript_segments.jsonl for later enrolment
  146. log_transcript_segment(
  147. speaker_id = speaker_id or "UNKNOWN",
  148. text = text,
  149. start_sec = seg_start,
  150. end_sec = seg_end,
  151. session_id = self._session_id,
  152. )
  153. resolved = self._resolve(speaker_id)
  154. if resolved != self._current_speaker:
  155. if self._text_buffer:
  156. self._flush(mqtt_client)
  157. self._current_speaker = resolved
  158. self._speaker_changed = True
  159. sep = " " if self._text_buffer else ""
  160. self._text_buffer += sep + text.strip()
  161. self._last_final_time = time.monotonic()
  162. if _is_sentence_end(text):
  163. self._flush(mqtt_client)
  164. def maybe_timeout_flush(self, mqtt_client: mqtt.Client) -> None:
  165. with self._lock:
  166. if self._text_buffer and (time.monotonic() - self._last_final_time) > SENTENCE_TIMEOUT:
  167. self._flush(mqtt_client)
  168. def _flush(self, mqtt_client: mqtt.Client) -> None:
  169. text = self._text_buffer.strip()
  170. self._text_buffer = ""
  171. if not text:
  172. return
  173. new_lines: list[str] = []
  174. if self._speaker_changed and self._current_speaker:
  175. new_lines.append(f"[{self._current_speaker.upper()}]")
  176. self._speaker_changed = False
  177. new_lines.extend(textwrap.wrap(text, MAX_LINE_CHARS) or [""])
  178. for line in new_lines:
  179. self._display.append(line)
  180. self._display = self._display[-DISPLAY_LINES:]
  181. while len(self._display) < DISPLAY_LINES:
  182. self._display.insert(0, "")
  183. payload = json.dumps({"lines": list(self._display)})
  184. mqtt_client.publish(MQTT_TOPIC_TEXT, payload)
  185. print(f"[Display] {self._display}")
  186. def clear(self, mqtt_client: mqtt.Client) -> None:
  187. with self._lock:
  188. self._display = [""] * DISPLAY_LINES
  189. self._text_buffer = ""
  190. self._current_speaker = None
  191. self._speaker_changed = False
  192. mqtt_client.publish(MQTT_TOPIC_CLEAR, "")
  193. print("[Display] Cleared")
  194. # ── Helpers ───────────────────────────────────────────────────────────────────
  195. def _is_sentence_end(text: str) -> bool:
  196. return bool(re.search(r'[.!?…]\s*$', text.strip()))
  197. # ── MQTT ──────────────────────────────────────────────────────────────────────
  198. def build_mqtt_client() -> mqtt.Client:
  199. client = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2)
  200. def on_connect(client, userdata, flags, rc, props):
  201. print("[MQTT] Connected" if rc == 0 else f"[MQTT] Failed: {rc}")
  202. def on_disconnect(client, userdata, flags, rc, props):
  203. print(f"[MQTT] Disconnected ({rc}), will reconnect...")
  204. client.on_connect = on_connect
  205. client.on_disconnect = on_disconnect
  206. client.reconnect_delay_set(min_delay=1, max_delay=30)
  207. client.connect_async(MQTT_HOST, MQTT_PORT)
  208. client.loop_start()
  209. return client
  210. # ── Audio pipeline ────────────────────────────────────────────────────────────
  211. def _choose_audio_device() -> int | None:
  212. try:
  213. devices = sd.query_devices()
  214. default_in = sd.default.device[0]
  215. except Exception as exc:
  216. print(f"[Audio] Cannot query devices: {exc}")
  217. return None
  218. print("[Audio] Available input devices:")
  219. input_devices: list[tuple[int, str]] = []
  220. for i, dev in enumerate(devices):
  221. if dev["max_input_channels"] > 0:
  222. marker = " ← default" if i == default_in else ""
  223. print(f" [{i}] {dev['name']}{marker}")
  224. input_devices.append((i, dev["name"]))
  225. if not input_devices:
  226. print("[Audio] ERROR: No input devices found.")
  227. return None
  228. if AUDIO_DEVICE is not None:
  229. print(f"[Audio] Using configured device [{AUDIO_DEVICE}]")
  230. return AUDIO_DEVICE
  231. if default_in >= 0:
  232. print(f"[Audio] Using default input device [{default_in}]")
  233. return default_in
  234. idx, name = input_devices[0]
  235. print(f"[Audio] No system default — using [{idx}] {name}")
  236. return idx
  237. async def audio_processor_loop(state: BridgeState, mqtt_client: mqtt.Client, engine: TranscriptionEngine) -> None:
  238. audio_queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=120)
  239. loop = asyncio.get_running_loop()
  240. def audio_callback(indata: np.ndarray, frames: int, time_info, status) -> None:
  241. if status:
  242. print(f"[Audio] {status}")
  243. chunk = indata.tobytes()
  244. loop.call_soon_threadsafe(
  245. lambda: audio_queue.put_nowait(chunk) if not audio_queue.full() else None
  246. )
  247. device = _choose_audio_device()
  248. if device is None:
  249. print("[Audio] No input device — cannot start.")
  250. return
  251. audio_processor = AudioProcessor(transcription_engine=engine)
  252. results_generator = await audio_processor.create_tasks()
  253. async def _receive_results():
  254. # FrontData.lines is validated_segments + a growing current-segment snapshot.
  255. # The last element's text GROWS silently between calls, so index-counting
  256. # misses incremental content. Instead, track the full concatenated
  257. # transcript and push only the delta each time it grows.
  258. prev_full_text = ""
  259. async for response in results_generator:
  260. lines = response.lines or []
  261. current_full_text = " ".join(
  262. (seg.text or "").strip()
  263. for seg in lines
  264. if not seg.is_silence() and (seg.text or "").strip()
  265. )
  266. if current_full_text == prev_full_text:
  267. continue
  268. if prev_full_text and current_full_text.startswith(prev_full_text):
  269. new_text = current_full_text[len(prev_full_text):].strip()
  270. # Drop leading punctuation that belongs to the previous sentence
  271. while new_text and new_text[0] in ".,;:!?":
  272. new_text = new_text[1:].strip()
  273. else:
  274. # First segment or context reset after a long silence
  275. new_text = current_full_text
  276. prev_full_text = current_full_text
  277. if not new_text or len(new_text) < 2:
  278. continue
  279. last_seg = next(
  280. (s for s in reversed(lines) if not s.is_silence() and (s.text or "").strip()),
  281. None,
  282. )
  283. spk = getattr(last_seg, "speaker", None) if last_seg else None
  284. speaker_id = f"SPEAKER_{spk:02d}" if isinstance(spk, int) and spk >= 0 else None
  285. seg_start = float(getattr(last_seg, "start", 0.0) or 0.0) if last_seg else 0.0
  286. seg_end = float(getattr(last_seg, "end", 0.0) or 0.0) if last_seg else 0.0
  287. print(f"[Whisper] ({speaker_id or '?'}) {new_text}")
  288. state.push_final(new_text, speaker_id, mqtt_client, seg_start, seg_end)
  289. async def _send_audio():
  290. with sd.InputStream(
  291. device=device, samplerate=SAMPLE_RATE, channels=CHANNELS,
  292. dtype="int16", blocksize=BLOCKSIZE, callback=audio_callback,
  293. ):
  294. while True:
  295. # Injected test audio takes priority over live microphone
  296. try:
  297. chunk = _inject_queue.get_nowait()
  298. except _stdlib_queue.Empty:
  299. chunk = await audio_queue.get()
  300. await audio_processor.process_audio(chunk)
  301. # Accumulate int16 PCM for live speaker matching
  302. current_spk = state._raw_speaker_id
  303. if current_spk and current_spk not in state._confirmed_ids:
  304. if current_spk not in state._accumulators:
  305. state._accumulators[current_spk] = \
  306. state._embedding_registry.make_accumulator(min_seconds=5.0)
  307. state._accumulators[current_spk].push(chunk)
  308. if state._accumulators[current_spk].ready():
  309. try:
  310. live_emb = state._accumulators[current_spk].extract_embedding()
  311. match = state._embedding_registry.find_match(live_emb)
  312. if match:
  313. matched_id, score = match
  314. matched_entry = state.speaker_names.get(matched_id, {})
  315. matched_name = (
  316. matched_entry.get("name", matched_id)
  317. if isinstance(matched_entry, dict)
  318. else str(matched_entry)
  319. )
  320. print(
  321. f"[Embeddings] Auto-matched {current_spk} → "
  322. f"{matched_name} (score={score:.3f})"
  323. )
  324. with state._lock:
  325. entry = state.speaker_names.get(current_spk, {})
  326. if not isinstance(entry, dict):
  327. entry = {}
  328. state.speaker_names[current_spk] = {**entry, "name": matched_name}
  329. state._confirmed_ids.add(current_spk)
  330. _write_speakers(state.speaker_names)
  331. else:
  332. # No match yet — reset and try again with more audio
  333. state._accumulators[current_spk].reset()
  334. except Exception as exc:
  335. print(f"[Embeddings] Accumulator error: {exc}")
  336. state._accumulators[current_spk].reset()
  337. flusher = asyncio.create_task(_flusher(state, mqtt_client))
  338. reloader = asyncio.create_task(_speaker_reloader(state))
  339. try:
  340. await asyncio.gather(_send_audio(), _receive_results())
  341. finally:
  342. flusher.cancel()
  343. reloader.cancel()
  344. async def _flusher(state: BridgeState, mqtt_client: mqtt.Client) -> None:
  345. while True:
  346. await asyncio.sleep(1.0)
  347. state.maybe_timeout_flush(mqtt_client)
  348. async def _speaker_reloader(state: BridgeState) -> None:
  349. last_mtime = 0.0
  350. while True:
  351. await asyncio.sleep(5.0)
  352. try:
  353. mtime = SPEAKERS_FILE.stat().st_mtime
  354. if mtime != last_mtime:
  355. fresh = _load_speakers()
  356. with state._lock:
  357. state.speaker_names = fresh
  358. last_mtime = mtime
  359. print("[Bridge] Speaker names reloaded from disk")
  360. except OSError:
  361. pass
  362. # ── Entry point ───────────────────────────────────────────────────────────────
  363. def main() -> None:
  364. state = BridgeState()
  365. mqtt_client = build_mqtt_client()
  366. engine = TranscriptionEngine(
  367. model_size="large-v3",
  368. lan="en",
  369. diarization=False,
  370. pcm_input=True,
  371. backend_policy="localagreement",
  372. confidence_validation=True,
  373. min_chunk_size=3,
  374. vac=False,
  375. )
  376. # Inject API must start before the audio loop so test playback works immediately
  377. def _run_inject_api():
  378. uvicorn.run(_bridge_app, host="127.0.0.1", port=8002, log_level="warning")
  379. inject_thread = threading.Thread(target=_run_inject_api, daemon=True)
  380. inject_thread.start()
  381. print("[Bridge] Test audio inject API at http://127.0.0.1:8002")
  382. def _run():
  383. asyncio.run(audio_processor_loop(state, mqtt_client, engine))
  384. ws_thread = threading.Thread(target=_run, daemon=True)
  385. ws_thread.start()
  386. print(f"[Bridge] Speaker names loaded from {SPEAKERS_FILE}")
  387. print("[Bridge] Audio pipeline running — speaker admin at http://localhost:8001")
  388. print("[Bridge] Close this window to quit")
  389. try:
  390. ws_thread.join()
  391. except KeyboardInterrupt:
  392. pass
  393. if __name__ == "__main__":
  394. main()