1515# specific language governing permissions and limitations
1616
1717import sys
18- import numpy as _np
18+ import numpy as onp
1919import pytest
2020import mxnet as mx
2121from mxnet import np
@@ -45,87 +45,87 @@ def dbg(name, data):
4545 print ('{} = {}' .format (name , data ))
4646
4747 configs = [
48- ('ii' , [(5 , 5 )], lambda * args : (_np .eye (5 ),)),
49- ('ii->i' , [(5 , 5 )], lambda * args : (_np .eye (5 ),)),
50- ('ij->i' , [(5 , 5 )], lambda * args : (_np .ones ((5 , 5 )),)),
51- ('...j->...' , [(5 , 5 )], lambda * args : (_np .ones ((5 , 5 )),)),
52- ('ji' , [(2 , 3 )], lambda * args : (_np .ones ((2 , 3 )),)),
53- ('ij->ji' , [(2 , 3 )], lambda * args : (_np .ones ((2 , 3 )),)),
54- ('ij, jk' , [(5 , 0 ), (0 , 4 )], lambda * args : (_np .empty ((5 , 0 )), _np .empty ((0 , 4 )))),
48+ ('ii' , [(5 , 5 )], lambda * args : (onp .eye (5 ),)),
49+ ('ii->i' , [(5 , 5 )], lambda * args : (onp .eye (5 ),)),
50+ ('ij->i' , [(5 , 5 )], lambda * args : (onp .ones ((5 , 5 )),)),
51+ ('...j->...' , [(5 , 5 )], lambda * args : (onp .ones ((5 , 5 )),)),
52+ ('ji' , [(2 , 3 )], lambda * args : (onp .ones ((2 , 3 )),)),
53+ ('ij->ji' , [(2 , 3 )], lambda * args : (onp .ones ((2 , 3 )),)),
54+ ('ij, jk' , [(5 , 0 ), (0 , 4 )], lambda * args : (onp .empty ((5 , 0 )), onp .empty ((0 , 4 )))),
5555
5656 ('i, i' , [(5 ,), (5 ,)], lambda * args : (args [1 ], args [0 ])),
57- ('ij, j' , [(5 , 5 ), (5 ,)], lambda * args : (_np .tile (args [1 ][None , :], [5 , 1 ]),
57+ ('ij, j' , [(5 , 5 ), (5 ,)], lambda * args : (onp .tile (args [1 ][None , :], [5 , 1 ]),
5858 args [0 ].sum (axis = 0 ))),
59- ('...j, j' , [(5 , 5 ), (5 ,)], lambda * args : (_np .tile (args [1 ][None , :], [5 , 1 ]),
60- _np .sum (args [0 ], axis = 0 ))),
61- ('..., ...' , [(), (2 , 3 )], lambda * args : (_np .sum (args [1 ], axis = None ),
62- args [0 ] * _np .ones ((2 , 3 )))),
63- (', ij' , [(), (2 , 3 )], lambda * args : (_np .sum (args [1 ], axis = None ),
64- args [0 ] * _np .ones ((2 , 3 )))),
65- ('i, j' , [(2 ,), (5 , )], lambda * args : (_np .sum (args [1 ], axis = None ) * _np .ones (2 ),
66- _np .sum (args [0 ], axis = None ) * _np .ones (5 ))),
67- ('ijk, jil->kl' , [(3 , 4 , 5 ), (4 , 3 , 2 )], lambda * args : (_np .tile (_np .transpose (_np .sum (args [1 ],
59+ ('...j, j' , [(5 , 5 ), (5 ,)], lambda * args : (onp .tile (args [1 ][None , :], [5 , 1 ]),
60+ onp .sum (args [0 ], axis = 0 ))),
61+ ('..., ...' , [(), (2 , 3 )], lambda * args : (onp .sum (args [1 ], axis = None ),
62+ args [0 ] * onp .ones ((2 , 3 )))),
63+ (', ij' , [(), (2 , 3 )], lambda * args : (onp .sum (args [1 ], axis = None ),
64+ args [0 ] * onp .ones ((2 , 3 )))),
65+ ('i, j' , [(2 ,), (5 , )], lambda * args : (onp .sum (args [1 ], axis = None ) * onp .ones (2 ),
66+ onp .sum (args [0 ], axis = None ) * onp .ones (5 ))),
67+ ('ijk, jil->kl' , [(3 , 4 , 5 ), (4 , 3 , 2 )], lambda * args : (onp .tile (onp .transpose (onp .sum (args [1 ],
6868 axis = - 1 ))[:, :, None ], [1 , 1 , 5 ]),
69- _np .tile (_np .transpose (_np .sum (args [0 ],
69+ onp .tile (onp .transpose (onp .sum (args [0 ],
7070 axis = - 1 ))[:, :, None ], [1 , 1 , 2 ]))),
71- ('ijk, jil->kl' , [(33 , 44 , 55 ), (44 , 33 , 22 )], lambda * args : (_np .tile (_np .transpose (_np .sum (args [1 ],
71+ ('ijk, jil->kl' , [(33 , 44 , 55 ), (44 , 33 , 22 )], lambda * args : (onp .tile (onp .transpose (onp .sum (args [1 ],
7272 axis = - 1 ))[:, :, None ], [1 , 1 , 55 ]),
73- _np .tile (_np .transpose (_np .sum (args [0 ],
73+ onp .tile (onp .transpose (onp .sum (args [0 ],
7474 axis = - 1 ))[:, :, None ], [1 , 1 , 22 ]))),
75- ('ki, jk->ij' , [(3 , 2 ), (4 , 3 )], lambda * args : (_np .tile (args [1 ].sum (axis = 0 )[:, None ], [1 , 2 ]),
76- _np .tile (args [0 ].sum (axis = 1 )[None , :], [4 , 1 ]))),
77- ('ki, ...k->i...' , [(3 , 2 ), (4 , 3 )], lambda * args : (_np .tile (args [1 ].sum (axis = 0 )[:, None ], [1 , 2 ]),
78- _np .tile (args [0 ].sum (axis = 1 )[None , :], [4 , 1 ]))),
79- ('k..., jk' , [(3 , 2 ), (4 , 3 )], lambda * args : (_np .tile (args [1 ].sum (axis = 0 )[:, None ], [1 , 2 ]),
80- _np .tile (args [0 ].sum (axis = 1 )[None , :], [4 , 1 ]))),
75+ ('ki, jk->ij' , [(3 , 2 ), (4 , 3 )], lambda * args : (onp .tile (args [1 ].sum (axis = 0 )[:, None ], [1 , 2 ]),
76+ onp .tile (args [0 ].sum (axis = 1 )[None , :], [4 , 1 ]))),
77+ ('ki, ...k->i...' , [(3 , 2 ), (4 , 3 )], lambda * args : (onp .tile (args [1 ].sum (axis = 0 )[:, None ], [1 , 2 ]),
78+ onp .tile (args [0 ].sum (axis = 1 )[None , :], [4 , 1 ]))),
79+ ('k..., jk' , [(3 , 2 ), (4 , 3 )], lambda * args : (onp .tile (args [1 ].sum (axis = 0 )[:, None ], [1 , 2 ]),
80+ onp .tile (args [0 ].sum (axis = 1 )[None , :], [4 , 1 ]))),
8181 (('ij,jk' ), [(2 , 5 ), (5 , 2 )],
82- lambda * args : (_np .dot (_np .ones ((2 , 2 )), args [1 ].T ),
83- _np .dot (args [0 ].T , _np .ones ((2 , 2 ))))),
82+ lambda * args : (onp .dot (onp .ones ((2 , 2 )), args [1 ].T ),
83+ onp .dot (args [0 ].T , onp .ones ((2 , 2 ))))),
8484 (('ij,jk,kl' ), [(2 , 2 ), (2 , 5 ), (5 , 2 )],
85- lambda * args : (_np .dot (_np .ones ((2 , 2 )), _np .dot (args [1 ], args [2 ]).T ),
86- _np .dot (args [0 ].T , _np .dot (_np .ones ((2 , 2 )), args [2 ].T )),
87- _np .dot (_np .dot (args [0 ], args [1 ]).T , _np .ones ((2 , 2 ))))),
85+ lambda * args : (onp .dot (onp .ones ((2 , 2 )), onp .dot (args [1 ], args [2 ]).T ),
86+ onp .dot (args [0 ].T , onp .dot (onp .ones ((2 , 2 )), args [2 ].T )),
87+ onp .dot (onp .dot (args [0 ], args [1 ]).T , onp .ones ((2 , 2 ))))),
8888 (('ij,jk,kl->il' ), [(2 , 2 ), (2 , 5 ), (5 , 2 )],
89- lambda * args : (_np .dot (_np .ones ((2 , 2 )), _np .dot (args [1 ], args [2 ]).T ),
90- _np .dot (args [0 ].T , _np .dot (_np .ones ((2 , 2 )), args [2 ].T )),
91- _np .dot (_np .dot (args [0 ], args [1 ]).T , _np .ones ((2 , 2 ))))),
89+ lambda * args : (onp .dot (onp .ones ((2 , 2 )), onp .dot (args [1 ], args [2 ]).T ),
90+ onp .dot (args [0 ].T , onp .dot (onp .ones ((2 , 2 )), args [2 ].T )),
91+ onp .dot (onp .dot (args [0 ], args [1 ]).T , onp .ones ((2 , 2 ))))),
9292 (('ij,jk,kl->il' ), [(67 , 89 ), (89 , 55 ), (55 , 99 )],
93- lambda * args : (_np .dot (_np .ones ((67 , 99 )), _np .dot (args [1 ], args [2 ]).T ),
94- _np .dot (args [0 ].T , _np .dot (_np .ones ((67 , 99 )), args [2 ].T )),
95- _np .dot (_np .dot (args [0 ], args [1 ]).T , _np .ones ((67 , 99 ))))),
93+ lambda * args : (onp .dot (onp .ones ((67 , 99 )), onp .dot (args [1 ], args [2 ]).T ),
94+ onp .dot (args [0 ].T , onp .dot (onp .ones ((67 , 99 )), args [2 ].T )),
95+ onp .dot (onp .dot (args [0 ], args [1 ]).T , onp .ones ((67 , 99 ))))),
9696 (('ij,jk,kl, lm->im' ), [(12 , 54 ), (54 , 32 ), (32 , 45 ), (45 , 67 )],
97- lambda * args : (_np .dot (_np .ones ((12 , 67 )), _np .dot (args [1 ], _np .dot (args [2 ], args [3 ])).T ),
98- _np .dot (args [0 ].T , _np .dot (_np .ones ((12 , 67 )), _np .dot (args [2 ], args [3 ]).T )),
99- _np .dot (_np .dot (args [0 ], args [1 ]).T , _np .dot (_np .ones ((12 , 67 )), args [3 ].T )),
100- _np .dot (_np .dot (args [0 ], _np .dot (args [1 ], args [2 ])).T , _np .ones ((12 , 67 ))))),
97+ lambda * args : (onp .dot (onp .ones ((12 , 67 )), onp .dot (args [1 ], onp .dot (args [2 ], args [3 ])).T ),
98+ onp .dot (args [0 ].T , onp .dot (onp .ones ((12 , 67 )), onp .dot (args [2 ], args [3 ]).T )),
99+ onp .dot (onp .dot (args [0 ], args [1 ]).T , onp .dot (onp .ones ((12 , 67 )), args [3 ].T )),
100+ onp .dot (onp .dot (args [0 ], onp .dot (args [1 ], args [2 ])).T , onp .ones ((12 , 67 ))))),
101101
102102 # broadcast axis
103- ('ij, ij -> i' , [(1 , 4 ), (2 , 4 )], lambda * args : (_np .sum (args [1 ], axis = 0 )[None , :],
104- _np .tile (args [0 ], [2 , 1 ]))),
103+ ('ij, ij -> i' , [(1 , 4 ), (2 , 4 )], lambda * args : (onp .sum (args [1 ], axis = 0 )[None , :],
104+ onp .tile (args [0 ], [2 , 1 ]))),
105105 ('...ij, ...jk -> ...ik' , [(1 , 4 ), (4 , 2 )], lambda * args : (args [1 ].sum (axis = 1 )[None , :],
106- _np .tile (args [0 ].sum (axis = 0 )[: ,None ], [1 , 2 ]))),
107- ('...ij, ...jk -> ...ik' , [(2 , 4 ), (4 , 2 )], lambda * args : (_np .tile (args [1 ].sum (axis = 1 )[None , :], [2 , 1 ]),
108- _np .tile (args [0 ].sum (axis = 0 )[: ,None ], [1 , 2 ]))),
106+ onp .tile (args [0 ].sum (axis = 0 )[: ,None ], [1 , 2 ]))),
107+ ('...ij, ...jk -> ...ik' , [(2 , 4 ), (4 , 2 )], lambda * args : (onp .tile (args [1 ].sum (axis = 1 )[None , :], [2 , 1 ]),
108+ onp .tile (args [0 ].sum (axis = 0 )[: ,None ], [1 , 2 ]))),
109109 ('...ij, ...jk -> ...ik' , [(3 , 2 , 1 , 4 ), (3 , 2 , 4 , 2 )], lambda * args : (
110110 args [1 ].sum (axis = 3 )[:, :, None , :],
111- _np .tile (args [0 ].sum (axis = 2 )[:, :, :, None ], [1 , 1 , 1 , 2 ]))),
111+ onp .tile (args [0 ].sum (axis = 2 )[:, :, :, None ], [1 , 1 , 1 , 2 ]))),
112112 ('...ij, ...ik -> ...jk' , [(1 , 1 , 1 , 4 ), (1 , 1 , 1 , 3 )], lambda * args : (
113- _np .tile (args [1 ].sum (axis = 3 )[:, :, :, None ], [1 , 1 , 1 , 4 ]),
114- _np .tile (args [0 ].sum (axis = 3 )[:, :, : ,None ], [1 , 1 , 1 , 3 ]))),
113+ onp .tile (args [1 ].sum (axis = 3 )[:, :, :, None ], [1 , 1 , 1 , 4 ]),
114+ onp .tile (args [0 ].sum (axis = 3 )[:, :, : ,None ], [1 , 1 , 1 , 3 ]))),
115115 ('...ij, ...jc -> ...ic' , [(1 , 1 , 5 , 3 ), (1 , 1 , 3 , 2 )], lambda * args : (
116- _np .tile (args [1 ].sum (axis = 3 )[:, :, None , :], [1 , 1 , 5 , 1 ]),
117- _np .tile (args [0 ].sum (axis = 2 )[:, :, : ,None ], [1 , 1 , 1 , 2 ]))),
116+ onp .tile (args [1 ].sum (axis = 3 )[:, :, None , :], [1 , 1 , 5 , 1 ]),
117+ onp .tile (args [0 ].sum (axis = 2 )[:, :, : ,None ], [1 , 1 , 1 , 2 ]))),
118118 ('...ij, ...jc -> ...ic' , [(1 , 2 , 5 , 4 ), (1 , 2 , 4 , 2 )], lambda * args : (
119- _np .tile (args [1 ].sum (axis = 3 )[:, :, None , :], [1 , 1 , 5 , 1 ]),
120- _np .tile (args [0 ].sum (axis = 2 )[:, :, : ,None ], [1 , 1 , 1 , 2 ]))),
119+ onp .tile (args [1 ].sum (axis = 3 )[:, :, None , :], [1 , 1 , 5 , 1 ]),
120+ onp .tile (args [0 ].sum (axis = 2 )[:, :, : ,None ], [1 , 1 , 1 , 2 ]))),
121121 ('...ij, ...jc -> ...ic' , [(2 , 1 , 5 , 4 ), (2 , 1 , 4 , 2 )], lambda * args : (
122- _np .tile (args [1 ].sum (axis = 3 )[:, :, None , :], [1 , 1 , 5 , 1 ]),
123- _np .tile (args [0 ].sum (axis = 2 )[:, :, : ,None ], [1 , 1 , 1 , 2 ]))),
122+ onp .tile (args [1 ].sum (axis = 3 )[:, :, None , :], [1 , 1 , 5 , 1 ]),
123+ onp .tile (args [0 ].sum (axis = 2 )[:, :, : ,None ], [1 , 1 , 1 , 2 ]))),
124124 # test with cuTensor using workspace
125125 (('ij,jk,kl->il' ), [(64 , 200 ), (200 , 64 ), (64 , 64 )],
126- lambda * args : (_np .dot (_np .ones ((64 , 64 )), _np .dot (args [1 ], args [2 ]).T ),
127- _np .dot (args [0 ].T , _np .dot (_np .ones ((64 , 64 )), args [2 ].T )),
128- _np .dot (_np .dot (args [0 ], args [1 ]).T , _np .ones ((64 , 64 )))))
126+ lambda * args : (onp .dot (onp .ones ((64 , 64 )), onp .dot (args [1 ], args [2 ]).T ),
127+ onp .dot (args [0 ].T , onp .dot (onp .ones ((64 , 64 )), args [2 ].T )),
128+ onp .dot (onp .dot (args [0 ], args [1 ]).T , onp .ones ((64 , 64 )))))
129129 ]
130130
131131 dtypes = ['float16' , 'float32' , 'float64' , 'int32' ]
@@ -144,11 +144,11 @@ def dbg(name, data):
144144 x = []
145145 x_np = []
146146 for shape in operands :
147- tmp = _np .array (_np .random .uniform (- 0.3 , 0.3 , shape ), dtype = dtype )
147+ tmp = onp .array (onp .random .uniform (- 0.3 , 0.3 , shape ), dtype = dtype )
148148 x_np .append (tmp )
149149 x .append (np .array (tmp , dtype = dtype ))
150150 x [- 1 ].attach_grad ()
151- expected_np = _np .einsum (subscripts , * x_np , optimize = False , dtype = dtype ).astype (dtype )
151+ expected_np = onp .einsum (subscripts , * x_np , optimize = False , dtype = dtype ).astype (dtype )
152152 with mx .autograd .record ():
153153 out_mx = test_einsum (* x )
154154 assert out_mx .shape == expected_np .shape
0 commit comments