Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/operator/contrib/deformable_convolution-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class DeformableConvolutionOp : public Operator {
// calculate the shape of col_buffer
mxnet::TShape col_buffer_shape(num_spatial_axes_ + 1, -1);
col_buffer_shape[0] = conv_in_channels_ * param_.kernel.Size();
for (size_t i = 1; i < col_buffer_shape.ndim(); ++i) {
for (int i = 1; i < col_buffer_shape.ndim(); ++i) {
col_buffer_shape[i] = out_data[0].shape_[i + 1];
}
// create a column buffer using workspace and col_buffer_shape
Expand Down Expand Up @@ -347,9 +347,9 @@ class DeformableConvolutionProp : public OperatorProperty {
param_.Init(kwargs);
if (param_.kernel.ndim() == 2) {
param_.layout = param_.layout ? param_.layout.value() : mshadow::kNCHW;
if (mxnet::op::shape_is_none(param_.stride)) param_.stride = Shape2(1, 1);
if (mxnet::op::shape_is_none(param_.dilate)) param_.dilate = Shape2(1, 1);
if (mxnet::op::shape_is_none(param_.pad)) param_.pad = Shape2(0, 0);
if (param_.stride.ndim() == 0) param_.stride = Shape2(1, 1);
if (param_.dilate.ndim() == 0) param_.dilate = Shape2(1, 1);
if (param_.pad.ndim() == 0) param_.pad = Shape2(0, 0);
} else {
LOG(FATAL) << "not implemented";
}
Expand Down
33 changes: 17 additions & 16 deletions src/operator/pooling_v1-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ struct PoolingV1Param : public dmlc::Parameter<PoolingV1Param> {
int pooling_convention;
bool global_pool;
DMLC_DECLARE_PARAMETER(PoolingV1Param) {
DMLC_DECLARE_FIELD(kernel).set_default(mxnet::TShape())
DMLC_DECLARE_FIELD(kernel).set_default(mxnet::TShape(0))
.enforce_nonzero()
.describe("pooling kernel size: (y, x) or (d, y, x)");

Expand All @@ -73,11 +73,11 @@ struct PoolingV1Param : public dmlc::Parameter<PoolingV1Param> {
.add_enum("valid", pool_v1_enum::kValid)
.describe("Pooling convention to be applied.");

DMLC_DECLARE_FIELD(stride).set_default(mxnet::TShape())
DMLC_DECLARE_FIELD(stride).set_default(mxnet::TShape(0))
.enforce_nonzero()
.describe("stride: for pooling (y, x) or (d, y, x)");

DMLC_DECLARE_FIELD(pad).set_default(mxnet::TShape())
DMLC_DECLARE_FIELD(pad).set_default(mxnet::TShape(0))
.describe("pad for pooling: (y, x) or (d, y, x)");
}
};
Expand Down Expand Up @@ -217,19 +217,20 @@ class PoolingV1Prop : public OperatorProperty {
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
using namespace mshadow;
param_.Init(kwargs);
if (!param_.global_pool) {
if (param_.kernel.ndim() == 2) {
if (param_.stride.ndim() == 0) param_.stride = Shape2(1, 1);
if (param_.pad.ndim() == 0) param_.pad = Shape2(0, 0);
} else {
CHECK_EQ(param_.kernel.ndim(), 3U) << param_.kernel.ndim() << "D pooling not supported";
if (param_.stride.ndim() == 0) param_.stride = Shape3(1, 1, 1);
if (param_.pad.ndim() == 0) param_.pad = Shape3(0, 0, 0);
}
CHECK_EQ(param_.stride.ndim(), param_.kernel.ndim())
<< "stride and kernel should have the same length";
CHECK_EQ(param_.pad.ndim(), param_.kernel.ndim())
<< "pad and kernel should have the same length";
if (param_.kernel.ndim() == 1) {
if (param_.stride.ndim() == 0) param_.stride = Shape1(1);
if (param_.pad.ndim() == 0) param_.pad = Shape1(0);
} else if (param_.kernel.ndim() == 2) {
if (param_.stride.ndim() == 0) param_.stride = Shape2(1, 1);
if (param_.pad.ndim() == 0) param_.pad = Shape2(0, 0);
} else {
// ignore kernel size only if global_pool not assigned false
if (param_.global_pool == false) {
CHECK_EQ(param_.kernel.ndim(), 3U) << param_.kernel.ndim()
<< "D pooling not supported";
}
if (param_.stride.ndim() == 0) param_.stride = Shape3(1, 1, 1);
if (param_.pad.ndim() == 0) param_.pad = Shape3(0, 0, 0);
}
}

Expand Down