From 639f8c05a4fac7c763a6e055ee59a5698de0a7a7 Mon Sep 17 00:00:00 2001 From: Mohamed Hisham Date: Sat, 2 Aug 2025 03:14:41 +0300 Subject: [PATCH] Fixing quantization uint8 packing bug for NF4 and FP4 --- csrc/kernels.cu | 11 +++----- tests/test_functional.py | 61 ++++++++++++++++++++++++++++++---------- 2 files changed, 50 insertions(+), 22 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 649f2ee1f..97b80f050 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -431,7 +431,6 @@ __global__ void kQuantizeBlockwise( LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); } - unsigned char packed_4bit = 0; switch (DATA_TYPE) { case General8bit: #pragma unroll NUM_PER_TH @@ -445,17 +444,15 @@ __global__ void kQuantizeBlockwise( case FP4: #pragma unroll NUM_PER_TH for (int j = 0; j < NUM_PER_TH / 2; j++) { - packed_4bit |= dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4; - packed_4bit |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max); - qvals[j] = packed_4bit; + qvals[j] = dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4; + qvals[j] |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max); } break; case NF4: #pragma unroll NUM_PER_TH for (int j = 0; j < NUM_PER_TH / 2; j++) { - packed_4bit |= dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4; - packed_4bit |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max); - qvals[j] = packed_4bit; + qvals[j] = dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4; + qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max); } break; } diff --git a/tests/test_functional.py b/tests/test_functional.py index b84db6502..fc37cb4c3 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1125,21 +1125,52 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): # With larger block sizes, we can expect this to blow up. # At blocksize>=1024, don't even bother looking at relerr. - if blocksize <= 64: - assert err.item() < 0.1 - assert relerr.item() < 0.28 - elif blocksize <= 256: - assert err.item() < 0.11 - assert relerr.item() < 0.30 - elif blocksize <= 512: - assert err.item() < 0.12 - assert relerr.item() < 0.31 - elif quant_type == "fp4": - # 1024 => 0.48, 2048 => 0.52, 4096 => 0.56 - assert err.item() < 0.08 + math.log2(blocksize) * 4e-2 - else: - # 1024 => 0.8, 2048 => 0.88, 4096 => 0.96 - assert err.item() < math.log2(blocksize) * 8e-2 + # + # Actually, the above is not true anymore after fixing the integer packing bug. + # The following values were taken from averaging 1k samples per test configuration after fixing the bug. + error_dict = dict() + error_dict["fp4"] = dict() + error_dict["nf4"] = dict() + error_dict["fp4"]["err"] = { + 64: 0.096545, + 128: 0.102947, + 256: 0.108685, + 512: 0.114087, + 1024: 0.119312, + 2048: 0.124460, + 4096: 0.129573, + } + error_dict["fp4"]["rel_err"] = { + 64: 0.260130, + 128: 0.275734, + 256: 0.289842, + 512: 0.302852, + 1024: 0.314982, + 2048: 0.326402, + 4096: 0.337228, + } + + error_dict["nf4"]["err"] = { + 64: 0.072792, + 128: 0.076835, + 256: 0.080326, + 512: 0.083535, + 1024: 0.086603, + 2048: 0.089592, + 4096: 0.092537, + } + error_dict["nf4"]["rel_err"] = { + 64: 0.203299, + 128: 0.215252, + 256: 0.226044, + 512: 0.236021, + 1024: 0.245365, + 2048: 0.254146, + 4096: 0.262457, + } + + assert err < error_dict[quant_type]["err"][blocksize] + 1e-3 + assert relerr < error_dict[quant_type]["rel_err"][blocksize] + 1e-3 @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])