Coverage for mindsdb / interfaces / skills / retrieval_tool.py: 13%
115 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from langchain_core.documents import Document
2from langchain_core.tools import Tool
4from mindsdb.integrations.utilities.rag.rag_pipeline_builder import RAG
5from mindsdb.integrations.utilities.rag.config_loader import load_rag_config
6from mindsdb.integrations.utilities.rag.settings import RAGPipelineModel
7from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator
8from mindsdb.integrations.libs.response import RESPONSE_TYPE
9from mindsdb.integrations.handlers.langchain_embedding_handler.langchain_embedding_handler import (
10 construct_model_from_args,
11)
12from mindsdb.interfaces.agents.constants import get_default_embeddings_model_class
13from mindsdb.interfaces.skills.skill_tool import skill_tool
14from mindsdb.interfaces.storage import db
15from mindsdb.interfaces.storage.db import KnowledgeBase
16from mindsdb.utilities import log
18logger = log.getLogger(__name__)
21def _load_rag_config(tool: dict, pred_args: dict, skill: db.Skills) -> RAGPipelineModel:
22 tools_config = tool["config"]
23 tools_config.update(pred_args)
25 kb_params = {}
26 embeddings_model = None
28 if "source" in tool:
29 kb_name = tool["source"]
30 executor = skill_tool.get_command_executor()
31 kb = _get_knowledge_base(kb_name, skill.project_id, executor)
33 if not kb:
34 raise ValueError(f"Knowledge base not found: {kb_name}")
36 kb_table = executor.session.kb_controller.get_table(kb.name, kb.project_id)
37 vector_store_config = {"kb_table": kb_table}
38 is_sparse = tools_config.pop("is_sparse", None)
39 vector_size = tools_config.pop("vector_size", None)
40 if is_sparse is not None:
41 vector_store_config["is_sparse"] = is_sparse
42 if vector_size is not None:
43 vector_store_config["vector_size"] = vector_size
44 kb_params = {"vector_store_config": vector_store_config}
46 # Get embedding model from knowledge base table
47 if kb_table._kb.embedding_model:
48 # Extract embedding model args from knowledge base table
49 embedding_args = kb_table._kb.embedding_model.learn_args.get("using", {})
50 # Construct the embedding model directly
51 embeddings_model = construct_model_from_args(embedding_args)
52 logger.debug(f"Using knowledge base embedding model with args: {embedding_args}")
53 else:
54 embeddings_model_class = get_default_embeddings_model_class()
55 embeddings_model = embeddings_model_class()
56 logger.debug("Using default embedding model as knowledge base has no embedding model")
57 elif "embedding_model" not in tools_config:
58 embeddings_model_class = get_default_embeddings_model_class()
59 embeddings_model = embeddings_model_class()
60 logger.debug("Using default embedding model as no knowledge base provided")
62 # Load and validate config
63 return load_rag_config(tools_config, kb_params, embeddings_model)
66def _build_rag_pipeline_tool(tool: dict, pred_args: dict, skill: db.Skills):
67 rag_config = _load_rag_config(tool, pred_args, skill)
68 # build retriever
69 rag_pipeline = RAG(rag_config)
70 logger.debug(f"RAG pipeline created with config: {rag_config}")
72 def rag_wrapper(query: str) -> str:
73 try:
74 result = rag_pipeline(query)
75 logger.debug(f"RAG pipeline result: {result}")
76 return result["answer"]
77 except Exception as e:
78 logger.exception("Error in RAG pipeline:")
79 return f"Error in retrieval: {str(e)}"
81 # Create RAG tool
82 tools_config = tool["config"]
83 tools_config.update(pred_args)
84 return Tool(
85 func=rag_wrapper,
86 name=tool["name"],
87 description=tool["description"],
88 response_format="content",
89 # Return directly by default since we already use an LLM against retrieved context to generate a response.
90 return_direct=tools_config.get("return_direct", True),
91 )
94def _build_name_lookup_tool(tool: dict, pred_args: dict, skill: db.Skills):
95 if "source" not in tool:
96 raise ValueError("Knowledge base for tool not found")
97 kb_name = tool["source"]
98 executor = skill_tool.get_command_executor()
99 kb = _get_knowledge_base(kb_name, skill.project_id, executor)
100 if not kb:
101 raise ValueError(f"Knowledge base not found: {kb_name}")
102 kb_table = executor.session.kb_controller.get_table(kb.name, kb.project_id)
103 vector_db_handler = kb_table.get_vector_db()
105 rag_config = _load_rag_config(tool, pred_args, skill)
106 metadata_config = rag_config.metadata_config
108 def _get_document_by_name(name: str):
109 if metadata_config.name_column_index is not None:
110 tsquery_str = " & ".join(name.split(" "))
111 documents_response = vector_db_handler.native_query(
112 f"SELECT * FROM {metadata_config.table} WHERE {metadata_config.name_column_index} @@ to_tsquery('{tsquery_str}') LIMIT 1;"
113 )
114 else:
115 documents_response = vector_db_handler.native_query(
116 f"SELECT * FROM {metadata_config.table} WHERE \"{metadata_config.name_column}\" ILIKE '%{name}%' LIMIT 1;"
117 )
118 if documents_response.resp_type == RESPONSE_TYPE.ERROR:
119 raise RuntimeError(f"There was an error looking up documents: {documents_response.error_message}")
120 if documents_response.data_frame.empty:
121 return None
122 document_row = documents_response.data_frame.head(1)
123 # Restore document from chunks, keeping in mind max context.
124 id_filter_condition = FilterCondition(
125 f"{metadata_config.embeddings_metadata_column}->>'{metadata_config.doc_id_key}'",
126 FilterOperator.EQUAL,
127 str(document_row.get(metadata_config.id_column).item()),
128 )
129 document_chunks_df = vector_db_handler.select(
130 metadata_config.embeddings_table, conditions=[id_filter_condition]
131 )
132 if document_chunks_df.empty:
133 return None
134 sort_col = "chunk_id" if "chunk_id" in document_chunks_df.columns else "id"
135 document_chunks_df.sort_values(by=sort_col)
136 content = ""
137 for _, chunk in document_chunks_df.iterrows():
138 if len(content) > metadata_config.max_document_context:
139 break
140 content += chunk.get(metadata_config.content_column, "")
142 return Document(page_content=content, metadata=document_row.to_dict(orient="records")[0])
144 def _lookup_document_by_name(name: str):
145 found_document = _get_document_by_name(name)
146 if found_document is None:
147 return (
148 f"I could not find any document with name {name}. Please make sure the document name matches exactly."
149 )
150 return f"I found document {found_document.metadata.get(metadata_config.id_column)} with name {found_document.metadata.get(metadata_config.name_column)}. Here is the full document to use as context:\n\n{found_document.page_content}"
152 return Tool(
153 func=_lookup_document_by_name,
154 name=tool.get("name", "") + "_name_lookup",
155 description="You must use this tool ONLY when the user is asking about a specific document by name or title. The input should be the exact name of the document the user is looking for.",
156 return_direct=False,
157 )
160def build_retrieval_tools(tool: dict, pred_args: dict, skill: db.Skills):
161 """
162 Builds a list of tools for retrieval i.e RAG
164 Args:
165 tool: Tool configuration dictionary
166 pred_args: Predictor arguments dictionary
167 skill: Skills database object
169 Returns:
170 Tool: Configured list of retrieval tools
172 Raises:
173 ValueError: If knowledge base is not found or configuration is invalid
174 """
175 # Catch configuration errors before creating tools.
176 try:
177 rag_config = _load_rag_config(tool, pred_args, skill)
178 except Exception as e:
179 logger.exception("Error building RAG pipeline:")
180 raise ValueError(f"Failed to build RAG pipeline: {str(e)}")
181 tools = [_build_rag_pipeline_tool(tool, pred_args, skill)]
182 if rag_config.metadata_config is None:
183 return tools
184 tools.append(_build_name_lookup_tool(tool, pred_args, skill))
185 return tools
188def _get_knowledge_base(knowledge_base_name: str, project_id, executor) -> KnowledgeBase:
189 """
190 Get knowledge base by name and project ID
192 Args:
193 knowledge_base_name: Name of the knowledge base
194 project_id: Project ID
195 executor: Command executor instance
197 Returns:
198 KnowledgeBase: Knowledge base instance if found, None otherwise
199 """
200 kb = executor.session.kb_controller.get(knowledge_base_name, project_id)
201 return kb