Coverage for mindsdb / interfaces / skills / skill_tool.py: 11%
247 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import enum
2import inspect
4from dataclasses import dataclass
5from collections import defaultdict
6from typing import List, Dict, Optional
8from langchain_core.embeddings import Embeddings
9from langchain_core.language_models import BaseChatModel
10from mindsdb_sql_parser.ast import Select, BinaryOperation, Identifier, Constant, Star
12from mindsdb.utilities import log
13from mindsdb.utilities.cache import get_cache
14from mindsdb.utilities.config import config
15from mindsdb.interfaces.storage import db
16from mindsdb.interfaces.skills.sql_agent import SQLAgent
17from mindsdb.integrations.libs.vectordatabase_handler import TableField
18from mindsdb.interfaces.agents.constants import DEFAULT_TEXT2SQL_DATABASE
20_DEFAULT_TOP_K_SIMILARITY_SEARCH = 5
21_MAX_CACHE_SIZE = 1000
23logger = log.getLogger(__name__)
26class SkillType(enum.Enum):
27 TEXT2SQL_LEGACY = "text2sql"
28 TEXT2SQL = "sql"
29 KNOWLEDGE_BASE = "knowledge_base"
30 RETRIEVAL = "retrieval"
33@dataclass
34class SkillData:
35 """Storage for skill's data
37 Attributes:
38 name (str): name of the skill
39 type (str): skill's type (SkillType)
40 params (dict): skill's attributes
41 project_id (int): id of the project
42 agent_tables_list (Optional[List[str]]): the restriction on available tables for an agent using the skill
43 """
45 name: str
46 type: str
47 params: dict
48 project_id: int
49 agent_tables_list: Optional[List[str]]
51 @property
52 def restriction_on_tables(self) -> Optional[Dict[str, set]]:
53 """Schemas and tables which agent+skill may use. The result is intersections of skill's and agent's tables lists.
55 Returns:
56 Optional[Dict[str, set]]: allowed schemas and tables. Schemas - are keys in dict, tables - are values.
57 if result is None, then there are no restrictions
59 Raises:
60 ValueError: if there is no intersection between skill's and agent's list.
61 This means that all tables restricted for use.
62 """
64 def list_to_map(input: List) -> Dict:
65 agent_tables_map = defaultdict(set)
66 for x in input:
67 if isinstance(x, str):
68 table_name = x
69 schema_name = None
70 elif isinstance(x, dict):
71 table_name = x["table"]
72 schema_name = x.get("schema")
73 else:
74 raise ValueError(f"Unexpected value in tables list: {x}")
75 agent_tables_map[schema_name].add(table_name)
76 return agent_tables_map
78 agent_tables_map = list_to_map(self.agent_tables_list or [])
79 skill_tables_map = list_to_map(self.params.get("tables", []))
81 if len(agent_tables_map) > 0 and len(skill_tables_map) > 0:
82 if len(set(agent_tables_map) & set(skill_tables_map)) == 0:
83 raise ValueError("Skill's and agent's allowed tables list have no shared schemas.")
85 intersection_tables_map = defaultdict(set)
86 has_intersection = False
87 for schema_name in agent_tables_map:
88 if schema_name not in skill_tables_map:
89 continue
90 intersection_tables_map[schema_name] = agent_tables_map[schema_name] & skill_tables_map[schema_name]
91 if len(intersection_tables_map[schema_name]) > 0:
92 has_intersection = True
93 if has_intersection is False:
94 raise ValueError("Skill's and agent's allowed tables list have no shared tables.")
95 return intersection_tables_map
96 if len(skill_tables_map) > 0:
97 return skill_tables_map
98 if len(agent_tables_map) > 0:
99 return agent_tables_map
100 return None
103class SkillToolController:
104 def __init__(self):
105 self.command_executor = None
107 def get_command_executor(self):
108 if self.command_executor is None:
109 from mindsdb.api.executor.command_executor import ExecuteCommands
110 from mindsdb.api.executor.controllers import (
111 SessionController,
112 ) # Top-level import produces circular import in some cases TODO: figure out a fix without losing runtime improvements (context: see #9304) # noqa
114 sql_session = SessionController()
115 sql_session.database = config.get("default_project")
117 self.command_executor = ExecuteCommands(sql_session)
118 return self.command_executor
120 def _make_text_to_sql_tools(self, skills: List[db.Skills], llm) -> List:
121 """
122 Uses SQLAgent to execute tool
123 """
124 # To prevent dependency on Langchain unless an actual tool uses it.
125 try:
126 from mindsdb.interfaces.agents.mindsdb_database_agent import MindsDBSQL
127 from mindsdb.interfaces.skills.custom.text2sql.mindsdb_sql_toolkit import MindsDBSQLToolkit
128 except ImportError:
129 raise ImportError(
130 "To use the text-to-SQL skill, please install langchain with `pip install mindsdb[langchain]`"
131 )
133 command_executor = self.get_command_executor()
135 def escape_table_name(name: str) -> str:
136 name = name.strip(" `")
137 return f"`{name}`"
139 tables_list = []
140 knowledge_bases_list = []
141 ignore_knowledge_bases_list = []
143 # Track databases extracted from dot notation
144 extracted_databases = set()
146 # Initialize knowledge_base_database with default value
147 knowledge_base_database = DEFAULT_TEXT2SQL_DATABASE # Default to mindsdb project
149 # First pass: collect all database and knowledge base parameters
150 for skill in skills:
151 # Update knowledge_base_database if specified in any skill
152 if skill.params.get("knowledge_base_database"):
153 knowledge_base_database = skill.params.get("knowledge_base_database")
155 # Extract databases from include_tables with dot notation
156 if skill.params.get("include_tables"):
157 include_tables = skill.params.get("include_tables")
158 if isinstance(include_tables, str):
159 include_tables = [t.strip() for t in include_tables.split(",")]
161 # Extract database names from dot notation
162 for table in include_tables:
163 if "." in table:
164 db_name = table.split(".")[0]
165 extracted_databases.add(db_name)
167 # Extract databases from include_knowledge_bases with dot notation
168 if skill.params.get("include_knowledge_bases"):
169 include_kbs = skill.params.get("include_knowledge_bases")
170 if isinstance(include_kbs, str):
171 include_kbs = [kb.strip() for kb in include_kbs.split(",")]
173 # Extract database names from dot notation
174 for kb in include_kbs:
175 if "." in kb:
176 db_name = kb.split(".")[0]
177 if db_name != knowledge_base_database:
178 # Only update if it's different from the default
179 knowledge_base_database = db_name
181 # Second pass: collect all tables and knowledge base restrictions
182 for skill in skills:
183 # Get database for tables (this is an actual database connection)
184 database = skill.params.get("database", DEFAULT_TEXT2SQL_DATABASE)
186 # Add databases extracted from dot notation if no explicit database is provided
187 if not database and extracted_databases:
188 # Use the first extracted database if no explicit database is provided
189 database = next(iter(extracted_databases))
190 # Update the skill params with the extracted database
191 skill.params["database"] = database
193 # Extract knowledge base restrictions if they exist in the skill params
194 if skill.params.get("include_knowledge_bases"):
195 # Convert to list if it's a string
196 include_kbs = skill.params.get("include_knowledge_bases")
197 if isinstance(include_kbs, str):
198 include_kbs = [kb.strip() for kb in include_kbs.split(",")]
200 # Process each knowledge base name
201 for kb in include_kbs:
202 # If it doesn't have a dot, prefix it with the knowledge_base_database
203 if "." not in kb:
204 knowledge_bases_list.append(f"{knowledge_base_database}.{kb}")
205 else:
206 knowledge_bases_list.append(kb)
208 # Collect ignore_knowledge_bases
209 if skill.params.get("ignore_knowledge_bases"):
210 # Convert to list if it's a string
211 ignore_kbs = skill.params.get("ignore_knowledge_bases")
212 if isinstance(ignore_kbs, str):
213 ignore_kbs = [kb.strip() for kb in ignore_kbs.split(",")]
215 # Process each knowledge base name to ignore
216 for kb in ignore_kbs:
217 # If it doesn't have a dot, prefix it with the knowledge_base_database
218 if "." not in kb:
219 ignore_knowledge_bases_list.append(f"{knowledge_base_database}.{kb}")
220 else:
221 ignore_knowledge_bases_list.append(kb)
223 # Skip if no database specified
224 if not database:
225 continue
227 # Process include_tables with dot notation
228 if skill.params.get("include_tables"):
229 include_tables = skill.params.get("include_tables")
230 if isinstance(include_tables, str):
231 include_tables = [t.strip() for t in include_tables.split(",")]
233 for table in include_tables:
234 # If table already has a database prefix, use it as is
235 if "." in table:
236 # Check if the table already has backticks
237 if "`" in table:
238 tables_list.append(table)
239 else:
240 # Apply escape_table_name only to the table part
241 parts = table.split(".")
242 if len(parts) == 2:
243 # Format: database.table
244 tables_list.append(f"{parts[0]}.{escape_table_name(parts[1])}")
245 elif len(parts) == 3:
246 # Format: database.schema.table
247 tables_list.append(f"{parts[0]}.{parts[1]}.{escape_table_name(parts[2])}")
248 else:
249 # Unusual format, escape the whole thing
250 tables_list.append(escape_table_name(table))
251 else:
252 # Otherwise, prefix with the database
253 tables_list.append(f"{database}.{escape_table_name(table)}")
255 # Skip further table processing if include_tables is specified
256 continue
258 restriction_on_tables = skill.restriction_on_tables
260 if restriction_on_tables is None and database:
261 try:
262 handler = command_executor.session.integration_controller.get_data_handler(database)
263 if "all" in inspect.signature(handler.get_tables).parameters:
264 response = handler.get_tables(all=True)
265 else:
266 response = handler.get_tables()
267 # no restrictions
268 columns = [c.lower() for c in response.data_frame.columns]
269 name_idx = columns.index("table_name") if "table_name" in columns else 0
271 if "table_schema" in response.data_frame.columns:
272 for _, row in response.data_frame.iterrows():
273 tables_list.append(f"{database}.{row['table_schema']}.{escape_table_name(row[name_idx])}")
274 else:
275 for table_name in response.data_frame.iloc[:, name_idx]:
276 tables_list.append(f"{database}.{escape_table_name(table_name)}")
277 except Exception:
278 logger.warning(f"Could not get tables from database {database}:", exc_info=True)
279 continue
281 # Handle table restrictions
282 if restriction_on_tables and database:
283 for schema_name, tables in restriction_on_tables.items():
284 for table in tables:
285 # Check if the table already has dot notation (e.g., 'postgresql_conn.home_rentals')
286 if "." in table:
287 # Table already has database prefix, add it directly
288 tables_list.append(escape_table_name(table))
289 else:
290 # No dot notation, apply schema and database as needed
291 if schema_name is None:
292 tables_list.append(f"{database}.{escape_table_name(table)}")
293 else:
294 tables_list.append(f"{database}.{schema_name}.{escape_table_name(table)}")
295 continue
297 # Remove duplicates from lists
298 tables_list = list(set(tables_list))
299 knowledge_bases_list = list(set(knowledge_bases_list))
300 ignore_knowledge_bases_list = list(set(ignore_knowledge_bases_list))
302 # Determine knowledge base parameters to pass to SQLAgent
303 include_knowledge_bases = knowledge_bases_list if knowledge_bases_list else None
304 ignore_knowledge_bases = ignore_knowledge_bases_list if ignore_knowledge_bases_list else None
306 # If both include and ignore lists exist, include takes precedence
307 if include_knowledge_bases:
308 ignore_knowledge_bases = None
310 # # Get all databases from skills and extracted databases
311 # all_databases = list(set([s.params.get('database', DEFAULT_TEXT2SQL_DATABASE) for s in skills if s.params.get('database')] + list(extracted_databases)))
312 #
313 #
314 # # If no databases were specified or extracted, use 'mindsdb' as a default
315 # if not all_databases:
316 # all_databases = [DEFAULT_TEXT2SQL_DATABASE]
317 #
319 all_databases = []
320 # Filter out None values
321 all_databases = [db for db in all_databases if db is not None]
323 # Create a databases_struct dictionary that includes all extracted databases
324 databases_struct = {}
326 # First, add databases from skills with explicit database parameters
327 for skill in skills:
328 if skill.params.get("database"):
329 databases_struct[skill.params["database"]] = skill.restriction_on_tables
331 # Then, add all extracted databases with no restrictions
332 for db_name in extracted_databases:
333 if db_name not in databases_struct:
334 databases_struct[db_name] = None
336 sql_agent = SQLAgent(
337 command_executor=command_executor,
338 databases=all_databases,
339 databases_struct=databases_struct,
340 include_tables=tables_list,
341 ignore_tables=None,
342 include_knowledge_bases=include_knowledge_bases,
343 ignore_knowledge_bases=ignore_knowledge_bases,
344 knowledge_base_database=knowledge_base_database,
345 sample_rows_in_table_info=3,
346 cache=get_cache("agent", max_size=_MAX_CACHE_SIZE),
347 )
348 db = MindsDBSQL.custom_init(sql_agent=sql_agent)
349 should_include_kb_tools = include_knowledge_bases is not None and len(include_knowledge_bases) > 0
350 should_include_tables_tools = len(databases_struct) > 0 or len(tables_list) > 0
351 toolkit = MindsDBSQLToolkit(
352 db=db,
353 llm=llm,
354 include_tables_tools=should_include_tables_tools,
355 include_knowledge_base_tools=should_include_kb_tools,
356 )
357 return toolkit.get_tools()
359 def _make_retrieval_tools(self, skill: db.Skills, llm, embedding_model):
360 """
361 creates advanced retrieval tool i.e. RAG
362 """
363 params = skill.params
364 config = params.get("config", {})
365 if "llm" not in config:
366 # Set LLM if not explicitly provided in configs.
367 config["llm"] = llm
368 tool = dict(
369 name=params.get("name", skill.name),
370 source=params.get("source", None),
371 config=config,
372 description=f"You must use this tool to get more context or information "
373 f"to answer a question about {params['description']}. "
374 f"The input should be the exact question the user is asking.",
375 type=skill.type,
376 )
377 pred_args = {}
378 pred_args["llm"] = llm
380 from .retrieval_tool import build_retrieval_tools
382 return build_retrieval_tools(tool, pred_args, skill)
384 def _get_rag_query_function(self, skill: db.Skills):
385 session_controller = self.get_command_executor().session
387 def _answer_question(question: str) -> str:
388 knowledge_base_name = skill.params["source"]
390 # make select in KB table
391 query = Select(
392 targets=[Star()],
393 where=BinaryOperation(op="=", args=[Identifier(TableField.CONTENT.value), Constant(question)]),
394 limit=Constant(_DEFAULT_TOP_K_SIMILARITY_SEARCH),
395 )
396 kb_table = session_controller.kb_controller.get_table(knowledge_base_name, skill.project_id)
398 res = kb_table.select_query(query)
399 # Handle both chunk_content and content column names
400 if hasattr(res, "chunk_content"):
401 return "\n".join(res.chunk_content)
402 elif hasattr(res, "content"):
403 return "\n".join(res.content)
404 else:
405 return "No content or chunk_content found in knowledge base response"
407 return _answer_question
409 def _make_knowledge_base_tools(self, skill: db.Skills) -> dict:
410 # To prevent dependency on Langchain unless an actual tool uses it.
411 description = skill.params.get("description", "")
413 logger.warning(
414 "This skill is deprecated and will be removed in the future. Please use `retrieval` skill instead "
415 )
417 return dict(
418 name="Knowledge Base Retrieval",
419 func=self._get_rag_query_function(skill),
420 description=f"Use this tool to get more context or information to answer a question about {description}. The input should be the exact question the user is asking.",
421 type=skill.type,
422 )
424 def get_tools_from_skills(
425 self, skills_data: List[SkillData], llm: BaseChatModel, embedding_model: Embeddings
426 ) -> dict:
427 """Creates function for skill and metadata (name, description)
429 Args:
430 skills_data (List[SkillData]): Skills to make a tool from
431 llm (BaseChatModel): LLM which will be used by skills
432 embedding_model (Embeddings): this model is used by retrieval skill
434 Returns:
435 dict: with keys: name, description, func
436 """
438 # group skills by type
439 skills_group = defaultdict(list)
440 for skill in skills_data:
441 try:
442 skill_type = SkillType(skill.type)
443 except ValueError:
444 raise NotImplementedError(
445 f"skill of type {skill.type} is not supported as a tool, supported types are: {list(SkillType._member_names_)}"
446 )
448 if skill_type == SkillType.TEXT2SQL_LEGACY:
449 skill_type = SkillType.TEXT2SQL
450 skills_group[skill_type].append(skill)
452 tools = {}
453 for skill_type, skills in skills_group.items():
454 if skill_type == SkillType.TEXT2SQL:
455 tools[skill_type] = self._make_text_to_sql_tools(skills, llm)
456 elif skill_type == SkillType.KNOWLEDGE_BASE:
457 tools[skill_type] = [self._make_knowledge_base_tools(skill) for skill in skills]
458 elif skill_type == SkillType.RETRIEVAL:
459 tools[skill_type] = []
460 for skill in skills:
461 tools[skill_type] += self._make_retrieval_tools(skill, llm, embedding_model)
462 return tools
465skill_tool = SkillToolController()