Coverage for mindsdb / integrations / utilities / rag / pipelines / rag.py: 18%

139 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from copy import copy 

2from typing import Optional, Any, List 

3 

4from langchain_core.output_parsers import StrOutputParser 

5from langchain.retrievers import ContextualCompressionRetriever 

6from langchain_core.documents import Document 

7 

8from langchain_core.prompts import ChatPromptTemplate 

9from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableSerializable 

10 

11from mindsdb.integrations.handlers.langchain_embedding_handler.langchain_embedding_handler import construct_model_from_args 

12from mindsdb.integrations.libs.vectordatabase_handler import DistanceFunction 

13from mindsdb.integrations.utilities.rag.chains.map_reduce_summarizer_chain import MapReduceSummarizerChain 

14from mindsdb.integrations.utilities.rag.retrievers.auto_retriever import AutoRetriever 

15from mindsdb.integrations.utilities.rag.retrievers.multi_vector_retriever import MultiVectorRetriever 

16from mindsdb.integrations.utilities.rag.retrievers.sql_retriever import SQLRetriever 

17from mindsdb.integrations.utilities.rag.rerankers.reranker_compressor import LLMReranker 

18from mindsdb.integrations.utilities.rag.settings import (RAGPipelineModel, 

19 DEFAULT_AUTO_META_PROMPT_TEMPLATE, 

20 SearchKwargs, SearchType, 

21 RerankerConfig, 

22 SummarizationConfig, VectorStoreConfig) 

23from mindsdb.integrations.utilities.rag.settings import DEFAULT_RERANKER_FLAG 

24 

25from mindsdb.integrations.utilities.rag.vector_store import VectorStoreOperator 

26from mindsdb.interfaces.agents.langchain_agent import create_chat_model 

27 

28 

29class LangChainRAGPipeline: 

30 """ 

31 Builds a RAG pipeline using langchain LCEL components 

32 

33 Args: 

34 retriever_runnable: Base retriever component 

35 prompt_template: Template for generating responses 

36 llm: Language model for generating responses 

37 reranker (bool): Whether to use reranking (default: False) 

38 reranker_config (RerankerConfig): Configuration for the reranker, including: 

39 - model: Model to use for reranking 

40 - filtering_threshold: Minimum score to keep a document 

41 - num_docs_to_keep: Maximum number of documents to keep 

42 - max_concurrent_requests: Maximum concurrent API requests 

43 - max_retries: Number of retry attempts for failed requests 

44 - retry_delay: Delay between retries 

45 - early_stop (bool): Whether to enable early stopping 

46 - early_stop_threshold: Confidence threshold for early stopping 

47 vector_store_config (VectorStoreConfig): Vector store configuration 

48 summarization_config (SummarizationConfig): Summarization configuration 

49 """ 

50 

51 def __init__( 

52 self, 

53 retriever_runnable, 

54 prompt_template, 

55 llm, 

56 reranker: bool = DEFAULT_RERANKER_FLAG, 

57 reranker_config: Optional[RerankerConfig] = None, 

58 vector_store_config: Optional[VectorStoreConfig] = None, 

59 summarization_config: Optional[SummarizationConfig] = None 

60 ): 

61 self.retriever_runnable = retriever_runnable 

62 self.prompt_template = prompt_template 

63 self.llm = llm 

64 if reranker: 

65 if reranker_config is None: 

66 reranker_config = RerankerConfig() 

67 # Convert config to dict and initialize reranker 

68 reranker_kwargs = reranker_config.model_dump(exclude_none=True) 

69 self.reranker = LLMReranker(**reranker_kwargs) 

70 else: 

71 self.reranker = None 

72 self.summarizer = None 

73 self.vector_store_config = vector_store_config 

74 knowledge_base_table = self.vector_store_config.kb_table if self.vector_store_config is not None else None 

75 if summarization_config is not None and knowledge_base_table is not None: 

76 self.summarizer = MapReduceSummarizerChain( 

77 vector_store_handler=knowledge_base_table.get_vector_db(), 

78 table_name=knowledge_base_table.get_vector_db_table_name(), 

79 summarization_config=summarization_config 

80 ) 

81 

82 def with_returned_sources(self) -> RunnableSerializable: 

83 """ 

84 Builds a RAG pipeline with returned sources 

85 :return: 

86 """ 

87 

88 def format_docs(docs): 

89 if isinstance(docs, str): 

90 # this is to handle the case where the retriever returns a string 

91 # instead of a list of documents e.g. SQLRetriever 

92 return docs 

93 if not docs: 

94 return '' 

95 # Sort by original document so we can group source summaries together. 

96 docs.sort(key=lambda d: d.metadata.get('original_row_id') if d.metadata else 0) 

97 original_document_id = None 

98 summary_prepended_text = 'Summary of the original document that the below context was taken from:\n' 

99 document_content = '' 

100 for d in docs: 

101 metadata = d.metadata or {} 

