Coverage for mindsdb / integrations / utilities / rag / pipelines / rag.py: 18%
139 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from copy import copy
2from typing import Optional, Any, List
4from langchain_core.output_parsers import StrOutputParser
5from langchain.retrievers import ContextualCompressionRetriever
6from langchain_core.documents import Document
8from langchain_core.prompts import ChatPromptTemplate
9from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableSerializable
11from mindsdb.integrations.handlers.langchain_embedding_handler.langchain_embedding_handler import construct_model_from_args
12from mindsdb.integrations.libs.vectordatabase_handler import DistanceFunction
13from mindsdb.integrations.utilities.rag.chains.map_reduce_summarizer_chain import MapReduceSummarizerChain
14from mindsdb.integrations.utilities.rag.retrievers.auto_retriever import AutoRetriever
15from mindsdb.integrations.utilities.rag.retrievers.multi_vector_retriever import MultiVectorRetriever
16from mindsdb.integrations.utilities.rag.retrievers.sql_retriever import SQLRetriever
17from mindsdb.integrations.utilities.rag.rerankers.reranker_compressor import LLMReranker
18from mindsdb.integrations.utilities.rag.settings import (RAGPipelineModel,
19 DEFAULT_AUTO_META_PROMPT_TEMPLATE,
20 SearchKwargs, SearchType,
21 RerankerConfig,
22 SummarizationConfig, VectorStoreConfig)
23from mindsdb.integrations.utilities.rag.settings import DEFAULT_RERANKER_FLAG
25from mindsdb.integrations.utilities.rag.vector_store import VectorStoreOperator
26from mindsdb.interfaces.agents.langchain_agent import create_chat_model
29class LangChainRAGPipeline:
30 """
31 Builds a RAG pipeline using langchain LCEL components
33 Args:
34 retriever_runnable: Base retriever component
35 prompt_template: Template for generating responses
36 llm: Language model for generating responses
37 reranker (bool): Whether to use reranking (default: False)
38 reranker_config (RerankerConfig): Configuration for the reranker, including:
39 - model: Model to use for reranking
40 - filtering_threshold: Minimum score to keep a document
41 - num_docs_to_keep: Maximum number of documents to keep
42 - max_concurrent_requests: Maximum concurrent API requests
43 - max_retries: Number of retry attempts for failed requests
44 - retry_delay: Delay between retries
45 - early_stop (bool): Whether to enable early stopping
46 - early_stop_threshold: Confidence threshold for early stopping
47 vector_store_config (VectorStoreConfig): Vector store configuration
48 summarization_config (SummarizationConfig): Summarization configuration
49 """
51 def __init__(
52 self,
53 retriever_runnable,
54 prompt_template,
55 llm,
56 reranker: bool = DEFAULT_RERANKER_FLAG,
57 reranker_config: Optional[RerankerConfig] = None,
58 vector_store_config: Optional[VectorStoreConfig] = None,
59 summarization_config: Optional[SummarizationConfig] = None
60 ):
61 self.retriever_runnable = retriever_runnable
62 self.prompt_template = prompt_template
63 self.llm = llm
64 if reranker:
65 if reranker_config is None:
66 reranker_config = RerankerConfig()
67 # Convert config to dict and initialize reranker
68 reranker_kwargs = reranker_config.model_dump(exclude_none=True)
69 self.reranker = LLMReranker(**reranker_kwargs)
70 else:
71 self.reranker = None
72 self.summarizer = None
73 self.vector_store_config = vector_store_config
74 knowledge_base_table = self.vector_store_config.kb_table if self.vector_store_config is not None else None
75 if summarization_config is not None and knowledge_base_table is not None:
76 self.summarizer = MapReduceSummarizerChain(
77 vector_store_handler=knowledge_base_table.get_vector_db(),
78 table_name=knowledge_base_table.get_vector_db_table_name(),
79 summarization_config=summarization_config
80 )
82 def with_returned_sources(self) -> RunnableSerializable:
83 """
84 Builds a RAG pipeline with returned sources
85 :return:
86 """
88 def format_docs(docs):
89 if isinstance(docs, str):
90 # this is to handle the case where the retriever returns a string
91 # instead of a list of documents e.g. SQLRetriever
92 return docs
93 if not docs:
94 return ''
95 # Sort by original document so we can group source summaries together.
96 docs.sort(key=lambda d: d.metadata.get('original_row_id') if d.metadata else 0)
97 original_document_id = None
98 summary_prepended_text = 'Summary of the original document that the below context was taken from:\n'
99 document_content = ''
100 for d in docs:
101 metadata = d.metadata or {}
102 if metadata.get('original_row_id') != original_document_id and metadata.get('summary'):
103 # We have a summary of a new document to prepend.
104 original_document_id = metadata.get('original_row_id')
105 summary = f"{summary_prepended_text}{metadata.get('summary')}\n"
106 document_content += summary
107 document_content += f'{d.page_content}\n\n'
108 return document_content
110 prompt = ChatPromptTemplate.from_template(self.prompt_template)
112 # Ensure all the required components are not None
113 if prompt is None:
114 raise ValueError("One of the required components (prompt) is None")
115 if self.llm is None:
116 raise ValueError("One of the required components (llm) is None")
118 if self.reranker:
119 # Create a custom retriever that handles async operations properly
120 class AsyncRerankerRetriever(ContextualCompressionRetriever):
121 """Async-aware retriever that properly handles concurrent reranking operations."""
123 def __init__(self, base_retriever, reranker):
124 super().__init__(
125 base_compressor=reranker,
126 base_retriever=base_retriever
127 )
129 async def ainvoke(self, query: str) -> List[Document]:
130 """Async retrieval with proper concurrency handling."""
131 # Get initial documents
132 if hasattr(self.base_retriever, 'ainvoke'):
133 docs = await self.base_retriever.ainvoke(query)
134 else:
135 docs = await RunnablePassthrough(self.base_retriever.get_relevant_documents)(query)
137 # Rerank documents
138 if docs:
139 docs = await self.base_compressor.acompress_documents(docs, query)
140 return docs
142 def get_relevant_documents(self, query: str) -> List[Document]:
143 """Sync wrapper for async retrieval."""
144 import asyncio
145 return asyncio.run(self.ainvoke(query))
147 # Use our custom async-aware retriever
148 self.retriever_runnable = AsyncRerankerRetriever(
149 base_retriever=copy(self.retriever_runnable),
150 reranker=self.reranker
151 )
153 rag_chain_from_docs = (
154 RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"])))
155 | prompt
156 | self.llm
157 | StrOutputParser()
158 )
160 retrieval_chain = RunnableParallel(
161 context=self.retriever_runnable,
162 question=RunnablePassthrough()
163 )
164 if self.summarizer is not None:
165 retrieval_chain = retrieval_chain | self.summarizer
167 rag_chain_with_source = retrieval_chain.assign(answer=rag_chain_from_docs)
168 return rag_chain_with_source
170 async def ainvoke(self, input_dict: dict) -> dict:
171 """Async invocation of the RAG pipeline."""
172 chain = self.with_returned_sources()
173 return await chain.ainvoke(input_dict)
175 def invoke(self, input_dict: dict) -> dict:
176 """Sync invocation of the RAG pipeline."""
177 import asyncio
178 return asyncio.run(self.ainvoke(input_dict))
180 @classmethod
181 def _apply_search_kwargs(cls, retriever: Any, search_kwargs: Optional[SearchKwargs] = None, search_type: Optional[SearchType] = None) -> Any:
182 """Apply search kwargs and search type to the retriever if they exist"""
183 if hasattr(retriever, 'search_kwargs') and search_kwargs:
184 # Convert search kwargs to dict, excluding None values
185 kwargs_dict = search_kwargs.model_dump(exclude_none=True)
187 # Only include relevant parameters based on search type
188 if search_type == SearchType.SIMILARITY:
189 # Remove MMR and similarity threshold specific params
190 kwargs_dict.pop('fetch_k', None)
191 kwargs_dict.pop('lambda_mult', None)
192 kwargs_dict.pop('score_threshold', None)
193 elif search_type == SearchType.MMR:
194 # Remove similarity threshold specific params
195 kwargs_dict.pop('score_threshold', None)
196 elif search_type == SearchType.SIMILARITY_SCORE_THRESHOLD:
197 # Remove MMR specific params
198 kwargs_dict.pop('fetch_k', None)
199 kwargs_dict.pop('lambda_mult', None)
201 retriever.search_kwargs.update(kwargs_dict)
203 # Set search type if supported by the retriever
204 if hasattr(retriever, 'search_type') and search_type:
205 retriever.search_type = search_type.value
207 return retriever
209 @classmethod
210 def from_retriever(cls, config: RAGPipelineModel):
211 """
212 Builds a RAG pipeline with returned sources using a simple vector store retriever
213 :param config: RAGPipelineModel
214 :return:
215 """
216 vector_store_operator = VectorStoreOperator(
217 vector_store=config.vector_store,
218 documents=config.documents,
219 embedding_model=config.embedding_model,
220 vector_store_config=config.vector_store_config
221 )
222 retriever = vector_store_operator.vector_store.as_retriever()
223 retriever = cls._apply_search_kwargs(retriever, config.search_kwargs, config.search_type)
225 return cls(
226 retriever,
227 config.rag_prompt_template,
228 config.llm,
229 vector_store_config=config.vector_store_config,
230 reranker=config.reranker,
231 reranker_config=config.reranker_config,
232 summarization_config=config.summarization_config
233 )
235 @classmethod
236 def from_auto_retriever(cls, config: RAGPipelineModel):
237 if not config.retriever_prompt_template:
238 config.retriever_prompt_template = DEFAULT_AUTO_META_PROMPT_TEMPLATE
240 retriever = AutoRetriever(config=config).as_runnable()
241 retriever = cls._apply_search_kwargs(retriever, config.search_kwargs, config.search_type)
242 return cls(
243 retriever,
244 config.rag_prompt_template,
245 config.llm,
246 reranker_config=config.reranker_config,
247 reranker=config.reranker,
248 vector_store_config=config.vector_store_config,
249 summarization_config=config.summarization_config
250 )
252 @classmethod
253 def from_multi_vector_retriever(cls, config: RAGPipelineModel):
254 retriever = MultiVectorRetriever(config=config).as_runnable()
255 retriever = cls._apply_search_kwargs(retriever, config.search_kwargs, config.search_type)
256 return cls(
257 retriever,
258 config.rag_prompt_template,
259 config.llm,
260 reranker_config=config.reranker_config,
261 reranker=config.reranker,
262 vector_store_config=config.vector_store_config,
263 summarization_config=config.summarization_config
264 )
266 @classmethod
267 def from_sql_retriever(cls, config: RAGPipelineModel):
268 retriever_config = config.sql_retriever_config
269 if retriever_config is None:
270 raise ValueError('Must provide "sql_retriever_config" for RAG pipeline config')
271 vector_store_config = config.vector_store_config
272 knowledge_base_table = vector_store_config.kb_table if vector_store_config is not None else None
273 if knowledge_base_table is None:
274 raise ValueError('Must provide valid "vector_store_config" for RAG pipeline config')
275 embedding_args = knowledge_base_table._kb.embedding_model.learn_args.get('using', {})
276 embeddings = construct_model_from_args(embedding_args)
277 sql_llm = create_chat_model({
278 'model_name': retriever_config.llm_config.model_name,
279 'provider': retriever_config.llm_config.provider,
280 **retriever_config.llm_config.params
281 })
282 vector_store_operator = VectorStoreOperator(
283 vector_store=config.vector_store,
284 documents=config.documents,
285 embedding_model=config.embedding_model,
286 vector_store_config=config.vector_store_config
287 )
288 vector_store_retriever = vector_store_operator.vector_store.as_retriever()
289 vector_store_retriever = cls._apply_search_kwargs(vector_store_retriever, config.search_kwargs, config.search_type)
290 distance_function = DistanceFunction.SQUARED_EUCLIDEAN_DISTANCE
291 if config.vector_store_config.is_sparse and config.vector_store_config.vector_size is not None:
292 # Use negative dot product for sparse retrieval.
293 distance_function = DistanceFunction.NEGATIVE_DOT_PRODUCT
294 retriever = SQLRetriever(
295 fallback_retriever=vector_store_retriever,
296 vector_store_handler=knowledge_base_table.get_vector_db(),
297 min_k=retriever_config.min_k,
298 max_filters=retriever_config.max_filters,
299 filter_threshold=retriever_config.filter_threshold,
300 database_schema=retriever_config.database_schema,
301 embeddings_model=embeddings,
302 search_kwargs=config.search_kwargs,
303 rewrite_prompt_template=retriever_config.rewrite_prompt_template,
304 table_prompt_template=retriever_config.table_prompt_template,
305 column_prompt_template=retriever_config.column_prompt_template,
306 value_prompt_template=retriever_config.value_prompt_template,
307 boolean_system_prompt=retriever_config.boolean_system_prompt,
308 generative_system_prompt=retriever_config.generative_system_prompt,
309 num_retries=retriever_config.num_retries,
310 embeddings_table=knowledge_base_table._kb.vector_database_table,
311 source_table=retriever_config.source_table,
312 source_id_column=retriever_config.source_id_column,
313 distance_function=distance_function,
314 llm=sql_llm
315 )
316 return cls(
317 retriever,
318 config.rag_prompt_template,
319 config.llm,
320 reranker_config=config.reranker_config,
321 reranker=config.reranker,
322 vector_store_config=config.vector_store_config,
323 summarization_config=config.summarization_config
324 )