From bea3b69169e1157a2c402bb0b37de5eaa37495cc Mon Sep 17 00:00:00 2001 From: Dongming Yang Date: Wed, 27 Jan 2021 15:38:49 +0800 Subject: [PATCH 1/4] [RELAY][Parser] Optimize relay parser to restore attrs for non-Operator calls * To avoid too much modification to the native class, only print out the attrs type key of non-Operator Call in relay printer. Then reconstruct the attrs object after parsing this attrs type key value in Relay parser. --- src/parser/parser.cc | 21 ++++++++++++++++----- src/printer/relay_text_printer.cc | 3 +++ 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/src/parser/parser.cc b/src/parser/parser.cc index afcf70737933..25c093486807 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -1334,6 +1334,8 @@ class Parser { case TokenType::kBoolean: case TokenType::kStringLiteral: return Match(next->token_type)->data; + case TokenType::kMetaReference: + return ParseMetaRef(); case TokenType::kLSquare: { return ParseSequence(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { return ParseAttributeValue(); }); @@ -1408,7 +1410,7 @@ class Parser { auto last_meta = Lookahead(2)->token_type == TokenType::kCloseParen; auto is_meta_attrs = is_meta_next && last_meta; - if (is_op && (is_pretty_attrs || is_meta_attrs)) { + if (is_pretty_attrs || is_meta_attrs) { if (is_meta_attrs) { auto meta_ref = ParseMetaRef(); if (meta_ref.as()) { @@ -1420,13 +1422,22 @@ class Parser { } } else { auto raw_attrs = ParseAttrs(); - auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs); - ICHECK(attr_obj.defined()); - attrs = Downcast(attr_obj); + if (is_op && op_key.size()) { + auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs); + ICHECK(attr_obj.defined()); + attrs = Downcast(attr_obj); + } else if (raw_attrs.count("attrs_type_key")) { + String attr_key = Downcast(raw_attrs["attrs_type_key"]); + if (attr_key.size()) { + raw_attrs.erase("attrs_type_key"); + auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(attr_key, raw_attrs); + ICHECK(attr_obj.defined()); + attrs = Downcast(attr_obj); + } + } } return true; } - return false; }); diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index da4f8cadfb3d..a87b8716c143 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -827,6 +827,9 @@ std::vector RelayTextPrinter::PrintCallAttrs(const Attrs& attrs, const Expr } else { AttrPrinter printer(&docs, this); const_cast(attrs.operator->())->VisitNonDefaultAttrs(&printer); + // print call attr type key to restore expr for relay parser + std::string s = std::string(attrs->GetTypeKey()); + printer.Visit("attrs_type_key", &s); return docs; } } From 2889f791efc941e20628fe93b234754747d76846 Mon Sep 17 00:00:00 2001 From: Dongming Yang Date: Wed, 27 Jan 2021 16:47:01 +0800 Subject: [PATCH 2/4] fix lint --- src/parser/parser.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 25c093486807..3061735eff7c 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -1430,7 +1430,8 @@ class Parser { String attr_key = Downcast(raw_attrs["attrs_type_key"]); if (attr_key.size()) { raw_attrs.erase("attrs_type_key"); - auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(attr_key, raw_attrs); + auto tbl = tvm::ReflectionVTable::Global(); + auto attr_obj = tbl->CreateObject(attr_key, raw_attrs); ICHECK(attr_obj.defined()); attrs = Downcast(attr_obj); } From c484dd8ddcc85f4fbac8769715269268a6671fa6 Mon Sep 17 00:00:00 2001 From: Dongming Yang Date: Wed, 27 Jan 2021 18:27:21 +0800 Subject: [PATCH 3/4] fix ci --- src/printer/relay_text_printer.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index a87b8716c143..cbee04f96096 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -827,9 +827,11 @@ std::vector RelayTextPrinter::PrintCallAttrs(const Attrs& attrs, const Expr } else { AttrPrinter printer(&docs, this); const_cast(attrs.operator->())->VisitNonDefaultAttrs(&printer); - // print call attr type key to restore expr for relay parser - std::string s = std::string(attrs->GetTypeKey()); - printer.Visit("attrs_type_key", &s); + if (!op_node) { + // print call attr type key to restore expr for relay parser + std::string s = std::string(attrs->GetTypeKey()); + printer.Visit("attrs_type_key", &s); + } return docs; } } From aaf4bb01dadb6ad34ba0bf92c61929d78677a459 Mon Sep 17 00:00:00 2001 From: Dongming Yang Date: Thu, 28 Jan 2021 15:31:33 +0800 Subject: [PATCH 4/4] add test case --- tests/python/relay/test_ir_parser.py | 39 ++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 162271756557..7412bb261367 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -910,6 +910,45 @@ def test_load_prelude(): tvm.parser.parse(mod.astext()) +def test_call_attrs(): + def get_func(shape, dtype): + x0 = relay.var("data", shape=shape, dtype=dtype) + w0 = relay.var("weight", shape=shape, dtype=dtype) + a = relay.nn.dense(x0, w0) + b = relay.nn.relu(a) + d = relay.add(b, relay.const(1.0, dtype=dtype)) + return relay.Function([x0, w0], d) + + # build relay graph + shape = (2, 4) + dtype = "float32" + sub_func = get_func(shape, dtype) + p0 = relay.var("p0", shape=shape, dtype=dtype) + p1 = relay.var("p1", shape=shape, dtype=dtype) + attr = tvm.ir.make_node("attrs.TestAttrs", name="func_call_attrs") + call = relay.Call(sub_func, [p0, p1], attrs=attr) + func = relay.Function([p0, p1], call) + + # build relay module + mod = tvm.IRModule() + mod["main"] = func + mod = tvm.relay.transform.InferType()(mod) + + # assert equal + program = """ + def @main(%p0: Tensor[(2, 4), float32], %p1: Tensor[(2, 4), float32]) { + %2 = fn (%data: Tensor[(2, 4), float32], %weight: Tensor[(2, 4), float32]) { + %0 = nn.dense(%data, %weight, units=None); + %1 = nn.relu(%0); + add(%1, 1f) + }; + %2(%p0, %p1, name="func_call_attrs", attrs_type_key="attrs.TestAttrs") + } + """ + parsed = parse_module(program) + assert_graph_equal(parsed, mod) + + if __name__ == "__main__": import sys