Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions src/operator/tensor/elemwise_unary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,34 @@ void HardSigmoidBackward(const nnvm::NodeAttrs& attrs,
});
}

struct ReshapeLikeParam : public dmlc::Parameter<ReshapeLikeParam> {
dmlc::optional<int> lhs_begin, rhs_begin, lhs_end, rhs_end;
DMLC_DECLARE_PARAMETER(ReshapeLikeParam) {
DMLC_DECLARE_FIELD(lhs_begin)
.set_default(dmlc::optional<int>())
.describe(
"Defaults to 0. "
"The beginning index along which the lhs dimensions are to be "
"reshaped. Supports negative indices.");
DMLC_DECLARE_FIELD(lhs_end)
.set_default(dmlc::optional<int>())
.describe("Defaults to None. "
"The ending index along which the lhs dimensions are to be "
"used for reshaping. Supports negative indices.");
DMLC_DECLARE_FIELD(rhs_begin)
.set_default(dmlc::optional<int>())
.describe("Defaults to 0. "
"The beginning index along which the rhs dimensions are to "
"be used for "
"reshaping. Supports negative indices.");
DMLC_DECLARE_FIELD(rhs_end)
.set_default(dmlc::optional<int>())
.describe("Defaults to None. "
"The ending index along which the rhs dimensions are to be "
"used for reshaping. Supports negative indices.");
}
};

/*! \brief Unary compute */
#define MXNET_OPERATOR_REGISTER_UNARY(__name$) \
NNVM_REGISTER_OP(__name$) \
Expand Down
118 changes: 103 additions & 15 deletions src/operator/tensor/elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -350,10 +350,109 @@ NNVM_REGISTER_OP(_identity_with_attr_like_rhs)
.add_argument("lhs", "NDArray-or-Symbol", "First input.")
.add_argument("rhs", "NDArray-or-Symbol", "Second input.");

void ReshapeLikeRangeCanonicalize(int ndims, const char *side,
const dmlc::optional<int> &begin,
const dmlc::optional<int> &end, int *cbegin,
int *cend) {
*cbegin = begin.has_value() ? begin.value() : 0;
if (*cbegin < 0)
*cbegin += ndims;

if (!end.has_value()) {
*cend = ndims;
} else {
*cend = end.value();
if (*cend < 0) {
*cend += ndims;
}
}
CHECK(*cend <= ndims) << "Invalid end for " << side << "_end=" << end
<< " as dimension number is " << ndims;
CHECK((*cbegin < *cend)) << "Invalid begin, end, get " << side
<< "_begin=" << begin << ", " << side
<< "_end=" << end;

CHECK(*cend >= 0) << "Invalid end for " << side << "_end=" << end;
CHECK(*cbegin >= 0) << "Invalid begin for " << side << "_begin=" << begin;
}

void GetReshapeLikeParams(const ReshapeLikeParam &param, const TShape &lshape,
const TShape &rshape, int *lhs_begin, int *lhs_end,
int *rhs_begin, int *rhs_end) {
// LHS params
ReshapeLikeRangeCanonicalize(lshape.ndim(), "lhs", param.lhs_begin,
param.lhs_end, lhs_begin, lhs_end);
// RHS params
ReshapeLikeRangeCanonicalize(rshape.ndim(), "rhs", param.rhs_begin,
param.rhs_end, rhs_begin, rhs_end);
}

bool ReshapeLikeShapeCompute(const nnvm::NodeAttrs &attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
const ReshapeLikeParam &param = nnvm::get<ReshapeLikeParam>(attrs.parsed);
const TShape &lshape = (*in_attrs)[0];
const TShape &rshape = (*in_attrs)[1];
int lhs_begin, lhs_end, rhs_begin, rhs_end;
GetReshapeLikeParams(param, lshape, rshape, &lhs_begin, &lhs_end, &rhs_begin,
&rhs_end);

int lhsrank = static_cast<int>(lshape.ndim());
int orank = lhsrank + (rhs_end - rhs_begin) - (lhs_end - lhs_begin);
TShape oshape(orank);

for (int i = 0; i < lhs_begin; ++i)
oshape[i] = lshape[i];

int opos = lhs_begin;
for (int i = rhs_begin; i < rhs_end; ++i) {
oshape[opos] = rshape[i];
opos += 1;
}

for (int i = lhs_end; i < lhsrank; ++i) {
oshape[opos] = lshape[i];
opos += 1;
}

CHECK_EQ((*in_attrs)[0].Size(), oshape.Size())
<< "Cannot reshape lhs with shape " << (*in_attrs)[0] << "to new "
<< "shape " << oshape << " because they have different "
<< "size.";
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
return true;
}

