Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions colossalai/booster/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.utils.data import DataLoader

from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.interface import ModelWrapper
from colossalai.interface import ModelWrapper, OptimizerWrapper

from .accelerator import Accelerator
from .mixed_precision import MixedPrecision, mixed_precision_factory
Expand Down Expand Up @@ -153,18 +153,20 @@ def execute_pipeline(self,
# return loss or outputs if needed
pass

def no_sync(self, model: nn.Module) -> contextmanager:
def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager:
"""Context manager to disable gradient synchronization across DP process groups.
Support torch DDP and Low Level ZeRO-1 for now.

Args:
model (nn.Module): The model to be disabled gradient synchronization.
model (nn.Module): The model to be disabled gradient synchronization, for DDP
optimizer (OptimizerWrapper): The optimizer to be disabled gradient synchronization, for ZeRO1-1

Returns:
contextmanager: Context to disable gradient synchronization.
"""
assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.'
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
return self.plugin.no_sync(model)
assert self.plugin.support_no_sync(), f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
return self.plugin.no_sync(model, optimizer)

def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True):
"""Load model from checkpoint.
Expand Down
7 changes: 2 additions & 5 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,10 +295,7 @@ def configure(

if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(model.unwrap(),
optimizer,
self.zero_optim_config,
self.optim_kwargs,
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
self.verbose)

return model, optimizer, criterion, dataloader, lr_scheduler
Expand All @@ -309,5 +306,5 @@ def control_checkpoint_io(self) -> bool:
def get_checkpoint_io(self) -> CheckpointIO:
return GeminiCheckpointIO()

def no_sync(self, model: nn.Module) -> Iterator[None]:
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
raise NotImplementedError
15 changes: 8 additions & 7 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,11 @@ def __init__(
norm_type=norm_type)
self.verbose = verbose

# set class name with stage, for better error message
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")

def support_no_sync(self) -> bool:
return False
return self.stage == 1

def control_precision(self) -> bool:
return True
Expand Down Expand Up @@ -208,10 +211,7 @@ def configure(

if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
optimizer = LowLevelZeroOptimizer(model.unwrap(),
optimizer,
self.zero_optim_config,
self.optim_kwargs,
optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
self.verbose)

return model, optimizer, criterion, dataloader, lr_scheduler
Expand All @@ -222,5 +222,6 @@ def control_checkpoint_io(self) -> bool:
def get_checkpoint_io(self) -> CheckpointIO:
return LowLevelZeroCheckpointIO()

def no_sync(self, model: nn.Module) -> Iterator[None]:
raise NotImplementedError
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert isinstance(optimizer, LowLevelZeroOptimizer)
return optimizer.optim.no_sync()
2 changes: 1 addition & 1 deletion colossalai/booster/plugin/plugin_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_checkpoint_io(self) -> CheckpointIO:
pass

@abstractmethod
def no_sync(self, model: nn.Module) -> Iterator[None]:
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
"""
Context manager to disable gradient synchronization.
"""
Expand Down
2 changes: 1 addition & 1 deletion colossalai/booster/plugin/torch_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,6 @@ def control_checkpoint_io(self) -> bool:
def get_checkpoint_io(self) -> CheckpointIO:
return TorchDDPCheckpointIO()

def no_sync(self, model: nn.Module) -> Iterator[None]:
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert isinstance(model, TorchDDPModel), 'Model is not boosted by TorchDDPPlugin.'
return model.module.no_sync()
2 changes: 1 addition & 1 deletion colossalai/booster/plugin/torch_fsdp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def __init__(
def support_no_sync(self) -> bool:
False

def no_sync(self, model: nn.Module) -> Iterator[None]:
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
raise NotImplementedError("Torch fsdp no_sync func not supported yet.")

def control_precision(self) -> bool:
Expand Down
29 changes: 12 additions & 17 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
)
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.utils import conditional_context
from colossalai.utils.cuda import get_current_device

from ._utils import (
Expand Down Expand Up @@ -56,7 +56,7 @@ def check_local_overflow(self) -> bool:
return False


class LowLevelZeroOptimizer(ColossalaiOptimizer):
class LowLevelZeroOptimizer(OptimizerWrapper):
"""Optimizer used for ZeRO-1 and ZeRO-2.
"""

Expand All @@ -77,11 +77,12 @@ def __init__(
overlap_communication: bool = False,
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
grad_accumulate_interval: int = 1,
forced_dtype: Optional[torch.dtype] = None):

assert not (partition_grad and grad_accumulate_interval > 1), \
"gradient accumulation is not compatible with ZeRO-2"
# TODO:
# 1. process group api
# 2. checkpoint IO

super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
self._dtype = self.optim.param_groups[0]['params'][0].dtype
self._logger = get_dist_logger()
Expand All @@ -94,8 +95,6 @@ def __init__(

# grad accumulation
self.require_grad_sync = True
self._accumulate_intervel = grad_accumulate_interval
self._accumulate_step = 0

colo_pg = self._search_colo_process_group()
if isinstance(colo_pg, ProcessGroup):
Expand Down Expand Up @@ -340,15 +339,15 @@ def _add_to_bucket(self, param, group_id):
################################

def backward(self, loss, retain_graph=False):
assert not(self._partition_grads and not self.require_grad_sync), \
"ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"

if self.mixed_precision_mixin is not None:
loss = self.mixed_precision_mixin.pre_backward(loss)

self._accumulate_step += 1
no_sync = self._accumulate_step < self._accumulate_intervel
with conditional_context(self.no_sync(), enable=no_sync):
loss.backward(retain_graph=retain_graph)
loss.backward(retain_graph=retain_graph)

if no_sync:
if not self.require_grad_sync:
return

self._reduce_grad(self._partition_grads)
Expand Down Expand Up @@ -385,15 +384,14 @@ def zero_grad(self, set_to_none=True):
def step(self, closure=None):
assert closure is None, 'closure is not supported by step()'

if not self._accumulate_step == self._accumulate_intervel:
if not self.require_grad_sync:
return

if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
self._grad_store.reset_all_gradients()
if self._verbose:
self._logger.info(f'Found overflow. Skip step')
self.zero_grad()
self._accumulate_step -= 1
return

# record all grads for unscale and clip
Expand Down Expand Up @@ -463,9 +461,6 @@ def step(self, closure=None):

self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id]

# reset accumulate step
self._accumulate_step = 0

#############################
# Mixed Precision Utilities #
#############################
Expand Down
35 changes: 15 additions & 20 deletions tests/test_zero/test_low_level/test_grad_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import colossalai
from colossalai.testing import spawn
from colossalai.testing.random import seed_all
from colossalai.utils import conditional_context
from colossalai.zero import LowLevelZeroOptimizer


Expand Down Expand Up @@ -39,14 +40,12 @@ def exam_zero_1_2_grad_acc():
overlap_communication=True,
initial_scale=32,
clip_grad_norm=1.0,
grad_accumulate_interval=2,
verbose=True)
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
overlap_communication=True,
partition_grad=True,
initial_scale=32,
clip_grad_norm=1.0,
grad_accumulate_interval=2)
clip_grad_norm=1.0)
# create data
seed_all(2021 + local_rank)
input_data1 = torch.randn(32, 128).cuda()
Expand All @@ -59,8 +58,11 @@ def fwd_bwd_func(number, cur_data, check_flag):
assert torch.equal(zero1_output, zero2_output)

# zero-dp backward
zero1_optimizer.backward(zero1_output.sum().float())
zero2_optimizer.backward(zero2_output.sum().float())
no_sync = number == 0
with conditional_context(zero1_optimizer.no_sync(), no_sync):
zero1_optimizer.backward(zero1_output.sum().float())
with conditional_context(zero2_optimizer.no_sync(), no_sync):
zero2_optimizer.backward(zero2_output.sum().float())

if check_flag:
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
Expand Down Expand Up @@ -101,8 +103,7 @@ def exam_zero_1_grad_acc():
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
overlap_communication=False,
reduce_bucket_size=262144,
clip_grad_norm=1.0,
grad_accumulate_interval=2)
clip_grad_norm=1.0)

torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)

Expand All @@ -112,28 +113,22 @@ def exam_zero_1_grad_acc():
input_data2 = torch.randn(32, 128).cuda()

def fwd_bwd_func(number, cur_data, check_flag):
# zero-dp forward
zero_output = zero_model(cur_data)

# torch-ddp forward
no_sync = number == 0
# zero1 fwd and bwd
with conditional_context(zero_optimizer.no_sync(), no_sync):
zero_output = zero_model(cur_data)
zero_optimizer.backward(zero_output.sum().float())

# zero-dp backward
zero_optimizer.backward(zero_output.sum().float())
# torch-ddp backward
if number < 1:
with torch_model.no_sync():
torch_output = torch_model(cur_data)
assert torch.equal(zero_output, torch_output)
torch_output.sum().backward()
else:
# torch-ddp fwd and bwd
with conditional_context(torch_model.no_sync(), no_sync):
torch_output = torch_model(cur_data)
assert torch.equal(zero_output, torch_output)
torch_output.sum().backward()

if check_flag:
# check grad
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
# print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad)))
assert torch.equal(p.grad, z1p.grad)

fwd_bwd_func(0, input_data1, True)
Expand Down