Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/pytorch/test_float8tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def test_transpose(
dims: DimsType,
transpose_dims: Tuple[int, int],
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 1,
scale: float = 0.5,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought there was a correctness issue that was hidden by scale=1, but I don't think it's actually an issue. Making this non-one does a better job stress-testing this in any case though.

dtype: torch.dtype = torch.float32,
) -> None:
"""Test transpose"""
Expand Down
62 changes: 28 additions & 34 deletions transformer_engine/pytorch/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,30 +435,12 @@ def expand_as(self, other: torch.Tensor):
return _IdentityFunc.apply(self)
return super().expand_as(other)

def _transpose_no_cache(self) -> torch.Tensor:
"""
Swap tensor dimensions

For basic 2D matrix transposes, an optimized transpose kernel
is applied and a Float8Tensor is returned.
"""

# Use optimized kernel for basic 2D transpose
# TODO Support differentiation # pylint: disable=fixme
return Float8Tensor.make_like(
self,
data=tex.fp8_transpose(
self._data.contiguous().detach(),
self._fp8_dtype,
),
)

def transpose(
self,
dim0: int = 0,
dim1: int = 1,
*,
update_cache: Optional[bool] = None,
update_cache: bool = False,
) -> torch.Tensor:
"""
Swap tensor dimensions
Expand All @@ -472,12 +454,14 @@ def transpose(
The first dimension to be transposed
dim1: int, default = 1
The second dimension to be transposed
update_cache: Optional[bool], default = None
If set to `True`, the result is computed and stored in a cache.
If set to `False`, the result is computed only if the cache is
empty, otherwise the cache is returned. If set to `None`, the
result is not cached. Caching is only supported for basic 2D
transposes and the cache is reset after any in-place operations.
update_cache: bool, default = False
If `True`, the transpose is computed and stored
in a cache. If `False`, a cached version is
returned if available and otherwise the
transpose is computed. Caching is only supported
for basic 2D transposes and the cache is reset
after any in-place operations.

"""

# Handle non-2D transposes
Expand All @@ -486,22 +470,32 @@ def transpose(
if -self.dim() <= dim1 < 0:
dim1 += self.dim()
if self.dim() != 2 or dim0 == dim1:
if update_cache is not None:
if update_cache:
raise ValueError(
"Transpose caching is only supported for basic 2D transposes "
f"(ndims={self.dim()}, dim0={dim0}, dim1={dim1})"
)
return super().transpose(dim0, dim1)

# No caching.
if update_cache is None:
return self._transpose_no_cache()

# Update cache.
if update_cache or self._transpose is None:
self._transpose = self._transpose_no_cache()
# Clear cache if needed
if update_cache:
self._transpose = None

# Compute transpose if needed
out = self._transpose
if out is None:
out = Float8Tensor.make_like(
self,
data=tex.fp8_transpose(
self._data.contiguous(),
self._fp8_dtype,
),
)

return self._transpose
# Update cache if needed
if update_cache:
self._transpose = out
return out

@torch.no_grad()
def reset_fp8_meta_scale_inv(self) -> None:
Expand Down