VCNet Post-Hoc Examples (with scikit-learn Classifier)
In this page, we illustrate how to use PHVCNet
to generate counterfactuals for a shallow classifier.
We take as an example a scikit-learn classifier.
Warning
The classifier must be probabilistic.
We use the same dataset as in the VCNet example: the Adult dataset.
This time, we need to import two classes from our library:
from vcnet import PHVCNet
from vcnet import SKLearnClassifier
PHVCNet
is the implementation of the counterfactual generator.
SKLearnClassifier
is a wrapper for scikit-learn classifiers. It is configured through a dictionary specifying the classifier to use.
Preparing the Dataset
The data preparation is exactly the same as for VCNet. Please refer to the previous example.
In principle, you create a DataCatalog
from the settings and prepare the data.
dataset = DataCatalog(dataset_settings)
dataset_settings = dataset.prepare_data(df)
Preparing and Training the Classification Model
Now, let’s define the classification model and train it using the SKLearnClassifier
class.
You can use the following code as a template:
hp = {
"dataset": dataset_settings,
"classifier_params": {
"skname": "RandomForestClassifier",
"kwargs": {
"n_estimators": 50,
}
}
}
train_loader = dataset.train_dataloader()
test_loader = dataset.test_dataloader()
classifier = SKLearnClassifier(hp)
classifier.fit(dataset.df_train)
In this code, the hyperparameters of the classifier are defined in classifier_params. It specifies the scikit-learn classifier class name and its kwargs.
In this example, we use a random forest with 50 trees.
Once the classifier has been defined through the hyperparameters passed to SKLearnClassifier
, the fit()
function trains the classifier on the training dataset.
At this stage, you have a trained classifier, but not yet the counterfactual generator.
Note
We also wrapped an XGBoost classifier (XGBoostClassifier
)
Training the Counterfactual Generator Model
The final step is to train the counterfactual generator model post-hoc.
This process is very similar to the VCNet example, with the key difference being that creating an instance of PHVCNet
requires a trained classifier as a parameter.
import lightning as L
hp = {
"dataset": dataset_settings,
"classifier_params": hp['classifier_params'],
"vcnet_params": {
"lr": 1e-2,
"epochs": 10,
"lambda_KLD": 0.5,
"lambda_BCE": 1,
"latent_size": 16,
"latent_size_share": 64,
"mid_reduce_size": 32
}
}
# Define the post-hoc VCNet model
vcnet = PHVCNet(hp, classifier)
# Finally, fit it using a Lightning module
trainer = L.Trainer(max_epochs=hp['vcnet_params']['epochs'])
trainer.fit(model=vcnet, train_dataloaders=train_loader)
Once your counterfactual generation model has been trained, it can be used in the same way as other VCNet models (see the other examples).