Coverage for mindsdb / integrations / handlers / langchain_embedding_handler / langchain_embedding_handler.py: 28%

96 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import copy 

2import importlib 

3from typing import Dict, Union 

4 

5import pandas as pd 

6from pandas import DataFrame 

7from pydantic import BaseModel 

8 

9from mindsdb.integrations.libs.base import BaseMLEngine 

10from mindsdb.utilities import log 

11from langchain_core.embeddings import Embeddings 

12from mindsdb.integrations.handlers.langchain_embedding_handler.vllm_embeddings import VLLMEmbeddings 

13from mindsdb.integrations.handlers.langchain_embedding_handler.fastapi_embeddings import FastAPIEmbeddings 

14 

15logger = log.getLogger(__name__) 

16 

17# construct the embedding model name to the class mapping 

18# we try to import all embedding models from langchain_community.embeddings 

19# for each class, we get a more user friendly name for it 

20# E.g. OpenAIEmbeddings -> OpenAI 

21# This is used for the user to select the embedding model 

22EMBEDDING_MODELS = { 

23 'VLLM': 'VLLMEmbeddings', 

24 'vllm': 'VLLMEmbeddings', 

25 'FastAPI': 'FastAPIEmbeddings', 

26 'fastapi': 'FastAPIEmbeddings' 

27 

28} 

29 

30try: 

31 module = importlib.import_module("langchain_community.embeddings") 

32 # iterate __all__ to get all the classes 

33 for class_name in module.__all__: 

34 class_ = getattr(module, class_name) 

35 if not issubclass(class_, Embeddings): 35 ↛ 36line 35 didn't jump to line 36 because the condition on line 35 was never true

36 continue 

37 # convert the class name to a more user friendly name 

38 # e.g. OpenAIEmbeddings -> OpenAI 

39 user_friendly_name = class_name.replace("Embeddings", "") 

40 EMBEDDING_MODELS[user_friendly_name] = class_name 

41 EMBEDDING_MODELS[user_friendly_name.lower()] = class_name 

42 

43except ImportError: 

44 raise Exception( 

45 "The langchain is not installed. Please install it with `pip install langchain-community`." 

46 ) 

47 

48 

49def get_langchain_class(class_name: str) -> Embeddings: 

50 """Returns the class object of the handler class. 

51 

52 Args: 

53 class_name (str): Name of the class 

54 

55 Returns: 

56 langchain.embeddings.BaseEmbedding: The class object 

57 """ 

58 # First check if it's our custom VLLMEmbeddings 

59 if class_name == "VLLMEmbeddings": 

60 return VLLMEmbeddings 

61 

62 if class_name == "FastAPIEmbeddings": 

63 return FastAPIEmbeddings 

64 

65 # Then try langchain_community.embeddings 

66 try: 

67 module = importlib.import_module("langchain_community.embeddings") 

68 class_ = getattr(module, class_name) 

69 except ImportError: 

70 raise Exception( 

71 "The langchain is not installed. Please install it with `pip install langchain`." 

72 ) 

73 except AttributeError: 

74 raise Exception( 

75 f"Could not find the class {class_name} in langchain_community.embeddings. Please check the class name." 

76 ) 

77 return class_ 

78 

79 

80def construct_model_from_args(args: Dict) -> Embeddings: 

81 """ 

82 Deserializes the model from the model storage 

83 """ 

84 target = args.pop("target", None) 

85 class_name = args.pop("class", LangchainEmbeddingHandler.DEFAULT_EMBEDDING_CLASS) 

86 if class_name in EMBEDDING_MODELS: 

87 logger.info( 

88 f"Mapping the user friendly name {class_name} to the class name: {EMBEDDING_MODELS[class_name]}" 

89 ) 

90 class_name = EMBEDDING_MODELS[class_name] 

91 MODEL_CLASS = get_langchain_class(class_name) 

92 serialized_dict = copy.deepcopy(args) 

93 

94 # Make sure we don't pass in unnecessary arguments. 

95 if issubclass(MODEL_CLASS, BaseModel): 

96 serialized_dict = { 

97 k: v for k, v in serialized_dict.items() if k in MODEL_CLASS.model_fields 

98 } 

99 

100 model = MODEL_CLASS(**serialized_dict) 

101 if target is not None: 

102 args["target"] = target 

103 args["class"] = class_name 

104 return model 

105 

106 

107def row_to_document(row: pd.Series) -> str: 

