Coverage for mindsdb / integrations / handlers / rag_handler / settings.py: 0%

257 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import json 

2from dataclasses import dataclass 

3from functools import lru_cache, partial 

4from typing import Any, Dict, List, Union, Optional 

5 

6import html2text 

7import openai 

8import pandas as pd 

9import requests 

10import writer 

11from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings 

12from langchain_community.llms import Writer 

13from langchain_community.document_loaders import DataFrameLoader 

14from langchain_community.vectorstores import FAISS, Chroma 

15from pydantic import BaseModel, Extra, Field, field_validator, ValidationInfo 

16 

17 

18from mindsdb.integrations.handlers.chromadb_handler.chromadb_handler import get_chromadb 

19from mindsdb.integrations.handlers.rag_handler.exceptions import ( 

20 InvalidOpenAIModel, 

21 InvalidPromptTemplate, 

22 InvalidWriterModel, 

23 UnsupportedLLM, 

24 UnsupportedVectorStore, 

25) 

26from langchain_core.callbacks import StreamingStdOutCallbackHandler 

27from langchain_core.documents import Document 

28from langchain_core.embeddings import Embeddings 

29from langchain_core.vectorstores import VectorStore 

30 

31DEFAULT_EMBEDDINGS_MODEL = "BAAI/bge-base-en" 

32 

33SUPPORTED_VECTOR_STORES = ("chromadb", "faiss") 

34 

35SUPPORTED_LLMS = ("writer", "openai") 

36 

37# Default parameters for RAG Handler 

38 

39# this is the default prompt template for qa 

40DEFAULT_QA_PROMPT_TEMPLATE = """ 

41Use the following pieces of context to answer the question at the end. If you do not know the answer, 

42just say that you do not know, do not try to make up an answer. 

43Context: {context} 

44Question: {question} 

45Helpful Answer:""" 

46 

47# this is the default prompt template for if the user wants to summarize the context before qa prompt 

48DEFAULT_SUMMARIZATION_PROMPT_TEMPLATE = """ 

49Summarize the following texts for me: 

50{context} 

51 

52When summarizing, please keep the following in mind the following question: 

53{question} 

54""" 

55 

56DEFAULT_CHUNK_SIZE = 750 

57DEFAULT_CHUNK_OVERLAP = 250 

58DEFAULT_VECTOR_STORE_NAME = "chromadb" 

59DEFAULT_VECTOR_STORE_COLLECTION_NAME = "collection" 

60MAX_EMBEDDINGS_BATCH_SIZE = 2000 

61 

62chromadb = get_chromadb() 

63 

64 

65def is_valid_store(name) -> bool: 

66 return name in SUPPORTED_VECTOR_STORES 

67 

68 

69class VectorStoreFactory: 

70 """Factory class for vector stores""" 

71 

72 @staticmethod 

73 def get_vectorstore_class(name) -> Union[FAISS, Chroma, VectorStore]: 

74 

75 if not isinstance(name, str): 

76 raise TypeError("name must be a string") 

77 

78 if not is_valid_store(name): 

79 raise ValueError(f"Invalid vector store {name}") 

80 

81 if name == "faiss": 

82 return FAISS 

83 

84 if name == "chromadb": 

85 return Chroma 

86 

87 

88def get_chroma_client(persist_directory: str) -> chromadb.PersistentClient: 

89 """Get Chroma client""" 

90 return chromadb.PersistentClient(path=persist_directory) 

91 

92 

93def get_available_writer_model_ids(args: ValidationInfo) -> list: 

94 """Get available writer LLM model ids""" 

95 

96 args = args.data 

97 

98 writer_client = writer.Writer( 

99 api_key=args["writer_api_key"], 

100 organization_id=args["writer_org_id"], 

101 ) 

102 

103 res = writer_client.models.list(organization_id=args["writer_org_id"]) 

104 

105 available_models_dict = json.loads(res.raw_response.text) 

106 

107 return [model["id"] for model in available_models_dict["models"]] 

108 

109 

110def get_available_openai_model_ids(args: ValidationInfo) -> list: 

111 """Get available openai LLM model ids""" 

112 

113 args = args.data 

114 

115 models = openai.OpenAI(api_key=args["openai_api_key"], base_url=args.get("base_url")).models.list().data 

116 

117 return [models.id for models in models] 

118 

119 

120@dataclass 

121class PersistedVectorStoreSaverConfig: 

