From d043a504fe1882fd620191a8a806e55a9f9c59c0 Mon Sep 17 00:00:00 2001 From: lizimin Date: Thu, 17 Oct 2024 11:18:11 +0800 Subject: [PATCH 1/5] Support fp32 for add operator --- operatorspy/tests/add.py | 9 ++++-- src/ops/add/cpu/add_cpu.cc | 24 ++++++++++++++-- src/ops/add/cuda/add.cc | 5 +++- src/ops/add/cuda/add.cu | 56 ++++++++++++++++++++++++-------------- src/ops/utils.h | 2 +- 5 files changed, 68 insertions(+), 28 deletions(-) diff --git a/operatorspy/tests/add.py b/operatorspy/tests/add.py index 2b74e1b9..d766208c 100644 --- a/operatorspy/tests/add.py +++ b/operatorspy/tests/add.py @@ -85,7 +85,8 @@ def test_cpu(lib, test_cases): device = DeviceEnum.DEVICE_CPU handle = create_handle(lib, device) for c_shape, a_shape, b_shape, inplace in test_cases: - test(lib, handle, "cpu", c_shape, a_shape, b_shape, inplace=inplace) + test(lib, handle, "cpu", c_shape, a_shape, b_shape, tensor_dtype=torch.float16, inplace=inplace) + test(lib, handle, "cpu", c_shape, a_shape, b_shape, tensor_dtype=torch.float32, inplace=inplace) destroy_handle(lib, handle) @@ -93,7 +94,8 @@ def test_cuda(lib, test_cases): device = DeviceEnum.DEVICE_CUDA handle = create_handle(lib, device) for c_shape, a_shape, b_shape, inplace in test_cases: - test(lib, handle, "cuda", c_shape, a_shape, b_shape, inplace=inplace) + test(lib, handle, "cuda", c_shape, a_shape, b_shape, tensor_dtype=torch.float16, inplace=inplace) + test(lib, handle, "cuda", c_shape, a_shape, b_shape, tensor_dtype=torch.float32, inplace=inplace) destroy_handle(lib, handle) @@ -103,7 +105,8 @@ def test_bang(lib, test_cases): device = DeviceEnum.DEVICE_BANG handle = create_handle(lib, device) for c_shape, a_shape, b_shape, inplace in test_cases: - test(lib, handle, "mlu", c_shape, a_shape, b_shape, inplace=inplace) + test(lib, handle, "mlu", c_shape, a_shape, b_shape, tensor_dtype=torch.float16, inplace=inplace) + test(lib, handle, "mlu", c_shape, a_shape, b_shape, tensor_dtype=torch.float32, inplace=inplace) destroy_handle(lib, handle) diff --git a/src/ops/add/cpu/add_cpu.cc b/src/ops/add/cpu/add_cpu.cc index 8a20f933..430c00a3 100644 --- a/src/ops/add/cpu/add_cpu.cc +++ b/src/ops/add/cpu/add_cpu.cc @@ -27,7 +27,10 @@ infiniopStatus_t cpuCreateAddDescriptor(infiniopHandle_t, if (!is_contiguous(a) || !is_contiguous(b) || !is_contiguous(c)) { return STATUS_BAD_TENSOR_STRIDES; } - if (!dtype_eq(c->dt, F16) || c->dt != a->dt || c->dt != b->dt) { + if (c->dt != F16 && c->dt != F32) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (c->dt != a->dt || c->dt != b->dt) { return STATUS_BAD_TENSOR_DTYPE; } @@ -79,12 +82,29 @@ void add_cpu_f16(AddCpuDescriptor_t desc, void *c, void const *a, void const *b) } } +void add_cpu_f32(AddCpuDescriptor_t desc, void *c, void const *a, void const *b) { + auto a_ = reinterpret_cast(a); + auto b_ = reinterpret_cast(b); + auto c_ = reinterpret_cast(c); + const auto &indices = desc->c_indices; + + for (uint64_t i = 0; i < desc->c_data_size; ++i, incrementOne(indices, desc->c_shape, desc->ndim)) { + auto a_index = compactToFlat(indices, desc->a_strides, desc->ndim); + auto b_index = compactToFlat(indices, desc->b_strides, desc->ndim); + c_[i] = a_[a_index] + b_[b_index]; + } +} + infiniopStatus_t cpuAdd(AddCpuDescriptor_t desc, void *c, void const *a, void const *b, void *stream) { - if (dtype_eq(desc->dtype, F16)) { + if (desc->dtype == F16) { add_cpu_f16(desc, c, a, b); return STATUS_SUCCESS; } + if (desc->dtype == F32) { + add_cpu_f32(desc, c, a, b); + return STATUS_SUCCESS; + } return STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/ops/add/cuda/add.cc b/src/ops/add/cuda/add.cc index 0b610e57..bfb885c1 100644 --- a/src/ops/add/cuda/add.cc +++ b/src/ops/add/cuda/add.cc @@ -14,7 +14,10 @@ infiniopStatus_t cudaCreateAddDescriptor(CudaHandle_t handle, if (!is_contiguous(a) || !is_contiguous(b) || !is_contiguous(c)) { return STATUS_BAD_TENSOR_STRIDES; } - if (!dtype_eq(c->dt, F16) || c->dt != a->dt || c->dt != b->dt) { + if (c->dt != F16 && c->dt != F32) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (c->dt != a->dt || c->dt != b->dt) { return STATUS_BAD_TENSOR_DTYPE; } bool broadcasted = false; diff --git a/src/ops/add/cuda/add.cu b/src/ops/add/cuda/add.cu index 4d880e4e..4615d385 100644 --- a/src/ops/add/cuda/add.cu +++ b/src/ops/add/cuda/add.cu @@ -2,11 +2,20 @@ #include "../../utils.h" #include "add.cuh" -struct half4 { - __half x, y, z, w; +template +struct vecN { + T data[N]; - __device__ half4 operator+(const half4 &other) const { - return half4{__hadd(x, other.x), __hadd(y, other.y), __hadd(z, other.z), __hadd(w, other.w)}; + __device__ vecN operator+(const vecN &other) const { + vecN result; + for (int i = 0; i < N; ++i) { + result.data[i] = data[i] + other.data[i]; + } + return result; + } + + __device__ const T &operator[](int i) const { + return data[i]; } }; @@ -52,7 +61,7 @@ __global__ void add( } template -void add_nv_gpu(AddCudaDescriptor_t desc, Tdata *c, Tdata const *a, Tdata const *b, uint64_t data_size, uint64_t pack_size, uint64_t offset, void *stream) { +void _add_nv_gpu(AddCudaDescriptor_t desc, Tdata *c, Tdata const *a, Tdata const *b, uint64_t data_size, uint64_t pack_size, uint64_t offset, void *stream) { if (data_size == 0) { return; } @@ -68,27 +77,32 @@ void add_nv_gpu(AddCudaDescriptor_t desc, Tdata *c, Tdata const *a, Tdata const } } -void add_nv_gpu_f16(AddCudaDescriptor_t desc, void *c, void const *a, void const *b, void *stream) { - auto data_size = desc->c_data_size / 4; - auto a_half4 = reinterpret_cast(a); - auto b_half4 = reinterpret_cast(b); - auto c_half4 = reinterpret_cast(c); - add_nv_gpu(desc, c_half4, a_half4, b_half4, data_size, 4, 0, stream); +template +void add_nv_gpu(AddCudaDescriptor_t desc, void *c, void const *a, void const *b, void *stream, uint64_t pack_size) { + auto data_size = desc->c_data_size / pack_size; + auto a_vec = reinterpret_cast(a); + auto b_vec = reinterpret_cast(b); + auto c_vec = reinterpret_cast(c); + _add_nv_gpu(desc, c_vec, a_vec, b_vec, data_size, pack_size, 0, stream); - auto remainder = desc->c_data_size % 4; - auto a_half = reinterpret_cast(a); - auto b_half = reinterpret_cast(b); - auto c_half = reinterpret_cast(c); - add_nv_gpu(desc, c_half, a_half, b_half, remainder, 1, data_size * 4, stream); + auto remainder = desc->c_data_size % pack_size; + auto a_ = reinterpret_cast(a); + auto b_ = reinterpret_cast(b); + auto c_ = reinterpret_cast(c); + _add_nv_gpu(desc, c_, a_, b_, remainder, 1, data_size * pack_size, stream); } infiniopStatus_t cudaAdd(AddCudaDescriptor_t desc, void *c, void const *a, void const *b, void *stream) { - if (!dtype_eq(desc->dtype, F16)) { - return STATUS_BAD_TENSOR_DTYPE; - } checkCudaError(cudaSetDevice(desc->device_id)); - add_nv_gpu_f16(desc, c, a, b, stream); - return STATUS_SUCCESS; + if (desc->dtype == F16) { + add_nv_gpu, half>(desc, c, a, b, stream, 4); + return STATUS_SUCCESS; + } + if (desc->dtype == F32) { + add_nv_gpu, float>(desc, c, a, b, stream, 4); + return STATUS_SUCCESS; + } + return STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/ops/utils.h b/src/ops/utils.h index bb4de8c6..a22dae2b 100644 --- a/src/ops/utils.h +++ b/src/ops/utils.h @@ -93,7 +93,7 @@ inline bool getBroadcastShape(const uint64_t *shape1, uint64_t ndim1, std::copy(shape2, shape2 + ndim2, padded_shape2 + max_rank - ndim2); // compute broadcasted shape - for (int i = 0; i < max_rank; ++i) { + for (size_t i = 0; i < max_rank; ++i) { if (padded_shape1[i] == padded_shape2[i] || padded_shape1[i] == 1 || padded_shape2[i] == 1) { broadcast_shape[i] = std::max(padded_shape1[i], padded_shape2[i]); } else { From cbfe86c79f1cbf4891f11f7ac3f9ffa892ae7521 Mon Sep 17 00:00:00 2001 From: lizimin Date: Thu, 17 Oct 2024 11:45:43 +0800 Subject: [PATCH 2/5] generalize the add_cpu function --- src/ops/add/cpu/add_cpu.cc | 32 ++++++++++++-------------------- src/ops/add/cpu/add_cpu.h | 1 + 2 files changed, 13 insertions(+), 20 deletions(-) diff --git a/src/ops/add/cpu/add_cpu.cc b/src/ops/add/cpu/add_cpu.cc index 430c00a3..7ca674fd 100644 --- a/src/ops/add/cpu/add_cpu.cc +++ b/src/ops/add/cpu/add_cpu.cc @@ -69,29 +69,21 @@ infiniopStatus_t cpuDestroyAddDescriptor(AddCpuDescriptor_t desc) { return STATUS_SUCCESS; } -void add_cpu_f16(AddCpuDescriptor_t desc, void *c, void const *a, void const *b) { - auto a_ = reinterpret_cast(a); - auto b_ = reinterpret_cast(b); - auto c_ = reinterpret_cast(c); +template +void add_cpu(AddCpuDescriptor_t desc, void *c, void const *a, void const *b) { + auto a_ = reinterpret_cast(a); + auto b_ = reinterpret_cast(b); + auto c_ = reinterpret_cast(c); const auto &indices = desc->c_indices; for (uint64_t i = 0; i < desc->c_data_size; ++i, incrementOne(indices, desc->c_shape, desc->ndim)) { auto a_index = compactToFlat(indices, desc->a_strides, desc->ndim); auto b_index = compactToFlat(indices, desc->b_strides, desc->ndim); - c_[i] = f32_to_f16(f16_to_f32(a_[a_index]) + f16_to_f32(b_[b_index])); - } -} - -void add_cpu_f32(AddCpuDescriptor_t desc, void *c, void const *a, void const *b) { - auto a_ = reinterpret_cast(a); - auto b_ = reinterpret_cast(b); - auto c_ = reinterpret_cast(c); - const auto &indices = desc->c_indices; - - for (uint64_t i = 0; i < desc->c_data_size; ++i, incrementOne(indices, desc->c_shape, desc->ndim)) { - auto a_index = compactToFlat(indices, desc->a_strides, desc->ndim); - auto b_index = compactToFlat(indices, desc->b_strides, desc->ndim); - c_[i] = a_[a_index] + b_[b_index]; + if constexpr (std::is_same::value) { + c_[i] = f32_to_f16(f16_to_f32(a_[a_index]) + f16_to_f32(b_[b_index])); + } else { + c_[i] = a_[a_index] + b_[b_index]; + } } } @@ -99,11 +91,11 @@ infiniopStatus_t cpuAdd(AddCpuDescriptor_t desc, void *c, void const *a, void const *b, void *stream) { if (desc->dtype == F16) { - add_cpu_f16(desc, c, a, b); + add_cpu(desc, c, a, b); return STATUS_SUCCESS; } if (desc->dtype == F32) { - add_cpu_f32(desc, c, a, b); + add_cpu(desc, c, a, b); return STATUS_SUCCESS; } return STATUS_BAD_TENSOR_DTYPE; diff --git a/src/ops/add/cpu/add_cpu.h b/src/ops/add/cpu/add_cpu.h index c9c8d98e..42e62435 100644 --- a/src/ops/add/cpu/add_cpu.h +++ b/src/ops/add/cpu/add_cpu.h @@ -3,6 +3,7 @@ #include "operators.h" #include +#include struct AddCpuDescriptor { Device device; From 69a91c7f5bb55290247746aa89feace36fc47c12 Mon Sep 17 00:00:00 2001 From: lizimin Date: Thu, 17 Oct 2024 11:53:05 +0800 Subject: [PATCH 3/5] optimized add_cpu format --- src/ops/add/cpu/add_cpu.cc | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/ops/add/cpu/add_cpu.cc b/src/ops/add/cpu/add_cpu.cc index 7ca674fd..649fa052 100644 --- a/src/ops/add/cpu/add_cpu.cc +++ b/src/ops/add/cpu/add_cpu.cc @@ -70,7 +70,7 @@ infiniopStatus_t cpuDestroyAddDescriptor(AddCpuDescriptor_t desc) { } template -void add_cpu(AddCpuDescriptor_t desc, void *c, void const *a, void const *b) { +infiniopStatus_t add_cpu(AddCpuDescriptor_t desc, void *c, void const *a, void const *b) { auto a_ = reinterpret_cast(a); auto b_ = reinterpret_cast(b); auto c_ = reinterpret_cast(c); @@ -85,18 +85,17 @@ void add_cpu(AddCpuDescriptor_t desc, void *c, void const *a, void const *b) { c_[i] = a_[a_index] + b_[b_index]; } } + return STATUS_SUCCESS; } infiniopStatus_t cpuAdd(AddCpuDescriptor_t desc, void *c, void const *a, void const *b, void *stream) { if (desc->dtype == F16) { - add_cpu(desc, c, a, b); - return STATUS_SUCCESS; + return add_cpu(desc, c, a, b); } if (desc->dtype == F32) { - add_cpu(desc, c, a, b); - return STATUS_SUCCESS; + return add_cpu(desc, c, a, b); } return STATUS_BAD_TENSOR_DTYPE; } From 1b4810fc23faacc6d06fb4464d583a933125fda3 Mon Sep 17 00:00:00 2001 From: lizimin Date: Fri, 18 Oct 2024 13:57:02 +0800 Subject: [PATCH 4/5] optimized add_nv_gpu --- src/ops/add/cuda/add.cu | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/ops/add/cuda/add.cu b/src/ops/add/cuda/add.cu index 4615d385..547712bb 100644 --- a/src/ops/add/cuda/add.cu +++ b/src/ops/add/cuda/add.cu @@ -49,6 +49,7 @@ __global__ void add( auto a_ = reinterpret_cast(a); auto b_ = reinterpret_cast(b); auto c_ = reinterpret_cast(c); +#pragma unroll for (size_t i = 0; i < pack_size; ++i) { auto a_idx = getDstIndex(idx + i, ndim, c_strides, a_strides); auto b_idx = getDstIndex(idx + i, ndim, c_strides, b_strides); @@ -71,6 +72,7 @@ void _add_nv_gpu(AddCudaDescriptor_t desc, Tdata *c, Tdata const *a, Tdata const cudaStream_t cuda_stream = reinterpret_cast(stream); +#pragma unroll for (uint64_t i = 0; i < data_size; i += step) { add<<>>( c, a, b, desc->a_strides, desc->b_strides, desc->c_strides, offset + data_size, desc->ndim, offset + i, desc->broadcasted, pack_size); @@ -78,7 +80,7 @@ void _add_nv_gpu(AddCudaDescriptor_t desc, Tdata *c, Tdata const *a, Tdata const } template -void add_nv_gpu(AddCudaDescriptor_t desc, void *c, void const *a, void const *b, void *stream, uint64_t pack_size) { +infiniopStatus_t add_nv_gpu(AddCudaDescriptor_t desc, void *c, void const *a, void const *b, void *stream, uint64_t pack_size) { auto data_size = desc->c_data_size / pack_size; auto a_vec = reinterpret_cast(a); auto b_vec = reinterpret_cast(b); @@ -90,6 +92,7 @@ void add_nv_gpu(AddCudaDescriptor_t desc, void *c, void const *a, void const *b, auto b_ = reinterpret_cast(b); auto c_ = reinterpret_cast(c); _add_nv_gpu(desc, c_, a_, b_, remainder, 1, data_size * pack_size, stream); + return STATUS_SUCCESS; } infiniopStatus_t cudaAdd(AddCudaDescriptor_t desc, @@ -97,12 +100,10 @@ infiniopStatus_t cudaAdd(AddCudaDescriptor_t desc, void *stream) { checkCudaError(cudaSetDevice(desc->device_id)); if (desc->dtype == F16) { - add_nv_gpu, half>(desc, c, a, b, stream, 4); - return STATUS_SUCCESS; + return add_nv_gpu, half>(desc, c, a, b, stream, 8); } if (desc->dtype == F32) { - add_nv_gpu, float>(desc, c, a, b, stream, 4); - return STATUS_SUCCESS; + return add_nv_gpu, float>(desc, c, a, b, stream, 4); } return STATUS_BAD_TENSOR_DTYPE; } From a9aec43ab71624c3800df127a080ab22d9e69312 Mon Sep 17 00:00:00 2001 From: lizimin Date: Tue, 22 Oct 2024 10:33:26 +0800 Subject: [PATCH 5/5] Enhanced fp16 and fp32 performance by applying better block size and vectorization --- src/ops/add/cuda/add.cu | 49 +++++++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 16 deletions(-) diff --git a/src/ops/add/cuda/add.cu b/src/ops/add/cuda/add.cu index 547712bb..6c1dfec4 100644 --- a/src/ops/add/cuda/add.cu +++ b/src/ops/add/cuda/add.cu @@ -2,23 +2,40 @@ #include "../../utils.h" #include "add.cuh" -template +/** + * @brief A templated vector struct that supports element-wise addition on arrays. + * + * @tparam T - The access data type for elements in the vector. + * @tparam TComp - The computation data type used for arithmetic operations. + * @tparam N - The number of elements of type T in the vector for a single access. + */ +template struct vecN { T data[N]; - __device__ vecN operator+(const vecN &other) const { - vecN result; + __device__ __forceinline__ vecN operator+(const vecN &other) const { + vecN result; + for (int i = 0; i < N; ++i) { - result.data[i] = data[i] + other.data[i]; + if constexpr (std::is_same::value) { + result.data[i] = data[i] + other.data[i]; + } else { + constexpr static size_t pack_size = sizeof(T) / sizeof(TComp); + auto data_ = reinterpret_cast *>(result.data); + data_[i] = std::move(reinterpret_cast const *>(data)[i] + + reinterpret_cast const *>(other.data)[i]); + } } + return result; } - __device__ const T &operator[](int i) const { + __device__ __forceinline__ const T &operator[](size_t i) const { return data[i]; } }; +// get the corresponding index in the destination given the flat index of the source __device__ uint64_t getDstIndex(uint64_t flat_index, uint64_t ndim, int64_t const *src_strides, int64_t const *dst_strides) { uint64_t res = 0; for (uint64_t i = 0; i < ndim; ++i) { @@ -66,7 +83,7 @@ void _add_nv_gpu(AddCudaDescriptor_t desc, Tdata *c, Tdata const *a, Tdata const if (data_size == 0) { return; } - dim3 blockDims = dim3(std::min(static_cast(MAX_THREADS_PER_BLOCK), data_size)); + dim3 blockDims = dim3(std::min(static_cast(256), data_size)); dim3 gridDims = dim3(std::min(ROUND_UP_DIV(data_size, blockDims.x), desc->max_grid_size)); uint64_t step = gridDims.x * blockDims.x; @@ -81,16 +98,16 @@ void _add_nv_gpu(AddCudaDescriptor_t desc, Tdata *c, Tdata const *a, Tdata const template infiniopStatus_t add_nv_gpu(AddCudaDescriptor_t desc, void *c, void const *a, void const *b, void *stream, uint64_t pack_size) { - auto data_size = desc->c_data_size / pack_size; - auto a_vec = reinterpret_cast(a); - auto b_vec = reinterpret_cast(b); - auto c_vec = reinterpret_cast(c); + const auto data_size = desc->c_data_size / pack_size; + const auto a_vec = reinterpret_cast(a); + const auto b_vec = reinterpret_cast(b); + const auto c_vec = reinterpret_cast(c); _add_nv_gpu(desc, c_vec, a_vec, b_vec, data_size, pack_size, 0, stream); - auto remainder = desc->c_data_size % pack_size; - auto a_ = reinterpret_cast(a); - auto b_ = reinterpret_cast(b); - auto c_ = reinterpret_cast(c); + const auto remainder = desc->c_data_size % pack_size; + const auto a_ = reinterpret_cast(a); + const auto b_ = reinterpret_cast(b); + const auto c_ = reinterpret_cast(c); _add_nv_gpu(desc, c_, a_, b_, remainder, 1, data_size * pack_size, stream); return STATUS_SUCCESS; } @@ -100,10 +117,10 @@ infiniopStatus_t cudaAdd(AddCudaDescriptor_t desc, void *stream) { checkCudaError(cudaSetDevice(desc->device_id)); if (desc->dtype == F16) { - return add_nv_gpu, half>(desc, c, a, b, stream, 8); + return add_nv_gpu, half>(desc, c, a, b, stream, 8); } if (desc->dtype == F32) { - return add_nv_gpu, float>(desc, c, a, b, stream, 4); + return add_nv_gpu, float>(desc, c, a, b, stream, 4); } return STATUS_BAD_TENSOR_DTYPE; }