Coverage for mindsdb / integrations / utilities / rag / chains / map_reduce_summarizer_chain.py: 26%
136 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import asyncio
2from collections import namedtuple
3from typing import Any, Dict, List, Optional
5from mindsdb.interfaces.agents.langchain_agent import create_chat_model
6from langchain.chains.base import Chain
7from langchain.chains.combine_documents.stuff import StuffDocumentsChain
8from langchain.chains.llm import LLMChain
9from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain, ReduceDocumentsChain
10from langchain_core.callbacks import dispatch_custom_event
11from langchain_core.callbacks.manager import CallbackManagerForChainRun
12from langchain_core.documents import Document
13from langchain_core.prompts import PromptTemplate
14from pandas import DataFrame
16from mindsdb.integrations.libs.vectordatabase_handler import VectorStoreHandler
17from mindsdb.integrations.utilities.rag.settings import SummarizationConfig
18from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator
19from mindsdb.utilities import log
21logger = log.getLogger(__name__)
24Summary = namedtuple('Summary', ['source_id', 'content'])
27def create_map_reduce_documents_chain(summarization_config: SummarizationConfig, input: str) -> ReduceDocumentsChain:
28 '''Creats a chain that map reduces documents into a single consolidated summary
30 Args:
31 summarization_config (SummarizationConfig): Configuration for how to perform summarization
33 Returns:
34 chain (MapReduceDocumentsChain): Chain that map reduces documents.
35 '''
36 summarization_llm = create_chat_model({
37 'model_name': summarization_config.llm_config.model_name,
38 'provider': summarization_config.llm_config.provider,
39 **summarization_config.llm_config.params
40 })
41 map_prompt_template = summarization_config.map_prompt_template
42 map_prompt = PromptTemplate.from_template(map_prompt_template)
43 # Langchain needs a template with only a variable for docs so we use a partial.
44 if 'input' in map_prompt.input_variables:
45 map_prompt = map_prompt.partial(input=input)
46 # Handles summarization of individual chunks.
47 # map_chain = LLMChain(llm=summarization_llm, prompt=map_prompt)
49 reduce_prompt_template = summarization_config.reduce_prompt_template
50 reduce_prompt = PromptTemplate.from_template(reduce_prompt_template)
51 # Langchain needs a template with only a variable for docs so we use a partial.
52 if 'input' in reduce_prompt.input_variables:
53 reduce_prompt = reduce_prompt.partial(input=input)
54 # Combines summarizations from multiple chunks into a consolidated summary.
55 reduce_chain = LLMChain(llm=summarization_llm, prompt=reduce_prompt)
57 # Takes a list of docs, combines them into a single string, then passes to an LLMChain.
58 combine_documents_chain = StuffDocumentsChain(
59 llm_chain=reduce_chain,
60 document_variable_name='docs'
61 )
63 # Combines & iteratively reduces mapped documents.
64 return ReduceDocumentsChain(
65 combine_documents_chain=combine_documents_chain,
66 collapse_documents_chain=combine_documents_chain,
67 # Max number of tokens to group documents into.
68 token_max=summarization_config.max_summarization_tokens
69 )
72class MapReduceSummarizerChain(Chain):
73 '''Chain to summarize the source documents for document chunks & return as context'''
75 context_key: str = 'context'
76 metadata_key: str = 'metadata'
77 doc_id_key: str = 'original_row_id'
78 question_key: str = 'question'
80 vector_store_handler: VectorStoreHandler
81 table_name: str = 'embeddings'
82 id_column_name: str = 'id'
83 content_column_name: str = 'content'
84 metadata_column_name: str = 'metadata'
86 summarization_config: SummarizationConfig
87 map_reduce_documents_chain: Optional[MapReduceDocumentsChain] = None
89 @property
90 def input_keys(self) -> List[str]:
91 return [self.context_key, self.question_key]
93 @property
94 def output_keys(self) -> List[str]:
95 return [self.context_key, self.question_key]
97 def _get_document_ids_from_chunks(self, chunks: List[Document]) -> List[str]:
98 unique_document_ids = set()
99 document_ids = []
100 logger.debug(f"Processing {len(chunks)} chunks to extract document IDs")
101 for chunk in chunks:
102 if not chunk.metadata:
103 chunk.metadata = {}
104 logger.warning("Chunk metadata was empty, creating new metadata dictionary")
105 metadata = chunk.metadata
106 doc_id = str(metadata.get(self.doc_id_key, ''))
107 logger.debug(f"Processing chunk with metadata: {metadata}, extracted doc_id: {doc_id}")
108 if doc_id and doc_id not in unique_document_ids:
109 # Sets in Python don't guarantee preserved order, so we use a list to make testing easier.
110 document_ids.append(doc_id)
111 unique_document_ids.add(doc_id)
112 logger.debug(f"Found {len(document_ids)} unique document IDs: {document_ids}")
113 return document_ids
115 def _select_chunks_from_vector_store(self, conditions: List[FilterCondition]) -> DataFrame:
116 return self.vector_store_handler.select(
117 self.table_name,
118 columns=[self.content_column_name, self.metadata_column_name],
119 conditions=conditions
120 )
122 async def _get_all_chunks_for_document(self, id: str) -> List[Document]:
123 logger.debug(f"Fetching all chunks for document ID: {id}")
124 id_filter_condition = FilterCondition(
125 f"{self.metadata_column_name}->>'{self.doc_id_key}'",
126 FilterOperator.EQUAL,
127 id
128 )
129 all_source_chunks = await asyncio.get_event_loop().run_in_executor(None, self._select_chunks_from_vector_store, [id_filter_condition])
130 document_chunks = []
131 for _, row in all_source_chunks.iterrows():
132 metadata = row.get(self.metadata_column_name, {})
133 if row.get('chunk_id', None) is not None:
134 metadata['chunk_index'] = row.get('chunk_id', 0)
135 document_chunks.append(Document(page_content=row[self.content_column_name], metadata=metadata))
136 # Sort by chunk index if present in metadata so the full document is in its original order.
137 document_chunks.sort(key=lambda doc: doc.metadata.get('chunk_index', 0) if doc.metadata else 0)
138 logger.debug(f"Found {len(document_chunks)} chunks for document ID {id}")
139 return document_chunks
141 async def _get_source_summary(self, source_id: str, map_reduce_documents_chain: MapReduceDocumentsChain) -> Summary:
142 if not source_id:
143 logger.warning("Received empty source_id, returning empty summary")
144 return Summary(source_id='', content='')
146 logger.debug(f"Getting summary for source ID: {source_id}")
147 source_chunks = await self._get_all_chunks_for_document(source_id)
149 if not source_chunks:
150 logger.warning(f"No chunks found for source ID: {source_id}")
151 return Summary(source_id=source_id, content='')
153 logger.debug(f"Summarizing {len(source_chunks)} chunks for source ID: {source_id}")
154 summary = await map_reduce_documents_chain.ainvoke(source_chunks)
155 content = summary.get('output_text', '')
156 logger.debug(f"Generated summary for source ID {source_id}: {content[:100]}...")
158 # Stream summarization update.
159 dispatch_custom_event('summary', {'source_id': source_id, 'content': content})
161 return Summary(source_id=source_id, content=content)
163 async def _get_source_summaries(self, source_ids: List[str], map_reduce_documents_chain: MapReduceDocumentsChain) -> List[Summary]:
164 summaries = await asyncio.gather(
165 *[self._get_source_summary(source_id, map_reduce_documents_chain) for source_id in source_ids]
166 )
167 return summaries
169 def _call(
170 self,
171 inputs: Dict[str, Any],
172 run_manager: Optional[CallbackManagerForChainRun] = None
173 ) -> Dict[str, Any]:
174 # Explicitly connect to make sure vectors are registered.
175 _ = self.vector_store_handler.connect()
176 logger.debug(f"Processing inputs with keys: {list(inputs.keys())}")
177 context_chunks = inputs.get(self.context_key, [])
178 logger.debug(f"Found {len(context_chunks)} context chunks")
180 unique_document_ids = self._get_document_ids_from_chunks(context_chunks)
181 logger.debug(f"Extracted {len(unique_document_ids)} unique document IDs")
183 question = inputs.get(self.question_key, '')
184 map_reduce_documents_chain = self.map_reduce_documents_chain
185 if map_reduce_documents_chain is None:
186 map_reduce_documents_chain = create_map_reduce_documents_chain(self.summarization_config, question)
187 # For each document ID associated with one or more chunks, build the full document by
188 # getting ALL chunks associated with that ID. Then, map reduce summarize the complete document.
189 dispatch_custom_event('summary_begin', {'num_documents': len(unique_document_ids)})
190 try:
191 logger.debug("Starting async summary generation")
192 summaries = asyncio.get_event_loop().run_until_complete(self._get_source_summaries(unique_document_ids, map_reduce_documents_chain))
193 except RuntimeError:
194 logger.info("No event loop available, creating new one")
195 # If no event loop is available, create a new one
196 loop = asyncio.new_event_loop()
197 asyncio.set_event_loop(loop)
198 summaries = loop.run_until_complete(self._get_source_summaries(unique_document_ids, map_reduce_documents_chain))
200 source_id_to_summary = {}
201 for summary in summaries:
202 source_id_to_summary[summary.source_id] = summary.content
203 logger.debug(f"Generated {len(source_id_to_summary)} summaries")
205 # Update context chunks with document summaries.
206 for chunk in context_chunks:
207 if not chunk.metadata:
208 chunk.metadata = {}
209 logger.warning("Chunk metadata was empty, creating new metadata dictionary")
211 metadata = chunk.metadata
212 doc_id = str(metadata.get(self.doc_id_key, ''))
213 logger.debug(f"Updating chunk with doc_id {doc_id}")
214 if doc_id in source_id_to_summary:
215 chunk.metadata['summary'] = source_id_to_summary[doc_id]
216 else:
217 logger.warning(f"No summary found for doc_id: {doc_id}")
218 chunk.metadata['summary'] = ''
220 # Stream summarization update.
221 dispatch_custom_event('summary_end', {'num_documents': len(source_id_to_summary)})
223 return inputs