2323import pytest
2424import tempfile
2525
26- def op_export_test (op_name , Model , inputs , tmp_path ):
27- def export_to_onnx (model , op_name , inputs ):
28- model_path = '{}/{}' .format (tmp_path , op_name )
26+ def def_model (op_name , ** params ):
27+ class Model (HybridBlock ):
28+ def __init__ (self , ** kwargs ):
29+ super (Model , self ).__init__ (** kwargs )
30+
31+ def hybrid_forward (self , F , * inputs ):
32+ names = op_name .split ('.' )
33+ func = F
34+ for name in names :
35+ func = getattr (func , name )
36+ out = func (* inputs , ** params )
37+ return out
38+ return Model
39+
40+ def op_export_test (model_name , Model , inputs , tmp_path ):
41+ def export_to_onnx (model , model_name , inputs ):
42+ model_path = '{}/{}' .format (tmp_path , model_name )
2943 model .export (model_path , epoch = 0 )
3044 sym_file = '{}-symbol.json' .format (model_path )
3145 params_file = '{}-0000.params' .format (model_path )
3246 dtype = inputs [0 ].dtype
33- onnx_file = '{}/{}.onnx' .format (tmp_path , op_name )
34- mx .contrib .onnx .export_model (sym_file , params_file , [i .shape for i in inputs ],
47+ onnx_file = '{}/{}.onnx' .format (tmp_path , model_name )
48+ mx .contrib .onnx .export_model (sym_file , params_file , [inp .shape for inp in inputs ],
3549 dtype , onnx_file )
3650 return onnx_file
51+
3752 def onnx_rt (onnx_file , inputs ):
3853 sess = rt .InferenceSession (onnx_file )
3954 input_dict = dict ((sess .get_inputs ()[i ].name , inputs [i ].asnumpy ()) for i in range (len (inputs )))
@@ -45,90 +60,47 @@ def onnx_rt(onnx_file, inputs):
4560 model .initialize (ctx = mx .cpu (0 ))
4661 model .hybridize ()
4762 pred_nat = model (* inputs )
48- onnx_file = export_to_onnx (model , op_name , inputs )
63+ onnx_file = export_to_onnx (model , model_name , inputs )
4964 pred_onx = onnx_rt (onnx_file , inputs )
5065 assert_almost_equal (pred_nat , pred_onx )
5166
5267
53- def test_onnx_export_abs ():
54- with tempfile .TemporaryDirectory () as tmp_path :
55- class Model (HybridBlock ):
56- def __init__ (self , ** kwargs ):
57- super (Model , self ).__init__ (** kwargs )
58- def hybrid_forward (self , F , x ):
59- out = F .abs (x )
60- return out
61- x = mx .nd .array ([[- 2 , - 1 ], [0 , 99 ]], dtype = 'float32' )
62- op_export_test ('abs' , Model , [x ], tmp_path )
63-
64- def test_onnx_export_slice ():
65- with tempfile .TemporaryDirectory () as tmp_path :
66- class Model (HybridBlock ):
67- def __init__ (self , ** kwargs ):
68- super (Model , self ).__init__ (** kwargs )
69- def hybrid_forward (self , F , x ):
70- out = F .slice (x , begin = (0 ,1 ), end = (2 ,4 ))
71- return out
72- x = mx .nd .array ([[1 ,2 ,3 ,4 ],[5 ,6 ,7 ,8 ],[9 ,10 ,11 ,12 ]], dtype = 'float32' )
73- op_export_test ('slice' , Model , [x ], tmp_path )
74-
75- def test_onnx_export_stack ():
76- with tempfile .TemporaryDirectory () as tmp_path :
77- dtype = 'float32'
78- class Model (HybridBlock ):
79- def __init__ (self , ** kwargs ):
80- super (Model , self ).__init__ (** kwargs )
81- def hybrid_forward (self , F , x , y ):
82- out = F .stack (x , y )
83- return out
84- x = mx .nd .array ([1 , 2 ], dtype = dtype )
85- y = mx .nd .array ([3 , 4 ], dtype = dtype )
86- op_export_test ('stack' , Model , [x , y ], tmp_path )
87-
88- def test_onnx_export_zeros_like ():
89- with tempfile .TemporaryDirectory () as tmp_path :
90- class Model (HybridBlock ):
91- def __init__ (self , ** kwargs ):
92- super (Model , self ).__init__ (** kwargs )
93- def hybrid_forward (self , F , x ):
94- out = F .zeros_like (x )
95- return out
96- x = mx .nd .array ([[- 2 ,- 1 ,0 ],[0 ,50 ,99 ],[4 ,5 ,6 ],[7 ,8 ,9 ]], dtype = 'float32' )
97- op_export_test ('zeros_like' , Model , [x ], tmp_path )
68+ def test_onnx_export_abs (tmp_path ):
69+ M = def_model ('abs' )
70+ x = mx .nd .array ([[- 2 , - 1 ], [0 , 99 ]], dtype = 'float32' )
71+ op_export_test ('abs' , M , [x ], tmp_path )
72+
73+
74+ def test_onnx_export_slice (tmp_path ):
75+ M = def_model ('slice' , begin = (0 ,1 ), end = (2 ,4 ))
76+ x = mx .nd .array ([[1 ,2 ,3 ,4 ],[5 ,6 ,7 ,8 ],[9 ,10 ,11 ,12 ]], dtype = 'float32' )
77+ op_export_test ('slice' , M , [x ], tmp_path )
78+
79+
80+ def test_onnx_export_stack (tmp_path ):
81+ M = def_model ('stack' )
82+ x = mx .nd .array ([1 , 2 ], dtype = 'float32' )
83+ y = mx .nd .array ([3 , 4 ], dtype = 'float32' )
84+ op_export_test ('stack' , M , [x , y ], tmp_path )
85+
86+
87+ def test_onnx_export_zeros_like (tmp_path ):
88+ M = def_model ('zeros_like' )
89+ x = mx .nd .array ([[- 2 ,- 1 ,0 ],[0 ,50 ,99 ],[4 ,5 ,6 ],[7 ,8 ,9 ]], dtype = 'float32' )
90+ op_export_test ('zeros_like' , M , [x ], tmp_path )
91+
9892
9993@pytest .mark .parametrize ("dtype" , ["float32" , "double" ])
100- def test_onnx_export_arange_like (dtype ):
101- with tempfile .TemporaryDirectory () as tmp_path :
102- class Model (HybridBlock ):
103- def __init__ (self , ** kwargs ):
104- super (Model , self ).__init__ (** kwargs )
105- def hybrid_forward (self , F , x ):
106- out = F .contrib .arange_like (x )
107- return out
108- x = mx .nd .array ([[- 2 ,- 1 ,0 ],[0 ,50 ,99 ],[4 ,5 ,6 ],[7 ,8 ,9 ]], dtype = dtype )
109- op_export_test ('arange_like' , Model , [x ], tmp_path )
110-
111- def test_onnx_export_layernorm ():
112- with tempfile .TemporaryDirectory () as tmp_path :
113- dtype = 'float32'
114- class Model (HybridBlock ):
115- def __init__ (self , ** kwargs ):
116- super (Model , self ).__init__ (** kwargs )
117- def hybrid_forward (self , F , x , gamma , beta ):
118- out = F .LayerNorm (x , gamma , beta , axis = 1 )
119- return out
120- x = mx .nd .array ([[1 ,3 ],[2 ,4 ]], dtype = dtype )
121- gamma = mx .random .uniform (0 , 1 , x [0 ].shape ).astype (dtype )
122- beta = mx .random .uniform (0 , 1 , x [0 ].shape ).astype (dtype )
123- op_export_test ('LayerNorm' , Model , [x , gamma , beta ], tmp_path )
124-
125-
126- if __name__ == '__main__' :
127- test_onnx_export_abs ()
128- test_onnx_export_slice ()
129- test_onnx_export_stack ()
130- test_onnx_export_zeros_like ()
131- test_onnx_export_arange_like ('float32' )
132- test_onnx_export_arange_like ('double' )
133- test_onnx_export_layernorm ()
94+ def test_onnx_export_arange_like (tmp_path , dtype ):
95+ M = def_model ('contrib.arange_like' )
96+ x = mx .nd .array ([[- 2 ,- 1 ,0 ],[0 ,50 ,99 ],[4 ,5 ,6 ],[7 ,8 ,9 ]], dtype = dtype )
97+ op_export_test ('arange_like' , M , [x ], tmp_path )
98+
99+
100+ def test_onnx_export_layernorm (tmp_path ):
101+ M = def_model ('LayerNorm' , axis = 1 )
102+ x = mx .nd .array ([[1 ,3 ],[2 ,4 ]], dtype = 'float32' )
103+ gamma = mx .random .uniform (0 , 1 , x [0 ].shape , dtype = 'float32' )
104+ beta = mx .random .uniform (0 , 1 , x [0 ].shape , dtype = 'float32' )
105+ op_export_test ('LayerNorm' , M , [x , gamma , beta ], tmp_path )
134106
0 commit comments