Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "infiniop/ops/attention.h"
#include "infiniop/ops/causal_softmax.h"
#include "infiniop/ops/clip.h"
#include "infiniop/ops/logsoftmax.h"
#include "infiniop/ops/conv.h"
#include "infiniop/ops/dequantize_awq.h"
#include "infiniop/ops/gemm.h"
Expand Down
24 changes: 24 additions & 0 deletions include/infiniop/ops/logsoftmax.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef __INFINIOP_LOGSOFTMAX_API_H__
#define __INFINIOP_LOGSOFTMAX_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopLogSoftmaxDescriptor_t;

__C __export infiniStatus_t infiniopCreateLogSoftmaxDescriptor(infiniopHandle_t handle,
infiniopLogSoftmaxDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc);

__C __export infiniStatus_t infiniopGetLogSoftmaxWorkspaceSize(infiniopLogSoftmaxDescriptor_t desc, size_t *size);

__C __export infiniStatus_t infiniopLogSoftmax(infiniopLogSoftmaxDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *y,
const void *x,
void *stream);

__C __export infiniStatus_t infiniopDestroyLogSoftmaxDescriptor(infiniopLogSoftmaxDescriptor_t desc);

#endif
1 change: 1 addition & 0 deletions scripts/python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def run_tests(args):
"causal_softmax.py",
"clip.py",
"gemm.py",
"logsoftmax.py",
"mul.py",
"random_sample.py",
"rearrange.py",
Expand Down
130 changes: 130 additions & 0 deletions src/infiniop/ops/logsoftmax/cpu/logsoftmax_cpu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#include "logsoftmax_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
#include "../../../reduce/cpu/reduce.h"
#include <algorithm>
#include <cmath>

namespace op::logsoftmax::cpu {

Descriptor::~Descriptor() {}

infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc) {
auto result = LogSoftmaxInfo::create(y_desc, x_desc);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}

template <typename Tx, typename Ty>
infiniStatus_t logsoftmax(const LogSoftmaxInfo *info, Ty *y, const Tx *x) {
#pragma omp parallel for
for (ptrdiff_t batch = 0; batch < ptrdiff_t(info->batch_size); batch++) {
ptrdiff_t y_offset, x_offset;

if (info->ndim == 3) {
// For 3D tensors, convert linear batch index back to 2D indices
ptrdiff_t batch_idx = batch / info->seq_len;
ptrdiff_t seq_idx = batch % info->seq_len;
y_offset = batch_idx * info->y_stride_0 + seq_idx * info->y_stride_1;
x_offset = batch_idx * info->x_stride_0 + seq_idx * info->x_stride_1;
} else {
// For 2D tensors, use the flattened strides
y_offset = batch * info->y_stride_b;
x_offset = batch * info->x_stride_b;
}

Ty *y_ = y + y_offset;
const Tx *x_ = x + x_offset;

// Find max value for numerical stability
float max_val;
if constexpr (std::is_same<Tx, fp16_t>::value || std::is_same<Tx, bf16_t>::value) {
max_val = op::common_cpu::reduce_op::max(x_, info->probs_size, info->x_stride_p);
} else {
max_val = op::common_cpu::reduce_op::max(x_, info->probs_size, info->x_stride_p);
}

// Compute exp(x - max) and sum
float sum = 0.0f;
for (size_t i = 0; i < info->probs_size; i++) {
float x_val;
if constexpr (std::is_same<Tx, fp16_t>::value || std::is_same<Tx, bf16_t>::value) {
x_val = utils::cast<float>(x_[i * info->x_stride_p]);
} else {
x_val = x_[i * info->x_stride_p];
}
sum += std::exp(x_val - max_val);
}

// Compute log(sum)
float log_sum = std::log(sum);

// Compute log_softmax = x - max - log(sum)
for (size_t i = 0; i < info->probs_size; i++) {
float x_val;
if constexpr (std::is_same<Tx, fp16_t>::value || std::is_same<Tx, bf16_t>::value) {
x_val = utils::cast<float>(x_[i * info->x_stride_p]);
} else {
x_val = x_[i * info->x_stride_p];
}

float result = x_val - max_val - log_sum;

if constexpr (std::is_same<Ty, fp16_t>::value || std::is_same<Ty, bf16_t>::value) {
y_[i * info->y_stride_p] = utils::cast<Ty>(result);
} else {
y_[i * info->y_stride_p] = result;
}
}
}

return INFINI_STATUS_SUCCESS;
}

infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y,
const void *x,
void *stream) const {

// Handle different input/output dtype combinations
if (_info.x_dtype == INFINI_DTYPE_F16) {
if (_info.y_dtype == INFINI_DTYPE_F16) {
return logsoftmax<fp16_t, fp16_t>(&_info, (fp16_t *)y, (const fp16_t *)x);
} else if (_info.y_dtype == INFINI_DTYPE_BF16) {
return logsoftmax<fp16_t, bf16_t>(&_info, (bf16_t *)y, (const fp16_t *)x);
} else if (_info.y_dtype == INFINI_DTYPE_F32) {
return logsoftmax<fp16_t, float>(&_info, (float *)y, (const fp16_t *)x);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (_info.x_dtype == INFINI_DTYPE_BF16) {
if (_info.y_dtype == INFINI_DTYPE_F16) {
return logsoftmax<bf16_t, fp16_t>(&_info, (fp16_t *)y, (const bf16_t *)x);
} else if (_info.y_dtype == INFINI_DTYPE_BF16) {
return logsoftmax<bf16_t, bf16_t>(&_info, (bf16_t *)y, (const bf16_t *)x);
} else if (_info.y_dtype == INFINI_DTYPE_F32) {
return logsoftmax<bf16_t, float>(&_info, (float *)y, (const bf16_t *)x);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (_info.x_dtype == INFINI_DTYPE_F32) {
if (_info.y_dtype == INFINI_DTYPE_F16) {
return logsoftmax<float, fp16_t>(&_info, (fp16_t *)y, (const float *)x);
} else if (_info.y_dtype == INFINI_DTYPE_BF16) {
return logsoftmax<float, bf16_t>(&_info, (bf16_t *)y, (const float *)x);
} else if (_info.y_dtype == INFINI_DTYPE_F32) {
return logsoftmax<float, float>(&_info, (float *)y, (const float *)x);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}

} // namespace op::logsoftmax::cpu
7 changes: 7 additions & 0 deletions src/infiniop/ops/logsoftmax/cpu/logsoftmax_cpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#ifndef __LOGSOFTMAX_CPU_H__
#define __LOGSOFTMAX_CPU_H__
#include "../logsoftmax.h"

DESCRIPTOR(cpu)

#endif
108 changes: 108 additions & 0 deletions src/infiniop/ops/logsoftmax/cuda/kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#ifndef __LOGSOFTMAX_KERNEL_CUH__
#define __LOGSOFTMAX_KERNEL_CUH__

#include <cub/block/block_reduce.cuh>
#include <type_traits>

template <unsigned int BLOCK_SIZE, typename Tdata_out, typename Tdata_in, typename Tcompute>
__device__ void logSoftmaxKernel(
Tdata_out *y, const Tdata_in *x,
size_t batch_size, size_t probs_size, size_t ndim, size_t seq_len,
ptrdiff_t y_stride_b, ptrdiff_t y_stride_p,
ptrdiff_t x_stride_b, ptrdiff_t x_stride_p,
ptrdiff_t y_stride_0, ptrdiff_t y_stride_1,
ptrdiff_t x_stride_0, ptrdiff_t x_stride_1) {

typedef cub::BlockReduce<Tcompute, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ Tcompute shared_max_val;
__shared__ Tcompute shared_sum_exp;

int batch_idx = blockIdx.x;
int tid = threadIdx.x;

if (batch_idx >= batch_size) {
return;
}

// Calculate correct memory offsets for 3D tensors
ptrdiff_t y_offset, x_offset;
if (ndim == 3) {
// For 3D tensors, convert linear batch index back to 2D indices
ptrdiff_t batch_dim_idx = batch_idx / seq_len;
ptrdiff_t seq_dim_idx = batch_idx % seq_len;
y_offset = batch_dim_idx * y_stride_0 + seq_dim_idx * y_stride_1;
x_offset = batch_dim_idx * x_stride_0 + seq_dim_idx * x_stride_1;
} else {
// For 2D tensors, use the flattened strides
y_offset = batch_idx * y_stride_b;
x_offset = batch_idx * x_stride_b;
}

const Tdata_in *x_batch = x + x_offset;
Tdata_out *y_batch = y + y_offset;

// Find maximum value for numerical stability
Tcompute max_val = static_cast<Tcompute>(-INFINITY);
for (int i = tid; i < probs_size; i += BLOCK_SIZE) {
if (i < probs_size) { // Add boundary check
Tcompute val = static_cast<Tcompute>(x_batch[i * x_stride_p]);
if constexpr (std::is_same_v<Tcompute, float>) {
max_val = fmaxf(max_val, val);
} else {
max_val = fmax(max_val, val);
}
}
}
max_val = BlockReduce(temp_storage).Reduce(max_val, cub::Max());
if (tid == 0) {
shared_max_val = max_val;
}
__syncthreads();

// Compute sum of exp(x - max)
Tcompute sum_exp = static_cast<Tcompute>(0.0);
for (int i = tid; i < probs_size; i += BLOCK_SIZE) {
if (i < probs_size) { // Add boundary check
Tcompute val = static_cast<Tcompute>(x_batch[i * x_stride_p]);
if constexpr (std::is_same_v<Tcompute, float>) {
sum_exp += expf(val - shared_max_val);
} else {
sum_exp += exp(val - shared_max_val);
}
}
}
sum_exp = BlockReduce(temp_storage).Sum(sum_exp);
if (tid == 0) {
shared_sum_exp = sum_exp;
}
__syncthreads();

// Compute log_softmax = x - max - log(sum_exp)
Tcompute log_sum_exp;
if constexpr (std::is_same_v<Tcompute, float>) {
log_sum_exp = logf(shared_sum_exp);
} else {
log_sum_exp = log(shared_sum_exp);
}
for (int i = tid; i < probs_size; i += BLOCK_SIZE) {
if (i < probs_size) { // Add boundary check
Tcompute val = static_cast<Tcompute>(x_batch[i * x_stride_p]);
Tcompute result = val - shared_max_val - log_sum_exp;
y_batch[i * y_stride_p] = static_cast<Tdata_out>(result);
}
}
}

template <unsigned int BLOCK_SIZE, typename Tdata_out, typename Tdata_in, typename Tcompute>
__global__ void logSoftmax(
Tdata_out *y, const Tdata_in *x,
size_t batch_size, size_t probs_size, size_t ndim, size_t seq_len,
ptrdiff_t y_stride_b, ptrdiff_t y_stride_p,
ptrdiff_t x_stride_b, ptrdiff_t x_stride_p,
ptrdiff_t y_stride_0, ptrdiff_t y_stride_1,
ptrdiff_t x_stride_0, ptrdiff_t x_stride_1) {
logSoftmaxKernel<BLOCK_SIZE, Tdata_out, Tdata_in, Tcompute>(y, x, batch_size, probs_size, ndim, seq_len, y_stride_b, y_stride_p, x_stride_b, x_stride_p, y_stride_0, y_stride_1, x_stride_0, x_stride_1);
}

#endif // __LOGSOFTMAX_KERNEL_CUH__
Loading