Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions operatorspy/tests/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,17 @@ def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)
for c_shape, a_shape, b_shape, inplace in test_cases:
test(lib, handle, "cpu", c_shape, a_shape, b_shape, inplace=inplace)
test(lib, handle, "cpu", c_shape, a_shape, b_shape, tensor_dtype=torch.float16, inplace=inplace)
test(lib, handle, "cpu", c_shape, a_shape, b_shape, tensor_dtype=torch.float32, inplace=inplace)
destroy_handle(lib, handle)


def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)
for c_shape, a_shape, b_shape, inplace in test_cases:
test(lib, handle, "cuda", c_shape, a_shape, b_shape, inplace=inplace)
test(lib, handle, "cuda", c_shape, a_shape, b_shape, tensor_dtype=torch.float16, inplace=inplace)
test(lib, handle, "cuda", c_shape, a_shape, b_shape, tensor_dtype=torch.float32, inplace=inplace)
destroy_handle(lib, handle)


Expand All @@ -103,7 +105,8 @@ def test_bang(lib, test_cases):
device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device)
for c_shape, a_shape, b_shape, inplace in test_cases:
test(lib, handle, "mlu", c_shape, a_shape, b_shape, inplace=inplace)
test(lib, handle, "mlu", c_shape, a_shape, b_shape, tensor_dtype=torch.float16, inplace=inplace)
test(lib, handle, "mlu", c_shape, a_shape, b_shape, tensor_dtype=torch.float32, inplace=inplace)
destroy_handle(lib, handle)


