Coverage for mindsdb / integrations / utilities / rag / chains / map_reduce_summarizer_chain.py: 26%

136 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import asyncio 

2from collections import namedtuple 

3from typing import Any, Dict, List, Optional 

4 

5from mindsdb.interfaces.agents.langchain_agent import create_chat_model 

6from langchain.chains.base import Chain 

7from langchain.chains.combine_documents.stuff import StuffDocumentsChain 

8from langchain.chains.llm import LLMChain 

9from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain, ReduceDocumentsChain 

10from langchain_core.callbacks import dispatch_custom_event 

11from langchain_core.callbacks.manager import CallbackManagerForChainRun 

12from langchain_core.documents import Document 

13from langchain_core.prompts import PromptTemplate 

14from pandas import DataFrame 

15 

16from mindsdb.integrations.libs.vectordatabase_handler import VectorStoreHandler 

17from mindsdb.integrations.utilities.rag.settings import SummarizationConfig 

18from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator 

19from mindsdb.utilities import log 

20 

21logger = log.getLogger(__name__) 

22 

23 

24Summary = namedtuple('Summary', ['source_id', 'content']) 

25 

26 

27def create_map_reduce_documents_chain(summarization_config: SummarizationConfig, input: str) -> ReduceDocumentsChain: 

28 '''Creats a chain that map reduces documents into a single consolidated summary 

29 

30 Args: 

31 summarization_config (SummarizationConfig): Configuration for how to perform summarization 

32 

33 Returns: 

34 chain (MapReduceDocumentsChain): Chain that map reduces documents. 

35 ''' 

36 summarization_llm = create_chat_model({ 

37 'model_name': summarization_config.llm_config.model_name, 

38 'provider': summarization_config.llm_config.provider, 

39 **summarization_config.llm_config.params 

40 }) 

41 map_prompt_template = summarization_config.map_prompt_template 

42 map_prompt = PromptTemplate.from_template(map_prompt_template) 

43 # Langchain needs a template with only a variable for docs so we use a partial. 

44 if 'input' in map_prompt.input_variables: 

45 map_prompt = map_prompt.partial(input=input) 

46 # Handles summarization of individual chunks. 

47 # map_chain = LLMChain(llm=summarization_llm, prompt=map_prompt) 

48 

49 reduce_prompt_template = summarization_config.reduce_prompt_template 

50 reduce_prompt = PromptTemplate.from_template(reduce_prompt_template) 

51 # Langchain needs a template with only a variable for docs so we use a partial. 

52 if 'input' in reduce_prompt.input_variables: 

53 reduce_prompt = reduce_prompt.partial(input=input) 

54 # Combines summarizations from multiple chunks into a consolidated summary. 

55 reduce_chain = LLMChain(llm=summarization_llm, prompt=reduce_prompt) 

56 

57 # Takes a list of docs, combines them into a single string, then passes to an LLMChain. 

58 combine_documents_chain = StuffDocumentsChain( 

59 llm_chain=reduce_chain, 

60 document_variable_name='docs' 

61 ) 

62 

63 # Combines & iteratively reduces mapped documents. 

64 return ReduceDocumentsChain( 

65 combine_documents_chain=combine_documents_chain, 

66 collapse_documents_chain=combine_documents_chain, 

67 # Max number of tokens to group documents into. 

68 token_max=summarization_config.max_summarization_tokens 

69 ) 

70 

71 

72class MapReduceSummarizerChain(Chain): 

73 '''Chain to summarize the source documents for document chunks & return as context''' 

74 

75 context_key: str = 'context' 

76 metadata_key: str = 'metadata' 

77 doc_id_key: str = 'original_row_id' 

78 question_key: str = 'question' 

79 

80 vector_store_handler: VectorStoreHandler 

81 table_name: str = 'embeddings' 

82 id_column_name: str = 'id' 

83 content_column_name: str = 'content' 

84 metadata_column_name: str = 'metadata' 

85 

86 summarization_config: SummarizationConfig 

87 map_reduce_documents_chain: Optional[MapReduceDocumentsChain] = None 

88 

89 @property 

90 def input_keys(self) -> List[str]: 

91 return [self.context_key, self.question_key] 

92 

93 @property 

94 def output_keys(self) -> List[str]: 

95 return [self.context_key, self.question_key] 

96 

97 def _get_document_ids_from_chunks(self, chunks: List[Document]) -> List[str]: 

98 unique_document_ids = set() 

99 document_ids = [] 

100 logger.debug(f"Processing {len(chunks)} chunks to extract document IDs") 

101 for chunk in chunks: 

102 if not chunk.metadata: 

103 chunk.metadata = {} 

104 logger.warning("Chunk metadata was empty, creating new metadata dictionary") 

105 metadata = chunk.metadata 

106 doc_id = str(metadata.get(self.doc_id_key, '')) 

107 logger.debug(f"Processing chunk with metadata: {metadata}, extracted doc_id: {doc_id}") 

108 if doc_id and doc_id not in unique_document_ids: 

109 # Sets in Python don't guarantee preserved order, so we use a list to make testing easier. 

110 document_ids.append(doc_id) 

111 unique_document_ids.add(doc_id) 

112 logger.debug(f"Found {len(document_ids)} unique document IDs: {document_ids}") 

113 return document_ids 

114 

115 def _select_chunks_from_vector_store(self, conditions: List[FilterCondition]) -> DataFrame: 

116 return self.vector_store_handler.select( 

117 self.table_name, 

118 columns=[self.content_column_name, self.metadata_column_name], 

119 conditions=conditions 

120 ) 

