- Notifications
You must be signed in to change notification settings - Fork552
Model interpretability and understanding for PyTorch
License
meta-pytorch/captum
Folders and files
| Name | Name | Last commit message | Last commit date | |
|---|---|---|---|---|
Repository files navigation
Captum is a model interpretability and understanding library for PyTorch.Captum means comprehension in Latin and contains general purpose implementationsof integrated gradients, saliency maps, smoothgrad, vargrad and others forPyTorch models. It has quick integration for models built with domain-specificlibraries such as torchvision, torchtext, and others.
With the increase in model complexity and the resulting lack of transparency, model interpretability methods have become increasingly important. Model understanding is both an active area of research as well as an area of focus for practical applications across industries using machine learning. Captum provides state-of-the-art algorithms such as Integrated Gradients, Testing with Concept Activation Vectors (TCAV), TracIn influence functions, just to name a few, that provide researchers and developers with an easy way to understand which features, training examples or concepts contribute to a models' predictions and in general what and how the model learns. In addition to that, Captum also provides adversarial attacks and minimal input perturbation capabilities that can be used both for generating counterfactual explanations and adversarial perturbations.
Captum helps ML researchers more easily implement interpretability algorithms that can interact with PyTorch models. Captum also allows researchers to quickly benchmark their work against other existing algorithms available in the library.
The primary audiences for Captum are model developers who are looking to improve their models and understand which concepts, features or training examples are important and interpretability researchers focused on identifying algorithms that can better interpret many types of models.
Captum can also be used by application engineers who are using trained models in production. Captum provides easier troubleshooting through improved model interpretability, and the potential for delivering better explanations to end users on why they’re seeing a specific piece of content, such as a movie recommendation.
Installation Requirements
- Python >= 3.10
- PyTorch >= 2.3
Install released Captum viapip.
Withpip
pip install captum
Manual / Dev install
If you'd like to try our bleeding edge features (and don't mind potentiallyrunning into the occasional bug here or there), you can install the latestmaster directly from GitHub. For a basic install, run:
git clone https://github.com/pytorch/captum.gitcd captumpip install -e.
To customize the installation, you can also run the following variants of theabove:
pip install -e .[dev]: Also installs all tools necessary for development(testing, linting, docs building; seeContributing below).pip install -e .[tutorials]: Also installs all packages necessary for running the tutorial notebooks.
To execute unit tests from a manual install, run:
# running a single unit testpython -m unittest -v tests.attr.test_saliency# running all unit testspytest -ra
Captum helps you interpret and understand predictions of PyTorch models byexploring features that contribute to a prediction the model makes.It also helps understand which neurons and layers are important formodel predictions.
Let's apply some of those algorithms to a toy model we have created fordemonstration purposes.For simplicity, we will use the following architecture, but users are welcometo use any PyTorch model of their choice.
importnumpyasnpimporttorchimporttorch.nnasnnfromcaptum.attrimport (GradientShap,DeepLift,DeepLiftShap,IntegratedGradients,LayerConductance,NeuronConductance,NoiseTunnel,)classToyModel(nn.Module):def__init__(self):super().__init__()self.lin1=nn.Linear(3,3)self.relu=nn.ReLU()self.lin2=nn.Linear(3,2)# initialize weights and biasesself.lin1.weight=nn.Parameter(torch.arange(-4.0,5.0).view(3,3))self.lin1.bias=nn.Parameter(torch.zeros(1,3))self.lin2.weight=nn.Parameter(torch.arange(-3.0,3.0).view(2,3))self.lin2.bias=nn.Parameter(torch.ones(1,2))defforward(self,input):returnself.lin2(self.relu(self.lin1(input)))
Let's create an instance of our model and set it to eval mode.
model=ToyModel()model.eval()
Next, we need to define simple input and baseline tensors.Baselines belong to the input space and often carry no predictive signal.Zero tensor can serve as a baseline for many tasks.Some interpretability algorithms such asIntegratedGradients,Deeplift andGradientShap are designed to attribute the changebetween the input and baseline to a predictive class or a value that the neuralnetwork outputs.
We will apply model interpretability algorithms on the networkmentioned above in order to understand the importance of individualneurons/layers and the parts of the input that play an important role in thefinal prediction.
To make computations deterministic, let's fix random seeds.
torch.manual_seed(123)np.random.seed(123)
Let's define our input and baseline tensors. Baselines are used in someinterpretability algorithms such asIntegratedGradients, DeepLift, GradientShap, NeuronConductance, LayerConductance, InternalInfluence andNeuronIntegratedGradients.
input=torch.rand(2,3)baseline=torch.zeros(2,3)
Next we will useIntegratedGradients algorithms to assign attributionscores to each input feature with respect to the first target output.
ig=IntegratedGradients(model)attributions,delta=ig.attribute(input,baseline,target=0,return_convergence_delta=True)print('IG Attributions:',attributions)print('Convergence Delta:',delta)
Output:
IG Attributions: tensor([[-0.5922, -1.5497, -1.0067], [ 0.0000, -0.2219, -5.1991]])Convergence Delta: tensor([2.3842e-07, -4.7684e-07])The algorithm outputs an attribution score for each input element and aconvergence delta. The lower the absolute value of the convergence delta the betteris the approximation. If we choose not to return delta,we can simply not provide thereturn_convergence_delta inputargument. The absolute value of the returned deltas can be interpreted as anapproximation error for each input sample.It can also serve as a proxy of how accurate the integral approximation for giveninputs and baselines is.If the approximation error is large, we can try a larger number of integralapproximation steps by settingn_steps to a larger value. Not all algorithmsreturn approximation error. Those which do, though, compute it based on thecompleteness property of the algorithms.
Positive attribution score means that the input in that particular positionpositively contributed to the final prediction and negative means the opposite.The magnitude of the attribution score signifies the strength of the contribution.Zero attribution score means no contribution from that particular feature.
Similarly, we can applyGradientShap,DeepLift and other attribution algorithms to the model.
GradientShap first chooses a random baseline from baselines' distribution, thenadds gaussian noise with std=0.09 to each input examplen_samples times.Afterwards, it chooses a random point between each example-baseline pair andcomputes the gradients with respect to target class (in this case target=0). Resultingattribution is the mean of gradients * (inputs - baselines)
gs=GradientShap(model)# We define a distribution of baselines and draw `n_samples` from that# distribution in order to estimate the expectations of gradients across all baselinesbaseline_dist=torch.randn(10,3)*0.001attributions,delta=gs.attribute(input,stdevs=0.09,n_samples=4,baselines=baseline_dist,target=0,return_convergence_delta=True)print('GradientShap Attributions:',attributions)print('Convergence Delta:',delta)
Output
GradientShap Attributions: tensor([[-0.1542, -1.6229, -1.5835], [-0.3916, -0.2836, -4.6851]])Convergence Delta: tensor([ 0.0000, -0.0005, -0.0029, -0.0084, -0.0087, -0.0405, 0.0000, -0.0084])Deltas are computed for eachn_samples * input.shape[0] example. The user can,for instance, average them:
deltas_per_example=torch.mean(delta.reshape(input.shape[0],-1),dim=1)
in order to get per example average delta.
Below is an example of how we can applyDeepLift andDeepLiftShap on theToyModel described above. The current implementation of DeepLift supports only theRescale rule.For more details on alternative implementations, please see theDeepLift paper.
dl=DeepLift(model)attributions,delta=dl.attribute(input,baseline,target=0,return_convergence_delta=True)print('DeepLift Attributions:',attributions)print('Convergence Delta:',delta)
Output
DeepLift Attributions: tensor([[-0.5922, -1.5497, -1.0067], [ 0.0000, -0.2219, -5.1991])Convergence Delta: tensor([0., 0.])DeepLift assigns similar attribution scores asIntegratedGradients to inputs,however it has lower execution time. Another important thing to remember aboutDeepLift is that it currently doesn't support all non-linear activation types.For more details on limitations of the current implementation, please see theDeepLift paper.
Similar to integrated gradients, DeepLift returns a convergence delta scoreper input example. The approximation error is then the absolutevalue of the convergence deltas and can serve as a proxy of how accurate thealgorithm's approximation is.
Now let's look intoDeepLiftShap. Similar toGradientShap,DeepLiftShap usesbaseline distribution. In the example below, we use the same baseline distributionas forGradientShap.
dl=DeepLiftShap(model)attributions,delta=dl.attribute(input,baseline_dist,target=0,return_convergence_delta=True)print('DeepLiftSHAP Attributions:',attributions)print('Convergence Delta:',delta)
Output
DeepLiftShap Attributions: tensor([[-5.9169e-01, -1.5491e+00, -1.0076e+00], [-4.7101e-03, -2.2300e-01, -5.1926e+00]], grad_fn=<MeanBackward1>)Convergence Delta: tensor([-4.6120e-03, -1.6267e-03, -5.1045e-04, -1.4184e-03, -6.8886e-03, -2.2224e-02, 0.0000e+00, -2.8790e-02, -4.1285e-03, -2.7295e-02, -3.2349e-03, -1.6265e-03, -4.7684e-07, -1.4191e-03, -6.8889e-03, -2.2224e-02, 0.0000e+00, -2.4792e-02, -4.1289e-03, -2.7296e-02])DeepLiftShap usesDeepLift to compute attribution score for eachinput-baseline pair and averages it for each input across all baselines.
It computes deltas for each input example-baseline pair, thus resulting toinput.shape[0] * baseline.shape[0] delta values.
Similar to GradientShap in order to compute example-based deltas we can average them per example:
deltas_per_example=torch.mean(delta.reshape(input.shape[0],-1),dim=1)
In order to smooth and improve the quality of the attributions we can runIntegratedGradients and other attribution methods through aNoiseTunnel.NoiseTunnel allows us to useSmoothGrad,SmoothGrad_Sq andVarGrad techniquesto smoothen the attributions by aggregating them for multiple noisysamples that were generated by adding gaussian noise.
Here is an example of how we can useNoiseTunnel withIntegratedGradients.
ig=IntegratedGradients(model)nt=NoiseTunnel(ig)attributions,delta=nt.attribute(input,nt_type='smoothgrad',stdevs=0.02,nt_samples=4,baselines=baseline,target=0,return_convergence_delta=True)print('IG + SmoothGrad Attributions:',attributions)print('Convergence Delta:',delta)
Output
IG + SmoothGrad Attributions: tensor([[-0.4574, -1.5493, -1.0893], [ 0.0000, -0.2647, -5.1619]])Convergence Delta: tensor([ 0.0000e+00, 2.3842e-07, 0.0000e+00, -2.3842e-07, 0.0000e+00, -4.7684e-07, 0.0000e+00, -4.7684e-07])The number of elements in thedelta tensor is equal to:nt_samples * input.shape[0]In order to get an example-wise delta, we can, for example, average them:
deltas_per_example=torch.mean(delta.reshape(input.shape[0],-1),dim=1)
Let's look into the internals of our network and understand which layersand neurons are important for the predictions.
We will start with theNeuronConductance.NeuronConductance helps us to identifyinput features that are important for a particular neuron in a givenlayer. It decomposes the computation of integrated gradients via the chain rule bydefining the importance of a neuron as path integral of the derivative of the outputwith respect to the neuron times the derivatives of the neuron with respect to theinputs of the model.
In this case, we choose to analyze the first neuron in the linear layer.
nc=NeuronConductance(model,model.lin1)attributions=nc.attribute(input,neuron_selector=1,target=0)print('Neuron Attributions:',attributions)
Output
Neuron Attributions: tensor([[ 0.0000, 0.0000, 0.0000], [ 1.3358, 0.0000, -1.6811]])Layer conductance shows the importance of neurons for a layer and given input.It is an extension of path integrated gradients for hidden layers and holds thecompleteness property as well.
It doesn't attribute the contribution scores to the input featuresbut shows the importance of each neuron in the selected layer.
lc=LayerConductance(model,model.lin1)attributions,delta=lc.attribute(input,baselines=baseline,target=0,return_convergence_delta=True)print('Layer Attributions:',attributions)print('Convergence Delta:',delta)
Outputs
Layer Attributions: tensor([[ 0.0000, 0.0000, -3.0856], [ 0.0000, -0.3488, -4.9638]], grad_fn=<SumBackward1>)Convergence Delta: tensor([0.0630, 0.1084])Similar to other attribution algorithms that return convergence delta,LayerConductancereturns the deltas for each example. The approximation error is then the absolutevalue of the convergence deltas and can serve as a proxy of how accurate integralapproximation for given inputs and baselines is.
More details on the list of supported algorithms and how to applyCaptum on different types of models can be found in our tutorials.
If you have questions about using Captum methods, please check thisFAQ, which addresses many common issues.
See theCONTRIBUTING file for how to help out.
NeurIPS 2019:The slides of our presentation can be foundhere
KDD 2020:The slides of our presentation from KDD 2020 tutorial can be foundhere.You can watch the recorded talkhere
GTC 2020:Opening Up the Black Box: Model Understanding with Captum and PyTorch.You can watch the recorded talkhere
XAI Summit 2020:Using Captum and Fiddler to Improve Model Understanding with Explainable AI.You can watch the recorded talkhere
PyTorch Developer Day 2020Model Interpretability.You can watch the recorded talkhere
NAACL 2021Tutorial on Fine-grained Interpretation and Causation Analysis in Deep NLP Models.You can watch the recorded talkhere
ICLR 2021 workshop on Responsible AI:
Summer school on medical imaging at University of Lyon has a class on model explainability you can watchhere
IntegratedGradients,LayerIntegratedGradients:Axiomatic Attribution for Deep Networks, Mukund Sundararajan et al. 2017 andDid the Model Understand the Question?, Pramod K. Mudrakarta, et al. 2018InputXGradient:Not Just a Black Box: Learning Important Features Through Propagating Activation Differences, Avanti Shrikumar et al. 2016SmoothGrad:SmoothGrad: removing noise by adding noise, Daniel Smilkov et al. 2017NoiseTunnel:Sanity Checks for Saliency Maps, Julius Adebayo et al. 2018NeuronConductance:How Important is a neuron?, Kedar Dhamdhere et al. 2018LayerConductance:Computationally Efficient Measures of Internal Neuron Importance, Avanti Shrikumar et al. 2018DeepLift,NeuronDeepLift,LayerDeepLift:Learning Important Features Through Propagating Activation Differences, Avanti Shrikumar et al. 2017 andTowards better understanding of gradient-based attribution methods for deep neural networks, Marco Ancona et al. 2018NeuronIntegratedGradients:Computationally Efficient Measures of Internal Neuron Importance, Avanti Shrikumar et al. 2018GradientShap,NeuronGradientShap,LayerGradientShap,DeepLiftShap,NeuronDeepLiftShap,LayerDeepLiftShap:A Unified Approach to Interpreting Model Predictions, Scott M. Lundberg et al. 2017InternalInfluence:Influence-Directed Explanations for Deep Convolutional Networks, Klas Leino et al. 2018Saliency,NeuronGradient:Deep Inside Convolutional Networks: VisualisingImage Classification Models and Saliency Maps, K. Simonyan, et. al. 2014GradCAM,Guided GradCAM:Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization, Ramprasaath R. Selvaraju et al. 2017Deconvolution,Neuron Deconvolution:Visualizing and Understanding Convolutional Networks, Matthew D Zeiler et al. 2014Guided Backpropagation,Neuron Guided Backpropagation:Striving for Simplicity: The All Convolutional Net, Jost Tobias Springenberg et al. 2015Feature Permutation:Permutation Feature ImportanceOcclusion:Visualizing and Understanding Convolutional NetworksShapley Value:A value for n-person games. Contributions to the Theory of Games 2.28 (1953): 307-317Shapley Value Sampling:Polynomial calculation of the Shapley value based on samplingInfidelity and Sensitivity:On the (In)fidelity and Sensitivity for ExplanationsTracInCP, TracInCPFast, TracInCPRandProj:Estimating Training Data Influence by Tracing Gradient DescentSimilarityInfluence: [Pairwise similarities between train and test examples based on predefined similarity metrics]BinaryConcreteStochasticGates:Stochastic Gates with Binary Concrete DistributionGaussianStochasticGates:Stochastic Gates with Gaussian Distribution
More details about the above mentionedattribution algorithms and their pros and cons can be found on ourwebsite.
Captum is BSD licensed, as found in theLICENSE file.
About
Model interpretability and understanding for PyTorch
Topics
Resources
License
Code of conduct
Contributing
Uh oh!
There was an error while loading.Please reload this page.

