diff --git a/src/transformers/integrations/torchao.py b/src/transformers/integrations/torchao.py index 421a004dd6e9..2fa20a3982b9 100644 --- a/src/transformers/integrations/torchao.py +++ b/src/transformers/integrations/torchao.py @@ -35,19 +35,10 @@ logger = logging.get_logger(__name__) -def _quantization_type(weight): - from torchao.dtypes import AffineQuantizedTensor - from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor - - if isinstance(weight, AffineQuantizedTensor): - return f"{weight.__class__.__name__}({weight._quantization_type()})" - - if isinstance(weight, LinearActivationQuantizedTensor): - return f"{weight.__class__.__name__}(activation={weight.input_quant_func}, weight={_quantization_type(weight.original_weight_tensor)})" - - def _linear_extra_repr(self): - weight = _quantization_type(self.weight) + from torchao.utils import TorchAOBaseTensor + + weight = self.weight.__class__.__name__ if isinstance(self.weight, TorchAOBaseTensor) else None if weight is None: return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight=None" else: diff --git a/tests/quantization/torchao_integration/test_torchao.py b/tests/quantization/torchao_integration/test_torchao.py index ebcc08816d95..678ae34aac03 100644 --- a/tests/quantization/torchao_integration/test_torchao.py +++ b/tests/quantization/torchao_integration/test_torchao.py @@ -36,9 +36,6 @@ import torch if is_torchao_available(): - from torchao.dtypes import ( - AffineQuantizedTensor, - ) from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, Float8Tensor, @@ -52,6 +49,9 @@ MappingType, PerAxis, ) + from torchao.utils import ( + TorchAOBaseTensor, + ) @require_torchao @@ -191,7 +191,7 @@ def test_per_module_config_skip(self): torch_dtype=torch.bfloat16, ) # making sure `model.layers.0.self_attn.q_proj` is skipped - self.assertTrue(not isinstance(quantized_model.model.layers[0].self_attn.q_proj.weight, AffineQuantizedTensor)) + self.assertTrue(not isinstance(quantized_model.model.layers[0].self_attn.q_proj.weight, TorchAOBaseTensor)) tokenizer = AutoTokenizer.from_pretrained(self.model_name) input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device) @@ -211,7 +211,7 @@ def test_fqn_to_config_regex_basic(self): torch_dtype=torch.bfloat16, ) # making sure `model.layers.0.self_attn.q_proj` is skipped - self.assertTrue(not isinstance(quantized_model.model.layers[0].self_attn.q_proj.weight, AffineQuantizedTensor)) + self.assertTrue(not isinstance(quantized_model.model.layers[0].self_attn.q_proj.weight, TorchAOBaseTensor)) tokenizer = AutoTokenizer.from_pretrained(self.model_name) input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device) @@ -244,7 +244,7 @@ def test_fqn_to_config_regex_fullmatch(self): self.assertTrue(isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, Float8Tensor)) # because regex `model\.layers\.+*\.self_attn\.q_pro` didin't fully match `model.layers.1.self_attn.q_proj` (missing last `j`) # this layer is not expected to be quantized to int8 - self.assertTrue(not isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, AffineQuantizedTensor)) + self.assertTrue(not isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, TorchAOBaseTensor)) tokenizer = AutoTokenizer.from_pretrained(self.model_name) input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device) @@ -273,9 +273,9 @@ def test_fqn_to_config_module_regex_precedence(self): # highest precedence is fully specified module fqn self.assertTrue(isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, Float8Tensor)) # second precedence: regex - self.assertTrue(not isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, AffineQuantizedTensor)) + self.assertTrue(not isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, TorchAOBaseTensor)) # last precedence: _default - self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.k_proj.weight, AffineQuantizedTensor)) + self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.k_proj.weight, TorchAOBaseTensor)) tokenizer = AutoTokenizer.from_pretrained(self.model_name) input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device) @@ -302,8 +302,8 @@ def test_fqn_to_config_regex_precedence(self): torch_dtype=torch.bfloat16, ) self.assertTrue(isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, Float8Tensor)) - self.assertTrue(not isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, AffineQuantizedTensor)) - self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.k_proj.weight, AffineQuantizedTensor)) + self.assertTrue(not isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, TorchAOBaseTensor)) + self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.k_proj.weight, TorchAOBaseTensor)) tokenizer = AutoTokenizer.from_pretrained(self.model_name) input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device) @@ -329,8 +329,8 @@ def test_fqn_to_config_param_over_module_regex_precedence(self): quantization_config=quant_config, torch_dtype=torch.bfloat16, ) - self.assertTrue(not isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, AffineQuantizedTensor)) - self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.k_proj.weight, AffineQuantizedTensor)) + self.assertTrue(not isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, TorchAOBaseTensor)) + self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.k_proj.weight, TorchAOBaseTensor)) tokenizer = AutoTokenizer.from_pretrained(self.model_name) input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device) @@ -356,8 +356,8 @@ def test_fqn_to_config_param_over_module_precedence(self): quantization_config=quant_config, torch_dtype=torch.bfloat16, ) - self.assertTrue(not isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, AffineQuantizedTensor)) - self.assertTrue(isinstance(quantized_model.model.layers[3].self_attn.k_proj.weight, AffineQuantizedTensor)) + self.assertTrue(not isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, TorchAOBaseTensor)) + self.assertTrue(isinstance(quantized_model.model.layers[3].self_attn.k_proj.weight, TorchAOBaseTensor)) tokenizer = AutoTokenizer.from_pretrained(self.model_name) input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device) @@ -383,8 +383,8 @@ def test_fqn_to_config_exact_over_regex_precedence(self): quantization_config=quant_config, torch_dtype=torch.bfloat16, ) - self.assertTrue(not isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, AffineQuantizedTensor)) - self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, AffineQuantizedTensor)) + self.assertTrue(not isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, TorchAOBaseTensor)) + self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, TorchAOBaseTensor)) self.assertTrue(isinstance(quantized_model.model.layers[2].self_attn.q_proj.weight, Float8Tensor)) tokenizer = AutoTokenizer.from_pretrained(self.model_name) @@ -418,7 +418,7 @@ def test_fqn_to_config_non_weight_param(self): self.assertTrue( not isinstance(quantized_model.model.layers[0].feed_forward.experts.gate_up_proj, Float8Tensor) ) - self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, AffineQuantizedTensor)) + self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, TorchAOBaseTensor)) def test_compute_module_sizes(self): r"""