1616# under the License.
1717
1818import os
19+ import sys
1920import ctypes
2021import mxnet as mx
2122from mxnet .base import SymbolHandle , check_call , _LIB , mx_uint , c_str_array , c_str , mx_real_t
2223from mxnet .symbol import Symbol
2324import numpy as np
2425from mxnet .test_utils import assert_almost_equal
26+ from mxnet .numpy_extension import get_cuda_compute_capability
2527from mxnet import gluon
2628from mxnet .gluon import nn
2729from mxnet import nd
2830from mxnet .gluon .model_zoo import vision
2931
32+ curr_path = os .path .dirname (os .path .abspath (os .path .expanduser (__file__ )))
33+ sys .path .insert (0 , os .path .join (curr_path , '../unittest' ))
34+ from common import setup_module , with_seed , teardown
35+
3036####################################
3137######### FP32/FP16 tests ##########
3238####################################
@@ -60,7 +66,7 @@ def get_baseline(input_data):
6066 return output
6167
6268
63- def check_tensorrt_symbol (baseline , input_data , fp16_mode , tol ):
69+ def check_tensorrt_symbol (baseline , input_data , fp16_mode , rtol = None , atol = None ):
6470 sym , arg_params , aux_params = get_model (batch_shape = input_data .shape )
6571 trt_sym = sym .optimize_for ('TensorRT' , args = arg_params , aux = aux_params , ctx = mx .gpu (0 ),
6672 precision = 'fp16' if fp16_mode else 'fp32' )
@@ -69,17 +75,18 @@ def check_tensorrt_symbol(baseline, input_data, fp16_mode, tol):
6975 grad_req = 'null' , force_rebind = True )
7076
7177 output = executor .forward (is_train = False , data = input_data )
72- assert_almost_equal (output [0 ]. asnumpy () , baseline [0 ]. asnumpy (), atol = tol [ 0 ], rtol = tol [ 1 ] )
78+ assert_almost_equal (output [0 ], baseline [0 ], rtol = rtol , atol = atol )
7379
80+ @with_seed ()
7481def test_tensorrt_symbol ():
7582 batch_shape = (32 , 3 , 224 , 224 )
7683 input_data = mx .nd .random .uniform (shape = (batch_shape ), ctx = mx .gpu (0 ))
7784 baseline = get_baseline (input_data )
7885 print ("Testing resnet50 with TensorRT backend numerical accuracy..." )
7986 print ("FP32" )
80- check_tensorrt_symbol (baseline , input_data , fp16_mode = False , tol = ( 1e-4 , 1e-4 ) )
87+ check_tensorrt_symbol (baseline , input_data , fp16_mode = False )
8188 print ("FP16" )
82- check_tensorrt_symbol (baseline , input_data , fp16_mode = True , tol = ( 1e-1 , 1e-2 ) )
89+ check_tensorrt_symbol (baseline , input_data , fp16_mode = True , rtol = 1e-2 , atol = 1e-1 )
8390
8491##############################
8592######### INT8 tests ##########
@@ -135,17 +142,25 @@ def get_top1(logits):
135142
136143
137144def test_tensorrt_symbol_int8 ():
145+ ctx = mx .gpu (0 )
146+ cuda_arch = get_cuda_compute_capability (ctx )
147+ cuda_arch_min = 70
148+ if cuda_arch < cuda_arch_min :
149+ print ('Bypassing test_tensorrt_symbol_int8 on cuda arch {}, need arch >= {}).' .format (
150+ cuda_arch , cuda_arch_min ))
151+ return
152+
138153 # INT8 engine output are not lossless, so we don't expect numerical uniformity,
139154 # but we have to compare the TOP1 metric
140155
141156 batch_shape = (1 ,3 ,224 ,224 )
142157 sym , arg_params , aux_params = get_model (batch_shape = batch_shape )
143158 calibration_iters = 700
144- trt_sym = sym .optimize_for ('TensorRT' , args = arg_params , aux = aux_params , ctx = mx . gpu ( 0 ) ,
159+ trt_sym = sym .optimize_for ('TensorRT' , args = arg_params , aux = aux_params , ctx = ctx ,
145160 precision = 'int8' ,
146161 calibration_iters = calibration_iters )
147162
148- executor = trt_sym .simple_bind (ctx = mx . gpu () , data = batch_shape ,
163+ executor = trt_sym .simple_bind (ctx = ctx , data = batch_shape ,
149164 grad_req = 'null' , force_rebind = True )
150165
151166 dali_val_iter = get_dali_iter ()
0 commit comments