Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/language/gpt/gemini/test_ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ $(cd `dirname $0`;pwd)
export TRAIN_STEP=4

for MODEL_TYPE in "gpt2_medium"; do
for DISTPLAN in "colossalai"; do
for DISTPLAN in "CAI_Gemini"; do
for BATCH_SIZE in 2; do
for GPUNUM in 1 4; do
for TPDEGREE in 1 2; do
Expand Down
57 changes: 28 additions & 29 deletions examples/language/gpt/gemini/train_gpt_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
from torch.nn.parallel import DistributedDataParallel as DDP

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
from colossalai.zero import ColoInitContext

CAI_VERSION = colossalai.__version__

Expand Down Expand Up @@ -236,23 +238,6 @@ def main():
tensor_parallelize(model, tp_pg)

# asign running configurations
gemini_config = None
if args.distplan.startswith("CAI_ZeRO"):
optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
elif args.distplan == "CAI_Gemini":
gemini_config = dict(strict_ddp_mode=args.tp_degree == 1,
device=get_current_device(),
placement_policy=args.placement,
pin_memory=True,
hidden_dim=model.config.n_embd,
search_range_mb=128)
optim_config = dict(gpu_margin_mem_ratio=0.)
else:
raise RuntimeError

# build a highly optimized gpu/cpu optimizer
optimizer = HybridAdam(model.parameters(), lr=1e-3)

if args.distplan == "CAI_ZeRO1":
zero_stage = 1
elif args.distplan == "CAI_ZeRO2":
Expand All @@ -262,22 +247,42 @@ def main():
else:
raise RuntimeError

# wrap your model and optimizer
model = zero_model_wrapper(model, zero_stage, gemini_config)
optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config)
plugin = None
if args.distplan.startswith("CAI_ZeRO"):
plugin = LowLevelZeroPlugin(stage=zero_stage,
reduce_bucket_size_in_m=12 * 1024 * 1024,
overlap_communication=True,
verbose=True)
elif args.distplan == "CAI_Gemini":
plugin = GeminiPlugin(device=get_current_device(),
placement_policy=args.placement,
pin_memory=True,
strict_ddp_mode=args.tp_degree == 1,
search_range_mb=128,
hidden_dim=model.config.n_embd,
gpu_margin_mem_ratio=0.)
else:
raise RuntimeError

# build a highly optimized gpu/cpu optimizer
optimizer = HybridAdam(model.parameters(), lr=1e-3)

logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
elif args.distplan.startswith("Pytorch"):
assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples."
model = model_builder(args.model_type)(checkpoint=True).cuda()
model = DDP(model)
plugin = TorchDDPPlugin()
if args.distplan.endswith("DDP"):
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
elif args.distplan.endswith("ZeRO"):
from torch.distributed.optim import ZeroRedundancyOptimizer
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3)

else:
raise RuntimeError
# wrap your model and optimizer
booster = Booster(plugin=plugin)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)

# model is shared after TP
numel = get_model_size(model)
Expand Down Expand Up @@ -305,13 +310,7 @@ def train_step():
fwd_end = time()
fwd_time = fwd_end - start
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0])

if args.distplan.startswith("CAI"):
optimizer.backward(loss)
elif args.distplan.startswith("Pytorch"):
loss.backward()
else:
raise RuntimeError
booster.backward(loss, optimizer)

torch.cuda.synchronize()
bwd_end = time()
Expand Down