Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,11 @@ def __init__(
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
if get_accelerator().name == "npu":
assert placement_policy == "static", "NPU only supports static placement policy"
if placement_policy == "auto" and enable_async_reduce:
logging.warning(
f"enable_async_reduce requires pin_memory to achieve best performance, which is not implicitly set."
)
pin_memory = True
self.gemini_config = dict(
chunk_config_dict=chunk_config_dict,
chunk_init_device=(chunk_init_device or get_accelerator().get_current_device()),
Expand Down
15 changes: 8 additions & 7 deletions colossalai/zero/gemini/chunk/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,20 +316,21 @@ def close_chunk(self):
if self.shard_device.type == "cpu":
self.cuda_shard = None

def shard_move(self, device: torch.device, force_copy: bool = False):
def shard_move(self, device: torch.device, force_copy: bool = False, non_blocking=False):
"""Move the shard tensor in the chunk.

Args:
device: the device to which the shard will move
force_copy: if True, copy function is called mandatorily
non_blocking: if True, the operation is non-blocking, the caller is responsible for synchronization
"""
# sanity check
assert not self.is_gathered
# when the current chunk is not synchronized with the optimizer
# just use another way for the movement
if not self.optim_sync_flag:
assert device.type == "cuda" or device.type == "npu", "each chunk should first be moved to CUDA"
self.__paired_shard_move()
self.__paired_shard_move(non_blocking=non_blocking)
self.optim_sync_flag = True
return

Expand All @@ -339,7 +340,7 @@ def shard_move(self, device: torch.device, force_copy: bool = False):
if self.cuda_shard:
return

self.cuda_shard = self.cpu_shard.to(get_accelerator().get_current_device())
self.cuda_shard = self.cpu_shard.to(get_accelerator().get_current_device(), non_blocking=non_blocking)

if not self.pin_memory:
self.cpu_shard = None
Expand All @@ -349,11 +350,11 @@ def shard_move(self, device: torch.device, force_copy: bool = False):

if self.pin_memory:
if force_copy or not self.cpu_vis_flag:
self.cpu_shard.copy_(self.cuda_shard)
self.cpu_shard.copy_(self.cuda_shard, non_blocking=non_blocking)
# if cpu_shard has been visited
# copy operation is not need
else:
self.cpu_shard = self.cuda_shard.cpu()
self.cpu_shard = self.cuda_shard.to("cpu", non_blocking=non_blocking)
self.cpu_vis_flag = True
self.cuda_shard = None
else:
Expand Down Expand Up @@ -542,15 +543,15 @@ def __scatter(self):
free_storage(self.cuda_global_chunk)
self.is_gathered = False

def __paired_shard_move(self):
def __paired_shard_move(self, non_blocking=False):
assert self.paired_chunk is not None, "chunks should be paired before training"
optim_chunk = self.paired_chunk
assert self.chunk_size == optim_chunk.chunk_size

# only be called when optimizer state is in CPU memory
# the grad and param should be in the same device
assert self.cuda_shard is None
temp = optim_chunk.cpu_shard.to(get_accelerator().get_current_device())
temp = optim_chunk.cpu_shard.to(get_accelerator().get_current_device(), non_blocking=non_blocking)
# avoid to transform FP32 in CPU
self.cuda_shard = temp.to(self.dtype)

Expand Down
2 changes: 1 addition & 1 deletion colossalai/zero/gemini/chunk/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def access_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dis
return None
self.__sub_memory_usage(chunk.memory_usage)
if chunk.device_type == "cpu":
chunk.shard_move(get_accelerator().get_current_device())
chunk.shard_move(get_accelerator().get_current_device(), non_blocking=async_access)
maybe_work = self.__add_accessed_chunk(chunk, async_access=async_access)
self.__add_memory_usage(chunk.memory_usage)
return maybe_work
Expand Down
94 changes: 39 additions & 55 deletions colossalai/zero/gemini/gemini_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,12 @@ def __init__(
self.extra_dp_group = extra_dp_group

self.master_weights = master_weights
self.enable_async_reduce = enable_async_reduce

if enable_async_reduce:
self.async_reduce_stream = torch.cuda.Stream()
else:
self.async_reduce_stream = None

self._logger = get_dist_logger()

Expand Down Expand Up @@ -176,6 +182,7 @@ def __init__(
super().__init__(module)
self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module)
self._cast_buffers()

# register grad hook
for p in module.parameters():
if is_ddp_ignored(p):
Expand All @@ -191,7 +198,7 @@ def __init__(
master_weights=self.master_weights,
enable_gradient_accumulation=self.enable_gradient_accumulation,
p=p,
async_reduce=enable_async_reduce,
async_reduce_stream=self.async_reduce_stream,
)
)

Expand Down Expand Up @@ -339,10 +346,8 @@ def _pre_backward(self):
setattr(param, "_gemini_reduced", False)

def _post_backward(self):
for param in self.param2name:
if hasattr(param, "_release_grad_chunk_cb"):
param._release_grad_chunk_cb()
delattr(param, "_release_grad_chunk_cb")
if self.enable_async_reduce:
self.async_reduce_stream.synchronize()

if self.chunk_manager.accessed_mem != 0:
error_params = ["Reduction failed at followed parameters:"]
Expand Down Expand Up @@ -381,7 +386,7 @@ def grad_handle(
master_weights: bool,
enable_gradient_accumulation: bool,
p: nn.Parameter,
async_reduce: bool,
async_reduce_stream: Optional[torch.cuda.Stream] = None,
):
setattr(p, "_gemini_reduced", True)
empty_grad = torch.empty_like(grad)
Expand Down Expand Up @@ -417,56 +422,35 @@ def grad_handle(
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=chunk_manager.reuse_fp16_chunk)
else:
grad_chunk.add_tensor_to_chunk_slice(p, grad)
reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=async_reduce)
if reduced: # if not async, can release immediately, else release in when work finished
if async_reduce:
# dirty fix by installing callback
assert not hasattr(p, "_release_grad_chunk_cb")

def _release_grad_chunk_cb():
grad_chunk.wait_async_reduce()
GeminiDDP.release_grad_chunk_handle(
chunk_manager,
grads_device,
master_weights,
enable_gradient_accumulation,
p,
chunk,
grad_chunk,
)

p._release_grad_chunk_cb = _release_grad_chunk_cb
else:
GeminiDDP.release_grad_chunk_handle(
chunk_manager, grads_device, master_weights, enable_gradient_accumulation, p, chunk, grad_chunk
)
return empty_grad

@staticmethod
def release_grad_chunk_handle(
chunk_manager, grads_device, master_weights, enable_gradient_accumulation, p, chunk, grad_chunk
):
if not chunk_manager.reuse_fp16_chunk:
if chunk.keep_gathered:
chunk_manager.fake_release_chunk(chunk)
else:
chunk_manager.release_chunk(chunk)
if grad_chunk.is_gathered:
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
if chunk.extra_dp_group is not None:
grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size)
else:
grad_chunk.cuda_shard.div_(chunk.pg_size)
if chunk.extra_dp_group is not None:
grad_chunk.cuda_shard.div_(chunk.extra_dp_size)
# check overflow elements
chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
if chunk.l2_norm_flag:
grad_chunk.set_l2_norm()
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
if not (master_weights) or (enable_gradient_accumulation):
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
if async_reduce_stream is not None:
async_reduce_stream.wait_stream(torch.cuda.current_stream())

with torch.cuda.stream(async_reduce_stream):
reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=(async_reduce_stream is not None))
if reduced:
grad_chunk.wait_async_reduce()
if not chunk_manager.reuse_fp16_chunk:
if chunk.keep_gathered:
chunk_manager.fake_release_chunk(chunk)
else:
chunk_manager.release_chunk(chunk)
if grad_chunk.is_gathered:
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
if chunk.extra_dp_group is not None:
grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size)
else:
grad_chunk.cuda_shard.div_(chunk.pg_size)
if chunk.extra_dp_group is not None:
grad_chunk.cuda_shard.div_(chunk.extra_dp_size)
# check overflow elements
chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
if chunk.l2_norm_flag:
grad_chunk.set_l2_norm()
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
if not (master_weights) or (enable_gradient_accumulation):
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)

def zero_grad(self, set_to_none: bool = False) -> None:
self.module.zero_grad(set_to_none=True)
Expand Down