Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions src/operator/contrib/batch_norm_relu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ static bool BatchNormWithReLUShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, MovingMean, MovingVar]";
CHECK_EQ(out_shape->size(), 4U);
const mxnet::TShape &dshape = in_shape->at(batchnormrelu::kData);
if (!mxnet::ndim_is_known(dshape)) {
return false;
}

const size_t channelAxis = static_cast<size_t>(param.axis < 0
? static_cast<int>(dshape.ndim()) + param.axis
Expand All @@ -63,10 +66,6 @@ static bool BatchNormWithReLUShape(const nnvm::NodeAttrs& attrs,

const int channelCount = dshape[channelAxis];

if (!mxnet::ndim_is_known(dshape)) {
return false;
}

in_shape->at(batchnormrelu::kGamma) = mxnet::TShape(Shape1(channelCount));
in_shape->at(batchnormrelu::kBeta) = mxnet::TShape(Shape1(channelCount));
in_shape->at(batchnormrelu::kInMovingMean) = mxnet::TShape(Shape1(channelCount)); // kMovingMean
Expand All @@ -84,14 +83,36 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type, std::vector<int> *out_type) {
using namespace mshadow;
CHECK_GE(in_type->size(), 1U);
const int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
const size_t n_out = 4;
// For float16 input type beta, gamma, mean, and average are stored in float32.
// For other input types, these parameters have the same type as input
// NOTE: This requirement is from cuDNN (v. 4 and 5)
int dtype_param;
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
int dtype = (*in_type)[0];

if (type_is_none(dtype)) {
// Input type is undefined, we try backward inference
if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
// Neither the input nor the output are defined,
// types cannot be infered for this op
return false;
} else {
// Input type is undefined but output type is: backward inference
dtype = (*out_type)[0];
(*in_type)[0] = dtype;
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
dtype_param = mshadow::DataType<AccRealX>::kFlag; });
}
} else {
// Input type is defined but output type is not: forward inference
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
dtype_param = mshadow::DataType<AccRealX>::kFlag; });
out_type->clear();
out_type->push_back(dtype);
for (size_t i = 1; i < n_out; ++i) {
out_type->push_back(dtype_param);
}
}
std::vector<std::string> args{"data", "gamma", "beta", "mean", "var"};
CHECK_LE(in_type->size(), args.size());
for (size_t i = 1; i < in_type->size(); ++i) {
Expand All @@ -101,12 +122,6 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs& attrs,
UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, args[i]);
}
}
const size_t n_out = 4;
out_type->clear();
out_type->push_back(dtype);
for (size_t i = 1; i < n_out; ++i) {
out_type->push_back(dtype_param);
}
return true;
}

Expand Down
40 changes: 27 additions & 13 deletions src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,9 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, MovingMean, MovingVar]";
CHECK_EQ(out_shape->size(), 3U);
const mxnet::TShape &dshape = in_shape->at(batchnorm::kData);
if (!mxnet::ndim_is_known(dshape)) {
return false;
}

const size_t channelAxis = static_cast<size_t>(param.axis < 0
? static_cast<int>(dshape.ndim()) + param.axis
Expand All @@ -373,10 +376,6 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs,

const int channelCount = dshape[channelAxis];

if (!mxnet::ndim_is_known(dshape)) {
return false;
}

in_shape->at(batchnorm::kGamma) = mxnet::TShape(Shape1(channelCount));
in_shape->at(batchnorm::kBeta) = mxnet::TShape(Shape1(channelCount));
in_shape->at(batchnorm::kInMovingMean) = mxnet::TShape(Shape1(channelCount)); // kMovingMean
Expand All @@ -394,14 +393,35 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type, std::vector<int> *out_type) {
using namespace mshadow;
CHECK_GE(in_type->size(), 1U);
const int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
const size_t n_out = 3;
// For float16 input type beta, gamma, mean, and average are stored in float32.
// For other input types, these parameters have the same type as input
// NOTE: This requirement is from cuDNN (v. 4 and 5)
int dtype_param;
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
int dtype = (*in_type)[0];
if (type_is_none(dtype)) {
// Input type is undefined, we try backward inference
if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
// Neither the input nor the output are defined,
// types cannot be infered for this op
return false;
} else {
// Input type is undefined but output type is: backward inference
dtype = (*out_type)[0];
(*in_type)[0] = dtype;
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
dtype_param = mshadow::DataType<AccRealX>::kFlag; });
}
} else {
// Input type is defined but output type is not: forward inference
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
dtype_param = mshadow::DataType<AccRealX>::kFlag; });
out_type->clear();
out_type->push_back(dtype);
for (size_t i = 1; i < n_out; ++i) {
out_type->push_back(dtype_param);
}
}
std::vector<std::string> args{"data", "gamma", "beta", "mean", "var"};
CHECK_LE(in_type->size(), args.size());
for (size_t i = 1; i < in_type->size(); ++i) {
Expand All @@ -411,12 +431,6 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,
UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, args[i]);
}
}
const size_t n_out = 3;
out_type->clear();
out_type->push_back(dtype);
for (size_t i = 1; i < n_out; ++i) {
out_type->push_back(dtype_param);
}
return true;
}

