Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions colossalai/pipeline/schedule/interleaved_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None)
assert self.last_batch_size is None or self.last_batch_size == self.batch_size
assert self.batch_size == self.microbatch_size * self.num_microbatch

assert (
self.num_microbatch % self.stage_manager.num_stages == 0
), "Number of microbatch should be an integer multiple of number of pipeline parallel devices"

if self.forward_only:
self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1
# NOTE: disable metadata cache when batch size changes (not valid anymore)
Expand Down
4 changes: 4 additions & 0 deletions colossalai/pipeline/schedule/one_f_one_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None)
assert self.last_batch_size is None or self.last_batch_size == self.batch_size
assert self.batch_size == self.microbatch_size * self.num_microbatches

assert (
self.num_microbatches >= self.stage_manager.num_stages
), "Number of microbatch should be larger than number of stages"

if self.forward_only:
self.num_microbatches = (self.batch_size - 1) // self.microbatch_size + 1
# NOTE: disable metadata cache when batch size changes (not valid anymore)
Expand Down
16 changes: 11 additions & 5 deletions examples/language/llama2/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def main():
parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini")
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
parser.add_argument("--mbs", type=int, default=1)
parser.add_argument("--zero", type=int, default=0)
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
args = parser.parse_args()

colossalai.launch_from_torch({})
Expand All @@ -98,7 +98,13 @@ def empty_init():
extra_dp_size=args.extra_dp,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(placement_policy="auto", precision="bf16", warmup_non_model_data_ratio=args.warmup_ratio, tp_size=args.tp, extra_dp_size=args.extra_dp)
plugin = GeminiPlugin(
placement_policy="auto",
precision="bf16",
warmup_non_model_data_ratio=args.warmup_ratio,
tp_size=args.tp,
extra_dp_size=args.extra_dp,
)
elif args.plugin == "fsdp":
if use_empty_init:
plugin = TorchFSDPPlugin(
Expand Down Expand Up @@ -137,7 +143,7 @@ def empty_init():
zero_stage=args.zero,
num_model_chunks=2,
enable_fused_normalization=torch.cuda.is_available(),
num_microbatches=args.mbs,
microbatch_size=args.mbs,
precision="bf16",
)
elif args.plugin == "3d_cpu":
Expand All @@ -147,7 +153,7 @@ def empty_init():
zero_stage=args.zero,
cpu_offload=True,
enable_fused_normalization=torch.cuda.is_available(),
num_microbatches=args.mbs,
microbatch_size=args.mbs,
initial_scale=2**8,
precision="bf16",
)
Expand Down
2 changes: 1 addition & 1 deletion examples/language/llama2/scripts/benchmark_70B/3d.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ cd ../..

export OMP_NUM_THREADS=8

colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -p 3d -g -x -b 8 --tp 4 --pp 2 --mbs 4
colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -p 3d -g -x -b 8 --tp 4 --pp 2 --mbs 1
2 changes: 1 addition & 1 deletion tests/test_pipeline/test_schedule/test_oneF_oneB.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def run_dist(


@pytest.mark.dist
@pytest.mark.parametrize("num_microbatch", [4, 12])
@pytest.mark.parametrize("num_microbatch", [4, 6])
@pytest.mark.parametrize("batch_size", [12])
@pytest.mark.parametrize("world_size", [2, 4])
@rerun_if_address_is_in_use()
Expand Down