108 """ 

109 Convert a row in the input dataframe into a document 

110 

111 Default implementation is to concatenate all the columns 

112 in the form of 

113 field1: value1\nfield2: value2\n... 

114 """ 

115 fields = row.index.tolist() 

116 values = row.values.tolist() 

117 document = "\n".join( 

118 [f"{field}: {value}" for field, value in zip(fields, values)] 

119 ) 

120 return document 

121 

122 

123class LangchainEmbeddingHandler(BaseMLEngine): 

124 """ 

125 Bridge class to connect langchain.embeddings module to mindsDB 

126 """ 

127 

128 DEFAULT_EMBEDDING_CLASS = "OpenAIEmbeddings" 

129 

130 def __init__(self, model_storage, engine_storage, **kwargs) -> None: 

131 super().__init__(model_storage, engine_storage, **kwargs) 

132 self.generative = True 

133 

134 def create( 

135 self, 

136 target: str, 

137 df: Union[DataFrame, None] = None, 

138 args: Union[Dict, None] = None, 

139 ) -> None: 

140 # get the class name from the args 

141 user_args = args.get("using", {}) 

142 

143 # infer the input columns arg if user did not provide it 

144 # from the columns of the input dataframe if it is provided 

145 if "input_columns" not in user_args and df is not None: 

146 # ignore private columns starts with __mindsdb 

147 # ignore target column in the input dataframe 

148 user_args["input_columns"] = [ 

149 col 

150 for col in df.columns.tolist() 

151 if not col.startswith("__mindsdb") and col != target 

152 ] 

153 # unquote the column names -- removing surrounding ` 

154 user_args["input_columns"] = [ 

155 col.strip("`") for col in user_args["input_columns"] 

156 ] 

157 

158 elif "input_columns" not in user_args: 

159 # set as empty list if the input_columns is not provided 

160 user_args["input_columns"] = [] 

161 

162 # this may raise an exception if 

163 # the arguments are not sufficient to create such as class 

164 # due to e.g., lack of API key 

165 # But the validation logic is handled by langchain and pydantic 

166 construct_model_from_args(user_args) 

167 

168 # save the model to the model storage 

169 target = target or "embeddings" 

170 user_args[ 

171 "target" 

172 ] = target # this is the name of the column to store the embeddings 

173 self.model_storage.json_set("args", user_args) 

174 

175 def predict(self, df: DataFrame, args) -> DataFrame: 

176 # reconstruct the model from the model storage 

177 user_args = self.model_storage.json_get("args") 

178 model = construct_model_from_args(user_args) 

179 

180 # get the target from the model storage 

181 target = user_args["target"] 

182 # run the actual embedding vector generation 

183 # TODO: need a better way to handle this 

184 # unquote the column names -- removing surrounding ` 

185 cols_dfs = [col.strip("`") for col in df.columns.tolist()] 

186 df.columns = cols_dfs 

187 # if input_columns is an empty list, use all the columns 

188 input_columns = user_args.get("input_columns") or df.columns.tolist() 

189 # check all the input columns are in the df 

190 if not all( 

191 # ignore surrounding ` in the column names when checking 

192 [col in cols_dfs for col in input_columns] 

193 ): 

194 raise Exception( 

195 f"Input columns {input_columns} not found in the input dataframe. Available columns are {df.columns}" 

196 ) 

197 

198 # convert each row into a document 

199 df_texts = df[input_columns].apply(row_to_document, axis=1) 

200 embeddings = model.embed_documents(df_texts.tolist()) 

201 

202 # create a new dataframe with the embeddings 

203 df_embeddings = df.copy().assign(**{target: embeddings}) 

204 

205 return df_embeddings 

206 

207 def finetune( 

208 self, df: Union[DataFrame, None] = None, args: Union[Dict, None] = None 

209 ) -> None: 

210 raise NotImplementedError( 

211 "Finetuning is not supported for langchain embeddings" 

212 ) 

213 

214 def describe(self, attribute: Union[str, None] = None) -> DataFrame: 

215 args = self.model_storage.json_get("args") 

216 

217 if attribute == "args": 

218 return pd.DataFrame(args.items(), columns=["key", "value"]) 

219 elif attribute == "metadata": 

220 return pd.DataFrame( 

221 [ 

222 ("model_class", self.model_storage.json_get("model_class")), 

223 ], 

224 columns=["key", "value"], 

225 ) 

226 

227 else: 

228 tables = ("args", "metadata") 

229 return pd.DataFrame(tables, columns=["tables"])