Your new tool for asynchronous deep learning
This repository contains the code to reproduce the results presented in "PETRA: Parallel End-to-end Training with Reversible Architectures". (Link available soon)
pip install -r requirements.txtIf you want to download the ImageNet32 prior to training, you can run:
python scripts/imagenet32loader.pyThe data is assumed to be located in the ./data folder. For ImageNet, you will need to provide the path for the dataset:
python main.py --dataset imagenet --dir PATH_TO_DATASETFor CIFAR-10 and ImageNet32, the data is assumed to be located in the ./data folder.
The code will automatically download those datasets if they are not on disk.
python main.py --dataset cifar10
python main.py --dataset imagenet32For ImageNet, you will need to provide the path for the dataset:
python main.py --dataset imagenet --dir PATH_TO_DATASETThe supported models are resnet18, resnet34, resnet50, revnet18, revnet34, revnet50.
python main.py --model MODELTo train a model using standard backpropagation, you can use:
python main.py --synchronous --store-vjpTo train a model using PETRA, you can use:
python main.py --remove-ctx-param --remove-ctx-inputNote that PETRA only applies to reversible architecture, i.e. the option --remove-ctx-input will not be effective.
Note also that the option --remove-ctx-param alone corresponds to the Diversely Stale Parameters (DSP) approach.
This code supports SGD, Adam and LARS optimizer.
To use the SGD optimizer:
python main.py --optimizer sgd --lr LEARNING_RATE --momentum MOMENTUM --dampening DAMPENING [--nesterov]To use the LARS optimizer:
python main.py --optimizer lars --lr LEARNING_RATE --momentum MOMENTUM --dampening DAMPENING [--nesterov]To use the Adam optimizer:
python main.py --optimizer adam --lr LEARNING_RATE --beta1 BETA1 --beta2 BETA2 [--amsgrad]You can set the weight decay with the option --weight-decay WEIGHT_DECAY. You can also remove weight decay on biases and batch-norm parameters with the option --no-bn-weight-decay.
To perform gradient accumulation:
python main.py --accumulate-steps ACCUMULATION_STEPS [--accumulation-averaging]The option --accumulation-averaging is used for averaging the gradients over the accumulation steps.
If you want to the linear scaling rule from Goyal et al., use the option --goyal-lr-scaling.
This will scale the learning rate according to the equation:
scaled_lr = lr * accumulation_steps * batch_size / 256This code supports multiple schedulers along with linear warm-up.
python main.py --max-epoch MAX_EPOCH --warm-up WARM_UP_EPOCHS --scheduler SCHEDULERTo use the STEPLR scheduler:
python main.py --scheduler steplr --lr-decay-milestones MILESTONE_1, MILESTONE_2,... --lr-decay-fact DECAY_FACTORTo use the Polynomial scheduler:
python main.py --scheduler polynomialTo use the Cosine scheduler:
python main.py --scheduler cosineTo save a checkpoint after each epoch:
python main.py --name-checkpoint CHECKPOINT_PATHTo resume from a checkpoint:
python main.py --resume CHECKPOINT_PATH-
To reproduce the CIFAR-10 results for revnet18:
python main.py --no-git --dataset cifar10 --batch-size 64 -p 78 --workers 4 --model resnet18 --synchronous --store-vjp --remove-ctx-input --remove-ctx-param --optimizer sgd --lr 0.1 --weight-decay 0.0005 --no-bn-weight-decay --nesterov --accumulation-steps 2 --accumulation-averaging --goyal-lr-scaling --scheduler steplr --max-epoch 300 --warm-up 5 --lr-decay-fact 0.1 --lr-decay-milestones 150 225 python main.py --no-git --dataset cifar10 --batch-size 64 -p 78 --workers 4 --model revnet18 --synchronous --store-vjp --remove-ctx-input --remove-ctx-param --optimizer sgd --lr 0.1 --weight-decay 0.0005 --no-bn-weight-decay --nesterov --accumulation-steps 2 --accumulation-averaging --goyal-lr-scaling --scheduler steplr --max-epoch 300 --warm-up 5 --lr-decay-fact 0.1 --lr-decay-milestones 150 225 python main.py --no-git --dataset cifar10 --batch-size 64 -p 78 --workers 4 --model revnet18 --remove-ctx-input --remove-ctx-param --optimizer sgd --lr 0.1 --weight-decay 0.0005 --no-bn-weight-decay --nesterov --accumulation-steps 2 --accumulation-averaging --goyal-lr-scaling --scheduler steplr --max-epoch 300 --warm-up 5 --lr-decay-fact 0.1 --lr-decay-milestones 150 225
-
To reproduce the ImageNet32 results for revnet34:
python main.py --no-git --dataset imagenet32 --batch-size 64 -p 2001 --workers 4 --model resnet34 --synchronous --store-vjp --remove-ctx-input --remove-ctx-param --optimizer sgd --lr 0.1 --weight-decay 0.0001 --no-bn-weight-decay --nesterov --accumulation-steps 2 --accumulation-averaging --goyal-lr-scaling --scheduler steplr --max-epoch 90 --warm-up 5 --lr-decay-fact 0.1 --lr-decay-milestones 30 60 80 python main.py --no-git --dataset imagenet32 --batch-size 64 -p 2001 --workers 4 --model revnet34 --synchronous --store-vjp --remove-ctx-input --remove-ctx-param --optimizer sgd --lr 0.1 --weight-decay 0.0001 --no-bn-weight-decay --nesterov --accumulation-steps 2 --accumulation-averaging --goyal-lr-scaling --scheduler steplr --max-epoch 90 --warm-up 5 --lr-decay-fact 0.1 --lr-decay-milestones 30 60 80 python main.py --no-git --dataset imagenet32 --batch-size 64 -p 2001 --workers 4 --model revnet34 --remove-ctx-input --remove-ctx-param --optimizer sgd --lr 0.1 --weight-decay 0.0001 --no-bn-weight-decay --nesterov --accumulation-steps 2 --accumulation-averaging --goyal-lr-scaling --scheduler steplr --max-epoch 90 --warm-up 5 --lr-decay-fact 0.1 --lr-decay-milestones 30 60 80
-
To reproduce the ImageNet results for revnet50:
python main.py --no-git --dataset imagenet --dir [PATH_TO_DATASET] --batch-size 64 -p 2001 --workers 16 --model resnet50 --synchronous --store-vjp --remove-ctx-input --remove-ctx-param --optimizer sgd --lr 0.1 --weight-decay 0.0001 --no-bn-weight-decay --nesterov --accumulation-steps 4 --accumulation-averaging --goyal-lr-scaling --scheduler steplr --max-epoch 90 --warm-up 5 --lr-decay-fact 0.1 --lr-decay-milestones 30 60 80 python main.py --no-git --dataset imagenet --dir [PATH_TO_DATASET] --batch-size 64 -p 2001 --workers 16 --model revnet50 --synchronous --store-vjp --remove-ctx-input --remove-ctx-param --optimizer sgd --lr 0.1 --weight-decay 0.0001 --no-bn-weight-decay --nesterov --accumulation-steps 4 --accumulation-averaging --goyal-lr-scaling --scheduler steplr --max-epoch 90 --warm-up 5 --lr-decay-fact 0.1 --lr-decay-milestones 30 60 80 python main.py --no-git --dataset imagenet --dir [PATH_TO_DATASET] --batch-size 64 -p 2001 --workers 16 --model revnet50 --remove-ctx-input --remove-ctx-param --optimizer sgd --lr 0.1 --weight-decay 0.0001 --no-bn-weight-decay --nesterov --accumulation-steps 4 --accumulation-averaging --goyal-lr-scaling --scheduler steplr --max-epoch 90 --warm-up 5 --lr-decay-fact 0.1 --lr-decay-milestones 30 60 80
