diff --git a/src/parser/parser.cc b/src/parser/parser.cc index afcf70737933..3061735eff7c 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -1334,6 +1334,8 @@ class Parser { case TokenType::kBoolean: case TokenType::kStringLiteral: return Match(next->token_type)->data; + case TokenType::kMetaReference: + return ParseMetaRef(); case TokenType::kLSquare: { return ParseSequence(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { return ParseAttributeValue(); }); @@ -1408,7 +1410,7 @@ class Parser { auto last_meta = Lookahead(2)->token_type == TokenType::kCloseParen; auto is_meta_attrs = is_meta_next && last_meta; - if (is_op && (is_pretty_attrs || is_meta_attrs)) { + if (is_pretty_attrs || is_meta_attrs) { if (is_meta_attrs) { auto meta_ref = ParseMetaRef(); if (meta_ref.as()) { @@ -1420,13 +1422,23 @@ class Parser { } } else { auto raw_attrs = ParseAttrs(); - auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs); - ICHECK(attr_obj.defined()); - attrs = Downcast(attr_obj); + if (is_op && op_key.size()) { + auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs); + ICHECK(attr_obj.defined()); + attrs = Downcast(attr_obj); + } else if (raw_attrs.count("attrs_type_key")) { + String attr_key = Downcast(raw_attrs["attrs_type_key"]); + if (attr_key.size()) { + raw_attrs.erase("attrs_type_key"); + auto tbl = tvm::ReflectionVTable::Global(); + auto attr_obj = tbl->CreateObject(attr_key, raw_attrs); + ICHECK(attr_obj.defined()); + attrs = Downcast(attr_obj); + } + } } return true; } - return false; }); diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index da4f8cadfb3d..cbee04f96096 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -827,6 +827,11 @@ std::vector RelayTextPrinter::PrintCallAttrs(const Attrs& attrs, const Expr } else { AttrPrinter printer(&docs, this); const_cast(attrs.operator->())->VisitNonDefaultAttrs(&printer); + if (!op_node) { + // print call attr type key to restore expr for relay parser + std::string s = std::string(attrs->GetTypeKey()); + printer.Visit("attrs_type_key", &s); + } return docs; } } diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 70fb56049873..62e52abefeb4 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -910,6 +910,45 @@ def test_load_prelude(): tvm.parser.parse(mod.astext()) +def test_call_attrs(): + def get_func(shape, dtype): + x0 = relay.var("data", shape=shape, dtype=dtype) + w0 = relay.var("weight", shape=shape, dtype=dtype) + a = relay.nn.dense(x0, w0) + b = relay.nn.relu(a) + d = relay.add(b, relay.const(1.0, dtype=dtype)) + return relay.Function([x0, w0], d) + + # build relay graph + shape = (2, 4) + dtype = "float32" + sub_func = get_func(shape, dtype) + p0 = relay.var("p0", shape=shape, dtype=dtype) + p1 = relay.var("p1", shape=shape, dtype=dtype) + attr = tvm.ir.make_node("attrs.TestAttrs", name="func_call_attrs") + call = relay.Call(sub_func, [p0, p1], attrs=attr) + func = relay.Function([p0, p1], call) + + # build relay module + mod = tvm.IRModule() + mod["main"] = func + mod = tvm.relay.transform.InferType()(mod) + + # assert equal + program = """ + def @main(%p0: Tensor[(2, 4), float32], %p1: Tensor[(2, 4), float32]) { + %2 = fn (%data: Tensor[(2, 4), float32], %weight: Tensor[(2, 4), float32]) { + %0 = nn.dense(%data, %weight, units=None); + %1 = nn.relu(%0); + add(%1, 1f) + }; + %2(%p0, %p1, name="func_call_attrs", attrs_type_key="attrs.TestAttrs") + } + """ + parsed = parse_module(program) + assert_graph_equal(parsed, mod) + + def test_tokenize_inf(): x = relay.var("x", shape=(3, 4), dtype="float32") y = relay.clip(x, -np.inf, np.inf)