瀏覽代碼

Embedding UI Updates

Benjamin Harris 1 月之前
父節點
當前提交
1dd8bd309a
共有 2 個文件被更改,包括 315 次插入22 次删除
  1. 292 4
      bridge/admin.py
  2. 23 18
      bridge/bridge.py

+ 292 - 4
bridge/admin.py

@@ -144,15 +144,29 @@ class AddBody(BaseModel):
 
 @app.get("/api/speakers")
 def api_list():
+    from embeddings import EmbeddingRegistry
+    registry = EmbeddingRegistry(
+        embeddings_dir=Path(__file__).parent / "embeddings",
+        speakers_file=SPEAKERS_FILE,
+    )
     speakers = _load()
     result = []
     for k, v in sorted(speakers.items()):
         if isinstance(v, dict):
-            entry = {"id": k, "name": v.get("name", ""), "role": v.get("role", ""),
-                     "location": v.get("location", ""), "has_recording": _recording_path(k) is not None}
+            entry = {
+                "id": k, "name": v.get("name", ""), "role": v.get("role", ""),
+                "location": v.get("location", ""),
+                "has_recording":  _recording_path(k) is not None,
+                "has_embedding":  registry.has(k),
+                "embedding_updated": v.get("embedding_updated"),
+            }
         else:
-            entry = {"id": k, "name": str(v), "role": "", "location": "",
-                     "has_recording": _recording_path(k) is not None}
+            entry = {
+                "id": k, "name": str(v), "role": "", "location": "",
+                "has_recording":  _recording_path(k) is not None,
+                "has_embedding":  registry.has(k),
+                "embedding_updated": None,
+            }
         result.append(entry)
     return {"speakers": result}
 
@@ -223,6 +237,98 @@ def api_playback(sid: str):
     return FileResponse(rec)
 
 
+# ── Voiceprint / embedding API ────────────────────────────────────────────────
+
+class EnrolBody(BaseModel):
+    audio_file: str   # filename in test_recordings/
+    start: float = 0.0
+    end: float | None = None
+
+class AutoEnrolBody(BaseModel):
+    audio_file: str
+    session: str | None = None
+    min_duration: float = 8.0
+
+
+def _get_registry():
+    from embeddings import EmbeddingRegistry
+    return EmbeddingRegistry(
+        embeddings_dir=Path(__file__).parent / "embeddings",
+        speakers_file=SPEAKERS_FILE,
+    )
+
+
+@app.get("/api/speakers/{sid}/voiceprint")
+def api_voiceprint_status(sid: str):
+    registry = _get_registry()
+    speakers = _load()
+    entry = speakers.get(sid, {})
+    if not isinstance(entry, dict):
+        entry = {}
+    return {
+        "has_embedding":     registry.has(sid),
+        "embedding_updated": entry.get("embedding_updated"),
+    }
+
+
+@app.post("/api/speakers/{sid}/voiceprint/enrol")
+async def api_voiceprint_enrol(sid: str, body: EnrolBody):
+    p = TEST_RECORDINGS_DIR / Path(body.audio_file).name
+    if not p.exists():
+        raise HTTPException(404, f"Recording not found: {body.audio_file}")
+    try:
+        registry = _get_registry()
+        emb = await asyncio.get_running_loop().run_in_executor(
+            None,
+            lambda: registry.extract_and_save(sid, p, body.start, body.end),
+        )
+        return {"ok": True, "dim": int(emb.shape[0])}
+    except Exception as exc:
+        raise HTTPException(500, str(exc))
+
+
+@app.post("/api/speakers/{sid}/voiceprint/auto-enrol")
+async def api_voiceprint_auto_enrol(sid: str, body: AutoEnrolBody):
+    p = TEST_RECORDINGS_DIR / Path(body.audio_file).name
+    if not p.exists():
+        raise HTTPException(404, f"Recording not found: {body.audio_file}")
+    try:
+        registry = _get_registry()
+        emb = await asyncio.get_running_loop().run_in_executor(
+            None,
+            lambda: registry.enrol_from_transcript(
+                sid, p,
+                session_id=body.session,
+                min_duration=body.min_duration,
+            ),
+        )
+        return {"ok": True, "dim": int(emb.shape[0])}
+    except ValueError as exc:
+        raise HTTPException(422, str(exc))
+    except Exception as exc:
+        raise HTTPException(500, str(exc))
+
+
+@app.delete("/api/speakers/{sid}/voiceprint")
+def api_voiceprint_delete(sid: str):
+    registry = _get_registry()
+    deleted  = registry.delete(sid)
+    return {"ok": True, "deleted": deleted}
+
+
+@app.get("/api/speakers/{sid}/voiceprint/candidates")
+def api_voiceprint_candidates(sid: str, session: str | None = None, min_dur: float = 8.0):
+    from embeddings import get_best_enrolment_segments
+    candidates = get_best_enrolment_segments(sid, session_id=session, min_duration=min_dur)
+    return {"candidates": candidates}
+
+
+@app.get("/api/voiceprints")
+def api_voiceprints_list():
+    registry = _get_registry()
+    return {"enrolled": registry.list_enrolled()}
+
+
 # ── Test playback state ───────────────────────────────────────────────────────
 
 _playback_task: asyncio.Task | None = None
