Coverage for mindsdb / integrations / handlers / langchain_embedding_handler / vllm_embeddings.py: 19%

44 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from typing import Any, List 

2from langchain_core.embeddings import Embeddings 

3from openai import AsyncOpenAI 

4import asyncio 

5 

6 

7class VLLMEmbeddings(Embeddings): 

8 """VLLMEmbeddings uses a VLLM server to generate embeddings.""" 

9 

10 def __init__( 

11 self, 

12 openai_api_base: str, 

13 model: str, 

14 batch_size: int = 32, 

15 **kwargs: Any, 

16 ): 

17 """Initialize the embeddings class. 

18 

19 Args: 

20 openai_api_base: Base URL for the VLLM server 

21 model: Model name/path to use for embeddings 

22 batch_size: Batch size for generating embeddings 

23 """ 

24 super().__init__() 

25 self.model = model 

26 self.batch_size = batch_size 

27 self.is_nomic = "nomic-embed-text" in model.lower() 

28 

29 # Initialize OpenAI client 

30 openai_kwargs = kwargs.copy() 

31 if "input_columns" in openai_kwargs: 

32 del openai_kwargs["input_columns"] 

33 

34 self.client = AsyncOpenAI( 

35 api_key="EMPTY", # vLLM doesn't need an API key 

36 base_url=openai_api_base, 

37 **openai_kwargs, 

38 ) 

39 

40 def _format_text(self, text: str, is_query: bool = False) -> str: 

41 """ 

42 Format text according to nomic-embed requirements if using nomic model. 

43 e.g. see here for more details: https://huggingface.co/nomic-ai/nomic-embed-text-v1.5#task-instruction-prefixes 

44 """ 

45 

46 if not self.is_nomic: 

47 return text 

48 prefix = "search_query: " if is_query else "search_document: " 

49 return prefix + text 

50 

51 def _get_embeddings(self, texts: List[str]) -> List[List[float]]: 

52 """Get embeddings for a batch of texts.""" 

53 

54 async def await_openai_call(batch): 

55 return await self.client.embeddings.create(model=self.model, input=batch) 

56 

57 embeddings = [] 

58 embedding_coroutines = [] 

59 chunk_start_indices = range(0, len(texts), self.batch_size) 

60 for i in chunk_start_indices: 

61 

62 batch = texts[i: i + self.batch_size] 

63 embedding_coroutines.append(await_openai_call(batch)) 

64 

65 # if at max-concurrency, then run with gather 

66 if len(embedding_coroutines) == 512 or len(embedding_coroutines) == len( 

67 chunk_start_indices 

68 ): 

69 

70 openai_responses = [] 

71 

72 async def gather_coroutines(openai_responses): 

73 # define a function to gather and save responses. 

74 intermediate = await asyncio.gather(*embedding_coroutines) 

75 openai_responses.extend(intermediate) 

76 

77 # run asynchronously 

78 asyncio.run(gather_coroutines(openai_responses)) 

79 

80 # extract embeddings from responses 

81 for response in openai_responses: 

82 embeddings.extend([data.embedding for data in response.data]) 

83 

84 # reset the embedding_coroutines list 

85 embedding_coroutines = [] 

86 

87 return embeddings 

88 

89 def embed_documents(self, texts: List[str]) -> List[List[float]]: 

90 """Embed a list of documents using vLLM. 

91 

92 Args: 

93 texts: List of documents to embed 

94 

95 Returns: 

96 List of embeddings, one for each document 

97 """ 

98 formatted_texts = [self._format_text(text) for text in texts] 

99 return self._get_embeddings(formatted_texts) 

100 

101 def embed_query(self, text: str) -> List[float]: 

102 """Embed a single query text using vLLM. 

103 

104 Args: 

105 text: Query text to embed 

106 

107 Returns: 

108 Query embedding 

109 """ 

110 formatted_text = self._format_text(text, is_query=True) 

111 return self._get_embeddings([formatted_text])[0]