From 4aaf27c673b77884d75cf857f46ebdb2394ed30f Mon Sep 17 00:00:00 2001 From: Wang Date: Wed, 30 May 2018 17:05:41 -0700 Subject: [PATCH 01/22] Add SSD tutorial for CPU --- nnvm/include/nnvm/top/nn.h | 58 +++++++ nnvm/python/nnvm/compiler/build_module.py | 42 ++++- nnvm/python/nnvm/compiler/graph_util.py | 36 +++++ nnvm/python/nnvm/frontend/mxnet.py | 36 +++++ nnvm/python/nnvm/top/attr_dict.py | 15 ++ nnvm/python/nnvm/top/transform.py | 16 ++ nnvm/python/nnvm/top/vision.py | 63 ++++++++ nnvm/src/top/tensor/transform.cc | 25 +++ nnvm/src/top/vision/nms.cc | 76 +++++++++ nnvm/src/top/vision/ssd/mutibox_op.cc | 147 ++++++++++++++++++ nnvm/tests/python/compiler/test_top_level4.py | 115 ++++++++++++++ tutorials/nnvm/deploy_ssd.py | 115 ++++++++++++++ 12 files changed, 743 insertions(+), 1 deletion(-) create mode 100644 nnvm/src/top/vision/nms.cc create mode 100644 nnvm/src/top/vision/ssd/mutibox_op.cc create mode 100644 tutorials/nnvm/deploy_ssd.py diff --git a/nnvm/include/nnvm/top/nn.h b/nnvm/include/nnvm/top/nn.h index bee6137829c5..8575ebe63fdc 100644 --- a/nnvm/include/nnvm/top/nn.h +++ b/nnvm/include/nnvm/top/nn.h @@ -319,6 +319,64 @@ struct LayoutTransformParam : public dmlc::Parameter { } }; +struct MultiBoxPriorParam : public dmlc::Parameter { + Tuple sizes; + Tuple ratios; + Tuple steps; + Tuple offsets; + bool clip; + + DMLC_DECLARE_PARAMETER(MultiBoxPriorParam) { + DMLC_DECLARE_FIELD(sizes).set_default(Tuple({1.0})) + .describe("List of sizes of generated MultiBoxPriores."); + DMLC_DECLARE_FIELD(ratios).set_default(Tuple({1.0})) + .describe("List of aspect ratios of generated MultiBoxPriores."); + DMLC_DECLARE_FIELD(steps).set_default(Tuple({-1.0, -1.0})) + .describe("Priorbox step across y and x, -1 for auto calculation."); + DMLC_DECLARE_FIELD(offsets).set_default(Tuple({0.5, 0.5})) + .describe("Priorbox center offsets, y and x respectively."); + DMLC_DECLARE_FIELD(clip).set_default(false) + .describe("Whether to clip out-of-boundary boxes."); + } +}; + +struct MultiBoxDetectionParam : public dmlc::Parameter { + bool clip; + float threshold; + float nms_threshold; + bool force_suppress; + int nms_topk; + Tuple variances; + DMLC_DECLARE_PARAMETER(MultiBoxDetectionParam) { + DMLC_DECLARE_FIELD(clip).set_default(true) + .describe("Clip out-of-boundary boxes."); + DMLC_DECLARE_FIELD(threshold).set_default(0.01) + .describe("Threshold to be a positive prediction."); + DMLC_DECLARE_FIELD(nms_threshold).set_default(0.5) + .describe("Non-maximum suppression threshold."); + DMLC_DECLARE_FIELD(force_suppress).set_default(false) + .describe("Suppress all detections regardless of class_id."); + DMLC_DECLARE_FIELD(variances).set_default(Tuple{0.1, 0.1, 0.2, 0.2}) + .describe("Variances to be decoded from box regression output."); + DMLC_DECLARE_FIELD(nms_topk).set_default(-1) + .describe("Keep maximum top k detections before nms, -1 for no limit."); + } +}; + +struct NMSParam : public dmlc::Parameter { + float nms_threshold; + bool force_suppress; + int nms_topk; + DMLC_DECLARE_PARAMETER(NMSParam) { + DMLC_DECLARE_FIELD(nms_threshold).set_default(0.5) + .describe("Non-maximum suppression threshold."); + DMLC_DECLARE_FIELD(force_suppress).set_default(false) + .describe("Suppress all detections regardless of class_id."); + DMLC_DECLARE_FIELD(nms_topk).set_default(-1) + .describe("Keep maximum top k detections before nms, -1 for no limit."); + } +}; + } // namespace top } // namespace nnvm diff --git a/nnvm/python/nnvm/compiler/build_module.py b/nnvm/python/nnvm/compiler/build_module.py index ed75b10414c7..d2f48c8981e7 100644 --- a/nnvm/python/nnvm/compiler/build_module.py +++ b/nnvm/python/nnvm/compiler/build_module.py @@ -32,6 +32,8 @@ class BuildConfig(object): defaults = { "opt_level": 2, "add_pass": None, + "extra_lib_op": None, + "extra_lib_target": None, } def __init__(self, **kwargs): self._old_scope = None @@ -232,6 +234,11 @@ def build(graph, target=None, shape=None, dtype="float32", params : dict of str to NDArray The updated parameters of graph if params is passed. This can be different from the params passed in. + + extra_lib : tuple of (Graph, tvm.Module, dict of str to NDArray) + Extra runtime library for the last operator of the graph. + This return value only exists when extra_lib_op and + extra_lib_target are set in build_config. """ target = target if target else tvm.target.current_target() if target is None: @@ -247,6 +254,36 @@ def build(graph, target=None, shape=None, dtype="float32", cfg = BuildConfig.current graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph) + # Build extra operator runtime library + extra_lib = () + build_extra = False + if cfg.extra_lib_op is not None: + build_extra = True + graph, extra_op_graph = graph_util.split_last_op(graph) + last_op_name = extra_op_graph.index.nodes[-1]["op"] + if cfg.extra_lib_op != last_op_name: + raise RuntimeError("Currently only supports splitting the " + "last operator of the input graph. " + "extra_lib_op in build_config is %s, " + "but the last op of the graph is %s." % + (cfg.extra_lib_op, last_op_name)) + extra_op_params = {} + if params is not None: + for input_name in extra_op_graph.symbol.list_input_names(): + if input_name in params: + extra_op_params[input_name] = params[input_name] + params.remove(input_name) + _, graph_oshape = graph_util.infer_shape(graph, **shape) + extra_op_ishape = {} + shape_idx = 0 + for input_name in extra_op_graph.symbol.list_input_names(): + if input_name not in extra_op_params: + extra_op_ishape[input_name] = graph_oshape[shape_idx] + shape_idx += 1 + # Disable extra_lib option in cfg to ensure extra_op only built once. + cfg.extra_lib_op = None + extra_lib = build(extra_op_graph, cfg.extra_lib_target, + shape=extra_op_ishape, params=extra_op_params) shape, dtype = _update_shape_dtype(shape, dtype, params) # correct layout if necessary @@ -298,7 +335,10 @@ def build(graph, target=None, shape=None, dtype="float32", if params is None: params = {} params.update(init_var) - return graph, libmod, params + if not build_extra: + return graph, libmod, params + else: + return graph, libmod, params, extra_lib def _remove_noref_params(params, graph): """ Helper to clear non referenced params diff --git a/nnvm/python/nnvm/compiler/graph_util.py b/nnvm/python/nnvm/compiler/graph_util.py index e831298b27d9..802c0fc2cd7b 100644 --- a/nnvm/python/nnvm/compiler/graph_util.py +++ b/nnvm/python/nnvm/compiler/graph_util.py @@ -146,3 +146,39 @@ def gradients(ys, xs, grad_ys=None): if isinstance(xs, list) else len(xs.list_output_names()) ret = [grad_g.symbol[i] for i in range(nx)] return ret + +def split_last_op(graph): + """Split graph into the last operator + and all other parts before. + + Parameters + ---------- + graph : Graph + The original graph. + + Returns + ------- + main_graph: Graph + The graph before last operator. + + last_op_graph: Graph + The graph for the last operator. + """ + graph_idx = graph.index + last_op_node = graph_idx.nodes[-1] + last_op_func = getattr(sym, last_op_node["op"]) + if "attrs" in last_op_node: + last_op_attr = last_op_node["attrs"] + else: + last_op_attr = {} + last_op_num_inputs = len(last_op_node["inputs"]) + last_op_inputs = [] + for i in range(last_op_num_inputs): + input_idx = last_op_node["inputs"][i][0] + input_name = graph_idx.nodes[input_idx]["name"] + last_op_inputs.append(sym.Variable(input_name)) + last_op_sym = last_op_func(*last_op_inputs, **last_op_attr) + last_op_graph = create(last_op_sym) + main_graph_sym = graph.symbol.get_children() + main_graph = create(main_graph_sym) + return main_graph, last_op_graph diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py index 82b8e555d5ec..912649ab75a1 100644 --- a/nnvm/python/nnvm/frontend/mxnet.py +++ b/nnvm/python/nnvm/frontend/mxnet.py @@ -195,6 +195,12 @@ def _split(inputs, attrs): new_attrs['axis'] = attrs.get('axis', 1) return _get_nnvm_op(op_name)(*inputs, **new_attrs) +def _softmax_activation(inputs, attrs): + op_name, new_attrs = 'softmax', {} + mode = attrs.get('mode', 'instance') + new_attrs['axis'] = 0 if mode == 'instance' else 1 + return _get_nnvm_op(op_name)(inputs[0], **new_attrs) + def _softmax_output(inputs, attrs): op_name, new_attrs = 'softmax', {} if _parse_bool_str(attrs, 'multi_output'): @@ -212,6 +218,30 @@ def _clip(inputs, attrs): new_attrs['a_max'] = _required_attr(attrs, 'a_max') return _get_nnvm_op(op_name)(*inputs, **new_attrs) +def _contrib_multibox_detection(inputs, attrs): + clip = _parse_bool_str(attrs, 'clip', default='True') + threshold = attrs.get('threshold') or 0.01 + nms_threshold = attrs.get('nms_threshold') or 0.5 + force_suppress = _parse_bool_str(attrs, 'force_suppress', default='False') + variances = tuple([float(x.strip()) for x in attrs.get('variances').strip('()').split(',')]) \ + if attrs.get('variances') is not None else (0.1, 0.1, 0.2, 0.2) + nms_topk = attrs.get('nms_topk') or -1 + new_attrs = {'clip': clip, 'threshold': float(threshold), + 'nms_threshold': float(nms_threshold), + 'force_suppress': force_suppress, + 'variances': variances, 'nms_topk': int(nms_topk)} + return _get_nnvm_op('multibox_detection')(inputs[0],inputs[1], + inputs[2], **new_attrs) + +def _crop_like(inputs, _): + if len(inputs) < 2: + raise RuntimeError("Only support crop_like pattern.") + return _get_nnvm_op('crop_like')(inputs[0], inputs[1]) + +def _elemwise_sum(inputs, _): + new_attrs = {'num_args':len(inputs)} + return _get_nnvm_op('elemwise_sum')(*inputs, **new_attrs) + _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__', '__div_symbol__', '__mul_scalar__', '__mul_symbol__', @@ -224,12 +254,15 @@ def _clip(inputs, attrs): 'relu', 'sigmoid', 'softmax', 'sum', 'tanh', 'transpose'] _convert_map = { + '_copy' : _rename('copy'), '_div_scalar' : _rename('__div_scalar__'), '_minus_scalar' : _rename('__sub_scalar__'), '_mul_scalar' : _rename('__mul_scalar__'), '_plus_scalar' : _rename('__add_scalar__'), '_rdiv_scalar' : _rename('__rdiv_scalar__'), '_rminus_scalar': _rename('__rsub_scalar__'), + '_contrib_MultiBoxPrior' : _rename('multibox_prior'), + '_contrib_MultiBoxDetection' : _contrib_multibox_detection, 'Activation' : _activations, 'BatchNorm' : _batch_norm, 'BatchNorm_v1' : _batch_norm, @@ -237,6 +270,7 @@ def _clip(inputs, attrs): 'Concat' : _concat, 'Convolution' : _conv2d, 'Convolution_v1': _conv2d, + 'Crop' : _crop_like, 'Deconvolution' : _conv2d_transpose, 'Dropout' : _dropout, 'Flatten' : _rename('flatten'), @@ -248,7 +282,9 @@ def _clip(inputs, attrs): 'SliceChannel' : _split, 'split' : _split, 'Softmax' : _rename('softmax'), + 'SoftmaxActivation' : _softmax_activation, 'SoftmaxOutput' : _softmax_output, + 'add_n' : _elemwise_sum, 'concat' : _concat, 'max_axis' : _rename('max'), 'min_axis' : _rename('min'), diff --git a/nnvm/python/nnvm/top/attr_dict.py b/nnvm/python/nnvm/top/attr_dict.py index a913a92552b2..efd439fa75fc 100644 --- a/nnvm/python/nnvm/top/attr_dict.py +++ b/nnvm/python/nnvm/top/attr_dict.py @@ -83,6 +83,21 @@ def get_int(self, key): """ return int(self[key]) + def get_float_tuple(self, key): + """Get tuple of float from attr dict + + Parameters + ---------- + key : str + The attr key + + Returns + ------- + tuple : tuple of float + The result tuple + """ + return tuple(float(x) for x in self[key][1:-1].split(",") if x) + def get_float(self, key): """Get float from attr dict diff --git a/nnvm/python/nnvm/top/transform.py b/nnvm/python/nnvm/top/transform.py index b4b8779f2a68..faf6cd0314cc 100644 --- a/nnvm/python/nnvm/top/transform.py +++ b/nnvm/python/nnvm/top/transform.py @@ -37,6 +37,22 @@ def compute_reshape_like(attrs, inputs, out_info): reg.register_pattern("reshape_like", OpPattern.INJECTIVE) reg.register_schedule("reshape_like", _fschedule_injective) +# crop_like +@reg.register_compute("crop_like") +def compute_crop_like(_, inputs, out_info): + """Compute definition of crop_like""" + data0 = inputs[0] + data1 = inputs[1] + h0, w0 = data0.shape[2], data0.shape[3] + h1, w1 = data1.shape[2], data1.shape[3] + if h0.value <= h1.value and w0.value <= w1.value: + return data0 + out = tvm.compute(data1.shape, lambda *shape: data0(*shape)) + return out + +reg.register_pattern("crop_like", OpPattern.INJECTIVE) +reg.register_schedule("crop_like", _fschedule_injective) + # transpose reg.register_pattern("transpose", OpPattern.INJECTIVE) reg.register_schedule("transpose", _fschedule_injective) diff --git a/nnvm/python/nnvm/top/vision.py b/nnvm/python/nnvm/top/vision.py index 89409de6263b..089a7697bb8d 100644 --- a/nnvm/python/nnvm/top/vision.py +++ b/nnvm/python/nnvm/top/vision.py @@ -38,3 +38,66 @@ def schedule_region(attrs, outs, target): return topi.generic.vision.schedule_region(outs) reg.register_pattern("yolo2_region", OpPattern.OPAQUE) + +# multibox_prior +@reg.register_schedule("multibox_prior") +def schedule_multibox_prior(_, outs, target): + """Schedule definition of multibox_prior""" + with tvm.target.create(target): + return topi.generic.schedule_multibox_prior(outs) + +@reg.register_compute("multibox_prior") +def compute_multibox_prior(attrs, inputs, _): + """Compute definition of multibox_prior""" + sizes = attrs.get_float_tuple('sizes') + ratios = attrs.get_float_tuple('ratios') + steps = attrs.get_float_tuple('steps') + offsets = attrs.get_float_tuple('offsets') + clip = attrs.get_bool('clip') + + return topi.vision.ssd.multibox_prior(inputs[0], sizes, ratios, + steps, offsets, clip) + +reg.register_pattern("multibox_prior", OpPattern.OPAQUE) + +# multibox_detection +@reg.register_schedule("multibox_detection") +def schedule_multibox_detection(_, outs, target): + """Schedule definition of multibox_detection""" + with tvm.target.create(target): + return topi.generic.schedule_multibox_detection(outs) + +@reg.register_compute("multibox_detection") +def compute_multibox_detection(attrs, inputs, _): + """Compute definition of multibox_detection""" + clip = attrs.get_bool('clip') + threshold = attrs.get_float('threshold') + nms_threshold = attrs.get_float('nms_threshold') + force_suppress = attrs.get_bool('force_suppress') + variance = attrs.get_float_tuple('variances') + nms_topk = attrs.get_int('nms_topk') + + return topi.vision.ssd.multibox_detection(inputs[0], inputs[1], inputs[2], + clip, threshold, nms_threshold, + force_suppress, variance, nms_topk) + +reg.register_pattern("multibox_detection", OpPattern.OPAQUE) + +# non-maximum suppression +@reg.register_schedule("nms") +def schedule_nms(_, outs, target): + """Schedule definition of nms""" + with tvm.target.create(target): + return topi.generic.schedule_nms(outs) + +@reg.register_compute("nms") +def compute_nms(attrs, inputs, _): + """Compute definition of nms""" + nms_threshold = attrs.get_float('nms_threshold') + force_suppress = attrs.get_bool('force_suppress') + nms_topk = attrs.get_int('nms_topk') + + return topi.vision.nms(inputs[0], inputs[1], nms_threshold, + force_suppress, nms_topk) + +reg.register_pattern("nms", OpPattern.OPAQUE) diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc index bdc8dc5a9c40..0375d71d1539 100644 --- a/nnvm/src/top/tensor/transform.cc +++ b/nnvm/src/top/tensor/transform.cc @@ -616,6 +616,31 @@ the input array into an output array with the same shape as the second input arr }) .set_support_level(4); +NNVM_REGISTER_OP(crop_like) +.describe(R"code(Crop the 3rd and 4th dim of the first input data, with the corresponding +size of the second input data. +.. note:: + Input arrays should have the same number of dimensions. +)code" NNVM_ADD_FILELINE) +.add_argument("data", "Tensor", "Input data.") +.add_argument("shape_like", "Tensor", "Input data.") +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr( + "FInferShape", [](const NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->at(0).ndim(), in_attrs->at(1).ndim()) + << "Input arrays should have the same number of dimensions."; + TShape oshape = in_attrs->at(0); + oshape[2] = std::min(in_attrs->at(0)[2], in_attrs->at(1)[2]); + oshape[3] = std::min(in_attrs->at(0)[3], in_attrs->at(1)[3]); + NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, oshape); + return true; +}) +.set_attr("FInferType", ElemwiseType<2, 1>) +.set_support_level(4); + // squeeze DMLC_REGISTER_PARAMETER(SqueezeParam); diff --git a/nnvm/src/top/vision/nms.cc b/nnvm/src/top/vision/nms.cc new file mode 100644 index 000000000000..352fbe0220e0 --- /dev/null +++ b/nnvm/src/top/vision/nms.cc @@ -0,0 +1,76 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file nms.cc + * \brief Property def of SSD non-maximum suppression operator. + */ + +#include +#include +#include +#include +#include +#include +#include "../op_common.h" +#include "../elemwise_op_common.h" + +namespace nnvm { +namespace top { +using compiler::FTVMCompute; +using tvm::Tensor; +using tvm::Array; + +DMLC_REGISTER_PARAMETER(NMSParam); + +bool NMSShape(const NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U) << "Inputs: [data, valid_count]"; + TShape dshape = in_attrs->at(0); + TShape vshape = in_attrs->at(1); + CHECK_EQ(dshape.ndim(), 3U) << "Provided: " << dshape; + CHECK_EQ(vshape.ndim(), 1U) << "Provided: " << vshape; + CHECK_EQ(dshape[2], 6U) << "Data input should have shape " + "(batch_size, num_anchors, 6)."; + CHECK_EQ(dshape[0], vshape[0]) << "batch_size mismatch."; + TShape oshape = TShape(3); + oshape[0] = dshape[0]; + oshape[1] = dshape[1]; + oshape[2] = 6; // [id, prob, xmin, ymin, xmax, ymax] + out_attrs->clear(); + NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, oshape); + return true; +} + +inline bool NMSInferLayout(const NodeAttrs& attrs, + std::vector *ilayouts, + const std::vector *last_ilayouts, + std::vector *olayouts) { + static const Layout kNCHW("NCHW"); + CHECK_EQ(ilayouts->size(), 2U); + CHECK_EQ(olayouts->size(), 1U); + NNVM_ASSIGN_LAYOUT(*ilayouts, 0, kNCHW); + NNVM_ASSIGN_LAYOUT(*ilayouts, 1, kNCHW); + return true; +} + +NNVM_REGISTER_OP(nms) + .describe(R"doc("Non-maximum suppression." +)doc" NNVM_ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FGetAttrDict", + ParamGetAttrDict) +.add_arguments(NMSParam::__FIELDS__()) +.add_argument("data", "Tensor", "Input data.") +.add_argument("valid_count", "Tensor", "Number of valid anchor boxes.") +.set_attr("FListInputNames", [](const NodeAttrs& attrs) { + return std::vector{"data", "valid_count"}; +}) +.set_attr("FInferShape", NMSShape) +.set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FCorrectLayout", NMSInferLayout) +.set_support_level(4); + +} // namespace top +} // namespace nnvm \ No newline at end of file diff --git a/nnvm/src/top/vision/ssd/mutibox_op.cc b/nnvm/src/top/vision/ssd/mutibox_op.cc new file mode 100644 index 000000000000..2417e9748f21 --- /dev/null +++ b/nnvm/src/top/vision/ssd/mutibox_op.cc @@ -0,0 +1,147 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file multibox_op.cc + * \brief Property def of SSD multibox related operators. + */ + +#include +#include +#include +#include +#include +#include +#include "../../op_common.h" +#include "../../elemwise_op_common.h" + +namespace nnvm { +namespace top { +using compiler::FTVMCompute; +using tvm::Tensor; +using tvm::Array; + +DMLC_REGISTER_PARAMETER(MultiBoxPriorParam); + +bool MultiBoxPriorShape(const NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const MultiBoxPriorParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), 1U) << "Inputs: [data]" << in_attrs->size(); + TShape dshape = in_attrs->at(0); + CHECK_GE(dshape.ndim(), 4U) << "Input data should be 4D: " + "[batch, channel, height, width]"; + int in_height = dshape[2]; + CHECK_GT(in_height, 0) << "Input height should > 0"; + int in_width = dshape[3]; + CHECK_GT(in_width, 0) << "Input width should > 0"; + // since input sizes are same in each batch, we could share MultiBoxPrior + TShape oshape = TShape(3); + int num_sizes = param.sizes.ndim(); + int num_ratios = param.ratios.ndim(); + oshape[0] = 1; + oshape[1] = in_height * in_width * (num_sizes + num_ratios - 1); + oshape[2] = 4; + CHECK_EQ(param.steps.ndim(), 2) << "Step ndim must be 2: (step_y, step_x)"; + CHECK_GE(param.steps[0] * param.steps[1], 0) << "Must specify both " + "step_y and step_x"; + out_attrs->clear(); + NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, oshape); + return true; +} + +inline bool MultiBoxPriorInferLayout(const NodeAttrs& attrs, + std::vector *ilayouts, + const std::vector *last_ilayouts, + std::vector *olayouts) { + static const Layout kNCHW("NCHW"); + CHECK_EQ(ilayouts->size(), 1U); + CHECK_EQ(olayouts->size(), 1U); + NNVM_ASSIGN_LAYOUT(*ilayouts, 0, kNCHW); + return true; +} + +NNVM_REGISTER_OP(multibox_prior) + .describe(R"doc("Generate prior(anchor) boxes from data, sizes and ratios." +)doc" NNVM_ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FGetAttrDict", ParamGetAttrDict) +.add_arguments(MultiBoxPriorParam::__FIELDS__()) +.add_argument("data", "Tensor", "Input data") +.set_attr("FInferShape", MultiBoxPriorShape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FCorrectLayout", MultiBoxPriorInferLayout) +.set_attr( + "FGradient", [](const NodePtr& n, + const std::vector& ograds) { + return std::vector{ + MakeNode("zeros_like", n->attrs.name + "_zero_grad", + {n->inputs[0]}), + ograds[0] + }; +}) +.set_support_level(4); + +DMLC_REGISTER_PARAMETER(MultiBoxDetectionParam); + +bool MultiBoxDetectionShape(const NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 3U) << "Inputs: [cls_prob, loc_pred, anchor]"; + TShape cshape = in_attrs->at(0); + TShape lshape = in_attrs->at(1); + TShape ashape = in_attrs->at(2); + CHECK_EQ(cshape.ndim(), 3U) << "Provided: " << cshape; + CHECK_EQ(lshape.ndim(), 2U) << "Provided: " << lshape; + CHECK_EQ(ashape.ndim(), 3U) << "Provided: " << ashape; + CHECK_EQ(cshape[2], ashape[1]) << "Number of anchors mismatch"; + CHECK_EQ(cshape[2] * 4, lshape[1]) << "# anchors mismatch with # loc"; + CHECK_GT(ashape[1], 0U) << "Number of anchors must > 0"; + CHECK_EQ(ashape[2], 4U); + TShape oshape = TShape(3); + oshape[0] = cshape[0]; + oshape[1] = ashape[1]; + oshape[2] = 6; // [id, prob, xmin, ymin, xmax, ymax] + out_attrs->clear(); + NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, oshape); + return true; +} + +inline bool MultiBoxDetectionInferLayout(const NodeAttrs& attrs, + std::vector *ilayouts, + const std::vector *last_ilayouts, + std::vector *olayouts) { + CHECK_EQ(ilayouts->size(), 3U); + CHECK_EQ(last_ilayouts->size(), 3U); + CHECK_EQ(olayouts->size(), 1U); + for (size_t i = 0; i < last_ilayouts->size(); ++i) { + const Layout& last_layout = last_ilayouts->at(i); + if (last_layout.defined()) { + NNVM_ASSIGN_LAYOUT(*ilayouts, i, last_layout); + } + } + return true; +} + +NNVM_REGISTER_OP(multibox_detection) + .describe(R"doc("Convert multibox detection predictions." +)doc" NNVM_ADD_FILELINE) +.set_num_inputs(3) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FGetAttrDict", + ParamGetAttrDict) +.add_arguments(MultiBoxDetectionParam::__FIELDS__()) +.add_argument("cls_prob", "Tensor", "Class probabilities.") +.add_argument("loc_pred", "Tensor", "Location regression predictions.") +.add_argument("anchor", "Tensor", "Multibox prior anchor boxes") +.set_attr("FListInputNames", [](const NodeAttrs& attrs) { + return std::vector{"cls_prob", "loc_pred", "anchor"}; +}) +.set_attr("FInferShape", MultiBoxDetectionShape) +.set_attr("FInferType", ElemwiseType<3, 1>) +.set_attr("FCorrectLayout", MultiBoxDetectionInferLayout) +.set_support_level(4); + +} // namespace top +} // namespace nnvm diff --git a/nnvm/tests/python/compiler/test_top_level4.py b/nnvm/tests/python/compiler/test_top_level4.py index 819768cfb341..54ac021bb9b9 100644 --- a/nnvm/tests/python/compiler/test_top_level4.py +++ b/nnvm/tests/python/compiler/test_top_level4.py @@ -1,3 +1,4 @@ +import math import numpy as np import tvm from tvm.contrib import graph_runtime @@ -356,6 +357,117 @@ def test_full(): np.full(shape, fill_value=0, dtype=dtype), atol=1e-5, rtol=1e-5) +def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), + offsets=(0.5, 0.5), clip=False): + data = sym.Variable("data") + out = sym.multibox_prior(data=data, sizes=sizes, ratios=ratios, steps=steps, + offsets=offsets, clip=clip) + + in_height = dshape[2] + in_width = dshape[3] + num_sizes = len(sizes) + num_ratios = len(ratios) + size_ratio_concat = sizes + ratios + steps_h = steps[0] if steps[0] > 0 else 1.0 / in_height + steps_w = steps[1] if steps[1] > 0 else 1.0 / in_width + offset_h = offsets[0] + offset_w = offsets[1] + + oshape = (1, in_height * in_width * (num_sizes + num_ratios - 1), 4) + dtype = "float32" + np_out = np.zeros(oshape).astype(dtype) + + for i in range(in_height): + center_h = (i + offset_h) * steps_h + for j in range(in_width): + center_w = (j + offset_w) * steps_w + for k in range(num_sizes + num_ratios - 1): + w = size_ratio_concat[k] * in_height / in_width / 2.0 if k < num_sizes else \ + size_ratio_concat[0] * in_height / in_width * math.sqrt(size_ratio_concat[k + 1]) / 2.0 + h = size_ratio_concat[k] / 2.0 if k < num_sizes else \ + size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0 + count = i * in_width * (num_sizes + num_ratios - 1) + j * (num_sizes + num_ratios - 1) + k + np_out[0][count][0] = center_w - w + np_out[0][count][1] = center_h - h + np_out[0][count][2] = center_w + w + np_out[0][count][3] = center_h + h + if clip: + np_out = np.clip(np_out, 0, 1) + + target = "llvm" + ctx = tvm.cpu() + graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape}) + m = graph_runtime.create(graph, lib, ctx) + m.set_input("data", np.random.uniform(size=dshape).astype(dtype)) + m.run() + out = m.get_output(0, tvm.nd.empty(np_out.shape, dtype)) + np.testing.assert_allclose(out.asnumpy(), np_out, atol=1e-5, rtol=1e-5) + +def test_multibox_prior(): + verify_multibox_prior((1, 3, 50, 50)) + verify_multibox_prior((1, 3, 224, 224), sizes=(0.5, 0.25, 0.1), ratios=(1, 2, 0.5)) + verify_multibox_prior((1, 32, 32, 32), sizes=(0.5, 0.25), ratios=(1, 2), steps=(2, 2), clip=True) + +def test_multibox_detection(): + batch_size = 1 + num_anchors = 3 + num_classes = 3 + cls_prob = sym.Variable("cls_prob") + loc_preds = sym.Variable("loc_preds") + anchors = sym.Variable("anchors") + out = sym.multibox_detection(cls_prob=cls_prob, loc_pred=loc_preds, anchor=anchors) + + # Manually create test case + np_cls_prob = np.array([[[0.2, 0.5, 0.3], [0.25, 0.3, 0.45], [0.7, 0.1, 0.2]]]) + np_loc_preds = np.array([[0.1, -0.2, 0.3, 0.2, 0.2, 0.4, 0.5, -0.3, 0.7, -0.2, -0.4, -0.8]]) + np_anchors = np.array([[[-0.1, -0.1, 0.1, 0.1], [-0.2, -0.2, 0.2, 0.2], [1.2, 1.2, 1.5, 1.5]]]) + + expected_np_out = np.array([[[1, 0.69999999, 0, 0, 0.10818365, 0.10008108], + [0, 0.44999999, 1, 1, 1, 1], + [0, 0.30000001, 0, 0, 0.22903419, 0.20435292]]]) + + target = "llvm" + dtype = "float32" + ctx = tvm.cpu() + graph, lib, _ = nnvm.compiler.build(out, target, {"cls_prob": (batch_size, num_anchors, num_classes), + "loc_preds": (batch_size, num_anchors * 4), + "anchors": (1, num_anchors, 4)}) + m = graph_runtime.create(graph, lib, ctx) + m.set_input(**{"cls_prob": np_cls_prob.astype(dtype), "loc_preds": np_loc_preds.astype(dtype), "anchors": np_anchors.astype(dtype)}) + m.run() + out = m.get_output(0, tvm.nd.empty(expected_np_out.shape, dtype)) + np.testing.assert_allclose(out.asnumpy(), expected_np_out, atol=1e-5, rtol=1e-5) + + +def test_nms(): + dshape = (1, 5, 6) + data = sym.Variable("data") + valid_count = sym.Variable("valid_count", dtype="int32") + nms_threshold = 0.7 + force_suppress = True + nms_topk = 2 + out = sym.nms(data=data, valid_count=valid_count, nms_threshold=nms_threshold, + force_suppress=force_suppress, nms_topk=nms_topk) + + np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80], + [0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79], + [1, 0.5, 100, 60, 70, 110]]]).astype("float32") + np_valid_count = np.array([4]).astype("int32") + np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45], + [0, 0.4, 4, 21, 19, 40], [-1, 0.9, 35, 61, 52, 79], + [-1, -1, -1, -1, -1, -1]]]) + + target = "llvm" + ctx = tvm.cpu() + graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape, "valid_count": (dshape[0],)}, + dtype={"data": "float32", "valid_count": "int32"}) + m = graph_runtime.create(graph, lib, ctx) + m.set_input(**{"data": np_data, "valid_count": np_valid_count}) + m.run() + out = m.get_output(0, tvm.nd.empty(np_result.shape, "float32")) + np.testing.assert_allclose(out.asnumpy(), np_result, atol=1e-5, rtol=1e-5) + + if __name__ == "__main__": test_reshape() @@ -370,4 +482,7 @@ def test_full(): test_block_grad() test_full() test_flip() + test_multibox_prior() + test_multibox_detection() + #test_nms() print(nnvm.compiler.engine.dump()) diff --git a/tutorials/nnvm/deploy_ssd.py b/tutorials/nnvm/deploy_ssd.py new file mode 100644 index 000000000000..99e549190608 --- /dev/null +++ b/tutorials/nnvm/deploy_ssd.py @@ -0,0 +1,115 @@ +""" +Deploy Single Shot Multibox Detector(SSD) model +====================================== +**Author**: `Yao Wang `_ + +This article is an introductory tutorial to deploy SSD models with TVM. +We will use mxnet pretrained SSD model with Resnet50 as body network and +convert it to NNVM graph. +""" +import os +import urllib +import zipfile +import tvm +import cv2 +import numpy as np +import mxnet as mx + +from nnvm import compiler +from nnvm.frontend import from_mxnet +from tvm.contrib import graph_runtime +from mxnet.model import load_checkpoint + +###################################################################### +# To get started, clone mxnet repo from github +# and extract ssd symbol directory: +# +# .. code-block:: bash +# +# git clone https://github.com/apache/incubator-mxnet mxnet +# mkdir symbol && cp -a mxnet/example/ssd/symbol/* symbol + + +###################################################################### +# Set the parameters here. +# + +model_name = "ssd_resnet50_512" +model_file = "%s.zip" % model_name +test_image = "person.jpg" +target = "llvm -mcpu=core-avx2" +dshape = (1, 3, 512, 512) +dtype = "float32" +ctx = tvm.cpu() + +def download(url, path, overwrite=False): + """Downloads the file from the internet. + Set the input options correctly to overwrite or do the size comparison + + Parameters + ---------- + url : str + Download file url + path : str + File saved path. + overwrite : boolean + Dict of operator attributes + """ + if os.path.isfile(path) and not overwrite: + print('File {} exists, skip.'.format(path)) + return + print('Downloading from url {} to {}'.format(url, path)) + try: + urllib.request.urlretrieve(url, path) + print('') + except: + urllib.urlretrieve(url, path) + +###################################################################### +# Download MXNet SSD pre-trained model and demo image. +# ---------------------------- +# Pre-trained model available at +# https://github.com/apache/incubator-\mxnet/tree/master/example/ssd + +model_url = "https://github.com/zhreshold/mxnet-ssd/releases/download/v0.6/" \ + "resnet50_ssd_512_voc0712_trainval.zip" +image_url = "https://cloud.githubusercontent.com/assets/3307514/20012563/" \ + "cbb41382-a27d-11e6-92a9-18dab4fd1ad3.jpg" +dir = "ssd_model" +if not os.path.exists(dir): + os.makedirs(dir) +model_file_path = "%s/%s" % (dir, model_file) +test_image_path = "%s/%s" % (dir, test_image) +download(model_url, model_file_path) +download(image_url, test_image_path) +zip_ref = zipfile.ZipFile(model_file_path, 'r') +zip_ref.extractall(dir) +zip_ref.close() + +###################################################################### +# Convert and compile model with NNVM for CPU. + +from symbol.symbol_factory import get_symbol +sym = get_symbol("resnet50", dshape[2], num_classes=20) +_, arg_params, aux_params = load_checkpoint("%s/%s" % (dir, model_name), 0) +net, params = from_mxnet(sym, arg_params, aux_params) +with compiler.build_config(opt_level=3): + graph, lib, params = compiler.build(net, target, {"data": dshape}, params=params) + +###################################################################### +# Create TVM runtime and do inference + +img_data = cv2.imread(test_image_path) +img_data = cv2.resize(img_data, (dshape[2], dshape[3])) +img_data = np.transpose(np.array(img_data), (2, 0, 1)) +img_data = np.expand_dims(img_data, axis=0) +np_data = np.random.uniform(0, 255, size=dshape).astype(dtype) +m = graph_runtime.create(graph, lib, ctx) +m.set_input('data', tvm.nd.array(img_data.astype(dtype))) +m.set_input(**params) +# execute +m.run() +# get outputs +_, oshape = compiler.graph_util.infer_shape(graph, {"data": dshape}) +tvm_output = m.get_output(0, tvm.nd.empty(oshape, dtype)) +print(tvm_output.shape) From 32190705053cb112fa3c24ddae98eee4d6bcde5e Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 31 May 2018 18:48:02 +0000 Subject: [PATCH 02/22] Add test for build with extra_lib --- nnvm/python/nnvm/frontend/mxnet.py | 2 +- nnvm/python/nnvm/top/vision.py | 2 +- nnvm/src/top/vision/nms.cc | 3 +- nnvm/src/top/vision/ssd/mutibox_op.cc | 20 ++++---- nnvm/tests/python/compiler/test_build.py | 35 +++++++++++++ tutorials/nnvm/deploy_ssd.py | 63 +++++++++++++++++++----- 6 files changed, 101 insertions(+), 24 deletions(-) diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py index 912649ab75a1..21e7bbbcf001 100644 --- a/nnvm/python/nnvm/frontend/mxnet.py +++ b/nnvm/python/nnvm/frontend/mxnet.py @@ -71,7 +71,7 @@ def _batch_norm(inputs, attrs): new_attrs['axis'] = attrs.get('axis', 1) new_attrs['epsilon'] = attrs.get('eps', 0.001) new_attrs['center'] = True - new_attrs['scale'] = True + new_attrs['scale'] = not _parse_bool_str(attrs, 'fix_gamma', default="False") return _get_nnvm_op(op_name)(*inputs, **new_attrs) def _concat(inputs, attrs): diff --git a/nnvm/python/nnvm/top/vision.py b/nnvm/python/nnvm/top/vision.py index 089a7697bb8d..7e5e641f340f 100644 --- a/nnvm/python/nnvm/top/vision.py +++ b/nnvm/python/nnvm/top/vision.py @@ -3,8 +3,8 @@ """Definition of nn ops""" from __future__ import absolute_import -import topi import tvm +import topi from . import registry as reg from .registry import OpPattern diff --git a/nnvm/src/top/vision/nms.cc b/nnvm/src/top/vision/nms.cc index 352fbe0220e0..22b2136341ef 100644 --- a/nnvm/src/top/vision/nms.cc +++ b/nnvm/src/top/vision/nms.cc @@ -73,4 +73,5 @@ NNVM_REGISTER_OP(nms) .set_support_level(4); } // namespace top -} // namespace nnvm \ No newline at end of file +} // namespace nnvm + diff --git a/nnvm/src/top/vision/ssd/mutibox_op.cc b/nnvm/src/top/vision/ssd/mutibox_op.cc index 2417e9748f21..577657ecb304 100644 --- a/nnvm/src/top/vision/ssd/mutibox_op.cc +++ b/nnvm/src/top/vision/ssd/mutibox_op.cc @@ -48,10 +48,10 @@ bool MultiBoxPriorShape(const NodeAttrs& attrs, return true; } -inline bool MultiBoxPriorInferLayout(const NodeAttrs& attrs, - std::vector *ilayouts, - const std::vector *last_ilayouts, - std::vector *olayouts) { +inline bool MultiBoxPriorLayout(const NodeAttrs& attrs, + std::vector *ilayouts, + const std::vector *last_ilayouts, + std::vector *olayouts) { static const Layout kNCHW("NCHW"); CHECK_EQ(ilayouts->size(), 1U); CHECK_EQ(olayouts->size(), 1U); @@ -70,7 +70,7 @@ NNVM_REGISTER_OP(multibox_prior) .add_argument("data", "Tensor", "Input data") .set_attr("FInferShape", MultiBoxPriorShape) .set_attr("FInferType", ElemwiseType<1, 1>) -.set_attr("FCorrectLayout", MultiBoxPriorInferLayout) +.set_attr("FCorrectLayout", MultiBoxPriorLayout) .set_attr( "FGradient", [](const NodePtr& n, const std::vector& ograds) { @@ -107,10 +107,10 @@ bool MultiBoxDetectionShape(const NodeAttrs& attrs, return true; } -inline bool MultiBoxDetectionInferLayout(const NodeAttrs& attrs, - std::vector *ilayouts, - const std::vector *last_ilayouts, - std::vector *olayouts) { +inline bool MultiBoxDetectionLayout(const NodeAttrs& attrs, + std::vector *ilayouts, + const std::vector *last_ilayouts, + std::vector *olayouts) { CHECK_EQ(ilayouts->size(), 3U); CHECK_EQ(last_ilayouts->size(), 3U); CHECK_EQ(olayouts->size(), 1U); @@ -140,7 +140,7 @@ NNVM_REGISTER_OP(multibox_detection) }) .set_attr("FInferShape", MultiBoxDetectionShape) .set_attr("FInferType", ElemwiseType<3, 1>) -.set_attr("FCorrectLayout", MultiBoxDetectionInferLayout) +.set_attr("FCorrectLayout", MultiBoxDetectionLayout) .set_support_level(4); } // namespace top diff --git a/nnvm/tests/python/compiler/test_build.py b/nnvm/tests/python/compiler/test_build.py index 5e1f0337c293..b2e2c5e67dc2 100644 --- a/nnvm/tests/python/compiler/test_build.py +++ b/nnvm/tests/python/compiler/test_build.py @@ -94,9 +94,44 @@ def test_dtypes(): out = m.get_output(0, tvm.nd.empty(oshape, dtype)) np.testing.assert_allclose(out.asnumpy(), data, atol=1e-5, rtol=1e-5) +def test_compile_extra_lib(): + data = sym.Variable("data") + net = sym.relu(data) + net = sym.sqrt(net) + out = sym.flatten(net) + + target = "cuda" + extra_lib_target = "llvm" + dshape = (1, 3, 56, 56) + dtype = "float32" + in_data = np.random.uniform(size=dshape).astype(dtype) + opt_level = 2 + with nnvm.compiler.build_config(opt_level=opt_level): + graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape}) + m = graph_runtime.create(graph, lib, tvm.gpu(0)) + m.set_input("data", in_data) + m.run() + _, oshape = nnvm.compiler.graph_util.infer_shape(graph, shape={"data": dshape}) + expected_out = m.get_output(0, tvm.nd.empty(oshape, dtype)) + + with nnvm.compiler.build_config(opt_level=opt_level, extra_lib_op="flatten", extra_lib_target=extra_lib_target): + graph, lib, _, extra_libmod = nnvm.compiler.build(out, target, {"data": dshape}) + major_m = graph_runtime.create(graph, lib, tvm.gpu(0)) + major_m.set_input("data", in_data) + major_m.run() + major_out = major_m.get_output(0, tvm.nd.empty(dshape, dtype)) + extra_graph, extra_lib, _ = extra_libmod + extra_m = graph_runtime.create(extra_graph, extra_lib, tvm.cpu()) + extra_input_name = extra_graph.symbol.list_input_names()[0] + extra_m.set_input(extra_input_name, major_out) + extra_m.run() + final_out = extra_m.get_output(0, tvm.nd.empty(oshape, dtype)) + np.testing.assert_allclose(major_out.asnumpy(), final_out.asnumpy(), atol=1e-5, rtol=1e-5) + if __name__ == "__main__": test_precompute_prune() test_compile() test_run() test_dtypes() + test_compile_extra_lib() diff --git a/tutorials/nnvm/deploy_ssd.py b/tutorials/nnvm/deploy_ssd.py index 99e549190608..fcc0c60a1311 100644 --- a/tutorials/nnvm/deploy_ssd.py +++ b/tutorials/nnvm/deploy_ssd.py @@ -13,7 +13,6 @@ import tvm import cv2 import numpy as np -import mxnet as mx from nnvm import compiler from nnvm.frontend import from_mxnet @@ -36,8 +35,8 @@ model_name = "ssd_resnet50_512" model_file = "%s.zip" % model_name -test_image = "person.jpg" -target = "llvm -mcpu=core-avx2" +test_image = "dog.jpg" +target = "llvm" dshape = (1, 3, 512, 512) dtype = "float32" ctx = tvm.cpu() @@ -73,8 +72,9 @@ def download(url, path, overwrite=False): model_url = "https://github.com/zhreshold/mxnet-ssd/releases/download/v0.6/" \ "resnet50_ssd_512_voc0712_trainval.zip" -image_url = "https://cloud.githubusercontent.com/assets/3307514/20012563/" \ - "cbb41382-a27d-11e6-92a9-18dab4fd1ad3.jpg" +image_url = "https://cloud.githubusercontent.com/assets/3307514/20012567/" \ + "cbb60336-a27d-11e6-93ff-cbc3f09f5c9e.jpg" + dir = "ssd_model" if not os.path.exists(dir): os.makedirs(dir) @@ -99,17 +99,58 @@ def download(url, path, overwrite=False): ###################################################################### # Create TVM runtime and do inference -img_data = cv2.imread(test_image_path) -img_data = cv2.resize(img_data, (dshape[2], dshape[3])) +# Preprocess image +image = cv2.imread(test_image_path) +img_data = cv2.resize(image, (dshape[2], dshape[3])) +img_data = img_data[:, :, (2, 1, 0)].astype(np.float32) +img_data -= np.array([123, 117, 104]) img_data = np.transpose(np.array(img_data), (2, 0, 1)) img_data = np.expand_dims(img_data, axis=0) -np_data = np.random.uniform(0, 255, size=dshape).astype(dtype) +# Build TVM runtime m = graph_runtime.create(graph, lib, ctx) m.set_input('data', tvm.nd.array(img_data.astype(dtype))) m.set_input(**params) # execute m.run() # get outputs -_, oshape = compiler.graph_util.infer_shape(graph, {"data": dshape}) -tvm_output = m.get_output(0, tvm.nd.empty(oshape, dtype)) -print(tvm_output.shape) +_, oshape = compiler.graph_util.infer_shape(graph, shape={"data": dshape}) +tvm_output = m.get_output(0, tvm.nd.empty(tuple(oshape[0]), dtype)) + + +###################################################################### +# Display result + +class_names = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", + "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", + "sheep", "sofa", "train", "tvmonitor"] +def display(img, out, thresh=0.5): + import random + import matplotlib as mpl + import matplotlib.pyplot as plt + mpl.rcParams['figure.figsize'] = (10,10) + pens = dict() + plt.clf() + plt.imshow(img) + for det in out: + cid = int(det[0]) + if cid < 0: + continue + score = det[1] + if score < thresh: + continue + if cid not in pens: + pens[cid] = (random.random(), random.random(), random.random()) + scales = [img.shape[1], img.shape[0]] * 2 + xmin, ymin, xmax, ymax = [int(p * s) for p, s in zip(det[2:6].tolist(), scales)] + rect = plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, + edgecolor=pens[cid], linewidth=3) + plt.gca().add_patch(rect) + text = class_names[cid] + plt.gca().text(xmin, ymin-2, '{:s} {:.3f}'.format(text, score), + bbox=dict(facecolor=pens[cid], alpha=0.5), + fontsize=12, color='white') + plt.show() + +image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) +display(image, tvm_output.asnumpy()[0], thresh=0.45) + From bb9cdce0c5251bdece4bf4a9e011666bdb3ac83a Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 31 May 2018 19:58:17 +0000 Subject: [PATCH 03/22] Fix ssd tutorial --- nnvm/python/nnvm/compiler/graph_util.py | 1 + nnvm/tests/python/compiler/test_build.py | 6 +++--- tutorials/nnvm/deploy_ssd.py | 13 +++---------- 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/nnvm/python/nnvm/compiler/graph_util.py b/nnvm/python/nnvm/compiler/graph_util.py index 802c0fc2cd7b..621872ead98b 100644 --- a/nnvm/python/nnvm/compiler/graph_util.py +++ b/nnvm/python/nnvm/compiler/graph_util.py @@ -7,6 +7,7 @@ from ..graph import create from ..symbol import Group, ones_like +from .. import symbol as sym def infer_shape(graph, **shape): """Infer the shape given the shape of inputs. diff --git a/nnvm/tests/python/compiler/test_build.py b/nnvm/tests/python/compiler/test_build.py index b2e2c5e67dc2..fcc882dfdb19 100644 --- a/nnvm/tests/python/compiler/test_build.py +++ b/nnvm/tests/python/compiler/test_build.py @@ -112,7 +112,7 @@ def test_compile_extra_lib(): m.set_input("data", in_data) m.run() _, oshape = nnvm.compiler.graph_util.infer_shape(graph, shape={"data": dshape}) - expected_out = m.get_output(0, tvm.nd.empty(oshape, dtype)) + expected_out = m.get_output(0, tvm.nd.empty(oshape[0], dtype)) with nnvm.compiler.build_config(opt_level=opt_level, extra_lib_op="flatten", extra_lib_target=extra_lib_target): graph, lib, _, extra_libmod = nnvm.compiler.build(out, target, {"data": dshape}) @@ -125,8 +125,8 @@ def test_compile_extra_lib(): extra_input_name = extra_graph.symbol.list_input_names()[0] extra_m.set_input(extra_input_name, major_out) extra_m.run() - final_out = extra_m.get_output(0, tvm.nd.empty(oshape, dtype)) - np.testing.assert_allclose(major_out.asnumpy(), final_out.asnumpy(), atol=1e-5, rtol=1e-5) + final_out = extra_m.get_output(0, tvm.nd.empty(oshape[0], dtype)) + np.testing.assert_allclose(expected_out.asnumpy(), final_out.asnumpy(), atol=1e-5, rtol=1e-5) if __name__ == "__main__": diff --git a/tutorials/nnvm/deploy_ssd.py b/tutorials/nnvm/deploy_ssd.py index fcc0c60a1311..ea90ea64fc5d 100644 --- a/tutorials/nnvm/deploy_ssd.py +++ b/tutorials/nnvm/deploy_ssd.py @@ -19,19 +19,9 @@ from tvm.contrib import graph_runtime from mxnet.model import load_checkpoint -###################################################################### -# To get started, clone mxnet repo from github -# and extract ssd symbol directory: -# -# .. code-block:: bash -# -# git clone https://github.com/apache/incubator-mxnet mxnet -# mkdir symbol && cp -a mxnet/example/ssd/symbol/* symbol - ###################################################################### # Set the parameters here. -# model_name = "ssd_resnet50_512" model_file = "%s.zip" % model_name @@ -88,7 +78,10 @@ def download(url, path, overwrite=False): ###################################################################### # Convert and compile model with NNVM for CPU. +# First we need to download MXNet SSD example and create inference model. +os.system("git clone https://github.com/apache/incubator-mxnet mxnet") +os.system("cp -avr mxnet/example/ssd/symbol symbol") from symbol.symbol_factory import get_symbol sym = get_symbol("resnet50", dshape[2], num_classes=20) _, arg_params, aux_params = load_checkpoint("%s/%s" % (dir, model_name), 0) From b24755c4cae65a8709f67c81410e54bfa46a4ea3 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 31 May 2018 20:20:00 +0000 Subject: [PATCH 04/22] Small fix --- nnvm/python/nnvm/frontend/mxnet.py | 6 ------ nnvm/python/nnvm/top/transform.py | 16 ---------------- nnvm/src/top/tensor/transform.cc | 25 ------------------------- 3 files changed, 47 deletions(-) diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py index 21e7bbbcf001..1aed998203c1 100644 --- a/nnvm/python/nnvm/frontend/mxnet.py +++ b/nnvm/python/nnvm/frontend/mxnet.py @@ -233,11 +233,6 @@ def _contrib_multibox_detection(inputs, attrs): return _get_nnvm_op('multibox_detection')(inputs[0],inputs[1], inputs[2], **new_attrs) -def _crop_like(inputs, _): - if len(inputs) < 2: - raise RuntimeError("Only support crop_like pattern.") - return _get_nnvm_op('crop_like')(inputs[0], inputs[1]) - def _elemwise_sum(inputs, _): new_attrs = {'num_args':len(inputs)} return _get_nnvm_op('elemwise_sum')(*inputs, **new_attrs) @@ -270,7 +265,6 @@ def _elemwise_sum(inputs, _): 'Concat' : _concat, 'Convolution' : _conv2d, 'Convolution_v1': _conv2d, - 'Crop' : _crop_like, 'Deconvolution' : _conv2d_transpose, 'Dropout' : _dropout, 'Flatten' : _rename('flatten'), diff --git a/nnvm/python/nnvm/top/transform.py b/nnvm/python/nnvm/top/transform.py index faf6cd0314cc..b4b8779f2a68 100644 --- a/nnvm/python/nnvm/top/transform.py +++ b/nnvm/python/nnvm/top/transform.py @@ -37,22 +37,6 @@ def compute_reshape_like(attrs, inputs, out_info): reg.register_pattern("reshape_like", OpPattern.INJECTIVE) reg.register_schedule("reshape_like", _fschedule_injective) -# crop_like -@reg.register_compute("crop_like") -def compute_crop_like(_, inputs, out_info): - """Compute definition of crop_like""" - data0 = inputs[0] - data1 = inputs[1] - h0, w0 = data0.shape[2], data0.shape[3] - h1, w1 = data1.shape[2], data1.shape[3] - if h0.value <= h1.value and w0.value <= w1.value: - return data0 - out = tvm.compute(data1.shape, lambda *shape: data0(*shape)) - return out - -reg.register_pattern("crop_like", OpPattern.INJECTIVE) -reg.register_schedule("crop_like", _fschedule_injective) - # transpose reg.register_pattern("transpose", OpPattern.INJECTIVE) reg.register_schedule("transpose", _fschedule_injective) diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc index 0375d71d1539..bdc8dc5a9c40 100644 --- a/nnvm/src/top/tensor/transform.cc +++ b/nnvm/src/top/tensor/transform.cc @@ -616,31 +616,6 @@ the input array into an output array with the same shape as the second input arr }) .set_support_level(4); -NNVM_REGISTER_OP(crop_like) -.describe(R"code(Crop the 3rd and 4th dim of the first input data, with the corresponding -size of the second input data. -.. note:: - Input arrays should have the same number of dimensions. -)code" NNVM_ADD_FILELINE) -.add_argument("data", "Tensor", "Input data.") -.add_argument("shape_like", "Tensor", "Input data.") -.set_num_inputs(2) -.set_num_outputs(1) -.set_attr( - "FInferShape", [](const NodeAttrs& attrs, - std::vector* in_attrs, - std::vector* out_attrs) { - CHECK_EQ(in_attrs->at(0).ndim(), in_attrs->at(1).ndim()) - << "Input arrays should have the same number of dimensions."; - TShape oshape = in_attrs->at(0); - oshape[2] = std::min(in_attrs->at(0)[2], in_attrs->at(1)[2]); - oshape[3] = std::min(in_attrs->at(0)[3], in_attrs->at(1)[3]); - NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, oshape); - return true; -}) -.set_attr("FInferType", ElemwiseType<2, 1>) -.set_support_level(4); - // squeeze DMLC_REGISTER_PARAMETER(SqueezeParam); From baa731df4b05f8211fdc0156db928c7359ea6ba8 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 31 May 2018 20:41:48 +0000 Subject: [PATCH 05/22] Fix pylint --- nnvm/python/nnvm/compiler/build_module.py | 3 +-- nnvm/python/nnvm/frontend/mxnet.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/nnvm/python/nnvm/compiler/build_module.py b/nnvm/python/nnvm/compiler/build_module.py index d2f48c8981e7..9f40e2a80fe8 100644 --- a/nnvm/python/nnvm/compiler/build_module.py +++ b/nnvm/python/nnvm/compiler/build_module.py @@ -337,8 +337,7 @@ def build(graph, target=None, shape=None, dtype="float32", params.update(init_var) if not build_extra: return graph, libmod, params - else: - return graph, libmod, params, extra_lib + return graph, libmod, params, extra_lib def _remove_noref_params(params, graph): """ Helper to clear non referenced params diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py index 1aed998203c1..fa7c44418348 100644 --- a/nnvm/python/nnvm/frontend/mxnet.py +++ b/nnvm/python/nnvm/frontend/mxnet.py @@ -230,7 +230,7 @@ def _contrib_multibox_detection(inputs, attrs): 'nms_threshold': float(nms_threshold), 'force_suppress': force_suppress, 'variances': variances, 'nms_topk': int(nms_topk)} - return _get_nnvm_op('multibox_detection')(inputs[0],inputs[1], + return _get_nnvm_op('multibox_detection')(inputs[0], inputs[1], inputs[2], **new_attrs) def _elemwise_sum(inputs, _): From 60c7f52d859c3e92ba4b936b9ffd48e4e335ba4f Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 31 May 2018 21:25:42 +0000 Subject: [PATCH 06/22] Store ssd inference symbol json file --- tutorials/nnvm/deploy_ssd.py | 7 +- .../nnvm/ssd/ssd_resnet50_inference.json | 6180 +++++++++++++++++ 2 files changed, 6182 insertions(+), 5 deletions(-) create mode 100644 tutorials/nnvm/ssd/ssd_resnet50_inference.json diff --git a/tutorials/nnvm/deploy_ssd.py b/tutorials/nnvm/deploy_ssd.py index ea90ea64fc5d..8ad92443b7a5 100644 --- a/tutorials/nnvm/deploy_ssd.py +++ b/tutorials/nnvm/deploy_ssd.py @@ -11,6 +11,7 @@ import urllib import zipfile import tvm +import mxnet as mx import cv2 import numpy as np @@ -78,12 +79,8 @@ def download(url, path, overwrite=False): ###################################################################### # Convert and compile model with NNVM for CPU. -# First we need to download MXNet SSD example and create inference model. -os.system("git clone https://github.com/apache/incubator-mxnet mxnet") -os.system("cp -avr mxnet/example/ssd/symbol symbol") -from symbol.symbol_factory import get_symbol -sym = get_symbol("resnet50", dshape[2], num_classes=20) +sym = mx.sym.load("ssd/ssd_resnet50_inference.json") _, arg_params, aux_params = load_checkpoint("%s/%s" % (dir, model_name), 0) net, params = from_mxnet(sym, arg_params, aux_params) with compiler.build_config(opt_level=3): diff --git a/tutorials/nnvm/ssd/ssd_resnet50_inference.json b/tutorials/nnvm/ssd/ssd_resnet50_inference.json new file mode 100644 index 000000000000..3af9a9023a72 --- /dev/null +++ b/tutorials/nnvm/ssd/ssd_resnet50_inference.json @@ -0,0 +1,6180 @@ +{ + "nodes": [ + { + "op": "null", + "name": "data", + "inputs": [] + }, + { + "op": "_copy", + "name": "id", + "inputs": [[0, 0, 0]] + }, + { + "op": "null", + "name": "bn_data_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "True", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "bn_data_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "True", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "bn_data_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "True", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "bn_data_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "True", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "bn_data", + "attrs": { + "eps": "2e-05", + "fix_gamma": "True", + "momentum": "0.9" + }, + "inputs": [[1, 0, 0], [2, 0, 0], [3, 0, 0], [4, 0, 1], [5, 0, 1]] + }, + { + "op": "null", + "name": "conv0_weight", + "attrs": { + "kernel": "(7, 7)", + "no_bias": "True", + "num_filter": "64", + "pad": "(3, 3)", + "stride": "(2, 2)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "conv0", + "attrs": { + "kernel": "(7, 7)", + "no_bias": "True", + "num_filter": "64", + "pad": "(3, 3)", + "stride": "(2, 2)", + "workspace": "256" + }, + "inputs": [[6, 0, 0], [7, 0, 0]] + }, + { + "op": "null", + "name": "bn0_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "bn0_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "bn0_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "bn0_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "bn0", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[8, 0, 0], [9, 0, 0], [10, 0, 0], [11, 0, 1], [12, 0, 1]] + }, + { + "op": "Activation", + "name": "relu0", + "attrs": {"act_type": "relu"}, + "inputs": [[13, 0, 0]] + }, + { + "op": "Pooling", + "name": "pooling0", + "attrs": { + "kernel": "(3, 3)", + "pad": "(1, 1)", + "pool_type": "max", + "stride": "(2, 2)" + }, + "inputs": [[14, 0, 0]] + }, + { + "op": "null", + "name": "stage1_unit1_bn1_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage1_unit1_bn1_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage1_unit1_bn1_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage1_unit1_bn1_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage1_unit1_bn1", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[15, 0, 0], [16, 0, 0], [17, 0, 0], [18, 0, 1], [19, 0, 1]] + }, + { + "op": "Activation", + "name": "stage1_unit1_relu1", + "attrs": {"act_type": "relu"}, + "inputs": [[20, 0, 0]] + }, + { + "op": "null", + "name": "stage1_unit1_conv1_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "64", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage1_unit1_conv1", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "64", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[21, 0, 0], [22, 0, 0]] + }, + { + "op": "null", + "name": "stage1_unit1_bn2_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage1_unit1_bn2_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage1_unit1_bn2_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage1_unit1_bn2_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage1_unit1_bn2", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[23, 0, 0], [24, 0, 0], [25, 0, 0], [26, 0, 1], [27, 0, 1]] + }, + { + "op": "Activation", + "name": "stage1_unit1_relu2", + "attrs": {"act_type": "relu"}, + "inputs": [[28, 0, 0]] + }, + { + "op": "null", + "name": "stage1_unit1_conv2_weight", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "64", + "pad": "(1, 1)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage1_unit1_conv2", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "64", + "pad": "(1, 1)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[29, 0, 0], [30, 0, 0]] + }, + { + "op": "null", + "name": "stage1_unit1_bn3_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage1_unit1_bn3_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage1_unit1_bn3_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage1_unit1_bn3_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage1_unit1_bn3", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[31, 0, 0], [32, 0, 0], [33, 0, 0], [34, 0, 1], [35, 0, 1]] + }, + { + "op": "Activation", + "name": "stage1_unit1_relu3", + "attrs": {"act_type": "relu"}, + "inputs": [[36, 0, 0]] + }, + { + "op": "null", + "name": "stage1_unit1_conv3_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "256", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage1_unit1_conv3", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "256", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[37, 0, 0], [38, 0, 0]] + }, + { + "op": "null", + "name": "stage1_unit1_sc_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "256", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage1_unit1_sc", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "256", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[21, 0, 0], [40, 0, 0]] + }, + { + "op": "elemwise_add", + "name": "_plus0", + "inputs": [[39, 0, 0], [41, 0, 0]] + }, + { + "op": "null", + "name": "stage1_unit2_bn1_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage1_unit2_bn1_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage1_unit2_bn1_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage1_unit2_bn1_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage1_unit2_bn1", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[42, 0, 0], [43, 0, 0], [44, 0, 0], [45, 0, 1], [46, 0, 1]] + }, + { + "op": "Activation", + "name": "stage1_unit2_relu1", + "attrs": {"act_type": "relu"}, + "inputs": [[47, 0, 0]] + }, + { + "op": "null", + "name": "stage1_unit2_conv1_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "64", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage1_unit2_conv1", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "64", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[48, 0, 0], [49, 0, 0]] + }, + { + "op": "null", + "name": "stage1_unit2_bn2_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage1_unit2_bn2_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage1_unit2_bn2_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage1_unit2_bn2_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage1_unit2_bn2", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[50, 0, 0], [51, 0, 0], [52, 0, 0], [53, 0, 1], [54, 0, 1]] + }, + { + "op": "Activation", + "name": "stage1_unit2_relu2", + "attrs": {"act_type": "relu"}, + "inputs": [[55, 0, 0]] + }, + { + "op": "null", + "name": "stage1_unit2_conv2_weight", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "64", + "pad": "(1, 1)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage1_unit2_conv2", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "64", + "pad": "(1, 1)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[56, 0, 0], [57, 0, 0]] + }, + { + "op": "null", + "name": "stage1_unit2_bn3_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage1_unit2_bn3_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage1_unit2_bn3_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage1_unit2_bn3_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage1_unit2_bn3", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[58, 0, 0], [59, 0, 0], [60, 0, 0], [61, 0, 1], [62, 0, 1]] + }, + { + "op": "Activation", + "name": "stage1_unit2_relu3", + "attrs": {"act_type": "relu"}, + "inputs": [[63, 0, 0]] + }, + { + "op": "null", + "name": "stage1_unit2_conv3_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "256", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage1_unit2_conv3", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "256", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[64, 0, 0], [65, 0, 0]] + }, + { + "op": "elemwise_add", + "name": "_plus1", + "inputs": [[66, 0, 0], [42, 0, 0]] + }, + { + "op": "null", + "name": "stage1_unit3_bn1_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage1_unit3_bn1_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage1_unit3_bn1_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage1_unit3_bn1_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage1_unit3_bn1", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[67, 0, 0], [68, 0, 0], [69, 0, 0], [70, 0, 1], [71, 0, 1]] + }, + { + "op": "Activation", + "name": "stage1_unit3_relu1", + "attrs": {"act_type": "relu"}, + "inputs": [[72, 0, 0]] + }, + { + "op": "null", + "name": "stage1_unit3_conv1_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "64", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage1_unit3_conv1", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "64", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[73, 0, 0], [74, 0, 0]] + }, + { + "op": "null", + "name": "stage1_unit3_bn2_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage1_unit3_bn2_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage1_unit3_bn2_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage1_unit3_bn2_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage1_unit3_bn2", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[75, 0, 0], [76, 0, 0], [77, 0, 0], [78, 0, 1], [79, 0, 1]] + }, + { + "op": "Activation", + "name": "stage1_unit3_relu2", + "attrs": {"act_type": "relu"}, + "inputs": [[80, 0, 0]] + }, + { + "op": "null", + "name": "stage1_unit3_conv2_weight", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "64", + "pad": "(1, 1)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage1_unit3_conv2", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "64", + "pad": "(1, 1)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[81, 0, 0], [82, 0, 0]] + }, + { + "op": "null", + "name": "stage1_unit3_bn3_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage1_unit3_bn3_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage1_unit3_bn3_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage1_unit3_bn3_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage1_unit3_bn3", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[83, 0, 0], [84, 0, 0], [85, 0, 0], [86, 0, 1], [87, 0, 1]] + }, + { + "op": "Activation", + "name": "stage1_unit3_relu3", + "attrs": {"act_type": "relu"}, + "inputs": [[88, 0, 0]] + }, + { + "op": "null", + "name": "stage1_unit3_conv3_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "256", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage1_unit3_conv3", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "256", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[89, 0, 0], [90, 0, 0]] + }, + { + "op": "elemwise_add", + "name": "_plus2", + "inputs": [[91, 0, 0], [67, 0, 0]] + }, + { + "op": "null", + "name": "stage2_unit1_bn1_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit1_bn1_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit1_bn1_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit1_bn1_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage2_unit1_bn1", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[92, 0, 0], [93, 0, 0], [94, 0, 0], [95, 0, 1], [96, 0, 1]] + }, + { + "op": "Activation", + "name": "stage2_unit1_relu1", + "attrs": {"act_type": "relu"}, + "inputs": [[97, 0, 0]] + }, + { + "op": "null", + "name": "stage2_unit1_conv1_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "128", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage2_unit1_conv1", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "128", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[98, 0, 0], [99, 0, 0]] + }, + { + "op": "null", + "name": "stage2_unit1_bn2_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit1_bn2_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit1_bn2_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit1_bn2_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage2_unit1_bn2", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[100, 0, 0], [101, 0, 0], [102, 0, 0], [103, 0, 1], [104, 0, 1]] + }, + { + "op": "Activation", + "name": "stage2_unit1_relu2", + "attrs": {"act_type": "relu"}, + "inputs": [[105, 0, 0]] + }, + { + "op": "null", + "name": "stage2_unit1_conv2_weight", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "128", + "pad": "(1, 1)", + "stride": "(2, 2)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage2_unit1_conv2", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "128", + "pad": "(1, 1)", + "stride": "(2, 2)", + "workspace": "256" + }, + "inputs": [[106, 0, 0], [107, 0, 0]] + }, + { + "op": "null", + "name": "stage2_unit1_bn3_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit1_bn3_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit1_bn3_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit1_bn3_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage2_unit1_bn3", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[108, 0, 0], [109, 0, 0], [110, 0, 0], [111, 0, 1], [112, 0, 1]] + }, + { + "op": "Activation", + "name": "stage2_unit1_relu3", + "attrs": {"act_type": "relu"}, + "inputs": [[113, 0, 0]] + }, + { + "op": "null", + "name": "stage2_unit1_conv3_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "512", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage2_unit1_conv3", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "512", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[114, 0, 0], [115, 0, 0]] + }, + { + "op": "null", + "name": "stage2_unit1_sc_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "512", + "stride": "(2, 2)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage2_unit1_sc", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "512", + "stride": "(2, 2)", + "workspace": "256" + }, + "inputs": [[98, 0, 0], [117, 0, 0]] + }, + { + "op": "elemwise_add", + "name": "_plus3", + "inputs": [[116, 0, 0], [118, 0, 0]] + }, + { + "op": "null", + "name": "stage2_unit2_bn1_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit2_bn1_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit2_bn1_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit2_bn1_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage2_unit2_bn1", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[119, 0, 0], [120, 0, 0], [121, 0, 0], [122, 0, 1], [123, 0, 1]] + }, + { + "op": "Activation", + "name": "stage2_unit2_relu1", + "attrs": {"act_type": "relu"}, + "inputs": [[124, 0, 0]] + }, + { + "op": "null", + "name": "stage2_unit2_conv1_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "128", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage2_unit2_conv1", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "128", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[125, 0, 0], [126, 0, 0]] + }, + { + "op": "null", + "name": "stage2_unit2_bn2_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit2_bn2_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit2_bn2_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit2_bn2_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage2_unit2_bn2", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[127, 0, 0], [128, 0, 0], [129, 0, 0], [130, 0, 1], [131, 0, 1]] + }, + { + "op": "Activation", + "name": "stage2_unit2_relu2", + "attrs": {"act_type": "relu"}, + "inputs": [[132, 0, 0]] + }, + { + "op": "null", + "name": "stage2_unit2_conv2_weight", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "128", + "pad": "(1, 1)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage2_unit2_conv2", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "128", + "pad": "(1, 1)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[133, 0, 0], [134, 0, 0]] + }, + { + "op": "null", + "name": "stage2_unit2_bn3_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit2_bn3_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit2_bn3_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit2_bn3_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage2_unit2_bn3", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[135, 0, 0], [136, 0, 0], [137, 0, 0], [138, 0, 1], [139, 0, 1]] + }, + { + "op": "Activation", + "name": "stage2_unit2_relu3", + "attrs": {"act_type": "relu"}, + "inputs": [[140, 0, 0]] + }, + { + "op": "null", + "name": "stage2_unit2_conv3_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "512", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage2_unit2_conv3", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "512", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[141, 0, 0], [142, 0, 0]] + }, + { + "op": "elemwise_add", + "name": "_plus4", + "inputs": [[143, 0, 0], [119, 0, 0]] + }, + { + "op": "null", + "name": "stage2_unit3_bn1_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit3_bn1_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit3_bn1_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit3_bn1_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage2_unit3_bn1", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[144, 0, 0], [145, 0, 0], [146, 0, 0], [147, 0, 1], [148, 0, 1]] + }, + { + "op": "Activation", + "name": "stage2_unit3_relu1", + "attrs": {"act_type": "relu"}, + "inputs": [[149, 0, 0]] + }, + { + "op": "null", + "name": "stage2_unit3_conv1_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "128", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage2_unit3_conv1", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "128", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[150, 0, 0], [151, 0, 0]] + }, + { + "op": "null", + "name": "stage2_unit3_bn2_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit3_bn2_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit3_bn2_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit3_bn2_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage2_unit3_bn2", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[152, 0, 0], [153, 0, 0], [154, 0, 0], [155, 0, 1], [156, 0, 1]] + }, + { + "op": "Activation", + "name": "stage2_unit3_relu2", + "attrs": {"act_type": "relu"}, + "inputs": [[157, 0, 0]] + }, + { + "op": "null", + "name": "stage2_unit3_conv2_weight", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "128", + "pad": "(1, 1)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage2_unit3_conv2", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "128", + "pad": "(1, 1)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[158, 0, 0], [159, 0, 0]] + }, + { + "op": "null", + "name": "stage2_unit3_bn3_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit3_bn3_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit3_bn3_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit3_bn3_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage2_unit3_bn3", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[160, 0, 0], [161, 0, 0], [162, 0, 0], [163, 0, 1], [164, 0, 1]] + }, + { + "op": "Activation", + "name": "stage2_unit3_relu3", + "attrs": {"act_type": "relu"}, + "inputs": [[165, 0, 0]] + }, + { + "op": "null", + "name": "stage2_unit3_conv3_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "512", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage2_unit3_conv3", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "512", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[166, 0, 0], [167, 0, 0]] + }, + { + "op": "elemwise_add", + "name": "_plus5", + "inputs": [[168, 0, 0], [144, 0, 0]] + }, + { + "op": "null", + "name": "stage2_unit4_bn1_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit4_bn1_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit4_bn1_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit4_bn1_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage2_unit4_bn1", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[169, 0, 0], [170, 0, 0], [171, 0, 0], [172, 0, 1], [173, 0, 1]] + }, + { + "op": "Activation", + "name": "stage2_unit4_relu1", + "attrs": {"act_type": "relu"}, + "inputs": [[174, 0, 0]] + }, + { + "op": "null", + "name": "stage2_unit4_conv1_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "128", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage2_unit4_conv1", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "128", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[175, 0, 0], [176, 0, 0]] + }, + { + "op": "null", + "name": "stage2_unit4_bn2_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit4_bn2_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit4_bn2_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit4_bn2_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage2_unit4_bn2", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[177, 0, 0], [178, 0, 0], [179, 0, 0], [180, 0, 1], [181, 0, 1]] + }, + { + "op": "Activation", + "name": "stage2_unit4_relu2", + "attrs": {"act_type": "relu"}, + "inputs": [[182, 0, 0]] + }, + { + "op": "null", + "name": "stage2_unit4_conv2_weight", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "128", + "pad": "(1, 1)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage2_unit4_conv2", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "128", + "pad": "(1, 1)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[183, 0, 0], [184, 0, 0]] + }, + { + "op": "null", + "name": "stage2_unit4_bn3_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit4_bn3_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit4_bn3_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage2_unit4_bn3_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage2_unit4_bn3", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[185, 0, 0], [186, 0, 0], [187, 0, 0], [188, 0, 1], [189, 0, 1]] + }, + { + "op": "Activation", + "name": "stage2_unit4_relu3", + "attrs": {"act_type": "relu"}, + "inputs": [[190, 0, 0]] + }, + { + "op": "null", + "name": "stage2_unit4_conv3_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "512", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage2_unit4_conv3", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "512", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[191, 0, 0], [192, 0, 0]] + }, + { + "op": "elemwise_add", + "name": "_plus6", + "inputs": [[193, 0, 0], [169, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit1_bn1_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit1_bn1_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit1_bn1_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit1_bn1_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage3_unit1_bn1", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[194, 0, 0], [195, 0, 0], [196, 0, 0], [197, 0, 1], [198, 0, 1]] + }, + { + "op": "Activation", + "name": "stage3_unit1_relu1", + "attrs": {"act_type": "relu"}, + "inputs": [[199, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit1_conv1_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "256", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage3_unit1_conv1", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "256", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[200, 0, 0], [201, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit1_bn2_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit1_bn2_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit1_bn2_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit1_bn2_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage3_unit1_bn2", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[202, 0, 0], [203, 0, 0], [204, 0, 0], [205, 0, 1], [206, 0, 1]] + }, + { + "op": "Activation", + "name": "stage3_unit1_relu2", + "attrs": {"act_type": "relu"}, + "inputs": [[207, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit1_conv2_weight", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "256", + "pad": "(1, 1)", + "stride": "(2, 2)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage3_unit1_conv2", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "256", + "pad": "(1, 1)", + "stride": "(2, 2)", + "workspace": "256" + }, + "inputs": [[208, 0, 0], [209, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit1_bn3_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit1_bn3_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit1_bn3_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit1_bn3_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage3_unit1_bn3", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[210, 0, 0], [211, 0, 0], [212, 0, 0], [213, 0, 1], [214, 0, 1]] + }, + { + "op": "Activation", + "name": "stage3_unit1_relu3", + "attrs": {"act_type": "relu"}, + "inputs": [[215, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit1_conv3_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "1024", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage3_unit1_conv3", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "1024", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[216, 0, 0], [217, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit1_sc_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "1024", + "stride": "(2, 2)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage3_unit1_sc", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "1024", + "stride": "(2, 2)", + "workspace": "256" + }, + "inputs": [[200, 0, 0], [219, 0, 0]] + }, + { + "op": "elemwise_add", + "name": "_plus7", + "inputs": [[218, 0, 0], [220, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit2_bn1_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit2_bn1_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit2_bn1_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit2_bn1_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage3_unit2_bn1", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[221, 0, 0], [222, 0, 0], [223, 0, 0], [224, 0, 1], [225, 0, 1]] + }, + { + "op": "Activation", + "name": "stage3_unit2_relu1", + "attrs": {"act_type": "relu"}, + "inputs": [[226, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit2_conv1_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "256", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage3_unit2_conv1", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "256", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[227, 0, 0], [228, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit2_bn2_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit2_bn2_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit2_bn2_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit2_bn2_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage3_unit2_bn2", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[229, 0, 0], [230, 0, 0], [231, 0, 0], [232, 0, 1], [233, 0, 1]] + }, + { + "op": "Activation", + "name": "stage3_unit2_relu2", + "attrs": {"act_type": "relu"}, + "inputs": [[234, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit2_conv2_weight", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "256", + "pad": "(1, 1)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage3_unit2_conv2", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "256", + "pad": "(1, 1)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[235, 0, 0], [236, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit2_bn3_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit2_bn3_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit2_bn3_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit2_bn3_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage3_unit2_bn3", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[237, 0, 0], [238, 0, 0], [239, 0, 0], [240, 0, 1], [241, 0, 1]] + }, + { + "op": "Activation", + "name": "stage3_unit2_relu3", + "attrs": {"act_type": "relu"}, + "inputs": [[242, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit2_conv3_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "1024", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage3_unit2_conv3", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "1024", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[243, 0, 0], [244, 0, 0]] + }, + { + "op": "elemwise_add", + "name": "_plus8", + "inputs": [[245, 0, 0], [221, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit3_bn1_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit3_bn1_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit3_bn1_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit3_bn1_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage3_unit3_bn1", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[246, 0, 0], [247, 0, 0], [248, 0, 0], [249, 0, 1], [250, 0, 1]] + }, + { + "op": "Activation", + "name": "stage3_unit3_relu1", + "attrs": {"act_type": "relu"}, + "inputs": [[251, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit3_conv1_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "256", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage3_unit3_conv1", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "256", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[252, 0, 0], [253, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit3_bn2_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit3_bn2_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit3_bn2_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit3_bn2_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage3_unit3_bn2", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[254, 0, 0], [255, 0, 0], [256, 0, 0], [257, 0, 1], [258, 0, 1]] + }, + { + "op": "Activation", + "name": "stage3_unit3_relu2", + "attrs": {"act_type": "relu"}, + "inputs": [[259, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit3_conv2_weight", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "256", + "pad": "(1, 1)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage3_unit3_conv2", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "256", + "pad": "(1, 1)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[260, 0, 0], [261, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit3_bn3_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit3_bn3_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit3_bn3_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit3_bn3_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage3_unit3_bn3", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[262, 0, 0], [263, 0, 0], [264, 0, 0], [265, 0, 1], [266, 0, 1]] + }, + { + "op": "Activation", + "name": "stage3_unit3_relu3", + "attrs": {"act_type": "relu"}, + "inputs": [[267, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit3_conv3_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "1024", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage3_unit3_conv3", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "1024", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[268, 0, 0], [269, 0, 0]] + }, + { + "op": "elemwise_add", + "name": "_plus9", + "inputs": [[270, 0, 0], [246, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit4_bn1_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit4_bn1_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit4_bn1_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit4_bn1_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage3_unit4_bn1", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[271, 0, 0], [272, 0, 0], [273, 0, 0], [274, 0, 1], [275, 0, 1]] + }, + { + "op": "Activation", + "name": "stage3_unit4_relu1", + "attrs": {"act_type": "relu"}, + "inputs": [[276, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit4_conv1_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "256", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage3_unit4_conv1", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "256", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[277, 0, 0], [278, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit4_bn2_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit4_bn2_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit4_bn2_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit4_bn2_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage3_unit4_bn2", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[279, 0, 0], [280, 0, 0], [281, 0, 0], [282, 0, 1], [283, 0, 1]] + }, + { + "op": "Activation", + "name": "stage3_unit4_relu2", + "attrs": {"act_type": "relu"}, + "inputs": [[284, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit4_conv2_weight", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "256", + "pad": "(1, 1)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage3_unit4_conv2", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "256", + "pad": "(1, 1)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[285, 0, 0], [286, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit4_bn3_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit4_bn3_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit4_bn3_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit4_bn3_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage3_unit4_bn3", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[287, 0, 0], [288, 0, 0], [289, 0, 0], [290, 0, 1], [291, 0, 1]] + }, + { + "op": "Activation", + "name": "stage3_unit4_relu3", + "attrs": {"act_type": "relu"}, + "inputs": [[292, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit4_conv3_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "1024", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage3_unit4_conv3", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "1024", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[293, 0, 0], [294, 0, 0]] + }, + { + "op": "elemwise_add", + "name": "_plus10", + "inputs": [[295, 0, 0], [271, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit5_bn1_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit5_bn1_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit5_bn1_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit5_bn1_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage3_unit5_bn1", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[296, 0, 0], [297, 0, 0], [298, 0, 0], [299, 0, 1], [300, 0, 1]] + }, + { + "op": "Activation", + "name": "stage3_unit5_relu1", + "attrs": {"act_type": "relu"}, + "inputs": [[301, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit5_conv1_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "256", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage3_unit5_conv1", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "256", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[302, 0, 0], [303, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit5_bn2_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit5_bn2_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit5_bn2_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit5_bn2_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage3_unit5_bn2", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[304, 0, 0], [305, 0, 0], [306, 0, 0], [307, 0, 1], [308, 0, 1]] + }, + { + "op": "Activation", + "name": "stage3_unit5_relu2", + "attrs": {"act_type": "relu"}, + "inputs": [[309, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit5_conv2_weight", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "256", + "pad": "(1, 1)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage3_unit5_conv2", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "256", + "pad": "(1, 1)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[310, 0, 0], [311, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit5_bn3_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit5_bn3_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit5_bn3_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit5_bn3_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage3_unit5_bn3", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[312, 0, 0], [313, 0, 0], [314, 0, 0], [315, 0, 1], [316, 0, 1]] + }, + { + "op": "Activation", + "name": "stage3_unit5_relu3", + "attrs": {"act_type": "relu"}, + "inputs": [[317, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit5_conv3_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "1024", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage3_unit5_conv3", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "1024", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[318, 0, 0], [319, 0, 0]] + }, + { + "op": "elemwise_add", + "name": "_plus11", + "inputs": [[320, 0, 0], [296, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit6_bn1_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit6_bn1_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit6_bn1_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit6_bn1_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage3_unit6_bn1", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[321, 0, 0], [322, 0, 0], [323, 0, 0], [324, 0, 1], [325, 0, 1]] + }, + { + "op": "Activation", + "name": "stage3_unit6_relu1", + "attrs": {"act_type": "relu"}, + "inputs": [[326, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit6_conv1_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "256", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage3_unit6_conv1", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "256", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[327, 0, 0], [328, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit6_bn2_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit6_bn2_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit6_bn2_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit6_bn2_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage3_unit6_bn2", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[329, 0, 0], [330, 0, 0], [331, 0, 0], [332, 0, 1], [333, 0, 1]] + }, + { + "op": "Activation", + "name": "stage3_unit6_relu2", + "attrs": {"act_type": "relu"}, + "inputs": [[334, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit6_conv2_weight", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "256", + "pad": "(1, 1)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage3_unit6_conv2", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "256", + "pad": "(1, 1)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[335, 0, 0], [336, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit6_bn3_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit6_bn3_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit6_bn3_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage3_unit6_bn3_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage3_unit6_bn3", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[337, 0, 0], [338, 0, 0], [339, 0, 0], [340, 0, 1], [341, 0, 1]] + }, + { + "op": "Activation", + "name": "stage3_unit6_relu3", + "attrs": {"act_type": "relu"}, + "inputs": [[342, 0, 0]] + }, + { + "op": "null", + "name": "stage3_unit6_conv3_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "1024", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage3_unit6_conv3", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "1024", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[343, 0, 0], [344, 0, 0]] + }, + { + "op": "elemwise_add", + "name": "_plus12", + "inputs": [[345, 0, 0], [321, 0, 0]] + }, + { + "op": "null", + "name": "_plus12_cls_pred_conv_weight", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "84", + "pad": "(1, 1)", + "stride": "(1, 1)" + }, + "inputs": [] + }, + { + "op": "null", + "name": "_plus12_cls_pred_conv_bias", + "attrs": { + "__init__": "[\"constant\", {\"value\": 0.0}]", + "__lr_mult__": "2.0" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "_plus12_cls_pred_conv", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "84", + "pad": "(1, 1)", + "stride": "(1, 1)" + }, + "inputs": [[346, 0, 0], [347, 0, 0], [348, 0, 0]] + }, + { + "op": "transpose", + "name": "transpose1", + "attrs": {"axes": "(0, 2, 3, 1)"}, + "inputs": [[349, 0, 0]] + }, + { + "op": "Flatten", + "name": "flatten2", + "inputs": [[350, 0, 0]] + }, + { + "op": "null", + "name": "stage4_unit1_bn1_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage4_unit1_bn1_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage4_unit1_bn1_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage4_unit1_bn1_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage4_unit1_bn1", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[346, 0, 0], [352, 0, 0], [353, 0, 0], [354, 0, 1], [355, 0, 1]] + }, + { + "op": "Activation", + "name": "stage4_unit1_relu1", + "attrs": {"act_type": "relu"}, + "inputs": [[356, 0, 0]] + }, + { + "op": "null", + "name": "stage4_unit1_conv1_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "512", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage4_unit1_conv1", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "512", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[357, 0, 0], [358, 0, 0]] + }, + { + "op": "null", + "name": "stage4_unit1_bn2_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage4_unit1_bn2_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage4_unit1_bn2_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage4_unit1_bn2_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage4_unit1_bn2", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[359, 0, 0], [360, 0, 0], [361, 0, 0], [362, 0, 1], [363, 0, 1]] + }, + { + "op": "Activation", + "name": "stage4_unit1_relu2", + "attrs": {"act_type": "relu"}, + "inputs": [[364, 0, 0]] + }, + { + "op": "null", + "name": "stage4_unit1_conv2_weight", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "512", + "pad": "(1, 1)", + "stride": "(2, 2)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage4_unit1_conv2", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "512", + "pad": "(1, 1)", + "stride": "(2, 2)", + "workspace": "256" + }, + "inputs": [[365, 0, 0], [366, 0, 0]] + }, + { + "op": "null", + "name": "stage4_unit1_bn3_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage4_unit1_bn3_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage4_unit1_bn3_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage4_unit1_bn3_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage4_unit1_bn3", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[367, 0, 0], [368, 0, 0], [369, 0, 0], [370, 0, 1], [371, 0, 1]] + }, + { + "op": "Activation", + "name": "stage4_unit1_relu3", + "attrs": {"act_type": "relu"}, + "inputs": [[372, 0, 0]] + }, + { + "op": "null", + "name": "stage4_unit1_conv3_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "2048", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage4_unit1_conv3", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "2048", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[373, 0, 0], [374, 0, 0]] + }, + { + "op": "null", + "name": "stage4_unit1_sc_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "2048", + "stride": "(2, 2)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage4_unit1_sc", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "2048", + "stride": "(2, 2)", + "workspace": "256" + }, + "inputs": [[357, 0, 0], [376, 0, 0]] + }, + { + "op": "elemwise_add", + "name": "_plus13", + "inputs": [[375, 0, 0], [377, 0, 0]] + }, + { + "op": "null", + "name": "stage4_unit2_bn1_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage4_unit2_bn1_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage4_unit2_bn1_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage4_unit2_bn1_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage4_unit2_bn1", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[378, 0, 0], [379, 0, 0], [380, 0, 0], [381, 0, 1], [382, 0, 1]] + }, + { + "op": "Activation", + "name": "stage4_unit2_relu1", + "attrs": {"act_type": "relu"}, + "inputs": [[383, 0, 0]] + }, + { + "op": "null", + "name": "stage4_unit2_conv1_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "512", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage4_unit2_conv1", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "512", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[384, 0, 0], [385, 0, 0]] + }, + { + "op": "null", + "name": "stage4_unit2_bn2_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage4_unit2_bn2_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage4_unit2_bn2_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage4_unit2_bn2_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage4_unit2_bn2", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[386, 0, 0], [387, 0, 0], [388, 0, 0], [389, 0, 1], [390, 0, 1]] + }, + { + "op": "Activation", + "name": "stage4_unit2_relu2", + "attrs": {"act_type": "relu"}, + "inputs": [[391, 0, 0]] + }, + { + "op": "null", + "name": "stage4_unit2_conv2_weight", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "512", + "pad": "(1, 1)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage4_unit2_conv2", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "512", + "pad": "(1, 1)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[392, 0, 0], [393, 0, 0]] + }, + { + "op": "null", + "name": "stage4_unit2_bn3_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage4_unit2_bn3_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage4_unit2_bn3_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage4_unit2_bn3_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage4_unit2_bn3", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[394, 0, 0], [395, 0, 0], [396, 0, 0], [397, 0, 1], [398, 0, 1]] + }, + { + "op": "Activation", + "name": "stage4_unit2_relu3", + "attrs": {"act_type": "relu"}, + "inputs": [[399, 0, 0]] + }, + { + "op": "null", + "name": "stage4_unit2_conv3_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "2048", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage4_unit2_conv3", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "2048", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[400, 0, 0], [401, 0, 0]] + }, + { + "op": "elemwise_add", + "name": "_plus14", + "inputs": [[402, 0, 0], [378, 0, 0]] + }, + { + "op": "null", + "name": "stage4_unit3_bn1_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage4_unit3_bn1_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage4_unit3_bn1_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage4_unit3_bn1_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage4_unit3_bn1", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[403, 0, 0], [404, 0, 0], [405, 0, 0], [406, 0, 1], [407, 0, 1]] + }, + { + "op": "Activation", + "name": "stage4_unit3_relu1", + "attrs": {"act_type": "relu"}, + "inputs": [[408, 0, 0]] + }, + { + "op": "null", + "name": "stage4_unit3_conv1_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "512", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage4_unit3_conv1", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "512", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[409, 0, 0], [410, 0, 0]] + }, + { + "op": "null", + "name": "stage4_unit3_bn2_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage4_unit3_bn2_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage4_unit3_bn2_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage4_unit3_bn2_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage4_unit3_bn2", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[411, 0, 0], [412, 0, 0], [413, 0, 0], [414, 0, 1], [415, 0, 1]] + }, + { + "op": "Activation", + "name": "stage4_unit3_relu2", + "attrs": {"act_type": "relu"}, + "inputs": [[416, 0, 0]] + }, + { + "op": "null", + "name": "stage4_unit3_conv2_weight", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "512", + "pad": "(1, 1)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage4_unit3_conv2", + "attrs": { + "kernel": "(3, 3)", + "no_bias": "True", + "num_filter": "512", + "pad": "(1, 1)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[417, 0, 0], [418, 0, 0]] + }, + { + "op": "null", + "name": "stage4_unit3_bn3_gamma", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage4_unit3_bn3_beta", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage4_unit3_bn3_moving_mean", + "attrs": { + "__init__": "[\"zero\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "null", + "name": "stage4_unit3_bn3_moving_var", + "attrs": { + "__init__": "[\"one\", {}]", + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [] + }, + { + "op": "BatchNorm", + "name": "stage4_unit3_bn3", + "attrs": { + "eps": "2e-05", + "fix_gamma": "False", + "momentum": "0.9" + }, + "inputs": [[419, 0, 0], [420, 0, 0], [421, 0, 0], [422, 0, 1], [423, 0, 1]] + }, + { + "op": "Activation", + "name": "stage4_unit3_relu3", + "attrs": {"act_type": "relu"}, + "inputs": [[424, 0, 0]] + }, + { + "op": "null", + "name": "stage4_unit3_conv3_weight", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "2048", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "stage4_unit3_conv3", + "attrs": { + "kernel": "(1, 1)", + "no_bias": "True", + "num_filter": "2048", + "pad": "(0, 0)", + "stride": "(1, 1)", + "workspace": "256" + }, + "inputs": [[425, 0, 0], [426, 0, 0]] + }, + { + "op": "elemwise_add", + "name": "_plus15", + "inputs": [[427, 0, 0], [403, 0, 0]] + }, + { + "op": "null", + "name": "_plus15_cls_pred_conv_weight", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "126", + "pad": "(1, 1)", + "stride": "(1, 1)" + }, + "inputs": [] + }, + { + "op": "null", + "name": "_plus15_cls_pred_conv_bias", + "attrs": { + "__init__": "[\"constant\", {\"value\": 0.0}]", + "__lr_mult__": "2.0" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "_plus15_cls_pred_conv", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "126", + "pad": "(1, 1)", + "stride": "(1, 1)" + }, + "inputs": [[428, 0, 0], [429, 0, 0], [430, 0, 0]] + }, + { + "op": "transpose", + "name": "transpose3", + "attrs": {"axes": "(0, 2, 3, 1)"}, + "inputs": [[431, 0, 0]] + }, + { + "op": "Flatten", + "name": "flatten5", + "inputs": [[432, 0, 0]] + }, + { + "op": "null", + "name": "multi_feat_2_conv_1x1_conv_weight", + "attrs": { + "kernel": "(1, 1)", + "num_filter": "256", + "pad": "(0, 0)", + "stride": "(1, 1)" + }, + "inputs": [] + }, + { + "op": "null", + "name": "multi_feat_2_conv_1x1_conv_bias", + "attrs": { + "kernel": "(1, 1)", + "num_filter": "256", + "pad": "(0, 0)", + "stride": "(1, 1)" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "multi_feat_2_conv_1x1_conv", + "attrs": { + "kernel": "(1, 1)", + "num_filter": "256", + "pad": "(0, 0)", + "stride": "(1, 1)" + }, + "inputs": [[428, 0, 0], [434, 0, 0], [435, 0, 0]] + }, + { + "op": "Activation", + "name": "multi_feat_2_conv_1x1_relu", + "attrs": {"act_type": "relu"}, + "inputs": [[436, 0, 0]] + }, + { + "op": "null", + "name": "multi_feat_2_conv_3x3_conv_weight", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "512", + "pad": "(1, 1)", + "stride": "(2, 2)" + }, + "inputs": [] + }, + { + "op": "null", + "name": "multi_feat_2_conv_3x3_conv_bias", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "512", + "pad": "(1, 1)", + "stride": "(2, 2)" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "multi_feat_2_conv_3x3_conv", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "512", + "pad": "(1, 1)", + "stride": "(2, 2)" + }, + "inputs": [[437, 0, 0], [438, 0, 0], [439, 0, 0]] + }, + { + "op": "Activation", + "name": "multi_feat_2_conv_3x3_relu", + "attrs": {"act_type": "relu"}, + "inputs": [[440, 0, 0]] + }, + { + "op": "null", + "name": "multi_feat_2_conv_3x3_relu_cls_pred_conv_weight", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "126", + "pad": "(1, 1)", + "stride": "(1, 1)" + }, + "inputs": [] + }, + { + "op": "null", + "name": "multi_feat_2_conv_3x3_relu_cls_pred_conv_bias", + "attrs": { + "__init__": "[\"constant\", {\"value\": 0.0}]", + "__lr_mult__": "2.0" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "multi_feat_2_conv_3x3_relu_cls_pred_conv", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "126", + "pad": "(1, 1)", + "stride": "(1, 1)" + }, + "inputs": [[441, 0, 0], [442, 0, 0], [443, 0, 0]] + }, + { + "op": "transpose", + "name": "transpose5", + "attrs": {"axes": "(0, 2, 3, 1)"}, + "inputs": [[444, 0, 0]] + }, + { + "op": "Flatten", + "name": "flatten8", + "inputs": [[445, 0, 0]] + }, + { + "op": "null", + "name": "multi_feat_3_conv_1x1_conv_weight", + "attrs": { + "kernel": "(1, 1)", + "num_filter": "128", + "pad": "(0, 0)", + "stride": "(1, 1)" + }, + "inputs": [] + }, + { + "op": "null", + "name": "multi_feat_3_conv_1x1_conv_bias", + "attrs": { + "kernel": "(1, 1)", + "num_filter": "128", + "pad": "(0, 0)", + "stride": "(1, 1)" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "multi_feat_3_conv_1x1_conv", + "attrs": { + "kernel": "(1, 1)", + "num_filter": "128", + "pad": "(0, 0)", + "stride": "(1, 1)" + }, + "inputs": [[441, 0, 0], [447, 0, 0], [448, 0, 0]] + }, + { + "op": "Activation", + "name": "multi_feat_3_conv_1x1_relu", + "attrs": {"act_type": "relu"}, + "inputs": [[449, 0, 0]] + }, + { + "op": "null", + "name": "multi_feat_3_conv_3x3_conv_weight", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "256", + "pad": "(1, 1)", + "stride": "(2, 2)" + }, + "inputs": [] + }, + { + "op": "null", + "name": "multi_feat_3_conv_3x3_conv_bias", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "256", + "pad": "(1, 1)", + "stride": "(2, 2)" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "multi_feat_3_conv_3x3_conv", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "256", + "pad": "(1, 1)", + "stride": "(2, 2)" + }, + "inputs": [[450, 0, 0], [451, 0, 0], [452, 0, 0]] + }, + { + "op": "Activation", + "name": "multi_feat_3_conv_3x3_relu", + "attrs": {"act_type": "relu"}, + "inputs": [[453, 0, 0]] + }, + { + "op": "null", + "name": "multi_feat_3_conv_3x3_relu_cls_pred_conv_weight", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "126", + "pad": "(1, 1)", + "stride": "(1, 1)" + }, + "inputs": [] + }, + { + "op": "null", + "name": "multi_feat_3_conv_3x3_relu_cls_pred_conv_bias", + "attrs": { + "__init__": "[\"constant\", {\"value\": 0.0}]", + "__lr_mult__": "2.0" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "multi_feat_3_conv_3x3_relu_cls_pred_conv", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "126", + "pad": "(1, 1)", + "stride": "(1, 1)" + }, + "inputs": [[454, 0, 0], [455, 0, 0], [456, 0, 0]] + }, + { + "op": "transpose", + "name": "transpose7", + "attrs": {"axes": "(0, 2, 3, 1)"}, + "inputs": [[457, 0, 0]] + }, + { + "op": "Flatten", + "name": "flatten11", + "inputs": [[458, 0, 0]] + }, + { + "op": "null", + "name": "multi_feat_4_conv_1x1_conv_weight", + "attrs": { + "kernel": "(1, 1)", + "num_filter": "128", + "pad": "(0, 0)", + "stride": "(1, 1)" + }, + "inputs": [] + }, + { + "op": "null", + "name": "multi_feat_4_conv_1x1_conv_bias", + "attrs": { + "kernel": "(1, 1)", + "num_filter": "128", + "pad": "(0, 0)", + "stride": "(1, 1)" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "multi_feat_4_conv_1x1_conv", + "attrs": { + "kernel": "(1, 1)", + "num_filter": "128", + "pad": "(0, 0)", + "stride": "(1, 1)" + }, + "inputs": [[454, 0, 0], [460, 0, 0], [461, 0, 0]] + }, + { + "op": "Activation", + "name": "multi_feat_4_conv_1x1_relu", + "attrs": {"act_type": "relu"}, + "inputs": [[462, 0, 0]] + }, + { + "op": "null", + "name": "multi_feat_4_conv_3x3_conv_weight", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "256", + "pad": "(1, 1)", + "stride": "(2, 2)" + }, + "inputs": [] + }, + { + "op": "null", + "name": "multi_feat_4_conv_3x3_conv_bias", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "256", + "pad": "(1, 1)", + "stride": "(2, 2)" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "multi_feat_4_conv_3x3_conv", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "256", + "pad": "(1, 1)", + "stride": "(2, 2)" + }, + "inputs": [[463, 0, 0], [464, 0, 0], [465, 0, 0]] + }, + { + "op": "Activation", + "name": "multi_feat_4_conv_3x3_relu", + "attrs": {"act_type": "relu"}, + "inputs": [[466, 0, 0]] + }, + { + "op": "null", + "name": "multi_feat_4_conv_3x3_relu_cls_pred_conv_weight", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "84", + "pad": "(1, 1)", + "stride": "(1, 1)" + }, + "inputs": [] + }, + { + "op": "null", + "name": "multi_feat_4_conv_3x3_relu_cls_pred_conv_bias", + "attrs": { + "__init__": "[\"constant\", {\"value\": 0.0}]", + "__lr_mult__": "2.0" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "multi_feat_4_conv_3x3_relu_cls_pred_conv", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "84", + "pad": "(1, 1)", + "stride": "(1, 1)" + }, + "inputs": [[467, 0, 0], [468, 0, 0], [469, 0, 0]] + }, + { + "op": "transpose", + "name": "transpose9", + "attrs": {"axes": "(0, 2, 3, 1)"}, + "inputs": [[470, 0, 0]] + }, + { + "op": "Flatten", + "name": "flatten14", + "inputs": [[471, 0, 0]] + }, + { + "op": "null", + "name": "multi_feat_5_conv_1x1_conv_weight", + "attrs": { + "kernel": "(1, 1)", + "num_filter": "128", + "pad": "(0, 0)", + "stride": "(1, 1)" + }, + "inputs": [] + }, + { + "op": "null", + "name": "multi_feat_5_conv_1x1_conv_bias", + "attrs": { + "kernel": "(1, 1)", + "num_filter": "128", + "pad": "(0, 0)", + "stride": "(1, 1)" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "multi_feat_5_conv_1x1_conv", + "attrs": { + "kernel": "(1, 1)", + "num_filter": "128", + "pad": "(0, 0)", + "stride": "(1, 1)" + }, + "inputs": [[467, 0, 0], [473, 0, 0], [474, 0, 0]] + }, + { + "op": "Activation", + "name": "multi_feat_5_conv_1x1_relu", + "attrs": {"act_type": "relu"}, + "inputs": [[475, 0, 0]] + }, + { + "op": "null", + "name": "multi_feat_5_conv_3x3_conv_weight", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "128", + "pad": "(1, 1)", + "stride": "(2, 2)" + }, + "inputs": [] + }, + { + "op": "null", + "name": "multi_feat_5_conv_3x3_conv_bias", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "128", + "pad": "(1, 1)", + "stride": "(2, 2)" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "multi_feat_5_conv_3x3_conv", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "128", + "pad": "(1, 1)", + "stride": "(2, 2)" + }, + "inputs": [[476, 0, 0], [477, 0, 0], [478, 0, 0]] + }, + { + "op": "Activation", + "name": "multi_feat_5_conv_3x3_relu", + "attrs": {"act_type": "relu"}, + "inputs": [[479, 0, 0]] + }, + { + "op": "null", + "name": "multi_feat_5_conv_3x3_relu_cls_pred_conv_weight", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "84", + "pad": "(1, 1)", + "stride": "(1, 1)" + }, + "inputs": [] + }, + { + "op": "null", + "name": "multi_feat_5_conv_3x3_relu_cls_pred_conv_bias", + "attrs": { + "__init__": "[\"constant\", {\"value\": 0.0}]", + "__lr_mult__": "2.0" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "multi_feat_5_conv_3x3_relu_cls_pred_conv", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "84", + "pad": "(1, 1)", + "stride": "(1, 1)" + }, + "inputs": [[480, 0, 0], [481, 0, 0], [482, 0, 0]] + }, + { + "op": "transpose", + "name": "transpose11", + "attrs": {"axes": "(0, 2, 3, 1)"}, + "inputs": [[483, 0, 0]] + }, + { + "op": "Flatten", + "name": "flatten17", + "inputs": [[484, 0, 0]] + }, + { + "op": "Concat", + "name": "concat0", + "attrs": { + "dim": "1", + "num_args": "6" + }, + "inputs": [[351, 0, 0], [433, 0, 0], [446, 0, 0], [459, 0, 0], [472, 0, 0], [485, 0, 0]] + }, + { + "op": "Reshape", + "name": "reshape0", + "attrs": {"shape": "(0, -1, 21)"}, + "inputs": [[486, 0, 0]] + }, + { + "op": "transpose", + "name": "multibox_cls_pred", + "attrs": {"axes": "(0, 2, 1)"}, + "inputs": [[487, 0, 0]] + }, + { + "op": "SoftmaxActivation", + "name": "cls_prob", + "attrs": {"mode": "channel"}, + "inputs": [[488, 0, 0]] + }, + { + "op": "null", + "name": "_plus12_loc_pred_conv_weight", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "16", + "pad": "(1, 1)", + "stride": "(1, 1)" + }, + "inputs": [] + }, + { + "op": "null", + "name": "_plus12_loc_pred_conv_bias", + "attrs": { + "__init__": "[\"constant\", {\"value\": 0.0}]", + "__lr_mult__": "2.0" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "_plus12_loc_pred_conv", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "16", + "pad": "(1, 1)", + "stride": "(1, 1)" + }, + "inputs": [[346, 0, 0], [490, 0, 0], [491, 0, 0]] + }, + { + "op": "transpose", + "name": "transpose0", + "attrs": {"axes": "(0, 2, 3, 1)"}, + "inputs": [[492, 0, 0]] + }, + { + "op": "Flatten", + "name": "flatten1", + "inputs": [[493, 0, 0]] + }, + { + "op": "null", + "name": "_plus15_loc_pred_conv_weight", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "24", + "pad": "(1, 1)", + "stride": "(1, 1)" + }, + "inputs": [] + }, + { + "op": "null", + "name": "_plus15_loc_pred_conv_bias", + "attrs": { + "__init__": "[\"constant\", {\"value\": 0.0}]", + "__lr_mult__": "2.0" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "_plus15_loc_pred_conv", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "24", + "pad": "(1, 1)", + "stride": "(1, 1)" + }, + "inputs": [[428, 0, 0], [495, 0, 0], [496, 0, 0]] + }, + { + "op": "transpose", + "name": "transpose2", + "attrs": {"axes": "(0, 2, 3, 1)"}, + "inputs": [[497, 0, 0]] + }, + { + "op": "Flatten", + "name": "flatten4", + "inputs": [[498, 0, 0]] + }, + { + "op": "null", + "name": "multi_feat_2_conv_3x3_relu_loc_pred_conv_weight", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "24", + "pad": "(1, 1)", + "stride": "(1, 1)" + }, + "inputs": [] + }, + { + "op": "null", + "name": "multi_feat_2_conv_3x3_relu_loc_pred_conv_bias", + "attrs": { + "__init__": "[\"constant\", {\"value\": 0.0}]", + "__lr_mult__": "2.0" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "multi_feat_2_conv_3x3_relu_loc_pred_conv", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "24", + "pad": "(1, 1)", + "stride": "(1, 1)" + }, + "inputs": [[441, 0, 0], [500, 0, 0], [501, 0, 0]] + }, + { + "op": "transpose", + "name": "transpose4", + "attrs": {"axes": "(0, 2, 3, 1)"}, + "inputs": [[502, 0, 0]] + }, + { + "op": "Flatten", + "name": "flatten7", + "inputs": [[503, 0, 0]] + }, + { + "op": "null", + "name": "multi_feat_3_conv_3x3_relu_loc_pred_conv_weight", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "24", + "pad": "(1, 1)", + "stride": "(1, 1)" + }, + "inputs": [] + }, + { + "op": "null", + "name": "multi_feat_3_conv_3x3_relu_loc_pred_conv_bias", + "attrs": { + "__init__": "[\"constant\", {\"value\": 0.0}]", + "__lr_mult__": "2.0" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "multi_feat_3_conv_3x3_relu_loc_pred_conv", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "24", + "pad": "(1, 1)", + "stride": "(1, 1)" + }, + "inputs": [[454, 0, 0], [505, 0, 0], [506, 0, 0]] + }, + { + "op": "transpose", + "name": "transpose6", + "attrs": {"axes": "(0, 2, 3, 1)"}, + "inputs": [[507, 0, 0]] + }, + { + "op": "Flatten", + "name": "flatten10", + "inputs": [[508, 0, 0]] + }, + { + "op": "null", + "name": "multi_feat_4_conv_3x3_relu_loc_pred_conv_weight", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "16", + "pad": "(1, 1)", + "stride": "(1, 1)" + }, + "inputs": [] + }, + { + "op": "null", + "name": "multi_feat_4_conv_3x3_relu_loc_pred_conv_bias", + "attrs": { + "__init__": "[\"constant\", {\"value\": 0.0}]", + "__lr_mult__": "2.0" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "multi_feat_4_conv_3x3_relu_loc_pred_conv", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "16", + "pad": "(1, 1)", + "stride": "(1, 1)" + }, + "inputs": [[467, 0, 0], [510, 0, 0], [511, 0, 0]] + }, + { + "op": "transpose", + "name": "transpose8", + "attrs": {"axes": "(0, 2, 3, 1)"}, + "inputs": [[512, 0, 0]] + }, + { + "op": "Flatten", + "name": "flatten13", + "inputs": [[513, 0, 0]] + }, + { + "op": "null", + "name": "multi_feat_5_conv_3x3_relu_loc_pred_conv_weight", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "16", + "pad": "(1, 1)", + "stride": "(1, 1)" + }, + "inputs": [] + }, + { + "op": "null", + "name": "multi_feat_5_conv_3x3_relu_loc_pred_conv_bias", + "attrs": { + "__init__": "[\"constant\", {\"value\": 0.0}]", + "__lr_mult__": "2.0" + }, + "inputs": [] + }, + { + "op": "Convolution", + "name": "multi_feat_5_conv_3x3_relu_loc_pred_conv", + "attrs": { + "kernel": "(3, 3)", + "num_filter": "16", + "pad": "(1, 1)", + "stride": "(1, 1)" + }, + "inputs": [[480, 0, 0], [515, 0, 0], [516, 0, 0]] + }, + { + "op": "transpose", + "name": "transpose10", + "attrs": {"axes": "(0, 2, 3, 1)"}, + "inputs": [[517, 0, 0]] + }, + { + "op": "Flatten", + "name": "flatten16", + "inputs": [[518, 0, 0]] + }, + { + "op": "Concat", + "name": "multibox_loc_pred", + "attrs": { + "dim": "1", + "num_args": "6" + }, + "inputs": [[494, 0, 0], [499, 0, 0], [504, 0, 0], [509, 0, 0], [514, 0, 0], [519, 0, 0]] + }, + { + "op": "_contrib_MultiBoxPrior", + "name": "_plus12_anchors", + "attrs": { + "clip": "False", + "ratios": "(1,2,0.5)", + "sizes": "(0.1,0.141)", + "steps": "(-1.0, -1.0)" + }, + "inputs": [[346, 0, 0]] + }, + { + "op": "Flatten", + "name": "flatten3", + "inputs": [[521, 0, 0]] + }, + { + "op": "_contrib_MultiBoxPrior", + "name": "_plus15_anchors", + "attrs": { + "clip": "False", + "ratios": "(1,2,0.5,3,0.333333333333)", + "sizes": "(0.2,0.272)", + "steps": "(-1.0, -1.0)" + }, + "inputs": [[428, 0, 0]] + }, + { + "op": "Flatten", + "name": "flatten6", + "inputs": [[523, 0, 0]] + }, + { + "op": "_contrib_MultiBoxPrior", + "name": "multi_feat_2_conv_3x3_relu_anchors", + "attrs": { + "clip": "False", + "ratios": "(1,2,0.5,3,0.333333333333)", + "sizes": "(0.37,0.447)", + "steps": "(-1.0, -1.0)" + }, + "inputs": [[441, 0, 0]] + }, + { + "op": "Flatten", + "name": "flatten9", + "inputs": [[525, 0, 0]] + }, + { + "op": "_contrib_MultiBoxPrior", + "name": "multi_feat_3_conv_3x3_relu_anchors", + "attrs": { + "clip": "False", + "ratios": "(1,2,0.5,3,0.333333333333)", + "sizes": "(0.54,0.619)", + "steps": "(-1.0, -1.0)" + }, + "inputs": [[454, 0, 0]] + }, + { + "op": "Flatten", + "name": "flatten12", + "inputs": [[527, 0, 0]] + }, + { + "op": "_contrib_MultiBoxPrior", + "name": "multi_feat_4_conv_3x3_relu_anchors", + "attrs": { + "clip": "False", + "ratios": "(1,2,0.5)", + "sizes": "(0.71,0.79)", + "steps": "(-1.0, -1.0)" + }, + "inputs": [[467, 0, 0]] + }, + { + "op": "Flatten", + "name": "flatten15", + "inputs": [[529, 0, 0]] + }, + { + "op": "_contrib_MultiBoxPrior", + "name": "multi_feat_5_conv_3x3_relu_anchors", + "attrs": { + "clip": "False", + "ratios": "(1,2,0.5)", + "sizes": "(0.88,0.961)", + "steps": "(-1.0, -1.0)" + }, + "inputs": [[480, 0, 0]] + }, + { + "op": "Flatten", + "name": "flatten18", + "inputs": [[531, 0, 0]] + }, + { + "op": "Concat", + "name": "concat1", + "attrs": { + "dim": "1", + "num_args": "6" + }, + "inputs": [[522, 0, 0], [524, 0, 0], [526, 0, 0], [528, 0, 0], [530, 0, 0], [532, 0, 0]] + }, + { + "op": "Reshape", + "name": "multibox_anchors", + "attrs": {"shape": "(0, -1, 4)"}, + "inputs": [[533, 0, 0]] + }, + { + "op": "_contrib_MultiBoxDetection", + "name": "detection", + "attrs": { + "force_suppress": "False", + "nms_threshold": "0.5", + "nms_topk": "400", + "variances": "(0.1, 0.1, 0.2, 0.2)" + }, + "inputs": [[489, 0, 0], [520, 0, 0], [534, 0, 0]] + } + ], + "arg_nodes": [ + 0, + 2, + 3, + 4, + 5, + 7, + 9, + 10, + 11, + 12, + 16, + 17, + 18, + 19, + 22, + 24, + 25, + 26, + 27, + 30, + 32, + 33, + 34, + 35, + 38, + 40, + 43, + 44, + 45, + 46, + 49, + 51, + 52, + 53, + 54, + 57, + 59, + 60, + 61, + 62, + 65, + 68, + 69, + 70, + 71, + 74, + 76, + 77, + 78, + 79, + 82, + 84, + 85, + 86, + 87, + 90, + 93, + 94, + 95, + 96, + 99, + 101, + 102, + 103, + 104, + 107, + 109, + 110, + 111, + 112, + 115, + 117, + 120, + 121, + 122, + 123, + 126, + 128, + 129, + 130, + 131, + 134, + 136, + 137, + 138, + 139, + 142, + 145, + 146, + 147, + 148, + 151, + 153, + 154, + 155, + 156, + 159, + 161, + 162, + 163, + 164, + 167, + 170, + 171, + 172, + 173, + 176, + 178, + 179, + 180, + 181, + 184, + 186, + 187, + 188, + 189, + 192, + 195, + 196, + 197, + 198, + 201, + 203, + 204, + 205, + 206, + 209, + 211, + 212, + 213, + 214, + 217, + 219, + 222, + 223, + 224, + 225, + 228, + 230, + 231, + 232, + 233, + 236, + 238, + 239, + 240, + 241, + 244, + 247, + 248, + 249, + 250, + 253, + 255, + 256, + 257, + 258, + 261, + 263, + 264, + 265, + 266, + 269, + 272, + 273, + 274, + 275, + 278, + 280, + 281, + 282, + 283, + 286, + 288, + 289, + 290, + 291, + 294, + 297, + 298, + 299, + 300, + 303, + 305, + 306, + 307, + 308, + 311, + 313, + 314, + 315, + 316, + 319, + 322, + 323, + 324, + 325, + 328, + 330, + 331, + 332, + 333, + 336, + 338, + 339, + 340, + 341, + 344, + 347, + 348, + 352, + 353, + 354, + 355, + 358, + 360, + 361, + 362, + 363, + 366, + 368, + 369, + 370, + 371, + 374, + 376, + 379, + 380, + 381, + 382, + 385, + 387, + 388, + 389, + 390, + 393, + 395, + 396, + 397, + 398, + 401, + 404, + 405, + 406, + 407, + 410, + 412, + 413, + 414, + 415, + 418, + 420, + 421, + 422, + 423, + 426, + 429, + 430, + 434, + 435, + 438, + 439, + 442, + 443, + 447, + 448, + 451, + 452, + 455, + 456, + 460, + 461, + 464, + 465, + 468, + 469, + 473, + 474, + 477, + 478, + 481, + 482, + 490, + 491, + 495, + 496, + 500, + 501, + 505, + 506, + 510, + 511, + 515, + 516 + ], + "node_row_ptr": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 60, + 61, + 62, + 63, + 64, + 65, + 66, + 67, + 70, + 71, + 72, + 73, + 74, + 75, + 76, + 77, + 80, + 81, + 82, + 83, + 84, + 85, + 86, + 87, + 88, + 91, + 92, + 93, + 94, + 95, + 96, + 97, + 98, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 111, + 112, + 113, + 114, + 115, + 116, + 117, + 118, + 119, + 122, + 123, + 124, + 125, + 126, + 127, + 128, + 129, + 132, + 133, + 134, + 135, + 136, + 137, + 138, + 139, + 142, + 143, + 144, + 145, + 146, + 147, + 148, + 149, + 150, + 151, + 152, + 155, + 156, + 157, + 158, + 159, + 160, + 161, + 162, + 165, + 166, + 167, + 168, + 169, + 170, + 171, + 172, + 175, + 176, + 177, + 178, + 179, + 180, + 181, + 182, + 183, + 186, + 187, + 188, + 189, + 190, + 191, + 192, + 193, + 196, + 197, + 198, + 199, + 200, + 201, + 202, + 203, + 206, + 207, + 208, + 209, + 210, + 211, + 212, + 213, + 214, + 217, + 218, + 219, + 220, + 221, + 222, + 223, + 224, + 227, + 228, + 229, + 230, + 231, + 232, + 233, + 234, + 237, + 238, + 239, + 240, + 241, + 242, + 243, + 244, + 245, + 248, + 249, + 250, + 251, + 252, + 253, + 254, + 255, + 258, + 259, + 260, + 261, + 262, + 263, + 264, + 265, + 268, + 269, + 270, + 271, + 272, + 273, + 274, + 275, + 276, + 277, + 278, + 281, + 282, + 283, + 284, + 285, + 286, + 287, + 288, + 291, + 292, + 293, + 294, + 295, + 296, + 297, + 298, + 301, + 302, + 303, + 304, + 305, + 306, + 307, + 308, + 309, + 312, + 313, + 314, + 315, + 316, + 317, + 318, + 319, + 322, + 323, + 324, + 325, + 326, + 327, + 328, + 329, + 332, + 333, + 334, + 335, + 336, + 337, + 338, + 339, + 340, + 343, + 344, + 345, + 346, + 347, + 348, + 349, + 350, + 353, + 354, + 355, + 356, + 357, + 358, + 359, + 360, + 363, + 364, + 365, + 366, + 367, + 368, + 369, + 370, + 371, + 374, + 375, + 376, + 377, + 378, + 379, + 380, + 381, + 384, + 385, + 386, + 387, + 388, + 389, + 390, + 391, + 394, + 395, + 396, + 397, + 398, + 399, + 400, + 401, + 402, + 405, + 406, + 407, + 408, + 409, + 410, + 411, + 412, + 415, + 416, + 417, + 418, + 419, + 420, + 421, + 422, + 425, + 426, + 427, + 428, + 429, + 430, + 431, + 432, + 433, + 434, + 435, + 436, + 437, + 438, + 441, + 442, + 443, + 444, + 445, + 446, + 447, + 448, + 451, + 452, + 453, + 454, + 455, + 456, + 457, + 458, + 461, + 462, + 463, + 464, + 465, + 466, + 467, + 468, + 469, + 470, + 471, + 474, + 475, + 476, + 477, + 478, + 479, + 480, + 481, + 484, + 485, + 486, + 487, + 488, + 489, + 490, + 491, + 494, + 495, + 496, + 497, + 498, + 499, + 500, + 501, + 502, + 505, + 506, + 507, + 508, + 509, + 510, + 511, + 512, + 515, + 516, + 517, + 518, + 519, + 520, + 521, + 522, + 525, + 526, + 527, + 528, + 529, + 530, + 531, + 532, + 533, + 534, + 535, + 536, + 537, + 538, + 539, + 540, + 541, + 542, + 543, + 544, + 545, + 546, + 547, + 548, + 549, + 550, + 551, + 552, + 553, + 554, + 555, + 556, + 557, + 558, + 559, + 560, + 561, + 562, + 563, + 564, + 565, + 566, + 567, + 568, + 569, + 570, + 571, + 572, + 573, + 574, + 575, + 576, + 577, + 578, + 579, + 580, + 581, + 582, + 583, + 584, + 585, + 586, + 587, + 588, + 589, + 590, + 591, + 592, + 593, + 594, + 595, + 596, + 597, + 598, + 599, + 600, + 601, + 602, + 603, + 604, + 605, + 606, + 607, + 608, + 609, + 610, + 611, + 612, + 613, + 614, + 615, + 616, + 617, + 618, + 619, + 620, + 621, + 622, + 623, + 624, + 625, + 626, + 627, + 628, + 629, + 630, + 631, + 632, + 633, + 634, + 635, + 636 + ], + "heads": [[535, 0, 0]], + "attrs": {"mxnet_version": ["int", 10200]} +} \ No newline at end of file From 64ffb4a93aa646540b41e0d2379b370c4be6848f Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 31 May 2018 23:01:11 +0000 Subject: [PATCH 07/22] Remove json file --- nnvm/src/top/vision/nms.cc | 16 +- nnvm/tests/python/compiler/test_top_level4.py | 2 +- tutorials/nnvm/deploy_ssd.py | 11 +- .../nnvm/ssd/ssd_resnet50_inference.json | 6180 ----------------- 4 files changed, 23 insertions(+), 6186 deletions(-) delete mode 100644 tutorials/nnvm/ssd/ssd_resnet50_inference.json diff --git a/nnvm/src/top/vision/nms.cc b/nnvm/src/top/vision/nms.cc index 22b2136341ef..9cbdcd4d5095 100644 --- a/nnvm/src/top/vision/nms.cc +++ b/nnvm/src/top/vision/nms.cc @@ -41,10 +41,18 @@ bool NMSShape(const NodeAttrs& attrs, return true; } +inline bool NMSInferType(const NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + DTYPE_ASSIGN(out_attrs->at(0), in_attrs->at(0)); + return true; + +} + inline bool NMSInferLayout(const NodeAttrs& attrs, - std::vector *ilayouts, - const std::vector *last_ilayouts, - std::vector *olayouts) { + std::vector *ilayouts, + const std::vector *last_ilayouts, + std::vector *olayouts) { static const Layout kNCHW("NCHW"); CHECK_EQ(ilayouts->size(), 2U); CHECK_EQ(olayouts->size(), 1U); @@ -68,7 +76,7 @@ NNVM_REGISTER_OP(nms) return std::vector{"data", "valid_count"}; }) .set_attr("FInferShape", NMSShape) -.set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FInferType", NMSInferType) .set_attr("FCorrectLayout", NMSInferLayout) .set_support_level(4); diff --git a/nnvm/tests/python/compiler/test_top_level4.py b/nnvm/tests/python/compiler/test_top_level4.py index 54ac021bb9b9..421f0879826f 100644 --- a/nnvm/tests/python/compiler/test_top_level4.py +++ b/nnvm/tests/python/compiler/test_top_level4.py @@ -484,5 +484,5 @@ def test_nms(): test_flip() test_multibox_prior() test_multibox_detection() - #test_nms() + test_nms() print(nnvm.compiler.engine.dump()) diff --git a/tutorials/nnvm/deploy_ssd.py b/tutorials/nnvm/deploy_ssd.py index 8ad92443b7a5..d1d8bcc6b1ed 100644 --- a/tutorials/nnvm/deploy_ssd.py +++ b/tutorials/nnvm/deploy_ssd.py @@ -65,22 +65,31 @@ def download(url, path, overwrite=False): "resnet50_ssd_512_voc0712_trainval.zip" image_url = "https://cloud.githubusercontent.com/assets/3307514/20012567/" \ "cbb60336-a27d-11e6-93ff-cbc3f09f5c9e.jpg" +inference_symbol_folder = "c1904e900848df4548ce5dfb18c719c7-a28c4856c827fe766aa3da0e35bad41d44f0fb26" +inference_symbol_url = "https://gist.github.com/kevinthesun/c1904e900848df4548ce5dfb18c719c7/" \ + "archive/a28c4856c827fe766aa3da0e35bad41d44f0fb26.zip" dir = "ssd_model" if not os.path.exists(dir): os.makedirs(dir) model_file_path = "%s/%s" % (dir, model_file) test_image_path = "%s/%s" % (dir, test_image) +inference_symbol_path = "%s/inference_model.zip" % dir download(model_url, model_file_path) download(image_url, test_image_path) +download(inference_symbol_url, inference_symbol_path) + zip_ref = zipfile.ZipFile(model_file_path, 'r') zip_ref.extractall(dir) zip_ref.close() +zip_ref = zipfile.ZipFile(inference_symbol_path) +zip_ref.extractall(dir) +zip_ref.close() ###################################################################### # Convert and compile model with NNVM for CPU. -sym = mx.sym.load("ssd/ssd_resnet50_inference.json") +sym = mx.sym.load("%s/%s/ssd_resnet50_inference.json" % (dir, inference_symbol_folder)) _, arg_params, aux_params = load_checkpoint("%s/%s" % (dir, model_name), 0) net, params = from_mxnet(sym, arg_params, aux_params) with compiler.build_config(opt_level=3): diff --git a/tutorials/nnvm/ssd/ssd_resnet50_inference.json b/tutorials/nnvm/ssd/ssd_resnet50_inference.json deleted file mode 100644 index 3af9a9023a72..000000000000 --- a/tutorials/nnvm/ssd/ssd_resnet50_inference.json +++ /dev/null @@ -1,6180 +0,0 @@ -{ - "nodes": [ - { - "op": "null", - "name": "data", - "inputs": [] - }, - { - "op": "_copy", - "name": "id", - "inputs": [[0, 0, 0]] - }, - { - "op": "null", - "name": "bn_data_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "True", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "bn_data_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "True", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "bn_data_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "True", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "bn_data_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "True", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "bn_data", - "attrs": { - "eps": "2e-05", - "fix_gamma": "True", - "momentum": "0.9" - }, - "inputs": [[1, 0, 0], [2, 0, 0], [3, 0, 0], [4, 0, 1], [5, 0, 1]] - }, - { - "op": "null", - "name": "conv0_weight", - "attrs": { - "kernel": "(7, 7)", - "no_bias": "True", - "num_filter": "64", - "pad": "(3, 3)", - "stride": "(2, 2)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "conv0", - "attrs": { - "kernel": "(7, 7)", - "no_bias": "True", - "num_filter": "64", - "pad": "(3, 3)", - "stride": "(2, 2)", - "workspace": "256" - }, - "inputs": [[6, 0, 0], [7, 0, 0]] - }, - { - "op": "null", - "name": "bn0_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "bn0_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "bn0_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "bn0_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "bn0", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[8, 0, 0], [9, 0, 0], [10, 0, 0], [11, 0, 1], [12, 0, 1]] - }, - { - "op": "Activation", - "name": "relu0", - "attrs": {"act_type": "relu"}, - "inputs": [[13, 0, 0]] - }, - { - "op": "Pooling", - "name": "pooling0", - "attrs": { - "kernel": "(3, 3)", - "pad": "(1, 1)", - "pool_type": "max", - "stride": "(2, 2)" - }, - "inputs": [[14, 0, 0]] - }, - { - "op": "null", - "name": "stage1_unit1_bn1_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage1_unit1_bn1_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage1_unit1_bn1_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage1_unit1_bn1_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage1_unit1_bn1", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[15, 0, 0], [16, 0, 0], [17, 0, 0], [18, 0, 1], [19, 0, 1]] - }, - { - "op": "Activation", - "name": "stage1_unit1_relu1", - "attrs": {"act_type": "relu"}, - "inputs": [[20, 0, 0]] - }, - { - "op": "null", - "name": "stage1_unit1_conv1_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "64", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage1_unit1_conv1", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "64", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[21, 0, 0], [22, 0, 0]] - }, - { - "op": "null", - "name": "stage1_unit1_bn2_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage1_unit1_bn2_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage1_unit1_bn2_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage1_unit1_bn2_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage1_unit1_bn2", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[23, 0, 0], [24, 0, 0], [25, 0, 0], [26, 0, 1], [27, 0, 1]] - }, - { - "op": "Activation", - "name": "stage1_unit1_relu2", - "attrs": {"act_type": "relu"}, - "inputs": [[28, 0, 0]] - }, - { - "op": "null", - "name": "stage1_unit1_conv2_weight", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "64", - "pad": "(1, 1)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage1_unit1_conv2", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "64", - "pad": "(1, 1)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[29, 0, 0], [30, 0, 0]] - }, - { - "op": "null", - "name": "stage1_unit1_bn3_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage1_unit1_bn3_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage1_unit1_bn3_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage1_unit1_bn3_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage1_unit1_bn3", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[31, 0, 0], [32, 0, 0], [33, 0, 0], [34, 0, 1], [35, 0, 1]] - }, - { - "op": "Activation", - "name": "stage1_unit1_relu3", - "attrs": {"act_type": "relu"}, - "inputs": [[36, 0, 0]] - }, - { - "op": "null", - "name": "stage1_unit1_conv3_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "256", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage1_unit1_conv3", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "256", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[37, 0, 0], [38, 0, 0]] - }, - { - "op": "null", - "name": "stage1_unit1_sc_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "256", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage1_unit1_sc", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "256", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[21, 0, 0], [40, 0, 0]] - }, - { - "op": "elemwise_add", - "name": "_plus0", - "inputs": [[39, 0, 0], [41, 0, 0]] - }, - { - "op": "null", - "name": "stage1_unit2_bn1_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage1_unit2_bn1_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage1_unit2_bn1_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage1_unit2_bn1_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage1_unit2_bn1", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[42, 0, 0], [43, 0, 0], [44, 0, 0], [45, 0, 1], [46, 0, 1]] - }, - { - "op": "Activation", - "name": "stage1_unit2_relu1", - "attrs": {"act_type": "relu"}, - "inputs": [[47, 0, 0]] - }, - { - "op": "null", - "name": "stage1_unit2_conv1_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "64", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage1_unit2_conv1", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "64", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[48, 0, 0], [49, 0, 0]] - }, - { - "op": "null", - "name": "stage1_unit2_bn2_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage1_unit2_bn2_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage1_unit2_bn2_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage1_unit2_bn2_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage1_unit2_bn2", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[50, 0, 0], [51, 0, 0], [52, 0, 0], [53, 0, 1], [54, 0, 1]] - }, - { - "op": "Activation", - "name": "stage1_unit2_relu2", - "attrs": {"act_type": "relu"}, - "inputs": [[55, 0, 0]] - }, - { - "op": "null", - "name": "stage1_unit2_conv2_weight", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "64", - "pad": "(1, 1)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage1_unit2_conv2", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "64", - "pad": "(1, 1)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[56, 0, 0], [57, 0, 0]] - }, - { - "op": "null", - "name": "stage1_unit2_bn3_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage1_unit2_bn3_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage1_unit2_bn3_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage1_unit2_bn3_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage1_unit2_bn3", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[58, 0, 0], [59, 0, 0], [60, 0, 0], [61, 0, 1], [62, 0, 1]] - }, - { - "op": "Activation", - "name": "stage1_unit2_relu3", - "attrs": {"act_type": "relu"}, - "inputs": [[63, 0, 0]] - }, - { - "op": "null", - "name": "stage1_unit2_conv3_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "256", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage1_unit2_conv3", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "256", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[64, 0, 0], [65, 0, 0]] - }, - { - "op": "elemwise_add", - "name": "_plus1", - "inputs": [[66, 0, 0], [42, 0, 0]] - }, - { - "op": "null", - "name": "stage1_unit3_bn1_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage1_unit3_bn1_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage1_unit3_bn1_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage1_unit3_bn1_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage1_unit3_bn1", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[67, 0, 0], [68, 0, 0], [69, 0, 0], [70, 0, 1], [71, 0, 1]] - }, - { - "op": "Activation", - "name": "stage1_unit3_relu1", - "attrs": {"act_type": "relu"}, - "inputs": [[72, 0, 0]] - }, - { - "op": "null", - "name": "stage1_unit3_conv1_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "64", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage1_unit3_conv1", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "64", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[73, 0, 0], [74, 0, 0]] - }, - { - "op": "null", - "name": "stage1_unit3_bn2_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage1_unit3_bn2_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage1_unit3_bn2_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage1_unit3_bn2_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage1_unit3_bn2", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[75, 0, 0], [76, 0, 0], [77, 0, 0], [78, 0, 1], [79, 0, 1]] - }, - { - "op": "Activation", - "name": "stage1_unit3_relu2", - "attrs": {"act_type": "relu"}, - "inputs": [[80, 0, 0]] - }, - { - "op": "null", - "name": "stage1_unit3_conv2_weight", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "64", - "pad": "(1, 1)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage1_unit3_conv2", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "64", - "pad": "(1, 1)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[81, 0, 0], [82, 0, 0]] - }, - { - "op": "null", - "name": "stage1_unit3_bn3_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage1_unit3_bn3_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage1_unit3_bn3_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage1_unit3_bn3_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage1_unit3_bn3", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[83, 0, 0], [84, 0, 0], [85, 0, 0], [86, 0, 1], [87, 0, 1]] - }, - { - "op": "Activation", - "name": "stage1_unit3_relu3", - "attrs": {"act_type": "relu"}, - "inputs": [[88, 0, 0]] - }, - { - "op": "null", - "name": "stage1_unit3_conv3_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "256", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage1_unit3_conv3", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "256", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[89, 0, 0], [90, 0, 0]] - }, - { - "op": "elemwise_add", - "name": "_plus2", - "inputs": [[91, 0, 0], [67, 0, 0]] - }, - { - "op": "null", - "name": "stage2_unit1_bn1_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit1_bn1_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit1_bn1_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit1_bn1_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage2_unit1_bn1", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[92, 0, 0], [93, 0, 0], [94, 0, 0], [95, 0, 1], [96, 0, 1]] - }, - { - "op": "Activation", - "name": "stage2_unit1_relu1", - "attrs": {"act_type": "relu"}, - "inputs": [[97, 0, 0]] - }, - { - "op": "null", - "name": "stage2_unit1_conv1_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "128", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage2_unit1_conv1", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "128", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[98, 0, 0], [99, 0, 0]] - }, - { - "op": "null", - "name": "stage2_unit1_bn2_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit1_bn2_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit1_bn2_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit1_bn2_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage2_unit1_bn2", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[100, 0, 0], [101, 0, 0], [102, 0, 0], [103, 0, 1], [104, 0, 1]] - }, - { - "op": "Activation", - "name": "stage2_unit1_relu2", - "attrs": {"act_type": "relu"}, - "inputs": [[105, 0, 0]] - }, - { - "op": "null", - "name": "stage2_unit1_conv2_weight", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "128", - "pad": "(1, 1)", - "stride": "(2, 2)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage2_unit1_conv2", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "128", - "pad": "(1, 1)", - "stride": "(2, 2)", - "workspace": "256" - }, - "inputs": [[106, 0, 0], [107, 0, 0]] - }, - { - "op": "null", - "name": "stage2_unit1_bn3_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit1_bn3_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit1_bn3_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit1_bn3_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage2_unit1_bn3", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[108, 0, 0], [109, 0, 0], [110, 0, 0], [111, 0, 1], [112, 0, 1]] - }, - { - "op": "Activation", - "name": "stage2_unit1_relu3", - "attrs": {"act_type": "relu"}, - "inputs": [[113, 0, 0]] - }, - { - "op": "null", - "name": "stage2_unit1_conv3_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "512", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage2_unit1_conv3", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "512", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[114, 0, 0], [115, 0, 0]] - }, - { - "op": "null", - "name": "stage2_unit1_sc_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "512", - "stride": "(2, 2)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage2_unit1_sc", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "512", - "stride": "(2, 2)", - "workspace": "256" - }, - "inputs": [[98, 0, 0], [117, 0, 0]] - }, - { - "op": "elemwise_add", - "name": "_plus3", - "inputs": [[116, 0, 0], [118, 0, 0]] - }, - { - "op": "null", - "name": "stage2_unit2_bn1_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit2_bn1_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit2_bn1_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit2_bn1_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage2_unit2_bn1", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[119, 0, 0], [120, 0, 0], [121, 0, 0], [122, 0, 1], [123, 0, 1]] - }, - { - "op": "Activation", - "name": "stage2_unit2_relu1", - "attrs": {"act_type": "relu"}, - "inputs": [[124, 0, 0]] - }, - { - "op": "null", - "name": "stage2_unit2_conv1_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "128", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage2_unit2_conv1", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "128", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[125, 0, 0], [126, 0, 0]] - }, - { - "op": "null", - "name": "stage2_unit2_bn2_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit2_bn2_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit2_bn2_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit2_bn2_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage2_unit2_bn2", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[127, 0, 0], [128, 0, 0], [129, 0, 0], [130, 0, 1], [131, 0, 1]] - }, - { - "op": "Activation", - "name": "stage2_unit2_relu2", - "attrs": {"act_type": "relu"}, - "inputs": [[132, 0, 0]] - }, - { - "op": "null", - "name": "stage2_unit2_conv2_weight", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "128", - "pad": "(1, 1)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage2_unit2_conv2", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "128", - "pad": "(1, 1)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[133, 0, 0], [134, 0, 0]] - }, - { - "op": "null", - "name": "stage2_unit2_bn3_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit2_bn3_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit2_bn3_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit2_bn3_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage2_unit2_bn3", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[135, 0, 0], [136, 0, 0], [137, 0, 0], [138, 0, 1], [139, 0, 1]] - }, - { - "op": "Activation", - "name": "stage2_unit2_relu3", - "attrs": {"act_type": "relu"}, - "inputs": [[140, 0, 0]] - }, - { - "op": "null", - "name": "stage2_unit2_conv3_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "512", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage2_unit2_conv3", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "512", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[141, 0, 0], [142, 0, 0]] - }, - { - "op": "elemwise_add", - "name": "_plus4", - "inputs": [[143, 0, 0], [119, 0, 0]] - }, - { - "op": "null", - "name": "stage2_unit3_bn1_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit3_bn1_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit3_bn1_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit3_bn1_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage2_unit3_bn1", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[144, 0, 0], [145, 0, 0], [146, 0, 0], [147, 0, 1], [148, 0, 1]] - }, - { - "op": "Activation", - "name": "stage2_unit3_relu1", - "attrs": {"act_type": "relu"}, - "inputs": [[149, 0, 0]] - }, - { - "op": "null", - "name": "stage2_unit3_conv1_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "128", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage2_unit3_conv1", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "128", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[150, 0, 0], [151, 0, 0]] - }, - { - "op": "null", - "name": "stage2_unit3_bn2_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit3_bn2_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit3_bn2_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit3_bn2_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage2_unit3_bn2", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[152, 0, 0], [153, 0, 0], [154, 0, 0], [155, 0, 1], [156, 0, 1]] - }, - { - "op": "Activation", - "name": "stage2_unit3_relu2", - "attrs": {"act_type": "relu"}, - "inputs": [[157, 0, 0]] - }, - { - "op": "null", - "name": "stage2_unit3_conv2_weight", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "128", - "pad": "(1, 1)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage2_unit3_conv2", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "128", - "pad": "(1, 1)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[158, 0, 0], [159, 0, 0]] - }, - { - "op": "null", - "name": "stage2_unit3_bn3_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit3_bn3_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit3_bn3_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit3_bn3_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage2_unit3_bn3", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[160, 0, 0], [161, 0, 0], [162, 0, 0], [163, 0, 1], [164, 0, 1]] - }, - { - "op": "Activation", - "name": "stage2_unit3_relu3", - "attrs": {"act_type": "relu"}, - "inputs": [[165, 0, 0]] - }, - { - "op": "null", - "name": "stage2_unit3_conv3_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "512", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage2_unit3_conv3", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "512", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[166, 0, 0], [167, 0, 0]] - }, - { - "op": "elemwise_add", - "name": "_plus5", - "inputs": [[168, 0, 0], [144, 0, 0]] - }, - { - "op": "null", - "name": "stage2_unit4_bn1_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit4_bn1_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit4_bn1_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit4_bn1_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage2_unit4_bn1", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[169, 0, 0], [170, 0, 0], [171, 0, 0], [172, 0, 1], [173, 0, 1]] - }, - { - "op": "Activation", - "name": "stage2_unit4_relu1", - "attrs": {"act_type": "relu"}, - "inputs": [[174, 0, 0]] - }, - { - "op": "null", - "name": "stage2_unit4_conv1_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "128", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage2_unit4_conv1", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "128", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[175, 0, 0], [176, 0, 0]] - }, - { - "op": "null", - "name": "stage2_unit4_bn2_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit4_bn2_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit4_bn2_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit4_bn2_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage2_unit4_bn2", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[177, 0, 0], [178, 0, 0], [179, 0, 0], [180, 0, 1], [181, 0, 1]] - }, - { - "op": "Activation", - "name": "stage2_unit4_relu2", - "attrs": {"act_type": "relu"}, - "inputs": [[182, 0, 0]] - }, - { - "op": "null", - "name": "stage2_unit4_conv2_weight", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "128", - "pad": "(1, 1)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage2_unit4_conv2", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "128", - "pad": "(1, 1)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[183, 0, 0], [184, 0, 0]] - }, - { - "op": "null", - "name": "stage2_unit4_bn3_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit4_bn3_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit4_bn3_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage2_unit4_bn3_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage2_unit4_bn3", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[185, 0, 0], [186, 0, 0], [187, 0, 0], [188, 0, 1], [189, 0, 1]] - }, - { - "op": "Activation", - "name": "stage2_unit4_relu3", - "attrs": {"act_type": "relu"}, - "inputs": [[190, 0, 0]] - }, - { - "op": "null", - "name": "stage2_unit4_conv3_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "512", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage2_unit4_conv3", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "512", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[191, 0, 0], [192, 0, 0]] - }, - { - "op": "elemwise_add", - "name": "_plus6", - "inputs": [[193, 0, 0], [169, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit1_bn1_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit1_bn1_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit1_bn1_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit1_bn1_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage3_unit1_bn1", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[194, 0, 0], [195, 0, 0], [196, 0, 0], [197, 0, 1], [198, 0, 1]] - }, - { - "op": "Activation", - "name": "stage3_unit1_relu1", - "attrs": {"act_type": "relu"}, - "inputs": [[199, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit1_conv1_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "256", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage3_unit1_conv1", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "256", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[200, 0, 0], [201, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit1_bn2_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit1_bn2_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit1_bn2_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit1_bn2_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage3_unit1_bn2", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[202, 0, 0], [203, 0, 0], [204, 0, 0], [205, 0, 1], [206, 0, 1]] - }, - { - "op": "Activation", - "name": "stage3_unit1_relu2", - "attrs": {"act_type": "relu"}, - "inputs": [[207, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit1_conv2_weight", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "256", - "pad": "(1, 1)", - "stride": "(2, 2)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage3_unit1_conv2", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "256", - "pad": "(1, 1)", - "stride": "(2, 2)", - "workspace": "256" - }, - "inputs": [[208, 0, 0], [209, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit1_bn3_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit1_bn3_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit1_bn3_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit1_bn3_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage3_unit1_bn3", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[210, 0, 0], [211, 0, 0], [212, 0, 0], [213, 0, 1], [214, 0, 1]] - }, - { - "op": "Activation", - "name": "stage3_unit1_relu3", - "attrs": {"act_type": "relu"}, - "inputs": [[215, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit1_conv3_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "1024", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage3_unit1_conv3", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "1024", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[216, 0, 0], [217, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit1_sc_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "1024", - "stride": "(2, 2)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage3_unit1_sc", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "1024", - "stride": "(2, 2)", - "workspace": "256" - }, - "inputs": [[200, 0, 0], [219, 0, 0]] - }, - { - "op": "elemwise_add", - "name": "_plus7", - "inputs": [[218, 0, 0], [220, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit2_bn1_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit2_bn1_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit2_bn1_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit2_bn1_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage3_unit2_bn1", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[221, 0, 0], [222, 0, 0], [223, 0, 0], [224, 0, 1], [225, 0, 1]] - }, - { - "op": "Activation", - "name": "stage3_unit2_relu1", - "attrs": {"act_type": "relu"}, - "inputs": [[226, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit2_conv1_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "256", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage3_unit2_conv1", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "256", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[227, 0, 0], [228, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit2_bn2_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit2_bn2_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit2_bn2_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit2_bn2_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage3_unit2_bn2", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[229, 0, 0], [230, 0, 0], [231, 0, 0], [232, 0, 1], [233, 0, 1]] - }, - { - "op": "Activation", - "name": "stage3_unit2_relu2", - "attrs": {"act_type": "relu"}, - "inputs": [[234, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit2_conv2_weight", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "256", - "pad": "(1, 1)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage3_unit2_conv2", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "256", - "pad": "(1, 1)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[235, 0, 0], [236, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit2_bn3_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit2_bn3_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit2_bn3_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit2_bn3_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage3_unit2_bn3", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[237, 0, 0], [238, 0, 0], [239, 0, 0], [240, 0, 1], [241, 0, 1]] - }, - { - "op": "Activation", - "name": "stage3_unit2_relu3", - "attrs": {"act_type": "relu"}, - "inputs": [[242, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit2_conv3_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "1024", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage3_unit2_conv3", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "1024", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[243, 0, 0], [244, 0, 0]] - }, - { - "op": "elemwise_add", - "name": "_plus8", - "inputs": [[245, 0, 0], [221, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit3_bn1_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit3_bn1_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit3_bn1_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit3_bn1_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage3_unit3_bn1", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[246, 0, 0], [247, 0, 0], [248, 0, 0], [249, 0, 1], [250, 0, 1]] - }, - { - "op": "Activation", - "name": "stage3_unit3_relu1", - "attrs": {"act_type": "relu"}, - "inputs": [[251, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit3_conv1_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "256", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage3_unit3_conv1", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "256", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[252, 0, 0], [253, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit3_bn2_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit3_bn2_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit3_bn2_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit3_bn2_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage3_unit3_bn2", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[254, 0, 0], [255, 0, 0], [256, 0, 0], [257, 0, 1], [258, 0, 1]] - }, - { - "op": "Activation", - "name": "stage3_unit3_relu2", - "attrs": {"act_type": "relu"}, - "inputs": [[259, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit3_conv2_weight", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "256", - "pad": "(1, 1)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage3_unit3_conv2", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "256", - "pad": "(1, 1)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[260, 0, 0], [261, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit3_bn3_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit3_bn3_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit3_bn3_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit3_bn3_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage3_unit3_bn3", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[262, 0, 0], [263, 0, 0], [264, 0, 0], [265, 0, 1], [266, 0, 1]] - }, - { - "op": "Activation", - "name": "stage3_unit3_relu3", - "attrs": {"act_type": "relu"}, - "inputs": [[267, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit3_conv3_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "1024", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage3_unit3_conv3", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "1024", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[268, 0, 0], [269, 0, 0]] - }, - { - "op": "elemwise_add", - "name": "_plus9", - "inputs": [[270, 0, 0], [246, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit4_bn1_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit4_bn1_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit4_bn1_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit4_bn1_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage3_unit4_bn1", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[271, 0, 0], [272, 0, 0], [273, 0, 0], [274, 0, 1], [275, 0, 1]] - }, - { - "op": "Activation", - "name": "stage3_unit4_relu1", - "attrs": {"act_type": "relu"}, - "inputs": [[276, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit4_conv1_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "256", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage3_unit4_conv1", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "256", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[277, 0, 0], [278, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit4_bn2_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit4_bn2_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit4_bn2_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit4_bn2_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage3_unit4_bn2", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[279, 0, 0], [280, 0, 0], [281, 0, 0], [282, 0, 1], [283, 0, 1]] - }, - { - "op": "Activation", - "name": "stage3_unit4_relu2", - "attrs": {"act_type": "relu"}, - "inputs": [[284, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit4_conv2_weight", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "256", - "pad": "(1, 1)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage3_unit4_conv2", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "256", - "pad": "(1, 1)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[285, 0, 0], [286, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit4_bn3_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit4_bn3_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit4_bn3_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit4_bn3_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage3_unit4_bn3", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[287, 0, 0], [288, 0, 0], [289, 0, 0], [290, 0, 1], [291, 0, 1]] - }, - { - "op": "Activation", - "name": "stage3_unit4_relu3", - "attrs": {"act_type": "relu"}, - "inputs": [[292, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit4_conv3_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "1024", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage3_unit4_conv3", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "1024", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[293, 0, 0], [294, 0, 0]] - }, - { - "op": "elemwise_add", - "name": "_plus10", - "inputs": [[295, 0, 0], [271, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit5_bn1_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit5_bn1_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit5_bn1_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit5_bn1_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage3_unit5_bn1", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[296, 0, 0], [297, 0, 0], [298, 0, 0], [299, 0, 1], [300, 0, 1]] - }, - { - "op": "Activation", - "name": "stage3_unit5_relu1", - "attrs": {"act_type": "relu"}, - "inputs": [[301, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit5_conv1_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "256", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage3_unit5_conv1", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "256", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[302, 0, 0], [303, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit5_bn2_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit5_bn2_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit5_bn2_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit5_bn2_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage3_unit5_bn2", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[304, 0, 0], [305, 0, 0], [306, 0, 0], [307, 0, 1], [308, 0, 1]] - }, - { - "op": "Activation", - "name": "stage3_unit5_relu2", - "attrs": {"act_type": "relu"}, - "inputs": [[309, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit5_conv2_weight", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "256", - "pad": "(1, 1)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage3_unit5_conv2", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "256", - "pad": "(1, 1)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[310, 0, 0], [311, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit5_bn3_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit5_bn3_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit5_bn3_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit5_bn3_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage3_unit5_bn3", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[312, 0, 0], [313, 0, 0], [314, 0, 0], [315, 0, 1], [316, 0, 1]] - }, - { - "op": "Activation", - "name": "stage3_unit5_relu3", - "attrs": {"act_type": "relu"}, - "inputs": [[317, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit5_conv3_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "1024", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage3_unit5_conv3", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "1024", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[318, 0, 0], [319, 0, 0]] - }, - { - "op": "elemwise_add", - "name": "_plus11", - "inputs": [[320, 0, 0], [296, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit6_bn1_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit6_bn1_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit6_bn1_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit6_bn1_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage3_unit6_bn1", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[321, 0, 0], [322, 0, 0], [323, 0, 0], [324, 0, 1], [325, 0, 1]] - }, - { - "op": "Activation", - "name": "stage3_unit6_relu1", - "attrs": {"act_type": "relu"}, - "inputs": [[326, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit6_conv1_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "256", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage3_unit6_conv1", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "256", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[327, 0, 0], [328, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit6_bn2_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit6_bn2_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit6_bn2_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit6_bn2_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage3_unit6_bn2", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[329, 0, 0], [330, 0, 0], [331, 0, 0], [332, 0, 1], [333, 0, 1]] - }, - { - "op": "Activation", - "name": "stage3_unit6_relu2", - "attrs": {"act_type": "relu"}, - "inputs": [[334, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit6_conv2_weight", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "256", - "pad": "(1, 1)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage3_unit6_conv2", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "256", - "pad": "(1, 1)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[335, 0, 0], [336, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit6_bn3_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit6_bn3_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit6_bn3_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage3_unit6_bn3_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage3_unit6_bn3", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[337, 0, 0], [338, 0, 0], [339, 0, 0], [340, 0, 1], [341, 0, 1]] - }, - { - "op": "Activation", - "name": "stage3_unit6_relu3", - "attrs": {"act_type": "relu"}, - "inputs": [[342, 0, 0]] - }, - { - "op": "null", - "name": "stage3_unit6_conv3_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "1024", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage3_unit6_conv3", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "1024", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[343, 0, 0], [344, 0, 0]] - }, - { - "op": "elemwise_add", - "name": "_plus12", - "inputs": [[345, 0, 0], [321, 0, 0]] - }, - { - "op": "null", - "name": "_plus12_cls_pred_conv_weight", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "84", - "pad": "(1, 1)", - "stride": "(1, 1)" - }, - "inputs": [] - }, - { - "op": "null", - "name": "_plus12_cls_pred_conv_bias", - "attrs": { - "__init__": "[\"constant\", {\"value\": 0.0}]", - "__lr_mult__": "2.0" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "_plus12_cls_pred_conv", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "84", - "pad": "(1, 1)", - "stride": "(1, 1)" - }, - "inputs": [[346, 0, 0], [347, 0, 0], [348, 0, 0]] - }, - { - "op": "transpose", - "name": "transpose1", - "attrs": {"axes": "(0, 2, 3, 1)"}, - "inputs": [[349, 0, 0]] - }, - { - "op": "Flatten", - "name": "flatten2", - "inputs": [[350, 0, 0]] - }, - { - "op": "null", - "name": "stage4_unit1_bn1_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage4_unit1_bn1_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage4_unit1_bn1_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage4_unit1_bn1_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage4_unit1_bn1", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[346, 0, 0], [352, 0, 0], [353, 0, 0], [354, 0, 1], [355, 0, 1]] - }, - { - "op": "Activation", - "name": "stage4_unit1_relu1", - "attrs": {"act_type": "relu"}, - "inputs": [[356, 0, 0]] - }, - { - "op": "null", - "name": "stage4_unit1_conv1_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "512", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage4_unit1_conv1", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "512", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[357, 0, 0], [358, 0, 0]] - }, - { - "op": "null", - "name": "stage4_unit1_bn2_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage4_unit1_bn2_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage4_unit1_bn2_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage4_unit1_bn2_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage4_unit1_bn2", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[359, 0, 0], [360, 0, 0], [361, 0, 0], [362, 0, 1], [363, 0, 1]] - }, - { - "op": "Activation", - "name": "stage4_unit1_relu2", - "attrs": {"act_type": "relu"}, - "inputs": [[364, 0, 0]] - }, - { - "op": "null", - "name": "stage4_unit1_conv2_weight", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "512", - "pad": "(1, 1)", - "stride": "(2, 2)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage4_unit1_conv2", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "512", - "pad": "(1, 1)", - "stride": "(2, 2)", - "workspace": "256" - }, - "inputs": [[365, 0, 0], [366, 0, 0]] - }, - { - "op": "null", - "name": "stage4_unit1_bn3_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage4_unit1_bn3_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage4_unit1_bn3_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage4_unit1_bn3_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage4_unit1_bn3", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[367, 0, 0], [368, 0, 0], [369, 0, 0], [370, 0, 1], [371, 0, 1]] - }, - { - "op": "Activation", - "name": "stage4_unit1_relu3", - "attrs": {"act_type": "relu"}, - "inputs": [[372, 0, 0]] - }, - { - "op": "null", - "name": "stage4_unit1_conv3_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "2048", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage4_unit1_conv3", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "2048", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[373, 0, 0], [374, 0, 0]] - }, - { - "op": "null", - "name": "stage4_unit1_sc_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "2048", - "stride": "(2, 2)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage4_unit1_sc", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "2048", - "stride": "(2, 2)", - "workspace": "256" - }, - "inputs": [[357, 0, 0], [376, 0, 0]] - }, - { - "op": "elemwise_add", - "name": "_plus13", - "inputs": [[375, 0, 0], [377, 0, 0]] - }, - { - "op": "null", - "name": "stage4_unit2_bn1_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage4_unit2_bn1_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage4_unit2_bn1_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage4_unit2_bn1_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage4_unit2_bn1", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[378, 0, 0], [379, 0, 0], [380, 0, 0], [381, 0, 1], [382, 0, 1]] - }, - { - "op": "Activation", - "name": "stage4_unit2_relu1", - "attrs": {"act_type": "relu"}, - "inputs": [[383, 0, 0]] - }, - { - "op": "null", - "name": "stage4_unit2_conv1_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "512", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage4_unit2_conv1", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "512", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[384, 0, 0], [385, 0, 0]] - }, - { - "op": "null", - "name": "stage4_unit2_bn2_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage4_unit2_bn2_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage4_unit2_bn2_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage4_unit2_bn2_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage4_unit2_bn2", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[386, 0, 0], [387, 0, 0], [388, 0, 0], [389, 0, 1], [390, 0, 1]] - }, - { - "op": "Activation", - "name": "stage4_unit2_relu2", - "attrs": {"act_type": "relu"}, - "inputs": [[391, 0, 0]] - }, - { - "op": "null", - "name": "stage4_unit2_conv2_weight", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "512", - "pad": "(1, 1)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage4_unit2_conv2", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "512", - "pad": "(1, 1)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[392, 0, 0], [393, 0, 0]] - }, - { - "op": "null", - "name": "stage4_unit2_bn3_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage4_unit2_bn3_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage4_unit2_bn3_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage4_unit2_bn3_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage4_unit2_bn3", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[394, 0, 0], [395, 0, 0], [396, 0, 0], [397, 0, 1], [398, 0, 1]] - }, - { - "op": "Activation", - "name": "stage4_unit2_relu3", - "attrs": {"act_type": "relu"}, - "inputs": [[399, 0, 0]] - }, - { - "op": "null", - "name": "stage4_unit2_conv3_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "2048", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage4_unit2_conv3", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "2048", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[400, 0, 0], [401, 0, 0]] - }, - { - "op": "elemwise_add", - "name": "_plus14", - "inputs": [[402, 0, 0], [378, 0, 0]] - }, - { - "op": "null", - "name": "stage4_unit3_bn1_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage4_unit3_bn1_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage4_unit3_bn1_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage4_unit3_bn1_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage4_unit3_bn1", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[403, 0, 0], [404, 0, 0], [405, 0, 0], [406, 0, 1], [407, 0, 1]] - }, - { - "op": "Activation", - "name": "stage4_unit3_relu1", - "attrs": {"act_type": "relu"}, - "inputs": [[408, 0, 0]] - }, - { - "op": "null", - "name": "stage4_unit3_conv1_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "512", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage4_unit3_conv1", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "512", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[409, 0, 0], [410, 0, 0]] - }, - { - "op": "null", - "name": "stage4_unit3_bn2_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage4_unit3_bn2_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage4_unit3_bn2_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage4_unit3_bn2_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage4_unit3_bn2", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[411, 0, 0], [412, 0, 0], [413, 0, 0], [414, 0, 1], [415, 0, 1]] - }, - { - "op": "Activation", - "name": "stage4_unit3_relu2", - "attrs": {"act_type": "relu"}, - "inputs": [[416, 0, 0]] - }, - { - "op": "null", - "name": "stage4_unit3_conv2_weight", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "512", - "pad": "(1, 1)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage4_unit3_conv2", - "attrs": { - "kernel": "(3, 3)", - "no_bias": "True", - "num_filter": "512", - "pad": "(1, 1)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[417, 0, 0], [418, 0, 0]] - }, - { - "op": "null", - "name": "stage4_unit3_bn3_gamma", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage4_unit3_bn3_beta", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage4_unit3_bn3_moving_mean", - "attrs": { - "__init__": "[\"zero\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "null", - "name": "stage4_unit3_bn3_moving_var", - "attrs": { - "__init__": "[\"one\", {}]", - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [] - }, - { - "op": "BatchNorm", - "name": "stage4_unit3_bn3", - "attrs": { - "eps": "2e-05", - "fix_gamma": "False", - "momentum": "0.9" - }, - "inputs": [[419, 0, 0], [420, 0, 0], [421, 0, 0], [422, 0, 1], [423, 0, 1]] - }, - { - "op": "Activation", - "name": "stage4_unit3_relu3", - "attrs": {"act_type": "relu"}, - "inputs": [[424, 0, 0]] - }, - { - "op": "null", - "name": "stage4_unit3_conv3_weight", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "2048", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "stage4_unit3_conv3", - "attrs": { - "kernel": "(1, 1)", - "no_bias": "True", - "num_filter": "2048", - "pad": "(0, 0)", - "stride": "(1, 1)", - "workspace": "256" - }, - "inputs": [[425, 0, 0], [426, 0, 0]] - }, - { - "op": "elemwise_add", - "name": "_plus15", - "inputs": [[427, 0, 0], [403, 0, 0]] - }, - { - "op": "null", - "name": "_plus15_cls_pred_conv_weight", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "126", - "pad": "(1, 1)", - "stride": "(1, 1)" - }, - "inputs": [] - }, - { - "op": "null", - "name": "_plus15_cls_pred_conv_bias", - "attrs": { - "__init__": "[\"constant\", {\"value\": 0.0}]", - "__lr_mult__": "2.0" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "_plus15_cls_pred_conv", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "126", - "pad": "(1, 1)", - "stride": "(1, 1)" - }, - "inputs": [[428, 0, 0], [429, 0, 0], [430, 0, 0]] - }, - { - "op": "transpose", - "name": "transpose3", - "attrs": {"axes": "(0, 2, 3, 1)"}, - "inputs": [[431, 0, 0]] - }, - { - "op": "Flatten", - "name": "flatten5", - "inputs": [[432, 0, 0]] - }, - { - "op": "null", - "name": "multi_feat_2_conv_1x1_conv_weight", - "attrs": { - "kernel": "(1, 1)", - "num_filter": "256", - "pad": "(0, 0)", - "stride": "(1, 1)" - }, - "inputs": [] - }, - { - "op": "null", - "name": "multi_feat_2_conv_1x1_conv_bias", - "attrs": { - "kernel": "(1, 1)", - "num_filter": "256", - "pad": "(0, 0)", - "stride": "(1, 1)" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "multi_feat_2_conv_1x1_conv", - "attrs": { - "kernel": "(1, 1)", - "num_filter": "256", - "pad": "(0, 0)", - "stride": "(1, 1)" - }, - "inputs": [[428, 0, 0], [434, 0, 0], [435, 0, 0]] - }, - { - "op": "Activation", - "name": "multi_feat_2_conv_1x1_relu", - "attrs": {"act_type": "relu"}, - "inputs": [[436, 0, 0]] - }, - { - "op": "null", - "name": "multi_feat_2_conv_3x3_conv_weight", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "512", - "pad": "(1, 1)", - "stride": "(2, 2)" - }, - "inputs": [] - }, - { - "op": "null", - "name": "multi_feat_2_conv_3x3_conv_bias", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "512", - "pad": "(1, 1)", - "stride": "(2, 2)" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "multi_feat_2_conv_3x3_conv", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "512", - "pad": "(1, 1)", - "stride": "(2, 2)" - }, - "inputs": [[437, 0, 0], [438, 0, 0], [439, 0, 0]] - }, - { - "op": "Activation", - "name": "multi_feat_2_conv_3x3_relu", - "attrs": {"act_type": "relu"}, - "inputs": [[440, 0, 0]] - }, - { - "op": "null", - "name": "multi_feat_2_conv_3x3_relu_cls_pred_conv_weight", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "126", - "pad": "(1, 1)", - "stride": "(1, 1)" - }, - "inputs": [] - }, - { - "op": "null", - "name": "multi_feat_2_conv_3x3_relu_cls_pred_conv_bias", - "attrs": { - "__init__": "[\"constant\", {\"value\": 0.0}]", - "__lr_mult__": "2.0" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "multi_feat_2_conv_3x3_relu_cls_pred_conv", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "126", - "pad": "(1, 1)", - "stride": "(1, 1)" - }, - "inputs": [[441, 0, 0], [442, 0, 0], [443, 0, 0]] - }, - { - "op": "transpose", - "name": "transpose5", - "attrs": {"axes": "(0, 2, 3, 1)"}, - "inputs": [[444, 0, 0]] - }, - { - "op": "Flatten", - "name": "flatten8", - "inputs": [[445, 0, 0]] - }, - { - "op": "null", - "name": "multi_feat_3_conv_1x1_conv_weight", - "attrs": { - "kernel": "(1, 1)", - "num_filter": "128", - "pad": "(0, 0)", - "stride": "(1, 1)" - }, - "inputs": [] - }, - { - "op": "null", - "name": "multi_feat_3_conv_1x1_conv_bias", - "attrs": { - "kernel": "(1, 1)", - "num_filter": "128", - "pad": "(0, 0)", - "stride": "(1, 1)" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "multi_feat_3_conv_1x1_conv", - "attrs": { - "kernel": "(1, 1)", - "num_filter": "128", - "pad": "(0, 0)", - "stride": "(1, 1)" - }, - "inputs": [[441, 0, 0], [447, 0, 0], [448, 0, 0]] - }, - { - "op": "Activation", - "name": "multi_feat_3_conv_1x1_relu", - "attrs": {"act_type": "relu"}, - "inputs": [[449, 0, 0]] - }, - { - "op": "null", - "name": "multi_feat_3_conv_3x3_conv_weight", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "256", - "pad": "(1, 1)", - "stride": "(2, 2)" - }, - "inputs": [] - }, - { - "op": "null", - "name": "multi_feat_3_conv_3x3_conv_bias", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "256", - "pad": "(1, 1)", - "stride": "(2, 2)" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "multi_feat_3_conv_3x3_conv", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "256", - "pad": "(1, 1)", - "stride": "(2, 2)" - }, - "inputs": [[450, 0, 0], [451, 0, 0], [452, 0, 0]] - }, - { - "op": "Activation", - "name": "multi_feat_3_conv_3x3_relu", - "attrs": {"act_type": "relu"}, - "inputs": [[453, 0, 0]] - }, - { - "op": "null", - "name": "multi_feat_3_conv_3x3_relu_cls_pred_conv_weight", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "126", - "pad": "(1, 1)", - "stride": "(1, 1)" - }, - "inputs": [] - }, - { - "op": "null", - "name": "multi_feat_3_conv_3x3_relu_cls_pred_conv_bias", - "attrs": { - "__init__": "[\"constant\", {\"value\": 0.0}]", - "__lr_mult__": "2.0" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "multi_feat_3_conv_3x3_relu_cls_pred_conv", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "126", - "pad": "(1, 1)", - "stride": "(1, 1)" - }, - "inputs": [[454, 0, 0], [455, 0, 0], [456, 0, 0]] - }, - { - "op": "transpose", - "name": "transpose7", - "attrs": {"axes": "(0, 2, 3, 1)"}, - "inputs": [[457, 0, 0]] - }, - { - "op": "Flatten", - "name": "flatten11", - "inputs": [[458, 0, 0]] - }, - { - "op": "null", - "name": "multi_feat_4_conv_1x1_conv_weight", - "attrs": { - "kernel": "(1, 1)", - "num_filter": "128", - "pad": "(0, 0)", - "stride": "(1, 1)" - }, - "inputs": [] - }, - { - "op": "null", - "name": "multi_feat_4_conv_1x1_conv_bias", - "attrs": { - "kernel": "(1, 1)", - "num_filter": "128", - "pad": "(0, 0)", - "stride": "(1, 1)" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "multi_feat_4_conv_1x1_conv", - "attrs": { - "kernel": "(1, 1)", - "num_filter": "128", - "pad": "(0, 0)", - "stride": "(1, 1)" - }, - "inputs": [[454, 0, 0], [460, 0, 0], [461, 0, 0]] - }, - { - "op": "Activation", - "name": "multi_feat_4_conv_1x1_relu", - "attrs": {"act_type": "relu"}, - "inputs": [[462, 0, 0]] - }, - { - "op": "null", - "name": "multi_feat_4_conv_3x3_conv_weight", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "256", - "pad": "(1, 1)", - "stride": "(2, 2)" - }, - "inputs": [] - }, - { - "op": "null", - "name": "multi_feat_4_conv_3x3_conv_bias", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "256", - "pad": "(1, 1)", - "stride": "(2, 2)" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "multi_feat_4_conv_3x3_conv", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "256", - "pad": "(1, 1)", - "stride": "(2, 2)" - }, - "inputs": [[463, 0, 0], [464, 0, 0], [465, 0, 0]] - }, - { - "op": "Activation", - "name": "multi_feat_4_conv_3x3_relu", - "attrs": {"act_type": "relu"}, - "inputs": [[466, 0, 0]] - }, - { - "op": "null", - "name": "multi_feat_4_conv_3x3_relu_cls_pred_conv_weight", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "84", - "pad": "(1, 1)", - "stride": "(1, 1)" - }, - "inputs": [] - }, - { - "op": "null", - "name": "multi_feat_4_conv_3x3_relu_cls_pred_conv_bias", - "attrs": { - "__init__": "[\"constant\", {\"value\": 0.0}]", - "__lr_mult__": "2.0" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "multi_feat_4_conv_3x3_relu_cls_pred_conv", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "84", - "pad": "(1, 1)", - "stride": "(1, 1)" - }, - "inputs": [[467, 0, 0], [468, 0, 0], [469, 0, 0]] - }, - { - "op": "transpose", - "name": "transpose9", - "attrs": {"axes": "(0, 2, 3, 1)"}, - "inputs": [[470, 0, 0]] - }, - { - "op": "Flatten", - "name": "flatten14", - "inputs": [[471, 0, 0]] - }, - { - "op": "null", - "name": "multi_feat_5_conv_1x1_conv_weight", - "attrs": { - "kernel": "(1, 1)", - "num_filter": "128", - "pad": "(0, 0)", - "stride": "(1, 1)" - }, - "inputs": [] - }, - { - "op": "null", - "name": "multi_feat_5_conv_1x1_conv_bias", - "attrs": { - "kernel": "(1, 1)", - "num_filter": "128", - "pad": "(0, 0)", - "stride": "(1, 1)" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "multi_feat_5_conv_1x1_conv", - "attrs": { - "kernel": "(1, 1)", - "num_filter": "128", - "pad": "(0, 0)", - "stride": "(1, 1)" - }, - "inputs": [[467, 0, 0], [473, 0, 0], [474, 0, 0]] - }, - { - "op": "Activation", - "name": "multi_feat_5_conv_1x1_relu", - "attrs": {"act_type": "relu"}, - "inputs": [[475, 0, 0]] - }, - { - "op": "null", - "name": "multi_feat_5_conv_3x3_conv_weight", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "128", - "pad": "(1, 1)", - "stride": "(2, 2)" - }, - "inputs": [] - }, - { - "op": "null", - "name": "multi_feat_5_conv_3x3_conv_bias", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "128", - "pad": "(1, 1)", - "stride": "(2, 2)" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "multi_feat_5_conv_3x3_conv", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "128", - "pad": "(1, 1)", - "stride": "(2, 2)" - }, - "inputs": [[476, 0, 0], [477, 0, 0], [478, 0, 0]] - }, - { - "op": "Activation", - "name": "multi_feat_5_conv_3x3_relu", - "attrs": {"act_type": "relu"}, - "inputs": [[479, 0, 0]] - }, - { - "op": "null", - "name": "multi_feat_5_conv_3x3_relu_cls_pred_conv_weight", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "84", - "pad": "(1, 1)", - "stride": "(1, 1)" - }, - "inputs": [] - }, - { - "op": "null", - "name": "multi_feat_5_conv_3x3_relu_cls_pred_conv_bias", - "attrs": { - "__init__": "[\"constant\", {\"value\": 0.0}]", - "__lr_mult__": "2.0" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "multi_feat_5_conv_3x3_relu_cls_pred_conv", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "84", - "pad": "(1, 1)", - "stride": "(1, 1)" - }, - "inputs": [[480, 0, 0], [481, 0, 0], [482, 0, 0]] - }, - { - "op": "transpose", - "name": "transpose11", - "attrs": {"axes": "(0, 2, 3, 1)"}, - "inputs": [[483, 0, 0]] - }, - { - "op": "Flatten", - "name": "flatten17", - "inputs": [[484, 0, 0]] - }, - { - "op": "Concat", - "name": "concat0", - "attrs": { - "dim": "1", - "num_args": "6" - }, - "inputs": [[351, 0, 0], [433, 0, 0], [446, 0, 0], [459, 0, 0], [472, 0, 0], [485, 0, 0]] - }, - { - "op": "Reshape", - "name": "reshape0", - "attrs": {"shape": "(0, -1, 21)"}, - "inputs": [[486, 0, 0]] - }, - { - "op": "transpose", - "name": "multibox_cls_pred", - "attrs": {"axes": "(0, 2, 1)"}, - "inputs": [[487, 0, 0]] - }, - { - "op": "SoftmaxActivation", - "name": "cls_prob", - "attrs": {"mode": "channel"}, - "inputs": [[488, 0, 0]] - }, - { - "op": "null", - "name": "_plus12_loc_pred_conv_weight", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "16", - "pad": "(1, 1)", - "stride": "(1, 1)" - }, - "inputs": [] - }, - { - "op": "null", - "name": "_plus12_loc_pred_conv_bias", - "attrs": { - "__init__": "[\"constant\", {\"value\": 0.0}]", - "__lr_mult__": "2.0" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "_plus12_loc_pred_conv", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "16", - "pad": "(1, 1)", - "stride": "(1, 1)" - }, - "inputs": [[346, 0, 0], [490, 0, 0], [491, 0, 0]] - }, - { - "op": "transpose", - "name": "transpose0", - "attrs": {"axes": "(0, 2, 3, 1)"}, - "inputs": [[492, 0, 0]] - }, - { - "op": "Flatten", - "name": "flatten1", - "inputs": [[493, 0, 0]] - }, - { - "op": "null", - "name": "_plus15_loc_pred_conv_weight", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "24", - "pad": "(1, 1)", - "stride": "(1, 1)" - }, - "inputs": [] - }, - { - "op": "null", - "name": "_plus15_loc_pred_conv_bias", - "attrs": { - "__init__": "[\"constant\", {\"value\": 0.0}]", - "__lr_mult__": "2.0" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "_plus15_loc_pred_conv", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "24", - "pad": "(1, 1)", - "stride": "(1, 1)" - }, - "inputs": [[428, 0, 0], [495, 0, 0], [496, 0, 0]] - }, - { - "op": "transpose", - "name": "transpose2", - "attrs": {"axes": "(0, 2, 3, 1)"}, - "inputs": [[497, 0, 0]] - }, - { - "op": "Flatten", - "name": "flatten4", - "inputs": [[498, 0, 0]] - }, - { - "op": "null", - "name": "multi_feat_2_conv_3x3_relu_loc_pred_conv_weight", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "24", - "pad": "(1, 1)", - "stride": "(1, 1)" - }, - "inputs": [] - }, - { - "op": "null", - "name": "multi_feat_2_conv_3x3_relu_loc_pred_conv_bias", - "attrs": { - "__init__": "[\"constant\", {\"value\": 0.0}]", - "__lr_mult__": "2.0" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "multi_feat_2_conv_3x3_relu_loc_pred_conv", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "24", - "pad": "(1, 1)", - "stride": "(1, 1)" - }, - "inputs": [[441, 0, 0], [500, 0, 0], [501, 0, 0]] - }, - { - "op": "transpose", - "name": "transpose4", - "attrs": {"axes": "(0, 2, 3, 1)"}, - "inputs": [[502, 0, 0]] - }, - { - "op": "Flatten", - "name": "flatten7", - "inputs": [[503, 0, 0]] - }, - { - "op": "null", - "name": "multi_feat_3_conv_3x3_relu_loc_pred_conv_weight", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "24", - "pad": "(1, 1)", - "stride": "(1, 1)" - }, - "inputs": [] - }, - { - "op": "null", - "name": "multi_feat_3_conv_3x3_relu_loc_pred_conv_bias", - "attrs": { - "__init__": "[\"constant\", {\"value\": 0.0}]", - "__lr_mult__": "2.0" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "multi_feat_3_conv_3x3_relu_loc_pred_conv", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "24", - "pad": "(1, 1)", - "stride": "(1, 1)" - }, - "inputs": [[454, 0, 0], [505, 0, 0], [506, 0, 0]] - }, - { - "op": "transpose", - "name": "transpose6", - "attrs": {"axes": "(0, 2, 3, 1)"}, - "inputs": [[507, 0, 0]] - }, - { - "op": "Flatten", - "name": "flatten10", - "inputs": [[508, 0, 0]] - }, - { - "op": "null", - "name": "multi_feat_4_conv_3x3_relu_loc_pred_conv_weight", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "16", - "pad": "(1, 1)", - "stride": "(1, 1)" - }, - "inputs": [] - }, - { - "op": "null", - "name": "multi_feat_4_conv_3x3_relu_loc_pred_conv_bias", - "attrs": { - "__init__": "[\"constant\", {\"value\": 0.0}]", - "__lr_mult__": "2.0" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "multi_feat_4_conv_3x3_relu_loc_pred_conv", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "16", - "pad": "(1, 1)", - "stride": "(1, 1)" - }, - "inputs": [[467, 0, 0], [510, 0, 0], [511, 0, 0]] - }, - { - "op": "transpose", - "name": "transpose8", - "attrs": {"axes": "(0, 2, 3, 1)"}, - "inputs": [[512, 0, 0]] - }, - { - "op": "Flatten", - "name": "flatten13", - "inputs": [[513, 0, 0]] - }, - { - "op": "null", - "name": "multi_feat_5_conv_3x3_relu_loc_pred_conv_weight", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "16", - "pad": "(1, 1)", - "stride": "(1, 1)" - }, - "inputs": [] - }, - { - "op": "null", - "name": "multi_feat_5_conv_3x3_relu_loc_pred_conv_bias", - "attrs": { - "__init__": "[\"constant\", {\"value\": 0.0}]", - "__lr_mult__": "2.0" - }, - "inputs": [] - }, - { - "op": "Convolution", - "name": "multi_feat_5_conv_3x3_relu_loc_pred_conv", - "attrs": { - "kernel": "(3, 3)", - "num_filter": "16", - "pad": "(1, 1)", - "stride": "(1, 1)" - }, - "inputs": [[480, 0, 0], [515, 0, 0], [516, 0, 0]] - }, - { - "op": "transpose", - "name": "transpose10", - "attrs": {"axes": "(0, 2, 3, 1)"}, - "inputs": [[517, 0, 0]] - }, - { - "op": "Flatten", - "name": "flatten16", - "inputs": [[518, 0, 0]] - }, - { - "op": "Concat", - "name": "multibox_loc_pred", - "attrs": { - "dim": "1", - "num_args": "6" - }, - "inputs": [[494, 0, 0], [499, 0, 0], [504, 0, 0], [509, 0, 0], [514, 0, 0], [519, 0, 0]] - }, - { - "op": "_contrib_MultiBoxPrior", - "name": "_plus12_anchors", - "attrs": { - "clip": "False", - "ratios": "(1,2,0.5)", - "sizes": "(0.1,0.141)", - "steps": "(-1.0, -1.0)" - }, - "inputs": [[346, 0, 0]] - }, - { - "op": "Flatten", - "name": "flatten3", - "inputs": [[521, 0, 0]] - }, - { - "op": "_contrib_MultiBoxPrior", - "name": "_plus15_anchors", - "attrs": { - "clip": "False", - "ratios": "(1,2,0.5,3,0.333333333333)", - "sizes": "(0.2,0.272)", - "steps": "(-1.0, -1.0)" - }, - "inputs": [[428, 0, 0]] - }, - { - "op": "Flatten", - "name": "flatten6", - "inputs": [[523, 0, 0]] - }, - { - "op": "_contrib_MultiBoxPrior", - "name": "multi_feat_2_conv_3x3_relu_anchors", - "attrs": { - "clip": "False", - "ratios": "(1,2,0.5,3,0.333333333333)", - "sizes": "(0.37,0.447)", - "steps": "(-1.0, -1.0)" - }, - "inputs": [[441, 0, 0]] - }, - { - "op": "Flatten", - "name": "flatten9", - "inputs": [[525, 0, 0]] - }, - { - "op": "_contrib_MultiBoxPrior", - "name": "multi_feat_3_conv_3x3_relu_anchors", - "attrs": { - "clip": "False", - "ratios": "(1,2,0.5,3,0.333333333333)", - "sizes": "(0.54,0.619)", - "steps": "(-1.0, -1.0)" - }, - "inputs": [[454, 0, 0]] - }, - { - "op": "Flatten", - "name": "flatten12", - "inputs": [[527, 0, 0]] - }, - { - "op": "_contrib_MultiBoxPrior", - "name": "multi_feat_4_conv_3x3_relu_anchors", - "attrs": { - "clip": "False", - "ratios": "(1,2,0.5)", - "sizes": "(0.71,0.79)", - "steps": "(-1.0, -1.0)" - }, - "inputs": [[467, 0, 0]] - }, - { - "op": "Flatten", - "name": "flatten15", - "inputs": [[529, 0, 0]] - }, - { - "op": "_contrib_MultiBoxPrior", - "name": "multi_feat_5_conv_3x3_relu_anchors", - "attrs": { - "clip": "False", - "ratios": "(1,2,0.5)", - "sizes": "(0.88,0.961)", - "steps": "(-1.0, -1.0)" - }, - "inputs": [[480, 0, 0]] - }, - { - "op": "Flatten", - "name": "flatten18", - "inputs": [[531, 0, 0]] - }, - { - "op": "Concat", - "name": "concat1", - "attrs": { - "dim": "1", - "num_args": "6" - }, - "inputs": [[522, 0, 0], [524, 0, 0], [526, 0, 0], [528, 0, 0], [530, 0, 0], [532, 0, 0]] - }, - { - "op": "Reshape", - "name": "multibox_anchors", - "attrs": {"shape": "(0, -1, 4)"}, - "inputs": [[533, 0, 0]] - }, - { - "op": "_contrib_MultiBoxDetection", - "name": "detection", - "attrs": { - "force_suppress": "False", - "nms_threshold": "0.5", - "nms_topk": "400", - "variances": "(0.1, 0.1, 0.2, 0.2)" - }, - "inputs": [[489, 0, 0], [520, 0, 0], [534, 0, 0]] - } - ], - "arg_nodes": [ - 0, - 2, - 3, - 4, - 5, - 7, - 9, - 10, - 11, - 12, - 16, - 17, - 18, - 19, - 22, - 24, - 25, - 26, - 27, - 30, - 32, - 33, - 34, - 35, - 38, - 40, - 43, - 44, - 45, - 46, - 49, - 51, - 52, - 53, - 54, - 57, - 59, - 60, - 61, - 62, - 65, - 68, - 69, - 70, - 71, - 74, - 76, - 77, - 78, - 79, - 82, - 84, - 85, - 86, - 87, - 90, - 93, - 94, - 95, - 96, - 99, - 101, - 102, - 103, - 104, - 107, - 109, - 110, - 111, - 112, - 115, - 117, - 120, - 121, - 122, - 123, - 126, - 128, - 129, - 130, - 131, - 134, - 136, - 137, - 138, - 139, - 142, - 145, - 146, - 147, - 148, - 151, - 153, - 154, - 155, - 156, - 159, - 161, - 162, - 163, - 164, - 167, - 170, - 171, - 172, - 173, - 176, - 178, - 179, - 180, - 181, - 184, - 186, - 187, - 188, - 189, - 192, - 195, - 196, - 197, - 198, - 201, - 203, - 204, - 205, - 206, - 209, - 211, - 212, - 213, - 214, - 217, - 219, - 222, - 223, - 224, - 225, - 228, - 230, - 231, - 232, - 233, - 236, - 238, - 239, - 240, - 241, - 244, - 247, - 248, - 249, - 250, - 253, - 255, - 256, - 257, - 258, - 261, - 263, - 264, - 265, - 266, - 269, - 272, - 273, - 274, - 275, - 278, - 280, - 281, - 282, - 283, - 286, - 288, - 289, - 290, - 291, - 294, - 297, - 298, - 299, - 300, - 303, - 305, - 306, - 307, - 308, - 311, - 313, - 314, - 315, - 316, - 319, - 322, - 323, - 324, - 325, - 328, - 330, - 331, - 332, - 333, - 336, - 338, - 339, - 340, - 341, - 344, - 347, - 348, - 352, - 353, - 354, - 355, - 358, - 360, - 361, - 362, - 363, - 366, - 368, - 369, - 370, - 371, - 374, - 376, - 379, - 380, - 381, - 382, - 385, - 387, - 388, - 389, - 390, - 393, - 395, - 396, - 397, - 398, - 401, - 404, - 405, - 406, - 407, - 410, - 412, - 413, - 414, - 415, - 418, - 420, - 421, - 422, - 423, - 426, - 429, - 430, - 434, - 435, - 438, - 439, - 442, - 443, - 447, - 448, - 451, - 452, - 455, - 456, - 460, - 461, - 464, - 465, - 468, - 469, - 473, - 474, - 477, - 478, - 481, - 482, - 490, - 491, - 495, - 496, - 500, - 501, - 505, - 506, - 510, - 511, - 515, - 516 - ], - "node_row_ptr": [ - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 9, - 10, - 11, - 12, - 13, - 14, - 15, - 18, - 19, - 20, - 21, - 22, - 23, - 24, - 27, - 28, - 29, - 30, - 31, - 32, - 33, - 34, - 37, - 38, - 39, - 40, - 41, - 42, - 43, - 44, - 47, - 48, - 49, - 50, - 51, - 52, - 53, - 54, - 55, - 56, - 57, - 60, - 61, - 62, - 63, - 64, - 65, - 66, - 67, - 70, - 71, - 72, - 73, - 74, - 75, - 76, - 77, - 80, - 81, - 82, - 83, - 84, - 85, - 86, - 87, - 88, - 91, - 92, - 93, - 94, - 95, - 96, - 97, - 98, - 101, - 102, - 103, - 104, - 105, - 106, - 107, - 108, - 111, - 112, - 113, - 114, - 115, - 116, - 117, - 118, - 119, - 122, - 123, - 124, - 125, - 126, - 127, - 128, - 129, - 132, - 133, - 134, - 135, - 136, - 137, - 138, - 139, - 142, - 143, - 144, - 145, - 146, - 147, - 148, - 149, - 150, - 151, - 152, - 155, - 156, - 157, - 158, - 159, - 160, - 161, - 162, - 165, - 166, - 167, - 168, - 169, - 170, - 171, - 172, - 175, - 176, - 177, - 178, - 179, - 180, - 181, - 182, - 183, - 186, - 187, - 188, - 189, - 190, - 191, - 192, - 193, - 196, - 197, - 198, - 199, - 200, - 201, - 202, - 203, - 206, - 207, - 208, - 209, - 210, - 211, - 212, - 213, - 214, - 217, - 218, - 219, - 220, - 221, - 222, - 223, - 224, - 227, - 228, - 229, - 230, - 231, - 232, - 233, - 234, - 237, - 238, - 239, - 240, - 241, - 242, - 243, - 244, - 245, - 248, - 249, - 250, - 251, - 252, - 253, - 254, - 255, - 258, - 259, - 260, - 261, - 262, - 263, - 264, - 265, - 268, - 269, - 270, - 271, - 272, - 273, - 274, - 275, - 276, - 277, - 278, - 281, - 282, - 283, - 284, - 285, - 286, - 287, - 288, - 291, - 292, - 293, - 294, - 295, - 296, - 297, - 298, - 301, - 302, - 303, - 304, - 305, - 306, - 307, - 308, - 309, - 312, - 313, - 314, - 315, - 316, - 317, - 318, - 319, - 322, - 323, - 324, - 325, - 326, - 327, - 328, - 329, - 332, - 333, - 334, - 335, - 336, - 337, - 338, - 339, - 340, - 343, - 344, - 345, - 346, - 347, - 348, - 349, - 350, - 353, - 354, - 355, - 356, - 357, - 358, - 359, - 360, - 363, - 364, - 365, - 366, - 367, - 368, - 369, - 370, - 371, - 374, - 375, - 376, - 377, - 378, - 379, - 380, - 381, - 384, - 385, - 386, - 387, - 388, - 389, - 390, - 391, - 394, - 395, - 396, - 397, - 398, - 399, - 400, - 401, - 402, - 405, - 406, - 407, - 408, - 409, - 410, - 411, - 412, - 415, - 416, - 417, - 418, - 419, - 420, - 421, - 422, - 425, - 426, - 427, - 428, - 429, - 430, - 431, - 432, - 433, - 434, - 435, - 436, - 437, - 438, - 441, - 442, - 443, - 444, - 445, - 446, - 447, - 448, - 451, - 452, - 453, - 454, - 455, - 456, - 457, - 458, - 461, - 462, - 463, - 464, - 465, - 466, - 467, - 468, - 469, - 470, - 471, - 474, - 475, - 476, - 477, - 478, - 479, - 480, - 481, - 484, - 485, - 486, - 487, - 488, - 489, - 490, - 491, - 494, - 495, - 496, - 497, - 498, - 499, - 500, - 501, - 502, - 505, - 506, - 507, - 508, - 509, - 510, - 511, - 512, - 515, - 516, - 517, - 518, - 519, - 520, - 521, - 522, - 525, - 526, - 527, - 528, - 529, - 530, - 531, - 532, - 533, - 534, - 535, - 536, - 537, - 538, - 539, - 540, - 541, - 542, - 543, - 544, - 545, - 546, - 547, - 548, - 549, - 550, - 551, - 552, - 553, - 554, - 555, - 556, - 557, - 558, - 559, - 560, - 561, - 562, - 563, - 564, - 565, - 566, - 567, - 568, - 569, - 570, - 571, - 572, - 573, - 574, - 575, - 576, - 577, - 578, - 579, - 580, - 581, - 582, - 583, - 584, - 585, - 586, - 587, - 588, - 589, - 590, - 591, - 592, - 593, - 594, - 595, - 596, - 597, - 598, - 599, - 600, - 601, - 602, - 603, - 604, - 605, - 606, - 607, - 608, - 609, - 610, - 611, - 612, - 613, - 614, - 615, - 616, - 617, - 618, - 619, - 620, - 621, - 622, - 623, - 624, - 625, - 626, - 627, - 628, - 629, - 630, - 631, - 632, - 633, - 634, - 635, - 636 - ], - "heads": [[535, 0, 0]], - "attrs": {"mxnet_version": ["int", 10200]} -} \ No newline at end of file From d20606a106f254e6383c7dfe91feae6527274278 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 31 May 2018 23:29:44 +0000 Subject: [PATCH 08/22] Fix cpplint --- nnvm/src/top/vision/nms.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/nnvm/src/top/vision/nms.cc b/nnvm/src/top/vision/nms.cc index 9cbdcd4d5095..4a6723222b21 100644 --- a/nnvm/src/top/vision/nms.cc +++ b/nnvm/src/top/vision/nms.cc @@ -46,7 +46,6 @@ inline bool NMSInferType(const NodeAttrs &attrs, std::vector *out_attrs) { DTYPE_ASSIGN(out_attrs->at(0), in_attrs->at(0)); return true; - } inline bool NMSInferLayout(const NodeAttrs& attrs, From 23f49b469f9aa258fa9d3243e84f0f7e701d37c5 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 1 Jun 2018 00:57:01 +0000 Subject: [PATCH 09/22] Fix BatchNorm scale issue --- nnvm/python/nnvm/testing/resnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nnvm/python/nnvm/testing/resnet.py b/nnvm/python/nnvm/testing/resnet.py index 64eb63c29b7a..243cc1b65144 100644 --- a/nnvm/python/nnvm/testing/resnet.py +++ b/nnvm/python/nnvm/testing/resnet.py @@ -108,7 +108,7 @@ def resnet(units, num_stages, filter_list, num_classes, image_shape, num_unit = len(units) assert num_unit == num_stages data = sym.Variable(name='data') - data = sym.batch_norm(data=data, epsilon=2e-5, name='bn_data') + data = sym.batch_norm(data=data, epsilon=2e-5, scale=False, name='bn_data') (_, height, _) = image_shape if height <= 32: # such as cifar10 body = sym.conv2d( From 58f47a7fa07a4503d9958a6136e8e5c3fe859592 Mon Sep 17 00:00:00 2001 From: Wang Date: Mon, 4 Jun 2018 12:06:38 -0700 Subject: [PATCH 10/22] Address comments --- nnvm/python/nnvm/compiler/build_module.py | 6 +++--- nnvm/python/nnvm/compiler/graph_util.py | 3 +++ nnvm/src/top/vision/nms.cc | 10 +++------- nnvm/src/top/vision/ssd/mutibox_op.cc | 12 ++++++------ 4 files changed, 15 insertions(+), 16 deletions(-) diff --git a/nnvm/python/nnvm/compiler/build_module.py b/nnvm/python/nnvm/compiler/build_module.py index 9f40e2a80fe8..817214544e1c 100644 --- a/nnvm/python/nnvm/compiler/build_module.py +++ b/nnvm/python/nnvm/compiler/build_module.py @@ -335,9 +335,9 @@ def build(graph, target=None, shape=None, dtype="float32", if params is None: params = {} params.update(init_var) - if not build_extra: - return graph, libmod, params - return graph, libmod, params, extra_lib + if build_extra: + return graph, libmod, params, extra_lib + return graph, libmod, params def _remove_noref_params(params, graph): """ Helper to clear non referenced params diff --git a/nnvm/python/nnvm/compiler/graph_util.py b/nnvm/python/nnvm/compiler/graph_util.py index 621872ead98b..3b2915b38b6d 100644 --- a/nnvm/python/nnvm/compiler/graph_util.py +++ b/nnvm/python/nnvm/compiler/graph_util.py @@ -167,6 +167,9 @@ def split_last_op(graph): """ graph_idx = graph.index last_op_node = graph_idx.nodes[-1] + if last_op_node["op"] == "null": + raise RuntimeError("split_last_op doesn't support sast operator " + "to be null.") last_op_func = getattr(sym, last_op_node["op"]) if "attrs" in last_op_node: last_op_attr = last_op_node["attrs"] diff --git a/nnvm/src/top/vision/nms.cc b/nnvm/src/top/vision/nms.cc index 4a6723222b21..2680b894255b 100644 --- a/nnvm/src/top/vision/nms.cc +++ b/nnvm/src/top/vision/nms.cc @@ -27,17 +27,13 @@ bool NMSShape(const NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 2U) << "Inputs: [data, valid_count]"; TShape dshape = in_attrs->at(0); TShape vshape = in_attrs->at(1); - CHECK_EQ(dshape.ndim(), 3U) << "Provided: " << dshape; - CHECK_EQ(vshape.ndim(), 1U) << "Provided: " << vshape; + CHECK_EQ(dshape.ndim(), 3U) << "Input data should be 3-D."; + CHECK_EQ(vshape.ndim(), 1U) << "Input valid count should be 1-D."; CHECK_EQ(dshape[2], 6U) << "Data input should have shape " "(batch_size, num_anchors, 6)."; CHECK_EQ(dshape[0], vshape[0]) << "batch_size mismatch."; - TShape oshape = TShape(3); - oshape[0] = dshape[0]; - oshape[1] = dshape[1]; - oshape[2] = 6; // [id, prob, xmin, ymin, xmax, ymax] out_attrs->clear(); - NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, oshape); + NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, dshape); return true; } diff --git a/nnvm/src/top/vision/ssd/mutibox_op.cc b/nnvm/src/top/vision/ssd/mutibox_op.cc index 577657ecb304..d02ae802c636 100644 --- a/nnvm/src/top/vision/ssd/mutibox_op.cc +++ b/nnvm/src/top/vision/ssd/mutibox_op.cc @@ -91,12 +91,12 @@ bool MultiBoxDetectionShape(const NodeAttrs& attrs, TShape cshape = in_attrs->at(0); TShape lshape = in_attrs->at(1); TShape ashape = in_attrs->at(2); - CHECK_EQ(cshape.ndim(), 3U) << "Provided: " << cshape; - CHECK_EQ(lshape.ndim(), 2U) << "Provided: " << lshape; - CHECK_EQ(ashape.ndim(), 3U) << "Provided: " << ashape; - CHECK_EQ(cshape[2], ashape[1]) << "Number of anchors mismatch"; - CHECK_EQ(cshape[2] * 4, lshape[1]) << "# anchors mismatch with # loc"; - CHECK_GT(ashape[1], 0U) << "Number of anchors must > 0"; + CHECK_EQ(cshape.ndim(), 3U) << "Class probability should be 3-D."; + CHECK_EQ(lshape.ndim(), 2U) << "Location prediction should be 2-D."; + CHECK_EQ(ashape.ndim(), 3U) << "Anchor should be 3-D."; + CHECK_EQ(cshape[2], ashape[1]) << "Number of anchors mismatch."; + CHECK_EQ(cshape[2] * 4, lshape[1]) << "# anchors mismatch with # loc."; + CHECK_GT(ashape[1], 0U) << "Number of anchors must > 0."; CHECK_EQ(ashape[2], 4U); TShape oshape = TShape(3); oshape[0] = cshape[0]; From 307032cb19ae45ee0e38ade96e93e8a157256cce Mon Sep 17 00:00:00 2001 From: Wang Date: Mon, 4 Jun 2018 12:15:50 -0700 Subject: [PATCH 11/22] Fix typo --- nnvm/python/nnvm/compiler/graph_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nnvm/python/nnvm/compiler/graph_util.py b/nnvm/python/nnvm/compiler/graph_util.py index 3b2915b38b6d..58782ce374b2 100644 --- a/nnvm/python/nnvm/compiler/graph_util.py +++ b/nnvm/python/nnvm/compiler/graph_util.py @@ -168,7 +168,7 @@ def split_last_op(graph): graph_idx = graph.index last_op_node = graph_idx.nodes[-1] if last_op_node["op"] == "null": - raise RuntimeError("split_last_op doesn't support sast operator " + raise RuntimeError("split_last_op doesn't support last operator " "to be null.") last_op_func = getattr(sym, last_op_node["op"]) if "attrs" in last_op_node: From 1bba66a17cdec2500ff61efbbd4ed858d6e3166d Mon Sep 17 00:00:00 2001 From: Wang Date: Sat, 9 Jun 2018 18:58:10 -0700 Subject: [PATCH 12/22] Split multibox_detection --- nnvm/include/nnvm/top/nn.h | 13 +---- nnvm/python/nnvm/frontend/mxnet.py | 12 ++-- nnvm/python/nnvm/top/vision.py | 19 +++---- nnvm/src/top/vision/ssd/mutibox_op.cc | 57 +++++++++++-------- nnvm/tests/python/compiler/test_top_level4.py | 11 ++-- topi/python/topi/vision/ssd/multibox.py | 4 +- 6 files changed, 57 insertions(+), 59 deletions(-) diff --git a/nnvm/include/nnvm/top/nn.h b/nnvm/include/nnvm/top/nn.h index 8575ebe63fdc..bbdb3b9c4f12 100644 --- a/nnvm/include/nnvm/top/nn.h +++ b/nnvm/include/nnvm/top/nn.h @@ -340,26 +340,17 @@ struct MultiBoxPriorParam : public dmlc::Parameter { } }; -struct MultiBoxDetectionParam : public dmlc::Parameter { +struct MultiBoxTransformLocParam : public dmlc::Parameter { bool clip; float threshold; - float nms_threshold; - bool force_suppress; - int nms_topk; Tuple variances; - DMLC_DECLARE_PARAMETER(MultiBoxDetectionParam) { + DMLC_DECLARE_PARAMETER(MultiBoxTransformLocParam) { DMLC_DECLARE_FIELD(clip).set_default(true) .describe("Clip out-of-boundary boxes."); DMLC_DECLARE_FIELD(threshold).set_default(0.01) .describe("Threshold to be a positive prediction."); - DMLC_DECLARE_FIELD(nms_threshold).set_default(0.5) - .describe("Non-maximum suppression threshold."); - DMLC_DECLARE_FIELD(force_suppress).set_default(false) - .describe("Suppress all detections regardless of class_id."); DMLC_DECLARE_FIELD(variances).set_default(Tuple{0.1, 0.1, 0.2, 0.2}) .describe("Variances to be decoded from box regression output."); - DMLC_DECLARE_FIELD(nms_topk).set_default(-1) - .describe("Keep maximum top k detections before nms, -1 for no limit."); } }; diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py index fa7c44418348..51b2c0b01618 100644 --- a/nnvm/python/nnvm/frontend/mxnet.py +++ b/nnvm/python/nnvm/frontend/mxnet.py @@ -226,12 +226,12 @@ def _contrib_multibox_detection(inputs, attrs): variances = tuple([float(x.strip()) for x in attrs.get('variances').strip('()').split(',')]) \ if attrs.get('variances') is not None else (0.1, 0.1, 0.2, 0.2) nms_topk = attrs.get('nms_topk') or -1 - new_attrs = {'clip': clip, 'threshold': float(threshold), - 'nms_threshold': float(nms_threshold), - 'force_suppress': force_suppress, - 'variances': variances, 'nms_topk': int(nms_topk)} - return _get_nnvm_op('multibox_detection')(inputs[0], inputs[1], - inputs[2], **new_attrs) + new_attrs0 = {'clip': clip, 'threshold': float(threshold), 'variances': variances} + new_attrs1 = {'nms_threshold': float(nms_threshold), 'force_suppress': force_suppress, + 'nms_topk': int(nms_topk)} + data, valid_count = _get_nnvm_op('multibox_detection')(inputs[0], inputs[1], + inputs[2], **new_attrs0) + return _get_nnvm_op('nms')(data, valid_count, **new_attrs1) def _elemwise_sum(inputs, _): new_attrs = {'num_args':len(inputs)} diff --git a/nnvm/python/nnvm/top/vision.py b/nnvm/python/nnvm/top/vision.py index 7e5e641f340f..16d07496f333 100644 --- a/nnvm/python/nnvm/top/vision.py +++ b/nnvm/python/nnvm/top/vision.py @@ -1,4 +1,3 @@ - # pylint: disable=invalid-name, unused-argument """Definition of nn ops""" from __future__ import absolute_import @@ -60,26 +59,22 @@ def compute_multibox_prior(attrs, inputs, _): reg.register_pattern("multibox_prior", OpPattern.OPAQUE) -# multibox_detection -@reg.register_schedule("multibox_detection") -def schedule_multibox_detection(_, outs, target): +# multibox_transform_loc +@reg.register_schedule("multibox_transform_loc") +def schedule_multibox_transform_loc(_, outs, target): """Schedule definition of multibox_detection""" with tvm.target.create(target): - return topi.generic.schedule_multibox_detection(outs) + return topi.generic.schedule_multibox_transform_loc(outs) -@reg.register_compute("multibox_detection") +@reg.register_compute("multibox_transform_loc") def compute_multibox_detection(attrs, inputs, _): """Compute definition of multibox_detection""" clip = attrs.get_bool('clip') threshold = attrs.get_float('threshold') - nms_threshold = attrs.get_float('nms_threshold') - force_suppress = attrs.get_bool('force_suppress') variance = attrs.get_float_tuple('variances') - nms_topk = attrs.get_int('nms_topk') - return topi.vision.ssd.multibox_detection(inputs[0], inputs[1], inputs[2], - clip, threshold, nms_threshold, - force_suppress, variance, nms_topk) + return topi.vision.ssd.multibox_transform_loc(inputs[0], inputs[1], inputs[2], + clip, threshold, variance) reg.register_pattern("multibox_detection", OpPattern.OPAQUE) diff --git a/nnvm/src/top/vision/ssd/mutibox_op.cc b/nnvm/src/top/vision/ssd/mutibox_op.cc index d02ae802c636..7f1aca5d2b82 100644 --- a/nnvm/src/top/vision/ssd/mutibox_op.cc +++ b/nnvm/src/top/vision/ssd/mutibox_op.cc @@ -82,11 +82,11 @@ NNVM_REGISTER_OP(multibox_prior) }) .set_support_level(4); -DMLC_REGISTER_PARAMETER(MultiBoxDetectionParam); +DMLC_REGISTER_PARAMETER(MultiBoxTransformLocParam); -bool MultiBoxDetectionShape(const NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { +bool MultiBoxTransformLocShape(const NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 3U) << "Inputs: [cls_prob, loc_pred, anchor]"; TShape cshape = in_attrs->at(0); TShape lshape = in_attrs->at(1); @@ -98,22 +98,25 @@ bool MultiBoxDetectionShape(const NodeAttrs& attrs, CHECK_EQ(cshape[2] * 4, lshape[1]) << "# anchors mismatch with # loc."; CHECK_GT(ashape[1], 0U) << "Number of anchors must > 0."; CHECK_EQ(ashape[2], 4U); - TShape oshape = TShape(3); - oshape[0] = cshape[0]; - oshape[1] = ashape[1]; - oshape[2] = 6; // [id, prob, xmin, ymin, xmax, ymax] + TShape oshape0 = TShape(3); + oshape0[0] = cshape[0]; + oshape0[1] = ashape[1]; + oshape0[2] = 6; // [id, prob, xmin, ymin, xmax, ymax] + TShape oshape1 = TShape(1); + oshape1[0] = cshape[0]; out_attrs->clear(); - NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, oshape); + NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, oshape0); + NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 1, oshape1); return true; } -inline bool MultiBoxDetectionLayout(const NodeAttrs& attrs, - std::vector *ilayouts, - const std::vector *last_ilayouts, - std::vector *olayouts) { +inline bool MultiBoxTransformLocLayout(const NodeAttrs& attrs, + std::vector *ilayouts, + const std::vector *last_ilayouts, + std::vector *olayouts) { CHECK_EQ(ilayouts->size(), 3U); CHECK_EQ(last_ilayouts->size(), 3U); - CHECK_EQ(olayouts->size(), 1U); + CHECK_EQ(olayouts->size(), 2U); for (size_t i = 0; i < last_ilayouts->size(); ++i) { const Layout& last_layout = last_ilayouts->at(i); if (last_layout.defined()) { @@ -123,24 +126,32 @@ inline bool MultiBoxDetectionLayout(const NodeAttrs& attrs, return true; } -NNVM_REGISTER_OP(multibox_detection) - .describe(R"doc("Convert multibox detection predictions." +inline bool MultiBoxTransformLocInferType(const NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + DTYPE_ASSIGN(out_attrs->at(0), in_attrs->at(0)); + DTYPE_ASSIGN(out_attrs->at(1), 4U); + return true; +} + +NNVM_REGISTER_OP(multibox_transform_loc) + .describe(R"doc("Location transformation for multibox detection." )doc" NNVM_ADD_FILELINE) .set_num_inputs(3) -.set_num_outputs(1) -.set_attr_parser(ParamParser) +.set_num_outputs(2) +.set_attr_parser(ParamParser) .set_attr("FGetAttrDict", - ParamGetAttrDict) -.add_arguments(MultiBoxDetectionParam::__FIELDS__()) + ParamGetAttrDict) +.add_arguments(MultiBoxTransformLocParam::__FIELDS__()) .add_argument("cls_prob", "Tensor", "Class probabilities.") .add_argument("loc_pred", "Tensor", "Location regression predictions.") .add_argument("anchor", "Tensor", "Multibox prior anchor boxes") .set_attr("FListInputNames", [](const NodeAttrs& attrs) { return std::vector{"cls_prob", "loc_pred", "anchor"}; }) -.set_attr("FInferShape", MultiBoxDetectionShape) -.set_attr("FInferType", ElemwiseType<3, 1>) -.set_attr("FCorrectLayout", MultiBoxDetectionLayout) +.set_attr("FInferShape", MultiBoxTransformLocShape) +.set_attr("FInferType", MultiBoxTransformLocInferType) +.set_attr("FCorrectLayout", MultiBoxTransformLocLayout) .set_support_level(4); } // namespace top diff --git a/nnvm/tests/python/compiler/test_top_level4.py b/nnvm/tests/python/compiler/test_top_level4.py index 421f0879826f..c734fc7f1f2e 100644 --- a/nnvm/tests/python/compiler/test_top_level4.py +++ b/nnvm/tests/python/compiler/test_top_level4.py @@ -408,14 +408,16 @@ def test_multibox_prior(): verify_multibox_prior((1, 3, 224, 224), sizes=(0.5, 0.25, 0.1), ratios=(1, 2, 0.5)) verify_multibox_prior((1, 32, 32, 32), sizes=(0.5, 0.25), ratios=(1, 2), steps=(2, 2), clip=True) -def test_multibox_detection(): +def test_multibox_transform_loc(): batch_size = 1 num_anchors = 3 num_classes = 3 cls_prob = sym.Variable("cls_prob") loc_preds = sym.Variable("loc_preds") anchors = sym.Variable("anchors") - out = sym.multibox_detection(cls_prob=cls_prob, loc_pred=loc_preds, anchor=anchors) + transform_loc_data, valid_count = sym.multibox_transform_loc(cls_prob=cls_prob, loc_pred=loc_preds, + anchor=anchors) + out = sym.nms(data=transform_loc_data, valid_count=valid_count) # Manually create test case np_cls_prob = np.array([[[0.2, 0.5, 0.3], [0.25, 0.3, 0.45], [0.7, 0.1, 0.2]]]) @@ -438,7 +440,6 @@ def test_multibox_detection(): out = m.get_output(0, tvm.nd.empty(expected_np_out.shape, dtype)) np.testing.assert_allclose(out.asnumpy(), expected_np_out, atol=1e-5, rtol=1e-5) - def test_nms(): dshape = (1, 5, 6) data = sym.Variable("data") @@ -483,6 +484,6 @@ def test_nms(): test_full() test_flip() test_multibox_prior() - test_multibox_detection() + test_multibox_transform_loc() test_nms() - print(nnvm.compiler.engine.dump()) + #print(nnvm.compiler.engine.dump()) diff --git a/topi/python/topi/vision/ssd/multibox.py b/topi/python/topi/vision/ssd/multibox.py index 31a090e345d5..1c121a96a670 100644 --- a/topi/python/topi/vision/ssd/multibox.py +++ b/topi/python/topi/vision/ssd/multibox.py @@ -211,7 +211,7 @@ def transform_loc(loc, loc_base_idx, anchor, anchor_base_idx, clip, vx, vy, vw, @tvm.target.generic_func -def mutibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, +def multibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, variances=(0.1, 0.1, 0.2, 0.2)): """Location transformation for multibox detection @@ -301,7 +301,7 @@ def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nm out : tvm.Tensor 3-D tensor with shape (batch_size, num_anchors, 6) """ - inter_out, valid_count = mutibox_transform_loc(cls_prob, loc_pred, anchor, + inter_out, valid_count = multibox_transform_loc(cls_prob, loc_pred, anchor, clip, threshold, variances) out = nms(inter_out, valid_count, nms_threshold, force_suppress, nms_topk) return out From bb30396e6fcf325a0f735dde19d0c5afbcf72f5f Mon Sep 17 00:00:00 2001 From: Wang Date: Sat, 9 Jun 2018 19:51:15 -0700 Subject: [PATCH 13/22] Fix output formant of multibox_transform_loc --- nnvm/python/nnvm/top/vision.py | 2 +- nnvm/tests/python/compiler/test_top_level4.py | 2 +- src/contrib/sort/sort.cc | 10 ++++------ topi/python/topi/vision/ssd/multibox.py | 16 ++++++---------- 4 files changed, 12 insertions(+), 18 deletions(-) diff --git a/nnvm/python/nnvm/top/vision.py b/nnvm/python/nnvm/top/vision.py index 16d07496f333..edbf72320a26 100644 --- a/nnvm/python/nnvm/top/vision.py +++ b/nnvm/python/nnvm/top/vision.py @@ -67,7 +67,7 @@ def schedule_multibox_transform_loc(_, outs, target): return topi.generic.schedule_multibox_transform_loc(outs) @reg.register_compute("multibox_transform_loc") -def compute_multibox_detection(attrs, inputs, _): +def compute_multibox_transform_loc(attrs, inputs, _): """Compute definition of multibox_detection""" clip = attrs.get_bool('clip') threshold = attrs.get_float('threshold') diff --git a/nnvm/tests/python/compiler/test_top_level4.py b/nnvm/tests/python/compiler/test_top_level4.py index c734fc7f1f2e..b202d1aad862 100644 --- a/nnvm/tests/python/compiler/test_top_level4.py +++ b/nnvm/tests/python/compiler/test_top_level4.py @@ -486,4 +486,4 @@ def test_nms(): test_multibox_prior() test_multibox_transform_loc() test_nms() - #print(nnvm.compiler.engine.dump()) + print(nnvm.compiler.engine.dump()) diff --git a/src/contrib/sort/sort.cc b/src/contrib/sort/sort.cc index 160e479b86b5..2ddc47f161e3 100644 --- a/src/contrib/sort/sort.cc +++ b/src/contrib/sort/sort.cc @@ -78,12 +78,10 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") int64_t full_idx = base_idx + k * axis_mul_after; sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx))); } - if (is_descend) { - std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); - } else { - std::stable_sort(sorter.begin(), sorter.end(), CompareAscend); - } - for (int32_t k = 0; k < input->shape[axis]; ++k) { + std::stable_sort(sorter.begin(), sorter.end(), + is_descend ? CompareDescend + : CompareAscend); + for (uint32_t k = 0; k < input->shape[axis]; ++k) { *(static_cast(output->data) + base_idx + k * axis_mul_after) = k < static_cast(sorter.size()) ? sorter[k].first : k; } diff --git a/topi/python/topi/vision/ssd/multibox.py b/topi/python/topi/vision/ssd/multibox.py index 1c121a96a670..7cfb21fbff12 100644 --- a/topi/python/topi/vision/ssd/multibox.py +++ b/topi/python/topi/vision/ssd/multibox.py @@ -27,7 +27,7 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets): ratios : tuple of float Tuple of ratios for anchor boxes. - steps : Tuple of int + steps : Tuple of float Priorbox step across y and x, -1 for auto calculation. offsets : tuple of int @@ -86,7 +86,7 @@ def multibox_prior(data, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5, ratios : tuple of float Tuple of ratios for anchor boxes. - steps : Tuple of int + steps : Tuple of float Priorbox step across y and x, -1 for auto calculation. offsets : tuple of int @@ -237,11 +237,7 @@ def multibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01 Returns ------- - out : tvm.Tensor - 3-D tensor with shape (batch_size, num_anchors, 6) - - valid_count : tvm.Tensor - 1-D tensor with shape (batch_size,), number of valid anchor boxes. + ret : tuple of tvm.Tensor """ batch_size = cls_prob.shape[0] num_anchors = anchor.shape[1] @@ -259,7 +255,7 @@ def multibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01 dtype=[valid_count_dtype, cls_prob.dtype], out_buffers=[valid_count_buf, out_buf], tag="multibox_transform_loc") - return out, valid_count + return [out, valid_count] @tvm.target.generic_func @@ -301,7 +297,7 @@ def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nm out : tvm.Tensor 3-D tensor with shape (batch_size, num_anchors, 6) """ - inter_out, valid_count = multibox_transform_loc(cls_prob, loc_pred, anchor, + inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor, clip, threshold, variances) - out = nms(inter_out, valid_count, nms_threshold, force_suppress, nms_topk) + out = nms(inter_out[0], inter_out[1], nms_threshold, force_suppress, nms_topk) return out From 837fc7a682686fe4cd0535112316dedc1deea1e0 Mon Sep 17 00:00:00 2001 From: Wang Date: Sat, 9 Jun 2018 20:02:37 -0700 Subject: [PATCH 14/22] Remove build extra lib --- nnvm/python/nnvm/compiler/build_module.py | 39 ----------------------- nnvm/python/nnvm/frontend/mxnet.py | 2 +- nnvm/tests/python/compiler/test_build.py | 35 -------------------- 3 files changed, 1 insertion(+), 75 deletions(-) diff --git a/nnvm/python/nnvm/compiler/build_module.py b/nnvm/python/nnvm/compiler/build_module.py index 817214544e1c..ed75b10414c7 100644 --- a/nnvm/python/nnvm/compiler/build_module.py +++ b/nnvm/python/nnvm/compiler/build_module.py @@ -32,8 +32,6 @@ class BuildConfig(object): defaults = { "opt_level": 2, "add_pass": None, - "extra_lib_op": None, - "extra_lib_target": None, } def __init__(self, **kwargs): self._old_scope = None @@ -234,11 +232,6 @@ def build(graph, target=None, shape=None, dtype="float32", params : dict of str to NDArray The updated parameters of graph if params is passed. This can be different from the params passed in. - - extra_lib : tuple of (Graph, tvm.Module, dict of str to NDArray) - Extra runtime library for the last operator of the graph. - This return value only exists when extra_lib_op and - extra_lib_target are set in build_config. """ target = target if target else tvm.target.current_target() if target is None: @@ -254,36 +247,6 @@ def build(graph, target=None, shape=None, dtype="float32", cfg = BuildConfig.current graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph) - # Build extra operator runtime library - extra_lib = () - build_extra = False - if cfg.extra_lib_op is not None: - build_extra = True - graph, extra_op_graph = graph_util.split_last_op(graph) - last_op_name = extra_op_graph.index.nodes[-1]["op"] - if cfg.extra_lib_op != last_op_name: - raise RuntimeError("Currently only supports splitting the " - "last operator of the input graph. " - "extra_lib_op in build_config is %s, " - "but the last op of the graph is %s." % - (cfg.extra_lib_op, last_op_name)) - extra_op_params = {} - if params is not None: - for input_name in extra_op_graph.symbol.list_input_names(): - if input_name in params: - extra_op_params[input_name] = params[input_name] - params.remove(input_name) - _, graph_oshape = graph_util.infer_shape(graph, **shape) - extra_op_ishape = {} - shape_idx = 0 - for input_name in extra_op_graph.symbol.list_input_names(): - if input_name not in extra_op_params: - extra_op_ishape[input_name] = graph_oshape[shape_idx] - shape_idx += 1 - # Disable extra_lib option in cfg to ensure extra_op only built once. - cfg.extra_lib_op = None - extra_lib = build(extra_op_graph, cfg.extra_lib_target, - shape=extra_op_ishape, params=extra_op_params) shape, dtype = _update_shape_dtype(shape, dtype, params) # correct layout if necessary @@ -335,8 +298,6 @@ def build(graph, target=None, shape=None, dtype="float32", if params is None: params = {} params.update(init_var) - if build_extra: - return graph, libmod, params, extra_lib return graph, libmod, params def _remove_noref_params(params, graph): diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py index 51b2c0b01618..06757568443a 100644 --- a/nnvm/python/nnvm/frontend/mxnet.py +++ b/nnvm/python/nnvm/frontend/mxnet.py @@ -229,7 +229,7 @@ def _contrib_multibox_detection(inputs, attrs): new_attrs0 = {'clip': clip, 'threshold': float(threshold), 'variances': variances} new_attrs1 = {'nms_threshold': float(nms_threshold), 'force_suppress': force_suppress, 'nms_topk': int(nms_topk)} - data, valid_count = _get_nnvm_op('multibox_detection')(inputs[0], inputs[1], + data, valid_count = _get_nnvm_op('multibox_transform_loc')(inputs[0], inputs[1], inputs[2], **new_attrs0) return _get_nnvm_op('nms')(data, valid_count, **new_attrs1) diff --git a/nnvm/tests/python/compiler/test_build.py b/nnvm/tests/python/compiler/test_build.py index fcc882dfdb19..5e1f0337c293 100644 --- a/nnvm/tests/python/compiler/test_build.py +++ b/nnvm/tests/python/compiler/test_build.py @@ -94,44 +94,9 @@ def test_dtypes(): out = m.get_output(0, tvm.nd.empty(oshape, dtype)) np.testing.assert_allclose(out.asnumpy(), data, atol=1e-5, rtol=1e-5) -def test_compile_extra_lib(): - data = sym.Variable("data") - net = sym.relu(data) - net = sym.sqrt(net) - out = sym.flatten(net) - - target = "cuda" - extra_lib_target = "llvm" - dshape = (1, 3, 56, 56) - dtype = "float32" - in_data = np.random.uniform(size=dshape).astype(dtype) - opt_level = 2 - with nnvm.compiler.build_config(opt_level=opt_level): - graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape}) - m = graph_runtime.create(graph, lib, tvm.gpu(0)) - m.set_input("data", in_data) - m.run() - _, oshape = nnvm.compiler.graph_util.infer_shape(graph, shape={"data": dshape}) - expected_out = m.get_output(0, tvm.nd.empty(oshape[0], dtype)) - - with nnvm.compiler.build_config(opt_level=opt_level, extra_lib_op="flatten", extra_lib_target=extra_lib_target): - graph, lib, _, extra_libmod = nnvm.compiler.build(out, target, {"data": dshape}) - major_m = graph_runtime.create(graph, lib, tvm.gpu(0)) - major_m.set_input("data", in_data) - major_m.run() - major_out = major_m.get_output(0, tvm.nd.empty(dshape, dtype)) - extra_graph, extra_lib, _ = extra_libmod - extra_m = graph_runtime.create(extra_graph, extra_lib, tvm.cpu()) - extra_input_name = extra_graph.symbol.list_input_names()[0] - extra_m.set_input(extra_input_name, major_out) - extra_m.run() - final_out = extra_m.get_output(0, tvm.nd.empty(oshape[0], dtype)) - np.testing.assert_allclose(expected_out.asnumpy(), final_out.asnumpy(), atol=1e-5, rtol=1e-5) - if __name__ == "__main__": test_precompute_prune() test_compile() test_run() test_dtypes() - test_compile_extra_lib() From 3b315419929fb786ecb58c877b42a9d4aade35c1 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 10 Jun 2018 03:44:36 +0000 Subject: [PATCH 15/22] Fix lint --- nnvm/python/nnvm/frontend/mxnet.py | 2 +- topi/python/topi/vision/ssd/multibox.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py index 06757568443a..deae3112bf5f 100644 --- a/nnvm/python/nnvm/frontend/mxnet.py +++ b/nnvm/python/nnvm/frontend/mxnet.py @@ -230,7 +230,7 @@ def _contrib_multibox_detection(inputs, attrs): new_attrs1 = {'nms_threshold': float(nms_threshold), 'force_suppress': force_suppress, 'nms_topk': int(nms_topk)} data, valid_count = _get_nnvm_op('multibox_transform_loc')(inputs[0], inputs[1], - inputs[2], **new_attrs0) + inputs[2], **new_attrs0) return _get_nnvm_op('nms')(data, valid_count, **new_attrs1) def _elemwise_sum(inputs, _): diff --git a/topi/python/topi/vision/ssd/multibox.py b/topi/python/topi/vision/ssd/multibox.py index 7cfb21fbff12..a8f97146519b 100644 --- a/topi/python/topi/vision/ssd/multibox.py +++ b/topi/python/topi/vision/ssd/multibox.py @@ -212,7 +212,7 @@ def transform_loc(loc, loc_base_idx, anchor, anchor_base_idx, clip, vx, vy, vw, @tvm.target.generic_func def multibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, - variances=(0.1, 0.1, 0.2, 0.2)): + variances=(0.1, 0.1, 0.2, 0.2)): """Location transformation for multibox detection Parameters @@ -298,6 +298,6 @@ def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nm 3-D tensor with shape (batch_size, num_anchors, 6) """ inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor, - clip, threshold, variances) + clip, threshold, variances) out = nms(inter_out[0], inter_out[1], nms_threshold, force_suppress, nms_topk) return out From e99887855c07ddab59babe019a028e528051df1c Mon Sep 17 00:00:00 2001 From: Wang Date: Tue, 12 Jun 2018 17:16:10 -0700 Subject: [PATCH 16/22] Fix tutorial title --- tutorials/nnvm/deploy_ssd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/nnvm/deploy_ssd.py b/tutorials/nnvm/deploy_ssd.py index d1d8bcc6b1ed..ed9a21b463c1 100644 --- a/tutorials/nnvm/deploy_ssd.py +++ b/tutorials/nnvm/deploy_ssd.py @@ -1,6 +1,6 @@ """ Deploy Single Shot Multibox Detector(SSD) model -====================================== +=============================================== **Author**: `Yao Wang `_ This article is an introductory tutorial to deploy SSD models with TVM. From a9e2c7ddd1ba629b7480736b388ca8b8a956bd05 Mon Sep 17 00:00:00 2001 From: Wang Date: Tue, 12 Jun 2018 17:27:20 -0700 Subject: [PATCH 17/22] Remove split_last_op --- nnvm/python/nnvm/compiler/graph_util.py | 39 ------------------------- 1 file changed, 39 deletions(-) diff --git a/nnvm/python/nnvm/compiler/graph_util.py b/nnvm/python/nnvm/compiler/graph_util.py index 58782ce374b2..878b413bcd84 100644 --- a/nnvm/python/nnvm/compiler/graph_util.py +++ b/nnvm/python/nnvm/compiler/graph_util.py @@ -147,42 +147,3 @@ def gradients(ys, xs, grad_ys=None): if isinstance(xs, list) else len(xs.list_output_names()) ret = [grad_g.symbol[i] for i in range(nx)] return ret - -def split_last_op(graph): - """Split graph into the last operator - and all other parts before. - - Parameters - ---------- - graph : Graph - The original graph. - - Returns - ------- - main_graph: Graph - The graph before last operator. - - last_op_graph: Graph - The graph for the last operator. - """ - graph_idx = graph.index - last_op_node = graph_idx.nodes[-1] - if last_op_node["op"] == "null": - raise RuntimeError("split_last_op doesn't support last operator " - "to be null.") - last_op_func = getattr(sym, last_op_node["op"]) - if "attrs" in last_op_node: - last_op_attr = last_op_node["attrs"] - else: - last_op_attr = {} - last_op_num_inputs = len(last_op_node["inputs"]) - last_op_inputs = [] - for i in range(last_op_num_inputs): - input_idx = last_op_node["inputs"][i][0] - input_name = graph_idx.nodes[input_idx]["name"] - last_op_inputs.append(sym.Variable(input_name)) - last_op_sym = last_op_func(*last_op_inputs, **last_op_attr) - last_op_graph = create(last_op_sym) - main_graph_sym = graph.symbol.get_children() - main_graph = create(main_graph_sym) - return main_graph, last_op_graph From 355fae166b9f8cd054583fa8485d0a5ef32d1247 Mon Sep 17 00:00:00 2001 From: Wang Date: Wed, 13 Jun 2018 23:27:25 -0700 Subject: [PATCH 18/22] Move download to testing --- nnvm/python/nnvm/compiler/graph_util.py | 1 - nnvm/python/nnvm/testing/__init__.py | 3 +- nnvm/python/nnvm/testing/download.py | 74 +++++++++++++++++++++++++ src/contrib/sort/sort.cc | 10 ++-- tutorials/nnvm/deploy_ssd.py | 37 ++++--------- tutorials/nnvm/from_darknet.py | 65 +--------------------- 6 files changed, 94 insertions(+), 96 deletions(-) create mode 100644 nnvm/python/nnvm/testing/download.py diff --git a/nnvm/python/nnvm/compiler/graph_util.py b/nnvm/python/nnvm/compiler/graph_util.py index 878b413bcd84..e831298b27d9 100644 --- a/nnvm/python/nnvm/compiler/graph_util.py +++ b/nnvm/python/nnvm/compiler/graph_util.py @@ -7,7 +7,6 @@ from ..graph import create from ..symbol import Group, ones_like -from .. import symbol as sym def infer_shape(graph, **shape): """Infer the shape given the shape of inputs. diff --git a/nnvm/python/nnvm/testing/__init__.py b/nnvm/python/nnvm/testing/__init__.py index 56d5a9a48b59..784d4c2dd60f 100644 --- a/nnvm/python/nnvm/testing/__init__.py +++ b/nnvm/python/nnvm/testing/__init__.py @@ -3,8 +3,9 @@ from .config import ctx_list from .utils import create_workload +from .download import download from . import mobilenet from . import mlp from . import resnet from . import vgg -from . import yolo2_detection +from . import yolo2_detection \ No newline at end of file diff --git a/nnvm/python/nnvm/testing/download.py b/nnvm/python/nnvm/testing/download.py new file mode 100644 index 000000000000..f7cc263a8b58 --- /dev/null +++ b/nnvm/python/nnvm/testing/download.py @@ -0,0 +1,74 @@ +# pylint: disable=invalid-name, no-member, import-error, no-name-in-module, global-variable-undefined, bare-except +"""Helper utility for downloading""" +from __future__ import print_function + +import os +import sys +import time +import urllib +import requests + +if sys.version_info >= (3,): + import urllib.request as urllib2 +else: + import urllib2 + +def _download_progress(count, block_size, total_size): + """Show the download progress. + """ + global start_time + if count == 0: + start_time = time.time() + return + duration = time.time() - start_time + progress_size = int(count * block_size) + speed = int(progress_size / (1024 * duration)) + percent = int(count * block_size * 100 / total_size) + sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" % + (percent, progress_size / (1024 * 1024), speed, duration)) + sys.stdout.flush() + +def download(url, path, overwrite=False, size_compare=False): + """Downloads the file from the internet. + Set the input options correctly to overwrite or do the size comparison + + Parameters + ---------- + url : str + Download url. + + path : str + Local file path to save downloaded file + + overwrite : bool, optional + Whether to overwrite existing file + + size_compare : bool, optional + Whether to do size compare to check downloaded file. + + Returns + ------- + out_name : converted out name of operation + sym : nnvm.Symbol + Converted nnvm Symbol + """ + if os.path.isfile(path) and not overwrite: + if size_compare: + file_size = os.path.getsize(path) + res_head = requests.head(url) + res_get = requests.get(url, stream=True) + if 'Content-Length' not in res_head.headers: + res_get = urllib2.urlopen(url) + url_file_size = int(res_get.headers['Content-Length']) + if url_file_size != file_size: + print("exist file got corrupted, downloading %s file freshly..." % path) + download(url, path, True, False) + return + print('File {} exists, skip.'.format(path)) + return + print('Downloading from url {} to {}'.format(url, path)) + try: + urllib.request.urlretrieve(url, path, reporthook=_download_progress) + print('') + except: + urllib.urlretrieve(url, path, reporthook=_download_progress) diff --git a/src/contrib/sort/sort.cc b/src/contrib/sort/sort.cc index 2ddc47f161e3..160e479b86b5 100644 --- a/src/contrib/sort/sort.cc +++ b/src/contrib/sort/sort.cc @@ -78,10 +78,12 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") int64_t full_idx = base_idx + k * axis_mul_after; sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx))); } - std::stable_sort(sorter.begin(), sorter.end(), - is_descend ? CompareDescend - : CompareAscend); - for (uint32_t k = 0; k < input->shape[axis]; ++k) { + if (is_descend) { + std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); + } else { + std::stable_sort(sorter.begin(), sorter.end(), CompareAscend); + } + for (int32_t k = 0; k < input->shape[axis]; ++k) { *(static_cast(output->data) + base_idx + k * axis_mul_after) = k < static_cast(sorter.size()) ? sorter[k].first : k; } diff --git a/tutorials/nnvm/deploy_ssd.py b/tutorials/nnvm/deploy_ssd.py index ed9a21b463c1..15669479e51f 100644 --- a/tutorials/nnvm/deploy_ssd.py +++ b/tutorials/nnvm/deploy_ssd.py @@ -14,6 +14,7 @@ import mxnet as mx import cv2 import numpy as np +import nnvm.testing.download as download from nnvm import compiler from nnvm.frontend import from_mxnet @@ -22,42 +23,24 @@ ###################################################################### -# Set the parameters here. +# Set the parameters here +# ----------------------- +# .. note:: +# +# Currently we support compiling SSD on CPU only. +# GPU support is in progress. model_name = "ssd_resnet50_512" model_file = "%s.zip" % model_name test_image = "dog.jpg" -target = "llvm" dshape = (1, 3, 512, 512) dtype = "float32" +target = "llvm" ctx = tvm.cpu() -def download(url, path, overwrite=False): - """Downloads the file from the internet. - Set the input options correctly to overwrite or do the size comparison - - Parameters - ---------- - url : str - Download file url - path : str - File saved path. - overwrite : boolean - Dict of operator attributes - """ - if os.path.isfile(path) and not overwrite: - print('File {} exists, skip.'.format(path)) - return - print('Downloading from url {} to {}'.format(url, path)) - try: - urllib.request.urlretrieve(url, path) - print('') - except: - urllib.urlretrieve(url, path) - ###################################################################### -# Download MXNet SSD pre-trained model and demo image. -# ---------------------------- +# Download MXNet SSD pre-trained model and demo image +# --------------------------------------------------- # Pre-trained model available at # https://github.com/apache/incubator-\mxnet/tree/master/example/ssd diff --git a/tutorials/nnvm/from_darknet.py b/tutorials/nnvm/from_darknet.py index 9613f023c1e9..ced1c03c7192 100644 --- a/tutorials/nnvm/from_darknet.py +++ b/tutorials/nnvm/from_darknet.py @@ -20,17 +20,12 @@ import nnvm import nnvm.frontend.darknet import nnvm.testing.darknet +import nnvm.testing.download as download from nnvm.testing.darknet import __darknetffi__ import matplotlib.pyplot as plt import numpy as np import tvm -import os, sys, time, urllib, requests -if sys.version_info >= (3,): - import urllib.request as urllib2 - import urllib.parse as urlparse -else: - import urllib2 - import urlparse +import os ###################################################################### # Set the parameters here. @@ -41,62 +36,6 @@ target = 'llvm' ctx = tvm.cpu(0) -def dlProgress(count, block_size, total_size): - """Show the download progress.""" - global start_time - if count == 0: - start_time = time.time() - return - duration = time.time() - start_time - progress_size = int(count * block_size) - speed = int(progress_size / (1024 * duration)) - percent = int(count * block_size * 100 / total_size) - sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" % - (percent, progress_size / (1024 * 1024), speed, duration)) - sys.stdout.flush() - -def download(url, path, overwrite=False, sizecompare=False): - """Downloads the file from the internet. - Set the input options correctly to overwrite or do the size comparison - - Parameters - ---------- - url : str - Operator name, such as Convolution, Connected, etc - path : str - List of input symbols. - overwrite : dict - Dict of operator attributes - sizecompare : dict - Dict of operator attributes - - Returns - ------- - out_name : converted out name of operation - sym : nnvm.Symbol - Converted nnvm Symbol - """ - if os.path.isfile(path) and not overwrite: - if (sizecompare): - fileSize = os.path.getsize(path) - resHead = requests.head(url) - resGet = requests.get(url,stream=True) - if 'Content-Length' not in resHead.headers : - resGet = urllib2.urlopen(url) - urlFileSize = int(resGet.headers['Content-Length']) - if urlFileSize != fileSize: - print ("exist file got corrupted, downloading", path , " file freshly") - download(url, path, True, False) - return - print('File {} exists, skip.'.format(path)) - return - print('Downloading from url {} to {}'.format(url, path)) - try: - urllib.request.urlretrieve(url, path, reporthook=dlProgress) - print('') - except: - urllib.urlretrieve(url, path, reporthook=dlProgress) - ###################################################################### # Prepare cfg and weights file # ---------------------------- From 5db4abaad7c91848858a911c913a5991092b68ff Mon Sep 17 00:00:00 2001 From: Wang Date: Wed, 13 Jun 2018 23:44:23 -0700 Subject: [PATCH 19/22] Fix lint --- nnvm/python/nnvm/testing/__init__.py | 2 +- nnvm/python/nnvm/testing/download.py | 1 + tutorials/nnvm/deploy_ssd.py | 1 - 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nnvm/python/nnvm/testing/__init__.py b/nnvm/python/nnvm/testing/__init__.py index 784d4c2dd60f..06da4854d12f 100644 --- a/nnvm/python/nnvm/testing/__init__.py +++ b/nnvm/python/nnvm/testing/__init__.py @@ -8,4 +8,4 @@ from . import mlp from . import resnet from . import vgg -from . import yolo2_detection \ No newline at end of file +from . import yolo2_detection diff --git a/nnvm/python/nnvm/testing/download.py b/nnvm/python/nnvm/testing/download.py index f7cc263a8b58..e4a4d0059bc2 100644 --- a/nnvm/python/nnvm/testing/download.py +++ b/nnvm/python/nnvm/testing/download.py @@ -1,6 +1,7 @@ # pylint: disable=invalid-name, no-member, import-error, no-name-in-module, global-variable-undefined, bare-except """Helper utility for downloading""" from __future__ import print_function +from __future__ import absolute_import as _abs import os import sys diff --git a/tutorials/nnvm/deploy_ssd.py b/tutorials/nnvm/deploy_ssd.py index 15669479e51f..64af44cd66b5 100644 --- a/tutorials/nnvm/deploy_ssd.py +++ b/tutorials/nnvm/deploy_ssd.py @@ -8,7 +8,6 @@ convert it to NNVM graph. """ import os -import urllib import zipfile import tvm import mxnet as mx From 2983847524d39b386f3483eb6fe28d0358ba6209 Mon Sep 17 00:00:00 2001 From: Wang Date: Thu, 14 Jun 2018 11:04:06 -0700 Subject: [PATCH 20/22] Minor fix --- nnvm/python/nnvm/testing/__init__.py | 1 - tutorials/nnvm/deploy_ssd.py | 2 +- tutorials/nnvm/from_darknet.py | 10 +++++----- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/nnvm/python/nnvm/testing/__init__.py b/nnvm/python/nnvm/testing/__init__.py index 06da4854d12f..56d5a9a48b59 100644 --- a/nnvm/python/nnvm/testing/__init__.py +++ b/nnvm/python/nnvm/testing/__init__.py @@ -3,7 +3,6 @@ from .config import ctx_list from .utils import create_workload -from .download import download from . import mobilenet from . import mlp from . import resnet diff --git a/tutorials/nnvm/deploy_ssd.py b/tutorials/nnvm/deploy_ssd.py index 64af44cd66b5..b3b5072f28c9 100644 --- a/tutorials/nnvm/deploy_ssd.py +++ b/tutorials/nnvm/deploy_ssd.py @@ -13,10 +13,10 @@ import mxnet as mx import cv2 import numpy as np -import nnvm.testing.download as download from nnvm import compiler from nnvm.frontend import from_mxnet +from nnvm.testing.download import download from tvm.contrib import graph_runtime from mxnet.model import load_checkpoint diff --git a/tutorials/nnvm/from_darknet.py b/tutorials/nnvm/from_darknet.py index ced1c03c7192..2cd681b624ad 100644 --- a/tutorials/nnvm/from_darknet.py +++ b/tutorials/nnvm/from_darknet.py @@ -14,19 +14,19 @@ pip install cffi pip install opencv-python """ -from ctypes import * -import math -import random + import nnvm import nnvm.frontend.darknet import nnvm.testing.darknet -import nnvm.testing.download as download -from nnvm.testing.darknet import __darknetffi__ import matplotlib.pyplot as plt import numpy as np import tvm import os +from ctypes import * +from nnvm.testing.download import download +from nnvm.testing.darknet import __darknetffi__ + ###################################################################### # Set the parameters here. # Supported models alexnet, resnet50, resnet152, extraction, yolo From 78e5b456218e6554353895d1686fd57e8e87841b Mon Sep 17 00:00:00 2001 From: Wang Date: Thu, 14 Jun 2018 14:48:14 -0700 Subject: [PATCH 21/22] Update dowmload docstring --- nnvm/python/nnvm/testing/download.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/nnvm/python/nnvm/testing/download.py b/nnvm/python/nnvm/testing/download.py index e4a4d0059bc2..849c18bcf0f6 100644 --- a/nnvm/python/nnvm/testing/download.py +++ b/nnvm/python/nnvm/testing/download.py @@ -46,12 +46,6 @@ def download(url, path, overwrite=False, size_compare=False): size_compare : bool, optional Whether to do size compare to check downloaded file. - - Returns - ------- - out_name : converted out name of operation - sym : nnvm.Symbol - Converted nnvm Symbol """ if os.path.isfile(path) and not overwrite: if size_compare: From 3cc2d430fe98d10362e42c6664dcd28dd638d320 Mon Sep 17 00:00:00 2001 From: Wang Date: Thu, 14 Jun 2018 14:50:02 -0700 Subject: [PATCH 22/22] Fix lint --- python/tvm/contrib/rpc/proxy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/contrib/rpc/proxy.py b/python/tvm/contrib/rpc/proxy.py index c7f66d68e492..e1e81d20b611 100644 --- a/python/tvm/contrib/rpc/proxy.py +++ b/python/tvm/contrib/rpc/proxy.py @@ -333,7 +333,7 @@ def _update_tracker(self, period_update=False): rpc_key = key.split(":")[0] base.sendjson(self._tracker_conn, [TrackerCode.PUT, rpc_key, - (self._listen_port, key), None]) + (self._listen_port, key), None]) assert base.recvjson(self._tracker_conn) == TrackerCode.SUCCESS if rpc_key not in self._key_set: self._key_set.add(rpc_key)