122 vector_store_name: str 

123 persist_directory: str 

124 collection_name: str 

125 vector_store: VectorStore 

126 

127 

128@dataclass 

129class PersistedVectorStoreLoaderConfig: 

130 vector_store_name: str 

131 embeddings_model: Embeddings 

132 persist_directory: str 

133 collection_name: str 

134 

135 

136class PersistedVectorStoreSaver: 

137 """Saves vector store to disk""" 

138 

139 def __init__(self, config: PersistedVectorStoreSaverConfig): 

140 self.config = config 

141 

142 def save_vector_store(self, vector_store: VectorStore): 

143 method_name = f"save_{self.config.vector_store_name}" 

144 getattr(self, method_name)(vector_store) 

145 

146 def save_chromadb(self, vector_store: Chroma): 

147 """Save Chroma vector store to disk""" 

148 # no need to save chroma vector store to disk, auto save 

149 pass 

150 

151 def save_faiss(self, vector_store: FAISS): 

152 vector_store.save_local( 

153 folder_path=self.config.persist_directory, 

154 index_name=self.config.collection_name, 

155 ) 

156 

157 

158class PersistedVectorStoreLoader: 

159 """Loads vector store from disk""" 

160 

161 def __init__(self, config: PersistedVectorStoreLoaderConfig): 

162 self.config = config 

163 

164 def load_vector_store_client( 

165 self, 

166 vector_store: str, 

167 ): 

168 """Load vector store from the persisted vector store""" 

169 

170 if vector_store == "chromadb": 

171 

172 return Chroma( 

173 collection_name=self.config.collection_name, 

174 embedding_function=self.config.embeddings_model, 

175 client=get_chroma_client(self.config.persist_directory), 

176 ) 

177 

178 elif vector_store == "faiss": 

179 

180 return FAISS.load_local( 

181 folder_path=self.config.persist_directory, 

182 embeddings=self.config.embeddings_model, 

183 index_name=self.config.collection_name, 

184 allow_dangerous_deserialization=True 

185 ) 

186 

187 else: 

188 raise NotImplementedError(f"{vector_store} client is not yet supported") 

189 

190 def load_vector_store(self) -> VectorStore: 

191 """Load vector store from the persisted vector store""" 

192 method_name = f"load_{self.config.vector_store_name}" 

193 return getattr(self, method_name)() 

194 

195 def load_chromadb(self) -> Chroma: 

196 """Load Chroma vector store from the persisted vector store""" 

197 return self.load_vector_store_client(vector_store="chromadb") 

198 

199 def load_faiss(self) -> FAISS: 

200 """Load FAISS vector store from the persisted vector store""" 

201 return self.load_vector_store_client(vector_store="faiss") 

202 

203 

204class LLMParameters(BaseModel): 

205 """Model parameters for the LLM API interface""" 

206 

207 llm_name: str = Field(default_factory=str, title="LLM API name") 

208 max_tokens: int = Field(default=100, title="max tokens in response") 

209 temperature: float = Field(default=0.0, title="temperature") 

210 base_url: Optional[str] = None 

211 top_p: float = 1 

212 best_of: int = 5 

213 stop: Optional[List[str]] = None 

214 

215 class Config: 

216 extra = Extra.forbid 

217 arbitrary_types_allowed = True 

218 use_enum_values = True 

219 protected_namespaces = () 

220 

221 

222class OpenAIParameters(LLMParameters): 

223 """Model parameters for the LLM API interface""" 

224 

225 openai_api_key: str 

226 model_id: str = Field(default="gpt-3.5-turbo-instruct", title="model name") 

227 n: int = Field(default=1, title="number of responses to return") 

228 

229 @field_validator("model_id", mode="after") 

230 def openai_model_must_be_supported(cls, v, values): 

231 supported_models = get_available_openai_model_ids(values) 

232 if v not in supported_models: 

233 raise InvalidOpenAIModel( 

234 f"'model_id' must be one of {supported_models}, got {v}" 

235 ) 

236 return v 

237 

238 

239class WriterLLMParameters(LLMParameters): 

240 """Model parameters for the Writer LLM API interface""" 

241 

242 writer_api_key: str 

243 writer_org_id: Optional[str] = None 

244 model_id: str = "palmyra-x" 

245 callbacks: List[StreamingStdOutCallbackHandler] = [StreamingStdOutCallbackHandler()] 

246 verbose: bool = False 

247 

248 @field_validator("model_id") 

