Coverage for mindsdb / integrations / handlers / rag_handler / settings.py: 0%
257 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import json
2from dataclasses import dataclass
3from functools import lru_cache, partial
4from typing import Any, Dict, List, Union, Optional
6import html2text
7import openai
8import pandas as pd
9import requests
10import writer
11from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings
12from langchain_community.llms import Writer
13from langchain_community.document_loaders import DataFrameLoader
14from langchain_community.vectorstores import FAISS, Chroma
15from pydantic import BaseModel, Extra, Field, field_validator, ValidationInfo
18from mindsdb.integrations.handlers.chromadb_handler.chromadb_handler import get_chromadb
19from mindsdb.integrations.handlers.rag_handler.exceptions import (
20 InvalidOpenAIModel,
21 InvalidPromptTemplate,
22 InvalidWriterModel,
23 UnsupportedLLM,
24 UnsupportedVectorStore,
25)
26from langchain_core.callbacks import StreamingStdOutCallbackHandler
27from langchain_core.documents import Document
28from langchain_core.embeddings import Embeddings
29from langchain_core.vectorstores import VectorStore
31DEFAULT_EMBEDDINGS_MODEL = "BAAI/bge-base-en"
33SUPPORTED_VECTOR_STORES = ("chromadb", "faiss")
35SUPPORTED_LLMS = ("writer", "openai")
37# Default parameters for RAG Handler
39# this is the default prompt template for qa
40DEFAULT_QA_PROMPT_TEMPLATE = """
41Use the following pieces of context to answer the question at the end. If you do not know the answer,
42just say that you do not know, do not try to make up an answer.
43Context: {context}
44Question: {question}
45Helpful Answer:"""
47# this is the default prompt template for if the user wants to summarize the context before qa prompt
48DEFAULT_SUMMARIZATION_PROMPT_TEMPLATE = """
49Summarize the following texts for me:
50{context}
52When summarizing, please keep the following in mind the following question:
53{question}
54"""
56DEFAULT_CHUNK_SIZE = 750
57DEFAULT_CHUNK_OVERLAP = 250
58DEFAULT_VECTOR_STORE_NAME = "chromadb"
59DEFAULT_VECTOR_STORE_COLLECTION_NAME = "collection"
60MAX_EMBEDDINGS_BATCH_SIZE = 2000
62chromadb = get_chromadb()
65def is_valid_store(name) -> bool:
66 return name in SUPPORTED_VECTOR_STORES
69class VectorStoreFactory:
70 """Factory class for vector stores"""
72 @staticmethod
73 def get_vectorstore_class(name) -> Union[FAISS, Chroma, VectorStore]:
75 if not isinstance(name, str):
76 raise TypeError("name must be a string")
78 if not is_valid_store(name):
79 raise ValueError(f"Invalid vector store {name}")
81 if name == "faiss":
82 return FAISS
84 if name == "chromadb":
85 return Chroma
88def get_chroma_client(persist_directory: str) -> chromadb.PersistentClient:
89 """Get Chroma client"""
90 return chromadb.PersistentClient(path=persist_directory)
93def get_available_writer_model_ids(args: ValidationInfo) -> list:
94 """Get available writer LLM model ids"""
96 args = args.data
98 writer_client = writer.Writer(
99 api_key=args["writer_api_key"],
100 organization_id=args["writer_org_id"],
101 )
103 res = writer_client.models.list(organization_id=args["writer_org_id"])
105 available_models_dict = json.loads(res.raw_response.text)
107 return [model["id"] for model in available_models_dict["models"]]
110def get_available_openai_model_ids(args: ValidationInfo) -> list:
111 """Get available openai LLM model ids"""
113 args = args.data
115 models = openai.OpenAI(api_key=args["openai_api_key"], base_url=args.get("base_url")).models.list().data
117 return [models.id for models in models]
120@dataclass
121class PersistedVectorStoreSaverConfig:
122 vector_store_name: str
123 persist_directory: str
124 collection_name: str
125 vector_store: VectorStore
128@dataclass
129class PersistedVectorStoreLoaderConfig:
130 vector_store_name: str
131 embeddings_model: Embeddings
132 persist_directory: str
133 collection_name: str
136class PersistedVectorStoreSaver:
137 """Saves vector store to disk"""
139 def __init__(self, config: PersistedVectorStoreSaverConfig):
140 self.config = config
142 def save_vector_store(self, vector_store: VectorStore):
143 method_name = f"save_{self.config.vector_store_name}"
144 getattr(self, method_name)(vector_store)
146 def save_chromadb(self, vector_store: Chroma):
147 """Save Chroma vector store to disk"""
148 # no need to save chroma vector store to disk, auto save
149 pass
151 def save_faiss(self, vector_store: FAISS):
152 vector_store.save_local(
153 folder_path=self.config.persist_directory,
154 index_name=self.config.collection_name,
155 )
158class PersistedVectorStoreLoader:
159 """Loads vector store from disk"""
161 def __init__(self, config: PersistedVectorStoreLoaderConfig):
162 self.config = config
164 def load_vector_store_client(
165 self,
166 vector_store: str,
167 ):
168 """Load vector store from the persisted vector store"""
170 if vector_store == "chromadb":
172 return Chroma(
173 collection_name=self.config.collection_name,
174 embedding_function=self.config.embeddings_model,
175 client=get_chroma_client(self.config.persist_directory),
176 )
178 elif vector_store == "faiss":
180 return FAISS.load_local(
181 folder_path=self.config.persist_directory,
182 embeddings=self.config.embeddings_model,
183 index_name=self.config.collection_name,
184 allow_dangerous_deserialization=True
185 )
187 else:
188 raise NotImplementedError(f"{vector_store} client is not yet supported")
190 def load_vector_store(self) -> VectorStore:
191 """Load vector store from the persisted vector store"""
192 method_name = f"load_{self.config.vector_store_name}"
193 return getattr(self, method_name)()
195 def load_chromadb(self) -> Chroma:
196 """Load Chroma vector store from the persisted vector store"""
197 return self.load_vector_store_client(vector_store="chromadb")
199 def load_faiss(self) -> FAISS:
200 """Load FAISS vector store from the persisted vector store"""
201 return self.load_vector_store_client(vector_store="faiss")
204class LLMParameters(BaseModel):
205 """Model parameters for the LLM API interface"""
207 llm_name: str = Field(default_factory=str, title="LLM API name")
208 max_tokens: int = Field(default=100, title="max tokens in response")
209 temperature: float = Field(default=0.0, title="temperature")
210 base_url: Optional[str] = None
211 top_p: float = 1
212 best_of: int = 5
213 stop: Optional[List[str]] = None
215 class Config:
216 extra = Extra.forbid
217 arbitrary_types_allowed = True
218 use_enum_values = True
219 protected_namespaces = ()
222class OpenAIParameters(LLMParameters):
223 """Model parameters for the LLM API interface"""
225 openai_api_key: str
226 model_id: str = Field(default="gpt-3.5-turbo-instruct", title="model name")
227 n: int = Field(default=1, title="number of responses to return")
229 @field_validator("model_id", mode="after")
230 def openai_model_must_be_supported(cls, v, values):
231 supported_models = get_available_openai_model_ids(values)
232 if v not in supported_models:
233 raise InvalidOpenAIModel(
234 f"'model_id' must be one of {supported_models}, got {v}"
235 )
236 return v
239class WriterLLMParameters(LLMParameters):
240 """Model parameters for the Writer LLM API interface"""
242 writer_api_key: str
243 writer_org_id: Optional[str] = None
244 model_id: str = "palmyra-x"
245 callbacks: List[StreamingStdOutCallbackHandler] = [StreamingStdOutCallbackHandler()]
246 verbose: bool = False
248 @field_validator("model_id")
249 def writer_model_must_be_supported(cls, v, values):
250 supported_models = get_available_writer_model_ids(values)
251 if v not in supported_models:
252 raise InvalidWriterModel(
253 f"'model_id' must be one of {supported_models}, got {v}"
254 )
255 return v
258class LLMLoader(BaseModel):
259 llm_config: dict
260 config_dict: dict = None
262 def load_llm(self) -> Union[Writer, partial]:
263 """Load LLM"""
264 method_name = f"load_{self.llm_config['llm_name']}_llm"
265 self.config_dict = self.llm_config.copy()
266 self.config_dict.pop("llm_name")
267 return getattr(self, method_name)()
269 def load_writer_llm(self) -> Writer:
270 """Load Writer LLM API interface"""
271 return Writer(**self.config_dict)
273 def load_openai_llm(self) -> partial:
274 """Load OpenAI LLM API interface"""
275 client = openai.OpenAI(api_key=self.config_dict["openai_api_key"], base_url=self.config_dict["base_url"])
276 config = self.config_dict.copy()
277 keys_to_remove = ["openai_api_key", "base_url"]
278 for key in keys_to_remove:
279 config.pop(key)
280 config["model"] = config.pop("model_id")
282 return partial(client.completions.create, **config)
285class RAGBaseParameters(BaseModel):
286 """Base model parameters for RAG Handler"""
288 llm_params: Any
289 vector_store_folder_name: str
290 input_column: str
291 use_gpu: bool = False
292 embeddings_batch_size: int = MAX_EMBEDDINGS_BATCH_SIZE # not used, leaving in place to prevent breaking changes
293 prompt_template: str = DEFAULT_QA_PROMPT_TEMPLATE
294 chunk_size: int = DEFAULT_CHUNK_SIZE
295 chunk_overlap: int = DEFAULT_CHUNK_OVERLAP
296 url: Optional[Union[str, List[str]]] = None
297 url_column_name: Optional[str] = None
298 run_embeddings: Optional[bool] = True
299 top_k: int = 4
300 embeddings_model: Optional[Embeddings] = None
301 embeddings_model_name: str = DEFAULT_EMBEDDINGS_MODEL
302 context_columns: Optional[Union[List[str], str]] = None
303 vector_store_name: str = DEFAULT_VECTOR_STORE_NAME
304 vector_store: Optional[VectorStore] = None
305 collection_name: str = DEFAULT_VECTOR_STORE_COLLECTION_NAME
306 summarize_context: bool = True
307 summarization_prompt_template: str = DEFAULT_SUMMARIZATION_PROMPT_TEMPLATE
308 vector_store_storage_path: Optional[str] = Field(
309 default=None, title="don't use this field, it's for internal use only"
310 )
312 class Config:
313 extra = Extra.forbid
314 arbitrary_types_allowed = True
315 use_enum_values = True
317 @field_validator("prompt_template")
318 def prompt_format_must_be_valid(cls, v):
319 if "{context}" not in v or "{question}" not in v:
320 raise InvalidPromptTemplate(
321 "prompt_template must contain {context} and {question}"
322 f"\n For example, {DEFAULT_QA_PROMPT_TEMPLATE}"
323 )
324 return v
326 @field_validator("vector_store_name")
327 def name_must_be_lower(cls, v):
328 return v.lower()
330 @field_validator("vector_store_name")
331 def vector_store_must_be_supported(cls, v):
332 if not is_valid_store(v):
333 raise UnsupportedVectorStore(
334 f"we don't support {v}. currently we only support {', '.join(str(v) for v in SUPPORTED_VECTOR_STORES)} vector store"
335 )
336 return v
339class RAGHandlerParameters(RAGBaseParameters):
340 """Model parameters for create model"""
342 llm_type: str
343 llm_params: LLMParameters
345 @field_validator("llm_type")
346 def llm_type_must_be_supported(cls, v):
347 if v not in SUPPORTED_LLMS:
348 raise UnsupportedLLM(f"'llm_type' must be one of {SUPPORTED_LLMS}, got {v}")
349 return v
352class DfLoader(DataFrameLoader):
353 """
354 override the load method of langchain.document_loaders.DataFrameLoaders to ignore rows with 'None' values
355 """
357 def __init__(self, data_frame: pd.DataFrame, page_content_column: str):
358 super().__init__(data_frame=data_frame, page_content_column=page_content_column)
359 self._data_frame = data_frame
360 self._page_content_column = page_content_column
362 def load(self) -> List[Document]:
363 """Loads the dataframe as a list of documents"""
364 documents = []
365 for n_row, frame in self._data_frame[self._page_content_column].items():
366 if pd.notnull(frame):
367 # ignore rows with None values
368 column_name = self._page_content_column
370 document_contents = frame
372 documents.append(
373 Document(
374 page_content=document_contents,
375 metadata={
376 "source": "dataframe",
377 "row": n_row,
378 "column": column_name,
379 },
380 )
381 )
382 return documents
385def df_to_documents(
386 df: pd.DataFrame,
387 page_content_columns: Union[List[str], str],
388 url_column_name: str = None,
389) -> List[Document]:
390 """Converts a given dataframe to a list of documents"""
391 documents = []
393 if isinstance(page_content_columns, str):
394 page_content_columns = [page_content_columns]
396 for _, page_content_column in enumerate(page_content_columns):
397 if page_content_column not in df.columns.tolist():
398 raise ValueError(
399 f"page_content_column {page_content_column} not in dataframe columns"
400 )
401 if url_column_name is not None and page_content_column == url_column_name:
402 documents.extend(url_to_documents(df[page_content_column].tolist()))
403 continue
405 loader = DfLoader(data_frame=df, page_content_column=page_content_column)
406 documents.extend(loader.load())
408 return documents
411def url_to_documents(urls: Union[List[str], str]) -> List[Document]:
412 """Converts a given url to a document"""
413 documents = []
414 if isinstance(urls, str):
415 urls = [urls]
417 for url in urls:
418 response = requests.get(url, headers=None).text
419 html_to_text = html2text.html2text(response)
420 documents.append(Document(page_content=html_to_text, metadata={"source": url}))
422 return documents
425@lru_cache()
426def load_embeddings_model(embeddings_model_name, use_gpu=False):
427 """Load embeddings model from Hugging Face Hub"""
428 try:
429 model_kwargs = dict(device="cuda" if use_gpu else "cpu")
430 embedding_model = HuggingFaceEmbeddings(
431 model_name=embeddings_model_name, model_kwargs=model_kwargs
432 )
433 except ValueError:
434 raise ValueError(
435 f"The {embeddings_model_name} is not supported, please select a valid option from Hugging Face Hub!"
436 )
437 return embedding_model
440def on_create_build_llm_params(
441 args: dict, llm_config_class: Union[WriterLLMParameters, OpenAIParameters]
442) -> Dict:
443 """build llm params from create args"""
445 llm_params = {"llm_name": args["llm_type"]}
447 for param in llm_config_class.model_fields.keys():
448 if param in args:
449 llm_params[param] = args.pop(param)
451 return llm_params
454def build_llm_params(args: dict, update=False) -> Dict:
455 """build llm params from args"""
457 if args["llm_type"] == "writer":
458 llm_config_class = WriterLLMParameters
459 elif args["llm_type"] == "openai":
460 llm_config_class = OpenAIParameters
461 else:
462 raise UnsupportedLLM(
463 f"'llm_type' must be one of {SUPPORTED_LLMS}, got {args['llm_type']}"
464 )
466 if not args.get("llm_params"):
467 # for create method only
468 llm_params = on_create_build_llm_params(args, llm_config_class)
469 else:
470 # for predict method only
471 llm_params = args.pop("llm_params")
472 if update:
473 # for update method only
474 args["llm_params"] = llm_params
475 return args
477 args["llm_params"] = llm_config_class(**llm_params)
479 return args