Coverage for mindsdb / integrations / handlers / rag_handler / rag.py: 0%
72 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import json
2from collections import defaultdict
3from typing import List
5from openai.types import Completion
7from mindsdb.integrations.handlers.rag_handler.settings import (
8 LLMLoader,
9 PersistedVectorStoreLoader,
10 PersistedVectorStoreLoaderConfig,
11 RAGBaseParameters,
12 RAGHandlerParameters,
13 load_embeddings_model,
14)
15from mindsdb.utilities import log
17logger = log.getLogger(__name__)
20class RAGQuestionAnswerer:
21 """A class for using a RAG model for question answering"""
23 def __init__(self, args: RAGBaseParameters):
25 self.output_data = defaultdict(list)
27 self.args = args
29 self.embeddings_model = args.embeddings_model
30 if self.embeddings_model is None:
31 self.embeddings_model = load_embeddings_model(args.embeddings_model_name)
33 self.persist_directory = args.vector_store_storage_path
35 self.collection_name = args.collection_name
37 vector_store_config = PersistedVectorStoreLoaderConfig(
38 vector_store_name=args.vector_store_name,
39 embeddings_model=self.embeddings_model,
40 persist_directory=self.persist_directory,
41 collection_name=self.collection_name,
42 )
44 self.vector_store_loader = PersistedVectorStoreLoader(vector_store_config)
46 self.persisted_vector_store = self.vector_store_loader.load_vector_store()
48 self.prompt_template = args.prompt_template
50 if isinstance(args, RAGHandlerParameters):
52 llm_config = {"llm_config": args.llm_params.model_dump()}
54 llm_loader = LLMLoader(**llm_config)
56 self.llm = llm_loader.load_llm()
58 def __call__(self, question: str) -> defaultdict:
59 return self.query(question)
61 def _prepare_prompt(self, vector_store_response, question) -> str:
63 context = [doc.page_content for doc in vector_store_response]
65 combined_context = "\n\n".join(context)
67 if self.args.summarize_context:
68 return self.summarize_context(combined_context, question)
70 return self.prompt_template.format(question=question, context=combined_context)
72 def summarize_context(self, combined_context: str, question: str) -> str:
74 summarization_prompt_template = self.args.summarization_prompt_template
76 summarization_prompt = summarization_prompt_template.format(
77 context=combined_context, question=question
78 )
80 summarized_context = self.llm(prompt=summarization_prompt)
82 return self.prompt_template.format(
83 question=question, context=self.extract_generated_text(summarized_context)
84 )
86 def query_vector_store(self, question: str) -> List:
88 return self.persisted_vector_store.similarity_search(
89 query=question,
90 k=self.args.top_k,
91 )
93 @staticmethod
94 def extract_generated_text(response: str) -> str:
95 """Extract generated text from LLM response"""
97 if isinstance(response, str):
98 data = json.loads(response)
99 else:
100 data = response
102 try:
103 if "choices" in data:
104 return data["choices"][0]["text"]
105 elif isinstance(data, Completion):
106 return data.choices[0].text
107 else:
108 logger.info(
109 f"Error extracting generated text: failed to parse response {response}"
110 )
111 return response
113 except Exception as e:
114 raise Exception(
115 f"{e} Error extracting generated text: failed to parse response {response}"
116 )
118 def query(self, question: str) -> defaultdict:
119 """Post process LLM response"""
120 llm_response, vector_store_response = self._query(question)
122 result = defaultdict(list)
123 extracted_text = self.extract_generated_text(llm_response)
124 result["answer"].append(extracted_text)
126 sources = defaultdict(list)
128 for idx, document in enumerate(vector_store_response):
129 sources["sources_content"].append(document.page_content)
130 sources["sources_document"].append(document.metadata.get("source", None))
131 sources["column"].append(document.metadata.get("column", None))
132 sources["sources_row"].append(document.metadata.get("row", None))
134 result["source_documents"].append(dict(sources))
136 return result
138 def _query(self, question: str):
139 logger.debug(f"Querying: {question}")
141 vector_store_response = self.query_vector_store(question)
143 formatted_prompt = self._prepare_prompt(vector_store_response, question)
145 llm_response = self.llm(prompt=formatted_prompt)
147 return llm_response, vector_store_response