Coverage for mindsdb / integrations / handlers / statsforecast_handler / statsforecast_handler.py: 0%

96 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import pandas as pd 

2import dill 

3from mindsdb.integrations.libs.base import BaseMLEngine 

4from mindsdb.integrations.utilities.time_series_utils import ( 

5 transform_to_nixtla_df, 

6 get_results_from_nixtla_df, 

7 infer_frequency, 

8 get_best_model_from_results_df, 

9 get_model_accuracy_dict, 

10 reconcile_forecasts, 

11 get_hierarchy_from_df 

12) 

13from sklearn.metrics import r2_score 

14from statsforecast import StatsForecast 

15from statsforecast.models import AutoARIMA, AutoCES, AutoETS, AutoTheta 

16 

17# hierarchicalforecast is an optional dependency 

18try: 

19 from hierarchicalforecast.core import HierarchicalReconciliation 

20except ImportError: 

21 HierarchicalReconciliation = None 

22 

23DEFAULT_MODEL_NAME = "AutoARIMA" 

24model_dict = { 

25 "AutoARIMA": AutoARIMA, 

26 "AutoCES": AutoCES, 

27 "AutoETS": AutoETS, 

28 "AutoTheta": AutoTheta, 

29} 

30 

31 

32def get_season_length(frequency): 

33 """Infers best season length from frequency parameter. 

34 

35 We set a sensible default for seasonality based on the 

36 frequency parameter. For example: we assume monthly data 

37 has a season length of 12 (months in a year). If the inferred frequency 

38 isn't found, we default to 1 i.e. no seasonality. 

39 """ 

40 season_dict = { # https://pandas.pydata.org/docs/user_guide/timeseries.html#timeseries-offset-aliases 

41 "H": 24, 

42 "M": 12, 

43 "MS": 12, 

44 "Q": 4, 

45 "SM": 24, 

46 "BM": 12, 

47 "BMS": 12, 

48 "BQ": 4, 

49 "BH": 24, 

50 } 

51 new_freq = frequency.split("-")[0] if "-" in frequency else frequency # shortens longer frequencies like Q-DEC 

52 return season_dict[new_freq] if new_freq in season_dict else 1 

53 

54 

55def get_insample_cv_results(model_args, df): 

56 """Gets insample cross validation results""" 

57 season_length = get_season_length(model_args["frequency"]) if not model_args.get("season_length") else model_args["season_length"] # noqa 

58 if model_args["model_name"] == "auto": 

59 models = [model(season_length=season_length) for model in model_dict.values()] 

60 else: 

61 models = [model_dict[model_args["model_name"]](season_length=season_length)] 

62 

63 sf = StatsForecast(models, model_args["frequency"]) 

64 sf.cross_validation(model_args["horizon"], df, fitted=True) 

65 results_df = sf.cross_validation_fitted_values() 

66 return results_df.rename({"CES": "AutoCES"}, axis=1) # Fixes a Nixtla bug 

67 

68 

69def choose_model(model_args, results_df): 

70 """Chooses which model to use in StatsForecast. 

71 

72 If the user passes 'auto' for their model_name, this will choose the best 

73 model based on in-sample cross validation performance. This will then modify 

74 the model_args dictionnary, replacing 'auto' with the best-performing model. 

75 """ 

76 if model_args["model_name"] == "auto": 

77 model_args["model_name"] = get_best_model_from_results_df(results_df) 

78 model_args["season_length"] = get_season_length(model_args["frequency"]) if not model_args.get("season_length") else model_args["season_length"] # noqa 

79 model = model_dict[model_args["model_name"]] 

80 return model(season_length=model_args["season_length"]) 

81 

82 

83class StatsForecastHandler(BaseMLEngine): 

84 """Integration with the Nixtla StatsForecast library for 

85 time series forecasting with classical methods. 

86 """ 

87 

88 name = "statsforecast" 

89 

90 def create(self, target, df, args={}): 

91 """Create the StatsForecast Handler. 

92 

93 Requires specifying the target column to predict and time series arguments for 

94 prediction horizon, time column (order by) and grouping column(s). 

95 

96 Saves args, models params, and the formatted training df to disk. The training df 

97 is used later in the predict() method. 

98 """ 

99 time_settings = args["timeseries_settings"] 

100 using_args = args["using"] 

101 assert time_settings["is_timeseries"], "Specify time series settings in your query" 

102 # store model args and time series settings in the model folder 

103 model_args = {} 

104 model_args.update(using_args) 

105 model_args["target"] = target 

106 model_args["horizon"] = time_settings["horizon"] 

