Coverage for mindsdb / interfaces / skills / custom / text2sql / mindsdb_kb_tools.py: 0%

151 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from typing import Type, List, Any 

2import re 

3import json 

4from pydantic import BaseModel, Field 

5from langchain_core.tools import BaseTool 

6from mindsdb_sql_parser.ast import Describe, Select, Identifier, Constant, Star 

7 

8 

9def llm_str_strip(s): 

10 length = -1 

11 while length != len(s): 

12 length = len(s) 

13 

14 # remove ``` 

15 if s.startswith("```"): 

16 s = s[3:] 

17 if s.endswith("```"): 

18 s = s[:-3] 

19 

20 # remove trailing new lines 

21 s = s.strip("\n") 

22 

23 # remove extra quotes 

24 for q in ('"', "'", "`"): 

25 if s.count(q) == 1: 

26 s = s.strip(q) 

27 return s 

28 

29 

30class KnowledgeBaseListToolInput(BaseModel): 

31 tool_input: str = Field("", description="An empty string to list all knowledge bases.") 

32 

33 

34class KnowledgeBaseListTool(BaseTool): 

35 """Tool for listing knowledge bases in MindsDB.""" 

36 

37 name: str = "kb_list_tool" 

38 description: str = "List all knowledge bases in MindsDB." 

39 args_schema: Type[BaseModel] = KnowledgeBaseListToolInput 

40 db: Any = None 

41 

42 def _run(self, tool_input: str) -> str: 

43 """List all knowledge bases.""" 

44 kb_names = self.db.get_usable_knowledge_base_names() 

45 # Convert list to a formatted string for better readability 

46 if not kb_names: 

47 return "No knowledge bases found." 

48 return json.dumps(kb_names) 

49 

50 

51class KnowledgeBaseInfoToolInput(BaseModel): 

52 tool_input: str = Field( 

53 ..., 

54 description="A comma-separated list of knowledge base names enclosed between $START$ and $STOP$.", 

55 ) 

56 

57 

58class KnowledgeBaseInfoTool(BaseTool): 

59 """Tool for getting information about knowledge bases in MindsDB.""" 

60 

61 name: str = "kb_info_tool" 

62 description: str = "Get information about knowledge bases in MindsDB." 

63 args_schema: Type[BaseModel] = KnowledgeBaseInfoToolInput 

64 db: Any = None 

65 

66 def _extract_kb_names(self, tool_input: str) -> List[str]: 

67 """Extract knowledge base names from the tool input.""" 

68 # First, check if the input is already a list (passed directly from include_knowledge_bases) 

69 if isinstance(tool_input, list): 

70 return tool_input 

71 

72 # Next, try to parse it as JSON in case it was serialized as a JSON string 

73 try: 

74 parsed_input = json.loads(tool_input) 

75 if isinstance(parsed_input, list): 

76 return parsed_input 

77 except (json.JSONDecodeError, TypeError): 

78 pass 

79 

80 # Finally, try the original regex pattern for $START$ and $STOP$ markers 

81 match = re.search(r"\$START\$(.*?)\$STOP\$", tool_input, re.DOTALL) 

82 if not match: 

83 # If no markers found, check if it's a simple comma-separated string 

84 if "," in tool_input: 

85 return [kb.strip() for kb in tool_input.split(",")] 

86 # If it's just a single string without formatting, return it as a single item 

87 if tool_input.strip(): 

88 return [llm_str_strip(tool_input)] 

89 return [] 

90 

91 # Extract and clean the knowledge base names 

92 kb_names_str = match.group(1).strip() 

93 kb_names = re.findall(r"`([^`]+)`", kb_names_str) 

94 

95 kb_names = [llm_str_strip(n) for n in kb_names] 

96 return kb_names 

97 

98 def _run(self, tool_input: str) -> str: 

99 """Get information about specified knowledge bases.""" 

100 kb_names = self._extract_kb_names(tool_input) 

101 

102 if not kb_names: 

103 return "No valid knowledge base names provided. Please provide knowledge base names as a list, comma-separated string, or enclosed in backticks between $START$ and $STOP$." 

104 

105 results = [] 

106 

107 for kb_name in kb_names: 

108 try: 

109 self.db.check_knowledge_base_permission(Identifier(kb_name)) 

110 

111 # Get knowledge base schema 

112 schema_result = self.db.run_no_throw(str(Describe(kb_name, type="knowledge_base"))) 

113 

114 if not schema_result: 

115 results.append(f"Knowledge base `{kb_name}` not found or has no schema information.") 

116 continue 

117 

118 # Format the results 

119 kb_info = f"## Knowledge Base: `{kb_name}`\n\n" 

120 

121 # Schema information 

122 kb_info += "### Schema Information:\n" 

123 kb_info += "```\n" 

124 

125 # Handle different return types for schema_result 

126 if isinstance(schema_result, str): 

127 kb_info += f"{schema_result}\n" 

128 elif isinstance(schema_result, list): 

129 for row in schema_result: 

130 if isinstance(row, dict): 

131 kb_info += f"{json.dumps(row, indent=2)}\n" 

132 else: 

133 kb_info += f"{str(row)}\n" 

134 else: 

135 kb_info += f"{str(schema_result)}\n" 

136 

137 kb_info += "```\n\n" 

138 

139 # Get sample data 

140 sample_data = self.db.run_no_throw( 

141 str(Select(targets=[Star()], from_table=Identifier(kb_name), limit=Constant(20))) 

142 ) 

143 

144 # Sample data 

145 kb_info += "### Sample Data:\n" 

146 

147 # Handle different return types for sample_data 

148 if not sample_data: 

149 kb_info += "No sample data available.\n" 

150 elif isinstance(sample_data, str): 

151 kb_info += f"```\n{sample_data}\n```\n" 

152 elif isinstance(sample_data, list) and len(sample_data) > 0: 

153 # Only try to extract columns if we have a list of dictionaries 

154 if isinstance(sample_data[0], dict): 

155 # Extract column names 

156 columns = list(sample_data[0].keys()) 

157 

158 # Create markdown table header 

159 kb_info += "| " + " | ".join(columns) + " |\n" 

160 kb_info += "| " + " | ".join(["---" for _ in columns]) + " |\n" 

161 

162 # Add rows 

163 for row in sample_data: 

164 formatted_row = [] 

165 for col in columns: 

166 cell_value = row[col] 

167 if isinstance(cell_value, dict): 

168 cell_value = json.dumps(cell_value, ensure_ascii=False) 

169 formatted_row.append(str(cell_value).replace("|", "\\|")) 

170 kb_info += "| " + " | ".join(formatted_row) + " |\n" 

171 else: 

172 # If it's a list but not of dictionaries, just format as text 

173 kb_info += "```\n" 

174 for item in sample_data: 

175 kb_info += f"{str(item)}\n" 

176 kb_info += "```\n" 

177 else: 

178 # For any other type, just convert to string 

179 kb_info += f"```\n{str(sample_data)}\n```\n" 

180 

181 results.append(kb_info) 

182 

183 except Exception as e: 

184 results.append(f"Error getting information for knowledge base `{kb_name}`: {str(e)}") 

185 

186 return "\n\n".join(results) 

187 

188 

189class KnowledgeBaseQueryToolInput(BaseModel): 

190 tool_input: str = Field( 

191 ..., 

192 description="A SQL query for knowledge bases. Can be provided directly or enclosed between $START$ and $STOP$.", 

193 ) 

194 

195 

196class KnowledgeBaseQueryTool(BaseTool): 

197 """Tool for querying knowledge bases in MindsDB.""" 

198 

199 name: str = "kb_query_tool" 

200 description: str = "Query knowledge bases in MindsDB." 

201 args_schema: Type[BaseModel] = KnowledgeBaseQueryToolInput 

202 db: Any = None 

203 

204 def _extract_query(self, tool_input: str) -> str: 

205 """Extract the SQL query from the tool input.""" 

206 # First check if the input is wrapped in $START$ and $STOP$ 

207 match = re.search(r"\$START\$(.*?)\$STOP\$", tool_input, re.DOTALL) 

208 if match: 

209 return match.group(1).strip() 

210 

211 # If not wrapped in delimiters, use the input directly 

212 # Check for SQL keywords to validate it's likely a query 

213 if re.search(r"\b(SELECT|FROM|WHERE|LIMIT|ORDER BY)\b", tool_input, re.IGNORECASE): 

214 return tool_input.strip() 

215 

216 return "" 

217 

218 def _run(self, tool_input: str) -> str: 

219 """Execute a knowledge base query.""" 

220 query = self._extract_query(tool_input) 

221 

222 if not query: 

223 return "No valid SQL query provided. Please provide a SQL query that includes SELECT, FROM, or other SQL keywords." 

224 

225 try: 

226 # Execute the query 

227 query = llm_str_strip(query) 

228 result = self.db.run_no_throw(query) 

229 

230 if not result: 

231 return "Query executed successfully, but no results were returned." 

232 

233 # Format the results as a markdown table 

234 if isinstance(result, list) and len(result) > 0: 

235 # Extract column names 

236 columns = list(result[0].keys()) 

237 

238 # Create markdown table header 

239 table = "| " + " | ".join(columns) + " |\n" 

240 table += "| " + " | ".join(["---" for _ in columns]) + " |\n" 

241 

242 # Add rows 

243 for row in result: 

244 formatted_row = [] 

245 for col in columns: 

246 cell_value = row[col] 

247 if isinstance(cell_value, dict): 

248 cell_value = json.dumps(cell_value, ensure_ascii=False) 

249 formatted_row.append(str(cell_value).replace("|", "\\|")) 

250 table += "| " + " | ".join(formatted_row) + " |\n" 

251 

252 return table 

253 

254 # Ensure we always return a string 

255 if isinstance(result, (list, dict)): 

256 return json.dumps(result, indent=2) 

257 return str(result) 

258 except Exception as e: 

259 return f"Error executing query: {str(e)}"