Coverage for mindsdb / integrations / libs / ml_handler_process / learn_process.py: 17%
103 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import os
2import importlib
3import datetime as dt
5from sqlalchemy.orm.attributes import flag_modified
7from mindsdb_sql_parser import parse_sql
8from mindsdb_sql_parser.ast import Identifier, Select, Star, NativeQuery
10from mindsdb.api.executor.sql_query import SQLQuery
11import mindsdb.utilities.profiler as profiler
12from mindsdb.utilities.functions import mark_process
13from mindsdb.utilities.config import Config
14from mindsdb.utilities.context import context as ctx
15from mindsdb.utilities import log
16import mindsdb.interfaces.storage.db as db
17from mindsdb.interfaces.storage.model_fs import ModelStorage, HandlerStorage
18from mindsdb.interfaces.model.functions import get_model_records
19from mindsdb.integrations.utilities.utils import format_exception_error
20from mindsdb.integrations.utilities.sql_utils import make_sql_session
21from mindsdb.integrations.libs.const import PREDICTOR_STATUS
22from mindsdb.integrations.libs.ml_handler_process.handlers_cacher import handlers_cacher
24logger = log.getLogger(__name__)
27@mark_process(name="learn")
28def learn_process(
29 data_integration_ref: dict,
30 problem_definition: dict,
31 fetch_data_query: str,
32 project_name: str,
33 model_id: int,
34 integration_id: int,
35 base_model_id: int,
36 set_active: bool,
37 module_path: str,
38):
39 ctx.profiling = {"level": 0, "enabled": True, "pointer": None, "tree": None}
40 profiler.set_meta(query="learn_process", api="http", environment=Config().get("environment"))
41 with profiler.Context("learn_process"):
42 from mindsdb.interfaces.database.database import DatabaseController
44 try:
45 predictor_record = db.Predictor.query.with_for_update().get(model_id)
46 predictor_record.training_metadata["process_id"] = os.getpid()
47 flag_modified(predictor_record, "training_metadata")
48 db.session.commit()
50 target = problem_definition.get("target", None)
51 training_data_df = None
52 if data_integration_ref is not None:
53 database_controller = DatabaseController()
54 sql_session = make_sql_session()
55 if data_integration_ref["type"] == "integration":
56 integration_name = database_controller.get_integration(data_integration_ref["id"])["name"]
57 query = Select(
58 targets=[Star()],
59 from_table=NativeQuery(integration=Identifier(integration_name), query=fetch_data_query),
60 )
61 sqlquery = SQLQuery(query, session=sql_session)
62 if data_integration_ref["type"] == "system":
63 query = Select(
64 targets=[Star()], from_table=NativeQuery(integration=Identifier("log"), query=fetch_data_query)
65 )
66 sqlquery = SQLQuery(query, session=sql_session)
67 elif data_integration_ref["type"] == "view":
68 project = database_controller.get_project(project_name)
69 query_ast = parse_sql(fetch_data_query)
70 view_meta = project.get_view_meta(query_ast)
71 sqlquery = SQLQuery(view_meta["query_ast"], session=sql_session)
72 elif data_integration_ref["type"] == "project":
73 query_ast = parse_sql(fetch_data_query)
74 sqlquery = SQLQuery(query_ast, session=sql_session)
76 training_data_df = sqlquery.fetched_data.to_df()
78 training_data_columns_count, training_data_rows_count = 0, 0
79 if training_data_df is not None:
80 training_data_columns_count = len(training_data_df.columns)
81 training_data_rows_count = len(training_data_df)
83 predictor_record.training_data_columns_count = training_data_columns_count
84 predictor_record.training_data_rows_count = training_data_rows_count
85 db.session.commit()
87 module = importlib.import_module(module_path)
89 # check if module is imported successfully and raise exception if not
90 if module.import_error is not None:
91 raise module.import_error
93 handlerStorage = HandlerStorage(integration_id)
94 modelStorage = ModelStorage(model_id)
95 modelStorage.fileStorage.push() # FIXME
97 kwargs = {}
98 if base_model_id is not None:
99 kwargs["base_model_storage"] = ModelStorage(base_model_id)
100 kwargs["base_model_storage"].fileStorage.pull()
101 ml_handler = module.Handler(engine_storage=handlerStorage, model_storage=modelStorage, **kwargs)
102 handlers_cacher[predictor_record.id] = ml_handler
104 if not ml_handler.generative and target is not None:
105 if training_data_df is not None and target not in training_data_df.columns:
106 # is the case different? convert column case in input dataframe
107 col_names = {c.lower(): c for c in training_data_df.columns}
108 target_found = col_names.get(target.lower())
109 if target_found:
110 training_data_df.rename(columns={target_found: target}, inplace=True)
111 else:
112 raise Exception(
113 f'Prediction target "{target}" not found in training dataframe: {list(training_data_df.columns)}'
114 )
116 # create new model
117 if base_model_id is None:
118 with profiler.Context("create"):
119 ml_handler.create(target, df=training_data_df, args=problem_definition)
121 # fine-tune (partially train) existing model
122 else:
123 # load model from previous version, use it as starting point
124 with profiler.Context("finetune"):
125 problem_definition["base_model_id"] = base_model_id
126 ml_handler.finetune(df=training_data_df, args=problem_definition)
128 predictor_record.status = PREDICTOR_STATUS.COMPLETE
129 predictor_record.active = set_active
130 db.session.commit()
131 # if retrain and set_active after success creation
132 if set_active is True:
133 models = get_model_records(
134 name=predictor_record.name, project_id=predictor_record.project_id, active=None
135 )
136 for model in models:
137 model.active = False
138 models = [x for x in models if x.status == PREDICTOR_STATUS.COMPLETE]
139 models.sort(key=lambda x: x.created_at)
140 models[-1].active = True
141 except Exception as e:
142 logger.exception("Error during 'learn' process:")
143 error_message = format_exception_error(e)
145 predictor_record = db.Predictor.query.with_for_update().get(model_id)
146 predictor_record.data = {"error": error_message}
147 predictor_record.status = PREDICTOR_STATUS.ERROR
148 db.session.commit()
150 predictor_record.training_stop_at = dt.datetime.now()
151 db.session.commit()