diff --git a/colossalai/kernel/cpu_adam_loader.py b/colossalai/kernel/cpu_adam_loader.py index 0df6bd49b4c9..4763f40ab197 100644 --- a/colossalai/kernel/cpu_adam_loader.py +++ b/colossalai/kernel/cpu_adam_loader.py @@ -12,37 +12,6 @@ class CPUAdamLoader(BaseKernelLoader): Usage: # init cpu_adam = CPUAdamLoader().load() - cpu_adam_op = cpu_adam.CPUAdamOptimizer( - alpha, beta1, beta2, epsilon, weight_decay, adamw_mode, - ) - ... - # optim step - cpu_adam_op.step( - step, lr, beta1, beta2, epsilon, weight_decay, bias_correction, - params, grads, exp_avg, exp_avg_sq, loss_scale, - ) - - Args: - func CPUAdamOptimizer: - alpha (float): learning rate. Default to 1e-3. - beta1 (float): coefficients used for computing running averages of gradient. Default to 0.9. - beta2 (float): coefficients used for computing running averages of its square. Default to 0.99. - epsilon (float): term added to the denominator to improve numerical stability. Default to 1e-8. - weight_decay (float): weight decay (L2 penalty). Default to 0. - adamw_mode (bool): whether to use the adamw. Default to True. - func step: - step (int): current step. - lr (float): learning rate. - beta1 (float): coefficients used for computing running averages of gradient. - beta2 (float): coefficients used for computing running averages of its square. - epsilon (float): term added to the denominator to improve numerical stability. - weight_decay (float): weight decay (L2 penalty). - bias_correction (bool): whether to use bias correction. - params (torch.Tensor): parameter. - grads (torch.Tensor): gradient. - exp_avg (torch.Tensor): exp average. - exp_avg_sq (torch.Tensor): exp average square. - loss_scale (float): loss scale value. """ def __init__(self): @@ -57,7 +26,7 @@ def __init__(self): def fetch_kernel(self): if platform.machine() == "x86_64": kernel = self._extension_map["x86"]().fetch() - elif platform.machine() in ["aarch64", "aarch64_be", "armv8b", "armv8l"]: + elif platform.machine() == "aarch64": kernel = self._extension_map["arm"]().fetch() else: raise Exception("not supported") diff --git a/colossalai/kernel/extensions/cpu_adam/arm_extension.py b/colossalai/kernel/extensions/cpu_adam/arm_extension.py index 9868059bfcfd..0b552f436f3a 100644 --- a/colossalai/kernel/extensions/cpu_adam/arm_extension.py +++ b/colossalai/kernel/extensions/cpu_adam/arm_extension.py @@ -6,7 +6,7 @@ class ArmCPUAdamExtension(BaseExtension): def __init__(self) -> None: super().__init__() self.kernel_builder = ArmCPUAdamBuilder() - self._requires_build = False + self._requires_build = True @property def requires_build(self) -> bool: @@ -14,7 +14,7 @@ def requires_build(self) -> bool: def build(self): self.kernel_builder.build() - self._requires_build = True + self._requires_build = False def load(self): return self.kernel_builder.load() diff --git a/colossalai/kernel/extensions/cpu_adam/x86_extension.py b/colossalai/kernel/extensions/cpu_adam/x86_extension.py index 687c91f35759..a5b64bed4b66 100644 --- a/colossalai/kernel/extensions/cpu_adam/x86_extension.py +++ b/colossalai/kernel/extensions/cpu_adam/x86_extension.py @@ -7,7 +7,7 @@ class X86CPUAdamExtension(BaseExtension): def __init__(self) -> None: super().__init__() self.kernel_builder = X86CPUAdamBuilder() - self._requires_build = False + self._requires_build = True @property def requires_build(self) -> bool: @@ -15,7 +15,7 @@ def requires_build(self) -> bool: def build(self): self.kernel_builder.build() - self._requires_build = True + self._requires_build = False def load(self): return self.kernel_builder.load()