Coverage for mindsdb / integrations / handlers / pycaret_handler / pycaret_handler.py: 0%
74 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from typing import Optional, Dict
2import os
4import pandas as pd
6from mindsdb.integrations.libs.base import BaseMLEngine
7from pycaret.classification import ClassificationExperiment
8from pycaret.regression import RegressionExperiment
9from pycaret.time_series import TSForecastingExperiment
10from pycaret.clustering import ClusteringExperiment
11from pycaret.anomaly import AnomalyExperiment
14class PyCaretHandler(BaseMLEngine):
15 name = 'pycaret'
17 def __init__(self, *args, **kwargs):
18 super().__init__(*args, **kwargs)
20 def create(self, target: str, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None:
21 """Create and train model on given data"""
22 # parse args
23 if 'using' not in args:
24 raise Exception("PyCaret engine requires a USING clause! Refer to its documentation for more details.")
25 using = args['using']
26 if df is None:
27 raise Exception("PyCaret engine requires a some data to initialize!")
28 # create experiment
29 s = self._get_experiment(using['model_type'])
30 s.setup(df, **self._get_experiment_setup_kwargs(using, args['target']))
31 # train model
32 model = self._train_model(s, using)
33 # save model and args
34 model_file_path = os.path.join(self.model_storage.fileStorage.folder_path, 'model')
35 s.save_model(model, model_file_path)
36 self.model_storage.json_set('saved_args', {
37 **args['using'],
38 'model_path': model_file_path
39 })
41 def predict(self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> pd.DataFrame:
42 """Predict on the given data"""
43 # load model
44 saved_args = self.model_storage.json_get('saved_args')
45 s = self._get_experiment(saved_args['model_type'])
46 model = s.load_model(saved_args['model_path'])
47 # predict and return
48 return self._predict_model(s, model, df, saved_args)
50 def _get_experiment(self, model_type):
51 """Returns one of the types of experiments in PyCaret"""
52 if model_type == "classification":
53 return ClassificationExperiment()
54 elif model_type == "regression":
55 return RegressionExperiment()
56 elif model_type == "time_series":
57 return TSForecastingExperiment()
58 elif model_type == "clustering":
59 return ClusteringExperiment()
60 elif model_type == "anomaly":
61 return AnomalyExperiment()
62 else:
63 raise Exception(f"Unrecognized model type '{model_type}'")
65 def _get_experiment_setup_kwargs(self, args: Dict, target: str):
66 """Returns the arguments that need to passed in setup function for the experiment"""
67 model_type = args['model_type']
68 # copy setup kwargs
69 kwargs = self._select_keys(args, "setup_")
70 # return kwargs
71 if model_type == 'classification' or model_type == 'regression' or model_type == 'time_series':
72 return {**kwargs, 'target': target}
73 elif model_type == 'clustering' or model_type == 'anomaly':
74 return {**kwargs}
75 raise Exception(f"Unrecognized model type '{model_type}'")
77 def _predict_model(self, s, model, df, args):
78 """Apply predictor arguments and get predictions"""
79 model_type = args["model_type"]
80 kwargs = self._select_keys(args, "predict_")
81 if (
82 model_type == 'classification'
83 or model_type == 'regression'
84 or model_type == 'clustering'
85 or model_type == 'anomaly'
86 ):
87 kwargs["data"] = df
88 elif model_type == 'time_series':
89 # do nothing
90 pass
91 else:
92 raise Exception(f"Unrecognized model type '{model_type}'")
93 return s.predict_model(model, **kwargs)
95 def _train_model(self, experiment, args):
96 """Train the model and return the best (if applicable)"""
97 model_type = args['model_type']
98 model_name = args['model_name']
99 kwargs = self._select_keys(args, "create_")
100 if (
101 model_type == 'classification'
102 or model_type == 'regression'
103 or model_type == 'time_series'
104 ) and model_name == 'best':
105 return experiment.compare_models(**kwargs)
106 if model_name == 'best':
107 raise Exception("Specific model name must be provided for clustering or anomaly tasks")
108 return experiment.create_model(model_name, **kwargs)
110 def _select_keys(self, d, prefix):
111 """Selects keys with given prefix and returns a new dict"""
112 result = {}
113 for k in d:
114 if k.startswith(prefix):
115 result[k[len(prefix):]] = d[k]
116 return result