Coverage for mindsdb / integrations / handlers / timegpt_handler / timegpt_handler.py: 0%

115 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from typing import Optional, Dict 

2 

3import pandas as pd 

4from nixtla import NixtlaClient 

5 

6from mindsdb.integrations.libs.base import BaseMLEngine 

7from mindsdb.integrations.utilities.handler_utils import get_api_key 

8from mindsdb.integrations.utilities.time_series_utils import get_results_from_nixtla_df 

9# TODO: add E2E tests. 

10 

11 

12class TimeGPTHandler(BaseMLEngine): 

13 """ 

14 Integration with the Nixtla TimeGPT models for 

15 zero-shot time series forecasting. 

16 """ 

17 

18 name = "timegpt" 

19 

20 def create(self, target: str, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None: 

21 """ 

22 Create the TimeGPT Handler. 

23 Requires specifying the target column and usual time series arguments. Saves model config for later usage. 

24 """ 

25 self.generative = True 

26 time_settings = args.get("timeseries_settings", {}) 

27 using_args = args["using"] 

28 

29 mode = 'forecasting' 

30 if args.get('__mdb_sql_task', False) and args['__mdb_sql_task'].lower() in ('forecasting', 'anomalydetection'): 

31 mode = args['__mdb_sql_task'].lower() 

32 

33 if mode == 'forecasting': 

34 assert time_settings["is_timeseries"], "Specify time series settings in your query" 

35 

36 timegpt_token = get_api_key('timegpt', using_args, self.engine_storage, strict=True) 

37 timegpt = NixtlaClient(api_key=timegpt_token) 

38 assert timegpt.validate_api_key(), "Invalid TimeGPT token provided." 

39 

40 model_args = { 

41 'token': timegpt_token, 

42 "target": target, 

43 "freq": using_args.get("frequency", None), 

44 "finetune_steps": using_args.get("finetune_steps", 0), 

45 "validate_token": using_args.get("validate_token", False), 

46 "date_features": using_args.get("date_features", False), 

47 "date_features_to_one_hot": using_args.get("date_features_to_one_hot", True), 

48 "clean_ex_first": using_args.get("clean_ex_first", True), 

49 "level": using_args.get("level", [90]), 

50 "add_history": using_args.get("add_history", False), 

51 'mode': mode, 

52 } 

53 

54 if time_settings: 

55 model_args["horizon"] = time_settings["horizon"] 

56 model_args["order_by"] = time_settings["order_by"] 

57 model_args["group_by"] = time_settings.get("group_by", []) 

58 

59 if mode == 'anomalydetection': 

60 model_args["target"] = using_args['target'] if target is None else target 

61 model_args["horizon"] = using_args.get('horizon', 1) 

62 model_args["order_by"] = using_args['order_by'] 

63 model_args["group_by"] = using_args.get("group_by", []) 

64 model_args['add_history'] = True 

65 

66 assert isinstance(model_args["level"], list), "`level` must be a list of integers" 

67 assert all([isinstance(level, int) for level in model_args["level"]]), "`level` must be a list of integers" 

68 

69 self.model_storage.json_set("model_args", model_args) # persist changes to handler folder 

70 

71 def predict(self, df, args={}): 

72 """ Makes forecasts with the TimeGPT API. """ 

73 model_args = self.model_storage.json_get("model_args") 

74 args = args['predict_params'] 

75 prediction_df = self._transform_to_nixtla_df(df, model_args) 

76 

77 timegpt = NixtlaClient(api_key=model_args['token']) 

78 assert timegpt.validate_api_key(), "Invalid TimeGPT token provided." 

79 

80 forecast_df = timegpt.forecast( 

81 prediction_df, 

82 

83 # TODO: supporting param override when JOINing with a WHERE clause is blocked by mindsdb_sql#285 

84 h=args.get("horizon", model_args.get("horizon", 1)), 

85 freq=args.get("freq", model_args["freq"]), # automatically infers correct frequency if not provided by user 

86 level=model_args["level"], 

87 finetune_steps=args.get('finetune_steps', model_args['finetune_steps']), 

88 validate_api_key=args.get('validate_token', model_args['validate_token']), 

89 date_features=args.get('date_features', model_args['date_features']), 

90 date_features_to_one_hot=args.get('date_features_to_one_hot', model_args['date_features_to_one_hot']), 

91 clean_ex_first=args.get('clean_ex_first', model_args['clean_ex_first']), 

92 

93 # anomaly detection 

94 add_history=args.get('add_history', model_args['add_history']) # insample bounds and anomaly detection 

95 

96 # TODO: enable this post-refactor (#6861) 

97 # X_df=None, # exogenous variables 

98 ) 

99 if model_args['mode'] == 'forecasting': 

100 results_df = forecast_df[['unique_id', 'ds', 'TimeGPT']] 

101 results_df = get_results_from_nixtla_df(results_df, model_args) 

102 elif model_args['mode'] == 'anomalydetection': 

103 forecast_df['ds'] = pd.to_datetime(forecast_df['ds']) 

104 results_df = forecast_df.merge(prediction_df, how='inner') # some rows drop because of TimeGPT's cold start 

105 results_df['anomaly'] = (results_df['y'] > results_df[f'TimeGPT-hi-{model_args["level"][0]}']) | (results_df['y'] < results_df[f'TimeGPT-lo-{model_args["level"][0]}']) 

106 

107 forecast_df = results_df # rewrite forecast_df so that we can reuse code below for prediction intervals 

108 results_df = get_results_from_nixtla_df(results_df, model_args) 

109 results_df = results_df.rename({'y': f'observed_{model_args["target"]}'}, axis=1) 

110 else: 

111 raise Exception(f'Unsupported prediction mode: {model_args["mode"]}') 

112 

113 # infer date 

114 ds_col = model_args["order_by"] 

115 if not pd.api.types.is_datetime64_any_dtype(results_df[ds_col]): 

116 results_df[ds_col] = pd.to_datetime(results_df[ds_col]) 

117 

118 results_df = results_df.rename({'TimeGPT': model_args['target']}, axis=1) 

119 

120 # add prediction intervals 

121 levels = sorted(model_args['level'], reverse=True) 

122 for i, level in enumerate(levels): 

123 if i == 0: 

124 # NOTE: this should be simplified once we refactor the expected time series output within MindsDB 

125 results_df['confidence'] = level / 100 # we report the highest level as the overall confidence 

126 results_df['lower'] = forecast_df[f'TimeGPT-lo-{level}'] 

127 results_df['upper'] = forecast_df[f'TimeGPT-hi-{level}'] 

128 else: 

129 results_df[f'lower_{level}'] = forecast_df[f'TimeGPT-lo-{level}'] 

130 results_df[f'upper_{level}'] = forecast_df[f'TimeGPT-hi-{level}'] 

131 

132 return results_df 

133 

134 def describe(self, attribute=None): 

135 model_args = self.model_storage.json_get("model_args") 

136 

137 if attribute == "model": 

138 df = pd.DataFrame({"frequency": [model_args["freq"] if model_args["freq"] else "automatic"]}) 

139 return df 

140 

141 elif attribute == "features": 

142 df = pd.DataFrame({ 

143 "order by": [model_args["order_by"]], 

144 "target": model_args["target"] 

145 }) 

146 if model_args["group_by"]: 

147 df["group by"] = [model_args["group_by"]] 

148 return df 

149 

150 elif attribute == 'info': 

151 outputs = model_args["target"] 

152 inputs = [model_args["target"], model_args["order_by"]] 

153 if model_args["group_by"]: 

154 inputs.append(model_args["group_by"]) 

155 return pd.DataFrame({"output": outputs, "input": [inputs]}) 

156 

157 else: 

158 tables = ['info', 'features', 'model'] 

159 return pd.DataFrame(tables, columns=['tables']) 

160 

161 # TODO: consolidate this method with the ones in time_series_utils.py 

162 @staticmethod 

163 def _convert_to_iso(df, date_column): 

164 # whether values in date_column are numeric (Unix timestamp) or string (date) 

165 if pd.api.types.is_numeric_dtype(df[date_column]): 

166 unit = '' 

167 # ascending unit order 

168 for u in ['ns', 'us', 'ms', 's']: 

169 mindate = pd.to_datetime(df[date_column].min(), unit=u, origin='unix') 

170 maxdate = pd.to_datetime(df[date_column].max(), unit=u, origin='unix') 

171 if mindate > pd.to_datetime('1970-01-01T00:00:00') and maxdate < pd.to_datetime('2050-12-31T23:59:59'): 

172 unit = u 

173 df[date_column] = pd.to_datetime(df[date_column], unit=unit, origin='unix') 

174 else: 

175 df[date_column] = pd.to_datetime(df[date_column]) 

176 df[date_column] = df[date_column].dt.strftime('%Y-%m-%dT%H:%M:%S') # convert to ISO 8601 format 

177 df[date_column] = pd.to_datetime(df[date_column]) 

178 return df 

179 

180 # TODO: consolidate this method with the ones in time_series_utils.py 

181 def _transform_to_nixtla_df(self, df, settings_dict, exog_vars=[]): 

182 nixtla_df = df.copy() 

183 # Transform group columns into single unique_id column 

184 gby = settings_dict["group_by"] 

185 if len(gby) > 1: 

186 for col in gby: 

187 nixtla_df[col] = nixtla_df[col].astype(str) 

188 nixtla_df["unique_id"] = nixtla_df[gby].agg("/".join, axis=1) 

189 group_col = "ignore this" 

190 elif len(gby) == 1 and gby[0] is not None: 

191 group_col = settings_dict["group_by"][0] 

192 else: 

193 group_col = '__unique_id' 

194 nixtla_df[group_col] = '1' 

195 

196 # Rename columns to statsforecast names 

197 nixtla_df = nixtla_df.rename( 

198 {settings_dict["target"]: "y", settings_dict["order_by"]: "ds", group_col: "unique_id"}, axis=1 

199 ) 

200 

201 columns_to_keep = ["unique_id", "ds", "y"] + exog_vars 

202 nixtla_df = self._convert_to_iso(nixtla_df, "ds") 

203 nixtla_df = nixtla_df[columns_to_keep].sort_values(by=['unique_id', 'ds'], ascending=True) # expects ascending 

204 nixtla_df['y'] = nixtla_df['y'].astype(float) 

205 return nixtla_df.reset_index(drop=True)