CriGrad (Critical Gradient) is a novel add-on framework designed to improve the clarity and reliability of feature attribution maps in deep learning models by filtering out gradient noise.
Crigrad added on the following methods are implemented:
| Method | Command | Description |
|---|---|---|
| Baseline | python generate_saliency_script.py |
Runs the standard feature attribution method (e.g., InputGrad). |
| CriGrad-Enhanced | python generate_saliency_script.py --crigrad |
Applies the CriGrad framework to the baseline method for noise reduction. |
To reproduce three metric scores reported in our paper, run the following command. This command executes the evaluation script (eval.py) using a specific threshold
python eval.py --k 0.5
CriGrad is designed to replace the standard gradient calculation within existing input gradient-based saliency methods. This modular approach allows users to easily incorporate CriGrad into other visualization frameworks (e.g., Integrated Gradients, GuidedBackProp, ...) that rely on a gradient path.
If you wish to apply CriGrad to another saliency method, you can follow this general three-step process, demonstrated by the provided InputGradient class structure:
- Initialization (Instantiation of
criGradExtractor) In the initializer (__init__) of your saliency method class, check if CriGrad is enabled. If so, instantiate thecriGradExtractormodule using your target model. This module prepares the model for CriGrad's specialized forward and backward passes.
from saliency.crigrad import criGradExtractor
class InputGradient():
def __init__(...):
self.model = model
self.loss = loss
self.cri_grad = cri_grad
if cri_grad:
self.model_ext = criGradExtractor(model, im_size = im_size, importance_measurement = 'gxi',crigrad_layers= crigrad_layers)
- Forward Pass Modification
When executing the forward pass, replace the standard model call (
self.model(image)) with the specialized CriGrad forward method, which also returns intermediate feature maps.
if self.cri_grad:
outputs, intermediates = self.model_ext._forward_crigrad(image)
else:
outputs = self.model(image)
- Gradient Calculation Replacement (Backward Pass)
Crucially, the standard
torch.autograd.grador equivalent calculation is replaced by the CriGrad backward method, which implements the critical gradient modification:
if self.cri_grad:
gradients = self.model_ext._backward_crigrad(image, target_class, intermediates)
else:
self.model.zero_grad()
# Gradients w.r.t. input and features
gradients = torch.autograd.grad(outputs = agg, inputs = image, only_inputs=True, retain_graph=False)[0]
The implementation of CriGrad Saliency Maps references the framework and methodology of FullGrad Saliency (github link).
torch==2.7.1
torchvision==0.22.1
opencv-python==4.12.0.88
numpy==1.24.1
timm==1.0.17
If you found our work helpful for your research, please do consider citing us.
To Be Updated



