Skip to content

[PyTorch] Add dtype information to QuantizedTensorStorage class#2676

Open
ptrendx wants to merge 4 commits intoNVIDIA:mainfrom
ptrendx:pr_dtype_in_storage
Open

[PyTorch] Add dtype information to QuantizedTensorStorage class#2676
ptrendx wants to merge 4 commits intoNVIDIA:mainfrom
ptrendx:pr_dtype_in_storage

Conversation

@ptrendx
Copy link
Member

@ptrendx ptrendx commented Feb 12, 2026

Description

This PR adds the fake dtype information to the QuantizedTensorStorage class. This eliminates the need to guess the correct type for dequantize, as was the case in the distributed.py, and it eliminates the unintentional dequantization to FP32 when calling dequantize() on the Storage class with no dtype argument.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Added the _dtype field to the QuantizedTensorStorage class
  • Modified the dequantize call to use that new field when calling dequantize with no arguments
  • Removed guessing of the dtype from distributed.py

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx ptrendx requested a review from timmoon10 February 12, 2026 19:07
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 12, 2026

Greptile Summary

This PR adds a _dtype field to the QuantizedTensorStorage base class and all its subclasses (Float8, Float8Blockwise, MXFP8, NVFP4) to track the original high-precision dtype (e.g., BF16, FP16) of the data before quantization. This eliminates the need for guessing the dequantization dtype (previously hardcoded as torch.bfloat16 in distributed.py) and prevents unintentional dequantization to FP32 when calling dequantize() with no arguments on storage classes.

  • Added fake_dtype parameter to all storage class constructors, stored as _dtype, propagated through get_metadata() and view() methods
  • Changed dequantize() default from dtype: torch.dtype = torch.float32 to dtype: Optional[torch.dtype] = None with fallback to self._dtype across all storage classes
  • Removed hardcoded torch.bfloat16 dtype guesses in distributed.py all-gather functions, now using inp._dtype
  • C++ quantizer.cpp now passes GetATenDType(dtype) as fake_dtype to all storage class constructors
  • Simplified __repr__ methods across all tensor classes to call self.dequantize() without explicit dtype
  • Fixed structural inconsistency in NVFP4TensorStorage.__new__ by adding the cls is check pattern used by other storage classes

Confidence Score: 5/5

  • This PR is safe to merge — it adds a well-structured dtype tracking field with consistent implementation across all tensor storage types and proper fallback defaults.
  • The changes are mechanical and consistent across all four storage class families. Every construction site (both Python and C++) properly passes the new fake_dtype parameter. The fallback to torch.float32 when fake_dtype is not provided maintains backward compatibility. The removal of hardcoded torch.bfloat16 guesses in distributed.py is a clear correctness improvement. No logical errors or edge cases were found.
  • No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/quantized_tensor.py Added _dtype annotation to QuantizedTensorStorage, added fake_dtype parameter to QuantizedTensor.__new__ with validation, set instance._dtype = dtype, and simplified dequantize() calls to use the stored dtype instead of passing dtype=self.dtype explicitly.
transformer_engine/pytorch/csrc/quantizer.cpp Passes fake_dtype=GetATenDType(dtype) to all four storage class constructors (Float8, Float8CurrentScaling, Float8Block, MXFP8, NVFP4) when creating internal (storage) tensors from C++. The dtype parameter is the high-precision type of the original data.
transformer_engine/pytorch/distributed.py Replaced hardcoded torch.bfloat16 guess with inp._dtype for Float8Blockwise, NVFP4, and MXFP8 storage types in all-gather functions. This is now correct since _dtype carries the actual original precision dtype.
transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py Added fake_dtype parameter to __new__, stores it as _dtype, propagates through get_metadata(), view(), and uses it as default in dequantize(). Signature change from dtype: torch.dtype = torch.float32 to Optional[torch.dtype] = None with _dtype fallback.
transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py Same pattern as Float8TensorStorage: added fake_dtype parameter, stores as _dtype, uses as default in both dequantize() and _dequantize_vectorwise(). Added cls is check for direct vs subclass instantiation.
transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py Same pattern: added fake_dtype, stores as _dtype, uses as default in dequantize(). Propagated through get_metadata() and view() method.
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py Added fake_dtype parameter and cls is NVFP4TensorStorage check (previously missing, unlike other storage classes). This fixes a structural inconsistency while adding the dtype tracking.
transformer_engine/pytorch/tensor/float8_tensor.py Passes fake_dtype to Float8TensorStorage in create_tensor_from_data for both Float8Quantizer and Float8CurrentScalingQuantizer. Simplified __repr__ to use self.dequantize() without explicit dtype.
transformer_engine/pytorch/module/base.py Passes fake_dtype=local_tensor._dtype when constructing Float8TensorStorage and MXFP8TensorStorage in the Userbuffers all-gather path.

Flowchart

flowchart TD
    A[High-precision Tensor\ne.g. BF16/FP16/FP32] -->|quantize| B{Quantizer\nC++ or Python}
    B -->|"creates with fake_dtype=original_dtype"| C[QuantizedTensorStorage\n_dtype = original dtype]
    C -->|"dequantize()"| D{dtype argument?}
    D -->|"dtype=None (default)"| E[Use self._dtype\nRestores original precision]
    D -->|"dtype=explicit"| F[Use provided dtype]
    E --> G[High-precision Tensor]
    F --> G

    subgraph Storage Classes
        C1[Float8TensorStorage]
        C2[Float8BlockwiseQTensorStorage]
        C3[MXFP8TensorStorage]
        C4[NVFP4TensorStorage]
    end
    C --- C1
    C --- C2
    C --- C3
    C --- C4

    style A fill:#4CAF50,color:#fff
    style G fill:#4CAF50,color:#fff
    style C fill:#FF9800,color:#fff
    style E fill:#2196F3,color:#fff
Loading

Last reviewed commit: be723b2

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@ptrendx
Copy link
Member Author

ptrendx commented Feb 12, 2026

/te-ci pytorch

ksivaman
ksivaman previously approved these changes Feb 12, 2026
Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

timmoon10
timmoon10 previously approved these changes Feb 14, 2026
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this is a big improvement. I have some naming nits.

shape: Iterable[int],
dtype: torch.dtype,
*,
fake_dtype: Optional[torch.dtype] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this redundant with the dtype kwarg?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly to avoid issues with MRO and still have fairly straightforward constructors for the Storage classes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also just noticed that the make_like call would be problematic there otherwise - we want to include the fake_dtype in get_metadata call, but if it was named dtype it would clash with the dtype that we pass directly in make_like.

data: Optional[torch.Tensor],
fp8_scale_inv: torch.Tensor,
fp8_dtype: TE_DType,
fake_dtype: Optional[torch.dtype] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to just name it dtype since QuantizedTensor is already using that name in its constructor.

Suggested change
fake_dtype: Optional[torch.dtype] = None,
dtype: Optional[torch.dtype] = None,

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx ptrendx dismissed stale reviews from timmoon10 and ksivaman via be723b2 February 18, 2026 01:40
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants