Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/quantizers/quantizer_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def get_weight_conversions(self):
source_patterns=[
"_weight_qdata",
"_weight_scale_and_zero",
"_weight_per_tensor_scale",
"_weight_scale",
"_weight_zero_point",
"_weight_act_pre_scale",
Expand Down
26 changes: 16 additions & 10 deletions tests/quantization/torchao_integration/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from torchao.dtypes import (
AffineQuantizedTensor,
)
from torchao.prototype.mx_formats import NVFP4DynamicActivationNVFP4WeightConfig
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
Float8Tensor,
Expand Down Expand Up @@ -587,13 +588,14 @@ class TorchAoSerializationTest(unittest.TestCase):

test_params = (
[
(Int8WeightOnlyConfig(version=2), ALL_DEVICES_COMMON),
(Int8DynamicActivationInt8WeightConfig(version=2), ALL_DEVICES_COMMON),
(Float8DynamicActivationFloat8WeightConfig(), Expectations({("cuda", None): "What are we having for dinner?\n\nJess: (smiling) I", ("xpu", None): "What are we having for dinner?\n\nJess: (smiling) I"})),
(Float8WeightOnlyConfig(), Expectations({("cuda", None): COMMON_OUTPUT, ("xpu", None): COMMON_OUTPUT})),
(Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d"), Expectations({("cuda", None): "What are we having for dinner?\nRed, white, and green beans,", ("xpu", None): COMMON_OUTPUT})),
(Int8DynamicActivationIntxWeightConfig(), Expectations({("cpu", None): COMMON_OUTPUT, ("cuda", 9): COMMON_OUTPUT, ("cuda", 8): "What are we having for dinner?\n\nJEN: (smiling) I", ("xpu", None): COMMON_OUTPUT})),
(IntxWeightOnlyConfig(), ALL_DEVICES_COMMON),
("Int8WeightOnlyConfig", Int8WeightOnlyConfig(version=2), ALL_DEVICES_COMMON),
("Int8DynamicActivationInt8WeightConfig", Int8DynamicActivationInt8WeightConfig(version=2), ALL_DEVICES_COMMON),
("Float8DynamicActivationFloat8WeightConfig", Float8DynamicActivationFloat8WeightConfig(), Expectations({("cuda", None): COMMON_OUTPUT, ("xpu", None): "What are we having for dinner?\n\nJess: (smiling) I"})),
("Float8WeightOnlyConfig", Float8WeightOnlyConfig(), Expectations({("cuda", None): COMMON_OUTPUT, ("xpu", None): COMMON_OUTPUT})),
("Int4WeightOnlyConfig", Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d"), Expectations({("cuda", None): "What are we having for dinner?\nRed, white, and green beans,", ("xpu", None): COMMON_OUTPUT})),
("Int8DynamicActivationIntxWeightConfig", Int8DynamicActivationIntxWeightConfig(), Expectations({("cpu", None): COMMON_OUTPUT, ("cuda", 9): COMMON_OUTPUT, ("cuda", 8): "What are we having for dinner?\n\nJEN: (smiling) I", ("xpu", None): COMMON_OUTPUT})),
("IntxWeightOnlyConfig", IntxWeightOnlyConfig(), ALL_DEVICES_COMMON),
("NVFP4DynamicActivationNVFP4WeightConfig", NVFP4DynamicActivationNVFP4WeightConfig(), Expectations({("cuda", None): "What are we having for dinner?\n\n10. Avoid using \"I"})),
]
if is_torchao_available()
else []
Expand All @@ -609,8 +611,12 @@ def _check_serialization(self, device, config, expected_output):
if isinstance(config, (Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig)):
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 9):
self.skipTest(f"{type(config).__name__} requires CUDA capability >= (8, 9)")
if isinstance(config, NVFP4DynamicActivationNVFP4WeightConfig):
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (10, 0):
self.skipTest(f"{type(config).__name__} requires CUDA capability >= (10, 0) (SM100)")
quant_config = TorchAoConfig(config)
dtype = torch.bfloat16 if isinstance(config, Int4WeightOnlyConfig) else "auto"
needs_bfloat16 = isinstance(config, Int4WeightOnlyConfig | NVFP4DynamicActivationNVFP4WeightConfig)
dtype = torch.bfloat16 if needs_bfloat16 else "auto"
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
dtype=dtype,
Expand All @@ -629,7 +635,7 @@ def _check_serialization(self, device, config, expected_output):
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), expected_output)

@parameterized.expand(test_params, skip_on_empty=True)
def test_serialization_cpu(self, config, expected_outputs):
def test_serialization_cpu(self, _name, config, expected_outputs):
try:
expected = expected_outputs.find_expectation(("cpu", None, None))
except ValueError:
Expand All @@ -638,7 +644,7 @@ def test_serialization_cpu(self, config, expected_outputs):

@parameterized.expand(test_params, skip_on_empty=True)
@require_torch_accelerator
def test_serialization_accelerator(self, config, expected_outputs):
def test_serialization_accelerator(self, _name, config, expected_outputs):
try:
expected = expected_outputs.get_expectation()
except ValueError:
Expand Down
Loading