Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import random
import warnings
from collections import defaultdict
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from copy import deepcopy
from functools import partial
from types import MethodType
Expand Down Expand Up @@ -33,8 +33,11 @@
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.d_tensor.api import is_distributed_tensor
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.zero.low_level import LowLevelZeroOptimizer
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle

from .pp_plugin_base import PipelinePluginBase

Expand All @@ -61,6 +64,7 @@ def __init__(
use_ddp: bool,
ddp_config: dict,
custom_policy: Policy,
overlap_allgather: bool = False,
) -> None:
self.stage_manager = shard_config.pipeline_stage_manager
self.shard_config = shard_config
Expand All @@ -69,6 +73,7 @@ def __init__(
self.sp_group = sp_group
self.use_dpp = use_ddp
self.require_grad_sync = True
self.overlap_allgather = overlap_allgather

shardformer = ShardFormer(shard_config)
if custom_policy is not None:
Expand Down Expand Up @@ -106,6 +111,12 @@ def __init__(
module = DDP(module, process_group=dp_group, **ddp_config)

super().__init__(module)
if overlap_allgather:
self.op_hook = ZeroOpHook()
for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter:
p.__class__ = ColoParameter
p.__init__(p, requires_grad=True)

def sync_shared_params(self):
for shared_param, group in zip(self.shared_params, self.shared_param_process_groups):
Expand Down Expand Up @@ -197,14 +208,22 @@ def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs)
with self._wait_all_gather():
return super().forward(*args, **kwargs)

def unwrap(self):
module = super().unwrap()
if isinstance(module, DDP):
module = module.module
return module

def _force_wait_all_gather(self):
for p in self.module.parameters():
wait_all_gather_handle(p)

def _wait_all_gather(self):
return ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext()


def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
Expand Down Expand Up @@ -650,6 +669,7 @@ def __init__(
tp_process_group: Optional[ProcessGroup] = None, # if using tp
pp_process_group: Optional[ProcessGroup] = None, # if using pp
forced_dtype: Optional[torch.dtype] = None,
overlap_allgather: bool = False,
):
self.model = model
self.param_info = param_info
Expand Down Expand Up @@ -677,7 +697,7 @@ def __init__(
cpu_offload=cpu_offload,
dp_process_group=dp_process_group,
forced_dtype=forced_dtype,
overlap_allgather=False,
overlap_allgather=overlap_allgather,
)

def sync_dp_grads(self):
Expand Down Expand Up @@ -993,6 +1013,7 @@ def __init__(
make_vocab_size_divisible_by: int = 64,
dp_outside: bool = True,
overlap_p2p: bool = True,
overlap_allgather: bool = False,
) -> None:
super().__init__()
assert (
Expand Down Expand Up @@ -1144,6 +1165,7 @@ def __init__(
cpu_offload=cpu_offload,
partition_grad=(self.zero_stage == 2),
forced_dtype=PRECISION_TORCH_TYPE[precision],
overlap_allgather=overlap_allgather,
)

self.max_norm = max_norm
Expand Down Expand Up @@ -1221,6 +1243,7 @@ def configure(
use_ddp=use_ddp,
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]),
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if zero_stage == 0:
Expand Down Expand Up @@ -1303,7 +1326,7 @@ def execute_pipeline(
# so we disable it, performing manual reduction instead.
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()

with ctx:
with ctx, model._wait_all_gather():
outputs = self.schedule.forward_backward_step(
model, data_iter, criterion, optimizer, return_loss, return_outputs
)
Expand Down
16 changes: 9 additions & 7 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class OptimizerParamCheckState(enum.Enum):


class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(self, module: nn.Module, precision: str, overlap_communication: bool = False) -> None:
def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None:
super().__init__(module)
self.dtype = None
if precision == "fp16":
Expand All @@ -76,8 +76,8 @@ def __init__(self, module: nn.Module, precision: str, overlap_communication: boo
self.convert_fn = None
if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
self.overlap_communication = overlap_communication
if overlap_communication:
self.overlap_allgather = overlap_allgather
if overlap_allgather:
self.op_hook = ZeroOpHook()
for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter:
Expand All @@ -88,7 +88,7 @@ def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_communication else nullcontext()
ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext()
with ctx:
return super().forward(*args, **kwargs)

Expand Down Expand Up @@ -356,8 +356,8 @@ def __init__(
partition_grad=(stage == 2),
cpu_offload=cpu_offload,
master_weights=master_weights,
overlap_allgather=overlap_allgather,
)
self.overlap_allgather = overlap_allgather
self.lora_enabled = False
self.verbose = verbose

Expand Down Expand Up @@ -473,11 +473,13 @@ def configure(
self.add_lora_params_to_optimizer(model, optimizer)

if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.precision, overlap_communication=self.overlap_allgather)
model = LowLevelZeroModel(
model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"]
)

# TODO: Support Galore + ZeRO
zero_stage = self.stage
zero_optim_kwargs = {**self.zero_optim_kwargs, "overlap_allgather": self.overlap_allgather}
zero_optim_kwargs = {**self.zero_optim_kwargs}
dp_size = dist.get_world_size()

# Replace with the distributed implementation if exists
Expand Down
4 changes: 4 additions & 0 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def save_sharded_model(
"""

assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
model = model.unwrap()

if os.path.isfile(checkpoint):
Expand Down Expand Up @@ -303,6 +304,7 @@ def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, s
This argument should be manually set to False since params on same device might be stored in different files.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
model._force_wait_all_gather()
model_before_wrapping = model # backup for model before wrapping
model = model.unwrap()

Expand Down Expand Up @@ -639,6 +641,7 @@ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dten
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")

assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
model = model.unwrap()

if self.dp_rank != 0:
Expand Down Expand Up @@ -679,6 +682,7 @@ def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: boo
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")

assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
model._force_wait_all_gather()
strict = False
model_before_wrapping = model
model = model.unwrap()
Expand Down
3 changes: 2 additions & 1 deletion examples/language/llama/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def main():
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--overlap_allgather", action="store_true")
args = parser.parse_args()

colossalai.launch_from_torch()
Expand Down Expand Up @@ -199,9 +200,9 @@ def empty_init():
enable_flash_attention=args.xformers,
microbatch_size=args.mbs,
precision="bf16",
dp_outside=False,
overlap_p2p=args.overlap,
enable_metadata_cache=not args.no_cache,
overlap_allgather=args.overlap_allgather,
**hybrid_kwargs,
)
elif args.plugin == "3d_cpu":
Expand Down