uploader.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. from abc import ABC
  2. from itertools import count, islice
  3. from typing import Any, Generator, Iterable, Optional, Union
  4. import numpy as np
  5. from qdrant_client.conversions import common_types as types
  6. from qdrant_client.conversions.common_types import Record
  7. from qdrant_client.http.models import ExtendedPointId
  8. from qdrant_client.parallel_processor import Worker
  9. def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable:
  10. """
  11. >>> list(iter_batch([1,2,3,4,5], 3))
  12. [[1, 2, 3], [4, 5]]
  13. """
  14. source_iter = iter(iterable)
  15. while source_iter:
  16. b = list(islice(source_iter, size))
  17. if len(b) == 0:
  18. break
  19. yield b
  20. class BaseUploader(Worker, ABC):
  21. @classmethod
  22. def iterate_records_batches(
  23. cls,
  24. records: Iterable[Union[Record, types.PointStruct]],
  25. batch_size: int,
  26. ) -> Iterable:
  27. record_batches = iter_batch(records, batch_size)
  28. for record_batch in record_batches:
  29. ids_batch, vectors_batch, payload_batch = [], [], []
  30. for record in record_batch:
  31. ids_batch.append(record.id)
  32. vectors_batch.append(record.vector)
  33. payload_batch.append(record.payload)
  34. yield ids_batch, vectors_batch, payload_batch
  35. @classmethod
  36. def iterate_batches(
  37. cls,
  38. vectors: Union[
  39. dict[str, types.NumpyArray], types.NumpyArray, Iterable[types.VectorStruct]
  40. ],
  41. payload: Optional[Iterable[dict]],
  42. ids: Optional[Iterable[ExtendedPointId]],
  43. batch_size: int,
  44. ) -> Iterable:
  45. if ids is None:
  46. ids_batches: Iterable = (None for _ in count())
  47. else:
  48. ids_batches = iter_batch(ids, batch_size)
  49. if payload is None:
  50. payload_batches: Iterable = (None for _ in count())
  51. else:
  52. payload_batches = iter_batch(payload, batch_size)
  53. if isinstance(vectors, np.ndarray):
  54. vector_batches: Iterable[Any] = cls._vector_batches_from_numpy(vectors, batch_size)
  55. elif isinstance(vectors, dict) and any(
  56. isinstance(value, np.ndarray) for value in vectors.values()
  57. ):
  58. vector_batches = cls._vector_batches_from_numpy_named_vectors(vectors, batch_size)
  59. else:
  60. vector_batches = iter_batch(vectors, batch_size)
  61. yield from zip(ids_batches, vector_batches, payload_batches)
  62. @staticmethod
  63. def _vector_batches_from_numpy(vectors: types.NumpyArray, batch_size: int) -> Iterable[float]:
  64. for i in range(0, vectors.shape[0], batch_size):
  65. yield vectors[i : i + batch_size].tolist()
  66. @staticmethod
  67. def _vector_batches_from_numpy_named_vectors(
  68. vectors: dict[str, types.NumpyArray], batch_size: int
  69. ) -> Iterable[dict[str, list[float]]]:
  70. assert (
  71. len(set([arr.shape[0] for arr in vectors.values()])) == 1
  72. ), "Each named vector should have the same number of vectors"
  73. num_vectors = next(iter(vectors.values())).shape[0]
  74. # Convert dict[str, np.ndarray] to Generator(dict[str, list[float]])
  75. vector_batches = (
  76. {name: vectors[name][i].tolist() for name in vectors.keys()}
  77. for i in range(num_vectors)
  78. )
  79. yield from iter_batch(vector_batches, batch_size)