From 0e3825eb59b334c6d8804a86956bf907764b0853 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Thu, 14 May 2020 13:21:41 +0100 Subject: [PATCH] [RELAY][BYOC] Preserve type information in Merge Composite Keep the type information when extracting patterns so that it can be used as part of 'check' functions. Change-Id: I16cc70c3d013a794d2ceefb5bec815129c7b8825 --- src/relay/transforms/merge_composite.cc | 13 ++++-- .../python/relay/test_pass_merge_composite.py | 41 +++++++++++++++++++ 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/src/relay/transforms/merge_composite.cc b/src/relay/transforms/merge_composite.cc index 596e2a1a29eb..027e5123365e 100644 --- a/src/relay/transforms/merge_composite.cc +++ b/src/relay/transforms/merge_composite.cc @@ -46,7 +46,8 @@ class MergeCompositeWrapper : public ExprMutator { if (var_map->find(pattern->name_hint()) == var_map->end()) { // if we haven't encountered this var yet, make a new free var and associate // it with the value at 'root' - auto free_var = Var(pattern->name_hint(), Type()); + auto free_var = Var(pattern->name_hint(), root->checked_type()); + free_var->checked_type_ = root->checked_type(); var_map->Set(pattern->name_hint(), Array({free_var, root})); return std::move(free_var); } else { @@ -147,7 +148,9 @@ class MergeCompositeWrapper : public ExprMutator { new_args.push_back(new_arg); i++; } - return Call(root->op, new_args, root->attrs); + Call new_call = Call(root->op, new_args, root->attrs); + new_call->checked_type_ = root->checked_type(); + return std::move(new_call); } Expr VisitExpr_(const CallNode* cn) { @@ -163,12 +166,15 @@ class MergeCompositeWrapper : public ExprMutator { auto new_e = this->Mutate(arg); new_args.push_back(new_e); } - return Call(call->op, new_args, call->attrs); + Call new_call = Call(call->op, new_args, call->attrs); + new_call->checked_type_ = call->checked_type(); + return std::move(new_call); } } Expr expr = ExprMutator::VisitExpr_(cn); call = Downcast(expr); + call->checked_type_ = cn->checked_type(); if (!call->op->IsInstance()) return std::move(call); // only call patterns are supported @@ -189,6 +195,7 @@ class MergeCompositeWrapper : public ExprMutator { args.push_back(args_map[free_var->name_hint()][1]); } auto new_call = Call(f, args); + new_call->checked_type_ = call->checked_type(); return std::move(new_call); } return std::move(call); diff --git a/tests/python/relay/test_pass_merge_composite.py b/tests/python/relay/test_pass_merge_composite.py index 317bb421477c..3a79f6ad860f 100644 --- a/tests/python/relay/test_pass_merge_composite.py +++ b/tests/python/relay/test_pass_merge_composite.py @@ -803,6 +803,46 @@ def get_net(): assert tvm.ir.structural_equal(result, expected, map_free_vars=True) +def test_type_check(): + """Test that we can query tensor types in the 'check' function.""" + def before(): + x = relay.var('x', shape=(1, 10, 10, 10)) + w = relay.var('w', shape=(10, 10, 3, 3)) + b = relay.var('b', shape=(8,)) + conv = relay.nn.conv2d(x, + w, + kernel_size=(3, 3), + kernel_layout="OIHW", + data_layout="NHWC") + bias = relay.nn.bias_add(conv, b) + relu = relay.nn.relu(bias) + return relay.Function([x, w, b], relu) + + def _check_type_true(extract): + conv = extract.args[0].args[0] + typ = conv.checked_type + return bool(typ.shape[0] == 1) + + def _check_type_false(extract): + conv = extract.args[0].args[0] + typ = conv.checked_type + return bool(typ.shape[0] != 1) + + pattern_table_true = [ + ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_type_true) + ] + pattern_table_false = [ + ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_type_false) + ] + + result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table_false)) + expected = run_opt_pass(before(), relay.transform.InferType()) + assert tvm.ir.structural_equal(result, expected, map_free_vars=True) + + result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table_true)) + assert result.body.op.attrs["Composite"] == "conv_bias_relu" + + if __name__ == "__main__": test_simple_merge() test_branch_merge() @@ -814,3 +854,4 @@ def get_net(): test_tuple_get_item_merge() test_pattern_with_check() test_diamond_not_merge() + test_type_check()