From b8a8a1789a695a8b9982ab6b8b9779187cf0d152 Mon Sep 17 00:00:00 2001 From: YangKai0616 Date: Wed, 16 Jul 2025 09:22:15 +0000 Subject: [PATCH 1/5] Implemented 32bit optimizers in triton --- bitsandbytes/_ops.py | 4 +- bitsandbytes/backends/triton/kernels_optim.py | 594 ++++++++++++++++++ bitsandbytes/backends/triton/ops.py | 48 +- bitsandbytes/backends/xpu/ops.py | 1 + bitsandbytes/functional.py | 2 +- tests/test_optim.py | 3 + 6 files changed, 648 insertions(+), 4 deletions(-) mode change 100644 => 100755 bitsandbytes/_ops.py create mode 100644 bitsandbytes/backends/triton/kernels_optim.py mode change 100644 => 100755 bitsandbytes/backends/triton/ops.py diff --git a/bitsandbytes/_ops.py b/bitsandbytes/_ops.py old mode 100644 new mode 100755 index e47e6f436..38ec62988 --- a/bitsandbytes/_ops.py +++ b/bitsandbytes/_ops.py @@ -352,7 +352,7 @@ def _( torch.library.define( "bitsandbytes::optimizer_update_32bit", - "(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, Tensor(a4!)? unorm_vec, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, float gnorm_scale, bool skip_zeros=False) -> ()", + "(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, Tensor(a4!)? unorm_vec, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, float weight_decay, float gnorm_scale, bool skip_zeros=False) -> ()", ) @@ -371,9 +371,9 @@ def _( beta3: float, alpha: float, eps: float, - weight_decay: float, step: int, lr: float, + weight_decay: float, gnorm_scale: float, skip_zeros=False, ) -> None: diff --git a/bitsandbytes/backends/triton/kernels_optim.py b/bitsandbytes/backends/triton/kernels_optim.py new file mode 100644 index 000000000..5be5553cd --- /dev/null +++ b/bitsandbytes/backends/triton/kernels_optim.py @@ -0,0 +1,594 @@ +import math +from typing import Optional + +import torch + +import triton +import triton.language as tl +# from triton.language.extra import libdevice + +########################################### +# Pure torch implementation for reference # +########################################### + +@torch.compile +def optimizer_update_32bit_impl_torch( + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + unorm_vec: Optional[torch.Tensor], + max_unorm: float, + param_norm: float, + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + step: int, + lr: float, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, +) -> None: + """ + Torch实现的32位优化器,用于性能对比 + """ + if skip_zeros: + raise NotImplementedError("skip_zeros is not supported on XPU yet") + + # 应用梯度缩放 + g_scaled = gnorm_scale * g + + update_scale = 1.0 + + # 根据优化器类型进行参数更新 + if optimizer_name == "adam": + # 更新状态 + state1.mul_(beta1).add_(g_scaled, alpha=1.0 - beta1) + state2.mul_(beta2).addcmul_(g_scaled, g_scaled, value=1.0 - beta2) + + # 计算修正因子 + correction1 = 1.0 - beta1 ** step + correction2_sqrt = math.sqrt(1.0 - beta2 ** step) + + # 计算 unorm + if max_unorm > 0.0 and unorm_vec is not None: + # AdamW: unorm is computed from corrected first moment, but before weight decay + # See https://github.com/NVIDIA/apex/blob/22603ab6109e346438b8bc439427f8791055b416/apex/optimizers/fused_adam.py#L265 + s1_corrected = state1 / correction1 + update_vals = s1_corrected / (torch.sqrt(state2) + eps) + update_norm = torch.sum(update_vals * update_vals) + unorm_vec.fill_(update_norm) + current_unorm = torch.sqrt(update_norm) + if current_unorm > max_unorm * param_norm: + update_scale = (max_unorm * param_norm) / current_unorm + + # 应用权重衰减 (decoupled weight decay) + if weight_decay > 0.0: + p.mul_(1.0 - lr * weight_decay) + + # 更新参数 + step_size = -lr * correction2_sqrt / correction1 + update_val = state1 / (torch.sqrt(state2) + eps * correction2_sqrt) + p.add_(update_val, alpha=update_scale * step_size) + + elif optimizer_name == "ademamix": + s1_vals = state1[0] + s3_vals = state1[1] + + # 更新状态 + s1_vals.mul_(beta1).add_(g_scaled, alpha=1.0 - beta1) + s3_vals.mul_(beta3).add_(g_scaled, alpha=1.0 - beta3) + state2.mul_(beta2).addcmul_(g_scaled, g_scaled, value=1.0 - beta2) + + # 计算修正因子 + correction1 = 1.0 - beta1 ** step + correction2_sqrt = math.sqrt(1.0 - beta2 ** step) + + # 计算更新值 + numerator = (s1_vals / correction1) + (alpha * s3_vals) + denominator = (torch.sqrt(state2) / correction2_sqrt) + eps + update_vals = numerator / denominator + + if max_unorm > 0.0 and unorm_vec is not None: + update_norm = torch.sum(update_vals * update_vals) + unorm_vec.fill_(update_norm) + current_unorm = torch.sqrt(update_norm) + if current_unorm > max_unorm * param_norm: + update_scale = (max_unorm * param_norm) / current_unorm + + # 应用权重衰减 + if weight_decay > 0.0: + p.mul_(1.0 - lr * weight_decay) + + # 更新参数 + p.add_(update_vals, alpha=-lr * update_scale) + + elif optimizer_name in ["momentum", "rmsprop", "adagrad", "lion"]: + # 这些优化器的 weight_decay 是耦合的 + g_with_decay = g_scaled + if weight_decay > 0.0: + g_with_decay = g_with_decay.add(p, alpha=weight_decay) + + if optimizer_name == "momentum": + state1.mul_(beta1).add_(g_with_decay) + update_vals = state1 + elif optimizer_name == "rmsprop": + state1.mul_(beta1).addcmul_(g_with_decay, g_with_decay, value=1.0 - beta1) + update_vals = g_with_decay / (torch.sqrt(state1) + eps) + elif optimizer_name == "adagrad": + state1.addcmul_(g_with_decay, g_with_decay, value=1.0) + update_vals = g_with_decay / (torch.sqrt(state1) + eps) + elif optimizer_name == "lion": + # Lion 更新: c = sign(beta1 * m + (1-beta1) * g) + # p = p - lr * c + # m = beta2 * m + (1-beta2) * g + momentum_update = state1.mul(beta1).add(g_with_decay, alpha=1.0 - beta1) + update_vals = torch.sign(momentum_update) + state1.mul_(beta2).add_(g_with_decay, alpha=1.0 - beta2) + + # 计算 unorm + if max_unorm > 0.0 and unorm_vec is not None: + # 对于Lion, unorm是基于更新后的动量计算的 + unorm_calc_source = state1 if optimizer_name == "lion" else update_vals + update_norm = torch.sum(unorm_calc_source * unorm_calc_source) + unorm_vec.fill_(update_norm) + current_unorm = torch.sqrt(update_norm) + if current_unorm > max_unorm * param_norm: + update_scale = (max_unorm * param_norm) / current_unorm + + # 更新参数 + if optimizer_name == "lion": + p.add_(update_vals, alpha=-lr * update_scale) + else: + p.add_(update_vals, alpha=-lr * update_scale) + + else: + raise ValueError(f"Unsupported optimizer: {optimizer_name}") + +######################### +# Triton implementation # +######################### + +MOMENTUM = 0 +RMSPROP = 1 +ADAGRAD = 2 +ADAM = 3 +# LION should be larger than MOMENTUM, RMSPROP, ADAGRAD due to comparison in kernels +LION = 4 +ADEMAMIX = 5 + +name2optimizer_id = { + "momentum": MOMENTUM, + "rmsprop": RMSPROP, + "adagrad": ADAGRAD, + "adam": ADAM, + "lion": LION, + "ademamix": ADEMAMIX, +} + + +@triton.jit +def _optimizer_precondition_2state_32bit( + g_ptr, + p_ptr, + state1_ptr, + state2_ptr, + unorm_ptr, + beta1: tl.constexpr, + beta2: tl.constexpr, + eps: tl.constexpr, + weight_decay: tl.constexpr, + step, + beta1_step, + beta2_step, + lr, + gnorm_scale: tl.constexpr, + n_elements, + OPTIMIZER_ID: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + N_PER_TH: tl.constexpr, +): + """预处理优化器,计算更新范数(2状态优化器)""" + pid = tl.program_id(axis=0) + block_start_idx = pid * N_PER_TH + offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH) + mask = offsets < n_elements + + # 加载梯度和状态 + g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0) + s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0) + s2_vals = tl.load(state2_ptr + offsets, mask=mask, other=0.0) + + # 应用梯度缩放 + g_vals = gnorm_scale * g_vals + + # 计算修正因子 + correction1 = 1.0 / (1.0 - beta1_step) + correction2 = 1.0 / (1.0 - beta2_step) + + # 根据优化器类型更新状态 + if OPTIMIZER_ID == 3: # ADAM + # 更新动量 + s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals + s2_vals = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals + + # 应用修正 + s1_vals = s1_vals * correction1 + s2_vals = s2_vals * correction2 + + # 计算更新值 + update_vals = s1_vals / (tl.sqrt(s2_vals) + eps) + # 计算更新范数 + update_norm = update_vals * update_vals + + elif OPTIMIZER_ID == 5: # ADEMAMIX + update_norm = s1_vals + + # 累加更新范数 + total_norm = tl.sum(tl.where(mask, update_norm, 0.0)) + + # 原子加到全局范数 + tl.atomic_add(unorm_ptr, total_norm) + + +@triton.jit +def _optimizer_precondition_1state_32bit( + g_ptr, + p_ptr, + state1_ptr, + state2_ptr, + unorm_ptr, + beta1: tl.constexpr, + beta2: tl.constexpr, + eps: tl.constexpr, + weight_decay, + step, + beta1_step, + beta2_step, + lr, + gnorm_scale: tl.constexpr, + n_elements, + OPTIMIZER_ID: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + N_PER_TH: tl.constexpr, +): + """预处理优化器,计算更新范数(1状态优化器)""" + pid = tl.program_id(axis=0) + block_start_idx = pid * N_PER_TH + offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH) + mask = offsets < n_elements + + # 加载梯度和状态 + g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0) + s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0) + + # 应用梯度缩放 + g_vals = gnorm_scale * g_vals + + if OPTIMIZER_ID == 0: # MOMENTUM + # 更新动量 + if step == 1: + s1_vals = g_vals + else: + s1_vals = s1_vals * beta1 + g_vals + update_norm = s1_vals * s1_vals + + elif OPTIMIZER_ID == 4: # LION + # LION 只更新状态,不计算范数 + s1_vals = s1_vals * beta2 + (1.0 - beta2) * g_vals + # update_norm = tl.zeros_like(g_vals) + update_norm = s1_vals + + elif OPTIMIZER_ID == 1: # RMSPROP + # 更新RMS状态 + s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals * g_vals + update_vals = g_vals / (tl.sqrt(s1_vals) + eps) + update_norm = update_vals * update_vals + + elif OPTIMIZER_ID == 2: # ADAGRAD + # 更新累积梯度平方 + s1_vals = s1_vals + g_vals * g_vals + update_vals = g_vals / (tl.sqrt(s1_vals) + eps) + update_norm = update_vals * update_vals + + # 累加更新范数 + total_norm = tl.sum(tl.where(mask, update_norm, 0.0)) + + # 原子加到全局范数 + tl.atomic_add(unorm_ptr, total_norm) + + +@triton.jit +def _optimizer_update_2state_32bit_triton_kernel( + g_ptr, + p_ptr, + state1_ptr, + state2_ptr, + unorm_ptr, + max_unorm: tl.constexpr, + param_norm, + beta1: tl.constexpr, + beta2: tl.constexpr, + beta3, + alpha, + eps: tl.constexpr, + weight_decay: tl.constexpr, + step, + beta1_step, + beta2_step, + lr, + gnorm_scale: tl.constexpr, + skip_zeros, + n_elements, + OPTIMIZER_ID: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + N_PER_TH: tl.constexpr, +): + """2状态优化器内核""" + pid = tl.program_id(axis=0) + block_start_idx = pid * N_PER_TH + offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH) + mask = offsets < n_elements + + # 加载数据 + g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + p_vals = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0) + s2_vals = tl.load(state2_ptr + offsets, mask=mask, other=0.0) + + # 对于ADEMAMIX,需要加载额外的状态 + if OPTIMIZER_ID == 5: # ADEMAMIX + s3_vals = tl.load(state1_ptr + n_elements + offsets, mask=mask, other=0.0) + + # 应用梯度缩放 + g_vals = gnorm_scale * g_vals + + # 计算更新缩放因子 + update_scale = 1.0 + if max_unorm > 0.0: + current_unorm = tl.sqrt(tl.load(unorm_ptr)) + if current_unorm > max_unorm * param_norm: + update_scale = (max_unorm * param_norm) / current_unorm + + # 根据优化器类型进行更新 + if OPTIMIZER_ID == 3: # ADAM + # 更新状态 + s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals + s2_vals = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals + + # 计算修正因子 + correction1 = 1.0 - beta1_step + correction2 = tl.sqrt(1.0 - beta2_step) + step_size = -lr * correction2 / correction1 + + # 应用权重衰减 + if weight_decay > 0.0: + p_vals = p_vals * (1.0 - lr * weight_decay) + + # 更新参数 + update_val = update_scale * step_size * (s1_vals / (tl.sqrt(s2_vals) + eps * correction2)) + p_vals = p_vals + update_val + + elif OPTIMIZER_ID == 5: # ADEMAMIX + # 更新状态 + s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals # m1 + s3_vals = s3_vals * beta3 + (1.0 - beta3) * g_vals # m2 + s2_vals = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals # nu + + # 计算修正因子 + correction1 = 1.0 - beta1_step + correction2 = tl.sqrt(1.0 - beta2_step) + + # 应用权重衰减 + if weight_decay > 0.0: + p_vals = p_vals * (1.0 - lr * weight_decay) + + # 更新参数 + numerator = (s1_vals / correction1) + (alpha * s3_vals) + denominator = (tl.sqrt(s2_vals) / correction2) + eps + p_vals = p_vals - lr * (numerator / denominator) + + # 存储更新后的值 + tl.store(p_ptr + offsets, p_vals, mask=mask) + tl.store(state1_ptr + offsets, s1_vals, mask=mask) + tl.store(state2_ptr + offsets, s2_vals, mask=mask) + + # 对于ADEMAMIX,存储额外状态 + if OPTIMIZER_ID == 5: # ADEMAMIX + tl.store(state1_ptr + n_elements + offsets, s3_vals, mask=mask) + + +@triton.jit +def _optimizer_update_1state_32bit_triton_kernel( + g_ptr, + p_ptr, + state1_ptr, + state2_ptr, + unorm_ptr, + max_unorm: tl.constexpr, + param_norm, + beta1: tl.constexpr, + beta2: tl.constexpr, + beta3, + alpha, + eps: tl.constexpr, + weight_decay: tl.constexpr, + step, + beta1_step, + beta2_step, + lr, + gnorm_scale: tl.constexpr, + skip_zeros, + n_elements, + OPTIMIZER_ID: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + N_PER_TH: tl.constexpr, +): + """1状态优化器内核""" + pid = tl.program_id(axis=0) + block_start_idx = pid * N_PER_TH + offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH) + mask = offsets < n_elements + + # 加载数据 + g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + p_vals = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0) + + # 应用梯度缩放和权重衰减 + g_vals = gnorm_scale * g_vals + if weight_decay > 0.0: + g_vals = g_vals + p_vals * weight_decay + + # 计算更新缩放因子 + update_scale = 1.0 + if max_unorm > 0.0: + current_unorm = tl.sqrt(tl.load(unorm_ptr)) + if current_unorm > max_unorm * param_norm + eps: + update_scale = (max_unorm * param_norm + eps) / current_unorm + + # 根据优化器类型进行更新 + if OPTIMIZER_ID == 0: # MOMENTUM + # 更新动量 + if step == 1: + s1_vals = g_vals + else: + s1_vals = s1_vals * beta1 + g_vals + + # 更新参数 + update_val = update_scale * (-lr * s1_vals) + p_vals = p_vals + update_val + + elif OPTIMIZER_ID == 4: # LION + # LION 优化器 + momentum_update = s1_vals * beta1 + (1.0 - beta1) * g_vals + update_val = update_scale * lr * tl.where(momentum_update > 0, 1.0, tl.where(momentum_update < 0, -1.0, 0.0)) + p_vals = p_vals - update_val + + # 更新动量状态 + s1_vals = s1_vals * beta2 + (1.0 - beta2) * g_vals + + elif OPTIMIZER_ID == 1: # RMSPROP + # 更新RMS状态 + s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals * g_vals + + # 更新参数 + update_val = update_scale * lr * g_vals / (tl.sqrt(s1_vals) + eps) + p_vals = p_vals - update_val + + elif OPTIMIZER_ID == 2: # ADAGRAD + # 更新累积梯度平方 + s1_vals = s1_vals + g_vals * g_vals + + # 更新参数 + update_val = lr * g_vals / (tl.sqrt(s1_vals) + eps) + p_vals = p_vals - update_val + + # 存储更新后的值 + tl.store(p_ptr + offsets, p_vals, mask=mask) + tl.store(state1_ptr + offsets, s1_vals, mask=mask) + + +name2optimizer_32bit_fn = { + "adam": { + "preprocess": _optimizer_precondition_2state_32bit, + "update": _optimizer_update_2state_32bit_triton_kernel, + }, + "ademamix": { + "preprocess": _optimizer_precondition_2state_32bit, + "update": _optimizer_update_2state_32bit_triton_kernel, + }, + "momentum": { + "preprocess": _optimizer_precondition_1state_32bit, + "update": _optimizer_update_1state_32bit_triton_kernel, + }, + "rmsprop": { + "preprocess": _optimizer_precondition_1state_32bit, + "update": _optimizer_update_1state_32bit_triton_kernel, + }, + "adagrad": { + "preprocess": _optimizer_precondition_1state_32bit, + "update": _optimizer_update_1state_32bit_triton_kernel, + }, + "lion": { + "preprocess": _optimizer_precondition_1state_32bit, + "update": _optimizer_update_1state_32bit_triton_kernel, + }, +} + + +def optimizer_update_32bit_impl( + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + unorm_vec: Optional[torch.Tensor], + max_unorm: float, + param_norm: float, + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + step: int, + lr: float, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, +) -> None: + """ + Triton实现的32位优化器 + """ + if skip_zeros: + raise NotImplementedError("skip_zeros is not supported on XPU yet") + + BLOCK_SIZE = 256 + N_PER_TH = 1 # Number of blocks processed per thread. + grid = (triton.cdiv(p.numel(), BLOCK_SIZE * N_PER_TH),) + optimizer_id = name2optimizer_id[optimizer_name] + fn_preprocess = name2optimizer_32bit_fn[optimizer_name]["preprocess"] + fn_update = name2optimizer_32bit_fn[optimizer_name]["update"] + + # In torch=2.7 on XPU there is an issue with libdevice.pow, leading to an error. + # For backwards compatibility we precompute the bias correction factors. + beta1_step = beta1**step + beta2_step = beta2**step + + # lion特殊处理 + if optimizer_name == "lion": + fn_update[grid]( + g, p, state1, state2, unorm_vec, max_unorm, param_norm, + beta1, beta2, beta3, alpha, eps, weight_decay, step, + beta1_step, beta2_step, lr, gnorm_scale, skip_zeros, + p.numel(), optimizer_id, BLOCK_SIZE, N_PER_TH, num_warps=2, + ) + + if max_unorm > 0.0: + unorm_vec.zero_() + fn_preprocess[grid]( + g, p, state1, state2, unorm_vec, + beta1, beta2, eps, weight_decay, step, + beta1_step, beta2_step, lr, gnorm_scale, + p.numel(), optimizer_id, BLOCK_SIZE, N_PER_TH, num_warps=2, + ) + + else: + if max_unorm > 0.0: + unorm_vec.zero_() + fn_preprocess[grid]( + g, p, state1, state2, unorm_vec, + beta1, beta2, eps, weight_decay, step, + beta1_step, beta2_step, lr, gnorm_scale, + p.numel(), optimizer_id, BLOCK_SIZE, N_PER_TH, num_warps=2, + ) + + fn_update[grid]( + g, p, state1, state2, unorm_vec, max_unorm, param_norm, + beta1, beta2, beta3, alpha, eps, weight_decay, step, + beta1_step, beta2_step, lr, gnorm_scale, skip_zeros, + p.numel(), optimizer_id, BLOCK_SIZE, N_PER_TH, num_warps=2, + ) diff --git a/bitsandbytes/backends/triton/ops.py b/bitsandbytes/backends/triton/ops.py old mode 100644 new mode 100755 index 058c2747d..cc0f06e6e --- a/bitsandbytes/backends/triton/ops.py +++ b/bitsandbytes/backends/triton/ops.py @@ -1,8 +1,9 @@ from collections.abc import Sequence +from typing import Optional import torch -from . import triton_kernels +from . import triton_kernels, kernels_optim # currently codes unused, kept for reference # Should be the same for quant/dequant @@ -175,3 +176,48 @@ def gemv_4bit( B_dq_triton, bias=None, ) + + +# optimizer_update_32bit_impl = kernels_optim.optimizer_update_32bit_impl_torch +optimizer_update_32bit_impl = kernels_optim.optimizer_update_32bit_impl +def optimizer_update_32bit( + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + unorm_vec: Optional[torch.Tensor], + max_unorm: float, + param_norm: float, + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + step: int, + lr: float, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, +) -> None: + with torch_accelerator_module.device(state1.device): + optimizer_update_32bit_impl( + optimizer_name=optimizer_name, + g=g, + p=p, + state1=state1, + state2=state2, + unorm_vec=unorm_vec, + max_unorm=max_unorm, + param_norm=param_norm, + beta1=beta1, + beta2=beta2, + beta3=beta3, + alpha=alpha, + eps=eps, + step=step, + lr=lr, + weight_decay=weight_decay, + gnorm_scale=gnorm_scale, + skip_zeros=skip_zeros, + ) diff --git a/bitsandbytes/backends/xpu/ops.py b/bitsandbytes/backends/xpu/ops.py index 88f448bcd..83c8537fb 100755 --- a/bitsandbytes/backends/xpu/ops.py +++ b/bitsandbytes/backends/xpu/ops.py @@ -65,5 +65,6 @@ def _( register_kernel("bitsandbytes::dequantize_4bit.out", "xpu")(triton_ops.dequantize_4bit_inplace) register_kernel("bitsandbytes::dequantize_4bit", "xpu")(triton_ops.dequantize_4bit) register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit) + register_kernel("bitsandbytes::optimizer_update_32bit", "xpu")(triton_ops.optimizer_update_32bit) else: warnings.warn("XPU available but no ipex or triton packages found.") diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index c9f5ece60..b148415ca 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1198,9 +1198,9 @@ def optimizer_update_32bit( beta3, alpha, eps, - weight_decay, step, lr, + weight_decay, gnorm_scale, skip_zeros, ) diff --git a/tests/test_optim.py b/tests/test_optim.py index 858adbe4c..3d4157152 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -178,6 +178,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name, device): if optim_name.startswith("paged_") and sys.platform == "win32": pytest.skip("Paged optimizers can have issues on Windows.") + if optim_name.startswith("paged_") and device == "xpu": + pytest.skip("Paged optimizers are not supported on XPU currently.") + if gtype == torch.bfloat16 and optim_name in ["momentum", "rmsprop"]: pytest.skip() if dim1 == 1 and dim2 == 1: From 5b784a3af5d2cfecd17cdcbb4b8696099d98de08 Mon Sep 17 00:00:00 2001 From: YangKai0616 Date: Wed, 16 Jul 2025 09:47:38 +0000 Subject: [PATCH 2/5] Modify Comments --- bitsandbytes/backends/triton/kernels_optim.py | 83 ++----------------- 1 file changed, 7 insertions(+), 76 deletions(-) mode change 100644 => 100755 bitsandbytes/backends/triton/kernels_optim.py diff --git a/bitsandbytes/backends/triton/kernels_optim.py b/bitsandbytes/backends/triton/kernels_optim.py old mode 100644 new mode 100755 index 5be5553cd..04c2ec8e1 --- a/bitsandbytes/backends/triton/kernels_optim.py +++ b/bitsandbytes/backends/triton/kernels_optim.py @@ -33,30 +33,22 @@ def optimizer_update_32bit_impl_torch( skip_zeros=False, ) -> None: """ - Torch实现的32位优化器,用于性能对比 + 32-bit optimizer implemented by Torch for performance comparison """ if skip_zeros: raise NotImplementedError("skip_zeros is not supported on XPU yet") - # 应用梯度缩放 g_scaled = gnorm_scale * g - update_scale = 1.0 - # 根据优化器类型进行参数更新 if optimizer_name == "adam": - # 更新状态 state1.mul_(beta1).add_(g_scaled, alpha=1.0 - beta1) state2.mul_(beta2).addcmul_(g_scaled, g_scaled, value=1.0 - beta2) - # 计算修正因子 correction1 = 1.0 - beta1 ** step correction2_sqrt = math.sqrt(1.0 - beta2 ** step) - # 计算 unorm if max_unorm > 0.0 and unorm_vec is not None: - # AdamW: unorm is computed from corrected first moment, but before weight decay - # See https://github.com/NVIDIA/apex/blob/22603ab6109e346438b8bc439427f8791055b416/apex/optimizers/fused_adam.py#L265 s1_corrected = state1 / correction1 update_vals = s1_corrected / (torch.sqrt(state2) + eps) update_norm = torch.sum(update_vals * update_vals) @@ -65,11 +57,9 @@ def optimizer_update_32bit_impl_torch( if current_unorm > max_unorm * param_norm: update_scale = (max_unorm * param_norm) / current_unorm - # 应用权重衰减 (decoupled weight decay) if weight_decay > 0.0: p.mul_(1.0 - lr * weight_decay) - # 更新参数 step_size = -lr * correction2_sqrt / correction1 update_val = state1 / (torch.sqrt(state2) + eps * correction2_sqrt) p.add_(update_val, alpha=update_scale * step_size) @@ -78,16 +68,13 @@ def optimizer_update_32bit_impl_torch( s1_vals = state1[0] s3_vals = state1[1] - # 更新状态 s1_vals.mul_(beta1).add_(g_scaled, alpha=1.0 - beta1) s3_vals.mul_(beta3).add_(g_scaled, alpha=1.0 - beta3) state2.mul_(beta2).addcmul_(g_scaled, g_scaled, value=1.0 - beta2) - # 计算修正因子 correction1 = 1.0 - beta1 ** step correction2_sqrt = math.sqrt(1.0 - beta2 ** step) - # 计算更新值 numerator = (s1_vals / correction1) + (alpha * s3_vals) denominator = (torch.sqrt(state2) / correction2_sqrt) + eps update_vals = numerator / denominator @@ -99,15 +86,12 @@ def optimizer_update_32bit_impl_torch( if current_unorm > max_unorm * param_norm: update_scale = (max_unorm * param_norm) / current_unorm - # 应用权重衰减 if weight_decay > 0.0: p.mul_(1.0 - lr * weight_decay) - # 更新参数 p.add_(update_vals, alpha=-lr * update_scale) elif optimizer_name in ["momentum", "rmsprop", "adagrad", "lion"]: - # 这些优化器的 weight_decay 是耦合的 g_with_decay = g_scaled if weight_decay > 0.0: g_with_decay = g_with_decay.add(p, alpha=weight_decay) @@ -122,16 +106,11 @@ def optimizer_update_32bit_impl_torch( state1.addcmul_(g_with_decay, g_with_decay, value=1.0) update_vals = g_with_decay / (torch.sqrt(state1) + eps) elif optimizer_name == "lion": - # Lion 更新: c = sign(beta1 * m + (1-beta1) * g) - # p = p - lr * c - # m = beta2 * m + (1-beta2) * g momentum_update = state1.mul(beta1).add(g_with_decay, alpha=1.0 - beta1) update_vals = torch.sign(momentum_update) state1.mul_(beta2).add_(g_with_decay, alpha=1.0 - beta2) - # 计算 unorm if max_unorm > 0.0 and unorm_vec is not None: - # 对于Lion, unorm是基于更新后的动量计算的 unorm_calc_source = state1 if optimizer_name == "lion" else update_vals update_norm = torch.sum(unorm_calc_source * unorm_calc_source) unorm_vec.fill_(update_norm) @@ -139,7 +118,6 @@ def optimizer_update_32bit_impl_torch( if current_unorm > max_unorm * param_norm: update_scale = (max_unorm * param_norm) / current_unorm - # 更新参数 if optimizer_name == "lion": p.add_(update_vals, alpha=-lr * update_scale) else: @@ -191,46 +169,37 @@ def _optimizer_precondition_2state_32bit( BLOCK_SIZE: tl.constexpr, N_PER_TH: tl.constexpr, ): - """预处理优化器,计算更新范数(2状态优化器)""" + """Preprocessing optimizer, computing update norm (2-state optimizer)""" pid = tl.program_id(axis=0) block_start_idx = pid * N_PER_TH offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH) mask = offsets < n_elements - # 加载梯度和状态 g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0) s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0) s2_vals = tl.load(state2_ptr + offsets, mask=mask, other=0.0) - # 应用梯度缩放 g_vals = gnorm_scale * g_vals - # 计算修正因子 correction1 = 1.0 / (1.0 - beta1_step) correction2 = 1.0 / (1.0 - beta2_step) - # 根据优化器类型更新状态 if OPTIMIZER_ID == 3: # ADAM - # 更新动量 s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals s2_vals = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals - # 应用修正 s1_vals = s1_vals * correction1 s2_vals = s2_vals * correction2 - # 计算更新值 update_vals = s1_vals / (tl.sqrt(s2_vals) + eps) - # 计算更新范数 + update_norm = update_vals * update_vals elif OPTIMIZER_ID == 5: # ADEMAMIX update_norm = s1_vals - # 累加更新范数 total_norm = tl.sum(tl.where(mask, update_norm, 0.0)) - # 原子加到全局范数 tl.atomic_add(unorm_ptr, total_norm) @@ -255,21 +224,18 @@ def _optimizer_precondition_1state_32bit( BLOCK_SIZE: tl.constexpr, N_PER_TH: tl.constexpr, ): - """预处理优化器,计算更新范数(1状态优化器)""" + """Preprocessing optimizer, computing update norm (1-state optimizer)""" pid = tl.program_id(axis=0) block_start_idx = pid * N_PER_TH offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH) mask = offsets < n_elements - # 加载梯度和状态 g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0) s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0) - # 应用梯度缩放 g_vals = gnorm_scale * g_vals if OPTIMIZER_ID == 0: # MOMENTUM - # 更新动量 if step == 1: s1_vals = g_vals else: @@ -277,27 +243,21 @@ def _optimizer_precondition_1state_32bit( update_norm = s1_vals * s1_vals elif OPTIMIZER_ID == 4: # LION - # LION 只更新状态,不计算范数 s1_vals = s1_vals * beta2 + (1.0 - beta2) * g_vals - # update_norm = tl.zeros_like(g_vals) update_norm = s1_vals elif OPTIMIZER_ID == 1: # RMSPROP - # 更新RMS状态 s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals * g_vals update_vals = g_vals / (tl.sqrt(s1_vals) + eps) update_norm = update_vals * update_vals elif OPTIMIZER_ID == 2: # ADAGRAD - # 更新累积梯度平方 s1_vals = s1_vals + g_vals * g_vals update_vals = g_vals / (tl.sqrt(s1_vals) + eps) update_norm = update_vals * update_vals - # 累加更新范数 total_norm = tl.sum(tl.where(mask, update_norm, 0.0)) - # 原子加到全局范数 tl.atomic_add(unorm_ptr, total_norm) @@ -327,76 +287,61 @@ def _optimizer_update_2state_32bit_triton_kernel( BLOCK_SIZE: tl.constexpr, N_PER_TH: tl.constexpr, ): - """2状态优化器内核""" + """2-state optimizer kernel""" pid = tl.program_id(axis=0) block_start_idx = pid * N_PER_TH offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH) mask = offsets < n_elements - # 加载数据 g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32) p_vals = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32) s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0) s2_vals = tl.load(state2_ptr + offsets, mask=mask, other=0.0) - # 对于ADEMAMIX,需要加载额外的状态 if OPTIMIZER_ID == 5: # ADEMAMIX s3_vals = tl.load(state1_ptr + n_elements + offsets, mask=mask, other=0.0) - # 应用梯度缩放 g_vals = gnorm_scale * g_vals - # 计算更新缩放因子 update_scale = 1.0 if max_unorm > 0.0: current_unorm = tl.sqrt(tl.load(unorm_ptr)) if current_unorm > max_unorm * param_norm: update_scale = (max_unorm * param_norm) / current_unorm - # 根据优化器类型进行更新 if OPTIMIZER_ID == 3: # ADAM - # 更新状态 s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals s2_vals = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals - # 计算修正因子 correction1 = 1.0 - beta1_step correction2 = tl.sqrt(1.0 - beta2_step) step_size = -lr * correction2 / correction1 - # 应用权重衰减 if weight_decay > 0.0: p_vals = p_vals * (1.0 - lr * weight_decay) - # 更新参数 update_val = update_scale * step_size * (s1_vals / (tl.sqrt(s2_vals) + eps * correction2)) p_vals = p_vals + update_val elif OPTIMIZER_ID == 5: # ADEMAMIX - # 更新状态 s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals # m1 s3_vals = s3_vals * beta3 + (1.0 - beta3) * g_vals # m2 s2_vals = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals # nu - # 计算修正因子 correction1 = 1.0 - beta1_step correction2 = tl.sqrt(1.0 - beta2_step) - # 应用权重衰减 if weight_decay > 0.0: p_vals = p_vals * (1.0 - lr * weight_decay) - # 更新参数 numerator = (s1_vals / correction1) + (alpha * s3_vals) denominator = (tl.sqrt(s2_vals) / correction2) + eps p_vals = p_vals - lr * (numerator / denominator) - # 存储更新后的值 tl.store(p_ptr + offsets, p_vals, mask=mask) tl.store(state1_ptr + offsets, s1_vals, mask=mask) tl.store(state2_ptr + offsets, s2_vals, mask=mask) - # 对于ADEMAMIX,存储额外状态 if OPTIMIZER_ID == 5: # ADEMAMIX tl.store(state1_ptr + n_elements + offsets, s3_vals, mask=mask) @@ -427,67 +372,54 @@ def _optimizer_update_1state_32bit_triton_kernel( BLOCK_SIZE: tl.constexpr, N_PER_TH: tl.constexpr, ): - """1状态优化器内核""" + """1-state optimizer kernel""" pid = tl.program_id(axis=0) block_start_idx = pid * N_PER_TH offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH) mask = offsets < n_elements - # 加载数据 g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32) p_vals = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32) s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0) - # 应用梯度缩放和权重衰减 g_vals = gnorm_scale * g_vals if weight_decay > 0.0: g_vals = g_vals + p_vals * weight_decay - # 计算更新缩放因子 update_scale = 1.0 if max_unorm > 0.0: current_unorm = tl.sqrt(tl.load(unorm_ptr)) if current_unorm > max_unorm * param_norm + eps: update_scale = (max_unorm * param_norm + eps) / current_unorm - # 根据优化器类型进行更新 if OPTIMIZER_ID == 0: # MOMENTUM - # 更新动量 if step == 1: s1_vals = g_vals else: s1_vals = s1_vals * beta1 + g_vals - # 更新参数 update_val = update_scale * (-lr * s1_vals) p_vals = p_vals + update_val elif OPTIMIZER_ID == 4: # LION - # LION 优化器 momentum_update = s1_vals * beta1 + (1.0 - beta1) * g_vals update_val = update_scale * lr * tl.where(momentum_update > 0, 1.0, tl.where(momentum_update < 0, -1.0, 0.0)) p_vals = p_vals - update_val - # 更新动量状态 s1_vals = s1_vals * beta2 + (1.0 - beta2) * g_vals elif OPTIMIZER_ID == 1: # RMSPROP - # 更新RMS状态 s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals * g_vals - # 更新参数 update_val = update_scale * lr * g_vals / (tl.sqrt(s1_vals) + eps) p_vals = p_vals - update_val elif OPTIMIZER_ID == 2: # ADAGRAD - # 更新累积梯度平方 s1_vals = s1_vals + g_vals * g_vals - # 更新参数 update_val = lr * g_vals / (tl.sqrt(s1_vals) + eps) p_vals = p_vals - update_val - # 存储更新后的值 tl.store(p_ptr + offsets, p_vals, mask=mask) tl.store(state1_ptr + offsets, s1_vals, mask=mask) @@ -541,7 +473,7 @@ def optimizer_update_32bit_impl( skip_zeros=False, ) -> None: """ - Triton实现的32位优化器 + 32-bit optimizer implemented by Triton """ if skip_zeros: raise NotImplementedError("skip_zeros is not supported on XPU yet") @@ -558,7 +490,6 @@ def optimizer_update_32bit_impl( beta1_step = beta1**step beta2_step = beta2**step - # lion特殊处理 if optimizer_name == "lion": fn_update[grid]( g, p, state1, state2, unorm_vec, max_unorm, param_norm, From 4e40c7fe575b30baf2cd9d761cab9e038422ab40 Mon Sep 17 00:00:00 2001 From: YangKai0616 Date: Thu, 17 Jul 2025 10:15:06 +0000 Subject: [PATCH 3/5] Optimizing pure torch implementation --- bitsandbytes/backends/triton/kernels_optim.py | 325 ++++++++++++------ 1 file changed, 220 insertions(+), 105 deletions(-) diff --git a/bitsandbytes/backends/triton/kernels_optim.py b/bitsandbytes/backends/triton/kernels_optim.py index 04c2ec8e1..fb5aed84f 100755 --- a/bitsandbytes/backends/triton/kernels_optim.py +++ b/bitsandbytes/backends/triton/kernels_optim.py @@ -7,13 +7,90 @@ import triton.language as tl # from triton.language.extra import libdevice +MOMENTUM = 0 +RMSPROP = 1 +ADAGRAD = 2 +ADAM = 3 +# LION should be larger than MOMENTUM, RMSPROP, ADAGRAD due to comparison in kernels +LION = 4 +ADEMAMIX = 5 + +name2optimizer_id = { + "momentum": MOMENTUM, + "rmsprop": RMSPROP, + "adagrad": ADAGRAD, + "adam": ADAM, + "lion": LION, + "ademamix": ADEMAMIX, +} + ########################################### # Pure torch implementation for reference # ########################################### @torch.compile -def optimizer_update_32bit_impl_torch( - optimizer_name: str, +def _optimizer_precondition_32bit_torch( + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + unorm_vec: torch.Tensor, + beta1: float, + beta2: float, + eps: float, + weight_decay: float, + step: int, + lr: float, + gnorm_scale: float, + optimizer_id: int, +): + """Preprocessing optimizer, computing update norm""" + + g_vals = gnorm_scale * g + + if optimizer_id == 3: # ADAM + correction1 = 1.0 / (1.0 - beta1**step) + correction2 = 1.0 / (1.0 - beta2**step) + + s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals + s2_vals = state2 * beta2 + (1.0 - beta2) * g_vals * g_vals + + s1_vals = s1_vals * correction1 + s2_vals = s2_vals * correction2 + + update_vals = s1_vals / (torch.sqrt(s2_vals) + eps) + update_norm = update_vals * update_vals + + elif optimizer_id == 5: # ADEMAMIX + update_norm = state1 + + elif optimizer_id == 0: # MOMENTUM + if step == 1: + s1_vals = g_vals + else: + s1_vals = state1 * beta1 + g_vals + update_norm = s1_vals * s1_vals + + elif optimizer_id == 4: # LION + s1_vals = state1 * beta2 + (1.0 - beta2) * g_vals + update_norm = s1_vals + + elif optimizer_id == 1: # RMSPROP + s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals * g_vals + update_vals = g_vals / (torch.sqrt(s1_vals) + eps) + update_norm = update_vals * update_vals + + elif optimizer_id == 2: # ADAGRAD + s1_vals = state1 + g_vals * g_vals + update_vals = g_vals / (torch.sqrt(s1_vals) + eps) + update_norm = update_vals * update_vals + + total_norm = torch.sum(update_norm) + unorm_vec.add_(total_norm) + + +@torch.compile +def _optimizer_update_32bit_torch( g: torch.Tensor, p: torch.Tensor, state1: torch.Tensor, @@ -26,128 +103,166 @@ def optimizer_update_32bit_impl_torch( beta3: float, alpha: float, eps: float, + weight_decay: float, step: int, lr: float, - weight_decay: float = 0.0, - gnorm_scale: float = 1.0, - skip_zeros=False, -) -> None: - """ - 32-bit optimizer implemented by Torch for performance comparison - """ - if skip_zeros: - raise NotImplementedError("skip_zeros is not supported on XPU yet") + gnorm_scale: float, + optimizer_id: int, +): + """Unified optimizer update kernel""" - g_scaled = gnorm_scale * g - update_scale = 1.0 + p_vals = p.float() + g_vals = (gnorm_scale * g).float() + if optimizer_id in [0, 1, 2, 4] and weight_decay > 0.0: + g_vals = g_vals + p_vals * weight_decay - if optimizer_name == "adam": - state1.mul_(beta1).add_(g_scaled, alpha=1.0 - beta1) - state2.mul_(beta2).addcmul_(g_scaled, g_scaled, value=1.0 - beta2) - - correction1 = 1.0 - beta1 ** step - correction2_sqrt = math.sqrt(1.0 - beta2 ** step) - - if max_unorm > 0.0 and unorm_vec is not None: - s1_corrected = state1 / correction1 - update_vals = s1_corrected / (torch.sqrt(state2) + eps) - update_norm = torch.sum(update_vals * update_vals) - unorm_vec.fill_(update_norm) - current_unorm = torch.sqrt(update_norm) + update_scale = 1.0 + if max_unorm > 0.0: + current_unorm = torch.sqrt(unorm_vec) + if optimizer_id in [0, 1, 2, 4]: # 1-state optimizers + if current_unorm > max_unorm * param_norm + eps: + update_scale = (max_unorm * param_norm + eps) / current_unorm + else: # 2-state optimizers if current_unorm > max_unorm * param_norm: update_scale = (max_unorm * param_norm) / current_unorm + if optimizer_id == 3: # ADAM + s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals + s2_vals = state2 * beta2 + (1.0 - beta2) * g_vals * g_vals + + correction1 = 1.0 - beta1**step + correction2 = math.sqrt(1.0 - beta2**step) + step_size = -lr * correction2 / correction1 + if weight_decay > 0.0: - p.mul_(1.0 - lr * weight_decay) - - step_size = -lr * correction2_sqrt / correction1 - update_val = state1 / (torch.sqrt(state2) + eps * correction2_sqrt) - p.add_(update_val, alpha=update_scale * step_size) - - elif optimizer_name == "ademamix": + p_vals = p_vals * (1.0 - lr * weight_decay) + + update_val = update_scale * step_size * (s1_vals / (torch.sqrt(s2_vals) + eps * correction2)) + p_vals = p_vals + update_val + + state1.copy_(s1_vals) + state2.copy_(s2_vals) + + elif optimizer_id == 5: # ADEMAMIX s1_vals = state1[0] s3_vals = state1[1] - - s1_vals.mul_(beta1).add_(g_scaled, alpha=1.0 - beta1) - s3_vals.mul_(beta3).add_(g_scaled, alpha=1.0 - beta3) - state2.mul_(beta2).addcmul_(g_scaled, g_scaled, value=1.0 - beta2) - - correction1 = 1.0 - beta1 ** step - correction2_sqrt = math.sqrt(1.0 - beta2 ** step) - - numerator = (s1_vals / correction1) + (alpha * s3_vals) - denominator = (torch.sqrt(state2) / correction2_sqrt) + eps - update_vals = numerator / denominator - - if max_unorm > 0.0 and unorm_vec is not None: - update_norm = torch.sum(update_vals * update_vals) - unorm_vec.fill_(update_norm) - current_unorm = torch.sqrt(update_norm) - if current_unorm > max_unorm * param_norm: - update_scale = (max_unorm * param_norm) / current_unorm + s2_vals = state2 + + m1 = s1_vals * beta1 + (1.0 - beta1) * g_vals + m2 = s3_vals * beta3 + (1.0 - beta3) * g_vals + nu = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals + + correction1 = 1.0 - beta1**step + correction2 = math.sqrt(1.0 - beta2**step) if weight_decay > 0.0: - p.mul_(1.0 - lr * weight_decay) + p_vals = p_vals * (1.0 - lr * weight_decay) + + mixed_momentum = (m1 / correction1) + (alpha * m2) + adaptive_term = (torch.sqrt(nu) / correction2) + eps + p_vals = p_vals - lr * (mixed_momentum / adaptive_term) + + state1[0].copy_(m1) + state1[1].copy_(m2) + state2.copy_(nu) + + elif optimizer_id == 0: # MOMENTUM + if step == 1: + s1_vals = g_vals + else: + s1_vals = state1 * beta1 + g_vals + + update_val = update_scale * (-lr * s1_vals) + p_vals = p_vals + update_val - p.add_(update_vals, alpha=-lr * update_scale) + state1.copy_(s1_vals) - elif optimizer_name in ["momentum", "rmsprop", "adagrad", "lion"]: - g_with_decay = g_scaled - if weight_decay > 0.0: - g_with_decay = g_with_decay.add(p, alpha=weight_decay) - - if optimizer_name == "momentum": - state1.mul_(beta1).add_(g_with_decay) - update_vals = state1 - elif optimizer_name == "rmsprop": - state1.mul_(beta1).addcmul_(g_with_decay, g_with_decay, value=1.0 - beta1) - update_vals = g_with_decay / (torch.sqrt(state1) + eps) - elif optimizer_name == "adagrad": - state1.addcmul_(g_with_decay, g_with_decay, value=1.0) - update_vals = g_with_decay / (torch.sqrt(state1) + eps) - elif optimizer_name == "lion": - momentum_update = state1.mul(beta1).add(g_with_decay, alpha=1.0 - beta1) - update_vals = torch.sign(momentum_update) - state1.mul_(beta2).add_(g_with_decay, alpha=1.0 - beta2) - - if max_unorm > 0.0 and unorm_vec is not None: - unorm_calc_source = state1 if optimizer_name == "lion" else update_vals - update_norm = torch.sum(unorm_calc_source * unorm_calc_source) - unorm_vec.fill_(update_norm) - current_unorm = torch.sqrt(update_norm) - if current_unorm > max_unorm * param_norm: - update_scale = (max_unorm * param_norm) / current_unorm + elif optimizer_id == 4: # LION + momentum_update = state1 * beta1 + (1.0 - beta1) * g_vals + update_val = update_scale * lr * torch.sign(momentum_update) + p_vals = p_vals - update_val + + s1_vals = state1 * beta2 + (1.0 - beta2) * g_vals + state1.copy_(s1_vals) + + elif optimizer_id == 1: # RMSPROP + s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals * g_vals + update_val = update_scale * lr * g_vals / (torch.sqrt(s1_vals) + eps) + p_vals = p_vals - update_val + + state1.copy_(s1_vals) - if optimizer_name == "lion": - p.add_(update_vals, alpha=-lr * update_scale) - else: - p.add_(update_vals, alpha=-lr * update_scale) + elif optimizer_id == 2: # ADAGRAD + s1_vals = state1 + g_vals * g_vals + update_val = lr * g_vals / (torch.sqrt(s1_vals) + eps) + p_vals = p_vals - update_val + + state1.copy_(s1_vals) + + p.copy_(p_vals) + + +def optimizer_update_32bit_impl_torch( + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + unorm_vec: Optional[torch.Tensor], + max_unorm: float, + param_norm: float, + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + step: int, + lr: float, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, +) -> None: + """ + 32-bit optimizer implemented by PyTorch with @torch.compile + """ + if skip_zeros: + raise NotImplementedError("skip_zeros is not supported yet") + optimizer_id = name2optimizer_id[optimizer_name] + + if optimizer_name == "lion": + _optimizer_update_32bit_torch( + g, p, state1, state2, unorm_vec, max_unorm, param_norm, + beta1, beta2, beta3, alpha, eps, weight_decay, step, + lr, gnorm_scale, optimizer_id + ) + + if max_unorm > 0.0: + unorm_vec.zero_() + _optimizer_precondition_32bit_torch( + g, p, state1, state2, unorm_vec, + beta1, beta2, eps, weight_decay, step, + lr, gnorm_scale, optimizer_id + ) else: - raise ValueError(f"Unsupported optimizer: {optimizer_name}") + if max_unorm > 0.0: + unorm_vec.zero_() + _optimizer_precondition_32bit_torch( + g, p, state1, state2, unorm_vec, + beta1, beta2, eps, weight_decay, step, + lr, gnorm_scale, optimizer_id + ) + + _optimizer_update_32bit_torch( + g, p, state1, state2, unorm_vec, max_unorm, param_norm, + beta1, beta2, beta3, alpha, eps, weight_decay, step, + lr, gnorm_scale, optimizer_id + ) ######################### # Triton implementation # ######################### -MOMENTUM = 0 -RMSPROP = 1 -ADAGRAD = 2 -ADAM = 3 -# LION should be larger than MOMENTUM, RMSPROP, ADAGRAD due to comparison in kernels -LION = 4 -ADEMAMIX = 5 - -name2optimizer_id = { - "momentum": MOMENTUM, - "rmsprop": RMSPROP, - "adagrad": ADAGRAD, - "adam": ADAM, - "lion": LION, - "ademamix": ADEMAMIX, -} - - @triton.jit def _optimizer_precondition_2state_32bit( g_ptr, @@ -334,9 +449,9 @@ def _optimizer_update_2state_32bit_triton_kernel( if weight_decay > 0.0: p_vals = p_vals * (1.0 - lr * weight_decay) - numerator = (s1_vals / correction1) + (alpha * s3_vals) - denominator = (tl.sqrt(s2_vals) / correction2) + eps - p_vals = p_vals - lr * (numerator / denominator) + mixed_momentum = (s1_vals / correction1) + (alpha * s3_vals) + adaptive_term = (tl.sqrt(s2_vals) / correction2) + eps + p_vals = p_vals - lr * (mixed_momentum / adaptive_term) tl.store(p_ptr + offsets, p_vals, mask=mask) tl.store(state1_ptr + offsets, s1_vals, mask=mask) From 06279af86a5cd9af199ab807f2ee90b9a2ca21aa Mon Sep 17 00:00:00 2001 From: YangKai0616 Date: Fri, 18 Jul 2025 05:48:31 +0000 Subject: [PATCH 4/5] Restore the order of parameters and modify the position of pure pytorch implementation --- bitsandbytes/_ops.py | 4 +- bitsandbytes/backends/default/ops.py | 252 +++++++++++++++++- bitsandbytes/backends/triton/kernels_optim.py | 242 +---------------- bitsandbytes/backends/triton/ops.py | 10 +- bitsandbytes/functional.py | 2 +- 5 files changed, 259 insertions(+), 251 deletions(-) mode change 100644 => 100755 bitsandbytes/backends/default/ops.py diff --git a/bitsandbytes/_ops.py b/bitsandbytes/_ops.py index 38ec62988..e47e6f436 100755 --- a/bitsandbytes/_ops.py +++ b/bitsandbytes/_ops.py @@ -352,7 +352,7 @@ def _( torch.library.define( "bitsandbytes::optimizer_update_32bit", - "(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, Tensor(a4!)? unorm_vec, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, float weight_decay, float gnorm_scale, bool skip_zeros=False) -> ()", + "(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, Tensor(a4!)? unorm_vec, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, float gnorm_scale, bool skip_zeros=False) -> ()", ) @@ -371,9 +371,9 @@ def _( beta3: float, alpha: float, eps: float, + weight_decay: float, step: int, lr: float, - weight_decay: float, gnorm_scale: float, skip_zeros=False, ) -> None: diff --git a/bitsandbytes/backends/default/ops.py b/bitsandbytes/backends/default/ops.py old mode 100644 new mode 100755 index ce5926979..a7cfb17a6 --- a/bitsandbytes/backends/default/ops.py +++ b/bitsandbytes/backends/default/ops.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from math import prod +from math import prod, sqrt from typing import Optional import torch @@ -301,3 +301,253 @@ def _( B_dq, bias=None, ) + + +MOMENTUM = 0 +RMSPROP = 1 +ADAGRAD = 2 +ADAM = 3 +# LION should be larger than MOMENTUM, RMSPROP, ADAGRAD due to comparison in kernels +LION = 4 +ADEMAMIX = 5 + +name2optimizer_id = { + "momentum": MOMENTUM, + "rmsprop": RMSPROP, + "adagrad": ADAGRAD, + "adam": ADAM, + "lion": LION, + "ademamix": ADEMAMIX, +} + +@torch.compile +def _optimizer_precondition_32bit( + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + unorm_vec: torch.Tensor, + beta1: float, + beta2: float, + eps: float, + weight_decay: float, + step: int, + lr: float, + gnorm_scale: float, + optimizer_id: int, +): + """Preprocessing optimizer, computing update norm""" + + g_vals = gnorm_scale * g + + if optimizer_id == 3: # ADAM + correction1 = 1.0 / (1.0 - beta1**step) + correction2 = 1.0 / (1.0 - beta2**step) + + s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals + s2_vals = state2 * beta2 + (1.0 - beta2) * g_vals * g_vals + + s1_vals = s1_vals * correction1 + s2_vals = s2_vals * correction2 + + update_vals = s1_vals / (torch.sqrt(s2_vals) + eps) + update_norm = update_vals * update_vals + + elif optimizer_id == 5: # ADEMAMIX + update_norm = state1 + + elif optimizer_id == 0: # MOMENTUM + if step == 1: + s1_vals = g_vals + else: + s1_vals = state1 * beta1 + g_vals + update_norm = s1_vals * s1_vals + + elif optimizer_id == 4: # LION + s1_vals = state1 * beta2 + (1.0 - beta2) * g_vals + update_norm = s1_vals + + elif optimizer_id == 1: # RMSPROP + s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals * g_vals + update_vals = g_vals / (torch.sqrt(s1_vals) + eps) + update_norm = update_vals * update_vals + + elif optimizer_id == 2: # ADAGRAD + s1_vals = state1 + g_vals * g_vals + update_vals = g_vals / (torch.sqrt(s1_vals) + eps) + update_norm = update_vals * update_vals + + total_norm = torch.sum(update_norm) + unorm_vec.add_(total_norm) + + +@torch.compile +def _optimizer_update_32bit( + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + unorm_vec: Optional[torch.Tensor], + max_unorm: float, + param_norm: float, + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + weight_decay: float, + step: int, + lr: float, + gnorm_scale: float, + optimizer_id: int, +): + """Unified optimizer update kernel""" + + p_vals = p.float() + g_vals = (gnorm_scale * g).float() + if optimizer_id in [0, 1, 2, 4] and weight_decay > 0.0: + g_vals = g_vals + p_vals * weight_decay + + update_scale = 1.0 + if max_unorm > 0.0: + current_unorm = torch.sqrt(unorm_vec) + if optimizer_id in [0, 1, 2, 4]: # 1-state optimizers + if current_unorm > max_unorm * param_norm + eps: + update_scale = (max_unorm * param_norm + eps) / current_unorm + else: # 2-state optimizers + if current_unorm > max_unorm * param_norm: + update_scale = (max_unorm * param_norm) / current_unorm + + if optimizer_id == 3: # ADAM + s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals + s2_vals = state2 * beta2 + (1.0 - beta2) * g_vals * g_vals + + correction1 = 1.0 - beta1**step + correction2 = sqrt(1.0 - beta2**step) + step_size = -lr * correction2 / correction1 + + if weight_decay > 0.0: + p_vals = p_vals * (1.0 - lr * weight_decay) + + update_val = update_scale * step_size * (s1_vals / (torch.sqrt(s2_vals) + eps * correction2)) + p_vals = p_vals + update_val + + state1.copy_(s1_vals) + state2.copy_(s2_vals) + + elif optimizer_id == 5: # ADEMAMIX + s1_vals = state1[0] + s3_vals = state1[1] + s2_vals = state2 + + m1 = s1_vals * beta1 + (1.0 - beta1) * g_vals + m2 = s3_vals * beta3 + (1.0 - beta3) * g_vals + nu = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals + + correction1 = 1.0 - beta1**step + correction2 = sqrt(1.0 - beta2**step) + + if weight_decay > 0.0: + p_vals = p_vals * (1.0 - lr * weight_decay) + + mixed_momentum = (m1 / correction1) + (alpha * m2) + adaptive_term = (torch.sqrt(nu) / correction2) + eps + p_vals = p_vals - lr * (mixed_momentum / adaptive_term) + + state1[0].copy_(m1) + state1[1].copy_(m2) + state2.copy_(nu) + + elif optimizer_id == 0: # MOMENTUM + if step == 1: + s1_vals = g_vals + else: + s1_vals = state1 * beta1 + g_vals + + update_val = update_scale * (-lr * s1_vals) + p_vals = p_vals + update_val + + state1.copy_(s1_vals) + + elif optimizer_id == 4: # LION + momentum_update = state1 * beta1 + (1.0 - beta1) * g_vals + update_val = update_scale * lr * torch.sign(momentum_update) + p_vals = p_vals - update_val + + s1_vals = state1 * beta2 + (1.0 - beta2) * g_vals + state1.copy_(s1_vals) + + elif optimizer_id == 1: # RMSPROP + s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals * g_vals + update_val = update_scale * lr * g_vals / (torch.sqrt(s1_vals) + eps) + p_vals = p_vals - update_val + + state1.copy_(s1_vals) + + elif optimizer_id == 2: # ADAGRAD + s1_vals = state1 + g_vals * g_vals + update_val = lr * g_vals / (torch.sqrt(s1_vals) + eps) + p_vals = p_vals - update_val + + state1.copy_(s1_vals) + + p.copy_(p_vals) + + +@register_kernel("bitsandbytes::optimizer_update_32bit", "default") +def _( + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + unorm_vec: Optional[torch.Tensor], + max_unorm: float, + param_norm: float, + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + weight_decay: float, + step: int, + lr: float, + gnorm_scale: float = 1.0, + skip_zeros=False, +) -> None: + """ + 32-bit optimizer implemented by PyTorch with @torch.compile + """ + if skip_zeros: + raise NotImplementedError("skip_zeros is not supported yet") + + optimizer_id = name2optimizer_id[optimizer_name] + + if optimizer_name == "lion": + _optimizer_update_32bit( + g, p, state1, state2, unorm_vec, max_unorm, param_norm, + beta1, beta2, beta3, alpha, eps, weight_decay, step, + lr, gnorm_scale, optimizer_id + ) + + if max_unorm > 0.0: + unorm_vec.zero_() + _optimizer_precondition_32bit( + g, p, state1, state2, unorm_vec, + beta1, beta2, eps, weight_decay, step, + lr, gnorm_scale, optimizer_id + ) + else: + if max_unorm > 0.0: + unorm_vec.zero_() + _optimizer_precondition_32bit( + g, p, state1, state2, unorm_vec, + beta1, beta2, eps, weight_decay, step, + lr, gnorm_scale, optimizer_id + ) + + _optimizer_update_32bit( + g, p, state1, state2, unorm_vec, max_unorm, param_norm, + beta1, beta2, beta3, alpha, eps, weight_decay, step, + lr, gnorm_scale, optimizer_id + ) diff --git a/bitsandbytes/backends/triton/kernels_optim.py b/bitsandbytes/backends/triton/kernels_optim.py index fb5aed84f..e2dcaac5f 100755 --- a/bitsandbytes/backends/triton/kernels_optim.py +++ b/bitsandbytes/backends/triton/kernels_optim.py @@ -1,4 +1,3 @@ -import math from typing import Optional import torch @@ -24,245 +23,6 @@ "ademamix": ADEMAMIX, } -########################################### -# Pure torch implementation for reference # -########################################### - -@torch.compile -def _optimizer_precondition_32bit_torch( - g: torch.Tensor, - p: torch.Tensor, - state1: torch.Tensor, - state2: Optional[torch.Tensor], - unorm_vec: torch.Tensor, - beta1: float, - beta2: float, - eps: float, - weight_decay: float, - step: int, - lr: float, - gnorm_scale: float, - optimizer_id: int, -): - """Preprocessing optimizer, computing update norm""" - - g_vals = gnorm_scale * g - - if optimizer_id == 3: # ADAM - correction1 = 1.0 / (1.0 - beta1**step) - correction2 = 1.0 / (1.0 - beta2**step) - - s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals - s2_vals = state2 * beta2 + (1.0 - beta2) * g_vals * g_vals - - s1_vals = s1_vals * correction1 - s2_vals = s2_vals * correction2 - - update_vals = s1_vals / (torch.sqrt(s2_vals) + eps) - update_norm = update_vals * update_vals - - elif optimizer_id == 5: # ADEMAMIX - update_norm = state1 - - elif optimizer_id == 0: # MOMENTUM - if step == 1: - s1_vals = g_vals - else: - s1_vals = state1 * beta1 + g_vals - update_norm = s1_vals * s1_vals - - elif optimizer_id == 4: # LION - s1_vals = state1 * beta2 + (1.0 - beta2) * g_vals - update_norm = s1_vals - - elif optimizer_id == 1: # RMSPROP - s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals * g_vals - update_vals = g_vals / (torch.sqrt(s1_vals) + eps) - update_norm = update_vals * update_vals - - elif optimizer_id == 2: # ADAGRAD - s1_vals = state1 + g_vals * g_vals - update_vals = g_vals / (torch.sqrt(s1_vals) + eps) - update_norm = update_vals * update_vals - - total_norm = torch.sum(update_norm) - unorm_vec.add_(total_norm) - - -@torch.compile -def _optimizer_update_32bit_torch( - g: torch.Tensor, - p: torch.Tensor, - state1: torch.Tensor, - state2: Optional[torch.Tensor], - unorm_vec: Optional[torch.Tensor], - max_unorm: float, - param_norm: float, - beta1: float, - beta2: float, - beta3: float, - alpha: float, - eps: float, - weight_decay: float, - step: int, - lr: float, - gnorm_scale: float, - optimizer_id: int, -): - """Unified optimizer update kernel""" - - p_vals = p.float() - g_vals = (gnorm_scale * g).float() - if optimizer_id in [0, 1, 2, 4] and weight_decay > 0.0: - g_vals = g_vals + p_vals * weight_decay - - update_scale = 1.0 - if max_unorm > 0.0: - current_unorm = torch.sqrt(unorm_vec) - if optimizer_id in [0, 1, 2, 4]: # 1-state optimizers - if current_unorm > max_unorm * param_norm + eps: - update_scale = (max_unorm * param_norm + eps) / current_unorm - else: # 2-state optimizers - if current_unorm > max_unorm * param_norm: - update_scale = (max_unorm * param_norm) / current_unorm - - if optimizer_id == 3: # ADAM - s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals - s2_vals = state2 * beta2 + (1.0 - beta2) * g_vals * g_vals - - correction1 = 1.0 - beta1**step - correction2 = math.sqrt(1.0 - beta2**step) - step_size = -lr * correction2 / correction1 - - if weight_decay > 0.0: - p_vals = p_vals * (1.0 - lr * weight_decay) - - update_val = update_scale * step_size * (s1_vals / (torch.sqrt(s2_vals) + eps * correction2)) - p_vals = p_vals + update_val - - state1.copy_(s1_vals) - state2.copy_(s2_vals) - - elif optimizer_id == 5: # ADEMAMIX - s1_vals = state1[0] - s3_vals = state1[1] - s2_vals = state2 - - m1 = s1_vals * beta1 + (1.0 - beta1) * g_vals - m2 = s3_vals * beta3 + (1.0 - beta3) * g_vals - nu = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals - - correction1 = 1.0 - beta1**step - correction2 = math.sqrt(1.0 - beta2**step) - - if weight_decay > 0.0: - p_vals = p_vals * (1.0 - lr * weight_decay) - - mixed_momentum = (m1 / correction1) + (alpha * m2) - adaptive_term = (torch.sqrt(nu) / correction2) + eps - p_vals = p_vals - lr * (mixed_momentum / adaptive_term) - - state1[0].copy_(m1) - state1[1].copy_(m2) - state2.copy_(nu) - - elif optimizer_id == 0: # MOMENTUM - if step == 1: - s1_vals = g_vals - else: - s1_vals = state1 * beta1 + g_vals - - update_val = update_scale * (-lr * s1_vals) - p_vals = p_vals + update_val - - state1.copy_(s1_vals) - - elif optimizer_id == 4: # LION - momentum_update = state1 * beta1 + (1.0 - beta1) * g_vals - update_val = update_scale * lr * torch.sign(momentum_update) - p_vals = p_vals - update_val - - s1_vals = state1 * beta2 + (1.0 - beta2) * g_vals - state1.copy_(s1_vals) - - elif optimizer_id == 1: # RMSPROP - s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals * g_vals - update_val = update_scale * lr * g_vals / (torch.sqrt(s1_vals) + eps) - p_vals = p_vals - update_val - - state1.copy_(s1_vals) - - elif optimizer_id == 2: # ADAGRAD - s1_vals = state1 + g_vals * g_vals - update_val = lr * g_vals / (torch.sqrt(s1_vals) + eps) - p_vals = p_vals - update_val - - state1.copy_(s1_vals) - - p.copy_(p_vals) - - -def optimizer_update_32bit_impl_torch( - optimizer_name: str, - g: torch.Tensor, - p: torch.Tensor, - state1: torch.Tensor, - state2: Optional[torch.Tensor], - unorm_vec: Optional[torch.Tensor], - max_unorm: float, - param_norm: float, - beta1: float, - beta2: float, - beta3: float, - alpha: float, - eps: float, - step: int, - lr: float, - weight_decay: float = 0.0, - gnorm_scale: float = 1.0, - skip_zeros=False, -) -> None: - """ - 32-bit optimizer implemented by PyTorch with @torch.compile - """ - if skip_zeros: - raise NotImplementedError("skip_zeros is not supported yet") - - optimizer_id = name2optimizer_id[optimizer_name] - - if optimizer_name == "lion": - _optimizer_update_32bit_torch( - g, p, state1, state2, unorm_vec, max_unorm, param_norm, - beta1, beta2, beta3, alpha, eps, weight_decay, step, - lr, gnorm_scale, optimizer_id - ) - - if max_unorm > 0.0: - unorm_vec.zero_() - _optimizer_precondition_32bit_torch( - g, p, state1, state2, unorm_vec, - beta1, beta2, eps, weight_decay, step, - lr, gnorm_scale, optimizer_id - ) - else: - if max_unorm > 0.0: - unorm_vec.zero_() - _optimizer_precondition_32bit_torch( - g, p, state1, state2, unorm_vec, - beta1, beta2, eps, weight_decay, step, - lr, gnorm_scale, optimizer_id - ) - - _optimizer_update_32bit_torch( - g, p, state1, state2, unorm_vec, max_unorm, param_norm, - beta1, beta2, beta3, alpha, eps, weight_decay, step, - lr, gnorm_scale, optimizer_id - ) - -######################### -# Triton implementation # -######################### - @triton.jit def _optimizer_precondition_2state_32bit( g_ptr, @@ -581,9 +341,9 @@ def optimizer_update_32bit_impl( beta3: float, alpha: float, eps: float, + weight_decay: float, step: int, lr: float, - weight_decay: float = 0.0, gnorm_scale: float = 1.0, skip_zeros=False, ) -> None: diff --git a/bitsandbytes/backends/triton/ops.py b/bitsandbytes/backends/triton/ops.py index cc0f06e6e..645eb5c30 100755 --- a/bitsandbytes/backends/triton/ops.py +++ b/bitsandbytes/backends/triton/ops.py @@ -178,8 +178,6 @@ def gemv_4bit( ) -# optimizer_update_32bit_impl = kernels_optim.optimizer_update_32bit_impl_torch -optimizer_update_32bit_impl = kernels_optim.optimizer_update_32bit_impl def optimizer_update_32bit( optimizer_name: str, g: torch.Tensor, @@ -194,14 +192,14 @@ def optimizer_update_32bit( beta3: float, alpha: float, eps: float, + weight_decay: float, step: int, lr: float, - weight_decay: float = 0.0, - gnorm_scale: float = 1.0, + gnorm_scale: float, skip_zeros=False, ) -> None: with torch_accelerator_module.device(state1.device): - optimizer_update_32bit_impl( + kernels_optim.optimizer_update_32bit_impl( optimizer_name=optimizer_name, g=g, p=p, @@ -215,9 +213,9 @@ def optimizer_update_32bit( beta3=beta3, alpha=alpha, eps=eps, + weight_decay=weight_decay, step=step, lr=lr, - weight_decay=weight_decay, gnorm_scale=gnorm_scale, skip_zeros=skip_zeros, ) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index b148415ca..c9f5ece60 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1198,9 +1198,9 @@ def optimizer_update_32bit( beta3, alpha, eps, + weight_decay, step, lr, - weight_decay, gnorm_scale, skip_zeros, ) From 810e8cb623db5d039e7077e20442469539bcd8a7 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Thu, 17 Jul 2025 23:21:37 -0700 Subject: [PATCH 5/5] Restore files permissions --- bitsandbytes/_ops.py | 0 bitsandbytes/backends/default/ops.py | 0 bitsandbytes/backends/triton/ops.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) mode change 100755 => 100644 bitsandbytes/_ops.py mode change 100755 => 100644 bitsandbytes/backends/default/ops.py mode change 100755 => 100644 bitsandbytes/backends/triton/ops.py diff --git a/bitsandbytes/_ops.py b/bitsandbytes/_ops.py old mode 100755 new mode 100644 diff --git a/bitsandbytes/backends/default/ops.py b/bitsandbytes/backends/default/ops.py old mode 100755 new mode 100644 diff --git a/bitsandbytes/backends/triton/ops.py b/bitsandbytes/backends/triton/ops.py old mode 100755 new mode 100644