Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,8 @@ MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
nnvm::Tuple<dim_t> shape(dims, dims+ndim);
CHECK_GT(arr->shape().Size(), 0) << "Source ndarray's shape is undefined. Input shape: "
<< arr->shape();
TShape new_shape = mxnet::op::InferReshapeShape(shape, arr->shape(), reverse);
*ptr = arr->ReshapeWithRecord(new_shape);
*out = ptr;
Expand Down
36 changes: 30 additions & 6 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ inline TShape InferReshapeShape(const nnvm::Tuple<IType>& shape,
CHECK(d1 != -1 || d2 != -1) << "Split dims cannot both be -1.";
if (d1 == -1) d1 = d0 / d2;
if (d2 == -1) d2 = d0 / d1;
CHECK_EQ(d1 * d2, static_cast<IType>(d0)) <<
CHECK(d1 * d2 == static_cast<IType>(d0) || static_cast<IType>(d0) == IType(0)) <<
"Split dims " << d1 << ", " << d2 << " do not divide original dim " << d0;
tmp.push_back(d1);
tmp.push_back(d2);
Expand Down Expand Up @@ -151,13 +151,36 @@ inline TShape InferReshapeShape(const nnvm::Tuple<IType>& shape,
return oshape;
}

inline bool ReverseReshapeInferShape(TShape *in, const TShape& out) {
if (in->Size() && out.Size()) {
return true;
} else if (!out.Size()) {
return false;
} else {
int zero_axis = -1;
int non_zero_prod = 1;
for (index_t i = 0; i < in->ndim(); i++) {
if ((*in)[i] == 0) {
if (zero_axis != -1)
return false; // more than 1 zero found.
else
zero_axis = i;
} else {
non_zero_prod *= (*in)[i];
}
}
(*in)[zero_axis] = out.Size() / non_zero_prod;
return true;
}
}

inline bool ReshapeShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
const ReshapeParam& param_ = nnvm::get<ReshapeParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 1U) << "Input: [data]";
CHECK_EQ(out_attrs->size(), 1U);
const TShape &dshape = (*in_attrs)[0];
TShape &dshape = (*in_attrs)[0];
if (dshape.ndim() == 0) return false;
TShape oshape;
if (param_.shape.ndim() != 0) {
Expand All @@ -182,14 +205,15 @@ inline bool ReshapeShape(const nnvm::NodeAttrs& attrs,
oshape[inf_idx] = dshape.Size() / oshape.Size();
}
} else {
return (*out_attrs)[0].ndim();
return (*out_attrs)[0].ndim() && ReverseReshapeInferShape(&(*in_attrs)[0], (*out_attrs)[0]);
}
ReverseReshapeInferShape(&dshape, oshape);
CHECK_EQ(oshape.Size(), dshape.Size())
<< "Target shape size is different to source. "
<< "Target: " << oshape
<< "\nSource: " << dshape;
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
return true;
return ReverseReshapeInferShape(&(*in_attrs)[0], (*out_attrs)[0]);
}

inline bool FlattenShape(const nnvm::NodeAttrs& attrs,
Expand Down
35 changes: 26 additions & 9 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1943,11 +1943,11 @@ def test_bxor(a, b):
test_bmul(a, b)
test_bdiv(a, b)
'''
Flaky Test Disabled due to master build failure:
http://jenkins.mxnet-ci.amazon-ml.com/blue/organizations/jenkins/incubator-mxnet/detail/master/1248/pipeline
Flaky Test Disabled due to master build failure:
http://jenkins.mxnet-ci.amazon-ml.com/blue/organizations/jenkins/incubator-mxnet/detail/master/1248/pipeline
Github Issue: https://github.com/apache/incubator-mxnet/issues/11838
test_bmod(a, b)

test_bmod(a, b)
'''
test_bmod_int(a, b)
test_bpow(a, b)
Expand Down Expand Up @@ -2065,6 +2065,23 @@ def test_reshape_new(src_shape, shape_args, reverse, dst_shape):
assert np.square(exe.grad_dict['data'].asnumpy() - grad_npy.reshape(src_shape)).mean() < 1E-7, \
'Src Shape = %s, Shape Arguments = %s, Reverse = %s, Dst Shape = %s'\
%(str(src_shape), str(shape_args), str(reverse), str(dst_shape))

for i in range(len(src_shape)):
holdout_src_shape = list(src_shape)
holdout_src_shape[i] = 0
holdout_src_shape = tuple(holdout_src_shape)
net = mx.sym.Variable('data')
net = mx.sym.elemwise_add(net.reshape(shape_args, reverse=reverse), mx.sym.ones(shape=dst_shape))
input_shape, output_shape, __ = net.infer_shape(data=holdout_src_shape)
assert output_shape[0] == dst_shape, \
'Holdout Src Shape = %s, Shape Arguments = %s, Reverse = %s, Dst Shape = %s, ' \
'Output Shape = %s' %(str(holdout_src_shape), str(shape_args), str(reverse),
str(dst_shape), str(output_shape[0]))
assert input_shape[0] == src_shape, \
'Holdout Src Shape = %s, Shape Arguments = %s, Reverse = %s, Dst Shape = %s, ' \
'Output Shape = %s' %(str(holdout_src_shape), str(shape_args), str(reverse),
str(dst_shape), str(output_shape[0]))

# Test new api (Using shape)
test_cases = [
[(2, 3, 5, 5), (0, -1), False, (2, 75)],
Expand Down Expand Up @@ -6614,7 +6631,7 @@ def test_diag():
w = np.random.randint(2,9)
a_np = np.random.random((h, w)).astype(np.float32)
a = mx.nd.array(a_np).astype('float32')

# k == 0
r = mx.nd.diag(a)
assert_almost_equal(r.asnumpy(), np.diag(a_np))
Expand Down Expand Up @@ -6657,7 +6674,7 @@ def test_diag():
d = np.random.randint(2,9)
a_np = np.random.random((d))
a = mx.nd.array(a_np)

# k is random
k = np.random.randint(-d,d)
r = mx.nd.diag(a, k=k)
Expand Down Expand Up @@ -6724,7 +6741,7 @@ def test_invalid_block_size():
invalid_shape_inp = (n , c, h, w)
data = rand_ndarray(invalid_shape_inp, 'default')
assertRaises(MXNetError, mx.nd.depth_to_space, data, block)

test_invalid_depth_dim()
test_invalid_space_dim()
test_invalid_block_size()
Expand Down Expand Up @@ -6770,12 +6787,12 @@ def test_invalid_block_size():
invalid_shape_inp = (n, c, h, w)
data = rand_ndarray(invalid_shape_inp, 'default')
assertRaises(MXNetError, mx.nd.space_to_depth, data, block)

def test_invalid_depth_dim():
invalid_shape_inp = (n, 0, h, w)
data = rand_ndarray(invalid_shape_inp, 'default')
assertRaises(MXNetError, mx.nd.space_to_depth, data, block)

test_invalid_space_dim()
test_invalid_block_size()
test_invalid_depth_dim()
Expand Down