Coverage for mindsdb / integrations / handlers / autosklearn_handler / autosklearn_handler.py: 0%

30 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from typing import Optional 

2 

3import dill 

4import pandas as pd 

5from type_infer.api import infer_types 

6import autosklearn.classification as automl_classification 

7import autosklearn.regression as automl_regression 

8 

9from .config import ClassificationConfig, RegressionConfig 

10 

11from mindsdb.integrations.libs.base import BaseMLEngine 

12 

13 

14class AutoSklearnHandler(BaseMLEngine): 

15 """ 

16 Integration with the Auto-Sklearn ML library. 

17 """ 

18 

19 name = 'autosklearn' 

20 

21 def create(self, target: str, df: Optional[pd.DataFrame] = None, args: Optional[dict] = None) -> None: 

22 config_args = args['using'] 

23 

24 target_dtype = infer_types(df, 0).to_dict()["dtypes"][target] 

25 

26 if target_dtype in ['binary', 'categorical', 'tags']: 

27 config = ClassificationConfig(**config_args) 

28 

29 model = automl_classification.AutoSklearnClassifier(**vars(config)) 

30 

31 elif target_dtype in ['integer', 'float', 'quantity']: 

32 config = RegressionConfig(**config_args) 

33 

34 model = automl_regression.AutoSklearnRegressor(**vars(config)) 

35 

36 else: 

37 raise Exception('This task is not supported!') 

38 

39 model.fit(df.drop(target, axis=1), df[target]) 

40 

41 self.model_storage.file_set('model', dill.dumps(model)) 

42 self.model_storage.json_set('args', args) 

43 

44 def predict(self, df: Optional[pd.DataFrame] = None, args: Optional[dict] = None) -> None: 

45 model = dill.loads(self.model_storage.file_get('model')) 

46 df = df.drop('__mindsdb_row_id', axis=1) 

47 

48 predictions = model.predict(df) 

49 

50 args = self.model_storage.json_get('args') 

51 df[args['target']] = predictions 

52 

53 return df