Coverage for mindsdb / interfaces / skills / skill_tool.py: 11%

247 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import enum 

2import inspect 

3 

4from dataclasses import dataclass 

5from collections import defaultdict 

6from typing import List, Dict, Optional 

7 

8from langchain_core.embeddings import Embeddings 

9from langchain_core.language_models import BaseChatModel 

10from mindsdb_sql_parser.ast import Select, BinaryOperation, Identifier, Constant, Star 

11 

12from mindsdb.utilities import log 

13from mindsdb.utilities.cache import get_cache 

14from mindsdb.utilities.config import config 

15from mindsdb.interfaces.storage import db 

16from mindsdb.interfaces.skills.sql_agent import SQLAgent 

17from mindsdb.integrations.libs.vectordatabase_handler import TableField 

18from mindsdb.interfaces.agents.constants import DEFAULT_TEXT2SQL_DATABASE 

19 

20_DEFAULT_TOP_K_SIMILARITY_SEARCH = 5 

21_MAX_CACHE_SIZE = 1000 

22 

23logger = log.getLogger(__name__) 

24 

25 

26class SkillType(enum.Enum): 

27 TEXT2SQL_LEGACY = "text2sql" 

28 TEXT2SQL = "sql" 

29 KNOWLEDGE_BASE = "knowledge_base" 

30 RETRIEVAL = "retrieval" 

31 

32 

33@dataclass 

34class SkillData: 

35 """Storage for skill's data 

36 

37 Attributes: 

38 name (str): name of the skill 

39 type (str): skill's type (SkillType) 

40 params (dict): skill's attributes 

41 project_id (int): id of the project 

42 agent_tables_list (Optional[List[str]]): the restriction on available tables for an agent using the skill 

43 """ 

44 

45 name: str 

46 type: str 

47 params: dict 

48 project_id: int 

49 agent_tables_list: Optional[List[str]] 

50 

51 @property 

52 def restriction_on_tables(self) -> Optional[Dict[str, set]]: 

53 """Schemas and tables which agent+skill may use. The result is intersections of skill's and agent's tables lists. 

54 

55 Returns: 

56 Optional[Dict[str, set]]: allowed schemas and tables. Schemas - are keys in dict, tables - are values. 

57 if result is None, then there are no restrictions 

58 

59 Raises: 

60 ValueError: if there is no intersection between skill's and agent's list. 

61 This means that all tables restricted for use. 

62 """ 

63 

64 def list_to_map(input: List) -> Dict: 

65 agent_tables_map = defaultdict(set) 

66 for x in input: 

67 if isinstance(x, str): 

68 table_name = x 

69 schema_name = None 

70 elif isinstance(x, dict): 

71 table_name = x["table"] 

72 schema_name = x.get("schema") 

73 else: 

74 raise ValueError(f"Unexpected value in tables list: {x}") 

75 agent_tables_map[schema_name].add(table_name) 

76 return agent_tables_map 

77 

78 agent_tables_map = list_to_map(self.agent_tables_list or []) 

79 skill_tables_map = list_to_map(self.params.get("tables", [])) 

80 

81 if len(agent_tables_map) > 0 and len(skill_tables_map) > 0: 

82 if len(set(agent_tables_map) & set(skill_tables_map)) == 0: 

83 raise ValueError("Skill's and agent's allowed tables list have no shared schemas.") 

84 

85 intersection_tables_map = defaultdict(set) 

86 has_intersection = False 

87 for schema_name in agent_tables_map: 

88 if schema_name not in skill_tables_map: 

89 continue 

90 intersection_tables_map[schema_name] = agent_tables_map[schema_name] & skill_tables_map[schema_name] 

91 if len(intersection_tables_map[schema_name]) > 0: 

92 has_intersection = True 

93 if has_intersection is False: 

94 raise ValueError("Skill's and agent's allowed tables list have no shared tables.") 

95 return intersection_tables_map 

