Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up

TimeSHAP explains Recurrent Neural Network predictions.

License

NotificationsYou must be signed in to change notification settings

feedzai/timeshap

Repository files navigation

PyPI versionDownloads

TimeSHAP is a model-agnostic, recurrent explainer that builds upon KernelSHAP andextends it to the sequential domain.TimeSHAP computes event/timestamp- feature-, and cell-level attributions.As sequences can be arbitrarily long, TimeSHAP also implements a pruning algorithmbased on Shapley Values, that finds a subset of consecutive, recent events that contributethe most to the decision.

This repository is the code implementation of the TimeSHAP algorithmpresent in the paperTimeSHAP: Explaining Recurrent Models through Sequence Perturbationspublished atKDD 2021.

Links to the paperhere,and to the video presentationhere.

Install TimeSHAP

Via Pip
pip install timeshap
Via Github

Clone the repository into a local directory using:

git clone https://github.com/feedzai/timeshap.git

Move into the cloned repo and install the package:

cd timeshappip install .
Test your installation

Start a Python session in your terminal using

python

And import TimeSHAP

import timeshap

TimeSHAP in 30 seconds

Inputs

  • Model being explained;
  • Instance(s) to explain;
  • Background instance.

Outputs

  • Local pruning output; (explaining a single instance)
  • Local event explanations; (explaining a single instance)
  • Local feature explanations; (explaining a single instance)
  • Global pruning statistics; (explaining multiple instances)
  • Global event explanations; (explaining multiple instances)
  • Global feature explanations; (explaining multiple instances)

Model Interface

In order for TimeSHAP to explain a model, an entry point must be provided.ThisCallable entry point must receive a 3-D numpy array,(#sequences; #sequence length; #features)and return a 2-D numpy array(#sequences; 1) with the corresponding score of each sequence.

In addition, to make TimeSHAP more optimized, it is possible to return thehidden stateof the model together with the score (if applicable). Although this is optional, we highly recommended it,as it has a very high impact.If you choose to return the hidden state, this hidden state should either be:(seenotebook for specific examples)

  • a 3-D numpy array,(#rnn layers, #sequences, #hidden_dimension) (classExplainedRNN on notebook);
  • a tuple of numpy arrays that follows the previously described characteristic(usually used when using stacked RNNs with different hidden dimensions) (classExplainedGRU2Layer on notebook);
  • a tuple of tuples of numpy arrays (usually used when using LSTM's) (classExplainedLSTM on notebook);;TimeSHAP is able to explain any black-box model as long as it complies with thepreviously described interface, including both PyTorch and TensorFlow models,both examplified in our tutorials (PyTorch,TensorFlow).

Example provided in our tutorials:

  • TensorFLow
model = tf.keras.models.Model(inputs=inputs, outputs=ff2)f = lambda x: model.predict(x)
  • Pytorch - (Example where model receives and returns hidden states)
model_wrapped = TorchModelWrapper(model)f_hs = lambda x, y=None: model_wrapped.predict_last_hs(x, y)
Model Wrappers

In order to facilitate the interface between models and TimeSHAP,TimeSHAP implementsModelWrappers. These wrappers, used on the PyTorchtutorial notebook, allow for greater flexibilityof explained models as they allow:

  • Batching logic: useful when using very large inputs or NSamples, which cannot fiton GPU memory, and therefore batching mechanisms are required;
  • Input format/type: useful when your model does not work with numpy arrays. Thisis the case of our provided PyToch example;
  • Hidden state logic: useful when the hidden states of your models do not matchthe hidden state format required by TimeSHAP

TimeSHAP Explanation Methods

TimeSHAP offers several methods to use depending on the desired explanations.Local methods provide detailed view of a model decision correspondingto a specific sequence being explained.Global methods aggregate local explanations of a given datasetto present a global view of the model.

Local Explanations

Pruning

local_pruning() performs the pruningalgorithm on a given sequence with a given user defined tolerance and returnsthe pruning index along the information for plotting.

plot_temp_coalition_pruning() plots the pruningalgorithm information calculated bylocal_pruning().

Event level explanations

local_event() calculates event level explanationsof a given sequence with the user-given parameteres and returns the respectiveevent-level explanations.

plot_event_heatmap() plots the event-level explanationscalculated bylocal_event().

Feature level explanations

local_feat() calculates feature level explanationsof a given sequence with the user-given parameteres and returns the respectivefeature-level explanations.

plot_feat_barplot() plots the feature-level explanationscalculated bylocal_feat().

Cell level explanations

local_cell_level() calculates cell level explanationsof a given sequence with the respective event- and feature-level explanationsand user-given parameteres, returing the respective cell-level explanations.

plot_cell_level() plots the feature-level explanationscalculated bylocal_cell_level().

Local Report

local_report() calculates TimeSHAPlocal explanations for a given sequence and plots them.

Global Explanations

Global pruning statistics

prune_all() performs the pruningalgorithm on multiple given sequences.

pruning_statistics() calculates the pruningstatistics for several user-given pruning tolerances using the pruningdata calculated byprune_all(), returning apandas.DataFrame with the statistics.

Global event level explanations

event_explain_all() calculates TimeSHAPevent level explanations for multiple instances given user defined parameters.

plot_global_event() plots the global event-level explanationscalculated byevent_explain_all().

Global feature level explanations

feat_explain_all() calculates TimeSHAPfeature level explanations for multiple instances given user defined parameters.

plot_global_feat() plots the global feature-levelexplanations calculated byfeat_explain_all().

Global report

global_report() calculates TimeSHAPexplanations for multiple instances, aggregating the explanations on two plotsand returning them.

Tutorial

In order to demonstrate TimeSHAP interfaces and methods, you can consultAReM.ipynb.In this tutorial we get an open-source dataset, process it, trainPytorch recurrent model with it and use TimeSHAP to explain it, showcasing allpreviously described methods.

Additionally, we also train a TensorFlow model on the same datasetAReM_TF.ipynb.

Repository Structure

Citing TimeSHAP

@inproceedings{bento2021timeshap,    author = {Bento, Jo\~{a}o and Saleiro, Pedro and Cruz, Andr\'{e} F. and Figueiredo, M\'{a}rio A.T. and Bizarro, Pedro},    title = {TimeSHAP: Explaining Recurrent Models through Sequence Perturbations},    year = {2021},    isbn = {9781450383325},    publisher = {Association for Computing Machinery},    address = {New York, NY, USA},    url = {https://doi.org/10.1145/3447548.3467166},    doi = {10.1145/3447548.3467166},    booktitle = {Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery & Data Mining},    pages = {2565–2573},    numpages = {9},    keywords = {SHAP, Shapley values, TimeSHAP, XAI, RNN, explainability},    location = {Virtual Event, Singapore},    series = {KDD '21}}

[8]ページ先頭

©2009-2025 Movatter.jp