app.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975
  1. """
  2. app.py — FastAPI backend for tasplanning.report
  3. RAG pipeline:
  4. 1. Embed the user query via Ollama (nomic-embed-text)
  5. 2. Search Qdrant for the closest chunks, split by corpus/scope
  6. 3. Inject retrieved context into a structured prompt
  7. 4. Call Ollama (llama3.1:8b) and return the answer + source citations
  8. BYOK mode (context_only=True): skip step 4 and return the prompt so
  9. the browser can call its own LLM (Claude, GPT, Grok, local Ollama).
  10. Restart required after any change to this file:
  11. docker compose restart backend
  12. """
  13. import os, re, hmac, logging
  14. import json
  15. import requests
  16. import time
  17. logger = logging.getLogger(__name__)
  18. from typing import Optional, Literal, List, Tuple
  19. from fastapi import BackgroundTasks, FastAPI, Query, HTTPException, Request
  20. from fastapi.middleware.cors import CORSMiddleware
  21. from fastapi.responses import StreamingResponse
  22. from slowapi.middleware import SlowAPIMiddleware
  23. from slowapi.errors import RateLimitExceeded
  24. from limiter import limiter
  25. from fastapi.responses import JSONResponse
  26. from pydantic import BaseModel
  27. from qdrant_client import QdrantClient
  28. from qdrant_client.http import models as qmodels
  29. from collections import Counter, defaultdict
  30. from datetime import datetime
  31. from telemetry import router as telemetry_router, db, ip_hash
  32. # ---------------------------------------------------------------------------
  33. # Environment
  34. # ---------------------------------------------------------------------------
  35. OLLAMA_URL = os.getenv("OLLAMA_URL", "http://192.168.8.73:11434")
  36. QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
  37. OLLAMA_KEEP_ALIVE = os.getenv("OLLAMA_KEEP_ALIVE", "-1") # -1 = keep loaded forever
  38. COLLECTION = os.getenv("QDRANT_COLLECTION", "planning_docs")
  39. EMBED_MODEL = os.getenv("EMBED_MODEL", "nomic-embed-text")
  40. CHAT_MODEL = os.getenv("CHAT_MODEL", "llama3.1:8b-instruct-q4_K_M")
  41. CORS_ORIGINS = [o.strip() for o in os.getenv("CORS_ORIGINS", "https://tasplanning.report").split(",") if o.strip()]
  42. OLLAMA_NUM_CTX = int(os.getenv("OLLAMA_NUM_CTX", "6144"))
  43. OLLAMA_NUM_PREDICT = int(os.getenv("OLLAMA_NUM_PREDICT", "512"))
  44. OLLAMA_TEMPERATURE = float(os.getenv("OLLAMA_TEMPERATURE", "0.2"))
  45. # ---------------------------------------------------------------------------
  46. # Demo token gate (disabled by default)
  47. # Enable by setting DEMO_REQUIRE_TOKEN=1 and DEMO_TOKEN=<secret> in .env.
  48. # When enabled, every request to /ask and /admin/* must include:
  49. # Authorization: Bearer <DEMO_TOKEN>
  50. # ---------------------------------------------------------------------------
  51. DEMO_REQUIRE_TOKEN = os.getenv("DEMO_REQUIRE_TOKEN", "0") == "1"
  52. DEMO_TOKEN = os.getenv("DEMO_TOKEN", "")
  53. def _verify_demo_token_if_needed(request):
  54. if not DEMO_REQUIRE_TOKEN:
  55. return
  56. auth = request.headers.get("Authorization", "")
  57. parts = auth.split(" ", 1)
  58. if len(parts) != 2 or parts[0] != "Bearer":
  59. raise HTTPException(status_code=401, detail="Unauthorized")
  60. # compare_digest runs in constant time — prevents timing-based token guessing
  61. if not hmac.compare_digest(parts[1], DEMO_TOKEN):
  62. raise HTTPException(status_code=401, detail="Unauthorized")
  63. # ---------------------------------------------------------------------------
  64. # FastAPI app + CORS
  65. # ---------------------------------------------------------------------------
  66. app = FastAPI()
  67. # If CORS_ORIGINS is empty (shouldn't happen in production) fall back to a
  68. # wildcard with the tasplanning.report regex — credentials cannot be used
  69. # with wildcard origins so allow_credentials is gated on explicit origins.
  70. _origins = CORS_ORIGINS if CORS_ORIGINS else []
  71. _allow_all = len(_origins) == 0
  72. app.add_middleware(
  73. CORSMiddleware,
  74. allow_origins=_origins if not _allow_all else ["*"],
  75. allow_origin_regex=r"https://.*\.tasplanning\.report" if _allow_all else None,
  76. allow_credentials=not _allow_all, # credentials only when origins are explicit
  77. allow_methods=["GET", "POST", "OPTIONS"],
  78. allow_headers=["Content-Type", "Authorization", "X-TPR-SID"],
  79. expose_headers=["X-TPR-SID"],
  80. )
  81. qc = QdrantClient(url=QDRANT_URL)
  82. app.include_router(telemetry_router)
  83. @app.on_event("startup")
  84. def check_qdrant():
  85. try:
  86. qc.get_collection(COLLECTION)
  87. logger.info("Qdrant collection '%s' ready", COLLECTION)
  88. except Exception as e:
  89. logger.error("Qdrant startup check failed for collection '%s': %s", COLLECTION, e)
  90. # ---------------------------------------------------------------------------
  91. # Rate limiting (slowapi — in-memory, per IP)
  92. # Shared limiter instance lives in limiter.py to avoid circular imports with
  93. # telemetry.py, which also needs to decorate its own endpoints.
  94. # ---------------------------------------------------------------------------
  95. app.state.limiter = limiter # type: ignore
  96. app.add_middleware(SlowAPIMiddleware)
  97. @app.exception_handler(RateLimitExceeded)
  98. def ratelimit_handler(request, exc):
  99. return JSONResponse(status_code=429, content={"error":"rate_limited","detail":"Too many requests"})
  100. # ---------------------------------------------------------------------------
  101. # Feedback endpoint
  102. # Stores thumbs-up/down ratings alongside the query + answer for prompt tuning.
  103. # Fields are truncated before insert to keep the SQLite row size reasonable.
  104. # ---------------------------------------------------------------------------
  105. class FeedbackBody(BaseModel):
  106. verdict: str # "up" or "down"
  107. query: Optional[str] = None # the question that was asked
  108. answer: Optional[str] = None # the answer that was rated
  109. note: Optional[str] = None # optional free-text from thumbs-down
  110. sid: Optional[str] = None # session id from browser
  111. model: Optional[str] = None # which model answered
  112. scope: Optional[str] = None # which scope was used
  113. sources: Optional[list] = None # which sources were cited
  114. @app.post("/feedback")
  115. @limiter.limit("60/minute")
  116. def feedback(request: Request, body: FeedbackBody):
  117. if body.verdict not in ("up", "down"):
  118. raise HTTPException(status_code=422, detail="verdict must be 'up' or 'down'")
  119. ip = request.client.host if request.client else "0.0.0.0"
  120. sid = body.sid or request.headers.get("X-TPR-SID") or ""
  121. try:
  122. with db() as conn:
  123. conn.execute("""
  124. INSERT INTO feedback
  125. (ts, sid, ip_hash, verdict, query, answer, note, model, scope, sources_json)
  126. VALUES (?,?,?,?,?,?,?,?,?,?)
  127. """, (
  128. datetime.utcnow().isoformat(),
  129. sid, ip_hash(ip), body.verdict,
  130. _trunc(body.query or "", 2000, "feedback.query"),
  131. _trunc(body.answer or "", 8000, "feedback.answer"),
  132. _trunc(body.note or "", 1000, "feedback.note"),
  133. body.model or CHAT_MODEL,
  134. body.scope or "",
  135. _json_dumps(body.sources or []),
  136. ))
  137. conn.commit()
  138. except Exception as e:
  139. logger.exception("[feedback] telemetry insert failed")
  140. # Still return ok — don't surface DB errors to users
  141. return {"ok": True}
  142. # ---------------------------------------------------------------------------
  143. # Ollama helpers
  144. # ---------------------------------------------------------------------------
  145. def slug(s: Optional[str]) -> Optional[str]:
  146. """Normalise a council name to a URL-safe slug for Qdrant filter matching."""
  147. if not s:
  148. return None
  149. return re.sub(r'[^a-z0-9]+', '-', s.strip().lower()).strip('-') or None
  150. def ollama_embed(text: str) -> List[float]:
  151. """Call the Ollama embeddings API and return the float vector."""
  152. try:
  153. r = requests.post(
  154. f"{OLLAMA_URL}/api/embeddings",
  155. json={"model": EMBED_MODEL, "prompt": text},
  156. timeout=60
  157. )
  158. r.raise_for_status()
  159. except requests.Timeout:
  160. logger.error("Ollama embed timeout after 60s (url=%s model=%s)", OLLAMA_URL, EMBED_MODEL)
  161. raise HTTPException(status_code=503, detail="Embedding service timed out")
  162. except requests.ConnectionError:
  163. logger.error("Ollama embed connection error (url=%s)", OLLAMA_URL)
  164. raise HTTPException(status_code=503, detail="Embedding service unavailable")
  165. except requests.HTTPError as e:
  166. logger.error("Ollama embed HTTP %s: %s", e.response.status_code, e.response.text[:200])
  167. raise HTTPException(status_code=502, detail="Embedding service error")
  168. data = r.json()
  169. if "embedding" not in data:
  170. logger.error("Ollama embed unexpected response: %s", str(data)[:200])
  171. raise HTTPException(status_code=502, detail="Embedding service returned unexpected response")
  172. return data["embedding"]
  173. _OLLAMA_GENERATE_BODY = lambda prompt, stream: {
  174. "model": CHAT_MODEL,
  175. "prompt": prompt,
  176. "stream": stream,
  177. "options": {
  178. "num_ctx": OLLAMA_NUM_CTX,
  179. "num_predict": OLLAMA_NUM_PREDICT,
  180. "temperature": OLLAMA_TEMPERATURE,
  181. "top_p": 0.9,
  182. "repeat_penalty": 1.1,
  183. },
  184. "keep_alive": -1, # keep model resident in VRAM — MUST be top-level, not inside options
  185. }
  186. def ollama_chat(prompt: str) -> str:
  187. """Send a prompt to Ollama and return the full generated text (non-streaming)."""
  188. try:
  189. r = requests.post(
  190. f"{OLLAMA_URL}/api/generate",
  191. json=_OLLAMA_GENERATE_BODY(prompt, False),
  192. timeout=180
  193. )
  194. r.raise_for_status()
  195. except requests.Timeout:
  196. logger.error("Ollama chat timeout after 180s (url=%s model=%s)", OLLAMA_URL, CHAT_MODEL)
  197. raise HTTPException(status_code=503, detail="LLM service timed out")
  198. except requests.ConnectionError:
  199. logger.error("Ollama chat connection error (url=%s)", OLLAMA_URL)
  200. raise HTTPException(status_code=503, detail="LLM service unavailable")
  201. except requests.HTTPError as e:
  202. logger.error("Ollama chat HTTP %s: %s", e.response.status_code, e.response.text[:200])
  203. raise HTTPException(status_code=502, detail="LLM service error")
  204. return r.json().get("response", "").strip()
  205. def ollama_chat_stream(prompt: str):
  206. """
  207. Yield raw text tokens from Ollama using streaming mode.
  208. Each yielded value is a string (one or more characters).
  209. Raises HTTPException on connection/HTTP errors before the first token.
  210. """
  211. try:
  212. r = requests.post(
  213. f"{OLLAMA_URL}/api/generate",
  214. json=_OLLAMA_GENERATE_BODY(prompt, True),
  215. stream=True,
  216. timeout=180
  217. )
  218. r.raise_for_status()
  219. except requests.Timeout:
  220. logger.error("Ollama stream timeout (url=%s model=%s)", OLLAMA_URL, CHAT_MODEL)
  221. raise HTTPException(status_code=503, detail="LLM service timed out")
  222. except requests.ConnectionError:
  223. logger.error("Ollama stream connection error (url=%s)", OLLAMA_URL)
  224. raise HTTPException(status_code=503, detail="LLM service unavailable")
  225. except requests.HTTPError as e:
  226. logger.error("Ollama stream HTTP %s: %s", e.response.status_code, e.response.text[:200])
  227. raise HTTPException(status_code=502, detail="LLM service error")
  228. for line in r.iter_lines():
  229. if not line:
  230. continue
  231. try:
  232. chunk = json.loads(line)
  233. except json.JSONDecodeError:
  234. continue
  235. token = chunk.get("response", "")
  236. if token:
  237. yield token
  238. if chunk.get("done"):
  239. break
  240. def _scroll_points(collection: str, qfilter=None, include_vector: bool=False, page_size: int=200):
  241. """
  242. Page through all points in a collection using Qdrant's scroll API.
  243. Used by /admin/export which needs vectors and supports arbitrary collections.
  244. For payload-only scans over the default collection use _scan_points instead.
  245. """
  246. offset = None
  247. while True:
  248. points, offset = qc.scroll(
  249. collection_name=collection,
  250. limit=page_size,
  251. with_payload=True,
  252. with_vectors=include_vector,
  253. offset=offset,
  254. scroll_filter=qfilter
  255. )
  256. if not points:
  257. break
  258. for pt in points:
  259. yield pt
  260. if offset is None:
  261. break
  262. # ---------------------------------------------------------------------------
  263. # Health + utility endpoints
  264. # ---------------------------------------------------------------------------
  265. @app.get("/readyz")
  266. def readyz():
  267. return {"ok": True}
  268. def _normalize(q: Optional[str]) -> str:
  269. """Collapse whitespace and lowercase — used for dedup tracking in ask_logs."""
  270. return re.sub(r"\s+", " ", (q or "").strip().lower())
  271. def _json_dumps(o) -> str:
  272. return json.dumps(o, ensure_ascii=False, separators=(",",":"))
  273. def _trunc(s: str, limit: int, field: str) -> str:
  274. """Truncate `s` to `limit` chars. Logs a warning if truncation occurs so
  275. data loss is visible in docker logs rather than silently discarded."""
  276. if len(s) > limit:
  277. logger.warning("telemetry field %r truncated from %d to %d chars", field, len(s), limit)
  278. return s[:limit]
  279. return s
  280. # ---- Councils list (prefers payload 'council', falls back to filename token) ----
  281. @app.get("/councils")
  282. def councils():
  283. councils = set()
  284. offset = None
  285. # sample up to ~5k points (50 * 100)
  286. for _ in range(50):
  287. points, offset = qc.scroll(
  288. collection_name=COLLECTION,
  289. limit=100,
  290. with_payload=True,
  291. offset=offset
  292. )
  293. for pt in points:
  294. p = pt.payload or {}
  295. token = (p.get("council") or "").strip().lower()
  296. if not token:
  297. sf = (p.get("source_file") or "").lower()
  298. if sf:
  299. token = sf.replace(".pdf", "").split("_")[0].split("-")[0]
  300. if token:
  301. councils.add(token)
  302. if offset is None:
  303. break
  304. return sorted(councils)
  305. # ---------------------------------------------------------------------------
  306. # Qdrant filter builders
  307. # _mv — exact MatchValue (keyword field, case-sensitive)
  308. # _mt — MatchText (full-text / substring match)
  309. # ---------------------------------------------------------------------------
  310. def _mv(key: str, value: str) -> qmodels.FieldCondition:
  311. return qmodels.FieldCondition(key=key, match=qmodels.MatchValue(value=value))
  312. def _mt(key: str, text: str) -> qmodels.FieldCondition:
  313. return qmodels.FieldCondition(key=key, match=qmodels.MatchText(text=text))
  314. def filter_tps() -> qmodels.Filter:
  315. """TPS only, exact match on corpus."""
  316. return qmodels.Filter(must=[_mv("corpus", "tps")])
  317. def filter_lps(council: str) -> qmodels.Filter:
  318. """
  319. LPS for a specific council (slug), exact match on both fields.
  320. """
  321. cslug = slug(council) or council.lower()
  322. return qmodels.Filter(must=[_mv("corpus", "lps"), _mv("council", cslug)])
  323. def filter_ncc() -> qmodels.Filter:
  324. return qmodels.Filter(must=[_mv("corpus", "ncc")])
  325. def filter_as() -> qmodels.Filter:
  326. return qmodels.Filter(must=[_mv("corpus", "as")])
  327. def with_source_contains(flt: Optional[qmodels.Filter], source_contains: Optional[str]) -> qmodels.Filter:
  328. """AND an additional source_file substring condition onto an existing filter."""
  329. if not source_contains:
  330. return flt
  331. add = _mt("source_file", source_contains)
  332. if flt:
  333. # preserve existing must/should/must_not and AND the filename condition
  334. must = list(getattr(flt, "must", []) or [])
  335. must.append(add)
  336. return qmodels.Filter(
  337. must=must,
  338. should=getattr(flt, "should", None),
  339. must_not=getattr(flt, "must_not", None),
  340. )
  341. return qmodels.Filter(must=[add])
  342. def q_search(vec: List[float], flt: Optional[qmodels.Filter], limit: int):
  343. """ANN vector search — returns up to `limit` scored points."""
  344. results = qc.query_points(
  345. collection_name=COLLECTION,
  346. query=vec,
  347. limit=max(1, limit),
  348. query_filter=flt,
  349. with_payload=True,
  350. )
  351. return results.points
  352. def render_blocks(hits) -> Tuple[List[str], List[dict]]:
  353. """Convert raw Qdrant hits into plain-text context blocks and source dicts."""
  354. blocks, sources = [], []
  355. for h in hits:
  356. p = h.payload or {}
  357. src = f"{p.get('source_file')} (p.{p.get('page')} chunk {p.get('chunk_index')})"
  358. snippet = p.get("text", "")
  359. blocks.append(f"Source: {src}\nText: {snippet}")
  360. sources.append({
  361. "source_file": p.get("source_file"),
  362. "page": p.get("page"),
  363. "chunk_index": p.get("chunk_index"),
  364. "score": h.score
  365. })
  366. return blocks, sources
  367. def combine_context(sections: List[Tuple[str, List[str]]]) -> str:
  368. """Join all section blocks into a single context string for the prompt."""
  369. out = []
  370. for heading, blocks in sections:
  371. if not blocks:
  372. continue
  373. out.append(f"=== {heading} ===")
  374. out.extend(blocks)
  375. return "\n\n".join(out) if out else "No context found."
  376. def _scan_points(qfilter: Optional[qmodels.Filter] = None, max_pages: int = 10000, page_size: int = 200):
  377. """
  378. Iterate through ALL points in the default collection (payload only, no vectors).
  379. Used by /admin/stats, /admin/files, /admin/sample.
  380. For the current dataset size this is fine. If the collection grows very large,
  381. switch to a pre-aggregated summary stored in a separate Qdrant collection or
  382. a background job that writes counts to SQLite.
  383. """
  384. offset = None
  385. pages = 0
  386. while pages < max_pages:
  387. points, offset = qc.scroll(
  388. collection_name=COLLECTION,
  389. limit=page_size,
  390. with_payload=True,
  391. offset=offset,
  392. scroll_filter=qfilter
  393. )
  394. if not points:
  395. break
  396. for pt in points:
  397. yield pt
  398. pages += 1
  399. if offset is None:
  400. break
  401. # ---------------------------------------------------------------------------
  402. # Admin endpoints — require DEMO_TOKEN when DEMO_REQUIRE_TOKEN=1
  403. # All endpoints are rate-limited; /export is tighter (streams full DB).
  404. # ---------------------------------------------------------------------------
  405. @app.get("/admin/stats")
  406. @limiter.limit("30/minute")
  407. def admin_stats(request: Request, council: Optional[str] = None, corpus: Optional[str] = None):
  408. _verify_demo_token_if_needed(request)
  409. must = []
  410. if council:
  411. must.append(qmodels.FieldCondition(key="council", match=qmodels.MatchText(text=council.lower())))
  412. if corpus:
  413. must.append(qmodels.FieldCondition(key="corpus", match=qmodels.MatchText(text=corpus.lower())))
  414. qfilter = qmodels.Filter(must=must) if must else None
  415. corp = Counter()
  416. councils = Counter()
  417. total = 0
  418. for pt in _scan_points(qfilter=qfilter):
  419. p = pt.payload or {}
  420. corp[(p.get("corpus") or "").lower()] += 1
  421. if p.get("council"):
  422. councils[(p.get("council") or "").lower()] += 1
  423. total += 1
  424. return {
  425. "collection": COLLECTION,
  426. "total_points": total,
  427. "by_corpus": dict(corp),
  428. "by_council": dict(councils),
  429. "note": "Counts are points (chunks), not documents.",
  430. }
  431. @app.get("/admin/files")
  432. @limiter.limit("30/minute")
  433. def admin_files(request: Request, council: Optional[str] = None, corpus: Optional[str] = None, contains: Optional[str] = None, limit: int = 200):
  434. _verify_demo_token_if_needed(request)
  435. must = []
  436. if council:
  437. must.append(qmodels.FieldCondition(key="council", match=qmodels.MatchText(text=council.lower())))
  438. if corpus:
  439. must.append(qmodels.FieldCondition(key="corpus", match=qmodels.MatchText(text=corpus.lower())))
  440. if contains:
  441. must.append(qmodels.FieldCondition(key="source_file", match=qmodels.MatchText(text=contains)))
  442. qfilter = qmodels.Filter(must=must) if must else None
  443. files = defaultdict(lambda: {"points": 0, "corpus": None, "council": None, "pages": set()})
  444. for pt in _scan_points(qfilter=qfilter):
  445. p = pt.payload or {}
  446. f = (p.get("source_file") or "").strip()
  447. if not f:
  448. continue
  449. rec = files[f]
  450. rec["points"] += 1
  451. rec["corpus"] = rec["corpus"] or p.get("corpus")
  452. rec["council"] = rec["council"] or p.get("council")
  453. if p.get("page") is not None:
  454. rec["pages"].add(p["page"])
  455. # shape for output
  456. out = []
  457. for f, rec in files.items():
  458. out.append({
  459. "source_file": f,
  460. "corpus": rec["corpus"],
  461. "council": rec["council"],
  462. "points": rec["points"],
  463. "page_count_est": len(rec["pages"]) if rec["pages"] else None,
  464. })
  465. # sort by points desc, limit
  466. out.sort(key=lambda x: x["points"], reverse=True)
  467. return out[:max(1, limit)]
  468. @app.get("/admin/sample")
  469. @limiter.limit("30/minute")
  470. def admin_sample(request: Request, council: Optional[str] = None, corpus: Optional[str] = None, n: int = 5):
  471. _verify_demo_token_if_needed(request)
  472. must = []
  473. if council:
  474. must.append(qmodels.FieldCondition(key="council", match=qmodels.MatchText(text=council.lower())))
  475. if corpus:
  476. must.append(qmodels.FieldCondition(key="corpus", match=qmodels.MatchText(text=corpus.lower())))
  477. qfilter = qmodels.Filter(must=must) if must else None
  478. samples = []
  479. for pt in _scan_points(qfilter=qfilter):
  480. p = pt.payload or {}
  481. txt = (p.get("text") or "").strip()
  482. if not txt:
  483. continue
  484. samples.append({
  485. "source_file": p.get("source_file"),
  486. "corpus": p.get("corpus"),
  487. "council": p.get("council"),
  488. "page": p.get("page"),
  489. "chunk_index": p.get("chunk_index"),
  490. "preview": (txt[:400] + "…") if len(txt) > 400 else txt
  491. })
  492. if len(samples) >= max(1, n):
  493. break
  494. return samples
  495. @app.get("/admin/export")
  496. @limiter.limit("5/minute")
  497. def admin_export(
  498. request: Request,
  499. collection: str = COLLECTION,
  500. council: Optional[str] = None,
  501. corpus: Optional[str] = None,
  502. source_contains: Optional[str] = None,
  503. include_vector: bool = False,
  504. limit: Optional[int] = None
  505. ):
  506. _verify_demo_token_if_needed(request)
  507. must = []
  508. if council:
  509. must.append(qmodels.FieldCondition(key="council", match=qmodels.MatchText(text=council.lower())))
  510. if corpus:
  511. must.append(qmodels.FieldCondition(key="corpus", match=qmodels.MatchText(text=corpus.lower())))
  512. if source_contains:
  513. must.append(qmodels.FieldCondition(key="source_file", match=qmodels.MatchText(text=source_contains)))
  514. qfilter = qmodels.Filter(must=must) if must else None
  515. def gen():
  516. count = 0
  517. for pt in _scroll_points(collection, qfilter=qfilter, include_vector=include_vector):
  518. obj = {
  519. "id": str(getattr(pt, "id", None)),
  520. "payload": pt.payload or {},
  521. }
  522. if include_vector:
  523. obj["vector"] = pt.vector
  524. yield json.dumps(obj, ensure_ascii=False) + "\n"
  525. count += 1
  526. if limit and count >= limit:
  527. break
  528. filename = f'{collection}-{corpus or "all"}-{council or "all"}.ndjson'
  529. headers = {"Content-Disposition": f'attachment; filename="{filename}"'}
  530. return StreamingResponse(gen(), media_type="application/x-ndjson", headers=headers)
  531. # ---------------------------------------------------------------------------
  532. # Section-specific format guides
  533. # Each section_id maps to a tightly-scoped formatting instruction injected at
  534. # the end of the prompt. This steers the LLM output for structured report
  535. # sections without changing the core RAG prompt.
  536. # ---------------------------------------------------------------------------
  537. def _section_format_guide(section_id: Optional[str], section_title: str, ctx: dict) -> str:
  538. """
  539. Return strict, section-specific formatting guidance for the LLM.
  540. Keep these short, prescriptive, and impossible to ignore.
  541. """
  542. sid = (section_id or "").lower()
  543. # Utility bits from context
  544. zones = ctx.get("planning_zones") or []
  545. zone_label = ", ".join(zones) if zones else "the applicable zone"
  546. council_label = ctx.get("council") or ""
  547. # ---- ZONING (tables of clauses like your sample) ----
  548. if sid in {"zoning", "zoning-41", "zoning-42", "zoning-43", "zoning-44", "zoning-441", "zoning-442"}:
  549. return f"""
  550. FORMAT REQUIREMENTS (MANDATORY):
  551. - Produce a concise preface (≤ 2 sentences) naming {zone_label}.
  552. - Then include a Markdown table listing EACH visible clause found in CONTEXT that applies to the zone or LPS for **{council_label or 'the selected council'}**.
  553. - One row per subclause. If an A/P pair exists (e.g., A1 / P1), include both in the same row.
  554. - Columns (exact):
  555. | Clause | Topic | Acceptable Solution (A) | Performance Criteria (P) | Assessment | Source |
  556. - "Clause": the clause number (e.g., "12.3.1 A1" or "DOR-S1.7.1").
  557. - "Topic": short label extracted from the clause heading.
  558. - "Acceptable Solution (A)" and "Performance Criteria (P)": quote briefly—no more than 1–2 lines each.
  559. - "Assessment": state clearly whether the proposal meets A, or relies on P. If unknown from CONTEXT, write "TBC".
  560. - "Source": filename + page (from CONTEXT).
  561. - Only include clauses actually present in CONTEXT; NEVER invent clause numbers or text.
  562. - After the table, add a one-paragraph summary noting any items assessed as TBC or non-compliant.
  563. """.strip()
  564. # ---- Codes overview list/table (optional future) ----
  565. if sid.startswith("code-"):
  566. return """
  567. FORMAT REQUIREMENTS:
  568. - Start with one sentence stating which Code and why it is triggered.
  569. - Then provide a short checklist or table of the relevant sub-clauses (A vs P), with Source for each.
  570. - Keep to 150–250 words + table.
  571. """.strip()
  572. # ---- Permit Overview (concise triggers) ----
  573. if sid == "permit-overview":
  574. return """
  575. FORMAT REQUIREMENTS:
  576. - Produce 3 blocks with headings:
  577. 1) "Project Context" – 3–5 bullet points (site, proposal, zone).
  578. 2) "Applicable Provisions" – bullets grouping TPS SPP, LPS (selected council), and triggered Codes.
  579. 3) "Assessment Path" – bullet list of key clauses to assess next.
  580. - Cite specific clause numbers ONLY if present in CONTEXT (include Source).
  581. """.strip()
  582. # ---- Default (no special formatting) ----
  583. return """
  584. FORMAT REQUIREMENTS:
  585. - Use concise Markdown with short paragraphs and bullets as needed.
  586. - Cite briefly (filename + page) when quoting a control.
  587. """.strip()
  588. # ---------------------------------------------------------------------------
  589. # /ask — core RAG endpoint
  590. # ---------------------------------------------------------------------------
  591. class AskBody(BaseModel):
  592. # accept multiple keys from different frontends
  593. query: Optional[str] = None
  594. question: Optional[str] = None
  595. q: Optional[str] = None
  596. prompt: Optional[str] = None
  597. top_k: int = 10
  598. council: Optional[str] = None
  599. include_ncc: bool = False
  600. include_standards: bool = False
  601. source_contains: Optional[str] = None
  602. scope: Literal['state_plus_local','local_only','state_only','any'] = 'state_plus_local'
  603. section_id: Optional[str] = None
  604. # BYOK mode: return context blocks without calling Ollama.
  605. # The browser then calls its own LLM with the returned context + prompt.
  606. context_only: bool = False
  607. def _allowed(p: dict, scope: str, cslug: Optional[str]) -> bool:
  608. """
  609. Secondary guardrail applied after the Qdrant vector search.
  610. Qdrant filters are the primary gate; this catches any edge-case leakage
  611. (e.g. MatchText returning a partial match across corpora).
  612. """
  613. corp = (p.get("corpus") or "").lower()
  614. council = (p.get("council") or "").lower()
  615. if scope == "local_only":
  616. return corp == "lps" and cslug and council == cslug
  617. if scope == "state_only":
  618. return corp == "tps"
  619. if scope == "state_plus_local":
  620. return corp == "tps" or (corp == "lps" and cslug and council == cslug)
  621. return True
  622. def _prepare_ask(
  623. query: str,
  624. top_k: int = 10,
  625. council: Optional[str] = None,
  626. include_ncc: bool = False,
  627. include_standards: bool = False,
  628. source_contains: Optional[str] = None,
  629. scope: str = "state_plus_local",
  630. section_id: Optional[str] = None,
  631. ) -> tuple:
  632. """
  633. Embed the query, run Qdrant retrieval, and build the LLM prompt.
  634. Returns (prompt, all_sources) — does NOT call the LLM.
  635. Shared by do_ask() and the streaming endpoint.
  636. """
  637. top_k = max(1, min(top_k, 30))
  638. vec = ollama_embed(query)
  639. cslug = slug(council) if council else None
  640. scopes: List[Tuple[str, qmodels.Filter]] = []
  641. if scope in ("state_only", "state_plus_local", "any"):
  642. scopes.append(("Tasmanian Planning Scheme (SPP)", filter_tps()))
  643. if scope in ("local_only", "state_plus_local", "any") and cslug:
  644. scopes.append((f"Local Provisions Schedule — {cslug}", filter_lps(cslug)))
  645. if include_ncc:
  646. scopes.append(("National Construction Code (NCC)", filter_ncc()))
  647. if include_standards:
  648. scopes.append(("Australian Standards (AS)", filter_as()))
  649. scopes = [(name, with_source_contains(flt, source_contains)) for name, flt in scopes]
  650. per_spp = max(3, top_k // 3) if any(n.startswith("Tasmanian Planning Scheme") for n, _ in scopes) else 0
  651. per_lps = max(3, top_k // 3) if any(n.startswith("Local Provisions Schedule") for n, _ in scopes) else 0
  652. remaining = max(1, top_k - (per_spp + per_lps))
  653. extra_scopes = sum(1 for n, _ in scopes if not (n.startswith("Tasmanian Planning Scheme") or n.startswith("Local Provisions Schedule")))
  654. per_extra = max(1, remaining // max(1, extra_scopes)) if extra_scopes else 0
  655. limits: List[int] = []
  656. for name, _ in scopes:
  657. if name.startswith("Tasmanian Planning Scheme"):
  658. limits.append(per_spp)
  659. elif name.startswith("Local Provisions Schedule"):
  660. limits.append(per_lps)
  661. else:
  662. limits.append(per_extra)
  663. sections: List[Tuple[str, List[str]]] = []
  664. all_sources: List[dict] = []
  665. for (name, flt), lim in zip(scopes, limits):
  666. if lim <= 0:
  667. continue
  668. hits = q_search(vec, flt, lim)
  669. hits = [h for h in hits if _allowed(h.payload or {}, scope, cslug)]
  670. blocks, sources = render_blocks(hits)
  671. sections.append((name, blocks))
  672. all_sources.extend(sources)
  673. context = combine_context(sections)
  674. format_guide = _section_format_guide(
  675. section_id, section_title="(auto)",
  676. ctx={"council": council, "planning_zones": []}
  677. )
  678. prompt = f"""
  679. You are an expert Tasmanian planning and building compliance assistant with deep knowledge of the Tasmanian Planning Scheme structure.
  680. ## AUTHORITY ORDER — always apply in this sequence:
  681. 1. State Planning Provisions (SPP) — the statewide baseline. Cite clause numbers exactly.
  682. 2. Local Provisions Schedule (LPS) for the selected council — overrides SPP where it differs.
  683. 3. National Construction Code (NCC) — building controls only, keep separate from planning.
  684. 4. Australian Standards — only when directly referenced by a clause in CONTEXT.
  685. ## STRICT RULES:
  686. - Use ONLY information present in CONTEXT below. Never invent clause numbers, standards, or measurements.
  687. - If CONTEXT does not contain enough information to answer, say: "The provided context does not cover this — check the TPSO viewer directly at tpso.planning.tas.gov.au"
  688. - Every specific standard or requirement you state MUST include its source: (filename, p.N)
  689. - Quote clause text briefly (1–2 lines max) then explain in plain English.
  690. - Distinguish clearly between Acceptable Solutions (A) and Performance Criteria (P).
  691. ## OUTPUT FORMAT:
  692. - Use Markdown: ## for main headings, ### for sub-headings, **bold** for clause numbers.
  693. - For setbacks, parking rates, or multiple standards: use a Markdown table with columns: Clause | Requirement | A or P | Source
  694. - End every response with a ## Sources section listing each cited document and page.
  695. - Keep answers concise but complete — do not pad or repeat information.
  696. - Professional planning language; avoid informal phrasing.
  697. ## CONTEXT (retrieved from Tasmanian Planning Scheme documents):
  698. {context}
  699. {format_guide}
  700. ## QUESTION:
  701. {query}
  702. ## ANSWER:
  703. """.strip()
  704. return prompt, all_sources, sections
  705. def do_ask(
  706. query: str,
  707. top_k: int = 10,
  708. council: Optional[str] = None,
  709. include_ncc: bool = False,
  710. include_standards: bool = False,
  711. source_contains: Optional[str] = None,
  712. scope: str = "state_plus_local",
  713. section_id: Optional[str] = None,
  714. context_only: bool = False,
  715. ):
  716. prompt, all_sources, sections = _prepare_ask(
  717. query, top_k, council, include_ncc, include_standards,
  718. source_contains, scope, section_id
  719. )
  720. if context_only:
  721. # Extract context from the prompt for BYOK mode
  722. ctx_start = prompt.find("## CONTEXT")
  723. ctx_end = prompt.find("## QUESTION")
  724. context = prompt[ctx_start:ctx_end].strip() if ctx_start != -1 and ctx_end != -1 else ""
  725. return {
  726. "context_only": True,
  727. "context": context,
  728. "prompt": prompt,
  729. "sources": all_sources,
  730. "sections": [{"heading": name, "blocks": blocks} for name, blocks in sections],
  731. }
  732. answer = ollama_chat(prompt)
  733. return {"answer": answer, "sources": all_sources}
  734. def _log_ask(ts, sid, ip, query, scope, allow_tps, latency_ms, model, sources, answer):
  735. """Write one ask_logs row. Runs in a background task — never raises to the caller."""
  736. try:
  737. topk = [{"id": f"{s.get('source_file')}#p{s.get('page')}", "score": s.get("score")} for s in sources]
  738. with db() as conn:
  739. conn.execute("""
  740. INSERT INTO ask_logs
  741. (ts, sid, ip_hash, query, normalized, scope, allow_tps, latency_ms,
  742. model, ok, topk_json, tokens_in, tokens_out, answer)
  743. VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?)
  744. """, (
  745. ts, sid, ip_hash(ip), query, _normalize(query),
  746. scope, int(allow_tps),
  747. latency_ms, model, 1, _json_dumps(topk), 0, 0,
  748. _trunc(answer, 8000, "ask_logs.answer"),
  749. ))
  750. conn.commit()
  751. except Exception:
  752. logger.exception("[telemetry] ask insert failed")
  753. @app.get("/ask")
  754. @limiter.limit("20/minute")
  755. def ask_get(
  756. request: Request,
  757. background_tasks: BackgroundTasks,
  758. query: str = Query(..., description="User question"),
  759. top_k: int = 10,
  760. council: Optional[str] = None,
  761. include_ncc: bool = False,
  762. include_standards: bool = False,
  763. source_contains: Optional[str] = None,
  764. scope: str = "state_plus_local",
  765. section_id: Optional[str] = None,
  766. context_only: bool = False,
  767. ):
  768. _verify_demo_token_if_needed(request)
  769. started = time.perf_counter()
  770. out = do_ask(query, top_k, council, include_ncc, include_standards, source_contains, scope, section_id, context_only)
  771. latency_ms = int((time.perf_counter() - started) * 1000)
  772. ip = request.client.host if request.client else "0.0.0.0"
  773. sid = request.headers.get("X-TPR-SID") or request.cookies.get("sid") or ""
  774. background_tasks.add_task(
  775. _log_ask,
  776. datetime.utcnow().isoformat(), sid, ip, query, scope,
  777. scope in ("state_only", "state_plus_local"),
  778. latency_ms, CHAT_MODEL, out.get("sources") or [], out.get("answer") or "",
  779. )
  780. return out
  781. @app.post("/ask")
  782. @limiter.limit("20/minute")
  783. def ask_post(request: Request, background_tasks: BackgroundTasks, body: AskBody):
  784. _verify_demo_token_if_needed(request)
  785. qtxt = (body.query or body.question or body.q or body.prompt or "").strip()
  786. if not qtxt:
  787. raise HTTPException(status_code=422, detail="Missing query/question")
  788. started = time.perf_counter()
  789. out = do_ask(
  790. query=qtxt,
  791. top_k=body.top_k,
  792. council=body.council,
  793. include_ncc=body.include_ncc,
  794. include_standards=body.include_standards,
  795. source_contains=body.source_contains,
  796. scope=body.scope,
  797. section_id=body.section_id,
  798. context_only=body.context_only,
  799. )
  800. latency_ms = int((time.perf_counter() - started) * 1000)
  801. ip = request.client.host if request.client else "0.0.0.0"
  802. sid = request.headers.get("X-TPR-SID") or request.cookies.get("sid") or ""
  803. background_tasks.add_task(
  804. _log_ask,
  805. datetime.utcnow().isoformat(), sid, ip, qtxt, body.scope,
  806. body.scope in ("state_only", "state_plus_local"),
  807. latency_ms, CHAT_MODEL, out.get("sources") or [], out.get("answer") or "",
  808. )
  809. return out
  810. # ---------------------------------------------------------------------------
  811. # /ask/stream — Server-Sent Events streaming endpoint
  812. # Embedding + retrieval run synchronously first (fast: ~0.5s).
  813. # Tokens stream as they arrive from Ollama — no waiting for full completion.
  814. #
  815. # SSE event types:
  816. # {"type": "sources", "sources": [...]} — sent first, before any tokens
  817. # {"type": "token", "text": "..."} — one per Ollama chunk
  818. # {"type": "done"} — stream complete
  819. # {"type": "error", "detail": "..."} — on failure mid-stream
  820. # ---------------------------------------------------------------------------
  821. @app.post("/ask/stream")
  822. @limiter.limit("20/minute")
  823. def ask_stream(request: Request, body: AskBody):
  824. _verify_demo_token_if_needed(request)
  825. qtxt = (body.query or body.question or body.q or body.prompt or "").strip()
  826. if not qtxt:
  827. raise HTTPException(status_code=422, detail="Missing query/question")
  828. # Embedding + retrieval + prompt building run before streaming starts.
  829. # Sources are sent as the first SSE event so the UI can render them
  830. # while tokens are still arriving from Ollama.
  831. started = time.perf_counter()
  832. prompt, all_sources, _ = _prepare_ask(
  833. qtxt, body.top_k, body.council, body.include_ncc,
  834. body.include_standards, body.source_contains, body.scope, body.section_id
  835. )
  836. ip = request.client.host if request.client else "0.0.0.0"
  837. sid = request.headers.get("X-TPR-SID") or request.cookies.get("sid") or ""
  838. def generate():
  839. tokens = []
  840. try:
  841. yield f"data: {json.dumps({'type': 'sources', 'sources': all_sources})}\n\n"
  842. for token in ollama_chat_stream(prompt):
  843. tokens.append(token)
  844. yield f"data: {json.dumps({'type': 'token', 'text': token})}\n\n"
  845. yield f"data: {json.dumps({'type': 'done'})}\n\n"
  846. except Exception as e:
  847. logger.error("[stream] error mid-stream: %s", e)
  848. yield f"data: {json.dumps({'type': 'error', 'detail': str(e)})}\n\n"
  849. finally:
  850. # Telemetry written inline at stream end (~1ms SQLite write)
  851. latency_ms = int((time.perf_counter() - started) * 1000)
  852. _log_ask(
  853. datetime.utcnow().isoformat(), sid, ip, qtxt, body.scope,
  854. body.scope in ("state_only", "state_plus_local"),
  855. latency_ms, CHAT_MODEL, all_sources, "".join(tokens),
  856. )
  857. return StreamingResponse(
  858. generate(),
  859. media_type="text/event-stream",
  860. headers={
  861. "Cache-Control": "no-cache",
  862. "X-Accel-Buffering": "no", # tell Nginx/Cloudflare not to buffer SSE
  863. },
  864. )