Coverage for mindsdb / integrations / handlers / timegpt_handler / timegpt_handler.py: 0%
115 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from typing import Optional, Dict
3import pandas as pd
4from nixtla import NixtlaClient
6from mindsdb.integrations.libs.base import BaseMLEngine
7from mindsdb.integrations.utilities.handler_utils import get_api_key
8from mindsdb.integrations.utilities.time_series_utils import get_results_from_nixtla_df
9# TODO: add E2E tests.
12class TimeGPTHandler(BaseMLEngine):
13 """
14 Integration with the Nixtla TimeGPT models for
15 zero-shot time series forecasting.
16 """
18 name = "timegpt"
20 def create(self, target: str, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None:
21 """
22 Create the TimeGPT Handler.
23 Requires specifying the target column and usual time series arguments. Saves model config for later usage.
24 """
25 self.generative = True
26 time_settings = args.get("timeseries_settings", {})
27 using_args = args["using"]
29 mode = 'forecasting'
30 if args.get('__mdb_sql_task', False) and args['__mdb_sql_task'].lower() in ('forecasting', 'anomalydetection'):
31 mode = args['__mdb_sql_task'].lower()
33 if mode == 'forecasting':
34 assert time_settings["is_timeseries"], "Specify time series settings in your query"
36 timegpt_token = get_api_key('timegpt', using_args, self.engine_storage, strict=True)
37 timegpt = NixtlaClient(api_key=timegpt_token)
38 assert timegpt.validate_api_key(), "Invalid TimeGPT token provided."
40 model_args = {
41 'token': timegpt_token,
42 "target": target,
43 "freq": using_args.get("frequency", None),
44 "finetune_steps": using_args.get("finetune_steps", 0),
45 "validate_token": using_args.get("validate_token", False),
46 "date_features": using_args.get("date_features", False),
47 "date_features_to_one_hot": using_args.get("date_features_to_one_hot", True),
48 "clean_ex_first": using_args.get("clean_ex_first", True),
49 "level": using_args.get("level", [90]),
50 "add_history": using_args.get("add_history", False),
51 'mode': mode,
52 }
54 if time_settings:
55 model_args["horizon"] = time_settings["horizon"]
56 model_args["order_by"] = time_settings["order_by"]
57 model_args["group_by"] = time_settings.get("group_by", [])
59 if mode == 'anomalydetection':
60 model_args["target"] = using_args['target'] if target is None else target
61 model_args["horizon"] = using_args.get('horizon', 1)
62 model_args["order_by"] = using_args['order_by']
63 model_args["group_by"] = using_args.get("group_by", [])
64 model_args['add_history'] = True
66 assert isinstance(model_args["level"], list), "`level` must be a list of integers"
67 assert all([isinstance(level, int) for level in model_args["level"]]), "`level` must be a list of integers"
69 self.model_storage.json_set("model_args", model_args) # persist changes to handler folder
71 def predict(self, df, args={}):
72 """ Makes forecasts with the TimeGPT API. """
73 model_args = self.model_storage.json_get("model_args")
74 args = args['predict_params']
75 prediction_df = self._transform_to_nixtla_df(df, model_args)
77 timegpt = NixtlaClient(api_key=model_args['token'])
78 assert timegpt.validate_api_key(), "Invalid TimeGPT token provided."
80 forecast_df = timegpt.forecast(
81 prediction_df,
83 # TODO: supporting param override when JOINing with a WHERE clause is blocked by mindsdb_sql#285
84 h=args.get("horizon", model_args.get("horizon", 1)),
85 freq=args.get("freq", model_args["freq"]), # automatically infers correct frequency if not provided by user
86 level=model_args["level"],
87 finetune_steps=args.get('finetune_steps', model_args['finetune_steps']),
88 validate_api_key=args.get('validate_token', model_args['validate_token']),
89 date_features=args.get('date_features', model_args['date_features']),
90 date_features_to_one_hot=args.get('date_features_to_one_hot', model_args['date_features_to_one_hot']),
91 clean_ex_first=args.get('clean_ex_first', model_args['clean_ex_first']),
93 # anomaly detection
94 add_history=args.get('add_history', model_args['add_history']) # insample bounds and anomaly detection
96 # TODO: enable this post-refactor (#6861)
97 # X_df=None, # exogenous variables
98 )
99 if model_args['mode'] == 'forecasting':
100 results_df = forecast_df[['unique_id', 'ds', 'TimeGPT']]
101 results_df = get_results_from_nixtla_df(results_df, model_args)
102 elif model_args['mode'] == 'anomalydetection':
103 forecast_df['ds'] = pd.to_datetime(forecast_df['ds'])
104 results_df = forecast_df.merge(prediction_df, how='inner') # some rows drop because of TimeGPT's cold start
105 results_df['anomaly'] = (results_df['y'] > results_df[f'TimeGPT-hi-{model_args["level"][0]}']) | (results_df['y'] < results_df[f'TimeGPT-lo-{model_args["level"][0]}'])
107 forecast_df = results_df # rewrite forecast_df so that we can reuse code below for prediction intervals
108 results_df = get_results_from_nixtla_df(results_df, model_args)
109 results_df = results_df.rename({'y': f'observed_{model_args["target"]}'}, axis=1)
110 else:
111 raise Exception(f'Unsupported prediction mode: {model_args["mode"]}')
113 # infer date
114 ds_col = model_args["order_by"]
115 if not pd.api.types.is_datetime64_any_dtype(results_df[ds_col]):
116 results_df[ds_col] = pd.to_datetime(results_df[ds_col])
118 results_df = results_df.rename({'TimeGPT': model_args['target']}, axis=1)
120 # add prediction intervals
121 levels = sorted(model_args['level'], reverse=True)
122 for i, level in enumerate(levels):
123 if i == 0:
124 # NOTE: this should be simplified once we refactor the expected time series output within MindsDB
125 results_df['confidence'] = level / 100 # we report the highest level as the overall confidence
126 results_df['lower'] = forecast_df[f'TimeGPT-lo-{level}']
127 results_df['upper'] = forecast_df[f'TimeGPT-hi-{level}']
128 else:
129 results_df[f'lower_{level}'] = forecast_df[f'TimeGPT-lo-{level}']
130 results_df[f'upper_{level}'] = forecast_df[f'TimeGPT-hi-{level}']
132 return results_df
134 def describe(self, attribute=None):
135 model_args = self.model_storage.json_get("model_args")
137 if attribute == "model":
138 df = pd.DataFrame({"frequency": [model_args["freq"] if model_args["freq"] else "automatic"]})
139 return df
141 elif attribute == "features":
142 df = pd.DataFrame({
143 "order by": [model_args["order_by"]],
144 "target": model_args["target"]
145 })
146 if model_args["group_by"]:
147 df["group by"] = [model_args["group_by"]]
148 return df
150 elif attribute == 'info':
151 outputs = model_args["target"]
152 inputs = [model_args["target"], model_args["order_by"]]
153 if model_args["group_by"]:
154 inputs.append(model_args["group_by"])
155 return pd.DataFrame({"output": outputs, "input": [inputs]})
157 else:
158 tables = ['info', 'features', 'model']
159 return pd.DataFrame(tables, columns=['tables'])
161 # TODO: consolidate this method with the ones in time_series_utils.py
162 @staticmethod
163 def _convert_to_iso(df, date_column):
164 # whether values in date_column are numeric (Unix timestamp) or string (date)
165 if pd.api.types.is_numeric_dtype(df[date_column]):
166 unit = ''
167 # ascending unit order
168 for u in ['ns', 'us', 'ms', 's']:
169 mindate = pd.to_datetime(df[date_column].min(), unit=u, origin='unix')
170 maxdate = pd.to_datetime(df[date_column].max(), unit=u, origin='unix')
171 if mindate > pd.to_datetime('1970-01-01T00:00:00') and maxdate < pd.to_datetime('2050-12-31T23:59:59'):
172 unit = u
173 df[date_column] = pd.to_datetime(df[date_column], unit=unit, origin='unix')
174 else:
175 df[date_column] = pd.to_datetime(df[date_column])
176 df[date_column] = df[date_column].dt.strftime('%Y-%m-%dT%H:%M:%S') # convert to ISO 8601 format
177 df[date_column] = pd.to_datetime(df[date_column])
178 return df
180 # TODO: consolidate this method with the ones in time_series_utils.py
181 def _transform_to_nixtla_df(self, df, settings_dict, exog_vars=[]):
182 nixtla_df = df.copy()
183 # Transform group columns into single unique_id column
184 gby = settings_dict["group_by"]
185 if len(gby) > 1:
186 for col in gby:
187 nixtla_df[col] = nixtla_df[col].astype(str)
188 nixtla_df["unique_id"] = nixtla_df[gby].agg("/".join, axis=1)
189 group_col = "ignore this"
190 elif len(gby) == 1 and gby[0] is not None:
191 group_col = settings_dict["group_by"][0]
192 else:
193 group_col = '__unique_id'
194 nixtla_df[group_col] = '1'
196 # Rename columns to statsforecast names
197 nixtla_df = nixtla_df.rename(
198 {settings_dict["target"]: "y", settings_dict["order_by"]: "ds", group_col: "unique_id"}, axis=1
199 )
201 columns_to_keep = ["unique_id", "ds", "y"] + exog_vars
202 nixtla_df = self._convert_to_iso(nixtla_df, "ds")
203 nixtla_df = nixtla_df[columns_to_keep].sort_values(by=['unique_id', 'ds'], ascending=True) # expects ascending
204 nixtla_df['y'] = nixtla_df['y'].astype(float)
205 return nixtla_df.reset_index(drop=True)