Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 1 addition & 32 deletions colossalai/kernel/cpu_adam_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,6 @@ class CPUAdamLoader(BaseKernelLoader):
Usage:
# init
cpu_adam = CPUAdamLoader().load()
cpu_adam_op = cpu_adam.CPUAdamOptimizer(
alpha, beta1, beta2, epsilon, weight_decay, adamw_mode,
)
...
# optim step
cpu_adam_op.step(
step, lr, beta1, beta2, epsilon, weight_decay, bias_correction,
params, grads, exp_avg, exp_avg_sq, loss_scale,
)

Args:
func CPUAdamOptimizer:
alpha (float): learning rate. Default to 1e-3.
beta1 (float): coefficients used for computing running averages of gradient. Default to 0.9.
beta2 (float): coefficients used for computing running averages of its square. Default to 0.99.
epsilon (float): term added to the denominator to improve numerical stability. Default to 1e-8.
weight_decay (float): weight decay (L2 penalty). Default to 0.
adamw_mode (bool): whether to use the adamw. Default to True.
func step:
step (int): current step.
lr (float): learning rate.
beta1 (float): coefficients used for computing running averages of gradient.
beta2 (float): coefficients used for computing running averages of its square.
epsilon (float): term added to the denominator to improve numerical stability.
weight_decay (float): weight decay (L2 penalty).
bias_correction (bool): whether to use bias correction.
params (torch.Tensor): parameter.
grads (torch.Tensor): gradient.
exp_avg (torch.Tensor): exp average.
exp_avg_sq (torch.Tensor): exp average square.
loss_scale (float): loss scale value.
"""

def __init__(self):
Expand All @@ -57,7 +26,7 @@ def __init__(self):
def fetch_kernel(self):
if platform.machine() == "x86_64":
kernel = self._extension_map["x86"]().fetch()
elif platform.machine() in ["aarch64", "aarch64_be", "armv8b", "armv8l"]:
elif platform.machine() == "aarch64":
kernel = self._extension_map["arm"]().fetch()
else:
raise Exception("not supported")
Expand Down
4 changes: 2 additions & 2 deletions colossalai/kernel/extensions/cpu_adam/arm_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ class ArmCPUAdamExtension(BaseExtension):
def __init__(self) -> None:
super().__init__()
self.kernel_builder = ArmCPUAdamBuilder()
self._requires_build = False
self._requires_build = True

Comment thread
wangbluo marked this conversation as resolved.
@property
def requires_build(self) -> bool:
return self._requires_build

def build(self):
self.kernel_builder.build()
self._requires_build = True
self._requires_build = False

def load(self):
return self.kernel_builder.load()
Expand Down
4 changes: 2 additions & 2 deletions colossalai/kernel/extensions/cpu_adam/x86_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ class X86CPUAdamExtension(BaseExtension):
def __init__(self) -> None:
super().__init__()
self.kernel_builder = X86CPUAdamBuilder()
self._requires_build = False
self._requires_build = True

@property
def requires_build(self) -> bool:
return self._requires_build

def build(self):
self.kernel_builder.build()
self._requires_build = True
self._requires_build = False

def load(self):
return self.kernel_builder.load()
Expand Down