Kaynağa Gözat

Streaming responses

Benjamin Harris 2 ay önce
ebeveyn
işleme
6353a6b016
2 değiştirilmiş dosya ile 268 ekleme ve 97 silme
  1. 150 51
      backend/app.py
  2. 118 46
      public/local_state-planning-scheme.php

+ 150 - 51
backend/app.py

@@ -196,32 +196,26 @@ def ollama_embed(text: str) -> List[float]:
         raise HTTPException(status_code=502, detail="Embedding service returned unexpected response")
     return data["embedding"]
 
-def ollama_chat(prompt: str) -> str:
-    """
-    Send a prompt to Ollama and return the generated text.
-
-    keep_alive MUST be a top-level key — putting it inside options{} causes
-    Ollama to silently ignore it and unload the model between requests.
+_OLLAMA_GENERATE_BODY = lambda prompt, stream: {
+    "model": CHAT_MODEL,
+    "prompt": prompt,
+    "stream": stream,
+    "options": {
+        "num_ctx": OLLAMA_NUM_CTX,
+        "num_predict": OLLAMA_NUM_PREDICT,
+        "temperature": OLLAMA_TEMPERATURE,
+        "top_p": 0.9,
+        "repeat_penalty": 1.1,
+    },
+    "keep_alive": -1,   # keep model resident in VRAM — MUST be top-level, not inside options
+}
 
-    num_ctx is fixed at 6144. Changing it between requests forces Ollama to
-    reload the model (KV cache is resized), adding ~3–5 s of cold-start latency.
-    """
+def ollama_chat(prompt: str) -> str:
+    """Send a prompt to Ollama and return the full generated text (non-streaming)."""
     try:
         r = requests.post(
             f"{OLLAMA_URL}/api/generate",
-            json={
-              "model": CHAT_MODEL,
-              "prompt": prompt,
-              "stream": False,
-              "options": {
-                "num_ctx": OLLAMA_NUM_CTX,
-                "num_predict": OLLAMA_NUM_PREDICT,
-                "temperature": OLLAMA_TEMPERATURE,
-                "top_p": 0.9,
-                "repeat_penalty": 1.1,
-              },
-              "keep_alive": -1,  # keep model resident in VRAM between requests
-            },
+            json=_OLLAMA_GENERATE_BODY(prompt, False),
             timeout=180
         )
         r.raise_for_status()
@@ -234,8 +228,45 @@ def ollama_chat(prompt: str) -> str:
     except requests.HTTPError as e:
         logger.error("Ollama chat HTTP %s: %s", e.response.status_code, e.response.text[:200])
         raise HTTPException(status_code=502, detail="LLM service error")
