Coverage for mindsdb / integrations / libs / ml_exec_base.py: 44%
150 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1"""
2This module defines the wrapper for ML engines which abstracts away a lot of complexity.
4In particular, three big components are included:
6 - `BaseMLEngineExec` class: this class wraps any object that inherits from `BaseMLEngine` and exposes some endpoints
7 normally associated with a DB handler (e.g. `native_query`, `get_tables`), as well as other ML-specific behaviors,
8 like `learn()` or `predict()`. Note that while these still have to be implemented at the engine level, the burden
9 on that class is lesser given that it only needs to return a pandas DataFrame. It's this class that will take said
10 output and format it into the HandlerResponse instance that MindsDB core expects.
12 - `learn_process` method: handles async dispatch of the `learn` method in an engine, as well as registering all
13 models inside of the internal MindsDB registry.
15 - `predict_process` method: handles async dispatch of the `predict` method in an engine.
17"""
19import socket
20import contextlib
21import datetime as dt
22from types import ModuleType
23from typing import Optional, Union
25import pandas as pd
26from sqlalchemy import func, null
27from sqlalchemy.sql.functions import coalesce
29from mindsdb.utilities.config import Config
30import mindsdb.interfaces.storage.db as db
31from mindsdb.__about__ import __version__ as mindsdb_version
32from mindsdb.utilities.hooks import after_predict as after_predict_hook
33from mindsdb.interfaces.model.functions import (
34 get_model_record
35)
36from mindsdb.integrations.libs.const import PREDICTOR_STATUS
37from mindsdb.interfaces.database.database import DatabaseController
38from mindsdb.utilities.context import context as ctx
39from mindsdb.interfaces.model.functions import get_model_records
40from mindsdb.utilities.functions import mark_process
41import mindsdb.utilities.profiler as profiler
42from mindsdb.utilities.ml_task_queue.producer import MLTaskProducer
43from mindsdb.utilities.ml_task_queue.const import ML_TASK_TYPE
44from mindsdb.integrations.libs.process_cache import process_cache, empty_callback, MLProcessException
46try:
47 import torch.multiprocessing as mp
48except Exception:
49 import multiprocessing as mp
50mp_ctx = mp.get_context('spawn')
53class MLEngineException(Exception):
54 pass
57class BaseMLEngineExec:
59 def __init__(self, name: str, integration_id: int, handler_module: ModuleType):
60 """ML handler interface
62 Args:
63 name (str): name of the ml_engine
64 integration_id (int): id of the ml_engine
65 handler_module (ModuleType): module of the ml_engine
66 """
67 self.name = name
68 self.config = Config()
69 self.integration_id = integration_id
70 self.engine = handler_module.name
71 self.handler_module = handler_module
73 self.database_controller = DatabaseController()
75 self.base_ml_executor = process_cache
76 if self.config['ml_task_queue']['type'] == 'redis': 76 ↛ 77line 76 didn't jump to line 77 because the condition on line 76 was never true
77 self.base_ml_executor = MLTaskProducer()
79 @profiler.profile()
80 def learn(
81 self, model_name, project_name,
82 data_integration_ref=None,
83 fetch_data_query=None,
84 problem_definition=None,
85 join_learn_process=False,
86 label=None,
87 is_retrain=False,
88 set_active=True,
89 ):
90 """ Trains a model given some data-gathering SQL statement. """
92 # may or may not be provided (e.g. 0-shot models do not need it), so engine will handle it
93 target = problem_definition.get('target', ['']) # db.Predictor expects Column(Array(String))
95 project = self.database_controller.get_project(name=project_name)
97 self.create_validation(target, problem_definition, self.integration_id)
99 predictor_record = db.Predictor(
100 company_id=ctx.company_id,
101 name=model_name,
102 integration_id=self.integration_id,
103 data_integration_ref=data_integration_ref,
104 fetch_data_query=fetch_data_query,
105 mindsdb_version=mindsdb_version,
106 to_predict=target,
107 learn_args=problem_definition,
108 data={'name': model_name},
109 project_id=project.id,
110 training_data_columns_count=None,
111 training_data_rows_count=None,
112 training_start_at=dt.datetime.now(),
113 status=PREDICTOR_STATUS.GENERATING,
114 label=label,
115 version=(
116 db.session.query(
117 coalesce(func.max(db.Predictor.version), 1) + (1 if is_retrain else 0)
118 ).filter_by(
119 company_id=ctx.company_id,
120 name=model_name,
121 project_id=project.id,
122 deleted_at=null()
123 ).scalar_subquery()),
124 active=(not is_retrain), # if create then active
125 training_metadata={
126 'hostname': socket.gethostname(),
127 'reason': 'retrain' if is_retrain else 'learn'
128 }
129 )
131 db.serializable_insert(predictor_record)
133 with self._catch_exception(model_name):
134 task = self.base_ml_executor.apply_async(
135 task_type=ML_TASK_TYPE.LEARN,
136 model_id=predictor_record.id,
137 payload={
138 'handler_meta': {
139 'module_path': self.handler_module.__package__,
140 'engine': self.engine,
141 'integration_id': self.integration_id
142 },
143 'context': ctx.dump(),
144 'problem_definition': problem_definition,
145 'set_active': set_active,
146 'data_integration_ref': data_integration_ref,
147 'fetch_data_query': fetch_data_query,
148 'project_name': project_name
149 }
150 )
152 if join_learn_process is True: 152 ↛ 158line 152 didn't jump to line 158 because the condition on line 152 was always true
153 task.result()
154 predictor_record = db.Predictor.query.get(predictor_record.id)
155 db.session.refresh(predictor_record)
156 else:
157 # to prevent memory leak need to add any callback
158 task.add_done_callback(empty_callback)
160 return predictor_record
162 def describe(self, model_id: int, attribute: Optional[str] = None) -> pd.DataFrame:
163 with self._catch_exception(model_id):
164 task = self.base_ml_executor.apply_async(
165 task_type=ML_TASK_TYPE.DESCRIBE,
166 model_id=model_id,
167 payload={
168 'handler_meta': {
169 'module_path': self.handler_module.__package__,
170 'engine': self.engine,
171 'integration_id': self.integration_id
172 },
173 'attribute': attribute,
174 'context': ctx.dump()
175 }
176 )
177 result = task.result()
178 return result
180 def function_call(self, func_name, args):
181 with self._catch_exception():
182 task = self.base_ml_executor.apply_async(
183 task_type=ML_TASK_TYPE.FUNC_CALL,
184 model_id=0, # can not be None
185 payload={
186 'context': ctx.dump(),
187 'name': func_name,
188 'args': args,
189 'handler_meta': {
190 'module_path': self.handler_module.__package__,
191 'engine': self.engine,
192 'integration_id': self.integration_id
193 },
194 }
195 )
196 result = task.result()
197 return result
199 @profiler.profile()
200 @mark_process(name='predict')
201 def predict(self, model_name: str, df: pd.DataFrame, pred_format: str = 'dict',
202 project_name: str = None, version=None, params: dict = None):
203 """ Generates predictions with some model and input data. """
205 kwargs = {
206 'name': model_name,
207 'ml_handler_name': self.name,
208 'project_name': project_name
209 }
210 if version is None:
211 kwargs['active'] = True
212 else:
213 kwargs['active'] = None
214 kwargs['version'] = version
215 predictor_record = get_model_record(**kwargs)
216 if predictor_record is None:
217 if version is not None:
218 model_name = f'{model_name}.{version}'
219 raise Exception(f"Error: model '{model_name}' does not exists!")
220 if predictor_record.status != PREDICTOR_STATUS.COMPLETE:
221 raise Exception("Error: model creation not completed")
223 using = {} if params is None else params
224 args = {
225 'pred_format': pred_format,
226 'predict_params': using,
227 'using': using
228 }
230 with self._catch_exception(model_name):
231 task = self.base_ml_executor.apply_async(
232 task_type=ML_TASK_TYPE.PREDICT,
233 model_id=predictor_record.id,
234 payload={
235 'handler_meta': {
236 'module_path': self.handler_module.__package__,
237 'engine': self.engine,
238 'integration_id': self.integration_id
239 },
240 'context': ctx.dump(),
241 'predictor_record': predictor_record,
242 'args': args
243 },
244 dataframe=df
245 )
246 predictions = task.result()
248 # mdb indexes
249 if '__mindsdb_row_id' not in predictions.columns and '__mindsdb_row_id' in df.columns:
250 predictions['__mindsdb_row_id'] = df['__mindsdb_row_id']
252 after_predict_hook(
253 company_id=ctx.company_id,
254 predictor_id=predictor_record.id,
255 rows_in_count=df.shape[0],
256 columns_in_count=df.shape[1],
257 rows_out_count=len(predictions)
258 )
259 return predictions
261 def create_validation(self, target, args, integration_id):
262 with self._catch_exception():
263 task = self.base_ml_executor.apply_async(
264 task_type=ML_TASK_TYPE.CREATE_VALIDATION,
265 model_id=0, # can not be None
266 payload={
267 'context': ctx.dump(),
268 'target': target,
269 'args': args,
270 'handler_meta': {
271 'module_path': self.handler_module.__package__,
272 'engine': self.engine,
273 'integration_id': integration_id
274 },
275 }
276 )
277 result = task.result()
278 return result
280 def update(self, args: dict, model_id: int):
281 with self._catch_exception(model_id):
282 task = self.base_ml_executor.apply_async(
283 task_type=ML_TASK_TYPE.UPDATE,
284 model_id=model_id,
285 payload={
286 'context': ctx.dump(),
287 'args': args,
288 'handler_meta': {
289 'module_path': self.handler_module.__package__,
290 'engine': self.engine,
291 'integration_id': self.integration_id
292 },
293 }
294 )
295 result = task.result()
296 return result
298 def update_engine(self, connection_args):
299 with self._catch_exception():
300 task = self.base_ml_executor.apply_async(
301 task_type=ML_TASK_TYPE.UPDATE_ENGINE,
302 model_id=0, # can not be None
303 payload={
304 'context': ctx.dump(),
305 'connection_args': connection_args,
306 'handler_meta': {
307 'module_path': self.handler_module.__package__,
308 'engine': self.engine,
309 'integration_id': self.integration_id
310 },
311 }
312 )
313 result = task.result()
314 return result
316 def create_engine(self, connection_args: dict, integration_id: int) -> None:
317 with self._catch_exception():
318 task = self.base_ml_executor.apply_async(
319 task_type=ML_TASK_TYPE.CREATE_ENGINE,
320 model_id=0, # can not be None
321 payload={
322 'context': ctx.dump(),
323 'connection_args': connection_args,
324 'handler_meta': {
325 'module_path': self.handler_module.__package__,
326 'engine': self.engine,
327 'integration_id': integration_id
328 },
329 }
330 )
331 result = task.result()
332 return result
334 @profiler.profile()
335 def finetune(
336 self, model_name, project_name,
337 base_model_version: int,
338 data_integration_ref=None,
339 fetch_data_query=None,
340 join_learn_process=False,
341 label=None,
342 set_active=True,
343 args: Optional[dict] = None
344 ):
345 # generate new record from latest version as starting point
346 project = self.database_controller.get_project(name=project_name)
348 search_args = {
349 'active': None,
350 'name': model_name,
351 'status': PREDICTOR_STATUS.COMPLETE
352 }
353 if base_model_version is not None:
354 search_args['version'] = base_model_version
355 else:
356 search_args['active'] = True
357 predictor_records = get_model_records(**search_args)
358 if len(predictor_records) == 0:
359 raise Exception("Can't find suitable base model")
361 predictor_records.sort(key=lambda x: x.training_stop_at, reverse=True)
362 predictor_records = [x for x in predictor_records if x.training_stop_at is not None]
363 base_predictor_record = predictor_records[0]
365 learn_args = base_predictor_record.learn_args
366 learn_args['using'] = args if not learn_args.get('using', False) else {**learn_args['using'], **args}
368 self.create_validation(
369 target=base_predictor_record.to_predict,
370 args=learn_args,
371 integration_id=self.integration_id
372 )
374 predictor_record = db.Predictor(
375 company_id=ctx.company_id,
376 name=model_name,
377 integration_id=self.integration_id,
378 data_integration_ref=data_integration_ref,
379 fetch_data_query=fetch_data_query,
380 mindsdb_version=mindsdb_version,
381 to_predict=base_predictor_record.to_predict,
382 learn_args=learn_args,
383 data={'name': model_name},
384 project_id=project.id,
385 training_data_columns_count=None,
386 training_data_rows_count=None,
387 training_start_at=dt.datetime.now(),
388 status=PREDICTOR_STATUS.GENERATING,
389 label=label,
390 version=(
391 db.session.query(
392 coalesce(func.max(db.Predictor.version), 1) + 1
393 ).filter_by(
394 company_id=ctx.company_id,
395 name=model_name,
396 project_id=project.id,
397 deleted_at=null()
398 ).scalar_subquery()
399 ),
400 active=False,
401 training_metadata={
402 'hostname': socket.gethostname(),
403 'reason': 'finetune'
404 }
405 )
406 db.serializable_insert(predictor_record)
408 with self._catch_exception(model_name):
409 task = self.base_ml_executor.apply_async(
410 task_type=ML_TASK_TYPE.FINETUNE,
411 model_id=predictor_record.id,
412 payload={
413 'handler_meta': {
414 'module_path': self.handler_module.__package__,
415 'engine': self.engine,
416 'integration_id': self.integration_id
417 },
418 'context': ctx.dump(),
419 'model_id': predictor_record.id,
420 'problem_definition': predictor_record.learn_args,
421 'set_active': set_active,
422 'base_model_id': base_predictor_record.id,
423 'data_integration_ref': data_integration_ref,
424 'fetch_data_query': fetch_data_query,
425 'project_name': project_name
426 }
427 )
429 if join_learn_process is True:
430 task.result()
431 predictor_record = db.Predictor.query.get(predictor_record.id)
432 db.session.refresh(predictor_record)
433 else:
434 # to prevent memory leak need to add any callback
435 task.add_done_callback(empty_callback)
437 return predictor_record
439 @contextlib.contextmanager
440 def _catch_exception(self, model_identifier: Optional[Union[int, str]] = None):
441 try:
442 yield
443 except (ImportError, ModuleNotFoundError):
444 raise
445 except Exception as e:
446 if type(e) is MLProcessException:
447 e = e.base_exception
448 msg = str(e).strip()
449 if msg == '':
450 msg = e.__class__.__name__
451 model_identifier = '' if model_identifier is None else f'/{model_identifier}'
452 msg = f'[{self.name}{model_identifier}]: {msg}'
453 raise MLEngineException(msg) from e