102 if metadata.get('original_row_id') != original_document_id and metadata.get('summary'): 

103 # We have a summary of a new document to prepend. 

104 original_document_id = metadata.get('original_row_id') 

105 summary = f"{summary_prepended_text}{metadata.get('summary')}\n" 

106 document_content += summary 

107 document_content += f'{d.page_content}\n\n' 

108 return document_content 

109 

110 prompt = ChatPromptTemplate.from_template(self.prompt_template) 

111 

112 # Ensure all the required components are not None 

113 if prompt is None: 

114 raise ValueError("One of the required components (prompt) is None") 

115 if self.llm is None: 

116 raise ValueError("One of the required components (llm) is None") 

117 

118 if self.reranker: 

119 # Create a custom retriever that handles async operations properly 

120 class AsyncRerankerRetriever(ContextualCompressionRetriever): 

121 """Async-aware retriever that properly handles concurrent reranking operations.""" 

122 

123 def __init__(self, base_retriever, reranker): 

124 super().__init__( 

125 base_compressor=reranker, 

126 base_retriever=base_retriever 

127 ) 

128 

129 async def ainvoke(self, query: str) -> List[Document]: 

130 """Async retrieval with proper concurrency handling.""" 

131 # Get initial documents 

132 if hasattr(self.base_retriever, 'ainvoke'): 

133 docs = await self.base_retriever.ainvoke(query) 

134 else: 

135 docs = await RunnablePassthrough(self.base_retriever.get_relevant_documents)(query) 

136 

137 # Rerank documents 

138 if docs: 

139 docs = await self.base_compressor.acompress_documents(docs, query) 

140 return docs 

141 

142 def get_relevant_documents(self, query: str) -> List[Document]: 

143 """Sync wrapper for async retrieval.""" 

144 import asyncio 

145 return asyncio.run(self.ainvoke(query)) 

146 

147 # Use our custom async-aware retriever 

148 self.retriever_runnable = AsyncRerankerRetriever( 

149 base_retriever=copy(self.retriever_runnable), 

150 reranker=self.reranker 

151 ) 

152 

153 rag_chain_from_docs = ( 

154 RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"]))) 

155 | prompt 

156 | self.llm 

157 | StrOutputParser() 

158 ) 

159 

160 retrieval_chain = RunnableParallel( 

161 context=self.retriever_runnable, 

162 question=RunnablePassthrough() 

163 ) 

164 if self.summarizer is not None: 

165 retrieval_chain = retrieval_chain | self.summarizer 

166 

167 rag_chain_with_source = retrieval_chain.assign(answer=rag_chain_from_docs) 

168 return rag_chain_with_source 

169 

170 async def ainvoke(self, input_dict: dict) -> dict: 

171 """Async invocation of the RAG pipeline.""" 

172 chain = self.with_returned_sources() 

173 return await chain.ainvoke(input_dict) 

174 

175 def invoke(self, input_dict: dict) -> dict: 

176 """Sync invocation of the RAG pipeline.""" 

177 import asyncio 

178 return asyncio.run(self.ainvoke(input_dict)) 

179 

180 @classmethod 

181 def _apply_search_kwargs(cls, retriever: Any, search_kwargs: Optional[SearchKwargs] = None, search_type: Optional[SearchType] = None) -> Any: 

182 """Apply search kwargs and search type to the retriever if they exist""" 

183 if hasattr(retriever, 'search_kwargs') and search_kwargs: 

184 # Convert search kwargs to dict, excluding None values 

185 kwargs_dict = search_kwargs.model_dump(exclude_none=True) 

186 

187 # Only include relevant parameters based on search type 

188 if search_type == SearchType.SIMILARITY: 

189 # Remove MMR and similarity threshold specific params 

190 kwargs_dict.pop('fetch_k', None) 

191 kwargs_dict.pop('lambda_mult', None) 

192 kwargs_dict.pop('score_threshold', None) 

193 elif search_type == SearchType.MMR: 

194 # Remove similarity threshold specific params 

195 kwargs_dict.pop('score_threshold', None) 

196 elif search_type == SearchType.SIMILARITY_SCORE_THRESHOLD: 

197 # Remove MMR specific params 

198 kwargs_dict.pop('fetch_k', None) 

199 kwargs_dict.pop('lambda_mult', None) 

200 

201 retriever.search_kwargs.update(kwargs_dict) 

202 

203 # Set search type if supported by the retriever 

204 if hasattr(retriever, 'search_type') and search_type: 

205 retriever.search_type = search_type.value 

206 

207 return retriever 

208 

209 @classmethod 

210 def from_retriever(cls, config: RAGPipelineModel): 

211 """ 

212 Builds a RAG pipeline with returned sources using a simple vector store retriever 

213 :param config: RAGPipelineModel 

214 :return: 

215 """ 

216 vector_store_operator = VectorStoreOperator( 

217 vector_store=config.vector_store, 

218 documents=config.documents, 

219 embedding_model=config.embedding_model, 

220 vector_store_config=config.vector_store_config 

221 ) 

