diff --git a/tests/test_functional.py b/tests/test_functional.py index fc37cb4c3..34d3e8412 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1169,8 +1169,12 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): 4096: 0.262457, } - assert err < error_dict[quant_type]["err"][blocksize] + 1e-3 - assert relerr < error_dict[quant_type]["rel_err"][blocksize] + 1e-3 + # Allow higher tolerance for fp32 on CPU with larger block sizes + reltol = 2.8e-3 if dtype == torch.float32 and blocksize >= 128 and device == "cpu" else 1e-3 + errtol = 1.2e-3 if dtype == torch.float32 and blocksize >= 1024 and device == "cpu" else 1e-3 + + assert err < error_dict[quant_type]["err"][blocksize] + errtol + assert relerr < error_dict[quant_type]["rel_err"][blocksize] + reltol @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])