| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185 |
- import time
- from typing import Iterable, Optional, Any
- from qdrant_client._pydantic_compat import to_dict, model_fields
- from qdrant_client.client_base import QdrantBase
- from qdrant_client.http import models
- def upload_with_retry(
- client: QdrantBase,
- collection_name: str,
- points: Iterable[models.PointStruct],
- max_attempts: int = 3,
- pause: float = 3.0,
- ) -> None:
- attempts = 1
- while attempts <= max_attempts:
- try:
- client.upload_points(
- collection_name=collection_name,
- points=points,
- wait=True,
- )
- return
- except Exception as e:
- print(f"Exception: {e}, attempt {attempts}/{max_attempts}")
- if attempts < max_attempts:
- print(f"Next attempt in {pause} seconds")
- time.sleep(pause)
- attempts += 1
- raise Exception(f"Failed to upload points after {max_attempts} attempts")
- def migrate(
- source_client: QdrantBase,
- dest_client: QdrantBase,
- collection_names: Optional[list[str]] = None,
- recreate_on_collision: bool = False,
- batch_size: int = 100,
- ) -> None:
- """
- Migrate collections from source client to destination client
- Args:
- source_client (QdrantBase): Source client
- dest_client (QdrantBase): Destination client
- collection_names (list[str], optional): List of collection names to migrate.
- If None - migrate all source client collections. Defaults to None.
- recreate_on_collision (bool, optional): If True - recreate collection if it exists, otherwise
- raise ValueError.
- batch_size (int, optional): Batch size for scrolling and uploading vectors. Defaults to 100.
- """
- collection_names = _select_source_collections(source_client, collection_names)
- if any(
- _has_custom_shards(source_client, collection_name) for collection_name in collection_names
- ):
- raise ValueError("Migration of collections with custom shards is not supported yet")
- collisions = _find_collisions(dest_client, collection_names)
- absent_dest_collections = set(collection_names) - set(collisions)
- if collisions and not recreate_on_collision:
- raise ValueError(f"Collections already exist in dest_client: {collisions}")
- for collection_name in absent_dest_collections:
- _recreate_collection(source_client, dest_client, collection_name)
- _migrate_collection(source_client, dest_client, collection_name, batch_size)
- for collection_name in collisions:
- _recreate_collection(source_client, dest_client, collection_name)
- _migrate_collection(source_client, dest_client, collection_name, batch_size)
- def _has_custom_shards(source_client: QdrantBase, collection_name: str) -> bool:
- collection_info = source_client.get_collection(collection_name)
- return (
- getattr(collection_info.config.params, "sharding_method", None)
- == models.ShardingMethod.CUSTOM
- )
- def _select_source_collections(
- source_client: QdrantBase, collection_names: Optional[list[str]] = None
- ) -> list[str]:
- source_collections = source_client.get_collections().collections
- source_collection_names = [collection.name for collection in source_collections]
- if collection_names is not None:
- assert all(
- collection_name in source_collection_names for collection_name in collection_names
- ), f"Source client does not have collections: {set(collection_names) - set(source_collection_names)}"
- else:
- collection_names = source_collection_names
- return collection_names
- def _find_collisions(dest_client: QdrantBase, collection_names: list[str]) -> list[str]:
- dest_collections = dest_client.get_collections().collections
- dest_collection_names = {collection.name for collection in dest_collections}
- existing_dest_collections = dest_collection_names & set(collection_names)
- return list(existing_dest_collections)
- def _recreate_collection(
- source_client: QdrantBase,
- dest_client: QdrantBase,
- collection_name: str,
- ) -> None:
- src_collection_info = source_client.get_collection(collection_name)
- src_config = src_collection_info.config
- src_payload_schema = src_collection_info.payload_schema
- if dest_client.collection_exists(collection_name):
- dest_client.delete_collection(collection_name)
- strict_mode_config: Optional[models.StrictModeConfig] = None
- if src_config.strict_mode_config is not None:
- strict_mode_config = models.StrictModeConfig(
- **{
- k: v
- for k, v in to_dict(src_config.strict_mode_config).items()
- if k in model_fields(models.StrictModeConfig)
- }
- )
- dest_client.create_collection(
- collection_name,
- vectors_config=src_config.params.vectors,
- sparse_vectors_config=src_config.params.sparse_vectors,
- shard_number=src_config.params.shard_number,
- replication_factor=src_config.params.replication_factor,
- write_consistency_factor=src_config.params.write_consistency_factor,
- on_disk_payload=src_config.params.on_disk_payload,
- hnsw_config=models.HnswConfigDiff(**to_dict(src_config.hnsw_config)),
- optimizers_config=models.OptimizersConfigDiff(**to_dict(src_config.optimizer_config)),
- wal_config=models.WalConfigDiff(**to_dict(src_config.wal_config)),
- quantization_config=src_config.quantization_config,
- strict_mode_config=strict_mode_config,
- )
- _recreate_payload_schema(dest_client, collection_name, src_payload_schema)
- def _recreate_payload_schema(
- dest_client: QdrantBase,
- collection_name: str,
- payload_schema: dict[str, models.PayloadIndexInfo],
- ) -> None:
- for field_name, field_info in payload_schema.items():
- dest_client.create_payload_index(
- collection_name,
- field_name=field_name,
- field_schema=field_info.data_type if field_info.params is None else field_info.params,
- )
- def _migrate_collection(
- source_client: QdrantBase,
- dest_client: QdrantBase,
- collection_name: str,
- batch_size: int = 100,
- ) -> None:
- """Migrate collection from source client to destination client
- Args:
- collection_name (str): Collection name
- source_client (QdrantBase): Source client
- dest_client (QdrantBase): Destination client
- batch_size (int, optional): Batch size for scrolling and uploading vectors. Defaults to 100.
- """
- records, next_offset = source_client.scroll(collection_name, limit=2, with_vectors=True)
- upload_with_retry(client=dest_client, collection_name=collection_name, points=records) # type: ignore
- # upload_records has been deprecated due to the usage of models.Record; models.Record has been deprecated as a
- # structure for uploading due to a `shard_key` field, and now is used only as a result structure.
- # since shard_keys are not supported in migration, we can safely type ignore here and use Records for uploading
- while next_offset is not None:
- records, next_offset = source_client.scroll(
- collection_name, offset=next_offset, limit=batch_size, with_vectors=True
- )
- upload_with_retry(client=dest_client, collection_name=collection_name, points=records) # type: ignore
- source_client_vectors_count = source_client.count(collection_name).count
- dest_client_vectors_count = dest_client.count(collection_name).count
- assert (
- source_client_vectors_count == dest_client_vectors_count
- ), f"Migration failed, vectors count are not equal: source vector count {source_client_vectors_count}, dest vector count {dest_client_vectors_count}"
|