diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.h b/colossalai/kernel/cuda_native/csrc/cpu_adam.h index bf9b85997c78..db1f26d5f6da 100644 --- a/colossalai/kernel/cuda_native/csrc/cpu_adam.h +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam.h @@ -142,6 +142,7 @@ class Adam_Optimizer { } } +#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) inline void simd_load(bool is_half, float *ptr, __half *h_ptr, AVX_Data &data) { if (is_half) { @@ -159,6 +160,7 @@ class Adam_Optimizer { SIMD_STORE(ptr, data.data); } } +#endif void step(size_t step, float lr, float beta1, float beta2, float epsilon, float weight_decay, bool bias_correction, torch::Tensor ¶ms, diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam_arm.cpp b/colossalai/kernel/cuda_native/csrc/cpu_adam_arm.cpp new file mode 100644 index 000000000000..a715a2711576 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam_arm.cpp @@ -0,0 +1,304 @@ +#include "cpu_adam_arm.h" + +void AdamOptimizer::Step_1(void *_params, void *grads, void *_exp_avg, + void *_exp_avg_sq, size_t _param_size, + at::ScalarType param_dtype, + at::ScalarType grad_dtype, + at::ScalarType exp_avg_dtype, + at::ScalarType exp_avg_sq_dtype, float loss_scale) { + size_t rounded_size = 0; +#if defined(__aarch64__) + rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH); +#endif + + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + float step_size = -1 * _alpha / _bias_correction1; + float w_decay = -1 * _alpha * _weight_decay; + +#if defined(__aarch64__) + float32x4_t betta1_4 = simd_set(_betta1); + float32x4_t betta2_4 = simd_set(_betta2); + float32x4_t betta1_minus1_4 = simd_set(betta1_minus1); + float32x4_t betta2_minus1_4 = simd_set(betta2_minus1); + float32x4_t bias2_sqrt = simd_set(_bias_correction2); + float32x4_t eps_4 = simd_set(_eps); + float32x4_t step_size_4 = simd_set(step_size); + float32x4_t weight_decay_4; + if (_weight_decay > 0) { + weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay); + } + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; + +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH) { + float32x4_t grad_4 = simd_load_offset(grads, grad_dtype, i); + if (loss_scale > 0) { + float32x4_t loss_scale_vec = simd_set(loss_scale); + grad_4 = vdivq_f32(grad_4, loss_scale_vec); + } + float32x4_t momentum_4 = simd_load_offset(_exp_avg, exp_avg_dtype, i); + float32x4_t variance_4 = + simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i); + float32x4_t param_4 = simd_load_offset(_params, param_dtype, i); + if (_weight_decay > 0 && !_adamw_mode) { + grad_4 = vfmaq_f32(grad_4, param_4, weight_decay_4); + } + momentum_4 = vmulq_f32(momentum_4, betta1_4); + momentum_4 = vfmaq_f32(momentum_4, grad_4, betta1_minus1_4); + variance_4 = vmulq_f32(variance_4, betta2_4); + grad_4 = vmulq_f32(grad_4, grad_4); + variance_4 = vfmaq_f32(variance_4, grad_4, betta2_minus1_4); + grad_4 = vsqrtq_f32(variance_4); + grad_4 = vfmaq_f32(eps_4, grad_4, bias2_sqrt); + grad_4 = vdivq_f32(momentum_4, grad_4); + if (_weight_decay > 0 && _adamw_mode) { + param_4 = vfmaq_f32(param_4, param_4, weight_decay_4); + } + param_4 = vfmaq_f32(param_4, grad_4, step_size_4); + simd_store_offset(_params, param_dtype, param_4, i); + simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4, i); + simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4, i); + } + } +#endif + if (_param_size > rounded_size) { + for (size_t t = rounded_size; t < _param_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > _param_size) copy_size = _param_size - t; + size_t offset = copy_size + t; + +#pragma omp parallel for + for (size_t k = t; k < offset; k++) { + float grad = scalar_load_offset(grads, grad_dtype, k); + if (loss_scale > 0) { + grad /= loss_scale; + } + float param = scalar_load_offset(_params, param_dtype, k); + float momentum = scalar_load_offset(_exp_avg, exp_avg_dtype, k); + float variance = scalar_load_offset(_exp_avg_sq, exp_avg_sq_dtype, k); + if (_weight_decay > 0 && !_adamw_mode) { + grad = param * _weight_decay + grad; + } + momentum = momentum * _betta1; + momentum = grad * betta1_minus1 + momentum; + + variance = variance * _betta2; + grad = grad * grad; + variance = grad * betta2_minus1 + variance; + + grad = sqrt(variance); + grad = grad * _bias_correction2 + _eps; + grad = momentum / grad; + if (_weight_decay > 0 && _adamw_mode) { + param += w_decay * param; + } + param = grad * step_size + param; + + scalar_store_offset(_params, param_dtype, param, k); + scalar_store_offset(_exp_avg, exp_avg_dtype, momentum, k); + scalar_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance, k); + } + } + } +} + +void AdamOptimizer::Step_4(void *_params, void *grads, void *_exp_avg, + void *_exp_avg_sq, size_t _param_size, + at::ScalarType param_dtype, + at::ScalarType grad_dtype, + at::ScalarType exp_avg_dtype, + at::ScalarType exp_avg_sq_dtype, float loss_scale) { + size_t rounded_size = 0; +#if defined(__aarch64__) + rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4); +#endif + + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + float step_size = -1 * _alpha / _bias_correction1; + float w_decay = -1 * _alpha * _weight_decay; + +#if defined(__aarch64__) + float32x4_t betta1_4 = simd_set(_betta1); + float32x4_t betta2_4 = simd_set(_betta2); + float32x4_t betta1_minus1_4 = simd_set(betta1_minus1); + float32x4_t betta2_minus1_4 = simd_set(betta2_minus1); + float32x4_t bias2_sqrt = simd_set(_bias_correction2); + float32x4_t eps_4 = simd_set(_eps); + float32x4_t step_size_4 = simd_set(step_size); + float32x4_t weight_decay_4; + if (_weight_decay > 0) { + weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay); + } + + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; + +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH * 4) { + float32x4_t grad_4[4]; + float32x4_t momentum_4[4]; + float32x4_t variance_4[4]; + float32x4_t param_4[4]; +#pragma unroll 4 + for (int j = 0; j < 4; j++) { + grad_4[j] = simd_load_offset(grads, grad_dtype, i + SIMD_WIDTH * j); + if (loss_scale > 0) { + float32x4_t loss_scale_vec = simd_set(loss_scale); + grad_4[j] = vdivq_f32(grad_4[j], loss_scale_vec); + } + momentum_4[j] = + simd_load_offset(_exp_avg, exp_avg_dtype, i + SIMD_WIDTH * j); + variance_4[j] = + simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i + SIMD_WIDTH * j); + param_4[j] = simd_load_offset(_params, param_dtype, i + SIMD_WIDTH * j); + if (_weight_decay > 0 && !_adamw_mode) { + grad_4[j] = vfmaq_f32(grad_4[j], param_4[j], weight_decay_4); + } + momentum_4[j] = vmulq_f32(momentum_4[j], betta1_4); + momentum_4[j] = vfmaq_f32(momentum_4[j], grad_4[j], betta1_minus1_4); + variance_4[j] = vmulq_f32(variance_4[j], betta2_4); + grad_4[j] = vmulq_f32(grad_4[j], grad_4[j]); + variance_4[j] = vfmaq_f32(variance_4[j], grad_4[j], betta2_minus1_4); + grad_4[j] = vsqrtq_f32(variance_4[j]); + grad_4[j] = vfmaq_f32(eps_4, grad_4[j], bias2_sqrt); + grad_4[j] = vdivq_f32(momentum_4[j], grad_4[j]); + if (_weight_decay > 0 && _adamw_mode) { + param_4[j] = vfmaq_f32(param_4[j], param_4[j], weight_decay_4); + } + param_4[j] = vfmaq_f32(param_4[j], grad_4[j], step_size_4); + simd_store_offset(_params, param_dtype, param_4[j], i + SIMD_WIDTH * j); + simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4[j], + i + SIMD_WIDTH * j); + simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4[j], + i + SIMD_WIDTH * j); + } + } + } +#endif + if (_param_size > rounded_size) { + Step_1(scalar_seek_offset(_params, param_dtype, rounded_size), + scalar_seek_offset(grads, grad_dtype, rounded_size), + scalar_seek_offset(_exp_avg, exp_avg_dtype, rounded_size), + scalar_seek_offset(_exp_avg_sq, exp_avg_sq_dtype, rounded_size), + (_param_size - rounded_size), param_dtype, grad_dtype, exp_avg_dtype, + exp_avg_sq_dtype, loss_scale); + } +} + +void AdamOptimizer::Step_8(void *_params, void *grads, void *_exp_avg, + void *_exp_avg_sq, size_t _param_size, + at::ScalarType param_dtype, + at::ScalarType grad_dtype, + at::ScalarType exp_avg_dtype, + at::ScalarType exp_avg_sq_dtype, float loss_scale) { + size_t rounded_size = 0; +#if defined(__aarch64__) + rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8); +#endif + + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + float step_size = -1 * _alpha / _bias_correction1; + float w_decay = -1 * _alpha * _weight_decay; +#if defined(__aarch64__) + float32x4_t betta1_4 = simd_set(_betta1); + float32x4_t betta2_4 = simd_set(_betta2); + float32x4_t betta1_minus1_4 = simd_set(betta1_minus1); + float32x4_t betta2_minus1_4 = simd_set(betta2_minus1); + float32x4_t bias2_sqrt = simd_set(_bias_correction2); + float32x4_t eps_4 = simd_set(_eps); + float32x4_t step_size_4 = simd_set(step_size); + float32x4_t weight_decay_4; + if (_weight_decay > 0) { + weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay); + } + + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; + +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH * 8) { + float32x4_t grad_4[8]; + float32x4_t momentum_4[8]; + float32x4_t variance_4[8]; + float32x4_t param_4[8]; +#pragma unroll 4 + for (int j = 0; j < 8; j++) { + grad_4[j] = simd_load_offset(grads, grad_dtype, i + SIMD_WIDTH * j); + if (loss_scale > 0) { + float32x4_t loss_scale_vec = simd_set(loss_scale); + grad_4[j] = vdivq_f32(grad_4[j], loss_scale_vec); + } + momentum_4[j] = + simd_load_offset(_exp_avg, exp_avg_dtype, i + SIMD_WIDTH * j); + variance_4[j] = + simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i + SIMD_WIDTH * j); + param_4[j] = simd_load_offset(_params, param_dtype, i + SIMD_WIDTH * j); + if (_weight_decay > 0 && !_adamw_mode) { + grad_4[j] = vfmaq_f32(grad_4[j], param_4[j], weight_decay_4); + } + momentum_4[j] = vmulq_f32(momentum_4[j], betta1_4); + momentum_4[j] = vfmaq_f32(momentum_4[j], grad_4[j], betta1_minus1_4); + variance_4[j] = vmulq_f32(variance_4[j], betta2_4); + grad_4[j] = vmulq_f32(grad_4[j], grad_4[j]); + variance_4[j] = vfmaq_f32(variance_4[j], grad_4[j], betta2_minus1_4); + grad_4[j] = vsqrtq_f32(variance_4[j]); + grad_4[j] = vfmaq_f32(eps_4, grad_4[j], bias2_sqrt); + grad_4[j] = vdivq_f32(momentum_4[j], grad_4[j]); + if (_weight_decay > 0 && _adamw_mode) { + param_4[j] = vfmaq_f32(param_4[j], param_4[j], weight_decay_4); + } + param_4[j] = vfmaq_f32(param_4[j], grad_4[j], step_size_4); + simd_store_offset(_params, param_dtype, param_4[j], i + SIMD_WIDTH * j); + simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4[j], + i + SIMD_WIDTH * j); + simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4[j], + i + SIMD_WIDTH * j); + } + } + } +#endif + if (_param_size > rounded_size) { + Step_4(scalar_seek_offset(_params, param_dtype, rounded_size), + scalar_seek_offset(grads, grad_dtype, rounded_size), + scalar_seek_offset(_exp_avg, exp_avg_dtype, rounded_size), + scalar_seek_offset(_exp_avg_sq, exp_avg_sq_dtype, rounded_size), + (_param_size - rounded_size), param_dtype, grad_dtype, exp_avg_dtype, + exp_avg_sq_dtype, loss_scale); + } +} + +void AdamOptimizer::step(size_t step, float lr, float beta1, float beta2, + float epsilon, float weight_decay, + bool bias_correction, torch::Tensor ¶ms, + torch::Tensor &grads, torch::Tensor &exp_avg, + torch::Tensor &exp_avg_sq, float loss_scale) { + auto params_c = params.contiguous(); + auto grads_c = grads.contiguous(); + auto exp_avg_c = exp_avg.contiguous(); + auto exp_avg_sq_c = exp_avg_sq.contiguous(); + + this->IncrementStep(step, beta1, beta2); + this->update_state(lr, epsilon, weight_decay, bias_correction); + this->Step_8(params_c.data_ptr(), grads_c.data_ptr(), exp_avg_c.data_ptr(), + exp_avg_sq_c.data_ptr(), params_c.numel(), + params_c.scalar_type(), grads_c.scalar_type(), + exp_avg_c.scalar_type(), exp_avg_sq_c.scalar_type(), loss_scale); +} + +namespace py = pybind11; + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + py::class_(m, "CPUAdamOptimizer") + .def(py::init()) + .def("step", &AdamOptimizer::step); +} diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam_arm.h b/colossalai/kernel/cuda_native/csrc/cpu_adam_arm.h new file mode 100644 index 000000000000..c731850edc31 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam_arm.h @@ -0,0 +1,201 @@ +#pragma once +#include +#include + +#include + +#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) +#define TILE (128 * 1024 * 1024) + +#if defined(__aarch64__) +#include +#define SIMD_WIDTH 4 + +inline float32x4_t simd_load_offset(const void *ptr, at::ScalarType dtype, + size_t offset) { + switch (dtype) { + case at::ScalarType::Float: { + auto ptr_f = reinterpret_cast(ptr); + return vld1q_f32(ptr_f + offset); + } + case at::ScalarType::Half: { + auto ptr_h = reinterpret_cast(ptr); + return vcvt_f32_f16(vld1_f16(ptr_h + offset)); + } + // case at::ScalarType::BFloat16: { + // auto ptr_b = reinterpret_cast(ptr); + // return vcvt_f32_bf16(vld1_bf16(ptr_b + offset)); + // } + default: + AT_ERROR("Unsupported dtype"); + break; + } +} +inline float32x4_t simd_load(void const *ptr, at::ScalarType dtype) { + return simd_load_offset(ptr, dtype, 0); +} + +inline void simd_store_offset(void *ptr, at::ScalarType dtype, float32x4_t data, + size_t offset) { + switch (dtype) { + case at::ScalarType::Float: { + auto ptr_f = reinterpret_cast(ptr); + vst1q_f32(ptr_f + offset, data); + break; + } + case at::ScalarType::Half: { + auto ptr_h = reinterpret_cast(ptr); + vst1_f16(ptr_h + offset, vcvt_f16_f32(data)); + break; + } + // case at::ScalarType::BFloat16: { + // auto ptr_b = reinterpret_cast(ptr); + // vst1_bf16(ptr_b + offset, vcvt_bf16_f32(data)); + // break; + // } + default: + AT_ERROR("Unsupported dtype"); + break; + } +} + +inline void simd_store(void *ptr, at::ScalarType dtype, float32x4_t data) { + return simd_store_offset(ptr, dtype, data, 0); +} + +inline float32x4_t simd_set(float value) { + auto val = static_cast(value); + return vdupq_n_f32(val); +} + +#endif + +inline float scalar_load_offset(const void *ptr, at::ScalarType dtype, + size_t offset) { + switch (dtype) { + case at::ScalarType::Float: + return *(reinterpret_cast(ptr) + offset); + case at::ScalarType::Half: + return static_cast( + *(reinterpret_cast(ptr) + offset)); + // case at::ScalarType::BFloat16: + // return static_cast( + // *(reinterpret_cast(ptr) + offset)); + default: + AT_ERROR("Unsupported dtype"); + break; + } +} + +inline void scalar_store_offset(void *ptr, at::ScalarType dtype, float data, + size_t offset) { + switch (dtype) { + case at::ScalarType::Float: + *(reinterpret_cast(ptr) + offset) = data; + break; + case at::ScalarType::Half: + *(reinterpret_cast(ptr) + offset) = data; + break; + // case at::ScalarType::BFloat16: + // *(reinterpret_cast(ptr) + offset) = data; + break; + default: + AT_ERROR("Unsupported dtype"); + break; + } +} + +inline void *scalar_seek_offset(void *ptr, at::ScalarType dtype, + size_t offset) { + switch (dtype) { + case at::ScalarType::Float: + return reinterpret_cast(ptr) + offset; + case at::ScalarType::Half: + return reinterpret_cast(ptr) + offset; + // case at::ScalarType::BFloat16: + // return reinterpret_cast(ptr) + offset; + default: + AT_ERROR("Unsupported dtype"); + break; + } +} +#define STEP(SPAN) \ + void Step_##SPAN(void *_params, void *grads, void *_exp_avg, \ + void *_exp_avg_sq, size_t _param_size, \ + at::ScalarType param_dtype, at::ScalarType grad_dtype, \ + at::ScalarType exp_avg_dtype, \ + at::ScalarType exp_avg_sq_dtype, float loss_scale = -1); + +class AdamOptimizer { + private: + float _alpha; + float _betta1; + float _betta2; + float _eps; + float _weight_decay; + + float _betta1_t; + float _betta2_t; + size_t _step; + + float _bias_correction1; + float _bias_correction2; + + bool _adamw_mode; + + public: + AdamOptimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999, + float eps = 1e-8, float weight_decay = 0, + bool adamw_mode = true) + : _alpha(alpha), + _betta1(betta1), + _betta2(betta2), + _eps(eps), + _weight_decay(weight_decay), + _betta1_t(1.0), + _betta2_t(1.0), + _step(0), + _adamw_mode(adamw_mode) {} + ~AdamOptimizer() {} + + STEP(1) + STEP(4) + STEP(8) + inline void IncrementStep(size_t step, float beta1, float beta2) { + if (beta1 != _betta1 || beta2 != _betta2) { + _step = step; + _betta1 = beta1; + _betta2 = beta2; + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); + } else { + _step++; + if (_step != step) { + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); + _step = step; + } else { + _betta1_t *= _betta1; + _betta2_t *= _betta2; + } + } + } + inline void update_state(float lr, float epsilon, float weight_decay, + bool bias_correction) { + _alpha = lr; + _eps = epsilon; + _weight_decay = weight_decay; + + _bias_correction1 = 1.0f; + _bias_correction2 = 1.0f; + if (bias_correction == 1) { + _bias_correction1 = 1 - _betta1_t; + _bias_correction2 = 1 / sqrt(1 - _betta2_t); + } + } + + void step(size_t step, float lr, float beta1, float beta2, float epsilon, + float weight_decay, bool bias_correction, torch::Tensor ¶ms, + torch::Tensor &grads, torch::Tensor &exp_avg, + torch::Tensor &exp_avg_sq, float loss_scale); +}; diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index c3c0180e8516..7d53a1dd6834 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -1,9 +1,10 @@ import math +import platform from typing import Optional import torch -from colossalai.kernel.op_builder import CPUAdamBuilder +from colossalai.kernel.op_builder import ArmCPUAdamBuilder, CPUAdamBuilder from .nvme_optimizer import NVMeOptimizer @@ -77,7 +78,7 @@ def __init__( default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) self.adamw_mode = adamw_mode - cpu_adam = CPUAdamBuilder().load() + cpu_adam = ArmCPUAdamBuilder().load() if platform.machine() == "aarch64" else CPUAdamBuilder().load() # if you find yourself stuck here, make sure that you install colossalai with CUDA_EXT=1 specification self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py index c7a309b872ce..d34fd601ab25 100644 --- a/colossalai/nn/optimizer/hybrid_adam.py +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -84,9 +84,10 @@ def __init__( nvme_offload_fraction, nvme_offload_dir, ) - fused_optim = FusedOptimBuilder().load() - self.gpu_adam_op = fused_optim.multi_tensor_adam - self._dummy_overflow_buf = torch.cuda.IntTensor([0]) + if torch.cuda.is_available(): + fused_optim = FusedOptimBuilder().load() + self.gpu_adam_op = fused_optim.multi_tensor_adam + self._dummy_overflow_buf = torch.cuda.IntTensor([0]) @torch.no_grad() def step(self, closure=None, div_scale: float = -1): @@ -118,11 +119,11 @@ def step(self, closure=None, div_scale: float = -1): group_step = state["step"] beta1, beta2 = group["betas"] - if target_device.type == "cpu": - assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu" - assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu" + if target_device.type == "cpu" or target_device.type == "npu": + assert state["exp_avg"].device.type in ("cpu", "npu"), "exp_avg should stay on cpu" + assert state["exp_avg_sq"].device.type in ("cpu", "npu"), "exp_avg should stay on cpu" self._pre_update(p, "exp_avg", "exp_avg_sq") - if p.grad.dtype is torch.bfloat16: + if p.grad.dtype is torch.bfloat16 or p.grad.device.type == "npu": # cpu adam kernel does not support bf16 now bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] diff --git a/op_builder/__init__.py b/op_builder/__init__.py index 808559ec9c2d..21e216437c47 100644 --- a/op_builder/__init__.py +++ b/op_builder/__init__.py @@ -1,3 +1,4 @@ +from .arm_cpu_adam import ArmCPUAdamBuilder from .cpu_adam import CPUAdamBuilder from .fused_optim import FusedOptimBuilder from .layernorm import LayerNormBuilder @@ -29,4 +30,5 @@ "MultiTensorLambBuilder", "MultiTensorScaleBuilder", "MultiTensorL2NormBuilder", + "ArmCPUAdamBuilder", ] diff --git a/op_builder/arm_cpu_adam.py b/op_builder/arm_cpu_adam.py new file mode 100644 index 000000000000..18dd519fae46 --- /dev/null +++ b/op_builder/arm_cpu_adam.py @@ -0,0 +1,34 @@ +from .builder import Builder + + +class ArmCPUAdamBuilder(Builder): + NAME = "arm_cpu_adam" + PREBUILT_IMPORT_PATH = "colossalai._C.arm_cpu_adam" + ext_type = "cpu" + + def __init__(self): + super().__init__(name=ArmCPUAdamBuilder.NAME, prebuilt_import_path=ArmCPUAdamBuilder.PREBUILT_IMPORT_PATH) + self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] + + # necessary 4 functions + def sources_files(self): + ret = [ + self.csrc_abs_path("cpu_adam_arm.cpp"), + ] + return ret + + def include_dirs(self): + return [self.csrc_abs_path("includes")] + + def cxx_flags(self): + extra_cxx_flags = [ + "-std=c++14", + "-std=c++17", + "-g", + "-Wno-reorder", + "-fopenmp", + ] + return ["-O3"] + self.version_dependent_macros + extra_cxx_flags + + def nvcc_flags(self): + return [] diff --git a/op_builder/builder.py b/op_builder/builder.py index 75823ef105c7..d804cb1602e4 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -7,7 +7,7 @@ import time from abc import ABC, abstractmethod from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Union from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0 @@ -21,6 +21,8 @@ class Builder(ABC): prebuilt_import_path (str): the path where the extension is installed during pip install """ + ext_type: str = "cuda" + def __init__(self, name: str, prebuilt_import_path: str): self.name = name self.prebuilt_import_path = prebuilt_import_path @@ -165,7 +167,8 @@ def load(self, verbose: Optional[bool] = None): ) except ImportError: # check environment - self.check_runtime_build_environment() + if self.ext_type == "cuda": + self.check_runtime_build_environment() # time the kernel compilation start_build = time.time() @@ -208,11 +211,19 @@ def load(self, verbose: Optional[bool] = None): return op_module - def builder(self) -> "CUDAExtension": + def builder(self) -> Union["CUDAExtension", "CppExtension"]: """ get a CUDAExtension instance used for setup.py """ - from torch.utils.cpp_extension import CUDAExtension + from torch.utils.cpp_extension import CppExtension, CUDAExtension + + if self.ext_type == "cpp": + return CppExtension( + name=self.prebuilt_import_path, + sources=self.strip_empty_entries(self.sources_files()), + include_dirs=self.strip_empty_entries(self.include_dirs()), + extra_compile_args=self.strip_empty_entries(self.cxx_flags()), + ) return CUDAExtension( name=self.prebuilt_import_path,