From 4ad515c64c9c69b0f4d78796c97022f193752aba Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Sun, 17 May 2020 04:45:31 -0700 Subject: [PATCH 1/2] Fix FInferShape for some ops to support partial type inference Signed-off-by: Serge Panev --- src/operator/contrib/batch_norm_relu.cc | 7 +++---- src/operator/nn/batch_norm.cc | 7 +++---- src/operator/nn/group_norm.cc | 8 ++++---- src/operator/nn/layer_norm.cc | 7 ++++--- src/operator/nn/pooling.cc | 9 +++++++-- 5 files changed, 21 insertions(+), 17 deletions(-) diff --git a/src/operator/contrib/batch_norm_relu.cc b/src/operator/contrib/batch_norm_relu.cc index 14452cc96729..0bb2f0b43693 100644 --- a/src/operator/contrib/batch_norm_relu.cc +++ b/src/operator/contrib/batch_norm_relu.cc @@ -55,6 +55,9 @@ static bool BatchNormWithReLUShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, MovingMean, MovingVar]"; CHECK_EQ(out_shape->size(), 4U); const mxnet::TShape &dshape = in_shape->at(batchnormrelu::kData); + if (!mxnet::ndim_is_known(dshape)) { + return false; + } const size_t channelAxis = static_cast(param.axis < 0 ? static_cast(dshape.ndim()) + param.axis @@ -63,10 +66,6 @@ static bool BatchNormWithReLUShape(const nnvm::NodeAttrs& attrs, const int channelCount = dshape[channelAxis]; - if (!mxnet::ndim_is_known(dshape)) { - return false; - } - in_shape->at(batchnormrelu::kGamma) = mxnet::TShape(Shape1(channelCount)); in_shape->at(batchnormrelu::kBeta) = mxnet::TShape(Shape1(channelCount)); in_shape->at(batchnormrelu::kInMovingMean) = mxnet::TShape(Shape1(channelCount)); // kMovingMean diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 815288cfe554..b865269fc6f5 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -323,6 +323,9 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, MovingMean, MovingVar]"; CHECK_EQ(out_shape->size(), 3U); const mxnet::TShape &dshape = in_shape->at(batchnorm::kData); + if (!mxnet::ndim_is_known(dshape)) { + return false; + } const size_t channelAxis = static_cast(param.axis < 0 ? static_cast(dshape.ndim()) + param.axis @@ -331,10 +334,6 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs, const index_t channelCount = dshape[channelAxis]; - if (!mxnet::ndim_is_known(dshape)) { - return false; - } - in_shape->at(batchnorm::kGamma) = mxnet::TShape(Shape1(channelCount)); in_shape->at(batchnorm::kBeta) = mxnet::TShape(Shape1(channelCount)); in_shape->at(batchnorm::kInMovingMean) = mxnet::TShape(Shape1(channelCount)); // kMovingMean diff --git a/src/operator/nn/group_norm.cc b/src/operator/nn/group_norm.cc index c939b4499c94..53fab3161426 100644 --- a/src/operator/nn/group_norm.cc +++ b/src/operator/nn/group_norm.cc @@ -39,14 +39,14 @@ static bool GroupNormShape(const nnvm::NodeAttrs& attrs, using namespace mshadow; CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]"; const mxnet::TShape &dshape = in_shape->at(groupnorm::kData); - CHECK_GE(dshape.ndim(), 3U); - const int num_groups = param.num_groups; - CHECK_EQ(dshape[1] % num_groups, 0) << "# of channels must be divisible by # of groups"; - if (!mxnet::ndim_is_known(dshape)) { return false; } + CHECK_GE(dshape.ndim(), 3U); + const int num_groups = param.num_groups; + CHECK_EQ(dshape[1] % num_groups, 0) << "# of channels must be divisible by # of groups"; + in_shape->at(groupnorm::kGamma) = mxnet::TShape(Shape1(dshape[1])); in_shape->at(groupnorm::kBeta) = mxnet::TShape(Shape1(dshape[1])); diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc index e3d641af4015..11178b358c2d 100644 --- a/src/operator/nn/layer_norm.cc +++ b/src/operator/nn/layer_norm.cc @@ -43,15 +43,16 @@ static bool LayerNormShape(const nnvm::NodeAttrs& attrs, using namespace mshadow; CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]"; const mxnet::TShape &dshape = in_shape->at(layernorm::kData); + if (!mxnet::ndim_is_known(dshape)) { + return false; + } + int axis = GetRealAxis(param.axis, dshape.ndim()); CHECK(axis >= 0 && axis < dshape.ndim()) << "Channel axis out of range: axis=" << param.axis; const index_t channelCount = dshape[axis]; - if (!mxnet::ndim_is_known(dshape)) { - return false; - } SHAPE_ASSIGN_CHECK(*in_shape, layernorm::kGamma, mxnet::TShape(Shape1(channelCount))); diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc index 03787f42b038..56edf74ee67a 100644 --- a/src/operator/nn/pooling.cc +++ b/src/operator/nn/pooling.cc @@ -95,10 +95,15 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs, mxnet::ShapeVector *out_shape) { const PoolingParam ¶m = nnvm::get(attrs.parsed); CHECK_EQ(in_shape->size(), 1U); + const mxnet::TShape &dshape = (*in_shape)[0]; + if (!mxnet::ndim_is_known(dshape)) { + return false; + } + if (param.pool_type == pool_enum::kLpPooling) { CHECK(param.p_value.has_value()); } - const mxnet::TShape &dshape = (*in_shape)[0]; + if (param.pooling_convention == pool_enum::kSame) { CHECK_EQ(dshape.ndim(), 3U) << "Pooling: Input data should be 3D in (batch, channel, x)" @@ -114,7 +119,7 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs, << "Pooling: Input data should be 3D in (batch, channel, x)" << " Or 4D in (batch, channel, y, x) " << " Or 5D in (batch, channel, d, y, x)"; - if (!mxnet::ndim_is_known(dshape)) return false; + int layout = param.GetLayout(dshape.ndim()); if (param.global_pool) { mxnet::TShape oshape = dshape; From 4c85693f0f6cd6f6e9d1be8dc06c1a665a1052f3 Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Tue, 19 May 2020 02:06:09 -0700 Subject: [PATCH 2/2] Add missing ndim check in in matrix_op-inl.h Signed-off-by: Serge Panev --- src/operator/tensor/matrix_op-inl.h | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 50501220462b..f038c20cd2c2 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -456,9 +456,9 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1U); mxnet::TShape& shp = (*in_attrs)[0]; mxnet::TShape& out_shp = (*out_attrs)[0]; - CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions"; - if (shp.ndim() == -1 && out_shp.ndim() == -1) + if (!mxnet::ndim_is_known(shp) && !mxnet::ndim_is_known(out_shp)) return false; // none of the shapes is known + CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions"; if (out_shp.ndim() >= 0 && shp.ndim() >= 0) CHECK_EQ(out_shp.ndim(), shp.ndim()); mxnet::TShape get(std::max(shp.ndim(), out_shp.ndim()), -1); @@ -513,12 +513,12 @@ inline bool ExpandDimShape(const nnvm::NodeAttrs& attrs, const ExpandDimParam& param = nnvm::get(attrs.parsed); CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); - if (!mxnet::ndim_is_known(in_attrs->at(0)) && !mxnet::ndim_is_known(out_attrs->at(0))) { + mxnet::TShape& ishape = (*in_attrs)[0]; + mxnet::TShape& oshape = (*out_attrs)[0]; + if (!mxnet::ndim_is_known(ishape) && !mxnet::ndim_is_known(oshape)) { return false; } - mxnet::TShape& ishape = (*in_attrs)[0]; - mxnet::TShape& oshape = (*out_attrs)[0]; int indim = ishape.ndim(); bool unknown_ishape = false; if (-1 == indim) { @@ -1441,6 +1441,9 @@ inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1U); mxnet::TShape& ishape = (*in_attrs)[0]; mxnet::TShape& from_shape = (*in_attrs)[1]; + if (!mxnet::ndim_is_known(ishape) || !mxnet::ndim_is_known(from_shape)) { + return false; + } if (param.axes.ndim() == 0) { CHECK_EQ(ishape.ndim(), from_shape.ndim()) << "By default slice_axis performs slice on all axes, but ndim mismatch " @@ -1749,6 +1752,9 @@ inline bool RepeatOpShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); const mxnet::TShape& ishape = (*in_attrs)[0]; + if (!mxnet::ndim_is_known(ishape)) { + return false; + } int repeats = 0; dmlc::optional axisOpt; GetRepeatParams(param, ishape, &repeats, &axisOpt); @@ -2427,6 +2433,9 @@ inline bool DepthToSpaceOpShape(const nnvm::NodeAttrs& attrs, mxnet::TShape expected_out(4, -1); mxnet::TShape& in_shape = in_attrs->at(0); + if (!mxnet::ndim_is_known(in_shape)) { + return false; + } int block = param.block_size; CHECK_NE(block, 0) << "block_size must be a positive integer value"; CHECK_NE(in_shape[1], 0) << "Depth dimension:1 cannot be 0"; @@ -2591,6 +2600,9 @@ inline bool SpaceToDepthOpShape(const nnvm::NodeAttrs& attrs, mxnet::TShape expected_out(in_attrs->at(0).ndim(), -1); mxnet::TShape& in_shape = in_attrs->at(0); + if (!mxnet::ndim_is_known(in_shape)) { + return false; + } int block = param.block_size; CHECK_NE(block, 0) << "block_size must be a positive integer value"; CHECK_NE(in_shape[0], 0)