From 22b47c4e7699edb34095991c3b175648b92745e9 Mon Sep 17 00:00:00 2001 From: Reyna Abhyankar Date: Mon, 1 Jan 2024 12:01:56 -0800 Subject: [PATCH 01/18] Element Unary --- .../include/kernels/element_unary_kernels.h | 36 +- lib/kernels/src/cuda/element_unary_kernels.cu | 103 +-- lib/kernels/src/hip/element_unary_kernels.cpp | 89 +- .../include/op-attrs/ops/element_unary.h | 50 +- lib/runtime/src/ops/element_unary.cc | 829 +++++------------- 5 files changed, 380 insertions(+), 727 deletions(-) diff --git a/lib/kernels/include/kernels/element_unary_kernels.h b/lib/kernels/include/kernels/element_unary_kernels.h index 428c0ed897..5fac73132d 100644 --- a/lib/kernels/include/kernels/element_unary_kernels.h +++ b/lib/kernels/include/kernels/element_unary_kernels.h @@ -3,42 +3,50 @@ #include "kernels/accessor.h" #include "kernels/device.h" -#include "legion.h" +#include "kernels/ff_handle.h" +#include "op-attrs/ops/element_unary.h" #include namespace FlexFlow { -class ElementUnaryPerDeviceState : public PerDeviceOpState { -public: - ElementUnaryPerDeviceState(FFHandler handle); +struct ElementUnaryPerDeviceState { + PerDeviceFFHandle handle; ffTensorDescriptor_t inputTensor, outputTensor; ffActivationDescriptor_t actiDesc; OperatorType op_type; DataType data_type; - bool inplace; float scalar; - char op_name[MAX_OPNAME]; }; +FF_VISITABLE_STRUCT_NO_EQ(ElementUnaryPerDeviceState, + handle, + inputTensor, + outputTensor, + actiDesc, + op_type, + data_type, + scalar); + namespace Kernels { namespace ElementUnary { -void init_kernel(ElementUnaryPerDeviceState *m, - Legion::Domain const &input_domain, - Legion::Domain const &output_domain); +ElementUnaryPerDeviceState init_kernel(PerDeviceFFHandle const &handle, + ArrayShape const &input_shape, + ArrayShape const &output_shape, + DataType data_type); void forward_kernel(ffStream_t stream, - ElementUnaryPerDeviceState const *m, + ElementUnaryPerDeviceState const &device_state, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output); void backward_kernel(ffStream_t stream, - ElementUnaryPerDeviceState const *m, + ElementUnaryPerDeviceState const &device_state, GenericTensorAccessorR const &input, - GenericTensorAccessorR const &input_grad, - GenericTensorAccessorW const &output, - GenericTensorAccessorW const &output_grad); + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output, + GenericTensorAccessorR const &output_grad); } // namespace ElementUnary } // namespace Kernels diff --git a/lib/kernels/src/cuda/element_unary_kernels.cu b/lib/kernels/src/cuda/element_unary_kernels.cu index 1251f75603..568e1f319c 100644 --- a/lib/kernels/src/cuda/element_unary_kernels.cu +++ b/lib/kernels/src/cuda/element_unary_kernels.cu @@ -18,18 +18,6 @@ #include "kernels/element_unary_kernels.h" namespace FlexFlow { - -// declare Legion names -using Legion::coord_t; -using Legion::Domain; - -ElementUnaryPerDeviceState::ElementUnaryPerDeviceState(FFHandler handler) - : PerDeviceOpState(handler) { - checkCUDNN(cudnnCreateTensorDescriptor(&inputTensor)); - checkCUDNN(cudnnCreateTensorDescriptor(&outputTensor)); - checkCUDNN(cudnnCreateActivationDescriptor(&actiDesc)); -} - namespace Kernels { namespace ElementUnary { @@ -45,13 +33,23 @@ static bool use_cudnn(OperatorType op_type) { } } -void init_kernel(ElementUnaryPerDeviceState *m, - Domain const &input_domain, - Domain const &output_domain) { +ElementUnaryPerDeviceState init_kernel(PerDeviceFFHandle const &handle, + ArrayShape const &input_shape, + ArrayShape const &output_shape, + OperatorType op_type, + DataType data_type) { + + ffTensorDescriptor_t inputTensor; + ffTensorDescriptor_t outputTensor; + ffActivationDescriptor_t actiDesc; - if (use_cudnn(m->op_type)) { + checkCUDNN(cudnnCreateTensorDescriptor(&inputTensor)); + checkCUDNN(cudnnCreateTensorDescriptor(&outputTensor)); + checkCUDNN(cudnnCreateActivationDescriptor(&actiDesc)); + + if (use_cudnn(op_type)) { cudnnActivationMode_t mode; - switch (m->op_type) { + switch (op_type) { case OP_SIGMOID: mode = CUDNN_ACTIVATION_SIGMOID; break; @@ -67,32 +65,37 @@ void init_kernel(ElementUnaryPerDeviceState *m, default: assert(false); } - checkCUDNN(cudnnSetActivationDescriptor( - m->actiDesc, mode, CUDNN_PROPAGATE_NAN, 0.0)); checkCUDNN( - cudnnSetTensorDescriptorFromDomain(m->inputTensor, input_domain)); - // input_domain == output_domain + cudnnSetActivationDescriptor(actiDesc, mode, CUDNN_PROPAGATE_NAN, 0.0)); checkCUDNN( - cudnnSetTensorDescriptorFromDomain(m->outputTensor, output_domain)); + cudnnSetTensorDescriptorFromArrayShape(inputTensor, input_shape)); + // input_shape == output_shape + checkCUDNN( + cudnnSetTensorDescriptorFromArrayShape(outputTensor, output_shape)); } + + ElementUnaryPerDeviceState per_device_state = { + handle, inputTensor, outputTensor, actiDesc, op_type, data_type, scalar}; + + return per_device_state; } template struct ForwardKernel { void operator()(ffStream_t stream, - ElementUnaryPerDeviceState const *m, + ElementUnaryPerDeviceState const &m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) const { - checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); - if (use_cudnn(m->op_type)) { + checkCUDNN(cudnnSetStream(m.handle.dnn, stream)); + if (use_cudnn(m.op_type)) { float alpha = 1.0f, beta = 0.0f; - checkCUDNN(cudnnActivationForward(m->handle.dnn, - m->actiDesc, + checkCUDNN(cudnnActivationForward(m.handle.dnn, + m.actiDesc, &alpha, - m->inputTensor, + m.inputTensor, input.get(), &beta, - m->outputTensor, + m.outputTensor, output.get())); } else { size_t num_elements = input.shape.num_elements(); @@ -100,8 +103,8 @@ struct ForwardKernel { CUDA_NUM_THREADS, 0, stream>>>(num_elements, - (T)m->scalar, - m->op_type, + (T)m.scalar, + m.op_type, input.get(), output.get()); } @@ -111,34 +114,34 @@ struct ForwardKernel { template struct BackwardKernel { void operator()(ffStream_t stream, - ElementUnaryPerDeviceState const *m, + ElementUnaryPerDeviceState const &m, GenericTensorAccessorR const &input, - GenericTensorAccessorR const &input_grad, - GenericTensorAccessorW const &output, - GenericTensorAccessorW const &output_grad) { - checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output, + GenericTensorAccessorR const &output_grad) { + checkCUDNN(cudnnSetStream(m.handle.dnn, stream)); - if (use_cudnn(m->op_type)) { + if (use_cudnn(m.op_type)) { float alpha = 1.0f; - checkCUDNN(cudnnActivationBackward(m->handle.dnn, - m->actiDesc, + checkCUDNN(cudnnActivationBackward(m.handle.dnn, + m.actiDesc, &alpha, - m->outputTensor, + m.outputTensor, output.get(), - m->outputTensor, + m.outputTensor, output_grad.get()), - m->inputTensor, + m.inputTensor, input.get(), &alpha, - m->inputTensor, + m.inputTensor, input_grad.get())); } else { size_t num_elements = input.shape.num_elements(); elewise_unary_backward_kernel <<>>( num_elements, - m->scalar, - m->op_type, + m.scalar, + m.op_type, output.get(), output_grad.get(), input.get(), @@ -148,21 +151,19 @@ struct BackwardKernel { } void forward_kernel(ffStream_t stream, - ElementUnaryPerDeviceState const *m, + ElementUnaryPerDeviceState const &device_state, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { - { - DataTypeDispatch1{}(m->data_type, stream, m, input, output); - } + { DataTypeDispatch1{}(m.data_type, stream, m, input, output); } void backward_kernel(ffStream_t stream, - ElementUnaryPerDeviceState const *m, + ElementUnaryPerDeviceState const &device_state, GenericTensorAccessorR const &input, GenericTensorAccessorR const &input_grad, GenericTensorAccessorW const &output, GenericTensorAccessorW const &output_grad) DataTypeDispatch1{}( - m->data_type, stream, m, input, input_grad, output, output_grad); + m.data_type, stream, m, input, input_grad, output, output_grad); } template diff --git a/lib/kernels/src/hip/element_unary_kernels.cpp b/lib/kernels/src/hip/element_unary_kernels.cpp index 58bec1b262..01bb7ca5f9 100644 --- a/lib/kernels/src/hip/element_unary_kernels.cpp +++ b/lib/kernels/src/hip/element_unary_kernels.cpp @@ -18,26 +18,24 @@ #include namespace FlexFlow { +namespace Kernels { +namespace ElementUnary { -// declare Legion names -using Legion::coord_t; -using Legion::Domain; +ElementUnaryPerDeviceState init_kernel(PerDeviceFFHandle const &handle, + ArrayShape const &input_shape, + ArrayShape const &output_shape, + OperatorType op_type, + DataType data_type) { + miopenTensorDescriptor_t inputTensor; + miopenTensorDescriptor_t outputTensor; + miopenActivationDescriptor_t actiDesc; + miopenActivationMode_t mode; -ElementUnaryPerDeviceState::ElementUnaryPerDeviceState(FFHandler handler) - : PerDeviceOpState(handler) { checkCUDNN(miopenCreateTensorDescriptor(&inputTensor)); checkCUDNN(miopenCreateTensorDescriptor(&outputTensor)); checkCUDNN(miopenCreateActivationDescriptor(&actiDesc)); -} -namespace Kernels { -namespace ElementUnary { - -void init_kernel(ElementUnaryPerDeviceState *m, - Domain const &input_domain, - Domain const &output_domain) { - miopenActivationMode_t mode; - switch (m->op_type) { + switch (op_type) { case OP_SIGMOID: mode = miopenActivationLOGISTIC; break; @@ -53,11 +51,16 @@ void init_kernel(ElementUnaryPerDeviceState *m, default: assert(false); } - checkCUDNN(miopenSetActivationDescriptor(m->actiDesc, mode, 0.0, 0.0, 0.0)); - checkCUDNN(cudnnSetTensorDescriptorFromDomain(m->inputTensor, input_domain)); + checkCUDNN(miopenSetActivationDescriptor(actiDesc, mode, 0.0, 0.0, 0.0)); + checkCUDNN(cudnnSetTensorDescriptorFromArrayShape(inputTensor, input_shape)); // input_domain == output_domain checkCUDNN( - cudnnSetTensorDescriptorFromDomain(m->outputTensor, output_domain)); + cudnnSetTensorDescriptorFromArrayShape(outputTensor, output_shape)); + + ElementUnaryPerDeviceState per_device_state = { + handle, inputTensor, outputTensor, actiDesc, op_type, data_type, scalar}; + + return per_device_state; } bool use_cudnn(OperatorType type) { @@ -82,16 +85,16 @@ struct ForwardKernel { ElementUnaryPerDeviceState const *m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { - checkCUDNN(miopenSetStream(m->handle.dnn, stream)); - if (use_cudnn(m->op_type)) { + checkCUDNN(miopenSetStream(m.handle.dnn, stream)); + if (use_cudnn(m.op_type)) { float alpha = 1.0f, beta = 0.0f; - checkCUDNN(miopenActivationForward(m->handle.dnn, - m->actiDesc, + checkCUDNN(miopenActivationForward(m.handle.dnn, + m.actiDesc, &alpha, - m->inputTensor, + m.inputTensor, input.get(), &beta, - m->outputTensor, + m.outputTensor, output.get())); } else { size_t num_elements = input.shape.num_elements(); @@ -101,8 +104,8 @@ struct ForwardKernel { 0, stream, num_elements, - (T)m->scalar, - m->op_type, + (T)m.scalar, + m.op_type, input.get(), output.get()); } @@ -117,22 +120,22 @@ struct BackwardKernel { GenericTensorAccessorR const &input_grad, GenericTensorAccessorW const &output, GenericTensorAccessorW const &output_grad) { - checkCUDNN(miopenSetStream(m->handle.dnn, stream)); + checkCUDNN(miopenSetStream(m.handle.dnn, stream)); - if (use_cudnn(m->op_type)) { + if (use_cudnn(m.op_type)) { float alpha = 1.0f; float beta = 0.0f; - checkCUDNN(miopenActivationBackward(m->handle.dnn, - m->actiDesc, + checkCUDNN(miopenActivationBackward(m.handle.dnn, + m.actiDesc, &alpha, - m->outputTensor, + m.outputTensor, output.get(), - m->outputTensor, + m.outputTensor, output_grad.get()), - m->inputTensor, + m.inputTensor, input.get(), &beta, - m->inputTensor, + m.inputTensor, input_grad.get()); } else { size_t num_elements = input.shape.num_elements(); @@ -142,8 +145,8 @@ struct BackwardKernel { 0, stream, num_elements, - m->scalar, - m->op_type, + m.scalar, + m.op_type, output.get(), output_grad.get(), input.get(), @@ -151,21 +154,19 @@ struct BackwardKernel { } } } void forward_kernel(ffStream_t stream, - ElementUnaryPerDeviceState const *m, + ElementUnaryPerDeviceState const &device_state, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { - { - DataTypeDispatch1{}(m->data_type, stream, m, input, output); - } + { DataTypeDispatch1{}(m.data_type, stream, m, input, output); } void backward_kernel(ffStream_t stream, - ElementUnaryPerDeviceState const *m, + ElementUnaryPerDeviceState const &device_state, GenericTensorAccessorR const &input, - GenericTensorAccessorR const &input_grad, - GenericTensorAccessorW const &output, - GenericTensorAccessorW const &output_grad) + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output, + GenericTensorAccessorR const &output_grad) DataTypeDispatch1{}( - m->data_type, stream, m, input, input_grad, output, output_grad); + m.data_type, stream, m, input, input_grad, output, output_grad); } template diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary.h b/lib/op-attrs/include/op-attrs/ops/element_unary.h index 1b72e83cb5..aa393821bd 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary.h @@ -1,26 +1,38 @@ -#ifndef _FLEXFLOW_ELEMENTARY_UNARY_ATTRS_H -#define _FLEXFLOW_ELEMENTARY_UNARY_ATTRS_H +#ifndef _ELEMENT_UNARY_H +#define _ELEMENT_UNARY_H -#include "core.h" -#include "op-attrs/op.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/element_unary.h" +#include "op_task_invocation.h" +#include "sim_environment.h" namespace FlexFlow { -struct ElementScalarUnaryAttrs { - req op; - /* bool inplace; */ - req scalar; -}; -FF_VISITABLE_STRUCT(ElementScalarUnaryAttrs, op, scalar); -CHECK_VALID_OP_ATTR(ElementScalarUnaryAttrs); - -struct ElementUnaryAttrs { - req op; -}; -FF_VISITABLE_STRUCT(ElementUnaryAttrs, op); -CHECK_VALID_OP_ATTR(ElementUnaryAttrs); +template <> +void register_task(); +template <> +void register_task(); +template <> +void register_task(); + +OpTaskInvocation init(ElementUnaryAttrs const &); +OpTaskInvocation forward(ElementUnaryAttrs const &); +OpTaskInvocation backward(ElementUnaryAttrs const &); + +OpTaskInvocation init(ElementScalarUnaryAttrs const &); +OpTaskInvocation forward(ElementScalarUnaryAttrs const &); +OpTaskInvocation backward(ElementScalarUnaryAttrs const &); + +CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, + ElementUnaryAttrs const &attrs, + InputParallelTensorDesc const &input_shape, + ProfilingSettings const &settings, + MachineView const &machine_view); + +CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, + ElementScalarUnaryAttrs const &attrs, + InputParallelTensorDesc const &input_shape, + ProfilingSettings const &settings, + MachineView const &machine_view); } // namespace FlexFlow diff --git a/lib/runtime/src/ops/element_unary.cc b/lib/runtime/src/ops/element_unary.cc index 07959bd6da..b5864749c9 100644 --- a/lib/runtime/src/ops/element_unary.cc +++ b/lib/runtime/src/ops/element_unary.cc @@ -6,665 +6,296 @@ namespace FlexFlow { // declare Legion names -using Legion::ArgumentMap; using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; using Legion::Runtime; using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; using namespace FlexFlow::Kernels::ElementUnary; -Tensor FFModel::unary(OperatorType op, - const Tensor x, - bool inplace, - char const *name, - float scalar) { - Layer *ele = nullptr; - DataType dtype; - // FIXME: currently cast input to float if it has a lower type - if (x->data_type < DT_FLOAT) { - dtype = DT_FLOAT; - std::string str(name); - Tensor new_x = cast(x, dtype, (str + "input_pre_cast").c_str()); - ele = new Layer(this, - op, - dtype, - name, - 1 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - new_x); - } else { - dtype = x->data_type; - ele = new Layer( - this, op, dtype, name, 1 /*inputs*/, 0 /*weights*/, 1 /*outputs*/, x); - } - int numdims = x->num_dims; - int dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdims; i++) { - dims[i] = x->dims[i]; - } - ele->outputs[0] = create_tensor_legion_ordering( - numdims, dims, dtype, ele, 0, true /*create_grad*/); - ele->add_int_property("inplace", inplace); - ele->add_float_property("scalar", scalar); - layers.push_back(ele); - return ele->outputs[0]; -} +enum Slots { + INPUT, + INPUT_SHAPE, + OUTPUT, + ATTRS, + HANDLE, + PROFILING, + PER_DEVICE_STATE +}; -Op *ElementUnary::create_operator_from_layer( - FFModel &model, - Layer const *layer, - std::vector const &inputs) { - long long value; - layer->get_int_property("inplace", value); - bool inplace = (bool)value; - float scalar; - layer->get_float_property("scalar", scalar); - return new ElementUnary( - model, layer->op_type, inputs[0], inplace, layer->name, scalar); -} +/* ElementUnary */ +OpTaskInvocation init(ElementUnaryAttrs const &attrs) { + OpTaskBinding b; -Tensor FFModel::exp(const Tensor x, char const *name) { - return this->unary(OP_EXP, x, false /*inplace*/, name); -} + b.bind_arg(HANDLE, ff_handle()); + b.bind_arg(ATTRS, attrs); + b.bind_arg(INPUT_SHAPE, input_parallel_tensor_shape(0)); -Tensor FFModel::scalar_multiply(const Tensor x, - float const scalar, - bool inplace, - char const *name) { - return this->unary(OP_SCALAR_MULTIPLY, x, inplace, name, scalar); + return {ELEMENTUNARY_INIT_TASK_ID, b}; } -Tensor FFModel::scalar_add(const Tensor x, - float const scalar, - bool inplace, - char const *name) { - return this->unary(OP_SCALAR_ADD, x, inplace, name, scalar); -} +OpTaskInvocation forward(ElementUnaryAttrs const &attrs) { + OpTaskBinding b; -Tensor FFModel::scalar_sub(const Tensor x, - float const scalar, - bool inplace, - char const *name) { - return this->unary(OP_SCALAR_SUB, x, inplace, name, scalar); -} + b.bind(INPUT, input_tensor(0)); + b.bind(OUTPUT, output_tensor(0)); -Tensor FFModel::scalar_truediv(const Tensor x, - float const scalar, - bool inplace, - char const *name) { - return this->unary(OP_SCALAR_TRUE_DIV, x, inplace, name, scalar); -} + b.bind_arg(PROFILING, profiling_settings()); + b.bind_arg(PER_DEVICE_STATE, + per_device_op_state()); -Tensor FFModel::relu(const Tensor x, bool inplace, char const *name) { - return this->unary(OP_RELU, x, inplace, name); + return {ELEMENTUNARY_FWD_TASK_ID, b}; } -Tensor FFModel::sigmoid(const Tensor x, char const *name) { - return this->unary(OP_SIGMOID, x, false /*inplace*/, name); -} +OpTaskInvocation backward(ElementUnaryAttrs const &attrs) { + OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); -Tensor FFModel::tanh(const Tensor x, char const *name) { - return this->unary(OP_TANH, x, false /*inplace*/, name); + return {ELEMENTUNARY_BWD_TASK_ID, b}; } -Tensor FFModel::identity(const Tensor x, char const *name) { - return this->unary(OP_IDENTITY, x, false /*inplace*/, name); -} +/* ElementScalarUnary */ +OpTaskInvocation init(ElementScalarUnaryAttrs const &attrs) { + OpTaskBinding b; -Tensor FFModel::gelu(const Tensor x, char const *name) { - return this->unary(OP_GELU, x, false /*inplace*/, name); -} + b.bind_arg(HANDLE, ff_handle()); + b.bind_arg(ATTRS, attrs); + b.bind_arg(INPUT_SHAPE, input_parallel_tensor_shape(0)); -Tensor FFModel::elu(const Tensor x, bool inplace, char const *name) { - // Currently assume inplace is false - assert(!inplace); - return this->unary(OP_ELU, x, inplace, name); + return {ELEMENTUNARY_INIT_TASK_ID, b}; } -Tensor FFModel::rsqrt(const Tensor x, bool inplace, char const *name) { - return this->unary(OP_RSQRT, x, inplace, name); -} +OpTaskInvocation forward(ElementScalarUnaryAttrs const &attrs) { + OpTaskBinding b; -Tensor FFModel::pow(const Tensor x, - float const exponent, - bool inplace, - char const *name) { - return this->unary(OP_POW, x, inplace, name, exponent); -} + b.bind(INPUT, input_tensor(0)); + b.bind(OUTPUT, output_tensor(0)); -Tensor FFModel::sin(const Tensor x, char const *name) { - return this->unary(OP_SIN, x, false /*inplace*/, name); -} + b.bind_arg(PROFILING, profiling_settings()); + b.bind_arg(PER_DEVICE_STATE, + per_device_op_state()); -Tensor FFModel::cos(const Tensor x, char const *name) { - return this->unary(OP_COS, x, false /*inplace*/, name); + return {ELEMENTUNARY_FWD_TASK_ID, b}; } -bool ElementUnaryParams::is_valid(ParallelTensorShape const &input) const { - return input.is_valid(); -} +OpTaskInvocation backward(ElementScalarUnaryAttrs const &attrs) { + OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); -bool operator==(ElementUnaryParams const &lhs, ElementUnaryParams const &rhs) { - return lhs.op_type == rhs.op_type && lhs.scalar == rhs.scalar && - lhs.inplace == rhs.inplace; + return {ELEMENTUNARY_BWD_TASK_ID, b}; } -ElementUnary::ElementUnary(FFModel &model, - OperatorType _op_type, - const ParallelTensor x, - bool _inplace, - char const *name, - float _scalar) - : Op(model, - _op_type, - x->data_type, - name, - 1 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - x), - inplace(_inplace), scalar(_scalar) { - numOutputs = 1; - int numdim = x->num_dims; - ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdim; i++) { - dims[i] = x->dims[i]; - } - outputs[0] = model.create_parallel_tensor_legion_ordering( - numdim, dims, inputs[0]->data_type, this); - // Disable inplace if shape mismatch - if (outputs[0]->get_shape() != inputs[0]->get_shape()) { - inplace = false; - } -} +static DeviceSpecific + init_task_impl(TaskArgumentAccessor const &acc) { -ElementUnary::ElementUnary(FFModel &model, - ElementUnaryParams const ¶ms, - const ParallelTensor input, - char const *name) - : ElementUnary( - model, params.op_type, input, params.inplace, name, params.scalar) {} - -void ElementUnary::map_output_tensors(FFModel &ff) { - if (has_inplace_output()) { - assert(numOutputs == 1); - assert(outputs[0]->get_volume() == inputs[0]->get_volume()); - outputs[0]->parallel_is = inputs[0]->parallel_is; - outputs[0]->region = inputs[0]->region; - outputs[0]->part = inputs[0]->part; - outputs[0]->region_grad = inputs[0]->region_grad; - outputs[0]->part_grad = inputs[0]->part_grad; - } else { - Op::map_output_tensors(ff); - } -} + auto const &attrs = acc.get_argument(ATTRS); + ProfilingSettings profiling = acc.get_argument(PROFILING); + PerDeviceFFHandle handle = acc.get_argument(HANDLE); + ParallelTensorShape input_shape = + acc.get_argument(INPUT_SHAPE); + ParallelTensorShape output_shape = get_output_shape(attrs, input_shape); -bool ElementUnary::can_inplace_output(void) { - return outputs[0]->get_shape() == inputs[0]->get_shape(); + DeviceSpecific per_device_state = + acc.create_device_specific( + init_kernel(handle, + {input_shape.dims}, + {output_shape.dims}, + input_shape.data_type)); + return per_device_state; } -bool ElementUnary::has_inplace_output(void) { - return inplace; +static DeviceSpecific + init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + return init_task_impl(acc); } -void ElementUnary::do_inplace_output(void) { - inplace = true; -} +static optional forward_task_impl(TaskArgumentAccessor const &acc) { + auto input = acc.get_tensor(INPUT); + auto output = acc.get_tensor(OUTPUT); -void ElementUnary::init(FFModel const &ff) { - assert(check_output_input_weight_same_parallel_is()); - parallel_is = outputs[0]->parallel_is; - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - set_argumentmap_for_init(ff, argmap); - IndexLauncher init_launcher(ELEMENTUNARY_INIT_TASK_ID, - parallel_is, - TaskArgument(this, sizeof(ElementUnary)), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - outputs[0]->machine_view.hash()); - if (!inplace) { - init_launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - inputs[0]->region)); - init_launcher.add_field(0, FID_DATA); - init_launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - outputs[0]->region)); - init_launcher.add_field(1, FID_DATA); - } else { - init_launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - 0 /*projection id*/, - READ_WRITE, - EXCLUSIVE, - inputs[0]->region)); - init_launcher.add_field(0, FID_DATA); - } - FutureMap fm = runtime->execute_index_space(ctx, init_launcher); - fm.wait_all_results(); - set_opmeta_from_futuremap(ff, fm); -} + ProfilingSettings profiling = acc.get_argument(PROFILING); + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); -PerDeviceOpState * - ElementUnary::init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - ElementUnary *eu = (ElementUnary *)task->args; - FFHandler handle = *((FFHandler *)task->local_args); - ElementUnaryMeta *m = new ElementUnaryMeta(handle); - m->op_type = eu->op_type; - m->data_type = eu->outputs[0]->data_type; - // Input and output should have the same data type - assert(eu->outputs[0]->data_type == eu->inputs[0]->data_type); - m->profiling = eu->profiling; - m->inplace = eu->inplace; - m->scalar = eu->scalar; - std::strcpy(m->op_name, eu->name); - if (m->inplace) { - assert(regions.size() == 1); - assert(task->regions.size() == 1); - } else { - assert(regions.size() == 2); - assert(task->regions.size() == 2); - } - - Domain input_domain = runtime->get_index_space_domain( - ctx, task->regions[0].region.get_index_space()); - init_kernel(m, input_domain, input_domain); - return m; + return profile(forward_kernel, + profiling, + "[ElementUnary] forward_time = %.2lfms\n", + per_device_state, + input, + output); } -void ElementUnary::forward(FFModel const &ff) { - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - set_argumentmap_for_forward(ff, argmap); - IndexLauncher launcher(ELEMENTUNARY_FWD_TASK_ID, - parallel_is, - TaskArgument(NULL, 0), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - outputs[0]->machine_view.hash()); - if (inplace) { - assert(outputs[0]->part == inputs[0]->part); - assert(outputs[0]->region == inputs[0]->region); - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection id*/, - READ_WRITE, - EXCLUSIVE, - outputs[0]->region)); - launcher.add_field(0, FID_DATA); - } else { - launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - inputs[0]->region)); - launcher.add_field(0, FID_DATA); - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - outputs[0]->region)); - launcher.add_field(1, FID_DATA); - } - runtime->execute_index_space(ctx, launcher); +static void forward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + forward_task_impl(acc); } -void ElementUnary::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - ElementUnaryMeta const *m = *((ElementUnaryMeta **)task->local_args); - if (m->data_type == DT_FLOAT) { - forward_task_with_type(task, regions, ctx, runtime); - } else if (m->data_type == DT_DOUBLE) { - forward_task_with_type(task, regions, ctx, runtime); - } else if (m->data_type == DT_INT32) { - forward_task_with_type(task, regions, ctx, runtime); - } else if (m->data_type == DT_INT64) { - forward_task_with_type(task, regions, ctx, runtime); - } else { - assert(false && "Unsupported data type in Embedding forward"); - } -} +static optional backward_task_impl(TaskArgumentAccessor const &acc) { + auto input = acc.get_tensor(INPUT); + auto input_grad = acc.get_tensor_grad(INPUT); + auto output = acc.get_tensor(OUTPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); + + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + ProfilingSettings profiling = acc.get_argument(PROFILING); -/* - regions[0](I): input - regions[1](O): output -*/ -template -void ElementUnary::forward_task_with_type( - Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - // const ElementUnary* ele = (const ElementUnary*) task->args; - ElementUnaryMeta const *m = *((ElementUnaryMeta **)task->local_args); - Domain input_domain = runtime->get_index_space_domain( - ctx, task->regions[0].region.get_index_space()); - const DT *input_ptr = NULL; - DT *output_ptr = NULL; - if (m->inplace) { - assert(regions.size() == 1); - assert(task->regions.size() == 1); - output_ptr = helperGetTensorPointerRW
( - regions[0], task->regions[0], FID_DATA, ctx, runtime); - input_ptr = output_ptr; - } else { - assert(regions.size() == 2); - assert(task->regions.size() == 2); - Domain output_domain = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); - assert(output_domain == input_domain); - input_ptr = helperGetTensorPointerRO
( - regions[0], task->regions[0], FID_DATA, ctx, runtime); - output_ptr = helperGetTensorPointerWO
( - regions[1], task->regions[1], FID_DATA, ctx, runtime); - } - - forward_kernel_wrapper
( - m, input_ptr, output_ptr, input_domain.get_volume()); + return profile(backward_kernel, + profiling, + "[ElementUnary] backward_time = %.2lfms\n", + per_device_state, + input, + input_grad, + output, + output_grad); } -void ElementUnary::backward(FFModel const &ff) { - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - set_argumentmap_for_backward(ff, argmap); - IndexLauncher launcher(ELEMENTUNARY_BWD_TASK_ID, - parallel_is, - TaskArgument(NULL, 0), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - outputs[0]->machine_view.hash()); - if (inplace) { - assert(inputs[0]->part == outputs[0]->part); - assert(inputs[0]->part_grad == outputs[0]->part_grad); - // regions[2](I): output_grad - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - outputs[0]->region)); - launcher.add_field(0, FID_DATA); - // regions[3](I): output_grad - launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, - 0 /*projection id*/, - READ_WRITE, - EXCLUSIVE, - outputs[0]->region_grad)); - launcher.add_field(1, FID_DATA); - } else { - // regions[0](I): input - launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - inputs[0]->region)); - launcher.add_field(0, FID_DATA); - // regions[1](I/O): input_grad - launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, - 0 /*projection id*/, - READ_WRITE, - EXCLUSIVE, - inputs[0]->region_grad)); - launcher.add_field(1, FID_DATA); - // regions[2](I): output_grad - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - outputs[0]->region)); - launcher.add_field(2, FID_DATA); - // regions[3](I): output_grad - launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - outputs[0]->region_grad)); - launcher.add_field(3, FID_DATA); - } - runtime->execute_index_space(ctx, launcher); +static void backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + backward_task_impl(acc); } -void ElementUnary::backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - ElementUnaryMeta const *m = *((ElementUnaryMeta **)task->local_args); - if (m->data_type == DT_FLOAT) { - backward_task_with_type(task, regions, ctx, runtime); - } else if (m->data_type == DT_DOUBLE) { - backward_task_with_type(task, regions, ctx, runtime); - } else if (m->data_type == DT_INT32) { - backward_task_with_type(task, regions, ctx, runtime); - } else if (m->data_type == DT_INT64) { - backward_task_with_type(task, regions, ctx, runtime); - } else { - assert(false && "Unsupported data type in Embedding forward"); - } +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + ElementUnaryAttrs const &attrs, + InputParallelTensorDesc const &input_shape, + ProfilingSettings const &settings, + MachineView const &mv) { + auto env = sim.new_environment(); + + ParallelTensorShape output_shape = get_output_shape(attrs, input_shape); + + SimTaskBinding init_binding; + init_binding.bind_arg(HANDLE, ff_handle()); + init_binding.bind_arg(ATTRS, attrs); + init_binding.bind_arg(INPUT_SHAPE, input_parallel_tensor_shape(0)); + + auto init_accessor = + env.get_init_accessor(ELEMENTUNARY_INIT_TASK_ID, init_binding); + DeviceSpecific per_device_state = + init_task_impl(init_accessor); + + SimTaskBinding fwd_binding; + fwd_binding.bind(INPUT, input_shape); + fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind_arg(PROFILING, settings); + fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); + + SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); + + auto fwd_accessor = + env.get_fwd_accessor(ELEMENTUNARY_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = + env.get_bwd_accessor(ELEMENTUNARY_BWD_TASK_ID, bwd_binding); + + float forward_time = forward_task_impl(fwd_accessor).value(); + float backward_time = backward_task_impl(bwd_accessor).value(); + + float sync_time = default_estimate_sync_time(env); + return make_metrics(forward_time, backward_time, sync_time, env); } -/* - regions[0](I): input - regions[1](I/O): input_grad - regions[2](I): output - regions[3](I): output_grad -*/ -template -void ElementUnary::backward_task_with_type( - Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - // const ElementUnary* ele = (const ElementUnary*) task->args; - ElementUnaryMeta const *m = *((ElementUnaryMeta **)task->local_args); - const DT *input_ptr = NULL, *output_ptr = NULL, *output_grad_ptr = NULL; - DT *input_grad_ptr = NULL; - Domain input_domain = runtime->get_index_space_domain( - ctx, task->regions[0].region.get_index_space()); - if (m->inplace) { - assert(regions.size() == 2); - assert(task->regions.size() == 2); - Domain input_grad_domain = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); - assert(input_grad_domain == input_domain); - input_ptr = helperGetTensorPointerRO
( - regions[0], task->regions[0], FID_DATA, ctx, runtime); - input_grad_ptr = helperGetTensorPointerRW
( - regions[1], task->regions[1], FID_DATA, ctx, runtime); - output_ptr = input_ptr; - output_grad_ptr = input_grad_ptr; - } else { - assert(regions.size() == 4); - assert(task->regions.size() == 4); - Domain input_grad_domain = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); - Domain output_domain = runtime->get_index_space_domain( - ctx, task->regions[2].region.get_index_space()); - Domain output_grad_domain = runtime->get_index_space_domain( - ctx, task->regions[3].region.get_index_space()); - assert(output_grad_domain == input_domain); - assert(output_grad_domain == output_domain); - assert(output_grad_domain == input_grad_domain); - input_ptr = helperGetTensorPointerRO
( - regions[0], task->regions[0], FID_DATA, ctx, runtime); - input_grad_ptr = helperGetTensorPointerRW
( - regions[1], task->regions[1], FID_DATA, ctx, runtime); - output_ptr = helperGetTensorPointerRO
( - regions[2], task->regions[2], FID_DATA, ctx, runtime); - output_grad_ptr = helperGetTensorPointerRO
( - regions[3], task->regions[3], FID_DATA, ctx, runtime); - } - - backward_kernel_wrapper
(m, - input_ptr, - input_grad_ptr, - output_ptr, - output_grad_ptr, - input_domain.get_volume()); +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + ElementScalarUnaryAttrs const &attrs, + InputParallelTensorDesc const &input_shape, + ProfilingSettings const &settings, + MachineView const &mv) { + auto env = sim.new_environment(); + + ParallelTensorShape output_shape = get_output_shape(attrs, input_shape); + + SimTaskBinding init_binding; + init_binding.bind_arg(HANDLE, ff_handle()); + init_binding.bind_arg(ATTRS, attrs); + init_binding.bind_arg(INPUT_SHAPE, input_parallel_tensor_shape(0)); + + auto init_accessor = + env.get_init_accessor(ELEMENTUNARY_INIT_TASK_ID, init_binding); + DeviceSpecific per_device_state = + init_task_impl(init_accessor); + + SimTaskBinding fwd_binding; + fwd_binding.bind(INPUT, input_shape); + fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind_arg(PROFILING, settings); + fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); + + SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); + + auto fwd_accessor = + env.get_fwd_accessor(ELEMENTUNARY_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = + env.get_bwd_accessor(ELEMENTUNARY_BWD_TASK_ID, bwd_binding); + + float forward_time = forward_task_impl(fwd_accessor).value(); + float backward_time = backward_task_impl(bwd_accessor).value(); + + float sync_time = default_estimate_sync_time(env); + return make_metrics(forward_time, backward_time, sync_time, env); } -void ElementUnary::serialize(Legion::Serializer &sez) const { - sez.serialize(this->op_type); - sez.serialize(this->inplace); - sez.serialize(scalar); +template <> +OpTaskSignature init_signature() { + OpTaskSignature init(OpTaskType::INIT); + init.add_arg_slot(INPUT_SHAPE); + init.add_arg_slot(ATTRS); + init.add_unchecked_arg_slot(HANDLE); + + init.add_return_value(); + + return init; } -bool ElementUnary::measure_operator_cost(Simulator *sim, - MachineView const &mv, - CostMetrics &cost_metrics) const { - ParallelTensorBase sub_output, sub_input; - if (!outputs[0]->get_sub_tensor(mv, sub_output)) { - return false; - } - if (!inputs[0]->get_sub_tensor(mv, sub_input)) { - return false; - } - ElementUnaryMeta *m = sim->ele_unary_meta; - m->op_type = op_type; - Domain input_domain, output_domain; - input_domain.dim = sub_input.num_dims; - for (int i = 0; i < sub_input.num_dims; i++) { - input_domain.rect_data[i] = 0; - input_domain.rect_data[i + input_domain.dim] = sub_input.dims[i].size - 1; - } - output_domain.dim = sub_output.num_dims; - for (int i = 0; i < sub_output.num_dims; i++) { - output_domain.rect_data[i] = 0; - output_domain.rect_data[i + input_domain.dim] = sub_output.dims[i].size - 1; - } - init_kernel(m, input_domain, output_domain); - sim->free_all(); - float *input_ptr = - (float *)sim->allocate(sub_input.get_volume(), inputs[0]->data_type); - assert(input_ptr != NULL); - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *output_ptr = NULL; - if (inplace) { - output_ptr = input_ptr; - } else { - output_ptr = - (float *)sim->allocate(sub_output.get_volume(), outputs[0]->data_type); - } - assert(output_ptr != NULL); - cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - assert(m->profiling == false); - - std::function forward, backward; - forward = [&] { - forward_kernel_wrapper(m, input_ptr, output_ptr, sub_output.get_volume()); - }; - if (sim->computationMode == COMP_MODE_TRAINING) { - float *input_grad_ptr = - (float *)sim->allocate(sub_input.get_volume(), inputs[0]->data_type); - assert(input_grad_ptr != NULL); - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *output_grad_ptr = NULL; - if (inplace) { - output_grad_ptr = input_grad_ptr; - } else { - output_grad_ptr = (float *)sim->allocate(sub_output.get_volume(), - outputs[0]->data_type); - } - assert(output_grad_ptr != NULL); - cost_metrics.outputs_memory += - cost_metrics.total_mem_diff_from(sim->offset); - - backward = [&] { - backward_kernel_wrapper(m, - input_ptr, - input_grad_ptr, - output_ptr, - output_grad_ptr, - sub_output.get_volume()); - }; - } - - inner_measure_operator_cost(sim, forward, backward, cost_metrics); - - if (sim->computationMode == COMP_MODE_TRAINING) { - log_measure.debug("[Measure Elewise Unary] name(%s) num_elements(%zu) " - "forward_time(%.4lf) backward_time(%.4lf)\n", - name, - sub_output.get_volume(), - cost_metrics.forward_time, - cost_metrics.backward_time); - } else { - log_measure.debug("[Measure Elewise Unary] name(%s) num_elements(%zu) " - "forward_time(%.4lf)\n", - name, - sub_output.get_volume(), - cost_metrics.forward_time); - } - return true; +template <> +void register_task() { + register_task(ELEMENTUNARY_INIT_TASK_ID, + "ElementUnary Init", + init_signature(), + init_task); } -using PCG::Node; -/*static*/ -Node ElementUnary::deserialize(FFModel &ff, - Legion::Deserializer &dez, - ParallelTensor inputs[], - int num_inputs) { - assert(num_inputs == 1); - OperatorType op_type; - float scalar; - bool inplace; - dez.deserialize(op_type); - dez.deserialize(inplace); - dez.deserialize(scalar); - - ElementUnaryParams params; - params.op_type = op_type; - params.inplace = inplace; - params.scalar = scalar; - return ff.get_or_create_node(inputs[0], params); +template <> +OpTaskSignature fwd_signature() { + OpTaskSignature fwd(OpTaskType::FWD); + + fwd.add_input_slot(INPUT); + fwd.add_output_slot(OUTPUT); + + fwd.add_arg_slot(PROFILING); + fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); + + return fwd; } -Op *ElementUnary::materialize(FFModel &ff, - ParallelTensor inputs[], - int num_inputs) const { - assert(num_inputs == 1); - return new ElementUnary( - ff, this->op_type, inputs[0], this->inplace, this->name, this->scalar); +template <> +void register_task() { + register_task(ELEMENTUNARY_FWD_TASK_ID, + "ElementUnary Fwd", + fwd_signature(), + forward_task); } -}; // namespace FlexFlow +template <> +OpTaskSignature bwd_signature() { + OpTaskSignature bwd = + infer_bwd_signature(fwd_signature()); -namespace std { -size_t hash::operator()( - FlexFlow::ElementUnaryParams const ¶ms) const { - size_t key = 0; - hash_combine(key, params.op_type); - hash_combine(key, params.scalar); - hash_combine(key, params.inplace); - return key; + return bwd; } -}; // namespace std + +template <> +void register_task() { + register_task(ELEMENTUNARY_BWD_TASK_ID, + "ElementUnary Bwd", + bwd_signature(), + backward_task); +} + +} // namespace FlexFlow From 247c0196c807fdf6b95386b9a86e3645fea86562 Mon Sep 17 00:00:00 2001 From: Reyna Abhyankar Date: Wed, 3 Jan 2024 18:09:26 -0800 Subject: [PATCH 02/18] Fix --- lib/op-attrs/include/op-attrs/ops/element_unary.h | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary.h b/lib/op-attrs/include/op-attrs/ops/element_unary.h index aa393821bd..5544981dd6 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary.h @@ -2,7 +2,6 @@ #define _ELEMENT_UNARY_H #include "op-attrs/ops/element_unary.h" -#include "op_task_invocation.h" #include "sim_environment.h" namespace FlexFlow { From 57e996c1bf01111161c623fc2ee9b19131e0f8e6 Mon Sep 17 00:00:00 2001 From: Reyna Abhyankar Date: Wed, 3 Jan 2024 18:34:49 -0800 Subject: [PATCH 03/18] Fix attrs --- .../include/op-attrs/ops/element_unary.h | 53 ++++++--------- lib/runtime/src/ops/element_unary.h | 64 +------------------ 2 files changed, 21 insertions(+), 96 deletions(-) diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary.h b/lib/op-attrs/include/op-attrs/ops/element_unary.h index 5544981dd6..ec8f6062d6 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary.h @@ -1,38 +1,23 @@ -#ifndef _ELEMENT_UNARY_H -#define _ELEMENT_UNARY_H +#ifndef _FLEXFLOW_ELEMENTARY_UNARY_ATTRS_H +#define _FLEXFLOW_ELEMENTARY_UNARY_ATTRS_H -#include "op-attrs/ops/element_unary.h" -#include "sim_environment.h" +#include "core.h" +#include "op-attrs/op.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "utils/visitable.h" namespace FlexFlow { -template <> -void register_task(); -template <> -void register_task(); -template <> -void register_task(); - -OpTaskInvocation init(ElementUnaryAttrs const &); -OpTaskInvocation forward(ElementUnaryAttrs const &); -OpTaskInvocation backward(ElementUnaryAttrs const &); - -OpTaskInvocation init(ElementScalarUnaryAttrs const &); -OpTaskInvocation forward(ElementScalarUnaryAttrs const &); -OpTaskInvocation backward(ElementScalarUnaryAttrs const &); - -CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, - ElementUnaryAttrs const &attrs, - InputParallelTensorDesc const &input_shape, - ProfilingSettings const &settings, - MachineView const &machine_view); - -CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, - ElementScalarUnaryAttrs const &attrs, - InputParallelTensorDesc const &input_shape, - ProfilingSettings const &settings, - MachineView const &machine_view); - -} // namespace FlexFlow - -#endif +struct ElementScalarUnaryAttrs { + req op; + /* bool inplace; */ + req scalar; +}; +FF_VISITABLE_STRUCT(ElementScalarUnaryAttrs, op, scalar); +CHECK_VALID_OP_ATTR(ElementScalarUnaryAttrs); + +struct ElementUnaryAttrs { + req op; +}; +FF_VISITABLE_STRUCT(ElementUnaryAttrs, op); +CHECK_VALID_OP_ATTR(ElementUnaryAttrs); \ No newline at end of file diff --git a/lib/runtime/src/ops/element_unary.h b/lib/runtime/src/ops/element_unary.h index ae661f1177..aa393821bd 100644 --- a/lib/runtime/src/ops/element_unary.h +++ b/lib/runtime/src/ops/element_unary.h @@ -24,76 +24,16 @@ OpTaskInvocation backward(ElementScalarUnaryAttrs const &); CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, ElementUnaryAttrs const &attrs, - ParallelTensorShape const &input_shape, + InputParallelTensorDesc const &input_shape, ProfilingSettings const &settings, MachineView const &machine_view); CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, ElementScalarUnaryAttrs const &attrs, - ParallelTensorShape const &input_shape, + InputParallelTensorDesc const &input_shape, ProfilingSettings const &settings, MachineView const &machine_view); -/* class ElementUnary : public Op { */ -/* public: */ -/* ElementUnary(FFModel &model, */ -/* OperatorType type, */ -/* const ParallelTensor x, */ -/* bool inplace, */ -/* char const *name, */ -/* float scalar); */ -/* void init(FFModel const &) override; */ -/* void forward(FFModel const &) override; */ -/* void backward(FFModel const &) override; */ -/* void map_output_tensors(FFModel &model) override; */ -/* bool can_inplace_output() override; */ -/* bool has_inplace_output() override; */ -/* void do_inplace_output() override; */ -/* static Op * */ -/* create_operator_from_layer(FFModel &model, */ -/* Layer const *layer, */ -/* std::vector const &inputs); - */ - -/* static PerDeviceOpState *init_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void forward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void backward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* template */ -/* static void */ -/* forward_task_with_type(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* template */ -/* static void backward_task_with_type( */ -/* Legion::Task const *task, */ -/* std::vector const ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* bool measure_operator_cost(Simulator *sim, */ -/* MachineView const &pc, */ -/* CostMetrics &cost_metrics) const override; */ - -/* private: */ -/* bool inplace; */ - -/* public: */ -/* float scalar; */ -/* }; */ - } // namespace FlexFlow #endif From d13e7101e97e4c6e440a1531fb2158d9d6f169a5 Mon Sep 17 00:00:00 2001 From: Reyna Abhyankar Date: Wed, 3 Jan 2024 18:35:06 -0800 Subject: [PATCH 04/18] Format --- lib/op-attrs/include/op-attrs/ops/element_unary.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary.h b/lib/op-attrs/include/op-attrs/ops/element_unary.h index ec8f6062d6..3be7e8c761 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary.h @@ -20,4 +20,4 @@ struct ElementUnaryAttrs { req op; }; FF_VISITABLE_STRUCT(ElementUnaryAttrs, op); -CHECK_VALID_OP_ATTR(ElementUnaryAttrs); \ No newline at end of file +CHECK_VALID_OP_ATTR(ElementUnaryAttrs); From 34c00f5ddf6c31a920fdabf85be80b648cc7aa0c Mon Sep 17 00:00:00 2001 From: Reyna Abhyankar Date: Wed, 17 Jan 2024 13:22:02 -0800 Subject: [PATCH 05/18] Remove comment --- lib/kernels/src/cuda/element_unary_kernels.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/kernels/src/cuda/element_unary_kernels.cu b/lib/kernels/src/cuda/element_unary_kernels.cu index 568e1f319c..f1e91fa2ed 100644 --- a/lib/kernels/src/cuda/element_unary_kernels.cu +++ b/lib/kernels/src/cuda/element_unary_kernels.cu @@ -69,7 +69,6 @@ ElementUnaryPerDeviceState init_kernel(PerDeviceFFHandle const &handle, cudnnSetActivationDescriptor(actiDesc, mode, CUDNN_PROPAGATE_NAN, 0.0)); checkCUDNN( cudnnSetTensorDescriptorFromArrayShape(inputTensor, input_shape)); - // input_shape == output_shape checkCUDNN( cudnnSetTensorDescriptorFromArrayShape(outputTensor, output_shape)); } From 93b44bba61cbcae485ad4280df16bce12bb31142 Mon Sep 17 00:00:00 2001 From: Reyna Abhyankar Date: Wed, 17 Jan 2024 13:42:49 -0800 Subject: [PATCH 06/18] Fix endif --- lib/op-attrs/include/op-attrs/ops/element_unary.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary.h b/lib/op-attrs/include/op-attrs/ops/element_unary.h index 3be7e8c761..82cb2a9a84 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary.h @@ -21,3 +21,9 @@ struct ElementUnaryAttrs { }; FF_VISITABLE_STRUCT(ElementUnaryAttrs, op); CHECK_VALID_OP_ATTR(ElementUnaryAttrs); + +} + +// namespace FlexFlow + +#endif \ No newline at end of file From 15c895f694fd7bb5fd05bb727ea5e60e12d6101d Mon Sep 17 00:00:00 2001 From: Reyna Abhyankar Date: Wed, 24 Jan 2024 17:28:13 -0800 Subject: [PATCH 07/18] Move arguments and remove ElementScalarUnary --- .../include/kernels/element_unary_kernels.h | 20 ++-- lib/kernels/src/cuda/element_unary_kernels.cu | 67 +++++++---- lib/kernels/src/hip/element_unary_kernels.cpp | 111 ++++++++++-------- .../include/op-attrs/ops/element_unary.h | 19 +-- lib/runtime/src/ops/element_unary.cc | 84 ++----------- 5 files changed, 129 insertions(+), 172 deletions(-) diff --git a/lib/kernels/include/kernels/element_unary_kernels.h b/lib/kernels/include/kernels/element_unary_kernels.h index 5fac73132d..407ff3ebfe 100644 --- a/lib/kernels/include/kernels/element_unary_kernels.h +++ b/lib/kernels/include/kernels/element_unary_kernels.h @@ -10,39 +10,33 @@ namespace FlexFlow { struct ElementUnaryPerDeviceState { - PerDeviceFFHandle handle; ffTensorDescriptor_t inputTensor, outputTensor; ffActivationDescriptor_t actiDesc; - - OperatorType op_type; - DataType data_type; - float scalar; }; FF_VISITABLE_STRUCT_NO_EQ(ElementUnaryPerDeviceState, - handle, inputTensor, outputTensor, - actiDesc, - op_type, - data_type, - scalar); + actiDesc); namespace Kernels { namespace ElementUnary { -ElementUnaryPerDeviceState init_kernel(PerDeviceFFHandle const &handle, - ArrayShape const &input_shape, +ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, ArrayShape const &output_shape, - DataType data_type); + ElementUnaryAttrs const &attrs); void forward_kernel(ffStream_t stream, ElementUnaryPerDeviceState const &device_state, + ElementUnaryAttrs const &attrs, + PerDeviceFFHandle &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output); void backward_kernel(ffStream_t stream, ElementUnaryPerDeviceState const &device_state, + ElementUnaryAttrs const &attrs, + PerDeviceFFHandle &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &input_grad, GenericTensorAccessorR const &output, diff --git a/lib/kernels/src/cuda/element_unary_kernels.cu b/lib/kernels/src/cuda/element_unary_kernels.cu index f1e91fa2ed..079aa35172 100644 --- a/lib/kernels/src/cuda/element_unary_kernels.cu +++ b/lib/kernels/src/cuda/element_unary_kernels.cu @@ -33,11 +33,9 @@ static bool use_cudnn(OperatorType op_type) { } } -ElementUnaryPerDeviceState init_kernel(PerDeviceFFHandle const &handle, - ArrayShape const &input_shape, +ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, ArrayShape const &output_shape, - OperatorType op_type, - DataType data_type) { + ElementUnaryAttrs const &attrs) { ffTensorDescriptor_t inputTensor; ffTensorDescriptor_t outputTensor; @@ -47,9 +45,9 @@ ElementUnaryPerDeviceState init_kernel(PerDeviceFFHandle const &handle, checkCUDNN(cudnnCreateTensorDescriptor(&outputTensor)); checkCUDNN(cudnnCreateActivationDescriptor(&actiDesc)); - if (use_cudnn(op_type)) { + if (use_cudnn(attrs.op_type)) { cudnnActivationMode_t mode; - switch (op_type) { + switch (attrs.op_type) { case OP_SIGMOID: mode = CUDNN_ACTIVATION_SIGMOID; break; @@ -74,7 +72,7 @@ ElementUnaryPerDeviceState init_kernel(PerDeviceFFHandle const &handle, } ElementUnaryPerDeviceState per_device_state = { - handle, inputTensor, outputTensor, actiDesc, op_type, data_type, scalar}; + inputTensor, outputTensor, actiDesc}; return per_device_state; } @@ -83,12 +81,14 @@ template struct ForwardKernel { void operator()(ffStream_t stream, ElementUnaryPerDeviceState const &m, + ElementUnaryAttrs const &attrs, + PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) const { - checkCUDNN(cudnnSetStream(m.handle.dnn, stream)); - if (use_cudnn(m.op_type)) { + checkCUDNN(cudnnSetStream(handle.dnn, stream)); + if (use_cudnn(attrs.op_type)) { float alpha = 1.0f, beta = 0.0f; - checkCUDNN(cudnnActivationForward(m.handle.dnn, + checkCUDNN(cudnnActivationForward(handle.dnn, m.actiDesc, &alpha, m.inputTensor, @@ -102,8 +102,8 @@ struct ForwardKernel { CUDA_NUM_THREADS, 0, stream>>>(num_elements, - (T)m.scalar, - m.op_type, + (T)attrs.scalar, + attrs.op_type, input.get(), output.get()); } @@ -114,15 +114,17 @@ template struct BackwardKernel { void operator()(ffStream_t stream, ElementUnaryPerDeviceState const &m, + ElementUnaryAttrs const &attrs, + PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &input_grad, GenericTensorAccessorR const &output, GenericTensorAccessorR const &output_grad) { - checkCUDNN(cudnnSetStream(m.handle.dnn, stream)); + checkCUDNN(cudnnSetStream(handle.dnn, stream)); - if (use_cudnn(m.op_type)) { + if (use_cudnn(attrs.op_type)) { float alpha = 1.0f; - checkCUDNN(cudnnActivationBackward(m.handle.dnn, + checkCUDNN(cudnnActivationBackward(handle.dnn, m.actiDesc, &alpha, m.outputTensor, @@ -139,8 +141,8 @@ struct BackwardKernel { elewise_unary_backward_kernel <<>>( num_elements, - m.scalar, - m.op_type, + attrs.scalar, + attrs.op_type, output.get(), output_grad.get(), input.get(), @@ -151,18 +153,31 @@ struct BackwardKernel { void forward_kernel(ffStream_t stream, ElementUnaryPerDeviceState const &device_state, + ElementUnaryAttrs const &attrs, + PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { - { DataTypeDispatch1{}(m.data_type, stream, m, input, output); } + DataTypeDispatch1{}( + input.data_type, stream, m, attrs, handle, input, output); +} - void backward_kernel(ffStream_t stream, - ElementUnaryPerDeviceState const &device_state, - GenericTensorAccessorR const &input, - GenericTensorAccessorR const &input_grad, - GenericTensorAccessorW const &output, - GenericTensorAccessorW const &output_grad) - DataTypeDispatch1{}( - m.data_type, stream, m, input, input_grad, output, output_grad); +void backward_kernel(ffStream_t stream, + ElementUnaryPerDeviceState const &device_state, + ElementUnaryAttrs const &attrs, + PerDeviceFFHandle const &handle, + GenericTensorAccessorR const &input, + GenericTensorAccessorR const &input_grad, + GenericTensorAccessorW const &output, + GenericTensorAccessorW const &output_grad) { + DataTypeDispatch1{}(input.data_type, + stream, + m, + attrs, + handle, + input, + input_grad, + output, + output_grad); } template diff --git a/lib/kernels/src/hip/element_unary_kernels.cpp b/lib/kernels/src/hip/element_unary_kernels.cpp index 01bb7ca5f9..d03a8ddfb8 100644 --- a/lib/kernels/src/hip/element_unary_kernels.cpp +++ b/lib/kernels/src/hip/element_unary_kernels.cpp @@ -14,6 +14,7 @@ */ #include "kernels/element_unary_kernels.h" +#include "kernels/datatype_dispatch.h" #include "kernels/hip_helper.h" #include @@ -21,11 +22,9 @@ namespace FlexFlow { namespace Kernels { namespace ElementUnary { -ElementUnaryPerDeviceState init_kernel(PerDeviceFFHandle const &handle, - ArrayShape const &input_shape, +ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, ArrayShape const &output_shape, - OperatorType op_type, - DataType data_type) { + ElementUnaryAttrs const &attrs) { miopenTensorDescriptor_t inputTensor; miopenTensorDescriptor_t outputTensor; miopenActivationDescriptor_t actiDesc; @@ -35,30 +34,33 @@ ElementUnaryPerDeviceState init_kernel(PerDeviceFFHandle const &handle, checkCUDNN(miopenCreateTensorDescriptor(&outputTensor)); checkCUDNN(miopenCreateActivationDescriptor(&actiDesc)); - switch (op_type) { - case OP_SIGMOID: - mode = miopenActivationLOGISTIC; - break; - case OP_RELU: - mode = miopenActivationRELU; - break; - case OP_TANH: - mode = miopenActivationTANH; - break; - case OP_ELU: - mode = miopenActivationELU; - break; - default: - assert(false); + if (use_cudnn(attrs.op_type)) { + switch (attrs.op_type) { + case OP_SIGMOID: + mode = miopenActivationLOGISTIC; + break; + case OP_RELU: + mode = miopenActivationRELU; + break; + case OP_TANH: + mode = miopenActivationTANH; + break; + case OP_ELU: + mode = miopenActivationELU; + break; + default: + assert(false); + } + checkCUDNN(miopenSetActivationDescriptor(actiDesc, mode, 0.0, 0.0, 0.0)); + checkCUDNN( + cudnnSetTensorDescriptorFromArrayShape(inputTensor, input_shape)); + // input_domain == output_domain + checkCUDNN( + cudnnSetTensorDescriptorFromArrayShape(outputTensor, output_shape)); } - checkCUDNN(miopenSetActivationDescriptor(actiDesc, mode, 0.0, 0.0, 0.0)); - checkCUDNN(cudnnSetTensorDescriptorFromArrayShape(inputTensor, input_shape)); - // input_domain == output_domain - checkCUDNN( - cudnnSetTensorDescriptorFromArrayShape(outputTensor, output_shape)); ElementUnaryPerDeviceState per_device_state = { - handle, inputTensor, outputTensor, actiDesc, op_type, data_type, scalar}; + inputTensor, outputTensor, actiDesc}; return per_device_state; } @@ -82,13 +84,15 @@ bool use_cudnn(OperatorType type) { template struct ForwardKernel { void operator()(ffStream_t stream, - ElementUnaryPerDeviceState const *m, + ElementUnaryPerDeviceState const &m, + ElementUnaryAttrs const &attrs, + PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { - checkCUDNN(miopenSetStream(m.handle.dnn, stream)); - if (use_cudnn(m.op_type)) { + checkCUDNN(miopenSetStream(handle.dnn, stream)); + if (use_cudnn(attrs.op_type)) { float alpha = 1.0f, beta = 0.0f; - checkCUDNN(miopenActivationForward(m.handle.dnn, + checkCUDNN(miopenActivationForward(handle.dnn, m.actiDesc, &alpha, m.inputTensor, @@ -104,8 +108,8 @@ struct ForwardKernel { 0, stream, num_elements, - (T)m.scalar, - m.op_type, + (T)attrs.scalar, + attrs.op_type, input.get(), output.get()); } @@ -115,17 +119,19 @@ struct ForwardKernel { template struct BackwardKernel { void operator()(ffStream_t stream, - ElementUnaryPerDeviceState const *m, + ElementUnaryPerDeviceState const &m, + ElementUnaryAttrs const &attrs, + PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorR const &input_grad, GenericTensorAccessorW const &output, GenericTensorAccessorW const &output_grad) { - checkCUDNN(miopenSetStream(m.handle.dnn, stream)); + checkCUDNN(miopenSetStream(handle.dnn, stream)); - if (use_cudnn(m.op_type)) { + if (use_cudnn(attrs.op_type)) { float alpha = 1.0f; float beta = 0.0f; - checkCUDNN(miopenActivationBackward(m.handle.dnn, + checkCUDNN(miopenActivationBackward(handle.dnn, m.actiDesc, &alpha, m.outputTensor, @@ -145,8 +151,8 @@ struct BackwardKernel { 0, stream, num_elements, - m.scalar, - m.op_type, + attrs.scalar, + attrs.op_type, output.get(), output_grad.get(), input.get(), @@ -155,18 +161,31 @@ struct BackwardKernel { } } void forward_kernel(ffStream_t stream, ElementUnaryPerDeviceState const &device_state, + ElementUnaryAttrs const &attrs, + PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { - { DataTypeDispatch1{}(m.data_type, stream, m, input, output); } + DataTypeDispatch1{}( + input.data_type, stream, m, attrs, handle, input, output); +} - void backward_kernel(ffStream_t stream, - ElementUnaryPerDeviceState const &device_state, - GenericTensorAccessorR const &input, - GenericTensorAccessorW const &input_grad, - GenericTensorAccessorR const &output, - GenericTensorAccessorR const &output_grad) - DataTypeDispatch1{}( - m.data_type, stream, m, input, input_grad, output, output_grad); +void backward_kernel(ffStream_t stream, + ElementUnaryPerDeviceState const &device_state, + ElementUnaryAttrs const &attrs, + PerDeviceFFHandle const &handle, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output, + GenericTensorAccessorR const &output_grad) { + DataTypeDispatch1{}(input.data_type, + stream, + m, + attrs, + hanlde, + input, + input_grad, + output, + output_grad); } template diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary.h b/lib/op-attrs/include/op-attrs/ops/element_unary.h index 82cb2a9a84..e908fffff6 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary.h @@ -8,22 +8,13 @@ namespace FlexFlow { -struct ElementScalarUnaryAttrs { - req op; - /* bool inplace; */ - req scalar; -}; -FF_VISITABLE_STRUCT(ElementScalarUnaryAttrs, op, scalar); -CHECK_VALID_OP_ATTR(ElementScalarUnaryAttrs); - struct ElementUnaryAttrs { - req op; + req op_type; + req scalar; }; -FF_VISITABLE_STRUCT(ElementUnaryAttrs, op); +FF_VISITABLE_STRUCT(ElementUnaryAttrs, op_type, scalar); CHECK_VALID_OP_ATTR(ElementUnaryAttrs); -} - -// namespace FlexFlow +} // namespace FlexFlow -#endif \ No newline at end of file +#endif diff --git a/lib/runtime/src/ops/element_unary.cc b/lib/runtime/src/ops/element_unary.cc index b5864749c9..80130147c7 100644 --- a/lib/runtime/src/ops/element_unary.cc +++ b/lib/runtime/src/ops/element_unary.cc @@ -53,36 +53,6 @@ OpTaskInvocation backward(ElementUnaryAttrs const &attrs) { return {ELEMENTUNARY_BWD_TASK_ID, b}; } -/* ElementScalarUnary */ -OpTaskInvocation init(ElementScalarUnaryAttrs const &attrs) { - OpTaskBinding b; - - b.bind_arg(HANDLE, ff_handle()); - b.bind_arg(ATTRS, attrs); - b.bind_arg(INPUT_SHAPE, input_parallel_tensor_shape(0)); - - return {ELEMENTUNARY_INIT_TASK_ID, b}; -} - -OpTaskInvocation forward(ElementScalarUnaryAttrs const &attrs) { - OpTaskBinding b; - - b.bind(INPUT, input_tensor(0)); - b.bind(OUTPUT, output_tensor(0)); - - b.bind_arg(PROFILING, profiling_settings()); - b.bind_arg(PER_DEVICE_STATE, - per_device_op_state()); - - return {ELEMENTUNARY_FWD_TASK_ID, b}; -} - -OpTaskInvocation backward(ElementScalarUnaryAttrs const &attrs) { - OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - - return {ELEMENTUNARY_BWD_TASK_ID, b}; -} - static DeviceSpecific init_task_impl(TaskArgumentAccessor const &acc) { @@ -95,10 +65,7 @@ static DeviceSpecific DeviceSpecific per_device_state = acc.create_device_specific( - init_kernel(handle, - {input_shape.dims}, - {output_shape.dims}, - input_shape.data_type)); + init_kernel(input_shape, output_shape, attrs)); return per_device_state; } @@ -114,6 +81,9 @@ static DeviceSpecific static optional forward_task_impl(TaskArgumentAccessor const &acc) { auto input = acc.get_tensor(INPUT); auto output = acc.get_tensor(OUTPUT); + auto const &attrs = acc.get_argument(ATTRS); + + auto &handle = acc.get_argument(HANDLE); ProfilingSettings profiling = acc.get_argument(PROFILING); auto per_device_state = @@ -123,6 +93,8 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { profiling, "[ElementUnary] forward_time = %.2lfms\n", per_device_state, + attrs, + handle, input, output); } @@ -141,6 +113,9 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { auto output = acc.get_tensor(OUTPUT); auto output_grad = acc.get_tensor_grad(OUTPUT); + auto const &attrs = acc.get_argument(ATTRS); + auto &handle = acc.get_argument(HANDLE); + auto per_device_state = acc.get_argument(PER_DEVICE_STATE); ProfilingSettings profiling = acc.get_argument(PROFILING); @@ -149,6 +124,8 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { profiling, "[ElementUnary] backward_time = %.2lfms\n", per_device_state, + attrs, + handle, input, input_grad, output, @@ -202,45 +179,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, return make_metrics(forward_time, backward_time, sync_time, env); } -CostMetrics measure_operator_cost(SimEnvFactory const &sim, - ElementScalarUnaryAttrs const &attrs, - InputParallelTensorDesc const &input_shape, - ProfilingSettings const &settings, - MachineView const &mv) { - auto env = sim.new_environment(); - - ParallelTensorShape output_shape = get_output_shape(attrs, input_shape); - - SimTaskBinding init_binding; - init_binding.bind_arg(HANDLE, ff_handle()); - init_binding.bind_arg(ATTRS, attrs); - init_binding.bind_arg(INPUT_SHAPE, input_parallel_tensor_shape(0)); - - auto init_accessor = - env.get_init_accessor(ELEMENTUNARY_INIT_TASK_ID, init_binding); - DeviceSpecific per_device_state = - init_task_impl(init_accessor); - - SimTaskBinding fwd_binding; - fwd_binding.bind(INPUT, input_shape); - fwd_binding.bind(OUTPUT, output_shape); - fwd_binding.bind_arg(PROFILING, settings); - fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); - - SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); - - auto fwd_accessor = - env.get_fwd_accessor(ELEMENTUNARY_FWD_TASK_ID, fwd_binding); - auto bwd_accessor = - env.get_bwd_accessor(ELEMENTUNARY_BWD_TASK_ID, bwd_binding); - - float forward_time = forward_task_impl(fwd_accessor).value(); - float backward_time = backward_task_impl(bwd_accessor).value(); - - float sync_time = default_estimate_sync_time(env); - return make_metrics(forward_time, backward_time, sync_time, env); -} - template <> OpTaskSignature init_signature() { OpTaskSignature init(OpTaskType::INIT); From 317c41e11e279d99e9e9859ff066e87d6dfb367b Mon Sep 17 00:00:00 2001 From: Reyna Abhyankar Date: Wed, 24 Jan 2024 17:36:31 -0800 Subject: [PATCH 08/18] Remove ElementScalarUnary --- lib/compiler/test/test_dp.cc | 4 ++-- lib/op-attrs/include/op-attrs/get_op_type.h | 1 - lib/op-attrs/include/op-attrs/operator_attrs.h | 2 -- lib/pcg/include/pcg/computation_graph_builder.h | 2 +- lib/pcg/src/computation_graph_builder.cc | 4 ++-- lib/runtime/src/ops/element_unary.h | 10 ---------- lib/runtime/test/src/test_serialization.cc | 1 - lib/substitutions/src/substitution.cc | 2 +- 8 files changed, 6 insertions(+), 20 deletions(-) diff --git a/lib/compiler/test/test_dp.cc b/lib/compiler/test/test_dp.cc index 01e4189839..1878ade0b6 100644 --- a/lib/compiler/test/test_dp.cc +++ b/lib/compiler/test/test_dp.cc @@ -22,8 +22,8 @@ TEST_CASE("optimal_cost") { Node n0 = g.add_node(InputAttrs()); Node n1 = g.add_node(RepartitionAttrs(ff_dim_t(0), 2)); - Node n2 = g.add_node(ElementScalarUnaryAttrs(OP_SCALAR_ADD, 0)); - Node n3 = g.add_node(ElementScalarUnaryAttrs(OP_SCALAR_ADD, 1)); + Node n2 = g.add_node(ElementUnaryAttrs(OP_SCALAR_ADD, 0)); + Node n3 = g.add_node(ElementUnaryAttrs(OP_SCALAR_ADD, 1)); Node n4 = g.add_node(ConcatAttrs(ff_dim_t(1))); Node n5 = g.add_node(CombineAttrs(ff_dim_t(0), 2)); diff --git a/lib/op-attrs/include/op-attrs/get_op_type.h b/lib/op-attrs/include/op-attrs/get_op_type.h index af13a75720..4558584189 100644 --- a/lib/op-attrs/include/op-attrs/get_op_type.h +++ b/lib/op-attrs/include/op-attrs/get_op_type.h @@ -14,7 +14,6 @@ OperatorType get_op_type(ConcatAttrs const &); OperatorType get_op_type(Conv2DAttrs const &); OperatorType get_op_type(DropoutAttrs const &); OperatorType get_op_type(ElementBinaryAttrs const &); -OperatorType get_op_type(ElementScalarUnaryAttrs const &); OperatorType get_op_type(ElementUnaryAttrs const &); OperatorType get_op_type(EmbeddingAttrs const &); OperatorType get_op_type(FlatAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/operator_attrs.h b/lib/op-attrs/include/op-attrs/operator_attrs.h index f429facf6f..c4eb78f9dd 100644 --- a/lib/op-attrs/include/op-attrs/operator_attrs.h +++ b/lib/op-attrs/include/op-attrs/operator_attrs.h @@ -42,7 +42,6 @@ using SharedOperatorAttrs = variant::value, ""); static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index 2d65a37a2d..3cea2845b1 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -247,7 +247,7 @@ struct ComputationGraphBuilder float scalar, optional const &name = nullopt); Tensor - element_unary(variant const &, + element_unary(ElementUnaryAttrs const &, Tensor const &input, optional const &name = nullopt); diff --git a/lib/pcg/src/computation_graph_builder.cc b/lib/pcg/src/computation_graph_builder.cc index 46a7ea421e..0864b84410 100644 --- a/lib/pcg/src/computation_graph_builder.cc +++ b/lib/pcg/src/computation_graph_builder.cc @@ -34,7 +34,7 @@ static std::string get_default_name(variant const &attrs) { } Tensor ComputationGraphBuilder::element_unary( - variant const &attrs, + ElementUnaryAttrs const &attrs, Tensor const &x, optional const &maybe_name) { std::string name = maybe_name.value_or(get_default_name(attrs)); @@ -60,7 +60,7 @@ Tensor ComputationGraphBuilder::element_scalar_unary( Tensor const &input, float scalar, optional const &name) { - ElementScalarUnaryAttrs attrs = {op_type, scalar}; + ElementUnaryAttrs attrs = {op_type, scalar}; return this->element_unary(attrs, input, name); } diff --git a/lib/runtime/src/ops/element_unary.h b/lib/runtime/src/ops/element_unary.h index aa393821bd..d41cb65c7b 100644 --- a/lib/runtime/src/ops/element_unary.h +++ b/lib/runtime/src/ops/element_unary.h @@ -18,22 +18,12 @@ OpTaskInvocation init(ElementUnaryAttrs const &); OpTaskInvocation forward(ElementUnaryAttrs const &); OpTaskInvocation backward(ElementUnaryAttrs const &); -OpTaskInvocation init(ElementScalarUnaryAttrs const &); -OpTaskInvocation forward(ElementScalarUnaryAttrs const &); -OpTaskInvocation backward(ElementScalarUnaryAttrs const &); - CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, ElementUnaryAttrs const &attrs, InputParallelTensorDesc const &input_shape, ProfilingSettings const &settings, MachineView const &machine_view); -CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, - ElementScalarUnaryAttrs const &attrs, - InputParallelTensorDesc const &input_shape, - ProfilingSettings const &settings, - MachineView const &machine_view); - } // namespace FlexFlow #endif diff --git a/lib/runtime/test/src/test_serialization.cc b/lib/runtime/test/src/test_serialization.cc index 127b332ccf..44c5bc320b 100644 --- a/lib/runtime/test/src/test_serialization.cc +++ b/lib/runtime/test/src/test_serialization.cc @@ -20,7 +20,6 @@ TEST_CASE("Serialization") { CombineAttrs combine_attrs, ConcatAttrs concat_attrs, Conv2DAttrs conv2d_attrs, DropoutAttrs dropout_attrs, ElementBinaryAttrs elem_bin_attrs, - ElementScalarUnaryAttrs elem_scalar_unary_attrs, ElementUnaryAttrs elem_unary_attrs, EmbeddingAttrs embedding_attrs, FlatAttrs flat_attrs, GatherAttrs gather_attrs, InputAttrs input_attrs, LayerNormAttrs layer_norm_attrs, LinearAttrs linear_attrs, diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc index 797272b13b..e08b813715 100644 --- a/lib/substitutions/src/substitution.cc +++ b/lib/substitutions/src/substitution.cc @@ -178,7 +178,7 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, case Op::SCALAR_SUB: case Op::SCALAR_TRUE_DIV: return Operator( - ElementScalarUnaryAttrs{ + ElementUnaryAttrs{ op_type, get(assignments.at(OperatorAttributeKey::SCALAR))}, nullopt); From d0c704ca2c6740136656c8250ad20b22c8863012 Mon Sep 17 00:00:00 2001 From: Reyna Abhyankar Date: Wed, 7 Feb 2024 16:17:42 -0800 Subject: [PATCH 09/18] Scalar Unary with inheritance --- lib/kernels/src/hip/element_unary_kernels.cpp | 2 +- lib/op-attrs/include/op-attrs/ops/element_unary.h | 10 ++++++++-- lib/op-attrs/src/get_op_type.cc | 3 --- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/lib/kernels/src/hip/element_unary_kernels.cpp b/lib/kernels/src/hip/element_unary_kernels.cpp index d03a8ddfb8..e79ef57592 100644 --- a/lib/kernels/src/hip/element_unary_kernels.cpp +++ b/lib/kernels/src/hip/element_unary_kernels.cpp @@ -181,7 +181,7 @@ void backward_kernel(ffStream_t stream, stream, m, attrs, - hanlde, + handle, input, input_grad, output, diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary.h b/lib/op-attrs/include/op-attrs/ops/element_unary.h index e908fffff6..31f46ff73f 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary.h @@ -10,11 +10,17 @@ namespace FlexFlow { struct ElementUnaryAttrs { req op_type; - req scalar; + float scalar; }; -FF_VISITABLE_STRUCT(ElementUnaryAttrs, op_type, scalar); +FF_VISITABLE_STRUCT(ElementUnaryAttrs, op_type); CHECK_VALID_OP_ATTR(ElementUnaryAttrs); +struct ElementScalarUnaryAttrs : ElementUnaryAttrs { + req scalar; +} +FF_VISITABLE_STRUCT(ElementScalarUnaryAttrs, op_type, sclar); +CHECK_VALID_OP_ATTR(ElementScalarUnaryAttrs); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/src/get_op_type.cc b/lib/op-attrs/src/get_op_type.cc index 7b6bf9eddd..b7a4116092 100644 --- a/lib/op-attrs/src/get_op_type.cc +++ b/lib/op-attrs/src/get_op_type.cc @@ -86,9 +86,6 @@ OperatorType get_op_type(RepartitionAttrs const &) { OperatorType get_op_type(ReplicateAttrs const &) { return Op::REPLICATE; } -OperatorType get_op_type(ElementScalarUnaryAttrs const &attrs) { - return attrs.op; -} OperatorType get_op_type(ReverseAttrs const &attrs) { return Op::REVERSE; } From 95b13e79b90d76f21dc2c544f61f2d4e54c76c56 Mon Sep 17 00:00:00 2001 From: Reyna Abhyankar Date: Wed, 7 Feb 2024 16:18:31 -0800 Subject: [PATCH 10/18] Format --- lib/op-attrs/include/op-attrs/ops/element_unary.h | 5 ++--- lib/pcg/include/pcg/computation_graph_builder.h | 7 +++---- lib/runtime/test/src/test_serialization.cc | 6 +++--- lib/substitutions/src/substitution.cc | 9 ++++----- 4 files changed, 12 insertions(+), 15 deletions(-) diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary.h b/lib/op-attrs/include/op-attrs/ops/element_unary.h index 31f46ff73f..10617cd637 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary.h @@ -12,13 +12,12 @@ struct ElementUnaryAttrs { req op_type; float scalar; }; -FF_VISITABLE_STRUCT(ElementUnaryAttrs, op_type); +FF_VISITABLE_STRUCT(ElementUnaryAttrs, op_type, scalar); CHECK_VALID_OP_ATTR(ElementUnaryAttrs); struct ElementScalarUnaryAttrs : ElementUnaryAttrs { req scalar; -} -FF_VISITABLE_STRUCT(ElementScalarUnaryAttrs, op_type, sclar); +} FF_VISITABLE_STRUCT(ElementScalarUnaryAttrs, op_type, scalar); CHECK_VALID_OP_ATTR(ElementScalarUnaryAttrs); } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index 3cea2845b1..1a0cb6c461 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -246,10 +246,9 @@ struct ComputationGraphBuilder Tensor const &input, float scalar, optional const &name = nullopt); - Tensor - element_unary(ElementUnaryAttrs const &, - Tensor const &input, - optional const &name = nullopt); + Tensor element_unary(ElementUnaryAttrs const &, + Tensor const &input, + optional const &name = nullopt); public: ComputationGraph computation_graph; diff --git a/lib/runtime/test/src/test_serialization.cc b/lib/runtime/test/src/test_serialization.cc index 44c5bc320b..ac8dcae50b 100644 --- a/lib/runtime/test/src/test_serialization.cc +++ b/lib/runtime/test/src/test_serialization.cc @@ -19,9 +19,9 @@ TEST_CASE("Serialization") { BroadcastAttrs broadcast_attrs, CastAttrs cast_attrs, CombineAttrs combine_attrs, ConcatAttrs concat_attrs, Conv2DAttrs conv2d_attrs, DropoutAttrs dropout_attrs, - ElementBinaryAttrs elem_bin_attrs, - ElementUnaryAttrs elem_unary_attrs, EmbeddingAttrs embedding_attrs, - FlatAttrs flat_attrs, GatherAttrs gather_attrs, InputAttrs input_attrs, + ElementBinaryAttrs elem_bin_attrs, ElementUnaryAttrs elem_unary_attrs, + EmbeddingAttrs embedding_attrs, FlatAttrs flat_attrs, + GatherAttrs gather_attrs, InputAttrs input_attrs, LayerNormAttrs layer_norm_attrs, LinearAttrs linear_attrs, MultiHeadAttentionAttrs mha_attrs, NoopAttrs noop_attrs, Pool2DAttrs pool2d_attrs, ReduceAttrs reduce_attrs, diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc index 61ef0f976b..d59f52054c 100644 --- a/lib/substitutions/src/substitution.cc +++ b/lib/substitutions/src/substitution.cc @@ -181,11 +181,10 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, case Op::SCALAR_MULTIPLY: case Op::SCALAR_SUB: case Op::SCALAR_TRUE_DIV: - return Operator( - ElementUnaryAttrs{ - op_type, - get(assignments.at(OperatorAttributeKey::SCALAR))}, - nullopt); + return Operator(ElementUnaryAttrs{op_type, + get(assignments.at( + OperatorAttributeKey::SCALAR))}, + nullopt); case Op::EMBEDDING: return Operator( EmbeddingAttrs{ From 3b22ce30b66be63d7f72b9ef908f0bf40c65a9a7 Mon Sep 17 00:00:00 2001 From: Reyna Abhyankar Date: Wed, 7 Feb 2024 19:36:03 -0800 Subject: [PATCH 11/18] Fix --- lib/op-attrs/include/op-attrs/ops/element_unary.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary.h b/lib/op-attrs/include/op-attrs/ops/element_unary.h index 10617cd637..8563b320ad 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary.h @@ -17,7 +17,8 @@ CHECK_VALID_OP_ATTR(ElementUnaryAttrs); struct ElementScalarUnaryAttrs : ElementUnaryAttrs { req scalar; -} FF_VISITABLE_STRUCT(ElementScalarUnaryAttrs, op_type, scalar); +}; +FF_VISITABLE_STRUCT(ElementScalarUnaryAttrs, op_type, scalar); CHECK_VALID_OP_ATTR(ElementScalarUnaryAttrs); } // namespace FlexFlow From ba2053e44fb885ccc0cf9f7f26a8aea55b51caab Mon Sep 17 00:00:00 2001 From: Reyna Abhyankar Date: Wed, 7 Feb 2024 20:02:31 -0800 Subject: [PATCH 12/18] add optional --- lib/op-attrs/include/op-attrs/ops/element_unary.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary.h b/lib/op-attrs/include/op-attrs/ops/element_unary.h index 8563b320ad..c50feed323 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary.h @@ -10,7 +10,7 @@ namespace FlexFlow { struct ElementUnaryAttrs { req op_type; - float scalar; + req> scalar; }; FF_VISITABLE_STRUCT(ElementUnaryAttrs, op_type, scalar); CHECK_VALID_OP_ATTR(ElementUnaryAttrs); From 2a567d71d1c8bbc9437481376c04568e1f74af82 Mon Sep 17 00:00:00 2001 From: Reyna Abhyankar Date: Thu, 8 Feb 2024 00:27:18 -0800 Subject: [PATCH 13/18] Add op type --- lib/op-attrs/include/op-attrs/ops/element_unary.h | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary.h b/lib/op-attrs/include/op-attrs/ops/element_unary.h index c50feed323..6b5f7e19d9 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary.h @@ -16,6 +16,7 @@ FF_VISITABLE_STRUCT(ElementUnaryAttrs, op_type, scalar); CHECK_VALID_OP_ATTR(ElementUnaryAttrs); struct ElementScalarUnaryAttrs : ElementUnaryAttrs { + req op_type; req scalar; }; FF_VISITABLE_STRUCT(ElementScalarUnaryAttrs, op_type, scalar); From 2ce13d98743859bac0da15dff78c33668fd5484b Mon Sep 17 00:00:00 2001 From: Reyna Abhyankar Date: Wed, 21 Feb 2024 14:49:40 -0800 Subject: [PATCH 14/18] Add element scalar unary --- .../include/kernels/element_unary_kernels.h | 9 ++- lib/kernels/src/cuda/element_unary_kernels.cu | 72 +++++++++++-------- lib/op-attrs/include/op-attrs/get_op_type.h | 1 + .../include/op-attrs/get_output_shapes.h | 4 ++ .../include/op-attrs/operator_attrs.h | 2 + .../include/op-attrs/ops/element_unary.h | 5 +- lib/op-attrs/src/get_op_type.cc | 5 +- .../include/pcg/computation_graph_builder.h | 6 +- lib/pcg/src/computation_graph_builder.cc | 18 ++++- lib/runtime/src/ops/element_unary.cc | 16 ++--- lib/runtime/src/ops/element_unary.h | 11 +-- lib/runtime/test/src/test_serialization.cc | 1 + .../include/substitutions/get_attribute.h | 4 +- lib/substitutions/src/operator_attributes.cc | 8 +++ lib/substitutions/src/substitution.cc | 17 +++-- 15 files changed, 121 insertions(+), 58 deletions(-) diff --git a/lib/kernels/include/kernels/element_unary_kernels.h b/lib/kernels/include/kernels/element_unary_kernels.h index 407ff3ebfe..826be47e32 100644 --- a/lib/kernels/include/kernels/element_unary_kernels.h +++ b/lib/kernels/include/kernels/element_unary_kernels.h @@ -9,6 +9,9 @@ namespace FlexFlow { +using ElementUnaryUnifiedAttrs = + variant; + struct ElementUnaryPerDeviceState { ffTensorDescriptor_t inputTensor, outputTensor; ffActivationDescriptor_t actiDesc; @@ -24,18 +27,18 @@ namespace ElementUnary { ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, ArrayShape const &output_shape, - ElementUnaryAttrs const &attrs); + ElementUnaryUnifiedAttrs const &attrs); void forward_kernel(ffStream_t stream, ElementUnaryPerDeviceState const &device_state, - ElementUnaryAttrs const &attrs, + ElementUnaryUnifiedAttrs const &attrs, PerDeviceFFHandle &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output); void backward_kernel(ffStream_t stream, ElementUnaryPerDeviceState const &device_state, - ElementUnaryAttrs const &attrs, + ElementUnaryUnifiedAttrs const &attrs, PerDeviceFFHandle &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &input_grad, diff --git a/lib/kernels/src/cuda/element_unary_kernels.cu b/lib/kernels/src/cuda/element_unary_kernels.cu index 079aa35172..808ff0ef97 100644 --- a/lib/kernels/src/cuda/element_unary_kernels.cu +++ b/lib/kernels/src/cuda/element_unary_kernels.cu @@ -33,9 +33,17 @@ static bool use_cudnn(OperatorType op_type) { } } +template +optional get_scalar(ElementUnaryAttrs const &attrs) {} + +template +optional get_scalar(ElementScalarUnaryAttrs const &attrs) { + return (T)attrs.scalar; +} + ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, ArrayShape const &output_shape, - ElementUnaryAttrs const &attrs) { + ElementUnaryUnifiedAttrs const &attrs) { ffTensorDescriptor_t inputTensor; ffTensorDescriptor_t outputTensor; @@ -45,9 +53,11 @@ ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, checkCUDNN(cudnnCreateTensorDescriptor(&outputTensor)); checkCUDNN(cudnnCreateActivationDescriptor(&actiDesc)); - if (use_cudnn(attrs.op_type)) { + Op op_type = std::visit([](auto &&arg) { get_op_type(arg); }, attrs); + + if (use_cudnn(op_type)) { cudnnActivationMode_t mode; - switch (attrs.op_type) { + switch (op_type) { case OP_SIGMOID: mode = CUDNN_ACTIVATION_SIGMOID; break; @@ -81,12 +91,13 @@ template struct ForwardKernel { void operator()(ffStream_t stream, ElementUnaryPerDeviceState const &m, - ElementUnaryAttrs const &attrs, + ElementUnaryUnifiedAttrs const &attrs, PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) const { checkCUDNN(cudnnSetStream(handle.dnn, stream)); - if (use_cudnn(attrs.op_type)) { + Op op_type = std::visit([](auto &&arg) { get_op_type(arg); }, attrs); + if (use_cudnn(op_type)) { float alpha = 1.0f, beta = 0.0f; checkCUDNN(cudnnActivationForward(handle.dnn, m.actiDesc, @@ -97,15 +108,14 @@ struct ForwardKernel { m.outputTensor, output.get())); } else { + optional scalar = + std::visit([](auto &&arg) { get_scalar(arg); }, attrs); size_t num_elements = input.shape.num_elements(); elewise_unary_forward_kernel<<>>(num_elements, - (T)attrs.scalar, - attrs.op_type, - input.get(), - output.get()); + stream>>>( + num_elements, scalar, op_type, input.get(), output.get()); } } } @@ -114,7 +124,7 @@ template struct BackwardKernel { void operator()(ffStream_t stream, ElementUnaryPerDeviceState const &m, - ElementUnaryAttrs const &attrs, + ElementUnaryUnifiedAttrs const &attrs, PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &input_grad, @@ -122,7 +132,8 @@ struct BackwardKernel { GenericTensorAccessorR const &output_grad) { checkCUDNN(cudnnSetStream(handle.dnn, stream)); - if (use_cudnn(attrs.op_type)) { + Op op_type = std::visit([](auto &&arg) { get_op_type(arg); }, attrs); + if (use_cudnn(op_type)) { float alpha = 1.0f; checkCUDNN(cudnnActivationBackward(handle.dnn, m.actiDesc, @@ -137,12 +148,14 @@ struct BackwardKernel { m.inputTensor, input_grad.get())); } else { + optional scalar = + std::visit([](auto &&arg) { get_scalar(arg); }, attrs); size_t num_elements = input.shape.num_elements(); elewise_unary_backward_kernel <<>>( num_elements, - attrs.scalar, - attrs.op_type, + scalar, + op_type, output.get(), output_grad.get(), input.get(), @@ -153,7 +166,7 @@ struct BackwardKernel { void forward_kernel(ffStream_t stream, ElementUnaryPerDeviceState const &device_state, - ElementUnaryAttrs const &attrs, + ElementUnaryUnifiedAttrs const &attrs, PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { @@ -163,7 +176,7 @@ void forward_kernel(ffStream_t stream, void backward_kernel(ffStream_t stream, ElementUnaryPerDeviceState const &device_state, - ElementUnaryAttrs const &attrs, + ElementUnaryUnifiedAttrs const &attrs, PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorR const &input_grad, @@ -181,8 +194,11 @@ void backward_kernel(ffStream_t stream, } template -__global__ void elewise_unary_forward_kernel( - coord_t volume, const T scalar, OperatorType type, T const *in, T *out) { +__global__ void elewise_unary_forward_kernel(coord_t volume, + optional const scalar, + OperatorType type, + T const *in, + T *out) { CUDA_KERNEL_LOOP(i, volume) { switch (type) { case OP_EXP: { @@ -194,19 +210,19 @@ __global__ void elewise_unary_forward_kernel( break; } case OP_SCALAR_MULTIPLY: { - out[i] = in[i] * scalar; + out[i] = in[i] * scalar.value(); break; } case OP_SCALAR_ADD: { - out[i] = in[i] + scalar; + out[i] = in[i] + scalar.value(); break; } case OP_SCALAR_SUB: { - out[i] = in[i] - scalar; + out[i] = in[i] - scalar.value(); break; } case OP_SCALAR_TRUE_DIV: { - out[i] = in[i] / scalar; + out[i] = in[i] / scalar.value(); break; } case OP_GELU: { @@ -218,7 +234,7 @@ __global__ void elewise_unary_forward_kernel( break; } case OP_POW: { - out[i] = (T)(powf(in[i], scalar)); + out[i] = (T)(powf(in[i], scalar.value())); break; } case OP_SIN: { @@ -237,7 +253,7 @@ __global__ void elewise_unary_forward_kernel( template __global__ void elewise_unary_backward_kernel(coord_t volume, - const T scalar, + optional const scalar, OperatorType type, T const *output, T const *output_grad, @@ -255,7 +271,7 @@ __global__ void elewise_unary_backward_kernel(coord_t volume, break; } case OP_SCALAR_MULTIPLY: { - input_grad[i] += output_grad[i] * scalar; + input_grad[i] += output_grad[i] * scalar.value(); break; } case OP_SCALAR_ADD: { @@ -267,7 +283,7 @@ __global__ void elewise_unary_backward_kernel(coord_t volume, break; } case OP_SCALAR_TRUE_DIV: { - input_grad[i] += output_grad[i] / scalar; + input_grad[i] += output_grad[i] / scalar.value(); break; } case OP_GELU: { @@ -283,8 +299,8 @@ __global__ void elewise_unary_backward_kernel(coord_t volume, break; } case OP_POW: { - input_grad[i] = - (T)(output_grad[i] * scalar * powf(input[i], scalar - 1)); + input_grad[i] = (T)(output_grad[i] * scalar.value() * + powf(input[i], scalar.value() - 1)); break; } case OP_SIN: { diff --git a/lib/op-attrs/include/op-attrs/get_op_type.h b/lib/op-attrs/include/op-attrs/get_op_type.h index 4558584189..421c464843 100644 --- a/lib/op-attrs/include/op-attrs/get_op_type.h +++ b/lib/op-attrs/include/op-attrs/get_op_type.h @@ -15,6 +15,7 @@ OperatorType get_op_type(Conv2DAttrs const &); OperatorType get_op_type(DropoutAttrs const &); OperatorType get_op_type(ElementBinaryAttrs const &); OperatorType get_op_type(ElementUnaryAttrs const &); +OperatorType get_op_type(ElementScalarUnaryAttrs const &); OperatorType get_op_type(EmbeddingAttrs const &); OperatorType get_op_type(FlatAttrs const &); OperatorType get_op_type(GatherAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/get_output_shapes.h b/lib/op-attrs/include/op-attrs/get_output_shapes.h index 3daa97ba3e..5f78ec2d3f 100644 --- a/lib/op-attrs/include/op-attrs/get_output_shapes.h +++ b/lib/op-attrs/include/op-attrs/get_output_shapes.h @@ -130,6 +130,8 @@ ParallelTensorShape get_output_shape(ElementBinaryAttrs const &, ParallelTensorShape const &); ParallelTensorShape get_output_shape(ElementUnaryAttrs const &, ParallelTensorShape const &); +ParallelTensorShape get_output_shape(ElementScalarUnaryAttrs const &, + ParallelTensorShape const &); ParallelTensorShape get_output_shape(EmbeddingAttrs const &, ParallelTensorShape const &); ParallelTensorShape get_output_shape(FlatAttrs const &, @@ -238,6 +240,8 @@ bool is_valid_internal(ElementBinaryAttrs const &, ParallelTensorShape const &, ParallelTensorShape const &); bool is_valid_internal(ElementUnaryAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(ElementScalarUnaryAttrs const &, + ParallelTensorShape const &); bool is_valid_internal(EmbeddingAttrs const &, ParallelTensorShape const &); bool is_valid_internal(FlatAttrs const &, ParallelTensorShape const &); bool is_valid_internal(GatherAttrs const &, diff --git a/lib/op-attrs/include/op-attrs/operator_attrs.h b/lib/op-attrs/include/op-attrs/operator_attrs.h index c4eb78f9dd..a7ba84624c 100644 --- a/lib/op-attrs/include/op-attrs/operator_attrs.h +++ b/lib/op-attrs/include/op-attrs/operator_attrs.h @@ -43,6 +43,7 @@ using SharedOperatorAttrs = variant::value, ""); static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); +static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary.h b/lib/op-attrs/include/op-attrs/ops/element_unary.h index 6b5f7e19d9..cccc3e2a3e 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary.h @@ -10,12 +10,11 @@ namespace FlexFlow { struct ElementUnaryAttrs { req op_type; - req> scalar; }; -FF_VISITABLE_STRUCT(ElementUnaryAttrs, op_type, scalar); +FF_VISITABLE_STRUCT(ElementUnaryAttrs, op_type); CHECK_VALID_OP_ATTR(ElementUnaryAttrs); -struct ElementScalarUnaryAttrs : ElementUnaryAttrs { +struct ElementScalarUnaryAttrs { req op_type; req scalar; }; diff --git a/lib/op-attrs/src/get_op_type.cc b/lib/op-attrs/src/get_op_type.cc index b7a4116092..3fa401b647 100644 --- a/lib/op-attrs/src/get_op_type.cc +++ b/lib/op-attrs/src/get_op_type.cc @@ -27,7 +27,10 @@ OperatorType get_op_type(ElementBinaryAttrs const &attrs) { return attrs.type; } OperatorType get_op_type(ElementUnaryAttrs const &attrs) { - return attrs.op; + return attrs.op_type; +} +OperatorType get_op_type(ElementScalarUnaryAttrs const &attrs) { + return attrs.op_type; } OperatorType get_op_type(EmbeddingAttrs const &) { return Op::EMBEDDING; diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index 1a0cb6c461..5fc58079a2 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -249,9 +249,11 @@ struct ComputationGraphBuilder Tensor element_unary(ElementUnaryAttrs const &, Tensor const &input, optional const &name = nullopt); + Tensor element_scalar_unary(ElementScalarUnaryAttrs const &attrs, + Tensor const &x, + optional const &maybe_name) -public: - ComputationGraph computation_graph; + public : ComputationGraph computation_graph; }; } // namespace FlexFlow diff --git a/lib/pcg/src/computation_graph_builder.cc b/lib/pcg/src/computation_graph_builder.cc index d62403187f..55e7111873 100644 --- a/lib/pcg/src/computation_graph_builder.cc +++ b/lib/pcg/src/computation_graph_builder.cc @@ -78,6 +78,20 @@ Tensor ComputationGraphBuilder::element_unary( return this->add_layer(layer, {input}, {}, output_shape); } +Tensor ComputationGraphBuilder::element_scalar_unary( + ElementScalarUnaryAttrs const &attrs, + Tensor const &x, + optional const &maybe_name) { + std::string name = maybe_name.value_or(get_default_name(attrs)); + + Tensor input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); + + Layer layer = {widen(attrs), name}; + TensorShape output_shape = get_output_shape(attrs, input); + + return this->add_layer(layer, {input}, {}, output_shape); +} + Tensor ComputationGraphBuilder::element_unary(OperatorType op_type, Tensor const &input, @@ -91,8 +105,8 @@ Tensor ComputationGraphBuilder::element_scalar_unary( Tensor const &input, float scalar, optional const &name) { - ElementUnaryAttrs attrs = {op_type, scalar}; - return this->element_unary(attrs, input, name); + ElementScalarUnaryAttrs attrs = {op_type, scalar}; + return this->element_scalar_unary(attrs, input, name); } Tensor ComputationGraphBuilder::element_binary( diff --git a/lib/runtime/src/ops/element_unary.cc b/lib/runtime/src/ops/element_unary.cc index 80130147c7..f41a8b3551 100644 --- a/lib/runtime/src/ops/element_unary.cc +++ b/lib/runtime/src/ops/element_unary.cc @@ -24,7 +24,7 @@ enum Slots { }; /* ElementUnary */ -OpTaskInvocation init(ElementUnaryAttrs const &attrs) { +OpTaskInvocation init(ElementUnaryUnifiedAttrs const &attrs) { OpTaskBinding b; b.bind_arg(HANDLE, ff_handle()); @@ -34,7 +34,7 @@ OpTaskInvocation init(ElementUnaryAttrs const &attrs) { return {ELEMENTUNARY_INIT_TASK_ID, b}; } -OpTaskInvocation forward(ElementUnaryAttrs const &attrs) { +OpTaskInvocation forward(ElementUnaryUnifiedAttrs const &attrs) { OpTaskBinding b; b.bind(INPUT, input_tensor(0)); @@ -47,7 +47,7 @@ OpTaskInvocation forward(ElementUnaryAttrs const &attrs) { return {ELEMENTUNARY_FWD_TASK_ID, b}; } -OpTaskInvocation backward(ElementUnaryAttrs const &attrs) { +OpTaskInvocation backward(ElementUnaryUnifiedAttrs const &attrs) { OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); return {ELEMENTUNARY_BWD_TASK_ID, b}; @@ -56,7 +56,7 @@ OpTaskInvocation backward(ElementUnaryAttrs const &attrs) { static DeviceSpecific init_task_impl(TaskArgumentAccessor const &acc) { - auto const &attrs = acc.get_argument(ATTRS); + auto const &attrs = acc.get_argument(ATTRS); ProfilingSettings profiling = acc.get_argument(PROFILING); PerDeviceFFHandle handle = acc.get_argument(HANDLE); ParallelTensorShape input_shape = @@ -81,7 +81,7 @@ static DeviceSpecific static optional forward_task_impl(TaskArgumentAccessor const &acc) { auto input = acc.get_tensor(INPUT); auto output = acc.get_tensor(OUTPUT); - auto const &attrs = acc.get_argument(ATTRS); + auto const &attrs = acc.get_argument(ATTRS); auto &handle = acc.get_argument(HANDLE); @@ -113,7 +113,7 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { auto output = acc.get_tensor(OUTPUT); auto output_grad = acc.get_tensor_grad(OUTPUT); - auto const &attrs = acc.get_argument(ATTRS); + auto const &attrs = acc.get_argument(ATTRS); auto &handle = acc.get_argument(HANDLE); auto per_device_state = @@ -141,7 +141,7 @@ static void backward_task(Task const *task, } CostMetrics measure_operator_cost(SimEnvFactory const &sim, - ElementUnaryAttrs const &attrs, + ElementUnaryUnifiedAttrs const &attrs, InputParallelTensorDesc const &input_shape, ProfilingSettings const &settings, MachineView const &mv) { @@ -183,7 +183,7 @@ template <> OpTaskSignature init_signature() { OpTaskSignature init(OpTaskType::INIT); init.add_arg_slot(INPUT_SHAPE); - init.add_arg_slot(ATTRS); + init.add_arg_slot(ATTRS); init.add_unchecked_arg_slot(HANDLE); init.add_return_value(); diff --git a/lib/runtime/src/ops/element_unary.h b/lib/runtime/src/ops/element_unary.h index d41cb65c7b..f44efc28db 100644 --- a/lib/runtime/src/ops/element_unary.h +++ b/lib/runtime/src/ops/element_unary.h @@ -7,6 +7,9 @@ namespace FlexFlow { +using ElementUnaryUnifiedAttrs = + variant; + template <> void register_task(); template <> @@ -14,12 +17,12 @@ void register_task(); template <> void register_task(); -OpTaskInvocation init(ElementUnaryAttrs const &); -OpTaskInvocation forward(ElementUnaryAttrs const &); -OpTaskInvocation backward(ElementUnaryAttrs const &); +OpTaskInvocation init(ElementUnaryUnifiedAttrs const &); +OpTaskInvocation forward(ElementUnaryUnifiedAttrs const &); +OpTaskInvocation backward(ElementUnaryUnifiedAttrs const &); CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, - ElementUnaryAttrs const &attrs, + ElementUnaryUnifiedAttrs const &attrs, InputParallelTensorDesc const &input_shape, ProfilingSettings const &settings, MachineView const &machine_view); diff --git a/lib/runtime/test/src/test_serialization.cc b/lib/runtime/test/src/test_serialization.cc index ac8dcae50b..d80808b7fb 100644 --- a/lib/runtime/test/src/test_serialization.cc +++ b/lib/runtime/test/src/test_serialization.cc @@ -20,6 +20,7 @@ TEST_CASE("Serialization") { CombineAttrs combine_attrs, ConcatAttrs concat_attrs, Conv2DAttrs conv2d_attrs, DropoutAttrs dropout_attrs, ElementBinaryAttrs elem_bin_attrs, ElementUnaryAttrs elem_unary_attrs, + ElementScalarUnaryAttrs elem_scalar_unary_attrs, EmbeddingAttrs embedding_attrs, FlatAttrs flat_attrs, GatherAttrs gather_attrs, InputAttrs input_attrs, LayerNormAttrs layer_norm_attrs, LinearAttrs linear_attrs, diff --git a/lib/substitutions/include/substitutions/get_attribute.h b/lib/substitutions/include/substitutions/get_attribute.h index f35145133e..50c4108a67 100644 --- a/lib/substitutions/include/substitutions/get_attribute.h +++ b/lib/substitutions/include/substitutions/get_attribute.h @@ -25,9 +25,7 @@ optional get_attribute(ElementUnaryAttrs const &p, OperatorAttributeKey); optional get_attribute(DropoutAttrs const &p, OperatorAttributeKey); -optional get_attribute(ElementBinaryAttrs const &p, - OperatorAttributeKey); -optional get_attribute(ElementUnaryAttrs const &p, +optional get_attribute(ElementScalarUnaryAttrs const &p, OperatorAttributeKey); optional get_attribute(EmbeddingAttrs const &p, OperatorAttributeKey); diff --git a/lib/substitutions/src/operator_attributes.cc b/lib/substitutions/src/operator_attributes.cc index eef0464c42..3922b091a7 100644 --- a/lib/substitutions/src/operator_attributes.cc +++ b/lib/substitutions/src/operator_attributes.cc @@ -85,6 +85,14 @@ optional get_attribute(ElementUnaryAttrs const &p, } } +optional get_attribute(ElementScalarUnaryAttrs const &p, + OperatorAttributeKey key) { + switch (key) { + default: + return nullopt; + } +} + optional get_attribute(DropoutAttrs const &p, OperatorAttributeKey key) { switch (key) { diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc index d59f52054c..72c9248e6c 100644 --- a/lib/substitutions/src/substitution.cc +++ b/lib/substitutions/src/substitution.cc @@ -181,10 +181,19 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, case Op::SCALAR_MULTIPLY: case Op::SCALAR_SUB: case Op::SCALAR_TRUE_DIV: - return Operator(ElementUnaryAttrs{op_type, - get(assignments.at( - OperatorAttributeKey::SCALAR))}, - nullopt); + return Operator( + ElementScalarUnaryAttrs{ + op_type, + get(assignments.at(OperatorAttributeKey::SCALAR))}, + nullopt); + case Op::EXP: + case Op::IDENTITY: + case Op::GELU: + case Op::RSQRT: + case Op::POW: + case Op::SIN: + case Op::COS: + return Operator(ElementUnaryAttrs{op_type}, nullopt); case Op::EMBEDDING: return Operator( EmbeddingAttrs{ From c85c8dc5fbbec59931468d8de9a7487b21fcfc6d Mon Sep 17 00:00:00 2001 From: Reyna Abhyankar Date: Thu, 22 Feb 2024 19:38:16 -0800 Subject: [PATCH 15/18] Fix format and build --- lib/pcg/include/pcg/computation_graph_builder.h | 5 +++-- lib/pcg/src/computation_graph_builder.cc | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index 5fc58079a2..ae937c590d 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -251,9 +251,10 @@ struct ComputationGraphBuilder optional const &name = nullopt); Tensor element_scalar_unary(ElementScalarUnaryAttrs const &attrs, Tensor const &x, - optional const &maybe_name) + optional const &maybe_name); - public : ComputationGraph computation_graph; +public: + ComputationGraph computation_graph; }; } // namespace FlexFlow diff --git a/lib/pcg/src/computation_graph_builder.cc b/lib/pcg/src/computation_graph_builder.cc index 55e7111873..9f8e930919 100644 --- a/lib/pcg/src/computation_graph_builder.cc +++ b/lib/pcg/src/computation_graph_builder.cc @@ -72,7 +72,7 @@ Tensor ComputationGraphBuilder::element_unary( Tensor input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); - Layer layer = {widen(attrs), name}; + Layer layer = {attrs, name}; TensorShape output_shape = get_output_shape(attrs, input); return this->add_layer(layer, {input}, {}, output_shape); @@ -86,7 +86,7 @@ Tensor ComputationGraphBuilder::element_scalar_unary( Tensor input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); - Layer layer = {widen(attrs), name}; + Layer layer = {attrs, name}; TensorShape output_shape = get_output_shape(attrs, input); return this->add_layer(layer, {input}, {}, output_shape); From 10a95061b124d502290f97938208111ea27dc580 Mon Sep 17 00:00:00 2001 From: Reyna Abhyankar Date: Wed, 6 Mar 2024 17:15:34 -0800 Subject: [PATCH 16/18] Update lib/op-attrs/include/op-attrs/ops/element_unary.h --- lib/op-attrs/include/op-attrs/ops/element_unary.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary.h b/lib/op-attrs/include/op-attrs/ops/element_unary.h index cccc3e2a3e..dc7424a490 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary.h @@ -15,7 +15,7 @@ FF_VISITABLE_STRUCT(ElementUnaryAttrs, op_type); CHECK_VALID_OP_ATTR(ElementUnaryAttrs); struct ElementScalarUnaryAttrs { - req op_type; +Op op_type; req scalar; }; FF_VISITABLE_STRUCT(ElementScalarUnaryAttrs, op_type, scalar); From 8b31b6887a95c6970dddfc8b847438ed512a77bd Mon Sep 17 00:00:00 2001 From: Reyna Abhyankar Date: Wed, 6 Mar 2024 17:16:39 -0800 Subject: [PATCH 17/18] Update lib/op-attrs/include/op-attrs/ops/element_unary.h --- lib/op-attrs/include/op-attrs/ops/element_unary.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary.h b/lib/op-attrs/include/op-attrs/ops/element_unary.h index dc7424a490..9a5a4dcc1c 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary.h @@ -15,7 +15,7 @@ FF_VISITABLE_STRUCT(ElementUnaryAttrs, op_type); CHECK_VALID_OP_ATTR(ElementUnaryAttrs); struct ElementScalarUnaryAttrs { -Op op_type; + Op op_type; req scalar; }; FF_VISITABLE_STRUCT(ElementScalarUnaryAttrs, op_type, scalar); From 8425ed546a034e09942ba010592d5f8aa4185961 Mon Sep 17 00:00:00 2001 From: Reyna Abhyankar Date: Wed, 6 Mar 2024 17:16:53 -0800 Subject: [PATCH 18/18] Update lib/op-attrs/include/op-attrs/ops/element_unary.h --- lib/op-attrs/include/op-attrs/ops/element_unary.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary.h b/lib/op-attrs/include/op-attrs/ops/element_unary.h index 9a5a4dcc1c..5e19b81c8c 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary.h @@ -15,7 +15,7 @@ FF_VISITABLE_STRUCT(ElementUnaryAttrs, op_type); CHECK_VALID_OP_ATTR(ElementUnaryAttrs); struct ElementScalarUnaryAttrs { - Op op_type; + Op op_type; req scalar; }; FF_VISITABLE_STRUCT(ElementScalarUnaryAttrs, op_type, scalar);