Coverage for mindsdb / integrations / utilities / rag / retrievers / multi_hop_retriever.py: 34%

41 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from typing import List, Optional 

2 

3import json 

4from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun 

5from langchain_core.documents import Document 

6from langchain_core.language_models import BaseChatModel 

7from langchain_core.retrievers import BaseRetriever 

8from pydantic import Field, PrivateAttr 

9 

10from mindsdb.integrations.utilities.rag.settings import ( 

11 RAGPipelineModel, 

12 DEFAULT_QUESTION_REFORMULATION_TEMPLATE 

13) 

14from mindsdb.integrations.utilities.rag.retrievers.retriever_factory import create_retriever 

15 

16 

17class MultiHopRetriever(BaseRetriever): 

18 """A retriever that implements multi-hop question reformulation strategy. 

19 

20 This retriever takes a base retriever and uses an LLM to generate follow-up 

21 questions based on the initial results. It then retrieves documents for each 

22 follow-up question and combines all results. 

23 """ 

24 

25 base_retriever: BaseRetriever = Field(description="Base retriever to use for document lookup") 

26 llm: BaseChatModel = Field(description="LLM to use for generating follow-up questions") 

27 max_hops: int = Field(default=3, description="Maximum number of follow-up questions to generate") 

28 reformulation_template: str = Field( 

29 default=DEFAULT_QUESTION_REFORMULATION_TEMPLATE, 

30 description="Template for reformulating questions" 

31 ) 

32 

33 _asked_questions: set = PrivateAttr(default_factory=set) 

34 

35 @classmethod 

36 def from_config(cls, config: RAGPipelineModel) -> "MultiHopRetriever": 

37 """Create a MultiHopRetriever from a RAGPipelineModel config.""" 

38 if config.multi_hop_config is None: 

39 raise ValueError("multi_hop_config must be set for MultiHopRetriever") 

40 

41 # Create base retriever based on type 

42 base_retriever = create_retriever(config, config.multi_hop_config.base_retriever_type) 

43 

44 return cls( 

45 base_retriever=base_retriever, 

46 llm=config.llm, 

47 max_hops=config.multi_hop_config.max_hops, 

48 reformulation_template=config.multi_hop_config.reformulation_template 

49 ) 

50 

51 def _get_relevant_documents( 

52 self, query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun] = None 

53 ) -> List[Document]: 

54 """Get relevant documents using multi-hop retrieval.""" 

55 if query in self._asked_questions: 

56 return [] 

57 

58 self._asked_questions.add(query) 

59 

60 # Get initial documents 

61 docs = self.base_retriever._get_relevant_documents(query) 

62 if not docs or len(self._asked_questions) >= self.max_hops: 

63 return docs 

64 

65 # Generate follow-up questions 

66 context = "\n".join(doc.page_content for doc in docs) 

67 prompt = self.reformulation_template.format( 

68 question=query, 

69 context=context 

70 ) 

71 

72 try: 

73 follow_up_questions = json.loads(self.llm.invoke(prompt)) 

74 if not isinstance(follow_up_questions, list): 

75 return docs 

76 except (json.JSONDecodeError, TypeError): 

77 return docs 

78 

79 # Get documents for follow-up questions 

80 for question in follow_up_questions: 

81 if isinstance(question, str): 

82 follow_up_docs = self._get_relevant_documents(question) 

83 docs.extend(follow_up_docs) 

84 

85 return docs