96 if len(skill_tables_map) > 0: 

97 return skill_tables_map 

98 if len(agent_tables_map) > 0: 

99 return agent_tables_map 

100 return None 

101 

102 

103class SkillToolController: 

104 def __init__(self): 

105 self.command_executor = None 

106 

107 def get_command_executor(self): 

108 if self.command_executor is None: 

109 from mindsdb.api.executor.command_executor import ExecuteCommands 

110 from mindsdb.api.executor.controllers import ( 

111 SessionController, 

112 ) # Top-level import produces circular import in some cases TODO: figure out a fix without losing runtime improvements (context: see #9304) # noqa 

113 

114 sql_session = SessionController() 

115 sql_session.database = config.get("default_project") 

116 

117 self.command_executor = ExecuteCommands(sql_session) 

118 return self.command_executor 

119 

120 def _make_text_to_sql_tools(self, skills: List[db.Skills], llm) -> List: 

121 """ 

122 Uses SQLAgent to execute tool 

123 """ 

124 # To prevent dependency on Langchain unless an actual tool uses it. 

125 try: 

126 from mindsdb.interfaces.agents.mindsdb_database_agent import MindsDBSQL 

127 from mindsdb.interfaces.skills.custom.text2sql.mindsdb_sql_toolkit import MindsDBSQLToolkit 

128 except ImportError: 

129 raise ImportError( 

130 "To use the text-to-SQL skill, please install langchain with `pip install mindsdb[langchain]`" 

131 ) 

132 

133 command_executor = self.get_command_executor() 

134 

135 def escape_table_name(name: str) -> str: 

136 name = name.strip(" `") 

137 return f"`{name}`" 

138 

139 tables_list = [] 

140 knowledge_bases_list = [] 

141 ignore_knowledge_bases_list = [] 

142 

143 # Track databases extracted from dot notation 

144 extracted_databases = set() 

145 

146 # Initialize knowledge_base_database with default value 

147 knowledge_base_database = DEFAULT_TEXT2SQL_DATABASE # Default to mindsdb project 

148 

149 # First pass: collect all database and knowledge base parameters 

150 for skill in skills: 

151 # Update knowledge_base_database if specified in any skill 

152 if skill.params.get("knowledge_base_database"): 

153 knowledge_base_database = skill.params.get("knowledge_base_database") 

154 

155 # Extract databases from include_tables with dot notation 

156 if skill.params.get("include_tables"): 

157 include_tables = skill.params.get("include_tables") 

158 if isinstance(include_tables, str): 

159 include_tables = [t.strip() for t in include_tables.split(",")] 

160 

161 # Extract database names from dot notation 

162 for table in include_tables: 

163 if "." in table: 

164 db_name = table.split(".")[0] 

165 extracted_databases.add(db_name) 

166 

167 # Extract databases from include_knowledge_bases with dot notation 

168 if skill.params.get("include_knowledge_bases"): 

169 include_kbs = skill.params.get("include_knowledge_bases") 

170 if isinstance(include_kbs, str): 

171 include_kbs = [kb.strip() for kb in include_kbs.split(",")] 

172 

173 # Extract database names from dot notation 

174 for kb in include_kbs: 

175 if "." in kb: 

176 db_name = kb.split(".")[0] 

177 if db_name != knowledge_base_database: 

178 # Only update if it's different from the default 

179 knowledge_base_database = db_name 

180 

181 # Second pass: collect all tables and knowledge base restrictions 

182 for skill in skills: 

183 # Get database for tables (this is an actual database connection) 

184 database = skill.params.get("database", DEFAULT_TEXT2SQL_DATABASE) 

185 

186 # Add databases extracted from dot notation if no explicit database is provided 

187 if not database and extracted_databases: 

188 # Use the first extracted database if no explicit database is provided 

189 database = next(iter(extracted_databases)) 

190 # Update the skill params with the extracted database 

