Coverage for mindsdb / integrations / libs / ml_exec_base.py: 44%

150 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1""" 

2This module defines the wrapper for ML engines which abstracts away a lot of complexity. 

3 

4In particular, three big components are included: 

5 

6 - `BaseMLEngineExec` class: this class wraps any object that inherits from `BaseMLEngine` and exposes some endpoints 

7 normally associated with a DB handler (e.g. `native_query`, `get_tables`), as well as other ML-specific behaviors, 

8 like `learn()` or `predict()`. Note that while these still have to be implemented at the engine level, the burden 

9 on that class is lesser given that it only needs to return a pandas DataFrame. It's this class that will take said 

10 output and format it into the HandlerResponse instance that MindsDB core expects. 

11 

12 - `learn_process` method: handles async dispatch of the `learn` method in an engine, as well as registering all 

13 models inside of the internal MindsDB registry. 

14 

15 - `predict_process` method: handles async dispatch of the `predict` method in an engine. 

16 

17""" 

18 

19import socket 

20import contextlib 

21import datetime as dt 

22from types import ModuleType 

23from typing import Optional, Union 

24 

25import pandas as pd 

26from sqlalchemy import func, null 

27from sqlalchemy.sql.functions import coalesce 

28 

29from mindsdb.utilities.config import Config 

30import mindsdb.interfaces.storage.db as db 

31from mindsdb.__about__ import __version__ as mindsdb_version 

32from mindsdb.utilities.hooks import after_predict as after_predict_hook 

33from mindsdb.interfaces.model.functions import ( 

34 get_model_record 

35) 

36from mindsdb.integrations.libs.const import PREDICTOR_STATUS 

37from mindsdb.interfaces.database.database import DatabaseController 

38from mindsdb.utilities.context import context as ctx 

39from mindsdb.interfaces.model.functions import get_model_records 

40from mindsdb.utilities.functions import mark_process 

41import mindsdb.utilities.profiler as profiler 

42from mindsdb.utilities.ml_task_queue.producer import MLTaskProducer 

43from mindsdb.utilities.ml_task_queue.const import ML_TASK_TYPE 

44from mindsdb.integrations.libs.process_cache import process_cache, empty_callback, MLProcessException 

45 

46try: 

47 import torch.multiprocessing as mp 

48except Exception: 

49 import multiprocessing as mp 

50mp_ctx = mp.get_context('spawn') 

51 

52 

53class MLEngineException(Exception): 

54 pass 

55 

56 

57class BaseMLEngineExec: 

58 

59 def __init__(self, name: str, integration_id: int, handler_module: ModuleType): 

60 """ML handler interface 

61 

62 Args: 

63 name (str): name of the ml_engine 

64 integration_id (int): id of the ml_engine 

65 handler_module (ModuleType): module of the ml_engine 

66 """ 

67 self.name = name 

68 self.config = Config() 

69 self.integration_id = integration_id 

70 self.engine = handler_module.name 

71 self.handler_module = handler_module 

72 

73 self.database_controller = DatabaseController() 

74 

75 self.base_ml_executor = process_cache 

76 if self.config['ml_task_queue']['type'] == 'redis': 76 ↛ 77line 76 didn't jump to line 77 because the condition on line 76 was never true

77 self.base_ml_executor = MLTaskProducer() 

78 

79 @profiler.profile() 

80 def learn( 

81 self, model_name, project_name, 

82 data_integration_ref=None, 

83 fetch_data_query=None, 

84 problem_definition=None, 

85 join_learn_process=False, 

86 label=None, 

87 is_retrain=False, 

88 set_active=True, 

89 ): 

90 """ Trains a model given some data-gathering SQL statement. """ 

91 

92 # may or may not be provided (e.g. 0-shot models do not need it), so engine will handle it 

93 target = problem_definition.get('target', ['']) # db.Predictor expects Column(Array(String)) 

94 

95 project = self.database_controller.get_project(name=project_name) 

96 

97 self.create_validation(target, problem_definition, self.integration_id) 

98 

99 predictor_record = db.Predictor( 

100 company_id=ctx.company_id, 

101 name=model_name, 

102 integration_id=self.integration_id, 

103 data_integration_ref=data_integration_ref, 

104 fetch_data_query=fetch_data_query, 

105 mindsdb_version=mindsdb_version, 

106 to_predict=target, 

107 learn_args=problem_definition, 

108 data={'name': model_name}, 

109 project_id=project.id, 

110 training_data_columns_count=None, 

111 training_data_rows_count=None, 

112 training_start_at=dt.datetime.now(), 

113 status=PREDICTOR_STATUS.GENERATING, 

114 label=label, 

115 version=( 

116 db.session.query( 

117 coalesce(func.max(db.Predictor.version), 1) + (1 if is_retrain else 0) 

118 ).filter_by( 

119 company_id=ctx.company_id, 

120 name=model_name, 

121 project_id=project.id, 

122 deleted_at=null() 

123 ).scalar_subquery()), 

124 active=(not is_retrain), # if create then active 

125 training_metadata={ 

126 'hostname': socket.gethostname(), 

127 'reason': 'retrain' if is_retrain else 'learn' 

128 } 

129 ) 

130 

131 db.serializable_insert(predictor_record) 

132 

133 with self._catch_exception(model_name): 

134 task = self.base_ml_executor.apply_async( 

135 task_type=ML_TASK_TYPE.LEARN, 

136 model_id=predictor_record.id, 

137 payload={ 

138 'handler_meta': { 

139 'module_path': self.handler_module.__package__, 

140 'engine': self.engine, 

141 'integration_id': self.integration_id 

142 }, 

143 'context': ctx.dump(), 

144 'problem_definition': problem_definition, 

145 'set_active': set_active, 

146 'data_integration_ref': data_integration_ref, 

147 'fetch_data_query': fetch_data_query, 

148 'project_name': project_name 

149 } 

150 ) 

151 

152 if join_learn_process is True: 152 ↛ 158line 152 didn't jump to line 158 because the condition on line 152 was always true

153 task.result() 

154 predictor_record = db.Predictor.query.get(predictor_record.id) 

155 db.session.refresh(predictor_record) 

156 else: 

157 # to prevent memory leak need to add any callback 

158 task.add_done_callback(empty_callback) 

159 

160 return predictor_record 

161 

162 def describe(self, model_id: int, attribute: Optional[str] = None) -> pd.DataFrame: 

163 with self._catch_exception(model_id): 

164 task = self.base_ml_executor.apply_async( 

165 task_type=ML_TASK_TYPE.DESCRIBE, 

166 model_id=model_id, 

167 payload={ 

168 'handler_meta': { 

169 'module_path': self.handler_module.__package__, 

170 'engine': self.engine, 

171 'integration_id': self.integration_id 

172 }, 

173 'attribute': attribute, 

174 'context': ctx.dump() 

175 } 

176 ) 

177 result = task.result() 

178 return result 

179 

180 def function_call(self, func_name, args): 

181 with self._catch_exception(): 

182 task = self.base_ml_executor.apply_async( 

183 task_type=ML_TASK_TYPE.FUNC_CALL, 

184 model_id=0, # can not be None 

185 payload={ 

186 'context': ctx.dump(), 

187 'name': func_name, 

188 'args': args, 

189 'handler_meta': { 

190 'module_path': self.handler_module.__package__, 

191 'engine': self.engine, 

192 'integration_id': self.integration_id 

193 }, 

194 } 

195 ) 

196 result = task.result() 

197 return result 

198 

199 @profiler.profile() 

200 @mark_process(name='predict') 

201 def predict(self, model_name: str, df: pd.DataFrame, pred_format: str = 'dict', 

202 project_name: str = None, version=None, params: dict = None): 

203 """ Generates predictions with some model and input data. """ 

204 

205 kwargs = { 

206 'name': model_name, 

207 'ml_handler_name': self.name, 

208 'project_name': project_name 

209 } 

210 if version is None: 

211 kwargs['active'] = True 

212 else: 

213 kwargs['active'] = None 

214 kwargs['version'] = version 

215 predictor_record = get_model_record(**kwargs) 

216 if predictor_record is None: 

217 if version is not None: 

218 model_name = f'{model_name}.{version}' 

219 raise Exception(f"Error: model '{model_name}' does not exists!") 

220 if predictor_record.status != PREDICTOR_STATUS.COMPLETE: 

221 raise Exception("Error: model creation not completed") 

222 

223 using = {} if params is None else params 

224 args = { 

225 'pred_format': pred_format, 

226 'predict_params': using, 

227 'using': using 

228 } 

229 

230 with self._catch_exception(model_name): 

231 task = self.base_ml_executor.apply_async( 

232 task_type=ML_TASK_TYPE.PREDICT, 

233 model_id=predictor_record.id, 

234 payload={ 

235 'handler_meta': { 

236 'module_path': self.handler_module.__package__, 

237 'engine': self.engine, 

238 'integration_id': self.integration_id 

239 }, 

240 'context': ctx.dump(), 

241 'predictor_record': predictor_record, 

242 'args': args 

243 }, 

244 dataframe=df 

245 ) 

246 predictions = task.result() 

247 

248 # mdb indexes 

249 if '__mindsdb_row_id' not in predictions.columns and '__mindsdb_row_id' in df.columns: 

250 predictions['__mindsdb_row_id'] = df['__mindsdb_row_id'] 

251 

252 after_predict_hook( 

253 company_id=ctx.company_id, 

254 predictor_id=predictor_record.id, 

255 rows_in_count=df.shape[0], 

256 columns_in_count=df.shape[1], 

257 rows_out_count=len(predictions) 

258 ) 

259 return predictions 

260 

261 def create_validation(self, target, args, integration_id): 

262 with self._catch_exception(): 

263 task = self.base_ml_executor.apply_async( 

264 task_type=ML_TASK_TYPE.CREATE_VALIDATION, 

265 model_id=0, # can not be None 

266 payload={ 

267 'context': ctx.dump(), 

268 'target': target, 

269 'args': args, 

270 'handler_meta': { 

271 'module_path': self.handler_module.__package__, 

272 'engine': self.engine, 

273 'integration_id': integration_id 

274 }, 

275 } 

276 ) 

277 result = task.result() 

278 return result 

279 

280 def update(self, args: dict, model_id: int): 

281 with self._catch_exception(model_id): 

282 task = self.base_ml_executor.apply_async( 

283 task_type=ML_TASK_TYPE.UPDATE, 

284 model_id=model_id, 

285 payload={ 

286 'context': ctx.dump(), 

287 'args': args, 

288 'handler_meta': { 

289 'module_path': self.handler_module.__package__, 

290 'engine': self.engine, 

291 'integration_id': self.integration_id 

292 }, 

293 } 

294 ) 

295 result = task.result() 

296 return result 

297 

298 def update_engine(self, connection_args): 

299 with self._catch_exception(): 

300 task = self.base_ml_executor.apply_async( 

301 task_type=ML_TASK_TYPE.UPDATE_ENGINE, 

302 model_id=0, # can not be None 

303 payload={ 

304 'context': ctx.dump(), 

305 'connection_args': connection_args, 

306 'handler_meta': { 

307 'module_path': self.handler_module.__package__, 

308 'engine': self.engine, 

309 'integration_id': self.integration_id 

310 }, 

311 } 

312 ) 

313 result = task.result() 

314 return result 

315 

316 def create_engine(self, connection_args: dict, integration_id: int) -> None: 

317 with self._catch_exception(): 

318 task = self.base_ml_executor.apply_async( 

319 task_type=ML_TASK_TYPE.CREATE_ENGINE, 

320 model_id=0, # can not be None 

321 payload={ 

322 'context': ctx.dump(), 

323 'connection_args': connection_args, 

324 'handler_meta': { 

325 'module_path': self.handler_module.__package__, 

326 'engine': self.engine, 

327 'integration_id': integration_id 

328 }, 

329 } 

330 ) 

331 result = task.result() 

332 return result 

333 

334 @profiler.profile() 

335 def finetune( 

336 self, model_name, project_name, 

337 base_model_version: int, 

338 data_integration_ref=None, 

339 fetch_data_query=None, 

340 join_learn_process=False, 

341 label=None, 

342 set_active=True, 

343 args: Optional[dict] = None 

344 ): 

