From 40500a382ef087f362626102e26bcb98aff8c1f2 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 31 Oct 2023 16:19:03 -0700 Subject: [PATCH 1/3] Debug distopt contiguous param buffers with uint8 param all-gathers Signed-off-by: Tim Moon --- .../optimizers/distributed_fused_adam.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/apex/contrib/optimizers/distributed_fused_adam.py b/apex/contrib/optimizers/distributed_fused_adam.py index 565bdc1fe..b8581cff0 100644 --- a/apex/contrib/optimizers/distributed_fused_adam.py +++ b/apex/contrib/optimizers/distributed_fused_adam.py @@ -2211,7 +2211,17 @@ def step( state_bucket = self.state["buckets"][bucket_id] shard_size = state_bucket.shard_size param_sync_dtype = state_bucket.param_sync_dtype - if self.contiguous_param_buffer: + if not param_sync_dtype.is_floating_point: + # Allocate temporary buffer for param shard + # Note: Adam kernel only supports floating-point + # dtypes. + params_bucket.params_shard = torch.empty( + [shard_size], + dtype=self.dtype, + device=self.device, + ) + overlap_first_bucket = False + elif self.contiguous_param_buffer: # Construct view into contiguous param buffer if not self._param_buffers: self.init_param_buffer() @@ -2225,16 +2235,6 @@ def step( params_bucket.params_shard = params_bucket.params_bucket[ bucket_start:bucket_end ] - elif not param_sync_dtype.is_floating_point: - # Allocate temporary buffer for param shard - # Note: Adam kernel only supports floating-point - # dtypes. - params_bucket.params_shard = torch.empty( - [shard_size], - dtype=self.dtype, - device=self.device, - ) - overlap_first_bucket = False else: # Allocate param shard buffer params_bucket.params_shard = torch.empty( From 1696bd46be60c06cb6b6c40c25c5a85657e59dfe Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 17 Nov 2023 22:55:21 +0000 Subject: [PATCH 2/3] Add test Signed-off-by: Tim Moon --- apex/contrib/optimizers/distributed_fused_adam.py | 2 +- apex/contrib/test/optimizers/test_dist_adam.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/apex/contrib/optimizers/distributed_fused_adam.py b/apex/contrib/optimizers/distributed_fused_adam.py index b8581cff0..168f3b69b 100644 --- a/apex/contrib/optimizers/distributed_fused_adam.py +++ b/apex/contrib/optimizers/distributed_fused_adam.py @@ -1007,7 +1007,7 @@ def init_param_buffer(self) -> None: if param_buffer_view.dtype != param.dtype: raise RuntimeError( f"Attempted to change a parameter with dtype={param.dtype} " - f"into a buffer view with dtype={param_view_buffer.dtype}" + f"into a buffer view with dtype={param_buffer_view.dtype}" ) param_flat_views.append(param.detach().view(-1)) param_buffer_views.append(param_buffer_view) diff --git a/apex/contrib/test/optimizers/test_dist_adam.py b/apex/contrib/test/optimizers/test_dist_adam.py index 298e5ec16..d41016d32 100644 --- a/apex/contrib/test/optimizers/test_dist_adam.py +++ b/apex/contrib/test/optimizers/test_dist_adam.py @@ -296,6 +296,12 @@ def test_matches_pytorch_int64_param_sync(self): param_sync_dtype=torch.int64, ) + def test_matches_pytorch_int64_param_sync_contiguous_buffers(self): + self.test_matches_pytorch( + param_sync_dtype=torch.int64, + contiguous_buffers=True, + ) + def test_matches_pytorch_uint8_param_sync(self): self.test_matches_pytorch( rtol=0.5, From 42d5d8defc864e20f6f62bfa1655fd1ad063ec85 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sun, 19 Nov 2023 03:58:36 +0000 Subject: [PATCH 3/3] Avoid temporary buffer for param shard in optim step if possible Signed-off-by: Tim Moon --- .../optimizers/distributed_fused_adam.py | 37 ++++++++++++------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/apex/contrib/optimizers/distributed_fused_adam.py b/apex/contrib/optimizers/distributed_fused_adam.py index 168f3b69b..bc7b9f364 100644 --- a/apex/contrib/optimizers/distributed_fused_adam.py +++ b/apex/contrib/optimizers/distributed_fused_adam.py @@ -2210,17 +2210,22 @@ def step( params_bucket = self.ParameterBucket() state_bucket = self.state["buckets"][bucket_id] shard_size = state_bucket.shard_size + dtype = state_bucket.dtype param_sync_dtype = state_bucket.param_sync_dtype if not param_sync_dtype.is_floating_point: - # Allocate temporary buffer for param shard - # Note: Adam kernel only supports floating-point - # dtypes. - params_bucket.params_shard = torch.empty( - [shard_size], - dtype=self.dtype, - device=self.device, - ) + # Make sure param shard buffer is floating-point overlap_first_bucket = False + if ( + state_bucket.params_shard is not None + and dtype.is_floating_point + ): + params_bucket.params_shard = state_bucket.params_shard + else: + params_bucket.params_shard = torch.empty( + [shard_size], + dtype=self.dtype, + device=self.device, + ) elif self.contiguous_param_buffer: # Construct view into contiguous param buffer if not self._param_buffers: @@ -2237,11 +2242,17 @@ def step( ] else: # Allocate param shard buffer - params_bucket.params_shard = torch.empty( - [shard_size], - dtype=param_sync_dtype, - device=self.device, - ) + if ( + state_bucket.params_shard is not None + and dtype == param_sync_dtype + ): + params_bucket.params_shard = state_bucket.params_shard + else: + params_bucket.params_shard = torch.empty( + [shard_size], + dtype=param_sync_dtype, + device=self.device, + ) self._params_buckets[bucket_id] = params_bucket # Apply optimizer step