Expand Down
13 changes: 10 additions & 3 deletions src/operator/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,16 +285,23 @@ static bool ConvolutionType(const nnvm::NodeAttrs& attrs,
const ConvolutionParam& param_ = nnvm::get<ConvolutionParam>(attrs.parsed);
CHECK_GE(in_type->size(), 1U);
int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
if (type_is_none(dtype)) {
if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
return false;
} else {
dtype = (*out_type)[0];
}
} else {
out_type->clear();
out_type->push_back(dtype);
}
for (size_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
(*in_type)[i] = dtype;
} else {
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]);
}
}
out_type->clear();
out_type->push_back(dtype);
return true;
}

Expand Down
18 changes: 15 additions & 3 deletions src/operator/nn/deconvolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -332,16 +332,28 @@ static bool DeconvolutionType(const nnvm::NodeAttrs& attrs,
const DeconvolutionParam& param_ = nnvm::get<DeconvolutionParam>(attrs.parsed);
CHECK_GE(in_type->size(), 1U);
int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
if (type_is_none(dtype)) {
// Input type is undefined, we try backward inference
if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
// Neither the input nor the output are defined,
// types cannot be infered for this op
return false;
} else {
// Input type is undefined but output type is: backward inference
dtype = (*out_type)[0];
}
} else {
// Input type is defined but output type is not: forward inference
out_type->clear();
out_type->push_back(dtype);
}
for (size_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
(*in_type)[i] = dtype;
} else {
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]);
}
}
out_type->clear();
out_type->push_back(dtype);
return true;
}

Expand Down
8 changes: 4 additions & 4 deletions src/operator/nn/group_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ static bool GroupNormShape(const nnvm::NodeAttrs& attrs,
using namespace mshadow;
CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]";
const mxnet::TShape &dshape = in_shape->at(groupnorm::kData);
CHECK_GE(dshape.ndim(), 3U);
const int num_groups = param.num_groups;
CHECK_EQ(dshape[1] % num_groups, 0) << "# of channels must be divisible by # of groups";

if (!mxnet::ndim_is_known(dshape)) {
return false;
}

CHECK_GE(dshape.ndim(), 3U);
const int num_groups = param.num_groups;
CHECK_EQ(dshape[1] % num_groups, 0) << "# of channels must be divisible by # of groups";

in_shape->at(groupnorm::kGamma) = mxnet::TShape(Shape1(num_groups));
in_shape->at(groupnorm::kBeta) = mxnet::TShape(Shape1(num_groups));

Expand Down
7 changes: 4 additions & 3 deletions src/operator/nn/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,16 @@ static bool LayerNormShape(const nnvm::NodeAttrs& attrs,
using namespace mshadow;
CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]";
const mxnet::TShape &dshape = in_shape->at(layernorm::kData);
if (!mxnet::ndim_is_known(dshape)) {
return false;
}

int axis = GetRealAxis(param.axis, dshape.ndim());
CHECK(axis >= 0 && axis < dshape.ndim())
<< "Channel axis out of range: axis=" << param.axis;

const int channelCount = dshape[axis];

if (!mxnet::ndim_is_known(dshape)) {
return false;
}
SHAPE_ASSIGN_CHECK(*in_shape,
layernorm::kGamma,
mxnet::TShape(Shape1(channelCount)));
Expand Down
7 changes: 5 additions & 2 deletions src/operator/nn/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,14 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs,
mxnet::ShapeVector *out_shape) {
const PoolingParam &param = nnvm::get<PoolingParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), 1U);
const mxnet::TShape &dshape = (*in_shape)[0];
if (!mxnet::ndim_is_known(dshape)) {
return false;
}
if (param.pool_type == pool_enum::kLpPooling) {
CHECK(param.p_value.has_value());
}
const mxnet::TShape &dshape = (*in_shape)[0];

