[PyTorch] Add dtype information to QuantizedTensorStorage class#2676
[PyTorch] Add dtype information to QuantizedTensorStorage class#2676ptrendx wants to merge 4 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds a
Confidence Score: 5/5
Important Files Changed
Flowchartflowchart TD
A[High-precision Tensor\ne.g. BF16/FP16/FP32] -->|quantize| B{Quantizer\nC++ or Python}
B -->|"creates with fake_dtype=original_dtype"| C[QuantizedTensorStorage\n_dtype = original dtype]
C -->|"dequantize()"| D{dtype argument?}
D -->|"dtype=None (default)"| E[Use self._dtype\nRestores original precision]
D -->|"dtype=explicit"| F[Use provided dtype]
E --> G[High-precision Tensor]
F --> G
subgraph Storage Classes
C1[Float8TensorStorage]
C2[Float8BlockwiseQTensorStorage]
C3[MXFP8TensorStorage]
C4[NVFP4TensorStorage]
end
C --- C1
C --- C2
C --- C3
C --- C4
style A fill:#4CAF50,color:#fff
style G fill:#4CAF50,color:#fff
style C fill:#FF9800,color:#fff
style E fill:#2196F3,color:#fff
Last reviewed commit: be723b2 |
|
/te-ci pytorch |
timmoon10
left a comment
There was a problem hiding this comment.
Overall this is a big improvement. I have some naming nits.
| shape: Iterable[int], | ||
| dtype: torch.dtype, | ||
| *, | ||
| fake_dtype: Optional[torch.dtype] = None, |
There was a problem hiding this comment.
Isn't this redundant with the dtype kwarg?
There was a problem hiding this comment.
This is mostly to avoid issues with MRO and still have fairly straightforward constructors for the Storage classes.
There was a problem hiding this comment.
Also just noticed that the make_like call would be problematic there otherwise - we want to include the fake_dtype in get_metadata call, but if it was named dtype it would clash with the dtype that we pass directly in make_like.
| data: Optional[torch.Tensor], | ||
| fp8_scale_inv: torch.Tensor, | ||
| fp8_dtype: TE_DType, | ||
| fake_dtype: Optional[torch.dtype] = None, |
There was a problem hiding this comment.
I'd prefer to just name it dtype since QuantizedTensor is already using that name in its constructor.
| fake_dtype: Optional[torch.dtype] = None, | |
| dtype: Optional[torch.dtype] = None, |
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Description
This PR adds the fake dtype information to the QuantizedTensorStorage class. This eliminates the need to guess the correct type for dequantize, as was the case in the distributed.py, and it eliminates the unintentional dequantization to FP32 when calling dequantize() on the Storage class with no dtype argument.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: