diff --git a/python/tvm/relay/quantize/_calibrate.py b/python/tvm/relay/quantize/_calibrate.py index aae50519b132..21254fa61e8e 100644 --- a/python/tvm/relay/quantize/_calibrate.py +++ b/python/tvm/relay/quantize/_calibrate.py @@ -53,11 +53,18 @@ def collect_stats(mod, dataset): logging.info("collecting statistics for calibration...") func = mod['main'] func = _quantize.CreateStatsCollector(func) - target = tvm.target.current_target() or 'llvm' + + if tvm.target.current_target(): + target = tvm.target.current_target() + ctx = tvm.context(target.target_name) + else: + target = 'llvm' + ctx = tvm.context(target) + with _transform.build_config(opt_level=3): graph, lib, params = _build_module.build(func, target=target) outputs = [] - runtime = graph_runtime.create(graph, lib, tvm.context(target)) + runtime = graph_runtime.create(graph, lib, ctx) runtime.set_input(**params) num_outputs = runtime.get_num_outputs() diff --git a/tests/python/relay/test_pass_auto_quantize.py b/tests/python/relay/test_pass_auto_quantize.py index e4aa36bf9f70..5b2e368f9dbd 100644 --- a/tests/python/relay/test_pass_auto_quantize.py +++ b/tests/python/relay/test_pass_auto_quantize.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy as np import tvm from tvm import relay from tvm.relay import testing @@ -45,5 +46,28 @@ def test_mul_rewrite(): quantize_and_build(act * pool) + +def get_calibration_dataset(input_name): + dataset = [] + for i in range(5): + data = np.random.uniform(size=(1, 3, 224, 224)) + dataset.append({input_name: data}) + return dataset + + +def test_calibrate_target(create_target=False): + mod, params = testing.resnet.get_workload(num_layers=18) + dataset = get_calibration_dataset("data") + with relay.quantize.qconfig(calibrate_mode="kl_divergence"): + if create_target: + with tvm.target.create("llvm"): + relay.quantize.quantize(mod, params, dataset) + else: + # current_target = None + relay.quantize.quantize(mod, params, dataset) + + if __name__ == "__main__": test_mul_rewrite() + test_calibrate_target(False) + test_calibrate_target(True)