-    data = r.json()
-    return data.get("response", "").strip()
+    return r.json().get("response", "").strip()
+
+
+def ollama_chat_stream(prompt: str):
+    """
+    Yield raw text tokens from Ollama using streaming mode.
+    Each yielded value is a string (one or more characters).
+    Raises HTTPException on connection/HTTP errors before the first token.
+    """
+    try:
+        r = requests.post(
+            f"{OLLAMA_URL}/api/generate",
+            json=_OLLAMA_GENERATE_BODY(prompt, True),
+            stream=True,
+            timeout=180
+        )
+        r.raise_for_status()
+    except requests.Timeout:
+        logger.error("Ollama stream timeout (url=%s model=%s)", OLLAMA_URL, CHAT_MODEL)
+        raise HTTPException(status_code=503, detail="LLM service timed out")
+    except requests.ConnectionError:
+        logger.error("Ollama stream connection error (url=%s)", OLLAMA_URL)
+        raise HTTPException(status_code=503, detail="LLM service unavailable")
+    except requests.HTTPError as e:
+        logger.error("Ollama stream HTTP %s: %s", e.response.status_code, e.response.text[:200])
+        raise HTTPException(status_code=502, detail="LLM service error")
+
+    for line in r.iter_lines():
+        if not line:
+            continue
+        try:
+            chunk = json.loads(line)
+        except json.JSONDecodeError:
+            continue
+        token = chunk.get("response", "")
+        if token:
+            yield token
+        if chunk.get("done"):
+            break
 
 def _scroll_points(collection: str, qfilter=None, include_vector: bool=False, page_size: int=200):
     """
@@ -661,7 +692,7 @@ def _allowed(p: dict, scope: str, cslug: Optional[str]) -> bool:
         return corp == "tps" or (corp == "lps" and cslug and council == cslug)
     return True
 
-def do_ask(
+def _prepare_ask(
     query: str,
     top_k: int = 10,
     council: Optional[str] = None,
@@ -670,15 +701,16 @@ def do_ask(
     source_contains: Optional[str] = None,
     scope: str = "state_plus_local",
     section_id: Optional[str] = None,
-    context_only: bool = False,
-):
-    top_k = max(1, min(top_k, 30))  # clamp: at least 1, at most 30
+) -> tuple:
+    """
+    Embed the query, run Qdrant retrieval, and build the LLM prompt.
+    Returns (prompt, all_sources) — does NOT call the LLM.
+    Shared by do_ask() and the streaming endpoint.
+    """
+    top_k = max(1, min(top_k, 30))
     vec = ollama_embed(query)
     cslug = slug(council) if council else None
 
-    # Build the list of (section_heading, qdrant_filter) pairs based on scope.
-    # Each pair is searched independently so we can control the chunk budget
-    # per corpus — avoids TPS drowning out LPS results or vice versa.
     scopes: List[Tuple[str, qmodels.Filter]] = []
     if scope in ("state_only", "state_plus_local", "any"):
         scopes.append(("Tasmanian Planning Scheme (SPP)", filter_tps()))
@@ -689,11 +721,8 @@ def do_ask(
     if include_standards:
         scopes.append(("Australian Standards (AS)", filter_as()))
 
-    # Apply additional filename filter if requested (AND)
     scopes = [(name, with_source_contains(flt, source_contains)) for name, flt in scopes]
 
-    # Divide top_k across scopes: SPP and LPS each get ~1/3, the remainder
-    # is split evenly across any extra corpora (NCC, AS).
     per_spp = max(3, top_k // 3) if any(n.startswith("Tasmanian Planning Scheme") for n, _ in scopes) else 0
     per_lps = max(3, top_k // 3) if any(n.startswith("Local Provisions Schedule") for n, _ in scopes) else 0
     remaining = max(1, top_k - (per_spp + per_lps))
@@ -716,23 +745,15 @@ def do_ask(
         if lim <= 0:
             continue
         hits = q_search(vec, flt, lim)
-
-        # Guardrail: drop any hit that violates scope/council
         hits = [h for h in hits if _allowed(h.payload or {}, scope, cslug)]
-
         blocks, sources = render_blocks(hits)
         sections.append((name, blocks))
         all_sources.extend(sources)
 
     context = combine_context(sections)
-
     format_guide = _section_format_guide(
-        section_id,
-        section_title="(auto)",
-        ctx={
-            "council": council,           # from do_ask parameter
-            "planning_zones": [],         # populate if you have zone detection
-        }
+        section_id, section_title="(auto)",
+        ctx={"council": council, "planning_zones": []}
     )
 
     prompt = f"""
@@ -769,19 +790,36 @@ You are an expert Tasmanian planning and building compliance assistant with deep
 ## ANSWER:
 """.strip()
 
-    # BYOK mode: skip Ollama and return the context + prompt so the
-    # browser can call its own LLM provider (Claude, GPT, Grok, etc.)
+    return prompt, all_sources, sections
+
+
+def do_ask(
+    query: str,
+    top_k: int = 10,
+    council: Optional[str] = None,
+    include_ncc: bool = False,
+    include_standards: bool = False,
+    source_contains: Optional[str] = None,
+    scope: str = "state_plus_local",
+    section_id: Optional[str] = None,
+    context_only: bool = False,
+):
+    prompt, all_sources, sections = _prepare_ask(
+        query, top_k, council, include_ncc, include_standards,
+        source_contains, scope, section_id
+    )
+
     if context_only:
+        # Extract context from the prompt for BYOK mode
+        ctx_start = prompt.find("## CONTEXT")
+        ctx_end   = prompt.find("## QUESTION")
+        context   = prompt[ctx_start:ctx_end].strip() if ctx_start != -1 and ctx_end != -1 else ""
         return {
             "context_only": True,
             "context": context,
             "prompt": prompt,
             "sources": all_sources,
-            # Include the raw section blocks so the browser can inspect them
-            "sections": [
-                {"heading": name, "blocks": blocks}
-                for name, blocks in sections
-            ]
+            "sections": [{"heading": name, "blocks": blocks} for name, blocks in sections],
         }
 
     answer = ollama_chat(prompt)
@@ -874,3 +912,64 @@ def ask_post(request: Request, background_tasks: BackgroundTasks, body: AskBody)
     )
 
     return out
+
+
+# ---------------------------------------------------------------------------
+# /ask/stream — Server-Sent Events streaming endpoint
+# Embedding + retrieval run synchronously first (fast: ~0.5s).
+# Tokens stream as they arrive from Ollama — no waiting for full completion.
+#
+# SSE event types:
+#   {"type": "sources", "sources": [...]}   — sent first, before any tokens
+#   {"type": "token",   "text": "..."}      — one per Ollama chunk
+#   {"type": "done"}                        — stream complete
+#   {"type": "error",   "detail": "..."}    — on failure mid-stream
+# ---------------------------------------------------------------------------
+@app.post("/ask/stream")
+@limiter.limit("20/minute")
+def ask_stream(request: Request, body: AskBody):
+    _verify_demo_token_if_needed(request)
+    qtxt = (body.query or body.question or body.q or body.prompt or "").strip()
+    if not qtxt:
+        raise HTTPException(status_code=422, detail="Missing query/question")
+
+    # Embedding + retrieval + prompt building run before streaming starts.
+    # Sources are sent as the first SSE event so the UI can render them
+    # while tokens are still arriving from Ollama.
+    started = time.perf_counter()
+    prompt, all_sources, _ = _prepare_ask(
+        qtxt, body.top_k, body.council, body.include_ncc,
+        body.include_standards, body.source_contains, body.scope, body.section_id
+    )
+
+    ip  = request.client.host if request.client else "0.0.0.0"
+    sid = request.headers.get("X-TPR-SID") or request.cookies.get("sid") or ""
+
+    def generate():
+        tokens = []
+        try:
+            yield f"data: {json.dumps({'type': 'sources', 'sources': all_sources})}\n\n"
+            for token in ollama_chat_stream(prompt):
+                tokens.append(token)
+                yield f"data: {json.dumps({'type': 'token', 'text': token})}\n\n"
+            yield f"data: {json.dumps({'type': 'done'})}\n\n"
+        except Exception as e:
+            logger.error("[stream] error mid-stream: %s", e)
+            yield f"data: {json.dumps({'type': 'error', 'detail': str(e)})}\n\n"
+        finally:
+            # Telemetry written inline at stream end (~1ms SQLite write)
+            latency_ms = int((time.perf_counter() - started) * 1000)
+            _log_ask(
+                datetime.utcnow().isoformat(), sid, ip, qtxt, body.scope,
+                body.scope in ("state_only", "state_plus_local"),
+                latency_ms, CHAT_MODEL, all_sources, "".join(tokens),
+            )
+
+    return StreamingResponse(
+        generate(),
+        media_type="text/event-stream",
+        headers={
+            "Cache-Control": "no-cache",
+            "X-Accel-Buffering": "no",   # tell Nginx/Cloudflare not to buffer SSE
+        },
+    )

+ 118 - 46
public/local_state-planning-scheme.php

@@ -279,6 +279,16 @@
     .fb-btn.active-up { border-color: var(--accent); color: var(--accent); }
     .fb-btn.active-dn { border-color: var(--danger); color: var(--danger); }
 
+    /* Streaming cursor */
+    .streaming-cursor::after {
+      content: '▍';
+      display: inline-block;
+      color: var(--accent);
+      animation: blink 0.8s step-end infinite;
+      margin-left: 1px;
+    }
+    @keyframes blink { 0%,100%{opacity:1} 50%{opacity:0} }
+
     /* Thinking indicator */
     .thinking {
       display: flex; align-items: center; gap: 10px;
@@ -918,27 +928,57 @@ async function ask(queryOverride) {
       addToHistory(rawQuery);
 
     } else {
-      // ── Internal Ollama path (unchanged) ───────────────────────────
-      const res = await fetch(`${API}/ask`, {
+      // ── Internal Ollama path — streaming ──────────────────────────
+      const res = await fetch(`${API}/ask/stream`, {
         method: 'POST',
         headers: { 'Content-Type': 'application/json', 'X-TPR-SID': sessionId },
         body: JSON.stringify({ query, council: council || null, top_k: 8, scope })
       });
-      const raw = await res.text();
-      if (!res.ok) throw new Error(`HTTP ${res.status} — ${raw.slice(0,200)}`);
-      const data = JSON.parse(raw);
+      if (!res.ok) {
+        const raw = await res.text();
+        throw new Error(`HTTP ${res.status} — ${raw.slice(0,200)}`);
+      }
 
       thinkEl.remove();
-      lastSources = Array.isArray(data.sources) ? data.sources : [];
-
+      const msgEl      = appendStreamingMsg(rawQuery, scope);
+      const streamText = msgEl.querySelector('.stream-text');
+      const reader     = res.body.getReader();
+      const decoder    = new TextDecoder();
+      let buf = '', fullAnswer = '';
+      lastSources = [];
+
+      outer: while (true) {
+        const { done, value } = await reader.read();
+        if (done) break;
+        buf += decoder.decode(value, { stream: true });
+        const lines = buf.split('\n');
+        buf = lines.pop();                          // keep incomplete line
+        for (const line of lines) {
+          if (!line.startsWith('data: ')) continue;
+          let evt;
+          try { evt = JSON.parse(line.slice(6)); } catch { continue; }
+
+          if (evt.type === 'sources') {
+            lastSources = evt.sources || [];
+          } else if (evt.type === 'token') {
+            fullAnswer += evt.text;
+            streamText.textContent = fullAnswer;    // raw text while streaming
+            scrollBottom();
+          } else if (evt.type === 'done') {
+            break outer;
+          } else if (evt.type === 'error') {
+            throw new Error(evt.detail || 'Stream error');
+          }
+        }
+      }
+
+      finalizeStreamingMsg(msgEl, fullAnswer || 'No answer returned.', lastSources);
       const latencyMs = Math.round(performance.now() - startedAt);
       sendEvent('search_result', {
         latency_ms: latencyMs,
         topk: lastSources.slice(0,10).map(s => ({ id:`${s.source_file}#p${s.page}`, score:s.score })),
-        model: data.model || 'unknown', ok: true,
+        model: 'stream', ok: true,
       });
-
-      appendAssistantMsg(data.answer || 'No answer returned.', scope, lastSources, rawQuery, 'internal');
       addToHistory(rawQuery);
     }
   } catch(e) {
@@ -983,51 +1023,19 @@ function appendThinking() {
 }
 
 function appendAssistantMsg(answer, scope, sources, query, provider = 'internal') {
+  const msgId = `msg-${Date.now()}`;
   const div = document.createElement('div');
   div.className = 'msg assistant';
-
-  const providerNames = { internal:'Ollama', anthropic:'Claude', openai:'GPT-4o', grok:'Grok', ollama:'Local Ollama' };
-  const providerName = providerNames[provider] || provider;
-  const providerIcon = provider === 'internal' ? 'cpu' : 'key';
-  const scopeHtml = `
-    <div style="display:flex;gap:6px;margin-bottom:10px;flex-wrap:wrap;">
-      <div class="scope-badge"><i class="bi bi-filter"></i> ${esc(scope)}</div>
-      <div class="scope-badge" style="background:${provider !== 'internal' ? 'rgba(192,132,252,0.1)' : 'var(--accent-dim)'};border-color:${provider !== 'internal' ? 'rgba(192,132,252,0.25)' : 'rgba(45,220,138,0.2)'};color:${provider !== 'internal' ? '#c084fc' : 'var(--accent)'};">
-        <i class="bi bi-${providerIcon}"></i> ${esc(providerName)}
-      </div>
-    </div>`;
-  const answerHtml = md2html(answer);
-
-  let sourcesHtml = '';
-  if (sources && sources.length) {
-    const chips = sources.map((s, i) => {
-      const label = `${s.source_file} p.${s.page}`;
-      const score = typeof s.score === 'number' ? `<span class="source-score">${s.score.toFixed(2)}</span>` : '';
-      return `<span class="source-chip" data-cite="${esc(`${s.source_file}#p${s.page}`)}" data-index="${i}"
-        onclick="openSourceInViewer(${i})">
-        <i class="bi bi-file-earmark-text"></i>${esc(label)}${score}
-      </span>`;
-    }).join('');
-    sourcesHtml = `
-      <div class="msg-sources">
-        <div class="sources-label">Sources</div>
-        <div class="source-chips">${chips}</div>
-      </div>`;
-  }
-
-  const msgId = `msg-${Date.now()}`;
   div.id = msgId;
-  // Store context on the element so feedback() can read it without closure issues
   div.dataset.query    = query || '';
   div.dataset.scope    = scope || '';
   div.dataset.provider = provider || 'internal';
-  // Store answer as plain text (strip HTML tags) for the feedback payload
   div.dataset.answer   = answer.replace(/<[^>]*>/g, '').substring(0, 4000);
   div.innerHTML = `
     <div class="msg-role"><i class="bi bi-stars"></i> Assistant</div>
-    ${scopeHtml}
-    <div class="msg-content">${answerHtml}</div>
-    ${sourcesHtml}
+    ${_scopeHtml(scope, provider)}
+    <div class="msg-content">${md2html(answer)}</div>
+    ${_sourceChipsHtml(sources)}
     <div class="msg-feedback">
       <button class="fb-btn" onclick="feedback('${msgId}','up',this)"><i class="bi bi-hand-thumbs-up"></i> Helpful</button>
       <button class="fb-btn" onclick="feedback('${msgId}','down',this)"><i class="bi bi-hand-thumbs-down"></i> Not helpful</button>
@@ -1037,6 +1045,70 @@ function appendAssistantMsg(answer, scope, sources, query, provider = 'internal'
   scrollBottom();
 }
 
+/* ── Streaming message helpers ───────────────────────────────────────── */
+
+function _scopeHtml(scope, provider) {
+  const providerNames = { internal:'Ollama', anthropic:'Claude', openai:'GPT-4o', grok:'Grok', ollama:'Local Ollama' };
+  const providerName = providerNames[provider] || provider;
+  const providerIcon = provider === 'internal' ? 'cpu' : 'key';
+  return `<div style="display:flex;gap:6px;margin-bottom:10px;flex-wrap:wrap;">
+    <div class="scope-badge"><i class="bi bi-filter"></i> ${esc(scope)}</div>
+    <div class="scope-badge" style="background:${provider !== 'internal' ? 'rgba(192,132,252,0.1)' : 'var(--accent-dim)'};border-color:${provider !== 'internal' ? 'rgba(192,132,252,0.25)' : 'rgba(45,220,138,0.2)'};color:${provider !== 'internal' ? '#c084fc' : 'var(--accent)'};">
+      <i class="bi bi-${providerIcon}"></i> ${esc(providerName)}
+    </div>
+  </div>`;
+}
+
+function _sourceChipsHtml(sources) {
+  if (!sources || !sources.length) return '';
+  const chips = sources.map((s, i) => {
+    const label = `${s.source_file} p.${s.page}`;
+    const score = typeof s.score === 'number' ? `<span class="source-score">${s.score.toFixed(2)}</span>` : '';
+    return `<span class="source-chip" data-cite="${esc(`${s.source_file}#p${s.page}`)}" data-index="${i}" onclick="openSourceInViewer(${i})">
+      <i class="bi bi-file-earmark-text"></i>${esc(label)}${score}
+    </span>`;
+  }).join('');
+  return `<div class="msg-sources"><div class="sources-label">Sources</div><div class="source-chips">${chips}</div></div>`;
+}
+
+// Create a message container for a streaming response.
+// Returns the div so the caller can access .querySelector('.stream-text') to append tokens.
+function appendStreamingMsg(rawQuery, scope) {
+  hideEmpty();
+  const msgId = `msg-${Date.now()}`;
+  const div = document.createElement('div');
+  div.className = 'msg assistant';
+  div.id = msgId;
+  div.dataset.query    = rawQuery;
+  div.dataset.scope    = scope;
+  div.dataset.provider = 'internal';
+  div.dataset.answer   = '';
+  div.innerHTML = `
+    <div class="msg-role"><i class="bi bi-stars"></i> Assistant</div>
+    ${_scopeHtml(scope, 'internal')}
+    <div class="msg-content"><span class="stream-text streaming-cursor"></span></div>
+  `;
+  chatThread.appendChild(div);
+  scrollBottom();
+  return div;
+}
+
+// Called when the stream is complete: renders markdown, appends sources + feedback.
+function finalizeStreamingMsg(msgEl, fullAnswer, sources) {
+  const contentEl = msgEl.querySelector('.msg-content');
+  contentEl.innerHTML = md2html(fullAnswer);
+  msgEl.dataset.answer = fullAnswer.replace(/<[^>]*>/g, '').substring(0, 4000);
+
+  const msgId = msgEl.id;
+  const trailing = _sourceChipsHtml(sources) + `
+    <div class="msg-feedback">
+      <button class="fb-btn" onclick="feedback('${msgId}','up',this)"><i class="bi bi-hand-thumbs-up"></i> Helpful</button>
+      <button class="fb-btn" onclick="feedback('${msgId}','down',this)"><i class="bi bi-hand-thumbs-down"></i> Not helpful</button>
+    </div>`;
+  msgEl.insertAdjacentHTML('beforeend', trailing);
+  scrollBottom();
+}
+
 function appendErrorMsg(msg) {
   const div = document.createElement('div');
   div.className = 'msg assistant';