Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions lib/kernels/include/kernels/embedding_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,25 @@
#include "kernels/device.h"

namespace FlexFlow {

class EmbeddingPerDeviceState : public PerDeviceOpState {
public:
EmbeddingPerDeviceState(FFHandler handle);
DataType input_data_type, output_data_type;
AggrMode aggr;
};

namespace Kernels {
namespace Embedding {
void forward_kernel(ffStream_t stream,
EmbeddingPerDeviceState const *m,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output,
GenericTensorAccessorR const &weight,
DataType input_data_type,
DataType output_data_type,
AggrMode aggr,
int in_dim,
int out_dim,
int batch_size);
void backward_kernel(ffStream_t stream,
EmbeddingPerDeviceState const *m,
GenericTensorAccessorR const &input,
GenericTensorAccessorR const &output,
GenericTensorAccessorW const &weight_grad,
DataType input_data_type,
DataType output_data_type,
AggrMode aggr,
int in_dim,
int out_dim,
int batch_size);
Expand Down
44 changes: 24 additions & 20 deletions lib/kernels/src/cuda/embedding_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace Embedding {
template <DataType TI, DataType TD>
struct ForwardKernel {
void operator()(cudaStream_t stream,
EmbeddingPerDeviceState const *m,
AggrMode aggr,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output,
GenericTensorAccessorR const &weight,
Expand All @@ -35,8 +35,8 @@ struct ForwardKernel {
assert(weight.data_type == DT_HALF || weight.data_type == DT_FLOAT ||
weight.data_type == DT_DOUBLE);

if (m->aggr == AGGR_MODE_NONE) {
embed_forward_no_aggr<TI, TD><<<GET_BLOCKS(output.domain.get_volume()),
if (aggr == AGGR_MODE_NONE) {
embed_forward_no_aggr<TI, TD><<<GET_BLOCKS(output.shape.get_volume()),
CUDA_NUM_THREADS,
0,
stream>>>(input.get<TI>(),
Expand All @@ -45,8 +45,8 @@ struct ForwardKernel {
out_dim,
batch_size);
} else {
assert(m->aggr == AGGR_MODE_AVG || m->aggr == AGGR_MODE_SUM);
embed_forward_with_aggr<TI, TD><<<GET_BLOCKS(output.domain.get_volume()),
assert(aggr == AGGR_MODE_AVG || aggr == AGGR_MODE_SUM);
embed_forward_with_aggr<TI, TD><<<GET_BLOCKS(output.shape.get_volume()),
CUDA_NUM_THREADS,
0,
stream>>>(input.get<TI>(),
Expand All @@ -55,15 +55,15 @@ struct ForwardKernel {
out_dim,
in_dim,
batch_size,
m->aggr);
aggr);
}
}
}

template <DataType TI, DataType TD>
struct BackwardKernel {
void operator()(cudaStream_t stream,
EmbeddingPerDeviceState const *m,
AggrMode aggr,
GenericTensorAccessorR const &input,
GenericTensorAccessorR const &output,
GenericTensorAccessorW const &weight_grad,
Expand All @@ -73,8 +73,8 @@ struct BackwardKernel {
assert(input.data_type == DT_INT32 || input.data_type == DT_INT64);
assert(output.data_type == DT_HALF || output.data_type == DT_FLOAT,
|| output.data_type == DT_DOUBLE);
if (m->aggr == AGGR_MODE_NONE) {
embed_backward_no_aggr<TI, TD><<<GET_BLOCKS(output.domain.get_volume()),
if (aggr == AGGR_MODE_NONE) {
embed_backward_no_aggr<TI, TD><<<GET_BLOCKS(output.shape.get_volume()),
CUDA_NUM_THREADS,
0,
stream>>>(input.get<TI>(),
Expand All @@ -83,7 +83,7 @@ struct BackwardKernel {
out_dim,
batch_size);
} else {
embed_backward_with_aggr<TI, TD><<<GET_BLOCKS(output.domain.get_volume()),
embed_backward_with_aggr<TI, TD><<<GET_BLOCKS(output.shape.get_volume()),
CUDA_NUM_THREADS,
0,
stream>>>(input.get<TI>(),
Expand All @@ -92,23 +92,25 @@ struct BackwardKernel {
out_dim,
in_dim,
batch_size,
m->aggr);
aggr);
}
}
}

void forward_kernel(cudaStream_t stream,
EmbeddingPerDeviceState const *m,
void forward_kernel(ffStream_t stream,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output,
GenericTensorAccessorR const &weight,
DataType input_data_type,
DataType output_data_type,
AggrMode aggr,
int in_dim,
int out_dim,
int batch_size) {
DataTypeDispatch2<ForwardKernel>{}(m->input_data_type,
m->output_data_type,
DataTypeDispatch2<ForwardKernel>{}(input_data_type,
output_data_type,
stream,
m,
aggr,
input,
output,
weight,
Expand All @@ -118,17 +120,19 @@ void forward_kernel(cudaStream_t stream,
}

void backward_kernel(cudaStream_t stream,
EmbeddingPerDeviceState const *m,
GenericTensorAccessorR const &input,
GenericTensorAccessorR const &output,
GenericTensorAccessorW const &weight_grad,
DataType input_data_type,
DataType output_data_type,
AggrMode aggr,
int in_dim,
int out_dim,
int batch_size) {
DataTypeDispatch2<BackwardKernel>{}(m->input_data_type,
m->output_data_type,
DataTypeDispatch2<BackwardKernel>{}(input_data_type,
output_data_type,
stream,
m,
aggr,
input,
output,
weight,
Expand Down
40 changes: 22 additions & 18 deletions lib/kernels/src/hip/embedding_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace Embedding {
template <DataType TI, DataType TD>
struct ForwardKernel {
void operator()(hipStream_t stream,
EmbeddingPerDeviceState const *m,
AggrMode aggr,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output,
GenericTensorAccessorR const &weight,
Expand All @@ -36,9 +36,9 @@ struct ForwardKernel {
assert(weight.data_type == DT_HALF || weight.data_type == DT_FLOAT ||
weight.data_type == DT_DOUBLE);

if (m->aggr == AGGR_MODE_NONE) {
if (aggr == AGGR_MODE_NONE) {
hipLaunchKernelGGL(HIP_KERNEL_NAME(embed_forward_no_aggr<TI, TD>),
GET_BLOCKS(output.domain.get_volume()),
GET_BLOCKS(output.shape.get_volume()),
CUDA_NUM_THREADS,
0,
stream,
Expand All @@ -49,7 +49,7 @@ struct ForwardKernel {
batch_size);
} else {
hipLaunchKernelGGL(HIP_KERNEL_NAME(embed_forward_with_aggr<TI, TD>),
GET_BLOCKS(output.domain.get_volume()),
GET_BLOCKS(output.shape.get_volume()),
CUDA_NUM_THREADS,
0,
stream,
Expand All @@ -59,15 +59,15 @@ struct ForwardKernel {
out_dim,
in_dim,
batch_size,
m->aggr);
aggr);
}
}
}

template <DataType TI, DataType TD>
struct BackwardKernel {
void operator()(hipStream_t stream,
EmbeddingPerDeviceState const *m,
AggrMode aggr,
GenericTensorAccessorR const &input,
GenericTensorAccessorR const &output,
GenericTensorAccessorW const &weight_grad,
Expand All @@ -77,9 +77,9 @@ struct BackwardKernel {
assert(input.data_type == DT_INT32 || input.data_type == DT_INT64);
assert(output.data_type == DT_HALF || output.data_type == DT_FLOAT,
|| output.data_type == DT_DOUBLE);
if (m->aggr == AGGR_MODE_NONE) {
if (aggr == AGGR_MODE_NONE) {
hipLaunchKernelGGL(HIP_KERNEL_NAME(embed_backward_no_aggr<TI, TD>),
GET_BLOCKS(output.domain.get_volume()),
GET_BLOCKS(output.shape.get_volume()),
CUDA_NUM_THREADS,
0,
stream,
Expand All @@ -90,7 +90,7 @@ struct BackwardKernel {
batch_size);
} else {
hipLaunchKernelGGL(HIP_KERNEL_NAME(embed_backward_with_aggr<TI, TD>),
GET_BLOCKS(output.domain.get_volume()),
GET_BLOCKS(output.shape.get_volume()),
CUDA_NUM_THREADS,
0,
stream,
Expand All @@ -100,23 +100,25 @@ struct BackwardKernel {
out_dim,
in_dim,
batch_size,
m->aggr);
aggr);
}
}
}

void forward_kernel(hipStream_t stream,
EmbeddingPerDeviceState const *m,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output,
GenericTensorAccessorR const &weight,
DataType input_data_type,
DataType output_data_type,
AggrMode aggr,
int in_dim,
int out_dim,
int batch_size) {
DataTypeDispatch2<ForwardKernel>{}(m->input_data_type,
m->output_data_type,
DataTypeDispatch2<ForwardKernel>{}(input_data_type,
output_data_type,
stream,
m,
aggr,
input,
output,
weight,
Expand All @@ -126,17 +128,19 @@ void forward_kernel(hipStream_t stream,
}

void backward_kernel(hipStream_t stream,
EmbeddingPerDeviceState const *m,
GenericTensorAccessorR const &input,
GenericTensorAccessorR const &output,
GenericTensorAccessorW const &weight_grad,
DataType input_data_type,
DataType output_data_type,
AggrMode aggr,
int in_dim,
int out_dim,
int batch_size) {
DataTypeDispatch2<BackwardKernel>{}(m->input_data_type,
m->output_data_type,
DataTypeDispatch2<BackwardKernel>{}(input_data_type,
output_data_type,
stream,
m,
aggr,
input,
output,
weight,
Expand Down
Loading