191 skill.params["database"] = database 

192 

193 # Extract knowledge base restrictions if they exist in the skill params 

194 if skill.params.get("include_knowledge_bases"): 

195 # Convert to list if it's a string 

196 include_kbs = skill.params.get("include_knowledge_bases") 

197 if isinstance(include_kbs, str): 

198 include_kbs = [kb.strip() for kb in include_kbs.split(",")] 

199 

200 # Process each knowledge base name 

201 for kb in include_kbs: 

202 # If it doesn't have a dot, prefix it with the knowledge_base_database 

203 if "." not in kb: 

204 knowledge_bases_list.append(f"{knowledge_base_database}.{kb}") 

205 else: 

206 knowledge_bases_list.append(kb) 

207 

208 # Collect ignore_knowledge_bases 

209 if skill.params.get("ignore_knowledge_bases"): 

210 # Convert to list if it's a string 

211 ignore_kbs = skill.params.get("ignore_knowledge_bases") 

212 if isinstance(ignore_kbs, str): 

213 ignore_kbs = [kb.strip() for kb in ignore_kbs.split(",")] 

214 

215 # Process each knowledge base name to ignore 

216 for kb in ignore_kbs: 

217 # If it doesn't have a dot, prefix it with the knowledge_base_database 

218 if "." not in kb: 

219 ignore_knowledge_bases_list.append(f"{knowledge_base_database}.{kb}") 

220 else: 

221 ignore_knowledge_bases_list.append(kb) 

222 

223 # Skip if no database specified 

224 if not database: 

225 continue 

226 

227 # Process include_tables with dot notation 

228 if skill.params.get("include_tables"): 

229 include_tables = skill.params.get("include_tables") 

230 if isinstance(include_tables, str): 

231 include_tables = [t.strip() for t in include_tables.split(",")] 

232 

233 for table in include_tables: 

234 # If table already has a database prefix, use it as is 

235 if "." in table: 

236 # Check if the table already has backticks 

237 if "`" in table: 

238 tables_list.append(table) 

239 else: 

240 # Apply escape_table_name only to the table part 

241 parts = table.split(".") 

242 if len(parts) == 2: 

243 # Format: database.table 

244 tables_list.append(f"{parts[0]}.{escape_table_name(parts[1])}") 

245 elif len(parts) == 3: 

246 # Format: database.schema.table 

247 tables_list.append(f"{parts[0]}.{parts[1]}.{escape_table_name(parts[2])}") 

248 else: 

249 # Unusual format, escape the whole thing 

250 tables_list.append(escape_table_name(table)) 

251 else: 

252 # Otherwise, prefix with the database 

253 tables_list.append(f"{database}.{escape_table_name(table)}") 

254 

255 # Skip further table processing if include_tables is specified 

256 continue 

257 

258 restriction_on_tables = skill.restriction_on_tables 

259 

260 if restriction_on_tables is None and database: 

261 try: 

262 handler = command_executor.session.integration_controller.get_data_handler(database) 

263 if "all" in inspect.signature(handler.get_tables).parameters: 

264 response = handler.get_tables(all=True) 

265 else: 

266 response = handler.get_tables() 

267 # no restrictions 

268 columns = [c.lower() for c in response.data_frame.columns] 

269 name_idx = columns.index("table_name") if "table_name" in columns else 0 

270 

271 if "table_schema" in response.data_frame.columns: 

272 for _, row in response.data_frame.iterrows(): 

273 tables_list.append(f"{database}.{row['table_schema']}.{escape_table_name(row[name_idx])}") 

274 else: 

275 for table_name in response.data_frame.iloc[:, name_idx]: 

276 tables_list.append(f"{database}.{escape_table_name(table_name)}") 

277 except Exception: 

278 logger.warning(f"Could not get tables from database {database}:", exc_info=True) 

279 continue 

280 

281 # Handle table restrictions 

282 if restriction_on_tables and database: 

283 for schema_name, tables in restriction_on_tables.items(): 

284 for table in tables: 

285 # Check if the table already has dot notation (e.g., 'postgresql_conn.home_rentals') 

286 if "." in table: 

287 # Table already has database prefix, add it directly 

288 tables_list.append(escape_table_name(table)) 

289 else: 

290 # No dot notation, apply schema and database as needed 

291 if schema_name is None: 

292 tables_list.append(f"{database}.{escape_table_name(table)}") 

293 else: 

294 tables_list.append(f"{database}.{schema_name}.{escape_table_name(table)}") 

295 continue 

296 

297 # Remove duplicates from lists 

298 tables_list = list(set(tables_list)) 

299 knowledge_bases_list = list(set(knowledge_bases_list)) 

300 ignore_knowledge_bases_list = list(set(ignore_knowledge_bases_list)) 

301 

302 # Determine knowledge base parameters to pass to SQLAgent 

303 include_knowledge_bases = knowledge_bases_list if knowledge_bases_list else None 

304 ignore_knowledge_bases = ignore_knowledge_bases_list if ignore_knowledge_bases_list else None 

305 

306 # If both include and ignore lists exist, include takes precedence 

307 if include_knowledge_bases: 

308 ignore_knowledge_bases = None 

309 

310 # # Get all databases from skills and extracted databases 

311 # all_databases = list(set([s.params.get('database', DEFAULT_TEXT2SQL_DATABASE) for s in skills if s.params.get('database')] + list(extracted_databases))) 

312 # 

313 # 

314 # # If no databases were specified or extracted, use 'mindsdb' as a default 

315 # if not all_databases: 

316 # all_databases = [DEFAULT_TEXT2SQL_DATABASE] 

317 # 

318 

319 all_databases = [] 

320 # Filter out None values 

321 all_databases = [db for db in all_databases if db is not None] 

322 

323 # Create a databases_struct dictionary that includes all extracted databases 

324 databases_struct = {} 

325 

326 # First, add databases from skills with explicit database parameters 

327 for skill in skills: 

328 if skill.params.get("database"): 

329 databases_struct[skill.params["database"]] = skill.restriction_on_tables 

330 

331 # Then, add all extracted databases with no restrictions 

332 for db_name in extracted_databases: 

333 if db_name not in databases_struct: 

334 databases_struct[db_name] = None 

335 

336 sql_agent = SQLAgent( 

337 command_executor=command_executor, 

338 databases=all_databases, 

339 databases_struct=databases_struct, 

340 include_tables=tables_list, 

341 ignore_tables=None, 

342 include_knowledge_bases=include_knowledge_bases, 

343 ignore_knowledge_bases=ignore_knowledge_bases, 

344 knowledge_base_database=knowledge_base_database, 

345 sample_rows_in_table_info=3, 

346 cache=get_cache("agent", max_size=_MAX_CACHE_SIZE), 

347 ) 

348 db = MindsDBSQL.custom_init(sql_agent=sql_agent) 

349 should_include_kb_tools = include_knowledge_bases is not None and len(include_knowledge_bases) > 0 

350 should_include_tables_tools = len(databases_struct) > 0 or len(tables_list) > 0 

351 toolkit = MindsDBSQLToolkit( 

352 db=db, 

353 llm=llm, 

354 include_tables_tools=should_include_tables_tools, 

355 include_knowledge_base_tools=should_include_kb_tools, 

356 ) 

357 return toolkit.get_tools() 

358 

359 def _make_retrieval_tools(self, skill: db.Skills, llm, embedding_model): 

360 """ 

361 creates advanced retrieval tool i.e. RAG 

362 """ 

363 params = skill.params 

364 config = params.get("config", {}) 

365 if "llm" not in config: 

366 # Set LLM if not explicitly provided in configs. 

367 config["llm"] = llm 

