Coverage for mindsdb / interfaces / skills / custom / text2sql / mindsdb_kb_tools.py: 0%
151 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from typing import Type, List, Any
2import re
3import json
4from pydantic import BaseModel, Field
5from langchain_core.tools import BaseTool
6from mindsdb_sql_parser.ast import Describe, Select, Identifier, Constant, Star
9def llm_str_strip(s):
10 length = -1
11 while length != len(s):
12 length = len(s)
14 # remove ```
15 if s.startswith("```"):
16 s = s[3:]
17 if s.endswith("```"):
18 s = s[:-3]
20 # remove trailing new lines
21 s = s.strip("\n")
23 # remove extra quotes
24 for q in ('"', "'", "`"):
25 if s.count(q) == 1:
26 s = s.strip(q)
27 return s
30class KnowledgeBaseListToolInput(BaseModel):
31 tool_input: str = Field("", description="An empty string to list all knowledge bases.")
34class KnowledgeBaseListTool(BaseTool):
35 """Tool for listing knowledge bases in MindsDB."""
37 name: str = "kb_list_tool"
38 description: str = "List all knowledge bases in MindsDB."
39 args_schema: Type[BaseModel] = KnowledgeBaseListToolInput
40 db: Any = None
42 def _run(self, tool_input: str) -> str:
43 """List all knowledge bases."""
44 kb_names = self.db.get_usable_knowledge_base_names()
45 # Convert list to a formatted string for better readability
46 if not kb_names:
47 return "No knowledge bases found."
48 return json.dumps(kb_names)
51class KnowledgeBaseInfoToolInput(BaseModel):
52 tool_input: str = Field(
53 ...,
54 description="A comma-separated list of knowledge base names enclosed between $START$ and $STOP$.",
55 )
58class KnowledgeBaseInfoTool(BaseTool):
59 """Tool for getting information about knowledge bases in MindsDB."""
61 name: str = "kb_info_tool"
62 description: str = "Get information about knowledge bases in MindsDB."
63 args_schema: Type[BaseModel] = KnowledgeBaseInfoToolInput
64 db: Any = None
66 def _extract_kb_names(self, tool_input: str) -> List[str]:
67 """Extract knowledge base names from the tool input."""
68 # First, check if the input is already a list (passed directly from include_knowledge_bases)
69 if isinstance(tool_input, list):
70 return tool_input
72 # Next, try to parse it as JSON in case it was serialized as a JSON string
73 try:
74 parsed_input = json.loads(tool_input)
75 if isinstance(parsed_input, list):
76 return parsed_input
77 except (json.JSONDecodeError, TypeError):
78 pass
80 # Finally, try the original regex pattern for $START$ and $STOP$ markers
81 match = re.search(r"\$START\$(.*?)\$STOP\$", tool_input, re.DOTALL)
82 if not match:
83 # If no markers found, check if it's a simple comma-separated string
84 if "," in tool_input:
85 return [kb.strip() for kb in tool_input.split(",")]
86 # If it's just a single string without formatting, return it as a single item
87 if tool_input.strip():
88 return [llm_str_strip(tool_input)]
89 return []
91 # Extract and clean the knowledge base names
92 kb_names_str = match.group(1).strip()
93 kb_names = re.findall(r"`([^`]+)`", kb_names_str)
95 kb_names = [llm_str_strip(n) for n in kb_names]
96 return kb_names
98 def _run(self, tool_input: str) -> str:
99 """Get information about specified knowledge bases."""
100 kb_names = self._extract_kb_names(tool_input)
102 if not kb_names:
103 return "No valid knowledge base names provided. Please provide knowledge base names as a list, comma-separated string, or enclosed in backticks between $START$ and $STOP$."
105 results = []
107 for kb_name in kb_names:
108 try:
109 self.db.check_knowledge_base_permission(Identifier(kb_name))
111 # Get knowledge base schema
112 schema_result = self.db.run_no_throw(str(Describe(kb_name, type="knowledge_base")))
114 if not schema_result:
115 results.append(f"Knowledge base `{kb_name}` not found or has no schema information.")
116 continue
118 # Format the results
119 kb_info = f"## Knowledge Base: `{kb_name}`\n\n"
121 # Schema information
122 kb_info += "### Schema Information:\n"
123 kb_info += "```\n"
125 # Handle different return types for schema_result
126 if isinstance(schema_result, str):
127 kb_info += f"{schema_result}\n"
128 elif isinstance(schema_result, list):
129 for row in schema_result:
130 if isinstance(row, dict):
131 kb_info += f"{json.dumps(row, indent=2)}\n"
132 else:
133 kb_info += f"{str(row)}\n"
134 else:
135 kb_info += f"{str(schema_result)}\n"
137 kb_info += "```\n\n"
139 # Get sample data
140 sample_data = self.db.run_no_throw(
141 str(Select(targets=[Star()], from_table=Identifier(kb_name), limit=Constant(20)))
142 )
144 # Sample data
145 kb_info += "### Sample Data:\n"
147 # Handle different return types for sample_data
148 if not sample_data:
149 kb_info += "No sample data available.\n"
150 elif isinstance(sample_data, str):
151 kb_info += f"```\n{sample_data}\n```\n"
152 elif isinstance(sample_data, list) and len(sample_data) > 0:
153 # Only try to extract columns if we have a list of dictionaries
154 if isinstance(sample_data[0], dict):
155 # Extract column names
156 columns = list(sample_data[0].keys())
158 # Create markdown table header
159 kb_info += "| " + " | ".join(columns) + " |\n"
160 kb_info += "| " + " | ".join(["---" for _ in columns]) + " |\n"
162 # Add rows
163 for row in sample_data:
164 formatted_row = []
165 for col in columns:
166 cell_value = row[col]
167 if isinstance(cell_value, dict):
168 cell_value = json.dumps(cell_value, ensure_ascii=False)
169 formatted_row.append(str(cell_value).replace("|", "\\|"))
170 kb_info += "| " + " | ".join(formatted_row) + " |\n"
171 else:
172 # If it's a list but not of dictionaries, just format as text
173 kb_info += "```\n"
174 for item in sample_data:
175 kb_info += f"{str(item)}\n"
176 kb_info += "```\n"
177 else:
178 # For any other type, just convert to string
179 kb_info += f"```\n{str(sample_data)}\n```\n"
181 results.append(kb_info)
183 except Exception as e:
184 results.append(f"Error getting information for knowledge base `{kb_name}`: {str(e)}")
186 return "\n\n".join(results)
189class KnowledgeBaseQueryToolInput(BaseModel):
190 tool_input: str = Field(
191 ...,
192 description="A SQL query for knowledge bases. Can be provided directly or enclosed between $START$ and $STOP$.",
193 )
196class KnowledgeBaseQueryTool(BaseTool):
197 """Tool for querying knowledge bases in MindsDB."""
199 name: str = "kb_query_tool"
200 description: str = "Query knowledge bases in MindsDB."
201 args_schema: Type[BaseModel] = KnowledgeBaseQueryToolInput
202 db: Any = None
204 def _extract_query(self, tool_input: str) -> str:
205 """Extract the SQL query from the tool input."""
206 # First check if the input is wrapped in $START$ and $STOP$
207 match = re.search(r"\$START\$(.*?)\$STOP\$", tool_input, re.DOTALL)
208 if match:
209 return match.group(1).strip()
211 # If not wrapped in delimiters, use the input directly
212 # Check for SQL keywords to validate it's likely a query
213 if re.search(r"\b(SELECT|FROM|WHERE|LIMIT|ORDER BY)\b", tool_input, re.IGNORECASE):
214 return tool_input.strip()
216 return ""
218 def _run(self, tool_input: str) -> str:
219 """Execute a knowledge base query."""
220 query = self._extract_query(tool_input)
222 if not query:
223 return "No valid SQL query provided. Please provide a SQL query that includes SELECT, FROM, or other SQL keywords."
225 try:
226 # Execute the query
227 query = llm_str_strip(query)
228 result = self.db.run_no_throw(query)
230 if not result:
231 return "Query executed successfully, but no results were returned."
233 # Format the results as a markdown table
234 if isinstance(result, list) and len(result) > 0:
235 # Extract column names
236 columns = list(result[0].keys())
238 # Create markdown table header
239 table = "| " + " | ".join(columns) + " |\n"
240 table += "| " + " | ".join(["---" for _ in columns]) + " |\n"
242 # Add rows
243 for row in result:
244 formatted_row = []
245 for col in columns:
246 cell_value = row[col]
247 if isinstance(cell_value, dict):
248 cell_value = json.dumps(cell_value, ensure_ascii=False)
249 formatted_row.append(str(cell_value).replace("|", "\\|"))
250 table += "| " + " | ".join(formatted_row) + " |\n"
252 return table
254 # Ensure we always return a string
255 if isinstance(result, (list, dict)):
256 return json.dumps(result, indent=2)
257 return str(result)
258 except Exception as e:
259 return f"Error executing query: {str(e)}"