ingest_ollama.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524
  1. import os
  2. import re
  3. import argparse
  4. import fitz # PyMuPDF
  5. import uuid
  6. import json
  7. import requests
  8. from typing import List, Tuple, Optional
  9. from pdfminer.high_level import extract_text as pdfminer_extract_text
  10. from qdrant_client import QdrantClient
  11. from qdrant_client.http import models as qmodels
  12. from tqdm import tqdm
  13. # import ocrmypdf
  14. import shutil, subprocess, tempfile
  15. # ---------------- CLI ----------------
  16. def get_args():
  17. p = argparse.ArgumentParser(description="Ingest PDFs -> chunks -> embeddings via Ollama -> Qdrant")
  18. p.add_argument("--pdf-dir", default="./pdfs", help="Folder with PDFs")
  19. p.add_argument("--qdrant-url", default=os.getenv("QDRANT_URL", "http://localhost:6333"))
  20. p.add_argument("--collection", default=os.getenv("QDRANT_COLLECTION", "planning_docs"))
  21. p.add_argument("--ollama-url", default=os.getenv("OLLAMA_URL", "http://192.168.8.73:11434"))
  22. p.add_argument("--embed-model", default=os.getenv("EMBED_MODEL", "nomic-embed-text"))
  23. p.add_argument("--chunk-words", type=int, default=400)
  24. p.add_argument("--overlap-words", type=int, default=60)
  25. p.add_argument("--batch-size", type=int, default=128)
  26. # New controls
  27. p.add_argument("--corpus", choices=["auto", "tps", "lps", "ncc", "as"], default="auto",
  28. help="Corpus label for all PDFs in this run (auto = infer per file).")
  29. p.add_argument("--council", default=None,
  30. help="Council token for LPS (e.g., 'brighton'). If omitted, inferred from filename.")
  31. p.add_argument("--wipe-existing", action="store_true",
  32. help="Delete existing points for each file before re-ingesting.")
  33. # Optional structured extraction
  34. p.add_argument("--extract-structure", action="store_true",
  35. help="Run LLM extraction on clause-start chunks and store JSON in payload.")
  36. p.add_argument("--extract-model", default=os.getenv("EXTRACT_MODEL", "llama3.1:8b"),
  37. help="Ollama model for structured extraction.")
  38. p.add_argument("--extract-max-per-file", type=int, default=200,
  39. help="Max extractions per file (safety cap).")
  40. p.add_argument("--json-retries", type=int, default=1,
  41. help="Retries on bad JSON from the LLM.")
  42. return p.parse_args()
  43. # ---------------- Helpers ----------------
  44. # Clause-like tokens such as C2.6.1 or 13.4.2
  45. _CLAUSE_RX = re.compile(r"\b(?:C\d+(?:\.\d+){1,3}|\d+(?:\.\d+){1,3})\b")
  46. _AS_RX = re.compile(r"\bAS\s*([0-9]{2,4}(?:\.[0-9]+)?)\s*[:\-]?\s*(\d{4})?", re.IGNORECASE)
  47. def make_point_id(source_file: str, page_no: int, chunk_index: int) -> str:
  48. ns = uuid.uuid5(uuid.NAMESPACE_DNS, "modulos.ai/ingest")
  49. return str(uuid.uuid5(ns, f"{source_file}:{page_no}:{chunk_index}"))
  50. def clean_text(s: str) -> str:
  51. return " ".join(s.replace("\u00ad", "").split())
  52. def extract_text_pymupdf(pdf_path: str) -> List[Tuple[int, str]]:
  53. """Return list of (page_number, text) with 1-based page numbers."""
  54. doc = fitz.open(pdf_path)
  55. pages = []
  56. for i, page in enumerate(doc):
  57. try:
  58. txt = page.get_text("text")
  59. except Exception:
  60. txt = ""
  61. pages.append((i + 1, txt or ""))
  62. return pages
  63. def fallback_pdfminer(pdf_path: str) -> str:
  64. try:
  65. return pdfminer_extract_text(pdf_path) or ""
  66. except Exception:
  67. return ""
  68. def try_repair_pdf(pdf_path: str) -> Optional[str]:
  69. """Try to rewrite/repair the PDF via qpdf or ghostscript and return a temp path if successful."""
  70. tmpdir = tempfile.mkdtemp(prefix="pdf_fix_")
  71. fixed = None
  72. qpdf_bin = shutil.which("qpdf")
  73. if qpdf_bin:
  74. out1 = os.path.join(tmpdir, "qpdf_fixed.pdf")
  75. r = subprocess.run([qpdf_bin, "--linearize", pdf_path, out1],
  76. stdout=subprocess.PIPE, stderr=subprocess.PIPE)
  77. if r.returncode == 0 and os.path.exists(out1) and os.path.getsize(out1) > 0:
  78. fixed = out1
  79. if not fixed:
  80. gs_bin = shutil.which("gs") or shutil.which("ghostscript")
  81. if gs_bin:
  82. out2 = os.path.join(tmpdir, "gs_fixed.pdf")
  83. cmd = [
  84. gs_bin, "-dSAFER", "-dBATCH", "-dNOPAUSE",
  85. "-sDEVICE=pdfwrite", "-dCompatibilityLevel=1.6",
  86. f"-sOutputFile={out2}", pdf_path
  87. ]
  88. r = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
  89. if r.returncode == 0 and os.path.exists(out2) and os.path.getsize(out2) > 0:
  90. fixed = out2
  91. return fixed
  92. def extract_with_pypdfium2(pdf_path: str) -> List[Tuple[int, str]]:
  93. """Optional fallback using PyPDFium2 if available."""
  94. try:
  95. import pypdfium2 as pdfium
  96. except Exception:
  97. return []
  98. pages = []
  99. try:
  100. doc = pdfium.PdfDocument(pdf_path)
  101. for i in range(len(doc)):
  102. page = doc.get_page(i)
  103. textpage = page.get_textpage()
  104. txt = textpage.get_text_bounded()
  105. pages.append((i + 1, txt or ""))
  106. textpage.close()
  107. page.close()
  108. except Exception:
  109. return []
  110. return pages
  111. def pagewise_or_fallback(pdf_path: str) -> List[Tuple[int, str]]:
  112. """
  113. Robust text extraction:
  114. 1) PyMuPDF pagewise
  115. 2) pdfminer whole-doc
  116. 3) PyPDFium2 pagewise (if installed)
  117. 4) repair via qpdf/ghostscript and retry 1–3
  118. 5) (optional) OCR via ocrmypdf and retry 1–3
  119. Returns list of (page_number, text) or [] if totally unextractable.
  120. """
  121. def nonempty(pages): return sum(len(t) for _, t in pages) > 50
  122. # 1) PyMuPDF
  123. pages = extract_text_pymupdf(pdf_path)
  124. if nonempty(pages):
  125. return pages
  126. # 2) pdfminer whole
  127. whole = fallback_pdfminer(pdf_path)
  128. if whole.strip():
  129. return [(0, clean_text(whole))]
  130. # 3) PyPDFium2
  131. pf = extract_with_pypdfium2(pdf_path)
  132. if nonempty(pf):
  133. return pf
  134. # 4) repair & retry
  135. fixed = try_repair_pdf(pdf_path)
  136. if fixed:
  137. pages = extract_text_pymupdf(fixed)
  138. if nonempty(pages):
  139. return pages
  140. whole = fallback_pdfminer(fixed)
  141. if whole.strip():
  142. return [(0, clean_text(whole))]
  143. pf = extract_with_pypdfium2(fixed)
  144. if nonempty(pf):
  145. return pf
  146. # 5) optional OCR (enable by flipping the flag below)
  147. OCR_IF_NEEDED = True # set True if you want automatic OCR fallback
  148. if OCR_IF_NEEDED and shutil.which("ocrmypdf"):
  149. jobs = str(max(1, (os.cpu_count() or 1))) # e.g., cap it: min(os.cpu_count() or 1, 6)
  150. ocrd = tempfile.mkdtemp(prefix="pdf_ocr_")
  151. ocr_out = os.path.join(ocrd, "ocr.pdf")
  152. def run_ocr(args):
  153. return subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
  154. try:
  155. # First pass: only OCR pages that lack text
  156. base_cmd = [
  157. "ocrmypdf",
  158. "--skip-text",
  159. "--rotate-pages",
  160. "--jobs", jobs,
  161. "--optimize", "3",
  162. "--output-type", "pdf",
  163. "-l", "eng",
  164. pdf_path, ocr_out,
  165. ]
  166. r = run_ocr(base_cmd)
  167. need_force = not (r.returncode == 0 and os.path.exists(ocr_out) and os.path.getsize(ocr_out) > 0)
  168. # If the output exists but extraction is still empty, we’ll detect that below
  169. def try_extract(p):
  170. _pages = extract_text_pymupdf(p)
  171. if sum(len(t) for _, t in _pages) > 50:
  172. return _pages
  173. _whole = fallback_pdfminer(p)
  174. if _whole.strip():
  175. return [(0, clean_text(_whole))]
  176. return []
  177. pages = []
  178. if r.returncode == 0 and os.path.exists(ocr_out):
  179. pages = try_extract(ocr_out)
  180. # If skip-text didn’t help (e.g., corrupt text layer), force OCR once
  181. if not pages:
  182. ocr_force = os.path.join(ocrd, "ocr_force.pdf")
  183. r2 = run_ocr([
  184. "ocrmypdf",
  185. "--force-ocr",
  186. "--rotate-pages",
  187. "--jobs", jobs,
  188. "--optimize", "3",
  189. "--output-type", "pdf",
  190. "-l", "eng",
  191. pdf_path, ocr_force,
  192. ])
  193. if r2.returncode == 0 and os.path.exists(ocr_force):
  194. pages = try_extract(ocr_force)
  195. if pages:
  196. return pages
  197. finally:
  198. # Clean up temp files
  199. try:
  200. shutil.rmtree(ocrd, ignore_errors=True)
  201. except Exception:
  202. pass
  203. # final: let caller see empty result
  204. return pages
  205. def chunk_words(text: str, chunk_size=400, overlap=60) -> List[str]:
  206. words = text.split()
  207. chunks = []
  208. step = max(1, chunk_size - overlap)
  209. for i in range(0, len(words), step):
  210. chunk = " ".join(words[i:i + chunk_size])
  211. if chunk.strip():
  212. chunks.append(chunk)
  213. return chunks
  214. def ollama_embed(ollama_url: str, model: str, text: str) -> List[float]:
  215. r = requests.post(f"{ollama_url}/api/embeddings",
  216. json={"model": model, "prompt": text},
  217. timeout=60)
  218. r.raise_for_status()
  219. data = r.json()
  220. if "embedding" not in data:
  221. raise RuntimeError(f"Ollama embeddings error: {data}")
  222. return data["embedding"]
  223. def determine_dim(ollama_url: str, model: str) -> int:
  224. emb = ollama_embed(ollama_url, model, "test")
  225. return len(emb)
  226. def ensure_collection(qc: QdrantClient, name: str, dim: int):
  227. try:
  228. qc.create_collection(
  229. collection_name=name,
  230. vectors_config=qmodels.VectorParams(size=dim, distance=qmodels.Distance.COSINE)
  231. )
  232. print(f"[qdrant] created collection '{name}' size={dim}")
  233. except Exception:
  234. # already exists or compatible — upsert will fail loudly if size mismatches
  235. pass
  236. def ensure_payload_indexes(qc: QdrantClient, collection: str):
  237. for field in ("corpus", "council", "source_file", "clauses", "clause_id", "doc_type", "title"):
  238. try:
  239. qc.create_payload_index(
  240. collection_name=collection,
  241. field_name=field,
  242. field_schema=qmodels.PayloadSchemaType.KEYWORD
  243. )
  244. except Exception:
  245. pass
  246. def delete_points_for_file(qc: QdrantClient, collection: str, source_file: str):
  247. filt = qmodels.Filter(must=[
  248. qmodels.FieldCondition(key="source_file", match=qmodels.MatchText(text=source_file))
  249. ])
  250. qc.delete(collection_name=collection, points_selector=qmodels.FilterSelector(filter=filt))
  251. print(f"[qdrant] wiped existing points for {source_file}")
  252. # -------- Structured extraction (optional) --------
  253. def llm_extract_structured(ollama_url: str, model: str, corpus: str, text: str,
  254. retries: int = 1) -> Optional[dict]:
  255. schema_hint = {
  256. "tps": "Tasmanian Planning Scheme (SPP)",
  257. "lps": "Local Provisions Schedule (LPS)",
  258. "ncc": "National Construction Code",
  259. "as": "Australian Standard"
  260. }.get(corpus, "Regulatory Document")
  261. template = {
  262. "doc_type": corpus,
  263. "clause_id": "",
  264. "title": "",
  265. "purpose": "",
  266. "application_scope": "",
  267. "performance_criteria": [],
  268. "acceptable_solutions": [],
  269. "exemptions": "",
  270. "cross_refs": [],
  271. "version": "",
  272. "volume": "",
  273. "year": ""
  274. }
  275. prompt = f"""
  276. You extract structured fields from regulatory clauses. Return ONLY valid JSON matching this template:
  277. {json.dumps(template, ensure_ascii=False, indent=2)}
  278. Guidance:
  279. - Prefer exact clause IDs and titles if present.
  280. - If a field is absent, keep it empty or [].
  281. - Use short faithful fragments; no speculation.
  282. - Only include cross_refs explicitly present.
  283. EXCERPT:
  284. {text}
  285. """.strip()
  286. for _ in range(max(1, retries + 1)):
  287. r = requests.post(f"{ollama_url}/api/generate", json={
  288. "model": model,
  289. "prompt": prompt,
  290. "stream": False,
  291. "options": {
  292. "temperature": 0.1,
  293. "num_ctx": 4096,
  294. "num_predict": 400,
  295. "keep_alive": "5m"
  296. }
  297. }, timeout=120)
  298. r.raise_for_status()
  299. resp = r.json().get("response", "").strip().strip("`").strip()
  300. try:
  301. obj = json.loads(resp)
  302. if isinstance(obj, dict) and obj.get("doc_type"):
  303. return obj
  304. except Exception:
  305. pass
  306. return None
  307. # ---------------- Classification & metadata ----------------
  308. def infer_corpus_and_meta(fname: str, forced_corpus: str = "auto", council_cli: Optional[str] = None):
  309. """
  310. Returns (corpus, council, extra_meta)
  311. corpus ∈ {tps, lps, ncc, as}
  312. council only when LPS
  313. extra_meta may include ncc_volume, as_code, as_year, year
  314. """
  315. lower = fname.lower()
  316. base = os.path.splitext(os.path.basename(lower))[0]
  317. corpus = None
  318. council = None
  319. extra = {}
  320. if forced_corpus != "auto":
  321. corpus = forced_corpus
  322. else:
  323. if ("local" in lower and "provision" in lower) or "lps" in lower:
  324. corpus = "lps"
  325. elif ("ncc" in lower or
  326. "national construction code" in lower or
  327. "volume one" in lower or "volume two" in lower or "volume three" in lower):
  328. corpus = "ncc"
  329. elif base.startswith("as") or "australian standard" in lower or _AS_RX.search(fname):
  330. corpus = "as"
  331. else:
  332. corpus = "tps"
  333. if corpus == "lps":
  334. council = (council_cli or base.split("_")[0].split("-")[0]).strip()
  335. if corpus == "ncc":
  336. if ("volume one" in lower or "vol 1" in lower or "vol1" in lower or
  337. re.search(r"\bvolume\s*1\b", lower)):
  338. extra["ncc_volume"] = "Volume One"
  339. elif ("volume two" in lower or "vol 2" in lower or "vol2" in lower or
  340. re.search(r"\bvolume\s*2\b", lower)):
  341. extra["ncc_volume"] = "Volume Two"
  342. elif ("volume three" in lower or "vol 3" in lower or "vol3" in lower or
  343. re.search(r"\bvolume\s*3\b", lower)):
  344. extra["ncc_volume"] = "Volume Three"
  345. m_year = re.search(r"\b(20\d{2}|19\d{2})\b", lower)
  346. if m_year:
  347. extra["year"] = m_year.group(1)
  348. if corpus == "as":
  349. m = _AS_RX.search(fname)
  350. if m:
  351. extra["as_code"] = m.group(1)
  352. if m.group(2):
  353. extra["as_year"] = m.group(2)
  354. return corpus, council, extra
  355. # ---------------- Upsert ----------------
  356. def upsert_batch(qc: QdrantClient, collection: str, ids, vectors, payloads):
  357. qc.upsert(
  358. collection_name=collection,
  359. points=qmodels.Batch(ids=ids, vectors=vectors, payloads=payloads)
  360. )
  361. # ---------------- Main ingest ----------------
  362. def ingest_folder(pdf_dir: str, ollama_url: str, embed_model: str, qdrant_url: str, collection: str,
  363. chunk_words_size: int, overlap_words: int, batch_size: int,
  364. forced_corpus: str, council_cli: Optional[str], wipe_existing: bool,
  365. extract_structure: bool, extract_model: str, extract_max_per_file: int, json_retries: int):
  366. qc = QdrantClient(url=qdrant_url)
  367. dim = determine_dim(ollama_url, embed_model)
  368. ensure_collection(qc, collection, dim)
  369. ensure_payload_indexes(qc, collection)
  370. pdfs = [f for f in os.listdir(pdf_dir) if f.lower().endswith(".pdf")]
  371. if not pdfs:
  372. print(f"No PDFs found in {pdf_dir}")
  373. return
  374. for fname in pdfs:
  375. path = os.path.join(pdf_dir, fname)
  376. corpus, council, extra = infer_corpus_and_meta(fname, forced_corpus, council_cli)
  377. print(f"\nProcessing {path} -> corpus={corpus} council={council or '-'} meta={extra}")
  378. if wipe_existing:
  379. delete_points_for_file(qc, collection, fname)
  380. pages = pagewise_or_fallback(path)
  381. all_chunks = []
  382. for page_no, txt in pages:
  383. txt = clean_text(txt)
  384. if not txt.strip():
  385. continue
  386. chunks = chunk_words(txt, chunk_size=chunk_words_size, overlap=overlap_words)
  387. for idx, ch in enumerate(chunks):
  388. all_chunks.append((page_no, idx, ch))
  389. if not all_chunks:
  390. print(f" !! No usable text extracted from {fname}")
  391. continue
  392. ids, vectors, payloads = [], [], []
  393. extracts_done = 0
  394. for page_no, cidx, ch in tqdm(all_chunks, desc=f"Embedding {fname}", unit="chunk"):
  395. emb = ollama_embed(ollama_url, embed_model, ch)
  396. point_id = make_point_id(fname, page_no, cidx)
  397. meta = {
  398. "source_file": fname,
  399. "page": page_no,
  400. "chunk_index": cidx,
  401. "text": ch,
  402. "corpus": corpus,
  403. "doc_type": corpus, # for filtering even without structured JSON
  404. }
  405. if council:
  406. meta["council"] = council
  407. # carry extra fields (NCC volume, AS code/year, year hints)
  408. for k, v in extra.items():
  409. meta[k] = v
  410. # detect a likely clause ID in this chunk
  411. m_clause = _CLAUSE_RX.search(ch)
  412. if m_clause:
  413. meta["clause_id"] = m_clause.group(0)
  414. # optional structured extraction (only on clause-starty chunks)
  415. if extract_structure and m_clause and extracts_done < extract_max_per_file:
  416. obj = llm_extract_structured(ollama_url, extract_model, corpus, ch, retries=json_retries)
  417. if obj:
  418. meta["structured"] = obj
  419. # If the model gave a title/ids, surface a few for filtering
  420. if obj.get("title"):
  421. meta["title"] = obj["title"]
  422. if obj.get("clause_id") and "clause_id" not in meta:
  423. meta["clause_id"] = obj["clause_id"]
  424. extracts_done += 1
  425. ids.append(point_id)
  426. vectors.append(emb)
  427. payloads.append(meta)
  428. if len(ids) >= batch_size:
  429. upsert_batch(qc, collection, ids, vectors, payloads)
  430. ids, vectors, payloads = [], [], []
  431. if ids:
  432. upsert_batch(qc, collection, ids, vectors, payloads)
  433. print(f" ✓ Ingested {len(all_chunks)} chunks from {fname}")
  434. # ---------------- Entrypoint ----------------
  435. if __name__ == "__main__":
  436. args = get_args()
  437. ingest_folder(
  438. pdf_dir=args.pdf_dir,
  439. ollama_url=args.ollama_url,
  440. embed_model=args.embed_model,
  441. qdrant_url=args.qdrant_url,
  442. collection=args.collection,
  443. chunk_words_size=args.chunk_words,
  444. overlap_words=args.overlap_words,
  445. batch_size=args.batch_size,
  446. forced_corpus=args.corpus,
  447. council_cli=args.council,
  448. wipe_existing=args.wipe_existing,
  449. extract_structure=args.extract_structure,
  450. extract_model=args.extract_model,
  451. extract_max_per_file=args.extract_max_per_file,
  452. json_retries=args.json_retries,
  453. )