Coverage for mindsdb / integrations / handlers / ludwig_handler / ludwig_handler.py: 0%
36 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from typing import Optional
3import dill
4import dask
5import pandas as pd
6from ludwig.automl import auto_train
8from mindsdb.integrations.libs.base import BaseMLEngine
9from .utils import RayConnection
12class LudwigHandler(BaseMLEngine):
13 """
14 Integration with the Ludwig declarative ML library.
15 """ # noqa
17 name = 'ludwig'
19 def create(self, target: str, df: Optional[pd.DataFrame] = None, args: Optional[dict] = None) -> None:
20 args = args['using'] # ignore the rest of the problem definition
22 # TODO: filter out incompatible use cases (e.g. time series won't work currently)
23 # TODO: enable custom values via `args` (mindful of local vs cloud)
24 user_config = {'hyperopt': {'executor': {'gpu_resources_per_trial': 0, 'num_samples': 3}}} # no GPU for now
26 with RayConnection():
27 results = auto_train(
28 dataset=df,
29 target=target,
30 tune_for_memory=False,
31 time_limit_s=120,
32 user_config=user_config,
33 # output_directory='./',
34 # random_seed=42,
35 # use_reference_config=False,
36 # kwargs={}
37 )
38 model = results.best_model
39 args['dtype_dict'] = {f['name']: f['type'] for f in model.base_config['input_features']}
40 args['accuracies'] = {'metric': results.experiment_analysis.best_result['metric_score']}
41 self.model_storage.json_set('args', args)
42 self.model_storage.file_set('model', dill.dumps(model))
44 def predict(self, df, args=None):
45 model = dill.loads(self.model_storage.file_get('model'))
46 with RayConnection():
47 predictions = self._call_model(df, model)
48 return predictions
50 @staticmethod
51 def _call_model(df, model):
52 predictions = dask.compute(model.predict(df)[0])[0]
53 target_name = model.config['output_features'][0]['column']
55 if target_name not in df:
56 predictions.columns = [target_name]
57 else:
58 predictions.columns = ['prediction']
60 predictions[f'{target_name}_explain'] = None
61 joined = df.join(predictions)
63 if 'prediction' in joined:
64 joined = joined.rename({
65 target_name: f'{target_name}_original',
66 'prediction': target_name
67 }, axis=1)
68 return joined