diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index c1a7b50d3f45..b6d6c921a8a1 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -276,7 +276,7 @@ def add_rewrite(ref_call, new_args, ctx): assert rhs_kind in [QAnnotateKind.INPUT, QAnnotateKind.ACTIVATION] lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT) expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) - return QAnnotateExpr(expr, QAnnotateKind.INPUT) + return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) if lhs_kind is not None and rhs_kind is None: if _analysis.check_constant(rhs_expr): @@ -290,7 +290,7 @@ def add_rewrite(ref_call, new_args, ctx): if lhs_kind is not None and rhs_kind is not None: if lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.INPUT: expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) - return QAnnotateExpr(expr, QAnnotateKind.INPUT) + return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.ACTIVATION: rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT) expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) diff --git a/tests/python/relay/test_pass_auto_quantize.py b/tests/python/relay/test_pass_auto_quantize.py index 488866ab6ff8..30d4c3650215 100644 --- a/tests/python/relay/test_pass_auto_quantize.py +++ b/tests/python/relay/test_pass_auto_quantize.py @@ -440,6 +440,78 @@ def _check_dense(node): relay.analysis.post_order_visit(qnn_mod["main"], _check_dense) +def test_add_lhs_is_none_annotate(): + data_conv = relay.var("data_conv", shape=(1, 16, 64, 64)) + conv2d_w = relay.const(np.random.random((16, 16, 3, 3))) + conv2d = relay.nn.conv2d(data_conv, conv2d_w, padding=(1, 1), kernel_size=(3, 3)) + data_add = relay.var("data_add", shape=(16, 1, 1)) + add = relay.add(data_add, conv2d) + global_avg_pool2d = relay.nn.global_avg_pool2d(add) + mod = tvm.IRModule.from_expr(global_avg_pool2d) + + calibrate_data = [ + {"data_conv": np.random.random((1, 16, 64, 64)), "data_add": np.random.random((16, 1, 1))} + ] + + with tvm.transform.PassContext(opt_level=3): + with relay.quantize.qconfig(calibrate_mode="kl_divergence", skip_conv_layers=None): + qmod = relay.quantize.quantize(mod, dataset=calibrate_data) + + params = [gen_rand_tvm(param.type_annotation, 0, 1) for param in mod["main"].params] + + def _eval_mod(mod): + return relay.create_executor("vm", device=tvm.cpu(0), target="llvm", mod=mod).evaluate()( + *params + ) + + mod_result = _eval_mod(mod) + qmod_result = _eval_mod(qmod) + tvm.testing.assert_allclose(mod_result.numpy(), qmod_result.numpy(), rtol=1e-1, atol=1e-1) + + +def test_add_lhs_rhs_is_input_annotate(): + data_conv_r = relay.var("data_conv_r", shape=(1, 16, 64, 64)) + conv2d_r = relay.nn.conv2d( + data_conv_r, + relay.const(np.random.random((16, 16, 3, 3))), + padding=(1, 1), + kernel_size=(3, 3), + ) + data_conv_l = relay.var("data_conv_l", shape=(1, 16, 64, 64)) + conv2d_l = relay.nn.conv2d( + data_conv_l, + relay.const(np.random.random((16, 16, 3, 3))), + padding=(1, 1), + kernel_size=(3, 3), + ) + add = relay.add(conv2d_l, conv2d_r) + global_avg_pool2d = relay.nn.global_avg_pool2d(add) + mod = tvm.IRModule.from_expr(global_avg_pool2d) + + calibrate_data = [ + { + "data_conv_l": np.random.random((1, 16, 64, 64)), + "data_conv_r": np.random.random((1, 16, 64, 64)), + "data_add": np.random.random((16, 1, 1)), + } + ] + + with tvm.transform.PassContext(opt_level=3): + with relay.quantize.qconfig(calibrate_mode="kl_divergence", skip_conv_layers=None): + qmod = relay.quantize.quantize(mod, dataset=calibrate_data) + + params = [gen_rand_tvm(param.type_annotation, 0, 1) for param in mod["main"].params] + + def _eval_mod(mod): + return relay.create_executor("vm", device=tvm.cpu(0), target="llvm", mod=mod).evaluate()( + *params + ) + + mod_result = _eval_mod(mod) + qmod_result = _eval_mod(qmod) + tvm.testing.assert_allclose(mod_result.numpy(), qmod_result.numpy(), rtol=1e-1, atol=1e-1) + + if __name__ == "__main__": test_mul_rewrite() test_batch_flatten_rewrite() @@ -460,3 +532,6 @@ def _check_dense(node): test_skip_conv() test_stop_quantize() + + test_add_lhs_is_none_annotate() + test_add_lhs_rhs_is_input_annotate()