Coverage for mindsdb / integrations / handlers / langchain_embedding_handler / langchain_embedding_handler.py: 28%
96 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import copy
2import importlib
3from typing import Dict, Union
5import pandas as pd
6from pandas import DataFrame
7from pydantic import BaseModel
9from mindsdb.integrations.libs.base import BaseMLEngine
10from mindsdb.utilities import log
11from langchain_core.embeddings import Embeddings
12from mindsdb.integrations.handlers.langchain_embedding_handler.vllm_embeddings import VLLMEmbeddings
13from mindsdb.integrations.handlers.langchain_embedding_handler.fastapi_embeddings import FastAPIEmbeddings
15logger = log.getLogger(__name__)
17# construct the embedding model name to the class mapping
18# we try to import all embedding models from langchain_community.embeddings
19# for each class, we get a more user friendly name for it
20# E.g. OpenAIEmbeddings -> OpenAI
21# This is used for the user to select the embedding model
22EMBEDDING_MODELS = {
23 'VLLM': 'VLLMEmbeddings',
24 'vllm': 'VLLMEmbeddings',
25 'FastAPI': 'FastAPIEmbeddings',
26 'fastapi': 'FastAPIEmbeddings'
28}
30try:
31 module = importlib.import_module("langchain_community.embeddings")
32 # iterate __all__ to get all the classes
33 for class_name in module.__all__:
34 class_ = getattr(module, class_name)
35 if not issubclass(class_, Embeddings): 35 ↛ 36line 35 didn't jump to line 36 because the condition on line 35 was never true
36 continue
37 # convert the class name to a more user friendly name
38 # e.g. OpenAIEmbeddings -> OpenAI
39 user_friendly_name = class_name.replace("Embeddings", "")
40 EMBEDDING_MODELS[user_friendly_name] = class_name
41 EMBEDDING_MODELS[user_friendly_name.lower()] = class_name
43except ImportError:
44 raise Exception(
45 "The langchain is not installed. Please install it with `pip install langchain-community`."
46 )
49def get_langchain_class(class_name: str) -> Embeddings:
50 """Returns the class object of the handler class.
52 Args:
53 class_name (str): Name of the class
55 Returns:
56 langchain.embeddings.BaseEmbedding: The class object
57 """
58 # First check if it's our custom VLLMEmbeddings
59 if class_name == "VLLMEmbeddings":
60 return VLLMEmbeddings
62 if class_name == "FastAPIEmbeddings":
63 return FastAPIEmbeddings
65 # Then try langchain_community.embeddings
66 try:
67 module = importlib.import_module("langchain_community.embeddings")
68 class_ = getattr(module, class_name)
69 except ImportError:
70 raise Exception(
71 "The langchain is not installed. Please install it with `pip install langchain`."
72 )
73 except AttributeError:
74 raise Exception(
75 f"Could not find the class {class_name} in langchain_community.embeddings. Please check the class name."
76 )
77 return class_
80def construct_model_from_args(args: Dict) -> Embeddings:
81 """
82 Deserializes the model from the model storage
83 """
84 target = args.pop("target", None)
85 class_name = args.pop("class", LangchainEmbeddingHandler.DEFAULT_EMBEDDING_CLASS)
86 if class_name in EMBEDDING_MODELS:
87 logger.info(
88 f"Mapping the user friendly name {class_name} to the class name: {EMBEDDING_MODELS[class_name]}"
89 )
90 class_name = EMBEDDING_MODELS[class_name]
91 MODEL_CLASS = get_langchain_class(class_name)
92 serialized_dict = copy.deepcopy(args)
94 # Make sure we don't pass in unnecessary arguments.
95 if issubclass(MODEL_CLASS, BaseModel):
96 serialized_dict = {
97 k: v for k, v in serialized_dict.items() if k in MODEL_CLASS.model_fields
98 }
100 model = MODEL_CLASS(**serialized_dict)
101 if target is not None:
102 args["target"] = target
103 args["class"] = class_name
104 return model
107def row_to_document(row: pd.Series) -> str:
108 """
109 Convert a row in the input dataframe into a document
111 Default implementation is to concatenate all the columns
112 in the form of
113 field1: value1\nfield2: value2\n...
114 """
115 fields = row.index.tolist()
116 values = row.values.tolist()
117 document = "\n".join(
118 [f"{field}: {value}" for field, value in zip(fields, values)]
119 )
120 return document
123class LangchainEmbeddingHandler(BaseMLEngine):
124 """
125 Bridge class to connect langchain.embeddings module to mindsDB
126 """
128 DEFAULT_EMBEDDING_CLASS = "OpenAIEmbeddings"
130 def __init__(self, model_storage, engine_storage, **kwargs) -> None:
131 super().__init__(model_storage, engine_storage, **kwargs)
132 self.generative = True
134 def create(
135 self,
136 target: str,
137 df: Union[DataFrame, None] = None,
138 args: Union[Dict, None] = None,
139 ) -> None:
140 # get the class name from the args
141 user_args = args.get("using", {})
143 # infer the input columns arg if user did not provide it
144 # from the columns of the input dataframe if it is provided
145 if "input_columns" not in user_args and df is not None:
146 # ignore private columns starts with __mindsdb
147 # ignore target column in the input dataframe
148 user_args["input_columns"] = [
149 col
150 for col in df.columns.tolist()
151 if not col.startswith("__mindsdb") and col != target
152 ]
153 # unquote the column names -- removing surrounding `
154 user_args["input_columns"] = [
155 col.strip("`") for col in user_args["input_columns"]
156 ]
158 elif "input_columns" not in user_args:
159 # set as empty list if the input_columns is not provided
160 user_args["input_columns"] = []
162 # this may raise an exception if
163 # the arguments are not sufficient to create such as class
164 # due to e.g., lack of API key
165 # But the validation logic is handled by langchain and pydantic
166 construct_model_from_args(user_args)
168 # save the model to the model storage
169 target = target or "embeddings"
170 user_args[
171 "target"
172 ] = target # this is the name of the column to store the embeddings
173 self.model_storage.json_set("args", user_args)
175 def predict(self, df: DataFrame, args) -> DataFrame:
176 # reconstruct the model from the model storage
177 user_args = self.model_storage.json_get("args")
178 model = construct_model_from_args(user_args)
180 # get the target from the model storage
181 target = user_args["target"]
182 # run the actual embedding vector generation
183 # TODO: need a better way to handle this
184 # unquote the column names -- removing surrounding `
185 cols_dfs = [col.strip("`") for col in df.columns.tolist()]
186 df.columns = cols_dfs
187 # if input_columns is an empty list, use all the columns
188 input_columns = user_args.get("input_columns") or df.columns.tolist()
189 # check all the input columns are in the df
190 if not all(
191 # ignore surrounding ` in the column names when checking
192 [col in cols_dfs for col in input_columns]
193 ):
194 raise Exception(
195 f"Input columns {input_columns} not found in the input dataframe. Available columns are {df.columns}"
196 )
198 # convert each row into a document
199 df_texts = df[input_columns].apply(row_to_document, axis=1)
200 embeddings = model.embed_documents(df_texts.tolist())
202 # create a new dataframe with the embeddings
203 df_embeddings = df.copy().assign(**{target: embeddings})
205 return df_embeddings
207 def finetune(
208 self, df: Union[DataFrame, None] = None, args: Union[Dict, None] = None
209 ) -> None:
210 raise NotImplementedError(
211 "Finetuning is not supported for langchain embeddings"
212 )
214 def describe(self, attribute: Union[str, None] = None) -> DataFrame:
215 args = self.model_storage.json_get("args")
217 if attribute == "args":
218 return pd.DataFrame(args.items(), columns=["key", "value"])
219 elif attribute == "metadata":
220 return pd.DataFrame(
221 [
222 ("model_class", self.model_storage.json_get("model_class")),
223 ],
224 columns=["key", "value"],
225 )
227 else:
228 tables = ("args", "metadata")
229 return pd.DataFrame(tables, columns=["tables"])