persistence.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import base64
  2. import dbm
  3. import logging
  4. import pickle
  5. import sqlite3
  6. from pathlib import Path
  7. from typing import Iterable, Optional
  8. from qdrant_client.http import models
  9. STORAGE_FILE_NAME_OLD = "storage.dbm"
  10. STORAGE_FILE_NAME = "storage.sqlite"
  11. def try_migrate_to_sqlite(location: str) -> None:
  12. dbm_path = Path(location) / STORAGE_FILE_NAME_OLD
  13. sql_path = Path(location) / STORAGE_FILE_NAME
  14. if sql_path.exists():
  15. return
  16. if not dbm_path.exists():
  17. return
  18. try:
  19. dbm_storage = dbm.open(str(dbm_path), "c")
  20. con = sqlite3.connect(str(sql_path))
  21. cur = con.cursor()
  22. # Create table
  23. cur.execute("CREATE TABLE IF NOT EXISTS points (id TEXT PRIMARY KEY, point BLOB)")
  24. for key in dbm_storage.keys():
  25. value = dbm_storage[key]
  26. if isinstance(key, str):
  27. key = key.encode("utf-8")
  28. key = pickle.loads(key)
  29. sqlite_key = CollectionPersistence.encode_key(key)
  30. # Insert a row of data
  31. cur.execute(
  32. "INSERT INTO points VALUES (?, ?)",
  33. (
  34. sqlite_key,
  35. sqlite3.Binary(value),
  36. ),
  37. )
  38. con.commit()
  39. con.close()
  40. dbm_storage.close()
  41. dbm_path.unlink()
  42. except Exception as e:
  43. logging.error("Failed to migrate dbm to sqlite:", e)
  44. logging.error(
  45. "Please try to use previous version of qdrant-client or re-create collection"
  46. )
  47. raise e
  48. class CollectionPersistence:
  49. CHECK_SAME_THREAD: Optional[bool] = None
  50. @classmethod
  51. def encode_key(cls, key: models.ExtendedPointId) -> str:
  52. return base64.b64encode(pickle.dumps(key)).decode("utf-8")
  53. def __init__(self, location: str, force_disable_check_same_thread: bool = False):
  54. """
  55. Create or load a collection from the local storage.
  56. Args:
  57. location: path to the collection directory.
  58. """
  59. try_migrate_to_sqlite(location)
  60. self.location = Path(location) / STORAGE_FILE_NAME
  61. self.location.parent.mkdir(exist_ok=True, parents=True)
  62. if self.CHECK_SAME_THREAD is None and force_disable_check_same_thread is False:
  63. with sqlite3.connect(":memory:") as tmp_conn:
  64. # it is unsafe to use `sqlite3.threadsafety` until python3.11 since it was hardcoded to 1, thus we
  65. # need to fetch threadsafe with a query
  66. # THREADSAFE = 0: Threads may not share the module
  67. # THREADSAFE = 1: Threads may share the module, connections and cursors. Default for Linux.
  68. # THREADSAFE = 2: Threads may share the module, but not connections. Default for macOS.
  69. threadsafe = tmp_conn.execute(
  70. "select * from pragma_compile_options where compile_options like 'THREADSAFE=%'"
  71. ).fetchone()[0]
  72. self.__class__.CHECK_SAME_THREAD = threadsafe != "THREADSAFE=1"
  73. if force_disable_check_same_thread:
  74. self.__class__.CHECK_SAME_THREAD = False
  75. self.storage = sqlite3.connect(
  76. str(self.location), check_same_thread=self.CHECK_SAME_THREAD # type: ignore
  77. )
  78. self._ensure_table()
  79. def close(self) -> None:
  80. self.storage.close()
  81. def _ensure_table(self) -> None:
  82. cursor = self.storage.cursor()
  83. cursor.execute("CREATE TABLE IF NOT EXISTS points (id TEXT PRIMARY KEY, point BLOB)")
  84. self.storage.commit()
  85. def persist(self, point: models.PointStruct) -> None:
  86. """
  87. Persist a point in the local storage.
  88. Args:
  89. point: point to persist
  90. """
  91. key = self.encode_key(point.id)
  92. value = pickle.dumps(point)
  93. cursor = self.storage.cursor()
  94. # Insert or update by key
  95. cursor.execute(
  96. "INSERT OR REPLACE INTO points VALUES (?, ?)",
  97. (
  98. key,
  99. sqlite3.Binary(value),
  100. ),
  101. )
  102. self.storage.commit()
  103. def delete(self, point_id: models.ExtendedPointId) -> None:
  104. """
  105. Delete a point from the local storage.
  106. Args:
  107. point_id: id of the point to delete
  108. """
  109. key = self.encode_key(point_id)
  110. cursor = self.storage.cursor()
  111. cursor.execute(
  112. "DELETE FROM points WHERE id = ?",
  113. (key,),
  114. )
  115. self.storage.commit()
  116. def load(self) -> Iterable[models.PointStruct]:
  117. """
  118. Load a point from the local storage.
  119. Returns:
  120. point: loaded point
  121. """
  122. cursor = self.storage.cursor()
  123. cursor.execute("SELECT point FROM points")
  124. for row in cursor.fetchall():
  125. yield pickle.loads(row[0])
  126. def test_persistence() -> None:
  127. import tempfile
  128. with tempfile.TemporaryDirectory() as tmpdir:
  129. persistence = CollectionPersistence(tmpdir)
  130. point = models.PointStruct(id=1, vector=[1.0, 2.0, 3.0], payload={"a": 1})
  131. persistence.persist(point)
  132. for loaded_point in persistence.load():
  133. assert loaded_point == point
  134. break
  135. del persistence
  136. persistence = CollectionPersistence(tmpdir)
  137. for loaded_point in persistence.load():
  138. assert loaded_point == point
  139. break
  140. persistence.delete(point.id)
  141. persistence.delete(point.id)
  142. for _ in persistence.load():
  143. assert False, "Should not load anything"