if (param.pooling_convention == pool_enum::kSame) {
CHECK_EQ(dshape.ndim(), 3U)
<< "Pooling: Input data should be 3D in (batch, channel, x)"
Expand All @@ -114,7 +118,6 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs,
<< "Pooling: Input data should be 3D in (batch, channel, x)"
<< " Or 4D in (batch, channel, y, x) "
<< " Or 5D in (batch, channel, d, y, x)";
if (!mxnet::ndim_is_known(dshape)) return false;
int layout = param.GetLayout(dshape.ndim());
if (param.global_pool) {
mxnet::TShape oshape = dshape;
Expand Down
18 changes: 15 additions & 3 deletions src/operator/softmax_output.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,28 @@ static bool SoftmaxOutputType(const nnvm::NodeAttrs& attrs,
std::vector<int> *out_type) {
CHECK_EQ(in_type->size(), 2U);
int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
if (type_is_none(dtype)) {
// Input type is undefined, we try backward inference
if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
// Neither the input nor the output are defined,
// types cannot be infered for this op
return false;
} else {
// Input type is undefined but output type is: backward inference
dtype = (*out_type)[0];
}
} else {
// Input type is defined but output type is not: forward inference
out_type->clear();
out_type->push_back(dtype);
}
for (size_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
(*in_type)[i] = dtype;
} else {
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]);
}
}
out_type->clear();
out_type->push_back(dtype);
return true;
}

Expand Down
22 changes: 17 additions & 5 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -455,9 +455,9 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape& shp = (*in_attrs)[0];
mxnet::TShape& out_shp = (*out_attrs)[0];
CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";
if (shp.ndim() == -1 && out_shp.ndim() == -1)
if (!mxnet::ndim_is_known(shp) && !mxnet::ndim_is_known(out_shp))
return false; // none of the shapes is known
CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";
if (out_shp.ndim() >= 0 && shp.ndim() >= 0)
CHECK_EQ(out_shp.ndim(), shp.ndim());
mxnet::TShape get(std::max(shp.ndim(), out_shp.ndim()), -1);
Expand Down Expand Up @@ -506,12 +506,12 @@ inline bool ExpandDimShape(const nnvm::NodeAttrs& attrs,
const ExpandDimParam& param = nnvm::get<ExpandDimParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
if (!mxnet::ndim_is_known(in_attrs->at(0)) && !mxnet::ndim_is_known(out_attrs->at(0))) {
mxnet::TShape& ishape = (*in_attrs)[0];
mxnet::TShape& oshape = (*out_attrs)[0];
if (!mxnet::ndim_is_known(ishape) && !mxnet::ndim_is_known(oshape)) {
return false;
}

mxnet::TShape& ishape = (*in_attrs)[0];
mxnet::TShape& oshape = (*out_attrs)[0];
int indim = ishape.ndim();
bool unknown_ishape = false;
if (-1 == indim) {
Expand Down Expand Up @@ -1434,6 +1434,9 @@ inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape& ishape = (*in_attrs)[0];
mxnet::TShape& from_shape = (*in_attrs)[1];
if (!mxnet::ndim_is_known(ishape) || !mxnet::ndim_is_known(from_shape)) {
return false;
}
if (param.axes.ndim() == 0) {
CHECK_EQ(ishape.ndim(), from_shape.ndim())
<< "By default slice_axis performs slice on all axes, but ndim mismatch "
Expand Down Expand Up @@ -1727,6 +1730,9 @@ inline bool RepeatOpShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
const mxnet::TShape& ishape = (*in_attrs)[0];
if (!mxnet::ndim_is_known(ishape)) {
return false;
}
int repeats = 0;
dmlc::optional<int> axisOpt;
GetRepeatParams(param, ishape, &repeats, &axisOpt);
Expand Down Expand Up @@ -2395,6 +2401,9 @@ inline bool DepthToSpaceOpShape(const nnvm::NodeAttrs& attrs,
mxnet::TShape expected_out(4, -1);

mxnet::TShape& in_shape = in_attrs->at(0);
if (!mxnet::ndim_is_known(in_shape)) {
return false;
}
int block = param.block_size;
CHECK_NE(block, 0) << "block_size must be a positive integer value";
CHECK_NE(in_shape[1], 0) << "Depth dimension:1 cannot be 0";
Expand Down Expand Up @@ -2559,6 +2568,9 @@ inline bool SpaceToDepthOpShape(const nnvm::NodeAttrs& attrs,
mxnet::TShape expected_out(in_attrs->at(0).ndim(), -1);

mxnet::TShape& in_shape = in_attrs->at(0);
if (!mxnet::ndim_is_known(in_shape)) {
return false;
}
int block = param.block_size;
CHECK_NE(block, 0) << "block_size must be a positive integer value";
CHECK_NE(in_shape[0], 0)
Expand Down