Skip to content

Novel add-on framework designed to improve the clarity and reliability of feature attribution maps in deep learning models by filtering out gradient noise.

License

Notifications You must be signed in to change notification settings

heejbin/CriGrad

Repository files navigation

CriGrad Saliency Maps

CriGrad

CriGrad (Critical Gradient) is a novel add-on framework designed to improve the clarity and reliability of feature attribution maps in deep learning models by filtering out gradient noise.

Crigrad added on the following methods are implemented:

  1. Input-gradient (paper)
  2. GxI (paper, paper)
Method Command Description
Baseline python generate_saliency_script.py Runs the standard feature attribution method (e.g., InputGrad).
CriGrad-Enhanced python generate_saliency_script.py --crigrad Applies the CriGrad framework to the baseline method for noise reduction.

Evaluation Metrics and Results

To reproduce three metric scores reported in our paper, run the following command. This command executes the evaluation script (eval.py) using a specific threshold $k=0.5$ for the metrics.

python eval.py --k 0.5

  1. Pixel Perturbation(paper) Pixel Perturbation Result Table
  2. Sensitivity-n(paper) Sensitivity-n Result Table
  3. Remove and Retrain(paper) Remove and Retrain Result Table

Interfaces

CriGrad is designed to replace the standard gradient calculation within existing input gradient-based saliency methods. This modular approach allows users to easily incorporate CriGrad into other visualization frameworks (e.g., Integrated Gradients, GuidedBackProp, ...) that rely on a gradient path.

If you wish to apply CriGrad to another saliency method, you can follow this general three-step process, demonstrated by the provided InputGradient class structure:

  1. Initialization (Instantiation of criGradExtractor) In the initializer (__init__) of your saliency method class, check if CriGrad is enabled. If so, instantiate the criGradExtractor module using your target model. This module prepares the model for CriGrad's specialized forward and backward passes.
from saliency.crigrad import criGradExtractor

class InputGradient():
    def __init__(...):
        self.model = model
        self.loss = loss
        
        self.cri_grad = cri_grad
        if cri_grad:
            self.model_ext = criGradExtractor(model, im_size = im_size, importance_measurement = 'gxi',crigrad_layers= crigrad_layers)
  1. Forward Pass Modification When executing the forward pass, replace the standard model call (self.model(image)) with the specialized CriGrad forward method, which also returns intermediate feature maps.
        if self.cri_grad:
            outputs, intermediates = self.model_ext._forward_crigrad(image)
        else:
            outputs = self.model(image)
  1. Gradient Calculation Replacement (Backward Pass) Crucially, the standard torch.autograd.grad or equivalent calculation is replaced by the CriGrad backward method, which implements the critical gradient modification:
        if self.cri_grad:
            gradients = self.model_ext._backward_crigrad(image, target_class, intermediates)
        else:
            self.model.zero_grad()
            # Gradients w.r.t. input and features
            gradients = torch.autograd.grad(outputs = agg, inputs = image, only_inputs=True, retain_graph=False)[0]
            

The implementation of CriGrad Saliency Maps references the framework and methodology of FullGrad Saliency (github link).

Examples

Example Feature Attribution Map of Running Shoe Example Feature Attribution Map of Sundial

Dependencies

torch==2.7.1
torchvision==0.22.1
opencv-python==4.12.0.88
numpy==1.24.1
timm==1.0.17

Research

If you found our work helpful for your research, please do consider citing us.

To Be Updated

About

Novel add-on framework designed to improve the clarity and reliability of feature attribution maps in deep learning models by filtering out gradient noise.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages