diff --git a/lib/compiler/test/test_dp.cc b/lib/compiler/test/test_dp.cc index 01e4189839..1878ade0b6 100644 --- a/lib/compiler/test/test_dp.cc +++ b/lib/compiler/test/test_dp.cc @@ -22,8 +22,8 @@ TEST_CASE("optimal_cost") { Node n0 = g.add_node(InputAttrs()); Node n1 = g.add_node(RepartitionAttrs(ff_dim_t(0), 2)); - Node n2 = g.add_node(ElementScalarUnaryAttrs(OP_SCALAR_ADD, 0)); - Node n3 = g.add_node(ElementScalarUnaryAttrs(OP_SCALAR_ADD, 1)); + Node n2 = g.add_node(ElementUnaryAttrs(OP_SCALAR_ADD, 0)); + Node n3 = g.add_node(ElementUnaryAttrs(OP_SCALAR_ADD, 1)); Node n4 = g.add_node(ConcatAttrs(ff_dim_t(1))); Node n5 = g.add_node(CombineAttrs(ff_dim_t(0), 2)); diff --git a/lib/kernels/include/kernels/element_unary_kernels.h b/lib/kernels/include/kernels/element_unary_kernels.h index 428c0ed897..826be47e32 100644 --- a/lib/kernels/include/kernels/element_unary_kernels.h +++ b/lib/kernels/include/kernels/element_unary_kernels.h @@ -3,42 +3,47 @@ #include "kernels/accessor.h" #include "kernels/device.h" -#include "legion.h" +#include "kernels/ff_handle.h" +#include "op-attrs/ops/element_unary.h" #include namespace FlexFlow { -class ElementUnaryPerDeviceState : public PerDeviceOpState { -public: - ElementUnaryPerDeviceState(FFHandler handle); +using ElementUnaryUnifiedAttrs = + variant; + +struct ElementUnaryPerDeviceState { ffTensorDescriptor_t inputTensor, outputTensor; ffActivationDescriptor_t actiDesc; - - OperatorType op_type; - DataType data_type; - bool inplace; - float scalar; - char op_name[MAX_OPNAME]; }; +FF_VISITABLE_STRUCT_NO_EQ(ElementUnaryPerDeviceState, + inputTensor, + outputTensor, + actiDesc); + namespace Kernels { namespace ElementUnary { -void init_kernel(ElementUnaryPerDeviceState *m, - Legion::Domain const &input_domain, - Legion::Domain const &output_domain); +ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, + ArrayShape const &output_shape, + ElementUnaryUnifiedAttrs const &attrs); void forward_kernel(ffStream_t stream, - ElementUnaryPerDeviceState const *m, + ElementUnaryPerDeviceState const &device_state, + ElementUnaryUnifiedAttrs const &attrs, + PerDeviceFFHandle &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output); void backward_kernel(ffStream_t stream, - ElementUnaryPerDeviceState const *m, + ElementUnaryPerDeviceState const &device_state, + ElementUnaryUnifiedAttrs const &attrs, + PerDeviceFFHandle &handle, GenericTensorAccessorR const &input, - GenericTensorAccessorR const &input_grad, - GenericTensorAccessorW const &output, - GenericTensorAccessorW const &output_grad); + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output, + GenericTensorAccessorR const &output_grad); } // namespace ElementUnary } // namespace Kernels diff --git a/lib/kernels/src/cuda/element_unary_kernels.cu b/lib/kernels/src/cuda/element_unary_kernels.cu index 1251f75603..808ff0ef97 100644 --- a/lib/kernels/src/cuda/element_unary_kernels.cu +++ b/lib/kernels/src/cuda/element_unary_kernels.cu @@ -18,18 +18,6 @@ #include "kernels/element_unary_kernels.h" namespace FlexFlow { - -// declare Legion names -using Legion::coord_t; -using Legion::Domain; - -ElementUnaryPerDeviceState::ElementUnaryPerDeviceState(FFHandler handler) - : PerDeviceOpState(handler) { - checkCUDNN(cudnnCreateTensorDescriptor(&inputTensor)); - checkCUDNN(cudnnCreateTensorDescriptor(&outputTensor)); - checkCUDNN(cudnnCreateActivationDescriptor(&actiDesc)); -} - namespace Kernels { namespace ElementUnary { @@ -45,13 +33,31 @@ static bool use_cudnn(OperatorType op_type) { } } -void init_kernel(ElementUnaryPerDeviceState *m, - Domain const &input_domain, - Domain const &output_domain) { +template +optional get_scalar(ElementUnaryAttrs const &attrs) {} + +template +optional get_scalar(ElementScalarUnaryAttrs const &attrs) { + return (T)attrs.scalar; +} + +ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, + ArrayShape const &output_shape, + ElementUnaryUnifiedAttrs const &attrs) { + + ffTensorDescriptor_t inputTensor; + ffTensorDescriptor_t outputTensor; + ffActivationDescriptor_t actiDesc; - if (use_cudnn(m->op_type)) { + checkCUDNN(cudnnCreateTensorDescriptor(&inputTensor)); + checkCUDNN(cudnnCreateTensorDescriptor(&outputTensor)); + checkCUDNN(cudnnCreateActivationDescriptor(&actiDesc)); + + Op op_type = std::visit([](auto &&arg) { get_op_type(arg); }, attrs); + + if (use_cudnn(op_type)) { cudnnActivationMode_t mode; - switch (m->op_type) { + switch (op_type) { case OP_SIGMOID: mode = CUDNN_ACTIVATION_SIGMOID; break; @@ -67,43 +73,49 @@ void init_kernel(ElementUnaryPerDeviceState *m, default: assert(false); } - checkCUDNN(cudnnSetActivationDescriptor( - m->actiDesc, mode, CUDNN_PROPAGATE_NAN, 0.0)); checkCUDNN( - cudnnSetTensorDescriptorFromDomain(m->inputTensor, input_domain)); - // input_domain == output_domain + cudnnSetActivationDescriptor(actiDesc, mode, CUDNN_PROPAGATE_NAN, 0.0)); + checkCUDNN( + cudnnSetTensorDescriptorFromArrayShape(inputTensor, input_shape)); checkCUDNN( - cudnnSetTensorDescriptorFromDomain(m->outputTensor, output_domain)); + cudnnSetTensorDescriptorFromArrayShape(outputTensor, output_shape)); } + + ElementUnaryPerDeviceState per_device_state = { + inputTensor, outputTensor, actiDesc}; + + return per_device_state; } template struct ForwardKernel { void operator()(ffStream_t stream, - ElementUnaryPerDeviceState const *m, + ElementUnaryPerDeviceState const &m, + ElementUnaryUnifiedAttrs const &attrs, + PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) const { - checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); - if (use_cudnn(m->op_type)) { + checkCUDNN(cudnnSetStream(handle.dnn, stream)); + Op op_type = std::visit([](auto &&arg) { get_op_type(arg); }, attrs); + if (use_cudnn(op_type)) { float alpha = 1.0f, beta = 0.0f; - checkCUDNN(cudnnActivationForward(m->handle.dnn, - m->actiDesc, + checkCUDNN(cudnnActivationForward(handle.dnn, + m.actiDesc, &alpha, - m->inputTensor, + m.inputTensor, input.get(), &beta, - m->outputTensor, + m.outputTensor, output.get())); } else { + optional scalar = + std::visit([](auto &&arg) { get_scalar(arg); }, attrs); size_t num_elements = input.shape.num_elements(); elewise_unary_forward_kernel<<>>(num_elements, - (T)m->scalar, - m->op_type, - input.get(), - output.get()); + stream>>>( + num_elements, scalar, op_type, input.get(), output.get()); } } } @@ -111,34 +123,39 @@ struct ForwardKernel { template struct BackwardKernel { void operator()(ffStream_t stream, - ElementUnaryPerDeviceState const *m, + ElementUnaryPerDeviceState const &m, + ElementUnaryUnifiedAttrs const &attrs, + PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, - GenericTensorAccessorR const &input_grad, - GenericTensorAccessorW const &output, - GenericTensorAccessorW const &output_grad) { - checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output, + GenericTensorAccessorR const &output_grad) { + checkCUDNN(cudnnSetStream(handle.dnn, stream)); - if (use_cudnn(m->op_type)) { + Op op_type = std::visit([](auto &&arg) { get_op_type(arg); }, attrs); + if (use_cudnn(op_type)) { float alpha = 1.0f; - checkCUDNN(cudnnActivationBackward(m->handle.dnn, - m->actiDesc, + checkCUDNN(cudnnActivationBackward(handle.dnn, + m.actiDesc, &alpha, - m->outputTensor, + m.outputTensor, output.get(), - m->outputTensor, + m.outputTensor, output_grad.get()), - m->inputTensor, + m.inputTensor, input.get(), &alpha, - m->inputTensor, + m.inputTensor, input_grad.get())); } else { + optional scalar = + std::visit([](auto &&arg) { get_scalar(arg); }, attrs); size_t num_elements = input.shape.num_elements(); elewise_unary_backward_kernel <<>>( num_elements, - m->scalar, - m->op_type, + scalar, + op_type, output.get(), output_grad.get(), input.get(), @@ -148,26 +165,40 @@ struct BackwardKernel { } void forward_kernel(ffStream_t stream, - ElementUnaryPerDeviceState const *m, + ElementUnaryPerDeviceState const &device_state, + ElementUnaryUnifiedAttrs const &attrs, + PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { - { - DataTypeDispatch1{}(m->data_type, stream, m, input, output); - } + DataTypeDispatch1{}( + input.data_type, stream, m, attrs, handle, input, output); +} - void backward_kernel(ffStream_t stream, - ElementUnaryPerDeviceState const *m, - GenericTensorAccessorR const &input, - GenericTensorAccessorR const &input_grad, - GenericTensorAccessorW const &output, - GenericTensorAccessorW const &output_grad) - DataTypeDispatch1{}( - m->data_type, stream, m, input, input_grad, output, output_grad); +void backward_kernel(ffStream_t stream, + ElementUnaryPerDeviceState const &device_state, + ElementUnaryUnifiedAttrs const &attrs, + PerDeviceFFHandle const &handle, + GenericTensorAccessorR const &input, + GenericTensorAccessorR const &input_grad, + GenericTensorAccessorW const &output, + GenericTensorAccessorW const &output_grad) { + DataTypeDispatch1{}(input.data_type, + stream, + m, + attrs, + handle, + input, + input_grad, + output, + output_grad); } template -__global__ void elewise_unary_forward_kernel( - coord_t volume, const T scalar, OperatorType type, T const *in, T *out) { +__global__ void elewise_unary_forward_kernel(coord_t volume, + optional const scalar, + OperatorType type, + T const *in, + T *out) { CUDA_KERNEL_LOOP(i, volume) { switch (type) { case OP_EXP: { @@ -179,19 +210,19 @@ __global__ void elewise_unary_forward_kernel( break; } case OP_SCALAR_MULTIPLY: { - out[i] = in[i] * scalar; + out[i] = in[i] * scalar.value(); break; } case OP_SCALAR_ADD: { - out[i] = in[i] + scalar; + out[i] = in[i] + scalar.value(); break; } case OP_SCALAR_SUB: { - out[i] = in[i] - scalar; + out[i] = in[i] - scalar.value(); break; } case OP_SCALAR_TRUE_DIV: { - out[i] = in[i] / scalar; + out[i] = in[i] / scalar.value(); break; } case OP_GELU: { @@ -203,7 +234,7 @@ __global__ void elewise_unary_forward_kernel( break; } case OP_POW: { - out[i] = (T)(powf(in[i], scalar)); + out[i] = (T)(powf(in[i], scalar.value())); break; } case OP_SIN: { @@ -222,7 +253,7 @@ __global__ void elewise_unary_forward_kernel( template __global__ void elewise_unary_backward_kernel(coord_t volume, - const T scalar, + optional const scalar, OperatorType type, T const *output, T const *output_grad, @@ -240,7 +271,7 @@ __global__ void elewise_unary_backward_kernel(coord_t volume, break; } case OP_SCALAR_MULTIPLY: { - input_grad[i] += output_grad[i] * scalar; + input_grad[i] += output_grad[i] * scalar.value(); break; } case OP_SCALAR_ADD: { @@ -252,7 +283,7 @@ __global__ void elewise_unary_backward_kernel(coord_t volume, break; } case OP_SCALAR_TRUE_DIV: { - input_grad[i] += output_grad[i] / scalar; + input_grad[i] += output_grad[i] / scalar.value(); break; } case OP_GELU: { @@ -268,8 +299,8 @@ __global__ void elewise_unary_backward_kernel(coord_t volume, break; } case OP_POW: { - input_grad[i] = - (T)(output_grad[i] * scalar * powf(input[i], scalar - 1)); + input_grad[i] = (T)(output_grad[i] * scalar.value() * + powf(input[i], scalar.value() - 1)); break; } case OP_SIN: { diff --git a/lib/kernels/src/hip/element_unary_kernels.cpp b/lib/kernels/src/hip/element_unary_kernels.cpp index 58bec1b262..e79ef57592 100644 --- a/lib/kernels/src/hip/element_unary_kernels.cpp +++ b/lib/kernels/src/hip/element_unary_kernels.cpp @@ -14,50 +14,55 @@ */ #include "kernels/element_unary_kernels.h" +#include "kernels/datatype_dispatch.h" #include "kernels/hip_helper.h" #include namespace FlexFlow { +namespace Kernels { +namespace ElementUnary { -// declare Legion names -using Legion::coord_t; -using Legion::Domain; +ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, + ArrayShape const &output_shape, + ElementUnaryAttrs const &attrs) { + miopenTensorDescriptor_t inputTensor; + miopenTensorDescriptor_t outputTensor; + miopenActivationDescriptor_t actiDesc; + miopenActivationMode_t mode; -ElementUnaryPerDeviceState::ElementUnaryPerDeviceState(FFHandler handler) - : PerDeviceOpState(handler) { checkCUDNN(miopenCreateTensorDescriptor(&inputTensor)); checkCUDNN(miopenCreateTensorDescriptor(&outputTensor)); checkCUDNN(miopenCreateActivationDescriptor(&actiDesc)); -} - -namespace Kernels { -namespace ElementUnary { -void init_kernel(ElementUnaryPerDeviceState *m, - Domain const &input_domain, - Domain const &output_domain) { - miopenActivationMode_t mode; - switch (m->op_type) { - case OP_SIGMOID: - mode = miopenActivationLOGISTIC; - break; - case OP_RELU: - mode = miopenActivationRELU; - break; - case OP_TANH: - mode = miopenActivationTANH; - break; - case OP_ELU: - mode = miopenActivationELU; - break; - default: - assert(false); + if (use_cudnn(attrs.op_type)) { + switch (attrs.op_type) { + case OP_SIGMOID: + mode = miopenActivationLOGISTIC; + break; + case OP_RELU: + mode = miopenActivationRELU; + break; + case OP_TANH: + mode = miopenActivationTANH; + break; + case OP_ELU: + mode = miopenActivationELU; + break; + default: + assert(false); + } + checkCUDNN(miopenSetActivationDescriptor(actiDesc, mode, 0.0, 0.0, 0.0)); + checkCUDNN( + cudnnSetTensorDescriptorFromArrayShape(inputTensor, input_shape)); + // input_domain == output_domain + checkCUDNN( + cudnnSetTensorDescriptorFromArrayShape(outputTensor, output_shape)); } - checkCUDNN(miopenSetActivationDescriptor(m->actiDesc, mode, 0.0, 0.0, 0.0)); - checkCUDNN(cudnnSetTensorDescriptorFromDomain(m->inputTensor, input_domain)); - // input_domain == output_domain - checkCUDNN( - cudnnSetTensorDescriptorFromDomain(m->outputTensor, output_domain)); + + ElementUnaryPerDeviceState per_device_state = { + inputTensor, outputTensor, actiDesc}; + + return per_device_state; } bool use_cudnn(OperatorType type) { @@ -79,19 +84,21 @@ bool use_cudnn(OperatorType type) { template struct ForwardKernel { void operator()(ffStream_t stream, - ElementUnaryPerDeviceState const *m, + ElementUnaryPerDeviceState const &m, + ElementUnaryAttrs const &attrs, + PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { - checkCUDNN(miopenSetStream(m->handle.dnn, stream)); - if (use_cudnn(m->op_type)) { + checkCUDNN(miopenSetStream(handle.dnn, stream)); + if (use_cudnn(attrs.op_type)) { float alpha = 1.0f, beta = 0.0f; - checkCUDNN(miopenActivationForward(m->handle.dnn, - m->actiDesc, + checkCUDNN(miopenActivationForward(handle.dnn, + m.actiDesc, &alpha, - m->inputTensor, + m.inputTensor, input.get(), &beta, - m->outputTensor, + m.outputTensor, output.get())); } else { size_t num_elements = input.shape.num_elements(); @@ -101,8 +108,8 @@ struct ForwardKernel { 0, stream, num_elements, - (T)m->scalar, - m->op_type, + (T)attrs.scalar, + attrs.op_type, input.get(), output.get()); } @@ -112,27 +119,29 @@ struct ForwardKernel { template struct BackwardKernel { void operator()(ffStream_t stream, - ElementUnaryPerDeviceState const *m, + ElementUnaryPerDeviceState const &m, + ElementUnaryAttrs const &attrs, + PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorR const &input_grad, GenericTensorAccessorW const &output, GenericTensorAccessorW const &output_grad) { - checkCUDNN(miopenSetStream(m->handle.dnn, stream)); + checkCUDNN(miopenSetStream(handle.dnn, stream)); - if (use_cudnn(m->op_type)) { + if (use_cudnn(attrs.op_type)) { float alpha = 1.0f; float beta = 0.0f; - checkCUDNN(miopenActivationBackward(m->handle.dnn, - m->actiDesc, + checkCUDNN(miopenActivationBackward(handle.dnn, + m.actiDesc, &alpha, - m->outputTensor, + m.outputTensor, output.get(), - m->outputTensor, + m.outputTensor, output_grad.get()), - m->inputTensor, + m.inputTensor, input.get(), &beta, - m->inputTensor, + m.inputTensor, input_grad.get()); } else { size_t num_elements = input.shape.num_elements(); @@ -142,8 +151,8 @@ struct BackwardKernel { 0, stream, num_elements, - m->scalar, - m->op_type, + attrs.scalar, + attrs.op_type, output.get(), output_grad.get(), input.get(), @@ -151,21 +160,32 @@ struct BackwardKernel { } } } void forward_kernel(ffStream_t stream, - ElementUnaryPerDeviceState const *m, + ElementUnaryPerDeviceState const &device_state, + ElementUnaryAttrs const &attrs, + PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { - { - DataTypeDispatch1{}(m->data_type, stream, m, input, output); - } + DataTypeDispatch1{}( + input.data_type, stream, m, attrs, handle, input, output); +} - void backward_kernel(ffStream_t stream, - ElementUnaryPerDeviceState const *m, - GenericTensorAccessorR const &input, - GenericTensorAccessorR const &input_grad, - GenericTensorAccessorW const &output, - GenericTensorAccessorW const &output_grad) - DataTypeDispatch1{}( - m->data_type, stream, m, input, input_grad, output, output_grad); +void backward_kernel(ffStream_t stream, + ElementUnaryPerDeviceState const &device_state, + ElementUnaryAttrs const &attrs, + PerDeviceFFHandle const &handle, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output, + GenericTensorAccessorR const &output_grad) { + DataTypeDispatch1{}(input.data_type, + stream, + m, + attrs, + handle, + input, + input_grad, + output, + output_grad); } template diff --git a/lib/op-attrs/include/op-attrs/get_op_type.h b/lib/op-attrs/include/op-attrs/get_op_type.h index af13a75720..421c464843 100644 --- a/lib/op-attrs/include/op-attrs/get_op_type.h +++ b/lib/op-attrs/include/op-attrs/get_op_type.h @@ -14,8 +14,8 @@ OperatorType get_op_type(ConcatAttrs const &); OperatorType get_op_type(Conv2DAttrs const &); OperatorType get_op_type(DropoutAttrs const &); OperatorType get_op_type(ElementBinaryAttrs const &); -OperatorType get_op_type(ElementScalarUnaryAttrs const &); OperatorType get_op_type(ElementUnaryAttrs const &); +OperatorType get_op_type(ElementScalarUnaryAttrs const &); OperatorType get_op_type(EmbeddingAttrs const &); OperatorType get_op_type(FlatAttrs const &); OperatorType get_op_type(GatherAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/get_output_shapes.h b/lib/op-attrs/include/op-attrs/get_output_shapes.h index 3daa97ba3e..5f78ec2d3f 100644 --- a/lib/op-attrs/include/op-attrs/get_output_shapes.h +++ b/lib/op-attrs/include/op-attrs/get_output_shapes.h @@ -130,6 +130,8 @@ ParallelTensorShape get_output_shape(ElementBinaryAttrs const &, ParallelTensorShape const &); ParallelTensorShape get_output_shape(ElementUnaryAttrs const &, ParallelTensorShape const &); +ParallelTensorShape get_output_shape(ElementScalarUnaryAttrs const &, + ParallelTensorShape const &); ParallelTensorShape get_output_shape(EmbeddingAttrs const &, ParallelTensorShape const &); ParallelTensorShape get_output_shape(FlatAttrs const &, @@ -238,6 +240,8 @@ bool is_valid_internal(ElementBinaryAttrs const &, ParallelTensorShape const &, ParallelTensorShape const &); bool is_valid_internal(ElementUnaryAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(ElementScalarUnaryAttrs const &, + ParallelTensorShape const &); bool is_valid_internal(EmbeddingAttrs const &, ParallelTensorShape const &); bool is_valid_internal(FlatAttrs const &, ParallelTensorShape const &); bool is_valid_internal(GatherAttrs const &, diff --git a/lib/op-attrs/include/op-attrs/operator_attrs.h b/lib/op-attrs/include/op-attrs/operator_attrs.h index f429facf6f..a7ba84624c 100644 --- a/lib/op-attrs/include/op-attrs/operator_attrs.h +++ b/lib/op-attrs/include/op-attrs/operator_attrs.h @@ -42,8 +42,8 @@ using SharedOperatorAttrs = variant::value, ""); static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); +static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary.h b/lib/op-attrs/include/op-attrs/ops/element_unary.h index 1b72e83cb5..5e19b81c8c 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary.h @@ -8,20 +8,19 @@ namespace FlexFlow { +struct ElementUnaryAttrs { + req op_type; +}; +FF_VISITABLE_STRUCT(ElementUnaryAttrs, op_type); +CHECK_VALID_OP_ATTR(ElementUnaryAttrs); + struct ElementScalarUnaryAttrs { - req op; - /* bool inplace; */ + Op op_type; req scalar; }; -FF_VISITABLE_STRUCT(ElementScalarUnaryAttrs, op, scalar); +FF_VISITABLE_STRUCT(ElementScalarUnaryAttrs, op_type, scalar); CHECK_VALID_OP_ATTR(ElementScalarUnaryAttrs); -struct ElementUnaryAttrs { - req op; -}; -FF_VISITABLE_STRUCT(ElementUnaryAttrs, op); -CHECK_VALID_OP_ATTR(ElementUnaryAttrs); - } // namespace FlexFlow #endif diff --git a/lib/op-attrs/src/get_op_type.cc b/lib/op-attrs/src/get_op_type.cc index 7b6bf9eddd..3fa401b647 100644 --- a/lib/op-attrs/src/get_op_type.cc +++ b/lib/op-attrs/src/get_op_type.cc @@ -27,7 +27,10 @@ OperatorType get_op_type(ElementBinaryAttrs const &attrs) { return attrs.type; } OperatorType get_op_type(ElementUnaryAttrs const &attrs) { - return attrs.op; + return attrs.op_type; +} +OperatorType get_op_type(ElementScalarUnaryAttrs const &attrs) { + return attrs.op_type; } OperatorType get_op_type(EmbeddingAttrs const &) { return Op::EMBEDDING; @@ -86,9 +89,6 @@ OperatorType get_op_type(RepartitionAttrs const &) { OperatorType get_op_type(ReplicateAttrs const &) { return Op::REPLICATE; } -OperatorType get_op_type(ElementScalarUnaryAttrs const &attrs) { - return attrs.op; -} OperatorType get_op_type(ReverseAttrs const &attrs) { return Op::REVERSE; } diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index 2d65a37a2d..ae937c590d 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -246,10 +246,12 @@ struct ComputationGraphBuilder Tensor const &input, float scalar, optional const &name = nullopt); - Tensor - element_unary(variant const &, - Tensor const &input, - optional const &name = nullopt); + Tensor element_unary(ElementUnaryAttrs const &, + Tensor const &input, + optional const &name = nullopt); + Tensor element_scalar_unary(ElementScalarUnaryAttrs const &attrs, + Tensor const &x, + optional const &maybe_name); public: ComputationGraph computation_graph; diff --git a/lib/pcg/src/computation_graph_builder.cc b/lib/pcg/src/computation_graph_builder.cc index 61007aba79..9f8e930919 100644 --- a/lib/pcg/src/computation_graph_builder.cc +++ b/lib/pcg/src/computation_graph_builder.cc @@ -65,14 +65,28 @@ static std::string get_default_name(variant const &attrs) { } Tensor ComputationGraphBuilder::element_unary( - variant const &attrs, + ElementUnaryAttrs const &attrs, Tensor const &x, optional const &maybe_name) { std::string name = maybe_name.value_or(get_default_name(attrs)); Tensor input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); - Layer layer = {widen(attrs), name}; + Layer layer = {attrs, name}; + TensorShape output_shape = get_output_shape(attrs, input); + + return this->add_layer(layer, {input}, {}, output_shape); +} + +Tensor ComputationGraphBuilder::element_scalar_unary( + ElementScalarUnaryAttrs const &attrs, + Tensor const &x, + optional const &maybe_name) { + std::string name = maybe_name.value_or(get_default_name(attrs)); + + Tensor input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); + + Layer layer = {attrs, name}; TensorShape output_shape = get_output_shape(attrs, input); return this->add_layer(layer, {input}, {}, output_shape); @@ -92,7 +106,7 @@ Tensor ComputationGraphBuilder::element_scalar_unary( float scalar, optional const &name) { ElementScalarUnaryAttrs attrs = {op_type, scalar}; - return this->element_unary(attrs, input, name); + return this->element_scalar_unary(attrs, input, name); } Tensor ComputationGraphBuilder::element_binary( diff --git a/lib/runtime/src/ops/element_unary.cc b/lib/runtime/src/ops/element_unary.cc index 07959bd6da..f41a8b3551 100644 --- a/lib/runtime/src/ops/element_unary.cc +++ b/lib/runtime/src/ops/element_unary.cc @@ -6,665 +6,234 @@ namespace FlexFlow { // declare Legion names -using Legion::ArgumentMap; using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; using Legion::Runtime; using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; using namespace FlexFlow::Kernels::ElementUnary; -Tensor FFModel::unary(OperatorType op, - const Tensor x, - bool inplace, - char const *name, - float scalar) { - Layer *ele = nullptr; - DataType dtype; - // FIXME: currently cast input to float if it has a lower type - if (x->data_type < DT_FLOAT) { - dtype = DT_FLOAT; - std::string str(name); - Tensor new_x = cast(x, dtype, (str + "input_pre_cast").c_str()); - ele = new Layer(this, - op, - dtype, - name, - 1 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - new_x); - } else { - dtype = x->data_type; - ele = new Layer( - this, op, dtype, name, 1 /*inputs*/, 0 /*weights*/, 1 /*outputs*/, x); - } - int numdims = x->num_dims; - int dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdims; i++) { - dims[i] = x->dims[i]; - } - ele->outputs[0] = create_tensor_legion_ordering( - numdims, dims, dtype, ele, 0, true /*create_grad*/); - ele->add_int_property("inplace", inplace); - ele->add_float_property("scalar", scalar); - layers.push_back(ele); - return ele->outputs[0]; -} +enum Slots { + INPUT, + INPUT_SHAPE, + OUTPUT, + ATTRS, + HANDLE, + PROFILING, + PER_DEVICE_STATE +}; -Op *ElementUnary::create_operator_from_layer( - FFModel &model, - Layer const *layer, - std::vector const &inputs) { - long long value; - layer->get_int_property("inplace", value); - bool inplace = (bool)value; - float scalar; - layer->get_float_property("scalar", scalar); - return new ElementUnary( - model, layer->op_type, inputs[0], inplace, layer->name, scalar); -} +/* ElementUnary */ +OpTaskInvocation init(ElementUnaryUnifiedAttrs const &attrs) { + OpTaskBinding b; -Tensor FFModel::exp(const Tensor x, char const *name) { - return this->unary(OP_EXP, x, false /*inplace*/, name); -} + b.bind_arg(HANDLE, ff_handle()); + b.bind_arg(ATTRS, attrs); + b.bind_arg(INPUT_SHAPE, input_parallel_tensor_shape(0)); -Tensor FFModel::scalar_multiply(const Tensor x, - float const scalar, - bool inplace, - char const *name) { - return this->unary(OP_SCALAR_MULTIPLY, x, inplace, name, scalar); + return {ELEMENTUNARY_INIT_TASK_ID, b}; } -Tensor FFModel::scalar_add(const Tensor x, - float const scalar, - bool inplace, - char const *name) { - return this->unary(OP_SCALAR_ADD, x, inplace, name, scalar); -} +OpTaskInvocation forward(ElementUnaryUnifiedAttrs const &attrs) { + OpTaskBinding b; -Tensor FFModel::scalar_sub(const Tensor x, - float const scalar, - bool inplace, - char const *name) { - return this->unary(OP_SCALAR_SUB, x, inplace, name, scalar); -} + b.bind(INPUT, input_tensor(0)); + b.bind(OUTPUT, output_tensor(0)); -Tensor FFModel::scalar_truediv(const Tensor x, - float const scalar, - bool inplace, - char const *name) { - return this->unary(OP_SCALAR_TRUE_DIV, x, inplace, name, scalar); -} + b.bind_arg(PROFILING, profiling_settings()); + b.bind_arg(PER_DEVICE_STATE, + per_device_op_state()); -Tensor FFModel::relu(const Tensor x, bool inplace, char const *name) { - return this->unary(OP_RELU, x, inplace, name); + return {ELEMENTUNARY_FWD_TASK_ID, b}; } -Tensor FFModel::sigmoid(const Tensor x, char const *name) { - return this->unary(OP_SIGMOID, x, false /*inplace*/, name); -} +OpTaskInvocation backward(ElementUnaryUnifiedAttrs const &attrs) { + OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); -Tensor FFModel::tanh(const Tensor x, char const *name) { - return this->unary(OP_TANH, x, false /*inplace*/, name); + return {ELEMENTUNARY_BWD_TASK_ID, b}; } -Tensor FFModel::identity(const Tensor x, char const *name) { - return this->unary(OP_IDENTITY, x, false /*inplace*/, name); -} +static DeviceSpecific + init_task_impl(TaskArgumentAccessor const &acc) { -Tensor FFModel::gelu(const Tensor x, char const *name) { - return this->unary(OP_GELU, x, false /*inplace*/, name); -} + auto const &attrs = acc.get_argument(ATTRS); + ProfilingSettings profiling = acc.get_argument(PROFILING); + PerDeviceFFHandle handle = acc.get_argument(HANDLE); + ParallelTensorShape input_shape = + acc.get_argument(INPUT_SHAPE); + ParallelTensorShape output_shape = get_output_shape(attrs, input_shape); -Tensor FFModel::elu(const Tensor x, bool inplace, char const *name) { - // Currently assume inplace is false - assert(!inplace); - return this->unary(OP_ELU, x, inplace, name); + DeviceSpecific per_device_state = + acc.create_device_specific( + init_kernel(input_shape, output_shape, attrs)); + return per_device_state; } -Tensor FFModel::rsqrt(const Tensor x, bool inplace, char const *name) { - return this->unary(OP_RSQRT, x, inplace, name); +static DeviceSpecific + init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + return init_task_impl(acc); } -Tensor FFModel::pow(const Tensor x, - float const exponent, - bool inplace, - char const *name) { - return this->unary(OP_POW, x, inplace, name, exponent); -} +static optional forward_task_impl(TaskArgumentAccessor const &acc) { + auto input = acc.get_tensor(INPUT); + auto output = acc.get_tensor(OUTPUT); + auto const &attrs = acc.get_argument(ATTRS); -Tensor FFModel::sin(const Tensor x, char const *name) { - return this->unary(OP_SIN, x, false /*inplace*/, name); -} + auto &handle = acc.get_argument(HANDLE); -Tensor FFModel::cos(const Tensor x, char const *name) { - return this->unary(OP_COS, x, false /*inplace*/, name); -} + ProfilingSettings profiling = acc.get_argument(PROFILING); + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); -bool ElementUnaryParams::is_valid(ParallelTensorShape const &input) const { - return input.is_valid(); + return profile(forward_kernel, + profiling, + "[ElementUnary] forward_time = %.2lfms\n", + per_device_state, + attrs, + handle, + input, + output); } -bool operator==(ElementUnaryParams const &lhs, ElementUnaryParams const &rhs) { - return lhs.op_type == rhs.op_type && lhs.scalar == rhs.scalar && - lhs.inplace == rhs.inplace; +static void forward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + forward_task_impl(acc); } -ElementUnary::ElementUnary(FFModel &model, - OperatorType _op_type, - const ParallelTensor x, - bool _inplace, - char const *name, - float _scalar) - : Op(model, - _op_type, - x->data_type, - name, - 1 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - x), - inplace(_inplace), scalar(_scalar) { - numOutputs = 1; - int numdim = x->num_dims; - ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdim; i++) { - dims[i] = x->dims[i]; - } - outputs[0] = model.create_parallel_tensor_legion_ordering( - numdim, dims, inputs[0]->data_type, this); - // Disable inplace if shape mismatch - if (outputs[0]->get_shape() != inputs[0]->get_shape()) { - inplace = false; - } -} +static optional backward_task_impl(TaskArgumentAccessor const &acc) { + auto input = acc.get_tensor(INPUT); + auto input_grad = acc.get_tensor_grad(INPUT); + auto output = acc.get_tensor(OUTPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); -ElementUnary::ElementUnary(FFModel &model, - ElementUnaryParams const ¶ms, - const ParallelTensor input, - char const *name) - : ElementUnary( - model, params.op_type, input, params.inplace, name, params.scalar) {} - -void ElementUnary::map_output_tensors(FFModel &ff) { - if (has_inplace_output()) { - assert(numOutputs == 1); - assert(outputs[0]->get_volume() == inputs[0]->get_volume()); - outputs[0]->parallel_is = inputs[0]->parallel_is; - outputs[0]->region = inputs[0]->region; - outputs[0]->part = inputs[0]->part; - outputs[0]->region_grad = inputs[0]->region_grad; - outputs[0]->part_grad = inputs[0]->part_grad; - } else { - Op::map_output_tensors(ff); - } -} + auto const &attrs = acc.get_argument(ATTRS); + auto &handle = acc.get_argument(HANDLE); -bool ElementUnary::can_inplace_output(void) { - return outputs[0]->get_shape() == inputs[0]->get_shape(); -} + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + ProfilingSettings profiling = acc.get_argument(PROFILING); -bool ElementUnary::has_inplace_output(void) { - return inplace; + return profile(backward_kernel, + profiling, + "[ElementUnary] backward_time = %.2lfms\n", + per_device_state, + attrs, + handle, + input, + input_grad, + output, + output_grad); } -void ElementUnary::do_inplace_output(void) { - inplace = true; +static void backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + backward_task_impl(acc); } -void ElementUnary::init(FFModel const &ff) { - assert(check_output_input_weight_same_parallel_is()); - parallel_is = outputs[0]->parallel_is; - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - set_argumentmap_for_init(ff, argmap); - IndexLauncher init_launcher(ELEMENTUNARY_INIT_TASK_ID, - parallel_is, - TaskArgument(this, sizeof(ElementUnary)), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - outputs[0]->machine_view.hash()); - if (!inplace) { - init_launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - inputs[0]->region)); - init_launcher.add_field(0, FID_DATA); - init_launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - outputs[0]->region)); - init_launcher.add_field(1, FID_DATA); - } else { - init_launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - 0 /*projection id*/, - READ_WRITE, - EXCLUSIVE, - inputs[0]->region)); - init_launcher.add_field(0, FID_DATA); - } - FutureMap fm = runtime->execute_index_space(ctx, init_launcher); - fm.wait_all_results(); - set_opmeta_from_futuremap(ff, fm); -} +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + ElementUnaryUnifiedAttrs const &attrs, + InputParallelTensorDesc const &input_shape, + ProfilingSettings const &settings, + MachineView const &mv) { + auto env = sim.new_environment(); -PerDeviceOpState * - ElementUnary::init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - ElementUnary *eu = (ElementUnary *)task->args; - FFHandler handle = *((FFHandler *)task->local_args); - ElementUnaryMeta *m = new ElementUnaryMeta(handle); - m->op_type = eu->op_type; - m->data_type = eu->outputs[0]->data_type; - // Input and output should have the same data type - assert(eu->outputs[0]->data_type == eu->inputs[0]->data_type); - m->profiling = eu->profiling; - m->inplace = eu->inplace; - m->scalar = eu->scalar; - std::strcpy(m->op_name, eu->name); - if (m->inplace) { - assert(regions.size() == 1); - assert(task->regions.size() == 1); - } else { - assert(regions.size() == 2); - assert(task->regions.size() == 2); - } - - Domain input_domain = runtime->get_index_space_domain( - ctx, task->regions[0].region.get_index_space()); - init_kernel(m, input_domain, input_domain); - return m; -} + ParallelTensorShape output_shape = get_output_shape(attrs, input_shape); -void ElementUnary::forward(FFModel const &ff) { - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - set_argumentmap_for_forward(ff, argmap); - IndexLauncher launcher(ELEMENTUNARY_FWD_TASK_ID, - parallel_is, - TaskArgument(NULL, 0), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - outputs[0]->machine_view.hash()); - if (inplace) { - assert(outputs[0]->part == inputs[0]->part); - assert(outputs[0]->region == inputs[0]->region); - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection id*/, - READ_WRITE, - EXCLUSIVE, - outputs[0]->region)); - launcher.add_field(0, FID_DATA); - } else { - launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - inputs[0]->region)); - launcher.add_field(0, FID_DATA); - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - outputs[0]->region)); - launcher.add_field(1, FID_DATA); - } - runtime->execute_index_space(ctx, launcher); -} + SimTaskBinding init_binding; + init_binding.bind_arg(HANDLE, ff_handle()); + init_binding.bind_arg(ATTRS, attrs); + init_binding.bind_arg(INPUT_SHAPE, input_parallel_tensor_shape(0)); -void ElementUnary::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - ElementUnaryMeta const *m = *((ElementUnaryMeta **)task->local_args); - if (m->data_type == DT_FLOAT) { - forward_task_with_type(task, regions, ctx, runtime); - } else if (m->data_type == DT_DOUBLE) { - forward_task_with_type(task, regions, ctx, runtime); - } else if (m->data_type == DT_INT32) { - forward_task_with_type(task, regions, ctx, runtime); - } else if (m->data_type == DT_INT64) { - forward_task_with_type(task, regions, ctx, runtime); - } else { - assert(false && "Unsupported data type in Embedding forward"); - } -} + auto init_accessor = + env.get_init_accessor(ELEMENTUNARY_INIT_TASK_ID, init_binding); + DeviceSpecific per_device_state = + init_task_impl(init_accessor); -/* - regions[0](I): input - regions[1](O): output -*/ -template -void ElementUnary::forward_task_with_type( - Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - // const ElementUnary* ele = (const ElementUnary*) task->args; - ElementUnaryMeta const *m = *((ElementUnaryMeta **)task->local_args); - Domain input_domain = runtime->get_index_space_domain( - ctx, task->regions[0].region.get_index_space()); - const DT *input_ptr = NULL; - DT *output_ptr = NULL; - if (m->inplace) { - assert(regions.size() == 1); - assert(task->regions.size() == 1); - output_ptr = helperGetTensorPointerRW
( - regions[0], task->regions[0], FID_DATA, ctx, runtime); - input_ptr = output_ptr; - } else { - assert(regions.size() == 2); - assert(task->regions.size() == 2); - Domain output_domain = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); - assert(output_domain == input_domain); - input_ptr = helperGetTensorPointerRO
( - regions[0], task->regions[0], FID_DATA, ctx, runtime); - output_ptr = helperGetTensorPointerWO
( - regions[1], task->regions[1], FID_DATA, ctx, runtime); - } - - forward_kernel_wrapper
( - m, input_ptr, output_ptr, input_domain.get_volume()); -} + SimTaskBinding fwd_binding; + fwd_binding.bind(INPUT, input_shape); + fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind_arg(PROFILING, settings); + fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); -void ElementUnary::backward(FFModel const &ff) { - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - set_argumentmap_for_backward(ff, argmap); - IndexLauncher launcher(ELEMENTUNARY_BWD_TASK_ID, - parallel_is, - TaskArgument(NULL, 0), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - outputs[0]->machine_view.hash()); - if (inplace) { - assert(inputs[0]->part == outputs[0]->part); - assert(inputs[0]->part_grad == outputs[0]->part_grad); - // regions[2](I): output_grad - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - outputs[0]->region)); - launcher.add_field(0, FID_DATA); - // regions[3](I): output_grad - launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, - 0 /*projection id*/, - READ_WRITE, - EXCLUSIVE, - outputs[0]->region_grad)); - launcher.add_field(1, FID_DATA); - } else { - // regions[0](I): input - launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - inputs[0]->region)); - launcher.add_field(0, FID_DATA); - // regions[1](I/O): input_grad - launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, - 0 /*projection id*/, - READ_WRITE, - EXCLUSIVE, - inputs[0]->region_grad)); - launcher.add_field(1, FID_DATA); - // regions[2](I): output_grad - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - outputs[0]->region)); - launcher.add_field(2, FID_DATA); - // regions[3](I): output_grad - launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - outputs[0]->region_grad)); - launcher.add_field(3, FID_DATA); - } - runtime->execute_index_space(ctx, launcher); -} + SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); -void ElementUnary::backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - ElementUnaryMeta const *m = *((ElementUnaryMeta **)task->local_args); - if (m->data_type == DT_FLOAT) { - backward_task_with_type(task, regions, ctx, runtime); - } else if (m->data_type == DT_DOUBLE) { - backward_task_with_type(task, regions, ctx, runtime); - } else if (m->data_type == DT_INT32) { - backward_task_with_type(task, regions, ctx, runtime); - } else if (m->data_type == DT_INT64) { - backward_task_with_type(task, regions, ctx, runtime); - } else { - assert(false && "Unsupported data type in Embedding forward"); - } -} + auto fwd_accessor = + env.get_fwd_accessor(ELEMENTUNARY_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = + env.get_bwd_accessor(ELEMENTUNARY_BWD_TASK_ID, bwd_binding); -/* - regions[0](I): input - regions[1](I/O): input_grad - regions[2](I): output - regions[3](I): output_grad -*/ -template -void ElementUnary::backward_task_with_type( - Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - // const ElementUnary* ele = (const ElementUnary*) task->args; - ElementUnaryMeta const *m = *((ElementUnaryMeta **)task->local_args); - const DT *input_ptr = NULL, *output_ptr = NULL, *output_grad_ptr = NULL; - DT *input_grad_ptr = NULL; - Domain input_domain = runtime->get_index_space_domain( - ctx, task->regions[0].region.get_index_space()); - if (m->inplace) { - assert(regions.size() == 2); - assert(task->regions.size() == 2); - Domain input_grad_domain = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); - assert(input_grad_domain == input_domain); - input_ptr = helperGetTensorPointerRO
( - regions[0], task->regions[0], FID_DATA, ctx, runtime); - input_grad_ptr = helperGetTensorPointerRW
( - regions[1], task->regions[1], FID_DATA, ctx, runtime); - output_ptr = input_ptr; - output_grad_ptr = input_grad_ptr; - } else { - assert(regions.size() == 4); - assert(task->regions.size() == 4); - Domain input_grad_domain = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); - Domain output_domain = runtime->get_index_space_domain( - ctx, task->regions[2].region.get_index_space()); - Domain output_grad_domain = runtime->get_index_space_domain( - ctx, task->regions[3].region.get_index_space()); - assert(output_grad_domain == input_domain); - assert(output_grad_domain == output_domain); - assert(output_grad_domain == input_grad_domain); - input_ptr = helperGetTensorPointerRO
( - regions[0], task->regions[0], FID_DATA, ctx, runtime); - input_grad_ptr = helperGetTensorPointerRW
( - regions[1], task->regions[1], FID_DATA, ctx, runtime); - output_ptr = helperGetTensorPointerRO
( - regions[2], task->regions[2], FID_DATA, ctx, runtime); - output_grad_ptr = helperGetTensorPointerRO
( - regions[3], task->regions[3], FID_DATA, ctx, runtime); - } - - backward_kernel_wrapper
(m, - input_ptr, - input_grad_ptr, - output_ptr, - output_grad_ptr, - input_domain.get_volume()); + float forward_time = forward_task_impl(fwd_accessor).value(); + float backward_time = backward_task_impl(bwd_accessor).value(); + + float sync_time = default_estimate_sync_time(env); + return make_metrics(forward_time, backward_time, sync_time, env); } -void ElementUnary::serialize(Legion::Serializer &sez) const { - sez.serialize(this->op_type); - sez.serialize(this->inplace); - sez.serialize(scalar); +template <> +OpTaskSignature init_signature() { + OpTaskSignature init(OpTaskType::INIT); + init.add_arg_slot(INPUT_SHAPE); + init.add_arg_slot(ATTRS); + init.add_unchecked_arg_slot(HANDLE); + + init.add_return_value(); + + return init; } -bool ElementUnary::measure_operator_cost(Simulator *sim, - MachineView const &mv, - CostMetrics &cost_metrics) const { - ParallelTensorBase sub_output, sub_input; - if (!outputs[0]->get_sub_tensor(mv, sub_output)) { - return false; - } - if (!inputs[0]->get_sub_tensor(mv, sub_input)) { - return false; - } - ElementUnaryMeta *m = sim->ele_unary_meta; - m->op_type = op_type; - Domain input_domain, output_domain; - input_domain.dim = sub_input.num_dims; - for (int i = 0; i < sub_input.num_dims; i++) { - input_domain.rect_data[i] = 0; - input_domain.rect_data[i + input_domain.dim] = sub_input.dims[i].size - 1; - } - output_domain.dim = sub_output.num_dims; - for (int i = 0; i < sub_output.num_dims; i++) { - output_domain.rect_data[i] = 0; - output_domain.rect_data[i + input_domain.dim] = sub_output.dims[i].size - 1; - } - init_kernel(m, input_domain, output_domain); - sim->free_all(); - float *input_ptr = - (float *)sim->allocate(sub_input.get_volume(), inputs[0]->data_type); - assert(input_ptr != NULL); - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *output_ptr = NULL; - if (inplace) { - output_ptr = input_ptr; - } else { - output_ptr = - (float *)sim->allocate(sub_output.get_volume(), outputs[0]->data_type); - } - assert(output_ptr != NULL); - cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - assert(m->profiling == false); - - std::function forward, backward; - forward = [&] { - forward_kernel_wrapper(m, input_ptr, output_ptr, sub_output.get_volume()); - }; - if (sim->computationMode == COMP_MODE_TRAINING) { - float *input_grad_ptr = - (float *)sim->allocate(sub_input.get_volume(), inputs[0]->data_type); - assert(input_grad_ptr != NULL); - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *output_grad_ptr = NULL; - if (inplace) { - output_grad_ptr = input_grad_ptr; - } else { - output_grad_ptr = (float *)sim->allocate(sub_output.get_volume(), - outputs[0]->data_type); - } - assert(output_grad_ptr != NULL); - cost_metrics.outputs_memory += - cost_metrics.total_mem_diff_from(sim->offset); - - backward = [&] { - backward_kernel_wrapper(m, - input_ptr, - input_grad_ptr, - output_ptr, - output_grad_ptr, - sub_output.get_volume()); - }; - } - - inner_measure_operator_cost(sim, forward, backward, cost_metrics); - - if (sim->computationMode == COMP_MODE_TRAINING) { - log_measure.debug("[Measure Elewise Unary] name(%s) num_elements(%zu) " - "forward_time(%.4lf) backward_time(%.4lf)\n", - name, - sub_output.get_volume(), - cost_metrics.forward_time, - cost_metrics.backward_time); - } else { - log_measure.debug("[Measure Elewise Unary] name(%s) num_elements(%zu) " - "forward_time(%.4lf)\n", - name, - sub_output.get_volume(), - cost_metrics.forward_time); - } - return true; +template <> +void register_task() { + register_task(ELEMENTUNARY_INIT_TASK_ID, + "ElementUnary Init", + init_signature(), + init_task); } -using PCG::Node; -/*static*/ -Node ElementUnary::deserialize(FFModel &ff, - Legion::Deserializer &dez, - ParallelTensor inputs[], - int num_inputs) { - assert(num_inputs == 1); - OperatorType op_type; - float scalar; - bool inplace; - dez.deserialize(op_type); - dez.deserialize(inplace); - dez.deserialize(scalar); - - ElementUnaryParams params; - params.op_type = op_type; - params.inplace = inplace; - params.scalar = scalar; - return ff.get_or_create_node(inputs[0], params); +template <> +OpTaskSignature fwd_signature() { + OpTaskSignature fwd(OpTaskType::FWD); + + fwd.add_input_slot(INPUT); + fwd.add_output_slot(OUTPUT); + + fwd.add_arg_slot(PROFILING); + fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); + + return fwd; } -Op *ElementUnary::materialize(FFModel &ff, - ParallelTensor inputs[], - int num_inputs) const { - assert(num_inputs == 1); - return new ElementUnary( - ff, this->op_type, inputs[0], this->inplace, this->name, this->scalar); +template <> +void register_task() { + register_task(ELEMENTUNARY_FWD_TASK_ID, + "ElementUnary Fwd", + fwd_signature(), + forward_task); } -}; // namespace FlexFlow +template <> +OpTaskSignature bwd_signature() { + OpTaskSignature bwd = + infer_bwd_signature(fwd_signature()); -namespace std { -size_t hash::operator()( - FlexFlow::ElementUnaryParams const ¶ms) const { - size_t key = 0; - hash_combine(key, params.op_type); - hash_combine(key, params.scalar); - hash_combine(key, params.inplace); - return key; + return bwd; } -}; // namespace std + +template <> +void register_task() { + register_task(ELEMENTUNARY_BWD_TASK_ID, + "ElementUnary Bwd", + bwd_signature(), + backward_task); +} + +} // namespace FlexFlow diff --git a/lib/runtime/src/ops/element_unary.h b/lib/runtime/src/ops/element_unary.h index ae661f1177..f44efc28db 100644 --- a/lib/runtime/src/ops/element_unary.h +++ b/lib/runtime/src/ops/element_unary.h @@ -7,6 +7,9 @@ namespace FlexFlow { +using ElementUnaryUnifiedAttrs = + variant; + template <> void register_task(); template <> @@ -14,86 +17,16 @@ void register_task(); template <> void register_task(); -OpTaskInvocation init(ElementUnaryAttrs const &); -OpTaskInvocation forward(ElementUnaryAttrs const &); -OpTaskInvocation backward(ElementUnaryAttrs const &); - -OpTaskInvocation init(ElementScalarUnaryAttrs const &); -OpTaskInvocation forward(ElementScalarUnaryAttrs const &); -OpTaskInvocation backward(ElementScalarUnaryAttrs const &); - -CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, - ElementUnaryAttrs const &attrs, - ParallelTensorShape const &input_shape, - ProfilingSettings const &settings, - MachineView const &machine_view); +OpTaskInvocation init(ElementUnaryUnifiedAttrs const &); +OpTaskInvocation forward(ElementUnaryUnifiedAttrs const &); +OpTaskInvocation backward(ElementUnaryUnifiedAttrs const &); CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, - ElementScalarUnaryAttrs const &attrs, - ParallelTensorShape const &input_shape, + ElementUnaryUnifiedAttrs const &attrs, + InputParallelTensorDesc const &input_shape, ProfilingSettings const &settings, MachineView const &machine_view); -/* class ElementUnary : public Op { */ -/* public: */ -/* ElementUnary(FFModel &model, */ -/* OperatorType type, */ -/* const ParallelTensor x, */ -/* bool inplace, */ -/* char const *name, */ -/* float scalar); */ -/* void init(FFModel const &) override; */ -/* void forward(FFModel const &) override; */ -/* void backward(FFModel const &) override; */ -/* void map_output_tensors(FFModel &model) override; */ -/* bool can_inplace_output() override; */ -/* bool has_inplace_output() override; */ -/* void do_inplace_output() override; */ -/* static Op * */ -/* create_operator_from_layer(FFModel &model, */ -/* Layer const *layer, */ -/* std::vector const &inputs); - */ - -/* static PerDeviceOpState *init_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void forward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void backward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* template */ -/* static void */ -/* forward_task_with_type(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* template */ -/* static void backward_task_with_type( */ -/* Legion::Task const *task, */ -/* std::vector const ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* bool measure_operator_cost(Simulator *sim, */ -/* MachineView const &pc, */ -/* CostMetrics &cost_metrics) const override; */ - -/* private: */ -/* bool inplace; */ - -/* public: */ -/* float scalar; */ -/* }; */ - } // namespace FlexFlow #endif diff --git a/lib/runtime/test/src/test_serialization.cc b/lib/runtime/test/src/test_serialization.cc index 127b332ccf..d80808b7fb 100644 --- a/lib/runtime/test/src/test_serialization.cc +++ b/lib/runtime/test/src/test_serialization.cc @@ -19,10 +19,10 @@ TEST_CASE("Serialization") { BroadcastAttrs broadcast_attrs, CastAttrs cast_attrs, CombineAttrs combine_attrs, ConcatAttrs concat_attrs, Conv2DAttrs conv2d_attrs, DropoutAttrs dropout_attrs, - ElementBinaryAttrs elem_bin_attrs, + ElementBinaryAttrs elem_bin_attrs, ElementUnaryAttrs elem_unary_attrs, ElementScalarUnaryAttrs elem_scalar_unary_attrs, - ElementUnaryAttrs elem_unary_attrs, EmbeddingAttrs embedding_attrs, - FlatAttrs flat_attrs, GatherAttrs gather_attrs, InputAttrs input_attrs, + EmbeddingAttrs embedding_attrs, FlatAttrs flat_attrs, + GatherAttrs gather_attrs, InputAttrs input_attrs, LayerNormAttrs layer_norm_attrs, LinearAttrs linear_attrs, MultiHeadAttentionAttrs mha_attrs, NoopAttrs noop_attrs, Pool2DAttrs pool2d_attrs, ReduceAttrs reduce_attrs, diff --git a/lib/substitutions/include/substitutions/get_attribute.h b/lib/substitutions/include/substitutions/get_attribute.h index f35145133e..50c4108a67 100644 --- a/lib/substitutions/include/substitutions/get_attribute.h +++ b/lib/substitutions/include/substitutions/get_attribute.h @@ -25,9 +25,7 @@ optional get_attribute(ElementUnaryAttrs const &p, OperatorAttributeKey); optional get_attribute(DropoutAttrs const &p, OperatorAttributeKey); -optional get_attribute(ElementBinaryAttrs const &p, - OperatorAttributeKey); -optional get_attribute(ElementUnaryAttrs const &p, +optional get_attribute(ElementScalarUnaryAttrs const &p, OperatorAttributeKey); optional get_attribute(EmbeddingAttrs const &p, OperatorAttributeKey); diff --git a/lib/substitutions/src/operator_attributes.cc b/lib/substitutions/src/operator_attributes.cc index eef0464c42..3922b091a7 100644 --- a/lib/substitutions/src/operator_attributes.cc +++ b/lib/substitutions/src/operator_attributes.cc @@ -85,6 +85,14 @@ optional get_attribute(ElementUnaryAttrs const &p, } } +optional get_attribute(ElementScalarUnaryAttrs const &p, + OperatorAttributeKey key) { + switch (key) { + default: + return nullopt; + } +} + optional get_attribute(DropoutAttrs const &p, OperatorAttributeKey key) { switch (key) { diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc index 8e99624acb..72c9248e6c 100644 --- a/lib/substitutions/src/substitution.cc +++ b/lib/substitutions/src/substitution.cc @@ -186,6 +186,14 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, op_type, get(assignments.at(OperatorAttributeKey::SCALAR))}, nullopt); + case Op::EXP: + case Op::IDENTITY: + case Op::GELU: + case Op::RSQRT: + case Op::POW: + case Op::SIN: + case Op::COS: + return Operator(ElementUnaryAttrs{op_type}, nullopt); case Op::EMBEDDING: return Operator( EmbeddingAttrs{