Coverage for mindsdb / integrations / libs / base.py: 62%

155 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import ast 

2import concurrent.futures 

3import inspect 

4import textwrap 

5from _ast import AnnAssign, AugAssign 

6from typing import Any, Dict, List, Optional 

7 

8import pandas as pd 

9from mindsdb_sql_parser.ast.base import ASTNode 

10from mindsdb.utilities import log 

11 

12from mindsdb.integrations.libs.response import HandlerResponse, HandlerStatusResponse, RESPONSE_TYPE 

13 

14logger = log.getLogger(__name__) 

15 

16 

17class BaseHandler: 

18 """Base class for database handlers 

19 

20 Base class for handlers that associate a source of information with the 

21 broader MindsDB ecosystem via SQL commands. 

22 """ 

23 

24 def __init__(self, name: str): 

25 """constructor 

26 Args: 

27 name (str): the handler name 

28 """ 

29 self.is_connected: bool = False 

30 self.name = name 

31 

32 def connect(self): 

33 """Set up any connections required by the handler 

34 

35 Should return connection 

36 

37 """ 

38 raise NotImplementedError() 

39 

40 def disconnect(self): 

41 """Close any existing connections 

42 

43 Should switch self.is_connected. 

44 """ 

45 self.is_connected = False 

46 return 

47 

48 def check_connection(self) -> HandlerStatusResponse: 

49 """Check connection to the handler 

50 

51 Returns: 

52 HandlerStatusResponse 

53 """ 

54 raise NotImplementedError() 

55 

56 def native_query(self, query: Any) -> HandlerResponse: 

57 """Receive raw query and act upon it somehow. 

58 

59 Args: 

60 query (Any): query in native format (str for sql databases, 

61 etc) 

62 

63 Returns: 

64 HandlerResponse 

65 """ 

66 raise NotImplementedError() 

67 

68 def query(self, query: ASTNode) -> HandlerResponse: 

69 """Receive query as AST (abstract syntax tree) and act upon it somehow. 

70 

71 Args: 

72 query (ASTNode): sql query represented as AST. May be any kind 

73 of query: SELECT, INSERT, DELETE, etc 

74 

75 Returns: 

76 HandlerResponse 

77 """ 

78 raise NotImplementedError() 

79 

80 def get_tables(self) -> HandlerResponse: 

81 """Return list of entities 

82 

83 Return list of entities that will be accesible as tables. 

84 

85 Returns: 

86 HandlerResponse: shoud have same columns as information_schema.tables 

87 (https://dev.mysql.com/doc/refman/8.0/en/information-schema-tables-table.html) 

88 Column 'TABLE_NAME' is mandatory, other is optional. 

89 """ 

90 raise NotImplementedError() 

91 

92 def get_columns(self, table_name: str) -> HandlerResponse: 

93 """Returns a list of entity columns 

94 

95 Args: 

96 table_name (str): name of one of tables returned by self.get_tables() 

97 

98 Returns: 

99 HandlerResponse: shoud have same columns as information_schema.columns 

100 (https://dev.mysql.com/doc/refman/8.0/en/information-schema-columns-table.html) 

101 Column 'COLUMN_NAME' is mandatory, other is optional. Hightly 

102 recomended to define also 'DATA_TYPE': it should be one of 

103 python data types (by default it str). 

104 """ 

105 raise NotImplementedError() 

106 

107 

108class DatabaseHandler(BaseHandler): 

109 """ 

110 Base class for handlers associated to data storage systems (e.g. databases, data warehouses, streaming services, etc.) 

111 """ 

112 

113 def __init__(self, name: str): 

114 super().__init__(name) 

115 

116 

117class MetaDatabaseHandler(DatabaseHandler): 

118 """ 

119 Base class for handlers associated to data storage systems (e.g. databases, data warehouses, streaming services, etc.) 

120 

121 This class is used when the handler is also needed to store information in the data catalog. 

122 This information is typically avaiable in the information schema or system tables of the database. 

123 """ 

124 

125 def __init__(self, name: str): 

126 super().__init__(name) 

127 

128 def meta_get_tables(self, table_names: Optional[List[str]]) -> HandlerResponse: 

