From b258694932ebdab3cb278c862f35b668aa3ad505 Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 24 May 2023 12:06:37 +0800 Subject: [PATCH 1/4] [zero] init context support bf16 --- colossalai/zero/legacy/init_ctx/init_context.py | 10 +++++++--- colossalai/zero/legacy/sharded_model/_utils.py | 10 +++++++++- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/colossalai/zero/legacy/init_ctx/init_context.py b/colossalai/zero/legacy/init_ctx/init_context.py index a921ca0aa83a..91e93f870544 100644 --- a/colossalai/zero/legacy/init_ctx/init_context.py +++ b/colossalai/zero/legacy/init_ctx/init_context.py @@ -14,7 +14,7 @@ from colossalai.logging import get_dist_logger from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses from colossalai.zero.legacy.shard_utils import BaseShardStrategy -from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16 +from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_bf16, cast_tensor_to_fp16 from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2 from colossalai.zero.legacy.sharded_param import ShardedParamV2 @@ -64,6 +64,7 @@ def __init__(self, seed: int = 2**10 - 1, shard_param: bool = False, default_dtype: Optional[torch.dtype] = None, + bf16: bool = False, model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long)): super().__init__(default_dtype=default_dtype) @@ -71,6 +72,7 @@ def __init__(self, self.param_list = [] self.model_numel_tensor = model_numel_tensor self.seed = seed + self.bf16 = bf16 self.dp_process_group = gpc.get_group(ParallelMode.DATA) self.config = ZeroContextConfig(target_device=target_device, is_replicated=True, shard_param=shard_param) @@ -183,9 +185,10 @@ def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): NOTE() The module may be passed to this function multiple times. """ self.top_module = module + half_dtype = torch.float16 if not self.bf16 else torch.bfloat16 def half_fn(t: torch.Tensor): - return t.half() if t.is_floating_point() else t + return t.to(half_dtype) if t.is_floating_point() else t for param in module.parameters(recurse=False): # avoid adapting a param to ShardedParam twice @@ -226,9 +229,10 @@ def half_fn(t: torch.Tensor): # We must cast buffers # If we use BN, buffers may be on CPU and Float # We must cast them + cast_fn = cast_tensor_to_fp16 if not self.bf16 else cast_tensor_to_bf16 for buffer in module.buffers(recurse=False): buffer.data = buffer.data.to(device=torch.cuda.current_device()) - buffer.data = cast_tensor_to_fp16(buffer.data) + buffer.data = cast_fn(buffer.data) class ZeroContextMgr(metaclass=SingletonMeta): diff --git a/colossalai/zero/legacy/sharded_model/_utils.py b/colossalai/zero/legacy/sharded_model/_utils.py index 2bd01531a78f..f1d642cf3f13 100644 --- a/colossalai/zero/legacy/sharded_model/_utils.py +++ b/colossalai/zero/legacy/sharded_model/_utils.py @@ -43,11 +43,19 @@ def cast_tensor_to_fp32(tensor: Union[torch.Tensor, StatefulTensor]) -> torch.Te if isinstance(tensor, StatefulTensor): tensor = tensor.payload - if torch.is_floating_point(tensor) and tensor.dtype is torch.float16: + if torch.is_floating_point(tensor) and tensor.dtype in (torch.float16, torch.bfloat16): return tensor.float() return tensor +def cast_tensor_to_bf16(tensor: torch.Tensor) -> torch.Tensor: + if isinstance(tensor, StatefulTensor): + tensor = tensor.payload + if torch.is_floating_point(tensor) and tensor.dtype is torch.float32: + return tensor.bfloat16() + return tensor + + def apply_to_tensors(x: Any, fn: Callable): if torch.is_tensor(x): return fn(x) From 58f12417abd14e108fd108671568c551a5e0ea6a Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 24 May 2023 14:17:22 +0800 Subject: [PATCH 2/4] [zero] legacy zero support bf16 --- .../legacy/sharded_model/sharded_model_v2.py | 6 ++- .../legacy/sharded_optim/sharded_optim_v2.py | 39 ++++++++++++------- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/colossalai/zero/legacy/sharded_model/sharded_model_v2.py b/colossalai/zero/legacy/sharded_model/sharded_model_v2.py index b3a83b741825..365cb9cbb4c0 100644 --- a/colossalai/zero/legacy/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/legacy/sharded_model/sharded_model_v2.py @@ -28,6 +28,7 @@ from ._utils import ( cast_float_arguments, + cast_tensor_to_bf16, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, @@ -86,11 +87,13 @@ def __init__(self, tensor_placement_policy: str = 'cuda', gradient_predivide_factor: Optional[float] = 1.0, reuse_fp16_shard: bool = False, + bf16: bool = False, *args, **kwargs): assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.' super().__init__() self.logger = get_dist_logger() + self.bf16 = bf16 # We force users to use ZeroInitContext for submodule in module.modules(): @@ -232,7 +235,8 @@ def _post_forward_operations(self): def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: self._pre_forward_operations(*args) - args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs) + cast_fn = cast_tensor_to_bf16 if self.bf16 else cast_tensor_to_fp16 + args, kwargs = cast_float_arguments(cast_fn, *args, **kwargs) outputs = self.module(*args, **kwargs) self._post_forward_operations() return outputs diff --git a/colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py b/colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py index be60209af434..41dd174cb65a 100644 --- a/colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py @@ -94,6 +94,7 @@ def __init__(self, super().__init__(optimizer) self.shard_strategy = sharded_model.shard_strategy self.model: ShardedModelV2 = sharded_model + self.bf16 = sharded_model.bf16 self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio) assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0' @@ -117,6 +118,7 @@ def __init__(self, self._found_overflow: Tensor = torch.IntTensor([0]).to(torch.cuda.current_device()) self._logger = get_dist_logger("ShardedOptimizerV2") self._verbose = verbose + self._grad_prepared: bool = False # this should be set to true when _prepare_grads() and reset to false when backward # Store fp32 param shards self._register_master_weight() @@ -166,8 +168,10 @@ def zero_grad(self, *args, **kwargs): self._zero_grad() def backward(self, loss: Tensor) -> None: - loss = self.loss_scale * loss - self.optim_state = OptimState.SCALED + if not self.bf16: + loss = self.loss_scale * loss + self.optim_state = OptimState.SCALED + self._grad_prepared = False self.model.backward(loss) def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None: @@ -175,30 +179,33 @@ def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None: # It receives the scaled grad from the previous rank # No need to scale the grad again # Need to unscale when optimizing - self.optim_state = OptimState.SCALED + if not self.bf16: + self.optim_state = OptimState.SCALED + self._grad_prepared = False self.model.backward_by_grad(tensor, grad) def clip_grad_norm(self, model: nn.Module, max_norm: float): - if self.optim_state == OptimState.SCALED: - self._prepare_grads() + self._prepare_grads() + if not self.bf16 and self.optim_state == OptimState.SCALED: self._unscale_grads() return super().clip_grad_norm(model, max_norm) def step(self, *args, **kwargs): + self._prepare_grads() # unscale grads if scaled - if self.optim_state == OptimState.SCALED: - self._prepare_grads() + if not self.bf16 and self.optim_state == OptimState.SCALED: self._unscale_grads() self._maybe_move_fp32_shards() - found_inf = self._check_overflow() - self.grad_scaler.update(found_inf) + if not self.bf16: + found_inf = self._check_overflow() + self.grad_scaler.update(found_inf) - if found_inf: - self._logger.warning('found inf during ShardedOptimV2 step') - self._zero_grad(recover_data=True) - return + if found_inf: + self._logger.warning('found inf during ShardedOptimV2 step') + self._zero_grad(recover_data=True) + return self._point_param_fp16_to_master_param() @@ -304,6 +311,8 @@ def _maybe_move_fp32_shards(self): state[k] = v.cuda() def _prepare_grads(self): + if self._grad_prepared: + return for group in self.optim.param_groups: for p in group['params']: if p.colo_attr.saved_grad.is_null(): @@ -320,6 +329,7 @@ def _prepare_grads(self): p.grad = p.colo_attr.grad_payload # Set p.data to empty tensor, in case of memory leaking p.colo_attr.set_data_none() + self._grad_prepared = True def _point_param_fp16_to_master_param(self): # assign master param pointers to p.data. @@ -357,7 +367,8 @@ def _copy_master_param_to_param_fp16(self, p): torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device)) # TODO() optimize this line CPU (fp32) -> GPU (fp16) - p.colo_attr.sharded_data_tensor.payload_copy(p.half().detach()) + half_dtype = torch.bfloat16 if self.bf16 else torch.float16 + p.colo_attr.sharded_data_tensor.payload_copy(p.to(half_dtype).detach()) p.colo_attr.set_data_none() if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated: From b74038d9c7894ae8db2cae7e4d6ca840870d9d39 Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 24 May 2023 15:27:13 +0800 Subject: [PATCH 3/4] [test] add zero bf16 test --- .../test_zero/test_legacy/test_zero_engine.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/tests/test_zero/test_legacy/test_zero_engine.py b/tests/test_zero/test_legacy/test_zero_engine.py index dc8847ce56ab..826a543db861 100644 --- a/tests/test_zero/test_legacy/test_zero_engine.py +++ b/tests/test_zero/test_legacy/test_zero_engine.py @@ -16,7 +16,11 @@ from tests.components_to_test.registry import non_distributed_component_funcs -def run_dist(rank, world_size, port, parallel_config): +def run_dist(rank, world_size, port, parallel_config, bf16): + is_mp_config = parallel_config == MP_PARALLEL_CONFIG + is_zero_config = parallel_config == ZERO_PARALLEL_CONFIG + if bf16: + parallel_config['zero']['model_config']['bf16'] = True colossalai.launch(config=parallel_config, rank=rank, world_size=world_size, @@ -30,7 +34,8 @@ def run_dist(rank, world_size, port, parallel_config): model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=gpc.config.zero.model_config.shard_strategy, - shard_param=True): + shard_param=True, + bf16=bf16): colo_model = model_builder(checkpoint=True) colo_optimizer = optimizer_class(colo_model.parameters(), lr=1e-3) @@ -38,7 +43,8 @@ def run_dist(rank, world_size, port, parallel_config): optimizer=colo_optimizer, criterion=criterion, train_dataloader=train_dataloader) - torch_model = model_builder(checkpoint=True).half() + dtype = torch.bfloat16 if bf16 else torch.float16 + torch_model = model_builder(checkpoint=True).to(dtype) col_model_deepcopy(engine.model, torch_model) torch_model = torch_model.cuda().float() @@ -80,9 +86,9 @@ def run_dist(rank, world_size, port, parallel_config): torch_optimizer.step() i += 1 - if parallel_config == MP_PARALLEL_CONFIG: + if is_mp_config: check_params(torch_model, colo_model, loose=True) - elif parallel_config == ZERO_PARALLEL_CONFIG: + elif is_zero_config: check_sharded_model_params(torch_model, colo_model, loose=True) @@ -97,9 +103,10 @@ def test_mp_engine(world_size): @pytest.mark.dist @pytest.mark.parametrize("world_size", [1, 2]) +@pytest.mark.parametrize("bf16", [True, False]) @rerun_if_address_is_in_use() -def test_zero_engine(world_size): - spawn(run_dist, world_size, parallel_config=ZERO_PARALLEL_CONFIG) +def test_zero_engine(world_size, bf16): + spawn(run_dist, world_size, parallel_config=ZERO_PARALLEL_CONFIG, bf16=bf16) if __name__ == '__main__': From 75bb1fa1ccef254c4a2a710e6f3b09b67d6f7141 Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 1 Jun 2023 15:38:32 +0800 Subject: [PATCH 4/4] [doc] add bf16 related docstring for legacy zero --- colossalai/zero/legacy/init_ctx/init_context.py | 1 + colossalai/zero/legacy/sharded_model/sharded_model_v2.py | 1 + 2 files changed, 2 insertions(+) diff --git a/colossalai/zero/legacy/init_ctx/init_context.py b/colossalai/zero/legacy/init_ctx/init_context.py index 91e93f870544..a3fa46b38b5a 100644 --- a/colossalai/zero/legacy/init_ctx/init_context.py +++ b/colossalai/zero/legacy/init_ctx/init_context.py @@ -55,6 +55,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): seed (int, optional): Random seed for weight initialization shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False. default_dtype (torch.dtype, optional): If it's not None, parameters will be initialized as ``default_dtype`` then converted to fp16. + bf16 (bool, optional): If it's True, parameters will be initialized as ``torch.bfloat16``. Otherwise, parameters will be initialized as ``torch.float16``. Defaults to False. model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int). """ diff --git a/colossalai/zero/legacy/sharded_model/sharded_model_v2.py b/colossalai/zero/legacy/sharded_model/sharded_model_v2.py index 365cb9cbb4c0..be3842beb208 100644 --- a/colossalai/zero/legacy/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/legacy/sharded_model/sharded_model_v2.py @@ -75,6 +75,7 @@ class ShardedModelV2(nn.Module): In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad). We find that PyTorch's optimizers don't support mixed precision, so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False. + bf16 (bool, optional): Whether to use bfloat16 for param and grad. Defaults to False. """ def __init__(self,