- Notifications
You must be signed in to change notification settings - Fork1
A XAI Framework to provide Contrastive Whole-output Explanation for Image Classification.
License
vaynexie/CWOX
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
CWOX-2s (Two-Stage Contrastive Whole-output Explanation) is a novel explanation framework where one can examine the evidence for competing classes, and thereby obtains contrastive explanations for Image Classification (seepaper for details and citations).
Requirement:
The main part of the codes is based PyTorch, please refer to therequirements.txt for the detailed requirements of package version;
The building of hierarchical latent tree model (HLTM) in part A requires Java 11 and Scala 2.12.12.
In the following, we give the step-by-step tutorial for generating the CWOX-2s explanations.
CWOX-2s has a preprocessing step that partitions all class labels into confusion clusters with respect to the classifier to be explained. Classes in each of those clusters (e.g., cello, violin) are confusing to the classifier, and are often competing labels for the same object/region in the input image. CWOX-2s does so by analyzing the co-occurrence of labels in classification results and thereby building a hierarchical latent tree model (HLTM):
The codes for learning HLTMs are given in the sub-directoryHLTM, along with the structures of the models obtained forResNet50 andGoogleNet. The HLTM codes output a json file namedoutput_name_fullname.nodes.json, which includes the learned hierarchical latent tree model. The json file is used in the following partition of label confusion clusters. TheResNet50.json in the example code shown below is renamed from theoutput_name_fullname.nodes.json.
When interpreting the output of the classifier on a target image, CWOX-2s obtains a subtree for the top classes by removing from the HLTM all the irrelevant nodes. The top classes are partitioned intoLabel Confusion Clusters by cutting the subtree at a certain level, the default being the lowest level. This is how the two clusters in Figure 1 are obtained from the tree in Figure 2.
The example codes for partitioning the top classes are given below:
'''Partition the Top-k labels into different clusters.User can select different cut_level to be used. In default, we apply the latent node in lowest level (cut_level=0) to divide the top classes into clusters.The JSON files including the HLTM information we obtained for ImageNet Image Classification: ResNet50.json and GoogleNet.json can be found inhttps://github.com/xie-lin-li/CWOX/blob/main/HLTM/result_json or https://github.com/xie-lin-li/CWOX/tree/main/resources'''fromCWOX.apply_hltmimport*clusterify_resnet50=apply_hltm(cut_level=0,json_path="ResNet50.json")'''An Example: the top-5 prediction classes for eval_image/cello_guitar.jpg by ResNet50486-cello 889-violin 402-acoustic guitar 420-banjo 546-electric guitarThe index for specified class can be checked in imagenet_class_index.json'''top_k=[486,889,402,420,546]cluster_resnet50=clusterify_resnet50.get_cluster(top_k)print(cluster_resnet50)'''Output Results: [[486, 889], [402, 420, 546]],which indicates that the top-5 classes are divided into two clusters suggested by the ResNet50 HLTM lowest-level latent node:Cluster 1: cello, violin;Cluster 2: acoustic guitar, banjo, electric guitar.'''
For the following discussions, we assume the top classes for an input x are partitioned into clusters:.
CWOX-2s requires a base explainer, which can be any existing explanation methods, such as Grad-CAM, MWP, LIME and RISE, that yield nonnegative heatmaps. CWOX-2s first runs the base explainer on the confusion clusters (C_i’s) and the individual classes (c_ij’s), and then combines the base heatmaps to form contrastive heatmaps.
A score function is needed in order to produce a base heatmap for a class c. It is usually either the logit of the class (for Grad-CAM and MWP) or the probability
of the class (for RISE and LIME). For confusion clusters, the logit is replaced by the generalized logit
and the probability is replaced by the probability of the cluster
.
CWOX.IOX(algo): Produces a base heatmap using explanation method named algo. Currently, algo = “Grad-CAM”, “MWP”, “RISE”, or “LIME” are supported.
- CWOX.grad_cam_cwox: Produces a base heatmap with Grad-CAM;
- CWOX.excitationbackprop_cwox: Produces a base heatmap with MWP;
- CWOX.rise_cwox: Produces a base heatmap with RISE;
- CWOX.lime_cwox: Produces a base heatmap with LIME.
CWOX.CWOX_2s: Produces contrastive heatmaps with CWOX-2s.
CWOX.plt_wox.plot_cwox: Visualize CWOX-2s results.
The following examples illustrate the use of CWOX-2s to explain the results of ResNet50 on one image. The complete code can be found atCWOX_Example.ipynb, and more testing images can be found in the sub-directoryeval_image.
# Load Needed Packageimporttorchfromtorchvisionimportdatasets,models,transforms,utilsfromPILimportImagefromCWOX.CWOX_2simportCWOX_2sfromCWOX.IOXimportIOXfromCWOX.plt_woximportplot_cwox# Load Image and Modelloader=transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225]) ])img_path='eval_image//cello_guitar.jpg'image=Image.open(img_path).convert('RGB')image=loader(image).float()image=torch.unsqueeze(image,0)model=models.resnet50(pretrained=True)_=model.train(False)# put model in evaluation mode
fromCWOX.grad_cam_cwoximportgrad_cam_cwoxIOX_cluster=IOX(grad_cam_cwox(model,layer='layer4'))IOX_class=IOX(grad_cam_cwox(model,layer='layer3.5.relu'))# Confusion Cluster information from HLTM (see the *Part A. Label Confusion Clusters Idenification* for how to obtain it)cluster_use_final=[[486,889],[402,420,546]]sal_dict=CWOX_2s(image,cluster_use_final,cluster_method=IOX_cluster,class_method=IOX_class,delta=50,multiple_output=False)# Make the plot for the CWOX-2s resultsplot_cwox(sal_dict,image,cluster_use_final)
fromCWOX.excitationbackprop_cwoximportexcitationbackprop_cwox# Update a ResNet model to use :class:`EltwiseSum` for the skip connection.fromCWOX.base_explainer.attribution.excitation_backpropimportupdate_resnetmodel_update=update_resnet(model)_=model_update.train(False)# put model in evaluation modeIOX_cluster=IOX(excitationbackprop_cwox(model_update,layer='layer4'))IOX_class=IOX(excitationbackprop_cwox(model_update,layer='layer4.0.relu'))# Confusion Cluster information from HLTM (see the *Part A. Label Confusion Clusters Idenification* for how to obtain it)cluster_use_final=[[486,889],[402,420,546]]sal_dict=CWOX_2s(image,cluster_use_final,cluster_method=IOX_cluster,class_method=IOX_class,delta=60,multiple_output=False)# Make the plot for the CWOX-2s resultsplot_cwox(sal_dict,image,cluster_use_final)
fromCWOX.base_explainer.attribution.riseimportrisefromCWOX.rise_cwoximportrise_cwoxIOX_cluster=IOX(rise_cwox(model,N=5000,mask_probability=0.3,down_sample_size=15,gpu_batch=30))IOX_class=IOX(rise(model,N=3000,mask_probability=0.14,down_sample_size=10,gpu_batch=30))# Confusion Cluster information from HLTM (see the *Part A. Label Confusion Clusters Idenification* for how to obtain it)cluster_use_final=[[486,889],[402,420,546]]sal_dict=CWOX_2s(image,cluster_use_final,cluster_method=IOX_cluster,class_method=IOX_class,delta=70,multiple_output=True)# Make the plot for the CWOX-2s resultsplot_cwox(sal_dict,image,cluster_use_final)
fromCWOX.base_explainer.attribution.limeimportlimefromCWOX.lime_cwoximportlime_cwoxIOX_cluster=IOX(lime_cwox(model,kernel_size=4,number_sample=2000,gpu_batch=100))IOX_class=IOX(lime(model,kernel_size=4,number_sample=2000,gpu_batch=100))# Confusion Cluster information from HLTM (see the *Part A. Label Confusion Clusters Idenification* for how to obtain it)cluster_use_final=[[486,889],[402,420,546]]sal_dict=CWOX_2s(img_path,cluster_use_final,cluster_method=IOX_cluster,class_method=IOX_class,delta=85,multiple_output=True)# Make the plot for the CWOX-2s resultsplot_cwox(sal_dict,image,cluster_use_final)
The sub-directoryEvaluation provides code to compute Evaluation Metrics for measuring Contrastive Faithfulness.
Application to perform the Contrastive Whole-out Explanation Process. Currently only ResNet50 and GoogleNet for ImageNet Image Classification are supported.
See theREADME page in the sub-directory CWOX_Explainer for the guidelines to use the Application.
- Weiyan Xie (wxieai@cse.ust.hk) (The Hong Kong University of Science and Techonology)