Coverage for mindsdb / integrations / handlers / lightwood_handler / lightwood_handler.py: 0%
354 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import copy
2import json
3from datetime import datetime
4from functools import lru_cache
5from typing import Dict, Optional
7import lightwood
8import numpy as np
9import pandas as pd
10from type_infer.dtype import dtype
12import mindsdb.interfaces.storage.db as db
13import mindsdb.utilities.profiler as profiler
14from mindsdb.integrations.libs.base import BaseMLEngine
16# from mindsdb.utilities.hooks import after_predict as after_predict_hook
17from mindsdb.interfaces.model.functions import get_model_record
18from mindsdb.interfaces.storage.json import get_json_storage
19from mindsdb.utilities.functions import cast_row_types
21from .functions import run_finetune, run_learn
24class NumpyJSONEncoder(json.JSONEncoder):
25 """
26 Use this encoder to avoid
27 "TypeError: Object of type float32 is not JSON serializable"
29 Example:
30 x = np.float32(5)
31 json.dumps(x, cls=NumpyJSONEncoder)
32 """
34 def default(self, obj):
35 if isinstance(obj, np.ndarray):
36 return obj.tolist()
37 elif isinstance(obj, (np.float, np.float32, np.float64)):
38 return float(obj)
39 else:
40 return super().default(obj)
43class LightwoodHandler(BaseMLEngine):
44 name = 'lightwood'
46 @staticmethod
47 def create_validation(target, args=None, **kwargs):
48 if 'df' not in kwargs:
49 return
50 df = kwargs['df']
51 columns = [x.lower() for x in df.columns]
52 if target.lower() not in columns:
53 raise Exception(f"There is no column '{target}' in dataframe")
55 if (
56 'timeseries_settings' in args
57 and args['timeseries_settings'].get('is_timeseries') is True
58 ):
59 tss = args['timeseries_settings']
60 if 'order_by' in tss and tss['order_by'].lower() not in columns:
61 raise Exception(f"There is no column '{tss['order_by']}' in dataframe")
62 if isinstance(tss.get('group_by'), list):
63 for column in tss['group_by']:
64 if column.lower() not in columns:
65 raise Exception(f"There is no column '{column}' in dataframe")
67 @profiler.profile('LightwoodHandler.create')
68 def create(
69 self,
70 target: str,
71 df: Optional[pd.DataFrame] = None,
72 args: Optional[Dict] = None,
73 ) -> None:
74 args['target'] = target
75 run_learn(
76 df, args, self.model_storage # Problem definition and JsonAI override
77 )
79 @profiler.profile('LightwoodHandler.finetune')
80 def finetune(
81 self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None
82 ) -> None:
83 run_finetune(df, args, self.model_storage)
85 @staticmethod
86 @lru_cache(maxsize=5)
87 def get_predictor(predictor_path, predictor_code):
88 predictor = lightwood.predictor_from_state(predictor_path, predictor_code)
89 return predictor
91 @profiler.profile('LightwoodHandler.predict')
92 def predict(self, df, args=None):
93 pred_format = args['pred_format']
94 predictor_code = args['code']
95 learn_args = args['learn_args']
96 pred_args = args.get('predict_params', {})
97 self.model_storage.fileStorage.pull()
99 with profiler.Context('load model'):
100 predictor_path = (
101 self.model_storage.fileStorage.folder_path
102 / self.model_storage.fileStorage.folder_name
103 )
104 predictor = LightwoodHandler.get_predictor(predictor_path, predictor_code)
106 dtype_dict = predictor.dtype_dict
108 if hasattr(predictor.problem_definition, 'embedding_only'):
109 embedding_mode = (
110 predictor.problem_definition.embedding_only
111 or pred_args.get('return_embedding', False)
112 )
113 else:
114 embedding_mode = False
116 with profiler.Context('predict'):
117 predictions = predictor.predict(df, args=pred_args)
119 with profiler.Context('predict-postprocessing'):
120 if embedding_mode:
121 predictions['prediction'] = predictions.values.tolist()
122 # note: return here once ml engine executor supports non-target named outputs
123 predictions = predictions[['prediction']]
125 predictions = predictions.to_dict(orient='records')
127 # TODO!!!
128 # after_predict_hook(
129 # company_id=self.company_id,
130 # predictor_id=predictor_record.id,
131 # rows_in_count=df.shape[0],
132 # columns_in_count=df.shape[1],
133 # rows_out_count=len(predictions)
134 # )
136 # region format result
137 target = args['target']
138 explain_arr = []
139 pred_dicts = []
140 for i, row in enumerate(predictions):
141 values = {
142 'predicted_value': row['prediction'],
143 'confidence': row.get('confidence', None),
144 'anomaly': row.get('anomaly', None),
145 'truth': row.get('truth', None),
146 }
148 if predictor.supports_proba:
149 for cls in predictor.statistical_analysis.train_observed_classes:
150 if row.get(f'__mdb_proba_{cls}', False):
151 values[f'probability_class_{cls}'] = round(
152 row[f'__mdb_proba_{cls}'], 4
153 )
155 for block in predictor.analysis_blocks:
156 if type(block).__name__ == 'ShapleyValues':
157 cols = block.columns
158 values['shap_base_response'] = round(
159 row['shap_base_response'], 4
160 )
161 values['shap_final_response'] = round(
162 row['shap_final_response'], 4
163 )
164 for col in cols:
165 values[f'shap_contribution_{col}'] = round(
166 row[f'shap_contribution_{col}'], 4
167 )
169 if 'lower' in row:
170 values['confidence_lower_bound'] = row.get('lower', None)
171 values['confidence_upper_bound'] = row.get('upper', None)
173 obj = {target: values}
174 explain_arr.append(obj)
176 td = {'predicted_value': row['prediction']}
177 for col in df.columns:
178 if col in row:
179 td[col] = row[col]
180 elif f'order_{col}' in row:
181 td[col] = row[f'order_{col}']
182 elif f'group_{col}' in row:
183 td[col] = row[f'group_{col}']
184 else:
185 orginal_index = row.get('original_index')
186 if orginal_index is None:
187 orginal_index = i
188 td[col] = df.iloc[orginal_index][col]
189 pred_dicts.append({target: td})
191 new_pred_dicts = []
192 for row in pred_dicts:
193 new_row = {}
194 for key in row:
195 new_row.update(row[key])
196 new_row[key] = new_row['predicted_value']
197 del new_row['predicted_value']
198 new_pred_dicts.append(new_row)
199 pred_dicts = new_pred_dicts
201 columns = list(dtype_dict.keys())
202 predicted_columns = target
203 if not isinstance(predicted_columns, list):
204 predicted_columns = [predicted_columns]
205 # endregion
207 original_target_values = {}
208 for col in predicted_columns:
209 df = df.reset_index()
210 original_target_values[col + '_original'] = []
211 for _index, row in df.iterrows():
212 original_target_values[col + '_original'].append(row.get(col))
214 # region transform ts predictions
215 timeseries_settings = learn_args.get(
216 'timeseries_settings', {'is_timeseries': False}
217 )
219 if timeseries_settings['is_timeseries'] is True:
220 # offset forecast if have __mdb_forecast_offset > 0
221 forecast_offset = any(
222 [
223 row.get('__mdb_forecast_offset') is not None
224 and row['__mdb_forecast_offset'] > 0
225 for row in pred_dicts
226 ]
227 )
229 group_by = timeseries_settings.get('group_by', [])
230 order_by_column = timeseries_settings['order_by']
231 if isinstance(order_by_column, list):
232 order_by_column = order_by_column[0]
233 horizon = timeseries_settings['horizon']
235 # region convert values to lists in case of horizon==1.
236 # That needs to make processing below unified for any case.
237 if horizon == 1:
238 for row in pred_dicts:
239 if isinstance(row[order_by_column], list) is False:
240 row[order_by_column] = [row[order_by_column]]
241 if isinstance(row[target], list) is False:
242 row[target] = [row[target]]
243 for row in explain_arr:
244 for col in (
245 'predicted_value',
246 'confidence',
247 'confidence_lower_bound',
248 'confidence_upper_bound',
249 ):
250 if isinstance(row[target][col], list) is False:
251 row[target][col] = [row[target][col]]
252 # endregion
254 if len(group_by) == 0:
255 rows_by_groups = {
256 (): {'rows': pred_dicts, 'explanations': explain_arr}
257 }
258 else:
259 groups = set()
260 for row in pred_dicts:
261 groups.add(tuple([row[x] for x in group_by]))
263 # split rows by groups
264 rows_by_groups = {}
265 for group in groups:
266 rows_by_groups[group] = {'rows': [], 'explanations': []}
267 for row_index, row in enumerate(pred_dicts):
268 is_wrong_group = False
269 for i, group_by_key in enumerate(group_by):
270 if row[group_by_key] != group[i]:
271 is_wrong_group = True
272 break
273 if not is_wrong_group:
274 rows_by_groups[group]['rows'].append(row)
275 rows_by_groups[group]['explanations'].append(
276 explain_arr[row_index]
277 )
279 for group, data in rows_by_groups.items():
280 rows = data['rows']
281 explanations = data['explanations']
283 if len(rows) == 0:
284 break
286 for row in rows:
287 predictions = row[target]
288 if isinstance(predictions, list) is False:
289 predictions = [predictions]
291 date_values = row[order_by_column]
292 if isinstance(date_values, list) is False:
293 date_values = [date_values]
295 if pred_args.get('force_ts_infer') is True:
296 # last row contains one additional prediction (used for cases like date > '2020-10-10').
297 # Extract that prediction from there and join to previous row
298 rows[-2][order_by_column] = rows[-2][order_by_column].copy()
299 rows[-2][target] = rows[-2][target].copy()
301 rows[-2][order_by_column].append(rows[-1][order_by_column][-1])
302 rows[-2][target].append(rows[-1][target][-1])
303 for col in (
304 'predicted_value',
305 'confidence',
306 'confidence_lower_bound',
307 'confidence_upper_bound',
308 ):
309 explanations[-2][target][col].append(
310 explanations[-1][target][col][-1]
311 )
312 rows.pop()
313 explanations.pop()
314 # horizon = horizon + 1
316 for i in range(len(rows) - 1):
317 row_horizon = len(rows[i][target])
318 if row_horizon > 1:
319 rows[i][target] = rows[i][target][0]
320 if isinstance(rows[i][order_by_column], list):
321 rows[i][order_by_column] = rows[i][order_by_column][0]
322 for col in (
323 'predicted_value',
324 'confidence',
325 'confidence_lower_bound',
326 'confidence_upper_bound',
327 ):
328 if row_horizon > 1 and col in explanations[i][target]:
329 explanations[i][target][col] = explanations[i][target][
330 col
331 ][0]
333 last_row = rows.pop()
334 last_explanation = explanations.pop()
335 for i in range(len(last_row[target])):
336 new_row = copy.deepcopy(last_row)
337 new_row[target] = new_row[target][i]
338 if isinstance(new_row[order_by_column], list):
339 new_row[order_by_column] = new_row[order_by_column][i]
340 if '__mindsdb_row_id' in new_row and (i > 0 or forecast_offset):
341 new_row['__mindsdb_row_id'] = None
343 new_explanation = copy.deepcopy(last_explanation)
344 for col in (
345 'predicted_value',
346 'confidence',
347 'confidence_lower_bound',
348 'confidence_upper_bound',
349 ):
350 if col in new_explanation[target]:
351 new_explanation[target][col] = new_explanation[target][
352 col
353 ][i]
354 if i != 0:
355 new_explanation[target]['anomaly'] = None
356 new_explanation[target]['truth'] = None
358 rows.append(new_row)
359 explanations.append(new_explanation)
361 pred_dicts = []
362 explanations = []
363 for group, data in rows_by_groups.items():
364 pred_dicts.extend(data['rows'])
365 explanations.extend(data['explanations'])
367 original_target_values[f'{target}_original'] = []
368 for i in range(len(pred_dicts)):
369 original_target_values[f'{target}_original'].append(
370 explanations[i][target].get('truth', None)
371 )
373 if dtype_dict[order_by_column] == dtype.date:
374 for row in pred_dicts:
375 if isinstance(row[order_by_column], (int, float)):
376 row[order_by_column] = datetime.fromtimestamp(
377 row[order_by_column]
378 ).date()
379 elif dtype_dict[order_by_column] == dtype.datetime:
380 for row in pred_dicts:
381 if isinstance(row[order_by_column], (int, float)):
382 row[order_by_column] = datetime.fromtimestamp(
383 row[order_by_column]
384 )
386 explain_arr = explanations
387 # endregion
389 if pred_format == 'explain':
390 return explain_arr
392 keys = [x for x in pred_dicts[0] if x in columns]
393 min_max_keys = []
394 for col in predicted_columns:
395 if dtype_dict[col] in (dtype.integer, dtype.float, dtype.num_tsarray):
396 min_max_keys.append(col)
398 data = []
399 explains = []
400 keys_to_save = [*keys, '__mindsdb_row_id', 'select_data_query', 'when_data']
401 for i, el in enumerate(pred_dicts):
402 data.append({key: el.get(key) for key in keys_to_save})
403 explains.append(explain_arr[i])
405 for i, row in enumerate(data):
406 cast_row_types(row, dtype_dict)
408 for k in original_target_values:
409 try:
410 row[k] = original_target_values[k][i]
411 except Exception:
412 row[k] = None
414 for column_name in columns:
415 if column_name not in row:
416 row[column_name] = None
418 explanation = explains[i]
419 for key in predicted_columns:
420 row[key + '_confidence'] = explanation[key]['confidence']
421 row[key + '_explain'] = json.dumps(
422 explanation[key], cls=NumpyJSONEncoder, ensure_ascii=False
423 )
424 if 'anomaly' in explanation[key]:
425 row[key + '_anomaly'] = explanation[key]['anomaly']
426 for key in min_max_keys:
427 if 'confidence_lower_bound' in explanation[key]:
428 row[key + '_min'] = explanation[key]['confidence_lower_bound']
429 if 'confidence_upper_bound' in explanation[key]:
430 row[key + '_max'] = explanation[key]['confidence_upper_bound']
432 return pd.DataFrame(data)
434 def edit_json_ai(self, name: str, json_ai: dict):
435 predictor_record = get_model_record(name=name, ml_handler_name='lightwood')
436 assert predictor_record is not None
438 json_ai = lightwood.JsonAI.from_dict(json_ai)
439 predictor_record.code = lightwood.code_from_json_ai(json_ai)
440 db.session.commit()
442 json_storage = get_json_storage(resource_id=predictor_record.id)
443 json_storage.set('json_ai', json_ai.to_dict())
445 def code_from_json_ai(self, json_ai: dict):
446 json_ai = lightwood.JsonAI.from_dict(json_ai)
447 code = lightwood.code_from_json_ai(json_ai)
448 return code
450 def edit_code(self, name: str, code: str):
451 """Edit an existing predictor's code"""
452 if self.config.get('cloud', False):
453 raise Exception('Code editing prohibited on cloud')
455 predictor_record = get_model_record(name=name, ml_handler_name='lightwood')
456 assert predictor_record is not None
458 lightwood.predictor_from_code(code)
459 predictor_record.code = code
460 db.session.commit()
462 json_storage = get_json_storage(resource_id=predictor_record.id)
463 json_storage.delete('json_ai')
465 def _get_features_info(self):
466 ai_info = self.model_storage.json_get('json_ai')
467 if ai_info == {}:
468 raise Exception(
469 "predictor doesn't contain enough data to generate 'feature' attribute."
470 )
471 data = []
472 dtype_dict = ai_info["dtype_dict"]
473 for column in dtype_dict:
474 c_data = []
475 c_data.append(column)
476 c_data.append(dtype_dict[column])
477 c_data.append(ai_info["encoders"][column]["module"])
478 if ai_info["encoders"][column]["args"].get("is_target", "False") == "True":
479 c_data.append("target")
480 else:
481 c_data.append("feature")
482 data.append(c_data)
484 return pd.DataFrame(data, columns=['column', 'type', 'encoder', 'role'])
486 def _get_model_info(self):
487 json_ai = self.model_storage.json_get('json_ai')
488 model_info = self.model_storage.get_info()
489 model_data = model_info['data']
491 accuracy_functions = json_ai.get('accuracy_functions')
492 if accuracy_functions:
493 accuracy_functions = str(accuracy_functions)
495 models_data = model_data.get("submodel_data", [])
496 if models_data == []:
497 raise Exception(
498 "predictor doesn't contain enough data to generate 'model' attribute"
499 )
500 data = []
502 for model in models_data:
503 m_data = []
504 m_data.append(model["name"])
505 m_data.append(model["accuracy"])
506 m_data.append(model.get("training_time", "unknown"))
507 m_data.append(1 if model["is_best"] else 0)
508 m_data.append(accuracy_functions)
509 data.append(m_data)
511 return pd.DataFrame(
512 data,
513 columns=[
514 'name',
515 'performance',
516 'training_time',
517 'selected',
518 'accuracy_functions',
519 ],
520 )
522 def _get_ensemble_data(self):
523 ai_info = self.model_storage.json_get('json_ai')
524 if ai_info == {}:
525 raise Exception(
526 "predictor doesn't contain enough data to generate 'ensamble' attribute. Please wait until predictor is complete."
527 )
528 ai_info_str = json.dumps(ai_info, indent=2)
530 return pd.DataFrame([[ai_info_str]], columns=['ensemble'])
532 def _get_progress_data(self):
533 progress_info = self.model_storage.training_state_get()
534 return pd.DataFrame([progress_info], columns=["current", "total", "name"])
536 def describe(self, attribute: Optional[str] = None) -> pd.DataFrame:
538 if attribute == 'info':
540 model_description = {}
542 model_info = self.model_storage.get_info()
543 model_data = model_info['data']
544 to_predict = model_info['to_predict'][0]
546 if model_data.get('accuracies', None) is not None:
547 if len(model_data['accuracies']) > 0:
548 model_data['accuracy'] = float(
549 np.mean(list(model_data['accuracies'].values()))
550 )
552 model_columns = self.model_storage.columns_get()
554 model_description['accuracies'] = model_data['accuracies']
555 model_description['column_importances'] = model_data['column_importances']
556 model_description['outputs'] = [to_predict]
557 model_description['inputs'] = [
558 col for col in model_columns if col not in model_description['outputs']
559 ]
561 return pd.DataFrame([model_description])
563 elif attribute == "features":
564 return self._get_features_info()
566 elif attribute == "model":
567 return self._get_model_info()
569 elif attribute == "jsonai":
570 return self._get_ensemble_data()
572 elif attribute == "progress":
573 # todo remove?
574 return self._get_progress_data()
576 else:
577 tables = ['info', 'features', 'model', 'jsonai']
578 return pd.DataFrame(tables, columns=['tables'])