| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387 |
- from collections import defaultdict
- from typing import Optional, Sequence, Any, TypeVar, Generic
- from pydantic import BaseModel
- from qdrant_client.http import models
- from qdrant_client.embed.models import NumericVector
- from qdrant_client.fastembed_common import (
- OnnxProvider,
- ImageInput,
- TextEmbedding,
- SparseTextEmbedding,
- LateInteractionTextEmbedding,
- LateInteractionMultimodalEmbedding,
- ImageEmbedding,
- FastEmbedMisc,
- )
- T = TypeVar("T")
- class ModelInstance(BaseModel, Generic[T], arbitrary_types_allowed=True): # type: ignore[call-arg]
- model: T
- options: dict[str, Any]
- deprecated: bool = False
- class Embedder:
- def __init__(self, threads: Optional[int] = None, **kwargs: Any) -> None:
- self.embedding_models: dict[str, list[ModelInstance[TextEmbedding]]] = defaultdict(list)
- self.sparse_embedding_models: dict[str, list[ModelInstance[SparseTextEmbedding]]] = (
- defaultdict(list)
- )
- self.late_interaction_embedding_models: dict[
- str, list[ModelInstance[LateInteractionTextEmbedding]]
- ] = defaultdict(list)
- self.image_embedding_models: dict[str, list[ModelInstance[ImageEmbedding]]] = defaultdict(
- list
- )
- self.late_interaction_multimodal_embedding_models: dict[
- str, list[ModelInstance[LateInteractionMultimodalEmbedding]]
- ] = defaultdict(list)
- self._threads = threads
- def get_or_init_model(
- self,
- model_name: str,
- cache_dir: Optional[str] = None,
- threads: Optional[int] = None,
- providers: Optional[Sequence["OnnxProvider"]] = None,
- cuda: bool = False,
- device_ids: Optional[list[int]] = None,
- deprecated: bool = False,
- **kwargs: Any,
- ) -> TextEmbedding:
- if not FastEmbedMisc.is_supported_text_model(model_name):
- raise ValueError(
- f"Unsupported embedding model: {model_name}. Supported models: {FastEmbedMisc.list_text_models()}"
- )
- options = {
- "cache_dir": cache_dir,
- "threads": threads or self._threads,
- "providers": providers,
- "cuda": cuda,
- "device_ids": device_ids,
- **kwargs,
- }
- for instance in self.embedding_models[model_name]:
- if (deprecated and instance.deprecated) or (
- not deprecated and instance.options == options
- ):
- return instance.model
- model = TextEmbedding(model_name=model_name, **options)
- model_instance: ModelInstance[TextEmbedding] = ModelInstance(
- model=model, options=options, deprecated=deprecated
- )
- self.embedding_models[model_name].append(model_instance)
- return model
- def get_or_init_sparse_model(
- self,
- model_name: str,
- cache_dir: Optional[str] = None,
- threads: Optional[int] = None,
- providers: Optional[Sequence["OnnxProvider"]] = None,
- cuda: bool = False,
- device_ids: Optional[list[int]] = None,
- deprecated: bool = False,
- **kwargs: Any,
- ) -> SparseTextEmbedding:
- if not FastEmbedMisc.is_supported_sparse_model(model_name):
- raise ValueError(
- f"Unsupported embedding model: {model_name}. Supported models: {FastEmbedMisc.list_sparse_models()}"
- )
- options = {
- "cache_dir": cache_dir,
- "threads": threads or self._threads,
- "providers": providers,
- "cuda": cuda,
- "device_ids": device_ids,
- **kwargs,
- }
- for instance in self.sparse_embedding_models[model_name]:
- if (deprecated and instance.deprecated) or (
- not deprecated and instance.options == options
- ):
- return instance.model
- model = SparseTextEmbedding(model_name=model_name, **options)
- model_instance: ModelInstance[SparseTextEmbedding] = ModelInstance(
- model=model, options=options, deprecated=deprecated
- )
- self.sparse_embedding_models[model_name].append(model_instance)
- return model
- def get_or_init_late_interaction_model(
- self,
- model_name: str,
- cache_dir: Optional[str] = None,
- threads: Optional[int] = None,
- providers: Optional[Sequence["OnnxProvider"]] = None,
- cuda: bool = False,
- device_ids: Optional[list[int]] = None,
- **kwargs: Any,
- ) -> LateInteractionTextEmbedding:
- if not FastEmbedMisc.is_supported_late_interaction_text_model(model_name):
- raise ValueError(
- f"Unsupported embedding model: {model_name}. "
- f"Supported models: {FastEmbedMisc.list_late_interaction_text_models()}"
- )
- options = {
- "cache_dir": cache_dir,
- "threads": threads or self._threads,
- "providers": providers,
- "cuda": cuda,
- "device_ids": device_ids,
- **kwargs,
- }
- for instance in self.late_interaction_embedding_models[model_name]:
- if instance.options == options:
- return instance.model
- model = LateInteractionTextEmbedding(model_name=model_name, **options)
- model_instance: ModelInstance[LateInteractionTextEmbedding] = ModelInstance(
- model=model, options=options
- )
- self.late_interaction_embedding_models[model_name].append(model_instance)
- return model
- def get_or_init_late_interaction_multimodal_model(
- self,
- model_name: str,
- cache_dir: Optional[str] = None,
- threads: Optional[int] = None,
- providers: Optional[Sequence["OnnxProvider"]] = None,
- cuda: bool = False,
- device_ids: Optional[list[int]] = None,
- **kwargs: Any,
- ) -> LateInteractionMultimodalEmbedding:
- if not FastEmbedMisc.is_supported_late_interaction_multimodal_model(model_name):
- raise ValueError(
- f"Unsupported embedding model: {model_name}. "
- f"Supported models: {FastEmbedMisc.list_late_interaction_multimodal_models()}"
- )
- options = {
- "cache_dir": cache_dir,
- "threads": threads or self._threads,
- "providers": providers,
- "cuda": cuda,
- "device_ids": device_ids,
- **kwargs,
- }
- for instance in self.late_interaction_multimodal_embedding_models[model_name]:
- if instance.options == options:
- return instance.model
- model = LateInteractionMultimodalEmbedding(model_name=model_name, **options)
- model_instance: ModelInstance[LateInteractionMultimodalEmbedding] = ModelInstance(
- model=model, options=options
- )
- self.late_interaction_multimodal_embedding_models[model_name].append(model_instance)
- return model
- def get_or_init_image_model(
- self,
- model_name: str,
- cache_dir: Optional[str] = None,
- threads: Optional[int] = None,
- providers: Optional[Sequence["OnnxProvider"]] = None,
- cuda: bool = False,
- device_ids: Optional[list[int]] = None,
- **kwargs: Any,
- ) -> ImageEmbedding:
- if not FastEmbedMisc.is_supported_image_model(model_name):
- raise ValueError(
- f"Unsupported embedding model: {model_name}. Supported models: {FastEmbedMisc.list_image_models()}"
- )
- options = {
- "cache_dir": cache_dir,
- "threads": threads or self._threads,
- "providers": providers,
- "cuda": cuda,
- "device_ids": device_ids,
- **kwargs,
- }
- for instance in self.image_embedding_models[model_name]:
- if instance.options == options:
- return instance.model
- model = ImageEmbedding(model_name=model_name, **options)
- model_instance: ModelInstance[ImageEmbedding] = ModelInstance(model=model, options=options)
- self.image_embedding_models[model_name].append(model_instance)
- return model
- def embed(
- self,
- model_name: str,
- texts: Optional[list[str]] = None,
- images: Optional[list[ImageInput]] = None,
- options: Optional[dict[str, Any]] = None,
- is_query: bool = False,
- batch_size: int = 8,
- ) -> NumericVector:
- if (texts is None) is (images is None):
- raise ValueError("Either documents or images should be provided")
- embeddings: NumericVector # define type for a static type checker
- if texts is not None:
- if FastEmbedMisc.is_supported_text_model(model_name):
- embeddings = self._embed_dense_text(
- texts, model_name, options, is_query, batch_size
- )
- elif FastEmbedMisc.is_supported_sparse_model(model_name):
- embeddings = self._embed_sparse_text(
- texts, model_name, options, is_query, batch_size
- )
- elif FastEmbedMisc.is_supported_late_interaction_text_model(model_name):
- embeddings = self._embed_late_interaction_text(
- texts, model_name, options, is_query, batch_size
- )
- elif FastEmbedMisc.is_supported_late_interaction_multimodal_model(model_name):
- embeddings = self._embed_late_interaction_multimodal_text(
- texts, model_name, options, batch_size
- )
- else:
- raise ValueError(f"Unsupported embedding model: {model_name}")
- else:
- assert (
- images is not None
- ) # just to satisfy mypy which can't infer it from the previous conditions
- if FastEmbedMisc.is_supported_image_model(model_name):
- embeddings = self._embed_dense_image(images, model_name, options, batch_size)
- elif FastEmbedMisc.is_supported_late_interaction_multimodal_model(model_name):
- embeddings = self._embed_late_interaction_multimodal_image(
- images, model_name, options, batch_size
- )
- else:
- raise ValueError(f"Unsupported embedding model: {model_name}")
- return embeddings
- def _embed_dense_text(
- self,
- texts: list[str],
- model_name: str,
- options: Optional[dict[str, Any]],
- is_query: bool,
- batch_size: int,
- ) -> list[list[float]]:
- embedding_model_inst = self.get_or_init_model(model_name=model_name, **options or {})
- if not is_query:
- embeddings = [
- embedding.tolist()
- for embedding in embedding_model_inst.embed(documents=texts, batch_size=batch_size)
- ]
- else:
- embeddings = [
- embedding.tolist() for embedding in embedding_model_inst.query_embed(query=texts)
- ]
- return embeddings
- def _embed_sparse_text(
- self,
- texts: list[str],
- model_name: str,
- options: Optional[dict[str, Any]],
- is_query: bool,
- batch_size: int,
- ) -> list[models.SparseVector]:
- embedding_model_inst = self.get_or_init_sparse_model(
- model_name=model_name, **options or {}
- )
- if not is_query:
- embeddings = [
- models.SparseVector(
- indices=sparse_embedding.indices.tolist(),
- values=sparse_embedding.values.tolist(),
- )
- for sparse_embedding in embedding_model_inst.embed(
- documents=texts, batch_size=batch_size
- )
- ]
- else:
- embeddings = [
- models.SparseVector(
- indices=sparse_embedding.indices.tolist(),
- values=sparse_embedding.values.tolist(),
- )
- for sparse_embedding in embedding_model_inst.query_embed(query=texts)
- ]
- return embeddings
- def _embed_late_interaction_text(
- self,
- texts: list[str],
- model_name: str,
- options: Optional[dict[str, Any]],
- is_query: bool,
- batch_size: int,
- ) -> list[list[list[float]]]:
- embedding_model_inst = self.get_or_init_late_interaction_model(
- model_name=model_name, **options or {}
- )
- if not is_query:
- embeddings = [
- embedding.tolist()
- for embedding in embedding_model_inst.embed(documents=texts, batch_size=batch_size)
- ]
- else:
- embeddings = [
- embedding.tolist() for embedding in embedding_model_inst.query_embed(query=texts)
- ]
- return embeddings
- def _embed_late_interaction_multimodal_text(
- self,
- texts: list[str],
- model_name: str,
- options: Optional[dict[str, Any]],
- batch_size: int,
- ) -> list[list[list[float]]]:
- embedding_model_inst = self.get_or_init_late_interaction_multimodal_model(
- model_name=model_name, **options or {}
- )
- return [
- embedding.tolist()
- for embedding in embedding_model_inst.embed_text(
- documents=texts, batch_size=batch_size
- )
- ]
- def _embed_late_interaction_multimodal_image(
- self,
- images: list[ImageInput],
- model_name: str,
- options: Optional[dict[str, Any]],
- batch_size: int,
- ) -> list[list[list[float]]]:
- embedding_model_inst = self.get_or_init_late_interaction_multimodal_model(
- model_name=model_name, **options or {}
- )
- return [
- embedding.tolist()
- for embedding in embedding_model_inst.embed_image(images=images, batch_size=batch_size)
- ]
- def _embed_dense_image(
- self,
- images: list[ImageInput],
- model_name: str,
- options: Optional[dict[str, Any]],
- batch_size: int,
- ) -> list[list[float]]:
- embedding_model_inst = self.get_or_init_image_model(model_name=model_name, **options or {})
- embeddings = [
- embedding.tolist()
- for embedding in embedding_model_inst.embed(images=images, batch_size=batch_size)
- ]
- return embeddings
|