diff --git a/nnvm/include/nnvm/top/nn.h b/nnvm/include/nnvm/top/nn.h index bee6137829c5..bbdb3b9c4f12 100644 --- a/nnvm/include/nnvm/top/nn.h +++ b/nnvm/include/nnvm/top/nn.h @@ -319,6 +319,55 @@ struct LayoutTransformParam : public dmlc::Parameter { } }; +struct MultiBoxPriorParam : public dmlc::Parameter { + Tuple sizes; + Tuple ratios; + Tuple steps; + Tuple offsets; + bool clip; + + DMLC_DECLARE_PARAMETER(MultiBoxPriorParam) { + DMLC_DECLARE_FIELD(sizes).set_default(Tuple({1.0})) + .describe("List of sizes of generated MultiBoxPriores."); + DMLC_DECLARE_FIELD(ratios).set_default(Tuple({1.0})) + .describe("List of aspect ratios of generated MultiBoxPriores."); + DMLC_DECLARE_FIELD(steps).set_default(Tuple({-1.0, -1.0})) + .describe("Priorbox step across y and x, -1 for auto calculation."); + DMLC_DECLARE_FIELD(offsets).set_default(Tuple({0.5, 0.5})) + .describe("Priorbox center offsets, y and x respectively."); + DMLC_DECLARE_FIELD(clip).set_default(false) + .describe("Whether to clip out-of-boundary boxes."); + } +}; + +struct MultiBoxTransformLocParam : public dmlc::Parameter { + bool clip; + float threshold; + Tuple variances; + DMLC_DECLARE_PARAMETER(MultiBoxTransformLocParam) { + DMLC_DECLARE_FIELD(clip).set_default(true) + .describe("Clip out-of-boundary boxes."); + DMLC_DECLARE_FIELD(threshold).set_default(0.01) + .describe("Threshold to be a positive prediction."); + DMLC_DECLARE_FIELD(variances).set_default(Tuple{0.1, 0.1, 0.2, 0.2}) + .describe("Variances to be decoded from box regression output."); + } +}; + +struct NMSParam : public dmlc::Parameter { + float nms_threshold; + bool force_suppress; + int nms_topk; + DMLC_DECLARE_PARAMETER(NMSParam) { + DMLC_DECLARE_FIELD(nms_threshold).set_default(0.5) + .describe("Non-maximum suppression threshold."); + DMLC_DECLARE_FIELD(force_suppress).set_default(false) + .describe("Suppress all detections regardless of class_id."); + DMLC_DECLARE_FIELD(nms_topk).set_default(-1) + .describe("Keep maximum top k detections before nms, -1 for no limit."); + } +}; + } // namespace top } // namespace nnvm diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py index 82b8e555d5ec..deae3112bf5f 100644 --- a/nnvm/python/nnvm/frontend/mxnet.py +++ b/nnvm/python/nnvm/frontend/mxnet.py @@ -71,7 +71,7 @@ def _batch_norm(inputs, attrs): new_attrs['axis'] = attrs.get('axis', 1) new_attrs['epsilon'] = attrs.get('eps', 0.001) new_attrs['center'] = True - new_attrs['scale'] = True + new_attrs['scale'] = not _parse_bool_str(attrs, 'fix_gamma', default="False") return _get_nnvm_op(op_name)(*inputs, **new_attrs) def _concat(inputs, attrs): @@ -195,6 +195,12 @@ def _split(inputs, attrs): new_attrs['axis'] = attrs.get('axis', 1) return _get_nnvm_op(op_name)(*inputs, **new_attrs) +def _softmax_activation(inputs, attrs): + op_name, new_attrs = 'softmax', {} + mode = attrs.get('mode', 'instance') + new_attrs['axis'] = 0 if mode == 'instance' else 1 + return _get_nnvm_op(op_name)(inputs[0], **new_attrs) + def _softmax_output(inputs, attrs): op_name, new_attrs = 'softmax', {} if _parse_bool_str(attrs, 'multi_output'): @@ -212,6 +218,25 @@ def _clip(inputs, attrs): new_attrs['a_max'] = _required_attr(attrs, 'a_max') return _get_nnvm_op(op_name)(*inputs, **new_attrs) +def _contrib_multibox_detection(inputs, attrs): + clip = _parse_bool_str(attrs, 'clip', default='True') + threshold = attrs.get('threshold') or 0.01 + nms_threshold = attrs.get('nms_threshold') or 0.5 + force_suppress = _parse_bool_str(attrs, 'force_suppress', default='False') + variances = tuple([float(x.strip()) for x in attrs.get('variances').strip('()').split(',')]) \ + if attrs.get('variances') is not None else (0.1, 0.1, 0.2, 0.2) + nms_topk = attrs.get('nms_topk') or -1 + new_attrs0 = {'clip': clip, 'threshold': float(threshold), 'variances': variances} + new_attrs1 = {'nms_threshold': float(nms_threshold), 'force_suppress': force_suppress, + 'nms_topk': int(nms_topk)} + data, valid_count = _get_nnvm_op('multibox_transform_loc')(inputs[0], inputs[1], + inputs[2], **new_attrs0) + return _get_nnvm_op('nms')(data, valid_count, **new_attrs1) + +def _elemwise_sum(inputs, _): + new_attrs = {'num_args':len(inputs)} + return _get_nnvm_op('elemwise_sum')(*inputs, **new_attrs) + _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__', '__div_symbol__', '__mul_scalar__', '__mul_symbol__', @@ -224,12 +249,15 @@ def _clip(inputs, attrs): 'relu', 'sigmoid', 'softmax', 'sum', 'tanh', 'transpose'] _convert_map = { + '_copy' : _rename('copy'), '_div_scalar' : _rename('__div_scalar__'), '_minus_scalar' : _rename('__sub_scalar__'), '_mul_scalar' : _rename('__mul_scalar__'), '_plus_scalar' : _rename('__add_scalar__'), '_rdiv_scalar' : _rename('__rdiv_scalar__'), '_rminus_scalar': _rename('__rsub_scalar__'), + '_contrib_MultiBoxPrior' : _rename('multibox_prior'), + '_contrib_MultiBoxDetection' : _contrib_multibox_detection, 'Activation' : _activations, 'BatchNorm' : _batch_norm, 'BatchNorm_v1' : _batch_norm, @@ -248,7 +276,9 @@ def _clip(inputs, attrs): 'SliceChannel' : _split, 'split' : _split, 'Softmax' : _rename('softmax'), + 'SoftmaxActivation' : _softmax_activation, 'SoftmaxOutput' : _softmax_output, + 'add_n' : _elemwise_sum, 'concat' : _concat, 'max_axis' : _rename('max'), 'min_axis' : _rename('min'), diff --git a/nnvm/python/nnvm/testing/download.py b/nnvm/python/nnvm/testing/download.py new file mode 100644 index 000000000000..849c18bcf0f6 --- /dev/null +++ b/nnvm/python/nnvm/testing/download.py @@ -0,0 +1,69 @@ +# pylint: disable=invalid-name, no-member, import-error, no-name-in-module, global-variable-undefined, bare-except +"""Helper utility for downloading""" +from __future__ import print_function +from __future__ import absolute_import as _abs + +import os +import sys +import time +import urllib +import requests + +if sys.version_info >= (3,): + import urllib.request as urllib2 +else: + import urllib2 + +def _download_progress(count, block_size, total_size): + """Show the download progress. + """ + global start_time + if count == 0: + start_time = time.time() + return + duration = time.time() - start_time + progress_size = int(count * block_size) + speed = int(progress_size / (1024 * duration)) + percent = int(count * block_size * 100 / total_size) + sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" % + (percent, progress_size / (1024 * 1024), speed, duration)) + sys.stdout.flush() + +def download(url, path, overwrite=False, size_compare=False): + """Downloads the file from the internet. + Set the input options correctly to overwrite or do the size comparison + + Parameters + ---------- + url : str + Download url. + + path : str + Local file path to save downloaded file + + overwrite : bool, optional + Whether to overwrite existing file + + size_compare : bool, optional + Whether to do size compare to check downloaded file. + """ + if os.path.isfile(path) and not overwrite: + if size_compare: + file_size = os.path.getsize(path) + res_head = requests.head(url) + res_get = requests.get(url, stream=True) + if 'Content-Length' not in res_head.headers: + res_get = urllib2.urlopen(url) + url_file_size = int(res_get.headers['Content-Length']) + if url_file_size != file_size: + print("exist file got corrupted, downloading %s file freshly..." % path) + download(url, path, True, False) + return + print('File {} exists, skip.'.format(path)) + return + print('Downloading from url {} to {}'.format(url, path)) + try: + urllib.request.urlretrieve(url, path, reporthook=_download_progress) + print('') + except: + urllib.urlretrieve(url, path, reporthook=_download_progress) diff --git a/nnvm/python/nnvm/testing/resnet.py b/nnvm/python/nnvm/testing/resnet.py index 64eb63c29b7a..243cc1b65144 100644 --- a/nnvm/python/nnvm/testing/resnet.py +++ b/nnvm/python/nnvm/testing/resnet.py @@ -108,7 +108,7 @@ def resnet(units, num_stages, filter_list, num_classes, image_shape, num_unit = len(units) assert num_unit == num_stages data = sym.Variable(name='data') - data = sym.batch_norm(data=data, epsilon=2e-5, name='bn_data') + data = sym.batch_norm(data=data, epsilon=2e-5, scale=False, name='bn_data') (_, height, _) = image_shape if height <= 32: # such as cifar10 body = sym.conv2d( diff --git a/nnvm/python/nnvm/top/attr_dict.py b/nnvm/python/nnvm/top/attr_dict.py index a913a92552b2..efd439fa75fc 100644 --- a/nnvm/python/nnvm/top/attr_dict.py +++ b/nnvm/python/nnvm/top/attr_dict.py @@ -83,6 +83,21 @@ def get_int(self, key): """ return int(self[key]) + def get_float_tuple(self, key): + """Get tuple of float from attr dict + + Parameters + ---------- + key : str + The attr key + + Returns + ------- + tuple : tuple of float + The result tuple + """ + return tuple(float(x) for x in self[key][1:-1].split(",") if x) + def get_float(self, key): """Get float from attr dict diff --git a/nnvm/python/nnvm/top/vision.py b/nnvm/python/nnvm/top/vision.py index 89409de6263b..edbf72320a26 100644 --- a/nnvm/python/nnvm/top/vision.py +++ b/nnvm/python/nnvm/top/vision.py @@ -1,10 +1,9 @@ - # pylint: disable=invalid-name, unused-argument """Definition of nn ops""" from __future__ import absolute_import -import topi import tvm +import topi from . import registry as reg from .registry import OpPattern @@ -38,3 +37,62 @@ def schedule_region(attrs, outs, target): return topi.generic.vision.schedule_region(outs) reg.register_pattern("yolo2_region", OpPattern.OPAQUE) + +# multibox_prior +@reg.register_schedule("multibox_prior") +def schedule_multibox_prior(_, outs, target): + """Schedule definition of multibox_prior""" + with tvm.target.create(target): + return topi.generic.schedule_multibox_prior(outs) + +@reg.register_compute("multibox_prior") +def compute_multibox_prior(attrs, inputs, _): + """Compute definition of multibox_prior""" + sizes = attrs.get_float_tuple('sizes') + ratios = attrs.get_float_tuple('ratios') + steps = attrs.get_float_tuple('steps') + offsets = attrs.get_float_tuple('offsets') + clip = attrs.get_bool('clip') + + return topi.vision.ssd.multibox_prior(inputs[0], sizes, ratios, + steps, offsets, clip) + +reg.register_pattern("multibox_prior", OpPattern.OPAQUE) + +# multibox_transform_loc +@reg.register_schedule("multibox_transform_loc") +def schedule_multibox_transform_loc(_, outs, target): + """Schedule definition of multibox_detection""" + with tvm.target.create(target): + return topi.generic.schedule_multibox_transform_loc(outs) + +@reg.register_compute("multibox_transform_loc") +def compute_multibox_transform_loc(attrs, inputs, _): + """Compute definition of multibox_detection""" + clip = attrs.get_bool('clip') + threshold = attrs.get_float('threshold') + variance = attrs.get_float_tuple('variances') + + return topi.vision.ssd.multibox_transform_loc(inputs[0], inputs[1], inputs[2], + clip, threshold, variance) + +reg.register_pattern("multibox_detection", OpPattern.OPAQUE) + +# non-maximum suppression +@reg.register_schedule("nms") +def schedule_nms(_, outs, target): + """Schedule definition of nms""" + with tvm.target.create(target): + return topi.generic.schedule_nms(outs) + +@reg.register_compute("nms") +def compute_nms(attrs, inputs, _): + """Compute definition of nms""" + nms_threshold = attrs.get_float('nms_threshold') + force_suppress = attrs.get_bool('force_suppress') + nms_topk = attrs.get_int('nms_topk') + + return topi.vision.nms(inputs[0], inputs[1], nms_threshold, + force_suppress, nms_topk) + +reg.register_pattern("nms", OpPattern.OPAQUE) diff --git a/nnvm/src/top/vision/nms.cc b/nnvm/src/top/vision/nms.cc new file mode 100644 index 000000000000..2680b894255b --- /dev/null +++ b/nnvm/src/top/vision/nms.cc @@ -0,0 +1,80 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file nms.cc + * \brief Property def of SSD non-maximum suppression operator. + */ + +#include +#include +#include +#include +#include +#include +#include "../op_common.h" +#include "../elemwise_op_common.h" + +namespace nnvm { +namespace top { +using compiler::FTVMCompute; +using tvm::Tensor; +using tvm::Array; + +DMLC_REGISTER_PARAMETER(NMSParam); + +bool NMSShape(const NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U) << "Inputs: [data, valid_count]"; + TShape dshape = in_attrs->at(0); + TShape vshape = in_attrs->at(1); + CHECK_EQ(dshape.ndim(), 3U) << "Input data should be 3-D."; + CHECK_EQ(vshape.ndim(), 1U) << "Input valid count should be 1-D."; + CHECK_EQ(dshape[2], 6U) << "Data input should have shape " + "(batch_size, num_anchors, 6)."; + CHECK_EQ(dshape[0], vshape[0]) << "batch_size mismatch."; + out_attrs->clear(); + NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, dshape); + return true; +} + +inline bool NMSInferType(const NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + DTYPE_ASSIGN(out_attrs->at(0), in_attrs->at(0)); + return true; +} + +inline bool NMSInferLayout(const NodeAttrs& attrs, + std::vector *ilayouts, + const std::vector *last_ilayouts, + std::vector *olayouts) { + static const Layout kNCHW("NCHW"); + CHECK_EQ(ilayouts->size(), 2U); + CHECK_EQ(olayouts->size(), 1U); + NNVM_ASSIGN_LAYOUT(*ilayouts, 0, kNCHW); + NNVM_ASSIGN_LAYOUT(*ilayouts, 1, kNCHW); + return true; +} + +NNVM_REGISTER_OP(nms) + .describe(R"doc("Non-maximum suppression." +)doc" NNVM_ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FGetAttrDict", + ParamGetAttrDict) +.add_arguments(NMSParam::__FIELDS__()) +.add_argument("data", "Tensor", "Input data.") +.add_argument("valid_count", "Tensor", "Number of valid anchor boxes.") +.set_attr("FListInputNames", [](const NodeAttrs& attrs) { + return std::vector{"data", "valid_count"}; +}) +.set_attr("FInferShape", NMSShape) +.set_attr("FInferType", NMSInferType) +.set_attr("FCorrectLayout", NMSInferLayout) +.set_support_level(4); + +} // namespace top +} // namespace nnvm + diff --git a/nnvm/src/top/vision/ssd/mutibox_op.cc b/nnvm/src/top/vision/ssd/mutibox_op.cc new file mode 100644 index 000000000000..7f1aca5d2b82 --- /dev/null +++ b/nnvm/src/top/vision/ssd/mutibox_op.cc @@ -0,0 +1,158 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file multibox_op.cc + * \brief Property def of SSD multibox related operators. + */ + +#include +#include +#include +#include +#include +#include +#include "../../op_common.h" +#include "../../elemwise_op_common.h" + +namespace nnvm { +namespace top { +using compiler::FTVMCompute; +using tvm::Tensor; +using tvm::Array; + +DMLC_REGISTER_PARAMETER(MultiBoxPriorParam); + +bool MultiBoxPriorShape(const NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const MultiBoxPriorParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), 1U) << "Inputs: [data]" << in_attrs->size(); + TShape dshape = in_attrs->at(0); + CHECK_GE(dshape.ndim(), 4U) << "Input data should be 4D: " + "[batch, channel, height, width]"; + int in_height = dshape[2]; + CHECK_GT(in_height, 0) << "Input height should > 0"; + int in_width = dshape[3]; + CHECK_GT(in_width, 0) << "Input width should > 0"; + // since input sizes are same in each batch, we could share MultiBoxPrior + TShape oshape = TShape(3); + int num_sizes = param.sizes.ndim(); + int num_ratios = param.ratios.ndim(); + oshape[0] = 1; + oshape[1] = in_height * in_width * (num_sizes + num_ratios - 1); + oshape[2] = 4; + CHECK_EQ(param.steps.ndim(), 2) << "Step ndim must be 2: (step_y, step_x)"; + CHECK_GE(param.steps[0] * param.steps[1], 0) << "Must specify both " + "step_y and step_x"; + out_attrs->clear(); + NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, oshape); + return true; +} + +inline bool MultiBoxPriorLayout(const NodeAttrs& attrs, + std::vector *ilayouts, + const std::vector *last_ilayouts, + std::vector *olayouts) { + static const Layout kNCHW("NCHW"); + CHECK_EQ(ilayouts->size(), 1U); + CHECK_EQ(olayouts->size(), 1U); + NNVM_ASSIGN_LAYOUT(*ilayouts, 0, kNCHW); + return true; +} + +NNVM_REGISTER_OP(multibox_prior) + .describe(R"doc("Generate prior(anchor) boxes from data, sizes and ratios." +)doc" NNVM_ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FGetAttrDict", ParamGetAttrDict) +.add_arguments(MultiBoxPriorParam::__FIELDS__()) +.add_argument("data", "Tensor", "Input data") +.set_attr("FInferShape", MultiBoxPriorShape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FCorrectLayout", MultiBoxPriorLayout) +.set_attr( + "FGradient", [](const NodePtr& n, + const std::vector& ograds) { + return std::vector{ + MakeNode("zeros_like", n->attrs.name + "_zero_grad", + {n->inputs[0]}), + ograds[0] + }; +}) +.set_support_level(4); + +DMLC_REGISTER_PARAMETER(MultiBoxTransformLocParam); + +bool MultiBoxTransformLocShape(const NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 3U) << "Inputs: [cls_prob, loc_pred, anchor]"; + TShape cshape = in_attrs->at(0); + TShape lshape = in_attrs->at(1); + TShape ashape = in_attrs->at(2); + CHECK_EQ(cshape.ndim(), 3U) << "Class probability should be 3-D."; + CHECK_EQ(lshape.ndim(), 2U) << "Location prediction should be 2-D."; + CHECK_EQ(ashape.ndim(), 3U) << "Anchor should be 3-D."; + CHECK_EQ(cshape[2], ashape[1]) << "Number of anchors mismatch."; + CHECK_EQ(cshape[2] * 4, lshape[1]) << "# anchors mismatch with # loc."; + CHECK_GT(ashape[1], 0U) << "Number of anchors must > 0."; + CHECK_EQ(ashape[2], 4U); + TShape oshape0 = TShape(3); + oshape0[0] = cshape[0]; + oshape0[1] = ashape[1]; + oshape0[2] = 6; // [id, prob, xmin, ymin, xmax, ymax] + TShape oshape1 = TShape(1); + oshape1[0] = cshape[0]; + out_attrs->clear(); + NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, oshape0); + NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 1, oshape1); + return true; +} + +inline bool MultiBoxTransformLocLayout(const NodeAttrs& attrs, + std::vector *ilayouts, + const std::vector *last_ilayouts, + std::vector *olayouts) { + CHECK_EQ(ilayouts->size(), 3U); + CHECK_EQ(last_ilayouts->size(), 3U); + CHECK_EQ(olayouts->size(), 2U); + for (size_t i = 0; i < last_ilayouts->size(); ++i) { + const Layout& last_layout = last_ilayouts->at(i); + if (last_layout.defined()) { + NNVM_ASSIGN_LAYOUT(*ilayouts, i, last_layout); + } + } + return true; +} + +inline bool MultiBoxTransformLocInferType(const NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + DTYPE_ASSIGN(out_attrs->at(0), in_attrs->at(0)); + DTYPE_ASSIGN(out_attrs->at(1), 4U); + return true; +} + +NNVM_REGISTER_OP(multibox_transform_loc) + .describe(R"doc("Location transformation for multibox detection." +)doc" NNVM_ADD_FILELINE) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr_parser(ParamParser) +.set_attr("FGetAttrDict", + ParamGetAttrDict) +.add_arguments(MultiBoxTransformLocParam::__FIELDS__()) +.add_argument("cls_prob", "Tensor", "Class probabilities.") +.add_argument("loc_pred", "Tensor", "Location regression predictions.") +.add_argument("anchor", "Tensor", "Multibox prior anchor boxes") +.set_attr("FListInputNames", [](const NodeAttrs& attrs) { + return std::vector{"cls_prob", "loc_pred", "anchor"}; +}) +.set_attr("FInferShape", MultiBoxTransformLocShape) +.set_attr("FInferType", MultiBoxTransformLocInferType) +.set_attr("FCorrectLayout", MultiBoxTransformLocLayout) +.set_support_level(4); + +} // namespace top +} // namespace nnvm diff --git a/nnvm/tests/python/compiler/test_top_level4.py b/nnvm/tests/python/compiler/test_top_level4.py index 819768cfb341..b202d1aad862 100644 --- a/nnvm/tests/python/compiler/test_top_level4.py +++ b/nnvm/tests/python/compiler/test_top_level4.py @@ -1,3 +1,4 @@ +import math import numpy as np import tvm from tvm.contrib import graph_runtime @@ -356,6 +357,118 @@ def test_full(): np.full(shape, fill_value=0, dtype=dtype), atol=1e-5, rtol=1e-5) +def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), + offsets=(0.5, 0.5), clip=False): + data = sym.Variable("data") + out = sym.multibox_prior(data=data, sizes=sizes, ratios=ratios, steps=steps, + offsets=offsets, clip=clip) + + in_height = dshape[2] + in_width = dshape[3] + num_sizes = len(sizes) + num_ratios = len(ratios) + size_ratio_concat = sizes + ratios + steps_h = steps[0] if steps[0] > 0 else 1.0 / in_height + steps_w = steps[1] if steps[1] > 0 else 1.0 / in_width + offset_h = offsets[0] + offset_w = offsets[1] + + oshape = (1, in_height * in_width * (num_sizes + num_ratios - 1), 4) + dtype = "float32" + np_out = np.zeros(oshape).astype(dtype) + + for i in range(in_height): + center_h = (i + offset_h) * steps_h + for j in range(in_width): + center_w = (j + offset_w) * steps_w + for k in range(num_sizes + num_ratios - 1): + w = size_ratio_concat[k] * in_height / in_width / 2.0 if k < num_sizes else \ + size_ratio_concat[0] * in_height / in_width * math.sqrt(size_ratio_concat[k + 1]) / 2.0 + h = size_ratio_concat[k] / 2.0 if k < num_sizes else \ + size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0 + count = i * in_width * (num_sizes + num_ratios - 1) + j * (num_sizes + num_ratios - 1) + k + np_out[0][count][0] = center_w - w + np_out[0][count][1] = center_h - h + np_out[0][count][2] = center_w + w + np_out[0][count][3] = center_h + h + if clip: + np_out = np.clip(np_out, 0, 1) + + target = "llvm" + ctx = tvm.cpu() + graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape}) + m = graph_runtime.create(graph, lib, ctx) + m.set_input("data", np.random.uniform(size=dshape).astype(dtype)) + m.run() + out = m.get_output(0, tvm.nd.empty(np_out.shape, dtype)) + np.testing.assert_allclose(out.asnumpy(), np_out, atol=1e-5, rtol=1e-5) + +def test_multibox_prior(): + verify_multibox_prior((1, 3, 50, 50)) + verify_multibox_prior((1, 3, 224, 224), sizes=(0.5, 0.25, 0.1), ratios=(1, 2, 0.5)) + verify_multibox_prior((1, 32, 32, 32), sizes=(0.5, 0.25), ratios=(1, 2), steps=(2, 2), clip=True) + +def test_multibox_transform_loc(): + batch_size = 1 + num_anchors = 3 + num_classes = 3 + cls_prob = sym.Variable("cls_prob") + loc_preds = sym.Variable("loc_preds") + anchors = sym.Variable("anchors") + transform_loc_data, valid_count = sym.multibox_transform_loc(cls_prob=cls_prob, loc_pred=loc_preds, + anchor=anchors) + out = sym.nms(data=transform_loc_data, valid_count=valid_count) + + # Manually create test case + np_cls_prob = np.array([[[0.2, 0.5, 0.3], [0.25, 0.3, 0.45], [0.7, 0.1, 0.2]]]) + np_loc_preds = np.array([[0.1, -0.2, 0.3, 0.2, 0.2, 0.4, 0.5, -0.3, 0.7, -0.2, -0.4, -0.8]]) + np_anchors = np.array([[[-0.1, -0.1, 0.1, 0.1], [-0.2, -0.2, 0.2, 0.2], [1.2, 1.2, 1.5, 1.5]]]) + + expected_np_out = np.array([[[1, 0.69999999, 0, 0, 0.10818365, 0.10008108], + [0, 0.44999999, 1, 1, 1, 1], + [0, 0.30000001, 0, 0, 0.22903419, 0.20435292]]]) + + target = "llvm" + dtype = "float32" + ctx = tvm.cpu() + graph, lib, _ = nnvm.compiler.build(out, target, {"cls_prob": (batch_size, num_anchors, num_classes), + "loc_preds": (batch_size, num_anchors * 4), + "anchors": (1, num_anchors, 4)}) + m = graph_runtime.create(graph, lib, ctx) + m.set_input(**{"cls_prob": np_cls_prob.astype(dtype), "loc_preds": np_loc_preds.astype(dtype), "anchors": np_anchors.astype(dtype)}) + m.run() + out = m.get_output(0, tvm.nd.empty(expected_np_out.shape, dtype)) + np.testing.assert_allclose(out.asnumpy(), expected_np_out, atol=1e-5, rtol=1e-5) + +def test_nms(): + dshape = (1, 5, 6) + data = sym.Variable("data") + valid_count = sym.Variable("valid_count", dtype="int32") + nms_threshold = 0.7 + force_suppress = True + nms_topk = 2 + out = sym.nms(data=data, valid_count=valid_count, nms_threshold=nms_threshold, + force_suppress=force_suppress, nms_topk=nms_topk) + + np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80], + [0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79], + [1, 0.5, 100, 60, 70, 110]]]).astype("float32") + np_valid_count = np.array([4]).astype("int32") + np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45], + [0, 0.4, 4, 21, 19, 40], [-1, 0.9, 35, 61, 52, 79], + [-1, -1, -1, -1, -1, -1]]]) + + target = "llvm" + ctx = tvm.cpu() + graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape, "valid_count": (dshape[0],)}, + dtype={"data": "float32", "valid_count": "int32"}) + m = graph_runtime.create(graph, lib, ctx) + m.set_input(**{"data": np_data, "valid_count": np_valid_count}) + m.run() + out = m.get_output(0, tvm.nd.empty(np_result.shape, "float32")) + np.testing.assert_allclose(out.asnumpy(), np_result, atol=1e-5, rtol=1e-5) + + if __name__ == "__main__": test_reshape() @@ -370,4 +483,7 @@ def test_full(): test_block_grad() test_full() test_flip() + test_multibox_prior() + test_multibox_transform_loc() + test_nms() print(nnvm.compiler.engine.dump()) diff --git a/python/tvm/contrib/rpc/proxy.py b/python/tvm/contrib/rpc/proxy.py index c7f66d68e492..e1e81d20b611 100644 --- a/python/tvm/contrib/rpc/proxy.py +++ b/python/tvm/contrib/rpc/proxy.py @@ -333,7 +333,7 @@ def _update_tracker(self, period_update=False): rpc_key = key.split(":")[0] base.sendjson(self._tracker_conn, [TrackerCode.PUT, rpc_key, - (self._listen_port, key), None]) + (self._listen_port, key), None]) assert base.recvjson(self._tracker_conn) == TrackerCode.SUCCESS if rpc_key not in self._key_set: self._key_set.add(rpc_key) diff --git a/topi/python/topi/vision/ssd/multibox.py b/topi/python/topi/vision/ssd/multibox.py index 31a090e345d5..a8f97146519b 100644 --- a/topi/python/topi/vision/ssd/multibox.py +++ b/topi/python/topi/vision/ssd/multibox.py @@ -27,7 +27,7 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets): ratios : tuple of float Tuple of ratios for anchor boxes. - steps : Tuple of int + steps : Tuple of float Priorbox step across y and x, -1 for auto calculation. offsets : tuple of int @@ -86,7 +86,7 @@ def multibox_prior(data, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5, ratios : tuple of float Tuple of ratios for anchor boxes. - steps : Tuple of int + steps : Tuple of float Priorbox step across y and x, -1 for auto calculation. offsets : tuple of int @@ -211,8 +211,8 @@ def transform_loc(loc, loc_base_idx, anchor, anchor_base_idx, clip, vx, vy, vw, @tvm.target.generic_func -def mutibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, - variances=(0.1, 0.1, 0.2, 0.2)): +def multibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, + variances=(0.1, 0.1, 0.2, 0.2)): """Location transformation for multibox detection Parameters @@ -237,11 +237,7 @@ def mutibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, Returns ------- - out : tvm.Tensor - 3-D tensor with shape (batch_size, num_anchors, 6) - - valid_count : tvm.Tensor - 1-D tensor with shape (batch_size,), number of valid anchor boxes. + ret : tuple of tvm.Tensor """ batch_size = cls_prob.shape[0] num_anchors = anchor.shape[1] @@ -259,7 +255,7 @@ def mutibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, dtype=[valid_count_dtype, cls_prob.dtype], out_buffers=[valid_count_buf, out_buf], tag="multibox_transform_loc") - return out, valid_count + return [out, valid_count] @tvm.target.generic_func @@ -301,7 +297,7 @@ def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nm out : tvm.Tensor 3-D tensor with shape (batch_size, num_anchors, 6) """ - inter_out, valid_count = mutibox_transform_loc(cls_prob, loc_pred, anchor, - clip, threshold, variances) - out = nms(inter_out, valid_count, nms_threshold, force_suppress, nms_topk) + inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor, + clip, threshold, variances) + out = nms(inter_out[0], inter_out[1], nms_threshold, force_suppress, nms_topk) return out diff --git a/tutorials/nnvm/deploy_ssd.py b/tutorials/nnvm/deploy_ssd.py new file mode 100644 index 000000000000..b3b5072f28c9 --- /dev/null +++ b/tutorials/nnvm/deploy_ssd.py @@ -0,0 +1,137 @@ +""" +Deploy Single Shot Multibox Detector(SSD) model +=============================================== +**Author**: `Yao Wang `_ + +This article is an introductory tutorial to deploy SSD models with TVM. +We will use mxnet pretrained SSD model with Resnet50 as body network and +convert it to NNVM graph. +""" +import os +import zipfile +import tvm +import mxnet as mx +import cv2 +import numpy as np + +from nnvm import compiler +from nnvm.frontend import from_mxnet +from nnvm.testing.download import download +from tvm.contrib import graph_runtime +from mxnet.model import load_checkpoint + + +###################################################################### +# Set the parameters here +# ----------------------- +# .. note:: +# +# Currently we support compiling SSD on CPU only. +# GPU support is in progress. + +model_name = "ssd_resnet50_512" +model_file = "%s.zip" % model_name +test_image = "dog.jpg" +dshape = (1, 3, 512, 512) +dtype = "float32" +target = "llvm" +ctx = tvm.cpu() + +###################################################################### +# Download MXNet SSD pre-trained model and demo image +# --------------------------------------------------- +# Pre-trained model available at +# https://github.com/apache/incubator-\mxnet/tree/master/example/ssd + +model_url = "https://github.com/zhreshold/mxnet-ssd/releases/download/v0.6/" \ + "resnet50_ssd_512_voc0712_trainval.zip" +image_url = "https://cloud.githubusercontent.com/assets/3307514/20012567/" \ + "cbb60336-a27d-11e6-93ff-cbc3f09f5c9e.jpg" +inference_symbol_folder = "c1904e900848df4548ce5dfb18c719c7-a28c4856c827fe766aa3da0e35bad41d44f0fb26" +inference_symbol_url = "https://gist.github.com/kevinthesun/c1904e900848df4548ce5dfb18c719c7/" \ + "archive/a28c4856c827fe766aa3da0e35bad41d44f0fb26.zip" + +dir = "ssd_model" +if not os.path.exists(dir): + os.makedirs(dir) +model_file_path = "%s/%s" % (dir, model_file) +test_image_path = "%s/%s" % (dir, test_image) +inference_symbol_path = "%s/inference_model.zip" % dir +download(model_url, model_file_path) +download(image_url, test_image_path) +download(inference_symbol_url, inference_symbol_path) + +zip_ref = zipfile.ZipFile(model_file_path, 'r') +zip_ref.extractall(dir) +zip_ref.close() +zip_ref = zipfile.ZipFile(inference_symbol_path) +zip_ref.extractall(dir) +zip_ref.close() + +###################################################################### +# Convert and compile model with NNVM for CPU. + +sym = mx.sym.load("%s/%s/ssd_resnet50_inference.json" % (dir, inference_symbol_folder)) +_, arg_params, aux_params = load_checkpoint("%s/%s" % (dir, model_name), 0) +net, params = from_mxnet(sym, arg_params, aux_params) +with compiler.build_config(opt_level=3): + graph, lib, params = compiler.build(net, target, {"data": dshape}, params=params) + +###################################################################### +# Create TVM runtime and do inference + +# Preprocess image +image = cv2.imread(test_image_path) +img_data = cv2.resize(image, (dshape[2], dshape[3])) +img_data = img_data[:, :, (2, 1, 0)].astype(np.float32) +img_data -= np.array([123, 117, 104]) +img_data = np.transpose(np.array(img_data), (2, 0, 1)) +img_data = np.expand_dims(img_data, axis=0) +# Build TVM runtime +m = graph_runtime.create(graph, lib, ctx) +m.set_input('data', tvm.nd.array(img_data.astype(dtype))) +m.set_input(**params) +# execute +m.run() +# get outputs +_, oshape = compiler.graph_util.infer_shape(graph, shape={"data": dshape}) +tvm_output = m.get_output(0, tvm.nd.empty(tuple(oshape[0]), dtype)) + + +###################################################################### +# Display result + +class_names = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", + "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", + "sheep", "sofa", "train", "tvmonitor"] +def display(img, out, thresh=0.5): + import random + import matplotlib as mpl + import matplotlib.pyplot as plt + mpl.rcParams['figure.figsize'] = (10,10) + pens = dict() + plt.clf() + plt.imshow(img) + for det in out: + cid = int(det[0]) + if cid < 0: + continue + score = det[1] + if score < thresh: + continue + if cid not in pens: + pens[cid] = (random.random(), random.random(), random.random()) + scales = [img.shape[1], img.shape[0]] * 2 + xmin, ymin, xmax, ymax = [int(p * s) for p, s in zip(det[2:6].tolist(), scales)] + rect = plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, + edgecolor=pens[cid], linewidth=3) + plt.gca().add_patch(rect) + text = class_names[cid] + plt.gca().text(xmin, ymin-2, '{:s} {:.3f}'.format(text, score), + bbox=dict(facecolor=pens[cid], alpha=0.5), + fontsize=12, color='white') + plt.show() + +image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) +display(image, tvm_output.asnumpy()[0], thresh=0.45) + diff --git a/tutorials/nnvm/from_darknet.py b/tutorials/nnvm/from_darknet.py index 9613f023c1e9..2cd681b624ad 100644 --- a/tutorials/nnvm/from_darknet.py +++ b/tutorials/nnvm/from_darknet.py @@ -14,23 +14,18 @@ pip install cffi pip install opencv-python """ -from ctypes import * -import math -import random + import nnvm import nnvm.frontend.darknet import nnvm.testing.darknet -from nnvm.testing.darknet import __darknetffi__ import matplotlib.pyplot as plt import numpy as np import tvm -import os, sys, time, urllib, requests -if sys.version_info >= (3,): - import urllib.request as urllib2 - import urllib.parse as urlparse -else: - import urllib2 - import urlparse +import os + +from ctypes import * +from nnvm.testing.download import download +from nnvm.testing.darknet import __darknetffi__ ###################################################################### # Set the parameters here. @@ -41,62 +36,6 @@ target = 'llvm' ctx = tvm.cpu(0) -def dlProgress(count, block_size, total_size): - """Show the download progress.""" - global start_time - if count == 0: - start_time = time.time() - return - duration = time.time() - start_time - progress_size = int(count * block_size) - speed = int(progress_size / (1024 * duration)) - percent = int(count * block_size * 100 / total_size) - sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" % - (percent, progress_size / (1024 * 1024), speed, duration)) - sys.stdout.flush() - -def download(url, path, overwrite=False, sizecompare=False): - """Downloads the file from the internet. - Set the input options correctly to overwrite or do the size comparison - - Parameters - ---------- - url : str - Operator name, such as Convolution, Connected, etc - path : str - List of input symbols. - overwrite : dict - Dict of operator attributes - sizecompare : dict - Dict of operator attributes - - Returns - ------- - out_name : converted out name of operation - sym : nnvm.Symbol - Converted nnvm Symbol - """ - if os.path.isfile(path) and not overwrite: - if (sizecompare): - fileSize = os.path.getsize(path) - resHead = requests.head(url) - resGet = requests.get(url,stream=True) - if 'Content-Length' not in resHead.headers : - resGet = urllib2.urlopen(url) - urlFileSize = int(resGet.headers['Content-Length']) - if urlFileSize != fileSize: - print ("exist file got corrupted, downloading", path , " file freshly") - download(url, path, True, False) - return - print('File {} exists, skip.'.format(path)) - return - print('Downloading from url {} to {}'.format(url, path)) - try: - urllib.request.urlretrieve(url, path, reporthook=dlProgress) - print('') - except: - urllib.urlretrieve(url, path, reporthook=dlProgress) - ###################################################################### # Prepare cfg and weights file # ----------------------------