@@ -560,6 +666,19 @@ HTML = """<!DOCTYPE html>
   }
   .pb-note { font-size: .78rem; color: #94a3b8; margin-top: 10px; }
   .pb-error { color: #dc2626; }
+
+  /* Voiceprint modal */
+  .vp-tabs { display: flex; gap: 0; border-bottom: 2px solid #e2e8f0; margin-bottom: 16px; }
+  .vp-tab  {
+    padding: 8px 16px; cursor: pointer; font-size: .9rem; font-weight: 500;
+    color: #64748b; border-bottom: 2px solid transparent; margin-bottom: -2px;
+  }
+  .vp-tab.active { color: #2563eb; border-bottom-color: #2563eb; }
+  .vp-panel { display: none; }
+  .vp-panel.active { display: block; }
+  .vp-row { display: flex; gap: 8px; align-items: flex-end; margin-bottom: 12px; }
+  .vp-row .field { flex: 1; margin: 0; }
+  .vp-status { font-size: .85rem; min-height: 20px; margin-top: 8px; }
 </style>
 </head>
 <body>
@@ -588,6 +707,7 @@ HTML = """<!DOCTYPE html>
         <th>Initials</th>
         <th>Name</th>
         <th>Locality</th>
+        <th>Voiceprint</th>
         <th>Voice Sample</th>
         <th>Actions</th>
       </tr>
