| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484 |
- #!/usr/bin/env python3
- """
- bridge.py — Live Transcription Bridge
- Uses WhisperLiveKit's AudioProcessor API directly (no WebSocket).
- Publishes rolling 3-line JSON to Mosquitto MQTT.
- Run this script:
- python bridge.py
- """
- import asyncio
- import json
- import queue as _stdlib_queue
- import re
- import textwrap
- import threading
- import time
- from collections import Counter
- from datetime import datetime, timezone
- from pathlib import Path
- import numpy as np
- import paho.mqtt.client as mqtt
- import sounddevice as sd
- from fastapi import FastAPI, Request
- from whisperlivekit import AudioProcessor, TranscriptionEngine
- from embeddings import (
- EmbeddingRegistry,
- log_transcript_segment,
- )
- import uvicorn
- # ── Configuration ─────────────────────────────────────────────────────────────
- MQTT_HOST = "localhost"
- MQTT_PORT = 1883
- MQTT_TOPIC_TEXT = "display/text"
- MQTT_TOPIC_CLEAR = "display/clear"
- SAMPLE_RATE = 16000
- CHANNELS = 1
- BLOCKSIZE = 4096 # ~256 ms per chunk at 16 kHz
- SENTENCE_TIMEOUT = 4.0 # seconds of silence before forcing a flush
- MAX_LINE_CHARS = 80 # characters per line
- DISPLAY_LINES = 3
- # Set to a device index (integer) to force a specific microphone.
- # Leave as None to use the Windows default input device.
- AUDIO_DEVICE: int | None = 12
- SPEAKERS_FILE = Path(__file__).parent / "speakers.json"
- DEFAULT_SPEAKERS: dict[str, dict] = {
- "SPEAKER_00": {"name": "A.A.A", "role": "Serving Brother", "location": "Sydney", "has_embedding": False, "embedding_updated": None, "colour": "#16a34a", "notes": ""},
- "SPEAKER_01": {"name": "A.A.A", "role": "Contributor", "location": "London", "has_embedding": False, "embedding_updated": None, "colour": "#16a34a", "notes": ""},
- "SPEAKER_02": {"name": "A.A.A", "role": "Contributor", "location": "Hobart", "has_embedding": False, "embedding_updated": None, "colour": "#16a34a", "notes": ""},
- "SPEAKER_03": {"name": "A.A.A", "role": "Contributor", "location": "Perth", "has_embedding": False, "embedding_updated": None, "colour": "#16a34a", "notes": ""},
- }
- # ── Audio injection queue ─────────────────────────────────────────────────────
- # stdlib queue.Queue is thread-safe across event loops; asyncio.Queue is not.
- # admin.py POSTs test audio chunks to /inject (port 8002) which puts them here.
- # _send_audio() drains this queue in preference to the live microphone.
- _inject_queue: _stdlib_queue.Queue = _stdlib_queue.Queue(maxsize=240)
- # ── Audio injection API ───────────────────────────────────────────────────────
- _bridge_app = FastAPI()
- @_bridge_app.post("/inject")
- async def inject_audio(request: Request):
- chunk = await request.body()
- if chunk:
- try:
- _inject_queue.put_nowait(chunk)
- except _stdlib_queue.Full:
- pass
- return {"ok": True}
- @_bridge_app.post("/inject/clear")
- async def inject_clear():
- while True:
- try:
- _inject_queue.get_nowait()
- except _stdlib_queue.Empty:
- break
- return {"ok": True}
- # ── Speaker persistence ───────────────────────────────────────────────────────
- def _load_speakers() -> dict:
- if SPEAKERS_FILE.exists():
- try:
- data = json.loads(SPEAKERS_FILE.read_text(encoding="utf-8"))
- if isinstance(data, dict):
- return data
- except (json.JSONDecodeError, OSError):
- pass
- _write_speakers(DEFAULT_SPEAKERS)
- return dict(DEFAULT_SPEAKERS)
- def _write_speakers(names: dict) -> None:
- try:
- SPEAKERS_FILE.write_text(
- json.dumps(names, indent=2, ensure_ascii=False),
- encoding="utf-8",
- )
- except OSError as exc:
- print(f"[Speakers] Save failed: {exc}")
- # ── State ─────────────────────────────────────────────────────────────────────
- class BridgeState:
- """All mutable state, protected by a single lock."""
- def __init__(self):
- self._lock = threading.Lock()
- self.speaker_names: dict = _load_speakers()
- self._seen: set[str] = set(self.speaker_names)
- self._current_speaker: str | None = None
- self._speaker_changed = False
- self._text_buffer = ""
- self._display: list[str] = [""] * DISPLAY_LINES
- self._last_final_time = time.monotonic()
- # Raw diarization ID of current speaker (SPEAKER_XX)
- self._raw_speaker_id: str | None = None
- # Voiceprint matching
- self._embedding_registry = EmbeddingRegistry()
- self._accumulators: dict[str, object] = {}
- self._confirmed_ids: set[str] = set()
- self._session_id: str = datetime.now(timezone.utc).strftime("%Y-%m-%d")
- def set_speaker_name(self, speaker_id: str, name: str) -> None:
- with self._lock:
- entry = self.speaker_names.get(speaker_id, {})
- if isinstance(entry, dict):
- self.speaker_names[speaker_id] = {**entry, "name": name.strip()}
- else:
- self.speaker_names[speaker_id] = {"name": name.strip()}
- self._seen.add(speaker_id)
- _write_speakers(self.speaker_names)
- def delete_speaker(self, speaker_id: str) -> None:
- with self._lock:
- self.speaker_names.pop(speaker_id, None)
- self._seen.discard(speaker_id)
- _write_speakers(self.speaker_names)
- def seen_speakers_snapshot(self) -> set[str]:
- with self._lock:
- return set(self._seen)
- def _resolve(self, speaker_id: str | None) -> str | None:
- if not speaker_id:
- return None
- entry = self.speaker_names.get(speaker_id)
- if entry is None:
- return speaker_id
- if isinstance(entry, dict):
- return entry.get("name") or speaker_id
- return str(entry)
- def push_final(self, text: str, speaker_id: str | None, mqtt_client: mqtt.Client, seg_start: float = 0.0, seg_end: float = 0.0) -> None:
- """Accept a finalised segment; flush on sentence boundary or speaker change."""
- with self._lock:
- # Track raw diarization ID for PCM accumulator
- self._raw_speaker_id = speaker_id
- if speaker_id:
- self._seen.add(speaker_id)
- # Log segment to transcript_segments.jsonl for later enrolment
- log_transcript_segment(
- speaker_id = speaker_id or "UNKNOWN",
- text = text,
- start_sec = seg_start,
- end_sec = seg_end,
- session_id = self._session_id,
- )
- resolved = self._resolve(speaker_id)
- if resolved != self._current_speaker:
- if self._text_buffer:
- self._flush(mqtt_client)
- self._current_speaker = resolved
- self._speaker_changed = True
- sep = " " if self._text_buffer else ""
- self._text_buffer += sep + text.strip()
- self._last_final_time = time.monotonic()
- if _is_sentence_end(text):
- self._flush(mqtt_client)
- def maybe_timeout_flush(self, mqtt_client: mqtt.Client) -> None:
- with self._lock:
- if self._text_buffer and (time.monotonic() - self._last_final_time) > SENTENCE_TIMEOUT:
- self._flush(mqtt_client)
- def _flush(self, mqtt_client: mqtt.Client) -> None:
- text = self._text_buffer.strip()
- self._text_buffer = ""
- if not text:
- return
- new_lines: list[str] = []
- if self._speaker_changed and self._current_speaker:
- new_lines.append(f"[{self._current_speaker.upper()}]")
- self._speaker_changed = False
- new_lines.extend(textwrap.wrap(text, MAX_LINE_CHARS) or [""])
- for line in new_lines:
- self._display.append(line)
- self._display = self._display[-DISPLAY_LINES:]
- while len(self._display) < DISPLAY_LINES:
- self._display.insert(0, "")
- payload = json.dumps({"lines": list(self._display)})
- mqtt_client.publish(MQTT_TOPIC_TEXT, payload)
- print(f"[Display] {self._display}")
- def clear(self, mqtt_client: mqtt.Client) -> None:
- with self._lock:
- self._display = [""] * DISPLAY_LINES
- self._text_buffer = ""
- self._current_speaker = None
- self._speaker_changed = False
- mqtt_client.publish(MQTT_TOPIC_CLEAR, "")
- print("[Display] Cleared")
- # ── Helpers ───────────────────────────────────────────────────────────────────
- def _is_sentence_end(text: str) -> bool:
- return bool(re.search(r'[.!?…]\s*$', text.strip()))
- # ── MQTT ──────────────────────────────────────────────────────────────────────
- def build_mqtt_client() -> mqtt.Client:
- client = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2)
- def on_connect(client, userdata, flags, rc, props):
- print("[MQTT] Connected" if rc == 0 else f"[MQTT] Failed: {rc}")
- def on_disconnect(client, userdata, flags, rc, props):
- print(f"[MQTT] Disconnected ({rc}), will reconnect...")
- client.on_connect = on_connect
- client.on_disconnect = on_disconnect
- client.reconnect_delay_set(min_delay=1, max_delay=30)
- client.connect_async(MQTT_HOST, MQTT_PORT)
- client.loop_start()
- return client
- # ── Audio pipeline ────────────────────────────────────────────────────────────
- def _choose_audio_device() -> int | None:
- try:
- devices = sd.query_devices()
- default_in = sd.default.device[0]
- except Exception as exc:
- print(f"[Audio] Cannot query devices: {exc}")
- return None
- print("[Audio] Available input devices:")
- input_devices: list[tuple[int, str]] = []
- for i, dev in enumerate(devices):
- if dev["max_input_channels"] > 0:
- marker = " ← default" if i == default_in else ""
- print(f" [{i}] {dev['name']}{marker}")
- input_devices.append((i, dev["name"]))
- if not input_devices:
- print("[Audio] ERROR: No input devices found.")
- return None
- if AUDIO_DEVICE is not None:
- print(f"[Audio] Using configured device [{AUDIO_DEVICE}]")
- return AUDIO_DEVICE
- if default_in >= 0:
- print(f"[Audio] Using default input device [{default_in}]")
- return default_in
- idx, name = input_devices[0]
- print(f"[Audio] No system default — using [{idx}] {name}")
- return idx
- async def audio_processor_loop(state: BridgeState, mqtt_client: mqtt.Client, engine: TranscriptionEngine) -> None:
- audio_queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=120)
- loop = asyncio.get_running_loop()
- def audio_callback(indata: np.ndarray, frames: int, time_info, status) -> None:
- if status:
- print(f"[Audio] {status}")
- chunk = indata.tobytes()
- loop.call_soon_threadsafe(
- lambda: audio_queue.put_nowait(chunk) if not audio_queue.full() else None
- )
- device = _choose_audio_device()
- if device is None:
- print("[Audio] No input device — cannot start.")
- return
- audio_processor = AudioProcessor(transcription_engine=engine)
- results_generator = await audio_processor.create_tasks()
- async def _receive_results():
- # FrontData.lines is validated_segments + a growing current-segment snapshot.
- # The last element's text GROWS silently between calls, so index-counting
- # misses incremental content. Instead, track the full concatenated
- # transcript and push only the delta each time it grows.
- prev_full_text = ""
- async for response in results_generator:
- lines = response.lines or []
- current_full_text = " ".join(
- (seg.text or "").strip()
- for seg in lines
- if not seg.is_silence() and (seg.text or "").strip()
- )
- if current_full_text == prev_full_text:
- continue
- if prev_full_text and current_full_text.startswith(prev_full_text):
- new_text = current_full_text[len(prev_full_text):].strip()
- # Drop leading punctuation that belongs to the previous sentence
- while new_text and new_text[0] in ".,;:!?":
- new_text = new_text[1:].strip()
- else:
- # First segment or context reset after a long silence
- new_text = current_full_text
- prev_full_text = current_full_text
- if not new_text or len(new_text) < 2:
- continue
- last_seg = next(
- (s for s in reversed(lines) if not s.is_silence() and (s.text or "").strip()),
- None,
- )
- spk = getattr(last_seg, "speaker", None) if last_seg else None
- speaker_id = f"SPEAKER_{spk:02d}" if isinstance(spk, int) and spk >= 0 else None
- seg_start = float(getattr(last_seg, "start", 0.0) or 0.0) if last_seg else 0.0
- seg_end = float(getattr(last_seg, "end", 0.0) or 0.0) if last_seg else 0.0
- print(f"[Whisper] ({speaker_id or '?'}) {new_text}")
- state.push_final(new_text, speaker_id, mqtt_client, seg_start, seg_end)
- async def _send_audio():
- with sd.InputStream(
- device=device, samplerate=SAMPLE_RATE, channels=CHANNELS,
- dtype="int16", blocksize=BLOCKSIZE, callback=audio_callback,
- ):
- while True:
- # Injected test audio takes priority over live microphone
- try:
- chunk = _inject_queue.get_nowait()
- except _stdlib_queue.Empty:
- chunk = await audio_queue.get()
- await audio_processor.process_audio(chunk)
- # Accumulate int16 PCM for live speaker matching
- current_spk = state._raw_speaker_id
- if current_spk and current_spk not in state._confirmed_ids:
- if current_spk not in state._accumulators:
- state._accumulators[current_spk] = \
- state._embedding_registry.make_accumulator(min_seconds=5.0)
- state._accumulators[current_spk].push(chunk)
- if state._accumulators[current_spk].ready():
- try:
- live_emb = state._accumulators[current_spk].extract_embedding()
- match = state._embedding_registry.find_match(live_emb)
- if match:
- matched_id, score = match
- matched_entry = state.speaker_names.get(matched_id, {})
- matched_name = (
- matched_entry.get("name", matched_id)
- if isinstance(matched_entry, dict)
- else str(matched_entry)
- )
- print(
- f"[Embeddings] Auto-matched {current_spk} → "
- f"{matched_name} (score={score:.3f})"
- )
- with state._lock:
- entry = state.speaker_names.get(current_spk, {})
- if not isinstance(entry, dict):
- entry = {}
- state.speaker_names[current_spk] = {**entry, "name": matched_name}
- state._confirmed_ids.add(current_spk)
- _write_speakers(state.speaker_names)
- else:
- # No match yet — reset and try again with more audio
- state._accumulators[current_spk].reset()
- except Exception as exc:
- print(f"[Embeddings] Accumulator error: {exc}")
- state._accumulators[current_spk].reset()
- flusher = asyncio.create_task(_flusher(state, mqtt_client))
- reloader = asyncio.create_task(_speaker_reloader(state))
- try:
- await asyncio.gather(_send_audio(), _receive_results())
- finally:
- flusher.cancel()
- reloader.cancel()
- async def _flusher(state: BridgeState, mqtt_client: mqtt.Client) -> None:
- while True:
- await asyncio.sleep(1.0)
- state.maybe_timeout_flush(mqtt_client)
- async def _speaker_reloader(state: BridgeState) -> None:
- last_mtime = 0.0
- while True:
- await asyncio.sleep(5.0)
- try:
- mtime = SPEAKERS_FILE.stat().st_mtime
- if mtime != last_mtime:
- fresh = _load_speakers()
- with state._lock:
- state.speaker_names = fresh
- last_mtime = mtime
- print("[Bridge] Speaker names reloaded from disk")
- except OSError:
- pass
- # ── Entry point ───────────────────────────────────────────────────────────────
- def main() -> None:
- state = BridgeState()
- mqtt_client = build_mqtt_client()
- engine = TranscriptionEngine(
- model_size="large-v3",
- lan="en",
- diarization=False,
- pcm_input=True,
- backend_policy="localagreement",
- confidence_validation=True,
- min_chunk_size=3,
- vac=False,
- )
- # Inject API must start before the audio loop so test playback works immediately
- def _run_inject_api():
- uvicorn.run(_bridge_app, host="127.0.0.1", port=8002, log_level="warning")
- inject_thread = threading.Thread(target=_run_inject_api, daemon=True)
- inject_thread.start()
- print("[Bridge] Test audio inject API at http://127.0.0.1:8002")
- def _run():
- asyncio.run(audio_processor_loop(state, mqtt_client, engine))
- ws_thread = threading.Thread(target=_run, daemon=True)
- ws_thread.start()
- print(f"[Bridge] Speaker names loaded from {SPEAKERS_FILE}")
- print("[Bridge] Audio pipeline running — speaker admin at http://localhost:8001")
- print("[Bridge] Close this window to quit")
- try:
- ws_thread.join()
- except KeyboardInterrupt:
- pass
- if __name__ == "__main__":
- main()
|