connection.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. import asyncio
  2. import collections
  3. from typing import Any, Awaitable, Callable, Optional, Union
  4. import grpc
  5. from qdrant_client.common.client_exceptions import ResourceExhaustedResponse
  6. # type: ignore # noqa: F401
  7. # Source <https://github.com/grpc/grpc/blob/master/examples/python/interceptors/headers/generic_client_interceptor.py>
  8. class _GenericClientInterceptor(
  9. grpc.UnaryUnaryClientInterceptor,
  10. grpc.UnaryStreamClientInterceptor,
  11. grpc.StreamUnaryClientInterceptor,
  12. grpc.StreamStreamClientInterceptor,
  13. ):
  14. def __init__(self, interceptor_function: Callable):
  15. self._fn = interceptor_function
  16. def intercept_unary_unary(
  17. self, continuation: Any, client_call_details: Any, request: Any
  18. ) -> Any:
  19. new_details, new_request_iterator, postprocess = self._fn(
  20. client_call_details, iter((request,)), False, False
  21. )
  22. response = continuation(new_details, next(new_request_iterator))
  23. return postprocess(response) if postprocess else response
  24. def intercept_unary_stream(
  25. self, continuation: Any, client_call_details: Any, request: Any
  26. ) -> Any:
  27. new_details, new_request_iterator, postprocess = self._fn(
  28. client_call_details, iter((request,)), False, True
  29. )
  30. response_it = continuation(new_details, next(new_request_iterator))
  31. return postprocess(response_it) if postprocess else response_it
  32. def intercept_stream_unary(
  33. self, continuation: Any, client_call_details: Any, request_iterator: Any
  34. ) -> Any:
  35. new_details, new_request_iterator, postprocess = self._fn(
  36. client_call_details, request_iterator, True, False
  37. )
  38. response = continuation(new_details, new_request_iterator)
  39. return postprocess(response) if postprocess else response
  40. def intercept_stream_stream(
  41. self, continuation: Any, client_call_details: Any, request_iterator: Any
  42. ) -> Any:
  43. new_details, new_request_iterator, postprocess = self._fn(
  44. client_call_details, request_iterator, True, True
  45. )
  46. response_it = continuation(new_details, new_request_iterator)
  47. return postprocess(response_it) if postprocess else response_it
  48. class _GenericAsyncClientInterceptor(
  49. grpc.aio.UnaryUnaryClientInterceptor,
  50. grpc.aio.UnaryStreamClientInterceptor,
  51. grpc.aio.StreamUnaryClientInterceptor,
  52. grpc.aio.StreamStreamClientInterceptor,
  53. ):
  54. def __init__(self, interceptor_function: Callable):
  55. self._fn = interceptor_function
  56. async def intercept_unary_unary(
  57. self, continuation: Any, client_call_details: Any, request: Any
  58. ) -> Any:
  59. new_details, new_request_iterator, postprocess = await self._fn(
  60. client_call_details, iter((request,)), False, False
  61. )
  62. next_request = next(new_request_iterator)
  63. response = await continuation(new_details, next_request)
  64. return await postprocess(response) if postprocess else response
  65. async def intercept_unary_stream(
  66. self, continuation: Any, client_call_details: Any, request: Any
  67. ) -> Any:
  68. new_details, new_request_iterator, postprocess = await self._fn(
  69. client_call_details, iter((request,)), False, True
  70. )
  71. response_it = await continuation(new_details, next(new_request_iterator))
  72. return await postprocess(response_it) if postprocess else response_it
  73. async def intercept_stream_unary(
  74. self, continuation: Any, client_call_details: Any, request_iterator: Any
  75. ) -> Any:
  76. new_details, new_request_iterator, postprocess = await self._fn(
  77. client_call_details, request_iterator, True, False
  78. )
  79. response = await continuation(new_details, new_request_iterator)
  80. return await postprocess(response) if postprocess else response
  81. async def intercept_stream_stream(
  82. self, continuation: Any, client_call_details: Any, request_iterator: Any
  83. ) -> Any:
  84. new_details, new_request_iterator, postprocess = await self._fn(
  85. client_call_details, request_iterator, True, True
  86. )
  87. response_it = await continuation(new_details, new_request_iterator)
  88. return await postprocess(response_it) if postprocess else response_it
  89. def create_generic_client_interceptor(intercept_call: Any) -> _GenericClientInterceptor:
  90. return _GenericClientInterceptor(intercept_call)
  91. def create_generic_async_client_interceptor(
  92. intercept_call: Any,
  93. ) -> _GenericAsyncClientInterceptor:
  94. return _GenericAsyncClientInterceptor(intercept_call)
  95. # Source:
  96. # <https://github.com/grpc/grpc/blob/master/examples/python/interceptors/headers/header_manipulator_client_interceptor.py>
  97. class _ClientCallDetails(
  98. collections.namedtuple("_ClientCallDetails", ("method", "timeout", "metadata", "credentials")),
  99. grpc.ClientCallDetails,
  100. ):
  101. pass
  102. class _ClientAsyncCallDetails(
  103. collections.namedtuple("_ClientCallDetails", ("method", "timeout", "metadata", "credentials")),
  104. grpc.aio.ClientCallDetails,
  105. ):
  106. pass
  107. def header_adder_interceptor(
  108. new_metadata: list[tuple[str, str]],
  109. auth_token_provider: Optional[Callable[[], str]] = None,
  110. ) -> _GenericClientInterceptor:
  111. def process_response(response: Any) -> Any:
  112. if response.code() == grpc.StatusCode.RESOURCE_EXHAUSTED:
  113. retry_after = None
  114. for item in response.trailing_metadata():
  115. if item.key == "retry-after":
  116. try:
  117. retry_after = int(item.value)
  118. except Exception:
  119. retry_after = None
  120. break
  121. reason_phrase = response.details() if response.details() else ""
  122. if retry_after:
  123. raise ResourceExhaustedResponse(message=reason_phrase, retry_after_s=retry_after)
  124. return response
  125. def intercept_call(
  126. client_call_details: _ClientCallDetails,
  127. request_iterator: Any,
  128. _request_streaming: Any,
  129. _response_streaming: Any,
  130. ) -> tuple[_ClientCallDetails, Any, Any]:
  131. metadata = []
  132. if client_call_details.metadata is not None:
  133. metadata = list(client_call_details.metadata)
  134. for header, value in new_metadata:
  135. metadata.append(
  136. (
  137. header,
  138. value,
  139. )
  140. )
  141. if auth_token_provider:
  142. if not asyncio.iscoroutinefunction(auth_token_provider):
  143. metadata.append(("authorization", f"Bearer {auth_token_provider()}"))
  144. else:
  145. raise ValueError("Synchronous channel requires synchronous auth token provider.")
  146. client_call_details = _ClientCallDetails(
  147. client_call_details.method,
  148. client_call_details.timeout,
  149. metadata,
  150. client_call_details.credentials,
  151. )
  152. return client_call_details, request_iterator, process_response
  153. return create_generic_client_interceptor(intercept_call)
  154. def header_adder_async_interceptor(
  155. new_metadata: list[tuple[str, str]],
  156. auth_token_provider: Optional[Union[Callable[[], str], Callable[[], Awaitable[str]]]] = None,
  157. ) -> _GenericAsyncClientInterceptor:
  158. async def process_response(call: Any) -> Any:
  159. try:
  160. return await call
  161. except grpc.aio.AioRpcError as er:
  162. if er.code() == grpc.StatusCode.RESOURCE_EXHAUSTED:
  163. retry_after = None
  164. for item in er.trailing_metadata():
  165. if item[0] == "retry-after":
  166. try:
  167. retry_after = int(item[1])
  168. except Exception:
  169. retry_after = None
  170. break
  171. reason_phrase = er.details() if er.details() else ""
  172. if retry_after:
  173. raise ResourceExhaustedResponse(
  174. message=reason_phrase, retry_after_s=retry_after
  175. ) from er
  176. raise
  177. async def intercept_call(
  178. client_call_details: grpc.aio.ClientCallDetails,
  179. request_iterator: Any,
  180. _request_streaming: Any,
  181. _response_streaming: Any,
  182. ) -> tuple[_ClientAsyncCallDetails, Any, Any]:
  183. metadata = []
  184. if client_call_details.metadata is not None:
  185. metadata = list(client_call_details.metadata)
  186. for header, value in new_metadata:
  187. metadata.append(
  188. (
  189. header,
  190. value,
  191. )
  192. )
  193. if auth_token_provider:
  194. if asyncio.iscoroutinefunction(auth_token_provider):
  195. token = await auth_token_provider()
  196. else:
  197. token = auth_token_provider()
  198. metadata.append(("authorization", f"Bearer {token}"))
  199. client_call_details = client_call_details._replace(metadata=metadata)
  200. return client_call_details, request_iterator, process_response
  201. return create_generic_async_client_interceptor(intercept_call)
  202. def parse_channel_options(options: Optional[dict[str, Any]] = None) -> list[tuple[str, Any]]:
  203. default_options: list[tuple[str, Any]] = [
  204. ("grpc.max_send_message_length", -1),
  205. ("grpc.max_receive_message_length", -1),
  206. ]
  207. if options is None:
  208. return default_options
  209. _options = [(option_name, option_value) for option_name, option_value in options.items()]
  210. for option_name, option_value in default_options:
  211. if option_name not in options:
  212. _options.append((option_name, option_value))
  213. return _options
  214. def get_channel(
  215. host: str,
  216. port: int,
  217. ssl: bool,
  218. metadata: Optional[list[tuple[str, str]]] = None,
  219. options: Optional[dict[str, Any]] = None,
  220. compression: Optional[grpc.Compression] = None,
  221. auth_token_provider: Optional[Callable[[], str]] = None,
  222. ) -> grpc.Channel:
  223. # Parse gRPC client options
  224. _options = parse_channel_options(options)
  225. metadata_interceptor = header_adder_interceptor(
  226. new_metadata=metadata or [], auth_token_provider=auth_token_provider
  227. )
  228. if ssl:
  229. ssl_creds = grpc.ssl_channel_credentials()
  230. channel = grpc.secure_channel(f"{host}:{port}", ssl_creds, _options, compression)
  231. return grpc.intercept_channel(channel, metadata_interceptor)
  232. else:
  233. channel = grpc.insecure_channel(f"{host}:{port}", _options, compression)
  234. return grpc.intercept_channel(channel, metadata_interceptor)
  235. def get_async_channel(
  236. host: str,
  237. port: int,
  238. ssl: bool,
  239. metadata: Optional[list[tuple[str, str]]] = None,
  240. options: Optional[dict[str, Any]] = None,
  241. compression: Optional[grpc.Compression] = None,
  242. auth_token_provider: Optional[Union[Callable[[], str], Callable[[], Awaitable[str]]]] = None,
  243. ) -> grpc.aio.Channel:
  244. # Parse gRPC client options
  245. _options = parse_channel_options(options)
  246. # Create metadata interceptor
  247. metadata_interceptor = header_adder_async_interceptor(
  248. new_metadata=metadata or [], auth_token_provider=auth_token_provider
  249. )
  250. if ssl:
  251. ssl_creds = grpc.ssl_channel_credentials()
  252. return grpc.aio.secure_channel(
  253. f"{host}:{port}",
  254. ssl_creds,
  255. _options,
  256. compression,
  257. interceptors=[metadata_interceptor],
  258. )
  259. else:
  260. return grpc.aio.insecure_channel(
  261. f"{host}:{port}", _options, compression, interceptors=[metadata_interceptor]
  262. )