Expand Down
29 changes: 20 additions & 9 deletions src/ops/add/cpu/add_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ infiniopStatus_t cpuCreateAddDescriptor(infiniopHandle_t,
if (!is_contiguous(a) || !is_contiguous(b) || !is_contiguous(c)) {
return STATUS_BAD_TENSOR_STRIDES;
}
if (!dtype_eq(c->dt, F16) || c->dt != a->dt || c->dt != b->dt) {
if (c->dt != F16 && c->dt != F32) {
return STATUS_BAD_TENSOR_DTYPE;
}
if (c->dt != a->dt || c->dt != b->dt) {
return STATUS_BAD_TENSOR_DTYPE;
}

Expand Down Expand Up @@ -66,25 +69,33 @@ infiniopStatus_t cpuDestroyAddDescriptor(AddCpuDescriptor_t desc) {
return STATUS_SUCCESS;
}

void add_cpu_f16(AddCpuDescriptor_t desc, void *c, void const *a, void const *b) {
auto a_ = reinterpret_cast<uint16_t const *>(a);
auto b_ = reinterpret_cast<uint16_t const *>(b);
auto c_ = reinterpret_cast<uint16_t *>(c);
template<typename Tdata>
infiniopStatus_t add_cpu(AddCpuDescriptor_t desc, void *c, void const *a, void const *b) {
auto a_ = reinterpret_cast<Tdata const *>(a);
auto b_ = reinterpret_cast<Tdata const *>(b);
auto c_ = reinterpret_cast<Tdata *>(c);
const auto &indices = desc->c_indices;

for (uint64_t i = 0; i < desc->c_data_size; ++i, incrementOne(indices, desc->c_shape, desc->ndim)) {
auto a_index = compactToFlat(indices, desc->a_strides, desc->ndim);
auto b_index = compactToFlat(indices, desc->b_strides, desc->ndim);
c_[i] = f32_to_f16(f16_to_f32(a_[a_index]) + f16_to_f32(b_[b_index]));
if constexpr (std::is_same<Tdata, uint16_t>::value) {
c_[i] = f32_to_f16(f16_to_f32(a_[a_index]) + f16_to_f32(b_[b_index]));
} else {
c_[i] = a_[a_index] + b_[b_index];
}
}
return STATUS_SUCCESS;
}

infiniopStatus_t cpuAdd(AddCpuDescriptor_t desc,
void *c, void const *a, void const *b,
void *stream) {
if (dtype_eq(desc->dtype, F16)) {
add_cpu_f16(desc, c, a, b);
return STATUS_SUCCESS;
if (desc->dtype == F16) {
return add_cpu<uint16_t>(desc, c, a, b);
}
if (desc->dtype == F32) {
return add_cpu<float>(desc, c, a, b);
}
return STATUS_BAD_TENSOR_DTYPE;
}
1 change: 1 addition & 0 deletions src/ops/add/cpu/add_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "operators.h"
#include <numeric>
#include <type_traits>

struct AddCpuDescriptor {
Device device;
Expand Down
5 changes: 4 additions & 1 deletion src/ops/add/cuda/add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ infiniopStatus_t cudaCreateAddDescriptor(CudaHandle_t handle,
if (!is_contiguous(a) || !is_contiguous(b) || !is_contiguous(c)) {
return STATUS_BAD_TENSOR_STRIDES;
}
if (!dtype_eq(c->dt, F16) || c->dt != a->dt || c->dt != b->dt) {
if (c->dt != F16 && c->dt != F32) {
return STATUS_BAD_TENSOR_DTYPE;
}
if (c->dt != a->dt || c->dt != b->dt) {
return STATUS_BAD_TENSOR_DTYPE;
}
bool broadcasted = false;
Expand Down
76 changes: 54 additions & 22 deletions src/ops/add/cuda/add.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,40 @@
#include "../../utils.h"
#include "add.cuh"

struct half4 {
__half x, y, z, w;
/**
* @brief A templated vector struct that supports element-wise addition on arrays.
*
* @tparam T - The access data type for elements in the vector.
* @tparam TComp - The computation data type used for arithmetic operations.
* @tparam N - The number of elements of type T in the vector for a single access.
*/
template<typename T, typename TComp, size_t N>
struct vecN {
T data[N];

__device__ half4 operator+(const half4 &other) const {
return half4{__hadd(x, other.x), __hadd(y, other.y), __hadd(z, other.z), __hadd(w, other.w)};
__device__ __forceinline__ vecN operator+(const vecN<T, TComp, N> &other) const {
vecN<T, TComp, N> result;

for (int i = 0; i < N; ++i) {
if constexpr (std::is_same<T, TComp>::value) {
result.data[i] = data[i] + other.data[i];
} else {
constexpr static size_t pack_size = sizeof(T) / sizeof(TComp);
auto data_ = reinterpret_cast<vecN<TComp, TComp, pack_size> *>(result.data);
data_[i] = std::move(reinterpret_cast<vecN<TComp, TComp, pack_size> const *>(data)[i] +
reinterpret_cast<vecN<TComp, TComp, pack_size> const *>(other.data)[i]);
}
}

return result;
}

__device__ __forceinline__ const T &operator[](size_t i) const {
return data[i];
}
};

// get the corresponding index in the destination given the flat index of the source
__device__ uint64_t getDstIndex(uint64_t flat_index, uint64_t ndim, int64_t const *src_strides, int64_t const *dst_strides) {
uint64_t res = 0;
for (uint64_t i = 0; i < ndim; ++i) {
Expand Down Expand Up @@ -40,6 +66,7 @@ __global__ void add(
auto a_ = reinterpret_cast<const BTdata *>(a);
auto b_ = reinterpret_cast<const BTdata *>(b);
auto c_ = reinterpret_cast<BTdata *>(c);
#pragma unroll
for (size_t i = 0; i < pack_size; ++i) {
auto a_idx = getDstIndex(idx + i, ndim, c_strides, a_strides);
auto b_idx = getDstIndex(idx + i, ndim, c_strides, b_strides);
Expand All @@ -52,43 +79,48 @@ __global__ void add(
}

template<typename Tdata, typename BTdata>
void add_nv_gpu(AddCudaDescriptor_t desc, Tdata *c, Tdata const *a, Tdata const *b, uint64_t data_size, uint64_t pack_size, uint64_t offset, void *stream) {
void _add_nv_gpu(AddCudaDescriptor_t desc, Tdata *c, Tdata const *a, Tdata const *b, uint64_t data_size, uint64_t pack_size, uint64_t offset, void *stream) {
if (data_size == 0) {
return;
}
dim3 blockDims = dim3(std::min(static_cast<uint64_t>(MAX_THREADS_PER_BLOCK), data_size));
dim3 blockDims = dim3(std::min(static_cast<uint64_t>(256), data_size));
dim3 gridDims = dim3(std::min(ROUND_UP_DIV(data_size, blockDims.x), desc->max_grid_size));
uint64_t step = gridDims.x * blockDims.x;

cudaStream_t cuda_stream = reinterpret_cast<cudaStream_t>(stream);

#pragma unroll
for (uint64_t i = 0; i < data_size; i += step) {
add<Tdata, BTdata><<<gridDims, blockDims, 0, cuda_stream>>>(
c, a, b, desc->a_strides, desc->b_strides, desc->c_strides, offset + data_size, desc->ndim, offset + i, desc->broadcasted, pack_size);
}
}

void add_nv_gpu_f16(AddCudaDescriptor_t desc, void *c, void const *a, void const *b, void *stream) {
auto data_size = desc->c_data_size / 4;
auto a_half4 = reinterpret_cast<const half4 *>(a);
auto b_half4 = reinterpret_cast<const half4 *>(b);
auto c_half4 = reinterpret_cast<half4 *>(c);
add_nv_gpu<half4, half>(desc, c_half4, a_half4, b_half4, data_size, 4, 0, stream);
template<typename Tdata, typename TIdata>
infiniopStatus_t add_nv_gpu(AddCudaDescriptor_t desc, void *c, void const *a, void const *b, void *stream, uint64_t pack_size) {
const auto data_size = desc->c_data_size / pack_size;
const auto a_vec = reinterpret_cast<const Tdata *>(a);
const auto b_vec = reinterpret_cast<const Tdata *>(b);
const auto c_vec = reinterpret_cast<Tdata *>(c);
_add_nv_gpu<Tdata, TIdata>(desc, c_vec, a_vec, b_vec, data_size, pack_size, 0, stream);

auto remainder = desc->c_data_size % 4;
auto a_half = reinterpret_cast<const half *>(a);
auto b_half = reinterpret_cast<const half *>(b);
auto c_half = reinterpret_cast<half *>(c);
add_nv_gpu<half, half>(desc, c_half, a_half, b_half, remainder, 1, data_size * 4, stream);
const auto remainder = desc->c_data_size % pack_size;
const auto a_ = reinterpret_cast<const TIdata *>(a);
const auto b_ = reinterpret_cast<const TIdata *>(b);
const auto c_ = reinterpret_cast<TIdata *>(c);
_add_nv_gpu<TIdata, TIdata>(desc, c_, a_, b_, remainder, 1, data_size * pack_size, stream);
return STATUS_SUCCESS;
}

infiniopStatus_t cudaAdd(AddCudaDescriptor_t desc,
void *c, void const *a, void const *b,
void *stream) {
if (!dtype_eq(desc->dtype, F16)) {
return STATUS_BAD_TENSOR_DTYPE;
}
checkCudaError(cudaSetDevice(desc->device_id));
add_nv_gpu_f16(desc, c, a, b, stream);
return STATUS_SUCCESS;
if (desc->dtype == F16) {
return add_nv_gpu<vecN<float2, half2, 2>, half>(desc, c, a, b, stream, 8);
}
if (desc->dtype == F32) {
return add_nv_gpu<vecN<float2, float, 2>, float>(desc, c, a, b, stream, 4);
}
return STATUS_BAD_TENSOR_DTYPE;
}
2 changes: 1 addition & 1 deletion src/ops/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ inline bool getBroadcastShape(const uint64_t *shape1, uint64_t ndim1,
std::copy(shape2, shape2 + ndim2, padded_shape2 + max_rank - ndim2);

// compute broadcasted shape
for (int i = 0; i < max_rank; ++i) {
for (size_t i = 0; i < max_rank; ++i) {
if (padded_shape1[i] == padded_shape2[i] || padded_shape1[i] == 1 || padded_shape2[i] == 1) {
broadcast_shape[i] = std::max(padded_shape1[i], padded_shape2[i]);
} else {
Expand Down