Tutorial - Implementing a custom mixer in Lightwood
Introduction
Mixers are the center piece of lightwood, tasked with learning the mapping between the encoded feature and target representation
Objective
In this tutorial we’ll be trying to implement a sklearn random forest as a mixer that handles categorical and binary targets.
Step 1: The Mixer Interface
The Mixer interface is defined by the BaseMixer
class, a mixer needs methods for 4 tasks: * fitting (fit
) * predicting (__call__
) * construction (__init__
) * partial fitting (partial_fit
), though this one is optional
Step 2: Writing our mixer
I’m going to create a file called random_forest_mixer.py
inside /etc/lightwood_modules
, this is where lightwood sources custom modules from.
Inside of it I’m going to write the following code:
[1]:
%%writefile random_forest_mixer.py
from lightwood.mixer import BaseMixer
from lightwood.api.types import PredictionArguments
from lightwood.data.encoded_ds import EncodedDs, ConcatedEncodedDs
from type_infer.dtype import dtype
from lightwood.encoder import BaseEncoder
import torch
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
class RandomForestMixer(BaseMixer):
clf: RandomForestClassifier
def __init__(self, stop_after: int, dtype_dict: dict, target: str, target_encoder: BaseEncoder):
super().__init__(stop_after)
self.target_encoder = target_encoder
# Throw in case someone tries to use this for a problem that's not classification, I'd fail anyway, but this way the error message is more intuitive
if dtype_dict[target] not in (dtype.categorical, dtype.binary):
raise Exception(f'This mixer can only be used for classification problems! Got target dtype {dtype_dict[target]} instead!')
# We could also initialize this in `fit` if some of the parameters depend on the input data, since `fit` is called exactly once
self.clf = RandomForestClassifier(max_depth=30)
def fit(self, train_data: EncodedDs, dev_data: EncodedDs) -> None:
X, Y = [], []
# By default mixers get some train data and a bit of dev data on which to do early stopping or hyper parameter optimization. For this mixer, we don't need dev data, so we're going to concat the two in order to get more training data. Then, we're going to turn them into an sklearn friendly foramat.
for x, y in ConcatedEncodedDs([train_data, dev_data]):
X.append(x.tolist())
Y.append(y.tolist())
self.clf.fit(X, Y)
def __call__(self, ds: EncodedDs,
args: PredictionArguments = PredictionArguments()) -> pd.DataFrame:
# Turn the data into an sklearn friendly format
X = []
for x, _ in ds:
X.append(x.tolist())
Yh = self.clf.predict(X)
# Lightwood encoders are meant to decode torch tensors, so we have to cast the predictions first
decoded_predictions = self.target_encoder.decode(torch.Tensor(Yh))
# Finally, turn the decoded predictions into a dataframe with a single column called `prediction`. This is the standard behaviour all lightwood mixers use
ydf = pd.DataFrame({'prediction': decoded_predictions})
return ydf
# We'll skip implementing `partial_fit`, thus making this mixer unsuitable for online training tasks
Writing random_forest_mixer.py
Step 3: Using our mixer
We’re going to use our mixer for diagnosing heart disease using this dataset: https://github.com/mindsdb/benchmarks/blob/main/benchmarks/datasets/heart_disease/data.csv
First, since we don’t want to bother writing a Json AI for this dataset from scratch, we’re going to let lightwood auto generate one.
[2]:
from lightwood.api.high_level import ProblemDefinition, json_ai_from_problem, load_custom_module
import pandas as pd
# load the code
load_custom_module('random_forest_mixer.py')
# read dataset
df = pd.read_csv('https://raw.githubusercontent.com/mindsdb/benchmarks/main/benchmarks/datasets/heart_disease/data.csv')
# define the predictive task
pdef = ProblemDefinition.from_dict({
'target': 'target', # column you want to predict
})
# generate the Json AI intermediate representation from the data and its corresponding settings
json_ai = json_ai_from_problem(df, problem_definition=pdef)
# Print it (you can also put it in a file and edit it there)
print(json_ai.to_json())
INFO:lightwood-2627:No torchvision detected, image helpers not supported.
INFO:lightwood-2627:No torchvision/pillow detected, image encoder not supported
INFO:type_infer-2627:Analyzing a sample of 298
INFO:type_infer-2627:from a total population of 303, this is equivalent to 98.3% of your data.
INFO:type_infer-2627:Infering type for: age
INFO:type_infer-2627:Column age has data type integer
INFO:type_infer-2627:Infering type for: sex
INFO:type_infer-2627:Column sex has data type binary
INFO:type_infer-2627:Infering type for: cp
INFO:type_infer-2627:Column cp has data type categorical
INFO:type_infer-2627:Infering type for: trestbps
INFO:type_infer-2627:Column trestbps has data type integer
INFO:type_infer-2627:Infering type for: chol
INFO:type_infer-2627:Column chol has data type integer
INFO:type_infer-2627:Infering type for: fbs
INFO:type_infer-2627:Column fbs has data type binary
INFO:type_infer-2627:Infering type for: restecg
INFO:type_infer-2627:Column restecg has data type categorical
INFO:type_infer-2627:Infering type for: thalach
INFO:type_infer-2627:Column thalach has data type integer
INFO:type_infer-2627:Infering type for: exang
INFO:type_infer-2627:Column exang has data type binary
INFO:type_infer-2627:Infering type for: oldpeak
INFO:type_infer-2627:Column oldpeak has data type float
INFO:type_infer-2627:Infering type for: slope
INFO:type_infer-2627:Column slope has data type categorical
INFO:type_infer-2627:Infering type for: ca
INFO:type_infer-2627:Column ca has data type categorical
INFO:type_infer-2627:Infering type for: thal
INFO:type_infer-2627:Column thal has data type categorical
INFO:type_infer-2627:Infering type for: target
INFO:type_infer-2627:Column target has data type binary
INFO:dataprep_ml-2627:Starting statistical analysis
INFO:dataprep_ml-2627:Finished statistical analysis
{
"encoders": {
"target": {
"module": "BinaryEncoder",
"args": {
"is_target": "True",
"target_weights": "$statistical_analysis.target_weights"
}
},
"age": {
"module": "NumericEncoder",
"args": {}
},
"sex": {
"module": "BinaryEncoder",
"args": {}
},
"cp": {
"module": "OneHotEncoder",
"args": {}
},
"trestbps": {
"module": "NumericEncoder",
"args": {}
},
"chol": {
"module": "NumericEncoder",
"args": {}
},
"fbs": {
"module": "BinaryEncoder",
"args": {}
},
"restecg": {
"module": "OneHotEncoder",
"args": {}
},
"thalach": {
"module": "NumericEncoder",
"args": {}
},
"exang": {
"module": "BinaryEncoder",
"args": {}
},
"oldpeak": {
"module": "NumericEncoder",
"args": {}
},
"slope": {
"module": "OneHotEncoder",
"args": {}
},
"ca": {
"module": "OneHotEncoder",
"args": {}
},
"thal": {
"module": "OneHotEncoder",
"args": {}
}
},
"dtype_dict": {
"age": "integer",
"sex": "binary",
"cp": "categorical",
"trestbps": "integer",
"chol": "integer",
"fbs": "binary",
"restecg": "categorical",
"thalach": "integer",
"exang": "binary",
"oldpeak": "float",
"slope": "categorical",
"ca": "categorical",
"thal": "categorical",
"target": "binary"
},
"dependency_dict": {},
"model": {
"module": "BestOf",
"args": {
"submodels": [
{
"module": "Neural",
"args": {
"fit_on_dev": true,
"stop_after": "$problem_definition.seconds_per_mixer",
"search_hyperparameters": true
}
},
{
"module": "XGBoostMixer",
"args": {
"stop_after": "$problem_definition.seconds_per_mixer",
"fit_on_dev": true
}
},
{
"module": "Regression",
"args": {
"stop_after": "$problem_definition.seconds_per_mixer"
}
},
{
"module": "RandomForest",
"args": {
"stop_after": "$problem_definition.seconds_per_mixer",
"fit_on_dev": true
}
}
]
}
},
"problem_definition": {
"target": "target",
"pct_invalid": 2,
"unbias_target": true,
"seconds_per_mixer": 42768.0,
"seconds_per_encoder": null,
"expected_additional_time": 0.06708312034606934,
"time_aim": 259200,
"target_weights": null,
"positive_domain": false,
"timeseries_settings": {
"is_timeseries": false,
"order_by": null,
"window": null,
"group_by": null,
"use_previous_target": true,
"horizon": null,
"historical_columns": null,
"target_type": "",
"allow_incomplete_history": true,
"eval_incomplete": false,
"interval_periods": []
},
"anomaly_detection": false,
"use_default_analysis": true,
"embedding_only": false,
"dtype_dict": {},
"ignore_features": [],
"fit_on_all": true,
"strict_mode": true,
"seed_nr": 1
},
"identifiers": {},
"imputers": [],
"accuracy_functions": [
"balanced_accuracy_score"
]
}
Now we have to edit the mixers
key of this json ai to tell lightwood to use our custom mixer. We can use it together with the others, and have it ensembled with them at the end, or standalone. In this case I’m going to replace all existing mixers with this one
[3]:
json_ai.model['args']['submodels'] = [{
'module': 'random_forest_mixer.RandomForestMixer',
'args': {
'stop_after': '$problem_definition.seconds_per_mixer',
'dtype_dict': '$dtype_dict',
'target': '$target',
'target_encoder': '$encoders[self.target]'
}
}]
Then we’ll generate some code, and finally turn that code into a predictor object and fit it on the original data.
[4]:
from lightwood.api.high_level import code_from_json_ai, predictor_from_code
code = code_from_json_ai(json_ai)
predictor = predictor_from_code(code)
[5]:
predictor.learn(df)
INFO:dataprep_ml-2627:[Learn phase 1/8] - Statistical analysis
INFO:dataprep_ml-2627:Starting statistical analysis
INFO:dataprep_ml-2627:Finished statistical analysis
DEBUG:lightwood-2627: `analyze_data` runtime: 0.03 seconds
INFO:dataprep_ml-2627:[Learn phase 2/8] - Data preprocessing
INFO:dataprep_ml-2627:Cleaning the data
DEBUG:lightwood-2627: `preprocess` runtime: 0.01 seconds
INFO:dataprep_ml-2627:[Learn phase 3/8] - Data splitting
INFO:dataprep_ml-2627:Splitting the data into train/test
DEBUG:lightwood-2627: `split` runtime: 0.01 seconds
INFO:dataprep_ml-2627:[Learn phase 4/8] - Preparing encoders
DEBUG:dataprep_ml-2627:Preparing sequentially...
DEBUG:dataprep_ml-2627:Preparing encoder for age...
DEBUG:dataprep_ml-2627:Preparing encoder for sex...
DEBUG:dataprep_ml-2627:Preparing encoder for cp...
DEBUG:lightwood-2627:Encoding UNKNOWN categories as index 0
DEBUG:dataprep_ml-2627:Preparing encoder for trestbps...
DEBUG:dataprep_ml-2627:Preparing encoder for chol...
DEBUG:dataprep_ml-2627:Preparing encoder for fbs...
DEBUG:dataprep_ml-2627:Preparing encoder for restecg...
DEBUG:lightwood-2627:Encoding UNKNOWN categories as index 0
DEBUG:dataprep_ml-2627:Preparing encoder for thalach...
DEBUG:dataprep_ml-2627:Preparing encoder for exang...
DEBUG:dataprep_ml-2627:Preparing encoder for oldpeak...
DEBUG:dataprep_ml-2627:Preparing encoder for slope...
DEBUG:lightwood-2627:Encoding UNKNOWN categories as index 0
DEBUG:dataprep_ml-2627:Preparing encoder for ca...
DEBUG:lightwood-2627:Encoding UNKNOWN categories as index 0
DEBUG:dataprep_ml-2627:Preparing encoder for thal...
DEBUG:lightwood-2627:Encoding UNKNOWN categories as index 0
DEBUG:lightwood-2627: `prepare` runtime: 0.02 seconds
INFO:dataprep_ml-2627:[Learn phase 5/8] - Feature generation
INFO:dataprep_ml-2627:Featurizing the data
DEBUG:lightwood-2627: `featurize` runtime: 0.09 seconds
INFO:dataprep_ml-2627:[Learn phase 6/8] - Mixer training
INFO:dataprep_ml-2627:Training the mixers
DEBUG:lightwood-2627: `fit_mixer` runtime: 0.12 seconds
INFO:dataprep_ml-2627:Ensembling the mixer
INFO:lightwood-2627:Mixer: RandomForestMixer got accuracy: 0.798
INFO:lightwood-2627:Picked best mixer: RandomForestMixer
DEBUG:lightwood-2627: `fit` runtime: 0.13 seconds
INFO:dataprep_ml-2627:[Learn phase 7/8] - Ensemble analysis
INFO:dataprep_ml-2627:Analyzing the ensemble of mixers
INFO:lightwood-2627:The block ICP is now running its analyze() method
INFO:lightwood-2627:The block ConfStats is now running its analyze() method
INFO:lightwood-2627:The block AccStats is now running its analyze() method
INFO:lightwood-2627:The block PermutationFeatureImportance is now running its analyze() method
INFO:lightwood-2627:[PFI] Using a random sample (1000 rows out of 31).
INFO:lightwood-2627:[PFI] Set to consider first 10 columns out of 10: ['age', 'sex', 'cp', 'trestbps', 'chol', 'fbs', 'restecg', 'thalach', 'exang', 'oldpeak'].
/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/sklearn/metrics/_classification.py:2480: UserWarning: y_pred contains classes not in y_true
warnings.warn("y_pred contains classes not in y_true")
/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/sklearn/metrics/_classification.py:2480: UserWarning: y_pred contains classes not in y_true
warnings.warn("y_pred contains classes not in y_true")
DEBUG:lightwood-2627: `analyze_ensemble` runtime: 0.24 seconds
INFO:dataprep_ml-2627:[Learn phase 8/8] - Adjustment on validation requested
INFO:dataprep_ml-2627:Updating the mixers
DEBUG:lightwood-2627: `adjust` runtime: 0.04 seconds
DEBUG:lightwood-2627: `learn` runtime: 0.59 seconds
Finally, we can use the trained predictor to make some predictions, or save it to a pickle for later use
[6]:
predictions = predictor.predict(pd.DataFrame({
'age': [63, 15, None],
'sex': [1, 1, 0],
'thal': [3, 1, 1]
}))
print(predictions)
predictor.save('my_custom_heart_disease_predictor.pickle')
INFO:dataprep_ml-2627:[Predict phase 1/4] - Data preprocessing
INFO:dataprep_ml-2627:Cleaning the data
DEBUG:lightwood-2627: `preprocess` runtime: 0.01 seconds
INFO:dataprep_ml-2627:[Predict phase 2/4] - Feature generation
INFO:dataprep_ml-2627:Featurizing the data
/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/numpy/lib/function_base.py:2455: RuntimeWarning: invalid value encountered in _none_fn (vectorized)
outputs = ufunc(*inputs)
/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/numpy/lib/function_base.py:2455: RuntimeWarning: invalid value encountered in _none_fn (vectorized)
outputs = ufunc(*inputs)
/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/numpy/lib/function_base.py:2455: RuntimeWarning: invalid value encountered in _none_fn (vectorized)
outputs = ufunc(*inputs)
/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/numpy/lib/function_base.py:2455: RuntimeWarning: invalid value encountered in _none_fn (vectorized)
outputs = ufunc(*inputs)
/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/numpy/lib/function_base.py:2455: RuntimeWarning: invalid value encountered in _none_fn (vectorized)
outputs = ufunc(*inputs)
/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/numpy/lib/function_base.py:2455: RuntimeWarning: invalid value encountered in _none_fn (vectorized)
outputs = ufunc(*inputs)
/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/numpy/lib/function_base.py:2455: RuntimeWarning: invalid value encountered in _none_fn (vectorized)
outputs = ufunc(*inputs)
/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/numpy/lib/function_base.py:2455: RuntimeWarning: invalid value encountered in _none_fn (vectorized)
outputs = ufunc(*inputs)
DEBUG:lightwood-2627: `featurize` runtime: 0.02 seconds
INFO:dataprep_ml-2627:[Predict phase 3/4] - Calling ensemble
DEBUG:lightwood-2627: `_timed_call` runtime: 0.0 seconds
INFO:dataprep_ml-2627:[Predict phase 4/4] - Analyzing output
INFO:lightwood-2627:The block ICP is now running its explain() method
INFO:lightwood-2627:The block ConfStats is now running its explain() method
INFO:lightwood-2627:ConfStats.explain() has not been implemented, no modifications will be done to the data insights.
INFO:lightwood-2627:The block AccStats is now running its explain() method
INFO:lightwood-2627:AccStats.explain() has not been implemented, no modifications will be done to the data insights.
INFO:lightwood-2627:The block PermutationFeatureImportance is now running its explain() method
INFO:lightwood-2627:PermutationFeatureImportance.explain() has not been implemented, no modifications will be done to the data insights.
DEBUG:lightwood-2627: `explain` runtime: 0.01 seconds
DEBUG:lightwood-2627: `predict` runtime: 0.05 seconds
original_index prediction confidence
0 0 1 0.073676
1 1 0 0.250612
2 2 0 0.462595
That’s it, all it takes to solve a predictive problem with lightwood using your own custom mixer.