From 5210104d1a1935499569114be163baa3ec9e7f37 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Mon, 25 Jan 2021 14:44:03 -0800 Subject: [PATCH 1/2] fix numpy pad operator --- src/api/operator/numpy/np_pad_op.cc | 65 ++++++++++++++++++++++---- tests/python/unittest/test_numpy_op.py | 2 +- 2 files changed, 58 insertions(+), 9 deletions(-) diff --git a/src/api/operator/numpy/np_pad_op.cc b/src/api/operator/numpy/np_pad_op.cc index 317076d23d48..23f9f84f2232 100644 --- a/src/api/operator/numpy/np_pad_op.cc +++ b/src/api/operator/numpy/np_pad_op.cc @@ -51,20 +51,70 @@ inline int String2MXNetPadType(const std::string& s) { return 0; } +inline Tuple> BroadcastPadWidth(int ndim, runtime::ADT adt) { + std::vector> temp; + int adt_size = adt.size(); + if (const runtime::IntegerObj* pad = adt[0].as()) { + if (adt_size == 1) { + int pad_width = static_cast(pad->value); + if (ndim == 1) { + temp.emplace_back(mxnet::Tuple({pad_width})); + temp.emplace_back(mxnet::Tuple({pad_width})); + } else { + for (int dim = 0; dim < ndim; dim++) { + temp.emplace_back(mxnet::Tuple({pad_width, pad_width})); + } + } + } else { + CHECK_EQ(adt_size, 2) << "Invalid Input pad_width"; + int pad_before = static_cast(pad->value); + int pad_after = static_cast(Downcast(adt[1])->value); + if (ndim == 1) { + temp.emplace_back(mxnet::Tuple({pad_before})); + temp.emplace_back(mxnet::Tuple({pad_after})); + } else { + for (int dim = 0; dim < ndim; dim++) { + temp.emplace_back(mxnet::Tuple({pad_before, pad_after})); + } + } + } + } else { + if (adt_size == 1) { + if (ndim == 1) { + runtime::ADT pad_adt = Downcast(adt[0]); + int pad_before = + static_cast(Downcast(pad_adt[0])->value); + int pad_after = + static_cast(Downcast(pad_adt[1])->value); + temp.emplace_back(mxnet::Tuple({pad_before})); + temp.emplace_back(mxnet::Tuple({pad_after})); + } else { + for (int dim = 0; dim < ndim; dim++) { + temp.emplace_back(mxnet::Tuple(adt[0])); + } + } + } else { + CHECK_EQ(adt_size, ndim) << "Invalid Input pad_width"; + for (int dim = 0; dim < ndim; dim++) { + temp.emplace_back(mxnet::Tuple(adt[dim])); + } + } + } + return Tuple>(temp.begin(), temp.end()); +} + MXNET_REGISTER_API("_npi.pad") .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_pad"); nnvm::NodeAttrs attrs; op::NumpyPadParam param; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + mxnet::TShape ashape = inputs[0]->shape(); + int ndim = ashape.ndim(); ADT adt = Downcast(args[1].operator ObjectRef()); - int ndim = adt.size(); - std::vector> temp; - int counter = 0; - for (counter = 0; counter < ndim; counter++) { - temp.emplace_back(mxnet::Tuple(adt[counter])); - } - param.pad_width = Tuple>(temp.begin(), temp.end()); + // broadcast pad_width to (ndim, 2) + param.pad_width = BroadcastPadWidth(ndim, adt); param.mode = String2MXNetPadType(args[2].operator std::string()); if (args[3].type_code() != kNull) { param.constant_values = args[3].operator double(); @@ -77,7 +127,6 @@ MXNET_REGISTER_API("_npi.pad") SetAttrDict(&attrs); int num_inputs = 1; int num_outputs = 0; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); *ret = reinterpret_cast(ndoutputs[0]); }); diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index e66ee313ddf0..650c420941e3 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -8325,7 +8325,7 @@ def __init__(self, pad_width, mode='constant'): def hybrid_forward(self,F,A,**kwargs): return F.np.pad(A, self._pad_width, mode=self._mode, **kwargs) - shapes = [(1,5), (2,2), (2,2), (3,3), (2,3), (3,4,5)] + shapes = [6, (1,5), (2,2), (2,2), (3,3), (2,3), (3,4,5)] dtypes = [np.int8, np.uint8, np.int32, np.int64, np.float16, np.float32, np.float64] mode = ['constant', 'reflect', 'symmetric', 'edge', 'minimum', 'maximum'] for hybridize, shape, dtype, in itertools.product([False,True], shapes, dtypes): From d11ea0bc6da0c26300c5000a811dd3c302b97e60 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Mon, 25 Jan 2021 15:30:33 -0800 Subject: [PATCH 2/2] fix sanity --- src/api/operator/numpy/np_pad_op.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/api/operator/numpy/np_pad_op.cc b/src/api/operator/numpy/np_pad_op.cc index 23f9f84f2232..1351c26dd5ba 100644 --- a/src/api/operator/numpy/np_pad_op.cc +++ b/src/api/operator/numpy/np_pad_op.cc @@ -82,9 +82,9 @@ inline Tuple> BroadcastPadWidth(int ndim, runtime::ADT adt) { if (adt_size == 1) { if (ndim == 1) { runtime::ADT pad_adt = Downcast(adt[0]); - int pad_before = + int pad_before = static_cast(Downcast(pad_adt[0])->value); - int pad_after = + int pad_after = static_cast(Downcast(pad_adt[1])->value); temp.emplace_back(mxnet::Tuple({pad_before})); temp.emplace_back(mxnet::Tuple({pad_after}));