{ "cells": [ { "cell_type": "markdown", "id": "03e21aaf", "metadata": {}, "source": [ "# Study of the plausibility of generated counterfactuals\n", "\n", "The objective of this notewokk is to study the plausibility of the counterfactuals generated by VCNet.\n", "We want to assess two properties of the generated counterfactuals :\n", "\n", "* are the generated counterfactuals in the distribution of the examples? If not, some generated examples may be seen as unrealistic.\n", "* are the generated counterfactuals diverse? If not, this would mean that VCNet generates always the same counterfactuals (overfitting in the learning of the distribution of instances)\n", "\n", "In this study, we will use synthetic data and compare the generation of counterfactuals by the Wachter method and by VCNet.\n", "\n", "Let first install some useful library for these experiments and then load all the required libraries.\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a6ccd2ba", "metadata": {}, "outputs": [], "source": [ "!pip install openTSNE\n", "!pip install seaborn\n", "!pip install tensorboard\n", "!pip install tensorboardX" ] }, { "cell_type": "code", "execution_count": null, "id": "1030d4ae", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import torch\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import lightning as L\n", "from lightning.pytorch.loggers import TensorBoardLogger\n", "from openTSNE import TSNE" ] }, { "cell_type": "code", "execution_count": null, "id": "30f8f813", "metadata": {}, "outputs": [], "source": [ "from vcnet import DataCatalog, VCNet\n", "from vcnet import PHVCNet\n", "from vcnet import SKLearnClassifier" ] }, { "cell_type": "markdown", "id": "c445c796", "metadata": {}, "source": [ "## Test VCNet on synthetic data\n", "\n", "In this section, we test VCNet on simple synthetic datasets.\n", "Datasets are generated on the principle of blobs with two classes: class centers are randomly generated using a uniform law, and then, samples are generated around this center according to a normal distribution (with fixed variance)." ] }, { "cell_type": "code", "execution_count": null, "id": "0d21e557", "metadata": {}, "outputs": [], "source": [ "np.random.seed(531)\n", "\n", "class_size = 2\n", "nbpd = 1000\n", "num_features = 5" ] }, { "cell_type": "code", "execution_count": null, "id": "478dfcd7", "metadata": {}, "outputs": [], "source": [ "class_centers = np.random.uniform(low=-10, high=10, size=(class_size, num_features) )\n", "class_spreads = np.random.uniform(low=0, high=3, size=(class_size, num_features) )\n", "\n", "print(class_centers)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "066d115a", "metadata": {}, "outputs": [], "source": [ "samples = []\n", "labels = []\n", "\n", "for c in range(class_size):\n", " # domain-class center C_dc\n", " class_center = class_centers[c] # (num_features,)\n", " scale = class_spreads [c]\n", " \n", " # sample\n", " generated_samples = np.random.normal(loc=class_center, scale=scale, size=(nbpd // class_size, num_features))\n", " samples.append(generated_samples)\n", " labels.append(np.full(nbpd // class_size, c)) # y\n", "\n", "samples = np.vstack(samples) # (nb_domains * nbpd, num_features)\n", "labels = np.concatenate(labels) # (nb_domains * nbpd,)\n", "\n", "# create a dataframe\n", "data = pd.DataFrame(samples, columns=[f\"x_{i+1}\" for i in range(num_features)])\n", "data[\"y\"] = labels\n", "\n", "#shuffle rows\n", "data = data.sample(frac=1).reset_index(drop=True)\n", "data.head()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "d7cf4748", "metadata": {}, "outputs": [], "source": [ "sns.scatterplot(data=data, x=\"x_1\", y=\"x_3\", hue=\"y\")" ] }, { "cell_type": "code", "execution_count": null, "id": "88facc86", "metadata": {}, "outputs": [], "source": [ "dataset_settings = {\n", " \"target\":\"y\",\n", " \"continuous\" : data.columns[:-1].to_list(),\n", " \"categorical\": [],\n", " \"immutables\" : [],\n", " \"scaling_method\": \"MinMax\",\n", " 'encoding_method': \"Identity\"\n", "}\n", "dataset = DataCatalog(dataset_settings)" ] }, { "cell_type": "code", "execution_count": null, "id": "4978d384", "metadata": {}, "outputs": [], "source": [ "dataset_settings = dataset.prepare_data(data)" ] }, { "cell_type": "markdown", "id": "a1c35158", "metadata": {}, "source": [ "### Learn the classifier\n", "\n", "Learn now a basic random forest classifier ... that is assumed to have a good accuracy on the easy to separate datasets" ] }, { "cell_type": "code", "execution_count": null, "id": "90ec2553", "metadata": {}, "outputs": [], "source": [ "hp = {\n", " \"dataset\": dataset_settings,\n", " \"classifier_params\": {\n", " \"skname\": \"RandomForestClassifier\",\n", " \"kwargs\": {\n", " \"n_estimators\": 50,\n", " }\n", " }\n", "}\n", "\n", "classifier = SKLearnClassifier(hp)\n", "classifier.fit(dataset.df_train)" ] }, { "cell_type": "markdown", "id": "b656f945", "metadata": {}, "source": [ "### Learn a post-hoc counterfactual generator\n", "\n", "We now learn a post-hoc VCNet model to generate realistic counterfactuals" ] }, { "cell_type": "code", "execution_count": null, "id": "fbe99530", "metadata": {}, "outputs": [], "source": [ "hp[\"vcnet_params\"]= {\n", " \"lr\": 1e-3,\n", " \"epochs\": 200,\n", " \"lambda_KLD\": 5,\n", " \"lambda_BCE\": 1,\n", " \"latent_size\": 16,\n", " \"latent_size_share\": 64,\n", " \"mid_reduce_size\": 32\n", " }\n", "\n", "# Define the post-hoc VCNet model\n", "vcnet = PHVCNet(hp, classifier)\n", "\n", "# Finally, fit it using a Lightning module\n", "logger = TensorBoardLogger(\"tb_logs\", name=\"PHVCNet\")\n", "trainer = L.Trainer(max_epochs=hp['vcnet_params']['epochs'], enable_checkpointing=False, logger=logger)\n", "trainer.fit(model=vcnet, train_dataloaders=dataset.train_dataloader())" ] }, { "cell_type": "markdown", "id": "cdb84fd6", "metadata": {}, "source": [ "### Some basic verification on the test set" ] }, { "cell_type": "markdown", "id": "23da6e5a", "metadata": {}, "source": [ "We finally check the prediction accuracy of the classifier and the validity of the counterfactuals that have been generated." ] }, { "cell_type": "code", "execution_count": null, "id": "fd1ae1c5", "metadata": {}, "outputs": [], "source": [ "vcnet.eval()\n", "for ldata, labels in dataset.test_dataloader():\n", " cl = vcnet.forward_pred(ldata)\n", " cf, clcf = vcnet.counterfactuals(ldata)\n", " rlcf = dataset.data_unloader(cf,clcf)\n", "\n", "\n", " acc = torch.sum((cl[:, 0] > 0.5).int() == labels[:, 0]) / len(ldata)\n", " validity = torch.sum((cl[:, 0] > 0.5).int() != (clcf[:, 0] > 0.5).int()) / len(ldata)\n", " print(f\"Accuracy: {acc}, validity:{validity}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "3edf6b4e", "metadata": {}, "outputs": [], "source": [ "#torch.dstack((cl[:, 0],clcf[:, 0]))" ] }, { "cell_type": "code", "execution_count": null, "id": "3dfec7c4", "metadata": {}, "outputs": [], "source": [ "rlcf" ] }, { "cell_type": "markdown", "id": "37b150fe", "metadata": {}, "source": [ "### Visualisation of the spread of generated counterfactuals : realism and diversity\n", "\n", "In this visualisation, we compute a TSNE projection to visualize the dataset in 2 dimensions.\n", "\n", "The TNSE projection is computed on the training set, and then, we reuse the same projection to visualize the tests set (x) and their counterfactuals (o) in the same space." ] }, { "cell_type": "code", "execution_count": null, "id": "7178af04", "metadata": {}, "outputs": [], "source": [ "# TSNE\n", "tsne = TSNE(\n", " perplexity=30,\n", " metric=\"euclidean\",\n", " n_jobs=8,\n", " random_state=42,\n", " verbose=True,\n", ")\n", "\n", "# Create a projection from the Train dataset\n", "embedding_train = tsne.fit( dataset.df_train.drop(columns='y').to_numpy() )" ] }, { "cell_type": "code", "execution_count": null, "id": "5d8483d4", "metadata": {}, "outputs": [], "source": [ "# use the projection to project the generated counterfactuals of the test set\n", "embedding_cf_test = embedding_train.transform( cf.numpy() )" ] }, { "cell_type": "code", "execution_count": null, "id": "e266fc18", "metadata": {}, "outputs": [], "source": [ "# use the same projection to project the data of the test set\n", "embedding_test = embedding_train.transform( dataset.df_test.drop(columns='y').to_numpy() )" ] }, { "cell_type": "code", "execution_count": null, "id": "d4ee89be", "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"Paired\")\n", "plotvalues = pd.DataFrame({'x':embedding_train[:,0], 'y':embedding_train[:,1], 'c':dataset.df_train['y']})\n", "fig=sns.kdeplot(data=plotvalues, x=\"x\", y=\"y\", hue=\"c\", fill=False, levels=5)\n", "plt.scatter( embedding_test[:,0], embedding_test[:,1], c=(clcf[:, 0] > 0.5).int(), cmap='Paired', marker='x')\n", "plt.scatter( embedding_cf_test[:,0], embedding_cf_test[:,1], c=(clcf[:, 0] > 0.5).int(), cmap='Paired', marker='o')\n", "fig.axes.get_xaxis().set_visible(False)\n", "fig.axes.get_yaxis().set_visible(False)\n", "plt.legend('',frameon=False)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "daaccf72", "metadata": {}, "source": [ "### Same experiment, with VCNet\n", "\n", "The posthoc version seems to work correctly ... then, we now investigate the joint-learning model of VCNet." ] }, { "cell_type": "code", "execution_count": null, "id": "9537ce13", "metadata": {}, "outputs": [], "source": [ "hp[\"vcnet_params\"]= {\n", " \"lr\": 2e-3,\n", " \"epochs\" : 10,\n", " \"lambda_KLD\": 0.5,\n", " \"lambda_CE\": 0.93,\n", " \"lambda_BCE\": 1,\n", " \"latent_size\" : 19,\n", " \"latent_size_share\" : 304,\n", " \"mid_reduce_size\" : 152\n", " }\n", "\n", "vcnet = VCNet(hp)" ] }, { "cell_type": "code", "execution_count": null, "id": "b3fb0880", "metadata": {}, "outputs": [], "source": [ "# Finally, fit it using a Lightning module\n", "trainer = L.Trainer(max_epochs=hp['vcnet_params']['epochs'])\n", "trainer.fit(model=vcnet, train_dataloaders=dataset.train_dataloader())" ] }, { "cell_type": "code", "execution_count": null, "id": "70368f74", "metadata": {}, "outputs": [], "source": [ "vcnet.eval()\n", "for data, labels in dataset.test_dataloader():\n", " cl = vcnet.forward_pred(data)\n", " cf, clcf = vcnet.counterfactuals(data)\n", " rlcf = dataset.data_unloader(cf,clcf)\n", "\n", "\n", " acc = torch.sum((cl[:, 0] > 0.5).int() == labels[:, 0]) / len(data)\n", " validity = torch.sum((cl[:, 0] > 0.5).int() != (clcf[:, 0] > 0.5).int()) / len(data)\n", " print(f\"Accuracy: {acc}, validity:{validity}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "6e3c0ea7", "metadata": {}, "outputs": [], "source": [ "# we simply reproject the counterfactual (in the similar space as before) \n", "embedding_cf_test = embedding_train.transform( cf.numpy() )" ] }, { "cell_type": "code", "execution_count": null, "id": "43133ed1", "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"Paired\")\n", "plotvalues = pd.DataFrame({'x':embedding_train[:,0], 'y':embedding_train[:,1], 'c':dataset.df_train['y']})\n", "fig=sns.kdeplot(data=plotvalues, x=\"x\", y=\"y\", hue=\"c\", fill=False, levels=5)\n", "plt.scatter( embedding_test[:,0], embedding_test[:,1], c=(clcf[:, 0] > 0.5).int(), cmap='Paired', marker='x')\n", "plt.scatter( embedding_cf_test[:,0], embedding_cf_test[:,1], c=(clcf[:, 0] > 0.5).int(), cmap='Paired', marker='o')\n", "fig.axes.get_xaxis().set_visible(False)\n", "fig.axes.get_yaxis().set_visible(False)\n", "plt.legend('',frameon=False)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "28b328c3", "metadata": {}, "source": [ "With the joint version, the spread of counterfactual seems to be better than before." ] } ], "metadata": { "kernelspec": { "display_name": "venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }