Coverage for mindsdb / integrations / handlers / statsforecast_handler / statsforecast_handler.py: 0%
96 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import pandas as pd
2import dill
3from mindsdb.integrations.libs.base import BaseMLEngine
4from mindsdb.integrations.utilities.time_series_utils import (
5 transform_to_nixtla_df,
6 get_results_from_nixtla_df,
7 infer_frequency,
8 get_best_model_from_results_df,
9 get_model_accuracy_dict,
10 reconcile_forecasts,
11 get_hierarchy_from_df
12)
13from sklearn.metrics import r2_score
14from statsforecast import StatsForecast
15from statsforecast.models import AutoARIMA, AutoCES, AutoETS, AutoTheta
17# hierarchicalforecast is an optional dependency
18try:
19 from hierarchicalforecast.core import HierarchicalReconciliation
20except ImportError:
21 HierarchicalReconciliation = None
23DEFAULT_MODEL_NAME = "AutoARIMA"
24model_dict = {
25 "AutoARIMA": AutoARIMA,
26 "AutoCES": AutoCES,
27 "AutoETS": AutoETS,
28 "AutoTheta": AutoTheta,
29}
32def get_season_length(frequency):
33 """Infers best season length from frequency parameter.
35 We set a sensible default for seasonality based on the
36 frequency parameter. For example: we assume monthly data
37 has a season length of 12 (months in a year). If the inferred frequency
38 isn't found, we default to 1 i.e. no seasonality.
39 """
40 season_dict = { # https://pandas.pydata.org/docs/user_guide/timeseries.html#timeseries-offset-aliases
41 "H": 24,
42 "M": 12,
43 "MS": 12,
44 "Q": 4,
45 "SM": 24,
46 "BM": 12,
47 "BMS": 12,
48 "BQ": 4,
49 "BH": 24,
50 }
51 new_freq = frequency.split("-")[0] if "-" in frequency else frequency # shortens longer frequencies like Q-DEC
52 return season_dict[new_freq] if new_freq in season_dict else 1
55def get_insample_cv_results(model_args, df):
56 """Gets insample cross validation results"""
57 season_length = get_season_length(model_args["frequency"]) if not model_args.get("season_length") else model_args["season_length"] # noqa
58 if model_args["model_name"] == "auto":
59 models = [model(season_length=season_length) for model in model_dict.values()]
60 else:
61 models = [model_dict[model_args["model_name"]](season_length=season_length)]
63 sf = StatsForecast(models, model_args["frequency"])
64 sf.cross_validation(model_args["horizon"], df, fitted=True)
65 results_df = sf.cross_validation_fitted_values()
66 return results_df.rename({"CES": "AutoCES"}, axis=1) # Fixes a Nixtla bug
69def choose_model(model_args, results_df):
70 """Chooses which model to use in StatsForecast.
72 If the user passes 'auto' for their model_name, this will choose the best
73 model based on in-sample cross validation performance. This will then modify
74 the model_args dictionnary, replacing 'auto' with the best-performing model.
75 """
76 if model_args["model_name"] == "auto":
77 model_args["model_name"] = get_best_model_from_results_df(results_df)
78 model_args["season_length"] = get_season_length(model_args["frequency"]) if not model_args.get("season_length") else model_args["season_length"] # noqa
79 model = model_dict[model_args["model_name"]]
80 return model(season_length=model_args["season_length"])
83class StatsForecastHandler(BaseMLEngine):
84 """Integration with the Nixtla StatsForecast library for
85 time series forecasting with classical methods.
86 """
88 name = "statsforecast"
90 def create(self, target, df, args={}):
91 """Create the StatsForecast Handler.
93 Requires specifying the target column to predict and time series arguments for
94 prediction horizon, time column (order by) and grouping column(s).
96 Saves args, models params, and the formatted training df to disk. The training df
97 is used later in the predict() method.
98 """
99 time_settings = args["timeseries_settings"]
100 using_args = args["using"]
101 assert time_settings["is_timeseries"], "Specify time series settings in your query"
102 # store model args and time series settings in the model folder
103 model_args = {}
104 model_args.update(using_args)
105 model_args["target"] = target
106 model_args["horizon"] = time_settings["horizon"]
107 model_args["order_by"] = time_settings["order_by"]
108 if 'group_by' not in time_settings:
109 # add group column
110 group_col = '__group_by'
111 time_settings["group_by"] = [group_col]
113 model_args["group_by"] = time_settings["group_by"]
114 model_args["frequency"] = (
115 using_args["frequency"] if "frequency" in using_args else infer_frequency(df, time_settings["order_by"])
116 )
117 model_args["hierarchy"] = using_args["hierarchy"] if "hierarchy" in using_args else False
118 if model_args["hierarchy"] and HierarchicalReconciliation is not None:
119 training_df, hier_df, hier_dict = get_hierarchy_from_df(df, model_args)
120 self.model_storage.file_set("hier_dict", dill.dumps(hier_dict))
121 self.model_storage.file_set("hier_df", dill.dumps(hier_df))
122 else:
123 training_df = transform_to_nixtla_df(df, model_args)
125 model_args["model_name"] = DEFAULT_MODEL_NAME if "model_name" not in using_args else using_args["model_name"]
127 results_df = get_insample_cv_results(model_args, training_df)
128 model_args["accuracies"] = get_model_accuracy_dict(results_df, r2_score)
129 model = choose_model(model_args, results_df)
130 sf = StatsForecast([model], freq=model_args["frequency"], df=training_df)
131 fitted_models = sf.fit().fitted_
133 # persist changes to handler folder
134 self.model_storage.json_set("model_args", model_args)
135 self.model_storage.file_set("training_df", dill.dumps(training_df))
136 self.model_storage.file_set("fitted_models", dill.dumps(fitted_models))
138 def predict(self, df, args={}):
139 """Makes forecasts with the StatsForecast Handler.
141 StatsForecast is setup to predict for all groups, so it won't handle
142 a dataframe that's been filtered to one group very well. Instead, we make
143 the prediction for all groups then take care of the filtering after the
144 forecasting. Prediction is nearly instant.
145 """
146 # Load model arguments
147 model_args = self.model_storage.json_get("model_args")
148 training_df = dill.loads(self.model_storage.file_get("training_df"))
149 fitted_models = dill.loads(self.model_storage.file_get("fitted_models"))
151 prediction_df = transform_to_nixtla_df(df, model_args)
152 groups_to_keep = prediction_df["unique_id"].unique()
154 sf = StatsForecast(models=[], freq=model_args["frequency"], df=training_df)
155 sf.fitted_ = fitted_models
156 model_name = str(fitted_models[0][0])
157 forecast_df = sf.predict(model_args["horizon"])
158 forecast_df.index = forecast_df.index.astype(str)
160 if model_args["hierarchy"] and HierarchicalReconciliation is not None:
161 hier_df = dill.loads(self.model_storage.file_get("hier_df"))
162 hier_dict = dill.loads(self.model_storage.file_get("hier_dict"))
163 reconciled_df = reconcile_forecasts(training_df, forecast_df, hier_df, hier_dict)
164 results_df = reconciled_df[reconciled_df.index.isin(groups_to_keep)]
166 else:
167 results_df = forecast_df[forecast_df.index.isin(groups_to_keep)]
169 result = get_results_from_nixtla_df(results_df, model_args)
170 result = result.rename(columns={model_name: model_args['target']})
171 return result
173 def describe(self, attribute=None):
174 model_args = self.model_storage.json_get("model_args")
176 if attribute == "model":
177 return pd.DataFrame({k: [model_args[k]] for k in ["model_name", "frequency", "season_length", "hierarchy"]})
179 elif attribute == "features":
180 return pd.DataFrame(
181 {"ds": [model_args["order_by"]], "y": model_args["target"], "unique_id": [model_args["group_by"]]}
182 )
184 elif attribute == 'info':
185 outputs = model_args["target"]
186 inputs = [model_args["target"], model_args["order_by"], model_args["group_by"]]
187 accuracies = [(model, acc) for model, acc in model_args["accuracies"].items()]
188 return pd.DataFrame({"accuracies": [accuracies], "outputs": outputs, "inputs": [inputs]})
190 else:
191 tables = ['info', 'features', 'model']
192 return pd.DataFrame(tables, columns=['tables'])