Coverage for mindsdb / integrations / handlers / merlion_handler / merlion_handler.py: 0%

145 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from enum import Enum 

2import json 

3from typing import Optional, Dict 

4 

5import numpy as np 

6import pandas as pd 

7import copy 

8 

9from mindsdb.integrations.libs.base import BaseMLEngine 

10from mindsdb.utilities import log 

11from .adapters import BaseMerlionForecastAdapter, DefaultForecasterAdapter, MerlionArguments, DefaultDetectorAdapter, \ 

12 SarimaForecasterAdapter, ProphetForecasterAdapter, MSESForecasterAdapter, IsolationForestDetectorAdapter, \ 

13 WindStatsDetectorAdapter, ProphetDetectorAdapter 

14 

15logger = log.getLogger(__name__) 

16 

17 

18class DetectorModelType(Enum): 

19 default = DefaultDetectorAdapter 

20 isolation = IsolationForestDetectorAdapter 

21 windstats = WindStatsDetectorAdapter 

22 prophet = ProphetDetectorAdapter 

23 

24 

25class ForecastModelType(Enum): 

26 default = DefaultForecasterAdapter 

27 sarima = SarimaForecasterAdapter 

28 prophet = ProphetForecasterAdapter 

29 mses = MSESForecasterAdapter 

30 

31 

32class TaskType(Enum): 

33 detector = DetectorModelType 

34 forecast = ForecastModelType 

35 

36 

37def is_invalid_type(name: str, type_class: Enum) -> bool: 

38 if name is None: 

39 return True 

40 return name not in type_class._member_names_ 

41 

42 

43def enum_to_str(type_class: Enum) -> str: 

44 all = [] 

45 for element in type_class: 

46 all.append(element.name) 

47 return "|".join(all) 

48 

49 

50def to_ts_dataframe(df: pd.DataFrame, time_col=None) -> (pd.DataFrame, str): 

51 columns = list(df.columns.values) 

52 # if time column has been specified, check the specified time column 

53 if time_col is not None: 

54 if time_col not in columns: 

55 raise Exception("invalid column name: " + time_col) 

56 if df[time_col].dtype != np.datetime64: 

57 try: 

58 idx = pd.to_datetime(df[time_col]) 

59 except Exception as e: 

60 raise Exception("can not convert column to datetime: " + time_col + " " + str(e)) 

61 # if time column has not been specified, try to find one 

62 else: 

63 datetime_cols = list(df.select_dtypes(include=["datetime"]).columns.values) 

64 if len(datetime_cols) > 0: 

65 time_col = datetime_cols[0] 

66 idx = pd.to_datetime(df[time_col]) 

67 else: 

68 raise Exception("can not find datetime column for time series") 

69 # build return dataframe 

70 rt_df = copy.deepcopy(df) 

71 rt_df.drop(columns=[time_col], inplace=True) 

72 rt_df.index = idx 

73 return rt_df, time_col 

74 

75 

76class MerlionHandler(BaseMLEngine): 

77 name = 'merlion' 

78 

79 ARG_USING_TASK = "task" 

80 ARG_USING_MODEL_TYPE = "model_type" 

81 # keys only be used to persist args to args.json 

82 ARG_TIME_COLUMN = "time_column" 

83 ARG_BASE_WINDOW = "base_window" 

84 ARG_PREDICT_HORIZON = "predict_horizon" 

85 ARG_TARGET = "target" 

86 ARG_COLUMN_SEQUENCE = "column_sequence" 

87 

88 KWARGS_DF = "df" 

89 

90 DEFAULT_MODEL_TYPE = "default" 

91 DEFAULT_MAX_PREDICT_STEP = 100 

92 DEFAULT_PREDICT_BASE_WINDOW = 10 

93 

94 PERSIST_MODEL_FILE_NAME = "merlion_model" 

95 PERSIST_ARGS_KEY_IN_JSON_STORAGE = "args" 

96 

97 def create(self, target, args=None, **kwargs): 

98 df: pd.DataFrame = kwargs.get(self.KWARGS_DF, None) 

99 # prepare arguments 

100 using_args = args.get("using", dict()) 

101 task = using_args.get(self.ARG_USING_TASK, TaskType.forecast.name) 

102 model_type = using_args.get(self.ARG_USING_MODEL_TYPE, self.DEFAULT_MODEL_TYPE) 

103 timeseries_settings = args.get("timeseries_settings", dict()) 

104 time_column = timeseries_settings.get("order_by", None) 

105 horizon = timeseries_settings.get("horizon", self.DEFAULT_MAX_PREDICT_STEP) 

106 window = timeseries_settings.get("window", self.DEFAULT_PREDICT_BASE_WINDOW) 

107 # update args for default value maybe has been used, only time column will be set afterwards 

108 serialize_args = dict() 

109 serialize_args[self.ARG_TARGET] = target 

