migrate.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import time
  2. from typing import Iterable, Optional, Any
  3. from qdrant_client._pydantic_compat import to_dict, model_fields
  4. from qdrant_client.client_base import QdrantBase
  5. from qdrant_client.http import models
  6. def upload_with_retry(
  7. client: QdrantBase,
  8. collection_name: str,
  9. points: Iterable[models.PointStruct],
  10. max_attempts: int = 3,
  11. pause: float = 3.0,
  12. ) -> None:
  13. attempts = 1
  14. while attempts <= max_attempts:
  15. try:
  16. client.upload_points(
  17. collection_name=collection_name,
  18. points=points,
  19. wait=True,
  20. )
  21. return
  22. except Exception as e:
  23. print(f"Exception: {e}, attempt {attempts}/{max_attempts}")
  24. if attempts < max_attempts:
  25. print(f"Next attempt in {pause} seconds")
  26. time.sleep(pause)
  27. attempts += 1
  28. raise Exception(f"Failed to upload points after {max_attempts} attempts")
  29. def migrate(
  30. source_client: QdrantBase,
  31. dest_client: QdrantBase,
  32. collection_names: Optional[list[str]] = None,
  33. recreate_on_collision: bool = False,
  34. batch_size: int = 100,
  35. ) -> None:
  36. """
  37. Migrate collections from source client to destination client
  38. Args:
  39. source_client (QdrantBase): Source client
  40. dest_client (QdrantBase): Destination client
  41. collection_names (list[str], optional): List of collection names to migrate.
  42. If None - migrate all source client collections. Defaults to None.
  43. recreate_on_collision (bool, optional): If True - recreate collection if it exists, otherwise
  44. raise ValueError.
  45. batch_size (int, optional): Batch size for scrolling and uploading vectors. Defaults to 100.
  46. """
  47. collection_names = _select_source_collections(source_client, collection_names)
  48. if any(
  49. _has_custom_shards(source_client, collection_name) for collection_name in collection_names
  50. ):
  51. raise ValueError("Migration of collections with custom shards is not supported yet")
  52. collisions = _find_collisions(dest_client, collection_names)
  53. absent_dest_collections = set(collection_names) - set(collisions)
  54. if collisions and not recreate_on_collision:
  55. raise ValueError(f"Collections already exist in dest_client: {collisions}")
  56. for collection_name in absent_dest_collections:
  57. _recreate_collection(source_client, dest_client, collection_name)
  58. _migrate_collection(source_client, dest_client, collection_name, batch_size)
  59. for collection_name in collisions:
  60. _recreate_collection(source_client, dest_client, collection_name)
  61. _migrate_collection(source_client, dest_client, collection_name, batch_size)
  62. def _has_custom_shards(source_client: QdrantBase, collection_name: str) -> bool:
  63. collection_info = source_client.get_collection(collection_name)
  64. return (
  65. getattr(collection_info.config.params, "sharding_method", None)
  66. == models.ShardingMethod.CUSTOM
  67. )
  68. def _select_source_collections(
  69. source_client: QdrantBase, collection_names: Optional[list[str]] = None
  70. ) -> list[str]:
  71. source_collections = source_client.get_collections().collections
  72. source_collection_names = [collection.name for collection in source_collections]
  73. if collection_names is not None:
  74. assert all(
  75. collection_name in source_collection_names for collection_name in collection_names
  76. ), f"Source client does not have collections: {set(collection_names) - set(source_collection_names)}"
  77. else:
  78. collection_names = source_collection_names
  79. return collection_names
  80. def _find_collisions(dest_client: QdrantBase, collection_names: list[str]) -> list[str]:
  81. dest_collections = dest_client.get_collections().collections
  82. dest_collection_names = {collection.name for collection in dest_collections}
  83. existing_dest_collections = dest_collection_names & set(collection_names)
  84. return list(existing_dest_collections)
  85. def _recreate_collection(
  86. source_client: QdrantBase,
  87. dest_client: QdrantBase,
  88. collection_name: str,
  89. ) -> None:
  90. src_collection_info = source_client.get_collection(collection_name)
  91. src_config = src_collection_info.config
  92. src_payload_schema = src_collection_info.payload_schema
  93. if dest_client.collection_exists(collection_name):
  94. dest_client.delete_collection(collection_name)
  95. strict_mode_config: Optional[models.StrictModeConfig] = None
  96. if src_config.strict_mode_config is not None:
  97. strict_mode_config = models.StrictModeConfig(
  98. **{
  99. k: v
  100. for k, v in to_dict(src_config.strict_mode_config).items()
  101. if k in model_fields(models.StrictModeConfig)
  102. }
  103. )
  104. dest_client.create_collection(
  105. collection_name,
  106. vectors_config=src_config.params.vectors,
  107. sparse_vectors_config=src_config.params.sparse_vectors,
  108. shard_number=src_config.params.shard_number,
  109. replication_factor=src_config.params.replication_factor,
  110. write_consistency_factor=src_config.params.write_consistency_factor,
  111. on_disk_payload=src_config.params.on_disk_payload,
  112. hnsw_config=models.HnswConfigDiff(**to_dict(src_config.hnsw_config)),
  113. optimizers_config=models.OptimizersConfigDiff(**to_dict(src_config.optimizer_config)),
  114. wal_config=models.WalConfigDiff(**to_dict(src_config.wal_config)),
  115. quantization_config=src_config.quantization_config,
  116. strict_mode_config=strict_mode_config,
  117. )
  118. _recreate_payload_schema(dest_client, collection_name, src_payload_schema)
  119. def _recreate_payload_schema(
  120. dest_client: QdrantBase,
  121. collection_name: str,
  122. payload_schema: dict[str, models.PayloadIndexInfo],
  123. ) -> None:
  124. for field_name, field_info in payload_schema.items():
  125. dest_client.create_payload_index(
  126. collection_name,
  127. field_name=field_name,
  128. field_schema=field_info.data_type if field_info.params is None else field_info.params,
  129. )
  130. def _migrate_collection(
  131. source_client: QdrantBase,
  132. dest_client: QdrantBase,
  133. collection_name: str,
  134. batch_size: int = 100,
  135. ) -> None:
  136. """Migrate collection from source client to destination client
  137. Args:
  138. collection_name (str): Collection name
  139. source_client (QdrantBase): Source client
  140. dest_client (QdrantBase): Destination client
  141. batch_size (int, optional): Batch size for scrolling and uploading vectors. Defaults to 100.
  142. """
  143. records, next_offset = source_client.scroll(collection_name, limit=2, with_vectors=True)
  144. upload_with_retry(client=dest_client, collection_name=collection_name, points=records) # type: ignore
  145. # upload_records has been deprecated due to the usage of models.Record; models.Record has been deprecated as a
  146. # structure for uploading due to a `shard_key` field, and now is used only as a result structure.
  147. # since shard_keys are not supported in migration, we can safely type ignore here and use Records for uploading
  148. while next_offset is not None:
  149. records, next_offset = source_client.scroll(
  150. collection_name, offset=next_offset, limit=batch_size, with_vectors=True
  151. )
  152. upload_with_retry(client=dest_client, collection_name=collection_name, points=records) # type: ignore
  153. source_client_vectors_count = source_client.count(collection_name).count
  154. dest_client_vectors_count = dest_client.count(collection_name).count
  155. assert (
  156. source_client_vectors_count == dest_client_vectors_count
  157. ), f"Migration failed, vectors count are not equal: source vector count {source_client_vectors_count}, dest vector count {dest_client_vectors_count}"