Coverage for mindsdb / integrations / utilities / rag / rerankers / reranker_compressor.py: 22%

53 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from __future__ import annotations 

2 

3import asyncio 

4import logging 

5from typing import Any, Dict, Optional, Sequence 

6 

7from langchain.retrievers.document_compressors.base import BaseDocumentCompressor 

8from langchain_core.callbacks import Callbacks, dispatch_custom_event 

9from langchain_core.documents import Document 

10 

11from mindsdb.integrations.utilities.rag.rerankers.base_reranker import BaseLLMReranker 

12 

13log = logging.getLogger(__name__) 

14 

15 

16class LLMReranker(BaseDocumentCompressor, BaseLLMReranker): 

17 remove_irrelevant: bool = True # New flag to control removal of irrelevant documents 

18 

19 def _dispatch_rerank_event(self, data): 

20 dispatch_custom_event("rerank", data) 

21 

22 async def acompress_documents( 

23 self, 

24 documents: Sequence[Document], 

25 query: str, 

26 callbacks: Optional[Callbacks] = None, 

27 ) -> Sequence[Document]: 

28 """Async compress documents using reranking with proper error handling.""" 

29 if callbacks: 

30 await callbacks.on_retriever_start({"query": query}, "Reranking documents") 

31 

32 log.info(f"Async compressing documents. Initial count: {len(documents)}") 

33 if not documents: 

34 if callbacks: 

35 await callbacks.on_retriever_end({"documents": []}) 

36 return [] 

37 

38 # Stream reranking update. 

39 dispatch_custom_event("rerank_begin", {"num_documents": len(documents)}) 

40 

41 try: 

42 # Prepare query-document pairs 

43 query_document_pairs = [(query, doc.page_content) for doc in documents] 

44 

45 if callbacks: 

46 await callbacks.on_text("Starting document reranking...") 

47 

48 # Get ranked results 

49 ranked_results = await self._rank(query_document_pairs, rerank_callback=self._dispatch_rerank_event) 

50 

51 # Sort by score in descending order 

52 ranked_results.sort(key=lambda x: x[1], reverse=True) 

53 

54 # Filter based on threshold and num_docs_to_keep 

55 filtered_docs = [] 

56 for doc, score in ranked_results: 

57 if score >= self.filtering_threshold: 

58 matching_doc = next(d for d in documents if d.page_content == doc) 

59 matching_doc.metadata = {**(matching_doc.metadata or {}), "relevance_score": score} 

60 filtered_docs.append(matching_doc) 

61 

62 if callbacks: 

63 await callbacks.on_text(f"Document scored {score:.2f}") 

64 

65 if self.num_docs_to_keep and len(filtered_docs) >= self.num_docs_to_keep: 

66 break 

67 

68 log.info(f"Async compression complete. Final count: {len(filtered_docs)}") 

69 

70 if callbacks: 

71 await callbacks.on_retriever_end({"documents": filtered_docs}) 

72 

73 return filtered_docs 

74 

75 except Exception as e: 

76 error_msg = "Error during async document compression:" 

77 log.exception(error_msg) 

78 if callbacks: 

79 await callbacks.on_retriever_error(f"{error_msg} {e}") 

80 return documents # Return original documents on error 

81 

82 def compress_documents( 

83 self, 

84 documents: Sequence[Document], 

85 query: str, 

86 callbacks: Optional[Callbacks] = None, 

87 ) -> Sequence[Document]: 

88 """Sync wrapper for async compression.""" 

89 return asyncio.run(self.acompress_documents(documents, query, callbacks)) 

90 

91 @property 

92 def _identifying_params(self) -> Dict[str, Any]: 

93 """Get the identifying parameters.""" 

94 return { 

95 "model": self.model, 

96 "temperature": self.temperature, 

97 "remove_irrelevant": self.remove_irrelevant, 

98 "method": self.method, 

99 }