110 serialize_args[self.ARG_USING_TASK] = task 

111 serialize_args[self.ARG_USING_MODEL_TYPE] = model_type 

112 serialize_args[self.ARG_PREDICT_HORIZON] = horizon 

113 serialize_args[self.ARG_BASE_WINDOW] = window 

114 

115 # check df 

116 if df is None: 

117 raise Exception("missing required key in args: " + self.KWARGS_DF) 

118 else: 

119 column_sequence = sorted(list(df.columns.values)) 

120 df = df[column_sequence] 

121 serialize_args[self.ARG_COLUMN_SEQUENCE] = column_sequence 

122 

123 # check task, model_type and get the adapter_class 

124 adapter_class = self.__args_to_adapter_class(task=task, model_type=model_type) 

125 task_enum = TaskType[task] 

126 

127 # check and cast to ts dataframe 

128 ts_df, time_column = to_ts_dataframe(df=df, time_col=time_column) 

129 serialize_args[self.ARG_TIME_COLUMN] = time_column 

130 

131 # train model 

132 model_args = {} 

133 if task_enum == TaskType.forecast: 

134 model_args[MerlionArguments.max_forecast_steps.value] = horizon 

135 adapter: BaseMerlionForecastAdapter = adapter_class(**model_args) 

136 logger.info("Training model, args: " + json.dumps(serialize_args)) 

137 adapter.train(df=ts_df, target=target) 

138 logger.info("Training model completed.") 

139 

140 # persist save model 

141 model_bytes = adapter.to_bytes() 

142 self.model_storage.file_set(self.PERSIST_MODEL_FILE_NAME, model_bytes) 

143 self.model_storage.json_set(self.PERSIST_ARGS_KEY_IN_JSON_STORAGE, serialize_args) 

144 logger.info("Model and args saved.") 

145 

146 def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame: 

147 rt_df = df.copy(deep=True) 

148 # read model and args from storage 

149 model_bytes = self.model_storage.file_get(self.PERSIST_MODEL_FILE_NAME) 

150 deserialize_args = self.model_storage.json_get(self.PERSIST_ARGS_KEY_IN_JSON_STORAGE) 

151 

152 # resolve args 

153 task = deserialize_args[self.ARG_USING_TASK] 

154 model_type = deserialize_args[self.ARG_USING_MODEL_TYPE] 

155 time_column = deserialize_args[self.ARG_TIME_COLUMN] 

156 target = deserialize_args[self.ARG_TARGET] 

157 horizon = deserialize_args[self.ARG_PREDICT_HORIZON] 

158 feature_column_sequence = list(deserialize_args[self.ARG_COLUMN_SEQUENCE]) 

159 task_enum = TaskType[task] 

160 

161 # check df and prepare data 

162 if task_enum == TaskType.forecast: 

163 feature_column_sequence.remove(target) 

164 missing_required_columns = set(feature_column_sequence) - set(rt_df.columns.values) 

165 if len(missing_required_columns) > 0: 

166 raise Exception("Missing required columns: " + ",".join(missing_required_columns)) 

167 feature_df = rt_df[feature_column_sequence] 

168 ts_feature_df, _ = to_ts_dataframe(df=feature_df, time_col=time_column) 

169 

170 # init model adapter 

171 adapter_class: BaseMerlionForecastAdapter = self.__args_to_adapter_class(task=task, model_type=model_type) 

172 model_args = {} 

173 if task_enum == TaskType.forecast: 

174 model_args[MerlionArguments.max_forecast_steps.value] = horizon 

175 adapter = adapter_class(**model_args) 

176 adapter.initialize_model(bytes=model_bytes) 

177 

178 # predict 

179 pred_df = adapter.predict(df=ts_feature_df, target=target) 

180 

181 # build result 

182 pred_df = ts_feature_df[[]].join(pred_df, how="left") 

183 

184 # arrange data 

185 pred_df.index = rt_df.index 

186 if task_enum == TaskType.forecast: 

187 pred_df = pred_df[~pred_df[target].isna()] 

188 rt_df.drop(columns=[target], inplace=True) 

189 elif task_enum == TaskType.detector: 

190 pred_df[f"{target}__anomaly_score"].fillna(0, inplace=True) 

191 rt_df = rt_df.join(pred_df, how="right") 

192 return rt_df 

193 

194 def __args_to_adapter_class(self, task: str, model_type: str): 

195 # check task_type 

196 try: 

197 task_enum = TaskType[task] 

198 except Exception: 

199 raise Exception("wrong using.task: " + task + ", valid options: " + enum_to_str(TaskType)) 

200 # check and get model class 

201 try: 

202 adapter_class = task_enum.value[model_type].value 

203 except Exception as e: 

204 raise Exception("Wrong using.model_type: " + model_type + ", valid options: " 

205 + enum_to_str(task_enum.value) + ", " + str(e)) 

206 return adapter_class