Coverage for mindsdb / integrations / handlers / autosklearn_handler / autosklearn_handler.py: 0%
30 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from typing import Optional
3import dill
4import pandas as pd
5from type_infer.api import infer_types
6import autosklearn.classification as automl_classification
7import autosklearn.regression as automl_regression
9from .config import ClassificationConfig, RegressionConfig
11from mindsdb.integrations.libs.base import BaseMLEngine
14class AutoSklearnHandler(BaseMLEngine):
15 """
16 Integration with the Auto-Sklearn ML library.
17 """
19 name = 'autosklearn'
21 def create(self, target: str, df: Optional[pd.DataFrame] = None, args: Optional[dict] = None) -> None:
22 config_args = args['using']
24 target_dtype = infer_types(df, 0).to_dict()["dtypes"][target]
26 if target_dtype in ['binary', 'categorical', 'tags']:
27 config = ClassificationConfig(**config_args)
29 model = automl_classification.AutoSklearnClassifier(**vars(config))
31 elif target_dtype in ['integer', 'float', 'quantity']:
32 config = RegressionConfig(**config_args)
34 model = automl_regression.AutoSklearnRegressor(**vars(config))
36 else:
37 raise Exception('This task is not supported!')
39 model.fit(df.drop(target, axis=1), df[target])
41 self.model_storage.file_set('model', dill.dumps(model))
42 self.model_storage.json_set('args', args)
44 def predict(self, df: Optional[pd.DataFrame] = None, args: Optional[dict] = None) -> None:
45 model = dill.loads(self.model_storage.file_get('model'))
46 df = df.drop('__mindsdb_row_id', axis=1)
48 predictions = model.predict(df)
50 args = self.model_storage.json_get('args')
51 df[args['target']] = predictions
53 return df