Coverage for mindsdb / integrations / handlers / autogluon_handler / autogluon_handler.py: 0%

35 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from typing import Optional 

2 

3import dill 

4import pandas as pd 

5 

6from autogluon.tabular import TabularPredictor 

7from type_infer.api import infer_types 

8 

9from mindsdb.integrations.libs.base import BaseMLEngine 

10from mindsdb.utilities import log 

11from .config import ClassificationConfig, RegressionConfig 

12 

13 

14logger = log.getLogger(__name__) 

15 

16 

17class AutoGluonHandler(BaseMLEngine): 

18 name = "autogluon" 

19 

20 def create(self, target: str, df: Optional[pd.DataFrame] = None, args: Optional[dict] = None) -> None: 

21 config_args = args['using'] 

22 

23 target_dtype = infer_types(df).to_dict()["dtypes"][target] 

24 

25 if target_dtype in ['binary', 'categorical', 'tags']: 

26 config = ClassificationConfig(**config_args) 

27 

28 model = TabularPredictor(label=target, ) 

29 elif target_dtype in ['integer', 'float', 'quantity']: 

30 config = RegressionConfig(**config_args) 

31 

32 model = TabularPredictor(label=target) 

33 

34 else: 

35 raise Exception('This task is not supported!') 

36 

37 model.fit(df, **vars(config)) 

38 self.model_storage.file_set('model', dill.dumps(model)) 

39 self.model_storage.json_set('args', args) 

40 

41 def predict(self, df: Optional[pd.DataFrame] = None, args: Optional[dict] = None) -> None: 

42 model = dill.loads(self.model_storage.file_get('model')) 

43 df = df.drop('__mindsdb_row_id', axis=1) 

44 

45 predictions = model.predict(df) 

46 

47 args = self.model_storage.json_get('args') 

48 df[args['target']] = predictions 

49 

50 return df 

51 

52 def describe(self, attribute: Optional[str] = None) -> pd.DataFrame: 

53 

54 args = self.model_storage.json_get("args") 

55 

56 if attribute == "args": 

57 return pd.DataFrame(args.items(), columns=["key", "value"])