Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit c68dffe

Browse files
authored
Add interleaved_matmul_* to npx namespace (#20375)
1 parent 0716dc0 commit c68dffe

File tree

6 files changed

+1442
-1049
lines changed

6 files changed

+1442
-1049
lines changed

src/operator/contrib/transformer.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,7 @@ void BackwardInterleavedMatMulEncDecValAttCPU(const nnvm::NodeAttrs& attrs,
648648
}
649649

650650
NNVM_REGISTER_OP(_contrib_interleaved_matmul_selfatt_qk)
651+
.add_alias("_npx_interleaved_matmul_selfatt_qk")
651652
.describe(R"code(Compute the matrix multiplication between the projections of
652653
queries and keys in multihead attention use as self attention.
653654
@@ -691,6 +692,7 @@ NNVM_REGISTER_OP(_backward_interleaved_matmul_selfatt_qk)
691692
.set_attr<FCompute>("FCompute<cpu>", BackwardInterleavedMatMulSelfAttQKCPU);
692693

693694
NNVM_REGISTER_OP(_contrib_interleaved_matmul_selfatt_valatt)
695+
.add_alias("_npx_interleaved_matmul_selfatt_valatt")
694696
.describe(R"code(Compute the matrix multiplication between the projections of
695697
values and the attention weights in multihead attention use as self attention.
696698
@@ -738,6 +740,7 @@ NNVM_REGISTER_OP(_backward_interleaved_matmul_selfatt_valatt)
738740
.set_attr<FCompute>("FCompute<cpu>", BackwardInterleavedMatMulSelfAttValAttCPU);
739741

740742
NNVM_REGISTER_OP(_contrib_interleaved_matmul_encdec_qk)
743+
.add_alias("_npx_interleaved_matmul_encdec_qk")
741744
.describe(R"code(Compute the matrix multiplication between the projections of
742745
queries and keys in multihead attention use as encoder-decoder.
743746
@@ -784,6 +787,7 @@ NNVM_REGISTER_OP(_backward_interleaved_matmul_encdec_qk)
784787
.set_attr<FCompute>("FCompute<cpu>", BackwardInterleavedMatMulEncDecQKCPU);
785788

786789
NNVM_REGISTER_OP(_contrib_interleaved_matmul_encdec_valatt)
790+
.add_alias("_npx_interleaved_matmul_encdec_valatt")
787791
.describe(R"code(Compute the matrix multiplication between the projections of
788792
values and the attention weights in multihead attention use as encoder-decoder.
789793

src/operator/tensor/elemwise_sum.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ NNVM_REGISTER_OP(add_n)
139139
MXNET_ADD_SPARSE_OP_ALIAS(add_n)
140140
MXNET_ADD_SPARSE_OP_ALIAS(ElementWiseSum)
141141
.add_alias("ElementWiseSum")
142+
.add_alias("_npx_add_n")
142143
.describe(R"doc(Adds all input arguments element-wise.
143144
144145
.. math::

src/operator/tensor/matrix_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,7 @@ NNVM_REGISTER_OP(_backward_slice_axis)
612612
.set_attr<FCompute>("FCompute<cpu>", SliceAxisGrad_<cpu>);
613613

614614
NNVM_REGISTER_OP(slice_like)
615+
.add_alias("_npx_slice_like")
615616
.describe(R"code(Slices a region of the array like the shape of another array.
616617
This function is similar to ``slice``, however, the `begin` are always `0`s
617618
and `end` of specific axes are inferred from the second input `shape_like`.

tests/python/gpu/test_numpy_op.py

Lines changed: 61 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# specific language governing permissions and limitations
1616

1717
import sys
18-
import numpy as _np
18+
import numpy as onp
1919
import pytest
2020
import mxnet as mx
2121
from mxnet import np
@@ -45,87 +45,87 @@ def dbg(name, data):
4545
print('{} = {}'.format(name, data))
4646

4747
configs = [
48-
('ii', [(5, 5)], lambda *args: (_np.eye(5),)),
49-
('ii->i', [(5, 5)], lambda *args: (_np.eye(5),)),
50-
('ij->i', [(5, 5)], lambda *args: (_np.ones((5, 5)),)),
51-
('...j->...', [(5, 5)], lambda *args: (_np.ones((5, 5)),)),
52-
('ji', [(2, 3)], lambda *args: (_np.ones((2, 3)),)),
53-
('ij->ji', [(2, 3)], lambda *args: (_np.ones((2, 3)),)),
54-
('ij, jk', [(5, 0), (0, 4)], lambda *args: (_np.empty((5, 0)), _np.empty((0, 4)))),
48+
('ii', [(5, 5)], lambda *args: (onp.eye(5),)),
49+
('ii->i', [(5, 5)], lambda *args: (onp.eye(5),)),
50+
('ij->i', [(5, 5)], lambda *args: (onp.ones((5, 5)),)),
51+
('...j->...', [(5, 5)], lambda *args: (onp.ones((5, 5)),)),
52+
('ji', [(2, 3)], lambda *args: (onp.ones((2, 3)),)),
53+
('ij->ji', [(2, 3)], lambda *args: (onp.ones((2, 3)),)),
54+
('ij, jk', [(5, 0), (0, 4)], lambda *args: (onp.empty((5, 0)), onp.empty((0, 4)))),
5555

5656
('i, i', [(5,), (5,)], lambda *args: (args[1], args[0])),
57-
('ij, j', [(5, 5), (5,)], lambda *args: (_np.tile(args[1][None, :], [5, 1]),
57+
('ij, j', [(5, 5), (5,)], lambda *args: (onp.tile(args[1][None, :], [5, 1]),
5858
args[0].sum(axis=0))),
59-
('...j, j', [(5, 5), (5,)], lambda *args: (_np.tile(args[1][None, :], [5, 1]),
60-
_np.sum(args[0], axis=0))),
61-
('..., ...', [(), (2, 3)], lambda *args: (_np.sum(args[1], axis=None),
62-
args[0] * _np.ones((2, 3)))),
63-
(', ij', [(), (2, 3)], lambda *args: (_np.sum(args[1], axis=None),
64-
args[0] * _np.ones((2, 3)))),
65-
('i, j', [(2,), (5, )], lambda *args: (_np.sum(args[1], axis=None) * _np.ones(2),
66-
_np.sum(args[0], axis=None) * _np.ones(5))),
67-
('ijk, jil->kl', [(3, 4, 5), (4, 3, 2)], lambda *args: (_np.tile(_np.transpose(_np.sum(args[1],
59+
('...j, j', [(5, 5), (5,)], lambda *args: (onp.tile(args[1][None, :], [5, 1]),
60+
onp.sum(args[0], axis=0))),
61+
('..., ...', [(), (2, 3)], lambda *args: (onp.sum(args[1], axis=None),
62+
args[0] * onp.ones((2, 3)))),
63+
(', ij', [(), (2, 3)], lambda *args: (onp.sum(args[1], axis=None),
64+
args[0] * onp.ones((2, 3)))),
65+
('i, j', [(2,), (5, )], lambda *args: (onp.sum(args[1], axis=None) * onp.ones(2),
66+
onp.sum(args[0], axis=None) * onp.ones(5))),
67+
('ijk, jil->kl', [(3, 4, 5), (4, 3, 2)], lambda *args: (onp.tile(onp.transpose(onp.sum(args[1],
6868
axis=-1))[:, :, None], [1, 1, 5]),
69-
_np.tile(_np.transpose(_np.sum(args[0],
69+
onp.tile(onp.transpose(onp.sum(args[0],
7070
axis=-1))[:, :, None], [1, 1, 2]))),
71-
('ijk, jil->kl', [(33, 44, 55), (44, 33, 22)], lambda *args: (_np.tile(_np.transpose(_np.sum(args[1],
71+
('ijk, jil->kl', [(33, 44, 55), (44, 33, 22)], lambda *args: (onp.tile(onp.transpose(onp.sum(args[1],
7272
axis=-1))[:, :, None], [1, 1, 55]),
73-
_np.tile(_np.transpose(_np.sum(args[0],
73+
onp.tile(onp.transpose(onp.sum(args[0],
7474
axis=-1))[:, :, None], [1, 1, 22]))),
75-
('ki, jk->ij', [(3, 2), (4, 3)], lambda *args: (_np.tile(args[1].sum(axis=0)[:, None], [1, 2]),
76-
_np.tile(args[0].sum(axis=1)[None, :], [4, 1]))),
77-
('ki, ...k->i...', [(3, 2), (4, 3)], lambda *args: (_np.tile(args[1].sum(axis=0)[:, None], [1, 2]),
78-
_np.tile(args[0].sum(axis=1)[None, :], [4, 1]))),
79-
('k..., jk', [(3, 2), (4, 3)], lambda *args: (_np.tile(args[1].sum(axis=0)[:, None], [1, 2]),
80-
_np.tile(args[0].sum(axis=1)[None, :], [4, 1]))),
75+
('ki, jk->ij', [(3, 2), (4, 3)], lambda *args: (onp.tile(args[1].sum(axis=0)[:, None], [1, 2]),
76+
onp.tile(args[0].sum(axis=1)[None, :], [4, 1]))),
77+
('ki, ...k->i...', [(3, 2), (4, 3)], lambda *args: (onp.tile(args[1].sum(axis=0)[:, None], [1, 2]),
78+
onp.tile(args[0].sum(axis=1)[None, :], [4, 1]))),
79+
('k..., jk', [(3, 2), (4, 3)], lambda *args: (onp.tile(args[1].sum(axis=0)[:, None], [1, 2]),
80+
onp.tile(args[0].sum(axis=1)[None, :], [4, 1]))),
8181
(('ij,jk'), [(2, 5), (5, 2)],
82-
lambda *args: (_np.dot(_np.ones((2, 2)), args[1].T),
83-
_np.dot(args[0].T, _np.ones((2, 2))))),
82+
lambda *args: (onp.dot(onp.ones((2, 2)), args[1].T),
83+
onp.dot(args[0].T, onp.ones((2, 2))))),
8484
(('ij,jk,kl'), [(2, 2), (2, 5), (5, 2)],
85-
lambda *args: (_np.dot(_np.ones((2, 2)), _np.dot(args[1], args[2]).T),
86-
_np.dot(args[0].T, _np.dot(_np.ones((2, 2)), args[2].T)),
87-
_np.dot(_np.dot(args[0], args[1]).T, _np.ones((2, 2))))),
85+
lambda *args: (onp.dot(onp.ones((2, 2)), onp.dot(args[1], args[2]).T),
86+
onp.dot(args[0].T, onp.dot(onp.ones((2, 2)), args[2].T)),
87+
onp.dot(onp.dot(args[0], args[1]).T, onp.ones((2, 2))))),
8888
(('ij,jk,kl->il'), [(2, 2), (2, 5), (5, 2)],
89-
lambda *args: (_np.dot(_np.ones((2, 2)), _np.dot(args[1], args[2]).T),
90-
_np.dot(args[0].T, _np.dot(_np.ones((2, 2)), args[2].T)),
91-
_np.dot(_np.dot(args[0], args[1]).T, _np.ones((2, 2))))),
89+
lambda *args: (onp.dot(onp.ones((2, 2)), onp.dot(args[1], args[2]).T),
90+
onp.dot(args[0].T, onp.dot(onp.ones((2, 2)), args[2].T)),
91+
onp.dot(onp.dot(args[0], args[1]).T, onp.ones((2, 2))))),
9292
(('ij,jk,kl->il'), [(67, 89), (89, 55), (55, 99)],
93-
lambda *args: (_np.dot(_np.ones((67, 99)), _np.dot(args[1], args[2]).T),
94-
_np.dot(args[0].T, _np.dot(_np.ones((67, 99)), args[2].T)),
95-
_np.dot(_np.dot(args[0], args[1]).T, _np.ones((67, 99))))),
93+
lambda *args: (onp.dot(onp.ones((67, 99)), onp.dot(args[1], args[2]).T),
94+
onp.dot(args[0].T, onp.dot(onp.ones((67, 99)), args[2].T)),
95+
onp.dot(onp.dot(args[0], args[1]).T, onp.ones((67, 99))))),
9696
(('ij,jk,kl, lm->im'), [(12, 54), (54, 32), (32, 45), (45, 67)],
97-
lambda *args: (_np.dot(_np.ones((12, 67)), _np.dot(args[1], _np.dot(args[2], args[3])).T),
98-
_np.dot(args[0].T, _np.dot(_np.ones((12, 67)), _np.dot(args[2], args[3]).T)),
99-
_np.dot(_np.dot(args[0], args[1]).T, _np.dot(_np.ones((12, 67)), args[3].T)),
100-
_np.dot(_np.dot(args[0], _np.dot(args[1], args[2])).T, _np.ones((12, 67))))),
97+
lambda *args: (onp.dot(onp.ones((12, 67)), onp.dot(args[1], onp.dot(args[2], args[3])).T),
98+
onp.dot(args[0].T, onp.dot(onp.ones((12, 67)), onp.dot(args[2], args[3]).T)),
99+
onp.dot(onp.dot(args[0], args[1]).T, onp.dot(onp.ones((12, 67)), args[3].T)),
100+
onp.dot(onp.dot(args[0], onp.dot(args[1], args[2])).T, onp.ones((12, 67))))),
101101

102102
# broadcast axis
103-
('ij, ij -> i', [(1, 4), (2, 4)], lambda *args: (_np.sum(args[1], axis=0)[None, :],
104-
_np.tile(args[0], [2, 1]))),
103+
('ij, ij -> i', [(1, 4), (2, 4)], lambda *args: (onp.sum(args[1], axis=0)[None, :],
104+
onp.tile(args[0], [2, 1]))),
105105
('...ij, ...jk -> ...ik', [(1, 4), (4, 2)], lambda *args: (args[1].sum(axis=1)[None, :],
106-
_np.tile(args[0].sum(axis=0)[: ,None], [1, 2]))),
107-
('...ij, ...jk -> ...ik', [(2, 4), (4, 2)], lambda *args: (_np.tile(args[1].sum(axis=1)[None, :], [2, 1]),
108-
_np.tile(args[0].sum(axis=0)[: ,None], [1, 2]))),
106+
onp.tile(args[0].sum(axis=0)[: ,None], [1, 2]))),
107+
('...ij, ...jk -> ...ik', [(2, 4), (4, 2)], lambda *args: (onp.tile(args[1].sum(axis=1)[None, :], [2, 1]),
108+
onp.tile(args[0].sum(axis=0)[: ,None], [1, 2]))),
109109
('...ij, ...jk -> ...ik', [(3, 2, 1, 4), (3, 2, 4, 2)], lambda *args: (
110110
args[1].sum(axis=3)[:, :, None, :],
111-
_np.tile(args[0].sum(axis=2)[:, :, :, None], [1, 1, 1, 2]))),
111+
onp.tile(args[0].sum(axis=2)[:, :, :, None], [1, 1, 1, 2]))),
112112
('...ij, ...ik -> ...jk', [(1, 1, 1, 4), (1, 1, 1, 3)], lambda *args: (
113-
_np.tile(args[1].sum(axis=3)[:, :, :, None], [1, 1, 1, 4]),
114-
_np.tile(args[0].sum(axis=3)[:, :, : ,None], [1, 1, 1, 3]))),
113+
onp.tile(args[1].sum(axis=3)[:, :, :, None], [1, 1, 1, 4]),
114+
onp.tile(args[0].sum(axis=3)[:, :, : ,None], [1, 1, 1, 3]))),
115115
('...ij, ...jc -> ...ic', [(1, 1, 5, 3), (1, 1, 3, 2)], lambda *args: (
116-
_np.tile(args[1].sum(axis=3)[:, :, None, :], [1, 1, 5, 1]),
117-
_np.tile(args[0].sum(axis=2)[:, :, : ,None], [1, 1, 1, 2]))),
116+
onp.tile(args[1].sum(axis=3)[:, :, None, :], [1, 1, 5, 1]),
117+
onp.tile(args[0].sum(axis=2)[:, :, : ,None], [1, 1, 1, 2]))),
118118
('...ij, ...jc -> ...ic', [(1, 2, 5, 4), (1, 2, 4, 2)], lambda *args: (
119-
_np.tile(args[1].sum(axis=3)[:, :, None, :], [1, 1, 5, 1]),
120-
_np.tile(args[0].sum(axis=2)[:, :, : ,None], [1, 1, 1, 2]))),
119+
onp.tile(args[1].sum(axis=3)[:, :, None, :], [1, 1, 5, 1]),
120+
onp.tile(args[0].sum(axis=2)[:, :, : ,None], [1, 1, 1, 2]))),
121121
('...ij, ...jc -> ...ic', [(2, 1, 5, 4), (2, 1, 4, 2)], lambda *args: (
122-
_np.tile(args[1].sum(axis=3)[:, :, None, :], [1, 1, 5, 1]),
123-
_np.tile(args[0].sum(axis=2)[:, :, : ,None], [1, 1, 1, 2]))),
122+
onp.tile(args[1].sum(axis=3)[:, :, None, :], [1, 1, 5, 1]),
123+
onp.tile(args[0].sum(axis=2)[:, :, : ,None], [1, 1, 1, 2]))),
124124
# test with cuTensor using workspace
125125
(('ij,jk,kl->il'), [(64, 200), (200, 64), (64, 64)],
126-
lambda *args: (_np.dot(_np.ones((64, 64)), _np.dot(args[1], args[2]).T),
127-
_np.dot(args[0].T, _np.dot(_np.ones((64, 64)), args[2].T)),
128-
_np.dot(_np.dot(args[0], args[1]).T, _np.ones((64, 64)))))
126+
lambda *args: (onp.dot(onp.ones((64, 64)), onp.dot(args[1], args[2]).T),
127+
onp.dot(args[0].T, onp.dot(onp.ones((64, 64)), args[2].T)),
128+
onp.dot(onp.dot(args[0], args[1]).T, onp.ones((64, 64)))))
129129
]
130130

131131
dtypes = ['float16', 'float32', 'float64', 'int32']
@@ -144,11 +144,11 @@ def dbg(name, data):
144144
x = []
145145
x_np = []
146146
for shape in operands:
147-
tmp = _np.array(_np.random.uniform(-0.3, 0.3, shape), dtype=dtype)
147+
tmp = onp.array(onp.random.uniform(-0.3, 0.3, shape), dtype=dtype)
148148
x_np.append(tmp)
149149
x.append(np.array(tmp, dtype=dtype))
150150
x[-1].attach_grad()
151-
expected_np = _np.einsum(subscripts, *x_np, optimize=False, dtype=dtype).astype(dtype)
151+
expected_np = onp.einsum(subscripts, *x_np, optimize=False, dtype=dtype).astype(dtype)
152152
with mx.autograd.record():
153153
out_mx = test_einsum(*x)
154154
assert out_mx.shape == expected_np.shape

0 commit comments

Comments
 (0)