From e9dd6dd47a29cacba04d382ca78361a6a7614e09 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 5 Sep 2024 09:56:49 -0400 Subject: [PATCH 1/3] fix nf4 memory issue by init op_context in forward --- bitsandbytes/nn/modules.py | 18 ++++++++++++++---- bitsandbytes/utils.py | 22 ++++++++++++++++++++++ 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 2348d0791..502c78a3c 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -19,6 +19,7 @@ INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer, + enable_ipex_fusion, ) T = TypeVar("T", bound="torch.nn.Module") @@ -444,17 +445,26 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): save weight and bias, then fill state_dict with components of quant_state """ + if getattr(self.weight, "quant_state", None) is not None and getattr(self.weight.quant_state, "op_context", None) is not None: + context = self.weight.quant_state.op_context + self.weight.data = context.to_public(context.get_weight()).reshape([1, -1]) + super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias if getattr(self.weight, "quant_state", None) is not None: + if self.weight.quant_state.absmax.shape.numel() == 0 and getattr(self.weight.quant_state, "op_context", None) is not None: + self.weight.quant_state.absmax = context.get_scales().reshape(-1) + delattr(self.weight.quant_state, "op_context") for k, v in self.weight.quant_state.as_dict(packed=True).items(): destination[prefix + "weight." + k] = v if keep_vars else v.detach() - if getattr(self.weight.quant_state, "op_context", None) is not None: - context = self.weight.quant_state.op_context - destination[prefix + "weight." + "absmax"] = context.get_scales().reshape(-1) - self.weight.data = context.to_public(context.get_weight()).reshape([1, -1]) def forward(self, x: torch.Tensor): + # Check if ipex fusion can be used + if x.device.type == "cpu" and not hasattr(self.weight.quant_state, "op_context") and \ + self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0 and \ + self.weight.quant_state.quant_type == "nf4": + enable_ipex_fusion(self.weight, self.weight.quant_state) + # weights are cast automatically as Int8Params, but the bias has to be cast manually if self.bias is not None and self.bias.dtype != x.dtype: self.bias.data = self.bias.data.to(x.dtype) diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index fa9a7eb70..8460602fe 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -200,6 +200,28 @@ def unpack_tensor_to_dict(tensor_data): return unpacked_dict +def enable_ipex_fusion(weight, quant_state): + from bitsandbytes.backends.cpu_xpu_common import _ipex_cpu_version_prereq + if _ipex_cpu_version_prereq(2, 3): + import intel_extension_for_pytorch as ipex + lowp_mode = ipex.quantization.WoqLowpMode.BF16 + quant_state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack( + weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]), + ipex.quantization.WoqWeightDtype.NF4, + quant_state.shape, # weight shape + quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales + None, # zero_points + None, # bias + None, # g_idx + None, # batch_size + quant_state.blocksize, + int(lowp_mode), + -1, # act_quant_mode. -1 means don't quant activation + ) + quant_state.absmax = torch.Tensor() + weight.data = torch.empty([1, 0], dtype=torch.uint8) + + class QuantState: """container for quantization state components to work with Params4bit and similar classes""" From 461a5401ad3e32e794ac78619b11162e1510be6a Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 9 Sep 2024 12:05:19 -0400 Subject: [PATCH 2/3] disable repack in init --- bitsandbytes/backends/cpu_xpu_common.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index 0fcfffa07..0d865b541 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -370,25 +370,6 @@ def quantize_4bit_impl( quant_type=quant_type, ) - if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and input_shape[1] % blocksize == 0 and quant_type == "nf4": - # lowp_mode: lowest precision for computation - lowp_mode = ipex_cpu.quantization.WoqLowpMode.BF16 - state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack( - out.reshape([input_shape[0], input_shape[1] // 2]), - ipex_cpu.quantization.WoqWeightDtype.NF4, - input_shape, # weight shape - absmax.view(input_shape[0], input_shape[1] // blocksize), # scales - None, # zero_points - None, # bias - None, # g_idx - None, # batch_size - blocksize, - int(lowp_mode), - -1, # act_quant_mode. -1 means don't quant activation - ) - state.absmax = torch.Tensor() - return torch.empty([1, 0], dtype=torch.uint8), state - return out.unsqueeze(0), state From cfb7663953084295d8b3d243087a9e3deaad3bf9 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 12 Sep 2024 05:07:38 -0400 Subject: [PATCH 3/3] fix code style --- bitsandbytes/nn/modules.py | 19 ++++++++++++++----- bitsandbytes/utils.py | 2 ++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 502c78a3c..ad424a6f4 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -445,14 +445,20 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): save weight and bias, then fill state_dict with components of quant_state """ - if getattr(self.weight, "quant_state", None) is not None and getattr(self.weight.quant_state, "op_context", None) is not None: + if ( + getattr(self.weight, "quant_state", None) is not None + and getattr(self.weight.quant_state, "op_context", None) is not None + ): context = self.weight.quant_state.op_context self.weight.data = context.to_public(context.get_weight()).reshape([1, -1]) super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias if getattr(self.weight, "quant_state", None) is not None: - if self.weight.quant_state.absmax.shape.numel() == 0 and getattr(self.weight.quant_state, "op_context", None) is not None: + if ( + self.weight.quant_state.absmax.shape.numel() == 0 + and getattr(self.weight.quant_state, "op_context", None) is not None + ): self.weight.quant_state.absmax = context.get_scales().reshape(-1) delattr(self.weight.quant_state, "op_context") for k, v in self.weight.quant_state.as_dict(packed=True).items(): @@ -460,9 +466,12 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): def forward(self, x: torch.Tensor): # Check if ipex fusion can be used - if x.device.type == "cpu" and not hasattr(self.weight.quant_state, "op_context") and \ - self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0 and \ - self.weight.quant_state.quant_type == "nf4": + if ( + x.device.type == "cpu" + and not hasattr(self.weight.quant_state, "op_context") + and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0 + and self.weight.quant_state.quant_type == "nf4" + ): enable_ipex_fusion(self.weight, self.weight.quant_state) # weights are cast automatically as Int8Params, but the bias has to be cast manually diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 8460602fe..9e52c915d 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -202,8 +202,10 @@ def unpack_tensor_to_dict(tensor_data): def enable_ipex_fusion(weight, quant_state): from bitsandbytes.backends.cpu_xpu_common import _ipex_cpu_version_prereq + if _ipex_cpu_version_prereq(2, 3): import intel_extension_for_pytorch as ipex + lowp_mode = ipex.quantization.WoqLowpMode.BF16 quant_state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack( weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),