diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index 0fcfffa07..0d865b541 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -370,25 +370,6 @@ def quantize_4bit_impl( quant_type=quant_type, ) - if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and input_shape[1] % blocksize == 0 and quant_type == "nf4": - # lowp_mode: lowest precision for computation - lowp_mode = ipex_cpu.quantization.WoqLowpMode.BF16 - state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack( - out.reshape([input_shape[0], input_shape[1] // 2]), - ipex_cpu.quantization.WoqWeightDtype.NF4, - input_shape, # weight shape - absmax.view(input_shape[0], input_shape[1] // blocksize), # scales - None, # zero_points - None, # bias - None, # g_idx - None, # batch_size - blocksize, - int(lowp_mode), - -1, # act_quant_mode. -1 means don't quant activation - ) - state.absmax = torch.Tensor() - return torch.empty([1, 0], dtype=torch.uint8), state - return out.unsqueeze(0), state diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 2348d0791..ad424a6f4 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -19,6 +19,7 @@ INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer, + enable_ipex_fusion, ) T = TypeVar("T", bound="torch.nn.Module") @@ -444,17 +445,35 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): save weight and bias, then fill state_dict with components of quant_state """ + if ( + getattr(self.weight, "quant_state", None) is not None + and getattr(self.weight.quant_state, "op_context", None) is not None + ): + context = self.weight.quant_state.op_context + self.weight.data = context.to_public(context.get_weight()).reshape([1, -1]) + super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias if getattr(self.weight, "quant_state", None) is not None: + if ( + self.weight.quant_state.absmax.shape.numel() == 0 + and getattr(self.weight.quant_state, "op_context", None) is not None + ): + self.weight.quant_state.absmax = context.get_scales().reshape(-1) + delattr(self.weight.quant_state, "op_context") for k, v in self.weight.quant_state.as_dict(packed=True).items(): destination[prefix + "weight." + k] = v if keep_vars else v.detach() - if getattr(self.weight.quant_state, "op_context", None) is not None: - context = self.weight.quant_state.op_context - destination[prefix + "weight." + "absmax"] = context.get_scales().reshape(-1) - self.weight.data = context.to_public(context.get_weight()).reshape([1, -1]) def forward(self, x: torch.Tensor): + # Check if ipex fusion can be used + if ( + x.device.type == "cpu" + and not hasattr(self.weight.quant_state, "op_context") + and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0 + and self.weight.quant_state.quant_type == "nf4" + ): + enable_ipex_fusion(self.weight, self.weight.quant_state) + # weights are cast automatically as Int8Params, but the bias has to be cast manually if self.bias is not None and self.bias.dtype != x.dtype: self.bias.data = self.bias.data.to(x.dtype) diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index fa9a7eb70..9e52c915d 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -200,6 +200,30 @@ def unpack_tensor_to_dict(tensor_data): return unpacked_dict +def enable_ipex_fusion(weight, quant_state): + from bitsandbytes.backends.cpu_xpu_common import _ipex_cpu_version_prereq + + if _ipex_cpu_version_prereq(2, 3): + import intel_extension_for_pytorch as ipex + + lowp_mode = ipex.quantization.WoqLowpMode.BF16 + quant_state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack( + weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]), + ipex.quantization.WoqWeightDtype.NF4, + quant_state.shape, # weight shape + quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales + None, # zero_points + None, # bias + None, # g_idx + None, # batch_size + quant_state.blocksize, + int(lowp_mode), + -1, # act_quant_mode. -1 means don't quant activation + ) + quant_state.absmax = torch.Tensor() + weight.data = torch.empty([1, 0], dtype=torch.uint8) + + class QuantState: """container for quantization state components to work with Params4bit and similar classes"""