diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 74c236ae3280..015489dd0857 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -108,6 +108,25 @@ class AnnotateTargetRewriter : public ExprRewriter { return new_op; } + Expr InsertCompilerEndAndPropogateTarget(const Expr& expr) { + /*! + * \brief This function inserts compiler end to expr and maps the corresponding target to the + * new expression. + * + * This function checks for expr existence within the map and inserts the annotation + * Further, it propagates the target to the new expression and returns it + * + * \param expr A relay expression + * \return An annotated and target-propagated relay expression. + */ + Expr new_expr = expr; + if (op_expr_to_target_.find(expr) != op_expr_to_target_.end()) { + new_expr = InsertAnnotation(expr, op_expr_to_target_[expr], make_end_op); + op_expr_to_target_[new_expr] = op_expr_to_target_[expr]; + } + return std::move(new_expr); + } + Expr Rewrite_(const CallNode* pre, const Expr& post) final { // Supported targets for this node. The order implies the priority. std::vector supported_targets; @@ -127,14 +146,16 @@ class AnnotateTargetRewriter : public ExprRewriter { CHECK(op_expr_to_target_.find(input_expr) != op_expr_to_target_.end()); return InsertAnnotation(input_expr, op_expr_to_target_[input_expr], make_end_op); } - - // Peek the first argument. If it is compiler begin then this node had annotated by - // another target before, so we also consider that target as a supported target. - const CallNode* first_arg_call = pre->args[0].as(); - if (first_arg_call && first_arg_call->op == CompilerBeginOp()) { - std::string arg_target = first_arg_call->attrs.as()->compiler; - if (arg_target != "default") { - supported_targets.push_back(arg_target); + // Check prior to peeking first argument + if (pre->args.size()) { + // Peek the first argument. If it is compiler begin then this node had annotated by + // another target before, so we also consider that target as a supported target. + const CallNode* first_arg_call = pre->args[0].as(); + if (first_arg_call && first_arg_call->op == CompilerBeginOp()) { + std::string arg_target = first_arg_call->attrs.as()->compiler; + if (arg_target != "default") { + supported_targets.push_back(arg_target); + } } } @@ -222,11 +243,7 @@ class AnnotateTargetRewriter : public ExprRewriter { new_body = func->body; } else { func = Downcast(post); - new_body = func->body; - if (op_expr_to_target_.find(func->body) != op_expr_to_target_.end()) { - new_body = InsertAnnotation(func->body, op_expr_to_target_[func->body], make_end_op); - op_expr_to_target_[new_body] = op_expr_to_target_[func->body]; - } + new_body = InsertCompilerEndAndPropogateTarget(func->body); } return Function(func->params, new_body, func->ret_type, func->type_params, func->attrs); } @@ -234,20 +251,27 @@ class AnnotateTargetRewriter : public ExprRewriter { Expr Rewrite_(const LetNode* op, const Expr& post) final { auto let = Downcast(post); - auto target_n_args = AnnotateArgs({let->value, let->body}); - auto new_expr = Let(let->var, std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]); - op_expr_to_target_[new_expr] = std::get<0>(target_n_args); + Expr new_expr; + std::pair> target_n_args; + Expr new_body = InsertCompilerEndAndPropogateTarget(let->body); + // Do not annotate function literal with let binding. + if (let->value->IsInstance()) { + new_expr = Let(let->var, let->value, new_body); + } else { + target_n_args = AnnotateArgs({let->value}); + new_expr = Let(let->var, std::get<1>(target_n_args)[0], new_body); + } + return std::move(new_expr); } Expr Rewrite_(const IfNode* op, const Expr& post) final { auto expr = Downcast(post); + Expr new_cond = InsertCompilerEndAndPropogateTarget(expr->cond); + Expr new_true_branch = InsertCompilerEndAndPropogateTarget(expr->true_branch); + Expr new_false_branch = InsertCompilerEndAndPropogateTarget(expr->false_branch); - auto target_n_args = AnnotateArgs({expr->cond, expr->true_branch, expr->false_branch}); - CHECK_EQ(std::get<1>(target_n_args).size(), 3U); - auto new_expr = If(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1], - std::get<1>(target_n_args)[2]); - op_expr_to_target_[new_expr] = std::get<0>(target_n_args); + auto new_expr = If(new_cond, new_true_branch, new_false_branch); return std::move(new_expr); } diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index b7c43498a69a..ba1de7416384 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -353,6 +353,161 @@ def before(): assert tvm.ir.structural_equal(expected, mod) +def test_if_else(): + target = "test_if_else" + + @tvm.ir.register_op_attr("equal", "target." + target) + def relu(attrs, args): # pylint: disable=unused-variable + return True + + @tvm.ir.register_op_attr("tanh", "target." + target) + def tanh(attrs, args): # pylint: disable=unused-variable + return True + + @tvm.ir.register_op_attr("sigmoid", "target." + target) + def sigmoid(attrs, args): # pylint: disable=unused-variable + return True + + @tvm.ir.register_op_attr("erf", "target." + target) + def erf(attrs, args): # pylint: disable=unused-variable + return True + + """Test that If-else nodes compiles correctly when surrounded by supported nodes.""" + + def before(): + data = relay.var("data", shape=(1, 32)) + eq1 = relay.var("e1", shape=[], dtype="float32") + eq2 = relay.var("e2", shape=[], dtype="float32") + eq = relay.equal(eq1, eq2) + + true_branch = relay.tanh(data) + false_branch = relay.sigmoid(data) + ife = relay.If(eq, true_branch, false_branch) + out = relay.erf(ife) + func = relay.Function([data, eq1, eq2], out) + mod = tvm.IRModule.from_expr(func) + + return mod + + def after(): + + data = relay.var("data", shape=(1, 32)) + eq1 = relay.var("e1", shape=[], dtype="float32") + eq2 = relay.var("e2", shape=[], dtype="float32") + + cb_1 = relay.annotation.compiler_begin(eq1, target) + cb_2 = relay.annotation.compiler_begin(eq2, target) + + equality_condition = relay.equal(cb_1, cb_2) + ce_1 = relay.annotation.compiler_end(equality_condition, target) + + # if condition + cb_3 = relay.annotation.compiler_begin(data, target) + true_branch = relay.tanh(cb_3) + ce_2 = relay.annotation.compiler_end(true_branch, target) + + # else condition + cb_4 = relay.annotation.compiler_begin(data, target) + false_branch = relay.sigmoid(cb_4) + ce_3 = relay.annotation.compiler_end(false_branch, target) + + if_condition = relay.If(ce_1, ce_2, ce_3) + cb_5 = relay.annotation.compiler_begin(if_condition, target) + erf_out = relay.erf(cb_5) + ce_4 = relay.annotation.compiler_end(erf_out, target) + func = relay.Function([data, eq1, eq2], ce_4) + mod = tvm.IRModule.from_expr(func) + return mod + + result = transform.AnnotateTarget(target)(before()) + expected = transform.InferType()(after()) + assert tvm.ir.structural_equal(expected, result) + + +def test_while_let(): + target = "test_while_let" + + @tvm.ir.register_op_attr("less", "target." + target) + def less(attrs, args): # pylint: disable=unused-variable + return True + + @tvm.ir.register_op_attr("add", "target." + target) + def add(attrs, args): # pylint: disable=unused-variable + return True + + @tvm.ir.register_op_attr("zeros_like", "target." + target) + def zeros_like(attrs, args): # pylint: disable=unused-variable + return True + + """Test that let nodes compiles correctly when surrounded by other nodes.""" + + def before(): + + var1 = relay.var("var1", shape=(2,)) + var2 = relay.var("var2", shape=(), dtype="int32") + var3 = relay.var("var3", shape=(2,)) + cond = relay.less(var2, relay.const(10, dtype="int32")) + + loop = relay.var("while_loop") + ii = var2 + relay.const(1, dtype="int32") + ss = var3 + var1 + true_branch = loop(ii, ss) + ife = relay.If(cond, true_branch, var3) + func_1 = relay.Function([var2, var3], ife) + + ret = relay.Let(loop, func_1, loop(relay.const(0, dtype="int32"), relay.zeros_like(var1))) + func_2 = relay.Function([var1], ret) + mod = tvm.IRModule.from_expr(func_2) + return mod + + def after(): + var1 = relay.var("var1", shape=(2,)) + var2 = relay.var("var2", shape=(), dtype="int32") + var3 = relay.var("var3", shape=(2,)) + var4 = relay.const(10, dtype="int32") + + cb_1 = relay.annotation.compiler_begin(var2, target) + cb_2 = relay.annotation.compiler_begin(var4, target) + + less_condition = relay.less(cb_1, cb_2) + ce_1 = relay.annotation.compiler_end(less_condition, target) + + loop = relay.var("while_loop") + + # if condition + cb_3 = relay.annotation.compiler_begin(var2, target) + cb_4 = relay.annotation.compiler_begin(relay.const(1, dtype="int32"), target) + add_op_1 = relay.add(cb_3, cb_4) + ce_2 = relay.annotation.compiler_end(add_op_1, target) + cb_5 = relay.annotation.compiler_begin(ce_2, "default") + cb_6 = relay.annotation.compiler_begin(var3, target) + cb_7 = relay.annotation.compiler_begin(var1, target) + add_op_2 = relay.add(cb_6, cb_7) + ce_3 = relay.annotation.compiler_end(add_op_2, target) + cb_8 = relay.annotation.compiler_begin(ce_3, "default") + true_branch = loop(cb_5, cb_8) # while loop + ce_4 = relay.annotation.compiler_end(true_branch, "default") + if_condition = relay.If(ce_1, ce_4, var3) + + cb_9 = relay.annotation.compiler_begin(relay.const(0, dtype="int32"), "default") + cb_10 = relay.annotation.compiler_begin(var1, target) + zeros_like = relay.zeros_like(cb_10) + ce_5 = relay.annotation.compiler_end(zeros_like, target) + cb_11 = relay.annotation.compiler_begin(ce_5, "default") + while_condition = loop(cb_9, cb_11) + ce_6 = relay.annotation.compiler_end(while_condition, "default") + + func_1 = relay.Function([var2, var3], if_condition) + ret = relay.Let(loop, func_1, ce_6) + func_2 = relay.Function([var1], ret) + mod = tvm.IRModule.from_expr(func_2) + return mod + + result = transform.AnnotateTarget(target)(before()) + expected = transform.InferType()(after()) + assert tvm.ir.structural_equal(expected, result) + + if __name__ == "__main__": test_extern_dnnl() test_composite_function() @@ -361,3 +516,5 @@ def before(): test_type_propagation() test_tuple() test_multiple_runs() + test_if_else() + test_while_let()