Coverage for mindsdb / integrations / utilities / rag / retrievers / multi_vector_retriever.py: 28%
55 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from typing import List, Tuple
2import uuid
4from langchain.retrievers.multi_vector import MultiVectorRetriever as LangChainMultiVectorRetriever
5from langchain_core.documents import Document
6from langchain_core.prompts import ChatPromptTemplate
7from langchain_openai import ChatOpenAI
9from mindsdb.integrations.utilities.rag.retrievers.base import BaseRetriever
10from mindsdb.integrations.utilities.rag.settings import DEFAULT_LLM_MODEL, \
11 MultiVectorRetrieverMode, RAGPipelineModel
12from mindsdb.integrations.utilities.rag.vector_store import VectorStoreOperator
13from mindsdb.interfaces.agents.safe_output_parser import SafeOutputParser
16class MultiVectorRetriever(BaseRetriever):
17 """
18 MultiVectorRetriever stores multiple vectors per document.
19 """
21 def __init__(self, config: RAGPipelineModel):
22 self.vectorstore = config.vector_store
23 self.parent_store = config.parent_store
24 self.id_key = config.id_key
25 self.documents = config.documents
26 self.text_splitter = config.text_splitter
27 self.embedding_model = config.embedding_model
28 self.max_concurrency = config.max_concurrency
29 self.mode = config.multi_retriever_mode
31 def _generate_id_and_split_document(self, doc: Document) -> Tuple[str, List[Document]]:
32 """
33 Generate a unique id for the document and split it into sub-documents.
34 :param doc:
35 :return:
36 """
37 doc_id = str(uuid.uuid4())
38 sub_docs = self.text_splitter.split_documents([doc])
39 for sub_doc in sub_docs:
40 sub_doc.metadata[self.id_key] = doc_id
41 return doc_id, sub_docs
43 def _split_documents(self) -> Tuple[List[Document], List[str]]:
44 """
45 Split the documents into sub-documents and generate unique ids for each document.
46 :return:
47 """
48 split_info = list(map(self._generate_id_and_split_document, self.documents))
49 doc_ids, split_docs_lists = zip(*split_info)
50 split_docs = [doc for sublist in split_docs_lists for doc in sublist]
51 return split_docs, list(doc_ids)
53 def _create_retriever_and_vs_operator(self, docs: List[Document]) \
54 -> Tuple[LangChainMultiVectorRetriever, VectorStoreOperator]:
55 vstore_operator = VectorStoreOperator(
56 vector_store=self.vectorstore,
57 documents=docs,
58 embedding_model=self.embedding_model,
59 )
60 retriever = LangChainMultiVectorRetriever(
61 vectorstore=vstore_operator.vector_store,
62 byte_store=self.parent_store,
63 id_key=self.id_key
64 )
65 return retriever, vstore_operator
67 def _get_document_summaries(self) -> List[str]:
68 chain = (
69 {"doc": lambda x: x.page_content} # noqa: E126, E122
70 | ChatPromptTemplate.from_template("Summarize the following document:\n\n{doc}")
71 | ChatOpenAI(max_retries=0, model_name=DEFAULT_LLM_MODEL)
72 | SafeOutputParser()
73 )
74 return chain.batch(self.documents, {"max_concurrency": self.max_concurrency})
76 def as_runnable(self) -> BaseRetriever:
77 if self.mode in {MultiVectorRetrieverMode.SPLIT, MultiVectorRetrieverMode.BOTH}:
78 split_docs, doc_ids = self._split_documents()
79 retriever, vstore_operator = self._create_retriever_and_vs_operator(split_docs)
80 summaries = self._get_document_summaries()
81 summary_docs = [
82 Document(page_content=s, metadata={self.id_key: doc_ids[i]})
83 for i, s in enumerate(summaries)
84 ]
85 vstore_operator.add_documents(summary_docs)
86 retriever.docstore.mset(list(zip(doc_ids, self.documents)))
87 return retriever
89 elif self.mode == MultiVectorRetrieverMode.SUMMARIZE:
90 summaries = self._get_document_summaries()
91 doc_ids = [str(uuid.uuid4()) for _ in self.documents]
92 summary_docs = [
93 Document(page_content=s, metadata={self.id_key: doc_ids[i]})
94 for i, s in enumerate(summaries)
95 ]
96 retriever, vstore_operator = self._create_retriever_and_vs_operator(summary_docs)
97 retriever.docstore.mset(list(zip(doc_ids, self.documents)))
98 return retriever
100 else:
101 raise ValueError(f"Invalid mode: {self.mode}")