Coverage for mindsdb / integrations / handlers / lightwood_handler / functions.py: 0%
156 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import dataclasses
2import json
3import os
4import tempfile
5import traceback
6from datetime import datetime
7from pathlib import Path
9import lightwood
10import pandas as pd
11import requests
12from lightwood.api.types import JsonAI
13from pandas.core.frame import DataFrame
15import mindsdb.utilities.profiler as profiler
16from mindsdb.integrations.libs.const import PREDICTOR_STATUS
17from mindsdb.integrations.utilities.utils import format_exception_error
18from mindsdb.interfaces.storage import db
19from mindsdb.interfaces.storage.fs import RESOURCE_GROUP, FileStorage
20from mindsdb.interfaces.storage.json import get_json_storage
21from mindsdb.utilities import log
22from mindsdb.utilities.functions import mark_process
24from .utils import brack_to_mod, rep_recur, unpack_jsonai_old_args
26logger = log.getLogger(__name__)
29def create_learn_mark():
30 if os.name == "posix":
31 p = Path(tempfile.gettempdir()).joinpath("mindsdb/learn_processes/")
32 p.mkdir(parents=True, exist_ok=True)
33 p.joinpath(f"{os.getpid()}").touch()
36def delete_learn_mark():
37 if os.name == "posix":
38 p = Path(tempfile.gettempdir()).joinpath("mindsdb/learn_processes/").joinpath(f"{os.getpid()}")
39 if p.exists():
40 p.unlink()
43@mark_process(name="learn")
44@profiler.profile()
45def run_generate(df: DataFrame, predictor_id: int, model_storage, args: dict = None):
46 model_storage.training_state_set(current_state_num=1, total_states=5, state_name="Generating problem definition")
47 json_ai_override = args.pop("using", {})
49 if "dtype_dict" in json_ai_override:
50 args["dtype_dict"] = json_ai_override.pop("dtype_dict")
52 if "problem_definition" in json_ai_override:
53 args = {**args, **json_ai_override["problem_definition"]}
55 if "timeseries_settings" in args:
56 for tss_key in [f.name for f in dataclasses.fields(lightwood.api.TimeseriesSettings)]:
57 k = f"timeseries_settings.{tss_key}"
58 if k in json_ai_override:
59 args["timeseries_settings"][tss_key] = json_ai_override.pop(k)
61 problem_definition = lightwood.ProblemDefinition.from_dict(args)
63 model_storage.training_state_set(current_state_num=2, total_states=5, state_name="Generating JsonAI")
64 json_ai = lightwood.json_ai_from_problem(df, problem_definition)
65 json_ai = json_ai.to_dict()
66 unpack_jsonai_old_args(json_ai_override)
67 json_ai_override = brack_to_mod(json_ai_override)
68 rep_recur(json_ai, json_ai_override)
69 json_ai = JsonAI.from_dict(json_ai)
71 model_storage.training_state_set(current_state_num=3, total_states=5, state_name="Generating code")
72 code = lightwood.code_from_json_ai(json_ai)
74 predictor_record = db.Predictor.query.with_for_update().get(predictor_id)
75 predictor_record.code = code
76 db.session.commit()
78 json_storage = get_json_storage(resource_id=predictor_id)
79 json_storage.set("json_ai", json_ai.to_dict())
82@mark_process(name="learn")
83@profiler.profile()
84def run_fit(predictor_id: int, df: pd.DataFrame, model_storage) -> None:
85 try:
86 predictor_record = db.Predictor.query.with_for_update().get(predictor_id)
87 assert predictor_record is not None
89 predictor_record.data = {"training_log": "training"}
90 predictor_record.status = PREDICTOR_STATUS.TRAINING
91 db.session.commit()
93 model_storage.training_state_set(current_state_num=4, total_states=5, state_name="Training model")
94 predictor: lightwood.PredictorInterface = lightwood.predictor_from_code(predictor_record.code)
95 predictor.learn(df)
97 db.session.refresh(predictor_record)
99 fs = FileStorage(resource_group=RESOURCE_GROUP.PREDICTOR, resource_id=predictor_id, sync=True)
100 predictor.save(fs.folder_path / fs.folder_name)
101 fs.push(compression_level=0)
103 predictor_record.data = predictor.model_analysis.to_dict()
105 # getting training time for each tried model. it is possible to do
106 # after training only
107 fit_mixers = list(
108 predictor.runtime_log[x] for x in predictor.runtime_log if isinstance(x, tuple) and x[0] == "fit_mixer"
109 )
110 submodel_data = predictor_record.data.get("submodel_data", [])
111 # add training time to other mixers info
112 if submodel_data and fit_mixers and len(submodel_data) == len(fit_mixers):
113 for i, tr_time in enumerate(fit_mixers):
114 submodel_data[i]["training_time"] = tr_time
115 predictor_record.data["submodel_data"] = submodel_data
117 model_storage.training_state_set(current_state_num=5, total_states=5, state_name="Complete")
118 predictor_record.dtype_dict = predictor.dtype_dict
119 db.session.commit()
120 except Exception as e:
121 db.session.refresh(predictor_record)
122 predictor_record.data = {"error": f"{traceback.format_exc()}\nMain error: {e}"}
123 db.session.commit()
124 raise e
127@mark_process(name="learn")
128def run_learn_remote(df: DataFrame, predictor_id: int) -> None:
129 try:
130 serialized_df = json.dumps(df.to_dict())
131 predictor_record = db.Predictor.query.with_for_update().get(predictor_id)
132 resp = requests.post(
133 predictor_record.data["train_url"],
134 json={"df": serialized_df, "target": predictor_record.to_predict[0]},
135 )
137 assert resp.status_code == 200
138 predictor_record.data["status"] = "complete"
139 except Exception:
140 predictor_record.data["status"] = "error"
141 predictor_record.data["error"] = str(resp.text)
143 db.session.commit()
146@mark_process(name="learn")
147def run_learn(df: DataFrame, args: dict, model_storage) -> None:
148 if df is None or df.shape[0] == 0:
149 raise Exception("No input data. Ensure the data source is healthy and try again.")
151 predictor_id = model_storage.predictor_id
153 predictor_record = db.Predictor.query.with_for_update().get(predictor_id)
154 predictor_record.training_start_at = datetime.now()
155 db.session.commit()
157 run_generate(df, predictor_id, model_storage, args)
158 run_fit(predictor_id, df, model_storage)
160 predictor_record.status = PREDICTOR_STATUS.COMPLETE
161 predictor_record.training_stop_at = datetime.now()
162 db.session.commit()
165@mark_process(name="finetune")
166def run_finetune(df: DataFrame, args: dict, model_storage):
167 try:
168 if df is None or df.shape[0] == 0:
169 raise Exception("No input data. Ensure the data source is healthy and try again.")
171 base_predictor_id = args["base_model_id"]
172 base_predictor_record = db.Predictor.query.get(base_predictor_id)
173 if base_predictor_record.status != PREDICTOR_STATUS.COMPLETE:
174 raise Exception("Base model must be in status 'complete'")
176 predictor_id = model_storage.predictor_id
177 predictor_record = db.Predictor.query.get(predictor_id)
179 # TODO move this to ModelStorage (don't work with database directly)
180 predictor_record.data = {"training_log": "training"}
181 predictor_record.training_start_at = datetime.now()
182 predictor_record.status = PREDICTOR_STATUS.FINETUNING # TODO: parallel execution block
183 db.session.commit()
185 base_fs = FileStorage(
186 resource_group=RESOURCE_GROUP.PREDICTOR,
187 resource_id=base_predictor_id,
188 sync=True,
189 )
190 predictor = lightwood.predictor_from_state(
191 base_fs.folder_path / base_fs.folder_name, base_predictor_record.code
192 )
193 predictor.adjust(df, adjust_args=args.get("using", {}))
195 fs = FileStorage(resource_group=RESOURCE_GROUP.PREDICTOR, resource_id=predictor_id, sync=True)
196 predictor.save(fs.folder_path / fs.folder_name)
197 fs.push(compression_level=0)
199 predictor_record.data = predictor.model_analysis.to_dict() # todo: update accuracy in LW as post-finetune hook
200 predictor_record.code = base_predictor_record.code
201 predictor_record.update_status = "up_to_date"
202 predictor_record.status = PREDICTOR_STATUS.COMPLETE
203 predictor_record.training_stop_at = datetime.now()
204 db.session.commit()
206 except Exception as e:
207 logger.error("Unexpected error during Lightwood model finetune:", exc_info=True)
208 predictor_id = model_storage.predictor_id
209 predictor_record = db.Predictor.query.with_for_update().get(predictor_id)
210 error_message = format_exception_error(e)
211 predictor_record.data = {"error": error_message}
212 predictor_record.status = PREDICTOR_STATUS.ERROR
213 db.session.commit()
214 raise
215 finally:
216 if predictor_record.training_stop_at is None:
217 predictor_record.training_stop_at = datetime.now()
218 db.session.commit()