Coverage for mindsdb / interfaces / model / model_controller.py: 46%

279 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import copy 

2import datetime as dt 

3from copy import deepcopy 

4from multiprocessing.pool import ThreadPool 

5 

6import pandas as pd 

7from dateutil.parser import parse as parse_datetime 

8 

9from sqlalchemy import func 

10import numpy as np 

11 

12import mindsdb.interfaces.storage.db as db 

13from mindsdb.utilities.config import Config 

14from mindsdb.interfaces.model.functions import get_model_record, get_model_records, get_project_record 

15from mindsdb.interfaces.storage.json import get_json_storage 

16from mindsdb.interfaces.storage.model_fs import ModelStorage 

17from mindsdb.utilities.config import config 

18from mindsdb.utilities.context import context as ctx 

19from mindsdb.utilities.functions import resolve_model_identifier 

20import mindsdb.utilities.profiler as profiler 

21from mindsdb.utilities.exception import EntityExistsError, EntityNotExistsError 

22from mindsdb.utilities import log 

23 

24logger = log.getLogger(__name__) 

25 

26 

27default_project = config.get("default_project") 

28 

29 

30def delete_model_storage(model_id, ctx_dump): 

31 try: 

32 ctx.load(ctx_dump) 

33 modelStorage = ModelStorage(model_id) 

34 modelStorage.delete() 

35 except Exception: 

36 logger.exception(f"Something went wrong during deleting storage of model {model_id}:") 

37 

38 

39class ModelController: 

40 config: Config 

41 

42 def __init__(self) -> None: 

43 self.config = Config() 

44 

45 def get_model_data(self, name: str = None, predictor_record=None, ml_handler_name="lightwood") -> dict: 

46 if predictor_record is None: 46 ↛ 47line 46 didn't jump to line 47 because the condition on line 46 was never true

47 predictor_record = get_model_record(except_absent=True, name=name, ml_handler_name=ml_handler_name) 

48 

49 data = deepcopy(predictor_record.data) 

50 data["dtype_dict"] = predictor_record.dtype_dict 

51 data["created_at"] = str(parse_datetime(str(predictor_record.created_at).split(".")[0])) 

52 data["updated_at"] = str(parse_datetime(str(predictor_record.updated_at).split(".")[0])) 

53 data["training_start_at"] = predictor_record.training_start_at 

54 data["training_stop_at"] = predictor_record.training_stop_at 

55 data["predict"] = predictor_record.to_predict[0] 

56 data["update"] = predictor_record.update_status 

57 data["mindsdb_version"] = predictor_record.mindsdb_version 

58 data["name"] = predictor_record.name 

59 data["code"] = predictor_record.code 

60 data["problem_definition"] = predictor_record.learn_args 

61 data["fetch_data_query"] = predictor_record.fetch_data_query 

62 data["active"] = predictor_record.active 

63 data["status"] = predictor_record.status 

64 data["id"] = predictor_record.id 

65 data["version"] = predictor_record.version 

66 

67 json_storage = get_json_storage(resource_id=predictor_record.id) 

68 data["json_ai"] = json_storage.get("json_ai") 

69 

70 if data.get("accuracies", None) is not None: 70 ↛ 71line 70 didn't jump to line 71 because the condition on line 70 was never true

71 if len(data["accuracies"]) > 0: 

72 data["accuracy"] = float(np.mean(list(data["accuracies"].values()))) 

73 return data 

74 

75 def get_reduced_model_data(self, name: str = None, predictor_record=None, ml_handler_name="lightwood") -> dict: 

76 full_model_data = self.get_model_data( 

77 name=name, predictor_record=predictor_record, ml_handler_name=ml_handler_name 

78 ) 

79 reduced_model_data = {} 

80 for k in [ 

81 "id", 

82 "name", 

83 "version", 

84 "is_active", 

85 "predict", 

86 "status", 

87 "problem_definition", 

88 "current_phase", 

89 "accuracy", 

90 "data_source", 

91 "update", 

92 "active", 

93 "mindsdb_version", 

94 "error", 

95 "created_at", 

96 "fetch_data_query", 

97 ]: 

98 reduced_model_data[k] = full_model_data.get(k, None) 

99 

100 reduced_model_data["training_time"] = None 