121 

122 async def _get_all_chunks_for_document(self, id: str) -> List[Document]: 

123 logger.debug(f"Fetching all chunks for document ID: {id}") 

124 id_filter_condition = FilterCondition( 

125 f"{self.metadata_column_name}->>'{self.doc_id_key}'", 

126 FilterOperator.EQUAL, 

127 id 

128 ) 

129 all_source_chunks = await asyncio.get_event_loop().run_in_executor(None, self._select_chunks_from_vector_store, [id_filter_condition]) 

130 document_chunks = [] 

131 for _, row in all_source_chunks.iterrows(): 

132 metadata = row.get(self.metadata_column_name, {}) 

133 if row.get('chunk_id', None) is not None: 

134 metadata['chunk_index'] = row.get('chunk_id', 0) 

135 document_chunks.append(Document(page_content=row[self.content_column_name], metadata=metadata)) 

136 # Sort by chunk index if present in metadata so the full document is in its original order. 

137 document_chunks.sort(key=lambda doc: doc.metadata.get('chunk_index', 0) if doc.metadata else 0) 

138 logger.debug(f"Found {len(document_chunks)} chunks for document ID {id}") 

139 return document_chunks 

140 

141 async def _get_source_summary(self, source_id: str, map_reduce_documents_chain: MapReduceDocumentsChain) -> Summary: 

142 if not source_id: 

143 logger.warning("Received empty source_id, returning empty summary") 

144 return Summary(source_id='', content='') 

145 

146 logger.debug(f"Getting summary for source ID: {source_id}") 

147 source_chunks = await self._get_all_chunks_for_document(source_id) 

148 

149 if not source_chunks: 

150 logger.warning(f"No chunks found for source ID: {source_id}") 

151 return Summary(source_id=source_id, content='') 

152 

153 logger.debug(f"Summarizing {len(source_chunks)} chunks for source ID: {source_id}") 

154 summary = await map_reduce_documents_chain.ainvoke(source_chunks) 

155 content = summary.get('output_text', '') 

156 logger.debug(f"Generated summary for source ID {source_id}: {content[:100]}...") 

157 

158 # Stream summarization update. 

159 dispatch_custom_event('summary', {'source_id': source_id, 'content': content}) 

160 

161 return Summary(source_id=source_id, content=content) 

162 

163 async def _get_source_summaries(self, source_ids: List[str], map_reduce_documents_chain: MapReduceDocumentsChain) -> List[Summary]: 

164 summaries = await asyncio.gather( 

165 *[self._get_source_summary(source_id, map_reduce_documents_chain) for source_id in source_ids] 

166 ) 

167 return summaries 

168 

169 def _call( 

170 self, 

171 inputs: Dict[str, Any], 

172 run_manager: Optional[CallbackManagerForChainRun] = None 

173 ) -> Dict[str, Any]: 

174 # Explicitly connect to make sure vectors are registered. 

175 _ = self.vector_store_handler.connect() 

176 logger.debug(f"Processing inputs with keys: {list(inputs.keys())}") 

177 context_chunks = inputs.get(self.context_key, []) 

178 logger.debug(f"Found {len(context_chunks)} context chunks") 

179 

180 unique_document_ids = self._get_document_ids_from_chunks(context_chunks) 

181 logger.debug(f"Extracted {len(unique_document_ids)} unique document IDs") 

182 

183 question = inputs.get(self.question_key, '') 

184 map_reduce_documents_chain = self.map_reduce_documents_chain 

185 if map_reduce_documents_chain is None: 

186 map_reduce_documents_chain = create_map_reduce_documents_chain(self.summarization_config, question) 

187 # For each document ID associated with one or more chunks, build the full document by 

188 # getting ALL chunks associated with that ID. Then, map reduce summarize the complete document. 

189 dispatch_custom_event('summary_begin', {'num_documents': len(unique_document_ids)}) 

190 try: 

191 logger.debug("Starting async summary generation") 

192 summaries = asyncio.get_event_loop().run_until_complete(self._get_source_summaries(unique_document_ids, map_reduce_documents_chain)) 

193 except RuntimeError: 

194 logger.info("No event loop available, creating new one") 

195 # If no event loop is available, create a new one 

196 loop = asyncio.new_event_loop() 

197 asyncio.set_event_loop(loop) 

198 summaries = loop.run_until_complete(self._get_source_summaries(unique_document_ids, map_reduce_documents_chain)) 

199 

200 source_id_to_summary = {} 

201 for summary in summaries: 

202 source_id_to_summary[summary.source_id] = summary.content 

203 logger.debug(f"Generated {len(source_id_to_summary)} summaries") 

204 

205 # Update context chunks with document summaries. 

206 for chunk in context_chunks: 

207 if not chunk.metadata: 

208 chunk.metadata = {} 

209 logger.warning("Chunk metadata was empty, creating new metadata dictionary") 

210 

211 metadata = chunk.metadata 

212 doc_id = str(metadata.get(self.doc_id_key, '')) 

213 logger.debug(f"Updating chunk with doc_id {doc_id}") 

214 if doc_id in source_id_to_summary: 

215 chunk.metadata['summary'] = source_id_to_summary[doc_id] 

216 else: 

217 logger.warning(f"No summary found for doc_id: {doc_id}") 

218 chunk.metadata['summary'] = '' 

219 

220 # Stream summarization update. 

221 dispatch_custom_event('summary_end', {'num_documents': len(source_id_to_summary)}) 

222 

223 return inputs