embedder.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. from collections import defaultdict
  2. from typing import Optional, Sequence, Any, TypeVar, Generic
  3. from pydantic import BaseModel
  4. from qdrant_client.http import models
  5. from qdrant_client.embed.models import NumericVector
  6. from qdrant_client.fastembed_common import (
  7. OnnxProvider,
  8. ImageInput,
  9. TextEmbedding,
  10. SparseTextEmbedding,
  11. LateInteractionTextEmbedding,
  12. LateInteractionMultimodalEmbedding,
  13. ImageEmbedding,
  14. FastEmbedMisc,
  15. )
  16. T = TypeVar("T")
  17. class ModelInstance(BaseModel, Generic[T], arbitrary_types_allowed=True): # type: ignore[call-arg]
  18. model: T
  19. options: dict[str, Any]
  20. deprecated: bool = False
  21. class Embedder:
  22. def __init__(self, threads: Optional[int] = None, **kwargs: Any) -> None:
  23. self.embedding_models: dict[str, list[ModelInstance[TextEmbedding]]] = defaultdict(list)
  24. self.sparse_embedding_models: dict[str, list[ModelInstance[SparseTextEmbedding]]] = (
  25. defaultdict(list)
  26. )
  27. self.late_interaction_embedding_models: dict[
  28. str, list[ModelInstance[LateInteractionTextEmbedding]]
  29. ] = defaultdict(list)
  30. self.image_embedding_models: dict[str, list[ModelInstance[ImageEmbedding]]] = defaultdict(
  31. list
  32. )
  33. self.late_interaction_multimodal_embedding_models: dict[
  34. str, list[ModelInstance[LateInteractionMultimodalEmbedding]]
  35. ] = defaultdict(list)
  36. self._threads = threads
  37. def get_or_init_model(
  38. self,
  39. model_name: str,
  40. cache_dir: Optional[str] = None,
  41. threads: Optional[int] = None,
  42. providers: Optional[Sequence["OnnxProvider"]] = None,
  43. cuda: bool = False,
  44. device_ids: Optional[list[int]] = None,
  45. deprecated: bool = False,
  46. **kwargs: Any,
  47. ) -> TextEmbedding:
  48. if not FastEmbedMisc.is_supported_text_model(model_name):
  49. raise ValueError(
  50. f"Unsupported embedding model: {model_name}. Supported models: {FastEmbedMisc.list_text_models()}"
  51. )
  52. options = {
  53. "cache_dir": cache_dir,
  54. "threads": threads or self._threads,
  55. "providers": providers,
  56. "cuda": cuda,
  57. "device_ids": device_ids,
  58. **kwargs,
  59. }
  60. for instance in self.embedding_models[model_name]:
  61. if (deprecated and instance.deprecated) or (
  62. not deprecated and instance.options == options
  63. ):
  64. return instance.model
  65. model = TextEmbedding(model_name=model_name, **options)
  66. model_instance: ModelInstance[TextEmbedding] = ModelInstance(
  67. model=model, options=options, deprecated=deprecated
  68. )
  69. self.embedding_models[model_name].append(model_instance)
  70. return model
  71. def get_or_init_sparse_model(
  72. self,
  73. model_name: str,
  74. cache_dir: Optional[str] = None,
  75. threads: Optional[int] = None,
  76. providers: Optional[Sequence["OnnxProvider"]] = None,
  77. cuda: bool = False,
  78. device_ids: Optional[list[int]] = None,
  79. deprecated: bool = False,
  80. **kwargs: Any,
  81. ) -> SparseTextEmbedding:
  82. if not FastEmbedMisc.is_supported_sparse_model(model_name):
  83. raise ValueError(
  84. f"Unsupported embedding model: {model_name}. Supported models: {FastEmbedMisc.list_sparse_models()}"
  85. )
  86. options = {
  87. "cache_dir": cache_dir,
  88. "threads": threads or self._threads,
  89. "providers": providers,
  90. "cuda": cuda,
  91. "device_ids": device_ids,
  92. **kwargs,
  93. }
  94. for instance in self.sparse_embedding_models[model_name]:
  95. if (deprecated and instance.deprecated) or (
  96. not deprecated and instance.options == options
  97. ):
  98. return instance.model
  99. model = SparseTextEmbedding(model_name=model_name, **options)
  100. model_instance: ModelInstance[SparseTextEmbedding] = ModelInstance(
  101. model=model, options=options, deprecated=deprecated
  102. )
  103. self.sparse_embedding_models[model_name].append(model_instance)
  104. return model
  105. def get_or_init_late_interaction_model(
  106. self,
  107. model_name: str,
  108. cache_dir: Optional[str] = None,
  109. threads: Optional[int] = None,
  110. providers: Optional[Sequence["OnnxProvider"]] = None,
  111. cuda: bool = False,
  112. device_ids: Optional[list[int]] = None,
  113. **kwargs: Any,
  114. ) -> LateInteractionTextEmbedding:
  115. if not FastEmbedMisc.is_supported_late_interaction_text_model(model_name):
  116. raise ValueError(
  117. f"Unsupported embedding model: {model_name}. "
  118. f"Supported models: {FastEmbedMisc.list_late_interaction_text_models()}"
  119. )
  120. options = {
  121. "cache_dir": cache_dir,
  122. "threads": threads or self._threads,
  123. "providers": providers,
  124. "cuda": cuda,
  125. "device_ids": device_ids,
  126. **kwargs,
  127. }
  128. for instance in self.late_interaction_embedding_models[model_name]:
  129. if instance.options == options:
  130. return instance.model
  131. model = LateInteractionTextEmbedding(model_name=model_name, **options)
  132. model_instance: ModelInstance[LateInteractionTextEmbedding] = ModelInstance(
  133. model=model, options=options
  134. )
  135. self.late_interaction_embedding_models[model_name].append(model_instance)
  136. return model
  137. def get_or_init_late_interaction_multimodal_model(
  138. self,
  139. model_name: str,
  140. cache_dir: Optional[str] = None,
  141. threads: Optional[int] = None,
  142. providers: Optional[Sequence["OnnxProvider"]] = None,
  143. cuda: bool = False,
  144. device_ids: Optional[list[int]] = None,
  145. **kwargs: Any,
  146. ) -> LateInteractionMultimodalEmbedding:
  147. if not FastEmbedMisc.is_supported_late_interaction_multimodal_model(model_name):
  148. raise ValueError(
  149. f"Unsupported embedding model: {model_name}. "
  150. f"Supported models: {FastEmbedMisc.list_late_interaction_multimodal_models()}"
  151. )
  152. options = {
  153. "cache_dir": cache_dir,
  154. "threads": threads or self._threads,
  155. "providers": providers,
  156. "cuda": cuda,
  157. "device_ids": device_ids,
  158. **kwargs,
  159. }
  160. for instance in self.late_interaction_multimodal_embedding_models[model_name]:
  161. if instance.options == options:
  162. return instance.model
  163. model = LateInteractionMultimodalEmbedding(model_name=model_name, **options)
  164. model_instance: ModelInstance[LateInteractionMultimodalEmbedding] = ModelInstance(
  165. model=model, options=options
  166. )
  167. self.late_interaction_multimodal_embedding_models[model_name].append(model_instance)
  168. return model
  169. def get_or_init_image_model(
  170. self,
  171. model_name: str,
  172. cache_dir: Optional[str] = None,
  173. threads: Optional[int] = None,
  174. providers: Optional[Sequence["OnnxProvider"]] = None,
  175. cuda: bool = False,
  176. device_ids: Optional[list[int]] = None,
  177. **kwargs: Any,
  178. ) -> ImageEmbedding:
  179. if not FastEmbedMisc.is_supported_image_model(model_name):
  180. raise ValueError(
  181. f"Unsupported embedding model: {model_name}. Supported models: {FastEmbedMisc.list_image_models()}"
  182. )
  183. options = {
  184. "cache_dir": cache_dir,
  185. "threads": threads or self._threads,
  186. "providers": providers,
  187. "cuda": cuda,
  188. "device_ids": device_ids,
  189. **kwargs,
  190. }
  191. for instance in self.image_embedding_models[model_name]:
  192. if instance.options == options:
  193. return instance.model
  194. model = ImageEmbedding(model_name=model_name, **options)
  195. model_instance: ModelInstance[ImageEmbedding] = ModelInstance(model=model, options=options)
  196. self.image_embedding_models[model_name].append(model_instance)
  197. return model
  198. def embed(
  199. self,
  200. model_name: str,
  201. texts: Optional[list[str]] = None,
  202. images: Optional[list[ImageInput]] = None,
  203. options: Optional[dict[str, Any]] = None,
  204. is_query: bool = False,
  205. batch_size: int = 8,
  206. ) -> NumericVector:
  207. if (texts is None) is (images is None):
  208. raise ValueError("Either documents or images should be provided")
  209. embeddings: NumericVector # define type for a static type checker
  210. if texts is not None:
  211. if FastEmbedMisc.is_supported_text_model(model_name):
  212. embeddings = self._embed_dense_text(
  213. texts, model_name, options, is_query, batch_size
  214. )
  215. elif FastEmbedMisc.is_supported_sparse_model(model_name):
  216. embeddings = self._embed_sparse_text(
  217. texts, model_name, options, is_query, batch_size
  218. )
  219. elif FastEmbedMisc.is_supported_late_interaction_text_model(model_name):
  220. embeddings = self._embed_late_interaction_text(
  221. texts, model_name, options, is_query, batch_size
  222. )
  223. elif FastEmbedMisc.is_supported_late_interaction_multimodal_model(model_name):
  224. embeddings = self._embed_late_interaction_multimodal_text(
  225. texts, model_name, options, batch_size
  226. )
  227. else:
  228. raise ValueError(f"Unsupported embedding model: {model_name}")
  229. else:
  230. assert (
  231. images is not None
  232. ) # just to satisfy mypy which can't infer it from the previous conditions
  233. if FastEmbedMisc.is_supported_image_model(model_name):
  234. embeddings = self._embed_dense_image(images, model_name, options, batch_size)
  235. elif FastEmbedMisc.is_supported_late_interaction_multimodal_model(model_name):
  236. embeddings = self._embed_late_interaction_multimodal_image(
  237. images, model_name, options, batch_size
  238. )
  239. else:
  240. raise ValueError(f"Unsupported embedding model: {model_name}")
  241. return embeddings
  242. def _embed_dense_text(
  243. self,
  244. texts: list[str],
  245. model_name: str,
  246. options: Optional[dict[str, Any]],
  247. is_query: bool,
  248. batch_size: int,
  249. ) -> list[list[float]]:
  250. embedding_model_inst = self.get_or_init_model(model_name=model_name, **options or {})
  251. if not is_query:
  252. embeddings = [
  253. embedding.tolist()
  254. for embedding in embedding_model_inst.embed(documents=texts, batch_size=batch_size)
  255. ]
  256. else:
  257. embeddings = [
  258. embedding.tolist() for embedding in embedding_model_inst.query_embed(query=texts)
  259. ]
  260. return embeddings
  261. def _embed_sparse_text(
  262. self,
  263. texts: list[str],
  264. model_name: str,
  265. options: Optional[dict[str, Any]],
  266. is_query: bool,
  267. batch_size: int,
  268. ) -> list[models.SparseVector]:
  269. embedding_model_inst = self.get_or_init_sparse_model(
  270. model_name=model_name, **options or {}
  271. )
  272. if not is_query:
  273. embeddings = [
  274. models.SparseVector(
  275. indices=sparse_embedding.indices.tolist(),
  276. values=sparse_embedding.values.tolist(),
  277. )
  278. for sparse_embedding in embedding_model_inst.embed(
  279. documents=texts, batch_size=batch_size
  280. )
  281. ]
  282. else:
  283. embeddings = [
  284. models.SparseVector(
  285. indices=sparse_embedding.indices.tolist(),
  286. values=sparse_embedding.values.tolist(),
  287. )
  288. for sparse_embedding in embedding_model_inst.query_embed(query=texts)
  289. ]
  290. return embeddings
  291. def _embed_late_interaction_text(
  292. self,
  293. texts: list[str],
  294. model_name: str,
  295. options: Optional[dict[str, Any]],
  296. is_query: bool,
  297. batch_size: int,
  298. ) -> list[list[list[float]]]:
  299. embedding_model_inst = self.get_or_init_late_interaction_model(
  300. model_name=model_name, **options or {}
  301. )
  302. if not is_query:
  303. embeddings = [
  304. embedding.tolist()
  305. for embedding in embedding_model_inst.embed(documents=texts, batch_size=batch_size)
  306. ]
  307. else:
  308. embeddings = [
  309. embedding.tolist() for embedding in embedding_model_inst.query_embed(query=texts)
  310. ]
  311. return embeddings
  312. def _embed_late_interaction_multimodal_text(
  313. self,
  314. texts: list[str],
  315. model_name: str,
  316. options: Optional[dict[str, Any]],
  317. batch_size: int,
  318. ) -> list[list[list[float]]]:
  319. embedding_model_inst = self.get_or_init_late_interaction_multimodal_model(
  320. model_name=model_name, **options or {}
  321. )
  322. return [
  323. embedding.tolist()
  324. for embedding in embedding_model_inst.embed_text(
  325. documents=texts, batch_size=batch_size
  326. )
  327. ]
  328. def _embed_late_interaction_multimodal_image(
  329. self,
  330. images: list[ImageInput],
  331. model_name: str,
  332. options: Optional[dict[str, Any]],
  333. batch_size: int,
  334. ) -> list[list[list[float]]]:
  335. embedding_model_inst = self.get_or_init_late_interaction_multimodal_model(
  336. model_name=model_name, **options or {}
  337. )
  338. return [
  339. embedding.tolist()
  340. for embedding in embedding_model_inst.embed_image(images=images, batch_size=batch_size)
  341. ]
  342. def _embed_dense_image(
  343. self,
  344. images: list[ImageInput],
  345. model_name: str,
  346. options: Optional[dict[str, Any]],
  347. batch_size: int,
  348. ) -> list[list[float]]:
  349. embedding_model_inst = self.get_or_init_image_model(model_name=model_name, **options or {})
  350. embeddings = [
  351. embedding.tolist()
  352. for embedding in embedding_model_inst.embed(images=images, batch_size=batch_size)
  353. ]
  354. return embeddings