Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/compiler/test/test_dp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ TEST_CASE("optimal_cost") {

Node n0 = g.add_node(InputAttrs());
Node n1 = g.add_node(RepartitionAttrs(ff_dim_t(0), 2));
Node n2 = g.add_node(ElementScalarUnaryAttrs(OP_SCALAR_ADD, 0));
Node n3 = g.add_node(ElementScalarUnaryAttrs(OP_SCALAR_ADD, 1));
Node n2 = g.add_node(ElementUnaryAttrs(OP_SCALAR_ADD, 0));
Node n3 = g.add_node(ElementUnaryAttrs(OP_SCALAR_ADD, 1));
Node n4 = g.add_node(ConcatAttrs(ff_dim_t(1)));
Node n5 = g.add_node(CombineAttrs(ff_dim_t(0), 2));

Expand Down
41 changes: 23 additions & 18 deletions lib/kernels/include/kernels/element_unary_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,42 +3,47 @@

#include "kernels/accessor.h"
#include "kernels/device.h"
#include "legion.h"
#include "kernels/ff_handle.h"
#include "op-attrs/ops/element_unary.h"
#include <cstddef>

namespace FlexFlow {

class ElementUnaryPerDeviceState : public PerDeviceOpState {
public:
ElementUnaryPerDeviceState(FFHandler handle);
using ElementUnaryUnifiedAttrs =
variant<ElementUnaryAttrs, ElementScalarUnaryAttrs>;

struct ElementUnaryPerDeviceState {
ffTensorDescriptor_t inputTensor, outputTensor;
ffActivationDescriptor_t actiDesc;

OperatorType op_type;
DataType data_type;
bool inplace;
float scalar;
char op_name[MAX_OPNAME];
};

FF_VISITABLE_STRUCT_NO_EQ(ElementUnaryPerDeviceState,
inputTensor,
outputTensor,
actiDesc);

namespace Kernels {
namespace ElementUnary {

void init_kernel(ElementUnaryPerDeviceState *m,
Legion::Domain const &input_domain,
Legion::Domain const &output_domain);
ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape,
ArrayShape const &output_shape,
ElementUnaryUnifiedAttrs const &attrs);

void forward_kernel(ffStream_t stream,
ElementUnaryPerDeviceState const *m,
ElementUnaryPerDeviceState const &device_state,
ElementUnaryUnifiedAttrs const &attrs,
PerDeviceFFHandle &handle,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output);

void backward_kernel(ffStream_t stream,
ElementUnaryPerDeviceState const *m,
ElementUnaryPerDeviceState const &device_state,
ElementUnaryUnifiedAttrs const &attrs,
PerDeviceFFHandle &handle,
GenericTensorAccessorR const &input,
GenericTensorAccessorR const &input_grad,
GenericTensorAccessorW const &output,
GenericTensorAccessorW const &output_grad);
GenericTensorAccessorW const &input_grad,
GenericTensorAccessorR const &output,
GenericTensorAccessorR const &output_grad);

} // namespace ElementUnary
} // namespace Kernels
Expand Down
175 changes: 103 additions & 72 deletions lib/kernels/src/cuda/element_unary_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,6 @@
#include "kernels/element_unary_kernels.h"

namespace FlexFlow {

// declare Legion names
using Legion::coord_t;
using Legion::Domain;

ElementUnaryPerDeviceState::ElementUnaryPerDeviceState(FFHandler handler)
: PerDeviceOpState(handler) {
checkCUDNN(cudnnCreateTensorDescriptor(&inputTensor));
checkCUDNN(cudnnCreateTensorDescriptor(&outputTensor));
checkCUDNN(cudnnCreateActivationDescriptor(&actiDesc));
}

namespace Kernels {
namespace ElementUnary {

Expand All @@ -45,13 +33,31 @@ static bool use_cudnn(OperatorType op_type) {
}
}

void init_kernel(ElementUnaryPerDeviceState *m,
Domain const &input_domain,
Domain const &output_domain) {
template <T>
optional<T> get_scalar(ElementUnaryAttrs const &attrs) {}

template <T>
optional<T> get_scalar(ElementScalarUnaryAttrs const &attrs) {
return (T)attrs.scalar;
}

ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape,
ArrayShape const &output_shape,
ElementUnaryUnifiedAttrs const &attrs) {

ffTensorDescriptor_t inputTensor;
ffTensorDescriptor_t outputTensor;
ffActivationDescriptor_t actiDesc;

if (use_cudnn(m->op_type)) {
checkCUDNN(cudnnCreateTensorDescriptor(&inputTensor));
checkCUDNN(cudnnCreateTensorDescriptor(&outputTensor));
checkCUDNN(cudnnCreateActivationDescriptor(&actiDesc));

Op op_type = std::visit([](auto &&arg) { get_op_type(arg); }, attrs);

if (use_cudnn(op_type)) {
cudnnActivationMode_t mode;
switch (m->op_type) {
switch (op_type) {
case OP_SIGMOID:
mode = CUDNN_ACTIVATION_SIGMOID;
break;
Expand All @@ -67,78 +73,89 @@ void init_kernel(ElementUnaryPerDeviceState *m,
default:
assert(false);
}
checkCUDNN(cudnnSetActivationDescriptor(
m->actiDesc, mode, CUDNN_PROPAGATE_NAN, 0.0));
checkCUDNN(
cudnnSetTensorDescriptorFromDomain(m->inputTensor, input_domain));
// input_domain == output_domain
cudnnSetActivationDescriptor(actiDesc, mode, CUDNN_PROPAGATE_NAN, 0.0));
checkCUDNN(
cudnnSetTensorDescriptorFromArrayShape(inputTensor, input_shape));
checkCUDNN(
cudnnSetTensorDescriptorFromDomain(m->outputTensor, output_domain));
cudnnSetTensorDescriptorFromArrayShape(outputTensor, output_shape));
}

ElementUnaryPerDeviceState per_device_state = {
inputTensor, outputTensor, actiDesc};

return per_device_state;
}

template <DataType T>
struct ForwardKernel {
void operator()(ffStream_t stream,
ElementUnaryPerDeviceState const *m,
ElementUnaryPerDeviceState const &m,
ElementUnaryUnifiedAttrs const &attrs,
PerDeviceFFHandle const &handle,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output) const {
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));
if (use_cudnn(m->op_type)) {
checkCUDNN(cudnnSetStream(handle.dnn, stream));
Op op_type = std::visit([](auto &&arg) { get_op_type(arg); }, attrs);
if (use_cudnn(op_type)) {
float alpha = 1.0f, beta = 0.0f;
checkCUDNN(cudnnActivationForward(m->handle.dnn,
m->actiDesc,
checkCUDNN(cudnnActivationForward(handle.dnn,
m.actiDesc,
&alpha,
m->inputTensor,
m.inputTensor,
input.get<T>(),
&beta,
m->outputTensor,
m.outputTensor,
output.get<T>()));
} else {
optional<T> scalar =
std::visit([](auto &&arg) { get_scalar<T>(arg); }, attrs);
size_t num_elements = input.shape.num_elements();
elewise_unary_forward_kernel<<<GET_BLOCKS(num_elements),
CUDA_NUM_THREADS,
0,
stream>>>(num_elements,
(T)m->scalar,
m->op_type,
input.get<T>(),
output.get<T>());
stream>>>(
num_elements, scalar, op_type, input.get<T>(), output.get<T>());
}
}
}

template <DataType T>
struct BackwardKernel {
void operator()(ffStream_t stream,
ElementUnaryPerDeviceState const *m,
ElementUnaryPerDeviceState const &m,
ElementUnaryUnifiedAttrs const &attrs,
PerDeviceFFHandle const &handle,
GenericTensorAccessorR const &input,
GenericTensorAccessorR const &input_grad,
GenericTensorAccessorW const &output,
GenericTensorAccessorW const &output_grad) {
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));
GenericTensorAccessorW const &input_grad,
GenericTensorAccessorR const &output,
GenericTensorAccessorR const &output_grad) {
checkCUDNN(cudnnSetStream(handle.dnn, stream));

if (use_cudnn(m->op_type)) {
Op op_type = std::visit([](auto &&arg) { get_op_type(arg); }, attrs);
if (use_cudnn(op_type)) {
float alpha = 1.0f;
checkCUDNN(cudnnActivationBackward(m->handle.dnn,
m->actiDesc,
checkCUDNN(cudnnActivationBackward(handle.dnn,
m.actiDesc,
&alpha,
m->outputTensor,
m.outputTensor,
output.get<T>(),
m->outputTensor,
m.outputTensor,
output_grad.get<T>()),
m->inputTensor,
m.inputTensor,
input.get<T>(),
&alpha,
m->inputTensor,
m.inputTensor,
input_grad.get<T>()));
} else {
optional<T> scalar =
std::visit([](auto &&arg) { get_scalar<T>(arg); }, attrs);
size_t num_elements = input.shape.num_elements();
elewise_unary_backward_kernel<T>
<<<GET_BLOCKS(num_elements), CUDA_NUM_THREADS, 0, stream>>>(
num_elements,
m->scalar,
m->op_type,
scalar,
op_type,
output.get<T>(),
output_grad.get<T>(),
input.get<T>(),
Expand All @@ -148,26 +165,40 @@ struct BackwardKernel {
}

void forward_kernel(ffStream_t stream,
ElementUnaryPerDeviceState const *m,
ElementUnaryPerDeviceState const &device_state,
ElementUnaryUnifiedAttrs const &attrs,
PerDeviceFFHandle const &handle,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output) {
{
DataTypeDispatch1<ForwardKernel>{}(m->data_type, stream, m, input, output);
}
DataTypeDispatch1<ForwardKernel>{}(
input.data_type, stream, m, attrs, handle, input, output);
}

void backward_kernel(ffStream_t stream,
ElementUnaryPerDeviceState const *m,
GenericTensorAccessorR const &input,
GenericTensorAccessorR const &input_grad,
GenericTensorAccessorW const &output,
GenericTensorAccessorW const &output_grad)
DataTypeDispatch1<BackwardKernel>{}(
m->data_type, stream, m, input, input_grad, output, output_grad);
void backward_kernel(ffStream_t stream,
ElementUnaryPerDeviceState const &device_state,
ElementUnaryUnifiedAttrs const &attrs,
PerDeviceFFHandle const &handle,
GenericTensorAccessorR const &input,
GenericTensorAccessorR const &input_grad,
GenericTensorAccessorW const &output,
GenericTensorAccessorW const &output_grad) {
DataTypeDispatch1<BackwardKernel>{}(input.data_type,
stream,
m,
attrs,
handle,
input,
input_grad,
output,
output_grad);
}

template <typename T>
__global__ void elewise_unary_forward_kernel(
coord_t volume, const T scalar, OperatorType type, T const *in, T *out) {
__global__ void elewise_unary_forward_kernel(coord_t volume,
optional<T> const scalar,
OperatorType type,
T const *in,
T *out) {
CUDA_KERNEL_LOOP(i, volume) {
switch (type) {
case OP_EXP: {
Expand All @@ -179,19 +210,19 @@ __global__ void elewise_unary_forward_kernel(
break;
}
case OP_SCALAR_MULTIPLY: {
out[i] = in[i] * scalar;
out[i] = in[i] * scalar.value();
break;
}
case OP_SCALAR_ADD: {
out[i] = in[i] + scalar;
out[i] = in[i] + scalar.value();
break;
}
case OP_SCALAR_SUB: {
out[i] = in[i] - scalar;
out[i] = in[i] - scalar.value();
break;
}
case OP_SCALAR_TRUE_DIV: {
out[i] = in[i] / scalar;
out[i] = in[i] / scalar.value();
break;
}
case OP_GELU: {
Expand All @@ -203,7 +234,7 @@ __global__ void elewise_unary_forward_kernel(
break;
}
case OP_POW: {
out[i] = (T)(powf(in[i], scalar));
out[i] = (T)(powf(in[i], scalar.value()));
break;
}
case OP_SIN: {
Expand All @@ -222,7 +253,7 @@ __global__ void elewise_unary_forward_kernel(

template <typename T>
__global__ void elewise_unary_backward_kernel(coord_t volume,
const T scalar,
optional<T> const scalar,
OperatorType type,
T const *output,
T const *output_grad,
Expand All @@ -240,7 +271,7 @@ __global__ void elewise_unary_backward_kernel(coord_t volume,
break;
}
case OP_SCALAR_MULTIPLY: {
input_grad[i] += output_grad[i] * scalar;
input_grad[i] += output_grad[i] * scalar.value();
break;
}
case OP_SCALAR_ADD: {
Expand All @@ -252,7 +283,7 @@ __global__ void elewise_unary_backward_kernel(coord_t volume,
break;
}
case OP_SCALAR_TRUE_DIV: {
input_grad[i] += output_grad[i] / scalar;
input_grad[i] += output_grad[i] / scalar.value();
break;
}
case OP_GELU: {
Expand All @@ -268,8 +299,8 @@ __global__ void elewise_unary_backward_kernel(coord_t volume,
break;
}
case OP_POW: {
input_grad[i] =
(T)(output_grad[i] * scalar * powf(input[i], scalar - 1));
input_grad[i] = (T)(output_grad[i] * scalar.value() *
powf(input[i], scalar.value() - 1));
break;
}
case OP_SIN: {
Expand Down
Loading