368 tool = dict( 

369 name=params.get("name", skill.name), 

370 source=params.get("source", None), 

371 config=config, 

372 description=f"You must use this tool to get more context or information " 

373 f"to answer a question about {params['description']}. " 

374 f"The input should be the exact question the user is asking.", 

375 type=skill.type, 

376 ) 

377 pred_args = {} 

378 pred_args["llm"] = llm 

379 

380 from .retrieval_tool import build_retrieval_tools 

381 

382 return build_retrieval_tools(tool, pred_args, skill) 

383 

384 def _get_rag_query_function(self, skill: db.Skills): 

385 session_controller = self.get_command_executor().session 

386 

387 def _answer_question(question: str) -> str: 

388 knowledge_base_name = skill.params["source"] 

389 

390 # make select in KB table 

391 query = Select( 

392 targets=[Star()], 

393 where=BinaryOperation(op="=", args=[Identifier(TableField.CONTENT.value), Constant(question)]), 

394 limit=Constant(_DEFAULT_TOP_K_SIMILARITY_SEARCH), 

395 ) 

396 kb_table = session_controller.kb_controller.get_table(knowledge_base_name, skill.project_id) 

397 

398 res = kb_table.select_query(query) 

399 # Handle both chunk_content and content column names 

400 if hasattr(res, "chunk_content"): 

401 return "\n".join(res.chunk_content) 

402 elif hasattr(res, "content"): 

403 return "\n".join(res.content) 

404 else: 

405 return "No content or chunk_content found in knowledge base response" 

406 

407 return _answer_question 

408 

409 def _make_knowledge_base_tools(self, skill: db.Skills) -> dict: 

410 # To prevent dependency on Langchain unless an actual tool uses it. 

411 description = skill.params.get("description", "") 

412 

413 logger.warning( 

414 "This skill is deprecated and will be removed in the future. Please use `retrieval` skill instead " 

415 ) 

416 

417 return dict( 

418 name="Knowledge Base Retrieval", 

419 func=self._get_rag_query_function(skill), 

420 description=f"Use this tool to get more context or information to answer a question about {description}. The input should be the exact question the user is asking.", 

421 type=skill.type, 

422 ) 

423 

424 def get_tools_from_skills( 

425 self, skills_data: List[SkillData], llm: BaseChatModel, embedding_model: Embeddings 

426 ) -> dict: 

427 """Creates function for skill and metadata (name, description) 

428 

429 Args: 

430 skills_data (List[SkillData]): Skills to make a tool from 

431 llm (BaseChatModel): LLM which will be used by skills 

432 embedding_model (Embeddings): this model is used by retrieval skill 

433 

434 Returns: 

435 dict: with keys: name, description, func 

436 """ 

437 

438 # group skills by type 

439 skills_group = defaultdict(list) 

440 for skill in skills_data: 

441 try: 

442 skill_type = SkillType(skill.type) 

443 except ValueError: 

444 raise NotImplementedError( 

445 f"skill of type {skill.type} is not supported as a tool, supported types are: {list(SkillType._member_names_)}" 

446 ) 

447 

448 if skill_type == SkillType.TEXT2SQL_LEGACY: 

449 skill_type = SkillType.TEXT2SQL 

450 skills_group[skill_type].append(skill) 

451 

452 tools = {} 

453 for skill_type, skills in skills_group.items(): 

454 if skill_type == SkillType.TEXT2SQL: 

455 tools[skill_type] = self._make_text_to_sql_tools(skills, llm) 

456 elif skill_type == SkillType.KNOWLEDGE_BASE: 

457 tools[skill_type] = [self._make_knowledge_base_tools(skill) for skill in skills] 

458 elif skill_type == SkillType.RETRIEVAL: 

459 tools[skill_type] = [] 

460 for skill in skills: 

461 tools[skill_type] += self._make_retrieval_tools(skill, llm, embedding_model) 

462 return tools 

463 

464 

465skill_tool = SkillToolController()