Coverage for mindsdb / integrations / handlers / langchain_embedding_handler / vllm_embeddings.py: 19%
44 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from typing import Any, List
2from langchain_core.embeddings import Embeddings
3from openai import AsyncOpenAI
4import asyncio
7class VLLMEmbeddings(Embeddings):
8 """VLLMEmbeddings uses a VLLM server to generate embeddings."""
10 def __init__(
11 self,
12 openai_api_base: str,
13 model: str,
14 batch_size: int = 32,
15 **kwargs: Any,
16 ):
17 """Initialize the embeddings class.
19 Args:
20 openai_api_base: Base URL for the VLLM server
21 model: Model name/path to use for embeddings
22 batch_size: Batch size for generating embeddings
23 """
24 super().__init__()
25 self.model = model
26 self.batch_size = batch_size
27 self.is_nomic = "nomic-embed-text" in model.lower()
29 # Initialize OpenAI client
30 openai_kwargs = kwargs.copy()
31 if "input_columns" in openai_kwargs:
32 del openai_kwargs["input_columns"]
34 self.client = AsyncOpenAI(
35 api_key="EMPTY", # vLLM doesn't need an API key
36 base_url=openai_api_base,
37 **openai_kwargs,
38 )
40 def _format_text(self, text: str, is_query: bool = False) -> str:
41 """
42 Format text according to nomic-embed requirements if using nomic model.
43 e.g. see here for more details: https://huggingface.co/nomic-ai/nomic-embed-text-v1.5#task-instruction-prefixes
44 """
46 if not self.is_nomic:
47 return text
48 prefix = "search_query: " if is_query else "search_document: "
49 return prefix + text
51 def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
52 """Get embeddings for a batch of texts."""
54 async def await_openai_call(batch):
55 return await self.client.embeddings.create(model=self.model, input=batch)
57 embeddings = []
58 embedding_coroutines = []
59 chunk_start_indices = range(0, len(texts), self.batch_size)
60 for i in chunk_start_indices:
62 batch = texts[i: i + self.batch_size]
63 embedding_coroutines.append(await_openai_call(batch))
65 # if at max-concurrency, then run with gather
66 if len(embedding_coroutines) == 512 or len(embedding_coroutines) == len(
67 chunk_start_indices
68 ):
70 openai_responses = []
72 async def gather_coroutines(openai_responses):
73 # define a function to gather and save responses.
74 intermediate = await asyncio.gather(*embedding_coroutines)
75 openai_responses.extend(intermediate)
77 # run asynchronously
78 asyncio.run(gather_coroutines(openai_responses))
80 # extract embeddings from responses
81 for response in openai_responses:
82 embeddings.extend([data.embedding for data in response.data])
84 # reset the embedding_coroutines list
85 embedding_coroutines = []
87 return embeddings
89 def embed_documents(self, texts: List[str]) -> List[List[float]]:
90 """Embed a list of documents using vLLM.
92 Args:
93 texts: List of documents to embed
95 Returns:
96 List of embeddings, one for each document
97 """
98 formatted_texts = [self._format_text(text) for text in texts]
99 return self._get_embeddings(formatted_texts)
101 def embed_query(self, text: str) -> List[float]:
102 """Embed a single query text using vLLM.
104 Args:
105 text: Query text to embed
107 Returns:
108 Query embedding
109 """
110 formatted_text = self._format_text(text, is_query=True)
111 return self._get_embeddings([formatted_text])[0]