Coverage for mindsdb / integrations / utilities / rag / chains / local_context_summarizer_chain.py: 0%
121 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import asyncio
2from collections import namedtuple
3from typing import Any, Dict, List, Optional
5from mindsdb.interfaces.agents.langchain_agent import create_chat_model
6from langchain.chains.base import Chain
7from langchain.chains.combine_documents.stuff import StuffDocumentsChain
8from langchain.chains.llm import LLMChain
9from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain, ReduceDocumentsChain
10from langchain_core.callbacks import dispatch_custom_event
11from langchain_core.callbacks.manager import CallbackManagerForChainRun
12from langchain_core.documents import Document
13from langchain_core.prompts import PromptTemplate
14from pandas import DataFrame
16from mindsdb.integrations.libs.vectordatabase_handler import VectorStoreHandler
17from mindsdb.integrations.utilities.rag.settings import SummarizationConfig
18from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator
19from mindsdb.utilities import log
21logger = log.getLogger(__name__)
23Summary = namedtuple('Summary', ['source_id', 'content'])
26def create_map_reduce_documents_chain(summarization_config: SummarizationConfig, input: str) -> ReduceDocumentsChain:
27 """Creates a chain that map-reduces documents into a single consolidated summary."""
28 summarization_llm = create_chat_model({
29 'model_name': summarization_config.llm_config.model_name,
30 'provider': summarization_config.llm_config.provider,
31 **summarization_config.llm_config.params
32 })
34 reduce_prompt_template = summarization_config.reduce_prompt_template
35 reduce_prompt = PromptTemplate.from_template(reduce_prompt_template)
36 if 'input' in reduce_prompt.input_variables:
37 reduce_prompt = reduce_prompt.partial(input=input)
39 reduce_chain = LLMChain(llm=summarization_llm, prompt=reduce_prompt)
41 combine_documents_chain = StuffDocumentsChain(
42 llm_chain=reduce_chain,
43 document_variable_name='docs'
44 )
46 return ReduceDocumentsChain(
47 combine_documents_chain=combine_documents_chain,
48 collapse_documents_chain=combine_documents_chain,
49 token_max=summarization_config.max_summarization_tokens
50 )
53class LocalContextSummarizerChain(Chain):
54 """Summarizes M chunks before and after a given chunk in a document."""
56 doc_id_key: str = 'original_row_id'
57 chunk_index_key: str = 'chunk_index'
59 vector_store_handler: VectorStoreHandler
60 table_name: str = 'embeddings'
61 content_column_name: str = 'content'
62 metadata_column_name: str = 'metadata'
64 summarization_config: SummarizationConfig
65 map_reduce_documents_chain: Optional[ReduceDocumentsChain] = None
67 def _select_chunks_from_vector_store(self, doc_id: str) -> DataFrame:
68 condition = FilterCondition(
69 f"{self.metadata_column_name}->>'{self.doc_id_key}'",
70 FilterOperator.EQUAL,
71 doc_id
72 )
73 return self.vector_store_handler.select(
74 self.table_name,
75 columns=[self.content_column_name, self.metadata_column_name],
76 conditions=[condition]
77 )
79 async def _get_all_chunks_for_document(self, doc_id: str) -> List[Document]:
80 df = await asyncio.get_event_loop().run_in_executor(
81 None, self._select_chunks_from_vector_store, doc_id
82 )
83 chunks = []
84 for _, row in df.iterrows():
85 metadata = row.get(self.metadata_column_name, {})
86 metadata[self.chunk_index_key] = row.get('chunk_id', 0)
87 chunks.append(Document(page_content=row[self.content_column_name], metadata=metadata))
89 return sorted(chunks, key=lambda x: x.metadata.get(self.chunk_index_key, 0))
91 async def summarize_local_context(self, doc_id: str, target_chunk_index: int, M: int) -> Summary:
92 """
93 Summarizes M chunks before and after the given chunk.
95 Args:
96 doc_id (str): Document ID.
97 target_chunk_index (int): Index of the chunk to summarize around.
98 M (int): Number of chunks before and after to include.
100 Returns:
101 Summary: Summary object containing source_id and summary content.
102 """
103 logger.debug(f"Fetching chunks for document {doc_id}")
104 all_chunks = await self._get_all_chunks_for_document(doc_id)
106 if not all_chunks:
107 logger.warning(f"No chunks found for document {doc_id}")
108 return Summary(source_id=doc_id, content='')
110 # Determine window boundaries
111 start_idx = max(0, target_chunk_index - M)
112 end_idx = min(len(all_chunks), target_chunk_index + M + 1)
113 local_chunks = all_chunks[start_idx:end_idx]
115 logger.debug(f"Summarizing chunks {start_idx} to {end_idx - 1} for document {doc_id}")
117 if not self.map_reduce_documents_chain:
118 self.map_reduce_documents_chain = create_map_reduce_documents_chain(
119 self.summarization_config, input="Summarize these chunks."
120 )
122 summary_result = await self.map_reduce_documents_chain.ainvoke(local_chunks)
123 summary_text = summary_result.get('output_text', '')
125 logger.debug(f"Generated summary: {summary_text[:100]}...")
127 return Summary(source_id=doc_id, content=summary_text)
129 @property
130 def input_keys(self) -> List[str]:
131 return [self.context_key, self.question_key]
133 @property
134 def output_keys(self) -> List[str]:
135 return [self.context_key, self.question_key]
137 async def _get_source_summary(self, source_id: str, map_reduce_documents_chain: MapReduceDocumentsChain) -> Summary:
138 if not source_id:
139 logger.warning("Received empty source_id, returning empty summary")
140 return Summary(source_id='', content='')
142 logger.debug(f"Getting summary for source ID: {source_id}")
143 source_chunks = await self._get_all_chunks_for_document(source_id)
145 if not source_chunks:
146 logger.warning(f"No chunks found for source ID: {source_id}")
147 return Summary(source_id=source_id, content='')
149 logger.debug(f"Summarizing {len(source_chunks)} chunks for source ID: {source_id}")
150 summary = await map_reduce_documents_chain.ainvoke(source_chunks)
151 content = summary.get('output_text', '')
152 logger.debug(f"Generated summary for source ID {source_id}: {content[:100]}...")
154 # Stream summarization update.
155 dispatch_custom_event('summary', {'source_id': source_id, 'content': content})
157 return Summary(source_id=source_id, content=content)
159 async def _get_source_summaries(self, source_ids: List[str], map_reduce_documents_chain: MapReduceDocumentsChain) -> \
160 List[Summary]:
161 summaries = await asyncio.gather(
162 *[self._get_source_summary(source_id, map_reduce_documents_chain) for source_id in source_ids]
163 )
164 return summaries
166 def _call(
167 self,
168 inputs: Dict[str, Any],
169 run_manager: Optional[CallbackManagerForChainRun] = None
170 ) -> Dict[str, Any]:
171 # Step 1: Connect to vector store to ensure embeddings are accessible
172 self.vector_store_handler.connect()
174 context_chunks: List[Document] = inputs.get(self.context_key, [])
175 logger.debug(f"Found {len(context_chunks)} context chunks.")
177 # Step 2: Extract unique document IDs from the provided chunks
178 unique_document_ids = self._get_document_ids_from_chunks(context_chunks)
179 logger.debug(f"Extracted {len(unique_document_ids)} unique document IDs: {unique_document_ids}")
181 # Step 3: Initialize the summarization chain if not provided
182 question = inputs.get(self.question_key, '')
183 map_reduce_documents_chain = self.map_reduce_documents_chain or create_map_reduce_documents_chain(
184 self.summarization_config, question
185 )
187 # Step 4: Dispatch event to signal summarization start
188 if run_manager:
189 run_manager.on_text("Starting summarization for documents.", verbose=True)
191 # Step 5: Process each document ID to summarize chunks with local context
192 for doc_id in unique_document_ids:
193 logger.debug(f"Fetching and summarizing chunks for document ID: {doc_id}")
195 # Fetch all chunks for the document
196 chunks = asyncio.get_event_loop().run_until_complete(self._get_all_chunks_for_document(doc_id))
197 if not chunks:
198 logger.warning(f"No chunks found for document ID: {doc_id}")
199 continue
201 # Summarize each chunk with M neighboring chunks
202 M = self.neighbor_window
203 for i, chunk in enumerate(chunks):
204 window_chunks = chunks[max(0, i - M): min(len(chunks), i + M + 1)]
205 local_summary = asyncio.get_event_loop().run_until_complete(
206 map_reduce_documents_chain.ainvoke(window_chunks)
207 )
208 chunk.metadata['summary'] = local_summary.get('output_text', '')
209 logger.debug(f"Chunk {i} summary: {chunk.metadata['summary'][:100]}...")
211 # Step 6: Update the original context chunks with the newly generated summaries
212 for chunk in context_chunks:
213 doc_id = str(chunk.metadata.get(self.doc_id_key, ''))
214 matching_chunk = next((c for c in chunks if c.metadata.get(self.doc_id_key) == doc_id and c.metadata.get(
215 'chunk_index') == chunk.metadata.get('chunk_index')), None)
216 if matching_chunk:
217 chunk.metadata['summary'] = matching_chunk.metadata.get('summary', '')
218 else:
219 chunk.metadata['summary'] = ''
220 logger.warning(f"No matching chunk found for doc_id: {doc_id}")
222 # Step 7: Signal summarization end
223 if run_manager:
224 run_manager.on_text("Summarization completed.", verbose=True)
226 logger.debug(f"Updated {len(context_chunks)} context chunks with summaries.")
227 return inputs