DMLC_REGISTER_PARAMETER(ReshapeLikeParam);
NNVM_REGISTER_OP(reshape_like)
.describe("Reshape lhs to have the same shape as rhs.")
.describe(R"code(Reshape some or all dimensions of `lhs` to have the same shape as some or all dimensions of `rhs`.

Returns a **view** of the `lhs` array with a new shape without altering any data.

Example::

x = [1, 2, 3, 4, 5, 6]
y = [[0, -4], [3, 2], [2, 2]]
reshape_like(x, y) = [[1, 2], [3, 4], [5, 6]]

More precise control over how dimensions are inherited is achieved by specifying \
slices over the `lhs` and `rhs` array dimensions. Only the sliced `lhs` dimensions \
are reshaped to the `rhs` sliced dimensions, with the non-sliced `lhs` dimensions staying the same.

Examples::

- lhs shape = (30,7), rhs shape = (15,2,4), lhs_begin=0, lhs_end=1, rhs_begin=0, rhs_end=2, output shape = (15,2,7)
- lhs shape = (3, 5), rhs shape = (1,15,4), lhs_begin=0, lhs_end=2, rhs_begin=1, rhs_end=2, output shape = (15)

Negative indices are supported, and `None` can be used for either `lhs_end` or `rhs_end` to indicate the end of the range.

Example::

- lhs shape = (30, 12), rhs shape = (4, 2, 2, 3), lhs_begin=-1, lhs_end=None, rhs_begin=1, rhs_end=None, output shape = (30, 2, 2, 3)

)code" ADD_FILELINE)
.set_num_inputs(2)
.set_attr_parser(ParamParser<ReshapeLikeParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) { return std::vector<std::string>{"lhs", "rhs"}; })
.set_attr<nnvm::FInplaceOption>(
Expand All @@ -365,19 +464,7 @@ NNVM_REGISTER_OP(reshape_like)
.set_attr<nnvm::FIgnoreInputs>("FIgnoreInputs",
[](const NodeAttrs& attrs) { return std::vector<uint32_t>(1, 1); })
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
.set_attr<nnvm::FInferShape>("FInferShape",
[](const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
if ((*in_attrs)[0].ndim()) {
CHECK_EQ((*in_attrs)[0].Size(), (*in_attrs)[1].Size())
<< "Cannot reshape lhs with shape " << (*in_attrs)[0] << "to rhs "
<< "with shape " << (*in_attrs)[1] << " because they have different "
<< "size.";
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[1]);
return true;
})
.set_attr<nnvm::FInferShape>("FInferShape", ReshapeLikeShapeCompute)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<nnvm::FGradient>(
"FGradient", [](const nnvm::NodePtr& n,
Expand Down Expand Up @@ -438,7 +525,8 @@ Example::
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64);
return out_attrs->at(0) != -1;
})
.add_argument("data", "NDArray-or-Symbol", "Input Array.");
.add_argument("data", "NDArray-or-Symbol", "Input Array.")
.add_arguments(ReshapeLikeParam::__FIELDS__());

void SizeComputeCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand Down
53 changes: 53 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2114,6 +2114,59 @@ def test_reshape_new(src_shape, shape_args, reverse, dst_shape):
assert_allclose(exe.grad_arrays[0].asnumpy(), out_grad_npy.reshape((5, 4, 3, 7)))


@with_seed()
def test_reshape_like():
def test_reshape_like_new(lhs_shape, rhs_shape, lbeg, lend, rbeg, rend, dst_shape):
lhs = mx.sym.Variable("lhs")
rhs = mx.sym.Variable("rhs")
net = mx.sym.reshape_like(lhs, rhs, lhs_begin=lbeg, lhs_end=lend, rhs_begin=rbeg, rhs_end=rend)
js = net.tojson()
net = mx.sym.load_json(js)
_, output_shape, __ = net.infer_shape(lhs=lhs_shape, rhs=rhs_shape)

assert output_shape[0] == dst_shape, \
'LHS Shape = %s, RHS Shape = %s, lhs_begin = %s, lhs_end = %s, rhs_begin= %s, rhs_end= %s'\
%(str(lhs_shape), str(rhs_shape), str(lbeg), str(lend), str(rbeg), str(rend))

lhs_npy = np.random.rand(*lhs_shape)
rhs_npy = np.random.rand(*rhs_shape)
grad_npy = np.random.rand(*dst_shape)

exe = net.simple_bind(default_context(), lhs=lhs_shape, rhs=rhs_shape)
exe.arg_dict['lhs'][:] = lhs_npy
exe.arg_dict['rhs'][:] = rhs_npy
exe.forward(is_train=True)
assert np.square(exe.outputs[0].asnumpy() - lhs_npy.reshape(dst_shape)).mean() < 1E-7, \
'LHS Shape = %s, RHS Shape = %s, lhs_begin = %s, lhs_end = %s, rhs_begin= %s, rhs_end= %s'\
%(str(lhs_shape), str(rhs_shape), str(lbeg), str(lend), str(rbeg), str(rend))
exe.backward(out_grads=mx.nd.array(grad_npy))
assert np.square(exe.grad_dict['lhs'].asnumpy() - grad_npy.reshape(lhs_shape)).mean() < 1E-7, \
'LHS Shape = %s, RHS Shape = %s, lhs_begin = %s, lhs_end = %s, rhs_begin= %s, rhs_end= %s'\
%(str(lhs_shape), str(rhs_shape), str(lbeg), str(lend), str(rbeg), str(rend))
# Test new api (Using shape)
test_cases = [
[(30,), (15,2,4), 0, None, 0, 2, (15,2)],
[(30,), (15,2,4), None, 1, None, 2, (15,2)],
[(30,7), (15,2,4), 0, 1, 0, 2, (15,2,7)],
[(3,5), (1,15,4), 0, 2, 1, 2, (15,)],
[(3,5), (1,15,4), 0, None, 1, -1, (15,)],
[(30,12), (4,2,2,3), -1, None, 1, None, (30,2,2,3)],
[(1,1,7,3,1,1), (81,1,1,21), 1, -1, 1, None, (1,1,1,21,1)]
]
# for test_case in test_cases:
for test_case in test_cases:
test_reshape_like_new(*test_case)

# Test old api
lhs = mx.sym.Variable("lhs")
rhs = mx.sym.Variable("rhs")
net = mx.sym.reshape_like(lhs, rhs)
js = net.tojson()
net = mx.sym.load_json(js)
_, output_shape, __ = net.infer_shape(lhs=(40, 30), rhs=(30,20,2))
assert(output_shape[0] == (30,20,2))


@with_seed()
def test_reduce():
sample_num = 500
Expand Down