Coverage for mindsdb / integrations / handlers / lightwood_handler / functions.py: 0%

156 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import dataclasses 

2import json 

3import os 

4import tempfile 

5import traceback 

6from datetime import datetime 

7from pathlib import Path 

8 

9import lightwood 

10import pandas as pd 

11import requests 

12from lightwood.api.types import JsonAI 

13from pandas.core.frame import DataFrame 

14 

15import mindsdb.utilities.profiler as profiler 

16from mindsdb.integrations.libs.const import PREDICTOR_STATUS 

17from mindsdb.integrations.utilities.utils import format_exception_error 

18from mindsdb.interfaces.storage import db 

19from mindsdb.interfaces.storage.fs import RESOURCE_GROUP, FileStorage 

20from mindsdb.interfaces.storage.json import get_json_storage 

21from mindsdb.utilities import log 

22from mindsdb.utilities.functions import mark_process 

23 

24from .utils import brack_to_mod, rep_recur, unpack_jsonai_old_args 

25 

26logger = log.getLogger(__name__) 

27 

28 

29def create_learn_mark(): 

30 if os.name == "posix": 

31 p = Path(tempfile.gettempdir()).joinpath("mindsdb/learn_processes/") 

32 p.mkdir(parents=True, exist_ok=True) 

33 p.joinpath(f"{os.getpid()}").touch() 

34 

35 

36def delete_learn_mark(): 

37 if os.name == "posix": 

38 p = Path(tempfile.gettempdir()).joinpath("mindsdb/learn_processes/").joinpath(f"{os.getpid()}") 

39 if p.exists(): 

40 p.unlink() 

41 

42 

43@mark_process(name="learn") 

44@profiler.profile() 

45def run_generate(df: DataFrame, predictor_id: int, model_storage, args: dict = None): 

46 model_storage.training_state_set(current_state_num=1, total_states=5, state_name="Generating problem definition") 

47 json_ai_override = args.pop("using", {}) 

48 

49 if "dtype_dict" in json_ai_override: 

50 args["dtype_dict"] = json_ai_override.pop("dtype_dict") 

51 

52 if "problem_definition" in json_ai_override: 

53 args = {**args, **json_ai_override["problem_definition"]} 

54 

55 if "timeseries_settings" in args: 

56 for tss_key in [f.name for f in dataclasses.fields(lightwood.api.TimeseriesSettings)]: 

57 k = f"timeseries_settings.{tss_key}" 

58 if k in json_ai_override: 

59 args["timeseries_settings"][tss_key] = json_ai_override.pop(k) 

60 

61 problem_definition = lightwood.ProblemDefinition.from_dict(args) 

62 

63 model_storage.training_state_set(current_state_num=2, total_states=5, state_name="Generating JsonAI") 

64 json_ai = lightwood.json_ai_from_problem(df, problem_definition) 

65 json_ai = json_ai.to_dict() 

66 unpack_jsonai_old_args(json_ai_override) 

67 json_ai_override = brack_to_mod(json_ai_override) 

68 rep_recur(json_ai, json_ai_override) 

69 json_ai = JsonAI.from_dict(json_ai) 

70 

71 model_storage.training_state_set(current_state_num=3, total_states=5, state_name="Generating code") 

72 code = lightwood.code_from_json_ai(json_ai) 

73 

74 predictor_record = db.Predictor.query.with_for_update().get(predictor_id) 

75 predictor_record.code = code 

76 db.session.commit() 

77 

78 json_storage = get_json_storage(resource_id=predictor_id) 

79 json_storage.set("json_ai", json_ai.to_dict()) 

80 

81 

82@mark_process(name="learn") 

83@profiler.profile() 

84def run_fit(predictor_id: int, df: pd.DataFrame, model_storage) -> None: 

85 try: 

86 predictor_record = db.Predictor.query.with_for_update().get(predictor_id) 

87 assert predictor_record is not None 

88 

89 predictor_record.data = {"training_log": "training"} 

90 predictor_record.status = PREDICTOR_STATUS.TRAINING 

91 db.session.commit() 

92 

93 model_storage.training_state_set(current_state_num=4, total_states=5, state_name="Training model") 

94 predictor: lightwood.PredictorInterface = lightwood.predictor_from_code(predictor_record.code) 

95 predictor.learn(df) 

96 

97 db.session.refresh(predictor_record) 

98 

99 fs = FileStorage(resource_group=RESOURCE_GROUP.PREDICTOR, resource_id=predictor_id, sync=True) 

100 predictor.save(fs.folder_path / fs.folder_name) 

101 fs.push(compression_level=0) 

102 

103 predictor_record.data = predictor.model_analysis.to_dict() 

104 

105 # getting training time for each tried model. it is possible to do 

106 # after training only 

107 fit_mixers = list( 

108 predictor.runtime_log[x] for x in predictor.runtime_log if isinstance(x, tuple) and x[0] == "fit_mixer" 

109 ) 

110 submodel_data = predictor_record.data.get("submodel_data", []) 

111 # add training time to other mixers info 

112 if submodel_data and fit_mixers and len(submodel_data) == len(fit_mixers): 

113 for i, tr_time in enumerate(fit_mixers): 

114 submodel_data[i]["training_time"] = tr_time 

115 predictor_record.data["submodel_data"] = submodel_data 

116 

117 model_storage.training_state_set(current_state_num=5, total_states=5, state_name="Complete") 

118 predictor_record.dtype_dict = predictor.dtype_dict 

119 db.session.commit() 

120 except Exception as e: 

121 db.session.refresh(predictor_record) 

122 predictor_record.data = {"error": f"{traceback.format_exc()}\nMain error: {e}"} 

123 db.session.commit() 

124 raise e 

125 

126 

127@mark_process(name="learn") 

128def run_learn_remote(df: DataFrame, predictor_id: int) -> None: 

129 try: 

130 serialized_df = json.dumps(df.to_dict()) 

131 predictor_record = db.Predictor.query.with_for_update().get(predictor_id) 

132 resp = requests.post( 

133 predictor_record.data["train_url"], 

134 json={"df": serialized_df, "target": predictor_record.to_predict[0]}, 

135 ) 

136 

137 assert resp.status_code == 200 

138 predictor_record.data["status"] = "complete" 

139 except Exception: 

140 predictor_record.data["status"] = "error" 

141 predictor_record.data["error"] = str(resp.text) 

142 

143 db.session.commit() 

144 

145 

146@mark_process(name="learn") 

147def run_learn(df: DataFrame, args: dict, model_storage) -> None: 

148 if df is None or df.shape[0] == 0: 

149 raise Exception("No input data. Ensure the data source is healthy and try again.") 

150 

151 predictor_id = model_storage.predictor_id 

152 

153 predictor_record = db.Predictor.query.with_for_update().get(predictor_id) 

154 predictor_record.training_start_at = datetime.now() 

155 db.session.commit() 

156 

157 run_generate(df, predictor_id, model_storage, args) 

158 run_fit(predictor_id, df, model_storage) 

159 

160 predictor_record.status = PREDICTOR_STATUS.COMPLETE 

161 predictor_record.training_stop_at = datetime.now() 

162 db.session.commit() 

163 

164 

165@mark_process(name="finetune") 

166def run_finetune(df: DataFrame, args: dict, model_storage): 

167 try: 

168 if df is None or df.shape[0] == 0: 

169 raise Exception("No input data. Ensure the data source is healthy and try again.") 

170 

171 base_predictor_id = args["base_model_id"] 

172 base_predictor_record = db.Predictor.query.get(base_predictor_id) 

173 if base_predictor_record.status != PREDICTOR_STATUS.COMPLETE: 

174 raise Exception("Base model must be in status 'complete'") 

175 

176 predictor_id = model_storage.predictor_id 

177 predictor_record = db.Predictor.query.get(predictor_id) 

178 

179 # TODO move this to ModelStorage (don't work with database directly) 

180 predictor_record.data = {"training_log": "training"} 

181 predictor_record.training_start_at = datetime.now() 

182 predictor_record.status = PREDICTOR_STATUS.FINETUNING # TODO: parallel execution block 

183 db.session.commit() 

184 

185 base_fs = FileStorage( 

186 resource_group=RESOURCE_GROUP.PREDICTOR, 

187 resource_id=base_predictor_id, 

188 sync=True, 

189 ) 

190 predictor = lightwood.predictor_from_state( 

191 base_fs.folder_path / base_fs.folder_name, base_predictor_record.code 

192 ) 

193 predictor.adjust(df, adjust_args=args.get("using", {})) 

194 

195 fs = FileStorage(resource_group=RESOURCE_GROUP.PREDICTOR, resource_id=predictor_id, sync=True) 

196 predictor.save(fs.folder_path / fs.folder_name) 

197 fs.push(compression_level=0) 

198 

199 predictor_record.data = predictor.model_analysis.to_dict() # todo: update accuracy in LW as post-finetune hook 

200 predictor_record.code = base_predictor_record.code 

201 predictor_record.update_status = "up_to_date" 

202 predictor_record.status = PREDICTOR_STATUS.COMPLETE 

203 predictor_record.training_stop_at = datetime.now() 

204 db.session.commit() 

205 

206 except Exception as e: 

207 logger.error("Unexpected error during Lightwood model finetune:", exc_info=True) 

208 predictor_id = model_storage.predictor_id 

209 predictor_record = db.Predictor.query.with_for_update().get(predictor_id) 

210 error_message = format_exception_error(e) 

211 predictor_record.data = {"error": error_message} 

212 predictor_record.status = PREDICTOR_STATUS.ERROR 

213 db.session.commit() 

214 raise 

215 finally: 

216 if predictor_record.training_stop_at is None: 

217 predictor_record.training_stop_at = datetime.now() 

218 db.session.commit()