Coverage for mindsdb / integrations / handlers / autogluon_handler / autogluon_handler.py: 0%
35 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from typing import Optional
3import dill
4import pandas as pd
6from autogluon.tabular import TabularPredictor
7from type_infer.api import infer_types
9from mindsdb.integrations.libs.base import BaseMLEngine
10from mindsdb.utilities import log
11from .config import ClassificationConfig, RegressionConfig
14logger = log.getLogger(__name__)
17class AutoGluonHandler(BaseMLEngine):
18 name = "autogluon"
20 def create(self, target: str, df: Optional[pd.DataFrame] = None, args: Optional[dict] = None) -> None:
21 config_args = args['using']
23 target_dtype = infer_types(df).to_dict()["dtypes"][target]
25 if target_dtype in ['binary', 'categorical', 'tags']:
26 config = ClassificationConfig(**config_args)
28 model = TabularPredictor(label=target, )
29 elif target_dtype in ['integer', 'float', 'quantity']:
30 config = RegressionConfig(**config_args)
32 model = TabularPredictor(label=target)
34 else:
35 raise Exception('This task is not supported!')
37 model.fit(df, **vars(config))
38 self.model_storage.file_set('model', dill.dumps(model))
39 self.model_storage.json_set('args', args)
41 def predict(self, df: Optional[pd.DataFrame] = None, args: Optional[dict] = None) -> None:
42 model = dill.loads(self.model_storage.file_get('model'))
43 df = df.drop('__mindsdb_row_id', axis=1)
45 predictions = model.predict(df)
47 args = self.model_storage.json_get('args')
48 df[args['target']] = predictions
50 return df
52 def describe(self, attribute: Optional[str] = None) -> pd.DataFrame:
54 args = self.model_storage.json_get("args")
56 if attribute == "args":
57 return pd.DataFrame(args.items(), columns=["key", "value"])