From 804f3a79aea1957225f50033a59c22cf126f8979 Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Wed, 9 Dec 2020 15:23:06 -0800 Subject: [PATCH] Add 1:many conversions in nnvm_to_onnx and non-flatten GEMM Signed-off-by: Serge Panev --- .../subgraph/tensorrt/nnvm_to_onnx-inl.h | 69 +++- .../subgraph/tensorrt/nnvm_to_onnx.cc | 389 ++++++++++++++---- 2 files changed, 352 insertions(+), 106 deletions(-) diff --git a/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h b/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h index d444e7a6239a..be6ebd05350b 100644 --- a/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h +++ b/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h @@ -47,7 +47,8 @@ using namespace nnvm; using namespace ::onnx; using int64 = ::google::protobuf::int64; -std::unordered_map GetPlaceholderShapes(const ShapeVector& shape_inputs, +std::unordered_map GetPlaceholderShapes( + const ShapeVector& shape_inputs, const nnvm::IndexedGraph& ig); std::unordered_map GetPlaceholderDTypes(const DTypeVector& dtype_inputs, @@ -70,7 +71,12 @@ void ConvertOutput(GraphProto* graph_proto, const std::string& node_name, const ShapeVector& shapes, const DTypeVector& dtypes, const nnvm::IndexedGraph &ig); -typedef void (*ConverterFunction)(NodeProto *node_proto, +void DefaultConnectInputsOutputs(const array_view& inputs, + const nnvm::IndexedGraph& ig, + const std::string& node_name); + +typedef void (*ConverterFunction)(GraphProto *graph_proto, + const std::string& node_name, const NodeAttrs &attrs, const nnvm::IndexedGraph &ig, const array_view &inputs); @@ -84,88 +90,112 @@ void ConvDeconvConvertHelper(NodeProto *node_proto, ConvDeconvType type); // Forward declarations -void ConvertIdentity(NodeProto* node_proto, +void ConvertIdentity(GraphProto *graph_proto, + const std::string& node_name, const NodeAttrs &attrs, const nnvm::IndexedGraph& ig, const array_view &inputs); void ConvertConvolution( - NodeProto *node_proto, + GraphProto *graph_proto, + const std::string& node_name, const NodeAttrs &attrs, const nnvm::IndexedGraph &ig, const array_view &inputs); -void ConvertDeconvolution(NodeProto *node_proto, +void ConvertDeconvolution(GraphProto *graph_proto, + const std::string& node_name, const NodeAttrs &attrs, const nnvm::IndexedGraph &ig, const array_view &inputs); -void ConvertPooling(NodeProto *node_proto, +void ConvertPooling(GraphProto *graph_proto, + const std::string& node_name, const NodeAttrs &attrs, const nnvm::IndexedGraph &ig, const array_view &inputs); -void ConvertRelu(NodeProto *node_proto, +void ConvertRelu(GraphProto *graph_proto, + const std::string& node_name, const NodeAttrs &attrs, const nnvm::IndexedGraph &ig, const array_view &inputs); -void ConvertActivation(NodeProto *node_proto, +void ConvertActivation(GraphProto *graph_proto, + const std::string& node_name, const NodeAttrs &attrs, const nnvm::IndexedGraph &ig, const array_view &inputs); -void ConvertFullyConnected(NodeProto *node_proto, +void ConvertFullyConnected(GraphProto *graph_proto, + const std::string& node_name, const NodeAttrs &attrs, const nnvm::IndexedGraph &ig, const array_view &inputs); -void ConvertSoftmaxOutput(NodeProto *node_proto, + +void ConvertSlice(GraphProto *graph_proto, + const std::string& node_name, + const NodeAttrs &attrs, + const nnvm::IndexedGraph &ig, + const array_view &inputs); + +void ConvertSoftmaxOutput(GraphProto *graph_proto, + const std::string& node_name, const NodeAttrs &attrs, const nnvm::IndexedGraph &ig, const array_view &inputs); -void ConvertFlatten(NodeProto *node_proto, +void ConvertFlatten(GraphProto *graph_proto, + const std::string& node_name, const NodeAttrs &attrs, const nnvm::IndexedGraph &ig, const array_view &inputs); -void ConvertDropout(NodeProto *node_proto, +void ConvertDropout(GraphProto *graph_proto, + const std::string& node_name, const NodeAttrs &attrs, const nnvm::IndexedGraph &ig, const array_view &inputs); -void ConvertBatchNorm(NodeProto *node_proto, +void ConvertBatchNorm(GraphProto *graph_proto, + const std::string& node_name, const NodeAttrs &attrs, const nnvm::IndexedGraph &ig, const array_view &inputs); -void ConvertElementwiseAdd(NodeProto *node_proto, +void ConvertElementwiseAdd(GraphProto *graph_proto, + const std::string& node_name, const NodeAttrs &attrs, const nnvm::IndexedGraph &ig, const array_view &inputs); -void ConvertElementwiseMul(NodeProto *node_proto, +void ConvertElementwiseMul(GraphProto *graph_proto, + const std::string& node_name, const NodeAttrs &attrs, const nnvm::IndexedGraph &ig, const array_view &inputs); -void ConvertElementwiseSub(NodeProto *node_proto, +void ConvertElementwiseSub(GraphProto *graph_proto, + const std::string& node_name, const NodeAttrs &attrs, const nnvm::IndexedGraph &ig, const array_view &inputs); -void ConvertConcatenate(NodeProto *node_proto, +void ConvertConcatenate(GraphProto *graph_proto, + const std::string& node_name, const NodeAttrs &attrs, const nnvm::IndexedGraph &ig, const array_view &inputs); -void ConvertClip(NodeProto *node_proto, +void ConvertClip(GraphProto *graph_proto, + const std::string& node_name, const NodeAttrs &attrs, const nnvm::IndexedGraph &ig, const array_view &inputs); -void ConvertPad(NodeProto* node_proto, +void ConvertPad(GraphProto *graph_proto, + const std::string& node_name, const NodeAttrs & attrs, const nnvm::IndexedGraph &ig, const array_view &inputs); @@ -190,6 +220,7 @@ static const std::unordered_map converter_map = {"Pad", ConvertPad}, {"Pooling", ConvertPooling}, {"relu", ConvertRelu}, + {"slice", ConvertSlice}, {"SoftmaxOutput", ConvertSoftmaxOutput} }; diff --git a/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc b/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc index 4f80d277cad8..cdc715176dd0 100644 --- a/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc +++ b/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc @@ -130,8 +130,6 @@ std::string ConvertNnvmGraphToOnnx( } // is_placeholder } else { // It's an op, rather than a "variable" (constant or placeholder) - NodeProto* node_proto = graph_proto->add_node(); - node_proto->set_name(node_name); if (converter_map.count(op->name) == 0) { LOG(FATAL) << "Conversion for node of type " << op->name << " (node " << node_name << ") " @@ -140,19 +138,7 @@ std::string ConvertNnvmGraphToOnnx( // Find function ptr to a converter based on the op name, and invoke the converter. This // looks unsafe because find may not succeed, but it does because we're in the operator // logic after testing that this node name does not represent a variable. - converter_map.find(op->name)->second(node_proto, attrs, ig, node.inputs); - // Add all inputs to the current node (i.e. add graph edges) - for (const nnvm::IndexedGraph::NodeEntry& entry : node.inputs) { - std::string in_node_name = ig[entry.node_id].source->attrs.name; - // As before, we're not adding labels e.g. for SoftmaxOutput, but I wish there was a less - // hacky way to do it than name matching. - if (in_node_name.find("label") != std::string::npos) { - continue; - } - node_proto->add_input(in_node_name); - } - // The node's output will have the same name as the node name. - node_proto->add_output(node_name); + converter_map.find(op->name)->second(graph_proto, node_name, attrs, ig, node.inputs); // See if the current node is an output node auto out_iter = output_lookup.find(node_name); // We found an output @@ -171,16 +157,113 @@ std::string ConvertNnvmGraphToOnnx( return serialized_onnx_graph; } -void ConvertIdentity(NodeProto* node_proto, const NodeAttrs& attrs, - const nnvm::IndexedGraph& /*ig*/, - const array_view& /*inputs*/) { +void DefaultConnectInputsOutputs(NodeProto *node_proto, + const array_view& inputs, + const nnvm::IndexedGraph& ig, + const std::string& node_name) { + for (const nnvm::IndexedGraph::NodeEntry& entry : inputs) { + std::string in_node_name = ig[entry.node_id].source->attrs.name; + // As before, we're not adding labels e.g. for SoftmaxOutput, but I wish there was a less + // hacky way to do it than name matching. + if (in_node_name.find("label") != std::string::npos) { + continue; + } + node_proto->add_input(in_node_name); + } + // The node's output will have the same name as the node name. + node_proto->add_output(node_name); +} + +TensorProto* const Make1DTensor(GraphProto* const graph_proto, const int64_t& size, + const std::string& name, const TensorProto_DataType& dtype) { + TensorProto* const initializer_proto = graph_proto->add_initializer(); + initializer_proto->set_name(name); + initializer_proto->set_data_type(dtype); + initializer_proto->add_dims(static_cast(size)); + + ValueInfoProto* const input_proto = graph_proto->add_input(); + input_proto->set_name(name); + auto var = input_proto->mutable_type()->mutable_tensor_type(); + var->set_elem_type(dtype); + var->mutable_shape()->add_dim()->set_dim_value(static_cast(size)); + return initializer_proto; +} + +// Keep for when ONNX version will be updated +/* +void ConvertSlice(GraphProto* const graph_proto, const Node* node, const Graph& g) { + const auto& params = nnvm::get(node->attrs.parsed); + int64 nb_slices = static_cast(params.begin.ndim()); + + // starts + auto init_starts = Make1DTensor(graph_proto, nb_slices, node->attrs.name + "_starts", + TensorProto_DataType_INT64); + for (auto& opt : params.begin) { + if (opt.has_value()) { + init_starts->add_int64_data(static_cast(opt.value())); + } else { + init_starts->add_int64_data(static_cast(0)); + } + } + + // ends + auto init_ends = Make1DTensor(graph_proto, nb_slices, node->attrs.name + "_ends", + TensorProto_DataType_INT64); + for (auto& opt : params.end) { + if (opt.has_value()) { + init_ends->add_int64_data(static_cast(opt.value())); + } else { + init_ends->add_int64_data(static_cast(INT_MAX)); + } + } + + // axes + auto init_axes = Make1DTensor(graph_proto, nb_slices, node->attrs.name + "_axes", + TensorProto_DataType_INT64); + for (int64_t i = 0; i < nb_slices; ++i) { + init_axes->add_int64_data(static_cast(i)); + } + + // slice node + NodeProto* node_proto = graph_proto->add_node(); + node_proto->set_name(node->attrs.name); + node_proto->set_op_type("Slice"); + node_proto->add_input(node->inputs[0].node->attrs.name); + node_proto->add_input(node->attrs.name + "_starts"); + node_proto->add_input(node->attrs.name + "_ends"); + node_proto->add_input(node->attrs.name + "_axes"); + + // steps + if (params.step.ndim() != 0) { + auto init_steps = Make1DTensor(graph_proto, nb_slices, node->attrs.name + "_steps", + TensorProto_DataType_INT64); + for (auto& opt : params.step) { + if (opt.has_value()) { + init_steps->add_int64_data(static_cast(opt.value())); + } else { + init_steps->add_int64_data(static_cast(1)); + } + } + node_proto->add_input(node->attrs.name + "_steps"); + } + + node_proto->add_output(node->attrs.name); +} +*/ + +void ConvertIdentity(GraphProto *graph_proto, const std::string& node_name, const NodeAttrs& attrs, + const nnvm::IndexedGraph& ig, + const array_view& inputs) { + NodeProto* node_proto = graph_proto->add_node(); + node_proto->set_name(node_name); node_proto->set_op_type("Identity"); + DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name); } template -void ConvDeconvConvertHelper(NodeProto* node_proto, const NodeAttrs& attrs, - const nnvm::IndexedGraph& /*ig*/, - const array_view& /*input*/, +void ConvDeconvConvertHelper(NodeProto *node_proto, const NodeAttrs& attrs, + const nnvm::IndexedGraph& ig, + const array_view& inputs, const ConvDeconvParam& param, ConvDeconvType type) { if (type == ConvDeconvType::Convolution) { @@ -240,25 +323,36 @@ void ConvDeconvConvertHelper(NodeProto* node_proto, const NodeAttrs& attrs, } } -void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs, +void ConvertConvolution(GraphProto *graph_proto, const std::string& node_name, + const NodeAttrs& attrs, const nnvm::IndexedGraph& ig, const array_view& inputs) { + NodeProto* node_proto = graph_proto->add_node(); + node_proto->set_name(node_name); const auto& conv_param = nnvm::get(attrs.parsed); ConvDeconvConvertHelper(node_proto, attrs, ig, inputs, conv_param, ConvDeconvType::Convolution); + DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name); } // end ConvertConvolution -void ConvertDeconvolution(NodeProto* node_proto, const NodeAttrs& attrs, +void ConvertDeconvolution(GraphProto *graph_proto, const std::string& node_name, + const NodeAttrs& attrs, const nnvm::IndexedGraph& ig, const array_view& inputs) { + NodeProto* node_proto = graph_proto->add_node(); + node_proto->set_name(node_name); const auto& deconv_param = nnvm::get(attrs.parsed); ConvDeconvConvertHelper(node_proto, attrs, ig, inputs, deconv_param, ConvDeconvType::Deconvolution); + DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name); } // end ConvertDeconvolution -void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs, - const nnvm::IndexedGraph& /*ig*/, - const array_view& /*inputs*/) { +void ConvertPooling(GraphProto *graph_proto, const std::string& node_name, + const NodeAttrs& attrs, + const nnvm::IndexedGraph& ig, + const array_view& inputs) { + NodeProto* node_proto = graph_proto->add_node(); + node_proto->set_name(node_name); const auto& pooling_param = nnvm::get(attrs.parsed); const mxnet::TShape kernel = pooling_param.kernel; @@ -275,6 +369,7 @@ void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs, } else { LOG(FATAL) << "Pool type of node '" << attrs.name << "' unsupported: " << attrs.name; } + DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name); return; } @@ -329,17 +424,24 @@ void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs, } else { count_include_pad->set_i(1); } + DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name); } // end ConvertPooling -void ConvertRelu(NodeProto* node_proto, const NodeAttrs& /*attrs*/, - const nnvm::IndexedGraph& /*ig*/, - const array_view& /*inputs*/) { +void ConvertRelu(GraphProto *graph_proto, const std::string& node_name, const NodeAttrs& /*attrs*/, + const nnvm::IndexedGraph& ig, + const array_view& inputs) { + NodeProto* node_proto = graph_proto->add_node(); + node_proto->set_name(node_name); node_proto->set_op_type("Relu"); + DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name); } -void ConvertActivation(NodeProto* node_proto, const NodeAttrs& attrs, - const nnvm::IndexedGraph& /*ig*/, - const array_view& /*inputs*/) { +void ConvertActivation(GraphProto *graph_proto, const std::string& node_name, + const NodeAttrs& attrs, + const nnvm::IndexedGraph& ig, + const array_view& inputs) { + NodeProto* node_proto = graph_proto->add_node(); + node_proto->set_name(node_name); const auto& act_param = nnvm::get(attrs.parsed); std::string act_type; switch (act_param.act_type) { @@ -361,42 +463,120 @@ void ConvertActivation(NodeProto* node_proto, const NodeAttrs& attrs, } node_proto->set_op_type(act_type); + DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name); } -void ConvertFullyConnected(NodeProto* node_proto, const NodeAttrs& attrs, - const nnvm::IndexedGraph& /*ig*/, - const array_view& /*inputs*/) { +void ConvertFullyConnected(GraphProto *graph_proto, const std::string& node_name, + const NodeAttrs& attrs, + const nnvm::IndexedGraph& ig, + const array_view& inputs) { const auto& act_param = nnvm::get(attrs.parsed); - if (act_param.no_bias) { - node_proto->set_op_type("MatMul"); + // ONNX spec doesn't support GEMMs with input of different dims, so we need to replace it + // by Transpose+MatMul+Add + if (!act_param.flatten && !act_param.no_bias) { + NodeProto* tranpose_node_proto = graph_proto->add_node(); + NodeProto* matmul_node_proto = graph_proto->add_node(); + NodeProto* add_node_proto = graph_proto->add_node(); + tranpose_node_proto->set_name(node_name+"_Transpose"); + matmul_node_proto->set_name(node_name+"_MatMul"); + add_node_proto->set_name(node_name+"_Add"); + + tranpose_node_proto->set_op_type("Transpose"); + matmul_node_proto->set_op_type("MatMul"); + add_node_proto->set_op_type("Add"); + + std::string input_node_name = ig[inputs[op::conv::kData].node_id].source->attrs.name; + std::string weight_node_name = ig[inputs[op::conv::kWeight].node_id].source->attrs.name; + std::string bias_node_name = ig[inputs[op::conv::kBias].node_id].source->attrs.name; + + tranpose_node_proto->add_input(weight_node_name); + tranpose_node_proto->add_output(node_name+"_Transpose"); + + matmul_node_proto->add_input(input_node_name); + matmul_node_proto->add_input(node_name+"_Transpose"); + matmul_node_proto->add_output(node_name+"_MatMul"); + + add_node_proto->add_input(node_name+"_MatMul"); + add_node_proto->add_input(bias_node_name); + // Add's output is the output of the Transpose+MatMul+Add subgraph + add_node_proto->add_output(node_name); } else { - node_proto->set_op_type("Gemm"); - - AttributeProto* const alpha = node_proto->add_attribute(); - alpha->set_name("alpha"); - alpha->set_type(AttributeProto::FLOAT); - alpha->set_f(1.0f); - - AttributeProto* const beta = node_proto->add_attribute(); - beta->set_name("beta"); - beta->set_type(AttributeProto::FLOAT); - beta->set_f(1.0f); - - AttributeProto* const transA = node_proto->add_attribute(); - transA->set_name("transA"); - transA->set_type(AttributeProto::INT); - transA->set_i(0); + NodeProto* node_proto = graph_proto->add_node(); + node_proto->set_name(node_name); + if (act_param.no_bias) { + node_proto->set_op_type("MatMul"); + } else { + node_proto->set_op_type("Gemm"); + + AttributeProto* const alpha = node_proto->add_attribute(); + alpha->set_name("alpha"); + alpha->set_type(AttributeProto::FLOAT); + alpha->set_f(1.0f); + + AttributeProto* const beta = node_proto->add_attribute(); + beta->set_name("beta"); + beta->set_type(AttributeProto::FLOAT); + beta->set_f(1.0f); + + AttributeProto* const transA = node_proto->add_attribute(); + transA->set_name("transA"); + transA->set_type(AttributeProto::INT); + transA->set_i(0); + + AttributeProto* const transB = node_proto->add_attribute(); + transB->set_name("transB"); + transB->set_type(AttributeProto::INT); + transB->set_i(1); + } + DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name); + } +} - AttributeProto* const transB = node_proto->add_attribute(); - transB->set_name("transB"); - transB->set_type(AttributeProto::INT); - transB->set_i(1); +void ConvertSlice(GraphProto *graph_proto, const std::string& node_name, const NodeAttrs& attrs, + const nnvm::IndexedGraph& ig, + const array_view& inputs) { + NodeProto* node_proto = graph_proto->add_node(); + node_proto->set_name(node_name); + const auto& params = nnvm::get(attrs.parsed); + node_proto->set_op_type("Slice"); + + // starts + AttributeProto* const starts = node_proto->add_attribute(); + starts->set_name("starts"); + starts->set_type(AttributeProto::INTS); + + // ends + AttributeProto* const ends = node_proto->add_attribute(); + ends->set_name("ends"); + ends->set_type(AttributeProto::INTS); + + // axes + AttributeProto* const axes = node_proto->add_attribute(); + axes->set_name("axes"); + axes->set_type(AttributeProto::INTS); + + for (int64_t i = 1; i < params.begin.ndim(); ++i) { + if (params.begin[i].has_value()) { + starts->add_ints(static_cast(params.begin[i].value())); + } else { + starts->add_ints(static_cast(0)); + } + if (params.end[i].has_value()) { + ends->add_ints(static_cast(params.end[i].value())); + } else { + ends->add_ints(static_cast(INT_MAX)); + } + axes->add_ints(static_cast(i)); } + DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name); } -void ConvertSoftmaxOutput(NodeProto* node_proto, const NodeAttrs& /*attrs*/, - const nnvm::IndexedGraph& /*ig*/, - const array_view& /*inputs*/) { +void ConvertSoftmaxOutput(GraphProto *graph_proto, const std::string& node_name, + const NodeAttrs& /*attrs*/, + const nnvm::IndexedGraph& ig, + const array_view& inputs) { + NodeProto* node_proto = graph_proto->add_node(); + node_proto->set_name(node_name); node_proto->set_op_type("Softmax"); // Setting by default to 1 since MXNet doesn't provide such an attribute for softmax in its @@ -406,11 +586,16 @@ void ConvertSoftmaxOutput(NodeProto* node_proto, const NodeAttrs& /*attrs*/, axis->set_name("axis"); axis->set_type(AttributeProto::INT); axis->set_i(1); + DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name); } -void ConvertFlatten(NodeProto* node_proto, const NodeAttrs& /*attrs*/, - const nnvm::IndexedGraph& /*ig*/, - const array_view& /*inputs*/) { + +void ConvertFlatten(GraphProto *graph_proto, const std::string& node_name, + const NodeAttrs& /*attrs*/, + const nnvm::IndexedGraph& ig, + const array_view& inputs) { + NodeProto* node_proto = graph_proto->add_node(); + node_proto->set_name(node_name); node_proto->set_op_type("Flatten"); // Setting by default to 1 since MXNet doesn't provide such an attribute for Flatten in its @@ -420,11 +605,15 @@ void ConvertFlatten(NodeProto* node_proto, const NodeAttrs& /*attrs*/, axis->set_name("axis"); axis->set_type(AttributeProto::INT); axis->set_i(1); + DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name); } -void ConvertBatchNorm(NodeProto* node_proto, const NodeAttrs& attrs, - const nnvm::IndexedGraph& /*ig*/, - const array_view& /*inputs*/) { +void ConvertBatchNorm(GraphProto *graph_proto, const std::string& node_name, + const NodeAttrs& attrs, + const nnvm::IndexedGraph& ig, + const array_view& inputs) { + NodeProto* node_proto = graph_proto->add_node(); + node_proto->set_name(node_name); node_proto->set_op_type("BatchNormalization"); const auto& param = nnvm::get(attrs.parsed); @@ -445,29 +634,45 @@ void ConvertBatchNorm(NodeProto* node_proto, const NodeAttrs& attrs, // (default in ONNX3) implies running batchnorm on all spatial features so we need to explicitly // disable this for MXNet's BatchNorm. spatial->set_i(0); + DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name); } -void ConvertElementwiseAdd(NodeProto* node_proto, const NodeAttrs& /*attrs*/, - const nnvm::IndexedGraph& /*ig*/, - const array_view& /*inputs*/) { +void ConvertElementwiseAdd(GraphProto *graph_proto, const std::string& node_name, + const NodeAttrs& /*attrs*/, + const nnvm::IndexedGraph& ig, + const array_view& inputs) { + NodeProto* node_proto = graph_proto->add_node(); + node_proto->set_name(node_name); node_proto->set_op_type("Add"); + DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name); } -void ConvertElementwiseSub(NodeProto* node_proto, const NodeAttrs& /*attrs*/, - const nnvm::IndexedGraph& /*ig*/, - const array_view& /*inputs*/) { +void ConvertElementwiseSub(GraphProto *graph_proto, const std::string& node_name, + const NodeAttrs& /*attrs*/, + const nnvm::IndexedGraph& ig, + const array_view& inputs) { + NodeProto* node_proto = graph_proto->add_node(); + node_proto->set_name(node_name); node_proto->set_op_type("Sub"); + DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name); } -void ConvertElementwiseMul(NodeProto* node_proto, const NodeAttrs& /*attrs*/, - const nnvm::IndexedGraph& /*ig*/, - const array_view& /*inputs*/) { +void ConvertElementwiseMul(GraphProto *graph_proto, const std::string& node_name, + const NodeAttrs& /*attrs*/, + const nnvm::IndexedGraph& ig, + const array_view& inputs) { + NodeProto* node_proto = graph_proto->add_node(); + node_proto->set_name(node_name); node_proto->set_op_type("Mul"); + DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name); } -void ConvertConcatenate(NodeProto* node_proto, const NodeAttrs& attrs, - const nnvm::IndexedGraph& /*ig*/, - const array_view& /*inputs*/) { +void ConvertConcatenate(GraphProto *graph_proto, const std::string& node_name, + const NodeAttrs& attrs, + const nnvm::IndexedGraph& ig, + const array_view& inputs) { + NodeProto* node_proto = graph_proto->add_node(); + node_proto->set_name(node_name); const auto& _param = nnvm::get(attrs.parsed); node_proto->set_op_type("Concat"); node_proto->set_name(attrs.name); @@ -476,6 +681,7 @@ void ConvertConcatenate(NodeProto* node_proto, const NodeAttrs& attrs, axis->set_name("axis"); axis->set_type(AttributeProto::INT); axis->set_i(static_cast(_param.dim)); + DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name); } inline TensorProto_DataType ConvertDType(int dtype) { @@ -630,9 +836,11 @@ void ConvertOutput( } } -void ConvertClip(NodeProto* node_proto, const NodeAttrs& attrs, - const nnvm::IndexedGraph& /*ig*/, - const array_view& /*inputs*/) { +void ConvertClip(GraphProto *graph_proto, const std::string& node_name, const NodeAttrs& attrs, + const nnvm::IndexedGraph& ig, + const array_view& inputs) { + NodeProto* node_proto = graph_proto->add_node(); + node_proto->set_name(node_name); const auto& param = nnvm::get(attrs.parsed); node_proto->set_op_type("Clip"); @@ -648,11 +856,14 @@ void ConvertClip(NodeProto* node_proto, const NodeAttrs& attrs, a_min->set_name("min"); a_min->set_type(AttributeProto::FLOAT); a_min->set_f(static_cast(param.a_min)); + DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name); } -void ConvertPad(NodeProto* node_proto, const NodeAttrs& attrs, - const nnvm::IndexedGraph& /*ig*/, - const array_view& /*inputs*/) { +void ConvertPad(GraphProto *graph_proto, const std::string& node_name, const NodeAttrs& attrs, + const nnvm::IndexedGraph& ig, + const array_view& inputs) { + NodeProto* node_proto = graph_proto->add_node(); + node_proto->set_name(node_name); const auto& param = nnvm::get(attrs.parsed); node_proto->set_op_type("Pad"); @@ -694,12 +905,16 @@ void ConvertPad(NodeProto* node_proto, const NodeAttrs& attrs, value->set_name("value"); value->set_type(AttributeProto::FLOAT); value->set_f(param.constant_value); + DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name); } -void ConvertDropout(NodeProto* node_proto, const NodeAttrs& attrs, - const nnvm::IndexedGraph& /*ig*/, - const array_view& /*inputs*/) { +void ConvertDropout(GraphProto *graph_proto, const std::string& node_name, const NodeAttrs& attrs, + const nnvm::IndexedGraph& ig, + const array_view& inputs) { + NodeProto* node_proto = graph_proto->add_node(); + node_proto->set_name(node_name); node_proto->set_op_type("Dropout"); + DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name); } void PreprocessBatchNorm(const NodeAttrs &attrs,