Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 3 additions & 12 deletions src/transformers/integrations/torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,10 @@
logger = logging.get_logger(__name__)


def _quantization_type(weight):
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor

if isinstance(weight, AffineQuantizedTensor):
return f"{weight.__class__.__name__}({weight._quantization_type()})"

if isinstance(weight, LinearActivationQuantizedTensor):
return f"{weight.__class__.__name__}(activation={weight.input_quant_func}, weight={_quantization_type(weight.original_weight_tensor)})"


def _linear_extra_repr(self):
weight = _quantization_type(self.weight)
from torchao.utils import TorchAOBaseTensor

weight = self.weight.__class__.__name__ if isinstance(self.weight, TorchAOBaseTensor) else None
if weight is None:
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight=None"
else:
Expand Down
34 changes: 17 additions & 17 deletions tests/quantization/torchao_integration/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@
import torch

if is_torchao_available():
from torchao.dtypes import (
AffineQuantizedTensor,
)
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
Float8Tensor,
Expand All @@ -52,6 +49,9 @@
MappingType,
PerAxis,
)
from torchao.utils import (
TorchAOBaseTensor,
)


@require_torchao
Expand Down Expand Up @@ -191,7 +191,7 @@ def test_per_module_config_skip(self):
torch_dtype=torch.bfloat16,
)
# making sure `model.layers.0.self_attn.q_proj` is skipped
self.assertTrue(not isinstance(quantized_model.model.layers[0].self_attn.q_proj.weight, AffineQuantizedTensor))
self.assertTrue(not isinstance(quantized_model.model.layers[0].self_attn.q_proj.weight, TorchAOBaseTensor))
tokenizer = AutoTokenizer.from_pretrained(self.model_name)

input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
Expand All @@ -211,7 +211,7 @@ def test_fqn_to_config_regex_basic(self):
torch_dtype=torch.bfloat16,
)
# making sure `model.layers.0.self_attn.q_proj` is skipped
self.assertTrue(not isinstance(quantized_model.model.layers[0].self_attn.q_proj.weight, AffineQuantizedTensor))
self.assertTrue(not isinstance(quantized_model.model.layers[0].self_attn.q_proj.weight, TorchAOBaseTensor))
tokenizer = AutoTokenizer.from_pretrained(self.model_name)

input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
Expand Down Expand Up @@ -244,7 +244,7 @@ def test_fqn_to_config_regex_fullmatch(self):
self.assertTrue(isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, Float8Tensor))
# because regex `model\.layers\.+*\.self_attn\.q_pro` didin't fully match `model.layers.1.self_attn.q_proj` (missing last `j`)
# this layer is not expected to be quantized to int8
self.assertTrue(not isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, AffineQuantizedTensor))
self.assertTrue(not isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, TorchAOBaseTensor))
tokenizer = AutoTokenizer.from_pretrained(self.model_name)

input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
Expand Down Expand Up @@ -273,9 +273,9 @@ def test_fqn_to_config_module_regex_precedence(self):
# highest precedence is fully specified module fqn
self.assertTrue(isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, Float8Tensor))
# second precedence: regex
self.assertTrue(not isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, AffineQuantizedTensor))
self.assertTrue(not isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, TorchAOBaseTensor))
# last precedence: _default
self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.k_proj.weight, AffineQuantizedTensor))
self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.k_proj.weight, TorchAOBaseTensor))
tokenizer = AutoTokenizer.from_pretrained(self.model_name)

input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
Expand All @@ -302,8 +302,8 @@ def test_fqn_to_config_regex_precedence(self):
torch_dtype=torch.bfloat16,
)
self.assertTrue(isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, Float8Tensor))
self.assertTrue(not isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, AffineQuantizedTensor))
self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.k_proj.weight, AffineQuantizedTensor))
self.assertTrue(not isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, TorchAOBaseTensor))
self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.k_proj.weight, TorchAOBaseTensor))
tokenizer = AutoTokenizer.from_pretrained(self.model_name)

input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
Expand All @@ -329,8 +329,8 @@ def test_fqn_to_config_param_over_module_regex_precedence(self):
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
self.assertTrue(not isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, AffineQuantizedTensor))
self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.k_proj.weight, AffineQuantizedTensor))
self.assertTrue(not isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, TorchAOBaseTensor))
self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.k_proj.weight, TorchAOBaseTensor))
tokenizer = AutoTokenizer.from_pretrained(self.model_name)

input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
Expand All @@ -356,8 +356,8 @@ def test_fqn_to_config_param_over_module_precedence(self):
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
self.assertTrue(not isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, AffineQuantizedTensor))
self.assertTrue(isinstance(quantized_model.model.layers[3].self_attn.k_proj.weight, AffineQuantizedTensor))
self.assertTrue(not isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, TorchAOBaseTensor))
self.assertTrue(isinstance(quantized_model.model.layers[3].self_attn.k_proj.weight, TorchAOBaseTensor))
tokenizer = AutoTokenizer.from_pretrained(self.model_name)

input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
Expand All @@ -383,8 +383,8 @@ def test_fqn_to_config_exact_over_regex_precedence(self):
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
self.assertTrue(not isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, AffineQuantizedTensor))
self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, AffineQuantizedTensor))
self.assertTrue(not isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, TorchAOBaseTensor))
self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, TorchAOBaseTensor))
self.assertTrue(isinstance(quantized_model.model.layers[2].self_attn.q_proj.weight, Float8Tensor))

tokenizer = AutoTokenizer.from_pretrained(self.model_name)
Expand Down Expand Up @@ -418,7 +418,7 @@ def test_fqn_to_config_non_weight_param(self):
self.assertTrue(
not isinstance(quantized_model.model.layers[0].feed_forward.experts.gate_up_proj, Float8Tensor)
)
self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, AffineQuantizedTensor))
self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, TorchAOBaseTensor))

def test_compute_module_sizes(self):
r"""
Expand Down
Loading