249 def writer_model_must_be_supported(cls, v, values): 

250 supported_models = get_available_writer_model_ids(values) 

251 if v not in supported_models: 

252 raise InvalidWriterModel( 

253 f"'model_id' must be one of {supported_models}, got {v}" 

254 ) 

255 return v 

256 

257 

258class LLMLoader(BaseModel): 

259 llm_config: dict 

260 config_dict: dict = None 

261 

262 def load_llm(self) -> Union[Writer, partial]: 

263 """Load LLM""" 

264 method_name = f"load_{self.llm_config['llm_name']}_llm" 

265 self.config_dict = self.llm_config.copy() 

266 self.config_dict.pop("llm_name") 

267 return getattr(self, method_name)() 

268 

269 def load_writer_llm(self) -> Writer: 

270 """Load Writer LLM API interface""" 

271 return Writer(**self.config_dict) 

272 

273 def load_openai_llm(self) -> partial: 

274 """Load OpenAI LLM API interface""" 

275 client = openai.OpenAI(api_key=self.config_dict["openai_api_key"], base_url=self.config_dict["base_url"]) 

276 config = self.config_dict.copy() 

277 keys_to_remove = ["openai_api_key", "base_url"] 

278 for key in keys_to_remove: 

279 config.pop(key) 

280 config["model"] = config.pop("model_id") 

281 

282 return partial(client.completions.create, **config) 

283 

284 

285class RAGBaseParameters(BaseModel): 

286 """Base model parameters for RAG Handler""" 

287 

288 llm_params: Any 

289 vector_store_folder_name: str 

290 input_column: str 

291 use_gpu: bool = False 

292 embeddings_batch_size: int = MAX_EMBEDDINGS_BATCH_SIZE # not used, leaving in place to prevent breaking changes 

293 prompt_template: str = DEFAULT_QA_PROMPT_TEMPLATE 

294 chunk_size: int = DEFAULT_CHUNK_SIZE 

295 chunk_overlap: int = DEFAULT_CHUNK_OVERLAP 

296 url: Optional[Union[str, List[str]]] = None 

297 url_column_name: Optional[str] = None 

298 run_embeddings: Optional[bool] = True 

299 top_k: int = 4 

300 embeddings_model: Optional[Embeddings] = None 

301 embeddings_model_name: str = DEFAULT_EMBEDDINGS_MODEL 

302 context_columns: Optional[Union[List[str], str]] = None 

303 vector_store_name: str = DEFAULT_VECTOR_STORE_NAME 

304 vector_store: Optional[VectorStore] = None 

305 collection_name: str = DEFAULT_VECTOR_STORE_COLLECTION_NAME 

306 summarize_context: bool = True 

307 summarization_prompt_template: str = DEFAULT_SUMMARIZATION_PROMPT_TEMPLATE 

308 vector_store_storage_path: Optional[str] = Field( 

309 default=None, title="don't use this field, it's for internal use only" 

310 ) 

311 

312 class Config: 

313 extra = Extra.forbid 

314 arbitrary_types_allowed = True 

315 use_enum_values = True 

316 

317 @field_validator("prompt_template") 

318 def prompt_format_must_be_valid(cls, v): 

319 if "{context}" not in v or "{question}" not in v: 

320 raise InvalidPromptTemplate( 

321 "prompt_template must contain {context} and {question}" 

322 f"\n For example, {DEFAULT_QA_PROMPT_TEMPLATE}" 

323 ) 

324 return v 

325 

326 @field_validator("vector_store_name") 

327 def name_must_be_lower(cls, v): 

328 return v.lower() 

329 

330 @field_validator("vector_store_name") 

331 def vector_store_must_be_supported(cls, v): 

332 if not is_valid_store(v): 

333 raise UnsupportedVectorStore( 

334 f"we don't support {v}. currently we only support {', '.join(str(v) for v in SUPPORTED_VECTOR_STORES)} vector store" 

335 ) 

336 return v 

337 

338 

339class RAGHandlerParameters(RAGBaseParameters): 

340 """Model parameters for create model""" 

341 

342 llm_type: str 

343 llm_params: LLMParameters 

344 

345 @field_validator("llm_type") 

346 def llm_type_must_be_supported(cls, v): 

347 if v not in SUPPORTED_LLMS: 

348 raise UnsupportedLLM(f"'llm_type' must be one of {SUPPORTED_LLMS}, got {v}") 

