diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 3bd7ac6524ff..0fda3cbb98ef 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -40,6 +40,7 @@ LARGE_SIZE = LARGE_X * SMALL_Y LARGE_TENSOR_SHAPE = 2**32 RNN_LARGE_TENSOR = 2**28 +LARGE_SQ_X = 80000 @pytest.mark.timeout(0) @@ -1755,3 +1756,21 @@ def test_sparse_dot(): assert out.asnumpy()[0][0] == 2 assert out.shape == (2, 2) + +def test_linalg_operators(): + def check_syrk_batch(): + # test both forward and backward + # batch syrk will be applied to the last two dimensions + A = nd.zeros((1, LARGE_SQ_X, LARGE_SQ_X)) + for i in range(LARGE_SQ_X): + A[0,i,i] = 1 + A.attach_grad() + with mx.autograd.record(): + out = nd.linalg.syrk(A, alpha=2, transpose=False) + for i in range(LARGE_SQ_X): + assert out[0,i,i] == 2 + out.backward() + for i in range(LARGE_SQ_X): + assert A.grad[0,0,i] == 4 + + check_syrk_batch()