VCNet Post-Hoc Examples (with scikit-learn Classifier)

In this page, we illustrate how to use PHVCNet to generate counterfactuals for a shallow classifier. We take as an example a scikit-learn classifier.

Warning

The classifier must be probabilistic.

We use the same dataset as in the VCNet example: the Adult dataset.

This time, we need to import two classes from our library:

from vcnet import PHVCNet
from vcnet import SKLearnClassifier

PHVCNet is the implementation of the counterfactual generator. SKLearnClassifier is a wrapper for scikit-learn classifiers. It is configured through a dictionary specifying the classifier to use.

Preparing the Dataset

The data preparation is exactly the same as for VCNet. Please refer to the previous example.

In principle, you create a DataCatalog from the settings and prepare the data.

dataset = DataCatalog(dataset_settings)
dataset_settings = dataset.prepare_data(df)

Preparing and Training the Classification Model

Now, let’s define the classification model and train it using the SKLearnClassifier class.

You can use the following code as a template:

hp = {
    "dataset": dataset_settings,
    "classifier_params": {
        "skname": "RandomForestClassifier",
        "kwargs": {
            "n_estimators": 50,
        }
    }
}

train_loader = dataset.train_dataloader()
test_loader = dataset.test_dataloader()

classifier = SKLearnClassifier(hp)
classifier.fit(dataset.df_train)

In this code, the hyperparameters of the classifier are defined in classifier_params. It specifies the scikit-learn classifier class name and its kwargs.

In this example, we use a random forest with 50 trees.

Once the classifier has been defined through the hyperparameters passed to SKLearnClassifier, the fit() function trains the classifier on the training dataset.

At this stage, you have a trained classifier, but not yet the counterfactual generator.

Note

We also wrapped an XGBoost classifier (XGBoostClassifier)

Training the Counterfactual Generator Model

The final step is to train the counterfactual generator model post-hoc. This process is very similar to the VCNet example, with the key difference being that creating an instance of PHVCNet requires a trained classifier as a parameter.

import lightning as L

hp = {
    "dataset": dataset_settings,
    "classifier_params": hp['classifier_params'],
    "vcnet_params": {
        "lr": 1e-2,
        "epochs": 10,
        "lambda_KLD": 0.5,
        "lambda_BCE": 1,
        "latent_size": 16,
        "latent_size_share": 64,
        "mid_reduce_size": 32
    }
}

# Define the post-hoc VCNet model
vcnet = PHVCNet(hp, classifier)

# Finally, fit it using a Lightning module
trainer = L.Trainer(max_epochs=hp['vcnet_params']['epochs'])
trainer.fit(model=vcnet, train_dataloaders=train_loader)

Once your counterfactual generation model has been trained, it can be used in the same way as other VCNet models (see the other examples).