Coverage for mindsdb / integrations / utilities / rag / rag_pipeline_builder.py: 44%

42 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import pandas as pd 

2from langchain.storage import InMemoryByteStore 

3from langchain_core.runnables import RunnableSerializable 

4from mindsdb.integrations.utilities.rag.pipelines.rag import LangChainRAGPipeline 

5from mindsdb.integrations.utilities.rag.settings import ( 

6 RetrieverType, 

7 RAGPipelineModel 

8) 

9from mindsdb.integrations.utilities.rag.utils import documents_to_df 

10from mindsdb.integrations.utilities.rag.retrievers.multi_hop_retriever import MultiHopRetriever 

11from mindsdb.utilities.log import getLogger 

12from langchain_text_splitters import RecursiveCharacterTextSplitter 

13 

14logger = getLogger(__name__) 

15 

16_retriever_strategies = { 

17 RetrieverType.VECTOR_STORE: lambda config: _create_pipeline_from_vector_store(config), 

18 RetrieverType.AUTO: lambda config: _create_pipeline_from_auto_retriever(config), 

19 RetrieverType.MULTI: lambda config: _create_pipeline_from_multi_retriever(config), 

20 RetrieverType.SQL: lambda config: _create_pipeline_from_sql_retriever(config), 

21 RetrieverType.MULTI_HOP: lambda config: _create_pipeline_from_multi_hop_retriever(config) 

22} 

23 

24 

25def _create_pipeline_from_vector_store(config: RAGPipelineModel) -> LangChainRAGPipeline: 

26 

27 return LangChainRAGPipeline.from_retriever( 

28 config=config 

29 ) 

30 

31 

32def _create_pipeline_from_auto_retriever(config: RAGPipelineModel) -> LangChainRAGPipeline: 

33 return LangChainRAGPipeline.from_auto_retriever( 

34 config=config 

35 ) 

36 

37 

38def _create_pipeline_from_multi_retriever(config: RAGPipelineModel) -> LangChainRAGPipeline: 

39 

40 if config.text_splitter is None: 

41 config.text_splitter = RecursiveCharacterTextSplitter( 

42 chunk_size=config.chunk_size, chunk_overlap=config.chunk_overlap 

43 ) 

44 if config.parent_store is None: 

45 config.parent_store = InMemoryByteStore() 

46 

47 return LangChainRAGPipeline.from_multi_vector_retriever( 

48 config=config 

49 ) 

50 

51 

52def _create_pipeline_from_sql_retriever(config: RAGPipelineModel) -> LangChainRAGPipeline: 

53 return LangChainRAGPipeline.from_sql_retriever( 

54 config=config 

55 ) 

56 

57 

58def _create_pipeline_from_multi_hop_retriever(config: RAGPipelineModel) -> LangChainRAGPipeline: 

59 retriever = MultiHopRetriever.from_config(config) 

60 return LangChainRAGPipeline( 

61 retriever_runnable=retriever, 

62 prompt_template=config.rag_prompt_template, 

63 llm=config.llm, 

64 reranker_config=config.reranker_config, 

65 reranker=config.reranker, 

66 vector_store_config=config.vector_store_config, 

67 summarization_config=config.summarization_config 

68 ) 

69 

70 

71def _process_documents_to_df(config: RAGPipelineModel) -> pd.DataFrame: 

72 return documents_to_df(config.content_column_name, 

73 config.documents, 

74 embedding_model=config.embedding_model, 

75 with_embeddings=True) 

76 

77 

78def get_pipeline_from_retriever(config: RAGPipelineModel) -> RunnableSerializable: 

79 retriever_strategy = _retriever_strategies.get(config.retriever_type) 

80 if retriever_strategy: 

81 return retriever_strategy(config).with_returned_sources() 

82 else: 

83 raise ValueError( 

84 f'Invalid retriever type, must be one of: {list(_retriever_strategies.keys())}. Got {config.retriever_type}') 

85 

86 

87class RAG: 

88 def __init__(self, config: RAGPipelineModel): 

89 self.pipeline = get_pipeline_from_retriever(config) 

90 

91 def __call__(self, question: str) -> dict: 

92 logger.info(f"Processing question using rag pipeline: {question}") 

93 result = self.pipeline.invoke(question) 

94 

95 returned_sources = [docs.page_content for docs in result['context']] 

96 logger.info(f"retrieved context used to answer question: {returned_sources}") 

97 

98 return result