129 """ 

130 Returns metadata information about the tables to be stored in the data catalog. 

131 

132 Returns: 

133 HandlerResponse: The response should consist of the following columns: 

134 - TABLE_NAME (str): Name of the table. 

135 - TABLE_TYPE (str): Type of the table, e.g. 'BASE TABLE', 'VIEW', etc. (optional). 

136 - TABLE_SCHEMA (str): Schema of the table (optional). 

137 - TABLE_DESCRIPTION (str): Description of the table (optional). 

138 - ROW_COUNT (int): Estimated number of rows in the table (optional). 

139 """ 

140 raise NotImplementedError() 

141 

142 def meta_get_columns(self, table_names: Optional[List[str]]) -> HandlerResponse: 

143 """ 

144 Returns metadata information about the columns in the tables to be stored in the data catalog. 

145 

146 Returns: 

147 HandlerResponse: The response should consist of the following columns: 

148 - TABLE_NAME (str): Name of the table. 

149 - COLUMN_NAME (str): Name of the column. 

150 - DATA_TYPE (str): Data type of the column, e.g. 'VARCHAR', 'INT', etc. 

151 - COLUMN_DESCRIPTION (str): Description of the column (optional). 

152 - IS_NULLABLE (bool): Whether the column can contain NULL values (optional). 

153 - COLUMN_DEFAULT (str): Default value of the column (optional). 

154 """ 

155 raise NotImplementedError() 

156 

157 def meta_get_column_statistics(self, table_names: Optional[List[str]]) -> HandlerResponse: 

158 """ 

159 Returns metadata statisical information about the columns in the tables to be stored in the data catalog. 

160 Either this method should be overridden in the handler or `meta_get_column_statistics_for_table` should be implemented. 

161 

162 Returns: 

163 HandlerResponse: The response should consist of the following columns: 

164 - TABLE_NAME (str): Name of the table. 

165 - COLUMN_NAME (str): Name of the column. 

166 - MOST_COMMON_VALUES (List[str]): Most common values in the column (optional). 

167 - MOST_COMMON_FREQUENCIES (List[str]): Frequencies of the most common values in the column (optional). 

168 - NULL_PERCENTAGE: Percentage of NULL values in the column (optional). 

169 - MINIMUM_VALUE (str): Minimum value in the column (optional). 

170 - MAXIMUM_VALUE (str): Maximum value in the column (optional). 

171 - DISTINCT_VALUES_COUNT (int): Count of distinct values in the column (optional). 

172 """ 

173 method = getattr(self, "meta_get_column_statistics_for_table") 

174 if method.__func__ is not MetaDatabaseHandler.meta_get_column_statistics_for_table: 

175 meta_columns = self.meta_get_columns(table_names) 

176 grouped_columns = ( 

177 meta_columns.data_frame.groupby("table_name") 

178 .agg( 

179 { 

180 "column_name": list, 

181 } 

182 ) 

183 .reset_index() 

184 ) 

185 

186 executor = concurrent.futures.ThreadPoolExecutor(max_workers=5) 

187 futures = [] 

188 

189 results = [] 

190 with executor: 

191 for _, row in grouped_columns.iterrows(): 

192 table_name = row["table_name"] 

193 columns = row["column_name"] 

194 futures.append(executor.submit(self.meta_get_column_statistics_for_table, table_name, columns)) 

195 

196 for future in concurrent.futures.as_completed(futures): 

197 try: 

198 result = future.result(timeout=120) 

199 if result.resp_type == RESPONSE_TYPE.TABLE: 

200 results.append(result.data_frame) 

201 else: 

202 logger.error( 

203 f"Error retrieving column statistics for table {table_name}: {result.error_message}" 

204 ) 

205 except Exception: 

206 logger.exception(f"Exception occurred while retrieving column statistics for table {table_name}:") 

207 

208 if not results: 

209 logger.warning("No column statistics could be retrieved for the specified tables.") 

210 return HandlerResponse(RESPONSE_TYPE.ERROR, error_message="No column statistics could be retrieved.") 

