diff --git a/include/infiniop.h b/include/infiniop.h index c0a09fcb4..e87839bc2 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -2,6 +2,10 @@ #define __INFINIOP_API_H__ #include "infiniop/handle.h" +// Unified headers for elementwise operators +#include "infiniop/ops/unary_ops_api.h" +#include "infiniop/ops/binary_ops_api.h" +// Other operators #include "infiniop/ops/add.h" #include "infiniop/ops/add_rms_norm.h" #include "infiniop/ops/attention.h" diff --git a/include/infiniop/ops/binary_op_api.h b/include/infiniop/ops/binary_op_api.h new file mode 100644 index 000000000..4ab2401b9 --- /dev/null +++ b/include/infiniop/ops/binary_op_api.h @@ -0,0 +1,50 @@ +#ifndef __INFINIOP_BINARY_OP_API_H__ +#define __INFINIOP_BINARY_OP_API_H__ + +#include "../operator_descriptor.h" + +/** + * @brief Macro to generate the C API header for a binary operator. + * + * This macro generates all the necessary declarations for a binary operator: + * - Descriptor type definition + * - Create descriptor function + * - Get workspace size function + * - Execute operator function + * - Destroy descriptor function + * + * Usage: + * BINARY_OP_API_DECLARE(div, Div) + * BINARY_OP_API_DECLARE(pow, Pow) + * + * @param OP_NAME Lowercase operator name (e.g., div, pow, mod) + * @param OP_NAME_UPPER Uppercase operator name (e.g., Div, Pow, Mod) + */ +#define BINARY_OP_API_DECLARE(OP_NAME, OP_NAME_UPPER) \ + \ + typedef struct InfiniopDescriptor *infiniop##OP_NAME_UPPER##Descriptor_t; \ + \ + __C __export infiniStatus_t infiniopCreate##OP_NAME_UPPER##Descriptor( \ + infiniopHandle_t handle, \ + infiniop##OP_NAME_UPPER##Descriptor_t *desc_ptr, \ + infiniopTensorDescriptor_t c, \ + infiniopTensorDescriptor_t a, \ + infiniopTensorDescriptor_t b); \ + \ + __C __export infiniStatus_t infiniopGet##OP_NAME_UPPER##WorkspaceSize( \ + infiniop##OP_NAME_UPPER##Descriptor_t desc, \ + size_t *size); \ + \ + __C __export infiniStatus_t infiniop##OP_NAME_UPPER( \ + infiniop##OP_NAME_UPPER##Descriptor_t desc, \ + void *workspace, \ + size_t workspace_size, \ + void *c, \ + const void *a, \ + const void *b, \ + void *stream); \ + \ + __C __export infiniStatus_t infiniopDestroy##OP_NAME_UPPER##Descriptor( \ + infiniop##OP_NAME_UPPER##Descriptor_t desc); + +#endif // __INFINIOP_BINARY_OP_API_H__ diff --git a/include/infiniop/ops/binary_ops_api.h b/include/infiniop/ops/binary_ops_api.h new file mode 100644 index 000000000..24d7715c9 --- /dev/null +++ b/include/infiniop/ops/binary_ops_api.h @@ -0,0 +1,23 @@ +#ifndef __INFINIOP_BINARY_OPS_API_H__ +#define __INFINIOP_BINARY_OPS_API_H__ + +#include "binary_op_api.h" + +/** + * @brief Unified API declarations for all binary operators. + * + * This header contains API declarations for all binary operators in a single file, + * eliminating the need for individual header files for each operator. + * + * All binary operator APIs are declared here: + * - div, pow, mod, max, min + */ + +// Declare all binary operator APIs +BINARY_OP_API_DECLARE(div, Div) +BINARY_OP_API_DECLARE(pow, Pow) +BINARY_OP_API_DECLARE(mod, Mod) +BINARY_OP_API_DECLARE(max, Max) +BINARY_OP_API_DECLARE(min, Min) + +#endif // __INFINIOP_BINARY_OPS_API_H__ diff --git a/include/infiniop/ops/unary_op_api.h b/include/infiniop/ops/unary_op_api.h new file mode 100644 index 000000000..eefe3c3a4 --- /dev/null +++ b/include/infiniop/ops/unary_op_api.h @@ -0,0 +1,48 @@ +#ifndef __INFINIOP_UNARY_OP_API_H__ +#define __INFINIOP_UNARY_OP_API_H__ + +#include "../operator_descriptor.h" + +/** + * @brief Macro to generate the C API header for a unary operator. + * + * This macro generates all the necessary declarations for a unary operator: + * - Descriptor type definition + * - Create descriptor function + * - Get workspace size function + * - Execute operator function + * - Destroy descriptor function + * + * Usage: + * UNARY_OP_API_DECLARE(abs, Abs) + * UNARY_OP_API_DECLARE(log, Log) + * + * @param OP_NAME Lowercase operator name (e.g., abs, log, sin) + * @param OP_NAME_UPPER Uppercase operator name (e.g., Abs, Log, Sin) + */ +#define UNARY_OP_API_DECLARE(OP_NAME, OP_NAME_UPPER) \ + \ + typedef struct InfiniopDescriptor *infiniop##OP_NAME_UPPER##Descriptor_t; \ + \ + __C __export infiniStatus_t infiniopCreate##OP_NAME_UPPER##Descriptor( \ + infiniopHandle_t handle, \ + infiniop##OP_NAME_UPPER##Descriptor_t *desc_ptr, \ + infiniopTensorDescriptor_t y, \ + infiniopTensorDescriptor_t x); \ + \ + __C __export infiniStatus_t infiniopGet##OP_NAME_UPPER##WorkspaceSize( \ + infiniop##OP_NAME_UPPER##Descriptor_t desc, \ + size_t *size); \ + \ + __C __export infiniStatus_t infiniop##OP_NAME_UPPER( \ + infiniop##OP_NAME_UPPER##Descriptor_t desc, \ + void *workspace, \ + size_t workspace_size, \ + void *y, \ + const void *x, \ + void *stream); \ + \ + __C __export infiniStatus_t infiniopDestroy##OP_NAME_UPPER##Descriptor( \ + infiniop##OP_NAME_UPPER##Descriptor_t desc); + +#endif // __INFINIOP_UNARY_OP_API_H__ diff --git a/include/infiniop/ops/unary_ops_api.h b/include/infiniop/ops/unary_ops_api.h new file mode 100644 index 000000000..95b0773b6 --- /dev/null +++ b/include/infiniop/ops/unary_ops_api.h @@ -0,0 +1,39 @@ +#ifndef __INFINIOP_UNARY_OPS_API_H__ +#define __INFINIOP_UNARY_OPS_API_H__ + +#include "unary_op_api.h" + +/** + * @brief Unified API declarations for all unary operators. + * + * This header contains API declarations for all unary operators in a single file, + * eliminating the need for individual header files for each operator. + * + * All unary operator APIs are declared here: + * - abs, log, sqrt, reciprocal, neg, round, sinh, sign, tan + * - acosh, asinh, cos, atanh, asin, floor, cosh, erf, atan, acos, ceil + */ + +// Declare all unary operator APIs +UNARY_OP_API_DECLARE(abs, Abs) +UNARY_OP_API_DECLARE(log, Log) +UNARY_OP_API_DECLARE(sqrt, Sqrt) +UNARY_OP_API_DECLARE(reciprocal, Reciprocal) +UNARY_OP_API_DECLARE(neg, Neg) +UNARY_OP_API_DECLARE(round, Round) +UNARY_OP_API_DECLARE(sinh, Sinh) +UNARY_OP_API_DECLARE(sign, Sign) +UNARY_OP_API_DECLARE(tan, Tan) +UNARY_OP_API_DECLARE(acosh, Acosh) +UNARY_OP_API_DECLARE(asinh, Asinh) +UNARY_OP_API_DECLARE(cos, Cos) +UNARY_OP_API_DECLARE(atanh, Atanh) +UNARY_OP_API_DECLARE(asin, Asin) +UNARY_OP_API_DECLARE(floor, Floor) +UNARY_OP_API_DECLARE(cosh, Cosh) +UNARY_OP_API_DECLARE(erf, Erf) +UNARY_OP_API_DECLARE(atan, Atan) +UNARY_OP_API_DECLARE(acos, Acos) +UNARY_OP_API_DECLARE(ceil, Ceil) + +#endif // __INFINIOP_UNARY_OPS_API_H__ diff --git a/src/infiniop/elementwise/binary.h b/src/infiniop/elementwise/binary.h new file mode 100644 index 000000000..1823fac3f --- /dev/null +++ b/src/infiniop/elementwise/binary.h @@ -0,0 +1,261 @@ +#ifndef __INFINIOP_ELEMENTWISE_BINARY_H__ +#define __INFINIOP_ELEMENTWISE_BINARY_H__ + +#include +#include +#include + +#ifdef __CUDACC__ +#include +#include +#include +// Include device-specific type aliases for cuda_bfloat16 +#include "../devices/nvidia/nvidia_kernel_common.cuh" +#endif + +namespace op::elementwise::binary { + +/** + * @brief Represents all the currently defined binary operations. + * + * This enum is used to specify which binary operation to perform + * in the generic BinaryOp template. + */ +enum class BinaryMode { + // Arithmetic operations: + Add, + Subtract, + Multiply, + Divide, + Pow, + Mod, + Max, + Min, + // Logical operations (for future use): + // And, Or, Xor, Less, LessOrEqual, Equal, Greater, GreaterOrEqual +}; + +/** + * @brief Generic binary operation template that performs different operations + * based on the specified BinaryMode. + * + * This template allows multiple binary operators (pow, div, mod, min, max, etc.) + * to share the same implementation infrastructure while only differing in the + * operation mode. + * + * @tparam Mode The binary operation mode (from BinaryMode enum) + */ +template +struct BinaryOp { + static constexpr size_t num_inputs = 2; + + template + T operator()(const T &a, const T &b) const { + if constexpr (Mode == BinaryMode::Add) { + return a + b; + } else if constexpr (Mode == BinaryMode::Subtract) { + return a - b; + } else if constexpr (Mode == BinaryMode::Multiply) { + return a * b; + } else if constexpr (Mode == BinaryMode::Divide) { + return a / b; + } else if constexpr (Mode == BinaryMode::Pow) { + return std::pow(a, b); + } else if constexpr (Mode == BinaryMode::Mod) { + if constexpr (std::is_floating_point_v) { + return std::fmod(a, b); + } else { + return a % b; + } + } else if constexpr (Mode == BinaryMode::Max) { + if constexpr (std::is_floating_point_v) { + return std::fmax(a, b); + } else { + return std::max(a, b); + } + } else if constexpr (Mode == BinaryMode::Min) { + if constexpr (std::is_floating_point_v) { + return std::fmin(a, b); + } else { + return std::min(a, b); + } + } else { + static_assert(Mode != Mode, "Unsupported binary operation mode"); + return a; + } + } +}; + +#ifdef __CUDACC__ +/** + * @brief CUDA-specific binary operation template that performs different operations + * based on the specified BinaryMode, using CUDA-optimized functions. + * + * This template provides CUDA device functions optimized for GPU execution, + * using intrinsics like __powf, __h2div, __hmin2, __hmax2, etc. + * + * @tparam Mode The binary operation mode (from BinaryMode enum) + */ +namespace cuda { +template +struct BinaryOp { + static constexpr size_t num_inputs = 2; + + template + __device__ __forceinline__ T operator()(const T &a, const T &b) const { + if constexpr (Mode == BinaryMode::Add) { + if constexpr (std::is_same_v) { + return __hadd2(a, b); + } else if constexpr (std::is_same_v || std::is_same_v) { + return __hadd(a, b); + } else if constexpr (std::is_same_v) { + return __fadd_rn(a, b); + } else { + return a + b; + } + } else if constexpr (Mode == BinaryMode::Subtract) { + if constexpr (std::is_same_v) { + return __hsub2(a, b); + } else if constexpr (std::is_same_v || std::is_same_v) { + return __hsub(a, b); + } else if constexpr (std::is_same_v) { + return __fsub_rn(a, b); + } else { + return a - b; + } + } else if constexpr (Mode == BinaryMode::Multiply) { + if constexpr (std::is_same_v) { + return __hmul2(a, b); + } else if constexpr (std::is_same_v || std::is_same_v) { + return __hmul(a, b); + } else if constexpr (std::is_same_v) { + return __fmul_rd(a, b); + } else { + return a * b; + } + } else if constexpr (Mode == BinaryMode::Divide) { + if constexpr (std::is_same_v) { + return __h2div(a, b); + } else if constexpr (std::is_same_v || std::is_same_v) { + return a / b; + } else if constexpr (std::is_same_v) { + return __fdividef(a, b); + } else { + return a / b; + } + } else if constexpr (Mode == BinaryMode::Pow) { + if constexpr (std::is_same_v) { + float2 a_f2 = __half22float2(a); + float2 b_f2 = __half22float2(b); + return __float22half2_rn(make_float2(__powf(a_f2.x, b_f2.x), __powf(a_f2.y, b_f2.y))); + } else if constexpr (std::is_same_v) { + float a_ = __half2float(a); + float b_ = __half2float(b); + float ans_f = __powf(a_, b_); + return __float2half(isnan(ans_f) ? std::pow(a_, b_) : ans_f); + } else if constexpr (std::is_same_v) { + float2 a_f2 = __bfloat1622float2(a); + float2 b_f2 = __bfloat1622float2(b); + return __floats2bfloat162_rn(__powf(a_f2.x, b_f2.x), __powf(a_f2.y, b_f2.y)); + } else if constexpr (std::is_same_v) { + float a_ = __bfloat162float(a); + float b_ = __bfloat162float(b); + return __float2bfloat16_rn(__powf(a_, b_)); + } else if constexpr (std::is_same_v) { + return __powf(a, b); + } else { + return std::pow(a, b); + } + } else if constexpr (Mode == BinaryMode::Mod) { + if constexpr (std::is_same_v) { + float2 a_f2 = __half22float2(a); + float2 b_f2 = __half22float2(b); + return __float22half2_rn(make_float2(std::fmod(a_f2.x, b_f2.x), std::fmod(a_f2.y, b_f2.y))); + } else if constexpr (std::is_same_v) { + float a_ = __half2float(a); + float b_ = __half2float(b); + return __float2half(std::fmod(a_, b_)); + } else if constexpr (std::is_floating_point_v) { + return std::fmod(a, b); + } else { + return a % b; + } + } else if constexpr (Mode == BinaryMode::Max) { + if constexpr (std::is_same_v) { + return __hmax2(a, b); + } else if constexpr (std::is_same_v || std::is_same_v) { + return a > b ? a : b; + } else if constexpr (std::is_same_v) { + return fmaxf(a, b); + } else { + return a > b ? a : b; + } + } else if constexpr (Mode == BinaryMode::Min) { + if constexpr (std::is_same_v) { + return __hmin2(a, b); + } else if constexpr (std::is_same_v || std::is_same_v) { + return a < b ? a : b; + } else if constexpr (std::is_same_v) { + return fminf(a, b); + } else { + return a < b ? a : b; + } + } else { + static_assert(Mode != Mode, "Unsupported binary operation mode"); + return a; + } + } +}; +} // namespace cuda +#endif // __CUDACC__ + +/** + * @brief Macro to define a binary elementwise descriptor for a specific operation. + * + * This macro simplifies the definition of binary operators (pow, div, mod, min, max, etc.) + * by automatically generating the Descriptor class and operation struct using the + * ELEMENTWISE_DESCRIPTOR macro and BinaryOp template. + * + * Usage: + * BINARY_ELEMENTWISE_DESCRIPTOR(pow, cpu, BinaryMode::Pow) + * BINARY_ELEMENTWISE_DESCRIPTOR(div, cpu, BinaryMode::Divide) + * + * @param OP The operator name (e.g., pow, div, mod) + * @param NAMESPACE The device namespace (e.g., cpu, nvidia) + * @param MODE The BinaryMode enum value for this operation + */ +#define BINARY_ELEMENTWISE_DESCRIPTOR(OP, NAMESPACE, MODE) \ + \ + ELEMENTWISE_DESCRIPTOR(OP, NAMESPACE) \ + \ + namespace op::OP::NAMESPACE { \ + using Op = op::elementwise::binary::BinaryOp; \ + } + +/** + * @brief Macro to define a binary elementwise descriptor for CUDA/NVIDIA backend. + * + * This macro is similar to BINARY_ELEMENTWISE_DESCRIPTOR but uses the CUDA-specific + * BinaryOp implementation for better GPU performance. + * + * Usage: + * BINARY_ELEMENTWISE_DESCRIPTOR_CUDA(pow, nvidia, BinaryMode::Pow) + * BINARY_ELEMENTWISE_DESCRIPTOR_CUDA(div, nvidia, BinaryMode::Divide) + * + * @param OP The operator name (e.g., pow, div, mod) + * @param NAMESPACE The device namespace (e.g., nvidia) + * @param MODE The BinaryMode enum value for this operation + */ +#ifdef __CUDACC__ +#define BINARY_ELEMENTWISE_DESCRIPTOR_CUDA(OP, NAMESPACE, MODE) \ + \ + ELEMENTWISE_DESCRIPTOR(OP, NAMESPACE) \ + \ + namespace op::OP::cuda { \ + using Op = op::elementwise::binary::cuda::BinaryOp; \ + } +#endif // __CUDACC__ + +} // namespace op::elementwise::binary + +#endif // __INFINIOP_ELEMENTWISE_BINARY_H__ diff --git a/src/infiniop/elementwise/cpu/elementwise_cpu_impl.h b/src/infiniop/elementwise/cpu/elementwise_cpu_impl.h new file mode 100644 index 000000000..030f4d87e --- /dev/null +++ b/src/infiniop/elementwise/cpu/elementwise_cpu_impl.h @@ -0,0 +1,130 @@ +#ifndef __INFINIOP_ELEMENTWISE_CPU_IMPL_H__ +#define __INFINIOP_ELEMENTWISE_CPU_IMPL_H__ + +#include "../../../utils/check.h" +#include "../../../utils/result.hpp" +#include "../../devices/cpu/common_cpu.h" +#include "elementwise_cpu.h" + +/** + * @brief Generic implementation for elementwise CPU operators. + * + * This file provides a generic implementation template that can be used + * by all binary and unary operators to reduce code duplication. + * + * Usage: + * #include "elementwise_cpu_impl.h" + * namespace op::pow::cpu { + * using Op = op::elementwise::binary::BinaryOp; + * ELEMENTWISE_CPU_IMPL_BINARY(pow) + * } + * + * namespace op::sqrt::cpu { + * using Op = op::elementwise::unary::UnaryOp; + * ELEMENTWISE_CPU_IMPL_UNARY(sqrt) + * } + */ + +/** + * @brief Macro to generate binary operator implementation. + * + * This macro generates the Descriptor destructor, create, and calculate methods + * for binary operators, using the generic implementation. + * + * Usage: + * namespace op::pow::cpu { + * using Op = op::elementwise::binary::BinaryOp; + * ELEMENTWISE_CPU_IMPL_BINARY(pow) + * } + */ +#define ELEMENTWISE_CPU_IMPL_BINARY(OP) \ + \ + Descriptor::~Descriptor() = default; \ + \ + infiniStatus_t Descriptor::create( \ + infiniopHandle_t handle_, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t out_desc, \ + std::vector input_desc_vec) { \ + auto handle = reinterpret_cast(handle_); \ + auto dtype = out_desc->dtype(); \ + const auto &a_desc = input_desc_vec.at(0); \ + const auto &b_desc = input_desc_vec.at(1); \ + const auto &out_shape = out_desc->shape(); \ + const auto &a_shape = a_desc->shape(); \ + const auto &b_shape = b_desc->shape(); \ + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32); \ + CHECK_SAME_SHAPE(out_shape, a_shape, b_shape); \ + CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec); \ + return INFINI_STATUS_SUCCESS; \ + } \ + \ + infiniStatus_t Descriptor::calculate( \ + void *workspace, \ + size_t workspace_size, \ + void *output, \ + std::vector inputs, \ + void *stream) const { \ + switch (_dtype) { \ + case INFINI_DTYPE_F16: \ + return _device_info->template calculate( \ + _info, output, inputs, stream); \ + case INFINI_DTYPE_F32: \ + return _device_info->template calculate( \ + _info, output, inputs, stream); \ + default: \ + return INFINI_STATUS_BAD_TENSOR_DTYPE; \ + } \ + } + +/** + * @brief Macro to generate unary operator implementation. + * + * This macro generates the Descriptor destructor, create, and calculate methods + * for unary operators, using the generic implementation. + * + * Usage: + * namespace op::sqrt::cpu { + * using Op = op::elementwise::unary::UnaryOp; + * ELEMENTWISE_CPU_IMPL_UNARY(sqrt) + * } + */ +#define ELEMENTWISE_CPU_IMPL_UNARY(OP) \ + \ + Descriptor::~Descriptor() = default; \ + \ + infiniStatus_t Descriptor::create( \ + infiniopHandle_t handle_, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t out_desc, \ + std::vector input_desc_vec) { \ + auto handle = reinterpret_cast(handle_); \ + auto dtype = out_desc->dtype(); \ + const auto &x_desc = input_desc_vec.at(0); \ + const auto &y_shape = out_desc->shape(); \ + const auto &x_shape = x_desc->shape(); \ + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32); \ + CHECK_SAME_SHAPE(y_shape, x_shape); \ + CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec); \ + return INFINI_STATUS_SUCCESS; \ + } \ + \ + infiniStatus_t Descriptor::calculate( \ + void *workspace, \ + size_t workspace_size, \ + void *output, \ + std::vector inputs, \ + void *stream) const { \ + switch (_dtype) { \ + case INFINI_DTYPE_F16: \ + return _device_info->template calculate( \ + _info, output, inputs, stream); \ + case INFINI_DTYPE_F32: \ + return _device_info->template calculate( \ + _info, output, inputs, stream); \ + default: \ + return INFINI_STATUS_BAD_TENSOR_DTYPE; \ + } \ + } + +#endif // __INFINIOP_ELEMENTWISE_CPU_IMPL_H__ diff --git a/src/infiniop/elementwise/nvidia/elementwise_nvidia_impl.cuh b/src/infiniop/elementwise/nvidia/elementwise_nvidia_impl.cuh new file mode 100644 index 000000000..39b78884a --- /dev/null +++ b/src/infiniop/elementwise/nvidia/elementwise_nvidia_impl.cuh @@ -0,0 +1,134 @@ +#ifndef __INFINIOP_ELEMENTWISE_NVIDIA_IMPL_CUH__ +#define __INFINIOP_ELEMENTWISE_NVIDIA_IMPL_CUH__ + +#include "../../../utils/check.h" +#include "../../../utils/result.hpp" +#include "../../devices/nvidia/nvidia_common.cuh" +#include "elementwise_nvidia.cuh" +#include +#include + +/** + * @brief Generic implementation for elementwise NVIDIA/CUDA operators. + * + * This file provides a generic implementation template that can be used + * by all binary and unary operators to reduce code duplication. + * + * Usage: + * #include "elementwise_nvidia_impl.cuh" + * namespace op::pow::nvidia { + * ELEMENTWISE_NVIDIA_IMPL_BINARY(pow) + * } + * + * namespace op::sqrt::nvidia { + * ELEMENTWISE_NVIDIA_IMPL_UNARY(sqrt) + * } + */ + +/** + * @brief Macro to generate binary operator implementation for NVIDIA/CUDA. + * + * This macro generates the Descriptor destructor, create, and calculate methods + * for binary operators, using the generic implementation. + * + * Usage: + * namespace op::pow::nvidia { + * ELEMENTWISE_NVIDIA_IMPL_BINARY(pow) + * } + */ +#define ELEMENTWISE_NVIDIA_IMPL_BINARY(OP) \ + \ + Descriptor::~Descriptor() = default; \ + \ + infiniStatus_t Descriptor::create( \ + infiniopHandle_t handle_, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t out_desc, \ + std::vector input_desc_vec) { \ + auto handle = reinterpret_cast(handle_); \ + auto dtype = out_desc->dtype(); \ + const auto &a_desc = input_desc_vec.at(0); \ + const auto &b_desc = input_desc_vec.at(1); \ + const auto &c_shape = out_desc->shape(); \ + const auto &a_shape = a_desc->shape(); \ + const auto &b_shape = b_desc->shape(); \ + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32); \ + CHECK_SAME_SHAPE(c_shape, a_shape, b_shape); \ + CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec); \ + return INFINI_STATUS_SUCCESS; \ + } \ + \ + infiniStatus_t Descriptor::calculate( \ + void *workspace, \ + size_t workspace_size, \ + void *output, \ + std::vector inputs, \ + void *stream) const { \ + if (workspace_size < _workspace_size) { \ + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; \ + } \ + switch (_dtype) { \ + case INFINI_DTYPE_F16: \ + return _device_info->calculate<256, cuda::Op, half>( \ + _info, workspace, output, inputs, stream); \ + case INFINI_DTYPE_F32: \ + return _device_info->calculate<256, cuda::Op, float>( \ + _info, workspace, output, inputs, stream); \ + default: \ + return INFINI_STATUS_BAD_TENSOR_DTYPE; \ + } \ + } + +/** + * @brief Macro to generate unary operator implementation for NVIDIA/CUDA. + * + * This macro generates the Descriptor destructor, create, and calculate methods + * for unary operators, using the generic implementation. + * + * Usage: + * namespace op::sqrt::nvidia { + * ELEMENTWISE_NVIDIA_IMPL_UNARY(sqrt) + * } + */ +#define ELEMENTWISE_NVIDIA_IMPL_UNARY(OP) \ + \ + Descriptor::~Descriptor() = default; \ + \ + infiniStatus_t Descriptor::create( \ + infiniopHandle_t handle_, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t out_desc, \ + std::vector input_desc_vec) { \ + auto handle = reinterpret_cast(handle_); \ + auto dtype = out_desc->dtype(); \ + const auto &x_desc = input_desc_vec.at(0); \ + const auto &y_shape = out_desc->shape(); \ + const auto &x_shape = x_desc->shape(); \ + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32); \ + CHECK_SAME_SHAPE(y_shape, x_shape); \ + CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec); \ + return INFINI_STATUS_SUCCESS; \ + } \ + \ + infiniStatus_t Descriptor::calculate( \ + void *workspace, \ + size_t workspace_size, \ + void *output, \ + std::vector inputs, \ + void *stream) const { \ + if (workspace_size < _workspace_size) { \ + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; \ + } \ + switch (_dtype) { \ + case INFINI_DTYPE_F16: \ + return _device_info->calculate<256, cuda::Op, half>( \ + _info, workspace, output, inputs, stream); \ + case INFINI_DTYPE_F32: \ + return _device_info->calculate<256, cuda::Op, float>( \ + _info, workspace, output, inputs, stream); \ + default: \ + return INFINI_STATUS_BAD_TENSOR_DTYPE; \ + } \ + } + +#endif // __INFINIOP_ELEMENTWISE_NVIDIA_IMPL_CUH__ diff --git a/src/infiniop/elementwise/unary.h b/src/infiniop/elementwise/unary.h new file mode 100644 index 000000000..9f41dedb2 --- /dev/null +++ b/src/infiniop/elementwise/unary.h @@ -0,0 +1,524 @@ +#ifndef __INFINIOP_ELEMENTWISE_UNARY_H__ +#define __INFINIOP_ELEMENTWISE_UNARY_H__ + +#include +#include +#include + +#ifdef __CUDACC__ +#include +#include +#include +// Include device-specific type aliases for cuda_bfloat16 +#include "../devices/nvidia/nvidia_kernel_common.cuh" +#endif + +namespace op::elementwise::unary { + +/** + * @brief Represents all the currently defined unary operations. + * + * This enum is used to specify which unary operation to perform + * in the generic UnaryOp template. + */ +enum class UnaryMode { + // Math operations: + Abs, + Exp, + Log, + Reciprocal, + Sqrt, + Neg, + Ceil, + Floor, + Round, + Sin, + Cos, + Tan, + Asin, + Acos, + Atan, + Sinh, + Cosh, + Tanh, + Asinh, + Acosh, + Atanh, + Relu, + Sigmoid, + Sign, + Erf, +}; + +/** + * @brief Generic unary operation template that performs different operations + * based on the specified UnaryMode. + * + * This template allows multiple unary operators (abs, log, sin, cos, etc.) + * to share the same implementation infrastructure while only differing in the + * operation mode. + * + * @tparam Mode The unary operation mode (from UnaryMode enum) + */ +template +struct UnaryOp { + static constexpr size_t num_inputs = 1; + + template + T operator()(const T &x) const { + if constexpr (Mode == UnaryMode::Abs) { + if constexpr (std::is_floating_point_v) { + return std::fabs(x); + } else { + return std::abs(x); + } + } else if constexpr (Mode == UnaryMode::Exp) { + return std::exp(x); + } else if constexpr (Mode == UnaryMode::Log) { + return std::log(x); + } else if constexpr (Mode == UnaryMode::Reciprocal) { + return T(1) / x; + } else if constexpr (Mode == UnaryMode::Sqrt) { + return std::sqrt(x); + } else if constexpr (Mode == UnaryMode::Neg) { + return -x; + } else if constexpr (Mode == UnaryMode::Ceil) { + return std::ceil(x); + } else if constexpr (Mode == UnaryMode::Floor) { + return std::floor(x); + } else if constexpr (Mode == UnaryMode::Round) { + if constexpr (std::is_integral_v) { + return x; + } else { + return std::nearbyint(x); + } + } else if constexpr (Mode == UnaryMode::Sin) { + return std::sin(x); + } else if constexpr (Mode == UnaryMode::Cos) { + return std::cos(x); + } else if constexpr (Mode == UnaryMode::Tan) { + return std::tan(x); + } else if constexpr (Mode == UnaryMode::Asin) { + return std::asin(x); + } else if constexpr (Mode == UnaryMode::Acos) { + return std::acos(x); + } else if constexpr (Mode == UnaryMode::Atan) { + return std::atan(x); + } else if constexpr (Mode == UnaryMode::Sinh) { + return std::sinh(x); + } else if constexpr (Mode == UnaryMode::Cosh) { + return std::cosh(x); + } else if constexpr (Mode == UnaryMode::Tanh) { + return std::tanh(x); + } else if constexpr (Mode == UnaryMode::Asinh) { + return std::asinh(x); + } else if constexpr (Mode == UnaryMode::Acosh) { + return std::acosh(x); + } else if constexpr (Mode == UnaryMode::Atanh) { + return std::atanh(x); + } else if constexpr (Mode == UnaryMode::Relu) { + return x > T(0) ? x : T(0); + } else if constexpr (Mode == UnaryMode::Sigmoid) { + return T(1) / (T(1) + std::exp(-x)); + } else if constexpr (Mode == UnaryMode::Sign) { + return x > T(0) ? T(1) : (x == T(0) ? T(0) : T(-1)); + } else if constexpr (Mode == UnaryMode::Erf) { + return std::erf(x); + } else { + static_assert(Mode != Mode, "Unsupported unary operation mode"); + return x; + } + } +}; + +#ifdef __CUDACC__ +/** + * @brief CUDA-specific unary operation template that performs different operations + * based on the specified UnaryMode, using CUDA-optimized functions. + * + * This template provides CUDA device functions optimized for GPU execution, + * using intrinsics like __habs2, __logf, __sinf, etc. + * + * @tparam Mode The unary operation mode (from UnaryMode enum) + */ +namespace cuda { +template +struct UnaryOp { + static constexpr size_t num_inputs = 1; + + template + __device__ __forceinline__ T operator()(const T &x) const { + if constexpr (Mode == UnaryMode::Abs) { + if constexpr (std::is_same_v) { + return __habs2(x); + } else if constexpr (std::is_same_v) { + return __habs(x); + } else if constexpr (std::is_floating_point_v) { + return std::fabs(x); + } else { + return std::abs(x); + } + } else if constexpr (Mode == UnaryMode::Exp) { + if constexpr (std::is_same_v) { + float2 x_f2 = __half22float2(x); + return __float22half2_rn(make_float2(__expf(x_f2.x), __expf(x_f2.y))); + } else if constexpr (std::is_same_v) { + return __float2half(__expf(__half2float(x))); + } else if constexpr (std::is_same_v) { + float2 x_f2 = __bfloat1622float2(x); + return __floats2bfloat162_rn(__expf(x_f2.x), __expf(x_f2.y)); + } else if constexpr (std::is_same_v) { + return __float2bfloat16_rn(__expf(__bfloat162float(x))); + } else if constexpr (std::is_same_v) { + return __expf(x); + } else { + return std::exp(x); + } + } else if constexpr (Mode == UnaryMode::Log) { + if constexpr (std::is_same_v) { + return h2log(x); + } else if constexpr (std::is_same_v) { + return __float2half(__logf(__half2float(x))); + } else if constexpr (std::is_same_v) { + float x0 = __bfloat162float(__low2bfloat16(x)); + float x1 = __bfloat162float(__high2bfloat16(x)); + return __floats2bfloat162_rn(logf(x0), logf(x1)); + } else if constexpr (std::is_same_v) { + return __float2bfloat16_rn(logf(__bfloat162float(x))); + } else if constexpr (std::is_same_v) { + return __logf(x); + } else { + return std::log(x); + } + } else if constexpr (Mode == UnaryMode::Reciprocal) { + if constexpr (std::is_same_v) { + return h2rcp(x); + } else if constexpr (std::is_same_v) { + return hrcp(x); + } else if constexpr (std::is_same_v) { + float x0 = __bfloat162float(__low2bfloat16(x)); + float x1 = __bfloat162float(__high2bfloat16(x)); + return __floats2bfloat162_rn(__frcp_rn(x0), __frcp_rn(x1)); + } else if constexpr (std::is_same_v) { + return __float2bfloat16_rn(__frcp_rn(__bfloat162float(x))); + } else if constexpr (std::is_same_v) { + return __frcp_rn(x); + } else { + return T(1) / x; + } + } else if constexpr (Mode == UnaryMode::Sqrt) { + if constexpr (std::is_same_v) { + return h2sqrt(x); + } else if constexpr (std::is_same_v) { + return hsqrt(x); + } else if constexpr (std::is_same_v) { + float x0 = __bfloat162float(__low2bfloat16(x)); + float x1 = __bfloat162float(__high2bfloat16(x)); + return __floats2bfloat162_rn(sqrtf(x0), sqrtf(x1)); + } else if constexpr (std::is_same_v) { + return __float2bfloat16_rn(sqrtf(__bfloat162float(x))); + } else if constexpr (std::is_same_v) { + return __fsqrt_rn(x); + } else { + return std::sqrt(x); + } + } else if constexpr (Mode == UnaryMode::Neg) { + if constexpr (std::is_same_v) { + return __hneg2(x); + } else if constexpr (std::is_same_v) { + return __hneg(x); + } else { + return -x; + } + } else if constexpr (Mode == UnaryMode::Ceil) { + if constexpr (std::is_same_v) { + return h2ceil(x); + } else if constexpr (std::is_same_v) { + return hceil(x); + } else if constexpr (std::is_same_v) { + float x0 = __bfloat162float(__low2bfloat16(x)); + float x1 = __bfloat162float(__high2bfloat16(x)); + return __floats2bfloat162_rn(ceilf(x0), ceilf(x1)); + } else if constexpr (std::is_same_v) { + return __float2bfloat16_rn(ceilf(__bfloat162float(x))); + } else if constexpr (std::is_same_v) { + return ceilf(x); + } else if constexpr (std::is_integral_v) { + return x; + } else { + return std::ceil(x); + } + } else if constexpr (Mode == UnaryMode::Floor) { + if constexpr (std::is_same_v) { + return h2floor(x); + } else if constexpr (std::is_same_v) { + return hfloor(x); + } else if constexpr (std::is_same_v) { + float x0 = __bfloat162float(__low2bfloat16(x)); + float x1 = __bfloat162float(__high2bfloat16(x)); + return __floats2bfloat162_rn(floorf(x0), floorf(x1)); + } else if constexpr (std::is_same_v) { + return __float2bfloat16_rn(floorf(__bfloat162float(x))); + } else if constexpr (std::is_same_v) { + return floorf(x); + } else if constexpr (std::is_integral_v) { + return x; + } else { + return std::floor(x); + } + } else if constexpr (Mode == UnaryMode::Round) { + if constexpr (std::is_same_v) { + return h2rint(x); + } else if constexpr (std::is_same_v) { + return hrint(x); + } else if constexpr (std::is_same_v) { + float x0 = __bfloat162float(__low2bfloat16(x)); + float x1 = __bfloat162float(__high2bfloat16(x)); + return __floats2bfloat162_rn(rintf(x0), rintf(x1)); + } else if constexpr (std::is_same_v) { + return __float2bfloat16_rn(rintf(__bfloat162float(x))); + } else if constexpr (std::is_same_v) { + return rintf(x); + } else if constexpr (std::is_integral_v) { + return x; + } else { + return std::nearbyint(x); + } + } else if constexpr (Mode == UnaryMode::Sin) { + if constexpr (std::is_same_v) { + float2 x_f2 = __half22float2(x); + return __float22half2_rn(make_float2(__sinf(x_f2.x), __sinf(x_f2.y))); + } else if constexpr (std::is_same_v) { + return __float2half(__sinf(__half2float(x))); + } else if constexpr (std::is_same_v) { + float x0 = __bfloat162float(__low2bfloat16(x)); + float x1 = __bfloat162float(__high2bfloat16(x)); + return __floats2bfloat162_rn(sinf(x0), sinf(x1)); + } else if constexpr (std::is_same_v) { + return __float2bfloat16_rn(sinf(__bfloat162float(x))); + } else if constexpr (std::is_same_v) { + return __sinf(x); + } else { + return std::sin(x); + } + } else if constexpr (Mode == UnaryMode::Cos) { + if constexpr (std::is_same_v) { + float2 x_f2 = __half22float2(x); + return __float22half2_rn(make_float2(__cosf(x_f2.x), __cosf(x_f2.y))); + } else if constexpr (std::is_same_v) { + return __float2half(__cosf(__half2float(x))); + } else if constexpr (std::is_same_v) { + float x0 = __bfloat162float(__low2bfloat16(x)); + float x1 = __bfloat162float(__high2bfloat16(x)); + return __floats2bfloat162_rn(cosf(x0), cosf(x1)); + } else if constexpr (std::is_same_v) { + return __float2bfloat16_rn(cosf(__bfloat162float(x))); + } else if constexpr (std::is_same_v) { + return __cosf(x); + } else { + return std::cos(x); + } + } else if constexpr (Mode == UnaryMode::Tan) { + if constexpr (std::is_same_v) { + float2 x_f2 = __half22float2(x); + return __float22half2_rn(make_float2(tanf(x_f2.x), tanf(x_f2.y))); + } else if constexpr (std::is_same_v) { + return __float2half(tanf(__half2float(x))); + } else if constexpr (std::is_same_v) { + return tanf(x); + } else { + return std::tan(x); + } + } else if constexpr (Mode == UnaryMode::Asin) { + if constexpr (std::is_same_v) { + float2 x_f2 = __half22float2(x); + return __float22half2_rn(make_float2(asinf(x_f2.x), asinf(x_f2.y))); + } else if constexpr (std::is_same_v) { + return __float2half(asinf(__half2float(x))); + } else if constexpr (std::is_same_v) { + return asinf(x); + } else { + return std::asin(x); + } + } else if constexpr (Mode == UnaryMode::Acos) { + if constexpr (std::is_same_v) { + float2 x_f2 = __half22float2(x); + return __float22half2_rn(make_float2(acosf(x_f2.x), acosf(x_f2.y))); + } else if constexpr (std::is_same_v) { + return __float2half(acosf(__half2float(x))); + } else if constexpr (std::is_same_v) { + return acosf(x); + } else { + return std::acos(x); + } + } else if constexpr (Mode == UnaryMode::Atan) { + if constexpr (std::is_same_v) { + float2 x_f2 = __half22float2(x); + return __float22half2_rn(make_float2(atanf(x_f2.x), atanf(x_f2.y))); + } else if constexpr (std::is_same_v) { + return __float2half(atanf(__half2float(x))); + } else if constexpr (std::is_same_v) { + return atanf(x); + } else { + return std::atan(x); + } + } else if constexpr (Mode == UnaryMode::Sinh) { + if constexpr (std::is_same_v) { + float2 x_f2 = __half22float2(x); + return __float22half2_rn(make_float2(sinhf(x_f2.x), sinhf(x_f2.y))); + } else if constexpr (std::is_same_v) { + return __float2half(sinhf(__half2float(x))); + } else if constexpr (std::is_same_v) { + return sinhf(x); + } else { + return std::sinh(x); + } + } else if constexpr (Mode == UnaryMode::Cosh) { + if constexpr (std::is_same_v) { + float2 x_f2 = __half22float2(x); + return __float22half2_rn(make_float2(coshf(x_f2.x), coshf(x_f2.y))); + } else if constexpr (std::is_same_v) { + return __float2half(coshf(__half2float(x))); + } else if constexpr (std::is_same_v) { + return coshf(x); + } else { + return std::cosh(x); + } + } else if constexpr (Mode == UnaryMode::Tanh) { + if constexpr (std::is_same_v) { + return __h2tanh(x); + } else if constexpr (std::is_same_v) { + return __float2half(tanhf(__half2float(x))); + } else if constexpr (std::is_same_v) { + float f0 = __bfloat162float(__low2bfloat16(x)); + float f1 = __bfloat162float(__high2bfloat16(x)); + return __floats2bfloat162_rn(tanhf(f0), tanhf(f1)); + } else if constexpr (std::is_same_v) { + return __float2bfloat16_rn(tanhf(__bfloat162float(x))); + } else if constexpr (std::is_same_v) { + return tanhf(x); + } else { + return std::tanh(x); + } + } else if constexpr (Mode == UnaryMode::Asinh) { + if constexpr (std::is_same_v) { + return __floats2half2_rn(asinhf(__half2float(__low2half(x))), asinhf(__half2float(__high2half(x)))); + } else if constexpr (std::is_same_v) { + return __float2half(asinhf(__half2float(x))); + } else if constexpr (std::is_same_v) { + float x0 = __bfloat162float(__low2bfloat16(x)); + float x1 = __bfloat162float(__high2bfloat16(x)); + return __floats2bfloat162_rn(asinhf(x0), asinhf(x1)); + } else if constexpr (std::is_same_v) { + return __float2bfloat16_rn(asinhf(__bfloat162float(x))); + } else if constexpr (std::is_same_v) { + return asinhf(x); + } else { + return std::asinh(x); + } + } else if constexpr (Mode == UnaryMode::Acosh) { + if constexpr (std::is_same_v) { + return __floats2half2_rn(acoshf(__half2float(__low2half(x))), acoshf(__half2float(__high2half(x)))); + } else if constexpr (std::is_same_v) { + return __float2half(acoshf(__half2float(x))); + } else if constexpr (std::is_same_v) { + float x0 = __bfloat162float(__low2bfloat16(x)); + float x1 = __bfloat162float(__high2bfloat16(x)); + return __floats2bfloat162_rn(acoshf(x0), acoshf(x1)); + } else if constexpr (std::is_same_v) { + return __float2bfloat16_rn(acoshf(__bfloat162float(x))); + } else if constexpr (std::is_same_v) { + return acoshf(x); + } else { + return std::acosh(x); + } + } else if constexpr (Mode == UnaryMode::Atanh) { + if constexpr (std::is_same_v) { + return __floats2half2_rn(atanhf(__half2float(__low2half(x))), atanhf(__half2float(__high2half(x)))); + } else if constexpr (std::is_same_v) { + return __float2half(atanhf(__half2float(x))); + } else if constexpr (std::is_same_v) { + float x0 = __bfloat162float(__low2bfloat16(x)); + float x1 = __bfloat162float(__high2bfloat16(x)); + return __floats2bfloat162_rn(atanhf(x0), atanhf(x1)); + } else if constexpr (std::is_same_v) { + return __float2bfloat16_rn(atanhf(__bfloat162float(x))); + } else if constexpr (std::is_same_v) { + return atanhf(x); + } else { + return std::atanh(x); + } + } else if constexpr (Mode == UnaryMode::Relu) { + if constexpr (std::is_same_v) { + return __hmax2(x, __floats2half2_rn(0.0f, 0.0f)); + } else { + return x > T(0) ? x : T(0); + } + } else if constexpr (Mode == UnaryMode::Sigmoid) { + if constexpr (std::is_same_v) { + float2 x_f2 = __half22float2(x); + float2 exp_neg_x = make_float2(__expf(-x_f2.x), __expf(-x_f2.y)); + return __float22half2_rn(make_float2(1.0f / (1.0f + exp_neg_x.x), 1.0f / (1.0f + exp_neg_x.y))); + } else if constexpr (std::is_same_v) { + float x_ = __half2float(x); + return __float2half(1.0f / (1.0f + __expf(-x_))); + } else if constexpr (std::is_same_v) { + return 1.0f / (1.0f + __expf(-x)); + } else { + return T(1) / (T(1) + std::exp(-x)); + } + } else if constexpr (Mode == UnaryMode::Sign) { + if constexpr (std::is_same_v) { + const auto lt_mask = __hlt2(x, __floats2half2_rn(0.0f, 0.0f)); + return __hadd2(__hneg2(lt_mask), __hsub2(__floats2half2_rn(1.0f, 1.0f), lt_mask)); + } else if constexpr (std::is_same_v) { + return x > half(0) ? half(1) : (x == half(0) ? half(0) : half(-1)); + } else { + return x > T(0) ? T(1) : (x == T(0) ? T(0) : T(-1)); + } + } else if constexpr (Mode == UnaryMode::Erf) { + if constexpr (std::is_same_v) { + float2 x_f2 = __half22float2(x); + return __float22half2_rn(make_float2(erff(x_f2.x), erff(x_f2.y))); + } else if constexpr (std::is_same_v) { + return __float2half(erff(__half2float(x))); + } else if constexpr (std::is_same_v) { + return erff(x); + } else { + return std::erf(x); + } + } else { + static_assert(Mode != Mode, "Unsupported unary operation mode"); + return x; + } + } +}; +} // namespace cuda +#endif // __CUDACC__ + +/** + * @brief Macro to define a unary elementwise descriptor for a specific operation. + * + * This macro simplifies the definition of unary operators (abs, log, sin, cos, etc.) + * by automatically generating the Descriptor class and operation struct using the + * ELEMENTWISE_DESCRIPTOR macro and UnaryOp template. + * + * Usage: + * UNARY_ELEMENTWISE_DESCRIPTOR(abs, cpu, UnaryMode::Abs) + * UNARY_ELEMENTWISE_DESCRIPTOR(log, cpu, UnaryMode::Log) + * + * @param OP The operator name (e.g., abs, log, sin) + * @param NAMESPACE The device namespace (e.g., cpu, nvidia) + * @param MODE The UnaryMode enum value for this operation + */ +#define UNARY_ELEMENTWISE_DESCRIPTOR(OP, NAMESPACE, MODE) \ + \ + ELEMENTWISE_DESCRIPTOR(OP, NAMESPACE) \ + \ + namespace op::OP::NAMESPACE { \ + using Op = op::elementwise::unary::UnaryOp; \ + } + +} // namespace op::elementwise::unary + +#endif // __INFINIOP_ELEMENTWISE_UNARY_H__ diff --git a/src/infiniop/operator_impl.h b/src/infiniop/operator_impl.h new file mode 100644 index 000000000..3ff543f7e --- /dev/null +++ b/src/infiniop/operator_impl.h @@ -0,0 +1,288 @@ +#ifndef __INFINIOP_OPERATOR_IMPL_H__ +#define __INFINIOP_OPERATOR_IMPL_H__ + +#include "handle.h" +#include "operator.h" + +// Conditional compilation helpers +#ifdef ENABLE_CPU_API +#define IF_ENABLE_CPU_API(...) __VA_ARGS__ +#else +#define IF_ENABLE_CPU_API(...) +#endif + +#ifdef ENABLE_NVIDIA_API +#define IF_ENABLE_NVIDIA_API(...) __VA_ARGS__ +#else +#define IF_ENABLE_NVIDIA_API(...) +#endif + +#ifdef ENABLE_ILUVATAR_API +#define IF_ENABLE_ILUVATAR_API(...) __VA_ARGS__ +#else +#define IF_ENABLE_ILUVATAR_API(...) +#endif + +#ifdef ENABLE_QY_API +#define IF_ENABLE_QY_API(...) __VA_ARGS__ +#else +#define IF_ENABLE_QY_API(...) +#endif + +#ifdef ENABLE_METAX_API +#define IF_ENABLE_METAX_API(...) __VA_ARGS__ +#else +#define IF_ENABLE_METAX_API(...) +#endif + +#ifdef ENABLE_KUNLUN_API +#define IF_ENABLE_KUNLUN_API(...) __VA_ARGS__ +#else +#define IF_ENABLE_KUNLUN_API(...) +#endif + +#ifdef ENABLE_CAMBRICON_API +#define IF_ENABLE_CAMBRICON_API(...) __VA_ARGS__ +#else +#define IF_ENABLE_CAMBRICON_API(...) +#endif + +#ifdef ENABLE_MOORE_API +#define IF_ENABLE_MOORE_API(...) __VA_ARGS__ +#else +#define IF_ENABLE_MOORE_API(...) +#endif + +/** + * Binary operator implementation macros + */ +#define BINARY_OP_IMPL_CASE(OP_NAME, DEVICE, NAMESPACE, c_desc, a_desc, b_desc) \ + IF_ENABLE_##DEVICE##_API( \ + case INFINI_DEVICE_##DEVICE \ + : return op::OP_NAME::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + c_desc, \ + {a_desc, b_desc});) + +#define BINARY_OP_IMPL_DEVICE_CASES(OP_NAME, c_desc, a_desc, b_desc) \ + BINARY_OP_IMPL_CASE(OP_NAME, CPU, cpu, c_desc, a_desc, b_desc) \ + BINARY_OP_IMPL_CASE(OP_NAME, NVIDIA, nvidia, c_desc, a_desc, b_desc) \ + BINARY_OP_IMPL_CASE(OP_NAME, ILUVATAR, nvidia, c_desc, a_desc, b_desc) \ + BINARY_OP_IMPL_CASE(OP_NAME, QY, nvidia, c_desc, a_desc, b_desc) \ + BINARY_OP_IMPL_CASE(OP_NAME, METAX, metax, c_desc, a_desc, b_desc) \ + BINARY_OP_IMPL_CASE(OP_NAME, KUNLUN, kunlun, c_desc, a_desc, b_desc) \ + BINARY_OP_IMPL_CASE(OP_NAME, CAMBRICON, bang, c_desc, a_desc, b_desc) \ + BINARY_OP_IMPL_CASE(OP_NAME, MOORE, moore, c_desc, a_desc, b_desc) + +#define BINARY_OP_IMPL_GET_WORKSPACE_CASE(OP_NAME, DEVICE, NAMESPACE) \ + IF_ENABLE_##DEVICE##_API( \ + case INFINI_DEVICE_##DEVICE \ + : \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS;) + +#define BINARY_OP_IMPL_GET_WORKSPACE_CASES(OP_NAME) \ + BINARY_OP_IMPL_GET_WORKSPACE_CASE(OP_NAME, CPU, cpu) \ + BINARY_OP_IMPL_GET_WORKSPACE_CASE(OP_NAME, NVIDIA, nvidia) \ + BINARY_OP_IMPL_GET_WORKSPACE_CASE(OP_NAME, ILUVATAR, nvidia) \ + BINARY_OP_IMPL_GET_WORKSPACE_CASE(OP_NAME, QY, nvidia) \ + BINARY_OP_IMPL_GET_WORKSPACE_CASE(OP_NAME, METAX, metax) \ + BINARY_OP_IMPL_GET_WORKSPACE_CASE(OP_NAME, KUNLUN, kunlun) \ + BINARY_OP_IMPL_GET_WORKSPACE_CASE(OP_NAME, CAMBRICON, bang) \ + BINARY_OP_IMPL_GET_WORKSPACE_CASE(OP_NAME, MOORE, moore) + +#define BINARY_OP_IMPL_CALCULATE_CASE(OP_NAME, DEVICE, NAMESPACE, c, a, b) \ + IF_ENABLE_##DEVICE##_API( \ + case INFINI_DEVICE_##DEVICE \ + : return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, c, {a, b}, stream);) + +#define BINARY_OP_IMPL_CALCULATE_CASES(OP_NAME, c, a, b) \ + BINARY_OP_IMPL_CALCULATE_CASE(OP_NAME, CPU, cpu, c, a, b) \ + BINARY_OP_IMPL_CALCULATE_CASE(OP_NAME, NVIDIA, nvidia, c, a, b) \ + BINARY_OP_IMPL_CALCULATE_CASE(OP_NAME, ILUVATAR, nvidia, c, a, b) \ + BINARY_OP_IMPL_CALCULATE_CASE(OP_NAME, QY, nvidia, c, a, b) \ + BINARY_OP_IMPL_CALCULATE_CASE(OP_NAME, METAX, metax, c, a, b) \ + BINARY_OP_IMPL_CALCULATE_CASE(OP_NAME, KUNLUN, kunlun, c, a, b) \ + BINARY_OP_IMPL_CALCULATE_CASE(OP_NAME, CAMBRICON, bang, c, a, b) \ + BINARY_OP_IMPL_CALCULATE_CASE(OP_NAME, MOORE, moore, c, a, b) + +#define BINARY_OP_IMPL_DESTROY_CASE(OP_NAME, DEVICE, NAMESPACE) \ + IF_ENABLE_##DEVICE##_API( \ + case INFINI_DEVICE_##DEVICE \ + : delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS;) + +#define BINARY_OP_IMPL_DESTROY_CASES(OP_NAME) \ + BINARY_OP_IMPL_DESTROY_CASE(OP_NAME, CPU, cpu) \ + BINARY_OP_IMPL_DESTROY_CASE(OP_NAME, NVIDIA, nvidia) \ + BINARY_OP_IMPL_DESTROY_CASE(OP_NAME, ILUVATAR, nvidia) \ + BINARY_OP_IMPL_DESTROY_CASE(OP_NAME, QY, nvidia) \ + BINARY_OP_IMPL_DESTROY_CASE(OP_NAME, METAX, metax) \ + BINARY_OP_IMPL_DESTROY_CASE(OP_NAME, KUNLUN, kunlun) \ + BINARY_OP_IMPL_DESTROY_CASE(OP_NAME, CAMBRICON, bang) \ + BINARY_OP_IMPL_DESTROY_CASE(OP_NAME, MOORE, moore) + +#define BINARY_OP_IMPL(OP_NAME, OP_NAME_UPPER) \ + __C infiniStatus_t infiniopCreate##OP_NAME_UPPER##Descriptor( \ + infiniopHandle_t handle, \ + infiniop##OP_NAME_UPPER##Descriptor_t *desc_ptr, \ + infiniopTensorDescriptor_t c_desc, \ + infiniopTensorDescriptor_t a_desc, \ + infiniopTensorDescriptor_t b_desc) { \ + switch (handle->device) { \ + BINARY_OP_IMPL_DEVICE_CASES(OP_NAME, c_desc, a_desc, b_desc) \ + default: \ + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; \ + } \ + } \ + __C infiniStatus_t infiniopGet##OP_NAME_UPPER##WorkspaceSize( \ + infiniop##OP_NAME_UPPER##Descriptor_t desc, \ + size_t *size) { \ + switch (desc->device_type) { \ + BINARY_OP_IMPL_GET_WORKSPACE_CASES(OP_NAME) \ + default: \ + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; \ + } \ + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; \ + } \ + __C infiniStatus_t infiniop##OP_NAME_UPPER( \ + infiniop##OP_NAME_UPPER##Descriptor_t desc, \ + void *workspace, \ + size_t workspace_size, \ + void *c, \ + const void *a, \ + const void *b, \ + void *stream) { \ + switch (desc->device_type) { \ + BINARY_OP_IMPL_CALCULATE_CASES(OP_NAME, c, a, b) \ + default: \ + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; \ + } \ + } \ + __C infiniStatus_t infiniopDestroy##OP_NAME_UPPER##Descriptor( \ + infiniop##OP_NAME_UPPER##Descriptor_t desc) { \ + switch (desc->device_type) { \ + BINARY_OP_IMPL_DESTROY_CASES(OP_NAME) \ + default: \ + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; \ + } \ + } + +/** + * Unary operator implementation macros + */ +#define UNARY_OP_IMPL_CASE(OP_NAME, DEVICE, NAMESPACE, y_desc, x_desc) \ + IF_ENABLE_##DEVICE##_API( \ + case INFINI_DEVICE_##DEVICE \ + : return op::OP_NAME::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + y_desc, \ + {x_desc});) + +#define UNARY_OP_IMPL_DEVICE_CASES(OP_NAME, y_desc, x_desc) \ + UNARY_OP_IMPL_CASE(OP_NAME, CPU, cpu, y_desc, x_desc) \ + UNARY_OP_IMPL_CASE(OP_NAME, NVIDIA, nvidia, y_desc, x_desc) \ + UNARY_OP_IMPL_CASE(OP_NAME, ILUVATAR, nvidia, y_desc, x_desc) \ + UNARY_OP_IMPL_CASE(OP_NAME, QY, nvidia, y_desc, x_desc) \ + UNARY_OP_IMPL_CASE(OP_NAME, METAX, metax, y_desc, x_desc) \ + UNARY_OP_IMPL_CASE(OP_NAME, KUNLUN, kunlun, y_desc, x_desc) \ + UNARY_OP_IMPL_CASE(OP_NAME, CAMBRICON, bang, y_desc, x_desc) \ + UNARY_OP_IMPL_CASE(OP_NAME, MOORE, moore, y_desc, x_desc) + +#define UNARY_OP_IMPL_GET_WORKSPACE_CASE(OP_NAME, DEVICE, NAMESPACE) \ + IF_ENABLE_##DEVICE##_API( \ + case INFINI_DEVICE_##DEVICE \ + : \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS;) + +#define UNARY_OP_IMPL_GET_WORKSPACE_CASES(OP_NAME) \ + UNARY_OP_IMPL_GET_WORKSPACE_CASE(OP_NAME, CPU, cpu) \ + UNARY_OP_IMPL_GET_WORKSPACE_CASE(OP_NAME, NVIDIA, nvidia) \ + UNARY_OP_IMPL_GET_WORKSPACE_CASE(OP_NAME, ILUVATAR, nvidia) \ + UNARY_OP_IMPL_GET_WORKSPACE_CASE(OP_NAME, QY, nvidia) \ + UNARY_OP_IMPL_GET_WORKSPACE_CASE(OP_NAME, METAX, metax) \ + UNARY_OP_IMPL_GET_WORKSPACE_CASE(OP_NAME, KUNLUN, kunlun) \ + UNARY_OP_IMPL_GET_WORKSPACE_CASE(OP_NAME, CAMBRICON, bang) \ + UNARY_OP_IMPL_GET_WORKSPACE_CASE(OP_NAME, MOORE, moore) + +#define UNARY_OP_IMPL_CALCULATE_CASE(OP_NAME, DEVICE, NAMESPACE, y, x) \ + IF_ENABLE_##DEVICE##_API( \ + case INFINI_DEVICE_##DEVICE \ + : return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, y, {x}, stream);) + +#define UNARY_OP_IMPL_CALCULATE_CASES(OP_NAME, y, x) \ + UNARY_OP_IMPL_CALCULATE_CASE(OP_NAME, CPU, cpu, y, x) \ + UNARY_OP_IMPL_CALCULATE_CASE(OP_NAME, NVIDIA, nvidia, y, x) \ + UNARY_OP_IMPL_CALCULATE_CASE(OP_NAME, ILUVATAR, nvidia, y, x) \ + UNARY_OP_IMPL_CALCULATE_CASE(OP_NAME, QY, nvidia, y, x) \ + UNARY_OP_IMPL_CALCULATE_CASE(OP_NAME, METAX, metax, y, x) \ + UNARY_OP_IMPL_CALCULATE_CASE(OP_NAME, KUNLUN, kunlun, y, x) \ + UNARY_OP_IMPL_CALCULATE_CASE(OP_NAME, CAMBRICON, bang, y, x) \ + UNARY_OP_IMPL_CALCULATE_CASE(OP_NAME, MOORE, moore, y, x) + +#define UNARY_OP_IMPL_DESTROY_CASE(OP_NAME, DEVICE, NAMESPACE) \ + IF_ENABLE_##DEVICE##_API( \ + case INFINI_DEVICE_##DEVICE \ + : delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS;) + +#define UNARY_OP_IMPL_DESTROY_CASES(OP_NAME) \ + UNARY_OP_IMPL_DESTROY_CASE(OP_NAME, CPU, cpu) \ + UNARY_OP_IMPL_DESTROY_CASE(OP_NAME, NVIDIA, nvidia) \ + UNARY_OP_IMPL_DESTROY_CASE(OP_NAME, ILUVATAR, nvidia) \ + UNARY_OP_IMPL_DESTROY_CASE(OP_NAME, QY, nvidia) \ + UNARY_OP_IMPL_DESTROY_CASE(OP_NAME, METAX, metax) \ + UNARY_OP_IMPL_DESTROY_CASE(OP_NAME, KUNLUN, kunlun) \ + UNARY_OP_IMPL_DESTROY_CASE(OP_NAME, CAMBRICON, bang) \ + UNARY_OP_IMPL_DESTROY_CASE(OP_NAME, MOORE, moore) + +#define UNARY_OP_IMPL(OP_NAME, OP_NAME_UPPER) \ + __C infiniStatus_t infiniopCreate##OP_NAME_UPPER##Descriptor( \ + infiniopHandle_t handle, \ + infiniop##OP_NAME_UPPER##Descriptor_t *desc_ptr, \ + infiniopTensorDescriptor_t y_desc, \ + infiniopTensorDescriptor_t x_desc) { \ + switch (handle->device) { \ + UNARY_OP_IMPL_DEVICE_CASES(OP_NAME, y_desc, x_desc) \ + default: \ + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; \ + } \ + } \ + __C infiniStatus_t infiniopGet##OP_NAME_UPPER##WorkspaceSize( \ + infiniop##OP_NAME_UPPER##Descriptor_t desc, \ + size_t *size) { \ + switch (desc->device_type) { \ + UNARY_OP_IMPL_GET_WORKSPACE_CASES(OP_NAME) \ + default: \ + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; \ + } \ + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; \ + } \ + __C infiniStatus_t infiniop##OP_NAME_UPPER( \ + infiniop##OP_NAME_UPPER##Descriptor_t desc, \ + void *workspace, \ + size_t workspace_size, \ + void *y, \ + const void *x, \ + void *stream) { \ + switch (desc->device_type) { \ + UNARY_OP_IMPL_CALCULATE_CASES(OP_NAME, y, x) \ + default: \ + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; \ + } \ + } \ + __C infiniStatus_t infiniopDestroy##OP_NAME_UPPER##Descriptor( \ + infiniop##OP_NAME_UPPER##Descriptor_t desc) { \ + switch (desc->device_type) { \ + UNARY_OP_IMPL_DESTROY_CASES(OP_NAME) \ + default: \ + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; \ + } \ + } + +#endif // __INFINIOP_OPERATOR_IMPL_H__ diff --git a/src/infiniop/ops/abs/cpu/abs_cpu.cc b/src/infiniop/ops/abs/cpu/abs_cpu.cc new file mode 100644 index 000000000..d4b541ba7 --- /dev/null +++ b/src/infiniop/ops/abs/cpu/abs_cpu.cc @@ -0,0 +1,8 @@ +#include "abs_cpu.h" +#include "../../../elementwise/cpu/elementwise_cpu_impl.h" + +namespace op::abs::cpu { + +ELEMENTWISE_CPU_IMPL_UNARY(abs) + +} // namespace op::abs::cpu diff --git a/src/infiniop/ops/abs/cpu/abs_cpu.h b/src/infiniop/ops/abs/cpu/abs_cpu.h new file mode 100644 index 000000000..cba8274e6 --- /dev/null +++ b/src/infiniop/ops/abs/cpu/abs_cpu.h @@ -0,0 +1,9 @@ +#ifndef __ABS_CPU_H__ +#define __ABS_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" +#include "../../../elementwise/unary.h" + +UNARY_ELEMENTWISE_DESCRIPTOR(abs, cpu, op::elementwise::unary::UnaryMode::Abs) + +#endif // __ABS_CPU_H__ diff --git a/src/infiniop/ops/abs/cuda/kernel.cuh b/src/infiniop/ops/abs/cuda/kernel.cuh new file mode 100644 index 000000000..406aa423f --- /dev/null +++ b/src/infiniop/ops/abs/cuda/kernel.cuh @@ -0,0 +1,10 @@ +#ifndef __ABS_CUDA_H__ +#define __ABS_CUDA_H__ + +#include "../../../elementwise/unary.h" + +namespace op::abs::cuda { +using Op = op::elementwise::unary::cuda::UnaryOp; +} // namespace op::abs::cuda + +#endif // __ABS_CUDA_H__ diff --git a/src/infiniop/ops/abs/nvidia/abs_nvidia.cu b/src/infiniop/ops/abs/nvidia/abs_nvidia.cu new file mode 100644 index 000000000..b9687226a --- /dev/null +++ b/src/infiniop/ops/abs/nvidia/abs_nvidia.cu @@ -0,0 +1,10 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia_impl.cuh" + +#include "../cuda/kernel.cuh" +#include "abs_nvidia.cuh" + +namespace op::abs::nvidia { + +ELEMENTWISE_NVIDIA_IMPL_UNARY(abs) + +} // namespace op::abs::nvidia diff --git a/src/infiniop/ops/abs/nvidia/abs_nvidia.cuh b/src/infiniop/ops/abs/nvidia/abs_nvidia.cuh new file mode 100644 index 000000000..db1751e26 --- /dev/null +++ b/src/infiniop/ops/abs/nvidia/abs_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __ABS_NVIDIA_API_H__ +#define __ABS_NVIDIA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(abs, nvidia) + +#endif // __ABS_NVIDIA_API_H__ diff --git a/src/infiniop/ops/abs/operator.cc b/src/infiniop/ops/abs/operator.cc new file mode 100644 index 000000000..8439236eb --- /dev/null +++ b/src/infiniop/ops/abs/operator.cc @@ -0,0 +1,11 @@ +#include "../../operator_impl.h" +#include "infiniop/ops/unary_ops_api.h" + +#ifdef ENABLE_CPU_API +#include "cpu/abs_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) +#include "nvidia/abs_nvidia.cuh" +#endif + +UNARY_OP_IMPL(abs, Abs) diff --git a/src/infiniop/ops/acos/cpu/acos_cpu.cc b/src/infiniop/ops/acos/cpu/acos_cpu.cc new file mode 100644 index 000000000..9be4ca1fe --- /dev/null +++ b/src/infiniop/ops/acos/cpu/acos_cpu.cc @@ -0,0 +1,8 @@ +#include "acos_cpu.h" +#include "../../../elementwise/cpu/elementwise_cpu_impl.h" + +namespace op::acos::cpu { + +ELEMENTWISE_CPU_IMPL_UNARY(acos) + +} // namespace op::acos::cpu diff --git a/src/infiniop/ops/acos/cpu/acos_cpu.h b/src/infiniop/ops/acos/cpu/acos_cpu.h new file mode 100644 index 000000000..50900e217 --- /dev/null +++ b/src/infiniop/ops/acos/cpu/acos_cpu.h @@ -0,0 +1,9 @@ +#ifndef __ACOS_CPU_H__ +#define __ACOS_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" +#include "../../../elementwise/unary.h" + +UNARY_ELEMENTWISE_DESCRIPTOR(acos, cpu, op::elementwise::unary::UnaryMode::Acos) + +#endif // __ACOS_CPU_H__ diff --git a/src/infiniop/ops/acos/cuda/kernel.cuh b/src/infiniop/ops/acos/cuda/kernel.cuh new file mode 100644 index 000000000..b62bf1e88 --- /dev/null +++ b/src/infiniop/ops/acos/cuda/kernel.cuh @@ -0,0 +1,10 @@ +#ifndef __ACOS_CUDA_H__ +#define __ACOS_CUDA_H__ + +#include "../../../elementwise/unary.h" + +namespace op::acos::cuda { +using Op = op::elementwise::unary::cuda::UnaryOp; +} // namespace op::acos::cuda + +#endif // __ACOS_CUDA_H__ diff --git a/src/infiniop/ops/acos/nvidia/acos_nvidia.cu b/src/infiniop/ops/acos/nvidia/acos_nvidia.cu new file mode 100644 index 000000000..e7cf1feea --- /dev/null +++ b/src/infiniop/ops/acos/nvidia/acos_nvidia.cu @@ -0,0 +1,10 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia_impl.cuh" + +#include "../cuda/kernel.cuh" +#include "acos_nvidia.cuh" + +namespace op::acos::nvidia { + +ELEMENTWISE_NVIDIA_IMPL_UNARY(acos) + +} // namespace op::acos::nvidia diff --git a/src/infiniop/ops/acos/nvidia/acos_nvidia.cuh b/src/infiniop/ops/acos/nvidia/acos_nvidia.cuh new file mode 100644 index 000000000..a7ac7e190 --- /dev/null +++ b/src/infiniop/ops/acos/nvidia/acos_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __ACOS_NVIDIA_API_H__ +#define __ACOS_NVIDIA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(acos, nvidia) + +#endif // __ACOS_NVIDIA_API_H__ diff --git a/src/infiniop/ops/acos/operator.cc b/src/infiniop/ops/acos/operator.cc new file mode 100644 index 000000000..3fd50fb51 --- /dev/null +++ b/src/infiniop/ops/acos/operator.cc @@ -0,0 +1,11 @@ +#include "../../operator_impl.h" +#include "infiniop/ops/unary_ops_api.h" + +#ifdef ENABLE_CPU_API +#include "cpu/acos_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) +#include "nvidia/acos_nvidia.cuh" +#endif + +UNARY_OP_IMPL(acos, Acos) diff --git a/src/infiniop/ops/acosh/cpu/acosh_cpu.cc b/src/infiniop/ops/acosh/cpu/acosh_cpu.cc new file mode 100644 index 000000000..0cb424c00 --- /dev/null +++ b/src/infiniop/ops/acosh/cpu/acosh_cpu.cc @@ -0,0 +1,8 @@ +#include "acosh_cpu.h" +#include "../../../elementwise/cpu/elementwise_cpu_impl.h" + +namespace op::acosh::cpu { + +ELEMENTWISE_CPU_IMPL_UNARY(acosh) + +} // namespace op::acosh::cpu diff --git a/src/infiniop/ops/acosh/cpu/acosh_cpu.h b/src/infiniop/ops/acosh/cpu/acosh_cpu.h new file mode 100644 index 000000000..bb05baf14 --- /dev/null +++ b/src/infiniop/ops/acosh/cpu/acosh_cpu.h @@ -0,0 +1,9 @@ +#ifndef __ACOSH_CPU_H__ +#define __ACOSH_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" +#include "../../../elementwise/unary.h" + +UNARY_ELEMENTWISE_DESCRIPTOR(acosh, cpu, op::elementwise::unary::UnaryMode::Acosh) + +#endif // __ACOSH_CPU_H__ diff --git a/src/infiniop/ops/acosh/cuda/kernel.cuh b/src/infiniop/ops/acosh/cuda/kernel.cuh new file mode 100644 index 000000000..9fbb54636 --- /dev/null +++ b/src/infiniop/ops/acosh/cuda/kernel.cuh @@ -0,0 +1,10 @@ +#ifndef __ACOSH_CUDA_H__ +#define __ACOSH_CUDA_H__ + +#include "../../../elementwise/unary.h" + +namespace op::acosh::cuda { +using Op = op::elementwise::unary::cuda::UnaryOp; +} // namespace op::acosh::cuda + +#endif // __ACOSH_CUDA_H__ diff --git a/src/infiniop/ops/acosh/nvidia/acosh_nvidia.cu b/src/infiniop/ops/acosh/nvidia/acosh_nvidia.cu new file mode 100644 index 000000000..5d065bdbc --- /dev/null +++ b/src/infiniop/ops/acosh/nvidia/acosh_nvidia.cu @@ -0,0 +1,10 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia_impl.cuh" + +#include "../cuda/kernel.cuh" +#include "acosh_nvidia.cuh" + +namespace op::acosh::nvidia { + +ELEMENTWISE_NVIDIA_IMPL_UNARY(acosh) + +} // namespace op::acosh::nvidia diff --git a/src/infiniop/ops/acosh/nvidia/acosh_nvidia.cuh b/src/infiniop/ops/acosh/nvidia/acosh_nvidia.cuh new file mode 100644 index 000000000..b13332431 --- /dev/null +++ b/src/infiniop/ops/acosh/nvidia/acosh_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __ACOSH_NVIDIA_API_H__ +#define __ACOSH_NVIDIA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(acosh, nvidia) + +#endif // __ACOSH_NVIDIA_API_H__ diff --git a/src/infiniop/ops/acosh/operator.cc b/src/infiniop/ops/acosh/operator.cc new file mode 100644 index 000000000..0fb30c0f6 --- /dev/null +++ b/src/infiniop/ops/acosh/operator.cc @@ -0,0 +1,11 @@ +#include "../../operator_impl.h" +#include "infiniop/ops/unary_ops_api.h" + +#ifdef ENABLE_CPU_API +#include "cpu/acosh_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) +#include "nvidia/acosh_nvidia.cuh" +#endif + +UNARY_OP_IMPL(acosh, Acosh) diff --git a/src/infiniop/ops/asin/cpu/asin_cpu.cc b/src/infiniop/ops/asin/cpu/asin_cpu.cc new file mode 100644 index 000000000..de42639ff --- /dev/null +++ b/src/infiniop/ops/asin/cpu/asin_cpu.cc @@ -0,0 +1,8 @@ +#include "asin_cpu.h" +#include "../../../elementwise/cpu/elementwise_cpu_impl.h" + +namespace op::asin::cpu { + +ELEMENTWISE_CPU_IMPL_UNARY(asin) + +} // namespace op::asin::cpu diff --git a/src/infiniop/ops/asin/cpu/asin_cpu.h b/src/infiniop/ops/asin/cpu/asin_cpu.h new file mode 100644 index 000000000..8c6da5e20 --- /dev/null +++ b/src/infiniop/ops/asin/cpu/asin_cpu.h @@ -0,0 +1,9 @@ +#ifndef __ASIN_CPU_H__ +#define __ASIN_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" +#include "../../../elementwise/unary.h" + +UNARY_ELEMENTWISE_DESCRIPTOR(asin, cpu, op::elementwise::unary::UnaryMode::Asin) + +#endif // __ASIN_CPU_H__ diff --git a/src/infiniop/ops/asin/cuda/kernel.cuh b/src/infiniop/ops/asin/cuda/kernel.cuh new file mode 100644 index 000000000..a7063f015 --- /dev/null +++ b/src/infiniop/ops/asin/cuda/kernel.cuh @@ -0,0 +1,10 @@ +#ifndef __ASIN_CUDA_H__ +#define __ASIN_CUDA_H__ + +#include "../../../elementwise/unary.h" + +namespace op::asin::cuda { +using Op = op::elementwise::unary::cuda::UnaryOp; +} // namespace op::asin::cuda + +#endif // __ASIN_CUDA_H__ diff --git a/src/infiniop/ops/asin/nvidia/asin_nvidia.cu b/src/infiniop/ops/asin/nvidia/asin_nvidia.cu new file mode 100644 index 000000000..262755d50 --- /dev/null +++ b/src/infiniop/ops/asin/nvidia/asin_nvidia.cu @@ -0,0 +1,10 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia_impl.cuh" + +#include "../cuda/kernel.cuh" +#include "asin_nvidia.cuh" + +namespace op::asin::nvidia { + +ELEMENTWISE_NVIDIA_IMPL_UNARY(asin) + +} // namespace op::asin::nvidia diff --git a/src/infiniop/ops/asin/nvidia/asin_nvidia.cuh b/src/infiniop/ops/asin/nvidia/asin_nvidia.cuh new file mode 100644 index 000000000..46e168ede --- /dev/null +++ b/src/infiniop/ops/asin/nvidia/asin_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __ASIN_NVIDIA_API_H__ +#define __ASIN_NVIDIA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(asin, nvidia) + +#endif // __ASIN_NVIDIA_API_H__ diff --git a/src/infiniop/ops/asin/operator.cc b/src/infiniop/ops/asin/operator.cc new file mode 100644 index 000000000..8ed07d55d --- /dev/null +++ b/src/infiniop/ops/asin/operator.cc @@ -0,0 +1,11 @@ +#include "../../operator_impl.h" +#include "infiniop/ops/unary_ops_api.h" + +#ifdef ENABLE_CPU_API +#include "cpu/asin_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) +#include "nvidia/asin_nvidia.cuh" +#endif + +UNARY_OP_IMPL(asin, Asin) diff --git a/src/infiniop/ops/asinh/cpu/asinh_cpu.cc b/src/infiniop/ops/asinh/cpu/asinh_cpu.cc new file mode 100644 index 000000000..8b18ab6f8 --- /dev/null +++ b/src/infiniop/ops/asinh/cpu/asinh_cpu.cc @@ -0,0 +1,8 @@ +#include "asinh_cpu.h" +#include "../../../elementwise/cpu/elementwise_cpu_impl.h" + +namespace op::asinh::cpu { + +ELEMENTWISE_CPU_IMPL_UNARY(asinh) + +} // namespace op::asinh::cpu diff --git a/src/infiniop/ops/asinh/cpu/asinh_cpu.h b/src/infiniop/ops/asinh/cpu/asinh_cpu.h new file mode 100644 index 000000000..4c3603752 --- /dev/null +++ b/src/infiniop/ops/asinh/cpu/asinh_cpu.h @@ -0,0 +1,9 @@ +#ifndef __ASINH_CPU_H__ +#define __ASINH_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" +#include "../../../elementwise/unary.h" + +UNARY_ELEMENTWISE_DESCRIPTOR(asinh, cpu, op::elementwise::unary::UnaryMode::Asinh) + +#endif // __ASINH_CPU_H__ diff --git a/src/infiniop/ops/asinh/cuda/kernel.cuh b/src/infiniop/ops/asinh/cuda/kernel.cuh new file mode 100644 index 000000000..866ea147a --- /dev/null +++ b/src/infiniop/ops/asinh/cuda/kernel.cuh @@ -0,0 +1,10 @@ +#ifndef __ASINH_CUDA_H__ +#define __ASINH_CUDA_H__ + +#include "../../../elementwise/unary.h" + +namespace op::asinh::cuda { +using Op = op::elementwise::unary::cuda::UnaryOp; +} // namespace op::asinh::cuda + +#endif // __ASINH_CUDA_H__ diff --git a/src/infiniop/ops/asinh/nvidia/asinh_nvidia.cu b/src/infiniop/ops/asinh/nvidia/asinh_nvidia.cu new file mode 100644 index 000000000..37c44baf0 --- /dev/null +++ b/src/infiniop/ops/asinh/nvidia/asinh_nvidia.cu @@ -0,0 +1,10 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia_impl.cuh" + +#include "../cuda/kernel.cuh" +#include "asinh_nvidia.cuh" + +namespace op::asinh::nvidia { + +ELEMENTWISE_NVIDIA_IMPL_UNARY(asinh) + +} // namespace op::asinh::nvidia diff --git a/src/infiniop/ops/asinh/nvidia/asinh_nvidia.cuh b/src/infiniop/ops/asinh/nvidia/asinh_nvidia.cuh new file mode 100644 index 000000000..d1dcb4287 --- /dev/null +++ b/src/infiniop/ops/asinh/nvidia/asinh_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __ASINH_NVIDIA_API_H__ +#define __ASINH_NVIDIA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(asinh, nvidia) + +#endif // __ASINH_NVIDIA_API_H__ diff --git a/src/infiniop/ops/asinh/operator.cc b/src/infiniop/ops/asinh/operator.cc new file mode 100644 index 000000000..020f83dc4 --- /dev/null +++ b/src/infiniop/ops/asinh/operator.cc @@ -0,0 +1,11 @@ +#include "../../operator_impl.h" +#include "infiniop/ops/unary_ops_api.h" + +#ifdef ENABLE_CPU_API +#include "cpu/asinh_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) +#include "nvidia/asinh_nvidia.cuh" +#endif + +UNARY_OP_IMPL(asinh, Asinh) diff --git a/src/infiniop/ops/atan/cpu/atan_cpu.cc b/src/infiniop/ops/atan/cpu/atan_cpu.cc new file mode 100644 index 000000000..075c7fd4e --- /dev/null +++ b/src/infiniop/ops/atan/cpu/atan_cpu.cc @@ -0,0 +1,8 @@ +#include "atan_cpu.h" +#include "../../../elementwise/cpu/elementwise_cpu_impl.h" + +namespace op::atan::cpu { + +ELEMENTWISE_CPU_IMPL_UNARY(atan) + +} // namespace op::atan::cpu diff --git a/src/infiniop/ops/atan/cpu/atan_cpu.h b/src/infiniop/ops/atan/cpu/atan_cpu.h new file mode 100644 index 000000000..6b333cfb1 --- /dev/null +++ b/src/infiniop/ops/atan/cpu/atan_cpu.h @@ -0,0 +1,9 @@ +#ifndef __ATAN_CPU_H__ +#define __ATAN_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" +#include "../../../elementwise/unary.h" + +UNARY_ELEMENTWISE_DESCRIPTOR(atan, cpu, op::elementwise::unary::UnaryMode::Atan) + +#endif // __ATAN_CPU_H__ diff --git a/src/infiniop/ops/atan/cuda/kernel.cuh b/src/infiniop/ops/atan/cuda/kernel.cuh new file mode 100644 index 000000000..ce553c1c1 --- /dev/null +++ b/src/infiniop/ops/atan/cuda/kernel.cuh @@ -0,0 +1,10 @@ +#ifndef __ATAN_CUDA_H__ +#define __ATAN_CUDA_H__ + +#include "../../../elementwise/unary.h" + +namespace op::atan::cuda { +using Op = op::elementwise::unary::cuda::UnaryOp; +} // namespace op::atan::cuda + +#endif // __ATAN_CUDA_H__ diff --git a/src/infiniop/ops/atan/nvidia/atan_nvidia.cu b/src/infiniop/ops/atan/nvidia/atan_nvidia.cu new file mode 100644 index 000000000..a05d65b79 --- /dev/null +++ b/src/infiniop/ops/atan/nvidia/atan_nvidia.cu @@ -0,0 +1,10 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia_impl.cuh" + +#include "../cuda/kernel.cuh" +#include "atan_nvidia.cuh" + +namespace op::atan::nvidia { + +ELEMENTWISE_NVIDIA_IMPL_UNARY(atan) + +} // namespace op::atan::nvidia diff --git a/src/infiniop/ops/atan/nvidia/atan_nvidia.cuh b/src/infiniop/ops/atan/nvidia/atan_nvidia.cuh new file mode 100644 index 000000000..2aaee1ad9 --- /dev/null +++ b/src/infiniop/ops/atan/nvidia/atan_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __ATAN_NVIDIA_API_H__ +#define __ATAN_NVIDIA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(atan, nvidia) + +#endif // __ATAN_NVIDIA_API_H__ diff --git a/src/infiniop/ops/atan/operator.cc b/src/infiniop/ops/atan/operator.cc new file mode 100644 index 000000000..2ee3ad449 --- /dev/null +++ b/src/infiniop/ops/atan/operator.cc @@ -0,0 +1,11 @@ +#include "../../operator_impl.h" +#include "infiniop/ops/unary_ops_api.h" + +#ifdef ENABLE_CPU_API +#include "cpu/atan_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) +#include "nvidia/atan_nvidia.cuh" +#endif + +UNARY_OP_IMPL(atan, Atan) diff --git a/src/infiniop/ops/atanh/cpu/atanh_cpu.cc b/src/infiniop/ops/atanh/cpu/atanh_cpu.cc new file mode 100644 index 000000000..d19c978e4 --- /dev/null +++ b/src/infiniop/ops/atanh/cpu/atanh_cpu.cc @@ -0,0 +1,8 @@ +#include "atanh_cpu.h" +#include "../../../elementwise/cpu/elementwise_cpu_impl.h" + +namespace op::atanh::cpu { + +ELEMENTWISE_CPU_IMPL_UNARY(atanh) + +} // namespace op::atanh::cpu diff --git a/src/infiniop/ops/atanh/cpu/atanh_cpu.h b/src/infiniop/ops/atanh/cpu/atanh_cpu.h new file mode 100644 index 000000000..1a37453f0 --- /dev/null +++ b/src/infiniop/ops/atanh/cpu/atanh_cpu.h @@ -0,0 +1,9 @@ +#ifndef __ATANH_CPU_H__ +#define __ATANH_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" +#include "../../../elementwise/unary.h" + +UNARY_ELEMENTWISE_DESCRIPTOR(atanh, cpu, op::elementwise::unary::UnaryMode::Atanh) + +#endif // __ATANH_CPU_H__ diff --git a/src/infiniop/ops/atanh/cuda/kernel.cuh b/src/infiniop/ops/atanh/cuda/kernel.cuh new file mode 100644 index 000000000..de0866ba5 --- /dev/null +++ b/src/infiniop/ops/atanh/cuda/kernel.cuh @@ -0,0 +1,10 @@ +#ifndef __ATANH_CUDA_H__ +#define __ATANH_CUDA_H__ + +#include "../../../elementwise/unary.h" + +namespace op::atanh::cuda { +using Op = op::elementwise::unary::cuda::UnaryOp; +} // namespace op::atanh::cuda + +#endif // __ATANH_CUDA_H__ diff --git a/src/infiniop/ops/atanh/nvidia/atanh_nvidia.cu b/src/infiniop/ops/atanh/nvidia/atanh_nvidia.cu new file mode 100644 index 000000000..55b435920 --- /dev/null +++ b/src/infiniop/ops/atanh/nvidia/atanh_nvidia.cu @@ -0,0 +1,10 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia_impl.cuh" + +#include "../cuda/kernel.cuh" +#include "atanh_nvidia.cuh" + +namespace op::atanh::nvidia { + +ELEMENTWISE_NVIDIA_IMPL_UNARY(atanh) + +} // namespace op::atanh::nvidia diff --git a/src/infiniop/ops/atanh/nvidia/atanh_nvidia.cuh b/src/infiniop/ops/atanh/nvidia/atanh_nvidia.cuh new file mode 100644 index 000000000..da73cfa99 --- /dev/null +++ b/src/infiniop/ops/atanh/nvidia/atanh_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __ATANH_NVIDIA_API_H__ +#define __ATANH_NVIDIA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(atanh, nvidia) + +#endif // __ATANH_NVIDIA_API_H__ diff --git a/src/infiniop/ops/atanh/operator.cc b/src/infiniop/ops/atanh/operator.cc new file mode 100644 index 000000000..fb991051c --- /dev/null +++ b/src/infiniop/ops/atanh/operator.cc @@ -0,0 +1,11 @@ +#include "../../operator_impl.h" +#include "infiniop/ops/unary_ops_api.h" + +#ifdef ENABLE_CPU_API +#include "cpu/atanh_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) +#include "nvidia/atanh_nvidia.cuh" +#endif + +UNARY_OP_IMPL(atanh, Atanh) diff --git a/src/infiniop/ops/ceil/cpu/ceil_cpu.cc b/src/infiniop/ops/ceil/cpu/ceil_cpu.cc new file mode 100644 index 000000000..81ca2fe7a --- /dev/null +++ b/src/infiniop/ops/ceil/cpu/ceil_cpu.cc @@ -0,0 +1,8 @@ +#include "ceil_cpu.h" +#include "../../../elementwise/cpu/elementwise_cpu_impl.h" + +namespace op::ceil::cpu { + +ELEMENTWISE_CPU_IMPL_UNARY(ceil) + +} // namespace op::ceil::cpu diff --git a/src/infiniop/ops/ceil/cpu/ceil_cpu.h b/src/infiniop/ops/ceil/cpu/ceil_cpu.h new file mode 100644 index 000000000..423c784cc --- /dev/null +++ b/src/infiniop/ops/ceil/cpu/ceil_cpu.h @@ -0,0 +1,9 @@ +#ifndef __CEIL_CPU_H__ +#define __CEIL_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" +#include "../../../elementwise/unary.h" + +UNARY_ELEMENTWISE_DESCRIPTOR(ceil, cpu, op::elementwise::unary::UnaryMode::Ceil) + +#endif // __CEIL_CPU_H__ diff --git a/src/infiniop/ops/ceil/cuda/kernel.cuh b/src/infiniop/ops/ceil/cuda/kernel.cuh new file mode 100644 index 000000000..1d30a42eb --- /dev/null +++ b/src/infiniop/ops/ceil/cuda/kernel.cuh @@ -0,0 +1,10 @@ +#ifndef __CEIL_CUDA_H__ +#define __CEIL_CUDA_H__ + +#include "../../../elementwise/unary.h" + +namespace op::ceil::cuda { +using Op = op::elementwise::unary::cuda::UnaryOp; +} // namespace op::ceil::cuda + +#endif // __CEIL_CUDA_H__ diff --git a/src/infiniop/ops/ceil/nvidia/ceil_nvidia.cu b/src/infiniop/ops/ceil/nvidia/ceil_nvidia.cu new file mode 100644 index 000000000..88ee35be8 --- /dev/null +++ b/src/infiniop/ops/ceil/nvidia/ceil_nvidia.cu @@ -0,0 +1,10 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia_impl.cuh" + +#include "../cuda/kernel.cuh" +#include "ceil_nvidia.cuh" + +namespace op::ceil::nvidia { + +ELEMENTWISE_NVIDIA_IMPL_UNARY(ceil) + +} // namespace op::ceil::nvidia diff --git a/src/infiniop/ops/ceil/nvidia/ceil_nvidia.cuh b/src/infiniop/ops/ceil/nvidia/ceil_nvidia.cuh new file mode 100644 index 000000000..9bada334d --- /dev/null +++ b/src/infiniop/ops/ceil/nvidia/ceil_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __CEIL_NVIDIA_API_H__ +#define __CEIL_NVIDIA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(ceil, nvidia) + +#endif // __CEIL_NVIDIA_API_H__ diff --git a/src/infiniop/ops/ceil/operator.cc b/src/infiniop/ops/ceil/operator.cc new file mode 100644 index 000000000..26252ec16 --- /dev/null +++ b/src/infiniop/ops/ceil/operator.cc @@ -0,0 +1,11 @@ +#include "../../operator_impl.h" +#include "infiniop/ops/unary_ops_api.h" + +#ifdef ENABLE_CPU_API +#include "cpu/ceil_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) +#include "nvidia/ceil_nvidia.cuh" +#endif + +UNARY_OP_IMPL(ceil, Ceil) diff --git a/src/infiniop/ops/cos/cpu/cos_cpu.cc b/src/infiniop/ops/cos/cpu/cos_cpu.cc new file mode 100644 index 000000000..19ef002cf --- /dev/null +++ b/src/infiniop/ops/cos/cpu/cos_cpu.cc @@ -0,0 +1,8 @@ +#include "cos_cpu.h" +#include "../../../elementwise/cpu/elementwise_cpu_impl.h" + +namespace op::cos::cpu { + +ELEMENTWISE_CPU_IMPL_UNARY(cos) + +} // namespace op::cos::cpu diff --git a/src/infiniop/ops/cos/cpu/cos_cpu.h b/src/infiniop/ops/cos/cpu/cos_cpu.h new file mode 100644 index 000000000..d62aa91b8 --- /dev/null +++ b/src/infiniop/ops/cos/cpu/cos_cpu.h @@ -0,0 +1,9 @@ +#ifndef __COS_CPU_H__ +#define __COS_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" +#include "../../../elementwise/unary.h" + +UNARY_ELEMENTWISE_DESCRIPTOR(cos, cpu, op::elementwise::unary::UnaryMode::Cos) + +#endif // __COS_CPU_H__ diff --git a/src/infiniop/ops/cos/cuda/kernel.cuh b/src/infiniop/ops/cos/cuda/kernel.cuh new file mode 100644 index 000000000..57fe4f50e --- /dev/null +++ b/src/infiniop/ops/cos/cuda/kernel.cuh @@ -0,0 +1,10 @@ +#ifndef __COS_CUDA_H__ +#define __COS_CUDA_H__ + +#include "../../../elementwise/unary.h" + +namespace op::cos::cuda { +using Op = op::elementwise::unary::cuda::UnaryOp; +} // namespace op::cos::cuda + +#endif // __COS_CUDA_H__ diff --git a/src/infiniop/ops/cos/nvidia/cos_nvidia.cu b/src/infiniop/ops/cos/nvidia/cos_nvidia.cu new file mode 100644 index 000000000..5da3c02e8 --- /dev/null +++ b/src/infiniop/ops/cos/nvidia/cos_nvidia.cu @@ -0,0 +1,10 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia_impl.cuh" + +#include "../cuda/kernel.cuh" +#include "cos_nvidia.cuh" + +namespace op::cos::nvidia { + +ELEMENTWISE_NVIDIA_IMPL_UNARY(cos) + +} // namespace op::cos::nvidia diff --git a/src/infiniop/ops/cos/nvidia/cos_nvidia.cuh b/src/infiniop/ops/cos/nvidia/cos_nvidia.cuh new file mode 100644 index 000000000..a9866e4d2 --- /dev/null +++ b/src/infiniop/ops/cos/nvidia/cos_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __COS_NVIDIA_API_H__ +#define __COS_NVIDIA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(cos, nvidia) + +#endif // __COS_NVIDIA_API_H__ diff --git a/src/infiniop/ops/cos/operator.cc b/src/infiniop/ops/cos/operator.cc new file mode 100644 index 000000000..e3d9237a9 --- /dev/null +++ b/src/infiniop/ops/cos/operator.cc @@ -0,0 +1,11 @@ +#include "../../operator_impl.h" +#include "infiniop/ops/unary_ops_api.h" + +#ifdef ENABLE_CPU_API +#include "cpu/cos_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) +#include "nvidia/cos_nvidia.cuh" +#endif + +UNARY_OP_IMPL(cos, Cos) diff --git a/src/infiniop/ops/cosh/cpu/cosh_cpu.cc b/src/infiniop/ops/cosh/cpu/cosh_cpu.cc new file mode 100644 index 000000000..e7b2a6dad --- /dev/null +++ b/src/infiniop/ops/cosh/cpu/cosh_cpu.cc @@ -0,0 +1,8 @@ +#include "cosh_cpu.h" +#include "../../../elementwise/cpu/elementwise_cpu_impl.h" + +namespace op::cosh::cpu { + +ELEMENTWISE_CPU_IMPL_UNARY(cosh) + +} // namespace op::cosh::cpu diff --git a/src/infiniop/ops/cosh/cpu/cosh_cpu.h b/src/infiniop/ops/cosh/cpu/cosh_cpu.h new file mode 100644 index 000000000..c789d38ea --- /dev/null +++ b/src/infiniop/ops/cosh/cpu/cosh_cpu.h @@ -0,0 +1,9 @@ +#ifndef __COSH_CPU_H__ +#define __COSH_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" +#include "../../../elementwise/unary.h" + +UNARY_ELEMENTWISE_DESCRIPTOR(cosh, cpu, op::elementwise::unary::UnaryMode::Cosh) + +#endif // __COSH_CPU_H__ diff --git a/src/infiniop/ops/cosh/cuda/kernel.cuh b/src/infiniop/ops/cosh/cuda/kernel.cuh new file mode 100644 index 000000000..934bfe12d --- /dev/null +++ b/src/infiniop/ops/cosh/cuda/kernel.cuh @@ -0,0 +1,10 @@ +#ifndef __COSH_CUDA_H__ +#define __COSH_CUDA_H__ + +#include "../../../elementwise/unary.h" + +namespace op::cosh::cuda { +using Op = op::elementwise::unary::cuda::UnaryOp; +} // namespace op::cosh::cuda + +#endif // __COSH_CUDA_H__ diff --git a/src/infiniop/ops/cosh/nvidia/cosh_nvidia.cu b/src/infiniop/ops/cosh/nvidia/cosh_nvidia.cu new file mode 100644 index 000000000..038b0373e --- /dev/null +++ b/src/infiniop/ops/cosh/nvidia/cosh_nvidia.cu @@ -0,0 +1,10 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia_impl.cuh" + +#include "../cuda/kernel.cuh" +#include "cosh_nvidia.cuh" + +namespace op::cosh::nvidia { + +ELEMENTWISE_NVIDIA_IMPL_UNARY(cosh) + +} // namespace op::cosh::nvidia diff --git a/src/infiniop/ops/cosh/nvidia/cosh_nvidia.cuh b/src/infiniop/ops/cosh/nvidia/cosh_nvidia.cuh new file mode 100644 index 000000000..6a032b0bb --- /dev/null +++ b/src/infiniop/ops/cosh/nvidia/cosh_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __COSH_NVIDIA_API_H__ +#define __COSH_NVIDIA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(cosh, nvidia) + +#endif // __COSH_NVIDIA_API_H__ diff --git a/src/infiniop/ops/cosh/operator.cc b/src/infiniop/ops/cosh/operator.cc new file mode 100644 index 000000000..c1a6159c1 --- /dev/null +++ b/src/infiniop/ops/cosh/operator.cc @@ -0,0 +1,11 @@ +#include "../../operator_impl.h" +#include "infiniop/ops/unary_ops_api.h" + +#ifdef ENABLE_CPU_API +#include "cpu/cosh_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) +#include "nvidia/cosh_nvidia.cuh" +#endif + +UNARY_OP_IMPL(cosh, Cosh) diff --git a/src/infiniop/ops/div/cpu/div_cpu.cc b/src/infiniop/ops/div/cpu/div_cpu.cc new file mode 100644 index 000000000..6d150070c --- /dev/null +++ b/src/infiniop/ops/div/cpu/div_cpu.cc @@ -0,0 +1,8 @@ +#include "div_cpu.h" +#include "../../../elementwise/cpu/elementwise_cpu_impl.h" + +namespace op::div::cpu { + +ELEMENTWISE_CPU_IMPL_BINARY(div) + +} // namespace op::div::cpu diff --git a/src/infiniop/ops/div/cpu/div_cpu.h b/src/infiniop/ops/div/cpu/div_cpu.h new file mode 100644 index 000000000..ad76e7ef1 --- /dev/null +++ b/src/infiniop/ops/div/cpu/div_cpu.h @@ -0,0 +1,9 @@ +#ifndef __DIV_CPU_H__ +#define __DIV_CPU_H__ + +#include "../../../elementwise/binary.h" +#include "../../../elementwise/cpu/elementwise_cpu.h" + +BINARY_ELEMENTWISE_DESCRIPTOR(div, cpu, op::elementwise::binary::BinaryMode::Divide) + +#endif // __DIV_CPU_H__ diff --git a/src/infiniop/ops/div/cuda/kernel.cuh b/src/infiniop/ops/div/cuda/kernel.cuh new file mode 100644 index 000000000..f1ab13152 --- /dev/null +++ b/src/infiniop/ops/div/cuda/kernel.cuh @@ -0,0 +1,10 @@ +#ifndef __DIV_CUDA_H__ +#define __DIV_CUDA_H__ + +#include "../../../elementwise/binary.h" + +namespace op::div::cuda { +using Op = op::elementwise::binary::cuda::BinaryOp; +} // namespace op::div::cuda + +#endif // __DIV_CUDA_H__ diff --git a/src/infiniop/ops/div/nvidia/div_nvidia.cu b/src/infiniop/ops/div/nvidia/div_nvidia.cu new file mode 100644 index 000000000..8aaba09b4 --- /dev/null +++ b/src/infiniop/ops/div/nvidia/div_nvidia.cu @@ -0,0 +1,10 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia_impl.cuh" + +#include "../cuda/kernel.cuh" +#include "div_nvidia.cuh" + +namespace op::div::nvidia { + +ELEMENTWISE_NVIDIA_IMPL_BINARY(div) + +} // namespace op::div::nvidia diff --git a/src/infiniop/ops/div/nvidia/div_nvidia.cuh b/src/infiniop/ops/div/nvidia/div_nvidia.cuh new file mode 100644 index 000000000..1ad8af94e --- /dev/null +++ b/src/infiniop/ops/div/nvidia/div_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __DIV_CUDA_API_H__ +#define __DIV_CUDA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(div, nvidia) + +#endif // __DIV_CUDA_API_H__ diff --git a/src/infiniop/ops/div/operator.cc b/src/infiniop/ops/div/operator.cc new file mode 100644 index 000000000..4ed2374af --- /dev/null +++ b/src/infiniop/ops/div/operator.cc @@ -0,0 +1,11 @@ +#include "../../operator_impl.h" +#include "infiniop/ops/binary_ops_api.h" + +#ifdef ENABLE_CPU_API +#include "cpu/div_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) +#include "nvidia/div_nvidia.cuh" +#endif + +BINARY_OP_IMPL(div, Div) diff --git a/src/infiniop/ops/erf/cpu/erf_cpu.cc b/src/infiniop/ops/erf/cpu/erf_cpu.cc new file mode 100644 index 000000000..d9119c697 --- /dev/null +++ b/src/infiniop/ops/erf/cpu/erf_cpu.cc @@ -0,0 +1,8 @@ +#include "erf_cpu.h" +#include "../../../elementwise/cpu/elementwise_cpu_impl.h" + +namespace op::erf::cpu { + +ELEMENTWISE_CPU_IMPL_UNARY(erf) + +} // namespace op::erf::cpu diff --git a/src/infiniop/ops/erf/cpu/erf_cpu.h b/src/infiniop/ops/erf/cpu/erf_cpu.h new file mode 100644 index 000000000..f50cd157d --- /dev/null +++ b/src/infiniop/ops/erf/cpu/erf_cpu.h @@ -0,0 +1,9 @@ +#ifndef __ERF_CPU_H__ +#define __ERF_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" +#include "../../../elementwise/unary.h" + +UNARY_ELEMENTWISE_DESCRIPTOR(erf, cpu, op::elementwise::unary::UnaryMode::Erf) + +#endif // __ERF_CPU_H__ diff --git a/src/infiniop/ops/erf/cuda/kernel.cuh b/src/infiniop/ops/erf/cuda/kernel.cuh new file mode 100644 index 000000000..978890cff --- /dev/null +++ b/src/infiniop/ops/erf/cuda/kernel.cuh @@ -0,0 +1,10 @@ +#ifndef __ERF_CUDA_H__ +#define __ERF_CUDA_H__ + +#include "../../../elementwise/unary.h" + +namespace op::erf::cuda { +using Op = op::elementwise::unary::cuda::UnaryOp; +} // namespace op::erf::cuda + +#endif // __ERF_CUDA_H__ diff --git a/src/infiniop/ops/erf/nvidia/erf_nvidia.cu b/src/infiniop/ops/erf/nvidia/erf_nvidia.cu new file mode 100644 index 000000000..0d743b538 --- /dev/null +++ b/src/infiniop/ops/erf/nvidia/erf_nvidia.cu @@ -0,0 +1,10 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia_impl.cuh" + +#include "../cuda/kernel.cuh" +#include "erf_nvidia.cuh" + +namespace op::erf::nvidia { + +ELEMENTWISE_NVIDIA_IMPL_UNARY(erf) + +} // namespace op::erf::nvidia diff --git a/src/infiniop/ops/erf/nvidia/erf_nvidia.cuh b/src/infiniop/ops/erf/nvidia/erf_nvidia.cuh new file mode 100644 index 000000000..0621150fa --- /dev/null +++ b/src/infiniop/ops/erf/nvidia/erf_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __ERF_NVIDIA_API_H__ +#define __ERF_NVIDIA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(erf, nvidia) + +#endif // __ERF_NVIDIA_API_H__ diff --git a/src/infiniop/ops/erf/operator.cc b/src/infiniop/ops/erf/operator.cc new file mode 100644 index 000000000..eeee864ee --- /dev/null +++ b/src/infiniop/ops/erf/operator.cc @@ -0,0 +1,11 @@ +#include "../../operator_impl.h" +#include "infiniop/ops/unary_ops_api.h" + +#ifdef ENABLE_CPU_API +#include "cpu/erf_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) +#include "nvidia/erf_nvidia.cuh" +#endif + +UNARY_OP_IMPL(erf, Erf) diff --git a/src/infiniop/ops/floor/cpu/floor_cpu.cc b/src/infiniop/ops/floor/cpu/floor_cpu.cc new file mode 100644 index 000000000..cc717ac11 --- /dev/null +++ b/src/infiniop/ops/floor/cpu/floor_cpu.cc @@ -0,0 +1,8 @@ +#include "floor_cpu.h" +#include "../../../elementwise/cpu/elementwise_cpu_impl.h" + +namespace op::floor::cpu { + +ELEMENTWISE_CPU_IMPL_UNARY(floor) + +} // namespace op::floor::cpu diff --git a/src/infiniop/ops/floor/cpu/floor_cpu.h b/src/infiniop/ops/floor/cpu/floor_cpu.h new file mode 100644 index 000000000..a246309e8 --- /dev/null +++ b/src/infiniop/ops/floor/cpu/floor_cpu.h @@ -0,0 +1,9 @@ +#ifndef __FLOOR_CPU_H__ +#define __FLOOR_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" +#include "../../../elementwise/unary.h" + +UNARY_ELEMENTWISE_DESCRIPTOR(floor, cpu, op::elementwise::unary::UnaryMode::Floor) + +#endif // __FLOOR_CPU_H__ diff --git a/src/infiniop/ops/floor/cuda/kernel.cuh b/src/infiniop/ops/floor/cuda/kernel.cuh new file mode 100644 index 000000000..23a7a44e9 --- /dev/null +++ b/src/infiniop/ops/floor/cuda/kernel.cuh @@ -0,0 +1,10 @@ +#ifndef __FLOOR_CUDA_H__ +#define __FLOOR_CUDA_H__ + +#include "../../../elementwise/unary.h" + +namespace op::floor::cuda { +using Op = op::elementwise::unary::cuda::UnaryOp; +} // namespace op::floor::cuda + +#endif // __FLOOR_CUDA_H__ diff --git a/src/infiniop/ops/floor/nvidia/floor_nvidia.cu b/src/infiniop/ops/floor/nvidia/floor_nvidia.cu new file mode 100644 index 000000000..cec304a1c --- /dev/null +++ b/src/infiniop/ops/floor/nvidia/floor_nvidia.cu @@ -0,0 +1,10 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia_impl.cuh" + +#include "../cuda/kernel.cuh" +#include "floor_nvidia.cuh" + +namespace op::floor::nvidia { + +ELEMENTWISE_NVIDIA_IMPL_UNARY(floor) + +} // namespace op::floor::nvidia diff --git a/src/infiniop/ops/floor/nvidia/floor_nvidia.cuh b/src/infiniop/ops/floor/nvidia/floor_nvidia.cuh new file mode 100644 index 000000000..7a3c2f5c7 --- /dev/null +++ b/src/infiniop/ops/floor/nvidia/floor_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __FLOOR_NVIDIA_API_H__ +#define __FLOOR_NVIDIA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(floor, nvidia) + +#endif // __FLOOR_NVIDIA_API_H__ diff --git a/src/infiniop/ops/floor/operator.cc b/src/infiniop/ops/floor/operator.cc new file mode 100644 index 000000000..bfb4a2466 --- /dev/null +++ b/src/infiniop/ops/floor/operator.cc @@ -0,0 +1,11 @@ +#include "../../operator_impl.h" +#include "infiniop/ops/unary_ops_api.h" + +#ifdef ENABLE_CPU_API +#include "cpu/floor_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) +#include "nvidia/floor_nvidia.cuh" +#endif + +UNARY_OP_IMPL(floor, Floor) diff --git a/src/infiniop/ops/log/cpu/log_cpu.cc b/src/infiniop/ops/log/cpu/log_cpu.cc new file mode 100644 index 000000000..734ad1617 --- /dev/null +++ b/src/infiniop/ops/log/cpu/log_cpu.cc @@ -0,0 +1,8 @@ +#include "log_cpu.h" +#include "../../../elementwise/cpu/elementwise_cpu_impl.h" + +namespace op::log::cpu { + +ELEMENTWISE_CPU_IMPL_UNARY(log) + +} // namespace op::log::cpu diff --git a/src/infiniop/ops/log/cpu/log_cpu.h b/src/infiniop/ops/log/cpu/log_cpu.h new file mode 100644 index 000000000..b13d01442 --- /dev/null +++ b/src/infiniop/ops/log/cpu/log_cpu.h @@ -0,0 +1,9 @@ +#ifndef __LOG_CPU_H__ +#define __LOG_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" +#include "../../../elementwise/unary.h" + +UNARY_ELEMENTWISE_DESCRIPTOR(log, cpu, op::elementwise::unary::UnaryMode::Log) + +#endif // __LOG_CPU_H__ diff --git a/src/infiniop/ops/log/cuda/kernel.cuh b/src/infiniop/ops/log/cuda/kernel.cuh new file mode 100644 index 000000000..80980ada1 --- /dev/null +++ b/src/infiniop/ops/log/cuda/kernel.cuh @@ -0,0 +1,10 @@ +#ifndef __LOG_CUDA_H__ +#define __LOG_CUDA_H__ + +#include "../../../elementwise/unary.h" + +namespace op::log::cuda { +using Op = op::elementwise::unary::cuda::UnaryOp; +} // namespace op::log::cuda + +#endif // __LOG_CUDA_H__ diff --git a/src/infiniop/ops/log/nvidia/log_nvidia.cu b/src/infiniop/ops/log/nvidia/log_nvidia.cu new file mode 100644 index 000000000..87aaa0388 --- /dev/null +++ b/src/infiniop/ops/log/nvidia/log_nvidia.cu @@ -0,0 +1,10 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia_impl.cuh" + +#include "../cuda/kernel.cuh" +#include "log_nvidia.cuh" + +namespace op::log::nvidia { + +ELEMENTWISE_NVIDIA_IMPL_UNARY(log) + +} // namespace op::log::nvidia diff --git a/src/infiniop/ops/log/nvidia/log_nvidia.cuh b/src/infiniop/ops/log/nvidia/log_nvidia.cuh new file mode 100644 index 000000000..c48841622 --- /dev/null +++ b/src/infiniop/ops/log/nvidia/log_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __LOG_NVIDIA_API_H__ +#define __LOG_NVIDIA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(log, nvidia) + +#endif // __LOG_NVIDIA_API_H__ diff --git a/src/infiniop/ops/log/operator.cc b/src/infiniop/ops/log/operator.cc new file mode 100644 index 000000000..b4814ff72 --- /dev/null +++ b/src/infiniop/ops/log/operator.cc @@ -0,0 +1,11 @@ +#include "../../operator_impl.h" +#include "infiniop/ops/unary_ops_api.h" + +#ifdef ENABLE_CPU_API +#include "cpu/log_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) +#include "nvidia/log_nvidia.cuh" +#endif + +UNARY_OP_IMPL(log, Log) diff --git a/src/infiniop/ops/max/cpu/max_cpu.cc b/src/infiniop/ops/max/cpu/max_cpu.cc new file mode 100644 index 000000000..98e8a52a2 --- /dev/null +++ b/src/infiniop/ops/max/cpu/max_cpu.cc @@ -0,0 +1,8 @@ +#include "max_cpu.h" +#include "../../../elementwise/cpu/elementwise_cpu_impl.h" + +namespace op::max::cpu { + +ELEMENTWISE_CPU_IMPL_BINARY(max) + +} // namespace op::max::cpu diff --git a/src/infiniop/ops/max/cpu/max_cpu.h b/src/infiniop/ops/max/cpu/max_cpu.h new file mode 100644 index 000000000..2219994d5 --- /dev/null +++ b/src/infiniop/ops/max/cpu/max_cpu.h @@ -0,0 +1,9 @@ +#ifndef __MAX_CPU_H__ +#define __MAX_CPU_H__ + +#include "../../../elementwise/binary.h" +#include "../../../elementwise/cpu/elementwise_cpu.h" + +BINARY_ELEMENTWISE_DESCRIPTOR(max, cpu, op::elementwise::binary::BinaryMode::Max) + +#endif // __MAX_CPU_H__ diff --git a/src/infiniop/ops/max/cuda/kernel.cuh b/src/infiniop/ops/max/cuda/kernel.cuh new file mode 100644 index 000000000..68f634559 --- /dev/null +++ b/src/infiniop/ops/max/cuda/kernel.cuh @@ -0,0 +1,10 @@ +#ifndef __MAX_CUDA_H__ +#define __MAX_CUDA_H__ + +#include "../../../elementwise/binary.h" + +namespace op::max::cuda { +using Op = op::elementwise::binary::cuda::BinaryOp; +} // namespace op::max::cuda + +#endif // __MAX_CUDA_H__ diff --git a/src/infiniop/ops/max/nvidia/max_nvidia.cu b/src/infiniop/ops/max/nvidia/max_nvidia.cu new file mode 100644 index 000000000..ba4620f3b --- /dev/null +++ b/src/infiniop/ops/max/nvidia/max_nvidia.cu @@ -0,0 +1,10 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia_impl.cuh" + +#include "../cuda/kernel.cuh" +#include "max_nvidia.cuh" + +namespace op::max::nvidia { + +ELEMENTWISE_NVIDIA_IMPL_BINARY(max) + +} // namespace op::max::nvidia diff --git a/src/infiniop/ops/max/nvidia/max_nvidia.cuh b/src/infiniop/ops/max/nvidia/max_nvidia.cuh new file mode 100644 index 000000000..b3b60dd2a --- /dev/null +++ b/src/infiniop/ops/max/nvidia/max_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __MAX_CUDA_API_H__ +#define __MAX_CUDA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(max, nvidia) + +#endif // __MAX_CUDA_API_H__ diff --git a/src/infiniop/ops/max/operator.cc b/src/infiniop/ops/max/operator.cc new file mode 100644 index 000000000..03b6d4eeb --- /dev/null +++ b/src/infiniop/ops/max/operator.cc @@ -0,0 +1,11 @@ +#include "../../operator_impl.h" +#include "infiniop/ops/binary_ops_api.h" + +#ifdef ENABLE_CPU_API +#include "cpu/max_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) +#include "nvidia/max_nvidia.cuh" +#endif + +BINARY_OP_IMPL(max, Max) diff --git a/src/infiniop/ops/min/cpu/min_cpu.cc b/src/infiniop/ops/min/cpu/min_cpu.cc new file mode 100644 index 000000000..1bac9ea61 --- /dev/null +++ b/src/infiniop/ops/min/cpu/min_cpu.cc @@ -0,0 +1,8 @@ +#include "min_cpu.h" +#include "../../../elementwise/cpu/elementwise_cpu_impl.h" + +namespace op::min::cpu { + +ELEMENTWISE_CPU_IMPL_BINARY(min) + +} // namespace op::min::cpu diff --git a/src/infiniop/ops/min/cpu/min_cpu.h b/src/infiniop/ops/min/cpu/min_cpu.h new file mode 100644 index 000000000..74042db50 --- /dev/null +++ b/src/infiniop/ops/min/cpu/min_cpu.h @@ -0,0 +1,9 @@ +#ifndef __MIN_CPU_H__ +#define __MIN_CPU_H__ + +#include "../../../elementwise/binary.h" +#include "../../../elementwise/cpu/elementwise_cpu.h" + +BINARY_ELEMENTWISE_DESCRIPTOR(min, cpu, op::elementwise::binary::BinaryMode::Min) + +#endif // __MIN_CPU_H__ diff --git a/src/infiniop/ops/min/cuda/kernel.cuh b/src/infiniop/ops/min/cuda/kernel.cuh new file mode 100644 index 000000000..75c6ab6b9 --- /dev/null +++ b/src/infiniop/ops/min/cuda/kernel.cuh @@ -0,0 +1,10 @@ +#ifndef __MIN_CUDA_H__ +#define __MIN_CUDA_H__ + +#include "../../../elementwise/binary.h" + +namespace op::min::cuda { +using Op = op::elementwise::binary::cuda::BinaryOp; +} // namespace op::min::cuda + +#endif // __MIN_CUDA_H__ diff --git a/src/infiniop/ops/min/nvidia/min_nvidia.cu b/src/infiniop/ops/min/nvidia/min_nvidia.cu new file mode 100644 index 000000000..0708cbcaf --- /dev/null +++ b/src/infiniop/ops/min/nvidia/min_nvidia.cu @@ -0,0 +1,10 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia_impl.cuh" + +#include "../cuda/kernel.cuh" +#include "min_nvidia.cuh" + +namespace op::min::nvidia { + +ELEMENTWISE_NVIDIA_IMPL_BINARY(min) + +} // namespace op::min::nvidia diff --git a/src/infiniop/ops/min/nvidia/min_nvidia.cuh b/src/infiniop/ops/min/nvidia/min_nvidia.cuh new file mode 100644 index 000000000..ada9a3545 --- /dev/null +++ b/src/infiniop/ops/min/nvidia/min_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __MIN_CUDA_API_H__ +#define __MIN_CUDA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(min, nvidia) + +#endif // __MIN_CUDA_API_H__ diff --git a/src/infiniop/ops/min/operator.cc b/src/infiniop/ops/min/operator.cc new file mode 100644 index 000000000..1597bb5d3 --- /dev/null +++ b/src/infiniop/ops/min/operator.cc @@ -0,0 +1,11 @@ +#include "../../operator_impl.h" +#include "infiniop/ops/binary_ops_api.h" + +#ifdef ENABLE_CPU_API +#include "cpu/min_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) +#include "nvidia/min_nvidia.cuh" +#endif + +BINARY_OP_IMPL(min, Min) diff --git a/src/infiniop/ops/mod/cpu/mod_cpu.cc b/src/infiniop/ops/mod/cpu/mod_cpu.cc new file mode 100644 index 000000000..609c2e76e --- /dev/null +++ b/src/infiniop/ops/mod/cpu/mod_cpu.cc @@ -0,0 +1,8 @@ +#include "mod_cpu.h" +#include "../../../elementwise/cpu/elementwise_cpu_impl.h" + +namespace op::mod::cpu { + +ELEMENTWISE_CPU_IMPL_BINARY(mod) + +} // namespace op::mod::cpu diff --git a/src/infiniop/ops/mod/cpu/mod_cpu.h b/src/infiniop/ops/mod/cpu/mod_cpu.h new file mode 100644 index 000000000..72ea7dede --- /dev/null +++ b/src/infiniop/ops/mod/cpu/mod_cpu.h @@ -0,0 +1,9 @@ +#ifndef __MOD_CPU_H__ +#define __MOD_CPU_H__ + +#include "../../../elementwise/binary.h" +#include "../../../elementwise/cpu/elementwise_cpu.h" + +BINARY_ELEMENTWISE_DESCRIPTOR(mod, cpu, op::elementwise::binary::BinaryMode::Mod) + +#endif // __MOD_CPU_H__ diff --git a/src/infiniop/ops/mod/cuda/kernel.cuh b/src/infiniop/ops/mod/cuda/kernel.cuh new file mode 100644 index 000000000..164784081 --- /dev/null +++ b/src/infiniop/ops/mod/cuda/kernel.cuh @@ -0,0 +1,10 @@ +#ifndef __MOD_CUDA_H__ +#define __MOD_CUDA_H__ + +#include "../../../elementwise/binary.h" + +namespace op::mod::cuda { +using Op = op::elementwise::binary::cuda::BinaryOp; +} // namespace op::mod::cuda + +#endif // __MOD_CUDA_H__ diff --git a/src/infiniop/ops/mod/nvidia/mod_nvidia.cu b/src/infiniop/ops/mod/nvidia/mod_nvidia.cu new file mode 100644 index 000000000..68b78ee70 --- /dev/null +++ b/src/infiniop/ops/mod/nvidia/mod_nvidia.cu @@ -0,0 +1,10 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia_impl.cuh" + +#include "../cuda/kernel.cuh" +#include "mod_nvidia.cuh" + +namespace op::mod::nvidia { + +ELEMENTWISE_NVIDIA_IMPL_BINARY(mod) + +} // namespace op::mod::nvidia diff --git a/src/infiniop/ops/mod/nvidia/mod_nvidia.cuh b/src/infiniop/ops/mod/nvidia/mod_nvidia.cuh new file mode 100644 index 000000000..31788cfd2 --- /dev/null +++ b/src/infiniop/ops/mod/nvidia/mod_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __MOD_CUDA_API_H__ +#define __MOD_CUDA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(mod, nvidia) + +#endif // __MOD_CUDA_API_H__ diff --git a/src/infiniop/ops/mod/operator.cc b/src/infiniop/ops/mod/operator.cc new file mode 100644 index 000000000..9f635d6e6 --- /dev/null +++ b/src/infiniop/ops/mod/operator.cc @@ -0,0 +1,11 @@ +#include "../../operator_impl.h" +#include "infiniop/ops/binary_ops_api.h" + +#ifdef ENABLE_CPU_API +#include "cpu/mod_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) +#include "nvidia/mod_nvidia.cuh" +#endif + +BINARY_OP_IMPL(mod, Mod) diff --git a/src/infiniop/ops/neg/cpu/neg_cpu.cc b/src/infiniop/ops/neg/cpu/neg_cpu.cc new file mode 100644 index 000000000..47f4d2b2e --- /dev/null +++ b/src/infiniop/ops/neg/cpu/neg_cpu.cc @@ -0,0 +1,8 @@ +#include "neg_cpu.h" +#include "../../../elementwise/cpu/elementwise_cpu_impl.h" + +namespace op::neg::cpu { + +ELEMENTWISE_CPU_IMPL_UNARY(neg) + +} // namespace op::neg::cpu diff --git a/src/infiniop/ops/neg/cpu/neg_cpu.h b/src/infiniop/ops/neg/cpu/neg_cpu.h new file mode 100644 index 000000000..f6778a6d3 --- /dev/null +++ b/src/infiniop/ops/neg/cpu/neg_cpu.h @@ -0,0 +1,9 @@ +#ifndef __NEG_CPU_H__ +#define __NEG_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" +#include "../../../elementwise/unary.h" + +UNARY_ELEMENTWISE_DESCRIPTOR(neg, cpu, op::elementwise::unary::UnaryMode::Neg) + +#endif // __NEG_CPU_H__ diff --git a/src/infiniop/ops/neg/cuda/kernel.cuh b/src/infiniop/ops/neg/cuda/kernel.cuh new file mode 100644 index 000000000..f5cf5a449 --- /dev/null +++ b/src/infiniop/ops/neg/cuda/kernel.cuh @@ -0,0 +1,10 @@ +#ifndef __NEG_CUDA_H__ +#define __NEG_CUDA_H__ + +#include "../../../elementwise/unary.h" + +namespace op::neg::cuda { +using Op = op::elementwise::unary::cuda::UnaryOp; +} // namespace op::neg::cuda + +#endif // __NEG_CUDA_H__ diff --git a/src/infiniop/ops/neg/nvidia/neg_nvidia.cu b/src/infiniop/ops/neg/nvidia/neg_nvidia.cu new file mode 100644 index 000000000..f568585f0 --- /dev/null +++ b/src/infiniop/ops/neg/nvidia/neg_nvidia.cu @@ -0,0 +1,10 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia_impl.cuh" + +#include "../cuda/kernel.cuh" +#include "neg_nvidia.cuh" + +namespace op::neg::nvidia { + +ELEMENTWISE_NVIDIA_IMPL_UNARY(neg) + +} // namespace op::neg::nvidia diff --git a/src/infiniop/ops/neg/nvidia/neg_nvidia.cuh b/src/infiniop/ops/neg/nvidia/neg_nvidia.cuh new file mode 100644 index 000000000..1265cd3df --- /dev/null +++ b/src/infiniop/ops/neg/nvidia/neg_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __NEG_NVIDIA_API_H__ +#define __NEG_NVIDIA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(neg, nvidia) + +#endif // __NEG_NVIDIA_API_H__ diff --git a/src/infiniop/ops/neg/operator.cc b/src/infiniop/ops/neg/operator.cc new file mode 100644 index 000000000..e8c99dcdf --- /dev/null +++ b/src/infiniop/ops/neg/operator.cc @@ -0,0 +1,11 @@ +#include "../../operator_impl.h" +#include "infiniop/ops/unary_ops_api.h" + +#ifdef ENABLE_CPU_API +#include "cpu/neg_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) +#include "nvidia/neg_nvidia.cuh" +#endif + +UNARY_OP_IMPL(neg, Neg) diff --git a/src/infiniop/ops/pow/cpu/pow_cpu.cc b/src/infiniop/ops/pow/cpu/pow_cpu.cc new file mode 100644 index 000000000..1134d8aae --- /dev/null +++ b/src/infiniop/ops/pow/cpu/pow_cpu.cc @@ -0,0 +1,8 @@ +#include "pow_cpu.h" +#include "../../../elementwise/cpu/elementwise_cpu_impl.h" + +namespace op::pow::cpu { + +ELEMENTWISE_CPU_IMPL_BINARY(pow) + +} // namespace op::pow::cpu diff --git a/src/infiniop/ops/pow/cpu/pow_cpu.h b/src/infiniop/ops/pow/cpu/pow_cpu.h new file mode 100644 index 000000000..9c8e8a368 --- /dev/null +++ b/src/infiniop/ops/pow/cpu/pow_cpu.h @@ -0,0 +1,9 @@ +#ifndef __POW_CPU_H__ +#define __POW_CPU_H__ + +#include "../../../elementwise/binary.h" +#include "../../../elementwise/cpu/elementwise_cpu.h" + +BINARY_ELEMENTWISE_DESCRIPTOR(pow, cpu, op::elementwise::binary::BinaryMode::Pow) + +#endif // __POW_CPU_H__ diff --git a/src/infiniop/ops/pow/cuda/kernel.cuh b/src/infiniop/ops/pow/cuda/kernel.cuh new file mode 100644 index 000000000..0637240e8 --- /dev/null +++ b/src/infiniop/ops/pow/cuda/kernel.cuh @@ -0,0 +1,10 @@ +#ifndef __POW_CUDA_H__ +#define __POW_CUDA_H__ + +#include "../../../elementwise/binary.h" + +namespace op::pow::cuda { +using Op = op::elementwise::binary::cuda::BinaryOp; +} // namespace op::pow::cuda + +#endif // __POW_CUDA_H__ diff --git a/src/infiniop/ops/pow/nvidia/pow_nvidia.cu b/src/infiniop/ops/pow/nvidia/pow_nvidia.cu new file mode 100644 index 000000000..63a3d40a3 --- /dev/null +++ b/src/infiniop/ops/pow/nvidia/pow_nvidia.cu @@ -0,0 +1,10 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia_impl.cuh" + +#include "../cuda/kernel.cuh" +#include "pow_nvidia.cuh" + +namespace op::pow::nvidia { + +ELEMENTWISE_NVIDIA_IMPL_BINARY(pow) + +} // namespace op::pow::nvidia diff --git a/src/infiniop/ops/pow/nvidia/pow_nvidia.cuh b/src/infiniop/ops/pow/nvidia/pow_nvidia.cuh new file mode 100644 index 000000000..5bbb2fb8c --- /dev/null +++ b/src/infiniop/ops/pow/nvidia/pow_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __POW_CUDA_API_H__ +#define __POW_CUDA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(pow, nvidia) + +#endif // __POW_CUDA_API_H__ diff --git a/src/infiniop/ops/pow/operator.cc b/src/infiniop/ops/pow/operator.cc new file mode 100644 index 000000000..7a24d7a20 --- /dev/null +++ b/src/infiniop/ops/pow/operator.cc @@ -0,0 +1,11 @@ +#include "../../operator_impl.h" +#include "infiniop/ops/binary_ops_api.h" + +#ifdef ENABLE_CPU_API +#include "cpu/pow_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) +#include "nvidia/pow_nvidia.cuh" +#endif + +BINARY_OP_IMPL(pow, Pow) diff --git a/src/infiniop/ops/reciprocal/cpu/reciprocal_cpu.cc b/src/infiniop/ops/reciprocal/cpu/reciprocal_cpu.cc new file mode 100644 index 000000000..0b66eca64 --- /dev/null +++ b/src/infiniop/ops/reciprocal/cpu/reciprocal_cpu.cc @@ -0,0 +1,8 @@ +#include "reciprocal_cpu.h" +#include "../../../elementwise/cpu/elementwise_cpu_impl.h" + +namespace op::reciprocal::cpu { + +ELEMENTWISE_CPU_IMPL_UNARY(reciprocal) + +} // namespace op::reciprocal::cpu diff --git a/src/infiniop/ops/reciprocal/cpu/reciprocal_cpu.h b/src/infiniop/ops/reciprocal/cpu/reciprocal_cpu.h new file mode 100644 index 000000000..9af583ab7 --- /dev/null +++ b/src/infiniop/ops/reciprocal/cpu/reciprocal_cpu.h @@ -0,0 +1,9 @@ +#ifndef __RECIPROCAL_CPU_H__ +#define __RECIPROCAL_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" +#include "../../../elementwise/unary.h" + +UNARY_ELEMENTWISE_DESCRIPTOR(reciprocal, cpu, op::elementwise::unary::UnaryMode::Reciprocal) + +#endif // __RECIPROCAL_CPU_H__ diff --git a/src/infiniop/ops/reciprocal/cuda/kernel.cuh b/src/infiniop/ops/reciprocal/cuda/kernel.cuh new file mode 100644 index 000000000..8c29a8e9e --- /dev/null +++ b/src/infiniop/ops/reciprocal/cuda/kernel.cuh @@ -0,0 +1,10 @@ +#ifndef __RECIPROCAL_CUDA_H__ +#define __RECIPROCAL_CUDA_H__ + +#include "../../../elementwise/unary.h" + +namespace op::reciprocal::cuda { +using Op = op::elementwise::unary::cuda::UnaryOp; +} // namespace op::reciprocal::cuda + +#endif // __RECIPROCAL_CUDA_H__ diff --git a/src/infiniop/ops/reciprocal/nvidia/reciprocal_nvidia.cu b/src/infiniop/ops/reciprocal/nvidia/reciprocal_nvidia.cu new file mode 100644 index 000000000..39a41b583 --- /dev/null +++ b/src/infiniop/ops/reciprocal/nvidia/reciprocal_nvidia.cu @@ -0,0 +1,10 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia_impl.cuh" + +#include "../cuda/kernel.cuh" +#include "reciprocal_nvidia.cuh" + +namespace op::reciprocal::nvidia { + +ELEMENTWISE_NVIDIA_IMPL_UNARY(reciprocal) + +} // namespace op::reciprocal::nvidia diff --git a/src/infiniop/ops/reciprocal/nvidia/reciprocal_nvidia.cuh b/src/infiniop/ops/reciprocal/nvidia/reciprocal_nvidia.cuh new file mode 100644 index 000000000..d98c8f4c2 --- /dev/null +++ b/src/infiniop/ops/reciprocal/nvidia/reciprocal_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __RECIPROCAL_NVIDIA_API_H__ +#define __RECIPROCAL_NVIDIA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(reciprocal, nvidia) + +#endif // __RECIPROCAL_NVIDIA_API_H__ diff --git a/src/infiniop/ops/reciprocal/operator.cc b/src/infiniop/ops/reciprocal/operator.cc new file mode 100644 index 000000000..4c55fdf20 --- /dev/null +++ b/src/infiniop/ops/reciprocal/operator.cc @@ -0,0 +1,11 @@ +#include "../../operator_impl.h" +#include "infiniop/ops/unary_ops_api.h" + +#ifdef ENABLE_CPU_API +#include "cpu/reciprocal_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) +#include "nvidia/reciprocal_nvidia.cuh" +#endif + +UNARY_OP_IMPL(reciprocal, Reciprocal) diff --git a/src/infiniop/ops/round/cpu/round_cpu.cc b/src/infiniop/ops/round/cpu/round_cpu.cc new file mode 100644 index 000000000..20ae304bd --- /dev/null +++ b/src/infiniop/ops/round/cpu/round_cpu.cc @@ -0,0 +1,8 @@ +#include "round_cpu.h" +#include "../../../elementwise/cpu/elementwise_cpu_impl.h" + +namespace op::round::cpu { + +ELEMENTWISE_CPU_IMPL_UNARY(round) + +} // namespace op::round::cpu diff --git a/src/infiniop/ops/round/cpu/round_cpu.h b/src/infiniop/ops/round/cpu/round_cpu.h new file mode 100644 index 000000000..1a755dbf8 --- /dev/null +++ b/src/infiniop/ops/round/cpu/round_cpu.h @@ -0,0 +1,9 @@ +#ifndef __ROUND_CPU_H__ +#define __ROUND_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" +#include "../../../elementwise/unary.h" + +UNARY_ELEMENTWISE_DESCRIPTOR(round, cpu, op::elementwise::unary::UnaryMode::Round) + +#endif // __ROUND_CPU_H__ diff --git a/src/infiniop/ops/round/cuda/kernel.cuh b/src/infiniop/ops/round/cuda/kernel.cuh new file mode 100644 index 000000000..f4de9c772 --- /dev/null +++ b/src/infiniop/ops/round/cuda/kernel.cuh @@ -0,0 +1,10 @@ +#ifndef __ROUND_CUDA_H__ +#define __ROUND_CUDA_H__ + +#include "../../../elementwise/unary.h" + +namespace op::round::cuda { +using Op = op::elementwise::unary::cuda::UnaryOp; +} // namespace op::round::cuda + +#endif // __ROUND_CUDA_H__ diff --git a/src/infiniop/ops/round/nvidia/round_nvidia.cu b/src/infiniop/ops/round/nvidia/round_nvidia.cu new file mode 100644 index 000000000..dc84388a3 --- /dev/null +++ b/src/infiniop/ops/round/nvidia/round_nvidia.cu @@ -0,0 +1,10 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia_impl.cuh" + +#include "../cuda/kernel.cuh" +#include "round_nvidia.cuh" + +namespace op::round::nvidia { + +ELEMENTWISE_NVIDIA_IMPL_UNARY(round) + +} // namespace op::round::nvidia diff --git a/src/infiniop/ops/round/nvidia/round_nvidia.cuh b/src/infiniop/ops/round/nvidia/round_nvidia.cuh new file mode 100644 index 000000000..65bb38566 --- /dev/null +++ b/src/infiniop/ops/round/nvidia/round_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __ROUND_NVIDIA_API_H__ +#define __ROUND_NVIDIA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(round, nvidia) + +#endif // __ROUND_NVIDIA_API_H__ diff --git a/src/infiniop/ops/round/operator.cc b/src/infiniop/ops/round/operator.cc new file mode 100644 index 000000000..5a1e0fcc5 --- /dev/null +++ b/src/infiniop/ops/round/operator.cc @@ -0,0 +1,11 @@ +#include "../../operator_impl.h" +#include "infiniop/ops/unary_ops_api.h" + +#ifdef ENABLE_CPU_API +#include "cpu/round_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) +#include "nvidia/round_nvidia.cuh" +#endif + +UNARY_OP_IMPL(round, Round) diff --git a/src/infiniop/ops/sign/cpu/sign_cpu.cc b/src/infiniop/ops/sign/cpu/sign_cpu.cc new file mode 100644 index 000000000..c65868d09 --- /dev/null +++ b/src/infiniop/ops/sign/cpu/sign_cpu.cc @@ -0,0 +1,8 @@ +#include "sign_cpu.h" +#include "../../../elementwise/cpu/elementwise_cpu_impl.h" + +namespace op::sign::cpu { + +ELEMENTWISE_CPU_IMPL_UNARY(sign) + +} // namespace op::sign::cpu diff --git a/src/infiniop/ops/sign/cpu/sign_cpu.h b/src/infiniop/ops/sign/cpu/sign_cpu.h new file mode 100644 index 000000000..7ddeec543 --- /dev/null +++ b/src/infiniop/ops/sign/cpu/sign_cpu.h @@ -0,0 +1,9 @@ +#ifndef __SIGN_CPU_H__ +#define __SIGN_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" +#include "../../../elementwise/unary.h" + +UNARY_ELEMENTWISE_DESCRIPTOR(sign, cpu, op::elementwise::unary::UnaryMode::Sign) + +#endif // __SIGN_CPU_H__ diff --git a/src/infiniop/ops/sign/cuda/kernel.cuh b/src/infiniop/ops/sign/cuda/kernel.cuh new file mode 100644 index 000000000..a1216fb82 --- /dev/null +++ b/src/infiniop/ops/sign/cuda/kernel.cuh @@ -0,0 +1,10 @@ +#ifndef __SIGN_CUDA_H__ +#define __SIGN_CUDA_H__ + +#include "../../../elementwise/unary.h" + +namespace op::sign::cuda { +using Op = op::elementwise::unary::cuda::UnaryOp; +} // namespace op::sign::cuda + +#endif // __SIGN_CUDA_H__ diff --git a/src/infiniop/ops/sign/nvidia/sign_nvidia.cu b/src/infiniop/ops/sign/nvidia/sign_nvidia.cu new file mode 100644 index 000000000..2a11f9e23 --- /dev/null +++ b/src/infiniop/ops/sign/nvidia/sign_nvidia.cu @@ -0,0 +1,10 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia_impl.cuh" + +#include "../cuda/kernel.cuh" +#include "sign_nvidia.cuh" + +namespace op::sign::nvidia { + +ELEMENTWISE_NVIDIA_IMPL_UNARY(sign) + +} // namespace op::sign::nvidia diff --git a/src/infiniop/ops/sign/nvidia/sign_nvidia.cuh b/src/infiniop/ops/sign/nvidia/sign_nvidia.cuh new file mode 100644 index 000000000..d5f2540a3 --- /dev/null +++ b/src/infiniop/ops/sign/nvidia/sign_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __SIGN_NVIDIA_API_H__ +#define __SIGN_NVIDIA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(sign, nvidia) + +#endif // __SIGN_NVIDIA_API_H__ diff --git a/src/infiniop/ops/sign/operator.cc b/src/infiniop/ops/sign/operator.cc new file mode 100644 index 000000000..18850ec1f --- /dev/null +++ b/src/infiniop/ops/sign/operator.cc @@ -0,0 +1,11 @@ +#include "../../operator_impl.h" +#include "infiniop/ops/unary_ops_api.h" + +#ifdef ENABLE_CPU_API +#include "cpu/sign_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) +#include "nvidia/sign_nvidia.cuh" +#endif + +UNARY_OP_IMPL(sign, Sign) diff --git a/src/infiniop/ops/sinh/cpu/sinh_cpu.cc b/src/infiniop/ops/sinh/cpu/sinh_cpu.cc new file mode 100644 index 000000000..897439905 --- /dev/null +++ b/src/infiniop/ops/sinh/cpu/sinh_cpu.cc @@ -0,0 +1,8 @@ +#include "sinh_cpu.h" +#include "../../../elementwise/cpu/elementwise_cpu_impl.h" + +namespace op::sinh::cpu { + +ELEMENTWISE_CPU_IMPL_UNARY(sinh) + +} // namespace op::sinh::cpu diff --git a/src/infiniop/ops/sinh/cpu/sinh_cpu.h b/src/infiniop/ops/sinh/cpu/sinh_cpu.h new file mode 100644 index 000000000..573027ee3 --- /dev/null +++ b/src/infiniop/ops/sinh/cpu/sinh_cpu.h @@ -0,0 +1,9 @@ +#ifndef __SINH_CPU_H__ +#define __SINH_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" +#include "../../../elementwise/unary.h" + +UNARY_ELEMENTWISE_DESCRIPTOR(sinh, cpu, op::elementwise::unary::UnaryMode::Sinh) + +#endif // __SINH_CPU_H__ diff --git a/src/infiniop/ops/sinh/cuda/kernel.cuh b/src/infiniop/ops/sinh/cuda/kernel.cuh new file mode 100644 index 000000000..d5bb7491f --- /dev/null +++ b/src/infiniop/ops/sinh/cuda/kernel.cuh @@ -0,0 +1,10 @@ +#ifndef __SINH_CUDA_H__ +#define __SINH_CUDA_H__ + +#include "../../../elementwise/unary.h" + +namespace op::sinh::cuda { +using Op = op::elementwise::unary::cuda::UnaryOp; +} // namespace op::sinh::cuda + +#endif // __SINH_CUDA_H__ diff --git a/src/infiniop/ops/sinh/nvidia/sinh_nvidia.cu b/src/infiniop/ops/sinh/nvidia/sinh_nvidia.cu new file mode 100644 index 000000000..3abfc2973 --- /dev/null +++ b/src/infiniop/ops/sinh/nvidia/sinh_nvidia.cu @@ -0,0 +1,10 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia_impl.cuh" + +#include "../cuda/kernel.cuh" +#include "sinh_nvidia.cuh" + +namespace op::sinh::nvidia { + +ELEMENTWISE_NVIDIA_IMPL_UNARY(sinh) + +} // namespace op::sinh::nvidia diff --git a/src/infiniop/ops/sinh/nvidia/sinh_nvidia.cuh b/src/infiniop/ops/sinh/nvidia/sinh_nvidia.cuh new file mode 100644 index 000000000..66e3e3e67 --- /dev/null +++ b/src/infiniop/ops/sinh/nvidia/sinh_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __SINH_NVIDIA_API_H__ +#define __SINH_NVIDIA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(sinh, nvidia) + +#endif // __SINH_NVIDIA_API_H__ diff --git a/src/infiniop/ops/sinh/operator.cc b/src/infiniop/ops/sinh/operator.cc new file mode 100644 index 000000000..263d20347 --- /dev/null +++ b/src/infiniop/ops/sinh/operator.cc @@ -0,0 +1,11 @@ +#include "../../operator_impl.h" +#include "infiniop/ops/unary_ops_api.h" + +#ifdef ENABLE_CPU_API +#include "cpu/sinh_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) +#include "nvidia/sinh_nvidia.cuh" +#endif + +UNARY_OP_IMPL(sinh, Sinh) diff --git a/src/infiniop/ops/sqrt/cpu/sqrt_cpu.cc b/src/infiniop/ops/sqrt/cpu/sqrt_cpu.cc new file mode 100644 index 000000000..eb9ac4d66 --- /dev/null +++ b/src/infiniop/ops/sqrt/cpu/sqrt_cpu.cc @@ -0,0 +1,8 @@ +#include "sqrt_cpu.h" +#include "../../../elementwise/cpu/elementwise_cpu_impl.h" + +namespace op::sqrt::cpu { + +ELEMENTWISE_CPU_IMPL_UNARY(sqrt) + +} // namespace op::sqrt::cpu diff --git a/src/infiniop/ops/sqrt/cpu/sqrt_cpu.h b/src/infiniop/ops/sqrt/cpu/sqrt_cpu.h new file mode 100644 index 000000000..ed6217e1f --- /dev/null +++ b/src/infiniop/ops/sqrt/cpu/sqrt_cpu.h @@ -0,0 +1,9 @@ +#ifndef __SQRT_CPU_H__ +#define __SQRT_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" +#include "../../../elementwise/unary.h" + +UNARY_ELEMENTWISE_DESCRIPTOR(sqrt, cpu, op::elementwise::unary::UnaryMode::Sqrt) + +#endif // __SQRT_CPU_H__ diff --git a/src/infiniop/ops/sqrt/cuda/kernel.cuh b/src/infiniop/ops/sqrt/cuda/kernel.cuh new file mode 100644 index 000000000..40ab9708f --- /dev/null +++ b/src/infiniop/ops/sqrt/cuda/kernel.cuh @@ -0,0 +1,10 @@ +#ifndef __SQRT_CUDA_H__ +#define __SQRT_CUDA_H__ + +#include "../../../elementwise/unary.h" + +namespace op::sqrt::cuda { +using Op = op::elementwise::unary::cuda::UnaryOp; +} // namespace op::sqrt::cuda + +#endif // __SQRT_CUDA_H__ diff --git a/src/infiniop/ops/sqrt/nvidia/sqrt_nvidia.cu b/src/infiniop/ops/sqrt/nvidia/sqrt_nvidia.cu new file mode 100644 index 000000000..4d6c70d72 --- /dev/null +++ b/src/infiniop/ops/sqrt/nvidia/sqrt_nvidia.cu @@ -0,0 +1,10 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia_impl.cuh" + +#include "../cuda/kernel.cuh" +#include "sqrt_nvidia.cuh" + +namespace op::sqrt::nvidia { + +ELEMENTWISE_NVIDIA_IMPL_UNARY(sqrt) + +} // namespace op::sqrt::nvidia diff --git a/src/infiniop/ops/sqrt/nvidia/sqrt_nvidia.cuh b/src/infiniop/ops/sqrt/nvidia/sqrt_nvidia.cuh new file mode 100644 index 000000000..6cd98c814 --- /dev/null +++ b/src/infiniop/ops/sqrt/nvidia/sqrt_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __SQRT_NVIDIA_API_H__ +#define __SQRT_NVIDIA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(sqrt, nvidia) + +#endif // __SQRT_NVIDIA_API_H__ diff --git a/src/infiniop/ops/sqrt/operator.cc b/src/infiniop/ops/sqrt/operator.cc new file mode 100644 index 000000000..5962860ca --- /dev/null +++ b/src/infiniop/ops/sqrt/operator.cc @@ -0,0 +1,11 @@ +#include "../../operator_impl.h" +#include "infiniop/ops/unary_ops_api.h" + +#ifdef ENABLE_CPU_API +#include "cpu/sqrt_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) +#include "nvidia/sqrt_nvidia.cuh" +#endif + +UNARY_OP_IMPL(sqrt, Sqrt) diff --git a/src/infiniop/ops/tan/cpu/tan_cpu.cc b/src/infiniop/ops/tan/cpu/tan_cpu.cc new file mode 100644 index 000000000..5166cf64f --- /dev/null +++ b/src/infiniop/ops/tan/cpu/tan_cpu.cc @@ -0,0 +1,8 @@ +#include "tan_cpu.h" +#include "../../../elementwise/cpu/elementwise_cpu_impl.h" + +namespace op::tan::cpu { + +ELEMENTWISE_CPU_IMPL_UNARY(tan) + +} // namespace op::tan::cpu diff --git a/src/infiniop/ops/tan/cpu/tan_cpu.h b/src/infiniop/ops/tan/cpu/tan_cpu.h new file mode 100644 index 000000000..6c697c311 --- /dev/null +++ b/src/infiniop/ops/tan/cpu/tan_cpu.h @@ -0,0 +1,9 @@ +#ifndef __TAN_CPU_H__ +#define __TAN_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" +#include "../../../elementwise/unary.h" + +UNARY_ELEMENTWISE_DESCRIPTOR(tan, cpu, op::elementwise::unary::UnaryMode::Tan) + +#endif // __TAN_CPU_H__ diff --git a/src/infiniop/ops/tan/cuda/kernel.cuh b/src/infiniop/ops/tan/cuda/kernel.cuh new file mode 100644 index 000000000..c3cf45350 --- /dev/null +++ b/src/infiniop/ops/tan/cuda/kernel.cuh @@ -0,0 +1,10 @@ +#ifndef __TAN_CUDA_H__ +#define __TAN_CUDA_H__ + +#include "../../../elementwise/unary.h" + +namespace op::tan::cuda { +using Op = op::elementwise::unary::cuda::UnaryOp; +} // namespace op::tan::cuda + +#endif // __TAN_CUDA_H__ diff --git a/src/infiniop/ops/tan/nvidia/tan_nvidia.cu b/src/infiniop/ops/tan/nvidia/tan_nvidia.cu new file mode 100644 index 000000000..5f56dcb6f --- /dev/null +++ b/src/infiniop/ops/tan/nvidia/tan_nvidia.cu @@ -0,0 +1,10 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia_impl.cuh" + +#include "../cuda/kernel.cuh" +#include "tan_nvidia.cuh" + +namespace op::tan::nvidia { + +ELEMENTWISE_NVIDIA_IMPL_UNARY(tan) + +} // namespace op::tan::nvidia diff --git a/src/infiniop/ops/tan/nvidia/tan_nvidia.cuh b/src/infiniop/ops/tan/nvidia/tan_nvidia.cuh new file mode 100644 index 000000000..ec620cbeb --- /dev/null +++ b/src/infiniop/ops/tan/nvidia/tan_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __TAN_NVIDIA_API_H__ +#define __TAN_NVIDIA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(tan, nvidia) + +#endif // __TAN_NVIDIA_API_H__ diff --git a/src/infiniop/ops/tan/operator.cc b/src/infiniop/ops/tan/operator.cc new file mode 100644 index 000000000..75dd8277e --- /dev/null +++ b/src/infiniop/ops/tan/operator.cc @@ -0,0 +1,11 @@ +#include "../../operator_impl.h" +#include "infiniop/ops/unary_ops_api.h" + +#ifdef ENABLE_CPU_API +#include "cpu/tan_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) +#include "nvidia/tan_nvidia.cuh" +#endif + +UNARY_OP_IMPL(tan, Tan) diff --git a/src/infiniop/ops/tanh/cuda/kernel.cuh b/src/infiniop/ops/tanh/cuda/kernel.cuh index e336a4995..d987ac7c5 100644 --- a/src/infiniop/ops/tanh/cuda/kernel.cuh +++ b/src/infiniop/ops/tanh/cuda/kernel.cuh @@ -1,44 +1,10 @@ #ifndef __TANH_CUDA_H__ #define __TANH_CUDA_H__ -#include +#include "../../../elementwise/unary.h" namespace op::tanh::cuda { -typedef struct TanhOp { - static constexpr size_t num_inputs = 1; - - __device__ __forceinline__ float tanh_f32_func(float x) const { - return tanhf(x); - } - template - __device__ __forceinline__ T operator()(const T &input) const { - if constexpr (std::is_same_v) { - float2 vf = __half22float2(input); - float2 vr = make_float2(tanh_f32_func(vf.x), tanh_f32_func(vf.y)); - return __float22half2_rn(vr); - } else if constexpr (std::is_same_v) { - float xf = __half2float(input); - float yf = tanh_f32_func(xf); - return __float2half_rn(yf); - } else if constexpr (std::is_same_v) { - float f0 = __bfloat162float(__low2bfloat16(input)); - float f1 = __bfloat162float(__high2bfloat16(input)); - float r0 = tanh_f32_func(f0); - float r1 = tanh_f32_func(f1); - return __floats2bfloat162_rn(r0, r1); - } else if constexpr (std::is_same_v) { - float xf = __bfloat162float(input); - float rf = tanh_f32_func(xf); - return __float2bfloat16_rn(rf); - } else if constexpr (std::is_same_v) { - return tanh_f32_func(input); - } else if constexpr (std::is_same_v) { - return std::tanh(input); - } else { - return std::tanh(input); - } - } -} TanhOp; +using Op = op::elementwise::unary::cuda::UnaryOp; } // namespace op::tanh::cuda #endif // __TANH_CUDA_H__ diff --git a/src/infiniop/ops/tanh/nvidia/tanh_nvidia.cu b/src/infiniop/ops/tanh/nvidia/tanh_nvidia.cu index a2c36551c..62f02da67 100644 --- a/src/infiniop/ops/tanh/nvidia/tanh_nvidia.cu +++ b/src/infiniop/ops/tanh/nvidia/tanh_nvidia.cu @@ -1,59 +1,10 @@ -#include "../../../elementwise/nvidia/elementwise_nvidia.cuh" +#include "../../../elementwise/nvidia/elementwise_nvidia_impl.cuh" #include "../cuda/kernel.cuh" #include "tanh_nvidia.cuh" namespace op::tanh::nvidia { -Descriptor::~Descriptor() = default; +ELEMENTWISE_NVIDIA_IMPL_UNARY(tanh) -infiniStatus_t Descriptor::create( - infiniopHandle_t handle_, - Descriptor **desc_ptr, - infiniopTensorDescriptor_t out_desc, - std::vector input_desc_vec) { - - auto handle = reinterpret_cast(handle_); - auto dtype = out_desc->dtype(); - - const auto &input_desc = input_desc_vec.at(0); - const auto &output_shape = out_desc->shape(); - const auto &input_shape = input_desc->shape(); - - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); - - CHECK_SAME_SHAPE(output_shape, input_shape); - - // create CUDA elementwise descriptor - CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) - - return INFINI_STATUS_SUCCESS; -} - -infiniStatus_t Descriptor::calculate( - void *workspace, - size_t workspace_size, - void *output, - std::vector inputs, - void *stream) const { - - if (workspace_size < _workspace_size) { - return INFINI_STATUS_INSUFFICIENT_WORKSPACE; - } - - switch (_dtype) { - case INFINI_DTYPE_F16: - return _device_info->calculate<256, cuda::TanhOp, half>(_info, workspace, output, inputs, stream); - case INFINI_DTYPE_BF16: - return _device_info->calculate<256, cuda::TanhOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); - case INFINI_DTYPE_F32: - return _device_info->calculate<256, cuda::TanhOp, float>(_info, workspace, output, inputs, stream); - case INFINI_DTYPE_F64: - return _device_info->calculate<256, cuda::TanhOp, double>(_info, workspace, output, inputs, stream); - default: - return INFINI_STATUS_BAD_TENSOR_DTYPE; - } - - return INFINI_STATUS_SUCCESS; -} } // namespace op::tanh::nvidia diff --git a/test/infiniop/libinfiniop/binary_test_base.py b/test/infiniop/libinfiniop/binary_test_base.py new file mode 100644 index 000000000..c9da5b4de --- /dev/null +++ b/test/infiniop/libinfiniop/binary_test_base.py @@ -0,0 +1,273 @@ +""" +Base test template for binary operators. + +This module provides a unified test framework for all binary operators, +eliminating code duplication across individual test scripts. + +Usage: + from libinfiniop.binary_test_base import BinaryTestBase + + class DivTest(BinaryTestBase): + OP_NAME = "Div" + OP_NAME_LOWER = "div" + + @staticmethod + def torch_op(c, a, b): + torch.div(a, b, out=c) + + @staticmethod + def generate_input_a(shape, dtype, device): + return TestTensor(shape, None, dtype, device) + + @staticmethod + def generate_input_b(shape, dtype, device): + # For division, ensure b doesn't contain zeros + return TestTensor(shape, None, dtype, device, scale=2, bias=0.1) + + TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-6, "rtol": 1e-6}, + } + + if __name__ == "__main__": + DivTest.run() +""" + +import ctypes +from ctypes import c_uint64 +from enum import Enum, auto + +import torch +from libinfiniop import ( + LIBINFINIOP, + InfiniDeviceNames, + InfiniDtype, + InfiniDtypeNames, + TestTensor, + TestWorkspace, + check_error, + debug, + get_args, + get_test_devices, + get_tolerance, + infiniopOperatorDescriptor_t, + profile_operation, + test_operator, +) + + +class Inplace(Enum): + OUT_OF_PLACE = auto() + INPLACE_A = auto() + INPLACE_B = auto() + + +# Common test cases for binary operators +_BINARY_TEST_CASES_ = [ + # shape, a_stride, b_stride, c_stride + ((13, 4), None, None, None), + ((13, 4), (10, 1), (10, 1), (10, 1)), + ((13, 4), (0, 1), None, None), + ((13, 4, 4), None, None, None), + ((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)), + ((13, 4, 4), (4, 0, 1), (0, 4, 1), None), + ((16, 5632), None, None, None), + ((16, 5632), (13312, 1), (13312, 1), (13312, 1)), + ((13, 16, 2), (128, 4, 1), (0, 2, 1), (64, 4, 1)), + ((13, 16, 2), (128, 4, 1), (2, 0, 1), (64, 4, 1)), + ((4, 4, 5632), None, None, None), + ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)), +] + +# Inplace options applied for each test case +_BINARY_INPLACE = [ + Inplace.OUT_OF_PLACE, + Inplace.INPLACE_A, + Inplace.INPLACE_B, +] + +# Form the test cases by appending each element of _INPLACE to each tuple in _TEST_CASES_ +_BINARY_TEST_CASES = [ + test_case + (inplace_item,) + for test_case in _BINARY_TEST_CASES_ + for inplace_item in _BINARY_INPLACE +] + +# Data types used for testing (matching old operators library: only F16 and F32) +_BINARY_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.F32] + + +class BinaryTestBase: + """ + Base class for binary operator tests. + + Subclasses must define: + - OP_NAME: Uppercase operator name (e.g., "Div", "Pow") + - OP_NAME_LOWER: Lowercase operator name (e.g., "div", "pow") + - torch_op: Static method that performs the PyTorch operation + - generate_input_a: Static method that generates first input tensor + - generate_input_b: Static method that generates second input tensor + - TOLERANCE_MAP: Dictionary mapping dtype to tolerance values + """ + + OP_NAME = None + OP_NAME_LOWER = None + + # Default tolerance map (can be overridden) + TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-6, "rtol": 1e-6}, + } + + # Test cases (can be overridden) + TEST_CASES = _BINARY_TEST_CASES + TENSOR_DTYPES = _BINARY_TENSOR_DTYPES + + DEBUG = False + PROFILE = False + NUM_PRERUN = 10 + NUM_ITERATIONS = 1000 + + @staticmethod + def torch_op(c, a, b): + """PyTorch operation - must be implemented by subclass""" + raise NotImplementedError("Subclass must implement torch_op") + + @staticmethod + def generate_input_a(shape, a_stride, dtype, device): + """ + Generate first input tensor - must be implemented by subclass. + + Args: + shape: Tensor shape tuple + a_stride: Stride tuple or None + dtype: InfiniDtype enum value + device: InfiniDeviceEnum value + + Returns: + TestTensor: Generated first input tensor + """ + raise NotImplementedError("Subclass must implement generate_input_a") + + @staticmethod + def generate_input_b(shape, b_stride, dtype, device): + """ + Generate second input tensor - must be implemented by subclass. + + Args: + shape: Tensor shape tuple + b_stride: Stride tuple or None + dtype: InfiniDtype enum value + device: InfiniDeviceEnum value + + Returns: + TestTensor: Generated second input tensor + """ + raise NotImplementedError("Subclass must implement generate_input_b") + + @classmethod + def test(cls, handle, device, shape, a_stride=None, b_stride=None, c_stride=None, + inplace=Inplace.OUT_OF_PLACE, dtype=InfiniDtype.F16, sync=None): + """Common test function for binary operators""" + a = cls.generate_input_a(shape, a_stride, dtype, device) + b = cls.generate_input_b(shape, b_stride, dtype, device) + + if inplace == Inplace.INPLACE_A: + if c_stride is not None and c_stride != a_stride: + return + c = a + elif inplace == Inplace.INPLACE_B: + if c_stride is not None and c_stride != b_stride: + return + c = b + else: + c = TestTensor(shape, c_stride, dtype, device) + + if c.is_broadcast(): + return + + print( + f"Testing {cls.OP_NAME} on {InfiniDeviceNames[device]} with shape:{shape} a_stride:{a_stride} b_stride:{b_stride} c_stride:{c_stride} " + f"dtype:{InfiniDtypeNames[dtype]} inplace:{inplace}" + ) + + cls.torch_op(c.torch_tensor(), a.torch_tensor(), b.torch_tensor()) + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + create_func = getattr(LIBINFINIOP, f"infiniopCreate{cls.OP_NAME}Descriptor") + check_error( + create_func( + handle, + ctypes.byref(descriptor), + c.descriptor, + a.descriptor, + b.descriptor, + ) + ) + + # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel + for tensor in [a, b, c]: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + get_workspace_func = getattr(LIBINFINIOP, f"infiniopGet{cls.OP_NAME}WorkspaceSize") + check_error( + get_workspace_func( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, device) + + def lib_op(): + op_func = getattr(LIBINFINIOP, f"infiniop{cls.OP_NAME}") + check_error( + op_func( + descriptor, + workspace.data(), + workspace_size.value, + c.data(), + a.data(), + b.data(), + None, + ) + ) + + lib_op() + if sync is not None: + sync() + + atol, rtol = get_tolerance(cls.TOLERANCE_MAP, dtype) + if cls.DEBUG: + debug(c.actual_tensor(), c.torch_tensor(), atol=atol, rtol=rtol) + + equal_nan = getattr(cls, 'EQUAL_NAN', False) + assert torch.allclose(c.actual_tensor(), c.torch_tensor(), atol=atol, rtol=rtol, equal_nan=equal_nan) + + # Profiling workflow + if cls.PROFILE: + # fmt: off + profile_operation("PyTorch", lambda: cls.torch_op(c.torch_tensor(), a.torch_tensor(), b.torch_tensor()), device, cls.NUM_PRERUN, cls.NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_op(), device, cls.NUM_PRERUN, cls.NUM_ITERATIONS) + # fmt: on + + destroy_func = getattr(LIBINFINIOP, f"infiniopDestroy{cls.OP_NAME}Descriptor") + check_error(destroy_func(descriptor)) + + @classmethod + def run(cls): + """Run the test""" + args = get_args() + + # Configure testing options + cls.DEBUG = args.debug + cls.PROFILE = args.profile + cls.NUM_PRERUN = args.num_prerun + cls.NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, cls.test, cls.TEST_CASES, cls.TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 618be2b05..20a9188d6 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -269,6 +269,176 @@ def mul_(lib): ] +@OpRegister.operator +def pow_(lib): + lib.infiniopCreatePowDescriptor.restype = c_int32 + lib.infiniopCreatePowDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetPowWorkspaceSize.restype = c_int32 + lib.infiniopGetPowWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopPow.restype = c_int32 + lib.infiniopPow.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyPowDescriptor.restype = c_int32 + lib.infiniopDestroyPowDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + +@OpRegister.operator +def div_(lib): + lib.infiniopCreateDivDescriptor.restype = c_int32 + lib.infiniopCreateDivDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetDivWorkspaceSize.restype = c_int32 + lib.infiniopGetDivWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopDiv.restype = c_int32 + lib.infiniopDiv.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyDivDescriptor.restype = c_int32 + lib.infiniopDestroyDivDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + +@OpRegister.operator +def mod_(lib): + lib.infiniopCreateModDescriptor.restype = c_int32 + lib.infiniopCreateModDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetModWorkspaceSize.restype = c_int32 + lib.infiniopGetModWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopMod.restype = c_int32 + lib.infiniopMod.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyModDescriptor.restype = c_int32 + lib.infiniopDestroyModDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + +@OpRegister.operator +def max_(lib): + lib.infiniopCreateMaxDescriptor.restype = c_int32 + lib.infiniopCreateMaxDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetMaxWorkspaceSize.restype = c_int32 + lib.infiniopGetMaxWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopMax.restype = c_int32 + lib.infiniopMax.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyMaxDescriptor.restype = c_int32 + lib.infiniopDestroyMaxDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + +@OpRegister.operator +def min_(lib): + lib.infiniopCreateMinDescriptor.restype = c_int32 + lib.infiniopCreateMinDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetMinWorkspaceSize.restype = c_int32 + lib.infiniopGetMinWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopMin.restype = c_int32 + lib.infiniopMin.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyMinDescriptor.restype = c_int32 + lib.infiniopDestroyMinDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + @OpRegister.operator def random_sample_(lib): lib.infiniopCreateRandomSampleDescriptor.restype = c_int32 @@ -326,6 +496,589 @@ def rearrange_(lib): lib.infiniopDestroyRearrangeDescriptor.argtypes = [infiniopOperatorDescriptor_t] +@OpRegister.operator +def abs_(lib): + lib.infiniopCreateAbsDescriptor.restype = c_int32 + lib.infiniopCreateAbsDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetAbsWorkspaceSize.restype = c_int32 + lib.infiniopGetAbsWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopAbs.restype = c_int32 + lib.infiniopAbs.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyAbsDescriptor.restype = c_int32 + lib.infiniopDestroyAbsDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + +@OpRegister.operator +def acos_(lib): + lib.infiniopCreateAcosDescriptor.restype = c_int32 + lib.infiniopCreateAcosDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + lib.infiniopGetAcosWorkspaceSize.restype = c_int32 + lib.infiniopGetAcosWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + lib.infiniopAcos.restype = c_int32 + lib.infiniopAcos.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyAcosDescriptor.restype = c_int32 + lib.infiniopDestroyAcosDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + +@OpRegister.operator +def acosh_(lib): + lib.infiniopCreateAcoshDescriptor.restype = c_int32 + lib.infiniopCreateAcoshDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + lib.infiniopGetAcoshWorkspaceSize.restype = c_int32 + lib.infiniopGetAcoshWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + lib.infiniopAcosh.restype = c_int32 + lib.infiniopAcosh.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyAcoshDescriptor.restype = c_int32 + lib.infiniopDestroyAcoshDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + +@OpRegister.operator +def asin_(lib): + lib.infiniopCreateAsinDescriptor.restype = c_int32 + lib.infiniopCreateAsinDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + lib.infiniopGetAsinWorkspaceSize.restype = c_int32 + lib.infiniopGetAsinWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + lib.infiniopAsin.restype = c_int32 + lib.infiniopAsin.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyAsinDescriptor.restype = c_int32 + lib.infiniopDestroyAsinDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + +@OpRegister.operator +def asinh_(lib): + lib.infiniopCreateAsinhDescriptor.restype = c_int32 + lib.infiniopCreateAsinhDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + lib.infiniopGetAsinhWorkspaceSize.restype = c_int32 + lib.infiniopGetAsinhWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + lib.infiniopAsinh.restype = c_int32 + lib.infiniopAsinh.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyAsinhDescriptor.restype = c_int32 + lib.infiniopDestroyAsinhDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + +@OpRegister.operator +def atan_(lib): + lib.infiniopCreateAtanDescriptor.restype = c_int32 + lib.infiniopCreateAtanDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + lib.infiniopGetAtanWorkspaceSize.restype = c_int32 + lib.infiniopGetAtanWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + lib.infiniopAtan.restype = c_int32 + lib.infiniopAtan.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyAtanDescriptor.restype = c_int32 + lib.infiniopDestroyAtanDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + +@OpRegister.operator +def atanh_(lib): + lib.infiniopCreateAtanhDescriptor.restype = c_int32 + lib.infiniopCreateAtanhDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + lib.infiniopGetAtanhWorkspaceSize.restype = c_int32 + lib.infiniopGetAtanhWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + lib.infiniopAtanh.restype = c_int32 + lib.infiniopAtanh.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyAtanhDescriptor.restype = c_int32 + lib.infiniopDestroyAtanhDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + +@OpRegister.operator +def ceil_(lib): + lib.infiniopCreateCeilDescriptor.restype = c_int32 + lib.infiniopCreateCeilDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + lib.infiniopGetCeilWorkspaceSize.restype = c_int32 + lib.infiniopGetCeilWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + lib.infiniopCeil.restype = c_int32 + lib.infiniopCeil.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyCeilDescriptor.restype = c_int32 + lib.infiniopDestroyCeilDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + +@OpRegister.operator +def cos_(lib): + lib.infiniopCreateCosDescriptor.restype = c_int32 + lib.infiniopCreateCosDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + lib.infiniopGetCosWorkspaceSize.restype = c_int32 + lib.infiniopGetCosWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + lib.infiniopCos.restype = c_int32 + lib.infiniopCos.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyCosDescriptor.restype = c_int32 + lib.infiniopDestroyCosDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + +@OpRegister.operator +def cosh_(lib): + lib.infiniopCreateCoshDescriptor.restype = c_int32 + lib.infiniopCreateCoshDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + lib.infiniopGetCoshWorkspaceSize.restype = c_int32 + lib.infiniopGetCoshWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + lib.infiniopCosh.restype = c_int32 + lib.infiniopCosh.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyCoshDescriptor.restype = c_int32 + lib.infiniopDestroyCoshDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + +@OpRegister.operator +def sinh_(lib): + lib.infiniopCreateSinhDescriptor.restype = c_int32 + lib.infiniopCreateSinhDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + lib.infiniopGetSinhWorkspaceSize.restype = c_int32 + lib.infiniopGetSinhWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + lib.infiniopSinh.restype = c_int32 + lib.infiniopSinh.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroySinhDescriptor.restype = c_int32 + lib.infiniopDestroySinhDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + +@OpRegister.operator +def erf_(lib): + lib.infiniopCreateErfDescriptor.restype = c_int32 + lib.infiniopCreateErfDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + lib.infiniopGetErfWorkspaceSize.restype = c_int32 + lib.infiniopGetErfWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + lib.infiniopErf.restype = c_int32 + lib.infiniopErf.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyErfDescriptor.restype = c_int32 + lib.infiniopDestroyErfDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + +@OpRegister.operator +def floor_(lib): + lib.infiniopCreateFloorDescriptor.restype = c_int32 + lib.infiniopCreateFloorDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + lib.infiniopGetFloorWorkspaceSize.restype = c_int32 + lib.infiniopGetFloorWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + lib.infiniopFloor.restype = c_int32 + lib.infiniopFloor.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyFloorDescriptor.restype = c_int32 + lib.infiniopDestroyFloorDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + +@OpRegister.operator +def neg_(lib): + lib.infiniopCreateNegDescriptor.restype = c_int32 + lib.infiniopCreateNegDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + lib.infiniopGetNegWorkspaceSize.restype = c_int32 + lib.infiniopGetNegWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + lib.infiniopNeg.restype = c_int32 + lib.infiniopNeg.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyNegDescriptor.restype = c_int32 + lib.infiniopDestroyNegDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + +@OpRegister.operator +def reciprocal_(lib): + lib.infiniopCreateReciprocalDescriptor.restype = c_int32 + lib.infiniopCreateReciprocalDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + lib.infiniopGetReciprocalWorkspaceSize.restype = c_int32 + lib.infiniopGetReciprocalWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + lib.infiniopReciprocal.restype = c_int32 + lib.infiniopReciprocal.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyReciprocalDescriptor.restype = c_int32 + lib.infiniopDestroyReciprocalDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + +@OpRegister.operator +def round_(lib): + lib.infiniopCreateRoundDescriptor.restype = c_int32 + lib.infiniopCreateRoundDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + lib.infiniopGetRoundWorkspaceSize.restype = c_int32 + lib.infiniopGetRoundWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + lib.infiniopRound.restype = c_int32 + lib.infiniopRound.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyRoundDescriptor.restype = c_int32 + lib.infiniopDestroyRoundDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + +@OpRegister.operator +def sign_(lib): + lib.infiniopCreateSignDescriptor.restype = c_int32 + lib.infiniopCreateSignDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + lib.infiniopGetSignWorkspaceSize.restype = c_int32 + lib.infiniopGetSignWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + lib.infiniopSign.restype = c_int32 + lib.infiniopSign.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroySignDescriptor.restype = c_int32 + lib.infiniopDestroySignDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + +@OpRegister.operator +def sqrt_(lib): + lib.infiniopCreateSqrtDescriptor.restype = c_int32 + lib.infiniopCreateSqrtDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + lib.infiniopGetSqrtWorkspaceSize.restype = c_int32 + lib.infiniopGetSqrtWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + lib.infiniopSqrt.restype = c_int32 + lib.infiniopSqrt.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroySqrtDescriptor.restype = c_int32 + lib.infiniopDestroySqrtDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + +@OpRegister.operator +def log_(lib): + lib.infiniopCreateLogDescriptor.restype = c_int32 + lib.infiniopCreateLogDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + lib.infiniopGetLogWorkspaceSize.restype = c_int32 + lib.infiniopGetLogWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + lib.infiniopLog.restype = c_int32 + lib.infiniopLog.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyLogDescriptor.restype = c_int32 + lib.infiniopDestroyLogDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + +@OpRegister.operator +def tan_(lib): + lib.infiniopCreateTanDescriptor.restype = c_int32 + lib.infiniopCreateTanDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + lib.infiniopGetTanWorkspaceSize.restype = c_int32 + lib.infiniopGetTanWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + lib.infiniopTan.restype = c_int32 + lib.infiniopTan.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyTanDescriptor.restype = c_int32 + lib.infiniopDestroyTanDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + @OpRegister.operator def relu_(lib): lib.infiniopCreateReluDescriptor.restype = c_int32 diff --git a/test/infiniop/libinfiniop/unary_test_base.py b/test/infiniop/libinfiniop/unary_test_base.py new file mode 100644 index 000000000..648a97d3e --- /dev/null +++ b/test/infiniop/libinfiniop/unary_test_base.py @@ -0,0 +1,242 @@ +""" +Base test template for unary operators. + +This module provides a unified test framework for all unary operators, +eliminating code duplication across individual test scripts. + +Usage: + from libinfiniop.unary_test_base import UnaryTestBase + + class AbsTest(UnaryTestBase): + OP_NAME = "Abs" + OP_NAME_LOWER = "abs" + + @staticmethod + def torch_op(x): + return torch.abs(x).to(x.dtype) + + @staticmethod + def generate_input(shape, dtype, device): + # Generate test tensors with values in range [-1, 1) for abs operation + return torch.rand(shape, dtype=dtype, device=device) * 2 - 1 + + TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7}, + } + + if __name__ == "__main__": + AbsTest.run() +""" + +import ctypes +from ctypes import c_uint64 +from enum import Enum, auto + +import torch +from libinfiniop import ( + LIBINFINIOP, + InfiniDeviceNames, + InfiniDtype, + InfiniDtypeNames, + TestTensor, + TestWorkspace, + check_error, + debug, + get_args, + get_test_devices, + get_tolerance, + infiniopOperatorDescriptor_t, + profile_operation, + test_operator, +) +from libinfiniop.utils import to_torch_dtype +from libinfiniop.devices import torch_device_map + + +class Inplace(Enum): + OUT_OF_PLACE = auto() + INPLACE_X = auto() + + +# Common test cases for unary operators +_UNARY_TEST_CASES_ = [ + # tensor_shape, inplace + ((1, 3),), + ((3, 3),), + ((32, 20, 512),), + ((33, 333, 333),), + ((32, 256, 112, 112),), + ((3, 3, 13, 9, 17),), +] + +# Inplace options applied for each test case +_UNARY_INPLACE = [ + Inplace.OUT_OF_PLACE, + Inplace.INPLACE_X, +] + +# Form the test cases by appending each element of _INPLACE to each tuple in _TEST_CASES_ +_UNARY_TEST_CASES = [ + test_case + (inplace_item,) + for test_case in _UNARY_TEST_CASES_ + for inplace_item in _UNARY_INPLACE +] + +# Data types used for testing (matching old operators library: only F16 and F32) +_UNARY_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.F32] + + +class UnaryTestBase: + """ + Base class for unary operator tests. + + Subclasses must define: + - OP_NAME: Uppercase operator name (e.g., "Abs", "Log") + - OP_NAME_LOWER: Lowercase operator name (e.g., "abs", "log") + - torch_op: Static method that performs the PyTorch operation + - generate_input: Static method that generates input tensor + - TOLERANCE_MAP: Dictionary mapping dtype to tolerance values + """ + + OP_NAME = None + OP_NAME_LOWER = None + + # Default tolerance map (can be overridden) + TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7}, + } + + # Test cases (can be overridden) + TEST_CASES = _UNARY_TEST_CASES + TENSOR_DTYPES = _UNARY_TENSOR_DTYPES + + DEBUG = False + PROFILE = False + NUM_PRERUN = 10 + NUM_ITERATIONS = 1000 + + @staticmethod + def torch_op(x): + """PyTorch operation - must be implemented by subclass""" + raise NotImplementedError("Subclass must implement torch_op") + + @staticmethod + def generate_input(shape, dtype, device): + """ + Generate input tensor - must be implemented by subclass. + + Args: + shape: Tensor shape tuple + dtype: PyTorch dtype (e.g., torch.float16, torch.float32) + device: PyTorch device string (e.g., "cpu", "cuda") + + Returns: + torch.Tensor: Generated input tensor + """ + raise NotImplementedError("Subclass must implement generate_input") + + @classmethod + def test(cls, handle, device, shape, inplace=Inplace.OUT_OF_PLACE, dtype=InfiniDtype.F16, sync=None): + """Common test function for unary operators""" + from libinfiniop.devices import torch_device_map + from libinfiniop.utils import to_torch_dtype + + # Generate input tensor + torch_dtype = to_torch_dtype(dtype) + torch_device = torch_device_map[device] + x_torch_tensor = cls.generate_input(shape, torch_dtype, torch_device) + + x = TestTensor( + shape, + x_torch_tensor.stride(), + dtype, + device, + mode="manual", + set_tensor=x_torch_tensor, + ) + + if inplace == Inplace.INPLACE_X: + y = x + else: + y = TestTensor(shape, None, dtype, device) + + if y.is_broadcast(): + return + + print( + f"Testing {cls.OP_NAME} on {InfiniDeviceNames[device]} with shape:{shape} dtype:{InfiniDtypeNames[dtype]} inplace: {inplace}" + ) + + ans = cls.torch_op(x.torch_tensor()) + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + create_func = getattr(LIBINFINIOP, f"infiniopCreate{cls.OP_NAME}Descriptor") + check_error( + create_func( + handle, ctypes.byref(descriptor), y.descriptor, x.descriptor + ) + ) + + # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel + for tensor in [x, y]: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + get_workspace_func = getattr(LIBINFINIOP, f"infiniopGet{cls.OP_NAME}WorkspaceSize") + check_error( + get_workspace_func( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, y.device) + + def lib_op(): + op_func = getattr(LIBINFINIOP, f"infiniop{cls.OP_NAME}") + check_error( + op_func( + descriptor, workspace.data(), workspace_size.value, y.data(), x.data(), None + ) + ) + + lib_op() + if sync is not None: + sync() + + atol, rtol = get_tolerance(cls.TOLERANCE_MAP, dtype) + equal_nan = getattr(cls, 'EQUAL_NAN', False) + + if cls.DEBUG: + debug(y.actual_tensor(), ans, atol=atol, rtol=rtol, equal_nan=equal_nan) + + assert torch.allclose(y.actual_tensor(), ans, atol=atol, rtol=rtol, equal_nan=equal_nan) + + # Profiling workflow + if cls.PROFILE: + # fmt: off + profile_operation("PyTorch", lambda: cls.torch_op(x.torch_tensor()), device, cls.NUM_PRERUN, cls.NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_op(), device, cls.NUM_PRERUN, cls.NUM_ITERATIONS) + # fmt: on + + destroy_func = getattr(LIBINFINIOP, f"infiniopDestroy{cls.OP_NAME}Descriptor") + check_error(destroy_func(descriptor)) + + @classmethod + def run(cls): + """Run the test""" + args = get_args() + + # Configure testing options + cls.DEBUG = args.debug + cls.PROFILE = args.profile + cls.NUM_PRERUN = args.num_prerun + cls.NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, cls.test, cls.TEST_CASES, cls.TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") diff --git a/test/infiniop/test_all_binary_ops.py b/test/infiniop/test_all_binary_ops.py new file mode 100644 index 000000000..e08b3e41b --- /dev/null +++ b/test/infiniop/test_all_binary_ops.py @@ -0,0 +1,251 @@ +""" +统一测试所有 Binary 算子 + +这个文件包含所有 binary 算子的测试,方便统一管理和运行。 +可以通过命令行参数选择运行哪些算子,或者运行所有算子。 + +使用方法: + # 运行所有 binary 算子测试 + python test_all_binary_ops.py + + # 只运行 div 和 pow 算子 + python test_all_binary_ops.py --ops div pow + + # 运行特定设备上的测试 + python test_all_binary_ops.py --cpu --nvidia +""" + +import torch +import argparse +from libinfiniop import InfiniDtype, TestTensor +from libinfiniop.binary_test_base import BinaryTestBase + + +# ============================================================================== +# 所有 Binary 算子的测试类定义 +# ============================================================================== + +class DivTest(BinaryTestBase): + OP_NAME = "Div" + OP_NAME_LOWER = "div" + + @staticmethod + def torch_op(c, a, b): + torch.div(a, b, out=c) + + @staticmethod + def generate_input_a(shape, a_stride, dtype, device): + return TestTensor(shape, a_stride, dtype, device) + + @staticmethod + def generate_input_b(shape, b_stride, dtype, device): + # For division, ensure b doesn't contain zeros + return TestTensor(shape, b_stride, dtype, device, scale=2, bias=0.1) + + TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-6, "rtol": 1e-6}, + } + + EQUAL_NAN = True + + +class PowTest(BinaryTestBase): + OP_NAME = "Pow" + OP_NAME_LOWER = "pow" + + @staticmethod + def torch_op(c, a, b): + torch.pow(a, b, out=c) + + @staticmethod + def generate_input_a(shape, a_stride, dtype, device): + # Avoid negative bases and very large exponents + return TestTensor(shape, a_stride, dtype, device, mode="random", scale=5.0, bias=0.1) + + @staticmethod + def generate_input_b(shape, b_stride, dtype, device): + return TestTensor(shape, b_stride, dtype, device, mode="random", scale=3.0, bias=0.1) + + TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-3}, + } + + EQUAL_NAN = True + + +class ModTest(BinaryTestBase): + OP_NAME = "Mod" + OP_NAME_LOWER = "mod" + + @staticmethod + def torch_op(c, a, b): + torch.remainder(a, b, out=c) + + @staticmethod + def generate_input_a(shape, a_stride, dtype, device): + return TestTensor(shape, a_stride, dtype, device) + + @staticmethod + def generate_input_b(shape, b_stride, dtype, device): + # Avoid zeros + return TestTensor(shape, b_stride, dtype, device, scale=2, bias=0.1) + + TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-6, "rtol": 1e-6}, + } + + EQUAL_NAN = True + + +class MaxTest(BinaryTestBase): + OP_NAME = "Max" + OP_NAME_LOWER = "max" + + @staticmethod + def torch_op(c, a, b): + torch.maximum(a, b, out=c) + + @staticmethod + def generate_input_a(shape, a_stride, dtype, device): + return TestTensor(shape, a_stride, dtype, device) + + @staticmethod + def generate_input_b(shape, b_stride, dtype, device): + return TestTensor(shape, b_stride, dtype, device) + + TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-6, "rtol": 1e-6}, + } + + EQUAL_NAN = True + + +class MinTest(BinaryTestBase): + OP_NAME = "Min" + OP_NAME_LOWER = "min" + + @staticmethod + def torch_op(c, a, b): + torch.minimum(a, b, out=c) + + @staticmethod + def generate_input_a(shape, a_stride, dtype, device): + return TestTensor(shape, a_stride, dtype, device) + + @staticmethod + def generate_input_b(shape, b_stride, dtype, device): + return TestTensor(shape, b_stride, dtype, device) + + TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-6, "rtol": 1e-6}, + } + + EQUAL_NAN = True + + +# ============================================================================== +# 算子注册表 +# ============================================================================== + +# 所有 binary 算子的测试类映射 +BINARY_OP_TESTS = { + "div": DivTest, + "pow": PowTest, + "mod": ModTest, + "max": MaxTest, + "min": MinTest, +} + + +# ============================================================================== +# 主函数 +# ============================================================================== + +def main(): + # 先获取基础参数解析器 + from libinfiniop.utils import get_args as get_base_args + import sys + + # 创建新的参数解析器,添加 --ops 参数 + parser = argparse.ArgumentParser(description="Test all binary operators", parents=[]) + parser.add_argument( + "--ops", + nargs="+", + choices=list(BINARY_OP_TESTS.keys()), + default=list(BINARY_OP_TESTS.keys()), + help="Specify which operators to test (default: all)", + ) + + # 解析参数 + args, unknown = parser.parse_known_args() + + # 将未知参数传递给基础参数解析器 + if unknown: + sys.argv = [sys.argv[0]] + unknown + base_args = get_base_args() + else: + # 如果没有其他参数,使用默认值 + sys.argv = [sys.argv[0]] + base_args = get_base_args() + + # 合并参数 + for attr in dir(base_args): + if not attr.startswith("_") and not hasattr(args, attr): + setattr(args, attr, getattr(base_args, attr)) + + # 运行选定的算子测试 + print(f"\n{'='*60}") + print(f"Testing {len(args.ops)} binary operator(s): {', '.join(args.ops)}") + print(f"{'='*60}\n") + + failed_ops = [] + passed_ops = [] + + for op_name in args.ops: + test_class = BINARY_OP_TESTS[op_name] + print(f"\n{'='*60}") + print(f"Testing {test_class.OP_NAME} operator") + print(f"{'='*60}") + + try: + # 创建临时参数对象,传递给测试类 + test_class.DEBUG = args.debug + test_class.PROFILE = args.profile + test_class.NUM_PRERUN = args.num_prerun + test_class.NUM_ITERATIONS = args.num_iterations + + # 运行测试 + for device in get_test_devices(args): + test_operator(device, test_class.test, test_class.TEST_CASES, test_class.TENSOR_DTYPES) + + print(f"\033[92m{test_class.OP_NAME} test passed!\033[0m") + passed_ops.append(op_name) + except Exception as e: + print(f"\033[91m{test_class.OP_NAME} test failed: {e}\033[0m") + failed_ops.append(op_name) + if args.debug: + import traceback + traceback.print_exc() + + # 打印总结 + print(f"\n{'='*60}") + print("Test Summary") + print(f"{'='*60}") + print(f"Total operators: {len(args.ops)}") + print(f"\033[92mPassed: {len(passed_ops)} - {', '.join(passed_ops)}\033[0m") + if failed_ops: + print(f"\033[91mFailed: {len(failed_ops)} - {', '.join(failed_ops)}\033[0m") + print(f"{'='*60}\n") + + if failed_ops: + exit(1) + + +if __name__ == "__main__": + from libinfiniop.utils import get_test_devices, test_operator + main() diff --git a/test/infiniop/test_all_unary_ops.py b/test/infiniop/test_all_unary_ops.py new file mode 100644 index 000000000..b9d7cdc8b --- /dev/null +++ b/test/infiniop/test_all_unary_ops.py @@ -0,0 +1,548 @@ +""" +统一测试所有 Unary 算子 + +这个文件包含所有 unary 算子的测试,方便统一管理和运行。 +可以通过命令行参数选择运行哪些算子,或者运行所有算子。 + +使用方法: + # 运行所有 unary 算子测试 + python test_all_unary_ops.py + + # 只运行 abs 和 log 算子 + python test_all_unary_ops.py --ops abs log + + # 运行特定设备上的测试 + python test_all_unary_ops.py --cpu --nvidia +""" + +import torch +import argparse +from libinfiniop import InfiniDtype +from libinfiniop.unary_test_base import UnaryTestBase + + +# ============================================================================== +# 所有 Unary 算子的测试类定义 +# ============================================================================== + +class AbsTest(UnaryTestBase): + OP_NAME = "Abs" + OP_NAME_LOWER = "abs" + + @staticmethod + def torch_op(x): + return torch.abs(x).to(x.dtype) + + @staticmethod + def generate_input(shape, dtype, device): + return torch.rand(shape, dtype=dtype, device=device) * 2 - 1 + + TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7}, + } + + +class AcosTest(UnaryTestBase): + OP_NAME = "Acos" + OP_NAME_LOWER = "acos" + + @staticmethod + def torch_op(x): + return torch.acos(x).to(x.dtype) + + @staticmethod + def generate_input(shape, dtype, device): + # acos domain is [-1, 1] + return torch.rand(shape, dtype=dtype, device=device) * 2 - 1 + + TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7}, + } + + EQUAL_NAN = True + + +class AcoshTest(UnaryTestBase): + OP_NAME = "Acosh" + OP_NAME_LOWER = "acosh" + + @staticmethod + def torch_op(x): + return torch.acosh(x).to(x.dtype) + + @staticmethod + def generate_input(shape, dtype, device): + # acosh domain is [1, +∞) + return torch.rand(shape, dtype=dtype, device=device) * 10 + 1 + + TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7}, + } + + EQUAL_NAN = True + + +class AsinTest(UnaryTestBase): + OP_NAME = "Asin" + OP_NAME_LOWER = "asin" + + @staticmethod + def torch_op(x): + return torch.asin(x).to(x.dtype) + + @staticmethod + def generate_input(shape, dtype, device): + # asin domain is [-1, 1] + return torch.rand(shape, dtype=dtype, device=device) * 2 - 1 + + TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7}, + } + + EQUAL_NAN = True + + +class AsinhTest(UnaryTestBase): + OP_NAME = "Asinh" + OP_NAME_LOWER = "asinh" + + @staticmethod + def torch_op(x): + return torch.asinh(x).to(x.dtype) + + @staticmethod + def generate_input(shape, dtype, device): + return torch.rand(shape, dtype=dtype, device=device) * 2 - 1 + + TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7}, + } + + EQUAL_NAN = True + + +class AtanTest(UnaryTestBase): + OP_NAME = "Atan" + OP_NAME_LOWER = "atan" + + @staticmethod + def torch_op(x): + return torch.atan(x).to(x.dtype) + + @staticmethod + def generate_input(shape, dtype, device): + return torch.rand(shape, dtype=dtype, device=device) * 2 - 1 + + TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7}, + } + + EQUAL_NAN = True + + +class AtanhTest(UnaryTestBase): + OP_NAME = "Atanh" + OP_NAME_LOWER = "atanh" + + @staticmethod + def torch_op(x): + return torch.atanh(x).to(x.dtype) + + @staticmethod + def generate_input(shape, dtype, device): + # atanh domain is (-1, 1) + return torch.rand(shape, dtype=dtype, device=device) * 1.8 - 0.9 + + TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7}, + } + + EQUAL_NAN = True + + +class CeilTest(UnaryTestBase): + OP_NAME = "Ceil" + OP_NAME_LOWER = "ceil" + + @staticmethod + def torch_op(x): + return torch.ceil(x).to(x.dtype) + + @staticmethod + def generate_input(shape, dtype, device): + return torch.rand(shape, dtype=dtype, device=device) * 10 - 5 + + TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7}, + } + + +class CosTest(UnaryTestBase): + OP_NAME = "Cos" + OP_NAME_LOWER = "cos" + + @staticmethod + def torch_op(x): + return torch.cos(x).to(x.dtype) + + @staticmethod + def generate_input(shape, dtype, device): + # Generate test tensors with values in range [-200, -100) for cos operation + # cos domain is (-∞, +∞), so we use range [-200, -100) + return torch.rand(shape, dtype=dtype, device=device) * 100 - 200 + + TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-4, "rtol": 1e-2}, + InfiniDtype.F32: {"atol": 1e-4, "rtol": 1e-2}, + } + + EQUAL_NAN = True + + +class CoshTest(UnaryTestBase): + OP_NAME = "Cosh" + OP_NAME_LOWER = "cosh" + + @staticmethod + def torch_op(x): + return torch.cosh(x).to(x.dtype) + + @staticmethod + def generate_input(shape, dtype, device): + return torch.rand(shape, dtype=dtype, device=device) * 2 - 1 + + TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7}, + } + + EQUAL_NAN = True + + +class ErfTest(UnaryTestBase): + OP_NAME = "Erf" + OP_NAME_LOWER = "erf" + + @staticmethod + def torch_op(x): + return torch.erf(x).to(x.dtype) + + @staticmethod + def generate_input(shape, dtype, device): + return torch.rand(shape, dtype=dtype, device=device) * 2 - 1 + + TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7}, + } + + EQUAL_NAN = True + + +class FloorTest(UnaryTestBase): + OP_NAME = "Floor" + OP_NAME_LOWER = "floor" + + @staticmethod + def torch_op(x): + return torch.floor(x).to(x.dtype) + + @staticmethod + def generate_input(shape, dtype, device): + return torch.rand(shape, dtype=dtype, device=device) * 10 - 5 + + TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7}, + } + + EQUAL_NAN = True + + +class LogTest(UnaryTestBase): + OP_NAME = "Log" + OP_NAME_LOWER = "log" + + @staticmethod + def torch_op(x): + return torch.log(x).to(x.dtype) + + @staticmethod + def generate_input(shape, dtype, device): + # log domain is (0, +∞), so we use range [0.1, 1.1) + return torch.rand(shape, dtype=dtype, device=device) + 0.1 + + TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-7, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-3}, + } + + EQUAL_NAN = True + + +class NegTest(UnaryTestBase): + OP_NAME = "Neg" + OP_NAME_LOWER = "neg" + + @staticmethod + def torch_op(x): + return torch.neg(x).to(x.dtype) + + @staticmethod + def generate_input(shape, dtype, device): + return torch.rand(shape, dtype=dtype, device=device) * 2 - 1 + + TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7}, + } + + EQUAL_NAN = True + + +class ReciprocalTest(UnaryTestBase): + OP_NAME = "Reciprocal" + OP_NAME_LOWER = "reciprocal" + + @staticmethod + def torch_op(x): + return torch.reciprocal(x).to(x.dtype) + + @staticmethod + def generate_input(shape, dtype, device): + # Avoid zeros + return torch.rand(shape, dtype=dtype, device=device) * 2 + 0.1 + + TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7}, + } + + EQUAL_NAN = True + + +class RoundTest(UnaryTestBase): + OP_NAME = "Round" + OP_NAME_LOWER = "round" + + @staticmethod + def torch_op(x): + return torch.round(x).to(x.dtype) + + @staticmethod + def generate_input(shape, dtype, device): + return torch.rand(shape, dtype=dtype, device=device) * 10 - 5 + + TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7}, + } + + EQUAL_NAN = True + + +class SignTest(UnaryTestBase): + OP_NAME = "Sign" + OP_NAME_LOWER = "sign" + + @staticmethod + def torch_op(x): + return torch.sign(x).to(x.dtype) + + @staticmethod + def generate_input(shape, dtype, device): + return torch.rand(shape, dtype=dtype, device=device) * 2 - 1 + + TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7}, + } + + EQUAL_NAN = True + + +class SinhTest(UnaryTestBase): + OP_NAME = "Sinh" + OP_NAME_LOWER = "sinh" + + @staticmethod + def torch_op(x): + return torch.sinh(x).to(x.dtype) + + @staticmethod + def generate_input(shape, dtype, device): + return torch.rand(shape, dtype=dtype, device=device) * 2 - 1 + + TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7}, + } + + EQUAL_NAN = True + + +class SqrtTest(UnaryTestBase): + OP_NAME = "Sqrt" + OP_NAME_LOWER = "sqrt" + + @staticmethod + def torch_op(x): + return torch.sqrt(x).to(x.dtype) + + @staticmethod + def generate_input(shape, dtype, device): + # sqrt domain is [0, +∞) + return torch.rand(shape, dtype=dtype, device=device) * 100 + + TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 0, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 0, "rtol": 1e-3}, + } + + EQUAL_NAN = True + + +class TanTest(UnaryTestBase): + OP_NAME = "Tan" + OP_NAME_LOWER = "tan" + + @staticmethod + def torch_op(x): + return torch.tan(x).to(x.dtype) + + @staticmethod + def generate_input(shape, dtype, device): + return torch.rand(shape, dtype=dtype, device=device) * 2 - 1 + + TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7}, + } + + EQUAL_NAN = True + + +# ============================================================================== +# 算子注册表 +# ============================================================================== + +# 所有 unary 算子的测试类映射 +UNARY_OP_TESTS = { + "abs": AbsTest, + "acos": AcosTest, + "acosh": AcoshTest, + "asin": AsinTest, + "asinh": AsinhTest, + "atan": AtanTest, + "atanh": AtanhTest, + "ceil": CeilTest, + "cos": CosTest, + "cosh": CoshTest, + "erf": ErfTest, + "floor": FloorTest, + "log": LogTest, + "neg": NegTest, + "reciprocal": ReciprocalTest, + "round": RoundTest, + "sign": SignTest, + "sinh": SinhTest, + "sqrt": SqrtTest, + "tan": TanTest, +} + + +# ============================================================================== +# 主函数 +# ============================================================================== + +def main(): + # 先获取基础参数解析器 + from libinfiniop.utils import get_args as get_base_args + import sys + + # 创建新的参数解析器,添加 --ops 参数 + parser = argparse.ArgumentParser(description="Test all unary operators", parents=[]) + parser.add_argument( + "--ops", + nargs="+", + choices=list(UNARY_OP_TESTS.keys()), + default=list(UNARY_OP_TESTS.keys()), + help="Specify which operators to test (default: all)", + ) + + # 解析参数 + args, unknown = parser.parse_known_args() + + # 将未知参数传递给基础参数解析器 + if unknown: + sys.argv = [sys.argv[0]] + unknown + base_args = get_base_args() + else: + # 如果没有其他参数,使用默认值 + sys.argv = [sys.argv[0]] + base_args = get_base_args() + + # 合并参数 + for attr in dir(base_args): + if not attr.startswith("_") and not hasattr(args, attr): + setattr(args, attr, getattr(base_args, attr)) + + # 运行选定的算子测试 + print(f"\n{'='*60}") + print(f"Testing {len(args.ops)} unary operator(s): {', '.join(args.ops)}") + print(f"{'='*60}\n") + + failed_ops = [] + passed_ops = [] + + for op_name in args.ops: + test_class = UNARY_OP_TESTS[op_name] + print(f"\n{'='*60}") + print(f"Testing {test_class.OP_NAME} operator") + print(f"{'='*60}") + + try: + # 创建临时参数对象,传递给测试类 + test_class.DEBUG = args.debug + test_class.PROFILE = args.profile + test_class.NUM_PRERUN = args.num_prerun + test_class.NUM_ITERATIONS = args.num_iterations + + # 运行测试 + for device in get_test_devices(args): + test_operator(device, test_class.test, test_class.TEST_CASES, test_class.TENSOR_DTYPES) + + print(f"\033[92m{test_class.OP_NAME} test passed!\033[0m") + passed_ops.append(op_name) + except Exception as e: + print(f"\033[91m{test_class.OP_NAME} test failed: {e}\033[0m") + failed_ops.append(op_name) + if args.debug: + import traceback + traceback.print_exc() + + # 打印总结 + print(f"\n{'='*60}") + print("Test Summary") + print(f"{'='*60}") + print(f"Total operators: {len(args.ops)}") + print(f"\033[92mPassed: {len(passed_ops)} - {', '.join(passed_ops)}\033[0m") + if failed_ops: + print(f"\033[91mFailed: {len(failed_ops)} - {', '.join(failed_ops)}\033[0m") + print(f"{'='*60}\n") + + if failed_ops: + exit(1) + + +if __name__ == "__main__": + from libinfiniop.utils import get_test_devices, test_operator + main()