Coverage for mindsdb / integrations / utilities / rag / config_loader.py: 91%

33 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1"""Utility functions for RAG pipeline configuration""" 

2 

3from typing import Dict, Any, Optional 

4 

5from mindsdb.utilities.log import getLogger 

6from mindsdb.integrations.utilities.rag.settings import ( 

7 RetrieverType, 

8 MultiVectorRetrieverMode, 

9 SearchType, 

10 SearchKwargs, 

11 SummarizationConfig, 

12 VectorStoreConfig, 

13 RerankerConfig, 

14 RAGPipelineModel, 

15 DEFAULT_COLLECTION_NAME, 

16) 

17 

18logger = getLogger(__name__) 

19 

20 

21def load_rag_config( 

22 base_config: Dict[str, Any], kb_params: Optional[Dict[str, Any]] = None, embedding_model: Any = None 

23) -> RAGPipelineModel: 

24 """ 

25 Load and validate RAG configuration parameters. This function handles the conversion of configuration 

26 parameters into their appropriate types and ensures all required settings are properly configured. 

27 

28 Args: 

29 base_config: Base configuration dictionary containing RAG pipeline settings 

30 kb_params: Optional knowledge base parameters to merge with base config 

31 embedding_model: Optional embedding model instance to use in the RAG pipeline 

32 

33 Returns: 

34 RAGPipelineModel: Validated RAG configuration model ready for pipeline creation 

35 

36 Raises: 

37 ValueError: If configuration validation fails or required parameters are missing 

38 """ 

39 # Create a shallow copy of the base config to avoid modifying the original 

40 # We avoid deepcopy because some objects (like embedding_model) may contain unpickleable objects 

41 rag_params = base_config.copy() 

42 

43 # Merge with knowledge base params if provided 

44 if kb_params: 

45 rag_params.update(kb_params) 

46 

47 # Set embedding model if provided 

48 if embedding_model is not None: 

49 rag_params["embedding_model"] = embedding_model 

50 

51 # Handle enums and type conversions 

52 if "retriever_type" in rag_params: 

53 rag_params["retriever_type"] = RetrieverType(rag_params["retriever_type"]) 

54 if "multi_retriever_mode" in rag_params: 

55 rag_params["multi_retriever_mode"] = MultiVectorRetrieverMode(rag_params["multi_retriever_mode"]) 

56 if "search_type" in rag_params: 

57 rag_params["search_type"] = SearchType(rag_params["search_type"]) 

58 

59 # Handle search kwargs if present 

60 if "search_kwargs" in rag_params and isinstance(rag_params["search_kwargs"], dict): 

61 rag_params["search_kwargs"] = SearchKwargs(**rag_params["search_kwargs"]) 

62 

63 # Handle summarization config if present 

64 summarization_config = rag_params.get("summarization_config") 

65 if summarization_config is not None and isinstance(summarization_config, dict): 65 ↛ 66line 65 didn't jump to line 66 because the condition on line 65 was never true

66 rag_params["summarization_config"] = SummarizationConfig(**summarization_config) 

67 

68 # Handle vector store config 

69 if "vector_store_config" in rag_params: 

70 if isinstance(rag_params["vector_store_config"], dict): 70 ↛ 80line 70 didn't jump to line 80 because the condition on line 70 was always true

71 rag_params["vector_store_config"] = VectorStoreConfig(**rag_params["vector_store_config"]) 

72 else: 

73 rag_params["vector_store_config"] = {} 

74 logger.warning( 

75 f"No collection_name specified for the retrieval tool, " 

76 f"using default collection_name: '{DEFAULT_COLLECTION_NAME}'" 

77 f"\nWarning: If this collection does not exist, no data will be retrieved" 

78 ) 

79 

80 if "reranker_config" in rag_params: 80 ↛ 81line 80 didn't jump to line 81 because the condition on line 80 was never true

81 rag_params["reranker_config"] = RerankerConfig(**rag_params["reranker_config"]) 

82 

83 # Convert to RAGPipelineModel with validation 

84 try: 

85 return RAGPipelineModel(**rag_params) 

86 except Exception as e: 

87 logger.exception("Invalid RAG configuration:") 

88 raise ValueError(f"Configuration validation failed: {str(e)}") from e