222 retriever = vector_store_operator.vector_store.as_retriever() 

223 retriever = cls._apply_search_kwargs(retriever, config.search_kwargs, config.search_type) 

224 

225 return cls( 

226 retriever, 

227 config.rag_prompt_template, 

228 config.llm, 

229 vector_store_config=config.vector_store_config, 

230 reranker=config.reranker, 

231 reranker_config=config.reranker_config, 

232 summarization_config=config.summarization_config 

233 ) 

234 

235 @classmethod 

236 def from_auto_retriever(cls, config: RAGPipelineModel): 

237 if not config.retriever_prompt_template: 

238 config.retriever_prompt_template = DEFAULT_AUTO_META_PROMPT_TEMPLATE 

239 

240 retriever = AutoRetriever(config=config).as_runnable() 

241 retriever = cls._apply_search_kwargs(retriever, config.search_kwargs, config.search_type) 

242 return cls( 

243 retriever, 

244 config.rag_prompt_template, 

245 config.llm, 

246 reranker_config=config.reranker_config, 

247 reranker=config.reranker, 

248 vector_store_config=config.vector_store_config, 

249 summarization_config=config.summarization_config 

250 ) 

251 

252 @classmethod 

253 def from_multi_vector_retriever(cls, config: RAGPipelineModel): 

254 retriever = MultiVectorRetriever(config=config).as_runnable() 

255 retriever = cls._apply_search_kwargs(retriever, config.search_kwargs, config.search_type) 

256 return cls( 

257 retriever, 

258 config.rag_prompt_template, 

259 config.llm, 

260 reranker_config=config.reranker_config, 

261 reranker=config.reranker, 

262 vector_store_config=config.vector_store_config, 

263 summarization_config=config.summarization_config 

264 ) 

265 

266 @classmethod 

267 def from_sql_retriever(cls, config: RAGPipelineModel): 

268 retriever_config = config.sql_retriever_config 

269 if retriever_config is None: 

270 raise ValueError('Must provide "sql_retriever_config" for RAG pipeline config') 

271 vector_store_config = config.vector_store_config 

272 knowledge_base_table = vector_store_config.kb_table if vector_store_config is not None else None 

273 if knowledge_base_table is None: 

274 raise ValueError('Must provide valid "vector_store_config" for RAG pipeline config') 

275 embedding_args = knowledge_base_table._kb.embedding_model.learn_args.get('using', {}) 

276 embeddings = construct_model_from_args(embedding_args) 

277 sql_llm = create_chat_model({ 

278 'model_name': retriever_config.llm_config.model_name, 

279 'provider': retriever_config.llm_config.provider, 

280 **retriever_config.llm_config.params 

281 }) 

282 vector_store_operator = VectorStoreOperator( 

283 vector_store=config.vector_store, 

284 documents=config.documents, 

285 embedding_model=config.embedding_model, 

286 vector_store_config=config.vector_store_config 

287 ) 

288 vector_store_retriever = vector_store_operator.vector_store.as_retriever() 

289 vector_store_retriever = cls._apply_search_kwargs(vector_store_retriever, config.search_kwargs, config.search_type) 

290 distance_function = DistanceFunction.SQUARED_EUCLIDEAN_DISTANCE 

291 if config.vector_store_config.is_sparse and config.vector_store_config.vector_size is not None: 

292 # Use negative dot product for sparse retrieval. 

293 distance_function = DistanceFunction.NEGATIVE_DOT_PRODUCT 

294 retriever = SQLRetriever( 

295 fallback_retriever=vector_store_retriever, 

296 vector_store_handler=knowledge_base_table.get_vector_db(), 

297 min_k=retriever_config.min_k, 

298 max_filters=retriever_config.max_filters, 

299 filter_threshold=retriever_config.filter_threshold, 

300 database_schema=retriever_config.database_schema, 

301 embeddings_model=embeddings, 

302 search_kwargs=config.search_kwargs, 

303 rewrite_prompt_template=retriever_config.rewrite_prompt_template, 

304 table_prompt_template=retriever_config.table_prompt_template, 

305 column_prompt_template=retriever_config.column_prompt_template, 

306 value_prompt_template=retriever_config.value_prompt_template, 

307 boolean_system_prompt=retriever_config.boolean_system_prompt, 

308 generative_system_prompt=retriever_config.generative_system_prompt, 

309 num_retries=retriever_config.num_retries, 

310 embeddings_table=knowledge_base_table._kb.vector_database_table, 

311 source_table=retriever_config.source_table, 

312 source_id_column=retriever_config.source_id_column, 

313 distance_function=distance_function, 

314 llm=sql_llm 

315 ) 

316 return cls( 

317 retriever, 

318 config.rag_prompt_template, 

319 config.llm, 

320 reranker_config=config.reranker_config, 

321 reranker=config.reranker, 

322 vector_store_config=config.vector_store_config, 

323 summarization_config=config.summarization_config 

324 )