Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 83 additions & 46 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings
from functools import partial
from pathlib import Path
from types import MethodType
from typing import Callable, Iterator, List, Optional, Tuple, Union

import torch
Expand All @@ -25,9 +26,9 @@
sharded_optimizer_loading_epilogue,
unwrap_optimizer,
)
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
from colossalai.zero import LowLevelZeroOptimizer, zero_model_wrapper, zero_optim_wrapper
from colossalai.zero import LowLevelZeroOptimizer

from .dp_plugin_base import DPPluginBase
from .torch_ddp_plugin import TorchDDPCheckpointIO
Expand All @@ -44,6 +45,34 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32']


class LowLevelZeroModel(ModelWrapper, AMPModelMixin):

def __init__(self, module: nn.Module, precision: str) -> None:
super().__init__(module)
self.dtype = None
if precision == 'fp16':
self.dtype = torch.float16
elif precision == 'bf16':
self.dtype = torch.bfloat16
if self.dtype is not None:
module = module.to(self.dtype)
module = module.to(get_current_device())
self.module = module
self.convert_fn = None
if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)

def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs)

def unwrap(self):
# TODO(ver217): this is a workaround for loading model
return self


class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):

def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
Expand Down Expand Up @@ -165,30 +194,36 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s

sharded_optimizer_loading_epilogue(optimizer)


class LowLevelZeroModel(ModelWrapper):

def __init__(self, module: nn.Module, stage: int, precision: str) -> None:
super().__init__(module)
self.dtype = None
if precision == 'fp16':
self.dtype = torch.float16
elif precision == 'bf16':
self.dtype = torch.bfloat16
module = zero_model_wrapper(module, zero_stage=stage)
if self.dtype is not None:
module = module.to(self.dtype)
module = module.to(get_current_device())
self.module = module
self.convert_fn = None
if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)

def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs)
def save_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool,
use_safetensors: bool):
assert isinstance(model, LowLevelZeroModel)
super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors)

def save_sharded_model(self,
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False):
assert isinstance(model, LowLevelZeroModel)
super().save_sharded_model(model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size,
use_safetensors)

def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True):
assert isinstance(model, LowLevelZeroModel)
super().load_unsharded_model(model.module, checkpoint, strict)
model.update_master_params()

def load_sharded_model(self,
model: LowLevelZeroModel,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True):
assert isinstance(model, LowLevelZeroModel)
super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params()


class LowLevelZeroPlugin(DPPluginBase):
Expand Down Expand Up @@ -248,22 +283,24 @@ def __init__(
super().__init__()
assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training'
assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training'

assert norm_type == 2.0, f'LowLevelZeroPlugin only supports norm_type=2.0 now'
self.stage = stage
self.precision = precision
self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload)
self.optim_kwargs = dict(initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type)
self.zero_optim_kwargs = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
clip_grad_norm=max_norm,
reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload,
partition_grad=(stage == 2),
)
self.verbose = verbose

# set class name with stage, for better error message
Expand Down Expand Up @@ -294,15 +331,15 @@ def configure(
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:

if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.stage, self.precision)
model = LowLevelZeroModel(model, self.precision)

if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
optimizer = zero_optim_wrapper(model.unwrap(),
optimizer,
optim_config=self.zero_optim_config,
**self.optim_kwargs,
verbose=self.verbose)
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(optimizer,
**self.zero_optim_kwargs,
verbose=self.verbose)
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)

return model, optimizer, criterion, dataloader, lr_scheduler

Expand Down
4 changes: 2 additions & 2 deletions colossalai/interface/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .model import ModelWrapper
from .model import AMPModelMixin, ModelWrapper
from .optimizer import OptimizerWrapper

__all__ = ['OptimizerWrapper', 'ModelWrapper']
__all__ = ['OptimizerWrapper', 'ModelWrapper', 'AMPModelMixin']
11 changes: 11 additions & 0 deletions colossalai/interface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,14 @@ def unwrap(self):

def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)


class AMPModelMixin:
"""This mixin class defines the interface for AMP training.
"""

def update_master_params(self):
"""
Update the master parameters for AMP training.
"""
pass
17 changes: 17 additions & 0 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.optim import Optimizer

Expand Down Expand Up @@ -600,3 +601,19 @@ def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, i
ret_block_size += current_block_size

yield ret_block, ret_block_size

def update_master_params(self, model: nn.Module) -> None:
"""Update master params from working params

Args:
model (nn.Module): The model to update master params
"""
for p in model.parameters():
p_id = id(p)
if p_id in self._param_store.working_to_master_param:
master_param = self._param_store.working_to_master_param[p_id]
padding_size = self._param_store.get_param_padding_size(p)
working_param = p.data.view(-1)
if padding_size > 0:
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
12 changes: 12 additions & 0 deletions tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
rerun_if_address_is_in_use,
spawn,
)
from colossalai.zero import LowLevelZeroOptimizer


# stage 1 and 2 process the optimizer/mode the same way
Expand Down Expand Up @@ -50,6 +51,17 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):

booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
# check master weight
assert isinstance(new_optimizer, LowLevelZeroOptimizer)
working_param_id_set = set(id(p) for p in new_model.parameters())
for p_id, master_param in new_optimizer._param_store.working_to_master_param.items():
assert p_id in working_param_id_set
working_param = new_optimizer._param_store.master_to_working_param[id(master_param)]
padding = new_optimizer._param_store.get_param_padding_size(working_param)
padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding))
working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()]
assert torch.equal(working_shard,
master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device))

booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False)
Expand Down