101 if full_model_data.get("training_start_at") is not None: 101 ↛ 114line 101 didn't jump to line 114 because the condition on line 101 was always true

102 if full_model_data.get("training_stop_at") is not None: 102 ↛ 106line 102 didn't jump to line 106 because the condition on line 102 was always true

103 reduced_model_data["training_time"] = full_model_data.get("training_stop_at") - full_model_data.get( 

104 "training_start_at" 

105 ) 

106 elif full_model_data.get("status") == "training": 

107 reduced_model_data["training_time"] = dt.datetime.now() - full_model_data.get("training_start_at") 

108 if reduced_model_data["training_time"] is not None: 108 ↛ 114line 108 didn't jump to line 114 because the condition on line 108 was always true

109 reduced_model_data["training_time"] = ( 

110 reduced_model_data["training_time"] 

111 - dt.timedelta(microseconds=reduced_model_data["training_time"].microseconds) 

112 ).total_seconds() 

113 

114 return reduced_model_data 

115 

116 def describe_model(self, session, project_name, model_name, attribute, version=None): 

117 args = {"name": model_name, "version": version, "project_name": project_name, "except_absent": True} 

118 if version is not None: 

119 args["active"] = None 

120 

121 model_record = get_model_record(**args) 

122 

123 integration_record = db.Integration.query.get(model_record.integration_id) 

124 

125 ml_handler_base = session.integration_controller.get_ml_handler(integration_record.name) 

126 

127 return ml_handler_base.describe(model_record.id, attribute) 

128 

129 def get_model(self, name, version=None, ml_handler_name=None, project_name=None): 

130 show_active = True if version is None else None 

131 model_record = get_model_record( 

132 active=show_active, 

133 version=version, 

134 name=name, 

135 ml_handler_name=ml_handler_name, 

136 except_absent=True, 

137 project_name=project_name, 

138 ) 

139 data = self.get_reduced_model_data(predictor_record=model_record) 

140 integration_record = db.Integration.query.get(model_record.integration_id) 

141 if integration_record is not None: 141 ↛ 144line 141 didn't jump to line 144 because the condition on line 141 was always true

142 data["engine"] = integration_record.engine 

143 data["engine_name"] = integration_record.name 

144 return data 

145 

146 def get_models(self, with_versions=False, ml_handler_name=None, integration_id=None, project_name=None): 

147 models = [] 

148 show_active = True if with_versions is False else None 

149 for model_record in get_model_records( 

150 active=show_active, 

151 ml_handler_name=ml_handler_name, 

152 integration_id=integration_id, 

153 project_name=project_name, 

154 ): 

155 model_data = self.get_reduced_model_data(predictor_record=model_record) 

156 models.append(model_data) 

157 return models 

158 

159 def delete_model(self, model_name: str, project_name: str = default_project, version=None): 

160 from mindsdb.interfaces.database.database import DatabaseController 

161 

162 project_record = get_project_record(func.lower(project_name)) 

163 if project_record is None: 163 ↛ 164line 163 didn't jump to line 164 because the condition on line 163 was never true

164 raise Exception(f"Project '{project_name}' does not exists") 

165 

166 database_controller = DatabaseController() 

167 

168 project = database_controller.get_project(project_name) 

169 

170 if version is None: 170 ↛ 178line 170 didn't jump to line 178 because the condition on line 170 was always true

171 # Delete latest version 

172 predictors_records = get_model_records( 

173 name=model_name, 

174 project_id=project.id, 

175 active=None, 

176 ) 

177 else: 

178 predictors_records = get_model_records( 

179 name=model_name, 

180 project_id=project.id, 

181 version=version, 

182 ) 

183 if len(predictors_records) == 0: 183 ↛ 184line 183 didn't jump to line 184 because the condition on line 183 was never true

184 raise EntityNotExistsError("Model does not exist", model_name) 

185 

186 is_cloud = self.config.get("cloud", False) 

187 if is_cloud: 187 ↛ 188line 187 didn't jump to line 188 because the condition on line 187 was never true

188 for predictor_record in predictors_records: 

189 model_data = self.get_model_data(predictor_record=predictor_record) 

190 if ( 

191 model_data.get("status") in ["generating", "training"] 

192 and isinstance(model_data.get("created_at"), str) is True 

193 and (dt.datetime.now() - parse_datetime(model_data.get("created_at"))) < dt.timedelta(hours=1) 

194 ): 

