Coverage for mindsdb / interfaces / model / model_controller.py: 46%
279 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import copy
2import datetime as dt
3from copy import deepcopy
4from multiprocessing.pool import ThreadPool
6import pandas as pd
7from dateutil.parser import parse as parse_datetime
9from sqlalchemy import func
10import numpy as np
12import mindsdb.interfaces.storage.db as db
13from mindsdb.utilities.config import Config
14from mindsdb.interfaces.model.functions import get_model_record, get_model_records, get_project_record
15from mindsdb.interfaces.storage.json import get_json_storage
16from mindsdb.interfaces.storage.model_fs import ModelStorage
17from mindsdb.utilities.config import config
18from mindsdb.utilities.context import context as ctx
19from mindsdb.utilities.functions import resolve_model_identifier
20import mindsdb.utilities.profiler as profiler
21from mindsdb.utilities.exception import EntityExistsError, EntityNotExistsError
22from mindsdb.utilities import log
24logger = log.getLogger(__name__)
27default_project = config.get("default_project")
30def delete_model_storage(model_id, ctx_dump):
31 try:
32 ctx.load(ctx_dump)
33 modelStorage = ModelStorage(model_id)
34 modelStorage.delete()
35 except Exception:
36 logger.exception(f"Something went wrong during deleting storage of model {model_id}:")
39class ModelController:
40 config: Config
42 def __init__(self) -> None:
43 self.config = Config()
45 def get_model_data(self, name: str = None, predictor_record=None, ml_handler_name="lightwood") -> dict:
46 if predictor_record is None: 46 ↛ 47line 46 didn't jump to line 47 because the condition on line 46 was never true
47 predictor_record = get_model_record(except_absent=True, name=name, ml_handler_name=ml_handler_name)
49 data = deepcopy(predictor_record.data)
50 data["dtype_dict"] = predictor_record.dtype_dict
51 data["created_at"] = str(parse_datetime(str(predictor_record.created_at).split(".")[0]))
52 data["updated_at"] = str(parse_datetime(str(predictor_record.updated_at).split(".")[0]))
53 data["training_start_at"] = predictor_record.training_start_at
54 data["training_stop_at"] = predictor_record.training_stop_at
55 data["predict"] = predictor_record.to_predict[0]
56 data["update"] = predictor_record.update_status
57 data["mindsdb_version"] = predictor_record.mindsdb_version
58 data["name"] = predictor_record.name
59 data["code"] = predictor_record.code
60 data["problem_definition"] = predictor_record.learn_args
61 data["fetch_data_query"] = predictor_record.fetch_data_query
62 data["active"] = predictor_record.active
63 data["status"] = predictor_record.status
64 data["id"] = predictor_record.id
65 data["version"] = predictor_record.version
67 json_storage = get_json_storage(resource_id=predictor_record.id)
68 data["json_ai"] = json_storage.get("json_ai")
70 if data.get("accuracies", None) is not None: 70 ↛ 71line 70 didn't jump to line 71 because the condition on line 70 was never true
71 if len(data["accuracies"]) > 0:
72 data["accuracy"] = float(np.mean(list(data["accuracies"].values())))
73 return data
75 def get_reduced_model_data(self, name: str = None, predictor_record=None, ml_handler_name="lightwood") -> dict:
76 full_model_data = self.get_model_data(
77 name=name, predictor_record=predictor_record, ml_handler_name=ml_handler_name
78 )
79 reduced_model_data = {}
80 for k in [
81 "id",
82 "name",
83 "version",
84 "is_active",
85 "predict",
86 "status",
87 "problem_definition",
88 "current_phase",
89 "accuracy",
90 "data_source",
91 "update",
92 "active",
93 "mindsdb_version",
94 "error",
95 "created_at",
96 "fetch_data_query",
97 ]:
98 reduced_model_data[k] = full_model_data.get(k, None)
100 reduced_model_data["training_time"] = None
101 if full_model_data.get("training_start_at") is not None: 101 ↛ 114line 101 didn't jump to line 114 because the condition on line 101 was always true
102 if full_model_data.get("training_stop_at") is not None: 102 ↛ 106line 102 didn't jump to line 106 because the condition on line 102 was always true
103 reduced_model_data["training_time"] = full_model_data.get("training_stop_at") - full_model_data.get(
104 "training_start_at"
105 )
106 elif full_model_data.get("status") == "training":
107 reduced_model_data["training_time"] = dt.datetime.now() - full_model_data.get("training_start_at")
108 if reduced_model_data["training_time"] is not None: 108 ↛ 114line 108 didn't jump to line 114 because the condition on line 108 was always true
109 reduced_model_data["training_time"] = (
110 reduced_model_data["training_time"]
111 - dt.timedelta(microseconds=reduced_model_data["training_time"].microseconds)
112 ).total_seconds()
114 return reduced_model_data
116 def describe_model(self, session, project_name, model_name, attribute, version=None):
117 args = {"name": model_name, "version": version, "project_name": project_name, "except_absent": True}
118 if version is not None:
119 args["active"] = None
121 model_record = get_model_record(**args)
123 integration_record = db.Integration.query.get(model_record.integration_id)
125 ml_handler_base = session.integration_controller.get_ml_handler(integration_record.name)
127 return ml_handler_base.describe(model_record.id, attribute)
129 def get_model(self, name, version=None, ml_handler_name=None, project_name=None):
130 show_active = True if version is None else None
131 model_record = get_model_record(
132 active=show_active,
133 version=version,
134 name=name,
135 ml_handler_name=ml_handler_name,
136 except_absent=True,
137 project_name=project_name,
138 )
139 data = self.get_reduced_model_data(predictor_record=model_record)
140 integration_record = db.Integration.query.get(model_record.integration_id)
141 if integration_record is not None: 141 ↛ 144line 141 didn't jump to line 144 because the condition on line 141 was always true
142 data["engine"] = integration_record.engine
143 data["engine_name"] = integration_record.name
144 return data
146 def get_models(self, with_versions=False, ml_handler_name=None, integration_id=None, project_name=None):
147 models = []
148 show_active = True if with_versions is False else None
149 for model_record in get_model_records(
150 active=show_active,
151 ml_handler_name=ml_handler_name,
152 integration_id=integration_id,
153 project_name=project_name,
154 ):
155 model_data = self.get_reduced_model_data(predictor_record=model_record)
156 models.append(model_data)
157 return models
159 def delete_model(self, model_name: str, project_name: str = default_project, version=None):
160 from mindsdb.interfaces.database.database import DatabaseController
162 project_record = get_project_record(func.lower(project_name))
163 if project_record is None: 163 ↛ 164line 163 didn't jump to line 164 because the condition on line 163 was never true
164 raise Exception(f"Project '{project_name}' does not exists")
166 database_controller = DatabaseController()
168 project = database_controller.get_project(project_name)
170 if version is None: 170 ↛ 178line 170 didn't jump to line 178 because the condition on line 170 was always true
171 # Delete latest version
172 predictors_records = get_model_records(
173 name=model_name,
174 project_id=project.id,
175 active=None,
176 )
177 else:
178 predictors_records = get_model_records(
179 name=model_name,
180 project_id=project.id,
181 version=version,
182 )
183 if len(predictors_records) == 0: 183 ↛ 184line 183 didn't jump to line 184 because the condition on line 183 was never true
184 raise EntityNotExistsError("Model does not exist", model_name)
186 is_cloud = self.config.get("cloud", False)
187 if is_cloud: 187 ↛ 188line 187 didn't jump to line 188 because the condition on line 187 was never true
188 for predictor_record in predictors_records:
189 model_data = self.get_model_data(predictor_record=predictor_record)
190 if (
191 model_data.get("status") in ["generating", "training"]
192 and isinstance(model_data.get("created_at"), str) is True
193 and (dt.datetime.now() - parse_datetime(model_data.get("created_at"))) < dt.timedelta(hours=1)
194 ):
195 raise Exception(
196 "You are unable to delete models currently in progress, please wait before trying again"
197 )
199 for predictor_record in predictors_records:
200 if is_cloud: 200 ↛ 201line 200 didn't jump to line 201 because the condition on line 200 was never true
201 predictor_record.deleted_at = dt.datetime.now()
202 else:
203 db.session.delete(predictor_record)
204 db.session.commit()
206 # region delete storages
207 if len(predictors_records) > 1: 207 ↛ 208line 207 didn't jump to line 208 because the condition on line 207 was never true
208 ctx_dump = ctx.dump()
209 with ThreadPool(min(len(predictors_records), 100)) as pool:
210 pool.starmap(delete_model_storage, [(record.id, ctx_dump) for record in predictors_records])
211 else:
212 modelStorage = ModelStorage(predictors_records[0].id)
213 modelStorage.delete()
214 # endregion
216 def rename_model(self, old_name, new_name):
217 model_record = get_model_record(name=new_name)
218 if model_record is None:
219 raise Exception(f"Model with name '{new_name}' already exists")
221 for model_record in get_model_records(name=old_name):
222 model_record.name = new_name
223 db.session.commit()
225 @staticmethod
226 def _get_data_integration_ref(statement, database_controller):
227 # TODO use database_controller handler_controller internally
228 data_integration_ref = None
229 fetch_data_query = None
230 if statement.integration_name is not None: 230 ↛ 245line 230 didn't jump to line 245 because the condition on line 230 was always true
231 fetch_data_query = statement.query_str
232 integration_name = statement.integration_name.parts[0].lower()
234 databases_meta = database_controller.get_dict()
235 if integration_name not in databases_meta: 235 ↛ 236line 235 didn't jump to line 236 because the condition on line 235 was never true
236 raise EntityNotExistsError("Database does not exist", integration_name)
237 data_integration_meta = databases_meta[integration_name]
238 # TODO improve here. Suppose that it is view
239 if data_integration_meta["type"] == "project": 239 ↛ 240line 239 didn't jump to line 240 because the condition on line 239 was never true
240 data_integration_ref = {"type": "project"}
241 elif data_integration_meta["type"] == "system": 241 ↛ 242line 241 didn't jump to line 242 because the condition on line 241 was never true
242 data_integration_ref = {"type": "system"}
243 else:
244 data_integration_ref = {"type": "integration", "id": data_integration_meta["id"]}
245 return data_integration_ref, fetch_data_query
247 def prepare_create_statement(self, statement, database_controller):
248 # extract data from Create model or Retrain statement and prepare it for using in crate and retrain functions
249 project_name = statement.name.parts[0]
250 model_name = statement.name.parts[1]
252 sql_task = None
253 if statement.task is not None: 253 ↛ 254line 253 didn't jump to line 254 because the condition on line 253 was never true
254 sql_task = statement.task.to_string()
255 problem_definition = {"__mdb_sql_task": sql_task}
256 if statement.targets is not None: 256 ↛ 259line 256 didn't jump to line 259 because the condition on line 256 was always true
257 problem_definition["target"] = statement.targets[0].parts[-1]
259 data_integration_ref, fetch_data_query = self._get_data_integration_ref(statement, database_controller)
261 label = None
262 if statement.using is not None: 262 ↛ 267line 262 didn't jump to line 267 because the condition on line 262 was always true
263 label = statement.using.pop("tag", None)
265 problem_definition["using"] = statement.using
267 if statement.order_by is not None: 267 ↛ 268line 267 didn't jump to line 268 because the condition on line 267 was never true
268 problem_definition["timeseries_settings"] = {
269 "is_timeseries": True,
270 "order_by": getattr(statement, "order_by")[0].field.parts[-1],
271 }
272 for attr in ["horizon", "window"]:
273 if getattr(statement, attr) is not None:
274 problem_definition["timeseries_settings"][attr] = getattr(statement, attr)
276 if statement.group_by is not None:
277 problem_definition["timeseries_settings"]["group_by"] = [col.parts[-1] for col in statement.group_by]
279 join_learn_process = False
280 if "join_learn_process" in problem_definition.get("using", {}): 280 ↛ 284line 280 didn't jump to line 284 because the condition on line 280 was always true
281 join_learn_process = problem_definition["using"]["join_learn_process"]
282 del problem_definition["using"]["join_learn_process"]
284 return dict(
285 model_name=model_name,
286 project_name=project_name,
287 data_integration_ref=data_integration_ref,
288 fetch_data_query=fetch_data_query,
289 problem_definition=problem_definition,
290 join_learn_process=join_learn_process,
291 label=label,
292 )
294 def create_model(self, statement, ml_handler):
295 params = self.prepare_create_statement(statement, ml_handler.database_controller)
297 existing_projects_meta = ml_handler.database_controller.get_dict(filter_type="project", lowercase=False)
298 if params["project_name"] not in existing_projects_meta: 298 ↛ 299line 298 didn't jump to line 299 because the condition on line 298 was never true
299 raise EntityNotExistsError("Project does not exist", params["project_name"])
301 project = ml_handler.database_controller.get_project(name=params["project_name"], strict_case=True)
302 project_tables = project.get_tables()
303 if params["model_name"] in project_tables: 303 ↛ 304line 303 didn't jump to line 304 because the condition on line 303 was never true
304 raise EntityExistsError("Model already exists", f"{params['project_name']}.{params['model_name']}")
305 predictor_record = ml_handler.learn(**params)
307 return ModelController.get_model_info(predictor_record)
309 def retrain_model(self, statement, ml_handler):
310 # active setting
311 set_active = True
312 if statement.using is not None:
313 set_active = statement.using.pop("active", True)
314 if set_active in ("0", 0, None):
315 set_active = False
317 params = self.prepare_create_statement(statement, ml_handler.database_controller)
319 base_predictor_record = get_model_record(
320 name=params["model_name"], project_name=params["project_name"], active=True
321 )
323 model_name = params["model_name"]
324 if base_predictor_record is None:
325 raise Exception(f"Error: model '{model_name}' does not exist")
327 if params["data_integration_ref"] is None:
328 params["data_integration_ref"] = base_predictor_record.data_integration_ref
329 if params["fetch_data_query"] is None:
330 params["fetch_data_query"] = base_predictor_record.fetch_data_query
332 problem_definition = base_predictor_record.learn_args.copy()
333 problem_definition.update(params["problem_definition"])
334 params["problem_definition"] = problem_definition
336 params["is_retrain"] = True
337 params["set_active"] = set_active
338 predictor_record = ml_handler.learn(**params)
340 return ModelController.get_model_info(predictor_record)
342 def prepare_finetune_statement(self, statement, database_controller):
343 project_name, model_name, model_version = resolve_model_identifier(statement.name)
344 if project_name is None:
345 project_name = default_project
346 data_integration_ref, fetch_data_query = self._get_data_integration_ref(statement, database_controller)
348 set_active = True
349 if statement.using is not None:
350 set_active = statement.using.pop("active", True)
351 if set_active in ("0", 0, None):
352 set_active = False
354 label = None
355 args = {}
356 if statement.using is not None:
357 label = statement.using.pop("tag", None)
358 args = statement.using
360 join_learn_process = args.pop("join_learn_process", False)
362 base_predictor_record = get_model_record(
363 name=model_name,
364 project_name=project_name,
365 version=model_version,
366 active=True if model_version is None else None,
367 )
369 if data_integration_ref is None:
370 data_integration_ref = base_predictor_record.data_integration_ref
371 if fetch_data_query is None:
372 fetch_data_query = base_predictor_record.fetch_data_query
374 return dict(
375 model_name=model_name,
376 project_name=project_name,
377 data_integration_ref=data_integration_ref,
378 fetch_data_query=fetch_data_query,
379 base_model_version=model_version,
380 args=args,
381 join_learn_process=join_learn_process,
382 label=label,
383 set_active=set_active,
384 )
386 @profiler.profile()
387 def finetune_model(self, statement, ml_handler):
388 params = self.prepare_finetune_statement(statement, ml_handler.database_controller)
389 predictor_record = ml_handler.finetune(**params)
390 return ModelController.get_model_info(predictor_record)
392 def update_model(self, session, project_name: str, model_name: str, problem_definition, version=None):
393 model_record = get_model_record(name=model_name, version=version, project_name=project_name, except_absent=True)
394 integration_record = db.Integration.query.get(model_record.integration_id)
396 ml_handler_base = session.integration_controller.get_ml_handler(integration_record.name)
397 ml_handler_base.update(args=problem_definition, model_id=model_record.id)
399 # update model record
400 if "using" in problem_definition:
401 learn_args = copy.deepcopy(model_record.learn_args)
402 learn_args["using"].update(problem_definition["using"])
403 model_record.learn_args = learn_args
404 db.session.commit()
406 @staticmethod
407 def get_model_info(predictor_record):
408 from mindsdb.interfaces.database.projects import ProjectController
410 projects_controller = ProjectController()
411 project = projects_controller.get(id=predictor_record.project_id)
413 columns = [
414 "NAME",
415 "ENGINE",
416 "PROJECT",
417 "ACTIVE",
418 "VERSION",
419 "STATUS",
420 "ACCURACY",
421 "PREDICT",
422 "UPDATE_STATUS",
423 "MINDSDB_VERSION",
424 "ERROR",
425 "SELECT_DATA_QUERY",
426 "TRAINING_OPTIONS",
427 "TAG",
428 ]
430 project_name = project.name
431 model = project.get_model_by_id(model_id=predictor_record.id)
432 table_name = model["name"]
433 table_meta = model["metadata"]
434 record = [
435 table_name,
436 table_meta["engine"],
437 project_name,
438 table_meta["active"],
439 table_meta["version"],
440 table_meta["status"],
441 table_meta["accuracy"],
442 table_meta["predict"],
443 table_meta["update_status"],
444 table_meta["mindsdb_version"],
445 table_meta["error"],
446 table_meta["select_data_query"],
447 str(table_meta["training_options"]),
448 table_meta["label"],
449 ]
451 return pd.DataFrame([record], columns=columns)
453 def set_model_active_version(self, project_name, model_name, version):
454 model_record = get_model_record(name=model_name, project_name=project_name, version=version, active=None)
456 if model_record is None:
457 raise EntityNotExistsError(f"Model {model_name} with version {version} is not found in {project_name}")
459 model_record.active = True
461 # deactivate current active version
462 model_records = db.Predictor.query.filter(
463 db.Predictor.name == model_record.name,
464 db.Predictor.project_id == model_record.project_id,
465 db.Predictor.active == True, # noqa
466 db.Predictor.company_id == ctx.company_id,
467 db.Predictor.id != model_record.id,
468 )
469 for p in model_records:
470 p.active = False
472 db.session.commit()
474 def delete_model_version(self, project_name, model_name, version):
475 model_record = get_model_record(name=model_name, project_name=project_name, version=version, active=None)
476 if model_record is None:
477 raise EntityNotExistsError(f"Model {model_name} with version {version} is not found in {project_name}")
479 if model_record.active:
480 raise Exception(f"Can't remove active version: {project_name}.{model_name}.{version}")
482 is_cloud = self.config.get("cloud", False)
483 if is_cloud:
484 model_record.deleted_at = dt.datetime.now()
485 else:
486 db.session.delete(model_record)
487 modelStorage = ModelStorage(model_record.id)
488 modelStorage.delete()
490 db.session.commit()