Coverage for mindsdb / interfaces / skills / retrieval_tool.py: 13%

115 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from langchain_core.documents import Document 

2from langchain_core.tools import Tool 

3 

4from mindsdb.integrations.utilities.rag.rag_pipeline_builder import RAG 

5from mindsdb.integrations.utilities.rag.config_loader import load_rag_config 

6from mindsdb.integrations.utilities.rag.settings import RAGPipelineModel 

7from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator 

8from mindsdb.integrations.libs.response import RESPONSE_TYPE 

9from mindsdb.integrations.handlers.langchain_embedding_handler.langchain_embedding_handler import ( 

10 construct_model_from_args, 

11) 

12from mindsdb.interfaces.agents.constants import get_default_embeddings_model_class 

13from mindsdb.interfaces.skills.skill_tool import skill_tool 

14from mindsdb.interfaces.storage import db 

15from mindsdb.interfaces.storage.db import KnowledgeBase 

16from mindsdb.utilities import log 

17 

18logger = log.getLogger(__name__) 

19 

20 

21def _load_rag_config(tool: dict, pred_args: dict, skill: db.Skills) -> RAGPipelineModel: 

22 tools_config = tool["config"] 

23 tools_config.update(pred_args) 

24 

25 kb_params = {} 

26 embeddings_model = None 

27 

28 if "source" in tool: 

29 kb_name = tool["source"] 

30 executor = skill_tool.get_command_executor() 

31 kb = _get_knowledge_base(kb_name, skill.project_id, executor) 

32 

33 if not kb: 

34 raise ValueError(f"Knowledge base not found: {kb_name}") 

35 

36 kb_table = executor.session.kb_controller.get_table(kb.name, kb.project_id) 

37 vector_store_config = {"kb_table": kb_table} 

38 is_sparse = tools_config.pop("is_sparse", None) 

39 vector_size = tools_config.pop("vector_size", None) 

40 if is_sparse is not None: 

41 vector_store_config["is_sparse"] = is_sparse 

42 if vector_size is not None: 

43 vector_store_config["vector_size"] = vector_size 

44 kb_params = {"vector_store_config": vector_store_config} 

45 

46 # Get embedding model from knowledge base table 

47 if kb_table._kb.embedding_model: 

48 # Extract embedding model args from knowledge base table 

49 embedding_args = kb_table._kb.embedding_model.learn_args.get("using", {}) 

50 # Construct the embedding model directly 

51 embeddings_model = construct_model_from_args(embedding_args) 

52 logger.debug(f"Using knowledge base embedding model with args: {embedding_args}") 

53 else: 

54 embeddings_model_class = get_default_embeddings_model_class() 

55 embeddings_model = embeddings_model_class() 

56 logger.debug("Using default embedding model as knowledge base has no embedding model") 

57 elif "embedding_model" not in tools_config: 

58 embeddings_model_class = get_default_embeddings_model_class() 

59 embeddings_model = embeddings_model_class() 

60 logger.debug("Using default embedding model as no knowledge base provided") 

61 

62 # Load and validate config 

63 return load_rag_config(tools_config, kb_params, embeddings_model) 

64 

65 

66def _build_rag_pipeline_tool(tool: dict, pred_args: dict, skill: db.Skills): 

67 rag_config = _load_rag_config(tool, pred_args, skill) 

68 # build retriever 

69 rag_pipeline = RAG(rag_config) 

70 logger.debug(f"RAG pipeline created with config: {rag_config}") 

71 

72 def rag_wrapper(query: str) -> str: 

73 try: 

74 result = rag_pipeline(query) 

75 logger.debug(f"RAG pipeline result: {result}") 

76 return result["answer"] 

77 except Exception as e: 

78 logger.exception("Error in RAG pipeline:") 

79 return f"Error in retrieval: {str(e)}" 

80 

81 # Create RAG tool 

82 tools_config = tool["config"] 

83 tools_config.update(pred_args) 

84 return Tool( 

85 func=rag_wrapper, 

86 name=tool["name"], 

87 description=tool["description"], 

88 response_format="content", 

89 # Return directly by default since we already use an LLM against retrieved context to generate a response. 

90 return_direct=tools_config.get("return_direct", True), 

91 ) 

92 

93 

94def _build_name_lookup_tool(tool: dict, pred_args: dict, skill: db.Skills): 

95 if "source" not in tool: 

96 raise ValueError("Knowledge base for tool not found") 

97 kb_name = tool["source"] 

98 executor = skill_tool.get_command_executor() 

99 kb = _get_knowledge_base(kb_name, skill.project_id, executor) 

100 if not kb: 

