Coverage for mindsdb / integrations / utilities / rag / retrievers / multi_hop_retriever.py: 34%
41 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from typing import List, Optional
3import json
4from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
5from langchain_core.documents import Document
6from langchain_core.language_models import BaseChatModel
7from langchain_core.retrievers import BaseRetriever
8from pydantic import Field, PrivateAttr
10from mindsdb.integrations.utilities.rag.settings import (
11 RAGPipelineModel,
12 DEFAULT_QUESTION_REFORMULATION_TEMPLATE
13)
14from mindsdb.integrations.utilities.rag.retrievers.retriever_factory import create_retriever
17class MultiHopRetriever(BaseRetriever):
18 """A retriever that implements multi-hop question reformulation strategy.
20 This retriever takes a base retriever and uses an LLM to generate follow-up
21 questions based on the initial results. It then retrieves documents for each
22 follow-up question and combines all results.
23 """
25 base_retriever: BaseRetriever = Field(description="Base retriever to use for document lookup")
26 llm: BaseChatModel = Field(description="LLM to use for generating follow-up questions")
27 max_hops: int = Field(default=3, description="Maximum number of follow-up questions to generate")
28 reformulation_template: str = Field(
29 default=DEFAULT_QUESTION_REFORMULATION_TEMPLATE,
30 description="Template for reformulating questions"
31 )
33 _asked_questions: set = PrivateAttr(default_factory=set)
35 @classmethod
36 def from_config(cls, config: RAGPipelineModel) -> "MultiHopRetriever":
37 """Create a MultiHopRetriever from a RAGPipelineModel config."""
38 if config.multi_hop_config is None:
39 raise ValueError("multi_hop_config must be set for MultiHopRetriever")
41 # Create base retriever based on type
42 base_retriever = create_retriever(config, config.multi_hop_config.base_retriever_type)
44 return cls(
45 base_retriever=base_retriever,
46 llm=config.llm,
47 max_hops=config.multi_hop_config.max_hops,
48 reformulation_template=config.multi_hop_config.reformulation_template
49 )
51 def _get_relevant_documents(
52 self, query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun] = None
53 ) -> List[Document]:
54 """Get relevant documents using multi-hop retrieval."""
55 if query in self._asked_questions:
56 return []
58 self._asked_questions.add(query)
60 # Get initial documents
61 docs = self.base_retriever._get_relevant_documents(query)
62 if not docs or len(self._asked_questions) >= self.max_hops:
63 return docs
65 # Generate follow-up questions
66 context = "\n".join(doc.page_content for doc in docs)
67 prompt = self.reformulation_template.format(
68 question=query,
69 context=context
70 )
72 try:
73 follow_up_questions = json.loads(self.llm.invoke(prompt))
74 if not isinstance(follow_up_questions, list):
75 return docs
76 except (json.JSONDecodeError, TypeError):
77 return docs
79 # Get documents for follow-up questions
80 for question in follow_up_questions:
81 if isinstance(question, str):
82 follow_up_docs = self._get_relevant_documents(question)
83 docs.extend(follow_up_docs)
85 return docs