Coverage for mindsdb / integrations / handlers / tpot_handler / tpot_handler.py: 0%
43 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import dill
2import pandas as pd
3from mindsdb.integrations.libs.base import BaseMLEngine
4from typing import Dict, Optional
5from type_infer.api import infer_types
6from tpot import TPOTClassifier, TPOTRegressor
7from sklearn.preprocessing import LabelEncoder
10class TPOTHandler(BaseMLEngine):
11 name = "TPOT"
13 def create(self, target: str, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None:
14 if args is None:
15 args = {}
16 type_of_cols = infer_types(df, 0).dtypes
17 target_dtype = type_of_cols[target]
19 if target_dtype in ['binary', 'categorical', 'tags']:
20 model = TPOTClassifier(generations=args.get('generations', 10),
21 population_size=args.get('population_size', 100),
22 verbosity=0,
23 max_time_mins=args.get('max_time_mins', None),
24 n_jobs=args.get('n_jobs', -1))
26 elif target_dtype in ['integer', 'float', 'quantity']:
27 model = TPOTRegressor(generations=args.get('generations', 10),
28 population_size=args.get('population_size', 100),
29 verbosity=0,
30 max_time_mins=args.get('max_time_mins', None),
31 n_jobs=args.get('n_jobs', -1))
33 if df is not None:
34 # Separate out the categorical and non-categorical columns
35 categorical_cols = [col for col, type_col in type_of_cols.items() if type_col in ('categorical', 'binary')]
37 # Fit a LabelEncoder for each categorical column and store it in a dictionary
38 le_dict = {}
39 for col in categorical_cols:
40 le = LabelEncoder()
41 le.fit(df[col])
42 le_dict[col] = le
44 # Encode the categorical column using the fitted LabelEncoder
45 df[col] = le.transform(df[col])
47 model.fit(df.drop(columns=[target]), df[target])
48 self.model_storage.json_set('args', args)
49 self.model_storage.file_set('le_dict', dill.dumps(le_dict))
50 self.model_storage.file_set('model', dill.dumps(model.fitted_pipeline_))
51 else:
52 raise Exception(
53 "Data is empty!!"
54 )
56 def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame:
58 model = dill.loads(self.model_storage.file_get("model"))
59 le_dict = dill.loads(self.model_storage.file_get("le_dict"))
60 target = self.model_storage.json_get('args').get("target")
62 # Encode the categorical columns in the input DataFrame using the saved LabelEncoders
63 for col, le in le_dict.items():
64 if col in df.columns:
65 df[col] = le.transform(df[col])
67 # Make predictions using the trained TPOT model
68 results = pd.DataFrame(model.predict(df), columns=[target])
70 # Decode the predicted categorical values back into their original values
71 for col, le in le_dict.items():
72 if col in results.columns:
73 results[col] = le.inverse_transform(results[col])
75 return results