Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

A reproduction of the method defined in the paper "Explanations based on the Missing".

License

NotificationsYou must be signed in to change notification settings

davidvos/contrastive-explanation-method

Repository files navigation

This repo contains a reproduction of the method defined in the paper"Explanations based on the Missing". It extends code created by the original authors, which can be foundhere.

The implementation on this github repository is given for two datasets: MNIST and FashionMNIST, but allows for easy extensions to new datasets (see below).

Requirements

Seerequirements.txt.

Installation

Download the github repo and navigate to its root folder.

Usage of the python package

An example of the usage of the python implementation on the MNIST dataset is given below. For further usage examples seeusage-examples.ipynb.

Load required modules

fromcem.datasets.mnistimportMNISTfromcem.models.cae_modelimportCAEfromcem.models.conv_modelimportCNNfromcem.trainimporttrain_ae,train_cnnfromcem.cemimportContrastiveExplanationMethod

Dataset and models

This repo comes with two pretrained sets of models. These models are contained inmodels/saved_models/. By default, the MNIST models will be loaded. To instead train the classifier and autoencoder from scratch, specify theload_path argument for thetrain_cnn andtrain_ae functions as an empty string:"".

dataset=MNIST()# load / train classifier model and weightscnn=CNNtrain_cnn(cnn,dataset)# load / train autoencoder model and weightscae=CAE()train_ae(cae,dataset)

Initialising the Contrastive Explanation Method class

TheContrastiveExplanationMethod class takes a classifier as positional argument. For a full overview of all arguments and their uses, seecem/cem.py.

CEM=ContrastiveExplanationMethod(cnn,cae,iterations=1000,n_searches=9,kappa=10.0,gamma=1.0beta=0.1,learning_rate=0.1,c_init=10.0)

Producing pertinent positives or negatives

The explain function takes as input a single sample, and returns the perturbed image that satisfies the classification objective with the lowest corresponding overall loss as defined inequation 1. To obtain the delta for said perturbed image, see the following example code.

# obtain a sample from the dataset, discard the labelsample,_=dataset.get_sample()# obtaining the PPperturbed_image=CEM.explain(sample,mode="PP")pp_delta=sample-perturbed_image# obtaining the PNperturbed_image=CEM.explain(sample,mode="PP")pn_delta=perturbed_image-sample

Usage of the command line implementation

Experiments can also be ran from the command line by calling 'main.py'. For an overview of all the arguments see below.

An example of the usage of this script for the FashionMNIST dataset is given below.

python main.py --verbose -mode PP -dataset FashionMNIST \  -cnn_load_path ./cem/models/saved_models/fashion-mnist-cnn.h5\  -cae_load_path ./cem/models/saved_models/fashion-mnist-cae.h5

Command line arguments

usage: main.py [-h] [-dataset DATASET] [-cnn_load_path CNN_LOAD_PATH]               [--no_cae] [-cae_load_path CAE_LOAD_PATH]               [-sample_from_class SAMPLE_FROM_CLASS] [--discard_images]               [-mode MODE] [-kappa KAPPA] [-beta BETA] [-gamma GAMMA]               [-c_init C_INIT] [-c_converge C_CONVERGE]               [-iterations ITERATIONS] [-n_searches N_SEARCHES]               [-learning_rate LEARNING_RATE] [-input_shape INPUT_SHAPE]               [--verbose] [-print_every PRINT_EVERY] [-device DEVICE]optional arguments:  -h, --help            show this help message and exit  -dataset DATASET      choose a dataset (MNIST or FashionMNIST) to apply the                        contrastive explanation method to. (default: MNIST)  -cnn_load_path CNN_LOAD_PATH                        path to load classifier weights from. (default:                        ./models/saved_models/mnist-cnn.h5)  --no_cae              disable the autoencoder (default: False)  -cae_load_path CAE_LOAD_PATH                        path to load autoencoder weights from. (default:                        ./models/saved_models/mnist-cae.h5)  -sample_from_class SAMPLE_FROM_CLASS                        specify which class to sample from for pertinent                        negative or positive (default: 3)  --discard_images      specify whether or not to save the created images                        (default: False)  -mode MODE            Either PP for pertinent positive or PN for pertinent                        negative. (default: PN)  -kappa KAPPA          kappa value used in the CEM attack loss. (default:                        10.0)  -beta BETA            beta value used as L1 regularisation coefficient.                        (default: 0.1)  -gamma GAMMA          gamma value used as reconstruction regularisation                        coefficient (default: 1.0)  -c_init C_INIT        initial c value used as regularisation coefficient for                        the attack loss (default: 10.0)  -c_converge C_CONVERGE                        c value to amend the value of c towards if no solution                        has been found in the current iterations (default:                        0.1)  -iterations ITERATIONS                        number of iterations per search (default: 1000)  -n_searches N_SEARCHES                        number of searches (default: 9)  -learning_rate LEARNING_RATE                        initial learning rate used to optimise the slack                        variable (default: 0.01)  -input_shape INPUT_SHAPE                        shape of a single sample, used to reshape input for                        classifier and autoencoder input (default: (1, 28,                        28))  --verbose             print loss information during training (default:                        False)  -print_every PRINT_EVERY                        if verbose mode is enabled, interval to print the                        current loss (default: 100)  -device DEVICE        device to run experiment on (default: cpu)

Extending to new datasets

To extend this implementation to a new dataset, inherit the 'Dataset' class specified in 'datasets.dataset.py' and overwrite the initialisation by specifying a train_data and test_data attribute containing a Pytorch Dataset, train_loader and test_loader attributes containing Pytorch Dataloaders and train_list and test_list attributes containing a list of samples.

License

MIT

About

A reproduction of the method defined in the paper "Explanations based on the Missing".

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors4

  •  
  •  
  •  
  •  

[8]ページ先頭

©2009-2025 Movatter.jp