Coverage for mindsdb / integrations / utilities / rag / chains / local_context_summarizer_chain.py: 0%

121 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import asyncio 

2from collections import namedtuple 

3from typing import Any, Dict, List, Optional 

4 

5from mindsdb.interfaces.agents.langchain_agent import create_chat_model 

6from langchain.chains.base import Chain 

7from langchain.chains.combine_documents.stuff import StuffDocumentsChain 

8from langchain.chains.llm import LLMChain 

9from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain, ReduceDocumentsChain 

10from langchain_core.callbacks import dispatch_custom_event 

11from langchain_core.callbacks.manager import CallbackManagerForChainRun 

12from langchain_core.documents import Document 

13from langchain_core.prompts import PromptTemplate 

14from pandas import DataFrame 

15 

16from mindsdb.integrations.libs.vectordatabase_handler import VectorStoreHandler 

17from mindsdb.integrations.utilities.rag.settings import SummarizationConfig 

18from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator 

19from mindsdb.utilities import log 

20 

21logger = log.getLogger(__name__) 

22 

23Summary = namedtuple('Summary', ['source_id', 'content']) 

24 

25 

26def create_map_reduce_documents_chain(summarization_config: SummarizationConfig, input: str) -> ReduceDocumentsChain: 

27 """Creates a chain that map-reduces documents into a single consolidated summary.""" 

28 summarization_llm = create_chat_model({ 

29 'model_name': summarization_config.llm_config.model_name, 

30 'provider': summarization_config.llm_config.provider, 

31 **summarization_config.llm_config.params 

32 }) 

33 

34 reduce_prompt_template = summarization_config.reduce_prompt_template 

35 reduce_prompt = PromptTemplate.from_template(reduce_prompt_template) 

36 if 'input' in reduce_prompt.input_variables: 

37 reduce_prompt = reduce_prompt.partial(input=input) 

38 

39 reduce_chain = LLMChain(llm=summarization_llm, prompt=reduce_prompt) 

40 

41 combine_documents_chain = StuffDocumentsChain( 

42 llm_chain=reduce_chain, 

43 document_variable_name='docs' 

44 ) 

45 

46 return ReduceDocumentsChain( 

47 combine_documents_chain=combine_documents_chain, 

48 collapse_documents_chain=combine_documents_chain, 

49 token_max=summarization_config.max_summarization_tokens 

50 ) 

51 

52 

53class LocalContextSummarizerChain(Chain): 

54 """Summarizes M chunks before and after a given chunk in a document.""" 

55 

56 doc_id_key: str = 'original_row_id' 

57 chunk_index_key: str = 'chunk_index' 

58 

59 vector_store_handler: VectorStoreHandler 

60 table_name: str = 'embeddings' 

61 content_column_name: str = 'content' 

62 metadata_column_name: str = 'metadata' 

63 

64 summarization_config: SummarizationConfig 

65 map_reduce_documents_chain: Optional[ReduceDocumentsChain] = None 

66 

67 def _select_chunks_from_vector_store(self, doc_id: str) -> DataFrame: 

68 condition = FilterCondition( 

69 f"{self.metadata_column_name}->>'{self.doc_id_key}'", 

70 FilterOperator.EQUAL, 

71 doc_id 

72 ) 

73 return self.vector_store_handler.select( 

74 self.table_name, 

75 columns=[self.content_column_name, self.metadata_column_name], 

76 conditions=[condition] 

77 ) 

78 

79 async def _get_all_chunks_for_document(self, doc_id: str) -> List[Document]: 

80 df = await asyncio.get_event_loop().run_in_executor( 

81 None, self._select_chunks_from_vector_store, doc_id 

82 ) 

83 chunks = [] 

84 for _, row in df.iterrows(): 

85 metadata = row.get(self.metadata_column_name, {}) 

86 metadata[self.chunk_index_key] = row.get('chunk_id', 0) 

87 chunks.append(Document(page_content=row[self.content_column_name], metadata=metadata)) 

88 

89 return sorted(chunks, key=lambda x: x.metadata.get(self.chunk_index_key, 0)) 

90 

91 async def summarize_local_context(self, doc_id: str, target_chunk_index: int, M: int) -> Summary: 

92 """ 

93 Summarizes M chunks before and after the given chunk. 

94 

95 Args: 

96 doc_id (str): Document ID. 

97 target_chunk_index (int): Index of the chunk to summarize around. 

98 M (int): Number of chunks before and after to include. 

99 

100 Returns: 

101 Summary: Summary object containing source_id and summary content. 

102 """ 

103 logger.debug(f"Fetching chunks for document {doc_id}") 

104 all_chunks = await self._get_all_chunks_for_document(doc_id) 

105 

106 if not all_chunks: 

107 logger.warning(f"No chunks found for document {doc_id}") 

108 return Summary(source_id=doc_id, content='') 

109 

110 # Determine window boundaries 

111 start_idx = max(0, target_chunk_index - M) 

112 end_idx = min(len(all_chunks), target_chunk_index + M + 1) 

113 local_chunks = all_chunks[start_idx:end_idx] 

114 

115 logger.debug(f"Summarizing chunks {start_idx} to {end_idx - 1} for document {doc_id}") 

116 

117 if not self.map_reduce_documents_chain: 

