diff --git a/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc b/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc index fc4809d7f1cb..96e0544af31d 100644 --- a/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc +++ b/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc @@ -143,6 +143,7 @@ std::tuple, builder_config->setInt8Calibrator(calibrator); } else { LOG(WARNING) << "TensorRT can't use int8 on this platform"; + calibrator->setDone(); calibrator = nullptr; } } @@ -177,6 +178,7 @@ std::tuple, trt_builder->setInt8Calibrator(calibrator); } else { LOG(WARNING) << "TensorRT can't use int8 on this platform"; + calibrator->setDone(); calibrator = nullptr; } } diff --git a/src/operator/subgraph/tensorrt/tensorrt_int8_calibrator.cc b/src/operator/subgraph/tensorrt/tensorrt_int8_calibrator.cc index 8ba7a3aecb63..d5ee350e5d3f 100644 --- a/src/operator/subgraph/tensorrt/tensorrt_int8_calibrator.cc +++ b/src/operator/subgraph/tensorrt/tensorrt_int8_calibrator.cc @@ -118,6 +118,10 @@ void TRTInt8Calibrator::writeCalibrationCache(const void* ptr, << " length=" << length; } +void TRTInt8Calibrator::setDone() { + done_ = true; +} + void TRTInt8Calibrator::waitAndSetDone() { std::unique_lock lk(mutex_); cv_.wait(lk, [&]{ return (!batch_is_set_ && !calib_running_) || done_; }); diff --git a/src/operator/subgraph/tensorrt/tensorrt_int8_calibrator.h b/src/operator/subgraph/tensorrt/tensorrt_int8_calibrator.h index e6a5efbdd8c4..bb81c9e13880 100644 --- a/src/operator/subgraph/tensorrt/tensorrt_int8_calibrator.h +++ b/src/operator/subgraph/tensorrt/tensorrt_int8_calibrator.h @@ -75,6 +75,8 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator2 { // TODO(spanev): determine if we need to serialize it const std::string& getCalibrationTableAsString() { return calibration_table_; } + void setDone(); + void waitAndSetDone(); bool isCacheEmpty(); diff --git a/tests/python/tensorrt/test_tensorrt.py b/tests/python/tensorrt/test_tensorrt.py index 20b84d0ef7c6..1aecf94e1ddc 100644 --- a/tests/python/tensorrt/test_tensorrt.py +++ b/tests/python/tensorrt/test_tensorrt.py @@ -143,12 +143,6 @@ def get_top1(logits): def test_tensorrt_symbol_int8(): ctx = mx.gpu(0) - cuda_arch = get_cuda_compute_capability(ctx) - cuda_arch_min = 70 - if cuda_arch < cuda_arch_min: - print('Bypassing test_tensorrt_symbol_int8 on cuda arch {}, need arch >= {}).'.format( - cuda_arch, cuda_arch_min)) - return # INT8 engine output are not lossless, so we don't expect numerical uniformity, # but we have to compare the TOP1 metric