195 raise Exception( 

196 "You are unable to delete models currently in progress, please wait before trying again" 

197 ) 

198 

199 for predictor_record in predictors_records: 

200 if is_cloud: 200 ↛ 201line 200 didn't jump to line 201 because the condition on line 200 was never true

201 predictor_record.deleted_at = dt.datetime.now() 

202 else: 

203 db.session.delete(predictor_record) 

204 db.session.commit() 

205 

206 # region delete storages 

207 if len(predictors_records) > 1: 207 ↛ 208line 207 didn't jump to line 208 because the condition on line 207 was never true

208 ctx_dump = ctx.dump() 

209 with ThreadPool(min(len(predictors_records), 100)) as pool: 

210 pool.starmap(delete_model_storage, [(record.id, ctx_dump) for record in predictors_records]) 

211 else: 

212 modelStorage = ModelStorage(predictors_records[0].id) 

213 modelStorage.delete() 

214 # endregion 

215 

216 def rename_model(self, old_name, new_name): 

217 model_record = get_model_record(name=new_name) 

218 if model_record is None: 

219 raise Exception(f"Model with name '{new_name}' already exists") 

220 

221 for model_record in get_model_records(name=old_name): 

222 model_record.name = new_name 

223 db.session.commit() 

224 

225 @staticmethod 

226 def _get_data_integration_ref(statement, database_controller): 

227 # TODO use database_controller handler_controller internally 

228 data_integration_ref = None 

229 fetch_data_query = None 

230 if statement.integration_name is not None: 230 ↛ 245line 230 didn't jump to line 245 because the condition on line 230 was always true

231 fetch_data_query = statement.query_str 

232 integration_name = statement.integration_name.parts[0].lower() 

233 

234 databases_meta = database_controller.get_dict() 

235 if integration_name not in databases_meta: 235 ↛ 236line 235 didn't jump to line 236 because the condition on line 235 was never true

236 raise EntityNotExistsError("Database does not exist", integration_name) 

237 data_integration_meta = databases_meta[integration_name] 

238 # TODO improve here. Suppose that it is view 

239 if data_integration_meta["type"] == "project": 239 ↛ 240line 239 didn't jump to line 240 because the condition on line 239 was never true

240 data_integration_ref = {"type": "project"} 

241 elif data_integration_meta["type"] == "system": 241 ↛ 242line 241 didn't jump to line 242 because the condition on line 241 was never true

242 data_integration_ref = {"type": "system"} 

243 else: 

244 data_integration_ref = {"type": "integration", "id": data_integration_meta["id"]} 

245 return data_integration_ref, fetch_data_query 

246 

247 def prepare_create_statement(self, statement, database_controller): 

248 # extract data from Create model or Retrain statement and prepare it for using in crate and retrain functions 

249 project_name = statement.name.parts[0] 

250 model_name = statement.name.parts[1] 

251 

252 sql_task = None 

253 if statement.task is not None: 253 ↛ 254line 253 didn't jump to line 254 because the condition on line 253 was never true

254 sql_task = statement.task.to_string() 

255 problem_definition = {"__mdb_sql_task": sql_task} 

256 if statement.targets is not None: 256 ↛ 259line 256 didn't jump to line 259 because the condition on line 256 was always true

257 problem_definition["target"] = statement.targets[0].parts[-1] 

258 

259 data_integration_ref, fetch_data_query = self._get_data_integration_ref(statement, database_controller) 

260 

261 label = None 

262 if statement.using is not None: 262 ↛ 267line 262 didn't jump to line 267 because the condition on line 262 was always true

263 label = statement.using.pop("tag", None) 

264 

265 problem_definition["using"] = statement.using 

266 

267 if statement.order_by is not None: 267 ↛ 268line 267 didn't jump to line 268 because the condition on line 267 was never true

268 problem_definition["timeseries_settings"] = { 

269 "is_timeseries": True, 

270 "order_by": getattr(statement, "order_by")[0].field.parts[-1], 

271 } 

272 for attr in ["horizon", "window"]: 

273 if getattr(statement, attr) is not None: 

274 problem_definition["timeseries_settings"][attr] = getattr(statement, attr) 

