From 1434e87c8a8457de8888e08d3c248a1f821af5af Mon Sep 17 00:00:00 2001 From: Marek Drozdowski Date: Tue, 15 Jan 2019 09:21:05 -0800 Subject: [PATCH] enable use_synthetic for calibration --- tftrt/examples/image-classification/image_classification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index a02a0f059..4db426a4d 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -53,7 +53,7 @@ def after_run(self, run_context, run_values): self.batch_size / self.iter_times[-1])) def run(frozen_graph, model, data_files, batch_size, - num_iterations, num_warmup_iterations, use_synthetic, display_every=100, run_calibration=False): + num_iterations, num_warmup_iterations, use_synthetic=False, display_every=100, run_calibration=False): """Evaluates a frozen graph This function evaluates a graph on the ImageNet validation set. @@ -489,7 +489,7 @@ def get_frozen_graph( print('Calibrating INT8...') start_time = time.time() run(calib_graph, model, calib_files, batch_size, - num_calib_inputs // batch_size, 0, False, run_calibration=True) + num_calib_inputs // batch_size, 0, use_synthetic=use_synthetic, run_calibration=True) times['trt_calibration'] = time.time() - start_time start_time = time.time()