From 7639d27be41ac42d6e47950d707e081587865fe2 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 17 Dec 2024 11:22:37 +0100 Subject: [PATCH] update --- tests/quantization/torchao/test_torchao.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 5c71fc4e0ae7..58c1d3613daf 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -228,8 +228,7 @@ def test_quantization(self): ("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), ("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), ("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])), - ("int_a8w8", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), - ("uint_a16w7", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), + ("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), ] if TorchAoConfig._is_cuda_capability_atleast_8_9(): @@ -253,8 +252,8 @@ def test_quantization(self): for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: quant_kwargs = {} - if quantization_name in ["uint4wo", "uint_a16w7"]: - # The dummy flux model that we use requires us to impose some restrictions on group_size here + if quantization_name in ["uint4wo", "uint7wo"]: + # The dummy flux model that we use has smaller dimensions. This imposes some restrictions on group_size here quant_kwargs.update({"group_size": 16}) quantization_config = TorchAoConfig( quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs