|
|
@@ -21,7 +21,7 @@ import time
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
from typing import Optional, Literal, List, Tuple
|
|
|
-from fastapi import FastAPI, Query, HTTPException, Request
|
|
|
+from fastapi import BackgroundTasks, FastAPI, Query, HTTPException, Request
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
from fastapi.responses import StreamingResponse
|
|
|
from slowapi.middleware import SlowAPIMiddleware
|
|
|
@@ -45,6 +45,9 @@ COLLECTION = os.getenv("QDRANT_COLLECTION", "planning_docs")
|
|
|
EMBED_MODEL = os.getenv("EMBED_MODEL", "nomic-embed-text")
|
|
|
CHAT_MODEL = os.getenv("CHAT_MODEL", "llama3.1:8b-instruct-q4_K_M")
|
|
|
CORS_ORIGINS = [o.strip() for o in os.getenv("CORS_ORIGINS", "https://tasplanning.report").split(",") if o.strip()]
|
|
|
+OLLAMA_NUM_CTX = int(os.getenv("OLLAMA_NUM_CTX", "6144"))
|
|
|
+OLLAMA_NUM_PREDICT = int(os.getenv("OLLAMA_NUM_PREDICT", "512"))
|
|
|
+OLLAMA_TEMPERATURE = float(os.getenv("OLLAMA_TEMPERATURE", "0.2"))
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
|
# Demo token gate (disabled by default)
|
|
|
@@ -91,6 +94,14 @@ app.add_middleware(
|
|
|
qc = QdrantClient(url=QDRANT_URL)
|
|
|
app.include_router(telemetry_router)
|
|
|
|
|
|
+@app.on_event("startup")
|
|
|
+def check_qdrant():
|
|
|
+ try:
|
|
|
+ qc.get_collection(COLLECTION)
|
|
|
+ logger.info("Qdrant collection '%s' ready", COLLECTION)
|
|
|
+ except Exception as e:
|
|
|
+ logger.error("Qdrant startup check failed for collection '%s': %s", COLLECTION, e)
|
|
|
+
|
|
|
# ---------------------------------------------------------------------------
|
|
|
# Rate limiting (slowapi — in-memory, per IP)
|
|
|
# Shared limiter instance lives in limiter.py to avoid circular imports with
|
|
|
@@ -203,9 +214,9 @@ def ollama_chat(prompt: str) -> str:
|
|
|
"prompt": prompt,
|
|
|
"stream": False,
|
|
|
"options": {
|
|
|
- "num_ctx": 6144, # was 8192,
|
|
|
- "num_predict": 512,
|
|
|
- "temperature": 0.2,
|
|
|
+ "num_ctx": OLLAMA_NUM_CTX,
|
|
|
+ "num_predict": OLLAMA_NUM_PREDICT,
|
|
|
+ "temperature": OLLAMA_TEMPERATURE,
|
|
|
"top_p": 0.9,
|
|
|
"repeat_penalty": 1.1,
|
|
|
},
|
|
|
@@ -777,10 +788,32 @@ You are an expert Tasmanian planning and building compliance assistant with deep
|
|
|
return {"answer": answer, "sources": all_sources}
|
|
|
|
|
|
|
|
|
+def _log_ask(ts, sid, ip, query, scope, allow_tps, latency_ms, model, sources, answer):
|
|
|
+ """Write one ask_logs row. Runs in a background task — never raises to the caller."""
|
|
|
+ try:
|
|
|
+ topk = [{"id": f"{s.get('source_file')}#p{s.get('page')}", "score": s.get("score")} for s in sources]
|
|
|
+ with db() as conn:
|
|
|
+ conn.execute("""
|
|
|
+ INSERT INTO ask_logs
|
|
|
+ (ts, sid, ip_hash, query, normalized, scope, allow_tps, latency_ms,
|
|
|
+ model, ok, topk_json, tokens_in, tokens_out, answer)
|
|
|
+ VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?)
|
|
|
+ """, (
|
|
|
+ ts, sid, ip_hash(ip), query, _normalize(query),
|
|
|
+ scope, int(allow_tps),
|
|
|
+ latency_ms, model, 1, _json_dumps(topk), 0, 0,
|
|
|
+ _trunc(answer, 8000, "ask_logs.answer"),
|
|
|
+ ))
|
|
|
+ conn.commit()
|
|
|
+ except Exception:
|
|
|
+ logger.exception("[telemetry] ask insert failed")
|
|
|
+
|
|
|
+
|
|
|
@app.get("/ask")
|
|
|
@limiter.limit("20/minute")
|
|
|
def ask_get(
|
|
|
request: Request,
|
|
|
+ background_tasks: BackgroundTasks,
|
|
|
query: str = Query(..., description="User question"),
|
|
|
top_k: int = 10,
|
|
|
council: Optional[str] = None,
|
|
|
@@ -797,36 +830,21 @@ def ask_get(
|
|
|
out = do_ask(query, top_k, council, include_ncc, include_standards, source_contains, scope, section_id, context_only)
|
|
|
latency_ms = int((time.perf_counter() - started) * 1000)
|
|
|
|
|
|
- # Telemetry insert — never allowed to break the response
|
|
|
- try:
|
|
|
- ip = request.client.host if request.client else "0.0.0.0"
|
|
|
- sid = request.headers.get("X-TPR-SID") or request.cookies.get("sid") or ""
|
|
|
- allow_tps = scope in ("state_only", "state_plus_local")
|
|
|
- topk = [{"id": f"{s.get('source_file')}#p{s.get('page')}", "score": s.get("score")} for s in (out.get("sources") or [])]
|
|
|
-
|
|
|
- with db() as conn:
|
|
|
- conn.execute("""
|
|
|
- INSERT INTO ask_logs
|
|
|
- (ts, sid, ip_hash, query, normalized, scope, allow_tps, latency_ms,
|
|
|
- model, ok, topk_json, tokens_in, tokens_out, answer)
|
|
|
- VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?)
|
|
|
- """, (
|
|
|
- datetime.utcnow().isoformat(),
|
|
|
- sid, ip_hash(ip), query, _normalize(query),
|
|
|
- scope, int(allow_tps),
|
|
|
- latency_ms, CHAT_MODEL, 1, _json_dumps(topk), 0, 0,
|
|
|
- _trunc(out.get("answer") or "", 8000, "ask_get.answer"),
|
|
|
- ))
|
|
|
- conn.commit()
|
|
|
- except Exception as e:
|
|
|
- logger.exception("[telemetry] ask_get insert failed")
|
|
|
+ ip = request.client.host if request.client else "0.0.0.0"
|
|
|
+ sid = request.headers.get("X-TPR-SID") or request.cookies.get("sid") or ""
|
|
|
+ background_tasks.add_task(
|
|
|
+ _log_ask,
|
|
|
+ datetime.utcnow().isoformat(), sid, ip, query, scope,
|
|
|
+ scope in ("state_only", "state_plus_local"),
|
|
|
+ latency_ms, CHAT_MODEL, out.get("sources") or [], out.get("answer") or "",
|
|
|
+ )
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
@app.post("/ask")
|
|
|
@limiter.limit("20/minute")
|
|
|
-def ask_post(request: Request, body: AskBody):
|
|
|
+def ask_post(request: Request, background_tasks: BackgroundTasks, body: AskBody):
|
|
|
_verify_demo_token_if_needed(request)
|
|
|
qtxt = (body.query or body.question or body.q or body.prompt or "").strip()
|
|
|
if not qtxt:
|
|
|
@@ -846,28 +864,13 @@ def ask_post(request: Request, body: AskBody):
|
|
|
)
|
|
|
latency_ms = int((time.perf_counter() - started) * 1000)
|
|
|
|
|
|
- # Telemetry insert — never allowed to break the response
|
|
|
- try:
|
|
|
- ip = request.client.host if request.client else "0.0.0.0"
|
|
|
- sid = request.headers.get("X-TPR-SID") or request.cookies.get("sid") or ""
|
|
|
- allow_tps = body.scope in ("state_only", "state_plus_local")
|
|
|
- topk = [{"id": f"{s.get('source_file')}#p{s.get('page')}", "score": s.get("score")} for s in (out.get("sources") or [])]
|
|
|
-
|
|
|
- with db() as conn:
|
|
|
- conn.execute("""
|
|
|
- INSERT INTO ask_logs
|
|
|
- (ts, sid, ip_hash, query, normalized, scope, allow_tps, latency_ms,
|
|
|
- model, ok, topk_json, tokens_in, tokens_out, answer)
|
|
|
- VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?)
|
|
|
- """, (
|
|
|
- datetime.utcnow().isoformat(),
|
|
|
- sid, ip_hash(ip), qtxt, _normalize(qtxt),
|
|
|
- body.scope, int(allow_tps),
|
|
|
- latency_ms, CHAT_MODEL, 1, _json_dumps(topk), 0, 0,
|
|
|
- _trunc(out.get("answer") or "", 8000, "ask_post.answer"),
|
|
|
- ))
|
|
|
- conn.commit()
|
|
|
- except Exception as e:
|
|
|
- logger.exception("[telemetry] ask_post insert failed")
|
|
|
+ ip = request.client.host if request.client else "0.0.0.0"
|
|
|
+ sid = request.headers.get("X-TPR-SID") or request.cookies.get("sid") or ""
|
|
|
+ background_tasks.add_task(
|
|
|
+ _log_ask,
|
|
|
+ datetime.utcnow().isoformat(), sid, ip, qtxt, body.scope,
|
|
|
+ body.scope in ("state_only", "state_plus_local"),
|
|
|
+ latency_ms, CHAT_MODEL, out.get("sources") or [], out.get("answer") or "",
|
|
|
+ )
|
|
|
|
|
|
return out
|