211 return HandlerResponse( 

212 RESPONSE_TYPE.TABLE, pd.concat(results, ignore_index=True) if results else pd.DataFrame() 

213 ) 

214 

215 else: 

216 raise NotImplementedError() 

217 

218 def meta_get_column_statistics_for_table( 

219 self, table_name: str, column_names: Optional[List[str]] = None 

220 ) -> HandlerResponse: 

221 """ 

222 Returns metadata statistical information about the columns in a specific table to be stored in the data catalog. 

223 Either this method should be implemented in the handler or `meta_get_column_statistics` should be overridden. 

224 

225 Args: 

226 table_name (str): Name of the table. 

227 column_names (Optional[List[str]]): List of column names to retrieve statistics for. If None, statistics for all columns will be returned. 

228 

229 Returns: 

230 HandlerResponse: The response should consist of the following columns: 

231 - TABLE_NAME (str): Name of the table. 

232 - COLUMN_NAME (str): Name of the column. 

233 - MOST_COMMON_VALUES (List[str]): Most common values in the column (optional). 

234 - MOST_COMMON_FREQUENCIES (List[str]): Frequencies of the most common values in the column (optional). 

235 - NULL_PERCENTAGE: Percentage of NULL values in the column (optional). 

236 - MINIMUM_VALUE (str): Minimum value in the column (optional). 

237 - MAXIMUM_VALUE (str): Maximum value in the column (optional). 

238 - DISTINCT_VALUES_COUNT (int): Count of distinct values in the column (optional). 

239 """ 

240 pass 

241 

242 def meta_get_primary_keys(self, table_names: Optional[List[str]]) -> HandlerResponse: 

243 """ 

244 Returns metadata information about the primary keys in the tables to be stored in the data catalog. 

245 

246 Returns: 

247 HandlerResponse: The response should consist of the following columns: 

248 - TABLE_NAME (str): Name of the table. 

249 - COLUMN_NAME (str): Name of the column that is part of the primary key. 

250 - ORDINAL_POSITION (int): Position of the column in the primary key (optional). 

251 - CONSTRAINT_NAME (str): Name of the primary key constraint (optional). 

252 """ 

253 raise NotImplementedError() 

254 

255 def meta_get_foreign_keys(self, table_names: Optional[List[str]]) -> HandlerResponse: 

256 """ 

257 Returns metadata information about the foreign keys in the tables to be stored in the data catalog. 

258 

259 Returns: 

260 HandlerResponse: The response should consist of the following columns: 

261 - PARENT_TABLE_NAME (str): Name of the parent table. 

262 - PARENT_COLUMN_NAME (str): Name of the parent column that is part of the foreign key. 

263 - CHILD_TABLE_NAME (str): Name of the child table. 

264 - CHILD_COLUMN_NAME (str): Name of the child column that is part of the foreign key. 

265 - CONSTRAINT_NAME (str): Name of the foreign key constraint (optional). 

266 """ 

267 raise NotImplementedError() 

268 

269 def meta_get_handler_info(self, **kwargs) -> str: 

270 """ 

271 Retrieves information about the design and implementation of the database handler. 

272 This should include, but not be limited to, the following: 

273 - The type of SQL queries and operations that the handler supports. 

274 - etc. 

275 

276 Args: 

277 kwargs: Additional keyword arguments that may be used in generating the handler information. 

278 

279 Returns: 

280 str: A string containing information about the database handler's design and implementation. 

281 """ 

282 pass 

283 

284 

285class ArgProbeMixin: 

286 """ 

287 A mixin class that provides probing of arguments that 

288 are needed by a handler during creation and prediction time 

289 by running the static analysis on the source code of the handler. 

290 """ 

291 

292 class ArgProbeVisitor(ast.NodeVisitor): 

293 def __init__(self): 

294 self.arg_keys = [] 

295 self.var_names_to_track = {"args"} 

296 

297 def visit_Assign(self, node): 

298 # track if args['using'] get assigned to any variable 

299 # if so, we should track the variable by adding it to 

300 # self.var_names_to_track 

301 # E.g., using_args = args['using'] 

302 # we should track using_args as well 

303 if ( 

304 isinstance(node.value, ast.Subscript) 

305 and isinstance(node.value.value, ast.Name) 

306 and node.value.value.id == "args" 

307 ): 

308 if ( 308 ↛ 313line 308 didn't jump to line 313 because the condition on line 308 was never true

309 isinstance(node.value.slice, ast.Index) 

310 and isinstance(node.value.slice.value, ast.Str) 

311 and node.value.slice.value.s == "using" 

312 ): 

313 self.var_names_to_track.add(node.targets[0].id) 

314 

315 # for an assignment like `self.args['name'] = 'value'`, we should ignore 

316 # the left side of the assignment 

317 self.visit(node.value) 

318 

319 def visit_AnnAssign(self, node: AnnAssign) -> Any: 

320 self.visit(node.value) 

321 

322 def visit_AugAssign(self, node: AugAssign) -> Any: 

323 self.visit(node.value) 

324 

325 def visit_Subscript(self, node): 

326 if isinstance(node.value, ast.Name) and node.value.id in self.var_names_to_track: 

327 if isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Str): 327 ↛ 328line 327 didn't jump to line 328 because the condition on line 327 was never true

328 self.arg_keys.append({"name": node.slice.value.s, "required": True}) 

329 self.generic_visit(node) 

330 

331 def visit_Call(self, node): 

332 if isinstance(node.func, ast.Attribute) and node.func.attr == "get": 

333 if isinstance(node.func.value, ast.Name) and node.func.value.id in self.var_names_to_track: 

334 if isinstance(node.args[0], ast.Str): 334 ↛ 336line 334 didn't jump to line 336 because the condition on line 334 was always true

335 self.arg_keys.append({"name": node.args[0].s, "required": False}) 

336 self.generic_visit(node) 

337 

338 @classmethod 

339 def probe_function(self, method_name: str) -> List[Dict]: 

340 """ 

341 Probe the source code of the method with name method_name. 

342 Specifically, trace how the argument `args`, which is a dict is used in the method. 

343 

344 Find all places where a key of the dict is used, and return a list of all keys that are used. 

345 E.g., 

346 args["key1"] -> "key1" is accessed, and it is required 

347 args.get("key2", "default_value") -> "key2" is accessed, and it is optional (default value is provided) 

348 

349 Return a list of dict 

350 where each dict looks like 

351 { 

352 "name": "key1", 

353 "required": True 

354 } 

355 """ 

356 try: 

357 source_code = self.get_source_code(method_name) 

358 except Exception: 

359 logger.exception(f"Failed to get source code of method {method_name} in {self.__class__.__name__}. Reason:") 

360 return [] 

361 

362 # parse the source code 

363 # fix the indentation 

364 source_code = textwrap.dedent(source_code) 

365 # parse the source code 

366 tree = ast.parse(source_code) 

367 

368 # find all places where a key in args is accessed 

369 # and if it is accessed using args["key"] or args.get("key", "default_value") 

370 

371 visitor = self.ArgProbeVisitor() 

372 visitor.visit(tree) 

373 

374 # deduplicate the keys 

375 # if there two records with the same name but different required status 

376 # we should keep the one with required == True 

377 unique_arg_keys = {} 

378 for r in visitor.arg_keys: 

379 if r["name"] in unique_arg_keys: 

380 if r["required"]: 380 ↛ 381line 380 didn't jump to line 381 because the condition on line 380 was never true

381 unique_arg_keys[r["name"]] = r["required"] 

382 else: 

383 unique_arg_keys[r["name"]] = r["required"] 

384 

385 # convert back to list 

386 visitor.arg_keys = [{"name": k, "required": v} for k, v in unique_arg_keys.items()] 

387 

388 # filter out record where name == "using" 

389 return [r for r in visitor.arg_keys if r["name"] != "using"] 

390 

391 @classmethod 

392 def get_source_code(self, method_name: str): 

393 """ 

394 Get the source code of the method specified by method_name 

395 """ 

396 method = getattr(self, method_name) 

397 if method is None: 397 ↛ 398line 397 didn't jump to line 398 because the condition on line 397 was never true

398 raise Exception(f"Method {method_name} does not exist in {self.__class__.__name__}") 

