Coverage for mindsdb / integrations / handlers / ludwig_handler / ludwig_handler.py: 0%

36 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from typing import Optional 

2 

3import dill 

4import dask 

5import pandas as pd 

6from ludwig.automl import auto_train 

7 

8from mindsdb.integrations.libs.base import BaseMLEngine 

9from .utils import RayConnection 

10 

11 

12class LudwigHandler(BaseMLEngine): 

13 """ 

14 Integration with the Ludwig declarative ML library. 

15 """ # noqa 

16 

17 name = 'ludwig' 

18 

19 def create(self, target: str, df: Optional[pd.DataFrame] = None, args: Optional[dict] = None) -> None: 

20 args = args['using'] # ignore the rest of the problem definition 

21 

22 # TODO: filter out incompatible use cases (e.g. time series won't work currently) 

23 # TODO: enable custom values via `args` (mindful of local vs cloud) 

24 user_config = {'hyperopt': {'executor': {'gpu_resources_per_trial': 0, 'num_samples': 3}}} # no GPU for now 

25 

26 with RayConnection(): 

27 results = auto_train( 

28 dataset=df, 

29 target=target, 

30 tune_for_memory=False, 

31 time_limit_s=120, 

32 user_config=user_config, 

33 # output_directory='./', 

34 # random_seed=42, 

35 # use_reference_config=False, 

36 # kwargs={} 

37 ) 

38 model = results.best_model 

39 args['dtype_dict'] = {f['name']: f['type'] for f in model.base_config['input_features']} 

40 args['accuracies'] = {'metric': results.experiment_analysis.best_result['metric_score']} 

41 self.model_storage.json_set('args', args) 

42 self.model_storage.file_set('model', dill.dumps(model)) 

43 

44 def predict(self, df, args=None): 

45 model = dill.loads(self.model_storage.file_get('model')) 

46 with RayConnection(): 

47 predictions = self._call_model(df, model) 

48 return predictions 

49 

50 @staticmethod 

51 def _call_model(df, model): 

52 predictions = dask.compute(model.predict(df)[0])[0] 

53 target_name = model.config['output_features'][0]['column'] 

54 

55 if target_name not in df: 

56 predictions.columns = [target_name] 

57 else: 

58 predictions.columns = ['prediction'] 

59 

60 predictions[f'{target_name}_explain'] = None 

61 joined = df.join(predictions) 

62 

63 if 'prediction' in joined: 

64 joined = joined.rename({ 

65 target_name: f'{target_name}_original', 

66 'prediction': target_name 

67 }, axis=1) 

68 return joined