From 2020a9a1c04a094fdf3124be31e260c19747a7a4 Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 16 Aug 2023 17:17:41 +0800 Subject: [PATCH] [example] update resnet example result --- examples/images/resnet/README.md | 6 +++--- examples/images/resnet/train.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/images/resnet/README.md b/examples/images/resnet/README.md index c69828637269..9a7493ea31a6 100644 --- a/examples/images/resnet/README.md +++ b/examples/images/resnet/README.md @@ -49,8 +49,8 @@ python eval.py -c ./ckpt-low_level_zero -e 80 Expected accuracy performance will be: -| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 | Booster Low Level Zero | -| --------- | ------------------------ | --------------------- | --------------------- | ---------------------- | -| ResNet-18 | 85.85% | 84.91% | 85.46% | 84.50% | +| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 | Booster Low Level Zero | Booster Gemini | +| --------- | ------------------------ | --------------------- | --------------------- | ---------------------- | -------------- | +| ResNet-18 | 85.85% | 84.91% | 85.46% | 84.50% | 84.60% | **Note: the baseline is adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`** diff --git a/examples/images/resnet/train.py b/examples/images/resnet/train.py index fe0dabf08377..fa300395c9f3 100644 --- a/examples/images/resnet/train.py +++ b/examples/images/resnet/train.py @@ -104,7 +104,7 @@ def main(): '--plugin', type=str, default='torch_ddp', - choices=['torch_ddp', 'torch_ddp_fp16', 'low_level_zero'], + choices=['torch_ddp', 'torch_ddp_fp16', 'low_level_zero', 'gemini'], help="plugin to use") parser.add_argument('-r', '--resume', type=int, default=-1, help="resume from the epoch's checkpoint") parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory") @@ -141,7 +141,7 @@ def main(): if args.plugin.startswith('torch_ddp'): plugin = TorchDDPPlugin() elif args.plugin == 'gemini': - plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5) + plugin = GeminiPlugin(initial_scale=2**5) elif args.plugin == 'low_level_zero': plugin = LowLevelZeroPlugin(initial_scale=2**5)