Coverage for mindsdb / integrations / utilities / rag / rag_pipeline_builder.py: 44%
42 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import pandas as pd
2from langchain.storage import InMemoryByteStore
3from langchain_core.runnables import RunnableSerializable
4from mindsdb.integrations.utilities.rag.pipelines.rag import LangChainRAGPipeline
5from mindsdb.integrations.utilities.rag.settings import (
6 RetrieverType,
7 RAGPipelineModel
8)
9from mindsdb.integrations.utilities.rag.utils import documents_to_df
10from mindsdb.integrations.utilities.rag.retrievers.multi_hop_retriever import MultiHopRetriever
11from mindsdb.utilities.log import getLogger
12from langchain_text_splitters import RecursiveCharacterTextSplitter
14logger = getLogger(__name__)
16_retriever_strategies = {
17 RetrieverType.VECTOR_STORE: lambda config: _create_pipeline_from_vector_store(config),
18 RetrieverType.AUTO: lambda config: _create_pipeline_from_auto_retriever(config),
19 RetrieverType.MULTI: lambda config: _create_pipeline_from_multi_retriever(config),
20 RetrieverType.SQL: lambda config: _create_pipeline_from_sql_retriever(config),
21 RetrieverType.MULTI_HOP: lambda config: _create_pipeline_from_multi_hop_retriever(config)
22}
25def _create_pipeline_from_vector_store(config: RAGPipelineModel) -> LangChainRAGPipeline:
27 return LangChainRAGPipeline.from_retriever(
28 config=config
29 )
32def _create_pipeline_from_auto_retriever(config: RAGPipelineModel) -> LangChainRAGPipeline:
33 return LangChainRAGPipeline.from_auto_retriever(
34 config=config
35 )
38def _create_pipeline_from_multi_retriever(config: RAGPipelineModel) -> LangChainRAGPipeline:
40 if config.text_splitter is None:
41 config.text_splitter = RecursiveCharacterTextSplitter(
42 chunk_size=config.chunk_size, chunk_overlap=config.chunk_overlap
43 )
44 if config.parent_store is None:
45 config.parent_store = InMemoryByteStore()
47 return LangChainRAGPipeline.from_multi_vector_retriever(
48 config=config
49 )
52def _create_pipeline_from_sql_retriever(config: RAGPipelineModel) -> LangChainRAGPipeline:
53 return LangChainRAGPipeline.from_sql_retriever(
54 config=config
55 )
58def _create_pipeline_from_multi_hop_retriever(config: RAGPipelineModel) -> LangChainRAGPipeline:
59 retriever = MultiHopRetriever.from_config(config)
60 return LangChainRAGPipeline(
61 retriever_runnable=retriever,
62 prompt_template=config.rag_prompt_template,
63 llm=config.llm,
64 reranker_config=config.reranker_config,
65 reranker=config.reranker,
66 vector_store_config=config.vector_store_config,
67 summarization_config=config.summarization_config
68 )
71def _process_documents_to_df(config: RAGPipelineModel) -> pd.DataFrame:
72 return documents_to_df(config.content_column_name,
73 config.documents,
74 embedding_model=config.embedding_model,
75 with_embeddings=True)
78def get_pipeline_from_retriever(config: RAGPipelineModel) -> RunnableSerializable:
79 retriever_strategy = _retriever_strategies.get(config.retriever_type)
80 if retriever_strategy:
81 return retriever_strategy(config).with_returned_sources()
82 else:
83 raise ValueError(
84 f'Invalid retriever type, must be one of: {list(_retriever_strategies.keys())}. Got {config.retriever_type}')
87class RAG:
88 def __init__(self, config: RAGPipelineModel):
89 self.pipeline = get_pipeline_from_retriever(config)
91 def __call__(self, question: str) -> dict:
92 logger.info(f"Processing question using rag pipeline: {question}")
93 result = self.pipeline.invoke(question)
95 returned_sources = [docs.page_content for docs in result['context']]
96 logger.info(f"retrieved context used to answer question: {returned_sources}")
98 return result