diff --git a/apex/contrib/optimizers/distributed_fused_adam.py b/apex/contrib/optimizers/distributed_fused_adam.py index 565bdc1fe..bc7b9f364 100644 --- a/apex/contrib/optimizers/distributed_fused_adam.py +++ b/apex/contrib/optimizers/distributed_fused_adam.py @@ -1007,7 +1007,7 @@ def init_param_buffer(self) -> None: if param_buffer_view.dtype != param.dtype: raise RuntimeError( f"Attempted to change a parameter with dtype={param.dtype} " - f"into a buffer view with dtype={param_view_buffer.dtype}" + f"into a buffer view with dtype={param_buffer_view.dtype}" ) param_flat_views.append(param.detach().view(-1)) param_buffer_views.append(param_buffer_view) @@ -2210,8 +2210,23 @@ def step( params_bucket = self.ParameterBucket() state_bucket = self.state["buckets"][bucket_id] shard_size = state_bucket.shard_size + dtype = state_bucket.dtype param_sync_dtype = state_bucket.param_sync_dtype - if self.contiguous_param_buffer: + if not param_sync_dtype.is_floating_point: + # Make sure param shard buffer is floating-point + overlap_first_bucket = False + if ( + state_bucket.params_shard is not None + and dtype.is_floating_point + ): + params_bucket.params_shard = state_bucket.params_shard + else: + params_bucket.params_shard = torch.empty( + [shard_size], + dtype=self.dtype, + device=self.device, + ) + elif self.contiguous_param_buffer: # Construct view into contiguous param buffer if not self._param_buffers: self.init_param_buffer() @@ -2225,23 +2240,19 @@ def step( params_bucket.params_shard = params_bucket.params_bucket[ bucket_start:bucket_end ] - elif not param_sync_dtype.is_floating_point: - # Allocate temporary buffer for param shard - # Note: Adam kernel only supports floating-point - # dtypes. - params_bucket.params_shard = torch.empty( - [shard_size], - dtype=self.dtype, - device=self.device, - ) - overlap_first_bucket = False else: # Allocate param shard buffer - params_bucket.params_shard = torch.empty( - [shard_size], - dtype=param_sync_dtype, - device=self.device, - ) + if ( + state_bucket.params_shard is not None + and dtype == param_sync_dtype + ): + params_bucket.params_shard = state_bucket.params_shard + else: + params_bucket.params_shard = torch.empty( + [shard_size], + dtype=param_sync_dtype, + device=self.device, + ) self._params_buckets[bucket_id] = params_bucket # Apply optimizer step diff --git a/apex/contrib/test/optimizers/test_dist_adam.py b/apex/contrib/test/optimizers/test_dist_adam.py index 298e5ec16..d41016d32 100644 --- a/apex/contrib/test/optimizers/test_dist_adam.py +++ b/apex/contrib/test/optimizers/test_dist_adam.py @@ -296,6 +296,12 @@ def test_matches_pytorch_int64_param_sync(self): param_sync_dtype=torch.int64, ) + def test_matches_pytorch_int64_param_sync_contiguous_buffers(self): + self.test_matches_pytorch( + param_sync_dtype=torch.int64, + contiguous_buffers=True, + ) + def test_matches_pytorch_uint8_param_sync(self): self.test_matches_pytorch( rtol=0.5,