399 source_code = inspect.getsource(method) 

400 return source_code 

401 

402 @classmethod 

403 def prediction_args(self): 

404 """ 

405 Get the arguments that are needed by the prediction method 

406 """ 

407 return self.probe_function("predict") 

408 

409 @classmethod 

410 def creation_args(self): 

411 """ 

412 Get the arguments that are needed by the creation method 

413 """ 

414 return self.probe_function("create") 

415 

416 

417class BaseMLEngine(ArgProbeMixin): 

418 """ 

419 Base class for integration engines to connect with other machine learning libraries/frameworks. 

420 

421 This class will be instanced when interacting with the underlying framework. For compliance with the interface 

422 that MindsDB core expects, instances of this class will be wrapped with the `BaseMLEngineExec` class defined 

423 in `libs/ml_exec_base`. 

424 

425 Broadly speaking, the flow is as follows: 

426 - A SQL statement is sent to the MindsDB executor 

427 - The statement is parsed, and a sequential plan is generated by `mindsdb_sql` 

428 - If any step in the plan involves an ML framework, a wrapped engine that inherits from this class will be called for the respective action 

429 - For example, creating a new model would call `create()` 

430 - Any output produced by the ML engine is then formatted by the wrapper and passed back into the MindsDB executor, which can then morph the data to comply with the original SQL query 

431 """ # noqa 

432 

433 def __init__(self, model_storage, engine_storage, **kwargs) -> None: 

434 """ 

435 Warning: This method should not be overridden. 

436 

437 Initialize storage objects required by the ML engine. 

438 

439 - engine_storage: persists global engine-related internals or artifacts that may be used by all models from the engine. 

440 - model_storage: stores artifacts for any single given model. 

441 """ 

442 self.model_storage = model_storage 

443 self.engine_storage = engine_storage 

444 self.generative = False # if True, the target column name does not have to be specified at creation time 

445 

446 if kwargs.get("base_model_storage"): 446 ↛ 447line 446 didn't jump to line 447 because the condition on line 446 was never true

447 self.base_model_storage = kwargs["base_model_storage"] # available when updating a model 

448 else: 

449 self.base_model_storage = None 

450 

451 def create(self, target: str, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None: 

452 """ 

453 Saves a model inside the engine registry for later usage. 

454 

455 Normally, an input dataframe is required to train the model. 

456 However, some integrations may merely require registering the model instead of training, in which case `df` can be omitted. 

457 

458 Any other arguments required to register the model can be passed in an `args` dictionary. 

459 """ 

460 raise NotImplementedError 

461 

462 def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame: 

463 """ 

464 Calls a model with some input dataframe `df`, and optionally some arguments `args` that may modify the model behavior. 

465 

466 The expected output is a dataframe with the predicted values in the target-named column. 

467 Additional columns can be present, and will be considered row-wise explanations if their names finish with `_explain`. 

468 """ 

469 raise NotImplementedError 

470 

471 def finetune(self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None: 

472 """ 

473 Optional. 

474 

475 Used to fine-tune a pre-existing model without resetting its internal state (e.g. weights). 

476 

477 Availability will depend on underlying integration support, as not all ML models can be partially updated. 

478 """ 

479 raise NotImplementedError 

480 

481 def describe(self, attribute: Optional[str] = None) -> pd.DataFrame: 

482 """Optional. 

483 

484 When called, this method provides global model insights, e.g. framework-level parameters used in training. 

485 """ 

486 raise NotImplementedError 

487 

488 def update(self, args: dict) -> None: 

489 """Optional. 

490 

491 Update model. 

492 """ 

493 raise NotImplementedError 

494 

495 def create_engine(self, connection_args: dict): 

496 """Optional. 

497 

498 Used to connect with external sources (e.g. a REST API) that the engine will require to use any other methods. 

499 """ 

500 raise NotImplementedError 

501 

502 def update_engine(self, connection_args: dict): 

503 """Optional. 

504 

505 Used when need to change connection args or do any make any other changes to the engine 

506 """ 

507 raise NotImplementedError 

508 

509 def close(self): 

510 pass