Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions lib/kernels/include/kernels/layer_norm_kernels.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#ifndef _FLEXFLOW_OPS_KERNELS_LAYER_NORM_KERNELS_H
#define _FLEXFLOW_OPS_KERNELS_LAYER_NORM_KERNELS_H

#include "kernels/allocation.h"
#include "kernels/device.h"
#include "kernels/ff_handle.h"

namespace FlexFlow {

Expand All @@ -23,18 +25,47 @@ class LayerNormPerDeviceState : public PerDeviceOpState {
DataType data_type;
};

struct LayerNormPerDeviceState {
bool elementwise_affine;
int64_t effective_batch_size, effective_num_elements;
float eps;
float *mean, *rstd, *ds, *db, *scale, *bias;
DataType data_type;
};

FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(LayerNormPerDeviceState,
elementwise_affine,
effective_batch_size,
effective_num_elements,
eps,
mean,
rstd,
ds,
db,
scale,
bias,
data_type);

namespace Kernels {
namespace LayerNorm {

// todo: this may have some problem.
LayerNormPerDeviceState init_kernel(PerDeviceFFHandle const &,
Allocator const &,
bool elementwise_affine,
int64_t effective_batch_size,
int64_t effective_num_elements,
float eps);

void forward_kernel(ffStream_t stream,
LayerNormPerDeviceState const *m,
LayerNormPerDeviceState const &m,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output,
GenericTensorAccessorW const &gamma,
GenericTensorAccessorW const &beta);

void backward_kernel(ffStream_t stream,
LayerNormPerDeviceState const *m,
LayerNormPerDeviceState const &m,
GenericTensorAccessorR const &output_grad,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &input_grad,
Expand Down
36 changes: 34 additions & 2 deletions lib/kernels/src/cuda/layer_norm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,38 @@ LayerNormPerDeviceState::LayerNormPerDeviceState(
namespace Kernels {
namespace LayerNorm {

// todo: this may have some problem.
LayerNormPerDeviceState init_kernel(PerDeviceFFHandle const &handle,
Allocator const &allocator,
bool elementwise_affine_,
int64_t effective_batch_size_,
int64_t effective_num_elements_,
float eps_) {
elementwise_affine = elementwise_affine_;
effective_batch_size = effective_batch_size_;
effective_num_elements = effective_num_elements_;
eps = eps_;
mean = allocator.allocate(sizeof(float) * effective_batch_size);
rstd = allocator.allocate(sizeof(float) * effective_batch_size);
ds = allocator.allocate(sizeof(float) * effective_batch_size);
db = allocator.allocate(sizeof(float) * effective_batch_size);
scale = allocator.allocate(sizeof(float) * effective_batch_size);
bias = allocator.allocate(sizeof(float) * effective_batch_size);
LayerNormPerDeviceState per_device_state =
LayerNormPerDeviceState(handle,
elementwise_affine,
effective_batch_size,
effective_num_elements,
eps,
mean,
rstd,
ds,
db,
scale,
bias);
return per_device_state;
}

template <DataType T>
struct ForwardKernel {
void operator()(cudaStream_t stream,
Expand Down Expand Up @@ -137,7 +169,7 @@ struct BackwardKernel {
}

void forward_kernel(cudaStream_t stream,
LayerNormPerDeviceState const *m,
LayerNormPerDeviceState const &m,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output,
GenericTensorAccessorW const &gamma,
Expand All @@ -147,7 +179,7 @@ void forward_kernel(cudaStream_t stream,
}

void backward_kernel(cudaStream_t stream,
LayerNormPerDeviceState const *m,
LayerNormPerDeviceState const &m,
GenericTensorAccessorR const &output_grad,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &input_grad,
Expand Down
Loading