Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions colossalai/zero/legacy/init_ctx/init_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from colossalai.logging import get_dist_logger
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_bf16, cast_tensor_to_fp16
from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.legacy.sharded_param import ShardedParamV2

Expand Down Expand Up @@ -55,6 +55,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
seed (int, optional): Random seed for weight initialization
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
default_dtype (torch.dtype, optional): If it's not None, parameters will be initialized as ``default_dtype`` then converted to fp16.
bf16 (bool, optional): If it's True, parameters will be initialized as ``torch.bfloat16``. Otherwise, parameters will be initialized as ``torch.float16``. Defaults to False.
model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).
"""

Expand All @@ -64,13 +65,15 @@ def __init__(self,
seed: int = 2**10 - 1,
shard_param: bool = False,
default_dtype: Optional[torch.dtype] = None,
bf16: bool = False,
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long)):

super().__init__(default_dtype=default_dtype)
self.shard_strategy = shard_strategy
self.param_list = []
self.model_numel_tensor = model_numel_tensor
self.seed = seed
self.bf16 = bf16
self.dp_process_group = gpc.get_group(ParallelMode.DATA)

self.config = ZeroContextConfig(target_device=target_device, is_replicated=True, shard_param=shard_param)
Expand Down Expand Up @@ -183,9 +186,10 @@ def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
NOTE() The module may be passed to this function multiple times.
"""
self.top_module = module
half_dtype = torch.float16 if not self.bf16 else torch.bfloat16

def half_fn(t: torch.Tensor):
return t.half() if t.is_floating_point() else t
return t.to(half_dtype) if t.is_floating_point() else t

for param in module.parameters(recurse=False):
# avoid adapting a param to ShardedParam twice
Expand Down Expand Up @@ -226,9 +230,10 @@ def half_fn(t: torch.Tensor):
# We must cast buffers
# If we use BN, buffers may be on CPU and Float
# We must cast them
cast_fn = cast_tensor_to_fp16 if not self.bf16 else cast_tensor_to_bf16
for buffer in module.buffers(recurse=False):
buffer.data = buffer.data.to(device=torch.cuda.current_device())
buffer.data = cast_tensor_to_fp16(buffer.data)
buffer.data = cast_fn(buffer.data)


class ZeroContextMgr(metaclass=SingletonMeta):
Expand Down
10 changes: 9 additions & 1 deletion colossalai/zero/legacy/sharded_model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,19 @@ def cast_tensor_to_fp32(tensor: Union[torch.Tensor, StatefulTensor]) -> torch.Te
if isinstance(tensor, StatefulTensor):
tensor = tensor.payload

if torch.is_floating_point(tensor) and tensor.dtype is torch.float16:
if torch.is_floating_point(tensor) and tensor.dtype in (torch.float16, torch.bfloat16):
return tensor.float()
return tensor


def cast_tensor_to_bf16(tensor: torch.Tensor) -> torch.Tensor:
if isinstance(tensor, StatefulTensor):
tensor = tensor.payload
if torch.is_floating_point(tensor) and tensor.dtype is torch.float32:
return tensor.bfloat16()
return tensor


def apply_to_tensors(x: Any, fn: Callable):
if torch.is_tensor(x):
return fn(x)
Expand Down
7 changes: 6 additions & 1 deletion colossalai/zero/legacy/sharded_model/sharded_model_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from ._utils import (
cast_float_arguments,
cast_tensor_to_bf16,
cast_tensor_to_fp16,
cast_tensor_to_fp32,
chunk_and_pad,
Expand Down Expand Up @@ -74,6 +75,7 @@ class ShardedModelV2(nn.Module):
In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad).
We find that PyTorch's optimizers don't support mixed precision,
so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False.
bf16 (bool, optional): Whether to use bfloat16 for param and grad. Defaults to False.
"""

def __init__(self,
Expand All @@ -86,11 +88,13 @@ def __init__(self,
tensor_placement_policy: str = 'cuda',
gradient_predivide_factor: Optional[float] = 1.0,
reuse_fp16_shard: bool = False,
bf16: bool = False,
*args,
**kwargs):
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
super().__init__()
self.logger = get_dist_logger()
self.bf16 = bf16

# We force users to use ZeroInitContext
for submodule in module.modules():
Expand Down Expand Up @@ -232,7 +236,8 @@ def _post_forward_operations(self):

def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
self._pre_forward_operations(*args)
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
cast_fn = cast_tensor_to_bf16 if self.bf16 else cast_tensor_to_fp16
args, kwargs = cast_float_arguments(cast_fn, *args, **kwargs)
outputs = self.module(*args, **kwargs)
self._post_forward_operations()
return outputs
Expand Down
39 changes: 25 additions & 14 deletions colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(self,
super().__init__(optimizer)
self.shard_strategy = sharded_model.shard_strategy
self.model: ShardedModelV2 = sharded_model
self.bf16 = sharded_model.bf16

self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0'
Expand All @@ -117,6 +118,7 @@ def __init__(self,
self._found_overflow: Tensor = torch.IntTensor([0]).to(torch.cuda.current_device())
self._logger = get_dist_logger("ShardedOptimizerV2")
self._verbose = verbose
self._grad_prepared: bool = False # this should be set to true when _prepare_grads() and reset to false when backward

# Store fp32 param shards
self._register_master_weight()
Expand Down Expand Up @@ -166,39 +168,44 @@ def zero_grad(self, *args, **kwargs):
self._zero_grad()

def backward(self, loss: Tensor) -> None:
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
if not self.bf16:
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
self._grad_prepared = False
self.model.backward(loss)

def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
# This function is called except the last stage of pipeline parallel
# It receives the scaled grad from the previous rank
# No need to scale the grad again
# Need to unscale when optimizing
self.optim_state = OptimState.SCALED
if not self.bf16:
self.optim_state = OptimState.SCALED
self._grad_prepared = False
self.model.backward_by_grad(tensor, grad)

def clip_grad_norm(self, model: nn.Module, max_norm: float):
if self.optim_state == OptimState.SCALED:
self._prepare_grads()
self._prepare_grads()
if not self.bf16 and self.optim_state == OptimState.SCALED:
self._unscale_grads()
return super().clip_grad_norm(model, max_norm)

def step(self, *args, **kwargs):

self._prepare_grads()
# unscale grads if scaled
if self.optim_state == OptimState.SCALED:
self._prepare_grads()
if not self.bf16 and self.optim_state == OptimState.SCALED:
self._unscale_grads()

self._maybe_move_fp32_shards()
found_inf = self._check_overflow()
self.grad_scaler.update(found_inf)
if not self.bf16:
found_inf = self._check_overflow()
self.grad_scaler.update(found_inf)

if found_inf:
self._logger.warning('found inf during ShardedOptimV2 step')
self._zero_grad(recover_data=True)
return
if found_inf:
self._logger.warning('found inf during ShardedOptimV2 step')
self._zero_grad(recover_data=True)
return

self._point_param_fp16_to_master_param()

Expand Down Expand Up @@ -304,6 +311,8 @@ def _maybe_move_fp32_shards(self):
state[k] = v.cuda()

def _prepare_grads(self):
if self._grad_prepared:
return
for group in self.optim.param_groups:
for p in group['params']:
if p.colo_attr.saved_grad.is_null():
Expand All @@ -320,6 +329,7 @@ def _prepare_grads(self):
p.grad = p.colo_attr.grad_payload
# Set p.data to empty tensor, in case of memory leaking
p.colo_attr.set_data_none()
self._grad_prepared = True

def _point_param_fp16_to_master_param(self):
# assign master param pointers to p.data.
Expand Down Expand Up @@ -357,7 +367,8 @@ def _copy_master_param_to_param_fp16(self, p):
torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device))

# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p.colo_attr.sharded_data_tensor.payload_copy(p.half().detach())
half_dtype = torch.bfloat16 if self.bf16 else torch.float16
p.colo_attr.sharded_data_tensor.payload_copy(p.to(half_dtype).detach())
p.colo_attr.set_data_none()

if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated:
Expand Down
21 changes: 14 additions & 7 deletions tests/test_zero/test_legacy/test_zero_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
from tests.components_to_test.registry import non_distributed_component_funcs


def run_dist(rank, world_size, port, parallel_config):
def run_dist(rank, world_size, port, parallel_config, bf16):
is_mp_config = parallel_config == MP_PARALLEL_CONFIG
is_zero_config = parallel_config == ZERO_PARALLEL_CONFIG
if bf16:
parallel_config['zero']['model_config']['bf16'] = True
colossalai.launch(config=parallel_config,
rank=rank,
world_size=world_size,
Expand All @@ -30,15 +34,17 @@ def run_dist(rank, world_size, port, parallel_config):
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
with ZeroInitContext(target_device=torch.cuda.current_device(),
shard_strategy=gpc.config.zero.model_config.shard_strategy,
shard_param=True):
shard_param=True,
bf16=bf16):
colo_model = model_builder(checkpoint=True)

colo_optimizer = optimizer_class(colo_model.parameters(), lr=1e-3)
engine, train_dataloader, _, _ = colossalai.initialize(colo_model,
optimizer=colo_optimizer,
criterion=criterion,
train_dataloader=train_dataloader)
torch_model = model_builder(checkpoint=True).half()
dtype = torch.bfloat16 if bf16 else torch.float16
torch_model = model_builder(checkpoint=True).to(dtype)
col_model_deepcopy(engine.model, torch_model)
torch_model = torch_model.cuda().float()

Expand Down Expand Up @@ -80,9 +86,9 @@ def run_dist(rank, world_size, port, parallel_config):
torch_optimizer.step()
i += 1

if parallel_config == MP_PARALLEL_CONFIG:
if is_mp_config:
check_params(torch_model, colo_model, loose=True)
elif parallel_config == ZERO_PARALLEL_CONFIG:
elif is_zero_config:
check_sharded_model_params(torch_model, colo_model, loose=True)


Expand All @@ -97,9 +103,10 @@ def test_mp_engine(world_size):

@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@pytest.mark.parametrize("bf16", [True, False])
@rerun_if_address_is_in_use()
def test_zero_engine(world_size):
spawn(run_dist, world_size, parallel_config=ZERO_PARALLEL_CONFIG)
def test_zero_engine(world_size, bf16):
spawn(run_dist, world_size, parallel_config=ZERO_PARALLEL_CONFIG, bf16=bf16)


if __name__ == '__main__':
Expand Down