Coverage for mindsdb / interfaces / skills / sql_agent.py: 8%
381 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import re
2import csv
3import inspect
4from io import StringIO
5from typing import Iterable, List, Optional, Any, Tuple
6from collections import defaultdict
7import fnmatch
9import pandas as pd
10from mindsdb_sql_parser import parse_sql
11from mindsdb_sql_parser.ast import Select, Show, Describe, Explain, Identifier
13from mindsdb.utilities import log
14from mindsdb.utilities.context import context as ctx
15from mindsdb.integrations.utilities.query_traversal import query_traversal
16from mindsdb.integrations.libs.response import INF_SCHEMA_COLUMNS_NAMES
17from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import MYSQL_DATA_TYPE
18from mindsdb.utilities.config import config
19from mindsdb.interfaces.data_catalog.data_catalog_retriever import DataCatalogRetriever
21logger = log.getLogger(__name__)
24def list_to_csv_str(array: List[List[Any]]) -> str:
25 """Convert a 2D array into a CSV string.
27 Args:
28 array (List[List[Any]]): A 2D array/list of values to convert to CSV format
30 Returns:
31 str: The array formatted as a CSV string using Excel dialect
32 """
33 output = StringIO()
34 writer = csv.writer(output, dialect="excel")
35 str_array = [[str(item) for item in row] for row in array]
36 writer.writerows(str_array)
37 return output.getvalue()
40def split_table_name(table_name: str) -> List[str]:
41 """Split table name from llm to parts
43 Args:
44 table_name (str): input table name
46 Returns:
47 List[str]: parts of table identifier like ['database', 'schema', 'table']
49 Example:
50 'input': '`aaa`.`bbb.ccc`', 'output': ['aaa', 'bbb.ccc']
51 'input': '`aaa`.`bbb`.`ccc`', 'output': ['aaa', 'bbb', 'ccc']
52 'input': 'aaa.bbb', 'output': ['aaa', 'bbb']
53 'input': '`aaa.bbb`', 'output': ['aaa.bbb']
54 'input': '`aaa.bbb.ccc`', 'output': ['aaa.bbb.ccc']
55 'input': 'aaa.`bbb`', 'output': ['aaa', 'bbb']
56 'input': 'aaa.bbb.ccc', 'output': ['aaa', 'bbb', 'ccc']
57 'input': 'aaa.`bbb.ccc`', 'output': ['aaa', 'bbb.ccc']
58 'input': '`aaa`.`bbb.ccc`', 'output': ['aaa', 'bbb.ccc']
59 """
60 result = []
61 current = ""
62 in_backticks = False
64 i = 0
65 while i < len(table_name):
66 if table_name[i] == "`":
67 in_backticks = not in_backticks
68 elif table_name[i] == "." and not in_backticks:
69 if current:
70 result.append(current.strip("`"))
71 current = ""
72 else:
73 current += table_name[i]
74 i += 1
76 if current:
77 result.append(current.strip("`"))
79 return result
82class TablesCollection:
83 """
84 Collection of identifiers.
85 Supports wildcard in tables name.
86 """
88 def __init__(self, items: List[Identifier | str] = None, default_db=None):
89 if items is None:
90 items = []
92 self.items = items
93 self._dbs = defaultdict(set)
94 self._schemas = defaultdict(dict)
95 self._no_db_tables = set()
96 self.has_wildcard = False
97 self.databases = set()
98 self._default_db = default_db
100 for name in items:
101 if not isinstance(name, Identifier):
102 name = Identifier(name)
103 db, schema, tbl = self._get_paths(name)
104 if db is None:
105 self._no_db_tables.add(tbl)
106 elif schema is None:
107 self._dbs[db].add(tbl)
108 else:
109 if schema not in self._schemas[db]:
110 self._schemas[db][schema] = set()
111 self._schemas[db][schema].add(tbl)
113 if "*" in tbl:
114 self.has_wildcard = True
115 self.databases.add(db)
117 def _get_paths(self, table: Identifier) -> Tuple:
118 # split identifier to db, schema, table name
119 schema = None
120 db = None
122 match [x.lower() for x in table.parts]:
123 case [tbl]:
124 pass
125 case [db, tbl]:
126 pass
127 case [db, schema, tbl]:
128 pass
129 case _:
130 raise NotImplementedError
131 return db, schema, tbl.lower()
133 def match(self, table: Identifier) -> bool:
134 # Check if input table matches to tables in collection
136 db, schema, tbl = self._get_paths(table)
137 if db is None:
138 if tbl in self._no_db_tables:
139 return True
140 if self._default_db is not None:
141 return self.match(Identifier(parts=[self._default_db, tbl]))
143 if schema is not None:
144 if any([fnmatch.fnmatch(tbl, pattern) for pattern in self._schemas[db].get(schema, [])]):
145 return True
147 # table might be specified without schema
148 return any([fnmatch.fnmatch(tbl, pattern) for pattern in self._dbs[db]])
150 def __bool__(self):
151 return len(self.items) > 0
153 def __repr__(self):
154 return f"Tables({self.items})"
157class SQLAgent:
158 """
159 SQLAgent is a class that handles SQL queries for agents.
160 """
162 def __init__(
163 self,
164 command_executor,
165 databases: List[str],
166 databases_struct: dict,
167 knowledge_base_database: str = "mindsdb",
168 include_tables: Optional[List[str]] = None,
169 ignore_tables: Optional[List[str]] = None,
170 include_knowledge_bases: Optional[List[str]] = None,
171 ignore_knowledge_bases: Optional[List[str]] = None,
172 sample_rows_in_table_info: int = 3,
173 cache: Optional[dict] = None,
174 ):
175 """
176 Initialize SQLAgent.
178 Args:
179 command_executor: Executor for SQL commands
180 databases (List[str]): List of databases to use
181 databases_struct (dict): Dictionary of database structures
182 knowledge_base_database (str): Project name where knowledge bases are stored (defaults to 'mindsdb')
183 include_tables (List[str]): Tables to include
184 ignore_tables (List[str]): Tables to ignore
185 include_knowledge_bases (List[str]): Knowledge bases to include
186 ignore_knowledge_bases (List[str]): Knowledge bases to ignore
187 sample_rows_in_table_info (int): Number of sample rows to include in table info
188 cache (Optional[dict]): Cache for query results
189 """
190 self._command_executor = command_executor
191 self._mindsdb_db_struct = databases_struct
192 self.knowledge_base_database = knowledge_base_database # This is a project name, not a database connection
193 self._databases = databases
194 self._sample_rows_in_table_info = int(sample_rows_in_table_info)
196 self._tables_to_include = TablesCollection(include_tables)
197 if self._tables_to_include:
198 # ignore_tables and include_tables should not be used together.
199 # include_tables takes priority if it's set.
200 ignore_tables = []
201 self._tables_to_ignore = TablesCollection(ignore_tables)
203 self._knowledge_bases_to_include = TablesCollection(include_knowledge_bases, default_db=knowledge_base_database)
204 if self._knowledge_bases_to_include:
205 # ignore_knowledge_bases and include_knowledge_bases should not be used together.
206 # include_knowledge_bases takes priority if it's set.
207 ignore_knowledge_bases = []
208 self._knowledge_bases_to_ignore = TablesCollection(ignore_knowledge_bases, default_db=knowledge_base_database)
210 self._cache = cache
212 from mindsdb.interfaces.skills.skill_tool import SkillToolController
214 # Initialize the skill tool controller from MindsDB
215 self.skill_tool = SkillToolController()
217 def _call_engine(self, query: str, database=None):
218 # switch database
219 ast_query = parse_sql(query.strip("`"))
220 self._check_permissions(ast_query)
222 if database is None:
223 # if we use tables with prefixes it should work for any database
224 if self._databases is not None:
225 # if we have multiple databases, we need to check which one to use
226 # for now, we will just use the first one
227 database = self._databases[0] if self._databases else "mindsdb"
229 ret = self._command_executor.execute_command(ast_query, database_name=database)
230 return ret
232 def _check_permissions(self, ast_query):
233 # check type of query
234 if not isinstance(ast_query, (Select, Show, Describe, Explain)):
235 raise ValueError(f"Query is not allowed: {ast_query.to_string()}")
237 kb_names = self.get_all_knowledge_base_names()
239 # Check tables
240 if self._tables_to_include:
242 def _check_f(node, is_table=None, **kwargs):
243 if is_table and isinstance(node, Identifier):
244 table_name = ".".join(node.parts)
246 # Check if this table is a knowledge base
247 if table_name in kb_names or node.parts[-1] in kb_names:
248 # If it's a knowledge base and we have knowledge base restrictions
249 self.check_knowledge_base_permission(node)
250 else:
251 try:
252 # Regular table check
253 self.check_table_permission(node)
254 except ValueError as origin_exc:
255 # was it badly quoted by llm?
256 #
257 if "." in node.parts[0]:
258 # extract quoted parts (with dots) to sub-parts
259 parts = []
260 for i, item in enumerate(node.parts):
261 if node.is_quoted[i] and "." in item:
262 parts.extend(Identifier(item).parts)
263 else:
264 parts.append(item)
265 node2 = Identifier(parts=parts)
266 try:
267 _check_f(node2, is_table=True)
268 return node2
269 except ValueError:
270 ...
271 raise origin_exc
273 query_traversal(ast_query, _check_f)
275 def check_knowledge_base_permission(self, node):
276 if self._knowledge_bases_to_include and not self._knowledge_bases_to_include.match(node):
277 raise ValueError(
278 f"Knowledge base {str(node)} not found. Available knowledge bases: {', '.join(self._knowledge_bases_to_include.items)}"
279 )
280 # Check if it's a restricted knowledge base
281 if self._knowledge_bases_to_ignore and self._knowledge_bases_to_ignore.match(node):
282 raise ValueError(f"Knowledge base {str(node)} is not allowed.")
284 def check_table_permission(self, node):
285 if self._tables_to_include and not self._tables_to_include.match(node):
286 raise ValueError(
287 f"Table {str(node)} not found. Available tables: {', '.join(self._tables_to_include.items)}"
288 )
289 # Check if it's a restricted table
290 if self._tables_to_ignore and self._tables_to_ignore.match(node):
291 raise ValueError(f"Table {str(node)} is not allowed.")
293 def get_usable_table_names(self) -> Iterable[str]:
294 """Get a list of tables that the agent has access to.
296 Returns:
297 Iterable[str]: list with table names
298 """
299 cache_key = f"{ctx.company_id}_{','.join(self._databases)}_tables"
301 # first check cache and return if found
302 if self._cache:
303 cached_tables = self._cache.get(cache_key)
304 if cached_tables:
305 return cached_tables
307 if not self._tables_to_include:
308 # no tables allowed
309 return []
310 if not self._tables_to_include.has_wildcard:
311 return self._tables_to_include.items
313 result_tables = []
315 for db_name in self._tables_to_include.databases:
316 handler = self._command_executor.session.integration_controller.get_data_handler(db_name)
318 if "all" in inspect.signature(handler.get_tables).parameters:
319 response = handler.get_tables(all=True)
320 else:
321 response = handler.get_tables()
322 df = response.data_frame
323 col_name = "table_name"
324 if col_name not in df.columns:
325 # get first column if not found
326 col_name = df.columns[0]
328 for _, row in df.iterrows():
329 if "table_schema" in row:
330 parts = [db_name, row["table_schema"], row[col_name]]
331 else:
332 parts = [db_name, row[col_name]]
333 if self._tables_to_include.match(Identifier(parts=parts)):
334 if not self._tables_to_ignore.match(Identifier(parts=parts)):
335 result_tables.append(parts)
337 result_tables = [".".join(x) for x in result_tables]
338 if self._cache:
339 self._cache.set(cache_key, set(result_tables))
340 return result_tables
342 def get_usable_knowledge_base_names(self) -> Iterable[str]:
343 """Get a list of knowledge bases that the agent has access to.
345 Returns:
346 Iterable[str]: list with knowledge base names
347 """
349 if not self._knowledge_bases_to_include and not self._knowledge_bases_to_ignore:
350 # white or black list have to be set
351 return []
353 # Filter knowledge bases based on ignore list
354 kb_names = []
355 for kb_name in self.get_all_knowledge_base_names():
356 kb = Identifier(parts=[self.knowledge_base_database, kb_name])
357 if self._knowledge_bases_to_include and not self._knowledge_bases_to_include.match(kb):
358 continue
359 if not self._knowledge_bases_to_ignore.match(kb):
360 kb_names.append(kb_name)
361 return kb_names
363 def get_all_knowledge_base_names(self) -> Iterable[str]:
364 """Get a list of all knowledge bases
366 Returns:
367 Iterable[str]: list with knowledge base names
368 """
369 # cache_key = f"{ctx.company_id}_{self.knowledge_base_database}_knowledge_bases"
371 # todo we need to fix the cache, file cache can potentially store out of data information
372 # # first check cache and return if found
373 # if self._cache:
374 # cached_kbs = self._cache.get(cache_key)
375 # if cached_kbs:
376 # return cached_kbs
378 try:
379 # Query to get all knowledge bases
380 ast_query = Show(category="Knowledge Bases")
381 result = self._command_executor.execute_command(ast_query, database_name=self.knowledge_base_database)
383 # Filter knowledge bases based on ignore list
384 kb_names = []
385 for row in result.data.records:
386 kb_names.append(row["NAME"])
388 # if self._cache:
389 # self._cache.set(cache_key, set(kb_names))
391 return kb_names
392 except Exception:
393 # If there's an error, log it and return an empty list
394 logger.exception("Error in get_usable_knowledge_base_names")
395 return []
397 def _resolve_table_names(self, table_names: List[str], all_tables: List[Identifier]) -> List[Identifier]:
398 """
399 Tries to find table (which comes directly from an LLM) by its name
400 Handles backticks (`) and tables without databases
401 """
403 # index to lookup table
404 tables_idx = {}
405 for table in all_tables:
406 # by name
407 if len(table.parts) == 3:
408 tables_idx[tuple(table.parts[1:])] = table
409 else:
410 tables_idx[(table.parts[-1],)] = table
411 # by path
412 tables_idx[tuple(table.parts)] = table
414 tables = []
415 not_found = []
416 for table_name in table_names:
417 if not table_name.strip():
418 continue
420 # Some LLMs (e.g. gpt-4o) may include backticks or quotes when invoking tools.
421 table_parts = split_table_name(table_name)
422 if len(table_parts) == 1:
423 # most likely LLM enclosed all table name in backticks `database.table`
424 table_parts = split_table_name(table_name)
426 # resolved table
427 table_identifier = tables_idx.get(tuple(table_parts))
429 if table_identifier is None:
430 not_found.append(table_name)
431 else:
432 tables.append(table_identifier)
434 if not_found:
435 raise ValueError(f"Tables: {', '.join(not_found)} not found in the database")
436 return tables
438 def get_knowledge_base_info(self, kb_names: Optional[List[str]] = None) -> str:
439 """Get information about specified knowledge bases.
440 Follows best practices as specified in: Rajkumar et al, 2022 (https://arxiv.org/abs/2204.00498)
441 If `sample_rows_in_table_info`, the specified number of sample rows will be
442 appended to each table description. This can increase performance as demonstrated in the paper.
443 """
445 kbs_info = []
446 for kb in kb_names:
447 key = f"{ctx.company_id}_{kb}_info"
448 kb_info = self._cache.get(key) if self._cache else None
449 if True or kb_info is None:
450 kb_info = self.get_kb_sample_rows(kb)
451 if self._cache:
452 self._cache.set(key, kb_info)
454 kbs_info.append(kb_info)
456 return "\n\n".join(kbs_info)
458 def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
459 """Get information about specified tables.
460 Follows best practices as specified in: Rajkumar et al, 2022 (https://arxiv.org/abs/2204.00498)
461 If `sample_rows_in_table_info`, the specified number of sample rows will be
462 appended to each table description. This can increase performance as demonstrated in the paper.
463 """
464 if config.get("data_catalog", {}).get("enabled", False):
465 database_table_map = {}
466 for name in table_names or self.get_usable_table_names():
467 name = name.replace("`", "")
469 parts = name.split(".", 1)
470 # TODO: Will there be situations where parts has more than 2 elements? Like a schema?
471 # This is unlikely given that we default to a single schema per database.
472 if len(parts) == 1:
473 raise ValueError(f"Invalid table name: {name}. Expected format is 'database.table'.")
475 database_table_map.setdefault(parts[0], []).append(parts[1])
477 data_catalog_str = ""
478 # TODO: Introduce caching mechanism to avoid repeated retrievals?
479 for database_name, table_names in database_table_map.items():
480 data_catalog_retriever = DataCatalogRetriever(database_name=database_name, table_names=table_names)
482 result = data_catalog_retriever.retrieve_metadata_as_string()
483 data_catalog_str += str(result or "")
485 return data_catalog_str
487 else:
488 # TODO: Improve old logic without data catalog
489 all_tables = []
490 for name in self.get_usable_table_names():
491 # remove backticks
492 name = name.replace("`", "")
494 parts = name.split(".")
495 if len(parts) > 1:
496 all_tables.append(Identifier(parts=parts))
497 else:
498 all_tables.append(Identifier(name))
500 if table_names is not None:
501 all_tables = self._resolve_table_names(table_names, all_tables)
503 tables_info = []
504 for table in all_tables:
505 key = f"{ctx.company_id}_{table}_info"
506 table_info = self._cache.get(key) if self._cache else None
507 if True or table_info is None:
508 table_info = self._get_single_table_info(table)
509 if self._cache:
510 self._cache.set(key, table_info)
512 tables_info.append(table_info)
514 return "\n\n".join(tables_info)
516 def get_kb_sample_rows(self, kb_name: str) -> str:
517 """Get sample rows from a knowledge base.
519 Args:
520 kb_name (str): The name of the knowledge base.
522 Returns:
523 str: A string containing the sample rows from the knowledge base.
524 """
525 logger.info(f"_get_sample_rows: knowledge base={kb_name}")
526 command = f"select * from {kb_name} limit 10;"
527 try:
528 ret = self._call_engine(command)
529 sample_rows = ret.data.to_lists()
531 def truncate_value(val):
532 str_val = str(val)
533 return str_val if len(str_val) < 100 else (str_val[:100] + "...")
535 sample_rows = list(map(lambda row: [truncate_value(value) for value in row], sample_rows))
536 sample_rows_str = "\n" + f"{kb_name}:" + list_to_csv_str(sample_rows)
537 except Exception:
538 logger.info("_get_sample_rows error:", exc_info=True)
539 sample_rows_str = "\n" + "\t [error] Couldn't retrieve sample rows!"
541 return sample_rows_str
543 def _get_single_table_info(self, table: Identifier) -> str:
544 if len(table.parts) < 2:
545 raise ValueError(f"Database is required for table: {table}")
546 if len(table.parts) == 3:
547 integration, schema_name, table_name = table.parts[-3:]
548 else:
549 schema_name = None
550 integration, table_name = table.parts[-2:]
552 table_str = str(table)
554 dn = self._command_executor.session.datahub.get(integration)
556 fields, dtypes = [], []
557 try:
558 df = dn.get_table_columns_df(table_name, schema_name)
559 if not isinstance(df, pd.DataFrame) or df.empty:
560 logger.warning(f"Received empty or invalid DataFrame for table columns of {table_str}")
561 return f"Table named `{table_str}`:\n [No column information available]"
563 fields = df[INF_SCHEMA_COLUMNS_NAMES.COLUMN_NAME].to_list()
564 dtypes = [
565 mysql_data_type.value if isinstance(mysql_data_type, MYSQL_DATA_TYPE) else (data_type or "UNKNOWN")
566 for mysql_data_type, data_type in zip(
567 df[INF_SCHEMA_COLUMNS_NAMES.MYSQL_DATA_TYPE], df[INF_SCHEMA_COLUMNS_NAMES.DATA_TYPE]
568 )
569 ]
570 except Exception as e:
571 logger.exception(f"Failed processing column info for {table_str}:")
572 raise ValueError(f"Failed to process column info for {table_str}") from e
574 if not fields:
575 logger.error(f"Could not extract column fields for {table_str}.")
576 return f"Table named `{table_str}`:\n [Could not extract column information]"
578 try:
579 sample_rows_info = self._get_sample_rows(table_str, fields)
580 except Exception:
581 logger.warning(f"Could not get sample rows for {table_str}:", exc_info=True)
582 sample_rows_info = "\n\t [error] Couldn't retrieve sample rows!"
584 info = f"Table named `{table_str}`:\n"
585 info += f"\nSample with first {self._sample_rows_in_table_info} rows from table {table_str} in CSV format (dialect is 'excel'):\n"
586 info += sample_rows_info + "\n"
587 info += (
588 "\nColumn data types: "
589 + ",\t".join([f"\n`{field}` : `{dtype}`" for field, dtype in zip(fields, dtypes)])
590 + "\n"
591 )
592 return info
594 def _get_sample_rows(self, table: str, fields: List[str]) -> str:
595 logger.info(f"_get_sample_rows: table={table} fields={fields}")
596 command = f"select * from {table} limit {self._sample_rows_in_table_info};"
597 try:
598 ret = self._call_engine(command)
599 sample_rows = ret.data.to_lists()
601 def truncate_value(val):
602 str_val = str(val)
603 return str_val if len(str_val) < 100 else (str_val[:100] + "...")
605 sample_rows = list(map(lambda row: [truncate_value(value) for value in row], sample_rows))
606 sample_rows_str = "\n" + list_to_csv_str([fields] + sample_rows)
607 except Exception:
608 logger.info("_get_sample_rows error:", exc_info=True)
609 sample_rows_str = "\n" + "\t [error] Couldn't retrieve sample rows!"
611 return sample_rows_str
613 def _clean_query(self, query: str) -> str:
614 # Sometimes LLM can input markdown into query tools.
615 cmd = re.sub(r"```(sql)?", "", query)
616 return cmd
618 def query(self, command: str, fetch: str = "all") -> str:
619 """Execute a SQL command and return a string representing the results.
620 If the statement returns rows, a string of the results is returned.
621 If the statement returns no rows, an empty string is returned.
622 """
624 def _repr_result(ret):
625 limit_rows = 30
627 columns_str = ", ".join([repr(col.name) for col in ret.columns])
628 res = f"Output columns: {columns_str}\n"
630 data = ret.to_lists()
631 if len(data) > limit_rows:
632 df = pd.DataFrame(data, columns=[col.name for col in ret.columns])
634 res += f"Result has {len(data)} rows. Description of data:\n"
635 res += str(df.describe(include="all")) + "\n\n"
636 res += f"First {limit_rows} rows:\n"
638 else:
639 res += "Result in CSV format (dialect is 'excel'):\n"
640 res += list_to_csv_str([[col.name for col in ret.columns]] + data[:limit_rows])
641 return res
643 ret = self._call_engine(self._clean_query(command))
644 if fetch == "all":
645 result = _repr_result(ret.data)
646 elif fetch == "one":
647 result = "Result in CSV format (dialect is 'excel'):\n"
648 result += list_to_csv_str([[col.name for col in ret.data.columns]] + [ret.data.to_lists()[0]])
649 else:
650 raise ValueError("Fetch parameter must be either 'one' or 'all'")
651 return str(result)
653 def get_table_info_safe(self, table_names: Optional[List[str]] = None) -> str:
654 try:
655 logger.info(f"get_table_info_safe: {table_names}")
656 return self.get_table_info(table_names)
657 except Exception as e:
658 logger.info("get_table_info_safe error:", exc_info=True)
659 return f"Error: {e}"
661 def query_safe(self, command: str, fetch: str = "all") -> str:
662 try:
663 logger.info(f"query_safe (fetch={fetch}): {command}")
664 return self.query(command, fetch)
665 except Exception as e:
666 logger.exception("Error in query_safe:")
667 msg = f"Error: {e}"
668 if "does not exist" in msg and " relation " in msg:
669 msg += "\nAvailable tables: " + ", ".join(self.get_usable_table_names())
670 return msg