118 self.map_reduce_documents_chain = create_map_reduce_documents_chain( 

119 self.summarization_config, input="Summarize these chunks." 

120 ) 

121 

122 summary_result = await self.map_reduce_documents_chain.ainvoke(local_chunks) 

123 summary_text = summary_result.get('output_text', '') 

124 

125 logger.debug(f"Generated summary: {summary_text[:100]}...") 

126 

127 return Summary(source_id=doc_id, content=summary_text) 

128 

129 @property 

130 def input_keys(self) -> List[str]: 

131 return [self.context_key, self.question_key] 

132 

133 @property 

134 def output_keys(self) -> List[str]: 

135 return [self.context_key, self.question_key] 

136 

137 async def _get_source_summary(self, source_id: str, map_reduce_documents_chain: MapReduceDocumentsChain) -> Summary: 

138 if not source_id: 

139 logger.warning("Received empty source_id, returning empty summary") 

140 return Summary(source_id='', content='') 

141 

142 logger.debug(f"Getting summary for source ID: {source_id}") 

143 source_chunks = await self._get_all_chunks_for_document(source_id) 

144 

145 if not source_chunks: 

146 logger.warning(f"No chunks found for source ID: {source_id}") 

147 return Summary(source_id=source_id, content='') 

148 

149 logger.debug(f"Summarizing {len(source_chunks)} chunks for source ID: {source_id}") 

150 summary = await map_reduce_documents_chain.ainvoke(source_chunks) 

151 content = summary.get('output_text', '') 

152 logger.debug(f"Generated summary for source ID {source_id}: {content[:100]}...") 

153 

154 # Stream summarization update. 

155 dispatch_custom_event('summary', {'source_id': source_id, 'content': content}) 

156 

157 return Summary(source_id=source_id, content=content) 

158 

159 async def _get_source_summaries(self, source_ids: List[str], map_reduce_documents_chain: MapReduceDocumentsChain) -> \ 

160 List[Summary]: 

161 summaries = await asyncio.gather( 

162 *[self._get_source_summary(source_id, map_reduce_documents_chain) for source_id in source_ids] 

163 ) 

164 return summaries 

165 

166 def _call( 

167 self, 

168 inputs: Dict[str, Any], 

169 run_manager: Optional[CallbackManagerForChainRun] = None 

170 ) -> Dict[str, Any]: 

171 # Step 1: Connect to vector store to ensure embeddings are accessible 

172 self.vector_store_handler.connect() 

173 

174 context_chunks: List[Document] = inputs.get(self.context_key, []) 

175 logger.debug(f"Found {len(context_chunks)} context chunks.") 

176 

177 # Step 2: Extract unique document IDs from the provided chunks 

178 unique_document_ids = self._get_document_ids_from_chunks(context_chunks) 

179 logger.debug(f"Extracted {len(unique_document_ids)} unique document IDs: {unique_document_ids}") 

180 

181 # Step 3: Initialize the summarization chain if not provided 

182 question = inputs.get(self.question_key, '') 

183 map_reduce_documents_chain = self.map_reduce_documents_chain or create_map_reduce_documents_chain( 

184 self.summarization_config, question 

185 ) 

186 

187 # Step 4: Dispatch event to signal summarization start 

188 if run_manager: 

189 run_manager.on_text("Starting summarization for documents.", verbose=True) 

190 

191 # Step 5: Process each document ID to summarize chunks with local context 

192 for doc_id in unique_document_ids: 

193 logger.debug(f"Fetching and summarizing chunks for document ID: {doc_id}") 

194 

195 # Fetch all chunks for the document 

196 chunks = asyncio.get_event_loop().run_until_complete(self._get_all_chunks_for_document(doc_id)) 

197 if not chunks: 

198 logger.warning(f"No chunks found for document ID: {doc_id}") 

199 continue 

200 

201 # Summarize each chunk with M neighboring chunks 

202 M = self.neighbor_window 

203 for i, chunk in enumerate(chunks): 

204 window_chunks = chunks[max(0, i - M): min(len(chunks), i + M + 1)] 

205 local_summary = asyncio.get_event_loop().run_until_complete( 

206 map_reduce_documents_chain.ainvoke(window_chunks) 

207 ) 

208 chunk.metadata['summary'] = local_summary.get('output_text', '') 

209 logger.debug(f"Chunk {i} summary: {chunk.metadata['summary'][:100]}...") 

210 

211 # Step 6: Update the original context chunks with the newly generated summaries 

212 for chunk in context_chunks: 

213 doc_id = str(chunk.metadata.get(self.doc_id_key, '')) 

214 matching_chunk = next((c for c in chunks if c.metadata.get(self.doc_id_key) == doc_id and c.metadata.get( 

215 'chunk_index') == chunk.metadata.get('chunk_index')), None) 

216 if matching_chunk: 

217 chunk.metadata['summary'] = matching_chunk.metadata.get('summary', '') 

218 else: 

219 chunk.metadata['summary'] = '' 

220 logger.warning(f"No matching chunk found for doc_id: {doc_id}") 

221 

222 # Step 7: Signal summarization end 

223 if run_manager: 

224 run_manager.on_text("Summarization completed.", verbose=True) 

225 

226 logger.debug(f"Updated {len(context_chunks)} context chunks with summaries.") 

227 return inputs