349 return v 

350 

351 

352class DfLoader(DataFrameLoader): 

353 """ 

354 override the load method of langchain.document_loaders.DataFrameLoaders to ignore rows with 'None' values 

355 """ 

356 

357 def __init__(self, data_frame: pd.DataFrame, page_content_column: str): 

358 super().__init__(data_frame=data_frame, page_content_column=page_content_column) 

359 self._data_frame = data_frame 

360 self._page_content_column = page_content_column 

361 

362 def load(self) -> List[Document]: 

363 """Loads the dataframe as a list of documents""" 

364 documents = [] 

365 for n_row, frame in self._data_frame[self._page_content_column].items(): 

366 if pd.notnull(frame): 

367 # ignore rows with None values 

368 column_name = self._page_content_column 

369 

370 document_contents = frame 

371 

372 documents.append( 

373 Document( 

374 page_content=document_contents, 

375 metadata={ 

376 "source": "dataframe", 

377 "row": n_row, 

378 "column": column_name, 

379 }, 

380 ) 

381 ) 

382 return documents 

383 

384 

385def df_to_documents( 

386 df: pd.DataFrame, 

387 page_content_columns: Union[List[str], str], 

388 url_column_name: str = None, 

389) -> List[Document]: 

390 """Converts a given dataframe to a list of documents""" 

391 documents = [] 

392 

393 if isinstance(page_content_columns, str): 

394 page_content_columns = [page_content_columns] 

395 

396 for _, page_content_column in enumerate(page_content_columns): 

397 if page_content_column not in df.columns.tolist(): 

398 raise ValueError( 

399 f"page_content_column {page_content_column} not in dataframe columns" 

400 ) 

401 if url_column_name is not None and page_content_column == url_column_name: 

402 documents.extend(url_to_documents(df[page_content_column].tolist())) 

403 continue 

404 

405 loader = DfLoader(data_frame=df, page_content_column=page_content_column) 

406 documents.extend(loader.load()) 

407 

408 return documents 

409 

410 

411def url_to_documents(urls: Union[List[str], str]) -> List[Document]: 

412 """Converts a given url to a document""" 

413 documents = [] 

414 if isinstance(urls, str): 

415 urls = [urls] 

416 

417 for url in urls: 

418 response = requests.get(url, headers=None).text 

419 html_to_text = html2text.html2text(response) 

420 documents.append(Document(page_content=html_to_text, metadata={"source": url})) 

421 

422 return documents 

423 

424 

425@lru_cache() 

426def load_embeddings_model(embeddings_model_name, use_gpu=False): 

427 """Load embeddings model from Hugging Face Hub""" 

428 try: 

429 model_kwargs = dict(device="cuda" if use_gpu else "cpu") 

430 embedding_model = HuggingFaceEmbeddings( 

431 model_name=embeddings_model_name, model_kwargs=model_kwargs 

432 ) 

433 except ValueError: 

434 raise ValueError( 

435 f"The {embeddings_model_name} is not supported, please select a valid option from Hugging Face Hub!" 

436 ) 

437 return embedding_model 

438 

439 

440def on_create_build_llm_params( 

441 args: dict, llm_config_class: Union[WriterLLMParameters, OpenAIParameters] 

442) -> Dict: 

443 """build llm params from create args""" 

444 

445 llm_params = {"llm_name": args["llm_type"]} 

446 

447 for param in llm_config_class.model_fields.keys(): 

448 if param in args: 

449 llm_params[param] = args.pop(param) 

450 

451 return llm_params 

452 

453 

454def build_llm_params(args: dict, update=False) -> Dict: 

455 """build llm params from args""" 

456 

457 if args["llm_type"] == "writer": 

458 llm_config_class = WriterLLMParameters 

459 elif args["llm_type"] == "openai": 

460 llm_config_class = OpenAIParameters 

461 else: 

462 raise UnsupportedLLM( 

463 f"'llm_type' must be one of {SUPPORTED_LLMS}, got {args['llm_type']}" 

464 ) 

465 

466 if not args.get("llm_params"): 

467 # for create method only 

468 llm_params = on_create_build_llm_params(args, llm_config_class) 

469 else: 

470 # for predict method only 

471 llm_params = args.pop("llm_params") 

472 if update: 

473 # for update method only 

474 args["llm_params"] = llm_params 

475 return args 

476 

477 args["llm_params"] = llm_config_class(**llm_params) 

478 

479 return args