diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index a030a056f7cd..40fcee68edc7 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -416,11 +416,13 @@ class PrettyPrinter : Doc VisitExpr_(const CallNode* op) final { Doc doc; - doc << Print(op->op); + // visit args first so they are lifted before the op + // this places op closer to its call site std::vector args; for (Expr arg : op->args) { args.push_back(Print(arg)); } + doc << Print(op->op); return doc << "(" << PrintVec(args) << PrintAttrs(op->attrs, op->op) << ")"; } diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 626436d9573f..d3e2d005a4e1 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -3,9 +3,10 @@ import numpy as np from tvm import relay - do_print = [False] +SEMVER = "v0.0.1\n" + def show(text): if do_print[0]: print("---------------------------") @@ -152,6 +153,19 @@ def test_densenet(): net, params = tvm.relay.testing.densenet.get_workload(batch_size=1) net.astext() +def test_call_node_order(): + x = relay.var("x") + y = relay.var("y") + assert relay.Call(relay.Function([x], x), [relay.Call(relay.Function([y], y), [relay.const(1)])]).astext() == SEMVER + \ + ("%0 = fn (%y) {\n" + " %y\n" + "}\n" + "%1 = %0(1)\n" + "%2 = fn (%x) {\n" + " %x\n" + "}\n" + "%3 = %2(%1)\n" + "%3") if __name__ == "__main__": do_print[0] = True @@ -170,3 +184,4 @@ def test_densenet(): test_call_attrs() test_let_if_scope() test_variable_name() + test_call_node_order()