conversion.py 192 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761
  1. from datetime import date, datetime, timezone
  2. from typing import Any, Mapping, Optional, Sequence, Union, get_args
  3. from google.protobuf.internal.containers import MessageMap
  4. from google.protobuf.json_format import MessageToDict
  5. from google.protobuf.timestamp_pb2 import Timestamp
  6. try:
  7. from google.protobuf.pyext._message import MessageMapContainer # type: ignore
  8. except ImportError:
  9. pass
  10. from qdrant_client import grpc
  11. from qdrant_client.grpc import ListValue, NullValue, Struct, Value
  12. from qdrant_client.http.models import models as rest
  13. from qdrant_client._pydantic_compat import construct, to_jsonable_python
  14. from qdrant_client.conversions.common_types import get_args_subscribed
  15. def has_field(message: Any, field: str) -> bool:
  16. """
  17. Same as protobuf HasField, but also works for primitive values
  18. (https://stackoverflow.com/questions/51918871/check-if-a-field-has-been-set-in-protocol-buffer-3)
  19. Args:
  20. message (Any): protobuf message
  21. field (str): name of the field
  22. """
  23. try:
  24. return message.HasField(field)
  25. except ValueError:
  26. all_fields = set([descriptor.name for descriptor, _value in message.ListFields()])
  27. return field in all_fields
  28. def json_to_value(payload: Any) -> Value:
  29. if payload is None:
  30. return Value(null_value=NullValue.NULL_VALUE)
  31. if isinstance(payload, bool):
  32. return Value(bool_value=payload)
  33. if isinstance(payload, int):
  34. return Value(integer_value=payload)
  35. if isinstance(payload, float):
  36. return Value(double_value=payload)
  37. if isinstance(payload, str):
  38. return Value(string_value=payload)
  39. if isinstance(payload, (list, tuple)):
  40. return Value(list_value=ListValue(values=[json_to_value(v) for v in payload]))
  41. if isinstance(payload, dict):
  42. return Value(
  43. struct_value=Struct(fields=dict((k, json_to_value(v)) for k, v in payload.items()))
  44. )
  45. if isinstance(payload, datetime) or isinstance(payload, date):
  46. return Value(string_value=to_jsonable_python(payload))
  47. raise ValueError(f"Not supported json value: {payload}") # pragma: no cover
  48. def value_to_json(value: Value) -> Any:
  49. if isinstance(value, Value):
  50. value_ = MessageToDict(value, preserving_proto_field_name=False)
  51. else:
  52. value_ = value
  53. if "integerValue" in value_:
  54. # by default int are represented as string for precision
  55. # But in python it is OK to just use `int`
  56. return int(value_["integerValue"])
  57. if "doubleValue" in value_:
  58. return value_["doubleValue"]
  59. if "stringValue" in value_:
  60. return value_["stringValue"]
  61. if "boolValue" in value_:
  62. return value_["boolValue"]
  63. if "structValue" in value_:
  64. if "fields" not in value_["structValue"]:
  65. return {}
  66. return dict(
  67. (key, value_to_json(val)) for key, val in value_["structValue"]["fields"].items()
  68. )
  69. if "listValue" in value_:
  70. if "values" in value_["listValue"]:
  71. return list(value_to_json(val) for val in value_["listValue"]["values"])
  72. else:
  73. return []
  74. if "nullValue" in value_:
  75. return None
  76. raise ValueError(f"Not supported value: {value_}") # pragma: no cover
  77. def payload_to_grpc(payload: dict[str, Any]) -> dict[str, Value]:
  78. return dict((key, json_to_value(val)) for key, val in payload.items())
  79. def grpc_to_payload(grpc_: MessageMap[str, Value]) -> dict[str, Any]:
  80. return dict((key, value_to_json(val)) for key, val in grpc_.items())
  81. def grpc_payload_schema_to_field_type(model: grpc.PayloadSchemaType) -> grpc.FieldType:
  82. if model == grpc.PayloadSchemaType.Keyword:
  83. return grpc.FieldType.FieldTypeKeyword
  84. if model == grpc.PayloadSchemaType.Integer:
  85. return grpc.FieldType.FieldTypeInteger
  86. if model == grpc.PayloadSchemaType.Float:
  87. return grpc.FieldType.FieldTypeFloat
  88. if model == grpc.PayloadSchemaType.Bool:
  89. return grpc.FieldType.FieldTypeBool
  90. if model == grpc.PayloadSchemaType.Geo:
  91. return grpc.FieldType.FieldTypeGeo
  92. if model == grpc.PayloadSchemaType.Text:
  93. return grpc.FieldType.FieldTypeText
  94. if model == grpc.PayloadSchemaType.Datetime:
  95. return grpc.FieldType.FieldTypeDatetime
  96. if model == grpc.PayloadSchemaType.Uuid:
  97. return grpc.FieldType.FieldTypeUuid
  98. raise ValueError(f"invalid PayloadSchemaType model: {model}") # pragma: no cover
  99. def grpc_field_type_to_payload_schema(model: grpc.FieldType) -> grpc.PayloadSchemaType:
  100. if model == grpc.FieldType.FieldTypeKeyword:
  101. return grpc.PayloadSchemaType.Keyword
  102. if model == grpc.FieldType.FieldTypeInteger:
  103. return grpc.PayloadSchemaType.Integer
  104. if model == grpc.FieldType.FieldTypeFloat:
  105. return grpc.PayloadSchemaType.Float
  106. if model == grpc.FieldType.FieldTypeBool:
  107. return grpc.PayloadSchemaType.Bool
  108. if model == grpc.FieldType.FieldTypeGeo:
  109. return grpc.PayloadSchemaType.Geo
  110. if model == grpc.FieldType.FieldTypeText:
  111. return grpc.PayloadSchemaType.Text
  112. if model == grpc.FieldType.FieldTypeDatetime:
  113. return grpc.PayloadSchemaType.Datetime
  114. if model == grpc.FieldType.FieldTypeUuid:
  115. return grpc.PayloadSchemaType.Uuid
  116. raise ValueError(f"invalid FieldType model: {model}") # pragma: no cover
  117. class GrpcToRest:
  118. @classmethod
  119. def convert_condition(cls, model: grpc.Condition) -> rest.Condition:
  120. name = model.WhichOneof("condition_one_of")
  121. if name is None:
  122. raise ValueError(f"invalid Condition model: {model}") # pragma: no cover
  123. val = getattr(model, name)
  124. if name == "field":
  125. return cls.convert_field_condition(val)
  126. if name == "filter":
  127. return cls.convert_filter(val)
  128. if name == "has_id":
  129. return cls.convert_has_id_condition(val)
  130. if name == "has_vector":
  131. return cls.convert_has_vector_condition(val)
  132. if name == "is_empty":
  133. return cls.convert_is_empty_condition(val)
  134. if name == "is_null":
  135. return cls.convert_is_null_condition(val)
  136. if name == "nested":
  137. return cls.convert_nested_condition(val)
  138. raise ValueError(f"invalid Condition model: {model}") # pragma: no cover
  139. @classmethod
  140. def convert_filter(cls, model: grpc.Filter) -> rest.Filter:
  141. return rest.Filter(
  142. must=[cls.convert_condition(condition) for condition in model.must],
  143. should=[cls.convert_condition(condition) for condition in model.should],
  144. must_not=[cls.convert_condition(condition) for condition in model.must_not],
  145. min_should=(
  146. rest.MinShould(
  147. conditions=[
  148. cls.convert_condition(condition)
  149. for condition in model.min_should.conditions
  150. ],
  151. min_count=model.min_should.min_count,
  152. )
  153. if model.HasField("min_should")
  154. else None
  155. ),
  156. )
  157. @classmethod
  158. def convert_range(cls, model: grpc.Range) -> rest.Range:
  159. return rest.Range(
  160. gt=model.gt if model.HasField("gt") else None,
  161. gte=model.gte if model.HasField("gte") else None,
  162. lt=model.lt if model.HasField("lt") else None,
  163. lte=model.lte if model.HasField("lte") else None,
  164. )
  165. @classmethod
  166. def convert_timestamp(cls, model: Timestamp) -> datetime:
  167. return model.ToDatetime(tzinfo=timezone.utc)
  168. @classmethod
  169. def convert_datetime_range(cls, model: grpc.DatetimeRange) -> rest.DatetimeRange:
  170. return rest.DatetimeRange(
  171. gt=cls.convert_timestamp(model.gt) if model.HasField("gt") else None,
  172. gte=cls.convert_timestamp(model.gte) if model.HasField("gte") else None,
  173. lt=cls.convert_timestamp(model.lt) if model.HasField("lt") else None,
  174. lte=cls.convert_timestamp(model.lte) if model.HasField("lte") else None,
  175. )
  176. @classmethod
  177. def convert_geo_radius(cls, model: grpc.GeoRadius) -> rest.GeoRadius:
  178. return rest.GeoRadius(center=cls.convert_geo_point(model.center), radius=model.radius)
  179. @classmethod
  180. def convert_geo_line_string(cls, model: grpc.GeoLineString) -> rest.GeoLineString:
  181. return rest.GeoLineString(points=[cls.convert_geo_point(point) for point in model.points])
  182. @classmethod
  183. def convert_geo_polygon(cls, model: grpc.GeoPolygon) -> rest.GeoPolygon:
  184. return rest.GeoPolygon(
  185. exterior=cls.convert_geo_line_string(model.exterior),
  186. interiors=(
  187. [cls.convert_geo_line_string(interior) for interior in model.interiors]
  188. if model.interiors
  189. else None
  190. ),
  191. )
  192. @classmethod
  193. def convert_collection_description(
  194. cls, model: grpc.CollectionDescription
  195. ) -> rest.CollectionDescription:
  196. return rest.CollectionDescription(name=model.name)
  197. @classmethod
  198. def convert_collection_info(cls, model: grpc.CollectionInfo) -> rest.CollectionInfo:
  199. return rest.CollectionInfo(
  200. config=cls.convert_collection_config(model.config),
  201. optimizer_status=cls.convert_optimizer_status(model.optimizer_status),
  202. payload_schema=cls.convert_payload_schema(model.payload_schema),
  203. segments_count=model.segments_count,
  204. status=cls.convert_collection_status(model.status),
  205. vectors_count=model.vectors_count if model.HasField("vectors_count") else None,
  206. points_count=model.points_count,
  207. indexed_vectors_count=model.indexed_vectors_count or 0,
  208. )
  209. @classmethod
  210. def convert_optimizer_status(cls, model: grpc.OptimizerStatus) -> rest.OptimizersStatus:
  211. if model.ok:
  212. return rest.OptimizersStatusOneOf.OK
  213. else:
  214. return rest.OptimizersStatusOneOf1(error=model.error)
  215. @classmethod
  216. def convert_collection_config(cls, model: grpc.CollectionConfig) -> rest.CollectionConfig:
  217. return rest.CollectionConfig(
  218. hnsw_config=cls.convert_hnsw_config(model.hnsw_config),
  219. optimizer_config=cls.convert_optimizer_config(model.optimizer_config),
  220. params=cls.convert_collection_params(model.params),
  221. wal_config=cls.convert_wal_config(model.wal_config),
  222. quantization_config=(
  223. cls.convert_quantization_config(model.quantization_config)
  224. if model.HasField("quantization_config")
  225. else None
  226. ),
  227. strict_mode_config=(
  228. cls.convert_strict_mode_config_output(model.strict_mode_config)
  229. if model.HasField("strict_mode_config")
  230. else None
  231. ),
  232. )
  233. @classmethod
  234. def convert_hnsw_config_diff(cls, model: grpc.HnswConfigDiff) -> rest.HnswConfigDiff:
  235. return rest.HnswConfigDiff(
  236. ef_construct=model.ef_construct if model.HasField("ef_construct") else None,
  237. m=model.m if model.HasField("m") else None,
  238. full_scan_threshold=(
  239. model.full_scan_threshold if model.HasField("full_scan_threshold") else None
  240. ),
  241. max_indexing_threads=(
  242. model.max_indexing_threads if model.HasField("max_indexing_threads") else None
  243. ),
  244. on_disk=model.on_disk if model.HasField("on_disk") else None,
  245. payload_m=model.payload_m if model.HasField("payload_m") else None,
  246. )
  247. @classmethod
  248. def convert_hnsw_config(cls, model: grpc.HnswConfigDiff) -> rest.HnswConfig:
  249. return rest.HnswConfig(
  250. ef_construct=model.ef_construct if model.HasField("ef_construct") else None,
  251. m=model.m if model.HasField("m") else None,
  252. full_scan_threshold=(
  253. model.full_scan_threshold if model.HasField("full_scan_threshold") else None
  254. ),
  255. max_indexing_threads=(
  256. model.max_indexing_threads if model.HasField("max_indexing_threads") else None
  257. ),
  258. on_disk=model.on_disk if model.HasField("on_disk") else None,
  259. payload_m=model.payload_m if model.HasField("payload_m") else None,
  260. )
  261. @classmethod
  262. def convert_max_optimization_threads(
  263. cls, model: grpc.MaxOptimizationThreads
  264. ) -> rest.MaxOptimizationThreads:
  265. name = model.WhichOneof("variant")
  266. if name is None:
  267. raise ValueError(f"invalid MaxOptimizationThreads model: {model}") # pragma: no cover
  268. if name == "setting":
  269. if model.setting == grpc.MaxOptimizationThreads.Setting.Auto:
  270. return rest.MaxOptimizationThreadsSetting.AUTO
  271. else:
  272. raise ValueError(
  273. f"invalid MaxOptimizationThreads model: {model}"
  274. ) # pragma: no cover
  275. elif name == "value":
  276. return model.value
  277. else:
  278. raise ValueError(f"invalid MaxOptimizationThreads model: {model}") # pragma: no cover
  279. @classmethod
  280. def convert_optimizer_config(cls, model: grpc.OptimizersConfigDiff) -> rest.OptimizersConfig:
  281. max_optimization_threads = None
  282. if model.HasField("deprecated_max_optimization_threads"):
  283. max_optimization_threads = model.deprecated_max_optimization_threads
  284. elif model.HasField("max_optimization_threads"):
  285. max_optimization_threads = cls.convert_max_optimization_threads(
  286. model.max_optimization_threads
  287. )
  288. if max_optimization_threads == rest.MaxOptimizationThreadsSetting.AUTO:
  289. max_optimization_threads = None
  290. return rest.OptimizersConfig(
  291. default_segment_number=(
  292. model.default_segment_number if model.HasField("default_segment_number") else None
  293. ),
  294. deleted_threshold=(
  295. model.deleted_threshold if model.HasField("deleted_threshold") else None
  296. ),
  297. flush_interval_sec=(
  298. model.flush_interval_sec if model.HasField("flush_interval_sec") else None
  299. ),
  300. indexing_threshold=(
  301. model.indexing_threshold if model.HasField("indexing_threshold") else None
  302. ),
  303. max_optimization_threads=max_optimization_threads,
  304. max_segment_size=(
  305. model.max_segment_size if model.HasField("max_segment_size") else None
  306. ),
  307. memmap_threshold=(
  308. model.memmap_threshold if model.HasField("memmap_threshold") else None
  309. ),
  310. vacuum_min_vector_number=(
  311. model.vacuum_min_vector_number
  312. if model.HasField("vacuum_min_vector_number")
  313. else None
  314. ),
  315. )
  316. @classmethod
  317. def convert_distance(cls, model: grpc.Distance) -> rest.Distance:
  318. if model == grpc.Distance.Cosine:
  319. return rest.Distance.COSINE
  320. elif model == grpc.Distance.Euclid:
  321. return rest.Distance.EUCLID
  322. elif model == grpc.Distance.Manhattan:
  323. return rest.Distance.MANHATTAN
  324. elif model == grpc.Distance.Dot:
  325. return rest.Distance.DOT
  326. else:
  327. raise ValueError(f"invalid Distance model: {model}") # pragma: no cover
  328. @classmethod
  329. def convert_wal_config(cls, model: grpc.WalConfigDiff) -> rest.WalConfig:
  330. return rest.WalConfig(
  331. wal_capacity_mb=model.wal_capacity_mb if model.HasField("wal_capacity_mb") else None,
  332. wal_segments_ahead=(
  333. model.wal_segments_ahead if model.HasField("wal_segments_ahead") else None
  334. ),
  335. )
  336. @classmethod
  337. def convert_payload_schema(
  338. cls, model: dict[str, grpc.PayloadSchemaInfo]
  339. ) -> dict[str, rest.PayloadIndexInfo]:
  340. return {key: cls.convert_payload_schema_info(info) for key, info in model.items()}
  341. @classmethod
  342. def convert_payload_schema_info(cls, model: grpc.PayloadSchemaInfo) -> rest.PayloadIndexInfo:
  343. return rest.PayloadIndexInfo(
  344. data_type=cls.convert_payload_schema_type(model.data_type),
  345. params=(
  346. cls.convert_payload_schema_params(model.params)
  347. if model.HasField("params")
  348. else None
  349. ),
  350. points=model.points,
  351. )
  352. @classmethod
  353. def convert_payload_schema_params(
  354. cls, model: grpc.PayloadIndexParams
  355. ) -> rest.PayloadSchemaParams:
  356. if model.HasField("text_index_params"):
  357. text_index_params = model.text_index_params
  358. return cls.convert_text_index_params(text_index_params)
  359. if model.HasField("integer_index_params"):
  360. integer_index_params = model.integer_index_params
  361. return cls.convert_integer_index_params(integer_index_params)
  362. if model.HasField("keyword_index_params"):
  363. keyword_index_params = model.keyword_index_params
  364. return cls.convert_keyword_index_params(keyword_index_params)
  365. if model.HasField("float_index_params"):
  366. float_index_params = model.float_index_params
  367. return cls.convert_float_index_params(float_index_params)
  368. if model.HasField("geo_index_params"):
  369. geo_index_params = model.geo_index_params
  370. return cls.convert_geo_index_params(geo_index_params)
  371. if model.HasField("bool_index_params"):
  372. bool_index_params = model.bool_index_params
  373. return cls.convert_bool_index_params(bool_index_params)
  374. if model.HasField("datetime_index_params"):
  375. datetime_index_params = model.datetime_index_params
  376. return cls.convert_datetime_index_params(datetime_index_params)
  377. if model.HasField("uuid_index_params"):
  378. uuid_index_params = model.uuid_index_params
  379. return cls.convert_uuid_index_params(uuid_index_params)
  380. raise ValueError(f"invalid PayloadIndexParams model: {model}") # pragma: no cover
  381. @classmethod
  382. def convert_payload_schema_type(cls, model: grpc.PayloadSchemaType) -> rest.PayloadSchemaType:
  383. if model == grpc.PayloadSchemaType.Float:
  384. return rest.PayloadSchemaType.FLOAT
  385. elif model == grpc.PayloadSchemaType.Geo:
  386. return rest.PayloadSchemaType.GEO
  387. elif model == grpc.PayloadSchemaType.Integer:
  388. return rest.PayloadSchemaType.INTEGER
  389. elif model == grpc.PayloadSchemaType.Keyword:
  390. return rest.PayloadSchemaType.KEYWORD
  391. elif model == grpc.PayloadSchemaType.Bool:
  392. return rest.PayloadSchemaType.BOOL
  393. elif model == grpc.PayloadSchemaType.Text:
  394. return rest.PayloadSchemaType.TEXT
  395. elif model == grpc.PayloadSchemaType.Datetime:
  396. return rest.PayloadSchemaType.DATETIME
  397. elif model == grpc.PayloadSchemaType.Uuid:
  398. return rest.PayloadSchemaType.UUID
  399. else:
  400. raise ValueError(f"invalid PayloadSchemaType model: {model}") # pragma: no cover
  401. @classmethod
  402. def convert_collection_status(cls, model: grpc.CollectionStatus) -> rest.CollectionStatus:
  403. if model == grpc.CollectionStatus.Green:
  404. return rest.CollectionStatus.GREEN
  405. elif model == grpc.CollectionStatus.Yellow:
  406. return rest.CollectionStatus.YELLOW
  407. elif model == grpc.CollectionStatus.Red:
  408. return rest.CollectionStatus.RED
  409. elif model == grpc.CollectionStatus.Grey:
  410. return rest.CollectionStatus.GREY
  411. raise ValueError(f"invalid CollectionStatus model: {model}") # pragma: no cover
  412. @classmethod
  413. def convert_update_result(cls, model: grpc.UpdateResult) -> rest.UpdateResult:
  414. return rest.UpdateResult(
  415. operation_id=model.operation_id,
  416. status=cls.convert_update_status(model.status),
  417. )
  418. @classmethod
  419. def convert_update_status(cls, model: grpc.UpdateStatus) -> rest.UpdateStatus:
  420. if model == grpc.UpdateStatus.Acknowledged:
  421. return rest.UpdateStatus.ACKNOWLEDGED
  422. elif model == grpc.UpdateStatus.Completed:
  423. return rest.UpdateStatus.COMPLETED
  424. else:
  425. raise ValueError(f"invalid UpdateStatus model: {model}") # pragma: no cover
  426. @classmethod
  427. def convert_has_id_condition(cls, model: grpc.HasIdCondition) -> rest.HasIdCondition:
  428. return rest.HasIdCondition(has_id=[cls.convert_point_id(idx) for idx in model.has_id])
  429. @classmethod
  430. def convert_has_vector_condition(
  431. cls, model: grpc.HasVectorCondition
  432. ) -> rest.HasVectorCondition:
  433. return rest.HasVectorCondition(has_vector=model.has_vector)
  434. @classmethod
  435. def convert_point_id(cls, model: grpc.PointId) -> rest.ExtendedPointId:
  436. name = model.WhichOneof("point_id_options")
  437. if name == "num":
  438. return model.num
  439. if name == "uuid":
  440. return model.uuid
  441. raise ValueError(f"invalid PointId model: {model}") # pragma: no cover
  442. @classmethod
  443. def convert_delete_alias(cls, model: grpc.DeleteAlias) -> rest.DeleteAlias:
  444. return rest.DeleteAlias(alias_name=model.alias_name)
  445. @classmethod
  446. def convert_rename_alias(cls, model: grpc.RenameAlias) -> rest.RenameAlias:
  447. return rest.RenameAlias(
  448. old_alias_name=model.old_alias_name, new_alias_name=model.new_alias_name
  449. )
  450. @classmethod
  451. def convert_is_empty_condition(cls, model: grpc.IsEmptyCondition) -> rest.IsEmptyCondition:
  452. return rest.IsEmptyCondition(is_empty=rest.PayloadField(key=model.key))
  453. @classmethod
  454. def convert_is_null_condition(cls, model: grpc.IsNullCondition) -> rest.IsNullCondition:
  455. return rest.IsNullCondition(is_null=rest.PayloadField(key=model.key))
  456. @classmethod
  457. def convert_nested_condition(cls, model: grpc.NestedCondition) -> rest.NestedCondition:
  458. return rest.NestedCondition(
  459. nested=rest.Nested(
  460. key=model.key,
  461. filter=cls.convert_filter(model.filter),
  462. )
  463. )
  464. @classmethod
  465. def convert_search_params(cls, model: grpc.SearchParams) -> rest.SearchParams:
  466. return rest.SearchParams(
  467. hnsw_ef=model.hnsw_ef if model.HasField("hnsw_ef") else None,
  468. exact=model.exact if model.HasField("exact") else None,
  469. quantization=(
  470. cls.convert_quantization_search_params(model.quantization)
  471. if model.HasField("quantization")
  472. else None
  473. ),
  474. indexed_only=model.indexed_only if model.HasField("indexed_only") else None,
  475. )
  476. @classmethod
  477. def convert_create_alias(cls, model: grpc.CreateAlias) -> rest.CreateAlias:
  478. return rest.CreateAlias(collection_name=model.collection_name, alias_name=model.alias_name)
  479. @classmethod
  480. def convert_order_value(cls, model: grpc.OrderValue) -> rest.OrderValue:
  481. name = model.WhichOneof("variant")
  482. if name is None:
  483. raise ValueError(f"invalid OrderValue model: {model}") # pragma: no cover
  484. val = getattr(model, name)
  485. if name == "int":
  486. return val
  487. if name == "float":
  488. return val
  489. raise ValueError(f"invalid OrderValue model: {model}") # pragma: no cover
  490. @classmethod
  491. def convert_scored_point(cls, model: grpc.ScoredPoint) -> rest.ScoredPoint:
  492. return construct(
  493. rest.ScoredPoint,
  494. id=cls.convert_point_id(model.id),
  495. payload=cls.convert_payload(model.payload) if has_field(model, "payload") else None,
  496. score=model.score,
  497. vector=(
  498. cls.convert_vectors_output(model.vectors) if model.HasField("vectors") else None
  499. ),
  500. version=model.version,
  501. shard_key=(
  502. cls.convert_shard_key(model.shard_key) if model.HasField("shard_key") else None
  503. ),
  504. order_value=(
  505. cls.convert_order_value(model.order_value)
  506. if model.HasField("order_value")
  507. else None
  508. ),
  509. )
  510. @classmethod
  511. def convert_payload(cls, model: "MessageMapContainer") -> rest.Payload:
  512. return dict((key, value_to_json(model[key])) for key in model)
  513. @classmethod
  514. def convert_values_count(cls, model: grpc.ValuesCount) -> rest.ValuesCount:
  515. return rest.ValuesCount(
  516. gt=model.gt if model.HasField("gt") else None,
  517. gte=model.gte if model.HasField("gte") else None,
  518. lt=model.lt if model.HasField("lt") else None,
  519. lte=model.lte if model.HasField("lte") else None,
  520. )
  521. @classmethod
  522. def convert_geo_bounding_box(cls, model: grpc.GeoBoundingBox) -> rest.GeoBoundingBox:
  523. return rest.GeoBoundingBox(
  524. bottom_right=cls.convert_geo_point(model.bottom_right),
  525. top_left=cls.convert_geo_point(model.top_left),
  526. )
  527. @classmethod
  528. def convert_point_struct(cls, model: grpc.PointStruct) -> rest.PointStruct:
  529. return rest.PointStruct(
  530. id=cls.convert_point_id(model.id),
  531. payload=cls.convert_payload(model.payload),
  532. vector=cls.convert_vectors(model.vectors) if model.HasField("vectors") else None,
  533. )
  534. @classmethod
  535. def convert_field_condition(cls, model: grpc.FieldCondition) -> rest.FieldCondition:
  536. geo_bounding_box = (
  537. cls.convert_geo_bounding_box(model.geo_bounding_box)
  538. if model.HasField("geo_bounding_box")
  539. else None
  540. )
  541. geo_radius = (
  542. cls.convert_geo_radius(model.geo_radius) if model.HasField("geo_radius") else None
  543. )
  544. geo_polygon = (
  545. cls.convert_geo_polygon(model.geo_polygon) if model.HasField("geo_polygon") else None
  546. )
  547. match = cls.convert_match(model.match) if model.HasField("match") else None
  548. range_: Optional[rest.RangeInterface] = None
  549. if model.HasField("range"):
  550. range_ = cls.convert_range(model.range)
  551. elif model.HasField("datetime_range"):
  552. range_ = cls.convert_datetime_range(model.datetime_range)
  553. values_count = (
  554. cls.convert_values_count(model.values_count)
  555. if model.HasField("values_count")
  556. else None
  557. )
  558. is_empty = model.is_empty if model.HasField("is_empty") else None
  559. is_null = model.is_null if model.HasField("is_null") else None
  560. return rest.FieldCondition(
  561. key=model.key,
  562. geo_bounding_box=geo_bounding_box,
  563. geo_radius=geo_radius,
  564. geo_polygon=geo_polygon,
  565. match=match,
  566. range=range_,
  567. values_count=values_count,
  568. is_empty=is_empty,
  569. is_null=is_null,
  570. )
  571. @classmethod
  572. def convert_match(cls, model: grpc.Match) -> rest.Match:
  573. name = model.WhichOneof("match_value")
  574. if name is None:
  575. raise ValueError(f"invalid Match model: {model}") # pragma: no cover
  576. val = getattr(model, name)
  577. if name == "integer":
  578. return rest.MatchValue(value=val)
  579. if name == "boolean":
  580. return rest.MatchValue(value=val)
  581. if name == "keyword":
  582. return rest.MatchValue(value=val)
  583. if name == "text":
  584. return rest.MatchText(text=val)
  585. if name == "keywords":
  586. return rest.MatchAny(any=list(val.strings))
  587. if name == "integers":
  588. return rest.MatchAny(any=list(val.integers))
  589. if name == "except_keywords":
  590. return rest.MatchExcept(**{"except": list(val.strings)})
  591. if name == "except_integers":
  592. return rest.MatchExcept(**{"except": list(val.integers)})
  593. if name == "phrase":
  594. return rest.MatchPhrase(phrase=val)
  595. raise ValueError(f"invalid Match model: {model}") # pragma: no cover
  596. @classmethod
  597. def convert_wal_config_diff(cls, model: grpc.WalConfigDiff) -> rest.WalConfigDiff:
  598. return rest.WalConfigDiff(
  599. wal_capacity_mb=model.wal_capacity_mb if model.HasField("wal_capacity_mb") else None,
  600. wal_segments_ahead=(
  601. model.wal_segments_ahead if model.HasField("wal_segments_ahead") else None
  602. ),
  603. )
  604. @classmethod
  605. def convert_collection_params(cls, model: grpc.CollectionParams) -> rest.CollectionParams:
  606. return rest.CollectionParams(
  607. vectors=(
  608. cls.convert_vectors_config(model.vectors_config)
  609. if model.HasField("vectors_config")
  610. else None
  611. ),
  612. shard_number=model.shard_number,
  613. on_disk_payload=model.on_disk_payload,
  614. replication_factor=(
  615. model.replication_factor if model.HasField("replication_factor") else None
  616. ),
  617. read_fan_out_factor=(
  618. model.read_fan_out_factor if model.HasField("read_fan_out_factor") else None
  619. ),
  620. write_consistency_factor=(
  621. model.write_consistency_factor
  622. if model.HasField("write_consistency_factor")
  623. else None
  624. ),
  625. sparse_vectors=(
  626. cls.convert_sparse_vector_config(model.sparse_vectors_config)
  627. if model.HasField("sparse_vectors_config")
  628. else None
  629. ),
  630. sharding_method=(
  631. cls.convert_sharding_method(model.sharding_method)
  632. if model.HasField("sharding_method")
  633. else None
  634. ),
  635. )
  636. @classmethod
  637. def convert_optimizers_config_diff(
  638. cls, model: grpc.OptimizersConfigDiff
  639. ) -> rest.OptimizersConfigDiff:
  640. max_optimization_threads = None
  641. if model.HasField("deprecated_max_optimization_threads"):
  642. max_optimization_threads = model.deprecated_max_optimization_threads
  643. elif model.HasField("max_optimization_threads"):
  644. max_optimization_threads = cls.convert_max_optimization_threads(
  645. model.max_optimization_threads
  646. )
  647. return rest.OptimizersConfigDiff(
  648. default_segment_number=(
  649. model.default_segment_number if model.HasField("default_segment_number") else None
  650. ),
  651. deleted_threshold=(
  652. model.deleted_threshold if model.HasField("deleted_threshold") else None
  653. ),
  654. flush_interval_sec=(
  655. model.flush_interval_sec if model.HasField("flush_interval_sec") else None
  656. ),
  657. indexing_threshold=(
  658. model.indexing_threshold if model.HasField("indexing_threshold") else None
  659. ),
  660. max_optimization_threads=max_optimization_threads,
  661. max_segment_size=(
  662. model.max_segment_size if model.HasField("max_segment_size") else None
  663. ),
  664. memmap_threshold=(
  665. model.memmap_threshold if model.HasField("memmap_threshold") else None
  666. ),
  667. vacuum_min_vector_number=(
  668. model.vacuum_min_vector_number
  669. if model.HasField("vacuum_min_vector_number")
  670. else None
  671. ),
  672. )
  673. @classmethod
  674. def convert_update_collection(cls, model: grpc.UpdateCollection) -> rest.UpdateCollection:
  675. return rest.UpdateCollection(
  676. vectors=(
  677. cls.convert_vectors_config_diff(model.vectors_config)
  678. if model.HasField("vectors_config")
  679. else None
  680. ),
  681. optimizers_config=(
  682. cls.convert_optimizers_config_diff(model.optimizers_config)
  683. if model.HasField("optimizers_config")
  684. else None
  685. ),
  686. params=(
  687. cls.convert_collection_params_diff(model.params)
  688. if model.HasField("params")
  689. else None
  690. ),
  691. hnsw_config=(
  692. cls.convert_hnsw_config_diff(model.hnsw_config)
  693. if model.HasField("hnsw_config")
  694. else None
  695. ),
  696. quantization_config=(
  697. cls.convert_quantization_config_diff(model.quantization_config)
  698. if model.HasField("quantization_config")
  699. else None
  700. ),
  701. )
  702. @classmethod
  703. def convert_geo_point(cls, model: grpc.GeoPoint) -> rest.GeoPoint:
  704. return rest.GeoPoint(
  705. lat=model.lat,
  706. lon=model.lon,
  707. )
  708. @classmethod
  709. def convert_alias_operations(cls, model: grpc.AliasOperations) -> rest.AliasOperations:
  710. name = model.WhichOneof("action")
  711. if name is None:
  712. raise ValueError(f"invalid AliasOperations model: {model}") # pragma: no cover
  713. val = getattr(model, name)
  714. if name == "rename_alias":
  715. return rest.RenameAliasOperation(rename_alias=cls.convert_rename_alias(val))
  716. if name == "create_alias":
  717. return rest.CreateAliasOperation(create_alias=cls.convert_create_alias(val))
  718. if name == "delete_alias":
  719. return rest.DeleteAliasOperation(delete_alias=cls.convert_delete_alias(val))
  720. raise ValueError(f"invalid AliasOperations model: {model}") # pragma: no cover
  721. @classmethod
  722. def convert_alias_description(cls, model: grpc.AliasDescription) -> rest.AliasDescription:
  723. return rest.AliasDescription(
  724. alias_name=model.alias_name,
  725. collection_name=model.collection_name,
  726. )
  727. @classmethod
  728. def convert_points_selector(
  729. cls,
  730. model: grpc.PointsSelector,
  731. shard_key_selector: Optional[grpc.ShardKeySelector] = None,
  732. ) -> rest.PointsSelector:
  733. name = model.WhichOneof("points_selector_one_of")
  734. if name is None:
  735. raise ValueError(f"invalid PointsSelector model: {model}") # pragma: no cover
  736. val = getattr(model, name)
  737. if name == "points":
  738. return rest.PointIdsList(
  739. points=[cls.convert_point_id(point) for point in val.ids],
  740. shard_key=shard_key_selector,
  741. )
  742. if name == "filter":
  743. return rest.FilterSelector(
  744. filter=cls.convert_filter(val),
  745. shard_key=shard_key_selector,
  746. )
  747. raise ValueError(f"invalid PointsSelector model: {model}") # pragma: no cover
  748. @classmethod
  749. def convert_with_payload_selector(
  750. cls, model: grpc.WithPayloadSelector
  751. ) -> rest.WithPayloadInterface:
  752. name = model.WhichOneof("selector_options")
  753. if name is None:
  754. raise ValueError(f"invalid WithPayloadSelector model: {model}") # pragma: no cover
  755. val = getattr(model, name)
  756. if name == "enable":
  757. return val
  758. if name == "include":
  759. return list(val.fields)
  760. if name == "exclude":
  761. return rest.PayloadSelectorExclude(exclude=list(val.fields))
  762. raise ValueError(f"invalid WithPayloadSelector model: {model}") # pragma: no cover
  763. @classmethod
  764. def convert_with_payload_interface(
  765. cls, model: grpc.WithPayloadSelector
  766. ) -> rest.WithPayloadInterface:
  767. return cls.convert_with_payload_selector(model)
  768. @classmethod
  769. def convert_retrieved_point(cls, model: grpc.RetrievedPoint) -> rest.Record:
  770. return rest.Record(
  771. id=cls.convert_point_id(model.id),
  772. payload=cls.convert_payload(model.payload),
  773. vector=(
  774. cls.convert_vectors_output(model.vectors) if model.HasField("vectors") else None
  775. ),
  776. shard_key=(
  777. cls.convert_shard_key(model.shard_key) if model.HasField("shard_key") else None
  778. ),
  779. order_value=(
  780. cls.convert_order_value(model.order_value)
  781. if model.HasField("order_value")
  782. else None
  783. ),
  784. )
  785. @classmethod
  786. def convert_record(cls, model: grpc.RetrievedPoint) -> rest.Record:
  787. return cls.convert_retrieved_point(model)
  788. @classmethod
  789. def convert_count_result(cls, model: grpc.CountResult) -> rest.CountResult:
  790. return rest.CountResult(count=model.count)
  791. @classmethod
  792. def convert_snapshot_description(
  793. cls, model: grpc.SnapshotDescription
  794. ) -> rest.SnapshotDescription:
  795. return rest.SnapshotDescription(
  796. name=model.name,
  797. creation_time=(
  798. model.creation_time.ToDatetime().isoformat()
  799. if model.HasField("creation_time")
  800. else None
  801. ),
  802. size=model.size,
  803. )
  804. @classmethod
  805. def convert_datatype(cls, model: grpc.Datatype) -> rest.Datatype:
  806. if model == grpc.Datatype.Float32:
  807. return rest.Datatype.FLOAT32
  808. elif model == grpc.Datatype.Uint8:
  809. return rest.Datatype.UINT8
  810. elif model == grpc.Datatype.Float16:
  811. return rest.Datatype.FLOAT16
  812. else:
  813. raise ValueError(f"invalid Datatype model: {model}") # pragma: no cover
  814. @classmethod
  815. def convert_vector_params(cls, model: grpc.VectorParams) -> rest.VectorParams:
  816. return rest.VectorParams(
  817. size=model.size,
  818. distance=cls.convert_distance(model.distance),
  819. hnsw_config=(
  820. cls.convert_hnsw_config_diff(model.hnsw_config)
  821. if model.HasField("hnsw_config")
  822. else None
  823. ),
  824. quantization_config=(
  825. cls.convert_quantization_config(model.quantization_config)
  826. if model.HasField("quantization_config")
  827. else None
  828. ),
  829. on_disk=model.on_disk if model.HasField("on_disk") else None,
  830. datatype=cls.convert_datatype(model.datatype) if model.HasField("datatype") else None,
  831. multivector_config=(
  832. cls.convert_multivector_config(model.multivector_config)
  833. if model.HasField("multivector_config")
  834. else None
  835. ),
  836. )
  837. @classmethod
  838. def convert_multivector_config(cls, model: grpc.MultiVectorConfig) -> rest.MultiVectorConfig:
  839. return rest.MultiVectorConfig(
  840. comparator=cls.convert_multivector_comparator(model.comparator)
  841. )
  842. @classmethod
  843. def convert_multivector_comparator(
  844. cls, model: grpc.MultiVectorComparator
  845. ) -> rest.MultiVectorComparator:
  846. if model == grpc.MultiVectorComparator.MaxSim:
  847. return rest.MultiVectorComparator.MAX_SIM
  848. raise ValueError(f"invalid MultiVectorComparator model: {model}") # pragma: no cover
  849. @classmethod
  850. def convert_vectors_config(cls, model: grpc.VectorsConfig) -> rest.VectorsConfig:
  851. name = model.WhichOneof("config")
  852. if name is None:
  853. raise ValueError(f"invalid VectorsConfig model: {model}") # pragma: no cover
  854. val = getattr(model, name)
  855. if name == "params":
  856. return cls.convert_vector_params(val)
  857. if name == "params_map":
  858. return dict(
  859. (key, cls.convert_vector_params(vec_params)) for key, vec_params in val.map.items()
  860. )
  861. raise ValueError(f"invalid VectorsConfig model: {model}") # pragma: no cover
  862. @classmethod
  863. def _convert_vector(
  864. cls, model: Union[grpc.Vector, grpc.VectorOutput]
  865. ) -> tuple[
  866. Optional[str],
  867. Union[
  868. list[float],
  869. list[list[float]],
  870. rest.SparseVector,
  871. grpc.Document,
  872. grpc.Image,
  873. grpc.InferenceObject,
  874. ],
  875. ]:
  876. """Parse common parts of vector structs
  877. Args:
  878. model: Vector or VectorOutput
  879. Returns:
  880. Tuple of name and value, name is None if the struct was parsed and returned with the converted value,
  881. otherwise it's propagated for further processing along with the raw value
  882. """
  883. name = model.WhichOneof("vector")
  884. if name is None:
  885. if model.HasField("indices"):
  886. return None, rest.SparseVector(indices=model.indices.data[:], values=model.data[:])
  887. if model.HasField("vectors_count"):
  888. vectors_count = model.vectors_count
  889. vectors = model.data
  890. step = len(vectors) // vectors_count
  891. return None, [vectors[i : i + step] for i in range(0, len(vectors), step)]
  892. return None, model.data[:]
  893. val = getattr(model, name)
  894. if name == "dense":
  895. return None, cls.convert_dense_vector(val)
  896. if name == "sparse":
  897. return None, cls.convert_sparse_vector(val)
  898. if name == "multi_dense":
  899. return None, cls.convert_multi_dense_vector(val)
  900. return name, val
  901. @classmethod
  902. def convert_vector(
  903. cls, model: grpc.Vector
  904. ) -> Union[
  905. list[float],
  906. list[list[float]],
  907. rest.SparseVector,
  908. rest.Document,
  909. rest.Image,
  910. rest.InferenceObject,
  911. ]:
  912. name, val = cls._convert_vector(model)
  913. if name is None:
  914. return val
  915. if name == "document":
  916. return cls.convert_document(val)
  917. if name == "image":
  918. return cls.convert_image(val)
  919. if name == "object":
  920. return cls.convert_inference_object(val)
  921. raise ValueError(f"invalid Vector model: {model}") # pragma: no cover
  922. @classmethod
  923. def convert_vector_output(
  924. cls, model: grpc.VectorOutput
  925. ) -> Union[list[float], list[list[float]], rest.SparseVector]:
  926. name, val = cls._convert_vector(model)
  927. if name is None:
  928. return val
  929. raise ValueError(f"invalid Vector model: {model}") # pragma: no cover
  930. @classmethod
  931. def convert_named_vectors(cls, model: grpc.NamedVectors) -> dict[str, rest.Vector]:
  932. vectors = {}
  933. for name, vector in model.vectors.items():
  934. vectors[name] = cls.convert_vector(vector)
  935. return vectors
  936. @classmethod
  937. def convert_named_vectors_output(
  938. cls, model: grpc.NamedVectorsOutput
  939. ) -> dict[str, rest.VectorOutput]:
  940. vectors = {}
  941. for name, vector in model.vectors.items():
  942. vectors[name] = cls.convert_vector_output(vector)
  943. return vectors
  944. @classmethod
  945. def convert_vectors(cls, model: grpc.Vectors) -> rest.VectorStruct:
  946. name = model.WhichOneof("vectors_options")
  947. if name is None:
  948. raise ValueError(f"invalid Vectors model: {model}") # pragma: no cover
  949. val = getattr(model, name)
  950. if name == "vector":
  951. return cls.convert_vector(val)
  952. if name == "vectors":
  953. return cls.convert_named_vectors(val)
  954. raise ValueError(f"invalid Vectors model: {model}") # pragma: no cover
  955. @classmethod
  956. def convert_vectors_output(cls, model: grpc.VectorsOutput) -> rest.VectorStructOutput:
  957. name = model.WhichOneof("vectors_options")
  958. if name is None:
  959. raise ValueError(f"invalid VectorsOutput model: {model}") # pragma: no cover
  960. val = getattr(model, name)
  961. if name == "vector":
  962. return cls.convert_vector_output(val)
  963. if name == "vectors":
  964. return cls.convert_named_vectors_output(val)
  965. raise ValueError(f"invalid VectorsOutput model: {model}") # pragma: no cover
  966. @classmethod
  967. def convert_dense_vector(cls, model: grpc.DenseVector) -> list[float]:
  968. return model.data[:]
  969. @classmethod
  970. def convert_sparse_vector(cls, model: grpc.SparseVector) -> rest.SparseVector:
  971. return rest.SparseVector(indices=model.indices[:], values=model.values[:])
  972. @classmethod
  973. def convert_multi_dense_vector(cls, model: grpc.MultiDenseVector) -> list[list[float]]:
  974. return [cls.convert_dense_vector(vector) for vector in model.vectors]
  975. @classmethod
  976. def convert_document(cls, model: grpc.Document) -> rest.Document:
  977. return rest.Document(
  978. text=model.text,
  979. model=model.model,
  980. options=grpc_to_payload(model.options),
  981. )
  982. @classmethod
  983. def convert_image(cls, model: grpc.Image) -> rest.Image:
  984. return rest.Image(
  985. image=value_to_json(model.image),
  986. model=model.model,
  987. options=grpc_to_payload(model.options),
  988. )
  989. @classmethod
  990. def convert_inference_object(cls, model: grpc.InferenceObject) -> rest.InferenceObject:
  991. return rest.InferenceObject(
  992. object=value_to_json(model.object),
  993. model=model.model,
  994. options=grpc_to_payload(model.options),
  995. )
  996. @classmethod
  997. def convert_vector_input(cls, model: grpc.VectorInput) -> rest.VectorInput:
  998. name = model.WhichOneof("variant")
  999. if name is None:
  1000. raise ValueError(f"invalid VectorInput model: {model}") # pragma: no cover
  1001. val = getattr(model, name)
  1002. if name == "id":
  1003. return cls.convert_point_id(val)
  1004. if name == "dense":
  1005. return cls.convert_dense_vector(val)
  1006. if name == "sparse":
  1007. return cls.convert_sparse_vector(val)
  1008. if name == "multi_dense":
  1009. return cls.convert_multi_dense_vector(val)
  1010. if name == "document":
  1011. return cls.convert_document(val)
  1012. if name == "image":
  1013. return cls.convert_image(val)
  1014. if name == "object":
  1015. return cls.convert_inference_object(val)
  1016. raise ValueError(f"invalid VectorInput model: {model}") # pragma: no cover
  1017. @classmethod
  1018. def convert_recommend_input(cls, model: grpc.RecommendInput) -> rest.RecommendInput:
  1019. return rest.RecommendInput(
  1020. positive=[cls.convert_vector_input(vector) for vector in model.positive],
  1021. negative=[cls.convert_vector_input(vector) for vector in model.negative],
  1022. strategy=(
  1023. cls.convert_recommend_strategy(model.strategy)
  1024. if model.HasField("strategy")
  1025. else None
  1026. ),
  1027. )
  1028. @classmethod
  1029. def convert_context_input_pair(cls, model: grpc.ContextInputPair) -> rest.ContextPair:
  1030. return rest.ContextPair(
  1031. positive=cls.convert_vector_input(model.positive),
  1032. negative=cls.convert_vector_input(model.negative),
  1033. )
  1034. @classmethod
  1035. def convert_context_input(cls, model: grpc.ContextInput) -> rest.ContextInput:
  1036. return [cls.convert_context_input_pair(pair) for pair in model.pairs]
  1037. @classmethod
  1038. def convert_discover_input(cls, model: grpc.DiscoverInput) -> rest.DiscoverInput:
  1039. return rest.DiscoverInput(
  1040. target=cls.convert_vector_input(model.target),
  1041. context=cls.convert_context_input(model.context),
  1042. )
  1043. @classmethod
  1044. def convert_fusion(cls, model: grpc.Fusion) -> rest.Fusion:
  1045. if model == grpc.Fusion.RRF:
  1046. return rest.Fusion.RRF
  1047. if model == grpc.Fusion.DBSF:
  1048. return rest.Fusion.DBSF
  1049. raise ValueError(f"invalid Fusion model: {model}") # pragma: no cover
  1050. @classmethod
  1051. def convert_sample(cls, model: grpc.Sample) -> rest.Sample:
  1052. if model == grpc.Sample.Random:
  1053. return rest.Sample.RANDOM
  1054. raise ValueError(f"invalid Sample model: {model}") # pragma: no cover
  1055. @classmethod
  1056. def convert_formula_query(cls, model: grpc.Formula) -> rest.FormulaQuery:
  1057. defaults = grpc_to_payload(model.defaults)
  1058. return rest.FormulaQuery(
  1059. formula=cls.convert_expression(model.expression), defaults=defaults
  1060. )
  1061. @classmethod
  1062. def convert_expression(cls, model: grpc.Expression) -> rest.Expression:
  1063. name = model.WhichOneof("variant")
  1064. if name is None:
  1065. raise ValueError(f"invalid Query model: {model}") # pragma: no cover
  1066. if name == "constant":
  1067. return model.constant
  1068. if name == "variable":
  1069. return model.variable
  1070. if name == "condition":
  1071. return cls.convert_condition(model.condition)
  1072. if name == "datetime":
  1073. return rest.DatetimeExpression(datetime=model.datetime)
  1074. if name == "datetime_key":
  1075. return rest.DatetimeKeyExpression(datetime_key=model.datetime_key)
  1076. if name == "sum":
  1077. return cls.convert_sum_expression(model.sum)
  1078. if name == "mult":
  1079. return cls.convert_mult_expression(model.mult)
  1080. if name == "div":
  1081. return cls.convert_div_expression(model.div)
  1082. if name == "abs":
  1083. return rest.AbsExpression(abs=cls.convert_expression(model.abs))
  1084. if name == "neg":
  1085. return rest.NegExpression(neg=cls.convert_expression(model.neg))
  1086. if name == "log10":
  1087. return rest.Log10Expression(log10=cls.convert_expression(model.log10))
  1088. if name == "ln":
  1089. return rest.LnExpression(ln=cls.convert_expression(model.ln))
  1090. if name == "sqrt":
  1091. return rest.SqrtExpression(sqrt=cls.convert_expression(model.sqrt))
  1092. if name == "exp":
  1093. return rest.ExpExpression(exp=cls.convert_expression(model.exp))
  1094. if name == "pow":
  1095. return cls.convert_pow_expression(model.pow)
  1096. if name == "geo_distance":
  1097. return cls.convert_geo_distance(model.geo_distance)
  1098. if name == "lin_decay":
  1099. return rest.LinDecayExpression(
  1100. lin_decay=cls.convert_decay_params_expression(model.lin_decay)
  1101. )
  1102. if name == "exp_decay":
  1103. return rest.ExpDecayExpression(
  1104. exp_decay=cls.convert_decay_params_expression(model.exp_decay)
  1105. )
  1106. if name == "gauss_decay":
  1107. return rest.GaussDecayExpression(
  1108. gauss_decay=cls.convert_decay_params_expression(model.gauss_decay)
  1109. )
  1110. raise ValueError(f"Unknown function name: {name}")
  1111. @classmethod
  1112. def convert_sum_expression(cls, model: grpc.SumExpression) -> rest.SumExpression:
  1113. return rest.SumExpression(sum=[cls.convert_expression(expr) for expr in model.sum])
  1114. @classmethod
  1115. def convert_mult_expression(cls, model: grpc.MultExpression) -> rest.MultExpression:
  1116. return rest.MultExpression(mult=[cls.convert_expression(expr) for expr in model.mult])
  1117. @classmethod
  1118. def convert_div_expression(cls, model: grpc.DivExpression) -> rest.DivExpression:
  1119. left = cls.convert_expression(model.left)
  1120. right = cls.convert_expression(model.right)
  1121. by_zero_default = model.by_zero_default if model.HasField("by_zero_default") else None
  1122. params = rest.DivParams(left=left, right=right, by_zero_default=by_zero_default)
  1123. return rest.DivExpression(div=params)
  1124. @classmethod
  1125. def convert_pow_expression(cls, model: grpc.PowExpression) -> rest.PowExpression:
  1126. base = cls.convert_expression(model.base)
  1127. exponent = cls.convert_expression(model.exponent)
  1128. params = rest.PowParams(base=base, exponent=exponent)
  1129. return rest.PowExpression(pow=params)
  1130. @classmethod
  1131. def convert_geo_distance(cls, model: grpc.GeoDistance) -> rest.GeoDistance:
  1132. origin = cls.convert_geo_point(model.origin)
  1133. params = rest.GeoDistanceParams(origin=origin, to=model.to)
  1134. return rest.GeoDistance(geo_distance=params)
  1135. @classmethod
  1136. def convert_decay_params_expression(
  1137. cls, model: grpc.DecayParamsExpression
  1138. ) -> rest.DecayParamsExpression:
  1139. return rest.DecayParamsExpression(
  1140. x=cls.convert_expression(model.x),
  1141. target=cls.convert_expression(model.target) if model.HasField("target") else None,
  1142. midpoint=model.midpoint if model.HasField("midpoint") else None,
  1143. scale=model.scale if model.HasField("scale") else None,
  1144. )
  1145. @classmethod
  1146. def convert_mmr(cls, model: grpc.Mmr) -> rest.Mmr:
  1147. return rest.Mmr(
  1148. diversity=model.diversity if model.HasField("diversity") else None,
  1149. candidates_limit=model.candidates_limit
  1150. if model.HasField("candidates_limit")
  1151. else None,
  1152. )
  1153. @classmethod
  1154. def convert_query(cls, model: grpc.Query) -> rest.Query:
  1155. name = model.WhichOneof("variant")
  1156. if name is None:
  1157. raise ValueError(f"invalid Query model: {model}") # pragma: no cover
  1158. val = getattr(model, name)
  1159. if name == "nearest":
  1160. return rest.NearestQuery(nearest=cls.convert_vector_input(val))
  1161. if name == "recommend":
  1162. return rest.RecommendQuery(recommend=cls.convert_recommend_input(val))
  1163. if name == "discover":
  1164. return rest.DiscoverQuery(discover=cls.convert_discover_input(val))
  1165. if name == "context":
  1166. return rest.ContextQuery(context=cls.convert_context_input(val))
  1167. if name == "order_by":
  1168. return rest.OrderByQuery(order_by=cls.convert_order_by(val))
  1169. if name == "fusion":
  1170. return rest.FusionQuery(fusion=cls.convert_fusion(val))
  1171. if name == "sample":
  1172. return rest.SampleQuery(sample=cls.convert_sample(val))
  1173. if name == "formula":
  1174. return cls.convert_formula_query(val)
  1175. if name == "nearest_with_mmr":
  1176. val = model.nearest_with_mmr
  1177. return rest.NearestQuery(
  1178. nearest=cls.convert_vector_input(val.nearest), mmr=cls.convert_mmr(val.mmr)
  1179. )
  1180. raise ValueError(f"invalid Query model: {model}") # pragma: no cover
  1181. @classmethod
  1182. def convert_prefetch_query(cls, model: grpc.PrefetchQuery) -> rest.Prefetch:
  1183. return rest.Prefetch(
  1184. prefetch=(
  1185. [cls.convert_prefetch_query(prefetch) for prefetch in model.prefetch]
  1186. if len(model.prefetch) != 0
  1187. else None
  1188. ),
  1189. query=cls.convert_query(model.query) if model.HasField("query") else None,
  1190. using=model.using if model.HasField("using") else None,
  1191. filter=cls.convert_filter(model.filter) if model.HasField("filter") else None,
  1192. params=cls.convert_search_params(model.params) if model.HasField("params") else None,
  1193. score_threshold=model.score_threshold if model.HasField("score_threshold") else None,
  1194. limit=model.limit if model.HasField("limit") else None,
  1195. lookup_from=(
  1196. cls.convert_lookup_location(model.lookup_from)
  1197. if model.HasField("lookup_from")
  1198. else None
  1199. ),
  1200. )
  1201. @classmethod
  1202. def convert_vectors_selector(cls, model: grpc.VectorsSelector) -> list[str]:
  1203. return model.names[:]
  1204. @classmethod
  1205. def convert_with_vectors_selector(cls, model: grpc.WithVectorsSelector) -> rest.WithVector:
  1206. name = model.WhichOneof("selector_options")
  1207. if name is None:
  1208. raise ValueError(f"invalid WithVectorsSelector model: {model}") # pragma: no cover
  1209. val = getattr(model, name)
  1210. if name == "enable":
  1211. return val
  1212. if name == "include":
  1213. return cls.convert_vectors_selector(val)
  1214. raise ValueError(f"invalid WithVectorsSelector model: {model}") # pragma: no cover
  1215. @classmethod
  1216. def convert_search_points(cls, model: grpc.SearchPoints) -> rest.SearchRequest:
  1217. vector = (
  1218. rest.NamedVector(name=model.vector_name, vector=model.vector[:])
  1219. if not model.HasField("sparse_indices")
  1220. else (
  1221. rest.NamedSparseVector(
  1222. name=model.vector_name,
  1223. vector=rest.SparseVector(
  1224. indices=model.sparse_indices.data[:], values=model.vector[:]
  1225. ),
  1226. )
  1227. )
  1228. )
  1229. return rest.SearchRequest(
  1230. vector=vector,
  1231. filter=cls.convert_filter(model.filter) if model.HasField("filter") else None,
  1232. limit=model.limit,
  1233. with_payload=(
  1234. cls.convert_with_payload_interface(model.with_payload)
  1235. if model.HasField("with_payload")
  1236. else None
  1237. ),
  1238. params=cls.convert_search_params(model.params) if model.HasField("params") else None,
  1239. score_threshold=model.score_threshold if model.HasField("score_threshold") else None,
  1240. offset=model.offset if model.HasField("offset") else None,
  1241. with_vector=(
  1242. cls.convert_with_vectors_selector(model.with_vectors)
  1243. if model.HasField("with_vectors")
  1244. else None
  1245. ),
  1246. shard_key=(
  1247. cls.convert_shard_key_selector(model.shard_key_selector)
  1248. if model.HasField("shard_key_selector")
  1249. else None
  1250. ),
  1251. )
  1252. @classmethod
  1253. def convert_query_points(cls, model: grpc.QueryPoints) -> rest.QueryRequest:
  1254. return rest.QueryRequest(
  1255. shard_key=(
  1256. cls.convert_shard_key_selector(model.shard_key_selector)
  1257. if model.HasField("shard_key_selector")
  1258. else None
  1259. ),
  1260. prefetch=(
  1261. [cls.convert_prefetch_query(prefetch) for prefetch in model.prefetch]
  1262. if len(model.prefetch) != 0
  1263. else None
  1264. ),
  1265. query=cls.convert_query(model.query) if model.HasField("query") else None,
  1266. using=model.using if model.HasField("using") else None,
  1267. filter=cls.convert_filter(model.filter) if model.HasField("filter") else None,
  1268. params=cls.convert_search_params(model.params) if model.HasField("params") else None,
  1269. score_threshold=model.score_threshold if model.HasField("score_threshold") else None,
  1270. limit=model.limit if model.HasField("limit") else None,
  1271. offset=model.offset if model.HasField("offset") else None,
  1272. with_vector=(
  1273. cls.convert_with_vectors_selector(model.with_vectors)
  1274. if model.HasField("with_vectors")
  1275. else None
  1276. ),
  1277. with_payload=(
  1278. cls.convert_with_payload_interface(model.with_payload)
  1279. if model.HasField("with_payload")
  1280. else None
  1281. ),
  1282. lookup_from=(
  1283. cls.convert_lookup_location(model.lookup_from)
  1284. if model.HasField("lookup_from")
  1285. else None
  1286. ),
  1287. )
  1288. @classmethod
  1289. def convert_recommend_points(cls, model: grpc.RecommendPoints) -> rest.RecommendRequest:
  1290. positive_ids = [cls.convert_point_id(point_id) for point_id in model.positive]
  1291. negative_ids = [cls.convert_point_id(point_id) for point_id in model.negative]
  1292. positive_vectors = [cls.convert_vector(vector) for vector in model.positive_vectors]
  1293. negative_vectors = [cls.convert_vector(vector) for vector in model.negative_vectors]
  1294. return rest.RecommendRequest(
  1295. positive=positive_ids + positive_vectors,
  1296. negative=negative_ids + negative_vectors,
  1297. filter=cls.convert_filter(model.filter) if model.HasField("filter") else None,
  1298. limit=model.limit,
  1299. with_payload=(
  1300. cls.convert_with_payload_interface(model.with_payload)
  1301. if model.HasField("with_payload")
  1302. else None
  1303. ),
  1304. params=cls.convert_search_params(model.params) if model.HasField("params") else None,
  1305. score_threshold=model.score_threshold if model.HasField("score_threshold") else None,
  1306. offset=model.offset if model.HasField("offset") else None,
  1307. with_vector=(
  1308. cls.convert_with_vectors_selector(model.with_vectors)
  1309. if model.HasField("with_vectors")
  1310. else None
  1311. ),
  1312. using=model.using,
  1313. lookup_from=(
  1314. cls.convert_lookup_location(model.lookup_from)
  1315. if model.HasField("lookup_from")
  1316. else None
  1317. ),
  1318. strategy=(
  1319. cls.convert_recommend_strategy(model.strategy)
  1320. if model.HasField("strategy")
  1321. else None
  1322. ),
  1323. shard_key=(
  1324. cls.convert_shard_key_selector(model.shard_key_selector)
  1325. if model.HasField("shard_key_selector")
  1326. else None
  1327. ),
  1328. )
  1329. @classmethod
  1330. def convert_discover_points(cls, model: grpc.DiscoverPoints) -> rest.DiscoverRequest:
  1331. target = cls.convert_target_vector(model.target) if model.HasField("target") else None
  1332. context = [cls.convert_context_example_pair(pair) for pair in model.context]
  1333. return rest.DiscoverRequest(
  1334. target=target,
  1335. context=context,
  1336. filter=cls.convert_filter(model.filter) if model.HasField("filter") else None,
  1337. limit=model.limit,
  1338. with_payload=(
  1339. cls.convert_with_payload_interface(model.with_payload)
  1340. if model.HasField("with_payload")
  1341. else None
  1342. ),
  1343. params=cls.convert_search_params(model.params) if model.HasField("params") else None,
  1344. offset=model.offset if model.HasField("offset") else None,
  1345. with_vector=(
  1346. cls.convert_with_vectors_selector(model.with_vectors)
  1347. if model.HasField("with_vectors")
  1348. else None
  1349. ),
  1350. using=model.using,
  1351. lookup_from=(
  1352. cls.convert_lookup_location(model.lookup_from)
  1353. if model.HasField("lookup_from")
  1354. else None
  1355. ),
  1356. shard_key=(
  1357. cls.convert_shard_key_selector(model.shard_key_selector)
  1358. if model.HasField("shard_key_selector")
  1359. else None
  1360. ),
  1361. )
  1362. @classmethod
  1363. def convert_vector_example(cls, model: grpc.VectorExample) -> rest.RecommendExample:
  1364. if model.HasField("vector"):
  1365. return cls.convert_vector(model.vector)
  1366. if model.HasField("id"):
  1367. return cls.convert_point_id(model.id)
  1368. raise ValueError(f"invalid VectorExample model: {model}") # pragma: no cover
  1369. @classmethod
  1370. def convert_target_vector(cls, model: grpc.TargetVector) -> rest.RecommendExample:
  1371. if model.HasField("single"):
  1372. return cls.convert_vector_example(model.single)
  1373. raise ValueError(f"invalid TargetVector model: {model}") # pragma: no cover
  1374. @classmethod
  1375. def convert_context_example_pair(
  1376. cls, model: grpc.ContextExamplePair
  1377. ) -> rest.ContextExamplePair:
  1378. return rest.ContextExamplePair(
  1379. positive=cls.convert_vector_example(model.positive),
  1380. negative=cls.convert_vector_example(model.negative),
  1381. )
  1382. @classmethod
  1383. def convert_tokenizer_type(cls, model: grpc.TokenizerType) -> rest.TokenizerType:
  1384. if model == grpc.Unknown:
  1385. return None
  1386. if model == grpc.Prefix:
  1387. return rest.TokenizerType.PREFIX
  1388. if model == grpc.Whitespace:
  1389. return rest.TokenizerType.WHITESPACE
  1390. if model == grpc.Word:
  1391. return rest.TokenizerType.WORD
  1392. if model == grpc.Multilingual:
  1393. return rest.TokenizerType.MULTILINGUAL
  1394. raise ValueError(f"invalid TokenizerType model: {model}") # pragma: no cover
  1395. @classmethod
  1396. def convert_text_index_params(cls, model: grpc.TextIndexParams) -> rest.TextIndexParams:
  1397. return rest.TextIndexParams(
  1398. type="text",
  1399. tokenizer=cls.convert_tokenizer_type(model.tokenizer),
  1400. min_token_len=model.min_token_len if model.HasField("min_token_len") else None,
  1401. max_token_len=model.max_token_len if model.HasField("max_token_len") else None,
  1402. lowercase=model.lowercase if model.HasField("lowercase") else None,
  1403. phrase_matching=model.phrase_matching if model.HasField("phrase_matching") else None,
  1404. stopwords=cls.convert_stopwords(model.stopwords)
  1405. if model.HasField("stopwords")
  1406. else None,
  1407. on_disk=model.on_disk if model.HasField("on_disk") else None,
  1408. stemmer=cls.convert_stemmer(model.stemmer) if model.HasField("stemmer") else None,
  1409. )
  1410. @classmethod
  1411. def convert_stopwords(cls, model: grpc.StopwordsSet) -> rest.StopwordsInterface:
  1412. languages = model.languages[:]
  1413. custom = model.custom[:]
  1414. if len(languages) == 1 and not custom:
  1415. return rest.Language(languages[0])
  1416. return rest.StopwordsSet(languages=languages, custom=custom)
  1417. @classmethod
  1418. def convert_stemmer(cls, model: grpc.StemmingAlgorithm) -> rest.StemmingAlgorithm:
  1419. name = model.WhichOneof("stemming_params")
  1420. if name is None:
  1421. raise ValueError(f"invalid StemmingAlgorithm model: {model}") # pragma: no cover
  1422. val = getattr(model, name)
  1423. if name == "snowball":
  1424. return cls.convert_snowball_parameters(val)
  1425. raise ValueError(f"invalid StemmingAlgorithm model: {model}") # pragma: no cover
  1426. @classmethod
  1427. def convert_snowball_parameters(cls, model: grpc.SnowballParams) -> rest.SnowballParams:
  1428. return rest.SnowballParams(
  1429. type=rest.Snowball.SNOWBALL, language=rest.SnowballLanguage(model.language)
  1430. )
  1431. @classmethod
  1432. def convert_integer_index_params(
  1433. cls, model: grpc.IntegerIndexParams
  1434. ) -> rest.IntegerIndexParams:
  1435. return rest.IntegerIndexParams(
  1436. type=rest.IntegerIndexType.INTEGER,
  1437. range=model.range,
  1438. lookup=model.lookup,
  1439. is_principal=model.is_principal if model.HasField("is_principal") else None,
  1440. on_disk=model.on_disk if model.HasField("on_disk") else None,
  1441. )
  1442. @classmethod
  1443. def convert_keyword_index_params(
  1444. cls, model: grpc.KeywordIndexParams
  1445. ) -> rest.KeywordIndexParams:
  1446. return rest.KeywordIndexParams(
  1447. type=rest.KeywordIndexType.KEYWORD,
  1448. is_tenant=model.is_tenant if model.HasField("is_tenant") else None,
  1449. on_disk=model.on_disk if model.HasField("on_disk") else None,
  1450. )
  1451. @classmethod
  1452. def convert_float_index_params(cls, model: grpc.FloatIndexParams) -> rest.FloatIndexParams:
  1453. return rest.FloatIndexParams(
  1454. type=rest.FloatIndexType.FLOAT,
  1455. is_principal=model.is_principal if model.HasField("is_principal") else None,
  1456. on_disk=model.on_disk if model.HasField("on_disk") else None,
  1457. )
  1458. @classmethod
  1459. def convert_geo_index_params(cls, model: grpc.GeoIndexParams) -> rest.GeoIndexParams:
  1460. return rest.GeoIndexParams(
  1461. type=rest.GeoIndexType.GEO,
  1462. on_disk=model.on_disk if model.HasField("on_disk") else None,
  1463. )
  1464. @classmethod
  1465. def convert_bool_index_params(cls, model: grpc.BoolIndexParams) -> rest.BoolIndexParams:
  1466. return rest.BoolIndexParams(
  1467. type=rest.BoolIndexType.BOOL,
  1468. on_disk=model.on_disk if model.HasField("on_disk") else None,
  1469. )
  1470. @classmethod
  1471. def convert_datetime_index_params(
  1472. cls, model: grpc.DatetimeIndexParams
  1473. ) -> rest.DatetimeIndexParams:
  1474. return rest.DatetimeIndexParams(
  1475. type=rest.DatetimeIndexType.DATETIME,
  1476. is_principal=model.is_principal if model.HasField("is_principal") else None,
  1477. on_disk=model.on_disk if model.HasField("on_disk") else None,
  1478. )
  1479. @classmethod
  1480. def convert_uuid_index_params(cls, model: grpc.UuidIndexParams) -> rest.UuidIndexParams:
  1481. return rest.UuidIndexParams(
  1482. type=rest.UuidIndexType.UUID,
  1483. is_tenant=model.is_tenant if model.HasField("is_tenant") else None,
  1484. on_disk=model.on_disk if model.HasField("on_disk") else None,
  1485. )
  1486. @classmethod
  1487. def convert_collection_params_diff(
  1488. cls, model: grpc.CollectionParamsDiff
  1489. ) -> rest.CollectionParamsDiff:
  1490. return rest.CollectionParamsDiff(
  1491. replication_factor=(
  1492. model.replication_factor if model.HasField("replication_factor") else None
  1493. ),
  1494. write_consistency_factor=(
  1495. model.write_consistency_factor
  1496. if model.HasField("write_consistency_factor")
  1497. else None
  1498. ),
  1499. read_fan_out_factor=(
  1500. model.read_fan_out_factor if model.HasField("read_fan_out_factor") else None
  1501. ),
  1502. on_disk_payload=model.on_disk_payload if model.HasField("on_disk_payload") else None,
  1503. )
  1504. @classmethod
  1505. def convert_lookup_location(cls, model: grpc.LookupLocation) -> rest.LookupLocation:
  1506. return rest.LookupLocation(
  1507. collection=model.collection_name,
  1508. vector=model.vector_name if model.HasField("vector_name") else None,
  1509. )
  1510. @classmethod
  1511. def convert_write_ordering(cls, model: grpc.WriteOrdering) -> rest.WriteOrdering:
  1512. if model.type == grpc.WriteOrderingType.Weak:
  1513. return rest.WriteOrdering.WEAK
  1514. if model.type == grpc.WriteOrderingType.Medium:
  1515. return rest.WriteOrdering.MEDIUM
  1516. if model.type == grpc.WriteOrderingType.Strong:
  1517. return rest.WriteOrdering.STRONG
  1518. raise ValueError(f"invalid WriteOrdering model: {model}") # pragma: no cover
  1519. @classmethod
  1520. def convert_read_consistency(cls, model: grpc.ReadConsistency) -> rest.ReadConsistency:
  1521. name = model.WhichOneof("value")
  1522. if name is None:
  1523. raise ValueError(f"invalid ReadConsistency model: {model}") # pragma: no cover
  1524. val = getattr(model, name)
  1525. if name == "factor":
  1526. return val
  1527. if name == "type":
  1528. return cls.convert_read_consistency_type(val)
  1529. raise ValueError(f"invalid ReadConsistency model: {model}") # pragma: no cover
  1530. @classmethod
  1531. def convert_read_consistency_type(
  1532. cls, model: grpc.ReadConsistencyType
  1533. ) -> rest.ReadConsistencyType:
  1534. if model == grpc.All:
  1535. return rest.ReadConsistencyType.ALL
  1536. if model == grpc.Majority:
  1537. return rest.ReadConsistencyType.MAJORITY
  1538. if model == grpc.Quorum:
  1539. return rest.ReadConsistencyType.QUORUM
  1540. raise ValueError(f"invalid ReadConsistencyType model: {model}") # pragma: no cover
  1541. @classmethod
  1542. def convert_scalar_quantization_config(
  1543. cls, model: grpc.ScalarQuantization
  1544. ) -> rest.ScalarQuantizationConfig:
  1545. return rest.ScalarQuantizationConfig(
  1546. type=rest.ScalarType.INT8,
  1547. quantile=model.quantile if model.HasField("quantile") else None,
  1548. always_ram=model.always_ram if model.HasField("always_ram") else None,
  1549. )
  1550. @classmethod
  1551. def convert_product_quantization_config(
  1552. cls, model: grpc.ProductQuantization
  1553. ) -> rest.ProductQuantizationConfig:
  1554. return rest.ProductQuantizationConfig(
  1555. compression=cls.convert_compression_ratio(model.compression),
  1556. always_ram=model.always_ram if model.HasField("always_ram") else None,
  1557. )
  1558. @classmethod
  1559. def convert_binary_quantization_config(
  1560. cls, model: grpc.BinaryQuantization
  1561. ) -> rest.BinaryQuantizationConfig:
  1562. return rest.BinaryQuantizationConfig(
  1563. always_ram=model.always_ram if model.HasField("always_ram") else None,
  1564. encoding=cls.convert_binary_quantization_encoding(model.encoding)
  1565. if model.HasField("encoding")
  1566. else None,
  1567. query_encoding=cls.convert_binary_quantization_query_encoding(model.query_encoding)
  1568. if model.HasField("query_encoding")
  1569. else None,
  1570. )
  1571. @classmethod
  1572. def convert_binary_quantization_encoding(
  1573. cls, model: grpc.BinaryQuantizationEncoding
  1574. ) -> rest.BinaryQuantizationEncoding:
  1575. if model == grpc.BinaryQuantizationEncoding.OneBit:
  1576. return rest.BinaryQuantizationEncoding.ONE_BIT
  1577. if model == grpc.BinaryQuantizationEncoding.TwoBits:
  1578. return rest.BinaryQuantizationEncoding.TWO_BITS
  1579. if model == grpc.BinaryQuantizationEncoding.OneAndHalfBits:
  1580. return rest.BinaryQuantizationEncoding.ONE_AND_HALF_BITS
  1581. raise ValueError(f"invalid BinaryQuantizationEncoding model: {model}") # pragma: no cover
  1582. @classmethod
  1583. def convert_binary_quantization_query_encoding(
  1584. cls, model: grpc.BinaryQuantizationQueryEncoding
  1585. ) -> rest.BinaryQuantizationQueryEncoding:
  1586. name = model.WhichOneof("variant")
  1587. if name is None:
  1588. raise ValueError(f"invalid BinaryQuantizationQueryEncoding model: {model}")
  1589. val = getattr(model, name)
  1590. if name == "setting":
  1591. if val == grpc.BinaryQuantizationQueryEncoding.Setting.Default:
  1592. return rest.BinaryQuantizationQueryEncoding.DEFAULT
  1593. if val == grpc.BinaryQuantizationQueryEncoding.Setting.Binary:
  1594. return rest.BinaryQuantizationQueryEncoding.BINARY
  1595. if val == grpc.BinaryQuantizationQueryEncoding.Setting.Scalar4Bits:
  1596. return rest.BinaryQuantizationQueryEncoding.SCALAR4BITS
  1597. if val == grpc.BinaryQuantizationQueryEncoding.Setting.Scalar8Bits:
  1598. return rest.BinaryQuantizationQueryEncoding.SCALAR8BITS
  1599. raise ValueError(
  1600. f"invalid BinaryQuantizationQueryEncoding setting: {val}"
  1601. ) # pragma: no cover
  1602. raise ValueError(
  1603. f"invalid BinaryQuantizationQueryEncoding model: {model}"
  1604. ) # pragma: no cover
  1605. @classmethod
  1606. def convert_compression_ratio(cls, model: grpc.CompressionRatio) -> rest.CompressionRatio:
  1607. if model == grpc.x4:
  1608. return rest.CompressionRatio.X4
  1609. if model == grpc.x8:
  1610. return rest.CompressionRatio.X8
  1611. if model == grpc.x16:
  1612. return rest.CompressionRatio.X16
  1613. if model == grpc.x32:
  1614. return rest.CompressionRatio.X32
  1615. if model == grpc.x64:
  1616. return rest.CompressionRatio.X64
  1617. raise ValueError(f"invalid CompressionRatio model: {model}") # pragma: no cover
  1618. @classmethod
  1619. def convert_quantization_config(
  1620. cls, model: grpc.QuantizationConfig
  1621. ) -> rest.QuantizationConfig:
  1622. name = model.WhichOneof("quantization")
  1623. if name is None:
  1624. raise ValueError(f"invalid QuantizationConfig model: {model}") # pragma: no cover
  1625. val = getattr(model, name)
  1626. if name == "scalar":
  1627. return rest.ScalarQuantization(scalar=cls.convert_scalar_quantization_config(val))
  1628. if name == "product":
  1629. return rest.ProductQuantization(product=cls.convert_product_quantization_config(val))
  1630. if name == "binary":
  1631. return rest.BinaryQuantization(binary=cls.convert_binary_quantization_config(val))
  1632. raise ValueError(f"invalid QuantizationConfig model: {model}") # pragma: no cover
  1633. @classmethod
  1634. def convert_quantization_search_params(
  1635. cls, model: grpc.QuantizationSearchParams
  1636. ) -> rest.QuantizationSearchParams:
  1637. return rest.QuantizationSearchParams(
  1638. ignore=model.ignore if model.HasField("ignore") else None,
  1639. rescore=model.rescore if model.HasField("rescore") else None,
  1640. oversampling=model.oversampling if model.HasField("oversampling") else None,
  1641. )
  1642. @classmethod
  1643. def convert_point_vectors(cls, model: grpc.PointVectors) -> rest.PointVectors:
  1644. return rest.PointVectors(
  1645. id=cls.convert_point_id(model.id),
  1646. vector=cls.convert_vectors(model.vectors),
  1647. )
  1648. @classmethod
  1649. def convert_groups_result(cls, model: grpc.GroupsResult) -> rest.GroupsResult:
  1650. return rest.GroupsResult(
  1651. groups=[cls.convert_point_group(group) for group in model.groups],
  1652. )
  1653. @classmethod
  1654. def convert_point_group(cls, model: grpc.PointGroup) -> rest.PointGroup:
  1655. return rest.PointGroup(
  1656. id=cls.convert_group_id(model.id),
  1657. hits=[cls.convert_scored_point(hit) for hit in model.hits],
  1658. lookup=cls.convert_record(model.lookup) if model.HasField("lookup") else None,
  1659. )
  1660. @classmethod
  1661. def convert_group_id(cls, model: grpc.GroupId) -> rest.GroupId:
  1662. name = model.WhichOneof("kind")
  1663. if name is None:
  1664. raise ValueError(f"invalid GroupId model: {model}") # pragma: no cover
  1665. val = getattr(model, name)
  1666. return val
  1667. @classmethod
  1668. def convert_with_lookup(cls, model: grpc.WithLookup) -> rest.WithLookup:
  1669. return rest.WithLookup(
  1670. collection=model.collection,
  1671. with_payload=(
  1672. cls.convert_with_payload_selector(model.with_payload)
  1673. if model.HasField("with_payload")
  1674. else None
  1675. ),
  1676. with_vectors=(
  1677. cls.convert_with_vectors_selector(model.with_vectors)
  1678. if model.HasField("with_vectors")
  1679. else None
  1680. ),
  1681. )
  1682. @classmethod
  1683. def convert_quantization_config_diff(
  1684. cls, model: grpc.QuantizationConfigDiff
  1685. ) -> rest.QuantizationConfigDiff:
  1686. name = model.WhichOneof("quantization")
  1687. if name is None:
  1688. raise ValueError(f"invalid QuantizationConfigDiff model: {model}") # pragma: no cover
  1689. val = getattr(model, name)
  1690. if name == "scalar":
  1691. return rest.ScalarQuantization(scalar=cls.convert_scalar_quantization_config(val))
  1692. if name == "product":
  1693. return rest.ProductQuantization(product=cls.convert_product_quantization_config(val))
  1694. if name == "binary":
  1695. return rest.BinaryQuantization(binary=cls.convert_binary_quantization_config(val))
  1696. if name == "disabled":
  1697. return rest.Disabled.DISABLED
  1698. raise ValueError(f"invalid QuantizationConfigDiff model: {model}") # pragma: no cover
  1699. @classmethod
  1700. def convert_vector_params_diff(cls, model: grpc.VectorParamsDiff) -> rest.VectorParamsDiff:
  1701. return rest.VectorParamsDiff(
  1702. hnsw_config=(
  1703. cls.convert_hnsw_config_diff(model.hnsw_config)
  1704. if model.HasField("hnsw_config")
  1705. else None
  1706. ),
  1707. quantization_config=(
  1708. cls.convert_quantization_config_diff(model.quantization_config)
  1709. if model.HasField("quantization_config")
  1710. else None
  1711. ),
  1712. on_disk=model.on_disk if model.HasField("on_disk") else None,
  1713. )
  1714. @classmethod
  1715. def convert_vectors_config_diff(cls, model: grpc.VectorsConfigDiff) -> rest.VectorsConfigDiff:
  1716. name = model.WhichOneof("config")
  1717. if name is None:
  1718. raise ValueError(f"invalid VectorsConfigDiff model: {model}") # pragma: no cover
  1719. val = getattr(model, name)
  1720. if name == "params":
  1721. return {"": cls.convert_vector_params_diff(val)}
  1722. if name == "params_map":
  1723. return dict(
  1724. (key, cls.convert_vector_params_diff(vec_params))
  1725. for key, vec_params in val.map.items()
  1726. )
  1727. raise ValueError(f"invalid VectorsConfigDiff model: {model}") # pragma: no cover
  1728. @classmethod
  1729. def convert_points_update_operation(
  1730. cls, model: grpc.PointsUpdateOperation
  1731. ) -> rest.UpdateOperation:
  1732. name = model.WhichOneof("operation")
  1733. if name is None:
  1734. raise ValueError(f"invalid PointsUpdateOperation model: {model}") # pragma: no cover
  1735. val = getattr(model, name)
  1736. if name == "upsert":
  1737. shard_key_selector = (
  1738. cls.convert_shard_key(val.shard_key_selector)
  1739. if val.HasField("shard_key_selector")
  1740. else None
  1741. )
  1742. return rest.UpsertOperation(
  1743. upsert=rest.PointsList(
  1744. points=[cls.convert_point_struct(point) for point in val.points],
  1745. shard_key=shard_key_selector,
  1746. )
  1747. )
  1748. elif name == "delete_points":
  1749. shard_key_selector = (
  1750. val.shard_key_selector if val.HasField("shard_key_selector") else None
  1751. )
  1752. points_selector = cls.convert_points_selector(
  1753. val.points, shard_key_selector=shard_key_selector
  1754. )
  1755. return rest.DeleteOperation(delete=points_selector)
  1756. elif name == "set_payload":
  1757. shard_key_selector = (
  1758. val.shard_key_selector if val.HasField("shard_key_selector") else None
  1759. )
  1760. points_selector = cls.convert_points_selector(
  1761. val.points_selector, shard_key_selector=shard_key_selector
  1762. )
  1763. points = None
  1764. filter_ = None
  1765. if isinstance(points_selector, rest.PointIdsList):
  1766. points = points_selector.points
  1767. elif isinstance(points_selector, rest.FilterSelector):
  1768. filter_ = points_selector.filter
  1769. else:
  1770. raise ValueError(
  1771. f"invalid PointsSelector model: {points_selector}"
  1772. ) # pragma: no cover
  1773. return rest.SetPayloadOperation(
  1774. set_payload=rest.SetPayload(
  1775. payload=cls.convert_payload(val.payload),
  1776. points=points,
  1777. filter=filter_,
  1778. key=val.key if val.HasField("key") else None,
  1779. )
  1780. )
  1781. elif name == "overwrite_payload":
  1782. shard_key_selector = (
  1783. val.shard_key_selector if val.HasField("shard_key_selector") else None
  1784. )
  1785. points_selector = cls.convert_points_selector(
  1786. val.points_selector, shard_key_selector=shard_key_selector
  1787. )
  1788. points = None
  1789. filter_ = None
  1790. if isinstance(points_selector, rest.PointIdsList):
  1791. points = points_selector.points
  1792. elif isinstance(points_selector, rest.FilterSelector):
  1793. filter_ = points_selector.filter
  1794. else:
  1795. raise ValueError(
  1796. f"invalid PointsSelector model: {points_selector}"
  1797. ) # pragma: no cover
  1798. return rest.OverwritePayloadOperation(
  1799. overwrite_payload=rest.SetPayload(
  1800. payload=cls.convert_payload(val.payload),
  1801. points=points,
  1802. filter=filter_,
  1803. key=val.key if val.HasField("key") else None,
  1804. )
  1805. )
  1806. elif name == "delete_payload":
  1807. shard_key_selector = (
  1808. val.shard_key_selector if val.HasField("shard_key_selector") else None
  1809. )
  1810. points_selector = cls.convert_points_selector(
  1811. val.points_selector, shard_key_selector=shard_key_selector
  1812. )
  1813. points = None
  1814. filter_ = None
  1815. if isinstance(points_selector, rest.PointIdsList):
  1816. points = points_selector.points
  1817. elif isinstance(points_selector, rest.FilterSelector):
  1818. filter_ = points_selector.filter
  1819. else:
  1820. raise ValueError(
  1821. f"invalid PointsSelector model: {points_selector}"
  1822. ) # pragma: no cover
  1823. return rest.DeletePayloadOperation(
  1824. delete_payload=rest.DeletePayload(
  1825. keys=[key for key in val.keys],
  1826. points=points,
  1827. filter=filter_,
  1828. )
  1829. )
  1830. elif name == "clear_payload":
  1831. shard_key_selector = (
  1832. val.shard_key_selector if val.HasField("shard_key_selector") else None
  1833. )
  1834. points_selector = cls.convert_points_selector(
  1835. val.points, shard_key_selector=shard_key_selector
  1836. )
  1837. return rest.ClearPayloadOperation(clear_payload=points_selector)
  1838. elif name == "update_vectors":
  1839. shard_key_selector = (
  1840. cls.convert_shard_key(val.shard_key_selector)
  1841. if val.HasField("shard_key_selector")
  1842. else None
  1843. )
  1844. return rest.UpdateVectorsOperation(
  1845. update_vectors=rest.UpdateVectors(
  1846. points=[cls.convert_point_vectors(point) for point in val.points],
  1847. shard_key=shard_key_selector,
  1848. )
  1849. )
  1850. elif name == "delete_vectors":
  1851. shard_key_selector = (
  1852. val.shard_key_selector if val.HasField("shard_key_selector") else None
  1853. )
  1854. points_selector = cls.convert_points_selector(
  1855. val.points_selector, shard_key_selector=shard_key_selector
  1856. )
  1857. points = None
  1858. filter_ = None
  1859. if isinstance(points_selector, rest.PointIdsList):
  1860. points = points_selector.points
  1861. elif isinstance(points_selector, rest.FilterSelector):
  1862. filter_ = points_selector.filter
  1863. else:
  1864. raise ValueError(
  1865. f"invalid PointsSelector model: {points_selector}"
  1866. ) # pragma: no cover
  1867. return rest.DeleteVectorsOperation(
  1868. delete_vectors=rest.DeleteVectors(
  1869. vector=[name for name in val.vectors.names],
  1870. points=points,
  1871. filter=filter_,
  1872. )
  1873. )
  1874. else:
  1875. raise ValueError(f"invalid UpdateOperation model: {model}") # pragma: no cover
  1876. @classmethod
  1877. def convert_init_from(cls, model: str) -> rest.InitFrom:
  1878. if isinstance(model, str):
  1879. return rest.InitFrom(collection=model)
  1880. raise ValueError(f"Invalid InitFrom model: {model}") # pragma: no cover
  1881. @classmethod
  1882. def convert_recommend_strategy(cls, model: grpc.RecommendStrategy) -> rest.RecommendStrategy:
  1883. if model == grpc.RecommendStrategy.AverageVector:
  1884. return rest.RecommendStrategy.AVERAGE_VECTOR
  1885. if model == grpc.RecommendStrategy.BestScore:
  1886. return rest.RecommendStrategy.BEST_SCORE
  1887. if model == grpc.RecommendStrategy.SumScores:
  1888. return rest.RecommendStrategy.SUM_SCORES
  1889. raise ValueError(f"invalid RecommendStrategy model: {model}") # pragma: no cover
  1890. @classmethod
  1891. def convert_sparse_index_config(cls, model: grpc.SparseIndexConfig) -> rest.SparseIndexParams:
  1892. return rest.SparseIndexParams(
  1893. full_scan_threshold=(
  1894. model.full_scan_threshold if model.HasField("full_scan_threshold") else None
  1895. ),
  1896. on_disk=model.on_disk if model.HasField("on_disk") else None,
  1897. datatype=cls.convert_datatype(model.datatype) if model.HasField("datatype") else None,
  1898. )
  1899. @classmethod
  1900. def convert_modifier(cls, model: grpc.Modifier) -> rest.Modifier:
  1901. if model == grpc.Modifier.Idf:
  1902. return rest.Modifier.IDF
  1903. if model == getattr(grpc.Modifier, "None"):
  1904. return rest.Modifier.NONE
  1905. raise ValueError(f"invalid Modifier model: {model}") # pragma: no cover
  1906. @classmethod
  1907. def convert_sparse_vector_params(
  1908. cls, model: grpc.SparseVectorParams
  1909. ) -> rest.SparseVectorParams:
  1910. return rest.SparseVectorParams(
  1911. index=(
  1912. cls.convert_sparse_index_config(model.index)
  1913. if model.HasField("index") is not None
  1914. else None
  1915. ),
  1916. modifier=(
  1917. cls.convert_modifier(model.modifier) if model.HasField("modifier") else None
  1918. ),
  1919. )
  1920. @classmethod
  1921. def convert_sparse_vector_config(
  1922. cls, model: grpc.SparseVectorConfig
  1923. ) -> dict[str, rest.SparseVectorParams]:
  1924. return dict((key, cls.convert_sparse_vector_params(val)) for key, val in model.map.items())
  1925. @classmethod
  1926. def convert_shard_key(cls, model: grpc.ShardKey) -> rest.ShardKey:
  1927. name = model.WhichOneof("key")
  1928. if name is None:
  1929. raise ValueError(f"invalid ShardKey model: {model}") # pragma: no cover
  1930. val = getattr(model, name)
  1931. return val
  1932. @classmethod
  1933. def convert_shard_key_selector(cls, model: grpc.ShardKeySelector) -> rest.ShardKeySelector:
  1934. if len(model.shard_keys) == 1:
  1935. return cls.convert_shard_key(model.shard_keys[0])
  1936. return [cls.convert_shard_key(shard_key) for shard_key in model.shard_keys]
  1937. @classmethod
  1938. def convert_sharding_method(cls, model: grpc.ShardingMethod) -> rest.ShardingMethod:
  1939. if model == grpc.Auto:
  1940. return rest.ShardingMethod.AUTO
  1941. if model == grpc.Custom:
  1942. return rest.ShardingMethod.CUSTOM
  1943. raise ValueError(f"invalid ShardingMethod model: {model}") # pragma: no cover
  1944. @classmethod
  1945. def convert_direction(cls, model: grpc.Direction) -> rest.Direction:
  1946. if model == grpc.Asc:
  1947. return rest.Direction.ASC
  1948. if model == grpc.Desc:
  1949. return rest.Direction.DESC
  1950. raise ValueError(f"invalid Direction model: {model}") # pragma: no cover
  1951. @classmethod
  1952. def convert_start_from(cls, model: grpc.StartFrom) -> rest.StartFrom:
  1953. if model.HasField("integer"):
  1954. return model.integer
  1955. if model.HasField("float"):
  1956. return model.float
  1957. if model.HasField("timestamp"):
  1958. dt = cls.convert_timestamp(model.timestamp)
  1959. return dt
  1960. if model.HasField("datetime"):
  1961. return model.datetime
  1962. @classmethod
  1963. def convert_order_by(cls, model: grpc.OrderBy) -> rest.OrderBy:
  1964. return rest.OrderBy(
  1965. key=model.key,
  1966. direction=(
  1967. cls.convert_direction(model.direction) if model.HasField("direction") else None
  1968. ),
  1969. start_from=(
  1970. cls.convert_start_from(model.start_from) if model.HasField("start_from") else None
  1971. ),
  1972. )
  1973. @classmethod
  1974. def convert_facet_value(cls, model: grpc.FacetValue) -> rest.FacetValue:
  1975. name = model.WhichOneof("variant")
  1976. if name is None:
  1977. raise ValueError(f"invalid FacetValue model: {model}") # pragma: no cover
  1978. val = getattr(model, name)
  1979. return val
  1980. @classmethod
  1981. def convert_facet_value_hit(cls, model: grpc.FacetHit) -> rest.FacetValueHit:
  1982. return rest.FacetValueHit(
  1983. value=cls.convert_facet_value(model.value),
  1984. count=model.count,
  1985. )
  1986. @classmethod
  1987. def convert_health_check_reply(cls, model: grpc.HealthCheckReply) -> rest.VersionInfo:
  1988. return rest.VersionInfo(
  1989. title=model.title,
  1990. version=model.version,
  1991. commit=model.commit if model.HasField("commit") else None,
  1992. )
  1993. @classmethod
  1994. def convert_search_matrix_pair(cls, model: grpc.SearchMatrixPair) -> rest.SearchMatrixPair:
  1995. return rest.SearchMatrixPair(
  1996. a=cls.convert_point_id(model.a),
  1997. b=cls.convert_point_id(model.b),
  1998. score=model.score,
  1999. )
  2000. @classmethod
  2001. def convert_search_matrix_pairs(
  2002. cls, model: grpc.SearchMatrixPairs
  2003. ) -> rest.SearchMatrixPairsResponse:
  2004. return rest.SearchMatrixPairsResponse(
  2005. pairs=[cls.convert_search_matrix_pair(pair) for pair in model.pairs],
  2006. )
  2007. @classmethod
  2008. def convert_search_matrix_offsets(
  2009. cls, model: grpc.SearchMatrixOffsets
  2010. ) -> rest.SearchMatrixOffsetsResponse:
  2011. return rest.SearchMatrixOffsetsResponse(
  2012. offsets_row=list(model.offsets_row),
  2013. offsets_col=list(model.offsets_col),
  2014. scores=list(model.scores),
  2015. ids=[cls.convert_point_id(p_id) for p_id in model.ids],
  2016. )
  2017. @classmethod
  2018. def convert_strict_mode_multivector(
  2019. cls, model: grpc.StrictModeMultivector
  2020. ) -> rest.StrictModeMultivector:
  2021. return rest.StrictModeMultivector(
  2022. max_vectors=model.max_vectors if model.HasField("max_vectors") else None
  2023. )
  2024. @classmethod
  2025. def convert_strict_mode_multivector_config(
  2026. cls, model: grpc.StrictModeMultivectorConfig
  2027. ) -> rest.StrictModeMultivectorConfig:
  2028. return dict(
  2029. (key, cls.convert_strict_mode_multivector(val))
  2030. for key, val in model.multivector_config.items()
  2031. )
  2032. @classmethod
  2033. def convert_strict_mode_sparse(cls, model: grpc.StrictModeSparse) -> rest.StrictModeSparse:
  2034. return rest.StrictModeSparse(
  2035. max_length=model.max_length if model.HasField("max_length") else None
  2036. )
  2037. @classmethod
  2038. def convert_strict_mode_sparse_config(
  2039. cls, model: grpc.StrictModeSparseConfig
  2040. ) -> rest.StrictModeSparseConfig:
  2041. return dict(
  2042. (key, cls.convert_strict_mode_sparse(val)) for key, val in model.sparse_config.items()
  2043. )
  2044. @classmethod
  2045. def convert_strict_mode_config(cls, model: grpc.StrictModeConfig) -> rest.StrictModeConfig:
  2046. return rest.StrictModeConfig(
  2047. enabled=model.enabled if model.HasField("enabled") else None,
  2048. max_query_limit=model.max_query_limit if model.HasField("max_query_limit") else None,
  2049. max_timeout=model.max_timeout if model.HasField("max_timeout") else None,
  2050. unindexed_filtering_retrieve=(
  2051. model.unindexed_filtering_retrieve
  2052. if model.HasField("unindexed_filtering_retrieve")
  2053. else None
  2054. ),
  2055. unindexed_filtering_update=(
  2056. model.unindexed_filtering_update
  2057. if model.HasField("unindexed_filtering_update")
  2058. else None
  2059. ),
  2060. search_max_hnsw_ef=(
  2061. model.search_max_hnsw_ef if model.HasField("search_max_hnsw_ef") else None
  2062. ),
  2063. search_allow_exact=(
  2064. model.search_allow_exact if model.HasField("search_allow_exact") else None
  2065. ),
  2066. search_max_oversampling=(
  2067. model.search_max_oversampling
  2068. if model.HasField("search_max_oversampling")
  2069. else None
  2070. ),
  2071. upsert_max_batchsize=(
  2072. model.upsert_max_batchsize if model.HasField("upsert_max_batchsize") else None
  2073. ),
  2074. max_collection_vector_size_bytes=(
  2075. model.max_collection_vector_size_bytes
  2076. if model.HasField("max_collection_vector_size_bytes")
  2077. else None
  2078. ),
  2079. read_rate_limit=model.read_rate_limit if model.HasField("read_rate_limit") else None,
  2080. write_rate_limit=(
  2081. model.write_rate_limit if model.HasField("write_rate_limit") else None
  2082. ),
  2083. max_collection_payload_size_bytes=(
  2084. model.max_collection_payload_size_bytes
  2085. if model.HasField("max_collection_payload_size_bytes")
  2086. else None
  2087. ),
  2088. max_points_count=(
  2089. model.max_points_count if model.HasField("max_points_count") else None
  2090. ),
  2091. filter_max_conditions=(
  2092. model.filter_max_conditions if model.HasField("filter_max_conditions") else None
  2093. ),
  2094. condition_max_size=(
  2095. model.condition_max_size if model.HasField("condition_max_size") else None
  2096. ),
  2097. multivector_config=(
  2098. cls.convert_strict_mode_multivector_config(model.multivector_config)
  2099. if model.HasField("multivector_config")
  2100. else None
  2101. ),
  2102. sparse_config=(
  2103. cls.convert_strict_mode_sparse_config(model.sparse_config)
  2104. if model.HasField("sparse_config")
  2105. else None
  2106. ),
  2107. )
  2108. @classmethod
  2109. def convert_strict_mode_config_output(
  2110. cls, model: grpc.StrictModeConfig
  2111. ) -> rest.StrictModeConfigOutput:
  2112. return rest.StrictModeConfigOutput(
  2113. enabled=model.enabled if model.HasField("enabled") else None,
  2114. max_query_limit=model.max_query_limit if model.HasField("max_query_limit") else None,
  2115. max_timeout=model.max_timeout if model.HasField("max_timeout") else None,
  2116. unindexed_filtering_retrieve=(
  2117. model.unindexed_filtering_retrieve
  2118. if model.HasField("unindexed_filtering_retrieve")
  2119. else None
  2120. ),
  2121. unindexed_filtering_update=(
  2122. model.unindexed_filtering_update
  2123. if model.HasField("unindexed_filtering_update")
  2124. else None
  2125. ),
  2126. search_max_hnsw_ef=(
  2127. model.search_max_hnsw_ef if model.HasField("search_max_hnsw_ef") else None
  2128. ),
  2129. search_allow_exact=(
  2130. model.search_allow_exact if model.HasField("search_allow_exact") else None
  2131. ),
  2132. search_max_oversampling=(
  2133. model.search_max_oversampling
  2134. if model.HasField("search_max_oversampling")
  2135. else None
  2136. ),
  2137. upsert_max_batchsize=(
  2138. model.upsert_max_batchsize if model.HasField("upsert_max_batchsize") else None
  2139. ),
  2140. max_collection_vector_size_bytes=(
  2141. model.max_collection_vector_size_bytes
  2142. if model.HasField("max_collection_vector_size_bytes")
  2143. else None
  2144. ),
  2145. read_rate_limit=model.read_rate_limit if model.HasField("read_rate_limit") else None,
  2146. write_rate_limit=(
  2147. model.write_rate_limit if model.HasField("write_rate_limit") else None
  2148. ),
  2149. max_collection_payload_size_bytes=(
  2150. model.max_collection_payload_size_bytes
  2151. if model.HasField("max_collection_payload_size_bytes")
  2152. else None
  2153. ),
  2154. max_points_count=(
  2155. model.max_points_count if model.HasField("max_points_count") else None
  2156. ),
  2157. filter_max_conditions=(
  2158. model.filter_max_conditions if model.HasField("filter_max_conditions") else None
  2159. ),
  2160. condition_max_size=(
  2161. model.condition_max_size if model.HasField("condition_max_size") else None
  2162. ),
  2163. multivector_config=(
  2164. cls.convert_strict_mode_multivector_config_output(model.multivector_config)
  2165. if model.HasField("multivector_config")
  2166. else None
  2167. ),
  2168. sparse_config=(
  2169. cls.convert_strict_mode_sparse_config_output(model.sparse_config)
  2170. if model.HasField("sparse_config")
  2171. else None
  2172. ),
  2173. )
  2174. @classmethod
  2175. def convert_strict_mode_multivector_config_output(
  2176. cls, model: grpc.StrictModeMultivectorConfig
  2177. ) -> rest.StrictModeMultivectorConfigOutput:
  2178. return dict(
  2179. (key, cls.convert_strict_mode_multivector_output(val))
  2180. for key, val in model.multivector_config.items()
  2181. )
  2182. @classmethod
  2183. def convert_strict_mode_sparse_config_output(
  2184. cls, model: grpc.StrictModeSparseConfig
  2185. ) -> rest.StrictModeSparseConfigOutput:
  2186. return dict(
  2187. (key, cls.convert_strict_mode_sparse_output(val))
  2188. for key, val in model.sparse_config.items()
  2189. )
  2190. @classmethod
  2191. def convert_strict_mode_sparse_output(
  2192. cls, model: grpc.StrictModeSparse
  2193. ) -> rest.StrictModeSparseOutput:
  2194. return rest.StrictModeSparseOutput(
  2195. max_length=model.max_length if model.HasField("max_length") else None
  2196. )
  2197. @classmethod
  2198. def convert_strict_mode_multivector_output(
  2199. cls, model: grpc.StrictModeMultivector
  2200. ) -> rest.StrictModeMultivectorOutput:
  2201. return rest.StrictModeMultivectorOutput(
  2202. max_vectors=model.max_vectors if model.HasField("max_vectors") else None
  2203. )
  2204. # ----------------------------------------
  2205. #
  2206. # ----------- REST TO gRPC ---------------
  2207. #
  2208. # ----------------------------------------
  2209. class RestToGrpc:
  2210. @classmethod
  2211. def convert_filter(cls, model: rest.Filter) -> grpc.Filter:
  2212. def convert_conditions(
  2213. conditions: Union[list[rest.Condition], rest.Condition],
  2214. ) -> list[grpc.Condition]:
  2215. if not isinstance(conditions, list):
  2216. conditions = [conditions]
  2217. return [cls.convert_condition(condition) for condition in conditions]
  2218. return grpc.Filter(
  2219. must=(convert_conditions(model.must) if model.must is not None else None),
  2220. must_not=(convert_conditions(model.must_not) if model.must_not is not None else None),
  2221. should=(convert_conditions(model.should) if model.should is not None else None),
  2222. min_should=(
  2223. grpc.MinShould(
  2224. conditions=convert_conditions(model.min_should.conditions),
  2225. min_count=model.min_should.min_count,
  2226. )
  2227. if model.min_should is not None
  2228. else None
  2229. ),
  2230. )
  2231. @classmethod
  2232. def convert_range(cls, model: rest.Range) -> grpc.Range:
  2233. return grpc.Range(
  2234. lt=model.lt,
  2235. gt=model.gt,
  2236. gte=model.gte,
  2237. lte=model.lte,
  2238. )
  2239. @classmethod
  2240. def convert_datetime(cls, model: Union[datetime, date]) -> Timestamp:
  2241. if isinstance(model, date) and not isinstance(model, datetime):
  2242. model = datetime.combine(model, datetime.min.time())
  2243. ts = Timestamp()
  2244. ts.FromDatetime(model)
  2245. return ts
  2246. @classmethod
  2247. def convert_datetime_range(cls, model: rest.DatetimeRange) -> grpc.DatetimeRange:
  2248. return grpc.DatetimeRange(
  2249. lt=cls.convert_datetime(model.lt) if model.lt is not None else None,
  2250. gt=cls.convert_datetime(model.gt) if model.gt is not None else None,
  2251. gte=cls.convert_datetime(model.gte) if model.gte is not None else None,
  2252. lte=cls.convert_datetime(model.lte) if model.lte is not None else None,
  2253. )
  2254. @classmethod
  2255. def convert_geo_radius(cls, model: rest.GeoRadius) -> grpc.GeoRadius:
  2256. return grpc.GeoRadius(center=cls.convert_geo_point(model.center), radius=model.radius)
  2257. @classmethod
  2258. def convert_geo_line_string(cls, model: rest.GeoLineString) -> grpc.GeoLineString:
  2259. return grpc.GeoLineString(points=[cls.convert_geo_point(point) for point in model.points])
  2260. @classmethod
  2261. def convert_geo_polygon(cls, model: rest.GeoPolygon) -> grpc.GeoPolygon:
  2262. return grpc.GeoPolygon(
  2263. exterior=cls.convert_geo_line_string(model.exterior),
  2264. interiors=[cls.convert_geo_line_string(interior) for interior in model.interiors]
  2265. if model.interiors
  2266. else None,
  2267. )
  2268. @classmethod
  2269. def convert_collection_description(
  2270. cls, model: rest.CollectionDescription
  2271. ) -> grpc.CollectionDescription:
  2272. return grpc.CollectionDescription(name=model.name)
  2273. @classmethod
  2274. def convert_collection_info(cls, model: rest.CollectionInfo) -> grpc.CollectionInfo:
  2275. return grpc.CollectionInfo(
  2276. config=cls.convert_collection_config(model.config) if model.config else None,
  2277. optimizer_status=cls.convert_optimizer_status(model.optimizer_status),
  2278. payload_schema=(
  2279. cls.convert_payload_schema(model.payload_schema)
  2280. if model.payload_schema is not None
  2281. else None
  2282. ),
  2283. segments_count=model.segments_count,
  2284. status=cls.convert_collection_status(model.status),
  2285. vectors_count=model.vectors_count if model.vectors_count is not None else None,
  2286. points_count=model.points_count,
  2287. )
  2288. @classmethod
  2289. def convert_collection_status(cls, model: rest.CollectionStatus) -> grpc.CollectionStatus:
  2290. if model == rest.CollectionStatus.RED:
  2291. return grpc.CollectionStatus.Red
  2292. if model == rest.CollectionStatus.YELLOW:
  2293. return grpc.CollectionStatus.Yellow
  2294. if model == rest.CollectionStatus.GREEN:
  2295. return grpc.CollectionStatus.Green
  2296. if model == rest.CollectionStatus.GREY:
  2297. return grpc.CollectionStatus.Grey
  2298. raise ValueError(f"invalid CollectionStatus model: {model}") # pragma: no cover
  2299. @classmethod
  2300. def convert_optimizer_status(cls, model: rest.OptimizersStatus) -> grpc.OptimizerStatus:
  2301. if isinstance(model, rest.OptimizersStatusOneOf):
  2302. return grpc.OptimizerStatus(
  2303. ok=True,
  2304. )
  2305. if isinstance(model, rest.OptimizersStatusOneOf1):
  2306. return grpc.OptimizerStatus(ok=False, error=model.error)
  2307. raise ValueError(f"invalid OptimizersStatus model: {model}") # pragma: no cover
  2308. @classmethod
  2309. def convert_payload_schema(
  2310. cls, model: dict[str, rest.PayloadIndexInfo]
  2311. ) -> dict[str, grpc.PayloadSchemaInfo]:
  2312. return dict((key, cls.convert_payload_index_info(val)) for key, val in model.items())
  2313. @classmethod
  2314. def convert_payload_index_info(cls, model: rest.PayloadIndexInfo) -> grpc.PayloadSchemaInfo:
  2315. params = model.params
  2316. return grpc.PayloadSchemaInfo(
  2317. data_type=cls.convert_payload_schema_type(model.data_type),
  2318. params=cls.convert_payload_schema_params(params) if params is not None else None,
  2319. points=model.points,
  2320. )
  2321. @classmethod
  2322. def convert_payload_schema_params(
  2323. cls, model: rest.PayloadSchemaParams
  2324. ) -> grpc.PayloadIndexParams:
  2325. if isinstance(model, rest.TextIndexParams):
  2326. return grpc.PayloadIndexParams(text_index_params=cls.convert_text_index_params(model))
  2327. if isinstance(model, rest.IntegerIndexParams):
  2328. return grpc.PayloadIndexParams(
  2329. integer_index_params=cls.convert_integer_index_params(model)
  2330. )
  2331. if isinstance(model, rest.KeywordIndexParams):
  2332. return grpc.PayloadIndexParams(
  2333. keyword_index_params=cls.convert_keyword_index_params(model)
  2334. )
  2335. if isinstance(model, rest.FloatIndexParams):
  2336. return grpc.PayloadIndexParams(
  2337. float_index_params=cls.convert_float_index_params(model)
  2338. )
  2339. if isinstance(model, rest.GeoIndexParams):
  2340. return grpc.PayloadIndexParams(geo_index_params=cls.convert_geo_index_params(model))
  2341. if isinstance(model, rest.BoolIndexParams):
  2342. return grpc.PayloadIndexParams(bool_index_params=cls.convert_bool_index_params(model))
  2343. if isinstance(model, rest.DatetimeIndexParams):
  2344. return grpc.PayloadIndexParams(
  2345. datetime_index_params=cls.convert_datetime_index_params(model)
  2346. )
  2347. if isinstance(model, rest.UuidIndexParams):
  2348. return grpc.PayloadIndexParams(uuid_index_params=cls.convert_uuid_index_params(model))
  2349. raise ValueError(f"invalid PayloadSchemaParams model: {model}") # pragma: no cover
  2350. @classmethod
  2351. def convert_payload_schema_type(cls, model: rest.PayloadSchemaType) -> grpc.PayloadSchemaType:
  2352. if model == rest.PayloadSchemaType.KEYWORD:
  2353. return grpc.PayloadSchemaType.Keyword
  2354. if model == rest.PayloadSchemaType.INTEGER:
  2355. return grpc.PayloadSchemaType.Integer
  2356. if model == rest.PayloadSchemaType.FLOAT:
  2357. return grpc.PayloadSchemaType.Float
  2358. if model == rest.PayloadSchemaType.BOOL:
  2359. return grpc.PayloadSchemaType.Bool
  2360. if model == rest.PayloadSchemaType.GEO:
  2361. return grpc.PayloadSchemaType.Geo
  2362. if model == rest.PayloadSchemaType.TEXT:
  2363. return grpc.PayloadSchemaType.Text
  2364. if model == rest.PayloadSchemaType.DATETIME:
  2365. return grpc.PayloadSchemaType.Datetime
  2366. if model == rest.PayloadSchemaType.UUID:
  2367. return grpc.PayloadSchemaType.Uuid
  2368. raise ValueError(f"invalid PayloadSchemaType model: {model}") # pragma: no cover
  2369. @classmethod
  2370. def convert_update_result(cls, model: rest.UpdateResult) -> grpc.UpdateResult:
  2371. return grpc.UpdateResult(
  2372. operation_id=model.operation_id,
  2373. status=cls.convert_update_stats(model.status),
  2374. )
  2375. @classmethod
  2376. def convert_update_stats(cls, model: rest.UpdateStatus) -> grpc.UpdateStatus:
  2377. if model == rest.UpdateStatus.COMPLETED:
  2378. return grpc.UpdateStatus.Completed
  2379. if model == rest.UpdateStatus.ACKNOWLEDGED:
  2380. return grpc.UpdateStatus.Acknowledged
  2381. raise ValueError(f"invalid UpdateStatus model: {model}") # pragma: no cover
  2382. @classmethod
  2383. def convert_has_id_condition(cls, model: rest.HasIdCondition) -> grpc.HasIdCondition:
  2384. return grpc.HasIdCondition(
  2385. has_id=[cls.convert_extended_point_id(idx) for idx in model.has_id]
  2386. )
  2387. @classmethod
  2388. def convert_has_vector_condition(
  2389. cls, model: rest.HasVectorCondition
  2390. ) -> grpc.HasVectorCondition:
  2391. return grpc.HasVectorCondition(has_vector=model.has_vector)
  2392. @classmethod
  2393. def convert_delete_alias(cls, model: rest.DeleteAlias) -> grpc.DeleteAlias:
  2394. return grpc.DeleteAlias(alias_name=model.alias_name)
  2395. @classmethod
  2396. def convert_rename_alias(cls, model: rest.RenameAlias) -> grpc.RenameAlias:
  2397. return grpc.RenameAlias(
  2398. old_alias_name=model.old_alias_name, new_alias_name=model.new_alias_name
  2399. )
  2400. @classmethod
  2401. def convert_is_empty_condition(cls, model: rest.IsEmptyCondition) -> grpc.IsEmptyCondition:
  2402. return grpc.IsEmptyCondition(key=model.is_empty.key)
  2403. @classmethod
  2404. def convert_is_null_condition(cls, model: rest.IsNullCondition) -> grpc.IsNullCondition:
  2405. return grpc.IsNullCondition(key=model.is_null.key)
  2406. @classmethod
  2407. def convert_nested_condition(cls, model: rest.NestedCondition) -> grpc.NestedCondition:
  2408. return grpc.NestedCondition(
  2409. key=model.nested.key,
  2410. filter=cls.convert_filter(model.nested.filter),
  2411. )
  2412. @classmethod
  2413. def convert_search_params(cls, model: rest.SearchParams) -> grpc.SearchParams:
  2414. return grpc.SearchParams(
  2415. hnsw_ef=model.hnsw_ef,
  2416. exact=model.exact,
  2417. quantization=(
  2418. cls.convert_quantization_search_params(model.quantization)
  2419. if model.quantization is not None
  2420. else None
  2421. ),
  2422. indexed_only=model.indexed_only,
  2423. )
  2424. @classmethod
  2425. def convert_create_alias(cls, model: rest.CreateAlias) -> grpc.CreateAlias:
  2426. return grpc.CreateAlias(collection_name=model.collection_name, alias_name=model.alias_name)
  2427. @classmethod
  2428. def convert_order_value(cls, model: rest.OrderValue) -> grpc.OrderValue:
  2429. if isinstance(model, int):
  2430. return grpc.OrderValue(int=model)
  2431. if isinstance(model, float):
  2432. return grpc.OrderValue(float=model)
  2433. raise ValueError(f"invalid OrderValue model: {model}") # pragma: no cover
  2434. @classmethod
  2435. def convert_scored_point(cls, model: rest.ScoredPoint) -> grpc.ScoredPoint:
  2436. return grpc.ScoredPoint(
  2437. id=cls.convert_extended_point_id(model.id),
  2438. payload=cls.convert_payload(model.payload) if model.payload is not None else None,
  2439. score=model.score,
  2440. vectors=(
  2441. cls.convert_vector_struct_output(model.vector)
  2442. if model.vector is not None
  2443. else None
  2444. ),
  2445. version=model.version,
  2446. shard_key=cls.convert_shard_key(model.shard_key) if model.shard_key else None,
  2447. order_value=cls.convert_order_value(model.order_value) if model.order_value else None,
  2448. )
  2449. @classmethod
  2450. def convert_values_count(cls, model: rest.ValuesCount) -> grpc.ValuesCount:
  2451. return grpc.ValuesCount(
  2452. lt=model.lt,
  2453. gt=model.gt,
  2454. gte=model.gte,
  2455. lte=model.lte,
  2456. )
  2457. @classmethod
  2458. def convert_geo_bounding_box(cls, model: rest.GeoBoundingBox) -> grpc.GeoBoundingBox:
  2459. return grpc.GeoBoundingBox(
  2460. top_left=cls.convert_geo_point(model.top_left),
  2461. bottom_right=cls.convert_geo_point(model.bottom_right),
  2462. )
  2463. @classmethod
  2464. def convert_point_struct(cls, model: rest.PointStruct) -> grpc.PointStruct:
  2465. return grpc.PointStruct(
  2466. id=cls.convert_extended_point_id(model.id),
  2467. vectors=cls.convert_vector_struct(model.vector),
  2468. payload=cls.convert_payload(model.payload) if model.payload is not None else None,
  2469. )
  2470. @classmethod
  2471. def convert_payload(cls, model: rest.Payload) -> dict[str, grpc.Value]:
  2472. return dict((key, json_to_value(val)) for key, val in model.items())
  2473. @classmethod
  2474. def convert_hnsw_config_diff(cls, model: rest.HnswConfigDiff) -> grpc.HnswConfigDiff:
  2475. return grpc.HnswConfigDiff(
  2476. ef_construct=model.ef_construct,
  2477. full_scan_threshold=model.full_scan_threshold,
  2478. m=model.m,
  2479. max_indexing_threads=model.max_indexing_threads,
  2480. on_disk=model.on_disk,
  2481. payload_m=model.payload_m,
  2482. )
  2483. @classmethod
  2484. def convert_field_condition(cls, model: rest.FieldCondition) -> grpc.FieldCondition:
  2485. if model.match is not None:
  2486. return grpc.FieldCondition(key=model.key, match=cls.convert_match(model.match))
  2487. if model.range is not None:
  2488. if isinstance(model.range, rest.Range):
  2489. return grpc.FieldCondition(key=model.key, range=cls.convert_range(model.range))
  2490. if isinstance(model.range, rest.DatetimeRange):
  2491. return grpc.FieldCondition(
  2492. key=model.key,
  2493. datetime_range=cls.convert_datetime_range(model.range),
  2494. )
  2495. if model.geo_bounding_box is not None:
  2496. return grpc.FieldCondition(
  2497. key=model.key,
  2498. geo_bounding_box=cls.convert_geo_bounding_box(model.geo_bounding_box),
  2499. )
  2500. if model.geo_radius is not None:
  2501. return grpc.FieldCondition(
  2502. key=model.key, geo_radius=cls.convert_geo_radius(model.geo_radius)
  2503. )
  2504. if model.geo_polygon is not None:
  2505. return grpc.FieldCondition(
  2506. key=model.key, geo_polygon=cls.convert_geo_polygon(model.geo_polygon)
  2507. )
  2508. if model.values_count is not None:
  2509. return grpc.FieldCondition(
  2510. key=model.key, values_count=cls.convert_values_count(model.values_count)
  2511. )
  2512. if model.is_empty is not None:
  2513. return grpc.FieldCondition(key=model.key, is_empty=model.is_empty)
  2514. if model.is_null is not None:
  2515. return grpc.FieldCondition(key=model.key, is_null=model.is_null)
  2516. raise ValueError(f"invalid FieldCondition model: {model}") # pragma: no cover
  2517. @classmethod
  2518. def convert_wal_config_diff(cls, model: rest.WalConfigDiff) -> grpc.WalConfigDiff:
  2519. return grpc.WalConfigDiff(
  2520. wal_capacity_mb=model.wal_capacity_mb,
  2521. wal_segments_ahead=model.wal_segments_ahead,
  2522. )
  2523. @classmethod
  2524. def convert_collection_config(cls, model: rest.CollectionConfig) -> grpc.CollectionConfig:
  2525. return grpc.CollectionConfig(
  2526. params=cls.convert_collection_params(model.params),
  2527. hnsw_config=cls.convert_hnsw_config(model.hnsw_config),
  2528. optimizer_config=cls.convert_optimizers_config(model.optimizer_config),
  2529. wal_config=cls.convert_wal_config(model.wal_config),
  2530. quantization_config=(
  2531. cls.convert_quantization_config(model.quantization_config)
  2532. if model.quantization_config is not None
  2533. else None
  2534. ),
  2535. strict_mode_config=(
  2536. cls.convert_strict_mode_config_output(model.strict_mode_config)
  2537. if model.strict_mode_config is not None
  2538. else None
  2539. ),
  2540. )
  2541. @classmethod
  2542. def convert_hnsw_config(cls, model: rest.HnswConfig) -> grpc.HnswConfigDiff:
  2543. return grpc.HnswConfigDiff(
  2544. ef_construct=model.ef_construct,
  2545. full_scan_threshold=model.full_scan_threshold,
  2546. m=model.m,
  2547. max_indexing_threads=model.max_indexing_threads,
  2548. on_disk=model.on_disk,
  2549. payload_m=model.payload_m,
  2550. )
  2551. @classmethod
  2552. def convert_wal_config(cls, model: rest.WalConfig) -> grpc.WalConfigDiff:
  2553. return grpc.WalConfigDiff(
  2554. wal_capacity_mb=model.wal_capacity_mb,
  2555. wal_segments_ahead=model.wal_segments_ahead,
  2556. )
  2557. @classmethod
  2558. def convert_distance(cls, model: rest.Distance) -> grpc.Distance:
  2559. if model == rest.Distance.DOT:
  2560. return grpc.Distance.Dot
  2561. if model == rest.Distance.COSINE:
  2562. return grpc.Distance.Cosine
  2563. if model == rest.Distance.EUCLID:
  2564. return grpc.Distance.Euclid
  2565. if model == rest.Distance.MANHATTAN:
  2566. return grpc.Distance.Manhattan
  2567. raise ValueError(f"invalid Distance model: {model}") # pragma: no cover
  2568. @classmethod
  2569. def convert_collection_params(cls, model: rest.CollectionParams) -> grpc.CollectionParams:
  2570. return grpc.CollectionParams(
  2571. vectors_config=(
  2572. cls.convert_vectors_config(model.vectors) if model.vectors is not None else None
  2573. ),
  2574. shard_number=model.shard_number,
  2575. on_disk_payload=model.on_disk_payload or False,
  2576. write_consistency_factor=model.write_consistency_factor,
  2577. replication_factor=model.replication_factor,
  2578. read_fan_out_factor=model.read_fan_out_factor,
  2579. sparse_vectors_config=(
  2580. cls.convert_sparse_vector_config(model.sparse_vectors)
  2581. if model.sparse_vectors is not None
  2582. else None
  2583. ),
  2584. sharding_method=(
  2585. cls.convert_sharding_method(model.sharding_method)
  2586. if model.sharding_method is not None
  2587. else None
  2588. ),
  2589. )
  2590. @classmethod
  2591. def convert_max_optimization_threads(
  2592. cls, model: rest.MaxOptimizationThreads
  2593. ) -> grpc.MaxOptimizationThreads:
  2594. if model == rest.MaxOptimizationThreadsSetting.AUTO:
  2595. return grpc.MaxOptimizationThreads(setting=grpc.MaxOptimizationThreads.Setting.Auto)
  2596. elif isinstance(model, int):
  2597. return grpc.MaxOptimizationThreads(value=model)
  2598. raise ValueError(f"invalid MaxOptimizationThreads model: {model}") # pragma: no cover
  2599. @classmethod
  2600. def convert_optimizers_config(cls, model: rest.OptimizersConfig) -> grpc.OptimizersConfigDiff:
  2601. return grpc.OptimizersConfigDiff(
  2602. default_segment_number=model.default_segment_number,
  2603. deleted_threshold=model.deleted_threshold,
  2604. flush_interval_sec=model.flush_interval_sec,
  2605. indexing_threshold=model.indexing_threshold,
  2606. max_optimization_threads=(
  2607. cls.convert_max_optimization_threads(model.max_optimization_threads)
  2608. if model.max_optimization_threads is not None
  2609. else None
  2610. ),
  2611. max_segment_size=model.max_segment_size,
  2612. memmap_threshold=model.memmap_threshold,
  2613. vacuum_min_vector_number=model.vacuum_min_vector_number,
  2614. deprecated_max_optimization_threads=model.max_optimization_threads,
  2615. )
  2616. @classmethod
  2617. def convert_optimizers_config_diff(
  2618. cls, model: rest.OptimizersConfigDiff
  2619. ) -> grpc.OptimizersConfigDiff:
  2620. deprecated_max_optimization_threads = None
  2621. if isinstance(model.max_optimization_threads, int):
  2622. deprecated_max_optimization_threads = model.max_optimization_threads
  2623. return grpc.OptimizersConfigDiff(
  2624. default_segment_number=model.default_segment_number,
  2625. deleted_threshold=model.deleted_threshold,
  2626. flush_interval_sec=model.flush_interval_sec,
  2627. indexing_threshold=model.indexing_threshold,
  2628. max_optimization_threads=(
  2629. cls.convert_max_optimization_threads(model.max_optimization_threads)
  2630. if model.max_optimization_threads is not None
  2631. else None
  2632. ),
  2633. max_segment_size=model.max_segment_size,
  2634. memmap_threshold=model.memmap_threshold,
  2635. vacuum_min_vector_number=model.vacuum_min_vector_number,
  2636. deprecated_max_optimization_threads=deprecated_max_optimization_threads,
  2637. )
  2638. @classmethod
  2639. def convert_update_collection(
  2640. cls, model: rest.UpdateCollection, collection_name: str
  2641. ) -> grpc.UpdateCollection:
  2642. return grpc.UpdateCollection(
  2643. collection_name=collection_name,
  2644. optimizers_config=(
  2645. cls.convert_optimizers_config_diff(model.optimizers_config)
  2646. if model.optimizers_config is not None
  2647. else None
  2648. ),
  2649. vectors_config=(
  2650. cls.convert_vectors_config_diff(model.vectors)
  2651. if model.vectors is not None
  2652. else None
  2653. ),
  2654. params=(
  2655. cls.convert_collection_params_diff(model.params)
  2656. if model.params is not None
  2657. else None
  2658. ),
  2659. hnsw_config=(
  2660. cls.convert_hnsw_config_diff(model.hnsw_config)
  2661. if model.hnsw_config is not None
  2662. else None
  2663. ),
  2664. quantization_config=(
  2665. cls.convert_quantization_config_diff(model.quantization_config)
  2666. if model.quantization_config is not None
  2667. else None
  2668. ),
  2669. )
  2670. @classmethod
  2671. def convert_geo_point(cls, model: rest.GeoPoint) -> grpc.GeoPoint:
  2672. return grpc.GeoPoint(lon=model.lon, lat=model.lat)
  2673. @classmethod
  2674. def convert_match(cls, model: rest.Match) -> grpc.Match:
  2675. if isinstance(model, rest.MatchValue):
  2676. if isinstance(model.value, bool):
  2677. return grpc.Match(boolean=model.value)
  2678. if isinstance(model.value, int):
  2679. return grpc.Match(integer=model.value)
  2680. if isinstance(model.value, str):
  2681. return grpc.Match(keyword=model.value)
  2682. if isinstance(model, rest.MatchText):
  2683. return grpc.Match(text=model.text)
  2684. if isinstance(model, rest.MatchAny):
  2685. if len(model.any) == 0:
  2686. return grpc.Match(keywords=grpc.RepeatedStrings(strings=[]))
  2687. if isinstance(model.any[0], str):
  2688. return grpc.Match(keywords=grpc.RepeatedStrings(strings=model.any))
  2689. if isinstance(model.any[0], int):
  2690. return grpc.Match(integers=grpc.RepeatedIntegers(integers=model.any))
  2691. raise ValueError(f"invalid MatchAny model: {model}") # pragma: no cover
  2692. if isinstance(model, rest.MatchExcept):
  2693. if len(model.except_) == 0:
  2694. return grpc.Match(except_keywords=grpc.RepeatedStrings(strings=[]))
  2695. if isinstance(model.except_[0], str):
  2696. return grpc.Match(except_keywords=grpc.RepeatedStrings(strings=model.except_))
  2697. if isinstance(model.except_[0], int):
  2698. return grpc.Match(except_integers=grpc.RepeatedIntegers(integers=model.except_))
  2699. raise ValueError(f"invalid MatchExcept model: {model}") # pragma: no cover
  2700. if isinstance(model, rest.MatchPhrase):
  2701. return grpc.Match(phrase=model.phrase)
  2702. raise ValueError(f"invalid Match model: {model}") # pragma: no cover
  2703. @classmethod
  2704. def convert_alias_operations(cls, model: rest.AliasOperations) -> grpc.AliasOperations:
  2705. if isinstance(model, rest.CreateAliasOperation):
  2706. return grpc.AliasOperations(create_alias=cls.convert_create_alias(model.create_alias))
  2707. if isinstance(model, rest.DeleteAliasOperation):
  2708. return grpc.AliasOperations(delete_alias=cls.convert_delete_alias(model.delete_alias))
  2709. if isinstance(model, rest.RenameAliasOperation):
  2710. return grpc.AliasOperations(rename_alias=cls.convert_rename_alias(model.rename_alias))
  2711. raise ValueError(f"invalid AliasOperations model: {model}") # pragma: no cover
  2712. @classmethod
  2713. def convert_alias_description(cls, model: rest.AliasDescription) -> grpc.AliasDescription:
  2714. return grpc.AliasDescription(
  2715. alias_name=model.alias_name,
  2716. collection_name=model.collection_name,
  2717. )
  2718. @classmethod
  2719. def convert_recommend_examples_to_ids(
  2720. cls, examples: Sequence[rest.RecommendExample]
  2721. ) -> list[grpc.PointId]:
  2722. ids: list[grpc.PointId] = []
  2723. for example in examples:
  2724. if isinstance(example, get_args_subscribed(rest.ExtendedPointId)):
  2725. id_ = cls.convert_extended_point_id(example)
  2726. elif isinstance(example, grpc.PointId):
  2727. id_ = example
  2728. else:
  2729. continue
  2730. ids.append(id_)
  2731. return ids
  2732. @classmethod
  2733. def convert_recommend_examples_to_vectors(
  2734. cls, examples: Sequence[rest.RecommendExample]
  2735. ) -> list[grpc.Vector]:
  2736. vectors: list[grpc.Vector] = []
  2737. for example in examples:
  2738. if isinstance(example, grpc.Vector):
  2739. vector = example
  2740. elif isinstance(example, list):
  2741. vector = grpc.Vector(data=example)
  2742. elif isinstance(example, rest.SparseVector):
  2743. vector = cls.convert_sparse_vector_to_vector(example)
  2744. else:
  2745. continue
  2746. vectors.append(vector)
  2747. return vectors
  2748. @classmethod
  2749. def convert_vector_example(cls, model: rest.RecommendExample) -> grpc.VectorExample:
  2750. return cls.convert_recommend_example(model)
  2751. @classmethod
  2752. def convert_recommend_example(cls, model: rest.RecommendExample) -> grpc.VectorExample:
  2753. if isinstance(model, get_args_subscribed(rest.ExtendedPointId)):
  2754. return grpc.VectorExample(id=cls.convert_extended_point_id(model))
  2755. if isinstance(model, rest.SparseVector):
  2756. return grpc.VectorExample(vector=cls.convert_sparse_vector_to_vector(model))
  2757. if isinstance(model, list):
  2758. return grpc.VectorExample(vector=grpc.Vector(data=model))
  2759. raise ValueError(f"Invalid RecommendExample model: {model}") # pragma: no cover
  2760. @classmethod
  2761. def convert_sparse_vector_to_vector(cls, model: rest.SparseVector) -> grpc.Vector:
  2762. return grpc.Vector(
  2763. data=model.values,
  2764. indices=grpc.SparseIndices(data=model.indices),
  2765. )
  2766. @classmethod
  2767. def convert_sparse_vector_to_vector_output(cls, model: rest.SparseVector) -> grpc.VectorOutput:
  2768. return grpc.VectorOutput(
  2769. data=model.values,
  2770. indices=grpc.SparseIndices(data=model.indices),
  2771. )
  2772. @classmethod
  2773. def convert_target_vector(cls, model: rest.RecommendExample) -> grpc.TargetVector:
  2774. return grpc.TargetVector(single=cls.convert_recommend_example(model))
  2775. @classmethod
  2776. def convert_context_example_pair(
  2777. cls,
  2778. model: rest.ContextExamplePair,
  2779. ) -> grpc.ContextExamplePair:
  2780. return grpc.ContextExamplePair(
  2781. positive=cls.convert_recommend_example(model.positive),
  2782. negative=cls.convert_recommend_example(model.negative),
  2783. )
  2784. @classmethod
  2785. def convert_extended_point_id(cls, model: rest.ExtendedPointId) -> grpc.PointId:
  2786. if isinstance(model, int):
  2787. return grpc.PointId(num=model)
  2788. if isinstance(model, str):
  2789. return grpc.PointId(uuid=model)
  2790. raise ValueError(f"invalid ExtendedPointId model: {model}") # pragma: no cover
  2791. @classmethod
  2792. def convert_points_selector(cls, model: rest.PointsSelector) -> grpc.PointsSelector:
  2793. if isinstance(model, rest.PointIdsList):
  2794. return grpc.PointsSelector(
  2795. points=grpc.PointsIdsList(
  2796. ids=[cls.convert_extended_point_id(point) for point in model.points]
  2797. )
  2798. )
  2799. if isinstance(model, rest.FilterSelector):
  2800. return grpc.PointsSelector(filter=cls.convert_filter(model.filter))
  2801. raise ValueError(f"invalid PointsSelector model: {model}") # pragma: no cover
  2802. @classmethod
  2803. def convert_condition(cls, model: rest.Condition) -> grpc.Condition:
  2804. if isinstance(model, rest.FieldCondition):
  2805. return grpc.Condition(field=cls.convert_field_condition(model))
  2806. if isinstance(model, rest.IsEmptyCondition):
  2807. return grpc.Condition(is_empty=cls.convert_is_empty_condition(model))
  2808. if isinstance(model, rest.IsNullCondition):
  2809. return grpc.Condition(is_null=cls.convert_is_null_condition(model))
  2810. if isinstance(model, rest.HasIdCondition):
  2811. return grpc.Condition(has_id=cls.convert_has_id_condition(model))
  2812. if isinstance(model, rest.HasVectorCondition):
  2813. return grpc.Condition(has_vector=cls.convert_has_vector_condition(model))
  2814. if isinstance(model, rest.Filter):
  2815. return grpc.Condition(filter=cls.convert_filter(model))
  2816. if isinstance(model, rest.NestedCondition):
  2817. return grpc.Condition(nested=cls.convert_nested_condition(model))
  2818. raise ValueError(f"invalid Condition model: {model}") # pragma: no cover
  2819. @classmethod
  2820. def convert_payload_selector(cls, model: rest.PayloadSelector) -> grpc.WithPayloadSelector:
  2821. if isinstance(model, rest.PayloadSelectorInclude):
  2822. return grpc.WithPayloadSelector(
  2823. include=grpc.PayloadIncludeSelector(fields=model.include)
  2824. )
  2825. if isinstance(model, rest.PayloadSelectorExclude):
  2826. return grpc.WithPayloadSelector(
  2827. exclude=grpc.PayloadExcludeSelector(fields=model.exclude)
  2828. )
  2829. raise ValueError(f"invalid PayloadSelector model: {model}") # pragma: no cover
  2830. @classmethod
  2831. def convert_with_payload_selector(
  2832. cls, model: rest.PayloadSelector
  2833. ) -> grpc.WithPayloadSelector:
  2834. return cls.convert_with_payload_interface(model)
  2835. @classmethod
  2836. def convert_with_payload_interface(
  2837. cls, model: rest.WithPayloadInterface
  2838. ) -> grpc.WithPayloadSelector:
  2839. if isinstance(model, bool):
  2840. return grpc.WithPayloadSelector(enable=model)
  2841. elif isinstance(model, list):
  2842. return grpc.WithPayloadSelector(include=grpc.PayloadIncludeSelector(fields=model))
  2843. elif isinstance(model, get_args(rest.PayloadSelector)):
  2844. return cls.convert_payload_selector(model)
  2845. raise ValueError(f"invalid WithPayloadInterface model: {model}") # pragma: no cover
  2846. @classmethod
  2847. def convert_start_from(cls, model: rest.StartFrom) -> grpc.StartFrom:
  2848. if isinstance(model, int):
  2849. return grpc.StartFrom(integer=model)
  2850. if isinstance(model, float):
  2851. return grpc.StartFrom(float=model)
  2852. if isinstance(model, datetime):
  2853. ts = cls.convert_datetime(model)
  2854. return grpc.StartFrom(timestamp=ts)
  2855. if isinstance(model, str):
  2856. # Pydantic also accepts strings as datetime if they are correctly formatted
  2857. return grpc.StartFrom(datetime=model)
  2858. raise ValueError(f"invalid StartFrom model: {model}") # pragma: no cover
  2859. @classmethod
  2860. def convert_direction(cls, model: rest.Direction) -> grpc.Direction:
  2861. if model == rest.Direction.ASC:
  2862. return grpc.Direction.Asc
  2863. if model == rest.Direction.DESC:
  2864. return grpc.Direction.Desc
  2865. raise ValueError(f"invalid Direction model: {model}") # pragma: no cover
  2866. @classmethod
  2867. def convert_order_by(cls, model: rest.OrderBy) -> grpc.OrderBy:
  2868. return grpc.OrderBy(
  2869. key=model.key,
  2870. direction=(
  2871. cls.convert_direction(model.direction) if model.direction is not None else None
  2872. ),
  2873. start_from=(
  2874. cls.convert_start_from(model.start_from) if model.start_from is not None else None
  2875. ),
  2876. )
  2877. @classmethod
  2878. def convert_order_by_interface(cls, model: rest.OrderByInterface) -> grpc.OrderBy:
  2879. # using no cover because there is no OrderByInterface in grpc
  2880. if isinstance(model, str):
  2881. return grpc.OrderBy(key=model)
  2882. if isinstance(model, rest.OrderBy):
  2883. return cls.convert_order_by(model)
  2884. raise ValueError(f"invalid OrderByInterface model: {model}") # pragma: no cover
  2885. @classmethod
  2886. def convert_facet_value(cls, model: rest.FacetValue) -> grpc.FacetValue:
  2887. if isinstance(model, str):
  2888. return grpc.FacetValue(string_value=model)
  2889. if isinstance(model, int):
  2890. return grpc.FacetValue(integer_value=model)
  2891. raise ValueError(f"invalid FacetValue model: {model}") # pragma: no cover
  2892. @classmethod
  2893. def convert_facet_value_hit(cls, model: rest.FacetValueHit) -> grpc.FacetHit:
  2894. return grpc.FacetHit(
  2895. value=cls.convert_facet_value(model.value),
  2896. count=model.count,
  2897. )
  2898. @classmethod
  2899. def convert_record(cls, model: rest.Record) -> grpc.RetrievedPoint:
  2900. return grpc.RetrievedPoint(
  2901. id=cls.convert_extended_point_id(model.id),
  2902. payload=cls.convert_payload(model.payload),
  2903. vectors=(
  2904. cls.convert_vector_struct_output(model.vector)
  2905. if model.vector is not None
  2906. else None
  2907. ),
  2908. shard_key=cls.convert_shard_key(model.shard_key) if model.shard_key else None,
  2909. order_value=cls.convert_order_value(model.order_value) if model.order_value else None,
  2910. )
  2911. @classmethod
  2912. def convert_retrieved_point(cls, model: rest.Record) -> grpc.RetrievedPoint:
  2913. return cls.convert_record(model)
  2914. @classmethod
  2915. def convert_count_result(cls, model: rest.CountResult) -> grpc.CountResult:
  2916. return grpc.CountResult(count=model.count)
  2917. @classmethod
  2918. def convert_snapshot_description(
  2919. cls, model: rest.SnapshotDescription
  2920. ) -> grpc.SnapshotDescription:
  2921. timestamp = Timestamp()
  2922. timestamp.FromDatetime(datetime.fromisoformat(model.creation_time))
  2923. return grpc.SnapshotDescription(
  2924. name=model.name,
  2925. creation_time=timestamp,
  2926. size=model.size,
  2927. )
  2928. @classmethod
  2929. def convert_datatype(cls, model: rest.Datatype) -> grpc.Datatype:
  2930. if model == rest.Datatype.FLOAT32:
  2931. return grpc.Datatype.Float32
  2932. if model == rest.Datatype.UINT8:
  2933. return grpc.Datatype.Uint8
  2934. if model == rest.Datatype.FLOAT16:
  2935. return grpc.Datatype.Float16
  2936. raise ValueError(f"invalid Datatype model: {model}") # pragma: no cover
  2937. @classmethod
  2938. def convert_vector_params(cls, model: rest.VectorParams) -> grpc.VectorParams:
  2939. return grpc.VectorParams(
  2940. size=model.size,
  2941. distance=cls.convert_distance(model.distance),
  2942. hnsw_config=(
  2943. cls.convert_hnsw_config_diff(model.hnsw_config)
  2944. if model.hnsw_config is not None
  2945. else None
  2946. ),
  2947. quantization_config=(
  2948. cls.convert_quantization_config(model.quantization_config)
  2949. if model.quantization_config is not None
  2950. else None
  2951. ),
  2952. on_disk=model.on_disk,
  2953. datatype=cls.convert_datatype(model.datatype) if model.datatype is not None else None,
  2954. multivector_config=(
  2955. cls.convert_multivector_config(model.multivector_config)
  2956. if model.multivector_config is not None
  2957. else None
  2958. ),
  2959. )
  2960. @classmethod
  2961. def convert_multivector_config(cls, model: rest.MultiVectorConfig) -> grpc.MultiVectorConfig:
  2962. return grpc.MultiVectorConfig(
  2963. comparator=cls.convert_multivector_comparator(model.comparator)
  2964. )
  2965. @classmethod
  2966. def convert_multivector_comparator(
  2967. cls, model: rest.MultiVectorComparator
  2968. ) -> grpc.MultiVectorComparator:
  2969. if model == rest.MultiVectorComparator.MAX_SIM:
  2970. return grpc.MultiVectorComparator.MaxSim
  2971. raise ValueError(f"invalid MultiVectorComparator model: {model}") # pragma: no cover
  2972. @classmethod
  2973. def convert_vectors_config(cls, model: rest.VectorsConfig) -> grpc.VectorsConfig:
  2974. if isinstance(model, rest.VectorParams):
  2975. return grpc.VectorsConfig(params=cls.convert_vector_params(model))
  2976. elif isinstance(model, dict):
  2977. return grpc.VectorsConfig(
  2978. params_map=grpc.VectorParamsMap(
  2979. map=dict((key, cls.convert_vector_params(val)) for key, val in model.items())
  2980. )
  2981. )
  2982. else:
  2983. raise ValueError(f"invalid VectorsConfig model: {model}") # pragma: no cover
  2984. @classmethod
  2985. def convert_vector_struct(cls, model: rest.VectorStruct) -> grpc.Vectors:
  2986. def convert_vector(
  2987. vector: Union[list[float], list[list[float]]],
  2988. ) -> grpc.Vector:
  2989. if len(vector) != 0 and isinstance(
  2990. vector[0], list
  2991. ): # we can't say whether it is an empty dense or multi-dense vector
  2992. return grpc.Vector(
  2993. data=[
  2994. inner_vector
  2995. for multi_vector in vector
  2996. for inner_vector in multi_vector # type: ignore
  2997. ],
  2998. vectors_count=len(vector),
  2999. )
  3000. return grpc.Vector(data=vector)
  3001. if isinstance(model, list):
  3002. return grpc.Vectors(vector=convert_vector(model))
  3003. elif isinstance(model, dict):
  3004. vectors: dict = {}
  3005. for key, val in model.items():
  3006. if isinstance(val, list):
  3007. vectors.update({key: convert_vector(val)})
  3008. elif isinstance(val, rest.SparseVector):
  3009. vectors.update({key: cls.convert_sparse_vector_to_vector(val)})
  3010. elif isinstance(val, rest.Document):
  3011. vectors.update({key: grpc.Vector(document=cls.convert_document(val))})
  3012. elif isinstance(val, rest.Image):
  3013. vectors.update({key: grpc.Vector(image=cls.convert_image(val))})
  3014. elif isinstance(val, rest.InferenceObject):
  3015. vectors.update({key: grpc.Vector(object=cls.convert_inference_object(val))})
  3016. return grpc.Vectors(vectors=grpc.NamedVectors(vectors=vectors))
  3017. elif isinstance(model, rest.Document):
  3018. return grpc.Vectors(vector=grpc.Vector(document=cls.convert_document(model)))
  3019. elif isinstance(model, rest.Image):
  3020. return grpc.Vectors(vector=grpc.Vector(image=cls.convert_image(model)))
  3021. elif isinstance(model, rest.InferenceObject):
  3022. return grpc.Vectors(vector=grpc.Vector(object=cls.convert_inference_object(model)))
  3023. else:
  3024. raise ValueError(f"invalid VectorStruct model: {model}") # pragma: no cover
  3025. @classmethod
  3026. def convert_vector_struct_output(cls, model: rest.VectorStructOutput) -> grpc.VectorsOutput:
  3027. def convert_vector(
  3028. vector: Union[list[float], list[list[float]]],
  3029. ) -> grpc.VectorOutput:
  3030. if len(vector) != 0 and isinstance(
  3031. vector[0], list
  3032. ): # we can't say whether it is an empty dense or multi-dense vector
  3033. return grpc.VectorOutput(
  3034. data=[
  3035. inner_vector
  3036. for multi_vector in vector
  3037. for inner_vector in multi_vector # type: ignore
  3038. ],
  3039. vectors_count=len(vector),
  3040. )
  3041. return grpc.VectorOutput(data=vector)
  3042. if isinstance(model, list):
  3043. return grpc.VectorsOutput(vector=convert_vector(model))
  3044. elif isinstance(model, dict):
  3045. vectors: dict = {}
  3046. for key, val in model.items():
  3047. if isinstance(val, list):
  3048. vectors.update({key: convert_vector(val)})
  3049. elif isinstance(val, rest.SparseVector):
  3050. vectors.update({key: cls.convert_sparse_vector_to_vector_output(val)})
  3051. return grpc.VectorsOutput(vectors=grpc.NamedVectorsOutput(vectors=vectors))
  3052. else:
  3053. raise ValueError(f"invalid VectorStructOutput model: {model}") # pragma: no cover
  3054. @classmethod
  3055. def convert_with_vectors(cls, model: rest.WithVector) -> grpc.WithVectorsSelector:
  3056. if isinstance(model, bool):
  3057. return grpc.WithVectorsSelector(enable=model)
  3058. elif isinstance(model, list):
  3059. return grpc.WithVectorsSelector(include=grpc.VectorsSelector(names=model))
  3060. else:
  3061. raise ValueError(f"invalid WithVectors model: {model}") # pragma: no cover
  3062. @classmethod
  3063. def convert_batch_vector_struct(
  3064. cls, model: rest.BatchVectorStruct, num_records: int
  3065. ) -> list[grpc.Vectors]:
  3066. if isinstance(model, list):
  3067. return [cls.convert_vector_struct(item) for item in model]
  3068. elif isinstance(model, dict):
  3069. result: list[dict] = [{} for _ in range(num_records)]
  3070. for key, val in model.items():
  3071. for i, item in enumerate(val):
  3072. result[i][key] = item
  3073. return [cls.convert_vector_struct(item) for item in result]
  3074. else:
  3075. raise ValueError(f"invalid BatchVectorStruct model: {model}") # pragma: no cover
  3076. @classmethod
  3077. def convert_named_vector_struct(
  3078. cls, model: rest.NamedVectorStruct
  3079. ) -> tuple[list[float], Optional[grpc.SparseIndices], Optional[str]]:
  3080. if isinstance(model, list):
  3081. return model, None, None
  3082. elif isinstance(model, rest.NamedVector):
  3083. return model.vector, None, model.name
  3084. elif isinstance(model, rest.NamedSparseVector):
  3085. return (
  3086. model.vector.values,
  3087. grpc.SparseIndices(data=model.vector.indices),
  3088. model.name,
  3089. )
  3090. else:
  3091. raise ValueError(f"invalid NamedVectorStruct model: {model}") # pragma: no cover
  3092. @classmethod
  3093. def convert_dense_vector(cls, model: list[float]) -> grpc.DenseVector:
  3094. return grpc.DenseVector(data=model)
  3095. @classmethod
  3096. def convert_sparse_vector(cls, model: rest.SparseVector) -> grpc.SparseVector:
  3097. return grpc.SparseVector(values=model.values, indices=model.indices)
  3098. @classmethod
  3099. def convert_multi_dense_vector(cls, model: list[list[float]]) -> grpc.MultiDenseVector:
  3100. return grpc.MultiDenseVector(
  3101. vectors=[cls.convert_dense_vector(vector) for vector in model]
  3102. )
  3103. @classmethod
  3104. def convert_document(cls, model: rest.Document) -> grpc.Document:
  3105. return grpc.Document(
  3106. text=model.text,
  3107. model=model.model,
  3108. options=payload_to_grpc(model.options) if model.options is not None else None,
  3109. )
  3110. @classmethod
  3111. def convert_image(cls, model: rest.Image) -> grpc.Image:
  3112. return grpc.Image(
  3113. image=json_to_value(model.image),
  3114. model=model.model,
  3115. options=payload_to_grpc(model.options) if model.options is not None else None,
  3116. )
  3117. @classmethod
  3118. def convert_inference_object(cls, model: rest.InferenceObject) -> grpc.InferenceObject:
  3119. return grpc.InferenceObject(
  3120. object=json_to_value(model.object),
  3121. model=model.model,
  3122. options=payload_to_grpc(model.options) if model.options is not None else None,
  3123. )
  3124. @classmethod
  3125. def convert_vector_input(cls, model: rest.VectorInput) -> grpc.VectorInput:
  3126. if isinstance(model, list):
  3127. if len(model) != 0 and isinstance(
  3128. model[0], list
  3129. ): # we can't say whether it is an empty dense or multi-dense vector
  3130. return grpc.VectorInput(multi_dense=cls.convert_multi_dense_vector(model))
  3131. return grpc.VectorInput(dense=cls.convert_dense_vector(model))
  3132. if isinstance(model, rest.SparseVector):
  3133. return grpc.VectorInput(sparse=cls.convert_sparse_vector(model))
  3134. if isinstance(model, get_args_subscribed(rest.ExtendedPointId)):
  3135. return grpc.VectorInput(id=cls.convert_extended_point_id(model))
  3136. if isinstance(model, rest.Document):
  3137. return grpc.VectorInput(document=cls.convert_document(model))
  3138. if isinstance(model, rest.Image):
  3139. return grpc.VectorInput(image=cls.convert_image(model))
  3140. if isinstance(model, rest.InferenceObject):
  3141. return grpc.VectorInput(object=cls.convert_inference_object(model))
  3142. raise ValueError(f"invalid VectorInput model: {model}") # pragma: no cover
  3143. @classmethod
  3144. def convert_recommend_input(cls, model: rest.RecommendInput) -> grpc.RecommendInput:
  3145. return grpc.RecommendInput(
  3146. positive=(
  3147. [cls.convert_vector_input(vector) for vector in model.positive]
  3148. if model.positive is not None
  3149. else None
  3150. ),
  3151. negative=(
  3152. [cls.convert_vector_input(vector) for vector in model.negative]
  3153. if model.negative is not None
  3154. else None
  3155. ),
  3156. strategy=(
  3157. cls.convert_recommend_strategy(model.strategy)
  3158. if model.strategy is not None
  3159. else None
  3160. ),
  3161. )
  3162. @classmethod
  3163. def convert_context_input_pair(cls, model: rest.ContextPair) -> grpc.ContextInputPair:
  3164. return grpc.ContextInputPair(
  3165. positive=cls.convert_vector_input(model.positive),
  3166. negative=cls.convert_vector_input(model.negative),
  3167. )
  3168. @classmethod
  3169. def convert_context_input(cls, model: rest.ContextInput) -> grpc.ContextInput:
  3170. if isinstance(model, list):
  3171. return grpc.ContextInput(
  3172. pairs=[cls.convert_context_input_pair(pair) for pair in model]
  3173. )
  3174. if isinstance(model, rest.ContextPair):
  3175. return grpc.ContextInput(pairs=[cls.convert_context_input_pair(model)])
  3176. raise ValueError(f"invalid ContextInput model: {model}") # pragma: no cover
  3177. @classmethod
  3178. def convert_discover_input(cls, model: rest.DiscoverInput) -> grpc.DiscoverInput:
  3179. return grpc.DiscoverInput(
  3180. target=cls.convert_vector_input(model.target),
  3181. context=cls.convert_context_input(model.context),
  3182. )
  3183. @classmethod
  3184. def convert_fusion(cls, model: rest.Fusion) -> grpc.Fusion:
  3185. if model == rest.Fusion.RRF:
  3186. return grpc.Fusion.RRF
  3187. if model == rest.Fusion.DBSF:
  3188. return grpc.Fusion.DBSF
  3189. raise ValueError(f"invalid Fusion model: {model}") # pragma: no cover
  3190. @classmethod
  3191. def convert_sample(cls, model: rest.Sample) -> grpc.Sample:
  3192. if model == rest.Sample.RANDOM:
  3193. return grpc.Sample.Random
  3194. raise ValueError(f"invalid Sample model: {model}") # pragma: no cover
  3195. @classmethod
  3196. def convert_mmr(cls, model: rest.Mmr) -> grpc.Mmr:
  3197. return grpc.Mmr(diversity=model.diversity, candidates_limit=model.candidates_limit)
  3198. @classmethod
  3199. def convert_query(cls, model: rest.Query) -> grpc.Query:
  3200. if isinstance(model, rest.NearestQuery):
  3201. if model.mmr is not None:
  3202. nearest_with_mmr = grpc.NearestInputWithMmr(
  3203. nearest=cls.convert_vector_input(model.nearest), mmr=cls.convert_mmr(model.mmr)
  3204. )
  3205. return grpc.Query(nearest_with_mmr=nearest_with_mmr)
  3206. return grpc.Query(nearest=cls.convert_vector_input(model.nearest))
  3207. if isinstance(model, rest.RecommendQuery):
  3208. return grpc.Query(recommend=cls.convert_recommend_input(model.recommend))
  3209. if isinstance(model, rest.DiscoverQuery):
  3210. return grpc.Query(discover=cls.convert_discover_input(model.discover))
  3211. if isinstance(model, rest.ContextQuery):
  3212. return grpc.Query(context=cls.convert_context_input(model.context))
  3213. if isinstance(model, rest.OrderByQuery):
  3214. return grpc.Query(order_by=cls.convert_order_by_interface(model.order_by))
  3215. if isinstance(model, rest.FusionQuery):
  3216. return grpc.Query(fusion=cls.convert_fusion(model.fusion))
  3217. if isinstance(model, rest.SampleQuery):
  3218. return grpc.Query(sample=cls.convert_sample(model.sample))
  3219. if isinstance(model, rest.FormulaQuery):
  3220. return grpc.Query(formula=cls.convert_formula_query(model))
  3221. raise ValueError(f"invalid Query model: {model}") # pragma: no cover
  3222. @classmethod
  3223. def convert_formula_query(cls, model: rest.FormulaQuery) -> grpc.Formula:
  3224. defaults = payload_to_grpc(model.defaults) if model.defaults is not None else None
  3225. expression = cls.convert_expression(model.formula)
  3226. return grpc.Formula(defaults=defaults, expression=expression)
  3227. @classmethod
  3228. def convert_expression(cls, model: rest.Expression) -> grpc.Expression:
  3229. if isinstance(model, float):
  3230. return grpc.Expression(constant=model)
  3231. if isinstance(model, str):
  3232. return grpc.Expression(variable=model)
  3233. if isinstance(model, get_args_subscribed(rest.Condition)):
  3234. return grpc.Expression(condition=cls.convert_condition(model))
  3235. if isinstance(model, rest.DatetimeExpression):
  3236. return grpc.Expression(datetime=model.datetime)
  3237. if isinstance(model, rest.DatetimeKeyExpression):
  3238. return grpc.Expression(datetime_key=model.datetime_key)
  3239. if isinstance(model, rest.NegExpression):
  3240. return grpc.Expression(neg=cls.convert_expression(model.neg))
  3241. if isinstance(model, rest.SumExpression):
  3242. return grpc.Expression(sum=cls.convert_sum_expression(model))
  3243. if isinstance(model, rest.MultExpression):
  3244. return grpc.Expression(mult=cls.convert_mult_expression(model))
  3245. if isinstance(model, rest.DivExpression):
  3246. return grpc.Expression(div=cls.convert_div_expression(model))
  3247. if isinstance(model, rest.PowExpression):
  3248. return grpc.Expression(pow=cls.convert_pow_expression(model))
  3249. if isinstance(model, rest.Log10Expression):
  3250. return grpc.Expression(log10=cls.convert_expression(model.log10))
  3251. if isinstance(model, rest.LnExpression):
  3252. return grpc.Expression(ln=cls.convert_expression(model.ln))
  3253. if isinstance(model, rest.AbsExpression):
  3254. return grpc.Expression(abs=cls.convert_expression(model.abs))
  3255. if isinstance(model, rest.SqrtExpression):
  3256. return grpc.Expression(sqrt=cls.convert_expression(model.sqrt))
  3257. if isinstance(model, rest.ExpExpression):
  3258. return grpc.Expression(exp=cls.convert_expression(model.exp))
  3259. if isinstance(model, rest.GeoDistance):
  3260. return grpc.Expression(geo_distance=cls.convert_geo_distance(model))
  3261. if isinstance(model, rest.LinDecayExpression):
  3262. return grpc.Expression(lin_decay=cls.convert_decay_params_expression(model.lin_decay))
  3263. if isinstance(model, rest.ExpDecayExpression):
  3264. return grpc.Expression(exp_decay=cls.convert_decay_params_expression(model.exp_decay))
  3265. if isinstance(model, rest.GaussDecayExpression):
  3266. return grpc.Expression(
  3267. gauss_decay=cls.convert_decay_params_expression(model.gauss_decay)
  3268. )
  3269. raise ValueError(f"invalid Expression model: {model}") # pragma: no cover
  3270. @classmethod
  3271. def convert_sum_expression(cls, model: rest.SumExpression) -> grpc.SumExpression:
  3272. return grpc.SumExpression(sum=[cls.convert_expression(expr) for expr in model.sum])
  3273. @classmethod
  3274. def convert_mult_expression(cls, model: rest.MultExpression) -> grpc.MultExpression:
  3275. return grpc.MultExpression(mult=[cls.convert_expression(expr) for expr in model.mult])
  3276. @classmethod
  3277. def convert_div_expression(cls, model: rest.DivExpression) -> grpc.DivExpression:
  3278. return grpc.DivExpression(
  3279. left=cls.convert_expression(model.div.left),
  3280. right=cls.convert_expression(model.div.right),
  3281. by_zero_default=model.div.by_zero_default,
  3282. )
  3283. @classmethod
  3284. def convert_pow_expression(cls, model: rest.PowExpression) -> grpc.PowExpression:
  3285. return grpc.PowExpression(
  3286. base=cls.convert_expression(model.pow.base),
  3287. exponent=cls.convert_expression(model.pow.exponent),
  3288. )
  3289. @classmethod
  3290. def convert_geo_distance(cls, model: rest.GeoDistance) -> grpc.GeoDistance:
  3291. return grpc.GeoDistance(
  3292. origin=cls.convert_geo_point(model.geo_distance.origin),
  3293. to=model.geo_distance.to,
  3294. )
  3295. @classmethod
  3296. def convert_decay_params_expression(
  3297. cls, model: rest.DecayParamsExpression
  3298. ) -> grpc.DecayParamsExpression:
  3299. return grpc.DecayParamsExpression(
  3300. x=cls.convert_expression(model.x),
  3301. target=cls.convert_expression(model.target) if model.target is not None else None,
  3302. midpoint=model.midpoint,
  3303. scale=model.scale,
  3304. )
  3305. @classmethod
  3306. def convert_query_interface(cls, model: rest.QueryInterface) -> grpc.Query:
  3307. if isinstance(model, get_args_subscribed(rest.VectorInput)):
  3308. return grpc.Query(nearest=cls.convert_vector_input(model))
  3309. if isinstance(model, get_args(rest.Query)):
  3310. return cls.convert_query(model)
  3311. raise ValueError(f"invalid QueryInterface: {model}") # pragma: no cover
  3312. @classmethod
  3313. def convert_prefetch_query(cls, model: rest.Prefetch) -> grpc.PrefetchQuery:
  3314. prefetch = None
  3315. if isinstance(model.prefetch, rest.Prefetch):
  3316. prefetch = [cls.convert_prefetch_query(model.prefetch)]
  3317. elif isinstance(model.prefetch, list):
  3318. prefetch = [cls.convert_prefetch_query(prefetch) for prefetch in model.prefetch]
  3319. return grpc.PrefetchQuery(
  3320. prefetch=prefetch,
  3321. query=cls.convert_query_interface(model.query) if model.query is not None else None,
  3322. using=model.using if model.using is not None else None,
  3323. filter=cls.convert_filter(model.filter) if model.filter is not None else None,
  3324. params=cls.convert_search_params(model.params) if model.params is not None else None,
  3325. score_threshold=model.score_threshold,
  3326. limit=model.limit if model.limit is not None else None,
  3327. lookup_from=(
  3328. cls.convert_lookup_location(model.lookup_from)
  3329. if model.lookup_from is not None
  3330. else None
  3331. ),
  3332. )
  3333. @classmethod
  3334. def convert_search_request(
  3335. cls, model: rest.SearchRequest, collection_name: str
  3336. ) -> grpc.SearchPoints:
  3337. vector, sparse_indices, name = cls.convert_named_vector_struct(model.vector)
  3338. return grpc.SearchPoints(
  3339. collection_name=collection_name,
  3340. vector=vector,
  3341. sparse_indices=sparse_indices,
  3342. filter=cls.convert_filter(model.filter) if model.filter is not None else None,
  3343. limit=model.limit,
  3344. with_payload=(
  3345. cls.convert_with_payload_interface(model.with_payload)
  3346. if model.with_payload is not None
  3347. else None
  3348. ),
  3349. params=cls.convert_search_params(model.params) if model.params is not None else None,
  3350. score_threshold=model.score_threshold,
  3351. offset=model.offset,
  3352. vector_name=name,
  3353. with_vectors=(
  3354. cls.convert_with_vectors(model.with_vector)
  3355. if model.with_vector is not None
  3356. else None
  3357. ),
  3358. shard_key_selector=(
  3359. cls.convert_shard_key_selector(model.shard_key) if model.shard_key else None
  3360. ),
  3361. )
  3362. @classmethod
  3363. def convert_search_points(
  3364. cls, model: rest.SearchRequest, collection_name: str
  3365. ) -> grpc.SearchPoints:
  3366. return cls.convert_search_request(model, collection_name)
  3367. @classmethod
  3368. def convert_query_request(
  3369. cls, model: rest.QueryRequest, collection_name: str
  3370. ) -> grpc.QueryPoints:
  3371. prefetch = (
  3372. [model.prefetch] if isinstance(model.prefetch, rest.Prefetch) else model.prefetch
  3373. )
  3374. return grpc.QueryPoints(
  3375. collection_name=collection_name,
  3376. prefetch=(
  3377. [cls.convert_prefetch_query(p) for p in prefetch]
  3378. if model.prefetch is not None
  3379. else None
  3380. ),
  3381. query=cls.convert_query_interface(model.query) if model.query is not None else None,
  3382. using=model.using,
  3383. filter=cls.convert_filter(model.filter) if model.filter is not None else None,
  3384. params=cls.convert_search_params(model.params) if model.params is not None else None,
  3385. score_threshold=model.score_threshold,
  3386. limit=model.limit,
  3387. offset=model.offset,
  3388. with_vectors=(
  3389. cls.convert_with_vectors(model.with_vector)
  3390. if model.with_vector is not None
  3391. else None
  3392. ),
  3393. with_payload=(
  3394. cls.convert_with_payload_interface(model.with_payload)
  3395. if model.with_payload is not None
  3396. else None
  3397. ),
  3398. shard_key_selector=(
  3399. cls.convert_shard_key_selector(model.shard_key)
  3400. if model.shard_key is not None
  3401. else None
  3402. ),
  3403. lookup_from=(
  3404. cls.convert_lookup_location(model.lookup_from)
  3405. if model.lookup_from is not None
  3406. else None
  3407. ),
  3408. )
  3409. @classmethod
  3410. def convert_query_points(
  3411. cls, model: rest.QueryRequest, collection_name: str
  3412. ) -> grpc.QueryPoints:
  3413. return cls.convert_query_request(model, collection_name)
  3414. @classmethod
  3415. def convert_recommend_request(
  3416. cls, model: rest.RecommendRequest, collection_name: str
  3417. ) -> grpc.RecommendPoints:
  3418. positive_ids = cls.convert_recommend_examples_to_ids(model.positive)
  3419. negative_ids = cls.convert_recommend_examples_to_ids(model.negative)
  3420. positive_vectors = cls.convert_recommend_examples_to_vectors(model.positive)
  3421. negative_vectors = cls.convert_recommend_examples_to_vectors(model.negative)
  3422. return grpc.RecommendPoints(
  3423. collection_name=collection_name,
  3424. positive=positive_ids,
  3425. negative=negative_ids,
  3426. filter=cls.convert_filter(model.filter) if model.filter is not None else None,
  3427. limit=model.limit,
  3428. with_payload=(
  3429. cls.convert_with_payload_interface(model.with_payload)
  3430. if model.with_payload is not None
  3431. else None
  3432. ),
  3433. params=cls.convert_search_params(model.params) if model.params is not None else None,
  3434. score_threshold=model.score_threshold,
  3435. offset=model.offset,
  3436. with_vectors=(
  3437. cls.convert_with_vectors(model.with_vector)
  3438. if model.with_vector is not None
  3439. else None
  3440. ),
  3441. using=model.using,
  3442. lookup_from=(
  3443. cls.convert_lookup_location(model.lookup_from)
  3444. if model.lookup_from is not None
  3445. else None
  3446. ),
  3447. strategy=(
  3448. cls.convert_recommend_strategy(model.strategy)
  3449. if model.strategy is not None
  3450. else None
  3451. ),
  3452. positive_vectors=positive_vectors,
  3453. negative_vectors=negative_vectors,
  3454. shard_key_selector=(
  3455. cls.convert_shard_key_selector(model.shard_key) if model.shard_key else None
  3456. ),
  3457. )
  3458. @classmethod
  3459. def convert_discover_points(
  3460. cls, model: rest.DiscoverRequest, collection_name: str
  3461. ) -> grpc.DiscoverPoints:
  3462. return cls.convert_discover_request(model, collection_name)
  3463. @classmethod
  3464. def convert_discover_request(
  3465. cls, model: rest.DiscoverRequest, collection_name: str
  3466. ) -> grpc.DiscoverPoints:
  3467. target = cls.convert_target_vector(model.target) if model.target is not None else None
  3468. context = (
  3469. [cls.convert_context_example_pair(pair) for pair in model.context]
  3470. if model.context is not None
  3471. else None
  3472. )
  3473. query_filter = None if model.filter is None else cls.convert_filter(model=model.filter)
  3474. search_params = None if model.params is None else cls.convert_search_params(model.params)
  3475. with_payload = (
  3476. None
  3477. if model.with_payload is None
  3478. else cls.convert_with_payload_interface(model.with_payload)
  3479. )
  3480. with_vectors = (
  3481. None if model.with_vector is None else cls.convert_with_vectors(model.with_vector)
  3482. )
  3483. lookup_from = (
  3484. None if model.lookup_from is None else cls.convert_lookup_location(model.lookup_from)
  3485. )
  3486. shard_key_selector = (
  3487. None if model.shard_key is None else cls.convert_shard_key_selector(model.shard_key)
  3488. )
  3489. return grpc.DiscoverPoints(
  3490. collection_name=collection_name,
  3491. target=target,
  3492. context=context,
  3493. filter=query_filter,
  3494. limit=model.limit,
  3495. offset=model.offset,
  3496. with_vectors=with_vectors,
  3497. with_payload=with_payload,
  3498. params=search_params,
  3499. using=model.using,
  3500. lookup_from=lookup_from,
  3501. shard_key_selector=shard_key_selector,
  3502. )
  3503. @classmethod
  3504. def convert_recommend_points(
  3505. cls, model: rest.RecommendRequest, collection_name: str
  3506. ) -> grpc.RecommendPoints:
  3507. return cls.convert_recommend_request(model, collection_name)
  3508. @classmethod
  3509. def convert_tokenizer_type(cls, model: rest.TokenizerType) -> grpc.TokenizerType:
  3510. if model == rest.TokenizerType.WORD:
  3511. return grpc.TokenizerType.Word
  3512. elif model == rest.TokenizerType.WHITESPACE:
  3513. return grpc.TokenizerType.Whitespace
  3514. elif model == rest.TokenizerType.PREFIX:
  3515. return grpc.TokenizerType.Prefix
  3516. elif model == rest.TokenizerType.MULTILINGUAL:
  3517. return grpc.TokenizerType.Multilingual
  3518. else:
  3519. raise ValueError(f"invalid TokenizerType model: {model}") # pragma: no cover
  3520. @classmethod
  3521. def convert_text_index_params(cls, model: rest.TextIndexParams) -> grpc.TextIndexParams:
  3522. return grpc.TextIndexParams(
  3523. tokenizer=(
  3524. cls.convert_tokenizer_type(model.tokenizer)
  3525. if model.tokenizer is not None
  3526. else None
  3527. ),
  3528. lowercase=model.lowercase,
  3529. min_token_len=model.min_token_len,
  3530. max_token_len=model.max_token_len,
  3531. on_disk=model.on_disk,
  3532. stopwords=cls.convert_stopwords(model.stopwords)
  3533. if model.stopwords is not None
  3534. else None,
  3535. phrase_matching=model.phrase_matching,
  3536. stemmer=cls.convert_stemmer(model.stemmer) if model.stemmer is not None else None,
  3537. )
  3538. @classmethod
  3539. def convert_stopwords(cls, model: rest.StopwordsInterface) -> grpc.StopwordsSet:
  3540. if isinstance(model, rest.Language):
  3541. return grpc.StopwordsSet(languages=[model.value])
  3542. if isinstance(model, rest.StopwordsSet):
  3543. return grpc.StopwordsSet(
  3544. languages=[lang for lang in model.languages] if model.languages else None,
  3545. custom=model.custom,
  3546. )
  3547. raise ValueError(f"invalid StopwordsInterface model: {model}") # pragma: no cover
  3548. @classmethod
  3549. def convert_stemmer(cls, model: rest.StemmingAlgorithm) -> grpc.StemmingAlgorithm:
  3550. if isinstance(model, rest.SnowballParams):
  3551. return grpc.StemmingAlgorithm(snowball=grpc.SnowballParams(language=model.language))
  3552. raise ValueError(f"invalid StemmingAlgorithm model: {model}") # pragma: no cover
  3553. @classmethod
  3554. def convert_integer_index_params(
  3555. cls, model: rest.IntegerIndexParams
  3556. ) -> grpc.IntegerIndexParams:
  3557. return grpc.IntegerIndexParams(
  3558. lookup=model.lookup,
  3559. range=model.range,
  3560. is_principal=model.is_principal,
  3561. on_disk=model.on_disk,
  3562. )
  3563. @classmethod
  3564. def convert_keyword_index_params(
  3565. cls, model: rest.KeywordIndexParams
  3566. ) -> grpc.KeywordIndexParams:
  3567. return grpc.KeywordIndexParams(is_tenant=model.is_tenant, on_disk=model.on_disk)
  3568. @classmethod
  3569. def convert_float_index_params(cls, model: rest.FloatIndexParams) -> grpc.FloatIndexParams:
  3570. return grpc.FloatIndexParams(is_principal=model.is_principal, on_disk=model.on_disk)
  3571. @classmethod
  3572. def convert_geo_index_params(cls, model: rest.GeoIndexParams) -> grpc.GeoIndexParams:
  3573. return grpc.GeoIndexParams(on_disk=model.on_disk)
  3574. @classmethod
  3575. def convert_bool_index_params(cls, model: rest.BoolIndexParams) -> grpc.BoolIndexParams:
  3576. return grpc.BoolIndexParams(on_disk=model.on_disk)
  3577. @classmethod
  3578. def convert_datetime_index_params(
  3579. cls, model: rest.DatetimeIndexParams
  3580. ) -> grpc.DatetimeIndexParams:
  3581. return grpc.DatetimeIndexParams(is_principal=model.is_principal, on_disk=model.on_disk)
  3582. @classmethod
  3583. def convert_uuid_index_params(cls, model: rest.UuidIndexParams) -> grpc.UuidIndexParams:
  3584. return grpc.UuidIndexParams(is_tenant=model.is_tenant, on_disk=model.on_disk)
  3585. @classmethod
  3586. def convert_collection_params_diff(
  3587. cls, model: rest.CollectionParamsDiff
  3588. ) -> grpc.CollectionParamsDiff:
  3589. return grpc.CollectionParamsDiff(
  3590. replication_factor=model.replication_factor,
  3591. write_consistency_factor=model.write_consistency_factor,
  3592. on_disk_payload=model.on_disk_payload,
  3593. read_fan_out_factor=model.read_fan_out_factor,
  3594. )
  3595. @classmethod
  3596. def convert_lookup_location(cls, model: rest.LookupLocation) -> grpc.LookupLocation:
  3597. return grpc.LookupLocation(
  3598. collection_name=model.collection,
  3599. vector_name=model.vector,
  3600. )
  3601. @classmethod
  3602. def convert_read_consistency(cls, model: rest.ReadConsistency) -> grpc.ReadConsistency:
  3603. if isinstance(model, int):
  3604. return grpc.ReadConsistency(
  3605. factor=model,
  3606. )
  3607. elif isinstance(model, rest.ReadConsistencyType):
  3608. return grpc.ReadConsistency(
  3609. type=cls.convert_read_consistency_type(model),
  3610. )
  3611. else:
  3612. raise ValueError(f"invalid ReadConsistency model: {model}") # pragma: no cover
  3613. @classmethod
  3614. def convert_read_consistency_type(
  3615. cls, model: rest.ReadConsistencyType
  3616. ) -> grpc.ReadConsistencyType:
  3617. if model == rest.ReadConsistencyType.MAJORITY:
  3618. return grpc.ReadConsistencyType.Majority
  3619. elif model == rest.ReadConsistencyType.ALL:
  3620. return grpc.ReadConsistencyType.All
  3621. elif model == rest.ReadConsistencyType.QUORUM:
  3622. return grpc.ReadConsistencyType.Quorum
  3623. else:
  3624. raise ValueError(f"invalid ReadConsistencyType model: {model}") # pragma: no cover
  3625. @classmethod
  3626. def convert_write_ordering(cls, model: rest.WriteOrdering) -> grpc.WriteOrdering:
  3627. if model == rest.WriteOrdering.WEAK:
  3628. return grpc.WriteOrdering(type=grpc.WriteOrderingType.Weak)
  3629. elif model == rest.WriteOrdering.MEDIUM:
  3630. return grpc.WriteOrdering(type=grpc.WriteOrderingType.Medium)
  3631. elif model == rest.WriteOrdering.STRONG:
  3632. return grpc.WriteOrdering(type=grpc.WriteOrderingType.Strong)
  3633. else:
  3634. raise ValueError(f"invalid WriteOrdering model: {model}") # pragma: no cover
  3635. @classmethod
  3636. def convert_scalar_quantization_config(
  3637. cls, model: rest.ScalarQuantizationConfig
  3638. ) -> grpc.ScalarQuantization:
  3639. return grpc.ScalarQuantization(
  3640. type=grpc.QuantizationType.Int8,
  3641. quantile=model.quantile,
  3642. always_ram=model.always_ram,
  3643. )
  3644. @classmethod
  3645. def convert_product_quantization_config(
  3646. cls, model: rest.ProductQuantizationConfig
  3647. ) -> grpc.ProductQuantization:
  3648. return grpc.ProductQuantization(
  3649. compression=cls.convert_compression_ratio(model.compression),
  3650. always_ram=model.always_ram,
  3651. )
  3652. @classmethod
  3653. def convert_binary_quantization_config(
  3654. cls, model: rest.BinaryQuantizationConfig
  3655. ) -> grpc.BinaryQuantization:
  3656. return grpc.BinaryQuantization(
  3657. always_ram=model.always_ram,
  3658. encoding=cls.convert_binary_quantization_encoding(model.encoding)
  3659. if model.encoding is not None
  3660. else None,
  3661. query_encoding=cls.convert_binary_quantization_query_encoding(model.query_encoding)
  3662. if model.query_encoding is not None
  3663. else None,
  3664. )
  3665. @classmethod
  3666. def convert_binary_quantization_encoding(
  3667. cls, model: rest.BinaryQuantizationEncoding
  3668. ) -> grpc.BinaryQuantizationEncoding:
  3669. if model == rest.BinaryQuantizationEncoding.ONE_BIT:
  3670. return grpc.BinaryQuantizationEncoding.OneBit
  3671. if model == rest.BinaryQuantizationEncoding.TWO_BITS:
  3672. return grpc.BinaryQuantizationEncoding.TwoBits
  3673. if model == rest.BinaryQuantizationEncoding.ONE_AND_HALF_BITS:
  3674. return grpc.BinaryQuantizationEncoding.OneAndHalfBits
  3675. raise ValueError(f"invalid BinaryQuantizationEncoding model: {model}") # pragma: no cover
  3676. @classmethod
  3677. def convert_binary_quantization_query_encoding(
  3678. cls, model: rest.BinaryQuantizationQueryEncoding
  3679. ) -> grpc.BinaryQuantizationQueryEncoding:
  3680. if model == rest.BinaryQuantizationQueryEncoding.DEFAULT:
  3681. return grpc.BinaryQuantizationQueryEncoding(
  3682. setting=grpc.BinaryQuantizationQueryEncoding.Setting.Default
  3683. )
  3684. if model == rest.BinaryQuantizationQueryEncoding.BINARY:
  3685. return grpc.BinaryQuantizationQueryEncoding(
  3686. setting=grpc.BinaryQuantizationQueryEncoding.Setting.Binary
  3687. )
  3688. if model == rest.BinaryQuantizationQueryEncoding.SCALAR4BITS:
  3689. return grpc.BinaryQuantizationQueryEncoding(
  3690. setting=grpc.BinaryQuantizationQueryEncoding.Setting.Scalar4Bits
  3691. )
  3692. if model == rest.BinaryQuantizationQueryEncoding.SCALAR8BITS:
  3693. return grpc.BinaryQuantizationQueryEncoding(
  3694. setting=grpc.BinaryQuantizationQueryEncoding.Setting.Scalar8Bits
  3695. )
  3696. raise ValueError(
  3697. f"invalid BinaryQuantizationQueryEncoding model: {model}"
  3698. ) # pragma: no cover
  3699. @classmethod
  3700. def convert_compression_ratio(cls, model: rest.CompressionRatio) -> grpc.CompressionRatio:
  3701. if model == rest.CompressionRatio.X4:
  3702. return grpc.CompressionRatio.x4
  3703. elif model == rest.CompressionRatio.X8:
  3704. return grpc.CompressionRatio.x8
  3705. elif model == rest.CompressionRatio.X16:
  3706. return grpc.CompressionRatio.x16
  3707. elif model == rest.CompressionRatio.X32:
  3708. return grpc.CompressionRatio.x32
  3709. elif model == rest.CompressionRatio.X64:
  3710. return grpc.CompressionRatio.x64
  3711. else:
  3712. raise ValueError(f"invalid CompressionRatio model: {model}") # pragma: no cover
  3713. @classmethod
  3714. def convert_quantization_config(
  3715. cls, model: rest.QuantizationConfig
  3716. ) -> grpc.QuantizationConfig:
  3717. if isinstance(model, rest.ScalarQuantization):
  3718. return grpc.QuantizationConfig(
  3719. scalar=cls.convert_scalar_quantization_config(model.scalar)
  3720. )
  3721. if isinstance(model, rest.ProductQuantization):
  3722. return grpc.QuantizationConfig(
  3723. product=cls.convert_product_quantization_config(model.product)
  3724. )
  3725. if isinstance(model, rest.BinaryQuantization):
  3726. return grpc.QuantizationConfig(
  3727. binary=cls.convert_binary_quantization_config(model.binary)
  3728. )
  3729. else:
  3730. raise ValueError(f"invalid QuantizationConfig model: {model}") # pragma: no cover
  3731. @classmethod
  3732. def convert_quantization_search_params(
  3733. cls, model: rest.QuantizationSearchParams
  3734. ) -> grpc.QuantizationSearchParams:
  3735. return grpc.QuantizationSearchParams(
  3736. ignore=model.ignore,
  3737. rescore=model.rescore,
  3738. oversampling=model.oversampling,
  3739. )
  3740. @classmethod
  3741. def convert_point_vectors(cls, model: rest.PointVectors) -> grpc.PointVectors:
  3742. return grpc.PointVectors(
  3743. id=cls.convert_extended_point_id(model.id),
  3744. vectors=cls.convert_vector_struct(model.vector),
  3745. )
  3746. @classmethod
  3747. def convert_groups_result(cls, model: rest.GroupsResult) -> grpc.GroupsResult:
  3748. return grpc.GroupsResult(
  3749. groups=[cls.convert_point_group(group) for group in model.groups],
  3750. )
  3751. @classmethod
  3752. def convert_point_group(cls, model: rest.PointGroup) -> grpc.PointGroup:
  3753. return grpc.PointGroup(
  3754. id=cls.convert_group_id(model.id),
  3755. hits=[cls.convert_scored_point(point) for point in model.hits],
  3756. lookup=cls.convert_record(model.lookup) if model.lookup is not None else None,
  3757. )
  3758. @classmethod
  3759. def convert_group_id(cls, model: rest.GroupId) -> grpc.GroupId:
  3760. if isinstance(model, str):
  3761. return grpc.GroupId(
  3762. string_value=model,
  3763. )
  3764. elif isinstance(model, int):
  3765. if model >= 0:
  3766. return grpc.GroupId(
  3767. unsigned_value=model,
  3768. )
  3769. else:
  3770. return grpc.GroupId(
  3771. integer_value=model,
  3772. )
  3773. else:
  3774. raise ValueError(f"invalid GroupId model: {model}") # pragma: no cover
  3775. @classmethod
  3776. def convert_with_lookup(cls, model: rest.WithLookup) -> grpc.WithLookup:
  3777. return grpc.WithLookup(
  3778. collection=model.collection,
  3779. with_vectors=(
  3780. cls.convert_with_vectors(model.with_vectors)
  3781. if model.with_vectors is not None
  3782. else None
  3783. ),
  3784. with_payload=(
  3785. cls.convert_with_payload_interface(model.with_payload)
  3786. if model.with_payload is not None
  3787. else None
  3788. ),
  3789. )
  3790. @classmethod
  3791. def convert_quantization_config_diff(
  3792. cls, model: rest.QuantizationConfigDiff
  3793. ) -> grpc.QuantizationConfigDiff:
  3794. if isinstance(model, rest.ScalarQuantization):
  3795. return grpc.QuantizationConfigDiff(
  3796. scalar=cls.convert_scalar_quantization_config(model.scalar)
  3797. )
  3798. if isinstance(model, rest.ProductQuantization):
  3799. return grpc.QuantizationConfigDiff(
  3800. product=cls.convert_product_quantization_config(model.product)
  3801. )
  3802. if isinstance(model, rest.BinaryQuantization):
  3803. return grpc.QuantizationConfigDiff(
  3804. binary=cls.convert_binary_quantization_config(model.binary)
  3805. )
  3806. if model == rest.Disabled.DISABLED:
  3807. return grpc.QuantizationConfigDiff(
  3808. disabled=grpc.Disabled(),
  3809. )
  3810. else:
  3811. raise ValueError(f"invalid QuantizationConfigDiff model: {model}") # pragma: no cover
  3812. @classmethod
  3813. def convert_vector_params_diff(cls, model: rest.VectorParamsDiff) -> grpc.VectorParamsDiff:
  3814. return grpc.VectorParamsDiff(
  3815. hnsw_config=(
  3816. cls.convert_hnsw_config_diff(model.hnsw_config)
  3817. if model.hnsw_config is not None
  3818. else None
  3819. ),
  3820. quantization_config=(
  3821. cls.convert_quantization_config_diff(model.quantization_config)
  3822. if model.quantization_config is not None
  3823. else None
  3824. ),
  3825. on_disk=model.on_disk,
  3826. )
  3827. @classmethod
  3828. def convert_vectors_config_diff(cls, model: rest.VectorsConfigDiff) -> grpc.VectorsConfigDiff:
  3829. if isinstance(model, dict) and len(model) == 1 and "" in model:
  3830. return grpc.VectorsConfigDiff(params=cls.convert_vector_params_diff(model[""]))
  3831. elif isinstance(model, dict):
  3832. return grpc.VectorsConfigDiff(
  3833. params_map=grpc.VectorParamsDiffMap(
  3834. map=dict(
  3835. (key, cls.convert_vector_params_diff(val)) for key, val in model.items()
  3836. )
  3837. )
  3838. )
  3839. else:
  3840. raise ValueError(f"invalid VectorsConfigDiff model: {model}") # pragma: no cover
  3841. @classmethod
  3842. def convert_point_insert_operation(
  3843. cls, model: rest.PointInsertOperations
  3844. ) -> list[grpc.PointStruct]:
  3845. if isinstance(model, rest.PointsBatch):
  3846. vectors_batch: list[grpc.Vectors] = cls.convert_batch_vector_struct(
  3847. model.batch.vectors, len(model.batch.ids)
  3848. )
  3849. return [
  3850. grpc.PointStruct(
  3851. id=RestToGrpc.convert_extended_point_id(model.batch.ids[idx]),
  3852. vectors=vectors_batch[idx],
  3853. payload=(
  3854. RestToGrpc.convert_payload(model.batch.payloads[idx])
  3855. if model.batch.payloads is not None
  3856. else None
  3857. ),
  3858. )
  3859. for idx in range(len(model.batch.ids))
  3860. ]
  3861. elif isinstance(model, rest.PointsList):
  3862. return [cls.convert_point_struct(point) for point in model.points]
  3863. else:
  3864. raise ValueError(f"invalid PointInsertOperations model: {model}") # pragma: no cover
  3865. @classmethod
  3866. def convert_update_operation(cls, model: rest.UpdateOperation) -> grpc.PointsUpdateOperation:
  3867. return cls.convert_points_update_operation(model)
  3868. @classmethod
  3869. def convert_points_update_operation(
  3870. cls, model: rest.UpdateOperation
  3871. ) -> grpc.PointsUpdateOperation:
  3872. if isinstance(model, rest.UpsertOperation):
  3873. shard_key_selector = (
  3874. cls.convert_shard_key_selector(model.upsert.shard_key)
  3875. if model.upsert.shard_key
  3876. else None
  3877. )
  3878. return grpc.PointsUpdateOperation(
  3879. upsert=grpc.PointsUpdateOperation.PointStructList(
  3880. points=cls.convert_point_insert_operation(model.upsert),
  3881. shard_key_selector=shard_key_selector,
  3882. )
  3883. )
  3884. elif isinstance(model, rest.DeleteOperation):
  3885. shard_key_selector = (
  3886. cls.convert_shard_key_selector(model.delete.shard_key)
  3887. if model.delete.shard_key
  3888. else None
  3889. )
  3890. points_selector = cls.convert_points_selector(model.delete)
  3891. delete_points = grpc.PointsUpdateOperation.DeletePoints(
  3892. points=points_selector,
  3893. shard_key_selector=shard_key_selector,
  3894. )
  3895. return grpc.PointsUpdateOperation(
  3896. delete_points=delete_points,
  3897. )
  3898. elif isinstance(model, rest.SetPayloadOperation):
  3899. if model.set_payload.points:
  3900. points_selector = rest.PointIdsList(points=model.set_payload.points)
  3901. elif model.set_payload.filter:
  3902. points_selector = rest.FilterSelector(filter=model.set_payload.filter)
  3903. else:
  3904. raise ValueError(f"invalid SetPayloadOperation model: {model}") # pragma: no cover
  3905. shard_key_selector = (
  3906. cls.convert_shard_key_selector(model.set_payload.shard_key)
  3907. if model.set_payload.shard_key
  3908. else None
  3909. )
  3910. return grpc.PointsUpdateOperation(
  3911. set_payload=grpc.PointsUpdateOperation.SetPayload(
  3912. payload=cls.convert_payload(model.set_payload.payload),
  3913. points_selector=cls.convert_points_selector(points_selector),
  3914. shard_key_selector=shard_key_selector,
  3915. key=model.set_payload.key,
  3916. )
  3917. )
  3918. elif isinstance(model, rest.OverwritePayloadOperation):
  3919. if model.overwrite_payload.points:
  3920. points_selector = rest.PointIdsList(points=model.overwrite_payload.points)
  3921. elif model.overwrite_payload.filter:
  3922. points_selector = rest.FilterSelector(filter=model.overwrite_payload.filter)
  3923. else:
  3924. raise ValueError(
  3925. f"invalid OverwritePayloadOperation model: {model}"
  3926. ) # pragma: no cover
  3927. shard_key_selector = (
  3928. cls.convert_shard_key_selector(model.overwrite_payload.shard_key)
  3929. if model.overwrite_payload.shard_key
  3930. else None
  3931. )
  3932. return grpc.PointsUpdateOperation(
  3933. overwrite_payload=grpc.PointsUpdateOperation.OverwritePayload(
  3934. payload=cls.convert_payload(model.overwrite_payload.payload),
  3935. points_selector=cls.convert_points_selector(points_selector),
  3936. shard_key_selector=shard_key_selector,
  3937. key=model.overwrite_payload.key,
  3938. )
  3939. )
  3940. elif isinstance(model, rest.DeletePayloadOperation):
  3941. if model.delete_payload.points:
  3942. points_selector = rest.PointIdsList(points=model.delete_payload.points)
  3943. elif model.delete_payload.filter:
  3944. points_selector = rest.FilterSelector(filter=model.delete_payload.filter)
  3945. else:
  3946. raise ValueError(
  3947. f"invalid DeletePayloadOperation model: {model}"
  3948. ) # pragma: no cover
  3949. shard_key_selector = (
  3950. cls.convert_shard_key_selector(model.delete_payload.shard_key)
  3951. if model.delete_payload.shard_key
  3952. else None
  3953. )
  3954. return grpc.PointsUpdateOperation(
  3955. delete_payload=grpc.PointsUpdateOperation.DeletePayload(
  3956. keys=model.delete_payload.keys,
  3957. points_selector=cls.convert_points_selector(points_selector),
  3958. shard_key_selector=shard_key_selector,
  3959. )
  3960. )
  3961. elif isinstance(model, rest.ClearPayloadOperation):
  3962. shard_key_selector = (
  3963. cls.convert_shard_key_selector(model.clear_payload.shard_key)
  3964. if model.clear_payload.shard_key
  3965. else None
  3966. )
  3967. points_selector = cls.convert_points_selector(model.clear_payload)
  3968. clear_payload = grpc.PointsUpdateOperation.ClearPayload(
  3969. points=points_selector,
  3970. shard_key_selector=shard_key_selector,
  3971. )
  3972. return grpc.PointsUpdateOperation(
  3973. clear_payload=clear_payload,
  3974. )
  3975. elif isinstance(model, rest.UpdateVectorsOperation):
  3976. shard_key_selector = (
  3977. cls.convert_shard_key_selector(model.update_vectors.shard_key)
  3978. if model.update_vectors.shard_key
  3979. else None
  3980. )
  3981. return grpc.PointsUpdateOperation(
  3982. update_vectors=grpc.PointsUpdateOperation.UpdateVectors(
  3983. points=[
  3984. cls.convert_point_vectors(point) for point in model.update_vectors.points
  3985. ],
  3986. shard_key_selector=shard_key_selector,
  3987. )
  3988. )
  3989. elif isinstance(model, rest.DeleteVectorsOperation):
  3990. if model.delete_vectors.points:
  3991. points_selector = rest.PointIdsList(points=model.delete_vectors.points)
  3992. elif model.delete_vectors.filter:
  3993. points_selector = rest.FilterSelector(filter=model.delete_vectors.filter)
  3994. else:
  3995. raise ValueError(
  3996. f"invalid DeletePayloadOperation model: {model}"
  3997. ) # pragma: no cover
  3998. shard_key_selector = (
  3999. cls.convert_shard_key_selector(model.delete_vectors.shard_key)
  4000. if model.delete_vectors.shard_key
  4001. else None
  4002. )
  4003. return grpc.PointsUpdateOperation(
  4004. delete_vectors=grpc.PointsUpdateOperation.DeleteVectors(
  4005. points_selector=cls.convert_points_selector(points_selector),
  4006. vectors=grpc.VectorsSelector(names=model.delete_vectors.vector),
  4007. shard_key_selector=shard_key_selector,
  4008. )
  4009. )
  4010. else:
  4011. raise ValueError(f"invalid UpdateOperation model: {model}") # pragma: no cover
  4012. @classmethod
  4013. def convert_init_from(cls, model: rest.InitFrom) -> str:
  4014. if isinstance(model, rest.InitFrom):
  4015. return model.collection
  4016. else:
  4017. raise ValueError(f"invalid InitFrom model: {model}") # pragma: no cover
  4018. @classmethod
  4019. def convert_recommend_strategy(cls, model: rest.RecommendStrategy) -> grpc.RecommendStrategy:
  4020. if model == rest.RecommendStrategy.AVERAGE_VECTOR:
  4021. return grpc.RecommendStrategy.AverageVector
  4022. elif model == rest.RecommendStrategy.BEST_SCORE:
  4023. return grpc.RecommendStrategy.BestScore
  4024. elif model == rest.RecommendStrategy.SUM_SCORES:
  4025. return grpc.RecommendStrategy.SumScores
  4026. else:
  4027. raise ValueError(f"invalid RecommendStrategy model: {model}") # pragma: no cover
  4028. @classmethod
  4029. def convert_sparse_index_params(cls, model: rest.SparseIndexParams) -> grpc.SparseIndexConfig:
  4030. return grpc.SparseIndexConfig(
  4031. full_scan_threshold=(
  4032. model.full_scan_threshold if model.full_scan_threshold is not None else None
  4033. ),
  4034. on_disk=model.on_disk if model.on_disk is not None else None,
  4035. datatype=cls.convert_datatype(model.datatype) if model.datatype is not None else None,
  4036. )
  4037. @classmethod
  4038. def convert_modifier(cls, model: rest.Modifier) -> grpc.Modifier:
  4039. if model == rest.Modifier.IDF:
  4040. return grpc.Modifier.Idf
  4041. elif model == rest.Modifier.NONE:
  4042. return getattr(grpc.Modifier, "None")
  4043. else:
  4044. raise ValueError(f"invalid Modifier model: {model}") # pragma: no cover
  4045. @classmethod
  4046. def convert_sparse_vector_params(
  4047. cls, model: rest.SparseVectorParams
  4048. ) -> grpc.SparseVectorParams:
  4049. return grpc.SparseVectorParams(
  4050. index=(
  4051. cls.convert_sparse_index_params(model.index) if model.index is not None else None
  4052. ),
  4053. modifier=(
  4054. cls.convert_modifier(model.modifier) if model.modifier is not None else None
  4055. ),
  4056. )
  4057. @classmethod
  4058. def convert_sparse_vector_config(
  4059. cls, model: Mapping[str, rest.SparseVectorParams]
  4060. ) -> grpc.SparseVectorConfig:
  4061. return grpc.SparseVectorConfig(
  4062. map=dict((key, cls.convert_sparse_vector_params(val)) for key, val in model.items())
  4063. )
  4064. @classmethod
  4065. def convert_shard_key(cls, model: rest.ShardKey) -> grpc.ShardKey:
  4066. if isinstance(model, int):
  4067. return grpc.ShardKey(number=model)
  4068. if isinstance(model, str):
  4069. return grpc.ShardKey(keyword=model)
  4070. raise ValueError(f"invalid ShardKey model: {model}") # pragma: no cover
  4071. @classmethod
  4072. def convert_shard_key_selector(cls, model: rest.ShardKeySelector) -> grpc.ShardKeySelector:
  4073. if isinstance(model, get_args_subscribed(rest.ShardKey)):
  4074. return grpc.ShardKeySelector(shard_keys=[cls.convert_shard_key(model)])
  4075. if isinstance(model, list):
  4076. return grpc.ShardKeySelector(shard_keys=[cls.convert_shard_key(key) for key in model])
  4077. raise ValueError(f"invalid ShardKeySelector model: {model}") # pragma: no cover
  4078. @classmethod
  4079. def convert_sharding_method(cls, model: rest.ShardingMethod) -> grpc.ShardingMethod:
  4080. if model == rest.ShardingMethod.AUTO:
  4081. return grpc.Auto
  4082. elif model == rest.ShardingMethod.CUSTOM:
  4083. return grpc.Custom
  4084. else:
  4085. raise ValueError(f"invalid ShardingMethod model: {model}") # pragma: no cover
  4086. @classmethod
  4087. def convert_health_check_reply(cls, model: rest.VersionInfo) -> grpc.HealthCheckReply:
  4088. return grpc.HealthCheckReply(
  4089. title=model.title,
  4090. version=model.version,
  4091. commit=model.commit,
  4092. )
  4093. @classmethod
  4094. def convert_search_matrix_pair(cls, model: rest.SearchMatrixPair) -> grpc.SearchMatrixPair:
  4095. return grpc.SearchMatrixPair(
  4096. a=cls.convert_extended_point_id(model.a),
  4097. b=cls.convert_extended_point_id(model.b),
  4098. score=model.score,
  4099. )
  4100. @classmethod
  4101. def convert_search_matrix_pairs(
  4102. cls, model: rest.SearchMatrixPairsResponse
  4103. ) -> grpc.SearchMatrixPairs:
  4104. return grpc.SearchMatrixPairs(
  4105. pairs=[cls.convert_search_matrix_pair(pair) for pair in model.pairs],
  4106. )
  4107. @classmethod
  4108. def convert_search_matrix_offsets(
  4109. cls, model: rest.SearchMatrixOffsetsResponse
  4110. ) -> grpc.SearchMatrixOffsets:
  4111. return grpc.SearchMatrixOffsets(
  4112. offsets_row=list(model.offsets_row),
  4113. offsets_col=list(model.offsets_col),
  4114. scores=list(model.scores),
  4115. ids=[cls.convert_extended_point_id(p_id) for p_id in model.ids],
  4116. )
  4117. @classmethod
  4118. def convert_strict_mode_multivector(
  4119. cls, model: rest.StrictModeMultivector
  4120. ) -> grpc.StrictModeMultivector:
  4121. return grpc.StrictModeMultivector(
  4122. max_vectors=model.max_vectors,
  4123. )
  4124. @classmethod
  4125. def convert_strict_mode_multivector_config(
  4126. cls, model: rest.StrictModeMultivectorConfig
  4127. ) -> grpc.StrictModeMultivectorConfig:
  4128. return grpc.StrictModeMultivectorConfig(
  4129. multivector_config=dict(
  4130. (key, cls.convert_strict_mode_multivector(val)) for key, val in model.items()
  4131. )
  4132. )
  4133. @classmethod
  4134. def convert_strict_mode_sparse(cls, model: rest.StrictModeSparse) -> grpc.StrictModeSparse:
  4135. return grpc.StrictModeSparse(
  4136. max_length=model.max_length,
  4137. )
  4138. @classmethod
  4139. def convert_strict_mode_sparse_config(
  4140. cls, model: rest.StrictModeSparseConfig
  4141. ) -> grpc.StrictModeSparseConfig:
  4142. return grpc.StrictModeSparseConfig(
  4143. sparse_config=dict(
  4144. (key, cls.convert_strict_mode_sparse(val)) for key, val in model.items()
  4145. )
  4146. )
  4147. @classmethod
  4148. def convert_strict_mode_config(cls, model: rest.StrictModeConfig) -> grpc.StrictModeConfig:
  4149. return grpc.StrictModeConfig(
  4150. enabled=model.enabled,
  4151. max_query_limit=model.max_query_limit,
  4152. max_timeout=model.max_timeout,
  4153. unindexed_filtering_retrieve=model.unindexed_filtering_retrieve,
  4154. unindexed_filtering_update=model.unindexed_filtering_update,
  4155. search_max_hnsw_ef=model.search_max_hnsw_ef,
  4156. search_allow_exact=model.search_allow_exact,
  4157. search_max_oversampling=model.search_max_oversampling,
  4158. upsert_max_batchsize=model.upsert_max_batchsize,
  4159. max_collection_vector_size_bytes=model.max_collection_vector_size_bytes,
  4160. read_rate_limit=model.read_rate_limit,
  4161. write_rate_limit=model.write_rate_limit,
  4162. max_collection_payload_size_bytes=model.max_collection_payload_size_bytes,
  4163. max_points_count=model.max_points_count,
  4164. filter_max_conditions=model.filter_max_conditions,
  4165. condition_max_size=model.condition_max_size,
  4166. multivector_config=(
  4167. cls.convert_strict_mode_multivector_config(model.multivector_config)
  4168. if model.multivector_config
  4169. else None
  4170. ),
  4171. sparse_config=(
  4172. cls.convert_strict_mode_sparse_config(model.sparse_config)
  4173. if model.sparse_config
  4174. else None
  4175. ),
  4176. )
  4177. @classmethod
  4178. def convert_strict_mode_config_output(
  4179. cls, model: rest.StrictModeConfigOutput
  4180. ) -> grpc.StrictModeConfig:
  4181. return grpc.StrictModeConfig(
  4182. enabled=model.enabled,
  4183. max_query_limit=model.max_query_limit,
  4184. max_timeout=model.max_timeout,
  4185. unindexed_filtering_retrieve=model.unindexed_filtering_retrieve,
  4186. unindexed_filtering_update=model.unindexed_filtering_update,
  4187. search_max_hnsw_ef=model.search_max_hnsw_ef,
  4188. search_allow_exact=model.search_allow_exact,
  4189. search_max_oversampling=model.search_max_oversampling,
  4190. upsert_max_batchsize=model.upsert_max_batchsize,
  4191. max_collection_vector_size_bytes=model.max_collection_vector_size_bytes,
  4192. read_rate_limit=model.read_rate_limit,
  4193. write_rate_limit=model.write_rate_limit,
  4194. max_collection_payload_size_bytes=model.max_collection_payload_size_bytes,
  4195. max_points_count=model.max_points_count,
  4196. filter_max_conditions=model.filter_max_conditions,
  4197. condition_max_size=model.condition_max_size,
  4198. multivector_config=(
  4199. cls.convert_strict_mode_multivector_config_output(model.multivector_config)
  4200. if model.multivector_config
  4201. else None
  4202. ),
  4203. sparse_config=(
  4204. cls.convert_strict_mode_sparse_config_output(model.sparse_config)
  4205. if model.sparse_config
  4206. else None
  4207. ),
  4208. )
  4209. @classmethod
  4210. def convert_strict_mode_multivector_config_output(
  4211. cls, model: rest.StrictModeMultivectorConfigOutput
  4212. ) -> grpc.StrictModeMultivectorConfig:
  4213. return grpc.StrictModeMultivectorConfig(
  4214. multivector_config=dict(
  4215. (key, cls.convert_strict_mode_multivector_output(val))
  4216. for key, val in model.items()
  4217. )
  4218. )
  4219. @classmethod
  4220. def convert_strict_mode_sparse_config_output(
  4221. cls, model: rest.StrictModeSparseConfigOutput
  4222. ) -> grpc.StrictModeSparseConfig:
  4223. return grpc.StrictModeSparseConfig(
  4224. sparse_config=dict(
  4225. (key, cls.convert_strict_mode_sparse_output(val)) for key, val in model.items()
  4226. )
  4227. )
  4228. @classmethod
  4229. def convert_strict_mode_multivector_output(
  4230. cls, model: rest.StrictModeMultivectorOutput
  4231. ) -> grpc.StrictModeMultivector:
  4232. return grpc.StrictModeMultivector(
  4233. max_vectors=model.max_vectors,
  4234. )
  4235. @classmethod
  4236. def convert_strict_mode_sparse_output(
  4237. cls, model: rest.StrictModeSparseOutput
  4238. ) -> grpc.StrictModeSparse:
  4239. return grpc.StrictModeSparse(
  4240. max_length=model.max_length,
  4241. )