diff --git a/src/operator/sequence_last-inl.h b/src/operator/sequence_last-inl.h index 167f04c8165f..2f386dd820e8 100644 --- a/src/operator/sequence_last-inl.h +++ b/src/operator/sequence_last-inl.h @@ -153,6 +153,10 @@ class SequenceLastOp : public Operator { auto d1 = in_data[seq_last::kData].size(1); auto dsize = in_data[seq_last::kData].Size(); + if (dsize == 0) { + return; // noop if any input dimension is zero-sized, out_data is of a right shape + } + auto batch = (axis != 0) ? d0 : d1; auto max_seq_len = in_data[seq_last::kData].size(axis); auto rest_size = dsize / (d0 * d1); diff --git a/src/operator/sequence_mask-inl.h b/src/operator/sequence_mask-inl.h index 0934036f23a2..2fc698aab205 100644 --- a/src/operator/sequence_mask-inl.h +++ b/src/operator/sequence_mask-inl.h @@ -97,6 +97,11 @@ class SequenceMaskOp : public Operator { auto d0 = in_data[seq_mask::kData].size(0); auto d1 = in_data[seq_mask::kData].size(1); auto dsize = in_data[seq_mask::kData].Size(); + + if (dsize == 0) { + return; // noop if any input dimension is zero-sized, out_data is of a right shape + } + auto rest_size = dsize / (d0 * d1); Shape<3> s3 = Shape3(d0, d1, rest_size); diff --git a/src/operator/sequence_reverse-inl.h b/src/operator/sequence_reverse-inl.h index 68d596778b4a..e19e643744b8 100644 --- a/src/operator/sequence_reverse-inl.h +++ b/src/operator/sequence_reverse-inl.h @@ -136,6 +136,11 @@ class SequenceReverseOp : public Operator { auto max_seq_len = in_data[seq_reverse::kData].size(0); auto n = in_data[seq_reverse::kData].size(1); auto total_size = in_data[seq_reverse::kData].Size(); + + if (total_size == 0) { + return; // noop if any input dimension is zero-sized, out_data is of a right shape + } + auto rest_dim = static_cast(total_size / n / max_seq_len); Shape<3> s3 = Shape3(max_seq_len, n, rest_dim); diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 7a8536474211..cc94db5b0efc 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -9428,3 +9428,28 @@ def test_sldwin_atten_op_impl(batch_size, seq_length, num_heads, test_sldwin_atten_op_impl(2, 128, 2, 8, 16, symmetric, d) test_sldwin_atten_op_impl(1, 8, 2, 4, 2, symmetric, d) +def test_zero_sized_dim(): + + mx.util.set_np_shape(True) # Must be done to prevent zero-sized dimension conversion to 'unknown' + + def seq_last(): + """Test for issue: https://github.com/apache/incubator-mxnet/issues/18938""" + data = mx.nd.array(np.random.rand(1, 0, 0)) + res = mx.nd.op.SequenceLast(data) + assert data.shape[1:] == res.shape + + def seq_mask(): + """Test for issue: https://github.com/apache/incubator-mxnet/issues/18939""" + data = mx.nd.array(np.random.rand(0, 1, 1)) + res = mx.nd.op.SequenceMask(data) + assert data.shape == res.shape + + def seq_reverse(): + """Test for issue: https://github.com/apache/incubator-mxnet/issues/18940""" + data = mx.nd.array(np.random.rand(0, 1, 1)) + res = mx.nd.op.SequenceReverse(data) + assert data.shape == res.shape + + seq_last() + seq_reverse() + seq_mask() \ No newline at end of file