From b5ae587d50f0f331f6ca4330e9bc82a2e01f83bb Mon Sep 17 00:00:00 2001 From: hxwang Date: Tue, 28 May 2024 14:23:22 +0000 Subject: [PATCH 1/7] [gemini] optimize reduce scatter d2h copy --- colossalai/booster/plugin/gemini_plugin.py | 5 ++ colossalai/zero/gemini/chunk/chunk.py | 6 +- colossalai/zero/gemini/gemini_ddp.py | 90 +++++++++------------- 3 files changed, 43 insertions(+), 58 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index eb8db6212835..10406f009eeb 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -368,6 +368,11 @@ def __init__( assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" if get_accelerator().name == "npu": assert placement_policy == "static", "NPU only supports static placement policy" + if placement_policy == "auto" and enable_async_reduce: + logging.warning( + f"enable_async_reduce requires pin_memory to achieve best performance, which is not implicitly set." + ) + pin_memory = True self.gemini_config = dict( chunk_config_dict=chunk_config_dict, chunk_init_device=(chunk_init_device or get_accelerator().get_current_device()), diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index 8f048f0b7183..ed4566fe0981 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -339,7 +339,7 @@ def shard_move(self, device: torch.device, force_copy: bool = False): if self.cuda_shard: return - self.cuda_shard = self.cpu_shard.to(get_accelerator().get_current_device()) + self.cuda_shard = self.cpu_shard.to(get_accelerator().get_current_device(), non_blocking=True) if not self.pin_memory: self.cpu_shard = None @@ -349,7 +349,7 @@ def shard_move(self, device: torch.device, force_copy: bool = False): if self.pin_memory: if force_copy or not self.cpu_vis_flag: - self.cpu_shard.copy_(self.cuda_shard) + self.cpu_shard.copy_(self.cuda_shard, non_blocking=True) # if cpu_shard has been visited # copy operation is not need else: @@ -547,7 +547,7 @@ def __paired_shard_move(self): # only be called when optimizer state is in CPU memory # the grad and param should be in the same device assert self.cuda_shard is None - temp = optim_chunk.cpu_shard.to(get_accelerator().get_current_device()) + temp = optim_chunk.cpu_shard.to(get_accelerator().get_current_device(), non_blocking=True) # avoid to transform FP32 in CPU self.cuda_shard = temp.to(self.dtype) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 23f6ee683657..01f5087246cc 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -145,6 +145,12 @@ def __init__( self.extra_dp_group = extra_dp_group self.master_weights = master_weights + self.enable_async_reduce = enable_async_reduce + + if enable_async_reduce: + self.async_reduce_stream = torch.cuda.Stream() + else: + self.async_reduce_stream = None self._logger = get_dist_logger() @@ -174,6 +180,7 @@ def __init__( super().__init__(module) self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module) self._cast_buffers() + # register grad hook for p in module.parameters(): if is_ddp_ignored(p): @@ -189,7 +196,7 @@ def __init__( master_weights=self.master_weights, enable_gradient_accumulation=self.enable_gradient_accumulation, p=p, - async_reduce=enable_async_reduce, + async_reduce_stream=self.async_reduce_stream, ) ) @@ -337,10 +344,8 @@ def _pre_backward(self): setattr(param, "_gemini_reduced", False) def _post_backward(self): - for param in self.param2name: - if hasattr(param, "_release_grad_chunk_cb"): - param._release_grad_chunk_cb() - delattr(param, "_release_grad_chunk_cb") + if self.enable_async_reduce: + self.async_reduce_stream.synchronize() if self.chunk_manager.accessed_mem != 0: error_params = ["Reduction failed at followed parameters:"] @@ -379,7 +384,7 @@ def grad_handle( master_weights: bool, enable_gradient_accumulation: bool, p: nn.Parameter, - async_reduce: bool, + async_reduce_stream: Optional[torch.cuda.Stream] = None, ): setattr(p, "_gemini_reduced", True) empty_grad = torch.empty_like(grad) @@ -415,56 +420,31 @@ def grad_handle( grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=chunk_manager.reuse_fp16_chunk) else: grad_chunk.add_tensor_to_chunk_slice(p, grad) - reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=async_reduce) - if reduced: # if not async, can release immediately, else release in when work finished - if async_reduce: - # dirty fix by installing callback - assert not hasattr(p, "_release_grad_chunk_cb") - - def _release_grad_chunk_cb(): - grad_chunk.wait_async_reduce() - GeminiDDP.release_grad_chunk_handle( - chunk_manager, - grads_device, - master_weights, - enable_gradient_accumulation, - p, - chunk, - grad_chunk, - ) - - p._release_grad_chunk_cb = _release_grad_chunk_cb - else: - GeminiDDP.release_grad_chunk_handle( - chunk_manager, grads_device, master_weights, enable_gradient_accumulation, p, chunk, grad_chunk - ) - return empty_grad - @staticmethod - def release_grad_chunk_handle( - chunk_manager, grads_device, master_weights, enable_gradient_accumulation, p, chunk, grad_chunk - ): - if not chunk_manager.reuse_fp16_chunk: - if chunk.keep_gathered: - chunk_manager.fake_release_chunk(chunk) - else: - chunk_manager.release_chunk(chunk) - if grad_chunk.is_gathered: - grad_chunk.cuda_global_chunk.div_(chunk.pg_size) - if chunk.extra_dp_group is not None: - grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size) - else: - grad_chunk.cuda_shard.div_(chunk.pg_size) - if chunk.extra_dp_group is not None: - grad_chunk.cuda_shard.div_(chunk.extra_dp_size) - # check overflow elements - chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan - # record l2 norm for gradient clipping. flag is bound to fp16 chunk - if chunk.l2_norm_flag: - grad_chunk.set_l2_norm() - chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True) - if not (master_weights) or (enable_gradient_accumulation): - chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True) + with torch.cuda.stream(async_reduce_stream): + chunk_manager.reduce_chunk(grad_chunk) + + if not chunk_manager.reuse_fp16_chunk: + if chunk.keep_gathered: + chunk_manager.fake_release_chunk(chunk) + else: + chunk_manager.release_chunk(chunk) + if grad_chunk.is_gathered: + grad_chunk.cuda_global_chunk.div_(chunk.pg_size) + if chunk.extra_dp_group is not None: + grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size) + else: + grad_chunk.cuda_shard.div_(chunk.pg_size) + if chunk.extra_dp_group is not None: + grad_chunk.cuda_shard.div_(chunk.extra_dp_size) + # check overflow elements + chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan + # record l2 norm for gradient clipping. flag is bound to fp16 chunk + if chunk.l2_norm_flag: + grad_chunk.set_l2_norm() + chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True) + if not (master_weights) or (enable_gradient_accumulation): + chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True) def zero_grad(self, set_to_none: bool = False) -> None: self.module.zero_grad(set_to_none=True) From fee35678e549ff27a32ffbb2c7fc871d2ff81925 Mon Sep 17 00:00:00 2001 From: hxwang Date: Wed, 29 May 2024 02:09:14 +0000 Subject: [PATCH 2/7] [fix] fix missing reduce variable --- colossalai/zero/gemini/gemini_ddp.py | 44 ++++++++++++++-------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 01f5087246cc..4e301201061a 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -422,29 +422,29 @@ def grad_handle( grad_chunk.add_tensor_to_chunk_slice(p, grad) with torch.cuda.stream(async_reduce_stream): - chunk_manager.reduce_chunk(grad_chunk) - - if not chunk_manager.reuse_fp16_chunk: - if chunk.keep_gathered: - chunk_manager.fake_release_chunk(chunk) + reduced = chunk_manager.reduce_chunk(grad_chunk) + if reduced: + if not chunk_manager.reuse_fp16_chunk: + if chunk.keep_gathered: + chunk_manager.fake_release_chunk(chunk) + else: + chunk_manager.release_chunk(chunk) + if grad_chunk.is_gathered: + grad_chunk.cuda_global_chunk.div_(chunk.pg_size) + if chunk.extra_dp_group is not None: + grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size) else: - chunk_manager.release_chunk(chunk) - if grad_chunk.is_gathered: - grad_chunk.cuda_global_chunk.div_(chunk.pg_size) - if chunk.extra_dp_group is not None: - grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size) - else: - grad_chunk.cuda_shard.div_(chunk.pg_size) - if chunk.extra_dp_group is not None: - grad_chunk.cuda_shard.div_(chunk.extra_dp_size) - # check overflow elements - chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan - # record l2 norm for gradient clipping. flag is bound to fp16 chunk - if chunk.l2_norm_flag: - grad_chunk.set_l2_norm() - chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True) - if not (master_weights) or (enable_gradient_accumulation): - chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True) + grad_chunk.cuda_shard.div_(chunk.pg_size) + if chunk.extra_dp_group is not None: + grad_chunk.cuda_shard.div_(chunk.extra_dp_size) + # check overflow elements + chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan + # record l2 norm for gradient clipping. flag is bound to fp16 chunk + if chunk.l2_norm_flag: + grad_chunk.set_l2_norm() + chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True) + if not (master_weights) or (enable_gradient_accumulation): + chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True) def zero_grad(self, set_to_none: bool = False) -> None: self.module.zero_grad(set_to_none=True) From 58ad76d4665032bbe548d066116d1c572ce98979 Mon Sep 17 00:00:00 2001 From: hxwang Date: Wed, 29 May 2024 02:22:04 +0000 Subject: [PATCH 3/7] [refactor] remove legacy async reduce scatter code --- colossalai/zero/gemini/chunk/chunk.py | 29 +++++---------------- colossalai/zero/gemini/chunk/manager.py | 4 +-- tests/test_zero/test_gemini/test_chunkv2.py | 8 ++---- 3 files changed, 10 insertions(+), 31 deletions(-) diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index ed4566fe0981..6e85335558b2 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -164,8 +164,6 @@ def __init__( self.l2_norm = None self.grad_chunk = None - # the async all-reduce/reduce-scatter work of this grad chunk (None means sync) - self.grad_reduce_work = None @property def memory_usage(self) -> Dict[str, int]: @@ -376,49 +374,34 @@ def release_chunk(self): if self.is_gathered: self.__scatter() - def reduce(self, async_op: bool = False): + def reduce(self): """Reduce scatter all the gradients. It's an operation done in CUDA.""" # sanity check assert self.is_gathered - assert self.grad_reduce_work is None if self.pg_size == 1: # tricky code here # just move cuda_global_chunk to cuda_shard # the communication is not necessary self.__scatter() if self.extra_dp_group is not None: - self.grad_reduce_work = dist.all_reduce(self.cuda_shard, group=self.extra_dp_group, async_op=async_op) + dist.all_reduce(self.cuda_shard, group=self.extra_dp_group) elif self.keep_gathered: # we use all-reduce here - self.grad_reduce_work = dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg, async_op=async_op) + dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg) if self.extra_dp_group is not None: # cannot guranatee the order of multiple all-reduce - self.wait_async_reduce() - self.grad_reduce_work = dist.all_reduce( - self.cuda_global_chunk, group=self.extra_dp_group, async_op=async_op - ) + dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group) else: self.cuda_shard = torch.empty( self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device() ) - input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0)) - self.grad_reduce_work = dist.reduce_scatter( - self.cuda_shard, input_list, group=self.torch_pg, async_op=async_op - ) - + dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg) if self.extra_dp_group is not None: - self.wait_async_reduce() - self.grad_reduce_work = dist.all_reduce(self.cuda_shard, group=self.extra_dp_group, async_op=async_op) - + dist.all_reduce(self.cuda_shard, group=self.extra_dp_group) free_storage(self.cuda_global_chunk) self.is_gathered = False self.__update_tensors_state(TensorState.HOLD) - def wait_async_reduce(self) -> None: - if self.grad_reduce_work is not None: - self.grad_reduce_work.wait() - self.grad_reduce_work = None - def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None: """ Make a transition of the tensor into the next state. diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 6ec595914f37..5ad83d20f186 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -143,12 +143,12 @@ def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None: chunk = self.tensor_chunk_map[tensor] chunk.tensor_trans_state(tensor, state) - def reduce_chunk(self, chunk: Chunk, async_op: bool = False) -> bool: + def reduce_chunk(self, chunk: Chunk) -> bool: """Reduce or all reduce the chunk.""" if not chunk.can_reduce: return False self.__sub_memory_usage(chunk.memory_usage) - chunk.reduce(async_op=async_op) + chunk.reduce() self.__sub_accessed_chunk(chunk) self.__add_memory_usage(chunk.memory_usage) return True diff --git a/tests/test_zero/test_gemini/test_chunkv2.py b/tests/test_zero/test_gemini/test_chunkv2.py index 51b20c400c1c..25731132887b 100644 --- a/tests/test_zero/test_gemini/test_chunkv2.py +++ b/tests/test_zero/test_gemini/test_chunkv2.py @@ -34,8 +34,7 @@ def check_equal(param, param_cp): @parameterize("init_device", [None, torch.device("cpu")]) @parameterize("keep_gathered", [True, False]) @parameterize("pin_memory", [True, False]) -@parameterize("async_op", [True, False]) -def exam_chunk_basic(init_device, keep_gathered, pin_memory, async_op): +def exam_chunk_basic(init_device, keep_gathered, pin_memory): world_size = torch.distributed.get_world_size() pg = _get_default_group() my_chunk = Chunk( @@ -95,12 +94,9 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory, async_op): assert my_chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4 assert my_chunk.can_reduce - my_chunk.reduce(async_op) + my_chunk.reduce() assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 4 - if async_op: - my_chunk.wait_async_reduce() - if keep_gathered is False: assert my_chunk.cuda_shard.size(0) == 1024 // world_size assert my_chunk.device_type == "cuda" From cd736c81f3d9178d5bdf343cdf88c4c51cabafc6 Mon Sep 17 00:00:00 2001 From: hxwang Date: Wed, 29 May 2024 03:20:54 +0000 Subject: [PATCH 4/7] [gemini] missing sync --- colossalai/zero/gemini/gemini_ddp.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 4e301201061a..347ce9a8482f 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -421,6 +421,9 @@ def grad_handle( else: grad_chunk.add_tensor_to_chunk_slice(p, grad) + if async_reduce_stream is not None: + async_reduce_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(async_reduce_stream): reduced = chunk_manager.reduce_chunk(grad_chunk) if reduced: From f454bd960e19ae14c0f8648ed2ba7fc45de4d993 Mon Sep 17 00:00:00 2001 From: hxwang Date: Wed, 29 May 2024 05:39:50 +0000 Subject: [PATCH 5/7] Revert "[refactor] remove legacy async reduce scatter code" This reverts commit 58ad76d4665032bbe548d066116d1c572ce98979. --- colossalai/zero/gemini/chunk/chunk.py | 29 ++++++++++++++++----- colossalai/zero/gemini/chunk/manager.py | 4 +-- tests/test_zero/test_gemini/test_chunkv2.py | 8 ++++-- 3 files changed, 31 insertions(+), 10 deletions(-) diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index 6e85335558b2..ed4566fe0981 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -164,6 +164,8 @@ def __init__( self.l2_norm = None self.grad_chunk = None + # the async all-reduce/reduce-scatter work of this grad chunk (None means sync) + self.grad_reduce_work = None @property def memory_usage(self) -> Dict[str, int]: @@ -374,34 +376,49 @@ def release_chunk(self): if self.is_gathered: self.__scatter() - def reduce(self): + def reduce(self, async_op: bool = False): """Reduce scatter all the gradients. It's an operation done in CUDA.""" # sanity check assert self.is_gathered + assert self.grad_reduce_work is None if self.pg_size == 1: # tricky code here # just move cuda_global_chunk to cuda_shard # the communication is not necessary self.__scatter() if self.extra_dp_group is not None: - dist.all_reduce(self.cuda_shard, group=self.extra_dp_group) + self.grad_reduce_work = dist.all_reduce(self.cuda_shard, group=self.extra_dp_group, async_op=async_op) elif self.keep_gathered: # we use all-reduce here - dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg) + self.grad_reduce_work = dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg, async_op=async_op) if self.extra_dp_group is not None: # cannot guranatee the order of multiple all-reduce - dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group) + self.wait_async_reduce() + self.grad_reduce_work = dist.all_reduce( + self.cuda_global_chunk, group=self.extra_dp_group, async_op=async_op + ) else: self.cuda_shard = torch.empty( self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device() ) + input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0)) - dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg) + self.grad_reduce_work = dist.reduce_scatter( + self.cuda_shard, input_list, group=self.torch_pg, async_op=async_op + ) + if self.extra_dp_group is not None: - dist.all_reduce(self.cuda_shard, group=self.extra_dp_group) + self.wait_async_reduce() + self.grad_reduce_work = dist.all_reduce(self.cuda_shard, group=self.extra_dp_group, async_op=async_op) + free_storage(self.cuda_global_chunk) self.is_gathered = False self.__update_tensors_state(TensorState.HOLD) + def wait_async_reduce(self) -> None: + if self.grad_reduce_work is not None: + self.grad_reduce_work.wait() + self.grad_reduce_work = None + def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None: """ Make a transition of the tensor into the next state. diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 5ad83d20f186..6ec595914f37 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -143,12 +143,12 @@ def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None: chunk = self.tensor_chunk_map[tensor] chunk.tensor_trans_state(tensor, state) - def reduce_chunk(self, chunk: Chunk) -> bool: + def reduce_chunk(self, chunk: Chunk, async_op: bool = False) -> bool: """Reduce or all reduce the chunk.""" if not chunk.can_reduce: return False self.__sub_memory_usage(chunk.memory_usage) - chunk.reduce() + chunk.reduce(async_op=async_op) self.__sub_accessed_chunk(chunk) self.__add_memory_usage(chunk.memory_usage) return True diff --git a/tests/test_zero/test_gemini/test_chunkv2.py b/tests/test_zero/test_gemini/test_chunkv2.py index 25731132887b..51b20c400c1c 100644 --- a/tests/test_zero/test_gemini/test_chunkv2.py +++ b/tests/test_zero/test_gemini/test_chunkv2.py @@ -34,7 +34,8 @@ def check_equal(param, param_cp): @parameterize("init_device", [None, torch.device("cpu")]) @parameterize("keep_gathered", [True, False]) @parameterize("pin_memory", [True, False]) -def exam_chunk_basic(init_device, keep_gathered, pin_memory): +@parameterize("async_op", [True, False]) +def exam_chunk_basic(init_device, keep_gathered, pin_memory, async_op): world_size = torch.distributed.get_world_size() pg = _get_default_group() my_chunk = Chunk( @@ -94,9 +95,12 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory): assert my_chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4 assert my_chunk.can_reduce - my_chunk.reduce() + my_chunk.reduce(async_op) assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 4 + if async_op: + my_chunk.wait_async_reduce() + if keep_gathered is False: assert my_chunk.cuda_shard.size(0) == 1024 // world_size assert my_chunk.device_type == "cuda" From 77e7ce7253d1119c0c744c1bf3b7f7956d9df384 Mon Sep 17 00:00:00 2001 From: hxwang Date: Wed, 29 May 2024 05:43:27 +0000 Subject: [PATCH 6/7] [gemini] further optimize with async all reduce --- colossalai/zero/gemini/gemini_ddp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 347ce9a8482f..b9ea22de8d86 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -425,8 +425,9 @@ def grad_handle( async_reduce_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(async_reduce_stream): - reduced = chunk_manager.reduce_chunk(grad_chunk) + reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=(async_reduce_stream is not None)) if reduced: + grad_chunk.wait_async_reduce() if not chunk_manager.reuse_fp16_chunk: if chunk.keep_gathered: chunk_manager.fake_release_chunk(chunk) From a53d8526539c724f89e426b5181d446f864b820e Mon Sep 17 00:00:00 2001 From: hxwang Date: Tue, 4 Jun 2024 10:08:07 +0000 Subject: [PATCH 7/7] [fix] pass flag from manager to chunk --- colossalai/zero/gemini/chunk/chunk.py | 15 ++++++++------- colossalai/zero/gemini/chunk/manager.py | 2 +- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index bfd6b81d5ab4..18fbf8fc31fa 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -316,12 +316,13 @@ def close_chunk(self): if self.shard_device.type == "cpu": self.cuda_shard = None - def shard_move(self, device: torch.device, force_copy: bool = False): + def shard_move(self, device: torch.device, force_copy: bool = False, non_blocking=False): """Move the shard tensor in the chunk. Args: device: the device to which the shard will move force_copy: if True, copy function is called mandatorily + non_blocking: if True, the operation is non-blocking, the caller is responsible for synchronization """ # sanity check assert not self.is_gathered @@ -329,7 +330,7 @@ def shard_move(self, device: torch.device, force_copy: bool = False): # just use another way for the movement if not self.optim_sync_flag: assert device.type == "cuda" or device.type == "npu", "each chunk should first be moved to CUDA" - self.__paired_shard_move() + self.__paired_shard_move(non_blocking=non_blocking) self.optim_sync_flag = True return @@ -339,7 +340,7 @@ def shard_move(self, device: torch.device, force_copy: bool = False): if self.cuda_shard: return - self.cuda_shard = self.cpu_shard.to(get_accelerator().get_current_device(), non_blocking=True) + self.cuda_shard = self.cpu_shard.to(get_accelerator().get_current_device(), non_blocking=non_blocking) if not self.pin_memory: self.cpu_shard = None @@ -349,11 +350,11 @@ def shard_move(self, device: torch.device, force_copy: bool = False): if self.pin_memory: if force_copy or not self.cpu_vis_flag: - self.cpu_shard.copy_(self.cuda_shard, non_blocking=True) + self.cpu_shard.copy_(self.cuda_shard, non_blocking=non_blocking) # if cpu_shard has been visited # copy operation is not need else: - self.cpu_shard = self.cuda_shard.cpu() + self.cpu_shard = self.cuda_shard.to("cpu", non_blocking=non_blocking) self.cpu_vis_flag = True self.cuda_shard = None else: @@ -542,7 +543,7 @@ def __scatter(self): free_storage(self.cuda_global_chunk) self.is_gathered = False - def __paired_shard_move(self): + def __paired_shard_move(self, non_blocking=False): assert self.paired_chunk is not None, "chunks should be paired before training" optim_chunk = self.paired_chunk assert self.chunk_size == optim_chunk.chunk_size @@ -550,7 +551,7 @@ def __paired_shard_move(self): # only be called when optimizer state is in CPU memory # the grad and param should be in the same device assert self.cuda_shard is None - temp = optim_chunk.cpu_shard.to(get_accelerator().get_current_device(), non_blocking=True) + temp = optim_chunk.cpu_shard.to(get_accelerator().get_current_device(), non_blocking=non_blocking) # avoid to transform FP32 in CPU self.cuda_shard = temp.to(self.dtype) diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 36e7ee57bad4..45066ca898ef 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -117,7 +117,7 @@ def access_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dis return None self.__sub_memory_usage(chunk.memory_usage) if chunk.device_type == "cpu": - chunk.shard_move(get_accelerator().get_current_device()) + chunk.shard_move(get_accelerator().get_current_device(), non_blocking=async_access) maybe_work = self.__add_accessed_chunk(chunk, async_access=async_access) self.__add_memory_usage(chunk.memory_usage) return maybe_work