From 21eeaaf3c71604cee41c13d189d246562de3096e Mon Sep 17 00:00:00 2001 From: dtracz Date: Wed, 31 Jul 2019 10:50:03 -0700 Subject: [PATCH 1/4] make TransposeShape infer shape form both sides --- src/operator/tensor/matrix_op-inl.h | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 5cd7bf6652d3..0589c5fa2387 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -344,19 +344,40 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); mxnet::TShape& shp = (*in_attrs)[0]; + mxnet::TShape& out_shp = (*out_attrs)[0]; CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions"; - mxnet::TShape ret(shp.ndim(), -1); + CHECK_NE(shp.ndim(), 0) << "Number of dimensions cannot be 0"; + CHECK_NE(out_shp.ndim(), 0) << "Number of dimensions cannot be 0"; + if (shp.ndim() == -1 and out_shp.ndim() == -1) + return false; // none of the shapes is known + if (out_shp.ndim() > 0 and shp.ndim() > 0) + CHECK_EQ(out_shp.ndim(), shp.ndim()); + mxnet::TShape get(std::max(shp.ndim(), out_shp.ndim()), -1); + mxnet::TShape ret(std::max(shp.ndim(), out_shp.ndim()), -1); + auto ifUnknownAssignElseCheck = [](mxnet::TShape& arr, int i, int val) { + if (arr[i] == -1) arr[i] = val; + else CHECK_EQ(arr[i], val); }; if (param.axes.ndim() == 0) { for (int i = 0; i < shp.ndim(); ++i) { + get[i] = shp[i]; ret[i] = shp[shp.ndim()-1-i]; } + for (int i = 0; i < out_shp.ndim(); ++i) { + ifUnknownAssignElseCheck(ret, i, out_shp[i]); + ifUnknownAssignElseCheck(get, shp.ndim()-1-i, out_shp[i]); + } } else { - CHECK_EQ(shp.ndim(), param.axes.ndim()); + CHECK_EQ(std::max(shp.ndim(), out_shp.ndim()), param.axes.ndim()); for (int i = 0; i < shp.ndim(); ++i) { CHECK(param.axes[i] < static_cast(shp.ndim())); ret[i] = shp[param.axes[i]]; } + for (int i = 0; i < out_shp.ndim(); ++i) { + ifUnknownAssignElseCheck(ret, i, out_shp[i]); + ifUnknownAssignElseCheck(get, param.axes[i], out_shp[i]); + } } + SHAPE_ASSIGN_CHECK(*in_attrs, 0, get); SHAPE_ASSIGN_CHECK(*out_attrs, 0, ret); return shape_is_known(ret); } From 71e66768ff8143b3cb91c2502ba4ca9aa642b498 Mon Sep 17 00:00:00 2001 From: dtracz Date: Wed, 31 Jul 2019 15:07:59 -0700 Subject: [PATCH 2/4] small fixes --- src/operator/tensor/matrix_op-inl.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 0589c5fa2387..ef4929b91ff4 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -348,15 +348,15 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs, CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions"; CHECK_NE(shp.ndim(), 0) << "Number of dimensions cannot be 0"; CHECK_NE(out_shp.ndim(), 0) << "Number of dimensions cannot be 0"; - if (shp.ndim() == -1 and out_shp.ndim() == -1) - return false; // none of the shapes is known - if (out_shp.ndim() > 0 and shp.ndim() > 0) + if (shp.ndim() == -1 && out_shp.ndim() == -1) + return false; // none of the shapes is known + if (out_shp.ndim() > 0 && shp.ndim() > 0) CHECK_EQ(out_shp.ndim(), shp.ndim()); mxnet::TShape get(std::max(shp.ndim(), out_shp.ndim()), -1); mxnet::TShape ret(std::max(shp.ndim(), out_shp.ndim()), -1); auto ifUnknownAssignElseCheck = [](mxnet::TShape& arr, int i, int val) { - if (arr[i] == -1) arr[i] = val; - else CHECK_EQ(arr[i], val); }; + if (arr[i] == -1) { arr[i] = val; + } else { CHECK_EQ(arr[i], val); } }; if (param.axes.ndim() == 0) { for (int i = 0; i < shp.ndim(); ++i) { get[i] = shp[i]; From 5db11b2b2ca8aa9bdea79650cf39bf06a4e0a3e8 Mon Sep 17 00:00:00 2001 From: dtracz Date: Wed, 31 Jul 2019 19:55:50 -0700 Subject: [PATCH 3/4] remove redundant lines --- src/operator/tensor/matrix_op-inl.h | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index ef4929b91ff4..cd98cb020c6b 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -354,17 +354,12 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_shp.ndim(), shp.ndim()); mxnet::TShape get(std::max(shp.ndim(), out_shp.ndim()), -1); mxnet::TShape ret(std::max(shp.ndim(), out_shp.ndim()), -1); - auto ifUnknownAssignElseCheck = [](mxnet::TShape& arr, int i, int val) { - if (arr[i] == -1) { arr[i] = val; - } else { CHECK_EQ(arr[i], val); } }; if (param.axes.ndim() == 0) { for (int i = 0; i < shp.ndim(); ++i) { - get[i] = shp[i]; ret[i] = shp[shp.ndim()-1-i]; } for (int i = 0; i < out_shp.ndim(); ++i) { - ifUnknownAssignElseCheck(ret, i, out_shp[i]); - ifUnknownAssignElseCheck(get, shp.ndim()-1-i, out_shp[i]); + get[shp.ndim()-1-i] = out_shp[i]; } } else { CHECK_EQ(std::max(shp.ndim(), out_shp.ndim()), param.axes.ndim()); @@ -373,8 +368,7 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs, ret[i] = shp[param.axes[i]]; } for (int i = 0; i < out_shp.ndim(); ++i) { - ifUnknownAssignElseCheck(ret, i, out_shp[i]); - ifUnknownAssignElseCheck(get, param.axes[i], out_shp[i]); + get[param.axes[i]] = out_shp[i]; } } SHAPE_ASSIGN_CHECK(*in_attrs, 0, get); From dfb466cabc737f084a99447cdb6ab09a4e9eab5b Mon Sep 17 00:00:00 2001 From: dtracz Date: Thu, 1 Aug 2019 11:37:25 -0700 Subject: [PATCH 4/4] unit tests --- tests/python/unittest/test_operator.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index e8c9d6cbd061..8f1c2533c62c 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -8970,6 +8970,26 @@ def test_get_operator_arguments(): ok_(operator_arguments.narg == 2) +def test_transpose_infer_shape_back(): + o1 = mx.sym.ones(shape=[2,3]) + o2 = mx.sym.ones(shape=[-1,-1]) + t = mx.sym.transpose(o2) + b = o1 + t + x = b.bind(mx.cpu(), args={}) + y = x.forward() + assert(y[0].shape == (2,3)) + + +def test_transpose_infer_shape_mixed(): + o1 = mx.sym.ones(shape=[2,-1]) + o2 = mx.sym.ones(shape=[3,-1]) + t = mx.sym.transpose(o2) + b = o1 + t + x = b.bind(mx.cpu(), args={}) + y = x.forward() + assert(y[0].shape == (2,3)) + + if __name__ == '__main__': import nose nose.runmodule()