MD-SNN: Membrane-aware Distillation on Quantized Spiking Neural Network [Paper]
Spiking Neural Networks (SNNs) offer a promising and energy-efficient alternative to conventional neural networks, thanks to their sparse binary activation. However, they face challenges regarding memory and computation overhead due to complex spatio-temporal dynamics and the necessity for multiple backpropagation computations across timesteps during training. To mitigate this overhead, compression techniques such as quantization are applied to SNNs. Yet, naively applying quantization to SNNs introduces a mismatch in membrane potential, a crucial factor for the firing of spikes, resulting in accuracy degradation. In this paper, we introduce Membrane-aware Distillation on quantized Spiking Neural Network (MD-SNN), which leverages membrane potential to mitigate discrepancies after weight, membrane potential, and batch normalization quantization. To our knowledge, this study represents the first application of membrane potential knowledge distillation in SNNs. We validate our approach on various datasets, including CIFAR10, CIFAR100, N-Caltech101, and TinyImageNet, demonstrating its effectiveness for both static and dynamic data scenarios. Furthermore, for hardware efficiency, we evaluate the MD-SNN with SpikeSim platform, finding that MD-SNNs achieve 14.85× lower energy-delay-area product (EDAP), 2.64× higher TOPS/W, and 6.19× higher TOPS/mm² compared to floating point SNNs at iso-accuracy on N-Caltech101 dataset.
torch >= 1.10
torchvision
numpy
tqdm
spikingjelly # for N-Caltech101 data preprocessing
├── main.py
├── distiller.py
├── load_settings.py
├── data_loader.py
├── quantization.py
└── models/
├── MS_ResNet.py
└── Quan_MS_ResNet.py
All experiments use torch.distributed.launch for multi-GPU training.
# CIFAR-10 / CIFAR-100 (ResNet-18, T=4)
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 \
main.py --data CIFAR100 --paper_setting original --model_depth 18 \
--time_window 4 --batch_size 256 --seed 1234 --data_path ./data
# TinyImageNet (ResNet-34, T=4)
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 \
main.py --data TinyImagenet --paper_setting original --model_depth 34 \
--time_window 4 --batch_size 256 --seed 1234 --data_path ./data
# N-Caltech101 (ResNet-18, T=10)
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 \
main.py --data ncaltech101 --paper_setting original --model_depth 18 \
--time_window 10 --batch_size 32 --seed 1234 --data_path ./data# CIFAR-100, 4-bit W/U quantization with membrane KL distillation
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 \
main.py --data CIFAR100 --paper_setting qamd --model_depth 18 \
--time_window 4 --batch_size 256 --bits 4 --loss CLK --seed 1234 \
--data_path ./data --teacher_path /path/to/teacher/original-best.pth
# N-Caltech101, 8-bit
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 \
--master_port 2345 main.py --data ncaltech101 --paper_setting qamd --model_depth 18 \
--time_window 10 --batch_size 32 --bits 8 --loss CLK --seed 1234 \
--data_path ./data --teacher_path /path/to/teacher/original-best.pth# Quantization-only baseline (no distillation)
CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=1 \
main.py --data CIFAR100 --paper_setting student --model_depth 18 \
--time_window 4 --batch_size 256 --bits 4 --seed 1234 --data_path ./data| Argument | Description |
|---|---|
--paper_setting |
original (FP teacher), student (quantized, no KD), qamd (MD-SNN with KD) |
--bits |
Quantization bit-width for weights & membrane potentials (4 or 8) |
--loss |
CLK (CE+Logit+KL, default), CLM (CE+Logit+MSE), CL (CE+Logit only) |
--time_window |
Number of timesteps (T=4 for static datasets, T=10 for DVS datasets) |
--teacher_path |
Path to pretrained teacher checkpoint (required for qamd) |
--data_path |
Root directory for datasets |
--save_path |
Directory to save checkpoints |
N-Caltech101 requires preprocessing with SpikingJelly to convert raw events into frames:
from spikingjelly.datasets.n_caltech101 import NCaltech101
dataset = NCaltech101(root='./data/N-Caltech101', data_type='frame', frames_number=10, split_by='number')After preprocessing, set --data_path to the parent directory containing the N-Caltech101/frames_number_10_split_by_number/ folder.
If you find MD-SNN useful or relevant to your research, please kindly cite our paper:
@article{lee2025md,
title={MD-SNN: Membrane Potential-aware Distillation on Quantized Spiking Neural Network},
author={Lee, Donghyun and Moitra, Abhishek and Kim, Youngeun and Yin, Ruokai and Panda, Priyadarshini},
journal={arXiv preprint arXiv:2512.04443},
year={2025}
}