Coverage for mindsdb / integrations / utilities / rag / config_loader.py: 91%
33 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1"""Utility functions for RAG pipeline configuration"""
3from typing import Dict, Any, Optional
5from mindsdb.utilities.log import getLogger
6from mindsdb.integrations.utilities.rag.settings import (
7 RetrieverType,
8 MultiVectorRetrieverMode,
9 SearchType,
10 SearchKwargs,
11 SummarizationConfig,
12 VectorStoreConfig,
13 RerankerConfig,
14 RAGPipelineModel,
15 DEFAULT_COLLECTION_NAME,
16)
18logger = getLogger(__name__)
21def load_rag_config(
22 base_config: Dict[str, Any], kb_params: Optional[Dict[str, Any]] = None, embedding_model: Any = None
23) -> RAGPipelineModel:
24 """
25 Load and validate RAG configuration parameters. This function handles the conversion of configuration
26 parameters into their appropriate types and ensures all required settings are properly configured.
28 Args:
29 base_config: Base configuration dictionary containing RAG pipeline settings
30 kb_params: Optional knowledge base parameters to merge with base config
31 embedding_model: Optional embedding model instance to use in the RAG pipeline
33 Returns:
34 RAGPipelineModel: Validated RAG configuration model ready for pipeline creation
36 Raises:
37 ValueError: If configuration validation fails or required parameters are missing
38 """
39 # Create a shallow copy of the base config to avoid modifying the original
40 # We avoid deepcopy because some objects (like embedding_model) may contain unpickleable objects
41 rag_params = base_config.copy()
43 # Merge with knowledge base params if provided
44 if kb_params:
45 rag_params.update(kb_params)
47 # Set embedding model if provided
48 if embedding_model is not None:
49 rag_params["embedding_model"] = embedding_model
51 # Handle enums and type conversions
52 if "retriever_type" in rag_params:
53 rag_params["retriever_type"] = RetrieverType(rag_params["retriever_type"])
54 if "multi_retriever_mode" in rag_params:
55 rag_params["multi_retriever_mode"] = MultiVectorRetrieverMode(rag_params["multi_retriever_mode"])
56 if "search_type" in rag_params:
57 rag_params["search_type"] = SearchType(rag_params["search_type"])
59 # Handle search kwargs if present
60 if "search_kwargs" in rag_params and isinstance(rag_params["search_kwargs"], dict):
61 rag_params["search_kwargs"] = SearchKwargs(**rag_params["search_kwargs"])
63 # Handle summarization config if present
64 summarization_config = rag_params.get("summarization_config")
65 if summarization_config is not None and isinstance(summarization_config, dict): 65 ↛ 66line 65 didn't jump to line 66 because the condition on line 65 was never true
66 rag_params["summarization_config"] = SummarizationConfig(**summarization_config)
68 # Handle vector store config
69 if "vector_store_config" in rag_params:
70 if isinstance(rag_params["vector_store_config"], dict): 70 ↛ 80line 70 didn't jump to line 80 because the condition on line 70 was always true
71 rag_params["vector_store_config"] = VectorStoreConfig(**rag_params["vector_store_config"])
72 else:
73 rag_params["vector_store_config"] = {}
74 logger.warning(
75 f"No collection_name specified for the retrieval tool, "
76 f"using default collection_name: '{DEFAULT_COLLECTION_NAME}'"
77 f"\nWarning: If this collection does not exist, no data will be retrieved"
78 )
80 if "reranker_config" in rag_params: 80 ↛ 81line 80 didn't jump to line 81 because the condition on line 80 was never true
81 rag_params["reranker_config"] = RerankerConfig(**rag_params["reranker_config"])
83 # Convert to RAGPipelineModel with validation
84 try:
85 return RAGPipelineModel(**rag_params)
86 except Exception as e:
87 logger.exception("Invalid RAG configuration:")
88 raise ValueError(f"Configuration validation failed: {str(e)}") from e