Skip to content

PyTorch Implementation of MD-SNN: Membrane Potential-aware Distillation on Quantized Spiking Neural Network (DATE 2026)

Notifications You must be signed in to change notification settings

Intelligent-Computing-Lab-Panda/MD-SNN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MD-SNN: Membrane-aware Distillation on Quantized Spiking Neural Network [Paper]

Design, Automation and Test in Europe Conference (DATE), 2026

figure.png

Abstract

Spiking Neural Networks (SNNs) offer a promising and energy-efficient alternative to conventional neural networks, thanks to their sparse binary activation. However, they face challenges regarding memory and computation overhead due to complex spatio-temporal dynamics and the necessity for multiple backpropagation computations across timesteps during training. To mitigate this overhead, compression techniques such as quantization are applied to SNNs. Yet, naively applying quantization to SNNs introduces a mismatch in membrane potential, a crucial factor for the firing of spikes, resulting in accuracy degradation. In this paper, we introduce Membrane-aware Distillation on quantized Spiking Neural Network (MD-SNN), which leverages membrane potential to mitigate discrepancies after weight, membrane potential, and batch normalization quantization. To our knowledge, this study represents the first application of membrane potential knowledge distillation in SNNs. We validate our approach on various datasets, including CIFAR10, CIFAR100, N-Caltech101, and TinyImageNet, demonstrating its effectiveness for both static and dynamic data scenarios. Furthermore, for hardware efficiency, we evaluate the MD-SNN with SpikeSim platform, finding that MD-SNNs achieve 14.85× lower energy-delay-area product (EDAP), 2.64× higher TOPS/W, and 6.19× higher TOPS/mm² compared to floating point SNNs at iso-accuracy on N-Caltech101 dataset.

Requirements

torch >= 1.10
torchvision
numpy
tqdm
spikingjelly  # for N-Caltech101 data preprocessing

File Structure

├── main.py                    
├── distiller.py               
├── load_settings.py           
├── data_loader.py              
├── quantization.py             
└── models/
    ├── MS_ResNet.py            
    └── Quan_MS_ResNet.py      

Usage

All experiments use torch.distributed.launch for multi-GPU training.

Step 1: Train Full-Precision Teacher

# CIFAR-10 / CIFAR-100 (ResNet-18, T=4)
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 \
    main.py --data CIFAR100 --paper_setting original --model_depth 18 \
    --time_window 4 --batch_size 256 --seed 1234 --data_path ./data

# TinyImageNet (ResNet-34, T=4)
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 \
    main.py --data TinyImagenet --paper_setting original --model_depth 34 \
    --time_window 4 --batch_size 256 --seed 1234 --data_path ./data

# N-Caltech101 (ResNet-18, T=10)
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 \
    main.py --data ncaltech101 --paper_setting original --model_depth 18 \
    --time_window 10 --batch_size 32 --seed 1234 --data_path ./data

Step 2: Train Quantized Student with MD-SNN (KD)

# CIFAR-100, 4-bit W/U quantization with membrane KL distillation
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 \
    main.py --data CIFAR100 --paper_setting qamd --model_depth 18 \
    --time_window 4 --batch_size 256 --bits 4 --loss CLK --seed 1234 \
    --data_path ./data --teacher_path /path/to/teacher/original-best.pth

# N-Caltech101, 8-bit
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 \
    --master_port 2345 main.py --data ncaltech101 --paper_setting qamd --model_depth 18 \
    --time_window 10 --batch_size 32 --bits 8 --loss CLK --seed 1234 \
    --data_path ./data --teacher_path /path/to/teacher/original-best.pth

Step 3 (Optional): Train Quantized Student without KD

# Quantization-only baseline (no distillation)
CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=1 \
    main.py --data CIFAR100 --paper_setting student --model_depth 18 \
    --time_window 4 --batch_size 256 --bits 4 --seed 1234 --data_path ./data

Key Arguments

Argument Description
--paper_setting original (FP teacher), student (quantized, no KD), qamd (MD-SNN with KD)
--bits Quantization bit-width for weights & membrane potentials (4 or 8)
--loss CLK (CE+Logit+KL, default), CLM (CE+Logit+MSE), CL (CE+Logit only)
--time_window Number of timesteps (T=4 for static datasets, T=10 for DVS datasets)
--teacher_path Path to pretrained teacher checkpoint (required for qamd)
--data_path Root directory for datasets
--save_path Directory to save checkpoints

N-Caltech101 Data Preparation

N-Caltech101 requires preprocessing with SpikingJelly to convert raw events into frames:

from spikingjelly.datasets.n_caltech101 import NCaltech101
dataset = NCaltech101(root='./data/N-Caltech101', data_type='frame', frames_number=10, split_by='number')

After preprocessing, set --data_path to the parent directory containing the N-Caltech101/frames_number_10_split_by_number/ folder.

Citation

If you find MD-SNN useful or relevant to your research, please kindly cite our paper:

@article{lee2025md,
  title={MD-SNN: Membrane Potential-aware Distillation on Quantized Spiking Neural Network},
  author={Lee, Donghyun and Moitra, Abhishek and Kim, Youngeun and Yin, Ruokai and Panda, Priyadarshini},
  journal={arXiv preprint arXiv:2512.04443},
  year={2025}
}

About

PyTorch Implementation of MD-SNN: Membrane Potential-aware Distillation on Quantized Spiking Neural Network (DATE 2026)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages