Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2047,3 +2047,30 @@ def convert_broadcast_to(node, **kwargs):
)

return [tensor_node, expand_node]


@mx_op.register("UpSampling")
def convert_upsample(node, **kwargs):
"""Map MXNet's UpSampling operator attributes to onnx's Upsample operator
and return the created node.
"""
name, input_nodes, attrs = get_inputs(node, kwargs)

sample_type = attrs.get('sample_type', 'nearest')
sample_type = 'linear' if sample_type == 'bilinear' else sample_type
scale = convert_string_to_list(attrs.get('scale'))
scaleh = scalew = float(scale[0])
if len(scale) > 1:
scaleh = float(scale[0])
scalew = float(scale[1])
scale = [1.0, 1.0, scaleh, scalew]

node = onnx.helper.make_node(
'Upsample',
input_nodes,
[name],
scales=scale,
mode=sample_type,
name=name
)
return [node]
5 changes: 3 additions & 2 deletions python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ._op_translations import tanh, arccos, arcsin, arctan, _cos, _sin, _tan
from ._op_translations import softplus, shape, gather, lp_pooling, size
from ._op_translations import ceil, floor, hardsigmoid, global_lppooling
from ._op_translations import concat, hardmax
from ._op_translations import concat, hardmax, upsampling
from ._op_translations import leaky_relu, _elu, _prelu, _selu, softmax, fully_connected
from ._op_translations import global_avgpooling, global_maxpooling, linalg_gemm
from ._op_translations import sigmoid, pad, relu, matrix_multiplication, batch_norm
Expand Down Expand Up @@ -147,5 +147,6 @@
'DepthToSpace' : depthtospace,
'SpaceToDepth' : spacetodepth,
'Hardmax' : hardmax,
'LpNormalization' : lpnormalization
'LpNormalization' : lpnormalization,
'Upsample' : upsampling
}
21 changes: 21 additions & 0 deletions python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,3 +777,24 @@ def lpnormalization(attrs, inputs, proto_obj):
axis = int(attrs.get("axis", -1))
new_attrs.update(axis=axis)
return 'norm', new_attrs, inputs


def upsampling(attrs, inputs, proto_obj):
"""Rearranges blocks of spatial data into depth."""
new_attrs = translation_utils._fix_attribute_names(attrs, {'scales': 'scale',
'mode': 'sample_type'})
sample_type = new_attrs.get('sample_type', 'nearest')
if sample_type != 'nearest':
raise NotImplementedError("Operator {} in ONNX supports 'linear' mode "
"for linear, bilinear, trilinear etc. There is no "
"way to distinguish these so far. Therefore, supporting "
"import of only nearest neighbor upsampling for now. "
"https://github.com/onnx/onnx/issues/1774. "
"Use contrib.BilinearResize2D for bilinear mode."
.format('UpSample'))

scale = tuple(new_attrs.get('scale'))[2:]
scale = tuple([int(s) for s in scale])
mx_op = symbol.UpSampling(inputs[0], scale=scale, sample_type=sample_type)

return mx_op, new_attrs, inputs
141 changes: 116 additions & 25 deletions src/operator/nn/upsampling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,18 @@ enum UpSamplingMultiInputMode {kConcat, kSum};
} // namespace up_enum

struct UpSamplingParam : public dmlc::Parameter<UpSamplingParam> {
int scale;
TShape scale;
int num_filter;
int sample_type;
int num_args;
int multi_input_mode;
uint64_t workspace;
DMLC_DECLARE_PARAMETER(UpSamplingParam) {
DMLC_DECLARE_FIELD(scale)
.set_range(1, 1000)
.describe("Up sampling scale");
.set_default(TShape())
.describe("Up sampling scale. Integer or tuple of integers. "
"Different scale per dimension is allowed only for "
"nearest neighbor upsampling.");
DMLC_DECLARE_FIELD(num_filter)
.describe("Input filter. Only used by bilinear sample_type.")
.set_default(0);
Expand All @@ -82,6 +84,57 @@ struct UpSamplingParam : public dmlc::Parameter<UpSamplingParam> {
}
}; // struct UpSamplingParam

template<typename xpu, typename DTyp, typename AccReal>
void SpatialUpSamplingNearestUpdateOutput(mshadow::Stream<cpu> *s,
const std::vector<TBlob> &in_data,
std::vector<TBlob> *out_data) {
Tensor<xpu, 4, DTyp> itensor = in_data[0].get<xpu, 4, DTyp>(s);
Tensor<xpu, 4, DTyp> otensor = (*out_data)[0].get<xpu, 4, DTyp>(s);

int outputHeight = otensor.size(2);
int outputWidth = otensor.size(3);
int inputHeight = itensor.size(2);
int inputWidth = itensor.size(3);

int dW = outputWidth / inputWidth;
int dH = outputHeight / inputHeight;
int idim = itensor.shape_.kDimension;

// dims
int osz0 = otensor.size(0);
int osz1 = otensor.size(1);
int osz2 = otensor.size(2);
int osz3 = 1;
if (idim > 3) {
osz3 = otensor.size(3);
}

// perform the upsampling
int i0, i1, i2, i3;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about using index_t as datatype for all of them
Since for large operator support it was found to be useful -#13418

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll make this change

int iout[4]; // Output indices
int iin[4]; // Input indices

for (i0 = 0; i0 < osz0; i0++) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this nested for loop be vectorized?

iout[0] = i0;
iin[0] = i0;
for (i1 = 0; i1 < osz1; i1++) {
iout[1] = i1;
iin[1] = i1;
for (i2 = 0; i2 < osz2; i2++) {
iout[2] = i2;
iin[2] = i2;
int in_y = i2 / dH;
for (i3 = 0; i3 < osz3; i3++) {
iout[3] = i3;
iin[3] = i3;
int in_x = i3 / dW;
otensor[i0][i1][i2][i3] = itensor[i0][i1][in_y][in_x];
}
}
}
}
}

template<typename xpu, typename DType>
void UpSamplingForward(const OpContext &ctx, const UpSamplingParam &param,
const std::vector<TBlob> &in_data,
Expand All @@ -96,26 +149,37 @@ void UpSamplingForward(const OpContext &ctx, const UpSamplingParam &param,
}
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 4, DType> out = out_data[up_enum::kOut].get<xpu, 4, DType>(s);
std::vector<TBlob> outdata = out_data;
if (param.num_args > 1) {
int begin = 0;
for (int i = 0; i < param.num_args; ++i) {
Tensor<xpu, 4, DType> data = in_data[i].get<xpu, 4, DType>(s);
int end = begin + data.size(1);
int scale = out_data[up_enum::kOut].size(2)/in_data[i].size(2);
if (param.multi_input_mode == up_enum::kSum) {
if (i == 0) {
Assign(out, req[up_enum::kOut], upsampling_nearest(data, scale));
MSHADOW_REAL_TYPE_SWITCH_EX(in_data[0].type_flag_, DTyp, AccReal, {
SpatialUpSamplingNearestUpdateOutput<xpu, DTyp, AccReal>(s, in_data, &outdata);
out = out_data[up_enum::kOut].get<xpu, 4, DType>(s);
});
} else {
out += upsampling_nearest(data, scale);
MSHADOW_REAL_TYPE_SWITCH_EX(in_data[0].type_flag_, DTyp, AccReal, {
SpatialUpSamplingNearestUpdateOutput<xpu, DTyp, AccReal>(s, in_data, &outdata);
out += out_data[up_enum::kOut].get<xpu, 4, DType>(s);
});
}
} else {
Assign(slice<1>(out, begin, end), req[up_enum::kOut], upsampling_nearest(data, scale));
MSHADOW_REAL_TYPE_SWITCH_EX(in_data[0].type_flag_, DTyp, AccReal, {
SpatialUpSamplingNearestUpdateOutput<xpu, DTyp, AccReal>(s, in_data, &outdata);
slice<1>(out, begin, end) = out_data[up_enum::kOut].get<xpu, 4, DType>(s);
});
}
begin = end;
}
} else {
Tensor<xpu, 4, DType> data = in_data[up_enum::kData].get<xpu, 4, DType>(s);
Assign(out, req[up_enum::kOut], upsampling_nearest(data, param.scale));
MSHADOW_REAL_TYPE_SWITCH_EX(in_data[0].type_flag_, DTyp, AccReal, {
SpatialUpSamplingNearestUpdateOutput<xpu, DTyp, AccReal>(s, in_data, &outdata);
out = out_data[up_enum::kOut].get<xpu, 4, DType>(s);
});
}
}

Expand All @@ -134,44 +198,71 @@ void UpSamplingBackward(const OpContext &ctx, const UpSamplingParam &param,
Tensor<xpu, 4, DType> input_grad = in_grad[i].get<xpu, 4, DType>(s);
mshadow::Shape<2> in_shape = Shape2(input_grad.shape_[2], input_grad.shape_[3]);
int end = begin + input_grad.size(1);
int scale = grad.size(2)/in_shape[0];
int scale_h = grad.size(2)/in_shape[0];
int scale_w = grad.size(3)/in_shape[1];
if (param.multi_input_mode == up_enum::kSum) {
Assign(input_grad, req[i],
pool<mshadow::red::sum>(grad,
in_shape,
scale,
scale,
scale,
scale));
scale_h,
scale_w,
scale_h,
scale_w));
} else {
Assign(input_grad, req[i],
pool<mshadow::red::sum>(slice<1>(grad, begin, end),
in_shape,
scale,
scale,
scale,
scale));
scale_h,
scale_w,
scale_h,
scale_w));
}
begin = end;
}
} else {
Tensor<xpu, 4, DType> input_grad = in_grad[up_enum::kData].get<xpu, 4, DType>(s);
mshadow::Shape<2> in_shape = Shape2(input_grad.shape_[2], input_grad.shape_[3]);
int scale_h = 1;
int scale_w = 1;
if (param.scale.ndim() == 1) {
scale_h = param.scale[0];
scale_w = param.scale[0];
} else if (param.scale.ndim() == 2) {
scale_h = param.scale[0];
scale_w = param.scale[1];
} else if (param.scale.ndim() == 4) {
scale_h = param.scale[2];
scale_w = param.scale[3];
}
Assign(input_grad, req[up_enum::kData],
pool<mshadow::red::sum>(grad,
in_shape,
param.scale,
param.scale,
param.scale,
param.scale));
scale_h,
scale_w,
scale_h,
scale_w));
}
}

static inline DeconvolutionParam GetDeconvolutionParam(const UpSamplingParam& param) {
DeconvolutionParam p = DeconvolutionParam();
int kernel = 2 * param.scale - param.scale % 2;
int stride = param.scale;
int pad = static_cast<int>(ceil((param.scale - 1) / 2.));
int scale_h = 1;
int scale_w = 1;
if (param.scale.ndim() == 1) {
scale_h = param.scale[0];
scale_w = param.scale[0];
} else if (param.scale.ndim() == 2) {
scale_h = param.scale[0];
scale_w = param.scale[1];
} else if (param.scale.ndim() == 4) {
scale_h = param.scale[2];
scale_w = param.scale[3];
}
CHECK_EQ(scale_h, scale_w) <<
"UpSamplingBilinear: Scale should be the same along all dimensions for bilinear upsampling";
int kernel = static_cast<int>(2.0 * scale_h - ::fmod(scale_h, 2));
int stride = scale_h;
int pad = static_cast<int>(ceil((scale_h - 1) / 2.));
p.workspace = param.workspace;
p.num_group = param.num_filter;
p.num_filter = param.num_filter;
Expand Down
22 changes: 18 additions & 4 deletions src/operator/nn/upsampling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,25 @@ static bool UpSamplingShape(const nnvm::NodeAttrs& attrs,
CHECK_GE(in_shape->size(), 1U);
const TShape &dshape = (*in_shape)[0];
TShape oshape = dshape;
int scale_h = 1;
int scale_w = 1;
if (param_.scale.ndim() == 1) {
scale_h = param_.scale[0];
scale_w = param_.scale[0];
} else if (param_.scale.ndim() == 2) {
scale_h = param_.scale[0];
scale_w = param_.scale[1];
} else if (param_.scale.ndim() == 4) {
scale_h = param_.scale[2];
scale_w = param_.scale[3];
}
if (param_.sample_type == up_enum::kNearest) {
CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
oshape[1] = 0;
for (auto& shape : *in_shape) {
CHECK_EQ(shape.ndim(), 4U) << \
"UpSamplingNearest: Input data should be 4D in (batch, channel, y, x)";
int oh = dshape[2]*param_.scale, ow = dshape[3]*param_.scale;
int oh = dshape[2]*scale_h, ow = dshape[3]*scale_w;
CHECK_EQ(oh%shape[2], 0U) << "UpSamplingNearest: input height of " << shape[2] << \
"does not divide output height of " << oh;
CHECK_EQ(ow%shape[3], 0U) << "UpSamplingNearest: input width of " << shape[3] << \
Expand All @@ -58,17 +70,19 @@ static bool UpSamplingShape(const nnvm::NodeAttrs& attrs,
}
} else {
CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]";
CHECK_EQ(scale_h, scale_w) <<
"UpSamplingBilinear: Scale should be the same along all dimensions for bilinear upsampling";
CHECK_EQ(dshape.ndim(), 4U) << \
"UpSamplingBilinear: Input data should be 4D in (batch, channel, y, x)";
if (dshape.ndim() == 0) return false;
int kernel = 2 * param_.scale - param_.scale % 2;
int kernel = static_cast<int>(2.0 * scale_h - ::fmod(scale_h, 2));
SHAPE_ASSIGN_CHECK(*in_shape,
up_enum::kWeight,
mshadow::Shape4(dshape[1], 1, kernel, kernel));
oshape = dshape;
}
oshape[2] = dshape[2] * param_.scale;
oshape[3] = dshape[3] * param_.scale;
oshape[2] = dshape[2] * scale_h;
oshape[3] = dshape[3] * scale_w;
out_shape->clear();
out_shape->push_back(oshape);
return true;
Expand Down
3 changes: 2 additions & 1 deletion tests/python-pytest/onnx/test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@
'test_max_',
'test_softplus',
'test_reduce_',
'test_split_equal'
'test_split_equal',
'test_upsample_n'
],
'import': ['test_gather',
'test_softsign',
Expand Down
Loading