| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175 |
- import base64
- import dbm
- import logging
- import pickle
- import sqlite3
- from pathlib import Path
- from typing import Iterable, Optional
- from qdrant_client.http import models
- STORAGE_FILE_NAME_OLD = "storage.dbm"
- STORAGE_FILE_NAME = "storage.sqlite"
- def try_migrate_to_sqlite(location: str) -> None:
- dbm_path = Path(location) / STORAGE_FILE_NAME_OLD
- sql_path = Path(location) / STORAGE_FILE_NAME
- if sql_path.exists():
- return
- if not dbm_path.exists():
- return
- try:
- dbm_storage = dbm.open(str(dbm_path), "c")
- con = sqlite3.connect(str(sql_path))
- cur = con.cursor()
- # Create table
- cur.execute("CREATE TABLE IF NOT EXISTS points (id TEXT PRIMARY KEY, point BLOB)")
- for key in dbm_storage.keys():
- value = dbm_storage[key]
- if isinstance(key, str):
- key = key.encode("utf-8")
- key = pickle.loads(key)
- sqlite_key = CollectionPersistence.encode_key(key)
- # Insert a row of data
- cur.execute(
- "INSERT INTO points VALUES (?, ?)",
- (
- sqlite_key,
- sqlite3.Binary(value),
- ),
- )
- con.commit()
- con.close()
- dbm_storage.close()
- dbm_path.unlink()
- except Exception as e:
- logging.error("Failed to migrate dbm to sqlite:", e)
- logging.error(
- "Please try to use previous version of qdrant-client or re-create collection"
- )
- raise e
- class CollectionPersistence:
- CHECK_SAME_THREAD: Optional[bool] = None
- @classmethod
- def encode_key(cls, key: models.ExtendedPointId) -> str:
- return base64.b64encode(pickle.dumps(key)).decode("utf-8")
- def __init__(self, location: str, force_disable_check_same_thread: bool = False):
- """
- Create or load a collection from the local storage.
- Args:
- location: path to the collection directory.
- """
- try_migrate_to_sqlite(location)
- self.location = Path(location) / STORAGE_FILE_NAME
- self.location.parent.mkdir(exist_ok=True, parents=True)
- if self.CHECK_SAME_THREAD is None and force_disable_check_same_thread is False:
- with sqlite3.connect(":memory:") as tmp_conn:
- # it is unsafe to use `sqlite3.threadsafety` until python3.11 since it was hardcoded to 1, thus we
- # need to fetch threadsafe with a query
- # THREADSAFE = 0: Threads may not share the module
- # THREADSAFE = 1: Threads may share the module, connections and cursors. Default for Linux.
- # THREADSAFE = 2: Threads may share the module, but not connections. Default for macOS.
- threadsafe = tmp_conn.execute(
- "select * from pragma_compile_options where compile_options like 'THREADSAFE=%'"
- ).fetchone()[0]
- self.__class__.CHECK_SAME_THREAD = threadsafe != "THREADSAFE=1"
- if force_disable_check_same_thread:
- self.__class__.CHECK_SAME_THREAD = False
- self.storage = sqlite3.connect(
- str(self.location), check_same_thread=self.CHECK_SAME_THREAD # type: ignore
- )
- self._ensure_table()
- def close(self) -> None:
- self.storage.close()
- def _ensure_table(self) -> None:
- cursor = self.storage.cursor()
- cursor.execute("CREATE TABLE IF NOT EXISTS points (id TEXT PRIMARY KEY, point BLOB)")
- self.storage.commit()
- def persist(self, point: models.PointStruct) -> None:
- """
- Persist a point in the local storage.
- Args:
- point: point to persist
- """
- key = self.encode_key(point.id)
- value = pickle.dumps(point)
- cursor = self.storage.cursor()
- # Insert or update by key
- cursor.execute(
- "INSERT OR REPLACE INTO points VALUES (?, ?)",
- (
- key,
- sqlite3.Binary(value),
- ),
- )
- self.storage.commit()
- def delete(self, point_id: models.ExtendedPointId) -> None:
- """
- Delete a point from the local storage.
- Args:
- point_id: id of the point to delete
- """
- key = self.encode_key(point_id)
- cursor = self.storage.cursor()
- cursor.execute(
- "DELETE FROM points WHERE id = ?",
- (key,),
- )
- self.storage.commit()
- def load(self) -> Iterable[models.PointStruct]:
- """
- Load a point from the local storage.
- Returns:
- point: loaded point
- """
- cursor = self.storage.cursor()
- cursor.execute("SELECT point FROM points")
- for row in cursor.fetchall():
- yield pickle.loads(row[0])
- def test_persistence() -> None:
- import tempfile
- with tempfile.TemporaryDirectory() as tmpdir:
- persistence = CollectionPersistence(tmpdir)
- point = models.PointStruct(id=1, vector=[1.0, 2.0, 3.0], payload={"a": 1})
- persistence.persist(point)
- for loaded_point in persistence.load():
- assert loaded_point == point
- break
- del persistence
- persistence = CollectionPersistence(tmpdir)
- for loaded_point in persistence.load():
- assert loaded_point == point
- break
- persistence.delete(point.id)
- persistence.delete(point.id)
- for _ in persistence.load():
- assert False, "Should not load anything"
|