101 raise ValueError(f"Knowledge base not found: {kb_name}") 

102 kb_table = executor.session.kb_controller.get_table(kb.name, kb.project_id) 

103 vector_db_handler = kb_table.get_vector_db() 

104 

105 rag_config = _load_rag_config(tool, pred_args, skill) 

106 metadata_config = rag_config.metadata_config 

107 

108 def _get_document_by_name(name: str): 

109 if metadata_config.name_column_index is not None: 

110 tsquery_str = " & ".join(name.split(" ")) 

111 documents_response = vector_db_handler.native_query( 

112 f"SELECT * FROM {metadata_config.table} WHERE {metadata_config.name_column_index} @@ to_tsquery('{tsquery_str}') LIMIT 1;" 

113 ) 

114 else: 

115 documents_response = vector_db_handler.native_query( 

116 f"SELECT * FROM {metadata_config.table} WHERE \"{metadata_config.name_column}\" ILIKE '%{name}%' LIMIT 1;" 

117 ) 

118 if documents_response.resp_type == RESPONSE_TYPE.ERROR: 

119 raise RuntimeError(f"There was an error looking up documents: {documents_response.error_message}") 

120 if documents_response.data_frame.empty: 

121 return None 

122 document_row = documents_response.data_frame.head(1) 

123 # Restore document from chunks, keeping in mind max context. 

124 id_filter_condition = FilterCondition( 

125 f"{metadata_config.embeddings_metadata_column}->>'{metadata_config.doc_id_key}'", 

126 FilterOperator.EQUAL, 

127 str(document_row.get(metadata_config.id_column).item()), 

128 ) 

129 document_chunks_df = vector_db_handler.select( 

130 metadata_config.embeddings_table, conditions=[id_filter_condition] 

131 ) 

132 if document_chunks_df.empty: 

133 return None 

134 sort_col = "chunk_id" if "chunk_id" in document_chunks_df.columns else "id" 

135 document_chunks_df.sort_values(by=sort_col) 

136 content = "" 

137 for _, chunk in document_chunks_df.iterrows(): 

138 if len(content) > metadata_config.max_document_context: 

139 break 

140 content += chunk.get(metadata_config.content_column, "") 

141 

142 return Document(page_content=content, metadata=document_row.to_dict(orient="records")[0]) 

143 

144 def _lookup_document_by_name(name: str): 

145 found_document = _get_document_by_name(name) 

146 if found_document is None: 

147 return ( 

148 f"I could not find any document with name {name}. Please make sure the document name matches exactly." 

149 ) 

150 return f"I found document {found_document.metadata.get(metadata_config.id_column)} with name {found_document.metadata.get(metadata_config.name_column)}. Here is the full document to use as context:\n\n{found_document.page_content}" 

151 

152 return Tool( 

153 func=_lookup_document_by_name, 

154 name=tool.get("name", "") + "_name_lookup", 

155 description="You must use this tool ONLY when the user is asking about a specific document by name or title. The input should be the exact name of the document the user is looking for.", 

156 return_direct=False, 

157 ) 

158 

159 

160def build_retrieval_tools(tool: dict, pred_args: dict, skill: db.Skills): 

161 """ 

162 Builds a list of tools for retrieval i.e RAG 

163 

164 Args: 

165 tool: Tool configuration dictionary 

166 pred_args: Predictor arguments dictionary 

167 skill: Skills database object 

168 

169 Returns: 

170 Tool: Configured list of retrieval tools 

171 

172 Raises: 

173 ValueError: If knowledge base is not found or configuration is invalid 

174 """ 

175 # Catch configuration errors before creating tools. 

176 try: 

177 rag_config = _load_rag_config(tool, pred_args, skill) 

178 except Exception as e: 

179 logger.exception("Error building RAG pipeline:") 

180 raise ValueError(f"Failed to build RAG pipeline: {str(e)}") 

181 tools = [_build_rag_pipeline_tool(tool, pred_args, skill)] 

182 if rag_config.metadata_config is None: 

183 return tools 

184 tools.append(_build_name_lookup_tool(tool, pred_args, skill)) 

185 return tools 

186 

187 

188def _get_knowledge_base(knowledge_base_name: str, project_id, executor) -> KnowledgeBase: 

189 """ 

190 Get knowledge base by name and project ID 

191 

192 Args: 

193 knowledge_base_name: Name of the knowledge base 

194 project_id: Project ID 

195 executor: Command executor instance 

196 

197 Returns: 

198 KnowledgeBase: Knowledge base instance if found, None otherwise 

199 """ 

200 kb = executor.session.kb_controller.get(knowledge_base_name, project_id) 

201 return kb