275 

276 if statement.group_by is not None: 

277 problem_definition["timeseries_settings"]["group_by"] = [col.parts[-1] for col in statement.group_by] 

278 

279 join_learn_process = False 

280 if "join_learn_process" in problem_definition.get("using", {}): 280 ↛ 284line 280 didn't jump to line 284 because the condition on line 280 was always true

281 join_learn_process = problem_definition["using"]["join_learn_process"] 

282 del problem_definition["using"]["join_learn_process"] 

283 

284 return dict( 

285 model_name=model_name, 

286 project_name=project_name, 

287 data_integration_ref=data_integration_ref, 

288 fetch_data_query=fetch_data_query, 

289 problem_definition=problem_definition, 

290 join_learn_process=join_learn_process, 

291 label=label, 

292 ) 

293 

294 def create_model(self, statement, ml_handler): 

295 params = self.prepare_create_statement(statement, ml_handler.database_controller) 

296 

297 existing_projects_meta = ml_handler.database_controller.get_dict(filter_type="project", lowercase=False) 

298 if params["project_name"] not in existing_projects_meta: 298 ↛ 299line 298 didn't jump to line 299 because the condition on line 298 was never true

299 raise EntityNotExistsError("Project does not exist", params["project_name"]) 

300 

301 project = ml_handler.database_controller.get_project(name=params["project_name"], strict_case=True) 

302 project_tables = project.get_tables() 

303 if params["model_name"] in project_tables: 303 ↛ 304line 303 didn't jump to line 304 because the condition on line 303 was never true

304 raise EntityExistsError("Model already exists", f"{params['project_name']}.{params['model_name']}") 

305 predictor_record = ml_handler.learn(**params) 

306 

307 return ModelController.get_model_info(predictor_record) 

308 

309 def retrain_model(self, statement, ml_handler): 

310 # active setting 

311 set_active = True 

312 if statement.using is not None: 

313 set_active = statement.using.pop("active", True) 

314 if set_active in ("0", 0, None): 

315 set_active = False 

316 

317 params = self.prepare_create_statement(statement, ml_handler.database_controller) 

318 

319 base_predictor_record = get_model_record( 

320 name=params["model_name"], project_name=params["project_name"], active=True 

321 ) 

322 

323 model_name = params["model_name"] 

324 if base_predictor_record is None: 

325 raise Exception(f"Error: model '{model_name}' does not exist") 

326 

327 if params["data_integration_ref"] is None: 

328 params["data_integration_ref"] = base_predictor_record.data_integration_ref 

329 if params["fetch_data_query"] is None: 

330 params["fetch_data_query"] = base_predictor_record.fetch_data_query 

331 

332 problem_definition = base_predictor_record.learn_args.copy() 

333 problem_definition.update(params["problem_definition"]) 

334 params["problem_definition"] = problem_definition 

335 

336 params["is_retrain"] = True 

337 params["set_active"] = set_active 

338 predictor_record = ml_handler.learn(**params) 

339 

340 return ModelController.get_model_info(predictor_record) 

341 

342 def prepare_finetune_statement(self, statement, database_controller): 

343 project_name, model_name, model_version = resolve_model_identifier(statement.name) 

344 if project_name is None: 

345 project_name = default_project 

346 data_integration_ref, fetch_data_query = self._get_data_integration_ref(statement, database_controller) 

347 

348 set_active = True 

349 if statement.using is not None: 

350 set_active = statement.using.pop("active", True) 

351 if set_active in ("0", 0, None): 

352 set_active = False 

353 

354 label = None 

355 args = {} 

356 if statement.using is not None: 

357 label = statement.using.pop("tag", None) 

358 args = statement.using 

359 

360 join_learn_process = args.pop("join_learn_process", False) 

361 

362 base_predictor_record = get_model_record( 

363 name=model_name, 

364 project_name=project_name, 

365 version=model_version, 

366 active=True if model_version is None else None, 

367 ) 

368 

369 if data_integration_ref is None: 

370 data_integration_ref = base_predictor_record.data_integration_ref 

371 if fetch_data_query is None: 

372 fetch_data_query = base_predictor_record.fetch_data_query 

373 

