Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions transformer_engine/pytorch/csrc/quantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,10 @@ std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
py::object out_py;
if (internal) {
py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass));
out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = *scale_inv,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py,
"quantizer"_a = this->quantizer);
out_py =
Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = *scale_inv,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py,
"quantizer"_a = this->quantizer, "fake_dtype"_a = GetATenDType(dtype));
} else {
py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorPythonClass));
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
Expand Down Expand Up @@ -360,9 +361,10 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso
py::object transpose_py = with_transpose ? py::cast(transpose_tensor) : py::none();
if (internal) {
py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass));
out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py,
"quantizer"_a = this->quantizer);
out_py =
Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py,
"quantizer"_a = this->quantizer, "fake_dtype"_a = GetATenDType(dtype));
} else {
py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorPythonClass));
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
Expand Down Expand Up @@ -624,7 +626,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
"rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise,
"rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise,
"fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer,
"is_2D_scaled"_a = (block_scaling_dim == 2));
"is_2D_scaled"_a = (block_scaling_dim == 2), "fake_dtype"_a = GetATenDType(dtype));
} else {
py::handle Float8BlockwiseQTensorClass(
reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass));
Expand Down Expand Up @@ -911,7 +913,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve
py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorStoragePythonClass));
out_py = MXFP8TensorClass(rowwise_data_py, rowwise_scale_inv_py, columnwise_data_py,
columnwise_scale_inv_py, this->dtype, this->quantizer,
with_gemm_swizzled_scales);
with_gemm_swizzled_scales, "fake_dtype"_a = GetATenDType(dtype));
} else {
py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorPythonClass));
out_py = MXFP8TensorClass(
Expand Down Expand Up @@ -1202,7 +1204,8 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
py::handle NVFP4TensorClass(reinterpret_cast<PyObject*>(NVFP4TensorStoragePythonClass));
out_py = NVFP4TensorClass(rowwise_data_py, rowwise_scale_inv_py, columnwise_data_py,
columnwise_scale_inv_py, amax_rowwise_py, amax_columnwise_py,
this->dtype, this->quantizer, with_gemm_swizzled_scales);
this->dtype, this->quantizer, with_gemm_swizzled_scales,
"fake_dtype"_a = GetATenDType(dtype));
} else {
py::handle NVFP4TensorClass(reinterpret_cast<PyObject*>(NVFP4TensorPythonClass));
out_py = NVFP4TensorClass(
Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,7 +1079,7 @@ def _start_all_gather_fp8_blockwise(
device = inp._columnwise_data.device
else:
raise ValueError("Got Float8BlockwiseQTensorStorage input tensor without any data")
dtype = torch.bfloat16 # Only has fp8 dtype. Guess BF16 for dequant.
dtype = inp._dtype
else:
raise ValueError(
"Invalid type for input tensor (expected torch.Tensor or"
Expand Down Expand Up @@ -1317,7 +1317,7 @@ def _all_gather_nvfp4(
if inp._columnwise_data is not None:
in_shape_t = inp._columnwise_data.size()
device = inp._columnwise_data.device
dtype = torch.bfloat16
dtype = inp._dtype
else:
raise ValueError(
"Invalid type for input tensor (expected torch.Tensor or NVFP4TensorStorage, "
Expand Down Expand Up @@ -1486,7 +1486,7 @@ def _all_gather_mxfp8(
device = inp._columnwise_data.device
else:
raise ValueError("Got MXFP8 input tensor without any data")
dtype = torch.bfloat16 # Guess high-precision dtype.
dtype = inp._dtype
else:
raise ValueError(
"Invalid type for input tensor (expected torch.Tensor or MXFP8TensorStorage, "
Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ def fill_userbuffers_buffer_for_all_gather(
data=global_tensor_data,
fp8_scale_inv=local_tensor._scale_inv,
fp8_dtype=local_tensor._fp8_dtype,
fake_dtype=local_tensor._dtype,
quantizer=quantizer,
)
return global_tensor, local_tensor
Expand Down Expand Up @@ -596,6 +597,7 @@ def fill_userbuffers_buffer_for_all_gather(
fp8_dtype=local_tensor._fp8_dtype,
quantizer=quantizer,
with_gemm_swizzled_scales=False,
fake_dtype=local_tensor._dtype,
)
return global_tensor, local_tensor

Expand Down
9 changes: 7 additions & 2 deletions transformer_engine/pytorch/quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class QuantizedTensorStorage:
XTensor should only implement the functionality needed
to behave like regular torch.Tensor (like __torch_dispatch__)."""

_dtype: torch.dtype
_quantizer: Optional[Quantizer]

def update_usage(
Expand Down Expand Up @@ -355,9 +356,12 @@ def __new__(
shape: Iterable[int],
dtype: torch.dtype,
*,
fake_dtype: Optional[torch.dtype] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this redundant with the dtype kwarg?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly to avoid issues with MRO and still have fairly straightforward constructors for the Storage classes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also just noticed that the make_like call would be problematic there otherwise - we want to include the fake_dtype in get_metadata call, but if it was named dtype it would clash with the dtype that we pass directly in make_like.

requires_grad: bool = False,
device: Optional[torch.device] = None,
):
if fake_dtype is not None and fake_dtype != dtype:
raise ValueError(f"fake_dtype ({fake_dtype}) does not match dtype ({dtype})")
# We are assuming only contiguous tensors
stride = _stride_from_shape(shape)
instance = torch.Tensor._make_wrapper_subclass(
Expand All @@ -370,6 +374,7 @@ def __new__(
requires_grad=requires_grad,
device=torch.cuda.current_device() if device is None else device,
)
instance._dtype = dtype

return instance

Expand Down Expand Up @@ -403,7 +408,7 @@ def clear(self):
)

def __repr__(self, *, tensor_contents=None) -> str:
return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})"
return f"{self.__class__.__name__}(data={self.dequantize()})"

def float(self) -> torch.Tensor:
# pylint: disable=missing-function-docstring
Expand Down Expand Up @@ -506,7 +511,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):

def maybe_unwrap(arg):
if isinstance(arg, QuantizedTensor):
return arg.dequantize(dtype=arg.dtype)
return arg.dequantize()
return arg

def maybe_update_inplace(arg, new_arg, schema_arg):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def __repr__(self, *, tensor_contents=None):
return (
f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype},"
f" is_2D_scaled={self._is_2D_scaled},"
f" data={self.dequantize(dtype=self.dtype)})"
f" data={self.dequantize()})"
)

def quantize_(
Expand Down
4 changes: 3 additions & 1 deletion transformer_engine/pytorch/tensor/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def create_tensor_from_data(
data=data,
fp8_scale_inv=1 / self.scale,
fp8_dtype=self.dtype,
fake_dtype=fake_dtype,
requires_grad=requires_grad,
data_transpose=None,
quantizer=self,
Expand Down Expand Up @@ -393,6 +394,7 @@ def create_tensor_from_data(
data=data,
fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=data.device),
fp8_dtype=self.dtype,
fake_dtype=fake_dtype,
requires_grad=requires_grad,
data_transpose=None,
quantizer=self,
Expand Down Expand Up @@ -480,7 +482,7 @@ def __repr__(self, *, tensor_contents=None):
"Float8Tensor("
f"fp8_dtype={self._fp8_dtype}, "
f"scale_inv={self._scale_inv.item()}, "
f"data={self.dequantize(dtype=self.dtype)}"
f"data={self.dequantize()}"
")"
)

Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/tensor/mxfp8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def __new__(
)

def __repr__(self, *, tensor_contents=None):
return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize(dtype=self.dtype)})"
return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize()})"

def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/tensor/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def __new__(
return instance

def __repr__(self, *, tensor_contents=None):
return f"NVFP4Tensor, data={self.dequantize(dtype=self.dtype)})"
return f"NVFP4Tensor, data={self.dequantize()})"

def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@ def __new__(
quantizer: Quantizer,
is_2D_scaled: bool,
*args,
fake_dtype: Optional[torch.dtype] = None,
**kwargs,
):
if cls is Float8BlockwiseQTensorStorage:
instance = object.__new__(cls)
instance._dtype = fake_dtype if fake_dtype is not None else torch.float32
else:
instance = super().__new__(cls, *args, **kwargs)
instance = super().__new__(cls, *args, fake_dtype=fake_dtype, **kwargs)
instance._rowwise_data = rowwise_data
instance._columnwise_data = columnwise_data
instance._quantizer = quantizer.copy() if quantizer is not None else None
Expand Down Expand Up @@ -83,6 +85,7 @@ def get_metadata(self) -> Dict[str, Any]:
"fp8_dtype": self._fp8_dtype,
"quantizer": self._quantizer,
"is_2D_scaled": self._is_2D_scaled,
"fake_dtype": self._dtype,
}

def prepare_for_saving(
Expand Down Expand Up @@ -131,7 +134,9 @@ def _transpose_dq_columnwise_output(self, columnwise_dq: torch.Tensor) -> torch.
permute_dims.append(0)
return torch.permute(columnwise_dq, tuple(permute_dims)).contiguous()

def _dequantize_vectorwise(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:
def _dequantize_vectorwise(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
if dtype is None:
dtype = self._dtype
block_len = 128

q_M, q_K = 1, 1
Expand Down Expand Up @@ -193,10 +198,12 @@ def _dequantize_vectorwise(self, *, dtype: torch.dtype = torch.float32) -> torch
return self._transpose_dq_columnwise_output(result)
return result

def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:
def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""
Construct plain PyTorch tensor from Float8BlockwiseQTensor
"""
if dtype is None:
dtype = self._dtype
block_len = 128
if not self._is_2D_scaled:
return self._dequantize_vectorwise(dtype=dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,16 @@ def __new__(
data: Optional[torch.Tensor],
fp8_scale_inv: torch.Tensor,
fp8_dtype: TE_DType,
fake_dtype: Optional[torch.dtype] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to just name it dtype since QuantizedTensor is already using that name in its constructor.

Suggested change
fake_dtype: Optional[torch.dtype] = None,
dtype: Optional[torch.dtype] = None,

data_transpose: Optional[torch.Tensor] = None,
quantizer: Optional[Quantizer] = None,
**kwargs,
):
if cls is Float8TensorStorage:
instance = object.__new__(cls)
instance._dtype = fake_dtype if fake_dtype is not None else torch.float32
else:
instance = super().__new__(cls, *args, **kwargs)
instance = super().__new__(cls, *args, fake_dtype=fake_dtype, **kwargs)
instance._data = data
instance._quantizer = quantizer.copy() if quantizer is not None else None
instance._fp8_dtype = fp8_dtype
Expand Down Expand Up @@ -112,6 +114,7 @@ def get_metadata(self) -> Dict[str, Any]:
"fp8_dtype": self._fp8_dtype,
"data_transpose": self._transpose,
"quantizer": self._quantizer,
"fake_dtype": self._dtype,
}

def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]:
Expand Down Expand Up @@ -141,8 +144,10 @@ def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = Tr
return self._transpose
raise ValueError("No data to get, both rowwise_data and columnwise_data are False")

def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:
def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""Dequantize to a higher precision."""
if dtype is None:
dtype = self._dtype
return _FromFloat8Func.forward(None, self, dtype)

def size(self, *args, **kwargs):
Expand All @@ -165,6 +170,7 @@ def view(self, shape: torch.Size):
data=out_data,
fp8_scale_inv=self._scale_inv,
fp8_dtype=self._fp8_dtype,
fake_dtype=self._dtype,
data_transpose=out_transpose,
quantizer=self._quantizer,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,14 @@ def __new__(
quantizer: Optional[Quantizer],
with_gemm_swizzled_scales: bool,
*args,
fake_dtype: Optional[torch.dtype] = None,
**kwargs,
):
if cls is MXFP8TensorStorage:
instance = object.__new__(cls)
instance._dtype = fake_dtype if fake_dtype is not None else torch.float32
else:
instance = super().__new__(cls, *args, **kwargs)
instance = super().__new__(cls, *args, fake_dtype=fake_dtype, **kwargs)
instance._rowwise_data = rowwise_data
instance._columnwise_data = columnwise_data
instance._rowwise_scale_inv = rowwise_scale_inv
Expand Down Expand Up @@ -121,6 +123,7 @@ def get_metadata(self) -> Dict[str, Any]:
"fp8_dtype": self._fp8_dtype,
"quantizer": self._quantizer,
"with_gemm_swizzled_scales": self._with_gemm_swizzled_scales,
"fake_dtype": self._dtype,
}

def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorStorage]:
Expand Down Expand Up @@ -157,8 +160,10 @@ def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = Tr
return self._columnwise_data
raise ValueError("No data to get, both rowwise_data and columnwise_data are False")

def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:
def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""Dequantize to a higher precision."""
if dtype is None:
dtype = self._dtype
return _FromMXFP8Func.forward(None, self, dtype)

def size(self, *args, **kwargs):
Expand Down Expand Up @@ -211,6 +216,7 @@ def view(self, shape: torch.Size):
fp8_dtype=self._fp8_dtype,
quantizer=self._quantizer,
with_gemm_swizzled_scales=self._with_gemm_swizzled_scales,
fake_dtype=self._dtype,
)

def __repr__(self):
Expand Down
14 changes: 11 additions & 3 deletions transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,14 @@ def __new__(
quantizer: Optional[Quantizer],
with_gemm_swizzled_scales: bool,
*args,
fake_dtype: Optional[torch.dtype] = None,
**kwargs,
):

instance = super().__new__(cls, *args, **kwargs)
if cls is NVFP4TensorStorage:
instance = object.__new__(cls)
instance._dtype = fake_dtype if fake_dtype is not None else torch.float32
else:
instance = super().__new__(cls, *args, fake_dtype=fake_dtype, **kwargs)

instance._rowwise_data = rowwise_data
instance._columnwise_data = columnwise_data
Expand Down Expand Up @@ -148,6 +152,7 @@ def get_metadata(self) -> Dict[str, Any]:
"fp4_dtype": self._fp4_dtype,
"quantizer": self._quantizer,
"with_gemm_swizzled_scales": self._with_gemm_swizzled_scales,
"fake_dtype": self._dtype,
}

def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], NVFP4TensorStorage]:
Expand Down Expand Up @@ -184,8 +189,10 @@ def get_data_tensors(self):
"""Get this Tensor's data."""
return self._rowwise_data, self._columnwise_data

def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:
def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""Dequantize to a higher precision."""
if dtype is None:
dtype = self._dtype
return _FromNVFP4Func.forward(None, self, dtype)

def size(self, dim: Optional[int] = None) -> Union[torch.Size, int]:
Expand Down Expand Up @@ -266,6 +273,7 @@ def view(self, shape: torch.Size):
quantizer=self._quantizer,
fp4_dtype=self._fp4_dtype,
with_gemm_swizzled_scales=self._with_gemm_swizzled_scales,
fake_dtype=self._dtype,
)

def __repr__(self):
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def forward(
fp8_dtype=mixed_x_layer._fp8_dtype,
data=x.squeeze(split_dim) if squeeze else x,
shape=x.squeeze(split_dim).shape if squeeze else x.shape,
fake_dtype=mixed_x_layer._dtype,
quantizer=mixed_x_layer._quantizer,
)
for x in torch.split(
Expand Down
Loading