107 model_args["order_by"] = time_settings["order_by"] 

108 if 'group_by' not in time_settings: 

109 # add group column 

110 group_col = '__group_by' 

111 time_settings["group_by"] = [group_col] 

112 

113 model_args["group_by"] = time_settings["group_by"] 

114 model_args["frequency"] = ( 

115 using_args["frequency"] if "frequency" in using_args else infer_frequency(df, time_settings["order_by"]) 

116 ) 

117 model_args["hierarchy"] = using_args["hierarchy"] if "hierarchy" in using_args else False 

118 if model_args["hierarchy"] and HierarchicalReconciliation is not None: 

119 training_df, hier_df, hier_dict = get_hierarchy_from_df(df, model_args) 

120 self.model_storage.file_set("hier_dict", dill.dumps(hier_dict)) 

121 self.model_storage.file_set("hier_df", dill.dumps(hier_df)) 

122 else: 

123 training_df = transform_to_nixtla_df(df, model_args) 

124 

125 model_args["model_name"] = DEFAULT_MODEL_NAME if "model_name" not in using_args else using_args["model_name"] 

126 

127 results_df = get_insample_cv_results(model_args, training_df) 

128 model_args["accuracies"] = get_model_accuracy_dict(results_df, r2_score) 

129 model = choose_model(model_args, results_df) 

130 sf = StatsForecast([model], freq=model_args["frequency"], df=training_df) 

131 fitted_models = sf.fit().fitted_ 

132 

133 # persist changes to handler folder 

134 self.model_storage.json_set("model_args", model_args) 

135 self.model_storage.file_set("training_df", dill.dumps(training_df)) 

136 self.model_storage.file_set("fitted_models", dill.dumps(fitted_models)) 

137 

138 def predict(self, df, args={}): 

139 """Makes forecasts with the StatsForecast Handler. 

140 

141 StatsForecast is setup to predict for all groups, so it won't handle 

142 a dataframe that's been filtered to one group very well. Instead, we make 

143 the prediction for all groups then take care of the filtering after the 

144 forecasting. Prediction is nearly instant. 

145 """ 

146 # Load model arguments 

147 model_args = self.model_storage.json_get("model_args") 

148 training_df = dill.loads(self.model_storage.file_get("training_df")) 

149 fitted_models = dill.loads(self.model_storage.file_get("fitted_models")) 

150 

151 prediction_df = transform_to_nixtla_df(df, model_args) 

152 groups_to_keep = prediction_df["unique_id"].unique() 

153 

154 sf = StatsForecast(models=[], freq=model_args["frequency"], df=training_df) 

155 sf.fitted_ = fitted_models 

156 model_name = str(fitted_models[0][0]) 

157 forecast_df = sf.predict(model_args["horizon"]) 

158 forecast_df.index = forecast_df.index.astype(str) 

159 

160 if model_args["hierarchy"] and HierarchicalReconciliation is not None: 

161 hier_df = dill.loads(self.model_storage.file_get("hier_df")) 

162 hier_dict = dill.loads(self.model_storage.file_get("hier_dict")) 

163 reconciled_df = reconcile_forecasts(training_df, forecast_df, hier_df, hier_dict) 

164 results_df = reconciled_df[reconciled_df.index.isin(groups_to_keep)] 

165 

166 else: 

167 results_df = forecast_df[forecast_df.index.isin(groups_to_keep)] 

168 

169 result = get_results_from_nixtla_df(results_df, model_args) 

170 result = result.rename(columns={model_name: model_args['target']}) 

171 return result 

172 

173 def describe(self, attribute=None): 

174 model_args = self.model_storage.json_get("model_args") 

175 

176 if attribute == "model": 

177 return pd.DataFrame({k: [model_args[k]] for k in ["model_name", "frequency", "season_length", "hierarchy"]}) 

178 

179 elif attribute == "features": 

180 return pd.DataFrame( 

181 {"ds": [model_args["order_by"]], "y": model_args["target"], "unique_id": [model_args["group_by"]]} 

182 ) 

183 

184 elif attribute == 'info': 

185 outputs = model_args["target"] 

186 inputs = [model_args["target"], model_args["order_by"], model_args["group_by"]] 

187 accuracies = [(model, acc) for model, acc in model_args["accuracies"].items()] 

188 return pd.DataFrame({"accuracies": [accuracies], "outputs": outputs, "inputs": [inputs]}) 

189 

190 else: 

191 tables = ['info', 'features', 'model'] 

192 return pd.DataFrame(tables, columns=['tables'])