Coverage for mindsdb / integrations / handlers / merlion_handler / merlion_handler.py: 0%
145 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from enum import Enum
2import json
3from typing import Optional, Dict
5import numpy as np
6import pandas as pd
7import copy
9from mindsdb.integrations.libs.base import BaseMLEngine
10from mindsdb.utilities import log
11from .adapters import BaseMerlionForecastAdapter, DefaultForecasterAdapter, MerlionArguments, DefaultDetectorAdapter, \
12 SarimaForecasterAdapter, ProphetForecasterAdapter, MSESForecasterAdapter, IsolationForestDetectorAdapter, \
13 WindStatsDetectorAdapter, ProphetDetectorAdapter
15logger = log.getLogger(__name__)
18class DetectorModelType(Enum):
19 default = DefaultDetectorAdapter
20 isolation = IsolationForestDetectorAdapter
21 windstats = WindStatsDetectorAdapter
22 prophet = ProphetDetectorAdapter
25class ForecastModelType(Enum):
26 default = DefaultForecasterAdapter
27 sarima = SarimaForecasterAdapter
28 prophet = ProphetForecasterAdapter
29 mses = MSESForecasterAdapter
32class TaskType(Enum):
33 detector = DetectorModelType
34 forecast = ForecastModelType
37def is_invalid_type(name: str, type_class: Enum) -> bool:
38 if name is None:
39 return True
40 return name not in type_class._member_names_
43def enum_to_str(type_class: Enum) -> str:
44 all = []
45 for element in type_class:
46 all.append(element.name)
47 return "|".join(all)
50def to_ts_dataframe(df: pd.DataFrame, time_col=None) -> (pd.DataFrame, str):
51 columns = list(df.columns.values)
52 # if time column has been specified, check the specified time column
53 if time_col is not None:
54 if time_col not in columns:
55 raise Exception("invalid column name: " + time_col)
56 if df[time_col].dtype != np.datetime64:
57 try:
58 idx = pd.to_datetime(df[time_col])
59 except Exception as e:
60 raise Exception("can not convert column to datetime: " + time_col + " " + str(e))
61 # if time column has not been specified, try to find one
62 else:
63 datetime_cols = list(df.select_dtypes(include=["datetime"]).columns.values)
64 if len(datetime_cols) > 0:
65 time_col = datetime_cols[0]
66 idx = pd.to_datetime(df[time_col])
67 else:
68 raise Exception("can not find datetime column for time series")
69 # build return dataframe
70 rt_df = copy.deepcopy(df)
71 rt_df.drop(columns=[time_col], inplace=True)
72 rt_df.index = idx
73 return rt_df, time_col
76class MerlionHandler(BaseMLEngine):
77 name = 'merlion'
79 ARG_USING_TASK = "task"
80 ARG_USING_MODEL_TYPE = "model_type"
81 # keys only be used to persist args to args.json
82 ARG_TIME_COLUMN = "time_column"
83 ARG_BASE_WINDOW = "base_window"
84 ARG_PREDICT_HORIZON = "predict_horizon"
85 ARG_TARGET = "target"
86 ARG_COLUMN_SEQUENCE = "column_sequence"
88 KWARGS_DF = "df"
90 DEFAULT_MODEL_TYPE = "default"
91 DEFAULT_MAX_PREDICT_STEP = 100
92 DEFAULT_PREDICT_BASE_WINDOW = 10
94 PERSIST_MODEL_FILE_NAME = "merlion_model"
95 PERSIST_ARGS_KEY_IN_JSON_STORAGE = "args"
97 def create(self, target, args=None, **kwargs):
98 df: pd.DataFrame = kwargs.get(self.KWARGS_DF, None)
99 # prepare arguments
100 using_args = args.get("using", dict())
101 task = using_args.get(self.ARG_USING_TASK, TaskType.forecast.name)
102 model_type = using_args.get(self.ARG_USING_MODEL_TYPE, self.DEFAULT_MODEL_TYPE)
103 timeseries_settings = args.get("timeseries_settings", dict())
104 time_column = timeseries_settings.get("order_by", None)
105 horizon = timeseries_settings.get("horizon", self.DEFAULT_MAX_PREDICT_STEP)
106 window = timeseries_settings.get("window", self.DEFAULT_PREDICT_BASE_WINDOW)
107 # update args for default value maybe has been used, only time column will be set afterwards
108 serialize_args = dict()
109 serialize_args[self.ARG_TARGET] = target
110 serialize_args[self.ARG_USING_TASK] = task
111 serialize_args[self.ARG_USING_MODEL_TYPE] = model_type
112 serialize_args[self.ARG_PREDICT_HORIZON] = horizon
113 serialize_args[self.ARG_BASE_WINDOW] = window
115 # check df
116 if df is None:
117 raise Exception("missing required key in args: " + self.KWARGS_DF)
118 else:
119 column_sequence = sorted(list(df.columns.values))
120 df = df[column_sequence]
121 serialize_args[self.ARG_COLUMN_SEQUENCE] = column_sequence
123 # check task, model_type and get the adapter_class
124 adapter_class = self.__args_to_adapter_class(task=task, model_type=model_type)
125 task_enum = TaskType[task]
127 # check and cast to ts dataframe
128 ts_df, time_column = to_ts_dataframe(df=df, time_col=time_column)
129 serialize_args[self.ARG_TIME_COLUMN] = time_column
131 # train model
132 model_args = {}
133 if task_enum == TaskType.forecast:
134 model_args[MerlionArguments.max_forecast_steps.value] = horizon
135 adapter: BaseMerlionForecastAdapter = adapter_class(**model_args)
136 logger.info("Training model, args: " + json.dumps(serialize_args))
137 adapter.train(df=ts_df, target=target)
138 logger.info("Training model completed.")
140 # persist save model
141 model_bytes = adapter.to_bytes()
142 self.model_storage.file_set(self.PERSIST_MODEL_FILE_NAME, model_bytes)
143 self.model_storage.json_set(self.PERSIST_ARGS_KEY_IN_JSON_STORAGE, serialize_args)
144 logger.info("Model and args saved.")
146 def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame:
147 rt_df = df.copy(deep=True)
148 # read model and args from storage
149 model_bytes = self.model_storage.file_get(self.PERSIST_MODEL_FILE_NAME)
150 deserialize_args = self.model_storage.json_get(self.PERSIST_ARGS_KEY_IN_JSON_STORAGE)
152 # resolve args
153 task = deserialize_args[self.ARG_USING_TASK]
154 model_type = deserialize_args[self.ARG_USING_MODEL_TYPE]
155 time_column = deserialize_args[self.ARG_TIME_COLUMN]
156 target = deserialize_args[self.ARG_TARGET]
157 horizon = deserialize_args[self.ARG_PREDICT_HORIZON]
158 feature_column_sequence = list(deserialize_args[self.ARG_COLUMN_SEQUENCE])
159 task_enum = TaskType[task]
161 # check df and prepare data
162 if task_enum == TaskType.forecast:
163 feature_column_sequence.remove(target)
164 missing_required_columns = set(feature_column_sequence) - set(rt_df.columns.values)
165 if len(missing_required_columns) > 0:
166 raise Exception("Missing required columns: " + ",".join(missing_required_columns))
167 feature_df = rt_df[feature_column_sequence]
168 ts_feature_df, _ = to_ts_dataframe(df=feature_df, time_col=time_column)
170 # init model adapter
171 adapter_class: BaseMerlionForecastAdapter = self.__args_to_adapter_class(task=task, model_type=model_type)
172 model_args = {}
173 if task_enum == TaskType.forecast:
174 model_args[MerlionArguments.max_forecast_steps.value] = horizon
175 adapter = adapter_class(**model_args)
176 adapter.initialize_model(bytes=model_bytes)
178 # predict
179 pred_df = adapter.predict(df=ts_feature_df, target=target)
181 # build result
182 pred_df = ts_feature_df[[]].join(pred_df, how="left")
184 # arrange data
185 pred_df.index = rt_df.index
186 if task_enum == TaskType.forecast:
187 pred_df = pred_df[~pred_df[target].isna()]
188 rt_df.drop(columns=[target], inplace=True)
189 elif task_enum == TaskType.detector:
190 pred_df[f"{target}__anomaly_score"].fillna(0, inplace=True)
191 rt_df = rt_df.join(pred_df, how="right")
192 return rt_df
194 def __args_to_adapter_class(self, task: str, model_type: str):
195 # check task_type
196 try:
197 task_enum = TaskType[task]
198 except Exception:
199 raise Exception("wrong using.task: " + task + ", valid options: " + enum_to_str(TaskType))
200 # check and get model class
201 try:
202 adapter_class = task_enum.value[model_type].value
203 except Exception as e:
204 raise Exception("Wrong using.model_type: " + model_type + ", valid options: "
205 + enum_to_str(task_enum.value) + ", " + str(e))
206 return adapter_class