Coverage for mindsdb / integrations / utilities / rag / rerankers / reranker_compressor.py: 22%
53 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from __future__ import annotations
3import asyncio
4import logging
5from typing import Any, Dict, Optional, Sequence
7from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
8from langchain_core.callbacks import Callbacks, dispatch_custom_event
9from langchain_core.documents import Document
11from mindsdb.integrations.utilities.rag.rerankers.base_reranker import BaseLLMReranker
13log = logging.getLogger(__name__)
16class LLMReranker(BaseDocumentCompressor, BaseLLMReranker):
17 remove_irrelevant: bool = True # New flag to control removal of irrelevant documents
19 def _dispatch_rerank_event(self, data):
20 dispatch_custom_event("rerank", data)
22 async def acompress_documents(
23 self,
24 documents: Sequence[Document],
25 query: str,
26 callbacks: Optional[Callbacks] = None,
27 ) -> Sequence[Document]:
28 """Async compress documents using reranking with proper error handling."""
29 if callbacks:
30 await callbacks.on_retriever_start({"query": query}, "Reranking documents")
32 log.info(f"Async compressing documents. Initial count: {len(documents)}")
33 if not documents:
34 if callbacks:
35 await callbacks.on_retriever_end({"documents": []})
36 return []
38 # Stream reranking update.
39 dispatch_custom_event("rerank_begin", {"num_documents": len(documents)})
41 try:
42 # Prepare query-document pairs
43 query_document_pairs = [(query, doc.page_content) for doc in documents]
45 if callbacks:
46 await callbacks.on_text("Starting document reranking...")
48 # Get ranked results
49 ranked_results = await self._rank(query_document_pairs, rerank_callback=self._dispatch_rerank_event)
51 # Sort by score in descending order
52 ranked_results.sort(key=lambda x: x[1], reverse=True)
54 # Filter based on threshold and num_docs_to_keep
55 filtered_docs = []
56 for doc, score in ranked_results:
57 if score >= self.filtering_threshold:
58 matching_doc = next(d for d in documents if d.page_content == doc)
59 matching_doc.metadata = {**(matching_doc.metadata or {}), "relevance_score": score}
60 filtered_docs.append(matching_doc)
62 if callbacks:
63 await callbacks.on_text(f"Document scored {score:.2f}")
65 if self.num_docs_to_keep and len(filtered_docs) >= self.num_docs_to_keep:
66 break
68 log.info(f"Async compression complete. Final count: {len(filtered_docs)}")
70 if callbacks:
71 await callbacks.on_retriever_end({"documents": filtered_docs})
73 return filtered_docs
75 except Exception as e:
76 error_msg = "Error during async document compression:"
77 log.exception(error_msg)
78 if callbacks:
79 await callbacks.on_retriever_error(f"{error_msg} {e}")
80 return documents # Return original documents on error
82 def compress_documents(
83 self,
84 documents: Sequence[Document],
85 query: str,
86 callbacks: Optional[Callbacks] = None,
87 ) -> Sequence[Document]:
88 """Sync wrapper for async compression."""
89 return asyncio.run(self.acompress_documents(documents, query, callbacks))
91 @property
92 def _identifying_params(self) -> Dict[str, Any]:
93 """Get the identifying parameters."""
94 return {
95 "model": self.model,
96 "temperature": self.temperature,
97 "remove_irrelevant": self.remove_irrelevant,
98 "method": self.method,
99 }