374 return dict( 

375 model_name=model_name, 

376 project_name=project_name, 

377 data_integration_ref=data_integration_ref, 

378 fetch_data_query=fetch_data_query, 

379 base_model_version=model_version, 

380 args=args, 

381 join_learn_process=join_learn_process, 

382 label=label, 

383 set_active=set_active, 

384 ) 

385 

386 @profiler.profile() 

387 def finetune_model(self, statement, ml_handler): 

388 params = self.prepare_finetune_statement(statement, ml_handler.database_controller) 

389 predictor_record = ml_handler.finetune(**params) 

390 return ModelController.get_model_info(predictor_record) 

391 

392 def update_model(self, session, project_name: str, model_name: str, problem_definition, version=None): 

393 model_record = get_model_record(name=model_name, version=version, project_name=project_name, except_absent=True) 

394 integration_record = db.Integration.query.get(model_record.integration_id) 

395 

396 ml_handler_base = session.integration_controller.get_ml_handler(integration_record.name) 

397 ml_handler_base.update(args=problem_definition, model_id=model_record.id) 

398 

399 # update model record 

400 if "using" in problem_definition: 

401 learn_args = copy.deepcopy(model_record.learn_args) 

402 learn_args["using"].update(problem_definition["using"]) 

403 model_record.learn_args = learn_args 

404 db.session.commit() 

405 

406 @staticmethod 

407 def get_model_info(predictor_record): 

408 from mindsdb.interfaces.database.projects import ProjectController 

409 

410 projects_controller = ProjectController() 

411 project = projects_controller.get(id=predictor_record.project_id) 

412 

413 columns = [ 

414 "NAME", 

415 "ENGINE", 

416 "PROJECT", 

417 "ACTIVE", 

418 "VERSION", 

419 "STATUS", 

420 "ACCURACY", 

421 "PREDICT", 

422 "UPDATE_STATUS", 

423 "MINDSDB_VERSION", 

424 "ERROR", 

425 "SELECT_DATA_QUERY", 

426 "TRAINING_OPTIONS", 

427 "TAG", 

428 ] 

429 

430 project_name = project.name 

431 model = project.get_model_by_id(model_id=predictor_record.id) 

432 table_name = model["name"] 

433 table_meta = model["metadata"] 

434 record = [ 

435 table_name, 

436 table_meta["engine"], 

437 project_name, 

438 table_meta["active"], 

439 table_meta["version"], 

440 table_meta["status"], 

441 table_meta["accuracy"], 

442 table_meta["predict"], 

443 table_meta["update_status"], 

444 table_meta["mindsdb_version"], 

445 table_meta["error"], 

446 table_meta["select_data_query"], 

447 str(table_meta["training_options"]), 

448 table_meta["label"], 

449 ] 

450 

451 return pd.DataFrame([record], columns=columns) 

452 

453 def set_model_active_version(self, project_name, model_name, version): 

454 model_record = get_model_record(name=model_name, project_name=project_name, version=version, active=None) 

455 

456 if model_record is None: 

457 raise EntityNotExistsError(f"Model {model_name} with version {version} is not found in {project_name}") 

458 

459 model_record.active = True 

460 

461 # deactivate current active version 

462 model_records = db.Predictor.query.filter( 

463 db.Predictor.name == model_record.name, 

464 db.Predictor.project_id == model_record.project_id, 

465 db.Predictor.active == True, # noqa 

466 db.Predictor.company_id == ctx.company_id, 

467 db.Predictor.id != model_record.id, 

468 ) 

469 for p in model_records: 

470 p.active = False 

471 

472 db.session.commit() 

473 

474 def delete_model_version(self, project_name, model_name, version): 

475 model_record = get_model_record(name=model_name, project_name=project_name, version=version, active=None) 

476 if model_record is None: 

477 raise EntityNotExistsError(f"Model {model_name} with version {version} is not found in {project_name}") 

478 

479 if model_record.active: 

480 raise Exception(f"Can't remove active version: {project_name}.{model_name}.{version}") 

481 

482 is_cloud = self.config.get("cloud", False) 

483 if is_cloud: 

484 model_record.deleted_at = dt.datetime.now() 

485 else: 

486 db.session.delete(model_record) 

487 modelStorage = ModelStorage(model_record.id) 

488 modelStorage.delete() 

489 

490 db.session.commit()