Tutorial - Implementing a custom mixer in Lightwood¶
Introduction¶
Mixers are the center piece of lightwood, tasked with learning the mapping between the encoded feature and target representation
Objective¶
In this tutorial we’ll be trying to implement a sklearn random forest as a mixer that handles categorical and binary targets.
Step 1: The Mixer Interface¶
The Mixer interface is defined by the BaseMixer
class, a mixer needs methods for 4 tasks: * fitting (fit
) * predicting (__call__
) * construction (__init__
) * partial fitting (partial_fit
), though this one is optional
Step 2: Writing our mixer¶
I’m going to create a file called random_forest_mixer.py
inside /etc/lightwood_modules
, this is where lightwood sources custom modules from.
Inside of it I’m going to write the following code:
[1]:
from lightwood.mixer import BaseMixer
from lightwood.api.types import PredictionArguments
from lightwood.data.encoded_ds import EncodedDs, ConcatedEncodedDs
from lightwood import dtype
from lightwood.encoder import BaseEncoder
import torch
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
class RandomForestMixer(BaseMixer):
clf: RandomForestClassifier
def __init__(self, stop_after: int, dtype_dict: dict, target: str, target_encoder: BaseEncoder):
super().__init__(stop_after)
self.target_encoder = target_encoder
# Throw in case someone tries to use this for a problem that's not classification, I'd fail anyway, but this way the error message is more intuitive
if dtype_dict[target] not in (dtype.categorical, dtype.binary):
raise Exception(f'This mixer can only be used for classification problems! Got target dtype {dtype_dict[target]} instead!')
# We could also initialize this in `fit` if some of the parameters depend on the input data, since `fit` is called exactly once
self.clf = RandomForestClassifier(max_depth=30)
def fit(self, train_data: EncodedDs, dev_data: EncodedDs) -> None:
X, Y = [], []
# By default mixers get some train data and a bit of dev data on which to do early stopping or hyper parameter optimization. For this mixer, we don't need dev data, so we're going to concat the two in order to get more training data. Then, we're going to turn them into an sklearn friendly foramat.
for x, y in ConcatedEncodedDs([train_data, dev_data]):
X.append(x.tolist())
Y.append(y.tolist())
self.clf.fit(X, Y)
def __call__(self, ds: EncodedDs,
args: PredictionArguments = PredictionArguments()) -> pd.DataFrame:
# Turn the data into an sklearn friendly format
X = []
for x, _ in ds:
X.append(x.tolist())
Yh = self.clf.predict(X)
# Lightwood encoders are meant to decode torch tensors, so we have to cast the predictions first
decoded_predictions = self.target_encoder.decode(torch.Tensor(Yh))
# Finally, turn the decoded predictions into a dataframe with a single column called `prediction`. This is the standard behaviour all lightwood mixers use
ydf = pd.DataFrame({'prediction': decoded_predictions})
return ydf
# We'll skip implementing `partial_fit`, thus making this mixer unsuitable for online training tasks
Step 3: Using our mixer¶
We’re going to use our mixer for diagnosing heart disease using this dataset: https://github.com/mindsdb/benchmarks/blob/main/benchmarks/datasets/heart_disease/data.csv
First, since we don’t want to bother writing a Json AI for this dataset from scratch, we’re going to let lightwood auto generate one.
[2]:
from lightwood.api.high_level import ProblemDefinition, json_ai_from_problem
import pandas as pd
# read dataset
df = pd.read_csv('https://raw.githubusercontent.com/mindsdb/benchmarks/main/benchmarks/datasets/heart_disease/data.csv')
# define the predictive task
pdef = ProblemDefinition.from_dict({
'target': 'target', # column you want to predict
})
# generate the Json AI intermediate representation from the data and its corresponding settings
json_ai = json_ai_from_problem(df, problem_definition=pdef)
# Print it (you can also put it in a file and edit it there)
print(json_ai.to_json())
INFO:lightwood-56096:Dropping features: []
INFO:lightwood-56096:Analyzing a sample of 298
INFO:lightwood-56096:from a total population of 303, this is equivalent to 98.3% of your data.
INFO:lightwood-56096:Using 15 processes to deduct types.
INFO:lightwood-56096:Infering type for: age
INFO:lightwood-56096:Infering type for: sex
INFO:lightwood-56096:Infering type for: cp
INFO:lightwood-56096:Infering type for: trestbps
INFO:lightwood-56096:Infering type for: fbs
INFO:lightwood-56096:Infering type for: chol
INFO:lightwood-56096:Infering type for: thalach
INFO:lightwood-56096:Infering type for: restecg
INFO:lightwood-56096:Infering type for: exang
INFO:lightwood-56096:Infering type for: ca
INFO:lightwood-56096:Infering type for: slope
INFO:lightwood-56096:Infering type for: thal
INFO:lightwood-56096:Column age has data type integer
INFO:lightwood-56096:Infering type for: target
INFO:lightwood-56096:Column sex has data type binary
INFO:lightwood-56096:Column fbs has data type binary
INFO:lightwood-56096:Column cp has data type categorical
INFO:lightwood-56096:Infering type for: oldpeak
INFO:lightwood-56096:Column trestbps has data type integer
INFO:lightwood-56096:Column chol has data type integer
INFO:lightwood-56096:Column thalach has data type integer
INFO:lightwood-56096:Column restecg has data type categorical
INFO:lightwood-56096:Column exang has data type binary
INFO:lightwood-56096:Column ca has data type categorical
INFO:lightwood-56096:Column slope has data type categorical
INFO:lightwood-56096:Column thal has data type categorical
INFO:lightwood-56096:Column target has data type binary
INFO:lightwood-56096:Column oldpeak has data type float
INFO:lightwood-56096:Starting statistical analysis
INFO:lightwood-56096:Finished statistical analysis
random_forest_mixer.py
random_forest_mixer
{
"features": {
"age": {
"encoder": {
"module": "Integer.NumericEncoder",
"args": {}
}
},
"sex": {
"encoder": {
"module": "Binary.BinaryEncoder",
"args": {}
}
},
"cp": {
"encoder": {
"module": "Categorical.OneHotEncoder",
"args": {}
}
},
"trestbps": {
"encoder": {
"module": "Integer.NumericEncoder",
"args": {}
}
},
"chol": {
"encoder": {
"module": "Integer.NumericEncoder",
"args": {}
}
},
"fbs": {
"encoder": {
"module": "Binary.BinaryEncoder",
"args": {}
}
},
"restecg": {
"encoder": {
"module": "Categorical.OneHotEncoder",
"args": {}
}
},
"thalach": {
"encoder": {
"module": "Integer.NumericEncoder",
"args": {}
}
},
"exang": {
"encoder": {
"module": "Binary.BinaryEncoder",
"args": {}
}
},
"oldpeak": {
"encoder": {
"module": "Float.NumericEncoder",
"args": {}
}
},
"slope": {
"encoder": {
"module": "Categorical.OneHotEncoder",
"args": {}
}
},
"ca": {
"encoder": {
"module": "Categorical.OneHotEncoder",
"args": {}
}
},
"thal": {
"encoder": {
"module": "Categorical.OneHotEncoder",
"args": {}
}
}
},
"outputs": {
"target": {
"data_dtype": "binary",
"encoder": {
"module": "Binary.BinaryEncoder",
"args": {
"is_target": "True",
"target_class_distribution": "$statistical_analysis.target_class_distribution"
}
},
"mixers": [
{
"module": "Neural",
"args": {
"fit_on_dev": true,
"stop_after": "$problem_definition.seconds_per_mixer",
"search_hyperparameters": true
}
},
{
"module": "LightGBM",
"args": {
"stop_after": "$problem_definition.seconds_per_mixer",
"fit_on_dev": true
}
},
{
"module": "Regression",
"args": {
"stop_after": "$problem_definition.seconds_per_mixer"
}
}
],
"ensemble": {
"module": "BestOf",
"args": {
"args": "$pred_args",
"accuracy_functions": "$accuracy_functions",
"ts_analysis": null
}
}
}
},
"problem_definition": {
"target": "target",
"pct_invalid": 2,
"unbias_target": true,
"seconds_per_mixer": 2364,
"seconds_per_encoder": 0,
"time_aim": 10642.1306731291,
"target_weights": null,
"positive_domain": false,
"timeseries_settings": {
"is_timeseries": false,
"order_by": null,
"window": null,
"group_by": null,
"use_previous_target": true,
"nr_predictions": null,
"historical_columns": null,
"target_type": "",
"allow_incomplete_history": false
},
"anomaly_detection": true,
"ignore_features": [],
"fit_on_validation": true,
"strict_mode": true,
"seed_nr": 420
},
"identifiers": {},
"accuracy_functions": [
"balanced_accuracy_score"
]
}
Now we have to edit the mixers
key of this json ai to tell lightwood to use our custom mixer. We can use it together with the others, and have it ensembled with them at the end, or standalone. In this case I’m going to replace all existing mixers with this one
[3]:
json_ai.outputs['target'].mixers = [{
'module': 'random_forest_mixer.RandomForestMixer',
'args': {
'stop_after': '$problem_definition.seconds_per_mixer',
'dtype_dict': '$dtype_dict',
'target': '$target',
'target_encoder': '$encoders[self.target]'
}
}]
Then we’ll generate some code, and finally turn that code into a predictor object and fit it on the original data.
[4]:
from lightwood.api.high_level import code_from_json_ai, predictor_from_code
code = code_from_json_ai(json_ai)
predictor = predictor_from_code(code)
random_forest_mixer.py
random_forest_mixer
[5]:
predictor.learn(df)
INFO:lightwood-56096:Dropping features: []
INFO:lightwood-56096:Performing statistical analysis on data
INFO:lightwood-56096:Starting statistical analysis
INFO:lightwood-56096:Finished statistical analysis
INFO:lightwood-56096:Cleaning the data
INFO:lightwood-56096:Splitting the data into train/test
INFO:lightwood-56096:Preparing the encoders
INFO:lightwood-56096:Encoder prepping dict length of: 1
INFO:lightwood-56096:Encoder prepping dict length of: 2
INFO:lightwood-56096:Encoder prepping dict length of: 3
INFO:lightwood-56096:Encoder prepping dict length of: 4
INFO:lightwood-56096:Encoder prepping dict length of: 5
INFO:lightwood-56096:Encoder prepping dict length of: 6
INFO:lightwood-56096:Encoder prepping dict length of: 7
INFO:lightwood-56096:Encoder prepping dict length of: 8
INFO:lightwood-56096:Encoder prepping dict length of: 9
INFO:lightwood-56096:Encoder prepping dict length of: 10
INFO:lightwood-56096:Encoder prepping dict length of: 11
INFO:lightwood-56096:Encoder prepping dict length of: 12
INFO:lightwood-56096:Encoder prepping dict length of: 13
INFO:lightwood-56096:Encoder prepping dict length of: 14
INFO:lightwood-56096:Done running for: target
INFO:lightwood-56096:Done running for: age
INFO:lightwood-56096:Done running for: sex
INFO:lightwood-56096:Done running for: cp
INFO:lightwood-56096:Done running for: trestbps
INFO:lightwood-56096:Done running for: chol
INFO:lightwood-56096:Done running for: fbs
INFO:lightwood-56096:Done running for: restecg
INFO:lightwood-56096:Done running for: thalach
INFO:lightwood-56096:Done running for: exang
INFO:lightwood-56096:Done running for: oldpeak
INFO:lightwood-56096:Done running for: slope
INFO:lightwood-56096:Done running for: ca
INFO:lightwood-56096:Done running for: thal
INFO:lightwood-56096:Featurizing the data
INFO:lightwood-56096:Training the mixers
INFO:lightwood-56096:Ensembling the mixer
INFO:lightwood-56096:Mixer: RandomForestMixer got accuracy: 0.8149038461538461
INFO:lightwood-56096:Picked best mixer: RandomForestMixer
INFO:lightwood-56096:Analyzing the ensemble of mixers
INFO:lightwood-56096:Adjustment on validation requested.
INFO:lightwood-56096:Updating the mixers
Finally, we can use the trained predictor to make some predictions, or save it to a pickle for later use
[6]:
predictions = predictor.predict(pd.DataFrame({
'age': [63, 15, None],
'sex': [1, 1, 0],
'thal': [3, 1, 1]
}))
print(predictions)
predictor.save('my_custom_heart_disease_predictor.pickle')
INFO:lightwood-56096:Dropping features: []
INFO:lightwood-56096:Cleaning the data
INFO:lightwood-56096:AccStats.explain() has not been implemented, no modifications will be done to the data insights.
INFO:lightwood-56096:GlobalFeatureImportance.explain() has not been implemented, no modifications will be done to the data insights.
prediction truth confidence
0 0 None 0.95
1 0 None 0.94
2 1 None 0.97
That’s it, all it takes to solve a predictive problem with lightwood using your own custom mixer.