qdrant_remote.py 135 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389
  1. import importlib.metadata
  2. import logging
  3. import math
  4. import platform
  5. from multiprocessing import get_all_start_methods
  6. from typing import (
  7. Any,
  8. Awaitable,
  9. Callable,
  10. Iterable,
  11. Mapping,
  12. Optional,
  13. Sequence,
  14. Type,
  15. Union,
  16. get_args,
  17. )
  18. import httpx
  19. import numpy as np
  20. from grpc import Compression
  21. from urllib3.util import Url, parse_url
  22. from urllib.parse import urljoin
  23. from qdrant_client.common.client_warnings import show_warning, show_warning_once
  24. from qdrant_client import grpc as grpc
  25. from qdrant_client._pydantic_compat import construct
  26. from qdrant_client.auth import BearerAuth
  27. from qdrant_client.client_base import QdrantBase
  28. from qdrant_client.common.version_check import is_compatible, get_server_version
  29. from qdrant_client.connection import get_channel
  30. from qdrant_client.conversions import common_types as types
  31. from qdrant_client.conversions.common_types import get_args_subscribed
  32. from qdrant_client.conversions.conversion import (
  33. GrpcToRest,
  34. RestToGrpc,
  35. grpc_payload_schema_to_field_type,
  36. )
  37. from qdrant_client.http import ApiClient, SyncApis, models
  38. from qdrant_client.parallel_processor import ParallelWorkerPool
  39. from qdrant_client.uploader.grpc_uploader import GrpcBatchUploader
  40. from qdrant_client.uploader.rest_uploader import RestBatchUploader
  41. from qdrant_client.uploader.uploader import BaseUploader
  42. class QdrantRemote(QdrantBase):
  43. DEFAULT_GRPC_TIMEOUT = 5 # seconds
  44. def __init__(
  45. self,
  46. url: Optional[str] = None,
  47. port: Optional[int] = 6333,
  48. grpc_port: int = 6334,
  49. prefer_grpc: bool = False,
  50. https: Optional[bool] = None,
  51. api_key: Optional[str] = None,
  52. prefix: Optional[str] = None,
  53. timeout: Optional[int] = None,
  54. host: Optional[str] = None,
  55. grpc_options: Optional[dict[str, Any]] = None,
  56. auth_token_provider: Optional[
  57. Union[Callable[[], str], Callable[[], Awaitable[str]]]
  58. ] = None,
  59. check_compatibility: bool = True,
  60. **kwargs: Any,
  61. ):
  62. super().__init__(**kwargs)
  63. self._prefer_grpc = prefer_grpc
  64. self._grpc_port = grpc_port
  65. self._grpc_options = grpc_options or {}
  66. self._https = https if https is not None else api_key is not None
  67. self._scheme = "https" if self._https else "http"
  68. self._prefix = prefix or ""
  69. if len(self._prefix) > 0 and self._prefix[0] != "/":
  70. self._prefix = f"/{self._prefix}"
  71. if url is not None and host is not None:
  72. raise ValueError(f"Only one of (url, host) can be set. url is {url}, host is {host}")
  73. if host is not None and (host.startswith("http://") or host.startswith("https://")):
  74. raise ValueError(
  75. f"`host` param is not expected to contain protocol (http:// or https://). "
  76. f"Try to use `url` parameter instead."
  77. )
  78. elif url:
  79. if url.startswith("localhost"):
  80. # Handle for a special case when url is localhost:port
  81. # Which is not parsed correctly by urllib
  82. url = f"//{url}"
  83. parsed_url: Url = parse_url(url)
  84. self._host, self._port = parsed_url.host, parsed_url.port
  85. if parsed_url.scheme:
  86. self._https = parsed_url.scheme == "https"
  87. self._scheme = parsed_url.scheme
  88. self._port = self._port if self._port else port
  89. if self._prefix and parsed_url.path:
  90. raise ValueError(
  91. "Prefix can be set either in `url` or in `prefix`. "
  92. f"url is {url}, prefix is {parsed_url.path}"
  93. )
  94. elif parsed_url.path:
  95. self._prefix = parsed_url.path
  96. if self._scheme not in ("http", "https"):
  97. raise ValueError(f"Unknown scheme: {self._scheme}")
  98. else:
  99. self._host = host or "localhost"
  100. self._port = port
  101. _timeout = (
  102. math.ceil(timeout) if timeout is not None else None
  103. ) # it has been changed from float to int.
  104. # convert it to the closest greater or equal int value (e.g. 0.5 -> 1)
  105. self._api_key = api_key
  106. self._auth_token_provider = auth_token_provider
  107. limits = kwargs.pop("limits", None)
  108. if limits is None:
  109. if self._host in ["localhost", "127.0.0.1"]:
  110. # Disable keep-alive for local connections
  111. # Cause in some cases, it may cause extra delays
  112. limits = httpx.Limits(max_connections=None, max_keepalive_connections=0)
  113. http2 = kwargs.pop("http2", False)
  114. self._grpc_headers = []
  115. self._rest_headers = {k: v for k, v in kwargs.pop("metadata", {}).items()}
  116. if api_key is not None:
  117. if self._scheme == "http":
  118. show_warning(
  119. message="Api key is used with an insecure connection.",
  120. category=UserWarning,
  121. stacklevel=4,
  122. )
  123. # http2 = True
  124. self._rest_headers["api-key"] = api_key
  125. self._grpc_headers.append(("api-key", api_key))
  126. client_version = importlib.metadata.version("qdrant-client")
  127. python_version = platform.python_version()
  128. user_agent = f"python-client/{client_version} python/{python_version}"
  129. self._rest_headers["User-Agent"] = user_agent
  130. self._grpc_options["grpc.primary_user_agent"] = user_agent
  131. # GRPC Channel-Level Compression
  132. grpc_compression: Optional[Compression] = kwargs.pop("grpc_compression", None)
  133. if grpc_compression is not None and not isinstance(grpc_compression, Compression):
  134. raise TypeError(
  135. f"Expected 'grpc_compression' to be of type "
  136. f"grpc.Compression or None, but got {type(grpc_compression)}"
  137. )
  138. if grpc_compression == Compression.Deflate:
  139. raise ValueError(
  140. "grpc.Compression.Deflate is not supported. Try grpc.Compression.Gzip or grpc.Compression.NoCompression"
  141. )
  142. self._grpc_compression = grpc_compression
  143. address = f"{self._host}:{self._port}" if self._port is not None else self._host
  144. base_url = f"{self._scheme}://{address}"
  145. self.rest_uri = urljoin(base_url, self._prefix)
  146. self._rest_args = {"headers": self._rest_headers, "http2": http2, **kwargs}
  147. if limits is not None:
  148. self._rest_args["limits"] = limits
  149. if _timeout is not None:
  150. self._rest_args["timeout"] = _timeout
  151. self._timeout = _timeout
  152. else:
  153. self._timeout = self.DEFAULT_GRPC_TIMEOUT
  154. if self._auth_token_provider is not None:
  155. if self._scheme == "http":
  156. show_warning(
  157. message="Auth token provider is used with an insecure connection.",
  158. category=UserWarning,
  159. stacklevel=4,
  160. )
  161. bearer_auth = BearerAuth(self._auth_token_provider)
  162. self._rest_args["auth"] = bearer_auth
  163. self.openapi_client: SyncApis[ApiClient] = SyncApis(
  164. host=self.rest_uri,
  165. **self._rest_args,
  166. )
  167. self._grpc_channel = None
  168. self._grpc_points_client: Optional[grpc.PointsStub] = None
  169. self._grpc_collections_client: Optional[grpc.CollectionsStub] = None
  170. self._grpc_snapshots_client: Optional[grpc.SnapshotsStub] = None
  171. self._grpc_root_client: Optional[grpc.QdrantStub] = None
  172. self._aio_grpc_points_client: Optional[grpc.PointsStub] = None
  173. self._aio_grpc_collections_client: Optional[grpc.CollectionsStub] = None
  174. self._aio_grpc_snapshots_client: Optional[grpc.SnapshotsStub] = None
  175. self._aio_grpc_root_client: Optional[grpc.QdrantStub] = None
  176. self._closed: bool = False
  177. if check_compatibility:
  178. try:
  179. client_version = importlib.metadata.version("qdrant-client")
  180. server_version = get_server_version(
  181. self.rest_uri, self._rest_headers, self._rest_args.get("auth")
  182. )
  183. if not server_version:
  184. show_warning(
  185. message="Failed to obtain server version. Unable to check client-server compatibility."
  186. " Set check_compatibility=False to skip version check.",
  187. category=UserWarning,
  188. stacklevel=4,
  189. )
  190. elif not is_compatible(client_version, server_version):
  191. show_warning(
  192. message=f"Qdrant client version {client_version} is incompatible with server "
  193. f"version {server_version}. Major versions should match and minor version difference "
  194. "must not exceed 1. Set check_compatibility=False to skip version check.",
  195. category=UserWarning,
  196. stacklevel=4,
  197. )
  198. except Exception as er:
  199. logging.debug(
  200. f"Unable to get server version: {er}, server version defaults to None"
  201. )
  202. @property
  203. def closed(self) -> bool:
  204. return self._closed
  205. def close(self, grpc_grace: Optional[float] = None, **kwargs: Any) -> None:
  206. if hasattr(self, "_grpc_channel") and self._grpc_channel is not None:
  207. try:
  208. self._grpc_channel.close()
  209. except AttributeError:
  210. show_warning(
  211. message="Unable to close grpc_channel. Connection was interrupted on the server side",
  212. category=RuntimeWarning,
  213. stacklevel=4,
  214. )
  215. try:
  216. self.openapi_client.close()
  217. except Exception:
  218. show_warning(
  219. message="Unable to close http connection. Connection was interrupted on the server side",
  220. category=RuntimeWarning,
  221. stacklevel=4,
  222. )
  223. self._closed = True
  224. @staticmethod
  225. def _parse_url(url: str) -> tuple[Optional[str], str, Optional[int], Optional[str]]:
  226. parse_result: Url = parse_url(url)
  227. scheme, host, port, prefix = (
  228. parse_result.scheme,
  229. parse_result.host,
  230. parse_result.port,
  231. parse_result.path,
  232. )
  233. return scheme, host, port, prefix
  234. def _init_grpc_channel(self) -> None:
  235. if self._closed:
  236. raise RuntimeError("Client was closed. Please create a new QdrantClient instance.")
  237. if self._grpc_channel is None:
  238. self._grpc_channel = get_channel(
  239. host=self._host,
  240. port=self._grpc_port,
  241. ssl=self._https,
  242. metadata=self._grpc_headers,
  243. options=self._grpc_options,
  244. compression=self._grpc_compression,
  245. # sync get_channel does not accept coroutine functions,
  246. # but we can't check type here, since it'll get into async client as well
  247. auth_token_provider=self._auth_token_provider, # type: ignore
  248. )
  249. def _init_grpc_points_client(self) -> None:
  250. self._init_grpc_channel()
  251. self._grpc_points_client = grpc.PointsStub(self._grpc_channel)
  252. def _init_grpc_collections_client(self) -> None:
  253. self._init_grpc_channel()
  254. self._grpc_collections_client = grpc.CollectionsStub(self._grpc_channel)
  255. def _init_grpc_snapshots_client(self) -> None:
  256. self._init_grpc_channel()
  257. self._grpc_snapshots_client = grpc.SnapshotsStub(self._grpc_channel)
  258. def _init_grpc_root_client(self) -> None:
  259. self._init_grpc_channel()
  260. self._grpc_root_client = grpc.QdrantStub(self._grpc_channel)
  261. @property
  262. def grpc_collections(self) -> grpc.CollectionsStub:
  263. """gRPC client for collections methods
  264. Returns:
  265. An instance of raw gRPC client, generated from Protobuf
  266. """
  267. if self._grpc_collections_client is None:
  268. self._init_grpc_collections_client()
  269. return self._grpc_collections_client
  270. @property
  271. def grpc_points(self) -> grpc.PointsStub:
  272. """gRPC client for points methods
  273. Returns:
  274. An instance of raw gRPC client, generated from Protobuf
  275. """
  276. if self._grpc_points_client is None:
  277. self._init_grpc_points_client()
  278. return self._grpc_points_client
  279. @property
  280. def grpc_snapshots(self) -> grpc.SnapshotsStub:
  281. """gRPC client for snapshots methods
  282. Returns:
  283. An instance of raw gRPC client, generated from Protobuf
  284. """
  285. if self._grpc_snapshots_client is None:
  286. self._init_grpc_snapshots_client()
  287. return self._grpc_snapshots_client
  288. @property
  289. def grpc_root(self) -> grpc.QdrantStub:
  290. """gRPC client for info methods
  291. Returns:
  292. An instance of raw gRPC client, generated from Protobuf
  293. """
  294. if self._grpc_root_client is None:
  295. self._init_grpc_root_client()
  296. return self._grpc_root_client
  297. @property
  298. def rest(self) -> SyncApis[ApiClient]:
  299. """REST Client
  300. Returns:
  301. An instance of raw REST API client, generated from OpenAPI schema
  302. """
  303. return self.openapi_client
  304. @property
  305. def http(self) -> SyncApis[ApiClient]:
  306. """REST Client
  307. Returns:
  308. An instance of raw REST API client, generated from OpenAPI schema
  309. """
  310. return self.openapi_client
  311. def search_batch(
  312. self,
  313. collection_name: str,
  314. requests: Sequence[types.SearchRequest],
  315. consistency: Optional[types.ReadConsistency] = None,
  316. timeout: Optional[int] = None,
  317. **kwargs: Any,
  318. ) -> list[list[types.ScoredPoint]]:
  319. if self._prefer_grpc:
  320. requests = [
  321. (
  322. RestToGrpc.convert_search_request(r, collection_name)
  323. if isinstance(r, models.SearchRequest)
  324. else r
  325. )
  326. for r in requests
  327. ]
  328. if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
  329. consistency = RestToGrpc.convert_read_consistency(consistency)
  330. grpc_res: grpc.SearchBatchResponse = self.grpc_points.SearchBatch(
  331. grpc.SearchBatchPoints(
  332. collection_name=collection_name,
  333. search_points=requests,
  334. read_consistency=consistency,
  335. timeout=timeout,
  336. ),
  337. timeout=timeout if timeout is not None else self._timeout,
  338. )
  339. return [
  340. [GrpcToRest.convert_scored_point(hit) for hit in r.result] for r in grpc_res.result
  341. ]
  342. else:
  343. requests = [
  344. (GrpcToRest.convert_search_points(r) if isinstance(r, grpc.SearchPoints) else r)
  345. for r in requests
  346. ]
  347. http_res: Optional[list[list[models.ScoredPoint]]] = (
  348. self.http.search_api.search_batch_points(
  349. collection_name=collection_name,
  350. consistency=consistency,
  351. timeout=timeout,
  352. search_request_batch=models.SearchRequestBatch(searches=requests),
  353. ).result
  354. )
  355. assert http_res is not None, "Search batch returned None"
  356. return http_res
  357. def search(
  358. self,
  359. collection_name: str,
  360. query_vector: Union[
  361. Sequence[float],
  362. tuple[str, list[float]],
  363. types.NamedVector,
  364. types.NamedSparseVector,
  365. types.NumpyArray,
  366. ],
  367. query_filter: Optional[types.Filter] = None,
  368. search_params: Optional[types.SearchParams] = None,
  369. limit: int = 10,
  370. offset: Optional[int] = None,
  371. with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
  372. with_vectors: Union[bool, Sequence[str]] = False,
  373. score_threshold: Optional[float] = None,
  374. append_payload: bool = True,
  375. consistency: Optional[types.ReadConsistency] = None,
  376. shard_key_selector: Optional[types.ShardKeySelector] = None,
  377. timeout: Optional[int] = None,
  378. **kwargs: Any,
  379. ) -> list[types.ScoredPoint]:
  380. if not append_payload:
  381. show_warning_once(
  382. message="Usage of `append_payload` is deprecated. Please consider using `with_payload` instead",
  383. category=DeprecationWarning,
  384. stacklevel=5,
  385. idx="search-append-payload",
  386. )
  387. with_payload = append_payload
  388. if isinstance(query_vector, np.ndarray):
  389. query_vector = query_vector.tolist()
  390. if self._prefer_grpc:
  391. vector_name = None
  392. sparse_indices = None
  393. if isinstance(query_vector, types.NamedVector):
  394. vector = query_vector.vector
  395. vector_name = query_vector.name
  396. elif isinstance(query_vector, types.NamedSparseVector):
  397. vector_name = query_vector.name
  398. sparse_indices = grpc.SparseIndices(data=query_vector.vector.indices)
  399. vector = query_vector.vector.values
  400. elif isinstance(query_vector, tuple):
  401. vector_name = query_vector[0]
  402. vector = query_vector[1]
  403. else:
  404. vector = list(query_vector)
  405. if isinstance(query_filter, models.Filter):
  406. query_filter = RestToGrpc.convert_filter(model=query_filter)
  407. if isinstance(search_params, models.SearchParams):
  408. search_params = RestToGrpc.convert_search_params(search_params)
  409. if isinstance(with_payload, get_args_subscribed(models.WithPayloadInterface)):
  410. with_payload = RestToGrpc.convert_with_payload_interface(with_payload)
  411. if isinstance(with_vectors, get_args_subscribed(models.WithVector)):
  412. with_vectors = RestToGrpc.convert_with_vectors(with_vectors)
  413. if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
  414. consistency = RestToGrpc.convert_read_consistency(consistency)
  415. if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
  416. shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
  417. res: grpc.SearchResponse = self.grpc_points.Search(
  418. grpc.SearchPoints(
  419. collection_name=collection_name,
  420. vector=vector,
  421. vector_name=vector_name,
  422. filter=query_filter,
  423. limit=limit,
  424. offset=offset,
  425. with_vectors=with_vectors,
  426. with_payload=with_payload,
  427. params=search_params,
  428. score_threshold=score_threshold,
  429. read_consistency=consistency,
  430. timeout=timeout,
  431. sparse_indices=sparse_indices,
  432. shard_key_selector=shard_key_selector,
  433. ),
  434. timeout=timeout if timeout is not None else self._timeout,
  435. )
  436. return [GrpcToRest.convert_scored_point(hit) for hit in res.result]
  437. else:
  438. if isinstance(query_vector, tuple):
  439. query_vector = types.NamedVector(name=query_vector[0], vector=query_vector[1])
  440. if isinstance(query_filter, grpc.Filter):
  441. query_filter = GrpcToRest.convert_filter(model=query_filter)
  442. if isinstance(search_params, grpc.SearchParams):
  443. search_params = GrpcToRest.convert_search_params(search_params)
  444. if isinstance(with_payload, grpc.WithPayloadSelector):
  445. with_payload = GrpcToRest.convert_with_payload_selector(with_payload)
  446. search_result = self.http.search_api.search_points(
  447. collection_name=collection_name,
  448. consistency=consistency,
  449. timeout=timeout,
  450. search_request=models.SearchRequest(
  451. vector=query_vector,
  452. filter=query_filter,
  453. limit=limit,
  454. offset=offset,
  455. params=search_params,
  456. with_vector=with_vectors,
  457. with_payload=with_payload,
  458. score_threshold=score_threshold,
  459. shard_key=shard_key_selector,
  460. ),
  461. )
  462. result: Optional[list[types.ScoredPoint]] = search_result.result
  463. assert result is not None, "Search returned None"
  464. return result
  465. def query_points(
  466. self,
  467. collection_name: str,
  468. query: Union[
  469. types.PointId,
  470. list[float],
  471. list[list[float]],
  472. types.SparseVector,
  473. types.Query,
  474. types.NumpyArray,
  475. types.Document,
  476. types.Image,
  477. types.InferenceObject,
  478. None,
  479. ] = None,
  480. using: Optional[str] = None,
  481. prefetch: Union[types.Prefetch, list[types.Prefetch], None] = None,
  482. query_filter: Optional[types.Filter] = None,
  483. search_params: Optional[types.SearchParams] = None,
  484. limit: int = 10,
  485. offset: Optional[int] = None,
  486. with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
  487. with_vectors: Union[bool, Sequence[str]] = False,
  488. score_threshold: Optional[float] = None,
  489. lookup_from: Optional[types.LookupLocation] = None,
  490. consistency: Optional[types.ReadConsistency] = None,
  491. shard_key_selector: Optional[types.ShardKeySelector] = None,
  492. timeout: Optional[int] = None,
  493. **kwargs: Any,
  494. ) -> types.QueryResponse:
  495. if self._prefer_grpc:
  496. if query is not None:
  497. query = RestToGrpc.convert_query(query)
  498. if isinstance(prefetch, models.Prefetch):
  499. prefetch = [RestToGrpc.convert_prefetch_query(prefetch)]
  500. if isinstance(prefetch, list):
  501. prefetch = [
  502. RestToGrpc.convert_prefetch_query(p) if isinstance(p, models.Prefetch) else p
  503. for p in prefetch
  504. ]
  505. if isinstance(query_filter, models.Filter):
  506. query_filter = RestToGrpc.convert_filter(model=query_filter)
  507. if isinstance(search_params, models.SearchParams):
  508. search_params = RestToGrpc.convert_search_params(search_params)
  509. if isinstance(with_payload, get_args_subscribed(models.WithPayloadInterface)):
  510. with_payload = RestToGrpc.convert_with_payload_interface(with_payload)
  511. if isinstance(with_vectors, get_args_subscribed(models.WithVector)):
  512. with_vectors = RestToGrpc.convert_with_vectors(with_vectors)
  513. if isinstance(lookup_from, models.LookupLocation):
  514. lookup_from = RestToGrpc.convert_lookup_location(lookup_from)
  515. if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
  516. consistency = RestToGrpc.convert_read_consistency(consistency)
  517. if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
  518. shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
  519. res: grpc.QueryResponse = self.grpc_points.Query(
  520. grpc.QueryPoints(
  521. collection_name=collection_name,
  522. query=query,
  523. prefetch=prefetch,
  524. filter=query_filter,
  525. limit=limit,
  526. offset=offset,
  527. with_vectors=with_vectors,
  528. with_payload=with_payload,
  529. params=search_params,
  530. score_threshold=score_threshold,
  531. using=using,
  532. lookup_from=lookup_from,
  533. timeout=timeout,
  534. shard_key_selector=shard_key_selector,
  535. read_consistency=consistency,
  536. ),
  537. timeout=timeout if timeout is not None else self._timeout,
  538. )
  539. scored_points = [GrpcToRest.convert_scored_point(hit) for hit in res.result]
  540. return models.QueryResponse(points=scored_points)
  541. else:
  542. if isinstance(query, grpc.Query):
  543. query = GrpcToRest.convert_query(query)
  544. if isinstance(prefetch, grpc.PrefetchQuery):
  545. prefetch = GrpcToRest.convert_prefetch_query(prefetch)
  546. if isinstance(prefetch, list):
  547. prefetch = [
  548. GrpcToRest.convert_prefetch_query(p)
  549. if isinstance(p, grpc.PrefetchQuery)
  550. else p
  551. for p in prefetch
  552. ]
  553. if isinstance(query_filter, grpc.Filter):
  554. query_filter = GrpcToRest.convert_filter(model=query_filter)
  555. if isinstance(search_params, grpc.SearchParams):
  556. search_params = GrpcToRest.convert_search_params(search_params)
  557. if isinstance(with_payload, grpc.WithPayloadSelector):
  558. with_payload = GrpcToRest.convert_with_payload_selector(with_payload)
  559. if isinstance(lookup_from, grpc.LookupLocation):
  560. lookup_from = GrpcToRest.convert_lookup_location(lookup_from)
  561. query_request = models.QueryRequest(
  562. shard_key=shard_key_selector,
  563. prefetch=prefetch,
  564. query=query,
  565. using=using,
  566. filter=query_filter,
  567. params=search_params,
  568. score_threshold=score_threshold,
  569. limit=limit,
  570. offset=offset,
  571. with_vector=with_vectors,
  572. with_payload=with_payload,
  573. lookup_from=lookup_from,
  574. )
  575. query_result = self.http.search_api.query_points(
  576. collection_name=collection_name,
  577. consistency=consistency,
  578. timeout=timeout,
  579. query_request=query_request,
  580. )
  581. result: Optional[models.QueryResponse] = query_result.result
  582. assert result is not None, "Search returned None"
  583. return result
  584. def query_batch_points(
  585. self,
  586. collection_name: str,
  587. requests: Sequence[types.QueryRequest],
  588. consistency: Optional[types.ReadConsistency] = None,
  589. timeout: Optional[int] = None,
  590. **kwargs: Any,
  591. ) -> list[types.QueryResponse]:
  592. if self._prefer_grpc:
  593. requests = [
  594. (
  595. RestToGrpc.convert_query_request(r, collection_name)
  596. if isinstance(r, models.QueryRequest)
  597. else r
  598. )
  599. for r in requests
  600. ]
  601. if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
  602. consistency = RestToGrpc.convert_read_consistency(consistency)
  603. grpc_res: grpc.QueryBatchResponse = self.grpc_points.QueryBatch(
  604. grpc.QueryBatchPoints(
  605. collection_name=collection_name,
  606. query_points=requests,
  607. read_consistency=consistency,
  608. timeout=timeout,
  609. ),
  610. timeout=timeout if timeout is not None else self._timeout,
  611. )
  612. return [
  613. models.QueryResponse(
  614. points=[GrpcToRest.convert_scored_point(hit) for hit in r.result]
  615. )
  616. for r in grpc_res.result
  617. ]
  618. else:
  619. requests = [
  620. (GrpcToRest.convert_query_points(r) if isinstance(r, grpc.QueryPoints) else r)
  621. for r in requests
  622. ]
  623. http_res: Optional[list[models.QueryResponse]] = (
  624. self.http.search_api.query_batch_points(
  625. collection_name=collection_name,
  626. consistency=consistency,
  627. timeout=timeout,
  628. query_request_batch=models.QueryRequestBatch(searches=requests),
  629. ).result
  630. )
  631. assert http_res is not None, "Query batch returned None"
  632. return http_res
  633. def query_points_groups(
  634. self,
  635. collection_name: str,
  636. group_by: str,
  637. query: Union[
  638. types.PointId,
  639. list[float],
  640. list[list[float]],
  641. types.SparseVector,
  642. types.Query,
  643. types.NumpyArray,
  644. types.Document,
  645. types.Image,
  646. types.InferenceObject,
  647. None,
  648. ] = None,
  649. using: Optional[str] = None,
  650. prefetch: Union[types.Prefetch, list[types.Prefetch], None] = None,
  651. query_filter: Optional[types.Filter] = None,
  652. search_params: Optional[types.SearchParams] = None,
  653. limit: int = 10,
  654. group_size: int = 3,
  655. with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
  656. with_vectors: Union[bool, Sequence[str]] = False,
  657. score_threshold: Optional[float] = None,
  658. with_lookup: Optional[types.WithLookupInterface] = None,
  659. lookup_from: Optional[types.LookupLocation] = None,
  660. consistency: Optional[types.ReadConsistency] = None,
  661. shard_key_selector: Optional[types.ShardKeySelector] = None,
  662. timeout: Optional[int] = None,
  663. **kwargs: Any,
  664. ) -> types.GroupsResult:
  665. if self._prefer_grpc:
  666. if query is not None:
  667. query = RestToGrpc.convert_query(query)
  668. if isinstance(prefetch, models.Prefetch):
  669. prefetch = [RestToGrpc.convert_prefetch_query(prefetch)]
  670. if isinstance(prefetch, list):
  671. prefetch = [
  672. RestToGrpc.convert_prefetch_query(p) if isinstance(p, models.Prefetch) else p
  673. for p in prefetch
  674. ]
  675. if isinstance(query_filter, models.Filter):
  676. query_filter = RestToGrpc.convert_filter(model=query_filter)
  677. if isinstance(search_params, models.SearchParams):
  678. search_params = RestToGrpc.convert_search_params(search_params)
  679. if isinstance(with_payload, get_args_subscribed(models.WithPayloadInterface)):
  680. with_payload = RestToGrpc.convert_with_payload_interface(with_payload)
  681. if isinstance(with_vectors, get_args_subscribed(models.WithVector)):
  682. with_vectors = RestToGrpc.convert_with_vectors(with_vectors)
  683. if isinstance(with_lookup, models.WithLookup):
  684. with_lookup = RestToGrpc.convert_with_lookup(with_lookup)
  685. if isinstance(with_lookup, str):
  686. with_lookup = grpc.WithLookup(collection=with_lookup)
  687. if isinstance(lookup_from, models.LookupLocation):
  688. lookup_from = RestToGrpc.convert_lookup_location(lookup_from)
  689. if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
  690. consistency = RestToGrpc.convert_read_consistency(consistency)
  691. if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
  692. shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
  693. result: grpc.QueryGroupsResponse = self.grpc_points.QueryGroups(
  694. grpc.QueryPointGroups(
  695. collection_name=collection_name,
  696. query=query,
  697. prefetch=prefetch,
  698. filter=query_filter,
  699. limit=limit,
  700. with_vectors=with_vectors,
  701. with_payload=with_payload,
  702. params=search_params,
  703. score_threshold=score_threshold,
  704. using=using,
  705. group_by=group_by,
  706. group_size=group_size,
  707. with_lookup=with_lookup,
  708. lookup_from=lookup_from,
  709. timeout=timeout,
  710. shard_key_selector=shard_key_selector,
  711. read_consistency=consistency,
  712. ),
  713. timeout=timeout if timeout is not None else self._timeout,
  714. ).result
  715. return GrpcToRest.convert_groups_result(result)
  716. else:
  717. if isinstance(query, grpc.Query):
  718. query = GrpcToRest.convert_query(query)
  719. if isinstance(prefetch, grpc.PrefetchQuery):
  720. prefetch = GrpcToRest.convert_prefetch_query(prefetch)
  721. if isinstance(prefetch, list):
  722. prefetch = [
  723. GrpcToRest.convert_prefetch_query(p)
  724. if isinstance(p, grpc.PrefetchQuery)
  725. else p
  726. for p in prefetch
  727. ]
  728. if isinstance(query_filter, grpc.Filter):
  729. query_filter = GrpcToRest.convert_filter(model=query_filter)
  730. if isinstance(search_params, grpc.SearchParams):
  731. search_params = GrpcToRest.convert_search_params(search_params)
  732. if isinstance(with_payload, grpc.WithPayloadSelector):
  733. with_payload = GrpcToRest.convert_with_payload_selector(with_payload)
  734. if isinstance(with_lookup, grpc.WithLookup):
  735. with_lookup = GrpcToRest.convert_with_lookup(with_lookup)
  736. if isinstance(lookup_from, grpc.LookupLocation):
  737. lookup_from = GrpcToRest.convert_lookup_location(lookup_from)
  738. query_request = models.QueryGroupsRequest(
  739. shard_key=shard_key_selector,
  740. prefetch=prefetch,
  741. query=query,
  742. using=using,
  743. filter=query_filter,
  744. params=search_params,
  745. score_threshold=score_threshold,
  746. limit=limit,
  747. group_by=group_by,
  748. group_size=group_size,
  749. with_vector=with_vectors,
  750. with_payload=with_payload,
  751. with_lookup=with_lookup,
  752. lookup_from=lookup_from,
  753. )
  754. query_result = self.http.search_api.query_points_groups(
  755. collection_name=collection_name,
  756. consistency=consistency,
  757. timeout=timeout,
  758. query_groups_request=query_request,
  759. )
  760. assert query_result is not None, "Query points groups API returned None"
  761. return query_result.result
  762. def search_groups(
  763. self,
  764. collection_name: str,
  765. query_vector: Union[
  766. Sequence[float],
  767. tuple[str, list[float]],
  768. types.NamedVector,
  769. types.NamedSparseVector,
  770. types.NumpyArray,
  771. ],
  772. group_by: str,
  773. query_filter: Optional[models.Filter] = None,
  774. search_params: Optional[models.SearchParams] = None,
  775. limit: int = 10,
  776. group_size: int = 1,
  777. with_payload: Union[bool, Sequence[str], models.PayloadSelector] = True,
  778. with_vectors: Union[bool, Sequence[str]] = False,
  779. score_threshold: Optional[float] = None,
  780. with_lookup: Optional[types.WithLookupInterface] = None,
  781. consistency: Optional[types.ReadConsistency] = None,
  782. shard_key_selector: Optional[types.ShardKeySelector] = None,
  783. timeout: Optional[int] = None,
  784. **kwargs: Any,
  785. ) -> types.GroupsResult:
  786. if self._prefer_grpc:
  787. vector_name = None
  788. sparse_indices = None
  789. if isinstance(with_lookup, models.WithLookup):
  790. with_lookup = RestToGrpc.convert_with_lookup(with_lookup)
  791. if isinstance(with_lookup, str):
  792. with_lookup = grpc.WithLookup(collection=with_lookup)
  793. if isinstance(query_vector, types.NamedVector):
  794. vector = query_vector.vector
  795. vector_name = query_vector.name
  796. elif isinstance(query_vector, types.NamedSparseVector):
  797. vector_name = query_vector.name
  798. sparse_indices = grpc.SparseIndices(data=query_vector.vector.indices)
  799. vector = query_vector.vector.values
  800. elif isinstance(query_vector, tuple):
  801. vector_name = query_vector[0]
  802. vector = query_vector[1]
  803. else:
  804. vector = list(query_vector)
  805. if isinstance(query_filter, models.Filter):
  806. query_filter = RestToGrpc.convert_filter(model=query_filter)
  807. if isinstance(search_params, models.SearchParams):
  808. search_params = RestToGrpc.convert_search_params(search_params)
  809. if isinstance(with_payload, get_args_subscribed(models.WithPayloadInterface)):
  810. with_payload = RestToGrpc.convert_with_payload_interface(with_payload)
  811. if isinstance(with_vectors, get_args_subscribed(models.WithVector)):
  812. with_vectors = RestToGrpc.convert_with_vectors(with_vectors)
  813. if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
  814. consistency = RestToGrpc.convert_read_consistency(consistency)
  815. if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
  816. shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
  817. result: grpc.GroupsResult = self.grpc_points.SearchGroups(
  818. grpc.SearchPointGroups(
  819. collection_name=collection_name,
  820. vector=vector,
  821. vector_name=vector_name,
  822. filter=query_filter,
  823. limit=limit,
  824. group_size=group_size,
  825. with_vectors=with_vectors,
  826. with_payload=with_payload,
  827. params=search_params,
  828. score_threshold=score_threshold,
  829. group_by=group_by,
  830. read_consistency=consistency,
  831. with_lookup=with_lookup,
  832. timeout=timeout,
  833. sparse_indices=sparse_indices,
  834. shard_key_selector=shard_key_selector,
  835. ),
  836. timeout=timeout if timeout is not None else self._timeout,
  837. ).result
  838. return GrpcToRest.convert_groups_result(result)
  839. else:
  840. if isinstance(with_lookup, grpc.WithLookup):
  841. with_lookup = GrpcToRest.convert_with_lookup(with_lookup)
  842. if isinstance(query_vector, tuple):
  843. query_vector = construct(
  844. models.NamedVector,
  845. name=query_vector[0],
  846. vector=query_vector[1],
  847. )
  848. if isinstance(query_vector, np.ndarray):
  849. query_vector = query_vector.tolist()
  850. if isinstance(query_filter, grpc.Filter):
  851. query_filter = GrpcToRest.convert_filter(model=query_filter)
  852. if isinstance(search_params, grpc.SearchParams):
  853. search_params = GrpcToRest.convert_search_params(search_params)
  854. if isinstance(with_payload, grpc.WithPayloadSelector):
  855. with_payload = GrpcToRest.convert_with_payload_selector(with_payload)
  856. search_groups_request = construct(
  857. models.SearchGroupsRequest,
  858. vector=query_vector,
  859. filter=query_filter,
  860. params=search_params,
  861. with_payload=with_payload,
  862. with_vector=with_vectors,
  863. score_threshold=score_threshold,
  864. group_by=group_by,
  865. group_size=group_size,
  866. limit=limit,
  867. with_lookup=with_lookup,
  868. shard_key=shard_key_selector,
  869. )
  870. return self.openapi_client.search_api.search_point_groups(
  871. search_groups_request=search_groups_request,
  872. collection_name=collection_name,
  873. consistency=consistency,
  874. timeout=timeout,
  875. ).result
  876. def search_matrix_pairs(
  877. self,
  878. collection_name: str,
  879. query_filter: Optional[types.Filter] = None,
  880. limit: int = 3,
  881. sample: int = 10,
  882. using: Optional[str] = None,
  883. consistency: Optional[types.ReadConsistency] = None,
  884. shard_key_selector: Optional[types.ShardKeySelector] = None,
  885. timeout: Optional[int] = None,
  886. **kwargs: Any,
  887. ) -> types.SearchMatrixPairsResponse:
  888. if self._prefer_grpc:
  889. if isinstance(query_filter, models.Filter):
  890. query_filter = RestToGrpc.convert_filter(model=query_filter)
  891. if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
  892. shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
  893. if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
  894. consistency = RestToGrpc.convert_read_consistency(consistency)
  895. response = self.grpc_points.SearchMatrixPairs(
  896. grpc.SearchMatrixPoints(
  897. collection_name=collection_name,
  898. filter=query_filter,
  899. sample=sample,
  900. limit=limit,
  901. using=using,
  902. timeout=timeout,
  903. read_consistency=consistency,
  904. shard_key_selector=shard_key_selector,
  905. ),
  906. timeout=timeout if timeout is not None else self._timeout,
  907. )
  908. return GrpcToRest.convert_search_matrix_pairs(response.result)
  909. if isinstance(query_filter, grpc.Filter):
  910. query_filter = GrpcToRest.convert_filter(model=query_filter)
  911. search_matrix_result = self.openapi_client.search_api.search_matrix_pairs(
  912. collection_name=collection_name,
  913. consistency=consistency,
  914. timeout=timeout,
  915. search_matrix_request=models.SearchMatrixRequest(
  916. shard_key=shard_key_selector,
  917. limit=limit,
  918. sample=sample,
  919. using=using,
  920. filter=query_filter,
  921. ),
  922. ).result
  923. assert search_matrix_result is not None, "Search matrix pairs returned None result"
  924. return search_matrix_result
  925. def search_matrix_offsets(
  926. self,
  927. collection_name: str,
  928. query_filter: Optional[types.Filter] = None,
  929. limit: int = 3,
  930. sample: int = 10,
  931. using: Optional[str] = None,
  932. consistency: Optional[types.ReadConsistency] = None,
  933. shard_key_selector: Optional[types.ShardKeySelector] = None,
  934. timeout: Optional[int] = None,
  935. **kwargs: Any,
  936. ) -> types.SearchMatrixOffsetsResponse:
  937. if self._prefer_grpc:
  938. if isinstance(query_filter, models.Filter):
  939. query_filter = RestToGrpc.convert_filter(model=query_filter)
  940. if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
  941. shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
  942. if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
  943. consistency = RestToGrpc.convert_read_consistency(consistency)
  944. response = self.grpc_points.SearchMatrixOffsets(
  945. grpc.SearchMatrixPoints(
  946. collection_name=collection_name,
  947. filter=query_filter,
  948. sample=sample,
  949. limit=limit,
  950. using=using,
  951. timeout=timeout,
  952. read_consistency=consistency,
  953. shard_key_selector=shard_key_selector,
  954. ),
  955. timeout=timeout if timeout is not None else self._timeout,
  956. )
  957. return GrpcToRest.convert_search_matrix_offsets(response.result)
  958. if isinstance(query_filter, grpc.Filter):
  959. query_filter = GrpcToRest.convert_filter(model=query_filter)
  960. search_matrix_result = self.openapi_client.search_api.search_matrix_offsets(
  961. collection_name=collection_name,
  962. consistency=consistency,
  963. timeout=timeout,
  964. search_matrix_request=models.SearchMatrixRequest(
  965. shard_key=shard_key_selector,
  966. limit=limit,
  967. sample=sample,
  968. using=using,
  969. filter=query_filter,
  970. ),
  971. ).result
  972. assert search_matrix_result is not None, "Search matrix offsets returned None result"
  973. return search_matrix_result
  974. def recommend_batch(
  975. self,
  976. collection_name: str,
  977. requests: Sequence[types.RecommendRequest],
  978. consistency: Optional[types.ReadConsistency] = None,
  979. timeout: Optional[int] = None,
  980. **kwargs: Any,
  981. ) -> list[list[types.ScoredPoint]]:
  982. if self._prefer_grpc:
  983. requests = [
  984. (
  985. RestToGrpc.convert_recommend_request(r, collection_name)
  986. if isinstance(r, models.RecommendRequest)
  987. else r
  988. )
  989. for r in requests
  990. ]
  991. if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
  992. consistency = RestToGrpc.convert_read_consistency(consistency)
  993. grpc_res: grpc.SearchBatchResponse = self.grpc_points.RecommendBatch(
  994. grpc.RecommendBatchPoints(
  995. collection_name=collection_name,
  996. recommend_points=requests,
  997. read_consistency=consistency,
  998. timeout=timeout,
  999. ),
  1000. timeout=timeout if timeout is not None else self._timeout,
  1001. )
  1002. return [
  1003. [GrpcToRest.convert_scored_point(hit) for hit in r.result] for r in grpc_res.result
  1004. ]
  1005. else:
  1006. requests = [
  1007. (
  1008. GrpcToRest.convert_recommend_points(r)
  1009. if isinstance(r, grpc.RecommendPoints)
  1010. else r
  1011. )
  1012. for r in requests
  1013. ]
  1014. http_res: list[list[models.ScoredPoint]] = self.http.search_api.recommend_batch_points(
  1015. collection_name=collection_name,
  1016. consistency=consistency,
  1017. timeout=timeout,
  1018. recommend_request_batch=models.RecommendRequestBatch(searches=requests),
  1019. ).result
  1020. return http_res
  1021. def recommend(
  1022. self,
  1023. collection_name: str,
  1024. positive: Optional[Sequence[types.RecommendExample]] = None,
  1025. negative: Optional[Sequence[types.RecommendExample]] = None,
  1026. query_filter: Optional[types.Filter] = None,
  1027. search_params: Optional[types.SearchParams] = None,
  1028. limit: int = 10,
  1029. offset: int = 0,
  1030. with_payload: Union[bool, list[str], types.PayloadSelector] = True,
  1031. with_vectors: Union[bool, list[str]] = False,
  1032. score_threshold: Optional[float] = None,
  1033. using: Optional[str] = None,
  1034. lookup_from: Optional[types.LookupLocation] = None,
  1035. strategy: Optional[types.RecommendStrategy] = None,
  1036. consistency: Optional[types.ReadConsistency] = None,
  1037. shard_key_selector: Optional[types.ShardKeySelector] = None,
  1038. timeout: Optional[int] = None,
  1039. **kwargs: Any,
  1040. ) -> list[types.ScoredPoint]:
  1041. if positive is None:
  1042. positive = []
  1043. if negative is None:
  1044. negative = []
  1045. if self._prefer_grpc:
  1046. positive_ids = RestToGrpc.convert_recommend_examples_to_ids(positive)
  1047. positive_vectors = RestToGrpc.convert_recommend_examples_to_vectors(positive)
  1048. negative_ids = RestToGrpc.convert_recommend_examples_to_ids(negative)
  1049. negative_vectors = RestToGrpc.convert_recommend_examples_to_vectors(negative)
  1050. if isinstance(query_filter, models.Filter):
  1051. query_filter = RestToGrpc.convert_filter(model=query_filter)
  1052. if isinstance(search_params, models.SearchParams):
  1053. search_params = RestToGrpc.convert_search_params(search_params)
  1054. if isinstance(with_payload, get_args_subscribed(models.WithPayloadInterface)):
  1055. with_payload = RestToGrpc.convert_with_payload_interface(with_payload)
  1056. if isinstance(with_vectors, get_args_subscribed(models.WithVector)):
  1057. with_vectors = RestToGrpc.convert_with_vectors(with_vectors)
  1058. if isinstance(lookup_from, models.LookupLocation):
  1059. lookup_from = RestToGrpc.convert_lookup_location(lookup_from)
  1060. if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
  1061. consistency = RestToGrpc.convert_read_consistency(consistency)
  1062. if isinstance(strategy, (str, models.RecommendStrategy)):
  1063. strategy = RestToGrpc.convert_recommend_strategy(strategy)
  1064. if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
  1065. shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
  1066. res: grpc.SearchResponse = self.grpc_points.Recommend(
  1067. grpc.RecommendPoints(
  1068. collection_name=collection_name,
  1069. positive=positive_ids,
  1070. negative=negative_ids,
  1071. filter=query_filter,
  1072. limit=limit,
  1073. offset=offset,
  1074. with_vectors=with_vectors,
  1075. with_payload=with_payload,
  1076. params=search_params,
  1077. score_threshold=score_threshold,
  1078. using=using,
  1079. lookup_from=lookup_from,
  1080. read_consistency=consistency,
  1081. strategy=strategy,
  1082. positive_vectors=positive_vectors,
  1083. negative_vectors=negative_vectors,
  1084. shard_key_selector=shard_key_selector,
  1085. timeout=timeout,
  1086. ),
  1087. timeout=timeout if timeout is not None else self._timeout,
  1088. )
  1089. return [GrpcToRest.convert_scored_point(hit) for hit in res.result]
  1090. else:
  1091. positive = [
  1092. (
  1093. GrpcToRest.convert_point_id(example)
  1094. if isinstance(example, grpc.PointId)
  1095. else example
  1096. )
  1097. for example in positive
  1098. ]
  1099. negative = [
  1100. (
  1101. GrpcToRest.convert_point_id(example)
  1102. if isinstance(example, grpc.PointId)
  1103. else example
  1104. )
  1105. for example in negative
  1106. ]
  1107. if isinstance(query_filter, grpc.Filter):
  1108. query_filter = GrpcToRest.convert_filter(model=query_filter)
  1109. if isinstance(search_params, grpc.SearchParams):
  1110. search_params = GrpcToRest.convert_search_params(search_params)
  1111. if isinstance(with_payload, grpc.WithPayloadSelector):
  1112. with_payload = GrpcToRest.convert_with_payload_selector(with_payload)
  1113. if isinstance(lookup_from, grpc.LookupLocation):
  1114. lookup_from = GrpcToRest.convert_lookup_location(lookup_from)
  1115. result = self.openapi_client.search_api.recommend_points(
  1116. collection_name=collection_name,
  1117. consistency=consistency,
  1118. timeout=timeout,
  1119. recommend_request=models.RecommendRequest(
  1120. filter=query_filter,
  1121. positive=positive,
  1122. negative=negative,
  1123. params=search_params,
  1124. limit=limit,
  1125. offset=offset,
  1126. with_payload=with_payload,
  1127. with_vector=with_vectors,
  1128. score_threshold=score_threshold,
  1129. lookup_from=lookup_from,
  1130. using=using,
  1131. strategy=strategy,
  1132. shard_key=shard_key_selector,
  1133. ),
  1134. ).result
  1135. assert result is not None, "Recommend points API returned None"
  1136. return result
  1137. def recommend_groups(
  1138. self,
  1139. collection_name: str,
  1140. group_by: str,
  1141. positive: Optional[Sequence[Union[types.PointId, list[float]]]] = None,
  1142. negative: Optional[Sequence[Union[types.PointId, list[float]]]] = None,
  1143. query_filter: Optional[models.Filter] = None,
  1144. search_params: Optional[models.SearchParams] = None,
  1145. limit: int = 10,
  1146. group_size: int = 1,
  1147. score_threshold: Optional[float] = None,
  1148. with_payload: Union[bool, Sequence[str], models.PayloadSelector] = True,
  1149. with_vectors: Union[bool, Sequence[str]] = False,
  1150. using: Optional[str] = None,
  1151. lookup_from: Optional[models.LookupLocation] = None,
  1152. with_lookup: Optional[types.WithLookupInterface] = None,
  1153. strategy: Optional[types.RecommendStrategy] = None,
  1154. consistency: Optional[types.ReadConsistency] = None,
  1155. shard_key_selector: Optional[types.ShardKeySelector] = None,
  1156. timeout: Optional[int] = None,
  1157. **kwargs: Any,
  1158. ) -> types.GroupsResult:
  1159. positive = positive if positive is not None else []
  1160. negative = negative if negative is not None else []
  1161. if self._prefer_grpc:
  1162. if isinstance(with_lookup, models.WithLookup):
  1163. with_lookup = RestToGrpc.convert_with_lookup(with_lookup)
  1164. if isinstance(with_lookup, str):
  1165. with_lookup = grpc.WithLookup(collection=with_lookup)
  1166. positive_ids = RestToGrpc.convert_recommend_examples_to_ids(positive)
  1167. positive_vectors = RestToGrpc.convert_recommend_examples_to_vectors(positive)
  1168. negative_ids = RestToGrpc.convert_recommend_examples_to_ids(negative)
  1169. negative_vectors = RestToGrpc.convert_recommend_examples_to_vectors(negative)
  1170. if isinstance(query_filter, models.Filter):
  1171. query_filter = RestToGrpc.convert_filter(model=query_filter)
  1172. if isinstance(search_params, models.SearchParams):
  1173. search_params = RestToGrpc.convert_search_params(search_params)
  1174. if isinstance(with_payload, get_args_subscribed(models.WithPayloadInterface)):
  1175. with_payload = RestToGrpc.convert_with_payload_interface(with_payload)
  1176. if isinstance(with_vectors, get_args_subscribed(models.WithVector)):
  1177. with_vectors = RestToGrpc.convert_with_vectors(with_vectors)
  1178. if isinstance(lookup_from, models.LookupLocation):
  1179. lookup_from = RestToGrpc.convert_lookup_location(lookup_from)
  1180. if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
  1181. consistency = RestToGrpc.convert_read_consistency(consistency)
  1182. if isinstance(strategy, (str, models.RecommendStrategy)):
  1183. strategy = RestToGrpc.convert_recommend_strategy(strategy)
  1184. if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
  1185. shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
  1186. res: grpc.GroupsResult = self.grpc_points.RecommendGroups(
  1187. grpc.RecommendPointGroups(
  1188. collection_name=collection_name,
  1189. positive=positive_ids,
  1190. negative=negative_ids,
  1191. filter=query_filter,
  1192. group_by=group_by,
  1193. limit=limit,
  1194. group_size=group_size,
  1195. with_vectors=with_vectors,
  1196. with_payload=with_payload,
  1197. params=search_params,
  1198. score_threshold=score_threshold,
  1199. using=using,
  1200. lookup_from=lookup_from,
  1201. read_consistency=consistency,
  1202. with_lookup=with_lookup,
  1203. strategy=strategy,
  1204. positive_vectors=positive_vectors,
  1205. negative_vectors=negative_vectors,
  1206. shard_key_selector=shard_key_selector,
  1207. timeout=timeout,
  1208. ),
  1209. timeout=timeout if timeout is not None else self._timeout,
  1210. ).result
  1211. assert res is not None, "Recommend groups API returned None"
  1212. return GrpcToRest.convert_groups_result(res)
  1213. else:
  1214. if isinstance(with_lookup, grpc.WithLookup):
  1215. with_lookup = GrpcToRest.convert_with_lookup(with_lookup)
  1216. positive = [
  1217. (
  1218. GrpcToRest.convert_point_id(point_id)
  1219. if isinstance(point_id, grpc.PointId)
  1220. else point_id
  1221. )
  1222. for point_id in positive
  1223. ]
  1224. negative = [
  1225. (
  1226. GrpcToRest.convert_point_id(point_id)
  1227. if isinstance(point_id, grpc.PointId)
  1228. else point_id
  1229. )
  1230. for point_id in negative
  1231. ]
  1232. if isinstance(query_filter, grpc.Filter):
  1233. query_filter = GrpcToRest.convert_filter(model=query_filter)
  1234. if isinstance(search_params, grpc.SearchParams):
  1235. search_params = GrpcToRest.convert_search_params(search_params)
  1236. if isinstance(with_payload, grpc.WithPayloadSelector):
  1237. with_payload = GrpcToRest.convert_with_payload_selector(with_payload)
  1238. if isinstance(lookup_from, grpc.LookupLocation):
  1239. lookup_from = GrpcToRest.convert_lookup_location(lookup_from)
  1240. result = self.openapi_client.search_api.recommend_point_groups(
  1241. collection_name=collection_name,
  1242. consistency=consistency,
  1243. timeout=timeout,
  1244. recommend_groups_request=construct(
  1245. models.RecommendGroupsRequest,
  1246. positive=positive,
  1247. negative=negative,
  1248. filter=query_filter,
  1249. group_by=group_by,
  1250. limit=limit,
  1251. group_size=group_size,
  1252. params=search_params,
  1253. with_payload=with_payload,
  1254. with_vector=with_vectors,
  1255. score_threshold=score_threshold,
  1256. lookup_from=lookup_from,
  1257. using=using,
  1258. with_lookup=with_lookup,
  1259. strategy=strategy,
  1260. shard_key=shard_key_selector,
  1261. ),
  1262. ).result
  1263. assert result is not None, "Recommend points API returned None"
  1264. return result
  1265. def discover(
  1266. self,
  1267. collection_name: str,
  1268. target: Optional[types.TargetVector] = None,
  1269. context: Optional[Sequence[types.ContextExamplePair]] = None,
  1270. query_filter: Optional[types.Filter] = None,
  1271. search_params: Optional[types.SearchParams] = None,
  1272. limit: int = 10,
  1273. offset: int = 0,
  1274. with_payload: Union[bool, list[str], types.PayloadSelector] = True,
  1275. with_vectors: Union[bool, list[str]] = False,
  1276. using: Optional[str] = None,
  1277. lookup_from: Optional[types.LookupLocation] = None,
  1278. consistency: Optional[types.ReadConsistency] = None,
  1279. shard_key_selector: Optional[types.ShardKeySelector] = None,
  1280. timeout: Optional[int] = None,
  1281. **kwargs: Any,
  1282. ) -> list[types.ScoredPoint]:
  1283. if context is None:
  1284. context = []
  1285. if self._prefer_grpc:
  1286. target = (
  1287. RestToGrpc.convert_target_vector(target)
  1288. if target is not None
  1289. and isinstance(target, get_args_subscribed(models.RecommendExample))
  1290. else target
  1291. )
  1292. context = [
  1293. (
  1294. RestToGrpc.convert_context_example_pair(pair)
  1295. if isinstance(pair, models.ContextExamplePair)
  1296. else pair
  1297. )
  1298. for pair in context
  1299. ]
  1300. if isinstance(query_filter, models.Filter):
  1301. query_filter = RestToGrpc.convert_filter(model=query_filter)
  1302. if isinstance(search_params, models.SearchParams):
  1303. search_params = RestToGrpc.convert_search_params(search_params)
  1304. if isinstance(with_payload, get_args_subscribed(models.WithPayloadInterface)):
  1305. with_payload = RestToGrpc.convert_with_payload_interface(with_payload)
  1306. if isinstance(with_vectors, get_args_subscribed(models.WithVector)):
  1307. with_vectors = RestToGrpc.convert_with_vectors(with_vectors)
  1308. if isinstance(lookup_from, models.LookupLocation):
  1309. lookup_from = RestToGrpc.convert_lookup_location(lookup_from)
  1310. if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
  1311. consistency = RestToGrpc.convert_read_consistency(consistency)
  1312. if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
  1313. shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
  1314. res: grpc.SearchResponse = self.grpc_points.Discover(
  1315. grpc.DiscoverPoints(
  1316. collection_name=collection_name,
  1317. target=target,
  1318. context=context,
  1319. filter=query_filter,
  1320. limit=limit,
  1321. offset=offset,
  1322. with_vectors=with_vectors,
  1323. with_payload=with_payload,
  1324. params=search_params,
  1325. using=using,
  1326. lookup_from=lookup_from,
  1327. read_consistency=consistency,
  1328. shard_key_selector=shard_key_selector,
  1329. timeout=timeout,
  1330. ),
  1331. timeout=timeout if timeout is not None else self._timeout,
  1332. )
  1333. return [GrpcToRest.convert_scored_point(hit) for hit in res.result]
  1334. else:
  1335. target = (
  1336. GrpcToRest.convert_target_vector(target)
  1337. if target is not None and isinstance(target, grpc.TargetVector)
  1338. else target
  1339. )
  1340. context = [
  1341. (
  1342. GrpcToRest.convert_context_example_pair(pair)
  1343. if isinstance(pair, grpc.ContextExamplePair)
  1344. else pair
  1345. )
  1346. for pair in context
  1347. ]
  1348. if isinstance(query_filter, grpc.Filter):
  1349. query_filter = GrpcToRest.convert_filter(model=query_filter)
  1350. if isinstance(search_params, grpc.SearchParams):
  1351. search_params = GrpcToRest.convert_search_params(search_params)
  1352. if isinstance(with_payload, grpc.WithPayloadSelector):
  1353. with_payload = GrpcToRest.convert_with_payload_selector(with_payload)
  1354. if isinstance(lookup_from, grpc.LookupLocation):
  1355. lookup_from = GrpcToRest.convert_lookup_location(lookup_from)
  1356. result = self.openapi_client.search_api.discover_points(
  1357. collection_name=collection_name,
  1358. consistency=consistency,
  1359. timeout=timeout,
  1360. discover_request=models.DiscoverRequest(
  1361. target=target,
  1362. context=context,
  1363. filter=query_filter,
  1364. params=search_params,
  1365. limit=limit,
  1366. offset=offset,
  1367. with_payload=with_payload,
  1368. with_vector=with_vectors,
  1369. lookup_from=lookup_from,
  1370. using=using,
  1371. shard_key=shard_key_selector,
  1372. ),
  1373. ).result
  1374. assert result is not None, "Discover points API returned None"
  1375. return result
  1376. def discover_batch(
  1377. self,
  1378. collection_name: str,
  1379. requests: Sequence[types.DiscoverRequest],
  1380. consistency: Optional[types.ReadConsistency] = None,
  1381. timeout: Optional[int] = None,
  1382. **kwargs: Any,
  1383. ) -> list[list[types.ScoredPoint]]:
  1384. if self._prefer_grpc:
  1385. requests = [
  1386. (
  1387. RestToGrpc.convert_discover_request(r, collection_name)
  1388. if isinstance(r, models.DiscoverRequest)
  1389. else r
  1390. )
  1391. for r in requests
  1392. ]
  1393. grpc_res: grpc.SearchBatchResponse = self.grpc_points.DiscoverBatch(
  1394. grpc.DiscoverBatchPoints(
  1395. collection_name=collection_name,
  1396. discover_points=requests,
  1397. read_consistency=consistency,
  1398. timeout=timeout,
  1399. ),
  1400. timeout=timeout if timeout is not None else self._timeout,
  1401. )
  1402. return [
  1403. [GrpcToRest.convert_scored_point(hit) for hit in r.result] for r in grpc_res.result
  1404. ]
  1405. else:
  1406. requests = [
  1407. (
  1408. GrpcToRest.convert_discover_points(r)
  1409. if isinstance(r, grpc.DiscoverPoints)
  1410. else r
  1411. )
  1412. for r in requests
  1413. ]
  1414. http_res: list[list[models.ScoredPoint]] = self.http.search_api.discover_batch_points(
  1415. collection_name=collection_name,
  1416. discover_request_batch=models.DiscoverRequestBatch(searches=requests),
  1417. consistency=consistency,
  1418. timeout=timeout,
  1419. ).result
  1420. return http_res
  1421. def scroll(
  1422. self,
  1423. collection_name: str,
  1424. scroll_filter: Optional[types.Filter] = None,
  1425. limit: int = 10,
  1426. order_by: Optional[types.OrderBy] = None,
  1427. offset: Optional[types.PointId] = None,
  1428. with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
  1429. with_vectors: Union[bool, Sequence[str]] = False,
  1430. consistency: Optional[types.ReadConsistency] = None,
  1431. shard_key_selector: Optional[types.ShardKeySelector] = None,
  1432. timeout: Optional[int] = None,
  1433. **kwargs: Any,
  1434. ) -> tuple[list[types.Record], Optional[types.PointId]]:
  1435. if self._prefer_grpc:
  1436. if isinstance(offset, get_args_subscribed(models.ExtendedPointId)):
  1437. offset = RestToGrpc.convert_extended_point_id(offset)
  1438. if isinstance(scroll_filter, models.Filter):
  1439. scroll_filter = RestToGrpc.convert_filter(model=scroll_filter)
  1440. if isinstance(with_payload, get_args_subscribed(models.WithPayloadInterface)):
  1441. with_payload = RestToGrpc.convert_with_payload_interface(with_payload)
  1442. if isinstance(with_vectors, get_args_subscribed(models.WithVector)):
  1443. with_vectors = RestToGrpc.convert_with_vectors(with_vectors)
  1444. if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
  1445. consistency = RestToGrpc.convert_read_consistency(consistency)
  1446. if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
  1447. shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
  1448. if isinstance(order_by, get_args_subscribed(models.OrderByInterface)):
  1449. order_by = RestToGrpc.convert_order_by_interface(order_by)
  1450. res: grpc.ScrollResponse = self.grpc_points.Scroll(
  1451. grpc.ScrollPoints(
  1452. collection_name=collection_name,
  1453. filter=scroll_filter,
  1454. order_by=order_by,
  1455. offset=offset,
  1456. with_vectors=with_vectors,
  1457. with_payload=with_payload,
  1458. limit=limit,
  1459. read_consistency=consistency,
  1460. shard_key_selector=shard_key_selector,
  1461. timeout=timeout,
  1462. ),
  1463. timeout=timeout if timeout is not None else self._timeout,
  1464. )
  1465. return [GrpcToRest.convert_retrieved_point(point) for point in res.result], (
  1466. GrpcToRest.convert_point_id(res.next_page_offset)
  1467. if res.HasField("next_page_offset")
  1468. else None
  1469. )
  1470. else:
  1471. if isinstance(offset, grpc.PointId):
  1472. offset = GrpcToRest.convert_point_id(offset)
  1473. if isinstance(scroll_filter, grpc.Filter):
  1474. scroll_filter = GrpcToRest.convert_filter(model=scroll_filter)
  1475. if isinstance(order_by, grpc.OrderBy):
  1476. order_by = GrpcToRest.convert_order_by(order_by)
  1477. if isinstance(with_payload, grpc.WithPayloadSelector):
  1478. with_payload = GrpcToRest.convert_with_payload_selector(with_payload)
  1479. scroll_result: Optional[models.ScrollResult] = (
  1480. self.openapi_client.points_api.scroll_points(
  1481. collection_name=collection_name,
  1482. consistency=consistency,
  1483. scroll_request=models.ScrollRequest(
  1484. filter=scroll_filter,
  1485. limit=limit,
  1486. order_by=order_by,
  1487. offset=offset,
  1488. with_payload=with_payload,
  1489. with_vector=with_vectors,
  1490. shard_key=shard_key_selector,
  1491. ),
  1492. timeout=timeout,
  1493. ).result
  1494. )
  1495. assert scroll_result is not None, "Scroll points API returned None result"
  1496. return scroll_result.points, scroll_result.next_page_offset
  1497. def count(
  1498. self,
  1499. collection_name: str,
  1500. count_filter: Optional[types.Filter] = None,
  1501. exact: bool = True,
  1502. shard_key_selector: Optional[types.ShardKeySelector] = None,
  1503. timeout: Optional[int] = None,
  1504. **kwargs: Any,
  1505. ) -> types.CountResult:
  1506. if self._prefer_grpc:
  1507. if isinstance(count_filter, models.Filter):
  1508. count_filter = RestToGrpc.convert_filter(model=count_filter)
  1509. if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
  1510. shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
  1511. response = self.grpc_points.Count(
  1512. grpc.CountPoints(
  1513. collection_name=collection_name,
  1514. filter=count_filter,
  1515. exact=exact,
  1516. shard_key_selector=shard_key_selector,
  1517. timeout=timeout,
  1518. ),
  1519. timeout=timeout if timeout is not None else self._timeout,
  1520. ).result
  1521. return GrpcToRest.convert_count_result(response)
  1522. if isinstance(count_filter, grpc.Filter):
  1523. count_filter = GrpcToRest.convert_filter(model=count_filter)
  1524. count_result = self.openapi_client.points_api.count_points(
  1525. collection_name=collection_name,
  1526. count_request=models.CountRequest(
  1527. filter=count_filter,
  1528. exact=exact,
  1529. shard_key=shard_key_selector,
  1530. ),
  1531. timeout=timeout,
  1532. ).result
  1533. assert count_result is not None, "Count points returned None result"
  1534. return count_result
  1535. def facet(
  1536. self,
  1537. collection_name: str,
  1538. key: str,
  1539. facet_filter: Optional[types.Filter] = None,
  1540. limit: int = 10,
  1541. exact: bool = False,
  1542. timeout: Optional[int] = None,
  1543. consistency: Optional[types.ReadConsistency] = None,
  1544. shard_key_selector: Optional[types.ShardKeySelector] = None,
  1545. **kwargs: Any,
  1546. ) -> types.FacetResponse:
  1547. if self._prefer_grpc:
  1548. if isinstance(facet_filter, models.Filter):
  1549. facet_filter = RestToGrpc.convert_filter(model=facet_filter)
  1550. if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
  1551. shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
  1552. if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
  1553. consistency = RestToGrpc.convert_read_consistency(consistency)
  1554. response = self.grpc_points.Facet(
  1555. grpc.FacetCounts(
  1556. collection_name=collection_name,
  1557. key=key,
  1558. filter=facet_filter,
  1559. limit=limit,
  1560. exact=exact,
  1561. timeout=timeout,
  1562. read_consistency=consistency,
  1563. shard_key_selector=shard_key_selector,
  1564. ),
  1565. timeout=timeout if timeout is not None else self._timeout,
  1566. )
  1567. return types.FacetResponse(
  1568. hits=[GrpcToRest.convert_facet_value_hit(hit) for hit in response.hits]
  1569. )
  1570. if isinstance(facet_filter, grpc.Filter):
  1571. facet_filter = GrpcToRest.convert_filter(model=facet_filter)
  1572. facet_result = self.openapi_client.points_api.facet(
  1573. collection_name=collection_name,
  1574. consistency=consistency,
  1575. timeout=timeout,
  1576. facet_request=models.FacetRequest(
  1577. shard_key=shard_key_selector,
  1578. key=key,
  1579. limit=limit,
  1580. filter=facet_filter,
  1581. exact=exact,
  1582. ),
  1583. ).result
  1584. assert facet_result is not None, "Facet points returned None result"
  1585. return facet_result
  1586. def upsert(
  1587. self,
  1588. collection_name: str,
  1589. points: types.Points,
  1590. wait: bool = True,
  1591. ordering: Optional[types.WriteOrdering] = None,
  1592. shard_key_selector: Optional[types.ShardKeySelector] = None,
  1593. **kwargs: Any,
  1594. ) -> types.UpdateResult:
  1595. if self._prefer_grpc:
  1596. if isinstance(points, models.Batch):
  1597. vectors_batch: list[grpc.Vectors] = RestToGrpc.convert_batch_vector_struct(
  1598. points.vectors, len(points.ids)
  1599. )
  1600. points = [
  1601. grpc.PointStruct(
  1602. id=RestToGrpc.convert_extended_point_id(points.ids[idx]),
  1603. vectors=vectors_batch[idx],
  1604. payload=(
  1605. RestToGrpc.convert_payload(points.payloads[idx])
  1606. if points.payloads is not None
  1607. else None
  1608. ),
  1609. )
  1610. for idx in range(len(points.ids))
  1611. ]
  1612. if isinstance(points, list):
  1613. points = [
  1614. (
  1615. RestToGrpc.convert_point_struct(point)
  1616. if isinstance(point, models.PointStruct)
  1617. else point
  1618. )
  1619. for point in points
  1620. ]
  1621. if isinstance(ordering, models.WriteOrdering):
  1622. ordering = RestToGrpc.convert_write_ordering(ordering)
  1623. if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
  1624. shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
  1625. grpc_result = self.grpc_points.Upsert(
  1626. grpc.UpsertPoints(
  1627. collection_name=collection_name,
  1628. wait=wait,
  1629. points=points,
  1630. ordering=ordering,
  1631. shard_key_selector=shard_key_selector,
  1632. ),
  1633. timeout=self._timeout,
  1634. ).result
  1635. assert grpc_result is not None, "Upsert returned None result"
  1636. return GrpcToRest.convert_update_result(grpc_result)
  1637. else:
  1638. if isinstance(points, list):
  1639. points = [
  1640. (
  1641. GrpcToRest.convert_point_struct(point)
  1642. if isinstance(point, grpc.PointStruct)
  1643. else point
  1644. )
  1645. for point in points
  1646. ]
  1647. points = models.PointsList(points=points, shard_key=shard_key_selector)
  1648. if isinstance(points, models.Batch):
  1649. points = models.PointsBatch(batch=points, shard_key=shard_key_selector)
  1650. http_result = self.openapi_client.points_api.upsert_points(
  1651. collection_name=collection_name,
  1652. wait=wait,
  1653. point_insert_operations=points,
  1654. ordering=ordering,
  1655. ).result
  1656. assert http_result is not None, "Upsert returned None result"
  1657. return http_result
  1658. def update_vectors(
  1659. self,
  1660. collection_name: str,
  1661. points: Sequence[types.PointVectors],
  1662. wait: bool = True,
  1663. ordering: Optional[types.WriteOrdering] = None,
  1664. shard_key_selector: Optional[types.ShardKeySelector] = None,
  1665. **kwargs: Any,
  1666. ) -> types.UpdateResult:
  1667. if self._prefer_grpc:
  1668. points = [RestToGrpc.convert_point_vectors(point) for point in points]
  1669. if isinstance(ordering, models.WriteOrdering):
  1670. ordering = RestToGrpc.convert_write_ordering(ordering)
  1671. if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
  1672. shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
  1673. grpc_result = self.grpc_points.UpdateVectors(
  1674. grpc.UpdatePointVectors(
  1675. collection_name=collection_name,
  1676. wait=wait,
  1677. points=points,
  1678. ordering=ordering,
  1679. shard_key_selector=shard_key_selector,
  1680. ),
  1681. timeout=self._timeout,
  1682. ).result
  1683. assert grpc_result is not None, "Upsert returned None result"
  1684. return GrpcToRest.convert_update_result(grpc_result)
  1685. else:
  1686. return self.openapi_client.points_api.update_vectors(
  1687. collection_name=collection_name,
  1688. wait=wait,
  1689. update_vectors=models.UpdateVectors(
  1690. points=points,
  1691. shard_key=shard_key_selector,
  1692. ),
  1693. ordering=ordering,
  1694. ).result
  1695. def delete_vectors(
  1696. self,
  1697. collection_name: str,
  1698. vectors: Sequence[str],
  1699. points: types.PointsSelector,
  1700. wait: bool = True,
  1701. ordering: Optional[types.WriteOrdering] = None,
  1702. shard_key_selector: Optional[types.ShardKeySelector] = None,
  1703. **kwargs: Any,
  1704. ) -> types.UpdateResult:
  1705. if self._prefer_grpc:
  1706. points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(points)
  1707. shard_key_selector = shard_key_selector or opt_shard_key_selector
  1708. if isinstance(ordering, models.WriteOrdering):
  1709. ordering = RestToGrpc.convert_write_ordering(ordering)
  1710. if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
  1711. shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
  1712. grpc_result = self.grpc_points.DeleteVectors(
  1713. grpc.DeletePointVectors(
  1714. collection_name=collection_name,
  1715. wait=wait,
  1716. vectors=grpc.VectorsSelector(
  1717. names=vectors,
  1718. ),
  1719. points_selector=points_selector,
  1720. ordering=ordering,
  1721. shard_key_selector=shard_key_selector,
  1722. ),
  1723. timeout=self._timeout,
  1724. ).result
  1725. assert grpc_result is not None, "Delete vectors returned None result"
  1726. return GrpcToRest.convert_update_result(grpc_result)
  1727. else:
  1728. _points, _filter = self._try_argument_to_rest_points_and_filter(points)
  1729. return self.openapi_client.points_api.delete_vectors(
  1730. collection_name=collection_name,
  1731. wait=wait,
  1732. ordering=ordering,
  1733. delete_vectors=construct(
  1734. models.DeleteVectors,
  1735. vector=vectors,
  1736. points=_points,
  1737. filter=_filter,
  1738. shard_key=shard_key_selector,
  1739. ),
  1740. ).result
  1741. def retrieve(
  1742. self,
  1743. collection_name: str,
  1744. ids: Sequence[types.PointId],
  1745. with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
  1746. with_vectors: Union[bool, Sequence[str]] = False,
  1747. consistency: Optional[types.ReadConsistency] = None,
  1748. shard_key_selector: Optional[types.ShardKeySelector] = None,
  1749. timeout: Optional[int] = None,
  1750. **kwargs: Any,
  1751. ) -> list[types.Record]:
  1752. if self._prefer_grpc:
  1753. if isinstance(with_payload, get_args_subscribed(models.WithPayloadInterface)):
  1754. with_payload = RestToGrpc.convert_with_payload_interface(with_payload)
  1755. ids = [
  1756. (
  1757. RestToGrpc.convert_extended_point_id(idx)
  1758. if isinstance(idx, get_args_subscribed(models.ExtendedPointId))
  1759. else idx
  1760. )
  1761. for idx in ids
  1762. ]
  1763. with_vectors = RestToGrpc.convert_with_vectors(with_vectors)
  1764. if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
  1765. consistency = RestToGrpc.convert_read_consistency(consistency)
  1766. if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
  1767. shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
  1768. result = self.grpc_points.Get(
  1769. grpc.GetPoints(
  1770. collection_name=collection_name,
  1771. ids=ids,
  1772. with_payload=with_payload,
  1773. with_vectors=with_vectors,
  1774. read_consistency=consistency,
  1775. shard_key_selector=shard_key_selector,
  1776. timeout=timeout,
  1777. ),
  1778. timeout=timeout if timeout is not None else self._timeout,
  1779. ).result
  1780. assert result is not None, "Retrieve returned None result"
  1781. return [GrpcToRest.convert_retrieved_point(record) for record in result]
  1782. else:
  1783. if isinstance(with_payload, grpc.WithPayloadSelector):
  1784. with_payload = GrpcToRest.convert_with_payload_selector(with_payload)
  1785. ids = [
  1786. (GrpcToRest.convert_point_id(idx) if isinstance(idx, grpc.PointId) else idx)
  1787. for idx in ids
  1788. ]
  1789. http_result = self.openapi_client.points_api.get_points(
  1790. collection_name=collection_name,
  1791. consistency=consistency,
  1792. point_request=models.PointRequest(
  1793. ids=ids,
  1794. with_payload=with_payload,
  1795. with_vector=with_vectors,
  1796. shard_key=shard_key_selector,
  1797. ),
  1798. timeout=timeout,
  1799. ).result
  1800. assert http_result is not None, "Retrieve API returned None result"
  1801. return http_result
  1802. @classmethod
  1803. def _try_argument_to_grpc_selector(
  1804. cls, points: types.PointsSelector
  1805. ) -> tuple[grpc.PointsSelector, Optional[grpc.ShardKeySelector]]:
  1806. shard_key_selector = None
  1807. if isinstance(points, list):
  1808. points_selector = grpc.PointsSelector(
  1809. points=grpc.PointsIdsList(
  1810. ids=[
  1811. (
  1812. RestToGrpc.convert_extended_point_id(idx)
  1813. if isinstance(idx, get_args_subscribed(models.ExtendedPointId))
  1814. else idx
  1815. )
  1816. for idx in points
  1817. ]
  1818. )
  1819. )
  1820. elif isinstance(points, grpc.PointsSelector):
  1821. points_selector = points
  1822. elif isinstance(points, get_args(models.PointsSelector)):
  1823. if points.shard_key is not None:
  1824. shard_key_selector = RestToGrpc.convert_shard_key_selector(points.shard_key)
  1825. points_selector = RestToGrpc.convert_points_selector(points)
  1826. elif isinstance(points, models.Filter):
  1827. points_selector = RestToGrpc.convert_points_selector(
  1828. construct(models.FilterSelector, filter=points)
  1829. )
  1830. elif isinstance(points, grpc.Filter):
  1831. points_selector = grpc.PointsSelector(filter=points)
  1832. else:
  1833. raise ValueError(f"Unsupported points selector type: {type(points)}")
  1834. return points_selector, shard_key_selector
  1835. @classmethod
  1836. def _try_argument_to_rest_selector(
  1837. cls,
  1838. points: types.PointsSelector,
  1839. shard_key_selector: Optional[types.ShardKeySelector],
  1840. ) -> models.PointsSelector:
  1841. if isinstance(points, list):
  1842. _points = [
  1843. (GrpcToRest.convert_point_id(idx) if isinstance(idx, grpc.PointId) else idx)
  1844. for idx in points
  1845. ]
  1846. points_selector = construct(
  1847. models.PointIdsList,
  1848. points=_points,
  1849. shard_key=shard_key_selector,
  1850. )
  1851. elif isinstance(points, grpc.PointsSelector):
  1852. points_selector = GrpcToRest.convert_points_selector(points)
  1853. points_selector.shard_key = shard_key_selector
  1854. elif isinstance(points, get_args(models.PointsSelector)):
  1855. points_selector = points
  1856. points_selector.shard_key = shard_key_selector
  1857. elif isinstance(points, models.Filter):
  1858. points_selector = construct(
  1859. models.FilterSelector, filter=points, shard_key=shard_key_selector
  1860. )
  1861. elif isinstance(points, grpc.Filter):
  1862. points_selector = construct(
  1863. models.FilterSelector,
  1864. filter=GrpcToRest.convert_filter(points),
  1865. shard_key=shard_key_selector,
  1866. )
  1867. else:
  1868. raise ValueError(f"Unsupported points selector type: {type(points)}")
  1869. return points_selector
  1870. @classmethod
  1871. def _points_selector_to_points_list(
  1872. cls, points_selector: grpc.PointsSelector
  1873. ) -> list[grpc.PointId]:
  1874. name = points_selector.WhichOneof("points_selector_one_of")
  1875. if name is None:
  1876. return []
  1877. val = getattr(points_selector, name)
  1878. if name == "points":
  1879. return list(val.ids)
  1880. return []
  1881. @classmethod
  1882. def _try_argument_to_rest_points_and_filter(
  1883. cls, points: types.PointsSelector
  1884. ) -> tuple[Optional[list[models.ExtendedPointId]], Optional[models.Filter]]:
  1885. _points = None
  1886. _filter = None
  1887. if isinstance(points, list):
  1888. _points = [
  1889. (GrpcToRest.convert_point_id(idx) if isinstance(idx, grpc.PointId) else idx)
  1890. for idx in points
  1891. ]
  1892. elif isinstance(points, grpc.PointsSelector):
  1893. selector = GrpcToRest.convert_points_selector(points)
  1894. if isinstance(selector, models.PointIdsList):
  1895. _points = selector.points
  1896. elif isinstance(selector, models.FilterSelector):
  1897. _filter = selector.filter
  1898. elif isinstance(points, models.PointIdsList):
  1899. _points = points.points
  1900. elif isinstance(points, models.FilterSelector):
  1901. _filter = points.filter
  1902. elif isinstance(points, models.Filter):
  1903. _filter = points
  1904. elif isinstance(points, grpc.Filter):
  1905. _filter = GrpcToRest.convert_filter(points)
  1906. else:
  1907. raise ValueError(f"Unsupported points selector type: {type(points)}")
  1908. return _points, _filter
  1909. def delete(
  1910. self,
  1911. collection_name: str,
  1912. points_selector: types.PointsSelector,
  1913. wait: bool = True,
  1914. ordering: Optional[types.WriteOrdering] = None,
  1915. shard_key_selector: Optional[types.ShardKeySelector] = None,
  1916. **kwargs: Any,
  1917. ) -> types.UpdateResult:
  1918. if self._prefer_grpc:
  1919. points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(
  1920. points_selector
  1921. )
  1922. shard_key_selector = shard_key_selector or opt_shard_key_selector
  1923. if isinstance(ordering, models.WriteOrdering):
  1924. ordering = RestToGrpc.convert_write_ordering(ordering)
  1925. if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
  1926. shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
  1927. return GrpcToRest.convert_update_result(
  1928. self.grpc_points.Delete(
  1929. grpc.DeletePoints(
  1930. collection_name=collection_name,
  1931. wait=wait,
  1932. points=points_selector,
  1933. ordering=ordering,
  1934. shard_key_selector=shard_key_selector,
  1935. ),
  1936. timeout=self._timeout,
  1937. ).result
  1938. )
  1939. else:
  1940. points_selector = self._try_argument_to_rest_selector(
  1941. points_selector, shard_key_selector
  1942. )
  1943. result: Optional[types.UpdateResult] = self.openapi_client.points_api.delete_points(
  1944. collection_name=collection_name,
  1945. wait=wait,
  1946. points_selector=points_selector,
  1947. ordering=ordering,
  1948. ).result
  1949. assert result is not None, "Delete points returned None"
  1950. return result
  1951. def set_payload(
  1952. self,
  1953. collection_name: str,
  1954. payload: types.Payload,
  1955. points: types.PointsSelector,
  1956. key: Optional[str] = None,
  1957. wait: bool = True,
  1958. ordering: Optional[types.WriteOrdering] = None,
  1959. shard_key_selector: Optional[types.ShardKeySelector] = None,
  1960. **kwargs: Any,
  1961. ) -> types.UpdateResult:
  1962. if self._prefer_grpc:
  1963. points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(points)
  1964. shard_key_selector = shard_key_selector or opt_shard_key_selector
  1965. if isinstance(ordering, models.WriteOrdering):
  1966. ordering = RestToGrpc.convert_write_ordering(ordering)
  1967. if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
  1968. shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
  1969. return GrpcToRest.convert_update_result(
  1970. self.grpc_points.SetPayload(
  1971. grpc.SetPayloadPoints(
  1972. collection_name=collection_name,
  1973. wait=wait,
  1974. payload=RestToGrpc.convert_payload(payload),
  1975. points_selector=points_selector,
  1976. ordering=ordering,
  1977. shard_key_selector=shard_key_selector,
  1978. key=key,
  1979. ),
  1980. timeout=self._timeout,
  1981. ).result
  1982. )
  1983. else:
  1984. _points, _filter = self._try_argument_to_rest_points_and_filter(points)
  1985. result: Optional[types.UpdateResult] = self.openapi_client.points_api.set_payload(
  1986. collection_name=collection_name,
  1987. wait=wait,
  1988. ordering=ordering,
  1989. set_payload=models.SetPayload(
  1990. payload=payload,
  1991. points=_points,
  1992. filter=_filter,
  1993. shard_key=shard_key_selector,
  1994. key=key,
  1995. ),
  1996. ).result
  1997. assert result is not None, "Set payload returned None"
  1998. return result
  1999. def overwrite_payload(
  2000. self,
  2001. collection_name: str,
  2002. payload: types.Payload,
  2003. points: types.PointsSelector,
  2004. wait: bool = True,
  2005. ordering: Optional[types.WriteOrdering] = None,
  2006. shard_key_selector: Optional[types.ShardKeySelector] = None,
  2007. **kwargs: Any,
  2008. ) -> types.UpdateResult:
  2009. if self._prefer_grpc:
  2010. points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(points)
  2011. shard_key_selector = shard_key_selector or opt_shard_key_selector
  2012. if isinstance(ordering, models.WriteOrdering):
  2013. ordering = RestToGrpc.convert_write_ordering(ordering)
  2014. if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
  2015. shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
  2016. return GrpcToRest.convert_update_result(
  2017. self.grpc_points.OverwritePayload(
  2018. grpc.SetPayloadPoints(
  2019. collection_name=collection_name,
  2020. wait=wait,
  2021. payload=RestToGrpc.convert_payload(payload),
  2022. points_selector=points_selector,
  2023. ordering=ordering,
  2024. shard_key_selector=shard_key_selector,
  2025. ),
  2026. timeout=self._timeout,
  2027. ).result
  2028. )
  2029. else:
  2030. _points, _filter = self._try_argument_to_rest_points_and_filter(points)
  2031. result: Optional[types.UpdateResult] = (
  2032. self.openapi_client.points_api.overwrite_payload(
  2033. collection_name=collection_name,
  2034. wait=wait,
  2035. ordering=ordering,
  2036. set_payload=models.SetPayload(
  2037. payload=payload,
  2038. points=_points,
  2039. filter=_filter,
  2040. shard_key=shard_key_selector,
  2041. ),
  2042. ).result
  2043. )
  2044. assert result is not None, "Overwrite payload returned None"
  2045. return result
  2046. def delete_payload(
  2047. self,
  2048. collection_name: str,
  2049. keys: Sequence[str],
  2050. points: types.PointsSelector,
  2051. wait: bool = True,
  2052. ordering: Optional[types.WriteOrdering] = None,
  2053. shard_key_selector: Optional[types.ShardKeySelector] = None,
  2054. **kwargs: Any,
  2055. ) -> types.UpdateResult:
  2056. if self._prefer_grpc:
  2057. points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(points)
  2058. shard_key_selector = shard_key_selector or opt_shard_key_selector
  2059. if isinstance(ordering, models.WriteOrdering):
  2060. ordering = RestToGrpc.convert_write_ordering(ordering)
  2061. if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
  2062. shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
  2063. return GrpcToRest.convert_update_result(
  2064. self.grpc_points.DeletePayload(
  2065. grpc.DeletePayloadPoints(
  2066. collection_name=collection_name,
  2067. wait=wait,
  2068. keys=keys,
  2069. points_selector=points_selector,
  2070. ordering=ordering,
  2071. shard_key_selector=shard_key_selector,
  2072. ),
  2073. timeout=self._timeout,
  2074. ).result
  2075. )
  2076. else:
  2077. _points, _filter = self._try_argument_to_rest_points_and_filter(points)
  2078. result: Optional[types.UpdateResult] = self.openapi_client.points_api.delete_payload(
  2079. collection_name=collection_name,
  2080. wait=wait,
  2081. ordering=ordering,
  2082. delete_payload=models.DeletePayload(
  2083. keys=keys,
  2084. points=_points,
  2085. filter=_filter,
  2086. shard_key=shard_key_selector,
  2087. ),
  2088. ).result
  2089. assert result is not None, "Delete payload returned None"
  2090. return result
  2091. def clear_payload(
  2092. self,
  2093. collection_name: str,
  2094. points_selector: types.PointsSelector,
  2095. wait: bool = True,
  2096. ordering: Optional[types.WriteOrdering] = None,
  2097. shard_key_selector: Optional[types.ShardKeySelector] = None,
  2098. **kwargs: Any,
  2099. ) -> types.UpdateResult:
  2100. if self._prefer_grpc:
  2101. points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(
  2102. points_selector
  2103. )
  2104. shard_key_selector = shard_key_selector or opt_shard_key_selector
  2105. if isinstance(ordering, models.WriteOrdering):
  2106. ordering = RestToGrpc.convert_write_ordering(ordering)
  2107. if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
  2108. shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
  2109. return GrpcToRest.convert_update_result(
  2110. self.grpc_points.ClearPayload(
  2111. grpc.ClearPayloadPoints(
  2112. collection_name=collection_name,
  2113. wait=wait,
  2114. points=points_selector,
  2115. ordering=ordering,
  2116. shard_key_selector=shard_key_selector,
  2117. ),
  2118. timeout=self._timeout,
  2119. ).result
  2120. )
  2121. else:
  2122. points_selector = self._try_argument_to_rest_selector(
  2123. points_selector, shard_key_selector
  2124. )
  2125. result: Optional[types.UpdateResult] = self.openapi_client.points_api.clear_payload(
  2126. collection_name=collection_name,
  2127. wait=wait,
  2128. ordering=ordering,
  2129. points_selector=points_selector,
  2130. ).result
  2131. assert result is not None, "Clear payload returned None"
  2132. return result
  2133. def batch_update_points(
  2134. self,
  2135. collection_name: str,
  2136. update_operations: Sequence[types.UpdateOperation],
  2137. wait: bool = True,
  2138. ordering: Optional[types.WriteOrdering] = None,
  2139. **kwargs: Any,
  2140. ) -> list[types.UpdateResult]:
  2141. if self._prefer_grpc:
  2142. update_operations = [
  2143. RestToGrpc.convert_update_operation(operation) for operation in update_operations
  2144. ]
  2145. if isinstance(ordering, models.WriteOrdering):
  2146. ordering = RestToGrpc.convert_write_ordering(ordering)
  2147. return [
  2148. GrpcToRest.convert_update_result(result)
  2149. for result in self.grpc_points.UpdateBatch(
  2150. grpc.UpdateBatchPoints(
  2151. collection_name=collection_name,
  2152. wait=wait,
  2153. operations=update_operations,
  2154. ordering=ordering,
  2155. ),
  2156. timeout=self._timeout,
  2157. ).result
  2158. ]
  2159. else:
  2160. result: Optional[list[types.UpdateResult]] = (
  2161. self.openapi_client.points_api.batch_update(
  2162. collection_name=collection_name,
  2163. wait=wait,
  2164. ordering=ordering,
  2165. update_operations=models.UpdateOperations(operations=update_operations),
  2166. ).result
  2167. )
  2168. assert result is not None, "Batch update points returned None"
  2169. return result
  2170. def update_collection_aliases(
  2171. self,
  2172. change_aliases_operations: Sequence[types.AliasOperations],
  2173. timeout: Optional[int] = None,
  2174. **kwargs: Any,
  2175. ) -> bool:
  2176. if self._prefer_grpc:
  2177. change_aliases_operation = [
  2178. (
  2179. RestToGrpc.convert_alias_operations(operation)
  2180. if not isinstance(operation, grpc.AliasOperations)
  2181. else operation
  2182. )
  2183. for operation in change_aliases_operations
  2184. ]
  2185. return self.grpc_collections.UpdateAliases(
  2186. grpc.ChangeAliases(
  2187. timeout=timeout,
  2188. actions=change_aliases_operation,
  2189. ),
  2190. timeout=timeout if timeout is not None else self._timeout,
  2191. ).result
  2192. change_aliases_operation = [
  2193. (
  2194. GrpcToRest.convert_alias_operations(operation)
  2195. if isinstance(operation, grpc.AliasOperations)
  2196. else operation
  2197. )
  2198. for operation in change_aliases_operations
  2199. ]
  2200. result: Optional[bool] = self.http.aliases_api.update_aliases(
  2201. timeout=timeout,
  2202. change_aliases_operation=models.ChangeAliasesOperation(
  2203. actions=change_aliases_operation
  2204. ),
  2205. ).result
  2206. assert result is not None, "Update aliases returned None"
  2207. return result
  2208. def get_collection_aliases(
  2209. self, collection_name: str, **kwargs: Any
  2210. ) -> types.CollectionsAliasesResponse:
  2211. if self._prefer_grpc:
  2212. response = self.grpc_collections.ListCollectionAliases(
  2213. grpc.ListCollectionAliasesRequest(collection_name=collection_name),
  2214. timeout=self._timeout,
  2215. ).aliases
  2216. return types.CollectionsAliasesResponse(
  2217. aliases=[
  2218. GrpcToRest.convert_alias_description(description) for description in response
  2219. ]
  2220. )
  2221. result: Optional[types.CollectionsAliasesResponse] = (
  2222. self.http.aliases_api.get_collection_aliases(collection_name=collection_name).result
  2223. )
  2224. assert result is not None, "Get collection aliases returned None"
  2225. return result
  2226. def get_aliases(self, **kwargs: Any) -> types.CollectionsAliasesResponse:
  2227. if self._prefer_grpc:
  2228. response = self.grpc_collections.ListAliases(
  2229. grpc.ListAliasesRequest(), timeout=self._timeout
  2230. ).aliases
  2231. return types.CollectionsAliasesResponse(
  2232. aliases=[
  2233. GrpcToRest.convert_alias_description(description) for description in response
  2234. ]
  2235. )
  2236. result: Optional[types.CollectionsAliasesResponse] = (
  2237. self.http.aliases_api.get_collections_aliases().result
  2238. )
  2239. assert result is not None, "Get aliases returned None"
  2240. return result
  2241. def get_collections(self, **kwargs: Any) -> types.CollectionsResponse:
  2242. if self._prefer_grpc:
  2243. response = self.grpc_collections.List(
  2244. grpc.ListCollectionsRequest(), timeout=self._timeout
  2245. ).collections
  2246. return types.CollectionsResponse(
  2247. collections=[
  2248. GrpcToRest.convert_collection_description(description)
  2249. for description in response
  2250. ]
  2251. )
  2252. result: Optional[types.CollectionsResponse] = (
  2253. self.http.collections_api.get_collections().result
  2254. )
  2255. assert result is not None, "Get collections returned None"
  2256. return result
  2257. def get_collection(self, collection_name: str, **kwargs: Any) -> types.CollectionInfo:
  2258. if self._prefer_grpc:
  2259. return GrpcToRest.convert_collection_info(
  2260. self.grpc_collections.Get(
  2261. grpc.GetCollectionInfoRequest(collection_name=collection_name),
  2262. timeout=self._timeout,
  2263. ).result
  2264. )
  2265. result: Optional[types.CollectionInfo] = self.http.collections_api.get_collection(
  2266. collection_name=collection_name
  2267. ).result
  2268. assert result is not None, "Get collection returned None"
  2269. return result
  2270. def collection_exists(self, collection_name: str, **kwargs: Any) -> bool:
  2271. if self._prefer_grpc:
  2272. return self.grpc_collections.CollectionExists(
  2273. grpc.CollectionExistsRequest(collection_name=collection_name),
  2274. timeout=self._timeout,
  2275. ).result.exists
  2276. result: Optional[models.CollectionExistence] = self.http.collections_api.collection_exists(
  2277. collection_name=collection_name
  2278. ).result
  2279. assert result is not None, "Collection exists returned None"
  2280. return result.exists
  2281. def update_collection(
  2282. self,
  2283. collection_name: str,
  2284. optimizers_config: Optional[types.OptimizersConfigDiff] = None,
  2285. collection_params: Optional[types.CollectionParamsDiff] = None,
  2286. vectors_config: Optional[types.VectorsConfigDiff] = None,
  2287. hnsw_config: Optional[types.HnswConfigDiff] = None,
  2288. quantization_config: Optional[types.QuantizationConfigDiff] = None,
  2289. timeout: Optional[int] = None,
  2290. sparse_vectors_config: Optional[Mapping[str, types.SparseVectorParams]] = None,
  2291. strict_mode_config: Optional[types.StrictModeConfig] = None,
  2292. **kwargs: Any,
  2293. ) -> bool:
  2294. if self._prefer_grpc:
  2295. if isinstance(optimizers_config, models.OptimizersConfigDiff):
  2296. optimizers_config = RestToGrpc.convert_optimizers_config_diff(optimizers_config)
  2297. if isinstance(collection_params, models.CollectionParamsDiff):
  2298. collection_params = RestToGrpc.convert_collection_params_diff(collection_params)
  2299. if isinstance(vectors_config, dict):
  2300. vectors_config = RestToGrpc.convert_vectors_config_diff(vectors_config)
  2301. if isinstance(hnsw_config, models.HnswConfigDiff):
  2302. hnsw_config = RestToGrpc.convert_hnsw_config_diff(hnsw_config)
  2303. if isinstance(quantization_config, get_args(models.QuantizationConfigDiff)):
  2304. quantization_config = RestToGrpc.convert_quantization_config_diff(
  2305. quantization_config
  2306. )
  2307. if isinstance(sparse_vectors_config, dict):
  2308. sparse_vectors_config = RestToGrpc.convert_sparse_vector_config(
  2309. sparse_vectors_config
  2310. )
  2311. if isinstance(strict_mode_config, models.StrictModeConfig):
  2312. strict_mode_config = RestToGrpc.convert_strict_mode_config(strict_mode_config)
  2313. return self.grpc_collections.Update(
  2314. grpc.UpdateCollection(
  2315. collection_name=collection_name,
  2316. optimizers_config=optimizers_config,
  2317. params=collection_params,
  2318. vectors_config=vectors_config,
  2319. hnsw_config=hnsw_config,
  2320. quantization_config=quantization_config,
  2321. sparse_vectors_config=sparse_vectors_config,
  2322. strict_mode_config=strict_mode_config,
  2323. timeout=timeout,
  2324. ),
  2325. timeout=timeout if timeout is not None else self._timeout,
  2326. ).result
  2327. if isinstance(optimizers_config, grpc.OptimizersConfigDiff):
  2328. optimizers_config = GrpcToRest.convert_optimizers_config_diff(optimizers_config)
  2329. if isinstance(collection_params, grpc.CollectionParamsDiff):
  2330. collection_params = GrpcToRest.convert_collection_params_diff(collection_params)
  2331. if isinstance(vectors_config, grpc.VectorsConfigDiff):
  2332. vectors_config = GrpcToRest.convert_vectors_config_diff(vectors_config)
  2333. if isinstance(hnsw_config, grpc.HnswConfigDiff):
  2334. hnsw_config = GrpcToRest.convert_hnsw_config_diff(hnsw_config)
  2335. if isinstance(quantization_config, grpc.QuantizationConfigDiff):
  2336. quantization_config = GrpcToRest.convert_quantization_config_diff(quantization_config)
  2337. result: Optional[bool] = self.http.collections_api.update_collection(
  2338. collection_name,
  2339. update_collection=models.UpdateCollection(
  2340. optimizers_config=optimizers_config,
  2341. params=collection_params,
  2342. vectors=vectors_config,
  2343. hnsw_config=hnsw_config,
  2344. quantization_config=quantization_config,
  2345. sparse_vectors=sparse_vectors_config,
  2346. strict_mode_config=strict_mode_config,
  2347. ),
  2348. timeout=timeout,
  2349. ).result
  2350. assert result is not None, "Update collection returned None"
  2351. return result
  2352. def delete_collection(
  2353. self, collection_name: str, timeout: Optional[int] = None, **kwargs: Any
  2354. ) -> bool:
  2355. if self._prefer_grpc:
  2356. return self.grpc_collections.Delete(
  2357. grpc.DeleteCollection(collection_name=collection_name, timeout=timeout),
  2358. timeout=timeout if timeout is not None else self._timeout,
  2359. ).result
  2360. result: Optional[bool] = self.http.collections_api.delete_collection(
  2361. collection_name, timeout=timeout
  2362. ).result
  2363. assert result is not None, "Delete collection returned None"
  2364. return result
  2365. def create_collection(
  2366. self,
  2367. collection_name: str,
  2368. vectors_config: Optional[
  2369. Union[types.VectorParams, Mapping[str, types.VectorParams]]
  2370. ] = None,
  2371. shard_number: Optional[int] = None,
  2372. replication_factor: Optional[int] = None,
  2373. write_consistency_factor: Optional[int] = None,
  2374. on_disk_payload: Optional[bool] = None,
  2375. hnsw_config: Optional[types.HnswConfigDiff] = None,
  2376. optimizers_config: Optional[types.OptimizersConfigDiff] = None,
  2377. wal_config: Optional[types.WalConfigDiff] = None,
  2378. quantization_config: Optional[types.QuantizationConfig] = None,
  2379. init_from: Optional[types.InitFrom] = None,
  2380. timeout: Optional[int] = None,
  2381. sparse_vectors_config: Optional[Mapping[str, types.SparseVectorParams]] = None,
  2382. sharding_method: Optional[types.ShardingMethod] = None,
  2383. strict_mode_config: Optional[types.StrictModeConfig] = None,
  2384. **kwargs: Any,
  2385. ) -> bool:
  2386. if init_from is not None:
  2387. show_warning_once(
  2388. message="init_from is deprecated",
  2389. category=DeprecationWarning,
  2390. stacklevel=5,
  2391. idx="create-collection-init-from",
  2392. )
  2393. if self._prefer_grpc:
  2394. if isinstance(vectors_config, (models.VectorParams, dict)):
  2395. vectors_config = RestToGrpc.convert_vectors_config(vectors_config)
  2396. if isinstance(hnsw_config, models.HnswConfigDiff):
  2397. hnsw_config = RestToGrpc.convert_hnsw_config_diff(hnsw_config)
  2398. if isinstance(optimizers_config, models.OptimizersConfigDiff):
  2399. optimizers_config = RestToGrpc.convert_optimizers_config_diff(optimizers_config)
  2400. if isinstance(wal_config, models.WalConfigDiff):
  2401. wal_config = RestToGrpc.convert_wal_config_diff(wal_config)
  2402. if isinstance(
  2403. quantization_config,
  2404. get_args(models.QuantizationConfig),
  2405. ):
  2406. quantization_config = RestToGrpc.convert_quantization_config(quantization_config)
  2407. if isinstance(init_from, models.InitFrom):
  2408. init_from = RestToGrpc.convert_init_from(init_from)
  2409. if isinstance(sparse_vectors_config, dict):
  2410. sparse_vectors_config = RestToGrpc.convert_sparse_vector_config(
  2411. sparse_vectors_config
  2412. )
  2413. if isinstance(sharding_method, models.ShardingMethod):
  2414. sharding_method = RestToGrpc.convert_sharding_method(sharding_method)
  2415. if isinstance(strict_mode_config, models.StrictModeConfig):
  2416. strict_mode_config = RestToGrpc.convert_strict_mode_config(strict_mode_config)
  2417. create_collection = grpc.CreateCollection(
  2418. collection_name=collection_name,
  2419. hnsw_config=hnsw_config,
  2420. wal_config=wal_config,
  2421. optimizers_config=optimizers_config,
  2422. shard_number=shard_number,
  2423. on_disk_payload=on_disk_payload,
  2424. timeout=timeout,
  2425. vectors_config=vectors_config,
  2426. replication_factor=replication_factor,
  2427. write_consistency_factor=write_consistency_factor,
  2428. init_from_collection=init_from,
  2429. quantization_config=quantization_config,
  2430. sparse_vectors_config=sparse_vectors_config,
  2431. sharding_method=sharding_method,
  2432. strict_mode_config=strict_mode_config,
  2433. )
  2434. return self.grpc_collections.Create(create_collection, timeout=self._timeout).result
  2435. if isinstance(hnsw_config, grpc.HnswConfigDiff):
  2436. hnsw_config = GrpcToRest.convert_hnsw_config_diff(hnsw_config)
  2437. if isinstance(optimizers_config, grpc.OptimizersConfigDiff):
  2438. optimizers_config = GrpcToRest.convert_optimizers_config_diff(optimizers_config)
  2439. if isinstance(wal_config, grpc.WalConfigDiff):
  2440. wal_config = GrpcToRest.convert_wal_config_diff(wal_config)
  2441. if isinstance(quantization_config, grpc.QuantizationConfig):
  2442. quantization_config = GrpcToRest.convert_quantization_config(quantization_config)
  2443. if isinstance(init_from, str):
  2444. init_from = GrpcToRest.convert_init_from(init_from)
  2445. create_collection_request = models.CreateCollection(
  2446. vectors=vectors_config,
  2447. shard_number=shard_number,
  2448. replication_factor=replication_factor,
  2449. write_consistency_factor=write_consistency_factor,
  2450. on_disk_payload=on_disk_payload,
  2451. hnsw_config=hnsw_config,
  2452. optimizers_config=optimizers_config,
  2453. wal_config=wal_config,
  2454. quantization_config=quantization_config,
  2455. init_from=init_from,
  2456. sparse_vectors=sparse_vectors_config,
  2457. sharding_method=sharding_method,
  2458. strict_mode_config=strict_mode_config,
  2459. )
  2460. result: Optional[bool] = self.http.collections_api.create_collection(
  2461. collection_name=collection_name,
  2462. create_collection=create_collection_request,
  2463. timeout=timeout,
  2464. ).result
  2465. assert result is not None, "Create collection returned None"
  2466. return result
  2467. def recreate_collection(
  2468. self,
  2469. collection_name: str,
  2470. vectors_config: Union[types.VectorParams, Mapping[str, types.VectorParams]],
  2471. shard_number: Optional[int] = None,
  2472. replication_factor: Optional[int] = None,
  2473. write_consistency_factor: Optional[int] = None,
  2474. on_disk_payload: Optional[bool] = None,
  2475. hnsw_config: Optional[types.HnswConfigDiff] = None,
  2476. optimizers_config: Optional[types.OptimizersConfigDiff] = None,
  2477. wal_config: Optional[types.WalConfigDiff] = None,
  2478. quantization_config: Optional[types.QuantizationConfig] = None,
  2479. init_from: Optional[types.InitFrom] = None,
  2480. timeout: Optional[int] = None,
  2481. sparse_vectors_config: Optional[Mapping[str, types.SparseVectorParams]] = None,
  2482. sharding_method: Optional[types.ShardingMethod] = None,
  2483. strict_mode_config: Optional[types.StrictModeConfig] = None,
  2484. **kwargs: Any,
  2485. ) -> bool:
  2486. self.delete_collection(collection_name, timeout=timeout)
  2487. return self.create_collection(
  2488. collection_name=collection_name,
  2489. vectors_config=vectors_config,
  2490. shard_number=shard_number,
  2491. replication_factor=replication_factor,
  2492. write_consistency_factor=write_consistency_factor,
  2493. on_disk_payload=on_disk_payload,
  2494. hnsw_config=hnsw_config,
  2495. optimizers_config=optimizers_config,
  2496. wal_config=wal_config,
  2497. quantization_config=quantization_config,
  2498. init_from=init_from,
  2499. timeout=timeout,
  2500. sparse_vectors_config=sparse_vectors_config,
  2501. sharding_method=sharding_method,
  2502. strict_mode_config=strict_mode_config,
  2503. )
  2504. @property
  2505. def _updater_class(self) -> Type[BaseUploader]:
  2506. if self._prefer_grpc:
  2507. return GrpcBatchUploader
  2508. else:
  2509. return RestBatchUploader
  2510. def _upload_collection(
  2511. self,
  2512. batches_iterator: Iterable,
  2513. collection_name: str,
  2514. max_retries: int,
  2515. parallel: int = 1,
  2516. method: Optional[str] = None,
  2517. wait: bool = False,
  2518. shard_key_selector: Optional[types.ShardKeySelector] = None,
  2519. ) -> None:
  2520. if method is not None:
  2521. if method in get_all_start_methods():
  2522. start_method = method
  2523. else:
  2524. raise ValueError(
  2525. f"Start methods {method} is not available, available methods: {get_all_start_methods()}"
  2526. )
  2527. else:
  2528. start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
  2529. if self._prefer_grpc:
  2530. updater_kwargs = {
  2531. "collection_name": collection_name,
  2532. "host": self._host,
  2533. "port": self._grpc_port,
  2534. "max_retries": max_retries,
  2535. "ssl": self._https,
  2536. "metadata": self._grpc_headers,
  2537. "wait": wait,
  2538. "shard_key_selector": shard_key_selector,
  2539. "options": self._grpc_options,
  2540. "timeout": self._timeout,
  2541. }
  2542. else:
  2543. updater_kwargs = {
  2544. "collection_name": collection_name,
  2545. "uri": self.rest_uri,
  2546. "max_retries": max_retries,
  2547. "wait": wait,
  2548. "shard_key_selector": shard_key_selector,
  2549. **self._rest_args,
  2550. }
  2551. if parallel == 1:
  2552. updater = self._updater_class.start(**updater_kwargs)
  2553. for _ in updater.process(batches_iterator):
  2554. pass
  2555. else:
  2556. pool = ParallelWorkerPool(parallel, self._updater_class, start_method=start_method)
  2557. for _ in pool.unordered_map(batches_iterator, **updater_kwargs):
  2558. pass
  2559. def upload_records(
  2560. self,
  2561. collection_name: str,
  2562. records: Iterable[types.Record],
  2563. batch_size: int = 64,
  2564. parallel: int = 1,
  2565. method: Optional[str] = None,
  2566. max_retries: int = 3,
  2567. wait: bool = False,
  2568. shard_key_selector: Optional[types.ShardKeySelector] = None,
  2569. **kwargs: Any,
  2570. ) -> None:
  2571. batches_iterator = self._updater_class.iterate_records_batches(
  2572. records=records, batch_size=batch_size
  2573. )
  2574. self._upload_collection(
  2575. batches_iterator=batches_iterator,
  2576. collection_name=collection_name,
  2577. max_retries=max_retries,
  2578. parallel=parallel,
  2579. method=method,
  2580. shard_key_selector=shard_key_selector,
  2581. wait=wait,
  2582. )
  2583. def upload_points(
  2584. self,
  2585. collection_name: str,
  2586. points: Iterable[types.PointStruct],
  2587. batch_size: int = 64,
  2588. parallel: int = 1,
  2589. method: Optional[str] = None,
  2590. max_retries: int = 3,
  2591. wait: bool = False,
  2592. shard_key_selector: Optional[types.ShardKeySelector] = None,
  2593. **kwargs: Any,
  2594. ) -> None:
  2595. batches_iterator = self._updater_class.iterate_records_batches(
  2596. records=points, batch_size=batch_size
  2597. )
  2598. self._upload_collection(
  2599. batches_iterator=batches_iterator,
  2600. collection_name=collection_name,
  2601. max_retries=max_retries,
  2602. parallel=parallel,
  2603. method=method,
  2604. wait=wait,
  2605. shard_key_selector=shard_key_selector,
  2606. )
  2607. def upload_collection(
  2608. self,
  2609. collection_name: str,
  2610. vectors: Union[
  2611. dict[str, types.NumpyArray], types.NumpyArray, Iterable[types.VectorStruct]
  2612. ],
  2613. payload: Optional[Iterable[dict[Any, Any]]] = None,
  2614. ids: Optional[Iterable[types.PointId]] = None,
  2615. batch_size: int = 64,
  2616. parallel: int = 1,
  2617. method: Optional[str] = None,
  2618. max_retries: int = 3,
  2619. wait: bool = False,
  2620. shard_key_selector: Optional[types.ShardKeySelector] = None,
  2621. **kwargs: Any,
  2622. ) -> None:
  2623. batches_iterator = self._updater_class.iterate_batches(
  2624. vectors=vectors,
  2625. payload=payload,
  2626. ids=ids,
  2627. batch_size=batch_size,
  2628. )
  2629. self._upload_collection(
  2630. batches_iterator=batches_iterator,
  2631. collection_name=collection_name,
  2632. max_retries=max_retries,
  2633. parallel=parallel,
  2634. method=method,
  2635. wait=wait,
  2636. shard_key_selector=shard_key_selector,
  2637. )
  2638. def create_payload_index(
  2639. self,
  2640. collection_name: str,
  2641. field_name: str,
  2642. field_schema: Optional[types.PayloadSchemaType] = None,
  2643. field_type: Optional[types.PayloadSchemaType] = None,
  2644. wait: bool = True,
  2645. ordering: Optional[types.WriteOrdering] = None,
  2646. **kwargs: Any,
  2647. ) -> types.UpdateResult:
  2648. if field_type is not None:
  2649. show_warning_once(
  2650. message="field_type is deprecated, use field_schema instead",
  2651. category=DeprecationWarning,
  2652. stacklevel=5,
  2653. idx="payload-index-field-type",
  2654. )
  2655. field_schema = field_type
  2656. if self._prefer_grpc:
  2657. field_index_params = None
  2658. if isinstance(field_schema, models.PayloadSchemaType):
  2659. field_schema = RestToGrpc.convert_payload_schema_type(field_schema)
  2660. if isinstance(field_schema, str):
  2661. field_schema = RestToGrpc.convert_payload_schema_type(
  2662. models.PayloadSchemaType(field_schema)
  2663. )
  2664. if isinstance(field_schema, int):
  2665. # There are no means to distinguish grpc.PayloadSchemaType and grpc.FieldType,
  2666. # as both of them are just ints
  2667. # method signature assumes that grpc.PayloadSchemaType is passed,
  2668. # otherwise the value will be corrupted
  2669. field_schema = grpc_payload_schema_to_field_type(field_schema)
  2670. if isinstance(field_schema, get_args(models.PayloadSchemaParams)):
  2671. field_schema = RestToGrpc.convert_payload_schema_params(field_schema)
  2672. if isinstance(field_schema, grpc.PayloadIndexParams):
  2673. field_index_params = field_schema
  2674. name = field_index_params.WhichOneof("index_params")
  2675. index_params = getattr(field_index_params, name)
  2676. if isinstance(index_params, grpc.TextIndexParams):
  2677. field_schema = grpc.FieldType.FieldTypeText
  2678. if isinstance(index_params, grpc.IntegerIndexParams):
  2679. field_schema = grpc.FieldType.FieldTypeInteger
  2680. if isinstance(index_params, grpc.KeywordIndexParams):
  2681. field_schema = grpc.FieldType.FieldTypeKeyword
  2682. if isinstance(index_params, grpc.FloatIndexParams):
  2683. field_schema = grpc.FieldType.FieldTypeFloat
  2684. if isinstance(index_params, grpc.GeoIndexParams):
  2685. field_schema = grpc.FieldType.FieldTypeGeo
  2686. if isinstance(index_params, grpc.BoolIndexParams):
  2687. field_schema = grpc.FieldType.FieldTypeBool
  2688. if isinstance(index_params, grpc.DatetimeIndexParams):
  2689. field_schema = grpc.FieldType.FieldTypeDatetime
  2690. if isinstance(index_params, grpc.UuidIndexParams):
  2691. field_schema = grpc.FieldType.FieldTypeUuid
  2692. request = grpc.CreateFieldIndexCollection(
  2693. collection_name=collection_name,
  2694. field_name=field_name,
  2695. field_type=field_schema,
  2696. field_index_params=field_index_params,
  2697. wait=wait,
  2698. ordering=ordering,
  2699. )
  2700. return GrpcToRest.convert_update_result(
  2701. self.grpc_points.CreateFieldIndex(request, timeout=self._timeout).result
  2702. )
  2703. if isinstance(field_schema, int): # type(grpc.PayloadSchemaType) == int
  2704. field_schema = GrpcToRest.convert_payload_schema_type(field_schema)
  2705. if isinstance(field_schema, grpc.PayloadIndexParams):
  2706. field_schema = GrpcToRest.convert_payload_schema_params(field_schema)
  2707. result: Optional[types.UpdateResult] = self.openapi_client.indexes_api.create_field_index(
  2708. collection_name=collection_name,
  2709. create_field_index=models.CreateFieldIndex(
  2710. field_name=field_name, field_schema=field_schema
  2711. ),
  2712. wait=wait,
  2713. ordering=ordering,
  2714. ).result
  2715. assert result is not None, "Create field index returned None"
  2716. return result
  2717. def delete_payload_index(
  2718. self,
  2719. collection_name: str,
  2720. field_name: str,
  2721. wait: bool = True,
  2722. ordering: Optional[types.WriteOrdering] = None,
  2723. **kwargs: Any,
  2724. ) -> types.UpdateResult:
  2725. if self._prefer_grpc:
  2726. request = grpc.DeleteFieldIndexCollection(
  2727. collection_name=collection_name,
  2728. field_name=field_name,
  2729. wait=wait,
  2730. ordering=ordering,
  2731. )
  2732. return GrpcToRest.convert_update_result(
  2733. self.grpc_points.DeleteFieldIndex(request, timeout=self._timeout).result
  2734. )
  2735. result: Optional[types.UpdateResult] = self.openapi_client.indexes_api.delete_field_index(
  2736. collection_name=collection_name,
  2737. field_name=field_name,
  2738. wait=wait,
  2739. ordering=ordering,
  2740. ).result
  2741. assert result is not None, "Delete field index returned None"
  2742. return result
  2743. def list_snapshots(
  2744. self, collection_name: str, **kwargs: Any
  2745. ) -> list[types.SnapshotDescription]:
  2746. if self._prefer_grpc:
  2747. snapshots = self.grpc_snapshots.List(
  2748. grpc.ListSnapshotsRequest(collection_name=collection_name), timeout=self._timeout
  2749. ).snapshot_descriptions
  2750. return [GrpcToRest.convert_snapshot_description(snapshot) for snapshot in snapshots]
  2751. snapshots = self.openapi_client.snapshots_api.list_snapshots(
  2752. collection_name=collection_name
  2753. ).result
  2754. assert snapshots is not None, "List snapshots API returned None result"
  2755. return snapshots
  2756. def create_snapshot(
  2757. self, collection_name: str, wait: bool = True, **kwargs: Any
  2758. ) -> Optional[types.SnapshotDescription]:
  2759. if self._prefer_grpc:
  2760. snapshot = self.grpc_snapshots.Create(
  2761. grpc.CreateSnapshotRequest(collection_name=collection_name), timeout=self._timeout
  2762. ).snapshot_description
  2763. return GrpcToRest.convert_snapshot_description(snapshot)
  2764. return self.openapi_client.snapshots_api.create_snapshot(
  2765. collection_name=collection_name, wait=wait
  2766. ).result
  2767. def delete_snapshot(
  2768. self, collection_name: str, snapshot_name: str, wait: bool = True, **kwargs: Any
  2769. ) -> Optional[bool]:
  2770. if self._prefer_grpc:
  2771. self.grpc_snapshots.Delete(
  2772. grpc.DeleteSnapshotRequest(
  2773. collection_name=collection_name, snapshot_name=snapshot_name
  2774. ),
  2775. timeout=self._timeout,
  2776. )
  2777. return True
  2778. return self.openapi_client.snapshots_api.delete_snapshot(
  2779. collection_name=collection_name,
  2780. snapshot_name=snapshot_name,
  2781. wait=wait,
  2782. ).result
  2783. def list_full_snapshots(self, **kwargs: Any) -> list[types.SnapshotDescription]:
  2784. if self._prefer_grpc:
  2785. snapshots = self.grpc_snapshots.ListFull(
  2786. grpc.ListFullSnapshotsRequest(),
  2787. timeout=self._timeout,
  2788. ).snapshot_descriptions
  2789. return [GrpcToRest.convert_snapshot_description(snapshot) for snapshot in snapshots]
  2790. snapshots = self.openapi_client.snapshots_api.list_full_snapshots().result
  2791. assert snapshots is not None, "List full snapshots API returned None result"
  2792. return snapshots
  2793. def create_full_snapshot(self, wait: bool = True, **kwargs: Any) -> types.SnapshotDescription:
  2794. if self._prefer_grpc:
  2795. snapshot_description = self.grpc_snapshots.CreateFull(
  2796. grpc.CreateFullSnapshotRequest(), timeout=self._timeout
  2797. ).snapshot_description
  2798. return GrpcToRest.convert_snapshot_description(snapshot_description)
  2799. return self.openapi_client.snapshots_api.create_full_snapshot(wait=wait).result
  2800. def delete_full_snapshot(
  2801. self, snapshot_name: str, wait: bool = True, **kwargs: Any
  2802. ) -> Optional[bool]:
  2803. if self._prefer_grpc:
  2804. self.grpc_snapshots.DeleteFull(
  2805. grpc.DeleteFullSnapshotRequest(snapshot_name=snapshot_name),
  2806. timeout=self._timeout,
  2807. )
  2808. return True
  2809. return self.openapi_client.snapshots_api.delete_full_snapshot(
  2810. snapshot_name=snapshot_name, wait=wait
  2811. ).result
  2812. def recover_snapshot(
  2813. self,
  2814. collection_name: str,
  2815. location: str,
  2816. api_key: Optional[str] = None,
  2817. checksum: Optional[str] = None,
  2818. priority: Optional[types.SnapshotPriority] = None,
  2819. wait: bool = True,
  2820. **kwargs: Any,
  2821. ) -> Optional[bool]:
  2822. return self.openapi_client.snapshots_api.recover_from_snapshot(
  2823. collection_name=collection_name,
  2824. wait=wait,
  2825. snapshot_recover=models.SnapshotRecover(
  2826. location=location,
  2827. priority=priority,
  2828. checksum=checksum,
  2829. api_key=api_key,
  2830. ),
  2831. ).result
  2832. def list_shard_snapshots(
  2833. self, collection_name: str, shard_id: int, **kwargs: Any
  2834. ) -> list[types.SnapshotDescription]:
  2835. snapshots = self.openapi_client.snapshots_api.list_shard_snapshots(
  2836. collection_name=collection_name,
  2837. shard_id=shard_id,
  2838. ).result
  2839. assert snapshots is not None, "List snapshots API returned None result"
  2840. return snapshots
  2841. def create_shard_snapshot(
  2842. self, collection_name: str, shard_id: int, wait: bool = True, **kwargs: Any
  2843. ) -> Optional[types.SnapshotDescription]:
  2844. return self.openapi_client.snapshots_api.create_shard_snapshot(
  2845. collection_name=collection_name,
  2846. shard_id=shard_id,
  2847. wait=wait,
  2848. ).result
  2849. def delete_shard_snapshot(
  2850. self,
  2851. collection_name: str,
  2852. shard_id: int,
  2853. snapshot_name: str,
  2854. wait: bool = True,
  2855. **kwargs: Any,
  2856. ) -> Optional[bool]:
  2857. return self.openapi_client.snapshots_api.delete_shard_snapshot(
  2858. collection_name=collection_name,
  2859. shard_id=shard_id,
  2860. snapshot_name=snapshot_name,
  2861. wait=wait,
  2862. ).result
  2863. def recover_shard_snapshot(
  2864. self,
  2865. collection_name: str,
  2866. shard_id: int,
  2867. location: str,
  2868. api_key: Optional[str] = None,
  2869. checksum: Optional[str] = None,
  2870. priority: Optional[types.SnapshotPriority] = None,
  2871. wait: bool = True,
  2872. **kwargs: Any,
  2873. ) -> Optional[bool]:
  2874. return self.openapi_client.snapshots_api.recover_shard_from_snapshot(
  2875. collection_name=collection_name,
  2876. shard_id=shard_id,
  2877. wait=wait,
  2878. shard_snapshot_recover=models.ShardSnapshotRecover(
  2879. location=location,
  2880. priority=priority,
  2881. checksum=checksum,
  2882. api_key=api_key,
  2883. ),
  2884. ).result
  2885. def lock_storage(self, reason: str, **kwargs: Any) -> types.LocksOption:
  2886. result: Optional[types.LocksOption] = self.openapi_client.service_api.post_locks(
  2887. models.LocksOption(error_message=reason, write=True)
  2888. ).result
  2889. assert result is not None, "Lock storage returned None"
  2890. return result
  2891. def unlock_storage(self, **kwargs: Any) -> types.LocksOption:
  2892. result: Optional[types.LocksOption] = self.openapi_client.service_api.post_locks(
  2893. models.LocksOption(write=False)
  2894. ).result
  2895. assert result is not None, "Post locks returned None"
  2896. return result
  2897. def get_locks(self, **kwargs: Any) -> types.LocksOption:
  2898. result: Optional[types.LocksOption] = self.openapi_client.service_api.get_locks().result
  2899. assert result is not None, "Get locks returned None"
  2900. return result
  2901. def create_shard_key(
  2902. self,
  2903. collection_name: str,
  2904. shard_key: types.ShardKey,
  2905. shards_number: Optional[int] = None,
  2906. replication_factor: Optional[int] = None,
  2907. placement: Optional[list[int]] = None,
  2908. timeout: Optional[int] = None,
  2909. **kwargs: Any,
  2910. ) -> bool:
  2911. if self._prefer_grpc:
  2912. if isinstance(shard_key, get_args_subscribed(models.ShardKey)):
  2913. shard_key = RestToGrpc.convert_shard_key(shard_key)
  2914. return self.grpc_collections.CreateShardKey(
  2915. grpc.CreateShardKeyRequest(
  2916. collection_name=collection_name,
  2917. timeout=timeout,
  2918. request=grpc.CreateShardKey(
  2919. shard_key=shard_key,
  2920. shards_number=shards_number,
  2921. replication_factor=replication_factor,
  2922. placement=placement or [],
  2923. ),
  2924. ),
  2925. timeout=timeout if timeout is not None else self._timeout,
  2926. ).result
  2927. else:
  2928. result = self.openapi_client.distributed_api.create_shard_key(
  2929. collection_name=collection_name,
  2930. timeout=timeout,
  2931. create_sharding_key=models.CreateShardingKey(
  2932. shard_key=shard_key,
  2933. shards_number=shards_number,
  2934. replication_factor=replication_factor,
  2935. placement=placement,
  2936. ),
  2937. ).result
  2938. assert result is not None, "Create shard key returned None"
  2939. return result
  2940. def delete_shard_key(
  2941. self,
  2942. collection_name: str,
  2943. shard_key: types.ShardKey,
  2944. timeout: Optional[int] = None,
  2945. **kwargs: Any,
  2946. ) -> bool:
  2947. if self._prefer_grpc:
  2948. if isinstance(shard_key, get_args_subscribed(models.ShardKey)):
  2949. shard_key = RestToGrpc.convert_shard_key(shard_key)
  2950. return self.grpc_collections.DeleteShardKey(
  2951. grpc.DeleteShardKeyRequest(
  2952. collection_name=collection_name,
  2953. timeout=timeout,
  2954. request=grpc.DeleteShardKey(
  2955. shard_key=shard_key,
  2956. ),
  2957. ),
  2958. timeout=timeout if timeout is not None else self._timeout,
  2959. ).result
  2960. else:
  2961. result = self.openapi_client.distributed_api.delete_shard_key(
  2962. collection_name=collection_name,
  2963. timeout=timeout,
  2964. drop_sharding_key=models.DropShardingKey(
  2965. shard_key=shard_key,
  2966. ),
  2967. ).result
  2968. assert result is not None, "Delete shard key returned None"
  2969. return result
  2970. def info(self) -> types.VersionInfo:
  2971. if self._prefer_grpc:
  2972. version_info = self.grpc_root.HealthCheck(
  2973. grpc.HealthCheckRequest(), timeout=self._timeout
  2974. )
  2975. return GrpcToRest.convert_health_check_reply(version_info)
  2976. version_info = self.rest.service_api.root()
  2977. assert version_info is not None, "Healthcheck returned None"
  2978. return version_info