bridge.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. #!/usr/bin/env python3
  2. """
  3. bridge.py — Live Transcription Bridge
  4. Uses WhisperLiveKit's AudioProcessor API directly (no WebSocket).
  5. Publishes rolling 3-line JSON to Mosquitto MQTT.
  6. Run this script:
  7. python bridge.py
  8. """
  9. import asyncio
  10. import json
  11. import queue as _stdlib_queue
  12. import re
  13. import textwrap
  14. import threading
  15. import time
  16. from collections import Counter
  17. from pathlib import Path
  18. import numpy as np
  19. import paho.mqtt.client as mqtt
  20. import sounddevice as sd
  21. from fastapi import FastAPI, Request
  22. from whisperlivekit import AudioProcessor, TranscriptionEngine
  23. import uvicorn
  24. # ── Configuration ─────────────────────────────────────────────────────────────
  25. MQTT_HOST = "localhost"
  26. MQTT_PORT = 1883
  27. MQTT_TOPIC_TEXT = "display/text"
  28. MQTT_TOPIC_CLEAR = "display/clear"
  29. SAMPLE_RATE = 16000
  30. CHANNELS = 1
  31. BLOCKSIZE = 4096 # ~256 ms per chunk at 16 kHz
  32. SENTENCE_TIMEOUT = 4.0 # seconds of silence before forcing a flush
  33. MAX_LINE_CHARS = 38 # characters per line
  34. DISPLAY_LINES = 3
  35. # Set to a device index (integer) to force a specific microphone.
  36. # Leave as None to use the Windows default input device.
  37. AUDIO_DEVICE: int | None = 12
  38. SPEAKERS_FILE = Path(__file__).parent / "speakers.json"
  39. DEFAULT_SPEAKERS: dict[str, str] = {
  40. "SPEAKER_00": "Pastor",
  41. "SPEAKER_01": "Reader",
  42. "SPEAKER_02": "Guest",
  43. "SPEAKER_03": "Choir",
  44. }
  45. # ── Audio injection queue ─────────────────────────────────────────────────────
  46. # stdlib queue.Queue is thread-safe across event loops; asyncio.Queue is not.
  47. # admin.py POSTs test audio chunks to /inject (port 8002) which puts them here.
  48. # _send_audio() drains this queue in preference to the live microphone.
  49. _inject_queue: _stdlib_queue.Queue = _stdlib_queue.Queue(maxsize=240)
  50. # ── Audio injection API ───────────────────────────────────────────────────────
  51. _bridge_app = FastAPI()
  52. @_bridge_app.post("/inject")
  53. async def inject_audio(request: Request):
  54. chunk = await request.body()
  55. if chunk:
  56. try:
  57. _inject_queue.put_nowait(chunk)
  58. except _stdlib_queue.Full:
  59. pass
  60. return {"ok": True}
  61. @_bridge_app.post("/inject/clear")
  62. async def inject_clear():
  63. while True:
  64. try:
  65. _inject_queue.get_nowait()
  66. except _stdlib_queue.Empty:
  67. break
  68. return {"ok": True}
  69. # ── Speaker persistence ───────────────────────────────────────────────────────
  70. def _load_speakers() -> dict[str, str]:
  71. if SPEAKERS_FILE.exists():
  72. try:
  73. data = json.loads(SPEAKERS_FILE.read_text(encoding="utf-8"))
  74. if isinstance(data, dict):
  75. return data
  76. except (json.JSONDecodeError, OSError):
  77. pass
  78. _write_speakers(DEFAULT_SPEAKERS)
  79. return dict(DEFAULT_SPEAKERS)
  80. def _write_speakers(names: dict[str, str]) -> None:
  81. try:
  82. SPEAKERS_FILE.write_text(
  83. json.dumps(names, indent=2, ensure_ascii=False),
  84. encoding="utf-8",
  85. )
  86. except OSError as exc:
  87. print(f"[Speakers] Save failed: {exc}")
  88. # ── State ─────────────────────────────────────────────────────────────────────
  89. class BridgeState:
  90. """All mutable state, protected by a single lock."""
  91. def __init__(self):
  92. self._lock = threading.Lock()
  93. self.speaker_names: dict[str, str] = _load_speakers()
  94. self._seen: set[str] = set(self.speaker_names)
  95. self._current_speaker: str | None = None
  96. self._speaker_changed = False
  97. self._text_buffer = ""
  98. self._display: list[str] = [""] * DISPLAY_LINES
  99. self._last_final_time = time.monotonic()
  100. def set_speaker_name(self, speaker_id: str, name: str) -> None:
  101. with self._lock:
  102. self.speaker_names[speaker_id] = name.strip()
  103. self._seen.add(speaker_id)
  104. _write_speakers(self.speaker_names)
  105. def delete_speaker(self, speaker_id: str) -> None:
  106. with self._lock:
  107. self.speaker_names.pop(speaker_id, None)
  108. self._seen.discard(speaker_id)
  109. _write_speakers(self.speaker_names)
  110. def seen_speakers_snapshot(self) -> set[str]:
  111. with self._lock:
  112. return set(self._seen)
  113. def _resolve(self, speaker_id: str | None) -> str | None:
  114. if not speaker_id:
  115. return None
  116. return self.speaker_names.get(speaker_id, speaker_id)
  117. def push_final(self, text: str, speaker_id: str | None, mqtt_client: mqtt.Client) -> None:
  118. with self._lock:
  119. if speaker_id:
  120. self._seen.add(speaker_id)
  121. resolved = self._resolve(speaker_id)
  122. if resolved != self._current_speaker:
  123. if self._text_buffer:
  124. self._flush(mqtt_client)
  125. self._current_speaker = resolved
  126. self._speaker_changed = True
  127. sep = " " if self._text_buffer else ""
  128. self._text_buffer += sep + text.strip()
  129. self._last_final_time = time.monotonic()
  130. if _is_sentence_end(text):
  131. self._flush(mqtt_client)
  132. def maybe_timeout_flush(self, mqtt_client: mqtt.Client) -> None:
  133. with self._lock:
  134. if self._text_buffer and (time.monotonic() - self._last_final_time) > SENTENCE_TIMEOUT:
  135. self._flush(mqtt_client)
  136. def _flush(self, mqtt_client: mqtt.Client) -> None:
  137. text = self._text_buffer.strip()
  138. self._text_buffer = ""
  139. if not text:
  140. return
  141. new_lines: list[str] = []
  142. if self._speaker_changed and self._current_speaker:
  143. new_lines.append(f"[{self._current_speaker.upper()}]")
  144. self._speaker_changed = False
  145. new_lines.extend(textwrap.wrap(text, MAX_LINE_CHARS) or [""])
  146. self._display.extend(new_lines)
  147. self._display = self._display[-DISPLAY_LINES:]
  148. while len(self._display) < DISPLAY_LINES:
  149. self._display.insert(0, "")
  150. payload = json.dumps({"lines": list(self._display)})
  151. mqtt_client.publish(MQTT_TOPIC_TEXT, payload)
  152. print(f"[Display] {self._display}")
  153. def clear(self, mqtt_client: mqtt.Client) -> None:
  154. with self._lock:
  155. self._display = [""] * DISPLAY_LINES
  156. self._text_buffer = ""
  157. self._current_speaker = None
  158. self._speaker_changed = False
  159. mqtt_client.publish(MQTT_TOPIC_CLEAR, "")
  160. print("[Display] Cleared")
  161. # ── Helpers ───────────────────────────────────────────────────────────────────
  162. def _is_sentence_end(text: str) -> bool:
  163. return bool(re.search(r'[.!?…]\s*$', text.strip()))
  164. # ── MQTT ──────────────────────────────────────────────────────────────────────
  165. def build_mqtt_client() -> mqtt.Client:
  166. client = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2)
  167. def on_connect(client, userdata, flags, rc, props):
  168. print("[MQTT] Connected" if rc == 0 else f"[MQTT] Failed: {rc}")
  169. def on_disconnect(client, userdata, flags, rc, props):
  170. print(f"[MQTT] Disconnected ({rc}), will reconnect...")
  171. client.on_connect = on_connect
  172. client.on_disconnect = on_disconnect
  173. client.reconnect_delay_set(min_delay=1, max_delay=30)
  174. client.connect_async(MQTT_HOST, MQTT_PORT)
  175. client.loop_start()
  176. return client
  177. # ── Audio pipeline ────────────────────────────────────────────────────────────
  178. def _choose_audio_device() -> int | None:
  179. try:
  180. devices = sd.query_devices()
  181. default_in = sd.default.device[0]
  182. except Exception as exc:
  183. print(f"[Audio] Cannot query devices: {exc}")
  184. return None
  185. print("[Audio] Available input devices:")
  186. input_devices: list[tuple[int, str]] = []
  187. for i, dev in enumerate(devices):
  188. if dev["max_input_channels"] > 0:
  189. marker = " ← default" if i == default_in else ""
  190. print(f" [{i}] {dev['name']}{marker}")
  191. input_devices.append((i, dev["name"]))
  192. if not input_devices:
  193. print("[Audio] ERROR: No input devices found.")
  194. return None
  195. if AUDIO_DEVICE is not None:
  196. print(f"[Audio] Using configured device [{AUDIO_DEVICE}]")
  197. return AUDIO_DEVICE
  198. if default_in >= 0:
  199. print(f"[Audio] Using default input device [{default_in}]")
  200. return default_in
  201. idx, name = input_devices[0]
  202. print(f"[Audio] No system default — using [{idx}] {name}")
  203. return idx
  204. async def audio_processor_loop(state: BridgeState, mqtt_client: mqtt.Client, engine: TranscriptionEngine) -> None:
  205. audio_queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=120)
  206. loop = asyncio.get_running_loop()
  207. def audio_callback(indata: np.ndarray, frames: int, time_info, status) -> None:
  208. if status:
  209. print(f"[Audio] {status}")
  210. chunk = indata.tobytes()
  211. loop.call_soon_threadsafe(
  212. lambda: audio_queue.put_nowait(chunk) if not audio_queue.full() else None
  213. )
  214. device = _choose_audio_device()
  215. if device is None:
  216. print("[Audio] No input device — cannot start.")
  217. return
  218. audio_processor = AudioProcessor(transcription_engine=engine)
  219. results_generator = await audio_processor.create_tasks()
  220. async def _receive_results():
  221. # FrontData.lines is a cumulative list of committed Segment objects.
  222. # Track how many we've already processed so we only push new ones.
  223. seen_lines = 0
  224. async for response in results_generator:
  225. lines = response.lines or []
  226. # Guard against unexpected shrink (e.g. processor reset)
  227. if len(lines) < seen_lines:
  228. seen_lines = 0
  229. for seg in lines[seen_lines:]:
  230. text = (seg.text or "").strip()
  231. if text and not seg.is_silence():
  232. spk = seg.speaker
  233. speaker_id = f"SPEAKER_{spk:02d}" if isinstance(spk, int) and spk >= 0 else None
  234. print(f"[Whisper] ({speaker_id or '?'}) {text}")
  235. state.push_final(text, speaker_id, mqtt_client)
  236. seen_lines = len(lines)
  237. async def _send_audio():
  238. with sd.InputStream(
  239. device=device, samplerate=SAMPLE_RATE, channels=CHANNELS,
  240. dtype="int16", blocksize=BLOCKSIZE, callback=audio_callback,
  241. ):
  242. while True:
  243. # Injected test audio takes priority over live microphone
  244. try:
  245. chunk = _inject_queue.get_nowait()
  246. except _stdlib_queue.Empty:
  247. chunk = await audio_queue.get()
  248. await audio_processor.process_audio(chunk)
  249. flusher = asyncio.create_task(_flusher(state, mqtt_client))
  250. reloader = asyncio.create_task(_speaker_reloader(state))
  251. try:
  252. await asyncio.gather(_send_audio(), _receive_results())
  253. finally:
  254. flusher.cancel()
  255. reloader.cancel()
  256. async def _flusher(state: BridgeState, mqtt_client: mqtt.Client) -> None:
  257. while True:
  258. await asyncio.sleep(1.0)
  259. state.maybe_timeout_flush(mqtt_client)
  260. async def _speaker_reloader(state: BridgeState) -> None:
  261. last_mtime = 0.0
  262. while True:
  263. await asyncio.sleep(5.0)
  264. try:
  265. mtime = SPEAKERS_FILE.stat().st_mtime
  266. if mtime != last_mtime:
  267. fresh = _load_speakers()
  268. with state._lock:
  269. state.speaker_names = fresh
  270. last_mtime = mtime
  271. print("[Bridge] Speaker names reloaded from disk")
  272. except OSError:
  273. pass
  274. # ── Entry point ───────────────────────────────────────────────────────────────
  275. def main() -> None:
  276. state = BridgeState()
  277. mqtt_client = build_mqtt_client()
  278. engine = TranscriptionEngine(model_size="large-v3", lan="en", diarization=False, pcm_input=True)
  279. # Inject API must start before the audio loop so test playback works immediately
  280. def _run_inject_api():
  281. uvicorn.run(_bridge_app, host="127.0.0.1", port=8002, log_level="warning")
  282. inject_thread = threading.Thread(target=_run_inject_api, daemon=True)
  283. inject_thread.start()
  284. print("[Bridge] Test audio inject API at http://127.0.0.1:8002")
  285. def _run():
  286. asyncio.run(audio_processor_loop(state, mqtt_client, engine))
  287. ws_thread = threading.Thread(target=_run, daemon=True)
  288. ws_thread.start()
  289. print(f"[Bridge] Speaker names loaded from {SPEAKERS_FILE}")
  290. print("[Bridge] Audio pipeline running — speaker admin at http://localhost:8001")
  291. print("[Bridge] Close this window to quit")
  292. try:
  293. ws_thread.join()
  294. except KeyboardInterrupt:
  295. pass
  296. if __name__ == "__main__":
  297. main()