A comprehensive implementation and evaluation framework for testing cross-architecture generalization in dataset distillation methods. This repository provides systematic tools to answer three key research questions about multi-architecture dataset distillation.
RQ-1: How well does Progressive Dataset Distillation generalize across different neural network architectures (CNNs vs Transformers)?
RQ-2: How does the multi-architecture approach compare to existing cross-architecture generalization methods (Model Pool, MetaDD)?
RQ-3: What are the optimal design choices for multi-architecture dataset distillation (number of stages, architecture sampling, etc.)?
📁 Project Structure
├── 📄 README.md # This comprehensive guide
├── 📄 requirements.txt # Python dependencies
├── 📄 validate_setup.py # Environment validation
├── 📄 getting_started.py # Interactive setup guide
├── 📄 run_experiments.py # Main experiment runner
├── 📄 aggregate_results.py # Results analysis system
├── 📂 configs/
│ └── 📄 experiment_configs.yaml # All experiment configurations
│
├── 📂 Core Architectures
│ └── 📄 architectures.py # SimpleCNN, SmallViT, ResNet8
│
├── 📂 Distillation Methods
│ ├── 📄 dd_baseline_17june.py # Single-Step DD (your existing)
│ ├── 📄 pdd_baseline_17june.py # Progressive DD (your existing)
│ ├── 📄 multi_arch_distill.py # Multi-Architecture Progressive DD ⭐
│ ├── 📄 model_pool_distill.py # Model Pool baseline
│ └── 📄 metadd_distill.py # MetaDD-inspired baseline
│
├── 📂 Evaluation & Analysis
│ ├── 📄 transfer_train.py # Cross-architecture evaluation
│ ├── 📄 eval_metrics.py # Comprehensive metrics
│ └── 📄 plot_synthetic_grid.py # Visualization utilities
│
└── 📂 results/ # Generated results directory
├── 📂 dd_experiment/
├── 📂 pdd_experiment/
├── 📂 multi_arch_pdd/ # Your novel method results
├── 📂 model_pool_dd/
├── 📂 metadd_distill/
├── 📂 transfer_evaluation/
└── 📂 final_analysis/ # Comprehensive comparison
# Clone or download your research files
# Navigate to the project directory
# Create Python environment
python -m venv dd_env
source dd_env/bin/activate # On Windows: dd_env\Scripts\activate
# Install dependencies
pip install -r requirements.txt
# Validate setup
python validate_setup.py# Run the interactive guide (recommended for first-time users)
python getting_started.pyThis will walk you through the entire setup process and run your first experiment.
The configuration system ensures reproducible experiments. Create configs/experiment_configs.yaml:
# Base Configuration
dataset:
name: "MNIST"
num_classes: 10
images_per_class: 10
batch_size: 256
computational_budget:
total_gradient_computations: 3000
architectures:
available: ["SimpleCNN", "SmallViT", "ResNet8"]
training:
lr_synthetic: 0.0001
lr_inner: 0.01
device: "cpu"
# Add experiment-specific configurations...# Run all experiments for all research questions
python run_experiments.py --full_study# Run your core contribution
python run_experiments.py --config multi_arch_pdd
# Or run directly with custom parameters
python multi_arch_distill.py --architectures SimpleCNN SmallViT \
--arch_sampling random --num_stages 5 --save_dir ./results/custom# Single-step DD baseline
python run_experiments.py --config dd_baseline
# Progressive DD baseline
python run_experiments.py --config pdd_baseline
# Model Pool baseline (for RQ-2)
python run_experiments.py --config model_pool_dd
# MetaDD-inspired baseline (for RQ-2)
python run_experiments.py --config metadd_distill# Run all baseline methods first
python run_experiments.py --run_all_baselines
# Then evaluate cross-architecture generalization
python run_experiments.py --transfer_only
# Or evaluate specific method on specific architecture
python transfer_train.py --dataset_path ./results/multi_arch_pdd/multi_arch_pdd_dataset.pt \
--target_arch SmallViT --include_baseline --plot_curves# Systematic hyperparameter sweep
python run_experiments.py --sweep --method multi_arch_pdd
# This tests different values of:
# - Number of stages (1, 2, 5, 10)
# - Sampling strategies (random, all, alternating)
# - Architecture pool sizes
# - Feature alignment weightsThis is your novel approach that directly addresses the cross-architecture generalization problem. Instead of distilling with a single architecture, you randomly sample from multiple architectures during the progressive distillation process.
Key Innovation:
# Traditional PDD uses single architecture
model = SimpleCNN()
# Your approach samples from architecture pool
arch_pool = ['SimpleCNN', 'SmallViT']
for outer_step in range(outer_steps):
sampled_arch = random.choice(arch_pool)
model = get_architecture(sampled_arch)
# ... distillation with diverse architecturesWhy This Works: By exposing the synthetic data to different architectural inductive biases during creation, the resulting dataset captures more generalizable features rather than architecture-specific artifacts.
All methods use exactly 3000 gradient computations for fair comparison:
| Method | Structure | Computations |
|---|---|---|
| Single-Step DD | 3000 iterations × 1 step | 3000 |
| Progressive DD | 5 stages × 200 outer × 3 inner | 3000 |
| Multi-Arch PDD | 5 stages × 200 outer × 3 inner | 3000 |
| Model Pool | 3000 iterations × 1 step | 3000 |
| MetaDD | 3000 iterations × 1 step | 3000 |
The framework tests three fundamentally different architectures:
SimpleCNN (~50K parameters):
- Local receptive fields through convolutions
- Translation equivariance
- Hierarchical feature learning
SmallViT (~45K parameters):
- Global attention mechanisms
- Explicit positional encoding
- Patch-based processing
ResNet8 (~55K parameters):
- Residual connections for gradient flow
- Batch normalization
- Skip connections
The framework provides comprehensive automated analysis:
# Generate complete comparison analysis
python aggregate_results.py --results_dir ./results --output_dir ./analysis
# Creates publication-ready:
# - Cross-architecture performance heatmaps
# - Method comparison bar charts
# - Generalization consistency analysis
# - Performance vs stability trade-offs
# - Statistical significance testsRQ-1 Metrics (Cross-Architecture Generalization):
- Average accuracy across target architectures
- Generalization consistency:
min_accuracy / max_accuracy * 100 - Performance retention:
synthetic_accuracy / real_accuracy * 100 - Cross-architecture gap:
|accuracy_CNN - accuracy_ViT|
RQ-2 Metrics (Method Comparison):
- Performance ranking by mean accuracy
- Stability ranking by accuracy variance
- Efficiency: accuracy per gradient computation
- Statistical significance tests between methods
RQ-3 Metrics (Design Choices):
- Progressive vs single-step performance comparison
- Architecture pool size impact analysis
- Sampling strategy effectiveness evaluation
- Convergence speed and stability analysis
After running the complete study:
📂 results/
├── 📂 dd_experiment/
│ ├── dd_synthetic_dataset.pt # Single-step DD results
│ └── dd_synthetic_images.png
│
├── 📂 multi_arch_pdd/ # Your novel method ⭐
│ ├── multi_arch_pdd_dataset.pt # Key results for RQ-1 & RQ-2
│ ├── multi_arch_pdd_results.json
│ └── multi_arch_pdd_synthetic_images.png
│
├── 📂 transfer_evaluation/ # Cross-architecture results
│ ├── transfer_SimpleCNN_results.json # RQ-1 evaluation data
│ ├── transfer_SmallViT_results.json
│ └── transfer_ResNet8_results.json
│
└── 📂 final_analysis/ # Comprehensive comparison
├── research_findings_summary.txt # Direct answers to all RQs
├── detailed_results.csv
└── 📂 visualizations/
├── cross_architecture_heatmap.png
├── method_comparison.png
└── generalization_consistency.png
# Test specific architecture combinations
python multi_arch_distill.py --architectures SimpleCNN SmallViT ResNet8 \
--arch_sampling all --num_stages 10
# Compare sampling strategies
for strategy in random all alternating; do
python multi_arch_distill.py --arch_sampling $strategy \
--save_dir ./results/sampling_$strategy
done
# Test different stage counts (for RQ-3)
for stages in 1 2 5 10; do
python multi_arch_distill.py --num_stages $stages \
--save_dir ./results/stages_$stages
doneFor RQ-1 (Cross-Architecture Generalization):
# Generate synthetic data with your method
python multi_arch_distill.py --architectures SimpleCNN SmallViT
# Test on all target architectures
for arch in SimpleCNN SmallViT ResNet8; do
python transfer_train.py \
--dataset_path ./results/multi_arch_pdd/multi_arch_pdd_dataset.pt \
--target_arch $arch --include_baseline
done
# Analyze generalization gaps
python eval_metrics.py --transfer_results ./results/transfer_evaluation/For RQ-2 (Method Comparison):
# Run all methods
python run_experiments.py --run_all_baselines
# Statistical comparison
python aggregate_results.py --results_dir ./results \
--generate_significance_testsFor RQ-3 (Design Choices):
# Hyperparameter sweep
python run_experiments.py --sweep
# Analyze results
python aggregate_results.py --sweep_analysisEnvironment Setup Issues:
# Validate environment
python validate_setup.py --verbose
# Auto-fix common problems
python validate_setup.py --fix-issuesCUDA/GPU Issues (Since you're using CPU):
# Ensure CPU-only operation
export CUDA_VISIBLE_DEVICES=""
python run_experiments.py --full_studyMemory Issues:
# Reduce batch size if needed
python multi_arch_distill.py --batch_size 128
# Quick test with smaller dataset
python run_experiments.py --config quick_testMissing Dependencies:
# Install specific missing packages
pip install torch torchvision matplotlib seaborn pandas numpy PyYAML tqdm
# Or reinstall everything
pip install -r requirements.txt --force-reinstallThis framework follows rigorous experimental design principles essential for publishable research:
Controlled Variables: All methods use identical computational budgets, dataset sizes, evaluation protocols, and random seeds.
Fair Comparison: The 3000 gradient computation budget ensures no method has an unfair advantage due to more optimization steps.
Statistical Rigor: Multiple random seeds, confidence intervals, and significance testing ensure reliable conclusions.
Reproducibility: Configuration-driven experiments with fixed seeds enable exact reproduction of results.
RQ-1 Experimental Design: Your multi-architecture progressive method is compared against single-architecture baselines using identical synthetic dataset sizes. Cross-architecture transfer evaluation measures how well synthetic data from each method transfers to held-out architectures.
RQ-2 Experimental Design: Direct comparison with existing methods (Model Pool, MetaDD) using identical evaluation protocols. Statistical significance testing determines whether performance differences are meaningful.
RQ-3 Experimental Design: Systematic hyperparameter sweeps test the impact of key design choices while maintaining computational budget constraints. This identifies optimal configurations for practical deployment.
The framework generates publication-ready outputs:
Tables: Detailed performance comparisons with statistical confidence intervals.
Figures: High-quality visualizations showing cross-architecture performance, method comparisons, and design choice impacts.
Statistical Analysis: Significance tests, effect sizes, and confidence intervals for all major claims.
Reproducibility Information: Complete experimental configurations and random seeds for replication.
# Template for new methods
def run_new_method(args):
# 1. Create synthetic dataset (100 images, 10 per class)
synthetic_images, synthetic_labels = create_synthetic_dataset(...)
# 2. Distillation loop (exactly 3000 gradient computations)
for iteration in range(3000):
# Your distillation logic here
# Must update synthetic_images based on some loss
pass
# 3. Save results in consistent format
torch.save({
'synthetic_images': synthetic_images,
'synthetic_labels': synthetic_labels,
'method_name': 'Your Method',
'args': args
}, save_path)
return synthetic_images, synthetic_labels# Add to architectures.py
class NewArchitecture(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
# Your architecture here
def forward(self, x):
# Forward pass
return output
# Update the factory function
def get_architecture(arch_name):
if arch_name == 'NewArchitecture':
return NewArchitecture()
# ... existing architecturesBefore running your experiments:
- ✅ Environment setup completed (
python validate_setup.py) - ✅ Configuration file created (
configs/experiment_configs.yaml) - ✅ Dependencies installed (
pip install -r requirements.txt) - ✅ Test run successful (
python getting_started.py) - ✅ Sufficient disk space (~2GB for complete study)
- ✅ Expected runtime understanding (2-4 hours for full study)
Your research will be successful when you can demonstrate:
For RQ-1: Clear evidence that your multi-architecture progressive method achieves better cross-architecture generalization than single-architecture baselines.
For RQ-2: Statistical comparison showing your method's performance relative to existing approaches like Model Pool and MetaDD.
For RQ-3: Identification of optimal design choices through systematic hyperparameter analysis.
- Run Initial Validation:
python validate_setup.py - Follow Interactive Guide:
python getting_started.py - Test Single Method:
python run_experiments.py --config multi_arch_pdd - Run Complete Study:
python run_experiments.py --full_study - Analyze Results: Check
./final_analysis/research_findings_summary.txt
Your comprehensive dataset distillation research framework is ready to generate the evidence needed to answer your research questions systematically and rigorously! 🚀