Coverage for mindsdb / integrations / handlers / neuralforecast_handler / neuralforecast_handler.py: 0%

75 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from sklearn.metrics import r2_score 

2import dill 

3import pandas as pd 

4import tempfile 

5from mindsdb.integrations.libs.base import BaseMLEngine 

6from mindsdb.integrations.utilities.time_series_utils import ( 

7 transform_to_nixtla_df, 

8 get_results_from_nixtla_df, 

9 infer_frequency, 

10 get_model_accuracy_dict, 

11 get_hierarchy_from_df, 

12 reconcile_forecasts 

13) 

14from neuralforecast import NeuralForecast 

15from neuralforecast.models import NHITS 

16from neuralforecast.auto import AutoNHITS 

17from ray.tune.search.hyperopt import HyperOptSearch 

18 

19# hierarchicalforecast is an optional dependency 

20try: 

21 from hierarchicalforecast.core import HierarchicalReconciliation 

22except ImportError: 

23 HierarchicalReconciliation = None 

24 

25 

26class NeuralForecastHandler(BaseMLEngine): 

27 """Integration with the Nixtla NeuralForecast library for 

28 time series forecasting with neural networks. 

29 """ 

30 

31 name = "neuralforecast" 

32 

33 def create(self, target, df, args={}): 

34 """Create the NeuralForecast Handler. 

35 

36 Requires specifying the target column to predict and time series arguments for 

37 prediction horizon, time column (order by) and grouping column(s). 

38 

39 Saves model params to desk, which are called later in the predict() method. 

40 """ 

41 time_settings = args["timeseries_settings"] 

42 using_args = args["using"] 

43 assert time_settings["is_timeseries"], "Specify time series settings in your query" 

44 # store model args and time series settings in the model folder 

45 model_args = {} 

46 model_args["target"] = target 

47 model_args["horizon"] = time_settings["horizon"] 

48 model_args["order_by"] = time_settings["order_by"] 

49 model_args["group_by"] = time_settings["group_by"] 

50 model_args["frequency"] = ( 

51 using_args["frequency"] if "frequency" in using_args else infer_frequency(df, time_settings["order_by"]) 

52 ) 

53 model_args["exog_vars"] = using_args["exogenous_vars"] if "exogenous_vars" in using_args else [] 

54 model_args["max_steps"] = using_args.get('max_steps', 20) 

55 model_args["val_check_steps"] = using_args.get('val_check_steps', 10) 

56 model_args["n_auto_trials"] = using_args.get('n_auto_trials', 0) 

57 model_args["model_folder"] = tempfile.mkdtemp() 

58 

59 # Deal with hierarchy 

60 model_args["hierarchy"] = using_args["hierarchy"] if "hierarchy" in using_args else False 

61 if model_args["hierarchy"] and HierarchicalReconciliation is not None: 

62 training_df, hier_df, hier_dict = get_hierarchy_from_df(df, model_args) 

63 self.model_storage.file_set("hier_dict", dill.dumps(hier_dict)) 

64 self.model_storage.file_set("hier_df", dill.dumps(hier_df)) 

65 self.model_storage.file_set("training_df", dill.dumps(training_df)) 

66 else: 

67 training_df = transform_to_nixtla_df(df, model_args, model_args["exog_vars"]) 

68 

69 # Train model 

70 if model_args["n_auto_trials"]: 

71 model = AutoNHITS(time_settings["horizon"], gpus=0, num_samples=model_args["n_auto_trials"], search_alg=HyperOptSearch()) 

72 else: 

73 # faster implementation without auto parameter tuning 

74 model = NHITS(time_settings["horizon"], time_settings["window"], hist_exog_list=model_args["exog_vars"], max_steps=model_args["max_steps"]) 

75 neural = NeuralForecast(models=[model], freq=model_args["frequency"]) 

76 

77 if model_args.get('crossval', False): 

78 results_df = neural.cross_validation(training_df) 

79 model_args["accuracies"] = get_model_accuracy_dict(results_df, r2_score) 

80 else: 

81 neural.fit(training_df) 

82 

83 # persist changes to handler folder 

84 neural.save(model_args["model_folder"], overwrite=True) 

85 self.model_storage.json_set("model_args", model_args) 

86 

87 def predict(self, df, args={}): 

88 """Makes forecasts with the NeuralForecast Handler. 

89 

90 NeuralForecast is setup to predict for all groups, so it won't handle 

91 a dataframe that's been filtered to one group very well. Instead, we make 

92 the prediction for all groups then take care of the filtering after the 

93 forecasting. Prediction is nearly instant. 

94 """ 

95 # Load model arguments 

96 model_args = self.model_storage.json_get("model_args") 

97 

98 prediction_df = transform_to_nixtla_df(df, model_args) 

99 groups_to_keep = prediction_df["unique_id"].unique() 

100 

101 neural = NeuralForecast.load(model_args["model_folder"]) 

102 forecast_df = neural.predict() 

103 if model_args["hierarchy"] and HierarchicalReconciliation is not None: 

104 training_df = dill.loads(self.model_storage.file_get("training_df")) 

105 hier_df = dill.loads(self.model_storage.file_get("hier_df")) 

106 hier_dict = dill.loads(self.model_storage.file_get("hier_dict")) 

107 reconciled_df = reconcile_forecasts(training_df, forecast_df, hier_df, hier_dict) 

108 results_df = reconciled_df[reconciled_df.index.isin(groups_to_keep)] 

109 else: 

110 results_df = forecast_df[forecast_df.index.isin(groups_to_keep)].rename({ 

111 "y": model_args["target"], # auto mode 

112 "NHITS": model_args["target"], # non-auto mode 

113 }, axis=1) 

114 return get_results_from_nixtla_df(results_df, model_args) 

115 

116 def describe(self, attribute=None): 

117 model_args = self.model_storage.json_get("model_args") 

118 

119 if attribute == "model": 

120 return pd.DataFrame({k: [model_args[k]] for k in ["model_name", "frequency", "hierarchy"]}) 

121 

122 elif attribute == "features": 

123 return pd.DataFrame( 

124 {"ds": [model_args["order_by"]], "y": model_args["target"], "unique_id": [model_args["group_by"]], "exog_vars": [model_args["exog_vars"]]} 

125 ) 

126 

127 elif attribute == 'info': 

128 outputs = model_args["target"] 

129 inputs = [model_args["target"], model_args["order_by"], model_args["group_by"]] + model_args["exog_vars"] 

130 accuracies = [(model, acc) for model, acc in model_args.get("accuracies", {}).items()] 

131 return pd.DataFrame({"accuracies": [accuracies], "outputs": outputs, "inputs": [inputs]}) 

132 

133 else: 

134 tables = ['info', 'features', 'model'] 

135 return pd.DataFrame(tables, columns=['tables'])