Coverage for mindsdb / integrations / utilities / rag / retrievers / multi_vector_retriever.py: 28%

55 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from typing import List, Tuple 

2import uuid 

3 

4from langchain.retrievers.multi_vector import MultiVectorRetriever as LangChainMultiVectorRetriever 

5from langchain_core.documents import Document 

6from langchain_core.prompts import ChatPromptTemplate 

7from langchain_openai import ChatOpenAI 

8 

9from mindsdb.integrations.utilities.rag.retrievers.base import BaseRetriever 

10from mindsdb.integrations.utilities.rag.settings import DEFAULT_LLM_MODEL, \ 

11 MultiVectorRetrieverMode, RAGPipelineModel 

12from mindsdb.integrations.utilities.rag.vector_store import VectorStoreOperator 

13from mindsdb.interfaces.agents.safe_output_parser import SafeOutputParser 

14 

15 

16class MultiVectorRetriever(BaseRetriever): 

17 """ 

18 MultiVectorRetriever stores multiple vectors per document. 

19 """ 

20 

21 def __init__(self, config: RAGPipelineModel): 

22 self.vectorstore = config.vector_store 

23 self.parent_store = config.parent_store 

24 self.id_key = config.id_key 

25 self.documents = config.documents 

26 self.text_splitter = config.text_splitter 

27 self.embedding_model = config.embedding_model 

28 self.max_concurrency = config.max_concurrency 

29 self.mode = config.multi_retriever_mode 

30 

31 def _generate_id_and_split_document(self, doc: Document) -> Tuple[str, List[Document]]: 

32 """ 

33 Generate a unique id for the document and split it into sub-documents. 

34 :param doc: 

35 :return: 

36 """ 

37 doc_id = str(uuid.uuid4()) 

38 sub_docs = self.text_splitter.split_documents([doc]) 

39 for sub_doc in sub_docs: 

40 sub_doc.metadata[self.id_key] = doc_id 

41 return doc_id, sub_docs 

42 

43 def _split_documents(self) -> Tuple[List[Document], List[str]]: 

44 """ 

45 Split the documents into sub-documents and generate unique ids for each document. 

46 :return: 

47 """ 

48 split_info = list(map(self._generate_id_and_split_document, self.documents)) 

49 doc_ids, split_docs_lists = zip(*split_info) 

50 split_docs = [doc for sublist in split_docs_lists for doc in sublist] 

51 return split_docs, list(doc_ids) 

52 

53 def _create_retriever_and_vs_operator(self, docs: List[Document]) \ 

54 -> Tuple[LangChainMultiVectorRetriever, VectorStoreOperator]: 

55 vstore_operator = VectorStoreOperator( 

56 vector_store=self.vectorstore, 

57 documents=docs, 

58 embedding_model=self.embedding_model, 

59 ) 

60 retriever = LangChainMultiVectorRetriever( 

61 vectorstore=vstore_operator.vector_store, 

62 byte_store=self.parent_store, 

63 id_key=self.id_key 

64 ) 

65 return retriever, vstore_operator 

66 

67 def _get_document_summaries(self) -> List[str]: 

68 chain = ( 

69 {"doc": lambda x: x.page_content} # noqa: E126, E122 

70 | ChatPromptTemplate.from_template("Summarize the following document:\n\n{doc}") 

71 | ChatOpenAI(max_retries=0, model_name=DEFAULT_LLM_MODEL) 

72 | SafeOutputParser() 

73 ) 

74 return chain.batch(self.documents, {"max_concurrency": self.max_concurrency}) 

75 

76 def as_runnable(self) -> BaseRetriever: 

77 if self.mode in {MultiVectorRetrieverMode.SPLIT, MultiVectorRetrieverMode.BOTH}: 

78 split_docs, doc_ids = self._split_documents() 

79 retriever, vstore_operator = self._create_retriever_and_vs_operator(split_docs) 

80 summaries = self._get_document_summaries() 

81 summary_docs = [ 

82 Document(page_content=s, metadata={self.id_key: doc_ids[i]}) 

83 for i, s in enumerate(summaries) 

84 ] 

85 vstore_operator.add_documents(summary_docs) 

86 retriever.docstore.mset(list(zip(doc_ids, self.documents))) 

87 return retriever 

88 

89 elif self.mode == MultiVectorRetrieverMode.SUMMARIZE: 

90 summaries = self._get_document_summaries() 

91 doc_ids = [str(uuid.uuid4()) for _ in self.documents] 

92 summary_docs = [ 

93 Document(page_content=s, metadata={self.id_key: doc_ids[i]}) 

94 for i, s in enumerate(summaries) 

95 ] 

96 retriever, vstore_operator = self._create_retriever_and_vs_operator(summary_docs) 

97 retriever.docstore.mset(list(zip(doc_ids, self.documents))) 

98 return retriever 

99 

100 else: 

101 raise ValueError(f"Invalid mode: {self.mode}")