diff --git a/src/parser/parser.cc b/src/parser/parser.cc index b72a632635d9..793d6bb9a43d 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -1371,7 +1371,7 @@ class Parser { DLOG(INFO) << "Parser::ParseAttrs"; Map kwargs; while (Peek()->token_type == TokenType::kIdentifier) { - auto key = Match(TokenType::kIdentifier).ToString(); + auto key = GetHierarchicalName(ParseHierarchicalName().data); Match(TokenType::kEqual); // TOOD(@jroesch): syntactically what do we allow to appear in attribute right hand side. auto value = ParseAttributeValue(); @@ -1545,18 +1545,7 @@ class Parser { auto spanned_idents = ParseHierarchicalName(); auto idents = spanned_idents.data; auto span = spanned_idents.span; - ICHECK_NE(idents.size(), 0); - std::stringstream op_name; - int i = 0; - int periods = idents.size() - 1; - for (auto ident : idents) { - op_name << ident; - if (i < periods) { - op_name << "."; - i++; - } - } - return GetOp(op_name.str(), span); + return GetOp(GetHierarchicalName(idents), span); } } case TokenType::kGraph: { @@ -1696,6 +1685,21 @@ class Parser { return Spanned>(idents, span); } + std::string GetHierarchicalName(Array idents) { + ICHECK_NE(idents.size(), 0); + std::stringstream hierarchical_name; + int i = 0; + int periods = idents.size() - 1; + for (auto ident : idents) { + hierarchical_name << ident; + if (i < periods) { + hierarchical_name << "."; + i++; + } + } + return hierarchical_name.str(); + } + /*! \brief Parse a shape. */ Array ParseShape() { auto dims = ParseSequence( diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 8b6b39e3df15..04cc2c0e79e4 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -959,6 +959,13 @@ def test_tokenize_inf(): mod = relay.transform.AnnotateSpans()(mod) +def test_func_attrs(): + attrs = tvm.ir.make_node("DictAttrs", **{"Primitive": 1, "relay.reshape_only": 1}) + x = relay.var("x", shape=(2, 3)) + func = relay.Function([x], relay.reshape(x, (-1,)), attrs=attrs) + assert_parses_as(func.astext(), func) + + if __name__ == "__main__": import sys