345 # generate new record from latest version as starting point 

346 project = self.database_controller.get_project(name=project_name) 

347 

348 search_args = { 

349 'active': None, 

350 'name': model_name, 

351 'status': PREDICTOR_STATUS.COMPLETE 

352 } 

353 if base_model_version is not None: 

354 search_args['version'] = base_model_version 

355 else: 

356 search_args['active'] = True 

357 predictor_records = get_model_records(**search_args) 

358 if len(predictor_records) == 0: 

359 raise Exception("Can't find suitable base model") 

360 

361 predictor_records.sort(key=lambda x: x.training_stop_at, reverse=True) 

362 predictor_records = [x for x in predictor_records if x.training_stop_at is not None] 

363 base_predictor_record = predictor_records[0] 

364 

365 learn_args = base_predictor_record.learn_args 

366 learn_args['using'] = args if not learn_args.get('using', False) else {**learn_args['using'], **args} 

367 

368 self.create_validation( 

369 target=base_predictor_record.to_predict, 

370 args=learn_args, 

371 integration_id=self.integration_id 

372 ) 

373 

374 predictor_record = db.Predictor( 

375 company_id=ctx.company_id, 

376 name=model_name, 

377 integration_id=self.integration_id, 

378 data_integration_ref=data_integration_ref, 

379 fetch_data_query=fetch_data_query, 

380 mindsdb_version=mindsdb_version, 

381 to_predict=base_predictor_record.to_predict, 

382 learn_args=learn_args, 

383 data={'name': model_name}, 

384 project_id=project.id, 

385 training_data_columns_count=None, 

386 training_data_rows_count=None, 

387 training_start_at=dt.datetime.now(), 

388 status=PREDICTOR_STATUS.GENERATING, 

389 label=label, 

390 version=( 

391 db.session.query( 

392 coalesce(func.max(db.Predictor.version), 1) + 1 

393 ).filter_by( 

394 company_id=ctx.company_id, 

395 name=model_name, 

396 project_id=project.id, 

397 deleted_at=null() 

398 ).scalar_subquery() 

399 ), 

400 active=False, 

401 training_metadata={ 

402 'hostname': socket.gethostname(), 

403 'reason': 'finetune' 

404 } 

405 ) 

406 db.serializable_insert(predictor_record) 

407 

408 with self._catch_exception(model_name): 

409 task = self.base_ml_executor.apply_async( 

410 task_type=ML_TASK_TYPE.FINETUNE, 

411 model_id=predictor_record.id, 

412 payload={ 

413 'handler_meta': { 

414 'module_path': self.handler_module.__package__, 

415 'engine': self.engine, 

416 'integration_id': self.integration_id 

417 }, 

418 'context': ctx.dump(), 

419 'model_id': predictor_record.id, 

420 'problem_definition': predictor_record.learn_args, 

421 'set_active': set_active, 

422 'base_model_id': base_predictor_record.id, 

423 'data_integration_ref': data_integration_ref, 

424 'fetch_data_query': fetch_data_query, 

425 'project_name': project_name 

426 } 

427 ) 

428 

429 if join_learn_process is True: 

430 task.result() 

431 predictor_record = db.Predictor.query.get(predictor_record.id) 

432 db.session.refresh(predictor_record) 

433 else: 

434 # to prevent memory leak need to add any callback 

435 task.add_done_callback(empty_callback) 

436 

437 return predictor_record 

438 

439 @contextlib.contextmanager 

440 def _catch_exception(self, model_identifier: Optional[Union[int, str]] = None): 

441 try: 

442 yield 

443 except (ImportError, ModuleNotFoundError): 

444 raise 

445 except Exception as e: 

446 if type(e) is MLProcessException: 

447 e = e.base_exception 

448 msg = str(e).strip() 

449 if msg == '': 

450 msg = e.__class__.__name__ 

451 model_identifier = '' if model_identifier is None else f'/{model_identifier}' 

452 msg = f'[{self.name}{model_identifier}]: {msg}' 

453 raise MLEngineException(msg) from e