Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 0 additions & 19 deletions bitsandbytes/backends/cpu_xpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,25 +370,6 @@ def quantize_4bit_impl(
quant_type=quant_type,
)

if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and input_shape[1] % blocksize == 0 and quant_type == "nf4":
# lowp_mode: lowest precision for computation
lowp_mode = ipex_cpu.quantization.WoqLowpMode.BF16
state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack(
out.reshape([input_shape[0], input_shape[1] // 2]),
ipex_cpu.quantization.WoqWeightDtype.NF4,
input_shape, # weight shape
absmax.view(input_shape[0], input_shape[1] // blocksize), # scales
None, # zero_points
None, # bias
None, # g_idx
None, # batch_size
blocksize,
int(lowp_mode),
-1, # act_quant_mode. -1 means don't quant activation
)
state.absmax = torch.Tensor()
return torch.empty([1, 0], dtype=torch.uint8), state

return out.unsqueeze(0), state


Expand Down
27 changes: 23 additions & 4 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
OutlierTracer,
enable_ipex_fusion,
)

T = TypeVar("T", bound="torch.nn.Module")
Expand Down Expand Up @@ -444,17 +445,35 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
save weight and bias,
then fill state_dict with components of quant_state
"""
if (
getattr(self.weight, "quant_state", None) is not None
and getattr(self.weight.quant_state, "op_context", None) is not None
):
context = self.weight.quant_state.op_context
self.weight.data = context.to_public(context.get_weight()).reshape([1, -1])

super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias

if getattr(self.weight, "quant_state", None) is not None:
if (
self.weight.quant_state.absmax.shape.numel() == 0
and getattr(self.weight.quant_state, "op_context", None) is not None
):
self.weight.quant_state.absmax = context.get_scales().reshape(-1)
delattr(self.weight.quant_state, "op_context")
for k, v in self.weight.quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
if getattr(self.weight.quant_state, "op_context", None) is not None:
context = self.weight.quant_state.op_context
destination[prefix + "weight." + "absmax"] = context.get_scales().reshape(-1)
self.weight.data = context.to_public(context.get_weight()).reshape([1, -1])

def forward(self, x: torch.Tensor):
# Check if ipex fusion can be used
if (
x.device.type == "cpu"
and not hasattr(self.weight.quant_state, "op_context")
and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0
and self.weight.quant_state.quant_type == "nf4"
):
enable_ipex_fusion(self.weight, self.weight.quant_state)

# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
Expand Down
24 changes: 24 additions & 0 deletions bitsandbytes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,30 @@ def unpack_tensor_to_dict(tensor_data):
return unpacked_dict


def enable_ipex_fusion(weight, quant_state):
from bitsandbytes.backends.cpu_xpu_common import _ipex_cpu_version_prereq

if _ipex_cpu_version_prereq(2, 3):
import intel_extension_for_pytorch as ipex

lowp_mode = ipex.quantization.WoqLowpMode.BF16
quant_state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack(
weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
ipex.quantization.WoqWeightDtype.NF4,
quant_state.shape, # weight shape
quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales
None, # zero_points
None, # bias
None, # g_idx
None, # batch_size
quant_state.blocksize,
int(lowp_mode),
-1, # act_quant_mode. -1 means don't quant activation
)
quant_state.absmax = torch.Tensor()
weight.data = torch.empty([1, 0], dtype=torch.uint8)


class QuantState:
"""container for quantization state components to work with Params4bit and similar classes"""

Expand Down