From bac77452b82e12a3131c5153eac9c5c800706f76 Mon Sep 17 00:00:00 2001 From: gongchensu Date: Thu, 16 Oct 2025 15:31:05 +0800 Subject: [PATCH] issue/383: Add logsoftmax ops Co-authored-by: wawahejun Co-authored-by: zhuyue --- include/infiniop.h | 1 + include/infiniop/ops/logsoftmax.h | 24 ++ scripts/python_test.py | 1 + .../ops/logsoftmax/cpu/logsoftmax_cpu.cc | 130 +++++++ .../ops/logsoftmax/cpu/logsoftmax_cpu.h | 7 + src/infiniop/ops/logsoftmax/cuda/kernel.cuh | 108 ++++++ src/infiniop/ops/logsoftmax/info.h | 117 +++++++ src/infiniop/ops/logsoftmax/logsoftmax.h | 46 +++ .../logsoftmax/nvidia/logsoftmax_nvidia.cu | 131 +++++++ .../logsoftmax/nvidia/logsoftmax_nvidia.cuh | 8 + src/infiniop/ops/logsoftmax/operator.cc | 136 ++++++++ test/infiniop/libinfiniop/op_register.py | 32 ++ test/infiniop/logsoftmax.py | 324 ++++++++++++++++++ 13 files changed, 1065 insertions(+) create mode 100644 include/infiniop/ops/logsoftmax.h create mode 100644 src/infiniop/ops/logsoftmax/cpu/logsoftmax_cpu.cc create mode 100644 src/infiniop/ops/logsoftmax/cpu/logsoftmax_cpu.h create mode 100644 src/infiniop/ops/logsoftmax/cuda/kernel.cuh create mode 100644 src/infiniop/ops/logsoftmax/info.h create mode 100644 src/infiniop/ops/logsoftmax/logsoftmax.h create mode 100644 src/infiniop/ops/logsoftmax/nvidia/logsoftmax_nvidia.cu create mode 100644 src/infiniop/ops/logsoftmax/nvidia/logsoftmax_nvidia.cuh create mode 100644 src/infiniop/ops/logsoftmax/operator.cc create mode 100644 test/infiniop/logsoftmax.py diff --git a/include/infiniop.h b/include/infiniop.h index b3cf8b6ca..f0d75abc9 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -6,6 +6,7 @@ #include "infiniop/ops/attention.h" #include "infiniop/ops/causal_softmax.h" #include "infiniop/ops/clip.h" +#include "infiniop/ops/logsoftmax.h" #include "infiniop/ops/conv.h" #include "infiniop/ops/dequantize_awq.h" #include "infiniop/ops/gemm.h" diff --git a/include/infiniop/ops/logsoftmax.h b/include/infiniop/ops/logsoftmax.h new file mode 100644 index 000000000..1b944424c --- /dev/null +++ b/include/infiniop/ops/logsoftmax.h @@ -0,0 +1,24 @@ +#ifndef __INFINIOP_LOGSOFTMAX_API_H__ +#define __INFINIOP_LOGSOFTMAX_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopLogSoftmaxDescriptor_t; + +__C __export infiniStatus_t infiniopCreateLogSoftmaxDescriptor(infiniopHandle_t handle, + infiniopLogSoftmaxDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc); + +__C __export infiniStatus_t infiniopGetLogSoftmaxWorkspaceSize(infiniopLogSoftmaxDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopLogSoftmax(infiniopLogSoftmaxDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *y, + const void *x, + void *stream); + +__C __export infiniStatus_t infiniopDestroyLogSoftmaxDescriptor(infiniopLogSoftmaxDescriptor_t desc); + +#endif diff --git a/scripts/python_test.py b/scripts/python_test.py index 5348c8c69..1f3381adc 100644 --- a/scripts/python_test.py +++ b/scripts/python_test.py @@ -17,6 +17,7 @@ def run_tests(args): "causal_softmax.py", "clip.py", "gemm.py", + "logsoftmax.py", "mul.py", "random_sample.py", "rearrange.py", diff --git a/src/infiniop/ops/logsoftmax/cpu/logsoftmax_cpu.cc b/src/infiniop/ops/logsoftmax/cpu/logsoftmax_cpu.cc new file mode 100644 index 000000000..a6a3876f9 --- /dev/null +++ b/src/infiniop/ops/logsoftmax/cpu/logsoftmax_cpu.cc @@ -0,0 +1,130 @@ +#include "logsoftmax_cpu.h" +#include "../../../devices/cpu/common_cpu.h" +#include "../../../reduce/cpu/reduce.h" +#include +#include + +namespace op::logsoftmax::cpu { + +Descriptor::~Descriptor() {} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc) { + auto result = LogSoftmaxInfo::create(y_desc, x_desc); + CHECK_RESULT(result); + *desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t logsoftmax(const LogSoftmaxInfo *info, Ty *y, const Tx *x) { +#pragma omp parallel for + for (ptrdiff_t batch = 0; batch < ptrdiff_t(info->batch_size); batch++) { + ptrdiff_t y_offset, x_offset; + + if (info->ndim == 3) { + // For 3D tensors, convert linear batch index back to 2D indices + ptrdiff_t batch_idx = batch / info->seq_len; + ptrdiff_t seq_idx = batch % info->seq_len; + y_offset = batch_idx * info->y_stride_0 + seq_idx * info->y_stride_1; + x_offset = batch_idx * info->x_stride_0 + seq_idx * info->x_stride_1; + } else { + // For 2D tensors, use the flattened strides + y_offset = batch * info->y_stride_b; + x_offset = batch * info->x_stride_b; + } + + Ty *y_ = y + y_offset; + const Tx *x_ = x + x_offset; + + // Find max value for numerical stability + float max_val; + if constexpr (std::is_same::value || std::is_same::value) { + max_val = op::common_cpu::reduce_op::max(x_, info->probs_size, info->x_stride_p); + } else { + max_val = op::common_cpu::reduce_op::max(x_, info->probs_size, info->x_stride_p); + } + + // Compute exp(x - max) and sum + float sum = 0.0f; + for (size_t i = 0; i < info->probs_size; i++) { + float x_val; + if constexpr (std::is_same::value || std::is_same::value) { + x_val = utils::cast(x_[i * info->x_stride_p]); + } else { + x_val = x_[i * info->x_stride_p]; + } + sum += std::exp(x_val - max_val); + } + + // Compute log(sum) + float log_sum = std::log(sum); + + // Compute log_softmax = x - max - log(sum) + for (size_t i = 0; i < info->probs_size; i++) { + float x_val; + if constexpr (std::is_same::value || std::is_same::value) { + x_val = utils::cast(x_[i * info->x_stride_p]); + } else { + x_val = x_[i * info->x_stride_p]; + } + + float result = x_val - max_val - log_sum; + + if constexpr (std::is_same::value || std::is_same::value) { + y_[i * info->y_stride_p] = utils::cast(result); + } else { + y_[i * info->y_stride_p] = result; + } + } + } + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *y, + const void *x, + void *stream) const { + + // Handle different input/output dtype combinations + if (_info.x_dtype == INFINI_DTYPE_F16) { + if (_info.y_dtype == INFINI_DTYPE_F16) { + return logsoftmax(&_info, (fp16_t *)y, (const fp16_t *)x); + } else if (_info.y_dtype == INFINI_DTYPE_BF16) { + return logsoftmax(&_info, (bf16_t *)y, (const fp16_t *)x); + } else if (_info.y_dtype == INFINI_DTYPE_F32) { + return logsoftmax(&_info, (float *)y, (const fp16_t *)x); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + } else if (_info.x_dtype == INFINI_DTYPE_BF16) { + if (_info.y_dtype == INFINI_DTYPE_F16) { + return logsoftmax(&_info, (fp16_t *)y, (const bf16_t *)x); + } else if (_info.y_dtype == INFINI_DTYPE_BF16) { + return logsoftmax(&_info, (bf16_t *)y, (const bf16_t *)x); + } else if (_info.y_dtype == INFINI_DTYPE_F32) { + return logsoftmax(&_info, (float *)y, (const bf16_t *)x); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + } else if (_info.x_dtype == INFINI_DTYPE_F32) { + if (_info.y_dtype == INFINI_DTYPE_F16) { + return logsoftmax(&_info, (fp16_t *)y, (const float *)x); + } else if (_info.y_dtype == INFINI_DTYPE_BF16) { + return logsoftmax(&_info, (bf16_t *)y, (const float *)x); + } else if (_info.y_dtype == INFINI_DTYPE_F32) { + return logsoftmax(&_info, (float *)y, (const float *)x); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +} // namespace op::logsoftmax::cpu diff --git a/src/infiniop/ops/logsoftmax/cpu/logsoftmax_cpu.h b/src/infiniop/ops/logsoftmax/cpu/logsoftmax_cpu.h new file mode 100644 index 000000000..371917bad --- /dev/null +++ b/src/infiniop/ops/logsoftmax/cpu/logsoftmax_cpu.h @@ -0,0 +1,7 @@ +#ifndef __LOGSOFTMAX_CPU_H__ +#define __LOGSOFTMAX_CPU_H__ +#include "../logsoftmax.h" + +DESCRIPTOR(cpu) + +#endif diff --git a/src/infiniop/ops/logsoftmax/cuda/kernel.cuh b/src/infiniop/ops/logsoftmax/cuda/kernel.cuh new file mode 100644 index 000000000..1669b38a9 --- /dev/null +++ b/src/infiniop/ops/logsoftmax/cuda/kernel.cuh @@ -0,0 +1,108 @@ +#ifndef __LOGSOFTMAX_KERNEL_CUH__ +#define __LOGSOFTMAX_KERNEL_CUH__ + +#include +#include + +template +__device__ void logSoftmaxKernel( + Tdata_out *y, const Tdata_in *x, + size_t batch_size, size_t probs_size, size_t ndim, size_t seq_len, + ptrdiff_t y_stride_b, ptrdiff_t y_stride_p, + ptrdiff_t x_stride_b, ptrdiff_t x_stride_p, + ptrdiff_t y_stride_0, ptrdiff_t y_stride_1, + ptrdiff_t x_stride_0, ptrdiff_t x_stride_1) { + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ Tcompute shared_max_val; + __shared__ Tcompute shared_sum_exp; + + int batch_idx = blockIdx.x; + int tid = threadIdx.x; + + if (batch_idx >= batch_size) { + return; + } + + // Calculate correct memory offsets for 3D tensors + ptrdiff_t y_offset, x_offset; + if (ndim == 3) { + // For 3D tensors, convert linear batch index back to 2D indices + ptrdiff_t batch_dim_idx = batch_idx / seq_len; + ptrdiff_t seq_dim_idx = batch_idx % seq_len; + y_offset = batch_dim_idx * y_stride_0 + seq_dim_idx * y_stride_1; + x_offset = batch_dim_idx * x_stride_0 + seq_dim_idx * x_stride_1; + } else { + // For 2D tensors, use the flattened strides + y_offset = batch_idx * y_stride_b; + x_offset = batch_idx * x_stride_b; + } + + const Tdata_in *x_batch = x + x_offset; + Tdata_out *y_batch = y + y_offset; + + // Find maximum value for numerical stability + Tcompute max_val = static_cast(-INFINITY); + for (int i = tid; i < probs_size; i += BLOCK_SIZE) { + if (i < probs_size) { // Add boundary check + Tcompute val = static_cast(x_batch[i * x_stride_p]); + if constexpr (std::is_same_v) { + max_val = fmaxf(max_val, val); + } else { + max_val = fmax(max_val, val); + } + } + } + max_val = BlockReduce(temp_storage).Reduce(max_val, cub::Max()); + if (tid == 0) { + shared_max_val = max_val; + } + __syncthreads(); + + // Compute sum of exp(x - max) + Tcompute sum_exp = static_cast(0.0); + for (int i = tid; i < probs_size; i += BLOCK_SIZE) { + if (i < probs_size) { // Add boundary check + Tcompute val = static_cast(x_batch[i * x_stride_p]); + if constexpr (std::is_same_v) { + sum_exp += expf(val - shared_max_val); + } else { + sum_exp += exp(val - shared_max_val); + } + } + } + sum_exp = BlockReduce(temp_storage).Sum(sum_exp); + if (tid == 0) { + shared_sum_exp = sum_exp; + } + __syncthreads(); + + // Compute log_softmax = x - max - log(sum_exp) + Tcompute log_sum_exp; + if constexpr (std::is_same_v) { + log_sum_exp = logf(shared_sum_exp); + } else { + log_sum_exp = log(shared_sum_exp); + } + for (int i = tid; i < probs_size; i += BLOCK_SIZE) { + if (i < probs_size) { // Add boundary check + Tcompute val = static_cast(x_batch[i * x_stride_p]); + Tcompute result = val - shared_max_val - log_sum_exp; + y_batch[i * y_stride_p] = static_cast(result); + } + } +} + +template +__global__ void logSoftmax( + Tdata_out *y, const Tdata_in *x, + size_t batch_size, size_t probs_size, size_t ndim, size_t seq_len, + ptrdiff_t y_stride_b, ptrdiff_t y_stride_p, + ptrdiff_t x_stride_b, ptrdiff_t x_stride_p, + ptrdiff_t y_stride_0, ptrdiff_t y_stride_1, + ptrdiff_t x_stride_0, ptrdiff_t x_stride_1) { + logSoftmaxKernel(y, x, batch_size, probs_size, ndim, seq_len, y_stride_b, y_stride_p, x_stride_b, x_stride_p, y_stride_0, y_stride_1, x_stride_0, x_stride_1); +} + +#endif // __LOGSOFTMAX_KERNEL_CUH__ diff --git a/src/infiniop/ops/logsoftmax/info.h b/src/infiniop/ops/logsoftmax/info.h new file mode 100644 index 000000000..10ff7815e --- /dev/null +++ b/src/infiniop/ops/logsoftmax/info.h @@ -0,0 +1,117 @@ +#ifndef __LOGSOFTMAX_INFO_H__ +#define __LOGSOFTMAX_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" +#include + +namespace op::logsoftmax { + +class LogSoftmaxInfo { + LogSoftmaxInfo() = default; + +public: + infiniDtype_t x_dtype; + infiniDtype_t y_dtype; + size_t batch_size; + size_t probs_size; + + // Original tensor dimensions for 3D support + size_t ndim; + size_t seq_len; // Only used for 3D tensors + + // Flattened strides for CPU iteration + ptrdiff_t y_stride_b; + ptrdiff_t y_stride_p; + ptrdiff_t x_stride_b; + ptrdiff_t x_stride_p; + + // Original 3D strides for correct memory access + ptrdiff_t y_stride_0, y_stride_1, y_stride_2; + ptrdiff_t x_stride_0, x_stride_1, x_stride_2; + + static utils::Result create(infiniopTensorDescriptor_t y_desc, infiniopTensorDescriptor_t x_desc) { + auto x_dtype = x_desc->dtype(); + auto y_dtype = y_desc->dtype(); + + CHECK_DTYPE(x_dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); + // Check the output data type, and any dtype is allowed to output fp32. + CHECK_DTYPE(y_dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); + + auto x_shape = x_desc->shape(); + auto y_shape = y_desc->shape(); + CHECK_SAME_SHAPE(x_shape, y_shape); + + auto ndim = x_desc->ndim(); + if (ndim < 2 || ndim > 3) { + CHECK_STATUS(INFINI_STATUS_BAD_TENSOR_SHAPE); + } + + size_t batch_size, probs_size, seq_len = 0; + if (ndim == 2) { + batch_size = x_shape[0]; + probs_size = x_shape[1]; + } else { // ndim == 3 + batch_size = x_shape[0] * x_shape[1]; + probs_size = x_shape[2]; + seq_len = x_shape[1]; + } + + // Store original strides for all dimensions + ptrdiff_t y_stride_0 = 0, y_stride_1 = 0, y_stride_2 = 0; + ptrdiff_t x_stride_0 = 0, x_stride_1 = 0, x_stride_2 = 0; + + if (ndim == 2) { + y_stride_0 = y_desc->stride(0); // First dimension + y_stride_1 = y_desc->stride(1); // Second dimension + x_stride_0 = x_desc->stride(0); + x_stride_1 = x_desc->stride(1); + } else if (ndim == 3) { + y_stride_0 = y_desc->stride(0); // First dimension (batch) + y_stride_1 = y_desc->stride(1); // Second dimension (seq) + y_stride_2 = y_desc->stride(2); // Third dimension (prob) + x_stride_0 = x_desc->stride(0); + x_stride_1 = x_desc->stride(1); + x_stride_2 = x_desc->stride(2); + } + + ptrdiff_t y_stride_b, y_stride_p, x_stride_b, x_stride_p; + if (ndim == 2) { + y_stride_b = y_desc->stride(0); + y_stride_p = y_desc->stride(1); + x_stride_b = x_desc->stride(0); + x_stride_p = x_desc->stride(1); + } else { // ndim == 3 + // For 3D tensors, flat the first two dimensions + // The CPU implementation expects to iterate through batch_size elements + // where each batch contains probs_size elements + // For flattened iteration, we need stride between consecutive sequences + y_stride_b = y_desc->stride(1); // stride between sequences (20*512 -> 512) + y_stride_p = y_desc->stride(2); // stride within probability dimension + x_stride_b = x_desc->stride(1); // stride between sequences + x_stride_p = x_desc->stride(2); // stride within probability dimension + } + + return utils::Result(LogSoftmaxInfo{ + x_dtype, + y_dtype, + batch_size, + probs_size, + ndim, + seq_len, + y_stride_b, + y_stride_p, + x_stride_b, + x_stride_p, + y_stride_0, + y_stride_1, + y_stride_2, + x_stride_0, + x_stride_1, + x_stride_2}); + } +}; + +} // namespace op::logsoftmax + +#endif // __LOGSOFTMAX_INFO_H__ diff --git a/src/infiniop/ops/logsoftmax/logsoftmax.h b/src/infiniop/ops/logsoftmax/logsoftmax.h new file mode 100644 index 000000000..8babdeab7 --- /dev/null +++ b/src/infiniop/ops/logsoftmax/logsoftmax.h @@ -0,0 +1,46 @@ +#ifndef LOGSOFTMAX_H +#define LOGSOFTMAX_H + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::logsoftmax::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + LogSoftmaxInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + Opaque *opaque, \ + LogSoftmaxInfo info, \ + size_t workspace_size, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t y_desc, \ + infiniopTensorDescriptor_t x_desc); \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + void *y, \ + const void *x, \ + void *stream) const; \ + }; \ + } + +#endif // LOGSOFTMAX_H diff --git a/src/infiniop/ops/logsoftmax/nvidia/logsoftmax_nvidia.cu b/src/infiniop/ops/logsoftmax/nvidia/logsoftmax_nvidia.cu new file mode 100644 index 000000000..1235b2aaf --- /dev/null +++ b/src/infiniop/ops/logsoftmax/nvidia/logsoftmax_nvidia.cu @@ -0,0 +1,131 @@ +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "logsoftmax_nvidia.cuh" + +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include + +#include "../cuda/kernel.cuh" + +namespace op::logsoftmax::nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc) { + auto info = LogSoftmaxInfo::create(y_desc, x_desc); + CHECK_RESULT(info); + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + info.take(), 0, handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t launchKernel(void *y, const void *x, infiniDtype_t x_dtype, infiniDtype_t y_dtype, + size_t batch_size, size_t probs_size, size_t ndim, size_t seq_len, + ptrdiff_t y_stride_b, ptrdiff_t y_stride_p, + ptrdiff_t x_stride_b, ptrdiff_t x_stride_p, + ptrdiff_t y_stride_0, ptrdiff_t y_stride_1, + ptrdiff_t x_stride_0, ptrdiff_t x_stride_1, + cudaStream_t stream) { + dim3 grid(uint32_t(batch_size), 1, 1); + + // Handle mixed precision cases + if (x_dtype == INFINI_DTYPE_F16 && y_dtype == INFINI_DTYPE_F32) { + logSoftmax + <<>>((float *)y, (const half *)x, + batch_size, probs_size, ndim, seq_len, + y_stride_b, y_stride_p, + x_stride_b, x_stride_p, + y_stride_0, y_stride_1, + x_stride_0, x_stride_1); + } else if (x_dtype == INFINI_DTYPE_F32 && y_dtype == INFINI_DTYPE_F16) { + logSoftmax + <<>>((half *)y, (const float *)x, + batch_size, probs_size, ndim, seq_len, + y_stride_b, y_stride_p, + x_stride_b, x_stride_p, + y_stride_0, y_stride_1, + x_stride_0, x_stride_1); + } else if (x_dtype == INFINI_DTYPE_BF16 && y_dtype == INFINI_DTYPE_F32) { + logSoftmax + <<>>((float *)y, (const __nv_bfloat16 *)x, + batch_size, probs_size, ndim, seq_len, + y_stride_b, y_stride_p, + x_stride_b, x_stride_p, + y_stride_0, y_stride_1, + x_stride_0, x_stride_1); + } else if (x_dtype == INFINI_DTYPE_F32 && y_dtype == INFINI_DTYPE_BF16) { + logSoftmax + <<>>((__nv_bfloat16 *)y, (const float *)x, + batch_size, probs_size, ndim, seq_len, + y_stride_b, y_stride_p, + x_stride_b, x_stride_p, + y_stride_0, y_stride_1, + x_stride_0, x_stride_1); + } else if (x_dtype == INFINI_DTYPE_F16 && y_dtype == INFINI_DTYPE_F16) { + logSoftmax + <<>>((half *)y, (const half *)x, + batch_size, probs_size, ndim, seq_len, + y_stride_b, y_stride_p, + x_stride_b, x_stride_p, + y_stride_0, y_stride_1, + x_stride_0, x_stride_1); + } else if (x_dtype == INFINI_DTYPE_BF16 && y_dtype == INFINI_DTYPE_BF16) { + logSoftmax + <<>>((__nv_bfloat16 *)y, (const __nv_bfloat16 *)x, + batch_size, probs_size, ndim, seq_len, + y_stride_b, y_stride_p, + x_stride_b, x_stride_p, + y_stride_0, y_stride_1, + x_stride_0, x_stride_1); + } else if (x_dtype == INFINI_DTYPE_F32 && y_dtype == INFINI_DTYPE_F32) { + logSoftmax + <<>>((float *)y, (const float *)x, + batch_size, probs_size, ndim, seq_len, + y_stride_b, y_stride_p, + x_stride_b, x_stride_p, + y_stride_0, y_stride_1, + x_stride_0, x_stride_1); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, + void *y, + const void *x, + void *stream_) const { + cudaStream_t stream = (cudaStream_t)stream_; + if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { + CHECK_STATUS(launchKernel( + y, x, _info.x_dtype, _info.y_dtype, _info.batch_size, _info.probs_size, _info.ndim, _info.seq_len, + _info.y_stride_b, _info.y_stride_p, _info.x_stride_b, _info.x_stride_p, + _info.y_stride_0, _info.y_stride_1, _info.x_stride_0, _info.x_stride_1, stream)); + } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) { + CHECK_STATUS(launchKernel( + y, x, _info.x_dtype, _info.y_dtype, _info.batch_size, _info.probs_size, _info.ndim, _info.seq_len, + _info.y_stride_b, _info.y_stride_p, _info.x_stride_b, _info.x_stride_p, + _info.y_stride_0, _info.y_stride_1, _info.x_stride_0, _info.x_stride_1, stream)); + } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) { + CHECK_STATUS(launchKernel( + y, x, _info.x_dtype, _info.y_dtype, _info.batch_size, _info.probs_size, _info.ndim, _info.seq_len, + _info.y_stride_b, _info.y_stride_p, _info.x_stride_b, _info.x_stride_p, + _info.y_stride_0, _info.y_stride_1, _info.x_stride_0, _info.x_stride_1, stream)); + } else { + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::logsoftmax::nvidia diff --git a/src/infiniop/ops/logsoftmax/nvidia/logsoftmax_nvidia.cuh b/src/infiniop/ops/logsoftmax/nvidia/logsoftmax_nvidia.cuh new file mode 100644 index 000000000..803143ba7 --- /dev/null +++ b/src/infiniop/ops/logsoftmax/nvidia/logsoftmax_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __LOGSOFTMAX_NVIDIA_H__ +#define __LOGSOFTMAX_NVIDIA_H__ + +#include "../logsoftmax.h" + +DESCRIPTOR(nvidia) + +#endif diff --git a/src/infiniop/ops/logsoftmax/operator.cc b/src/infiniop/ops/logsoftmax/operator.cc new file mode 100644 index 000000000..ffb78135f --- /dev/null +++ b/src/infiniop/ops/logsoftmax/operator.cc @@ -0,0 +1,136 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/logsoftmax.h" + +#ifdef ENABLE_CPU_API +#include "cpu/logsoftmax_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) +#include "nvidia/logsoftmax_nvidia.cuh" +#endif +#ifdef ENABLE_METAX_API +// #include "metax/logsoftmax_metax.h" +#endif +#ifdef ENABLE_ASCEND_API +// #include "ascend/logsoftmax_ascend.h" +#endif + +__C infiniStatus_t infiniopCreateLogSoftmaxDescriptor( + infiniopHandle_t handle, + infiniopLogSoftmaxDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::logsoftmax::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + y_desc, \ + x_desc); + + switch (handle->device) { +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu) +#endif +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_ILUVATAR_API + // CREATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + // CREATE(INFINI_DEVICE_METAX, metax) +#endif +#ifdef ENABLE_ASCEND_API + // CREATE(INFINI_DEVICE_ASCEND, ascend) +#endif + } + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopGetLogSoftmaxWorkspaceSize(infiniopLogSoftmaxDescriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu) +#endif +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_ILUVATAR_API + // GET(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + // GET(INFINI_DEVICE_METAX, metax) +#endif +#ifdef ENABLE_ASCEND_API + // GET(INFINI_DEVICE_ASCEND, ascend) +#endif + } + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopLogSoftmax( + infiniopLogSoftmaxDescriptor_t desc, + void *workspace, size_t workspace_size, + void *y, + const void *x, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc)->calculate( \ + workspace, workspace_size, y, x, stream); + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu) +#endif +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_ILUVATAR_API + // CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + // CALCULATE(INFINI_DEVICE_METAX, metax) +#endif +#ifdef ENABLE_ASCEND_API + // CALCULATE(INFINI_DEVICE_ASCEND, ascend) +#endif + } + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopDestroyLogSoftmaxDescriptor(infiniopLogSoftmaxDescriptor_t desc) { + +#define DESTROY(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + DESTROY(INFINI_DEVICE_CPU, cpu) +#endif +#ifdef ENABLE_NVIDIA_API + DESTROY(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_ILUVATAR_API + // DESTROY(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + // DESTROY(INFINI_DEVICE_METAX, metax) +#endif +#ifdef ENABLE_ASCEND_API + // DESTROY(INFINI_DEVICE_ASCEND, ascend) +#endif + } + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index ba1ce33df..86cc8966a 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -162,6 +162,38 @@ def clip_(lib): ] +@OpRegister.operator +def logsoftmax_(lib): + lib.infiniopCreateLogSoftmaxDescriptor.restype = c_int32 + lib.infiniopCreateLogSoftmaxDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetLogSoftmaxWorkspaceSize.restype = c_int32 + lib.infiniopGetLogSoftmaxWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopLogSoftmax.restype = c_int32 + lib.infiniopLogSoftmax.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyLogSoftmaxDescriptor.restype = c_int32 + lib.infiniopDestroyLogSoftmaxDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + @OpRegister.operator def conv_(lib): pass diff --git a/test/infiniop/logsoftmax.py b/test/infiniop/logsoftmax.py new file mode 100644 index 000000000..ab7dd5ab1 --- /dev/null +++ b/test/infiniop/logsoftmax.py @@ -0,0 +1,324 @@ +import torch +import ctypes +from ctypes import c_uint64 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + TestWorkspace, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, +) +from enum import Enum, auto + +# ============================================================================== +# Configuration (Internal Use Only) +# ============================================================================== +# These are not meant to be imported from other modules +_TEST_CASES_ = [ + # shape, x_stride, y_stride + ((3, 3), None, None), + ((32, 512), None, None), + ((32, 512), (1024, 1), (1024, 1)), + ((32, 5, 5), None, None), + ((32, 20, 512), None, None), + ((32, 20, 512), (20480, 512, 1), None), + ((28, 15, 15), None, None), + ((1, 1000), None, None), + ((16, 50257), None, None), + ((4, 8, 256), None, None), + ((2, 16, 1024), None, None), +] + +# Data types used for testing +_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32] + +# Tolerance map for different data types +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2}, + InfiniDtype.BF16: {"atol": 5e-3, "rtol": 1e-2}, + InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5}, +} + +# Mixed precision test cases - support y_dtype == x_dtype or y_dtype == F32 +_MIXED_PRECISION_CASES = [ + (InfiniDtype.F16, InfiniDtype.F32), + (InfiniDtype.BF16, InfiniDtype.F32), + (InfiniDtype.F16, InfiniDtype.F16), + (InfiniDtype.BF16, InfiniDtype.BF16), + (InfiniDtype.F32, InfiniDtype.F32), +] + + +class Inplace(Enum): + OUT_OF_PLACE = auto() + INPLACE_X = auto() + + +_INPLACE = [ + Inplace.INPLACE_X, + Inplace.OUT_OF_PLACE, +] + +_TEST_CASES = [ + test_case + (inplace_item,) + for test_case in _TEST_CASES_ + for inplace_item in _INPLACE +] + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +def logsoftmax(x): + """PyTorch reference implementation of log_softmax""" + return torch.nn.functional.log_softmax(x.to(torch.float32), dim=-1) + + +def test( + handle, + device, + shape, + x_stride=None, + y_stride=None, + inplace=Inplace.OUT_OF_PLACE, + dtype=InfiniDtype.F16, + sync=None, +): + print( + f"Testing LogSoftmax on {InfiniDeviceNames[device]} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} dtype:{InfiniDtypeNames[dtype]} inplace:{inplace}" + ) + + x = TestTensor(shape, x_stride, dtype, device) + ans = logsoftmax(x.actual_tensor()) + + # Convert answer to match input dtype for default behavior + if dtype == InfiniDtype.F16: + ans = ans.to(torch.float16) + elif dtype == InfiniDtype.BF16: + ans = ans.to(torch.bfloat16) + elif dtype == InfiniDtype.F32: + ans = ans.to(torch.float32) + + if inplace == Inplace.INPLACE_X: + y = x + else: + y = TestTensor(shape, y_stride, dtype, device) # Default: same dtype as input + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + status = LIBINFINIOP.infiniopCreateLogSoftmaxDescriptor( + handle, ctypes.byref(descriptor), y.descriptor, x.descriptor + ) + check_error(status) + + # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel + x.destroy_desc() + y.destroy_desc() + + workspace_size = c_uint64(0) + status = LIBINFINIOP.infiniopGetLogSoftmaxWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + check_error(status) + workspace = TestWorkspace(workspace_size.value, x.device) + + def lib_logsoftmax(): + check_error( + LIBINFINIOP.infiniopLogSoftmax( + descriptor, + workspace.data(), + workspace_size.value, + y.data(), + x.data(), + None, + ) + ) + + lib_logsoftmax() + + if sync is not None: + sync() + + # Use tolerance based on input dtype for numerical stability + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + + # Always print debug info for failed cases + actual = y.actual_tensor() + max_diff = torch.max(torch.abs(actual - ans)) + is_close = torch.allclose(actual, ans, atol=atol, rtol=rtol) + + if DEBUG or not is_close: + print(f"\n=== Debug Info ===") + print(f"Shape: {shape}, Stride: {x_stride}, Dtype: {dtype}") + print(f"Input tensor: {x.torch_tensor()}") + print(f"Expected output: {ans}") + print(f"Actual output: {actual}") + print(f"Max diff: {max_diff}") + print(f"Tolerance: atol={atol}, rtol={rtol}") + print(f"Is close: {is_close}") + print(f"First few values - Actual: {actual.flatten()[:5]}") + print(f"First few values - Expected: {ans.flatten()[:5]}") + if DEBUG: + debug(actual, ans, atol=atol, rtol=rtol) + + assert is_close + + # Profiling workflow + if PROFILE: + # fmt: off + profile_operation("PyTorch", lambda: logsoftmax(x.torch_tensor()), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_logsoftmax(), device, NUM_PRERUN, NUM_ITERATIONS) + # fmt: on + + check_error(LIBINFINIOP.infiniopDestroyLogSoftmaxDescriptor(descriptor)) + + +def test_mixed_precision( + handle, + device, + shape, + x_stride=None, + y_stride=None, + inplace=Inplace.OUT_OF_PLACE, + x_dtype=InfiniDtype.F16, + y_dtype=InfiniDtype.F32, + sync=None, +): + print( + f"Testing LogSoftmax (Mixed) on {InfiniDeviceNames[device]} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} x_dtype:{InfiniDtypeNames[x_dtype]} y_dtype:{InfiniDtypeNames[y_dtype]} inplace:{inplace}" + ) + + x = TestTensor(shape, x_stride, x_dtype, device) + ans = logsoftmax(x.actual_tensor()) + + # Convert answer to target dtype for comparison + if y_dtype == InfiniDtype.F16: + ans = ans.to(torch.float16) + elif y_dtype == InfiniDtype.BF16: + ans = ans.to(torch.bfloat16) + elif y_dtype == InfiniDtype.F32: + ans = ans.to(torch.float32) + + if inplace == Inplace.INPLACE_X: + # For inplace operations, input and output must have the same dtype + if x_dtype != y_dtype: + print( + f"Skipping inplace test: x_dtype ({InfiniDtypeNames[x_dtype]}) != y_dtype ({InfiniDtypeNames[y_dtype]})" + ) + return + y = x + else: + y = TestTensor(shape, y_stride, y_dtype, device) + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateLogSoftmaxDescriptor( + handle, ctypes.byref(descriptor), y.descriptor, x.descriptor + ) + ) + + # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel + x.destroy_desc() + y.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetLogSoftmaxWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, x.device) + + def lib_logsoftmax(): + check_error( + LIBINFINIOP.infiniopLogSoftmax( + descriptor, + workspace.data(), + workspace_size.value, + y.data(), + x.data(), + None, + ) + ) + + lib_logsoftmax() + + if sync is not None: + sync() + + # Use tolerance based on output dtype for mixed precision cases + atol, rtol = get_tolerance(_TOLERANCE_MAP, y_dtype) + + # Ensure both tensors have the same dtype for comparison + y_tensor = y.actual_tensor() + if y_tensor.dtype != ans.dtype: + y_tensor = y_tensor.to(ans.dtype) + + if DEBUG: + debug(y_tensor, ans, atol=atol, rtol=rtol) + assert torch.allclose(y_tensor, ans, atol=atol, rtol=rtol) + + # Profiling workflow + if PROFILE: + # fmt: off + profile_operation("PyTorch", lambda: logsoftmax(x.torch_tensor()), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_logsoftmax(), device, NUM_PRERUN, NUM_ITERATIONS) + # fmt: on + + check_error(LIBINFINIOP.infiniopDestroyLogSoftmaxDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + # Configure testing options + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + # Test standard cases (fp32 output) + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + # Test mixed precision cases + from libinfiniop import create_handle, destroy_handle, get_sync_func + + handle = create_handle() + sync = get_sync_func(device) + try: + for x_dtype, y_dtype in _MIXED_PRECISION_CASES: + for shape, x_stride, y_stride, inplace in _TEST_CASES[ + :5 + ]: # Test subset for mixed precision + test_mixed_precision( + handle, + device, + shape, + x_stride, + y_stride, + inplace, + x_dtype, + y_dtype, + sync, + ) + finally: + destroy_handle(handle) + + print("\033[92mTest passed!\033[0m")