Coverage for mindsdb / integrations / libs / base.py: 62%
155 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import ast
2import concurrent.futures
3import inspect
4import textwrap
5from _ast import AnnAssign, AugAssign
6from typing import Any, Dict, List, Optional
8import pandas as pd
9from mindsdb_sql_parser.ast.base import ASTNode
10from mindsdb.utilities import log
12from mindsdb.integrations.libs.response import HandlerResponse, HandlerStatusResponse, RESPONSE_TYPE
14logger = log.getLogger(__name__)
17class BaseHandler:
18 """Base class for database handlers
20 Base class for handlers that associate a source of information with the
21 broader MindsDB ecosystem via SQL commands.
22 """
24 def __init__(self, name: str):
25 """constructor
26 Args:
27 name (str): the handler name
28 """
29 self.is_connected: bool = False
30 self.name = name
32 def connect(self):
33 """Set up any connections required by the handler
35 Should return connection
37 """
38 raise NotImplementedError()
40 def disconnect(self):
41 """Close any existing connections
43 Should switch self.is_connected.
44 """
45 self.is_connected = False
46 return
48 def check_connection(self) -> HandlerStatusResponse:
49 """Check connection to the handler
51 Returns:
52 HandlerStatusResponse
53 """
54 raise NotImplementedError()
56 def native_query(self, query: Any) -> HandlerResponse:
57 """Receive raw query and act upon it somehow.
59 Args:
60 query (Any): query in native format (str for sql databases,
61 etc)
63 Returns:
64 HandlerResponse
65 """
66 raise NotImplementedError()
68 def query(self, query: ASTNode) -> HandlerResponse:
69 """Receive query as AST (abstract syntax tree) and act upon it somehow.
71 Args:
72 query (ASTNode): sql query represented as AST. May be any kind
73 of query: SELECT, INSERT, DELETE, etc
75 Returns:
76 HandlerResponse
77 """
78 raise NotImplementedError()
80 def get_tables(self) -> HandlerResponse:
81 """Return list of entities
83 Return list of entities that will be accesible as tables.
85 Returns:
86 HandlerResponse: shoud have same columns as information_schema.tables
87 (https://dev.mysql.com/doc/refman/8.0/en/information-schema-tables-table.html)
88 Column 'TABLE_NAME' is mandatory, other is optional.
89 """
90 raise NotImplementedError()
92 def get_columns(self, table_name: str) -> HandlerResponse:
93 """Returns a list of entity columns
95 Args:
96 table_name (str): name of one of tables returned by self.get_tables()
98 Returns:
99 HandlerResponse: shoud have same columns as information_schema.columns
100 (https://dev.mysql.com/doc/refman/8.0/en/information-schema-columns-table.html)
101 Column 'COLUMN_NAME' is mandatory, other is optional. Hightly
102 recomended to define also 'DATA_TYPE': it should be one of
103 python data types (by default it str).
104 """
105 raise NotImplementedError()
108class DatabaseHandler(BaseHandler):
109 """
110 Base class for handlers associated to data storage systems (e.g. databases, data warehouses, streaming services, etc.)
111 """
113 def __init__(self, name: str):
114 super().__init__(name)
117class MetaDatabaseHandler(DatabaseHandler):
118 """
119 Base class for handlers associated to data storage systems (e.g. databases, data warehouses, streaming services, etc.)
121 This class is used when the handler is also needed to store information in the data catalog.
122 This information is typically avaiable in the information schema or system tables of the database.
123 """
125 def __init__(self, name: str):
126 super().__init__(name)
128 def meta_get_tables(self, table_names: Optional[List[str]]) -> HandlerResponse:
129 """
130 Returns metadata information about the tables to be stored in the data catalog.
132 Returns:
133 HandlerResponse: The response should consist of the following columns:
134 - TABLE_NAME (str): Name of the table.
135 - TABLE_TYPE (str): Type of the table, e.g. 'BASE TABLE', 'VIEW', etc. (optional).
136 - TABLE_SCHEMA (str): Schema of the table (optional).
137 - TABLE_DESCRIPTION (str): Description of the table (optional).
138 - ROW_COUNT (int): Estimated number of rows in the table (optional).
139 """
140 raise NotImplementedError()
142 def meta_get_columns(self, table_names: Optional[List[str]]) -> HandlerResponse:
143 """
144 Returns metadata information about the columns in the tables to be stored in the data catalog.
146 Returns:
147 HandlerResponse: The response should consist of the following columns:
148 - TABLE_NAME (str): Name of the table.
149 - COLUMN_NAME (str): Name of the column.
150 - DATA_TYPE (str): Data type of the column, e.g. 'VARCHAR', 'INT', etc.
151 - COLUMN_DESCRIPTION (str): Description of the column (optional).
152 - IS_NULLABLE (bool): Whether the column can contain NULL values (optional).
153 - COLUMN_DEFAULT (str): Default value of the column (optional).
154 """
155 raise NotImplementedError()
157 def meta_get_column_statistics(self, table_names: Optional[List[str]]) -> HandlerResponse:
158 """
159 Returns metadata statisical information about the columns in the tables to be stored in the data catalog.
160 Either this method should be overridden in the handler or `meta_get_column_statistics_for_table` should be implemented.
162 Returns:
163 HandlerResponse: The response should consist of the following columns:
164 - TABLE_NAME (str): Name of the table.
165 - COLUMN_NAME (str): Name of the column.
166 - MOST_COMMON_VALUES (List[str]): Most common values in the column (optional).
167 - MOST_COMMON_FREQUENCIES (List[str]): Frequencies of the most common values in the column (optional).
168 - NULL_PERCENTAGE: Percentage of NULL values in the column (optional).
169 - MINIMUM_VALUE (str): Minimum value in the column (optional).
170 - MAXIMUM_VALUE (str): Maximum value in the column (optional).
171 - DISTINCT_VALUES_COUNT (int): Count of distinct values in the column (optional).
172 """
173 method = getattr(self, "meta_get_column_statistics_for_table")
174 if method.__func__ is not MetaDatabaseHandler.meta_get_column_statistics_for_table:
175 meta_columns = self.meta_get_columns(table_names)
176 grouped_columns = (
177 meta_columns.data_frame.groupby("table_name")
178 .agg(
179 {
180 "column_name": list,
181 }
182 )
183 .reset_index()
184 )
186 executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
187 futures = []
189 results = []
190 with executor:
191 for _, row in grouped_columns.iterrows():
192 table_name = row["table_name"]
193 columns = row["column_name"]
194 futures.append(executor.submit(self.meta_get_column_statistics_for_table, table_name, columns))
196 for future in concurrent.futures.as_completed(futures):
197 try:
198 result = future.result(timeout=120)
199 if result.resp_type == RESPONSE_TYPE.TABLE:
200 results.append(result.data_frame)
201 else:
202 logger.error(
203 f"Error retrieving column statistics for table {table_name}: {result.error_message}"
204 )
205 except Exception:
206 logger.exception(f"Exception occurred while retrieving column statistics for table {table_name}:")
208 if not results:
209 logger.warning("No column statistics could be retrieved for the specified tables.")
210 return HandlerResponse(RESPONSE_TYPE.ERROR, error_message="No column statistics could be retrieved.")
211 return HandlerResponse(
212 RESPONSE_TYPE.TABLE, pd.concat(results, ignore_index=True) if results else pd.DataFrame()
213 )
215 else:
216 raise NotImplementedError()
218 def meta_get_column_statistics_for_table(
219 self, table_name: str, column_names: Optional[List[str]] = None
220 ) -> HandlerResponse:
221 """
222 Returns metadata statistical information about the columns in a specific table to be stored in the data catalog.
223 Either this method should be implemented in the handler or `meta_get_column_statistics` should be overridden.
225 Args:
226 table_name (str): Name of the table.
227 column_names (Optional[List[str]]): List of column names to retrieve statistics for. If None, statistics for all columns will be returned.
229 Returns:
230 HandlerResponse: The response should consist of the following columns:
231 - TABLE_NAME (str): Name of the table.
232 - COLUMN_NAME (str): Name of the column.
233 - MOST_COMMON_VALUES (List[str]): Most common values in the column (optional).
234 - MOST_COMMON_FREQUENCIES (List[str]): Frequencies of the most common values in the column (optional).
235 - NULL_PERCENTAGE: Percentage of NULL values in the column (optional).
236 - MINIMUM_VALUE (str): Minimum value in the column (optional).
237 - MAXIMUM_VALUE (str): Maximum value in the column (optional).
238 - DISTINCT_VALUES_COUNT (int): Count of distinct values in the column (optional).
239 """
240 pass
242 def meta_get_primary_keys(self, table_names: Optional[List[str]]) -> HandlerResponse:
243 """
244 Returns metadata information about the primary keys in the tables to be stored in the data catalog.
246 Returns:
247 HandlerResponse: The response should consist of the following columns:
248 - TABLE_NAME (str): Name of the table.
249 - COLUMN_NAME (str): Name of the column that is part of the primary key.
250 - ORDINAL_POSITION (int): Position of the column in the primary key (optional).
251 - CONSTRAINT_NAME (str): Name of the primary key constraint (optional).
252 """
253 raise NotImplementedError()
255 def meta_get_foreign_keys(self, table_names: Optional[List[str]]) -> HandlerResponse:
256 """
257 Returns metadata information about the foreign keys in the tables to be stored in the data catalog.
259 Returns:
260 HandlerResponse: The response should consist of the following columns:
261 - PARENT_TABLE_NAME (str): Name of the parent table.
262 - PARENT_COLUMN_NAME (str): Name of the parent column that is part of the foreign key.
263 - CHILD_TABLE_NAME (str): Name of the child table.
264 - CHILD_COLUMN_NAME (str): Name of the child column that is part of the foreign key.
265 - CONSTRAINT_NAME (str): Name of the foreign key constraint (optional).
266 """
267 raise NotImplementedError()
269 def meta_get_handler_info(self, **kwargs) -> str:
270 """
271 Retrieves information about the design and implementation of the database handler.
272 This should include, but not be limited to, the following:
273 - The type of SQL queries and operations that the handler supports.
274 - etc.
276 Args:
277 kwargs: Additional keyword arguments that may be used in generating the handler information.
279 Returns:
280 str: A string containing information about the database handler's design and implementation.
281 """
282 pass
285class ArgProbeMixin:
286 """
287 A mixin class that provides probing of arguments that
288 are needed by a handler during creation and prediction time
289 by running the static analysis on the source code of the handler.
290 """
292 class ArgProbeVisitor(ast.NodeVisitor):
293 def __init__(self):
294 self.arg_keys = []
295 self.var_names_to_track = {"args"}
297 def visit_Assign(self, node):
298 # track if args['using'] get assigned to any variable
299 # if so, we should track the variable by adding it to
300 # self.var_names_to_track
301 # E.g., using_args = args['using']
302 # we should track using_args as well
303 if (
304 isinstance(node.value, ast.Subscript)
305 and isinstance(node.value.value, ast.Name)
306 and node.value.value.id == "args"
307 ):
308 if ( 308 ↛ 313line 308 didn't jump to line 313 because the condition on line 308 was never true
309 isinstance(node.value.slice, ast.Index)
310 and isinstance(node.value.slice.value, ast.Str)
311 and node.value.slice.value.s == "using"
312 ):
313 self.var_names_to_track.add(node.targets[0].id)
315 # for an assignment like `self.args['name'] = 'value'`, we should ignore
316 # the left side of the assignment
317 self.visit(node.value)
319 def visit_AnnAssign(self, node: AnnAssign) -> Any:
320 self.visit(node.value)
322 def visit_AugAssign(self, node: AugAssign) -> Any:
323 self.visit(node.value)
325 def visit_Subscript(self, node):
326 if isinstance(node.value, ast.Name) and node.value.id in self.var_names_to_track:
327 if isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Str): 327 ↛ 328line 327 didn't jump to line 328 because the condition on line 327 was never true
328 self.arg_keys.append({"name": node.slice.value.s, "required": True})
329 self.generic_visit(node)
331 def visit_Call(self, node):
332 if isinstance(node.func, ast.Attribute) and node.func.attr == "get":
333 if isinstance(node.func.value, ast.Name) and node.func.value.id in self.var_names_to_track:
334 if isinstance(node.args[0], ast.Str): 334 ↛ 336line 334 didn't jump to line 336 because the condition on line 334 was always true
335 self.arg_keys.append({"name": node.args[0].s, "required": False})
336 self.generic_visit(node)
338 @classmethod
339 def probe_function(self, method_name: str) -> List[Dict]:
340 """
341 Probe the source code of the method with name method_name.
342 Specifically, trace how the argument `args`, which is a dict is used in the method.
344 Find all places where a key of the dict is used, and return a list of all keys that are used.
345 E.g.,
346 args["key1"] -> "key1" is accessed, and it is required
347 args.get("key2", "default_value") -> "key2" is accessed, and it is optional (default value is provided)
349 Return a list of dict
350 where each dict looks like
351 {
352 "name": "key1",
353 "required": True
354 }
355 """
356 try:
357 source_code = self.get_source_code(method_name)
358 except Exception:
359 logger.exception(f"Failed to get source code of method {method_name} in {self.__class__.__name__}. Reason:")
360 return []
362 # parse the source code
363 # fix the indentation
364 source_code = textwrap.dedent(source_code)
365 # parse the source code
366 tree = ast.parse(source_code)
368 # find all places where a key in args is accessed
369 # and if it is accessed using args["key"] or args.get("key", "default_value")
371 visitor = self.ArgProbeVisitor()
372 visitor.visit(tree)
374 # deduplicate the keys
375 # if there two records with the same name but different required status
376 # we should keep the one with required == True
377 unique_arg_keys = {}
378 for r in visitor.arg_keys:
379 if r["name"] in unique_arg_keys:
380 if r["required"]: 380 ↛ 381line 380 didn't jump to line 381 because the condition on line 380 was never true
381 unique_arg_keys[r["name"]] = r["required"]
382 else:
383 unique_arg_keys[r["name"]] = r["required"]
385 # convert back to list
386 visitor.arg_keys = [{"name": k, "required": v} for k, v in unique_arg_keys.items()]
388 # filter out record where name == "using"
389 return [r for r in visitor.arg_keys if r["name"] != "using"]
391 @classmethod
392 def get_source_code(self, method_name: str):
393 """
394 Get the source code of the method specified by method_name
395 """
396 method = getattr(self, method_name)
397 if method is None: 397 ↛ 398line 397 didn't jump to line 398 because the condition on line 397 was never true
398 raise Exception(f"Method {method_name} does not exist in {self.__class__.__name__}")
399 source_code = inspect.getsource(method)
400 return source_code
402 @classmethod
403 def prediction_args(self):
404 """
405 Get the arguments that are needed by the prediction method
406 """
407 return self.probe_function("predict")
409 @classmethod
410 def creation_args(self):
411 """
412 Get the arguments that are needed by the creation method
413 """
414 return self.probe_function("create")
417class BaseMLEngine(ArgProbeMixin):
418 """
419 Base class for integration engines to connect with other machine learning libraries/frameworks.
421 This class will be instanced when interacting with the underlying framework. For compliance with the interface
422 that MindsDB core expects, instances of this class will be wrapped with the `BaseMLEngineExec` class defined
423 in `libs/ml_exec_base`.
425 Broadly speaking, the flow is as follows:
426 - A SQL statement is sent to the MindsDB executor
427 - The statement is parsed, and a sequential plan is generated by `mindsdb_sql`
428 - If any step in the plan involves an ML framework, a wrapped engine that inherits from this class will be called for the respective action
429 - For example, creating a new model would call `create()`
430 - Any output produced by the ML engine is then formatted by the wrapper and passed back into the MindsDB executor, which can then morph the data to comply with the original SQL query
431 """ # noqa
433 def __init__(self, model_storage, engine_storage, **kwargs) -> None:
434 """
435 Warning: This method should not be overridden.
437 Initialize storage objects required by the ML engine.
439 - engine_storage: persists global engine-related internals or artifacts that may be used by all models from the engine.
440 - model_storage: stores artifacts for any single given model.
441 """
442 self.model_storage = model_storage
443 self.engine_storage = engine_storage
444 self.generative = False # if True, the target column name does not have to be specified at creation time
446 if kwargs.get("base_model_storage"): 446 ↛ 447line 446 didn't jump to line 447 because the condition on line 446 was never true
447 self.base_model_storage = kwargs["base_model_storage"] # available when updating a model
448 else:
449 self.base_model_storage = None
451 def create(self, target: str, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None:
452 """
453 Saves a model inside the engine registry for later usage.
455 Normally, an input dataframe is required to train the model.
456 However, some integrations may merely require registering the model instead of training, in which case `df` can be omitted.
458 Any other arguments required to register the model can be passed in an `args` dictionary.
459 """
460 raise NotImplementedError
462 def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame:
463 """
464 Calls a model with some input dataframe `df`, and optionally some arguments `args` that may modify the model behavior.
466 The expected output is a dataframe with the predicted values in the target-named column.
467 Additional columns can be present, and will be considered row-wise explanations if their names finish with `_explain`.
468 """
469 raise NotImplementedError
471 def finetune(self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None:
472 """
473 Optional.
475 Used to fine-tune a pre-existing model without resetting its internal state (e.g. weights).
477 Availability will depend on underlying integration support, as not all ML models can be partially updated.
478 """
479 raise NotImplementedError
481 def describe(self, attribute: Optional[str] = None) -> pd.DataFrame:
482 """Optional.
484 When called, this method provides global model insights, e.g. framework-level parameters used in training.
485 """
486 raise NotImplementedError
488 def update(self, args: dict) -> None:
489 """Optional.
491 Update model.
492 """
493 raise NotImplementedError
495 def create_engine(self, connection_args: dict):
496 """Optional.
498 Used to connect with external sources (e.g. a REST API) that the engine will require to use any other methods.
499 """
500 raise NotImplementedError
502 def update_engine(self, connection_args: dict):
503 """Optional.
505 Used when need to change connection args or do any make any other changes to the engine
506 """
507 raise NotImplementedError
509 def close(self):
510 pass