@@ -713,6 +833,66 @@ HTML = """<!DOCTYPE html>
   </div>
 </div>
 
+<!-- Voiceprint modal -->
+<div class="modal-bg" id="vp-modal" onclick="closeVoiceprintModal(event)">
+  <div class="modal" style="width:480px">
+    <h2 id="vp-title">Voiceprint</h2>
+    <div class="vp-tabs">
+      <div class="vp-tab active" onclick="vpTab('auto')">Auto-Enrol</div>
+      <div class="vp-tab" onclick="vpTab('manual')">Manual Segment</div>
+    </div>
+
+    <!-- Auto-enrol tab -->
+    <div class="vp-panel active" id="vp-panel-auto">
+      <p style="color:#64748b;font-size:.85rem;margin-bottom:12px">
+        Picks the cleanest isolated segment from the transcript log automatically.
+        Run a test recording first to populate the log.
+      </p>
+      <div class="field">
+        <label>Source Recording</label>
+        <select id="vp-auto-file" class="pb-select"></select>
+      </div>
+      <div class="field">
+        <label>Min segment length (seconds)</label>
+        <input id="vp-auto-mindur" type="number" value="8" min="3" max="60" step="1">
+      </div>
+      <div class="vp-status" id="vp-auto-status"></div>
+      <div class="modal-actions">
+        <button class="btn btn-ghost" onclick="closeVoiceprintModal()">Cancel</button>
+        <button class="btn btn-primary" onclick="vpAutoEnrol()">Auto-Enrol</button>
+      </div>
+    </div>
+
+    <!-- Manual segment tab -->
+    <div class="vp-panel" id="vp-panel-manual">
+      <p style="color:#64748b;font-size:.85rem;margin-bottom:12px">
+        Extract a voiceprint from a specific time range in a recording.
+        Choose a section with clear isolated speech (10–30 seconds).
+      </p>
+      <div class="field">
+        <label>Source Recording</label>
+        <select id="vp-manual-file" class="pb-select"></select>
+      </div>
+      <div class="vp-row">
+        <div class="field">
+          <label>Start (seconds)</label>
+          <input id="vp-manual-start" type="number" value="0" min="0" step="1">
+        </div>
+        <div class="field">
+          <label>End (seconds)</label>
+          <input id="vp-manual-end" type="number" value="30" min="1" step="1">
+        </div>
+      </div>
+      <div class="vp-status" id="vp-manual-status"></div>
+      <div class="modal-actions">
+        <button class="btn btn-ghost" onclick="closeVoiceprintModal()">Cancel</button>
+        <button class="btn btn-danger btn-sm" onclick="vpDelete()" style="margin-right:auto">Delete Voiceprint</button>
+        <button class="btn btn-primary" onclick="vpManualEnrol()">Extract &amp; Save</button>
+      </div>
+    </div>
+  </div>
+</div>
+
 <!-- Voice sample upload modal -->
 <div class="modal-bg" id="upload-modal" onclick="closeUploadModal(event)">
   <div class="modal">
@@ -779,11 +959,19 @@ function makeRow(s) {
          src="/api/speakers/${encodeURIComponent(s.id)}/recording"></audio>`
     : `<span class="rec-badge rec-no">No sample</span>`;
 
+  const vpHtml = s.has_embedding
+    ? `<span class="rec-badge rec-yes" title="${esc(s.embedding_updated||'')}">&#128304; Enrolled</span>
+       <button class="btn btn-ghost btn-sm" style="margin-left:4px"
+               onclick="openVoiceprintModal('${esc(s.id)}','${esc(s.name)}')">&#9998;</button>`
+    : `<button class="btn btn-ghost btn-sm"
+               onclick="openVoiceprintModal('${esc(s.id)}','${esc(s.name)}')">Enrol</button>`;
+
   tr.innerHTML = `
     <td class="sid">${esc(s.id)}</td>
     <td>${mkEdit(esc(s.id), 'name', esc(s.name))}</td>
     <td>${mkEdit(esc(s.id), 'role', esc(s.role))}</td>
     <td>${mkEdit(esc(s.id), 'location', esc(s.location))}</td>
+    <td>${vpHtml}</td>
     <td>${recHtml}</td>
     <td>
       <div class="actions">
@@ -903,6 +1091,106 @@ async function deleteSpeaker(id) {
   else        { toast('Delete failed', true); }
 }
 
+// ── Voiceprint modal ──────────────────────────────────────────────────────────
+
+let vpTarget = null;
+
+function vpTab(name) {
+  document.querySelectorAll('.vp-tab').forEach((t, i) => {
+    const panels = ['auto', 'manual'];
+    t.classList.toggle('active', panels[i] === name);
+  });
+  document.getElementById('vp-panel-auto').classList.toggle('active', name === 'auto');
+  document.getElementById('vp-panel-manual').classList.toggle('active', name === 'manual');
+}
+
+function vpPopulateFiles() {
+  const files = pbFiles.map(f => f.filename);
+  ['vp-auto-file', 'vp-manual-file'].forEach(id => {
+    const sel = document.getElementById(id);
+    sel.innerHTML = files.length
+      ? files.map(f => `<option value="${esc(f)}">${esc(f)}</option>`).join('')
+      : '<option value="">— no recordings uploaded —</option>';
+  });
+}
+
+function openVoiceprintModal(id, name) {
+  vpTarget = id;
+  document.getElementById('vp-title').textContent = `Voiceprint — ${name}`;
+  document.getElementById('vp-auto-status').textContent   = '';
+  document.getElementById('vp-manual-status').textContent = '';
+  vpPopulateFiles();
+  vpTab('auto');
+  document.getElementById('vp-modal').classList.add('open');
+}
+
+function closeVoiceprintModal(e) {
+  if (!e || e.target === document.getElementById('vp-modal')) {
+    document.getElementById('vp-modal').classList.remove('open');
+    vpTarget = null;
+    load();
+  }
+}
+
+async function vpAutoEnrol() {
+  if (!vpTarget) return;
+  const file    = document.getElementById('vp-auto-file').value;
+  const minDur  = parseFloat(document.getElementById('vp-auto-mindur').value) || 8;
+  const status  = document.getElementById('vp-auto-status');
+  if (!file) { toast('No recording selected', true); return; }
+  status.style.color = '#2563eb';
+  status.textContent = 'Extracting voiceprint… (this may take 10–30 seconds)';
+  const res = await fetch(`/api/speakers/${encodeURIComponent(vpTarget)}/voiceprint/auto-enrol`, {
+    method: 'POST',
+    headers: {'Content-Type': 'application/json'},
+    body: JSON.stringify({audio_file: file, min_duration: minDur}),
+  });
+  if (res.ok) {
+    status.style.color = '#166534';
+    status.textContent = '✓ Voiceprint saved';
+    toast('Voiceprint enrolled');
+  } else {
+    const err = await res.json().catch(() => ({detail: 'Failed'}));
+    status.style.color = '#dc2626';
+    status.textContent = `✗ ${err.detail || 'Failed'}`;
+    toast(err.detail || 'Auto-enrol failed', true);
+  }
+}
+
+async function vpManualEnrol() {
+  if (!vpTarget) return;
+  const file   = document.getElementById('vp-manual-file').value;
+  const start  = parseFloat(document.getElementById('vp-manual-start').value) || 0;
+  const end    = parseFloat(document.getElementById('vp-manual-end').value)   || null;
+  const status = document.getElementById('vp-manual-status');
+  if (!file) { toast('No recording selected', true); return; }
+  status.style.color = '#2563eb';
+  status.textContent = `Extracting ${start}s–${end ?? 'end'}s… (may take 10–30 seconds)`;
+  const res = await fetch(`/api/speakers/${encodeURIComponent(vpTarget)}/voiceprint/enrol`, {
+    method: 'POST',
+    headers: {'Content-Type': 'application/json'},
+    body: JSON.stringify({audio_file: file, start, end}),
+  });
+  if (res.ok) {
+    status.style.color = '#166534';
+    status.textContent = '✓ Voiceprint saved';
+    toast('Voiceprint enrolled');
+  } else {
+    const err = await res.json().catch(() => ({detail: 'Failed'}));
+    status.style.color = '#dc2626';
+    status.textContent = `✗ ${err.detail || 'Failed'}`;
+    toast(err.detail || 'Enrol failed', true);
+  }
+}
+
+async function vpDelete() {
+  if (!vpTarget) return;
+  if (!confirm('Delete the stored voiceprint for this speaker?')) return;
+  const res = await fetch(`/api/speakers/${encodeURIComponent(vpTarget)}/voiceprint`, {method: 'DELETE'});
+  if (res.ok) { toast('Voiceprint deleted'); closeVoiceprintModal(); }
+  else        { toast('Delete failed', true); }
+}
+
 function openUploadModal(id, name) {
   uploadTarget = id;
   document.getElementById('upload-title').textContent = `Voice Sample — ${name}`;

+ 23 - 18
bridge/bridge.py

@@ -55,10 +55,10 @@ AUDIO_DEVICE: int | None = 12
 SPEAKERS_FILE = Path(__file__).parent / "speakers.json"
 
 DEFAULT_SPEAKERS: dict[str, dict] = {
-    "SPEAKER_00": {"name": "A.A.A", "role": "Serving Brother", "location": "Sydney","has_embedding": false,"embedding_updated": null,"colour": "#16a34a","notes": ""},
-    "SPEAKER_01": {"name": "A.A.A", "role": "Contributor",     "location": "London","has_embedding": false,"embedding_updated": null,"colour": "#16a34a","notes": ""},
-    "SPEAKER_02": {"name": "A.A.A", "role": "Contributor",     "location": "Hobart","has_embedding": false,"embedding_updated": null,"colour": "#16a34a","notes": ""},
-    "SPEAKER_03": {"name": "A.A.A", "role": "Contributor",     "location": "Perth","has_embedding": false,"embedding_updated": null,"colour": "#16a34a","notes": ""},
+    "SPEAKER_00": {"name": "A.A.A", "role": "Serving Brother", "location": "Sydney",  "has_embedding": False, "embedding_updated": None, "colour": "#16a34a", "notes": ""},
+    "SPEAKER_01": {"name": "A.A.A", "role": "Contributor",     "location": "London",  "has_embedding": False, "embedding_updated": None, "colour": "#16a34a", "notes": ""},
+    "SPEAKER_02": {"name": "A.A.A", "role": "Contributor",     "location": "Hobart",  "has_embedding": False, "embedding_updated": None, "colour": "#16a34a", "notes": ""},
+    "SPEAKER_03": {"name": "A.A.A", "role": "Contributor",     "location": "Perth",   "has_embedding": False, "embedding_updated": None, "colour": "#16a34a", "notes": ""},
 }
 
 # ── Audio injection queue ─────────────────────────────────────────────────────
@@ -354,9 +354,11 @@ async def audio_processor_loop(state: BridgeState, mqtt_client: mqtt.Client, eng
             )
             spk = getattr(last_seg, "speaker", None) if last_seg else None
             speaker_id = f"SPEAKER_{spk:02d}" if isinstance(spk, int) and spk >= 0 else None
+            seg_start  = float(getattr(last_seg, "start", 0.0) or 0.0) if last_seg else 0.0
+            seg_end    = float(getattr(last_seg, "end",   0.0) or 0.0) if last_seg else 0.0
 
             print(f"[Whisper] ({speaker_id or '?'}) {new_text}")
-            state.push_final(new_text, speaker_id, mqtt_client)
+            state.push_final(new_text, speaker_id, mqtt_client, seg_start, seg_end)
 
     async def _send_audio():
         with sd.InputStream(
@@ -364,38 +366,41 @@ async def audio_processor_loop(state: BridgeState, mqtt_client: mqtt.Client, eng
             dtype="int16", blocksize=BLOCKSIZE, callback=audio_callback,
         ):
             while True:
-                # Drain test audio injection first if available
+                # Injected test audio takes priority over live microphone
                 try:
-                    chunk = test_audio_queue.get_nowait()
-                except asyncio.QueueEmpty:
+                    chunk = _inject_queue.get_nowait()
+                except _stdlib_queue.Empty:
                     chunk = await audio_queue.get()
                 await audio_processor.process_audio(chunk)
 
-                # Accumulate PCM for live speaker matching
+                # Accumulate int16 PCM for live speaker matching
                 current_spk = state._raw_speaker_id
                 if current_spk and current_spk not in state._confirmed_ids:
                     if current_spk not in state._accumulators:
                         state._accumulators[current_spk] = \
                             state._embedding_registry.make_accumulator(min_seconds=5.0)
-                    
-                    # Convert float32 → int16 for the embedding accumulator
-                    chunk_i16 = (np.frombuffer(chunk, dtype=np.float32) * 32767).astype(np.int16).tobytes()
-                    state._accumulators[current_spk].push(chunk_i16)
-                    # state._accumulators[current_spk].push(chunk)
-                    
+                    state._accumulators[current_spk].push(chunk)
                     if state._accumulators[current_spk].ready():
                         try:
                             live_emb = state._accumulators[current_spk].extract_embedding()
                             match    = state._embedding_registry.find_match(live_emb)
                             if match:
                                 matched_id, score = match
-                                resolved = state.speaker_names.get(matched_id, matched_id)
+                                matched_entry = state.speaker_names.get(matched_id, {})
+                                matched_name  = (
+                                    matched_entry.get("name", matched_id)
+                                    if isinstance(matched_entry, dict)
+                                    else str(matched_entry)
+                                )
                                 print(
                                     f"[Embeddings] Auto-matched {current_spk} → "
-                                    f"{resolved} (score={score:.3f})"
+                                    f"{matched_name} (score={score:.3f})"
                                 )
                                 with state._lock:
-                                    state.speaker_names[current_spk] = resolved
+                                    entry = state.speaker_names.get(current_spk, {})
+                                    if not isinstance(entry, dict):
+                                        entry = {}
+                                    state.speaker_names[current_spk] = {**entry, "name": matched_name}
                                     state._confirmed_ids.add(current_spk)
                                     _write_speakers(state.speaker_names)
                             else: