This repository contains code for NdLinear: Don't Flatten! Building Superior Neural Architectures by Preserving N-D Structure at NeurIPS 2025.
- Requirements
- Vision Models
- NLP Models
- Tabular Models
- Open Pre-trained Transformer (OPT) Models
- LLAMA-based Models
Install the dependencies with:
pip install -r requirements.txt- File:
cnn_img_classification.py - Description: Standard CNN pipeline for image classification tasks.
- Usage:
python cnn_img_classification.py- Files:
vit/vit.py,vit/vit_distill.py,vit/ndlinear.py - Description: Implements Vision Transformer model and a distillation variant.
- Usage:
torchrun --nnodes 1 --nproc_per_node=4 \
src/vit_distill.py \
--num_epochs 30 \
--num_transformers 6 \
--dataset CIFAR100 \
--batch_size 256 \
--lr 2.75e-4- Files:
dit/train_dit.py,dit/model.py,dit/mlp_ndlinear.py,dit/ndlinear.py - Description: Implements a DiT model for vision tasks, including custom nonlinear/MLP layers.
- Usage:
python dit/train_dit.py \
--feature-path /path/to/features \
--results-dir ./results \
--model DiT-XS/8 \
--image-size 256 \
--num-classes 100 \
--epochs 1400 \
--global-batch-size 128 \
--num-workers 8 \
--log-every 100 \
--ckpt-every 50000 \
--lr 2e-4 \
--accumulation-steps 1 \
--use-ndmlp \
--use-ndtse \
--mlp-variant 4 \
--tse-scale-factor 1 \
--use-num-transforms 2- File:
txt_classify_bert.py - Description: Uses HuggingFace BERT for text classification.
- Usage:
python txt_classify_bert.py \
--learning_rate 3e-5 \
--epochs 10 \
--batch_size 32 \
--file_path data/CoLA/train.tsv- File:
tabluar/mlp_cls.py - Description: Compares dense Linear and NdLinear neural architectures for tabular data classification on the Cardio Disease dataset, reporting accuracy and plotting learning curves.
- Usage:
python tabluar/mlp_cls.pyThis will train both models on datasets/cardio_disease.csv, and produce:
training_loss_classification.pngtest_accuracy_classification.png
- File:
tabluar/mlp_reg.py - Description: Benchmarks Linear vs. NdLinear models for regression on the Delivery Time dataset, comparing mean squared error and saving loss curves.
- Usage:
python tabluar/mlp_reg.pyThis will train both models on datasets/delivery_time.csv, and produce:
training_loss_regression.pngtest_loss_regression.png
See opt/README.md for instructions on using OPT models.
- Files:
llama/accelerate_llama_control.pyllama/evaluate_ndlinear_model.pyllama/ndlinear.pyllama/Makefile
- Description: Inference, evaluation, and custom (NdLinear) adaptation layers for LLaMA models.
With Makefile:
make install
make run_factorized_ndlinear_loraManual Execution:
python llama/accelerate_llama_control.py \
--model_name "Qwen/Qwen3-1.7B-Base" \
--dataset "lmms-lab/Math10K" \
--lora_type "ndlinear" \
--output_dir "./output_qwen3_1.7B_math10k_ndlinear_factorized" \
--target_modules "q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj" \
--lora_r 1 \
--lora_alpha 1 \
--epochs 2 \
--batch_size 1 \
--learning_rate 1e-4 \
--max_length 512 \
--seed 42Evaluation:
python llama/evaluate_ndlinear_model.py \
--model_path "./output_qwen3_1.7B_math10k_ndlinear_factorized/ndlinear_lora/" \
--base_model_name "Qwen/Qwen3-1.7B-Base" \
--dataset_name "openai/gsm8k" \
--dataset_config "main" \
--dataset_split "test" \
--max_examples 2000 \
--output_dir "./evaluation_results_test_ndlinear_model"See llama/Makefile and docstrings within the scripts for more advanced options and details.