From bd870c426c7bcd18c19526ab7c5c28e20bdc9251 Mon Sep 17 00:00:00 2001 From: BurkeHulk Date: Thu, 14 Nov 2024 13:43:12 +0800 Subject: [PATCH 1/2] remove redundant memcpy during backward --- colossalai/zero/low_level/bookkeeping/bucket_store.py | 6 +++--- colossalai/zero/low_level/low_level_optim.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py index 19d20de2b250..6729d4615f20 100644 --- a/colossalai/zero/low_level/bookkeeping/bucket_store.py +++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py @@ -78,13 +78,13 @@ def build_grad_in_bucket(self): } """ for param, padding_size in zip(self._param_list, self._padding_size): - grad = param.grad.clone().detach().flatten() + grad = param.grad.detach().flatten() if padding_size > 0: with torch.no_grad(): grad = torch.nn.functional.pad(grad.view(-1), [0, padding_size]) grad_list = grad.split(grad.numel() // self._world_size) for rank in range(self._world_size): - grad_current_rank = grad_list[rank].clone().detach() + grad_current_rank = grad_list[rank].detach() self.grad_to_param_mapping[id(grad_current_rank)] = id(param) self._grad_in_bucket[rank].append(grad_current_rank) param.grad = None @@ -110,7 +110,7 @@ def get_flatten_grad(self) -> Tensor: flat_grad = [] for grad_list in self._grad_in_bucket.values(): - flat_grad.append(_flatten_dense_tensors(grad_list)) + flat_grad.extend(grad_list) flat_grad = _flatten_dense_tensors(flat_grad) return flat_grad diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 26fff75fbfdf..93f4deec36ff 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -339,7 +339,7 @@ def _run_reduction(self): if self._overlap_communication: stream = bucket_store.comm_stream # in case of the memory being reused in the default stream - flat_grads.record_stream(stream) + # flat_grads.record_stream(stream) # waiting for ops in the default stream finishing stream.wait_stream(get_accelerator().current_stream()) else: From 3c076664185d140e9d99e454ddbeb65a851c25bf Mon Sep 17 00:00:00 2001 From: BurkeHulk Date: Thu, 14 Nov 2024 13:55:06 +0800 Subject: [PATCH 2/2] get back record_stream --- colossalai/zero/low_level/low_level_optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 93f4deec36ff..26fff75fbfdf 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -339,7 +339,7 @@ def _run_reduction(self): if self._overlap_communication: stream = bucket_store.comm_stream # in case of the memory being reused in the default stream - # flat_grads.record_stream(stream) + flat_grads.record_stream(stream) # waiting for ops in the default stream finishing stream.wait_stream(get_accelerator().current_stream()) else: