From 28b3c31c38511e7fa54f618fe333ddef27a7cc8d Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Mon, 15 Jan 2024 12:12:59 +0800 Subject: [PATCH 1/3] fix: fix misleading mbs arg --- examples/language/llama2/benchmark.py | 16 +++++++++++----- .../language/llama2/scripts/benchmark_70B/3d.sh | 2 +- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/examples/language/llama2/benchmark.py b/examples/language/llama2/benchmark.py index a4c29b7c8231..c931e5ba79d9 100644 --- a/examples/language/llama2/benchmark.py +++ b/examples/language/llama2/benchmark.py @@ -74,8 +74,8 @@ def main(): parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size") parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini") parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size") - parser.add_argument("--mbs", type=int, default=1) - parser.add_argument("--zero", type=int, default=0) + parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel") + parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled") args = parser.parse_args() colossalai.launch_from_torch({}) @@ -98,7 +98,13 @@ def empty_init(): extra_dp_size=args.extra_dp, ) elif args.plugin == "gemini_auto": - plugin = GeminiPlugin(placement_policy="auto", precision="bf16", warmup_non_model_data_ratio=args.warmup_ratio, tp_size=args.tp, extra_dp_size=args.extra_dp) + plugin = GeminiPlugin( + placement_policy="auto", + precision="bf16", + warmup_non_model_data_ratio=args.warmup_ratio, + tp_size=args.tp, + extra_dp_size=args.extra_dp, + ) elif args.plugin == "fsdp": if use_empty_init: plugin = TorchFSDPPlugin( @@ -137,7 +143,7 @@ def empty_init(): zero_stage=args.zero, num_model_chunks=2, enable_fused_normalization=torch.cuda.is_available(), - num_microbatches=args.mbs, + microbatch_size=args.mbs, precision="bf16", ) elif args.plugin == "3d_cpu": @@ -147,7 +153,7 @@ def empty_init(): zero_stage=args.zero, cpu_offload=True, enable_fused_normalization=torch.cuda.is_available(), - num_microbatches=args.mbs, + microbatch_size=args.mbs, initial_scale=2**8, precision="bf16", ) diff --git a/examples/language/llama2/scripts/benchmark_70B/3d.sh b/examples/language/llama2/scripts/benchmark_70B/3d.sh index d50c57042d1a..cb8f218fa3fc 100644 --- a/examples/language/llama2/scripts/benchmark_70B/3d.sh +++ b/examples/language/llama2/scripts/benchmark_70B/3d.sh @@ -14,4 +14,4 @@ cd ../.. export OMP_NUM_THREADS=8 -colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -p 3d -g -x -b 8 --tp 4 --pp 2 --mbs 4 +colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -p 3d -g -x -b 8 --tp 4 --pp 2 --mbs 1 From 4a30bf8a5f515837a5f45c6144f3c0fca8e9834d Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Mon, 15 Jan 2024 12:48:45 +0800 Subject: [PATCH 2/3] feat: add pp sanity check --- colossalai/pipeline/schedule/interleaved_pp.py | 4 ++++ colossalai/pipeline/schedule/one_f_one_b.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 0a01a1e7864b..53fc43040831 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -72,6 +72,10 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) assert self.last_batch_size is None or self.last_batch_size == self.batch_size assert self.batch_size == self.microbatch_size * self.num_microbatch + assert ( + self.num_microbatch % self.stage_manager.num_stages == 0 + ), "Number of microbatch should be an integer multiple of number of pipeline parallel devices" + if self.forward_only: self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1 # NOTE: disable metadata cache when batch size changes (not valid anymore) diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index cb078b25faeb..c1dc4dfc3303 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -85,6 +85,10 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) assert self.last_batch_size is None or self.last_batch_size == self.batch_size assert self.batch_size == self.microbatch_size * self.num_microbatches + assert ( + self.num_microbatches % self.stage_manager.num_stages == 0 + ), "Number of microbatch should be an integer multiple of number of pipeline parallel devices" + if self.forward_only: self.num_microbatches = (self.batch_size - 1) // self.microbatch_size + 1 # NOTE: disable metadata cache when batch size changes (not valid anymore) From 1e9980b9a8c0995ece5becc476ce226f6513049a Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Mon, 15 Jan 2024 15:46:57 +0800 Subject: [PATCH 3/3] fix: fix 1f1b sanity check --- colossalai/pipeline/schedule/one_f_one_b.py | 4 ++-- tests/test_pipeline/test_schedule/test_oneF_oneB.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index c1dc4dfc3303..d69f28e74be9 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -86,8 +86,8 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) assert self.batch_size == self.microbatch_size * self.num_microbatches assert ( - self.num_microbatches % self.stage_manager.num_stages == 0 - ), "Number of microbatch should be an integer multiple of number of pipeline parallel devices" + self.num_microbatches >= self.stage_manager.num_stages + ), "Number of microbatch should be larger than number of stages" if self.forward_only: self.num_microbatches = (self.batch_size - 1) // self.microbatch_size + 1 diff --git a/tests/test_pipeline/test_schedule/test_oneF_oneB.py b/tests/test_pipeline/test_schedule/test_oneF_oneB.py index 5f27be39657d..a08dc6d277d0 100644 --- a/tests/test_pipeline/test_schedule/test_oneF_oneB.py +++ b/tests/test_pipeline/test_schedule/test_oneF_oneB.py @@ -155,7 +155,7 @@ def run_dist( @pytest.mark.dist -@pytest.mark.parametrize("num_microbatch", [4, 12]) +@pytest.mark.parametrize("num_microbatch", [4, 6]) @pytest.mark.parametrize("batch_size", [12]) @pytest.mark.parametrize("world_size", [2, 4]) @rerun_if_address_is_in_use()