Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/operator/tensor/la_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,14 @@ struct LaTrianParam : public dmlc::Parameter<LaTrianParam> {
}
};

// check if any dim will overflow 32-bit int
inline void check_large_dim(std::vector<dim_t> dims) {
for (dim_t dim : dims) {
CHECK_LE(dim, INT_MAX)
<< "Large matrix dimensions (>= 2^31) are not supported";
}
}

// Common function for shape inference for matrix mult and matrix mac.
inline bool LaMatrixMultMacOpShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
Expand All @@ -181,6 +189,11 @@ inline bool LaMatrixMultMacOpShape(const nnvm::NodeAttrs& attrs,
const int ndim((*in_attrs)[0].ndim()), axis(axis_param < 0 ? ndim + axis_param : axis_param);
CHECK(axis >= 0 && axis < ndim-1)
<< "Invalid row axis (" << axis_param << ")";
// Check if input matrix dims are too large
check_large_dim({(*in_attrs)[0][axis],
(*in_attrs)[0][ndim-1],
(*in_attrs)[1][axis],
(*in_attrs)[1][ndim-1]});
std::vector<int> oshape(ndim);
for ( int i = 0; i < ndim-1; ++i ) {
if (i != axis) {
Expand Down Expand Up @@ -225,6 +238,10 @@ inline bool LaTriangMatrixMultOpShape(const nnvm::NodeAttrs& attrs,
<< "Shapes of inputs 0, 1 must be the same, except on last two dimensions";
oshape[i] = (*in_attrs)[0][i];
}
// Check if the input matrix dims are too large; it suffices to check the second
// input only because the first is square whose size is bounded by memory
check_large_dim({(*in_attrs)[1][ndim-1],
(*in_attrs)[1][ndim-2]});
if ( param.rightside ) {
// We compute B * A where A is the first and B the second input.
CHECK_EQ((*in_attrs)[0][ndim-2], (*in_attrs)[1][ndim-1])
Expand Down Expand Up @@ -341,6 +358,9 @@ inline bool LaSyrkShape(const nnvm::NodeAttrs& attrs,
bool transpose = nnvm::get<LaSyrkParam>(attrs.parsed).transpose;
const int ndim = in_attr.ndim();
if ( ndim >= 2 ) {
// Check if input matrix dims are too large
check_large_dim({in_attr[ndim-1],
in_attr[ndim-2]});
// Forward shape inference.
std::vector<int> oshape(ndim);
for ( int i = 0; i < ndim-2; ++i ) {
Expand Down Expand Up @@ -371,6 +391,9 @@ inline bool LaLQFactShape(const nnvm::NodeAttrs& attrs,
const int ndim(in_a.ndim());
CHECK_LE(in_a[ndim-2], in_a[ndim-1])
<< "Input A shape wrong: Last dimension must be >= than second to last";
// Check if the last dimension is too large; it suffices to check the last dim
// only since the second to last dim <= last dim
check_large_dim({in_a[ndim-1]});
// Q must have same shape as A
SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_a);
std::vector<int> oshape_l(ndim);
Expand Down
51 changes: 50 additions & 1 deletion tests/nightly/test_large_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@

from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, create_vector
from mxnet import gluon, nd
from tests.python.unittest.common import with_seed
from tests.python.unittest.common import with_seed, assertRaises
from mxnet.base import MXNetError
from nose.tools import with_setup
import unittest

# dimension constants
LARGE_X = 4300000000
MEDIUM_X = 1000000000
INT32_MAX = 2**31-1


def test_nn():
Expand Down Expand Up @@ -1064,6 +1066,53 @@ def check_minimum():
check_minimum()


# openblas and cublas are known to not work well with large
# matrix dims under current configuration. checks are added
# to exit from such use cases
def test_linalg_large_dim():
def check_gemm():
A = nd.ones(shape=(1, INT32_MAX + 1, 1))
B = nd.ones(shape=(1, INT32_MAX + 1, 1))
C = nd.ones(shape=(1, 1, 1))
assertRaises(MXNetError, nd.linalg.gemm, \
A, B, C, transpose_b=True)

def check_gemm2():
A = nd.ones(shape=(1, 1, INT32_MAX + 1))
B = nd.ones(shape=(1, 1, INT32_MAX + 1))
assertRaises(MXNetError, nd.linalg.gemm2, \
A, B, transpose_b=True)

def check_trmm():
A = nd.ones(shape=(1, 1, 1))
B = nd.ones(shape=(1, INT32_MAX + 1, 1))
assertRaises(MXNetError, nd.linalg.trmm, \
A, B, rightside=True)

def check_trsm():
A = nd.ones(shape=(1, 1, 1))
B = nd.ones(shape=(1, 1, INT32_MAX + 1))
assertRaises(MXNetError, nd.linalg.trsm, \
A, B, rightside=False)

def check_syrk():
A = nd.ones(shape=(1, INT32_MAX + 1, 1))
assertRaises(MXNetError, nd.linalg.syrk, A)
assertRaises(MXNetError, nd.linalg.syrk, A, transpose=True)

def check_gelqf():
A = nd.ones(shape=(1, 1, INT32_MAX + 1))
assertRaises(MXNetError, nd.linalg.gelqf, A)

# batch input
check_gemm()
check_gemm2()
check_trmm()
check_trsm()
check_syrk()
check_gelqf()


if __name__ == '__main__':
import nose
nose.runmodule()