From 4b8ceaa1572198afbb13cba1f7ed4e91aaddfc32 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Tue, 21 Apr 2026 21:09:04 +0000 Subject: [PATCH] fix transformers + torchao nvfp4 serialization Summary: 1. fix torchao NVFP4 serialization with transformers 2. add a test to cover the fix While i'm here, also did the following bundled into this PR: 3. make the torchao serialization test have human readable names (easier to debug) 4. fix the float8 test (update the expected output) after this PR the test command for all torchao configs passes on an NVIDIA B200 Test Plan: ``` RUN_SLOW=1 pytest tests/quantization/torchao_integration/test_torchao.py -k "Serialization" -s ``` --- .../quantizers/quantizer_torchao.py | 1 + .../torchao_integration/test_torchao.py | 26 ++++++++++++------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/src/transformers/quantizers/quantizer_torchao.py b/src/transformers/quantizers/quantizer_torchao.py index a76f73aeb562..fd117b08023b 100644 --- a/src/transformers/quantizers/quantizer_torchao.py +++ b/src/transformers/quantizers/quantizer_torchao.py @@ -184,6 +184,7 @@ def get_weight_conversions(self): source_patterns=[ "_weight_qdata", "_weight_scale_and_zero", + "_weight_per_tensor_scale", "_weight_scale", "_weight_zero_point", "_weight_act_pre_scale", diff --git a/tests/quantization/torchao_integration/test_torchao.py b/tests/quantization/torchao_integration/test_torchao.py index ebcc08816d95..b188b4f9a0c3 100644 --- a/tests/quantization/torchao_integration/test_torchao.py +++ b/tests/quantization/torchao_integration/test_torchao.py @@ -39,6 +39,7 @@ from torchao.dtypes import ( AffineQuantizedTensor, ) + from torchao.prototype.mx_formats import NVFP4DynamicActivationNVFP4WeightConfig from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, Float8Tensor, @@ -587,13 +588,14 @@ class TorchAoSerializationTest(unittest.TestCase): test_params = ( [ - (Int8WeightOnlyConfig(version=2), ALL_DEVICES_COMMON), - (Int8DynamicActivationInt8WeightConfig(version=2), ALL_DEVICES_COMMON), - (Float8DynamicActivationFloat8WeightConfig(), Expectations({("cuda", None): "What are we having for dinner?\n\nJess: (smiling) I", ("xpu", None): "What are we having for dinner?\n\nJess: (smiling) I"})), - (Float8WeightOnlyConfig(), Expectations({("cuda", None): COMMON_OUTPUT, ("xpu", None): COMMON_OUTPUT})), - (Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d"), Expectations({("cuda", None): "What are we having for dinner?\nRed, white, and green beans,", ("xpu", None): COMMON_OUTPUT})), - (Int8DynamicActivationIntxWeightConfig(), Expectations({("cpu", None): COMMON_OUTPUT, ("cuda", 9): COMMON_OUTPUT, ("cuda", 8): "What are we having for dinner?\n\nJEN: (smiling) I", ("xpu", None): COMMON_OUTPUT})), - (IntxWeightOnlyConfig(), ALL_DEVICES_COMMON), + ("Int8WeightOnlyConfig", Int8WeightOnlyConfig(version=2), ALL_DEVICES_COMMON), + ("Int8DynamicActivationInt8WeightConfig", Int8DynamicActivationInt8WeightConfig(version=2), ALL_DEVICES_COMMON), + ("Float8DynamicActivationFloat8WeightConfig", Float8DynamicActivationFloat8WeightConfig(), Expectations({("cuda", None): COMMON_OUTPUT, ("xpu", None): "What are we having for dinner?\n\nJess: (smiling) I"})), + ("Float8WeightOnlyConfig", Float8WeightOnlyConfig(), Expectations({("cuda", None): COMMON_OUTPUT, ("xpu", None): COMMON_OUTPUT})), + ("Int4WeightOnlyConfig", Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d"), Expectations({("cuda", None): "What are we having for dinner?\nRed, white, and green beans,", ("xpu", None): COMMON_OUTPUT})), + ("Int8DynamicActivationIntxWeightConfig", Int8DynamicActivationIntxWeightConfig(), Expectations({("cpu", None): COMMON_OUTPUT, ("cuda", 9): COMMON_OUTPUT, ("cuda", 8): "What are we having for dinner?\n\nJEN: (smiling) I", ("xpu", None): COMMON_OUTPUT})), + ("IntxWeightOnlyConfig", IntxWeightOnlyConfig(), ALL_DEVICES_COMMON), + ("NVFP4DynamicActivationNVFP4WeightConfig", NVFP4DynamicActivationNVFP4WeightConfig(), Expectations({("cuda", None): "What are we having for dinner?\n\n10. Avoid using \"I"})), ] if is_torchao_available() else [] @@ -609,8 +611,12 @@ def _check_serialization(self, device, config, expected_output): if isinstance(config, (Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig)): if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 9): self.skipTest(f"{type(config).__name__} requires CUDA capability >= (8, 9)") + if isinstance(config, NVFP4DynamicActivationNVFP4WeightConfig): + if torch.cuda.is_available() and torch.cuda.get_device_capability() < (10, 0): + self.skipTest(f"{type(config).__name__} requires CUDA capability >= (10, 0) (SM100)") quant_config = TorchAoConfig(config) - dtype = torch.bfloat16 if isinstance(config, Int4WeightOnlyConfig) else "auto" + needs_bfloat16 = isinstance(config, Int4WeightOnlyConfig | NVFP4DynamicActivationNVFP4WeightConfig) + dtype = torch.bfloat16 if needs_bfloat16 else "auto" quantized_model = AutoModelForCausalLM.from_pretrained( self.model_name, dtype=dtype, @@ -629,7 +635,7 @@ def _check_serialization(self, device, config, expected_output): self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), expected_output) @parameterized.expand(test_params, skip_on_empty=True) - def test_serialization_cpu(self, config, expected_outputs): + def test_serialization_cpu(self, _name, config, expected_outputs): try: expected = expected_outputs.find_expectation(("cpu", None, None)) except ValueError: @@ -638,7 +644,7 @@ def test_serialization_cpu(self, config, expected_outputs): @parameterized.expand(test_params, skip_on_empty=True) @require_torch_accelerator - def test_serialization_accelerator(self, config, expected_outputs): + def test_serialization_accelerator(self, _name, config, expected_outputs): try: expected = expected_outputs.get_expectation() except ValueError: