From fbe3a5b4b2814fd66ea9f66353a2cad60d80537f Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 10 Jul 2020 17:00:02 -0700 Subject: [PATCH 01/48] Add code from livestream with JK --- src/parser/parser.cc | 90 +++++++++++++++++++++++---- tests/python/relay/test_ir_parser2.py | 27 ++++++-- 2 files changed, 100 insertions(+), 17 deletions(-) diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 0aaa698be45e..2aca371b2a93 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -542,7 +542,7 @@ class Parser { */ template Array ParseSequence(TokenType start, TokenType sep, TokenType stop, std::function parse, - std::function before_stop = nullptr) { + std::function before_stop = nullptr) { Match(start); if (WhenMatch(stop)) { return Array(); @@ -550,23 +550,28 @@ class Parser { auto data = parse(); Array elements = {data}; - // parse '(' expr ')' + // parse '(' expr ','? ')' // if we are at the end invoke leftover parser - if (Peek()->token_type == stop && before_stop) { - before_stop(); - } + // if (Peek()->token_type == sep && before_stop) { + // before_stop(); + // } + if (WhenMatch(stop)) { return elements; // parse '( expr ',' * ')' } else if (WhenMatch(sep)) { - // if we are at the end invoke leftover parser - if (Peek()->token_type == stop && before_stop) { - before_stop(); - } while (true) { if (WhenMatch(stop)) { break; } else { + // If before stop is + if (before_stop) { + auto did_parse = before_stop(); + if (did_parse) { + Match(stop); + return elements; + } + } auto data = parse(); WhenMatch(sep); elements.push_back(data); @@ -757,6 +762,7 @@ class Parser { // NB: Might need to optimize to remove deep recursion. // Stack should only grow proportionally to the number of // nested scopes. + // Parses `{` expression `}`. return Bracket(TokenType::LCurly, TokenType::RCurly, [&]() { PushScope(); auto expr = ParseExpr(); @@ -764,6 +770,7 @@ class Parser { return expr; }); } + // Parses `let ...`; case TokenType::Let: exprs.push_back(ParseBindingExpr()); break; @@ -778,6 +785,7 @@ class Parser { exprs.push_back(ParseIf()); break; } + // %x ... case TokenType::Graph: if (Lookahead(2)->token_type == TokenType::Equal) { exprs.push_back(ParseBindingExpr()); @@ -843,9 +851,11 @@ class Parser { while (true) { auto next = Peek(); + std::cout << "right here about to parse graph" << std::endl; if (next->token_type == TokenType::Graph && Lookahead(2)->token_type == TokenType::Equal) { Match(TokenType::Graph); Match(TokenType::Equal); + DisplayNextN(10); auto val = this->ParseExprBinOp(); Match(TokenType::Semicolon); AddGraphBinding(next, val); @@ -1098,11 +1108,30 @@ class Parser { }); } + ObjectRef ParseAtributeValue() { + auto next = Peek(); + switch (next->token_type) { + case TokenType::Float: + case TokenType::Integer: + case TokenType::Boolean: + return Match(next->token_type)->data; + default: + return ParseAtomicExpr(); + } + } + Attrs ParseAttrs(const std::string& type_key) { Map kwargs; + while (Peek()->token_type == TokenType::Identifier) { + auto key = Match(TokenType::Identifier).ToString(); + Match(TokenType::Equal); + // TOOD(@jroesch): syntactically what do we allow to appear in attribute right hand side. + auto value = ParseAtributeValue(); + kwargs.Set(key, value); + WhenMatch(TokenType::Comma); + } auto attrs = tvm::ReflectionVTable::Global()->CreateObject(type_key, kwargs); - LOG(FATAL) << Attrs(); - return Attrs(); + return Downcast(attrs); } Expr ParseCallArgs(Expr op) { @@ -1118,8 +1147,11 @@ class Parser { if (is_ident && next_is_equal) { if (auto op_node = op.as()) { call_attrs = ParseAttrs(op_node->attrs_type_key); + return true; } } + + return false; }); return Expr(Call(op, args, call_attrs, {})); } else { @@ -1152,6 +1184,17 @@ class Parser { }); } + Expr GetOp(const std::string& op_name, const Token& tok) { + try { + return Op::Get(op_name); + } catch (dmlc::Error e) { + std::stringstream msg; + msg << "operator `" << op_name << "` not found, perhaps you forgot to register it?"; + this->diag_ctx.Emit({ tok->line, tok->column, msg.str() }); + return Expr(); + } + } + Expr ParseAtomicExpr() { return ConsumeWhitespace([this] { auto next = Peek(); @@ -1170,10 +1213,12 @@ class Parser { Expr e = Constant(boolean); return e; } + // Parse a local of the form `%x`. case TokenType::Local: { Consume(TokenType::Local); return Expr(LookupLocal(next)); } + // Parse a local of the form `@x`. case TokenType::Global: { auto string = next.ToString(); Consume(TokenType::Global); @@ -1186,6 +1231,8 @@ class Parser { return Expr(global.value()); } } + // Parse a local of the form `x`. + // Right now we fail to parse `x.y`. case TokenType::Identifier: { auto string = next.ToString(); Consume(TokenType::Identifier); @@ -1193,7 +1240,26 @@ class Parser { if (ctor) { return Expr(ctor.value()); } else { - return Expr(Op::Get(string)); + // id(x) ^ ';' this is works + // id(x) ^ '.' id(y) this is works + + // ^ id(x) '.' id(y) this is works + // id(x) ^ '.' id(y) this is works + // id(x) '.' ^ id(y) this is works + // id(x) '.' id(y) ^ this is works + // DisplayNextN(10); + + if (Peek()->token_type == TokenType::Period) { + std::stringstream hier_id; + Consume(TokenType::Period); + auto second_id = Match(TokenType::Identifier); + hier_id << string << "." + << second_id.ToString(); + return GetOp(hier_id.str(), next); + } else { + auto op = GetOp(string, next); + return op; + } } } case TokenType::Graph: { diff --git a/tests/python/relay/test_ir_parser2.py b/tests/python/relay/test_ir_parser2.py index 23ba1fa850e5..e18c61bd260c 100644 --- a/tests/python/relay/test_ir_parser2.py +++ b/tests/python/relay/test_ir_parser2.py @@ -17,6 +17,7 @@ import tvm from tvm import te from tvm import relay +import tvm.relay.testing import pytest from numpy import isclose from typing import Union @@ -73,7 +74,6 @@ def assert_graph_equal(lhs, rhs): def graph_equal(lhs, rhs): return tvm.ir.structural_equal(lhs, rhs, map_free_vars=True) - def roundtrip_expr(expr): x = tvm.parser.parse_expr(str(str(expr))) assert_graph_equal(x, expr) @@ -99,7 +99,6 @@ def parse_module(code): roundtrip(mod) return mod - def assert_parses_as(code, expr): parsed = parse_text(code) assert_graph_equal(parsed, expr) @@ -210,7 +209,6 @@ def test_op_assoc(): assert graph_equal(parse_text("1 * 1 + 1 < 1 == 1"), parse_text("(((1 * 1) + 1) < 1) == 1")) assert graph_equal(parse_text("1 == 1 < 1 + 1 * 1"), parse_text("1 == (1 < (1 + (1 * 1)))")) - def test_vars(): # var var = parse_text("let %foo = (); %foo") @@ -886,6 +884,25 @@ def test_import_grad(): mod = tvm.IRModule() mod.import_from_std("gradient.rly") +# hiearchy id, i.e parse nn.conv2d +# do with multiple levels +# +# call attributes not correctly parsing +# convert error from attribute construction to real error message +# lexing issue with projection of graph variables + +def test_hierarchical_identifiers(): + assert False + +def test_resnet(): + mod, params = relay.testing.resnet.get_workload() + text = str(mod.astext()) + print(text) + parse_module(text) + import pdb; pdb.set_trace() + if __name__ == "__main__": - import sys - pytest.main(sys.argv) + # import sys + # pytest.main(sys.argv) + test_hierarchical_identifiers() + # test_resnet() From 4c61844b94dba7a1792dc0e6abe51f3ed0f64507 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 10 Jul 2020 20:57:10 -0700 Subject: [PATCH 02/48] Fix errors parsing ResNet --- src/parser/parser.cc | 58 +++++++++++++++++++-------- src/parser/tokenizer.h | 12 ++++++ tests/python/relay/test_ir_parser2.py | 18 +++++---- 3 files changed, 65 insertions(+), 23 deletions(-) diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 2aca371b2a93..b45b6ef3c8f4 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -763,12 +763,14 @@ class Parser { // Stack should only grow proportionally to the number of // nested scopes. // Parses `{` expression `}`. - return Bracket(TokenType::LCurly, TokenType::RCurly, [&]() { + auto block = Bracket(TokenType::LCurly, TokenType::RCurly, [&]() { PushScope(); auto expr = ParseExpr(); PopScopes(1); return expr; }); + exprs.push_back(block); + break; } // Parses `let ...`; case TokenType::Let: @@ -851,11 +853,9 @@ class Parser { while (true) { auto next = Peek(); - std::cout << "right here about to parse graph" << std::endl; if (next->token_type == TokenType::Graph && Lookahead(2)->token_type == TokenType::Equal) { Match(TokenType::Graph); Match(TokenType::Equal); - DisplayNextN(10); auto val = this->ParseExprBinOp(); Match(TokenType::Semicolon); AddGraphBinding(next, val); @@ -1108,34 +1108,46 @@ class Parser { }); } - ObjectRef ParseAtributeValue() { + ObjectRef ParseAttributeValue() { auto next = Peek(); switch (next->token_type) { case TokenType::Float: case TokenType::Integer: case TokenType::Boolean: return Match(next->token_type)->data; + case TokenType::LSquare: { + return ParseSequence(TokenType::LSquare, TokenType::Comma, TokenType::RSquare, [&]() { + return ParseAttributeValue(); + }); + } default: return ParseAtomicExpr(); } } - Attrs ParseAttrs(const std::string& type_key) { + Map ParseAttrs() { Map kwargs; while (Peek()->token_type == TokenType::Identifier) { auto key = Match(TokenType::Identifier).ToString(); Match(TokenType::Equal); // TOOD(@jroesch): syntactically what do we allow to appear in attribute right hand side. - auto value = ParseAtributeValue(); + auto value = ParseAttributeValue(); kwargs.Set(key, value); WhenMatch(TokenType::Comma); } - auto attrs = tvm::ReflectionVTable::Global()->CreateObject(type_key, kwargs); - return Downcast(attrs); + return kwargs; } Expr ParseCallArgs(Expr op) { - Attrs call_attrs; + Map raw_attrs; + std::string op_key; + bool is_op = false; + + if (auto op_node = op.as()) { + is_op = true; + op_key = op_node->attrs_type_key; + } + if (Peek()->token_type == TokenType::OpenParen) { Array args = ParseSequence( TokenType::OpenParen, TokenType::Comma, TokenType::CloseParen, @@ -1144,16 +1156,23 @@ class Parser { auto is_ident = Lookahead(1)->token_type == TokenType::Identifier; auto next_is_equal = Lookahead(2)->token_type == TokenType::Equal; - if (is_ident && next_is_equal) { - if (auto op_node = op.as()) { - call_attrs = ParseAttrs(op_node->attrs_type_key); - return true; - } + if (is_op && is_ident && next_is_equal) { + raw_attrs = ParseAttrs(); + return true; } return false; }); - return Expr(Call(op, args, call_attrs, {})); + + Attrs attrs; + + if (is_op && op_key.size()) { + auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs); + CHECK(attr_obj.defined()); + attrs = Downcast(attr_obj); + } + + return Expr(Call(op, args, attrs, {})); } else { return Expr(); } @@ -1196,7 +1215,7 @@ class Parser { } Expr ParseAtomicExpr() { - return ConsumeWhitespace([this] { + auto expr = ConsumeWhitespace([this] { auto next = Peek(); switch (next->token_type) { case TokenType::Integer: @@ -1305,6 +1324,13 @@ class Parser { } } }); + + if (WhenMatch(TokenType::Period)) { + auto index = Match(TokenType::Integer).ToNumber(); + expr = relay::TupleGetItem(expr, index); + } + + return expr; } /*! \brief Parse a shape. */ diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index f6c27340e09a..7c4b99fd0eee 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -301,6 +301,18 @@ struct Tokenizer { } else if (next == '%') { auto token = NewToken(TokenType::Percent); Next(); + + std::stringstream number; + while (More() && IsDigit(Peek())) { + number << Next(); + } + + auto number_str = number.str(); + if (number_str.size()) { + auto num_tok = ParseNumber(true, false, number_str); + token = Token(token->line, token->column, TokenType::Graph, num_tok->data); + } + return token; } else if (next == '/') { Next(); diff --git a/tests/python/relay/test_ir_parser2.py b/tests/python/relay/test_ir_parser2.py index e18c61bd260c..aa785d2f61f2 100644 --- a/tests/python/relay/test_ir_parser2.py +++ b/tests/python/relay/test_ir_parser2.py @@ -225,6 +225,11 @@ def test_vars(): assert isinstance(op, tvm.ir.Op) assert op.name == "add" + # operator id with prefix + op = parse_text("nn.global_avg_pool2d") + assert isinstance(op, tvm.ir.Op) + assert op.name == "nn.global_avg_pool2d" + def test_let(): assert_parses_as( @@ -891,18 +896,17 @@ def test_import_grad(): # convert error from attribute construction to real error message # lexing issue with projection of graph variables -def test_hierarchical_identifiers(): - assert False +# def test_hierarchical_identifiers(): +# assert False def test_resnet(): mod, params = relay.testing.resnet.get_workload() text = str(mod.astext()) - print(text) - parse_module(text) - import pdb; pdb.set_trace() + parsed_mod = parse_module(text) + tvm.ir.assert_structural_equal(mod, parsed_mod) if __name__ == "__main__": # import sys # pytest.main(sys.argv) - test_hierarchical_identifiers() - # test_resnet() + # test_hierarchical_identifiers() + test_resnet() From 432e0d96078248b6a3d75e9647faddece70656ce Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 23 Jul 2020 01:05:37 -0700 Subject: [PATCH 03/48] Parse metadata section efficiently and do most of plumbing to resolve metadata section references. --- python/tvm/relay/expr.py | 2 +- src/parser/meta_ref.h | 94 +++++++++++++++++++++ src/parser/parser.cc | 112 +++++++++----------------- src/parser/token.h | 15 ++++ src/parser/tokenizer.h | 73 +++++++++++++++++ src/printer/text_printer.cc | 10 --- src/printer/text_printer.h | 8 +- tests/python/relay/test_ir_parser2.py | 30 ++++++- 8 files changed, 253 insertions(+), 91 deletions(-) create mode 100644 src/parser/meta_ref.h diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index fbb98fcf9e3c..106edc25c5ee 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -518,7 +518,7 @@ def bind(expr, binds): expr : tvm.relay.Expr The input expression. - binds : Union[Map[tvm.relay.Var, tvm.relay.Expr], Map[str, tvm.relay.Expr]] + binds : Map[tvm.relay.Var, tvm.relay.Expr] The specific bindings. Returns diff --git a/src/parser/meta_ref.h b/src/parser/meta_ref.h new file mode 100644 index 000000000000..b8d118195bfa --- /dev/null +++ b/src/parser/meta_ref.h @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file meta_ref.h + * \brief A reference into the metadata section of the Relay text format. + */ + +#ifndef TVM_PARSER_META_REF_H_ +#define TVM_PARSER_META_REF_H_ + +#include + +#include + +namespace tvm { +namespace parser { + +using namespace relay; + +/*! \brief A reference to a "meta-expression". + * + * In the text format we allow referencing metadata which + * uses a compact serialization that proceeds the main + * program body. + * + * We can reference this table using an expression of + * the form `meta[Type][index]`. + * + * We must later resolve these references to actual in-memory + * AST nodes but this requires first parsing the full program + * then expanding these temporary AST nodes into their corresponding + * nodes. + * + * For example the nth large constant will be pretty-printed as meta[relay.Constant][n] + * with its compact binary serialization residing in the metadata section at the end + * of the program. + */ +class MetaRefExprNode : public TempExprNode { + public: + /*! \brief The type key of the meta expression. */ + std::string type_key; + /*! \brief The index into the type key's table. */ + uint64_t node_index; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + // TODO(@jroesch): we probably will need to manually + // expand these with a pass. + Expr Realize() const final { return Expr(); } + + static constexpr const char* _type_key = "relay.MetaRefExpr"; + TVM_DECLARE_FINAL_OBJECT_INFO(MetaRefExprNode, TempExprNode); +}; + +class MetaRefExpr : public TempExpr { + public: + /*! + * \brief The constructor for MetaRefExpr + * \param type_key The type key of the object in the meta section. + * \param kind The index into that subfield. + */ + TVM_DLL MetaRefExpr(std::string type_key, uint64_t node_index); + + TVM_DEFINE_OBJECT_REF_METHODS(MetaRefExpr, TempExpr, MetaRefExprNode); +}; + +MetaRefExpr::MetaRefExpr(std::string type_key, uint64_t node_index) { + auto rnode = make_object(); + rnode->type_key = type_key; + rnode->node_index = node_index; + data_ = std::move(rnode); +} + +} // namespace parser +} // namespace tvm + +#endif // TVM_PARSER_META_REF_H_ diff --git a/src/parser/parser.cc b/src/parser/parser.cc index b45b6ef3c8f4..bc6a7704dde6 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -31,6 +31,7 @@ #include +#include "./meta_ref.h" #include "./diagnostic.h" #include "./op_table.h" #include "./tokenizer.h" @@ -87,60 +88,6 @@ class SemVer { patch_version(other.patch_version) {} }; -/*! \brief A reference to a "meta-expression". - * - * In the text format we allow referencing metadata which - * uses a compact serialization that proceeds the main - * program body. - * - * We can reference this table using an expression of - * the form `meta[Type][index]`. - * - * We must later resolve these references to actual in-memory - * AST nodes but this requires first parsing the full program - * then expanding these temporary AST nodes into their corresponding - * nodes. - * - * For example the nth large constant will be pretty-printed as meta[relay.Constant][n] - * with its compact binary serialization residing in the metadata section at the end - * of the program. - */ -class MetaRefExprNode : public TempExprNode { - public: - /*! \brief The type key of the meta expression. */ - std::string type_key; - /*! \brief The index into the type key's table. */ - uint64_t node_index; - - void VisitAttrs(tvm::AttrVisitor* v) {} - - // TODO(@jroesch): we probably will need to manually - // expand these with a pass. - Expr Realize() const final { return Expr(); } - - static constexpr const char* _type_key = "relay.MetaRefExpr"; - TVM_DECLARE_FINAL_OBJECT_INFO(MetaRefExprNode, TempExprNode); -}; - -class MetaRefExpr : public TempExpr { - public: - /*! - * \brief The constructor for MetaRefExpr - * \param type_key The type key of the object in the meta section. - * \param kind The index into that subfield. - */ - TVM_DLL MetaRefExpr(std::string type_key, uint64_t node_index); - - TVM_DEFINE_OBJECT_REF_METHODS(MetaRefExpr, TempExpr, MetaRefExprNode); -}; - -MetaRefExpr::MetaRefExpr(std::string type_key, uint64_t node_index) { - auto rnode = make_object(); - rnode->type_key = type_key; - rnode->node_index = node_index; - data_ = std::move(rnode); -} - /*! \brief A simple wrapper around a mapping from raw string names * to a TVM variable, type variable or other binder type. */ @@ -1253,38 +1200,33 @@ class Parser { // Parse a local of the form `x`. // Right now we fail to parse `x.y`. case TokenType::Identifier: { - auto string = next.ToString(); - Consume(TokenType::Identifier); - auto ctor = ctors.Get(string); + auto ctor = ctors.Get(next.ToString()); if (ctor) { + Consume(TokenType::Identifier); return Expr(ctor.value()); } else { - // id(x) ^ ';' this is works - // id(x) ^ '.' id(y) this is works - - // ^ id(x) '.' id(y) this is works - // id(x) ^ '.' id(y) this is works - // id(x) '.' ^ id(y) this is works - // id(x) '.' id(y) ^ this is works - // DisplayNextN(10); - - if (Peek()->token_type == TokenType::Period) { - std::stringstream hier_id; - Consume(TokenType::Period); - auto second_id = Match(TokenType::Identifier); - hier_id << string << "." - << second_id.ToString(); - return GetOp(hier_id.str(), next); - } else { - auto op = GetOp(string, next); - return op; + auto idents = ParseHierName(); + std::stringstream op_name; + int i = 0; + int periods = idents.size() - 1; + for (auto ident : idents) { + op_name << ident; + if (i < periods) { + op_name << "."; + i++; + } } + return GetOp(op_name.str(), next); } } case TokenType::Graph: { Consume(TokenType::Graph); return LookupGraphBinding(next); } + case TokenType::MetaRef: { + Consume(TokenType::MetaRef); + return Downcast(next->data); + } case TokenType::Fn: { Consume(TokenType::Fn); return Expr(ParseFunctionDef()); @@ -1333,6 +1275,24 @@ class Parser { return expr; } + /*! \brief Parse a hierarchical name. */ + Array ParseHierName() { + Array idents; + while (Peek()->token_type == TokenType::Identifier) { + idents.push_back(Peek().ToString()); + Consume(TokenType::Identifier); + + if (Peek()->token_type == TokenType::Period) { + Consume(TokenType::Period); + continue; + } else { + break; + } + } + + return idents; + } + /*! \brief Parse a shape. */ Array ParseShape() { auto dims = ParseSequence(TokenType::OpenParen, TokenType::Comma, diff --git a/src/parser/token.h b/src/parser/token.h index d7aac23ca350..0d38e09213d4 100644 --- a/src/parser/token.h +++ b/src/parser/token.h @@ -85,6 +85,8 @@ enum TokenType { Extern, Match, PartialMatch, + Metadata, + MetaRef, Unknown, EndOfFile, Null, @@ -186,6 +188,10 @@ std::string ToString(const TokenType& token_type) { return "Question"; case TokenType::Boolean: return "Boolean"; + case TokenType::Metadata: + return "Metadata"; + case TokenType::MetaRef: + return "MetaRef"; case TokenType::Unknown: return "Unknown"; case TokenType::EndOfFile: @@ -289,6 +295,10 @@ std::string Pretty(const TokenType& token_type) { return "`extern`"; case TokenType::Boolean: return "boolean"; + case TokenType::Metadata: + return "metadata section"; + case TokenType::MetaRef: + return "`meta`"; case TokenType::Match: return "`match`"; case TokenType::PartialMatch: @@ -339,6 +349,7 @@ class Token : public ObjectRef { static Token Null(); int64_t ToNumber() const; std::string ToString() const; + Map> ToMetadata() const; TVM_DEFINE_OBJECT_REF_METHODS(Token, ObjectRef, TokenNode); }; @@ -357,6 +368,10 @@ int64_t Token::ToNumber() const { return Downcast(this->operator-> std::string Token::ToString() const { return Downcast(this->operator->()->data); } + Map> Token::ToMetadata() const { + return Downcast>>(this->operator->()->data); + } + } // namespace parser } // namespace tvm #endif // TVM_PARSER_TOKEN_H_ diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index 7c4b99fd0eee..9333ed357de0 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -26,6 +26,7 @@ #include #include +#include #include #include @@ -171,6 +172,48 @@ struct Tokenizer { } } + bool MatchString(const std::string& string) { + int start = this->pos; + + for (auto c : string) { + if (Peek() != c) { + this->pos = start; + return false; + } else { + Next(); + } + } + + return true; + } + + Token TokenizeMetaRef() { + int line = this->line; + int column = this->col; + + CHECK_EQ(Peek(), '['); + Next(); + std::stringstream type_key; + while (More() && Peek() != ']') { + type_key << Next(); + } + CHECK_EQ(Peek(), ']'); + Next(); + + CHECK_EQ(Peek(), '['); + Next(); + std::stringstream str_index; + while (More() && Peek() != ']') { + str_index << Next(); + } + CHECK_EQ(Peek(), ']'); + Next(); + std::cout << "NUmber: " << str_index.str() << std::endl; + // todo: add error handling around bad indices + auto index = ParseNumber(true, false, str_index.str()).ToNumber(); + return Token(line, column, TokenType::MetaRef, MetaRefExpr(type_key.str(), index)); + } + inline Token TokenizeOnce() { auto next = Peek(); if (next == '\n') { @@ -298,6 +341,35 @@ struct Tokenizer { auto token = NewToken(TokenType::Question); Next(); return token; + } else if (MatchString("meta")) { + return TokenizeMetaRef(); + } else if (next == '#') { + Next(); + int line = this->line; + int column = this->col; + if (Peek() == '[') { + Next(); + std::stringstream attribute; + while (More() && Peek() != ']') { + attribute << Next(); + } + CHECK_EQ(Next(), ']'); + // Metadata can only appear at the bottom of a file and goes to EOF. + if (attribute.str() == "metadata") { + std::stringstream metadata; + while (More()) { + metadata << Next(); + } + ObjectRef metadata_map = tvm::LoadJSON(metadata.str()); + return Token(line, column, TokenType::Metadata, metadata_map); + } else { + LOG(FATAL) << "unsupported " << attribute.str(); + return Token(); + } + } else { + LOG(FATAL) << "lex error"; + return Token(); + } } else if (next == '%') { auto token = NewToken(TokenType::Percent); Next(); @@ -461,6 +533,7 @@ std::vector Tokenize(std::string source) { auto tokens = Condense(tokenizer.tokens); for (auto token : tokens) { CHECK(token.defined()); + std::cout << token << std::endl; } return tokens; } diff --git a/src/printer/text_printer.cc b/src/printer/text_printer.cc index 2993d38234ea..b887d174618f 100644 --- a/src/printer/text_printer.cc +++ b/src/printer/text_printer.cc @@ -33,16 +33,6 @@ namespace tvm { static const char* kSemVer = "v0.0.4"; -// TODO(tvm-team): split into files, related: arith/analyzer.h -// -// - text_printer.h (common header) -// - text_printer.cc (prints modules dispatch into relay and tir files) -// - type_text_printer.cc(specific printing logics for types, -// can also consider put under type_text_printer) -// - Implements AsText -// - relay_text_printer.cc (specific printing logics for relay) -// - tir_text_printer.cc (specific printing logics for TIR) - Doc TextPrinter::PrintMod(const IRModule& mod) { Doc doc; int counter = 0; diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 7baa3878ff72..867d48b0ef7b 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -385,10 +385,12 @@ class TextPrinter { if (!meta_.empty()) { doc << Doc::NewLine(); if (show_meta_data_) { - // append meta data in the end. - doc << "METADATA:" << Doc::NewLine() << meta_.GetMetaSection(); + doc << "#[metadata]" << Doc::NewLine() << meta_.GetMetaSection(); } else { - doc << "// meta data omitted. you can use show_meta_data=True to include meta data"; + doc << "/* For debugging purposes the metadata section has been omitted." << Doc::NewLine() + << " * If you would like to see the full metadata section you can set the `show_meta_data`" << Doc::NewLine() + << " * option to `True` when invoking `astext`. " << Doc::NewLine() + << " */"; } } return doc; diff --git a/tests/python/relay/test_ir_parser2.py b/tests/python/relay/test_ir_parser2.py index aa785d2f61f2..d2cfd6194ebc 100644 --- a/tests/python/relay/test_ir_parser2.py +++ b/tests/python/relay/test_ir_parser2.py @@ -230,6 +230,10 @@ def test_vars(): assert isinstance(op, tvm.ir.Op) assert op.name == "nn.global_avg_pool2d" +def test_meta_ref(): + # var = parse_text("meta[type_key][index]") + var = parse_text("meta[type_key][0]") + import pdb; pdb.set_trace() def test_let(): assert_parses_as( @@ -905,8 +909,32 @@ def test_resnet(): parsed_mod = parse_module(text) tvm.ir.assert_structural_equal(mod, parsed_mod) +def inline_params(mod, params): + main_fn = mod["main"] + str_to_var = {} + for param in main_fn.params: + str_to_var[param.name_hint] = param + + bind_map = {} + for param in params: + bind_map[str_to_var[param]] = relay.const(params[param]) + + body = relay.bind(main_fn.body, bind_map) + main_fn = relay.Function(relay.analysis.free_vars(body), body) + mod["main_fn"] = main_fn + return mod + +def test_resnet_inlined_params(): + mod, params = relay.testing.resnet.get_workload() + mod = inline_params(mod, params) + text = str(mod.astext()) + parsed_mod = parse_module(text) + import pdb; pdb.set_trace() + tvm.ir.assert_structural_equal(mod, parsed_mod) + if __name__ == "__main__": # import sys # pytest.main(sys.argv) # test_hierarchical_identifiers() - test_resnet() + # test_resnet_inlined_params() + test_meta_ref() From 71dc90c208325e35ad842235d51e6fecffa9ae9f Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 24 Jul 2020 14:39:37 -0700 Subject: [PATCH 04/48] WIP --- include/tvm/relay/expr_functor.h | 2 ++ src/parser/meta_ref.h | 8 -------- src/parser/parser.cc | 13 ++++++++++++- src/printer/relay_text_printer.cc | 7 ++++++- tests/python/relay/test_ir_parser2.py | 4 +++- 5 files changed, 23 insertions(+), 11 deletions(-) diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 1189643c8181..a175fdf4cd4a 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -107,6 +107,7 @@ class ExprFunctor { virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const TempExprNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); throw; @@ -132,6 +133,7 @@ class ExprFunctor { RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode); RELAY_EXPR_FUNCTOR_DISPATCH(ConstructorNode); RELAY_EXPR_FUNCTOR_DISPATCH(MatchNode); + RELAY_EXPR_FUNCTOR_DISPATCH(TempExprNode); return vtable; } }; diff --git a/src/parser/meta_ref.h b/src/parser/meta_ref.h index b8d118195bfa..0ca8cad6eb59 100644 --- a/src/parser/meta_ref.h +++ b/src/parser/meta_ref.h @@ -26,7 +26,6 @@ #define TVM_PARSER_META_REF_H_ #include - #include namespace tvm { @@ -81,13 +80,6 @@ class MetaRefExpr : public TempExpr { TVM_DEFINE_OBJECT_REF_METHODS(MetaRefExpr, TempExpr, MetaRefExprNode); }; -MetaRefExpr::MetaRefExpr(std::string type_key, uint64_t node_index) { - auto rnode = make_object(); - rnode->type_key = type_key; - rnode->node_index = node_index; - data_ = std::move(rnode); -} - } // namespace parser } // namespace tvm diff --git a/src/parser/parser.cc b/src/parser/parser.cc index bc6a7704dde6..96e871ed926b 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -42,6 +42,13 @@ namespace parser { using namespace relay; using Expr = relay::Expr; +MetaRefExpr::MetaRefExpr(std::string type_key, uint64_t node_index) { + auto rnode = make_object(); + rnode->type_key = type_key; + rnode->node_index = node_index; + data_ = std::move(rnode); +} + /*! \brief A wrapper structure for capturing the result of parsing * a global definition *before* we add it to the IRModule. * @@ -526,7 +533,11 @@ class Parser { } return elements; } else { - LOG(FATAL) << "issue"; + auto next = Peek(); + std::stringstream msg; + msg << "expected a " << Pretty(stop) << " found " << Pretty(next->token_type); + diag_ctx.Emit({next->line, next->column, msg.str()}); + diag_ctx.Render(std::cout); return Array(nullptr); } } diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index ee11548edf29..632ded4eecb2 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -40,6 +40,7 @@ #include "../ir/attr_functor.h" #include "../relay/analysis/dependency_graph.h" +#include "../parser/meta_ref.h" #include "doc.h" #include "meta_data.h" #include "text_printer.h" @@ -246,6 +247,7 @@ Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline) { // determine whether to inline bool inline_expr = AlwaysInline(expr); + if (try_inline) { inline_expr |= IsUnique(expr); } @@ -254,7 +256,10 @@ Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline) { if (it != memo_.end()) return it->second; Doc printed_expr; - if (meta) { + + if (auto meta_ref = expr.as()) { + printed_expr << "meta[" << meta_ref->type_key << "]" << "[" << meta_ref->node_index << "]"; + } else if (meta) { printed_expr = meta_->GetMetaNode(GetRef(expr.get())); } else if (!inline_expr && expr.as()) { // wrap GNFed let in brackets diff --git a/tests/python/relay/test_ir_parser2.py b/tests/python/relay/test_ir_parser2.py index d2cfd6194ebc..28b5660424d2 100644 --- a/tests/python/relay/test_ir_parser2.py +++ b/tests/python/relay/test_ir_parser2.py @@ -75,7 +75,9 @@ def graph_equal(lhs, rhs): return tvm.ir.structural_equal(lhs, rhs, map_free_vars=True) def roundtrip_expr(expr): - x = tvm.parser.parse_expr(str(str(expr))) + text = tvm.relay.Expr.astext(expr, show_meta_data=False) + import pdb; pdb.set_trace() + x = tvm.parser.parse_expr(str(text)) assert_graph_equal(x, expr) def roundtrip(expr): From b940e8b73e5cdd9cb11d1158e00ed55de6536f8b Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 27 Jul 2020 14:09:24 -0700 Subject: [PATCH 05/48] Change meta reference to an operator --- include/tvm/relay/expr_functor.h | 2 - src/parser/meta_ref.cc | 75 +++++++++++++++++++++++++++++++ src/parser/meta_ref.h | 50 +++++++++------------ src/parser/parser.cc | 11 +---- src/parser/token.h | 8 ++-- src/parser/tokenizer.h | 3 +- src/printer/relay_text_printer.cc | 4 +- 7 files changed, 106 insertions(+), 47 deletions(-) create mode 100644 src/parser/meta_ref.cc diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index a175fdf4cd4a..1189643c8181 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -107,7 +107,6 @@ class ExprFunctor { virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const TempExprNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); throw; @@ -133,7 +132,6 @@ class ExprFunctor { RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode); RELAY_EXPR_FUNCTOR_DISPATCH(ConstructorNode); RELAY_EXPR_FUNCTOR_DISPATCH(MatchNode); - RELAY_EXPR_FUNCTOR_DISPATCH(TempExprNode); return vtable; } }; diff --git a/src/parser/meta_ref.cc b/src/parser/meta_ref.cc new file mode 100644 index 000000000000..e98ecada72fd --- /dev/null +++ b/src/parser/meta_ref.cc @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/parser/meta_ref.cc + * \brief An operator which allows forward referencing a yet-to-be parsed meta table reference. + */ + +#include +#include +#include + +#include "./meta_ref.h" + +namespace tvm { +namespace parser { + +TVM_REGISTER_NODE_TYPE(MetaRefAttrs); + +bool MetaRefRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + LOG(FATAL) << "need to expand before type checking"; +// CHECK_EQ(types.size(), 3u); +// auto size_type = types[0]; +// auto tensor_type = size_type.as(); +// CHECK(tensor_type != nullptr); +// CHECK_EQ(tensor_type->dtype, DataType::Int(64)); +// CHECK_EQ(tensor_type->shape.size(), 0); +// auto align_type = types[1]; +// auto align_ttype = align_type.as(); +// CHECK(align_ttype != nullptr); +// CHECK_EQ(align_ttype->dtype, DataType::Int(64)); +// CHECK_EQ(align_ttype->shape.size(), 0); +// auto mod = reporter->GetModule(); +// CHECK(mod.defined()); +// auto storage_name = mod->GetGlobalTypeVar("Storage"); +// auto storage = TypeCall(storage_name, {}); +// reporter->Assign(types[2], storage); +// return true; +} + +RELAY_REGISTER_OP("parser.MetaRef") + .describe(R"code(A reference into the meta table.)code" TVM_ADD_FILELINE) + .set_num_inputs(0) + .set_support_level(10) + .add_type_rel("MetaRef", MetaRefRel) + .set_attr("TOpIsStateful", false) + .set_attr("TNonComputational", true); + +Expr MetaRef(std::string type_key, uint64_t node_index) { + static const Op& op = Op::Get("parser.MetaRef"); + auto attrs = make_object(); + attrs->type_key = type_key; + attrs->node_index = node_index; + return Call(op, {}, Attrs(attrs), {}); +} + +} // namespace parser +} // namespace tvm diff --git a/src/parser/meta_ref.h b/src/parser/meta_ref.h index 0ca8cad6eb59..2b4bf647a1e6 100644 --- a/src/parser/meta_ref.h +++ b/src/parser/meta_ref.h @@ -25,7 +25,10 @@ #ifndef TVM_PARSER_META_REF_H_ #define TVM_PARSER_META_REF_H_ + +#include #include + #include namespace tvm { @@ -33,6 +36,20 @@ namespace parser { using namespace relay; +/*! + * \brief Options for allocating storage. + */ +struct MetaRefAttrs : public tvm::AttrsNode { + std::string type_key; + uint64_t node_index; + + TVM_DECLARE_ATTRS(MetaRefAttrs, "relay.attrs.MetaRefAttrs") { + TVM_ATTR_FIELD(type_key) + .describe("The type_key representing the type of the node referenced."); + TVM_ATTR_FIELD(node_index).describe("The index into the type specific node array."); + } +}; + /*! \brief A reference to a "meta-expression". * * In the text format we allow referencing metadata which @@ -50,35 +67,12 @@ using namespace relay; * For example the nth large constant will be pretty-printed as meta[relay.Constant][n] * with its compact binary serialization residing in the metadata section at the end * of the program. + * + * \param type_key The type key of the object in the meta section. + * \param kind The index into that subfield. + * \returns The meta table reference. */ -class MetaRefExprNode : public TempExprNode { - public: - /*! \brief The type key of the meta expression. */ - std::string type_key; - /*! \brief The index into the type key's table. */ - uint64_t node_index; - - void VisitAttrs(tvm::AttrVisitor* v) {} - - // TODO(@jroesch): we probably will need to manually - // expand these with a pass. - Expr Realize() const final { return Expr(); } - - static constexpr const char* _type_key = "relay.MetaRefExpr"; - TVM_DECLARE_FINAL_OBJECT_INFO(MetaRefExprNode, TempExprNode); -}; - -class MetaRefExpr : public TempExpr { - public: - /*! - * \brief The constructor for MetaRefExpr - * \param type_key The type key of the object in the meta section. - * \param kind The index into that subfield. - */ - TVM_DLL MetaRefExpr(std::string type_key, uint64_t node_index); - - TVM_DEFINE_OBJECT_REF_METHODS(MetaRefExpr, TempExpr, MetaRefExprNode); -}; +Expr MetaRef(std::string type_key, uint64_t node_index); } // namespace parser } // namespace tvm diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 96e871ed926b..8a157f325dcd 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -42,13 +42,6 @@ namespace parser { using namespace relay; using Expr = relay::Expr; -MetaRefExpr::MetaRefExpr(std::string type_key, uint64_t node_index) { - auto rnode = make_object(); - rnode->type_key = type_key; - rnode->node_index = node_index; - data_ = std::move(rnode); -} - /*! \brief A wrapper structure for capturing the result of parsing * a global definition *before* we add it to the IRModule. * @@ -1234,8 +1227,8 @@ class Parser { Consume(TokenType::Graph); return LookupGraphBinding(next); } - case TokenType::MetaRef: { - Consume(TokenType::MetaRef); + case TokenType::MetaReference: { + Consume(TokenType::MetaReference); return Downcast(next->data); } case TokenType::Fn: { diff --git a/src/parser/token.h b/src/parser/token.h index 0d38e09213d4..8222389fa7a6 100644 --- a/src/parser/token.h +++ b/src/parser/token.h @@ -86,7 +86,7 @@ enum TokenType { Match, PartialMatch, Metadata, - MetaRef, + MetaReference, Unknown, EndOfFile, Null, @@ -190,8 +190,8 @@ std::string ToString(const TokenType& token_type) { return "Boolean"; case TokenType::Metadata: return "Metadata"; - case TokenType::MetaRef: - return "MetaRef"; + case TokenType::MetaReference: + return "MetaReference"; case TokenType::Unknown: return "Unknown"; case TokenType::EndOfFile: @@ -297,7 +297,7 @@ std::string Pretty(const TokenType& token_type) { return "boolean"; case TokenType::Metadata: return "metadata section"; - case TokenType::MetaRef: + case TokenType::MetaReference: return "`meta`"; case TokenType::Match: return "`match`"; diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index 9333ed357de0..5994b8697f68 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -34,6 +34,7 @@ #include #include "./token.h" +#include "./meta_ref.h" namespace tvm { namespace parser { @@ -211,7 +212,7 @@ struct Tokenizer { std::cout << "NUmber: " << str_index.str() << std::endl; // todo: add error handling around bad indices auto index = ParseNumber(true, false, str_index.str()).ToNumber(); - return Token(line, column, TokenType::MetaRef, MetaRefExpr(type_key.str(), index)); + return Token(line, column, TokenType::MetaReference, MetaRef(type_key.str(), index)); } inline Token TokenizeOnce() { diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 632ded4eecb2..70c4453bfff5 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -257,9 +257,7 @@ Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline) { Doc printed_expr; - if (auto meta_ref = expr.as()) { - printed_expr << "meta[" << meta_ref->type_key << "]" << "[" << meta_ref->node_index << "]"; - } else if (meta) { + if (meta) { printed_expr = meta_->GetMetaNode(GetRef(expr.get())); } else if (!inline_expr && expr.as()) { // wrap GNFed let in brackets From 089b34295483c3eb2c28a514cadec1379224dfcf Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 27 Jul 2020 18:01:57 -0700 Subject: [PATCH 06/48] Meta references now work --- include/tvm/ir/attrs.h | 3 +- src/parser/meta_ref.cc | 21 +---- src/parser/meta_ref.h | 4 +- src/parser/parser.cc | 65 ++++++++++----- src/parser/token.h | 5 ++ src/parser/tokenizer.h | 109 ++++++++++++++++++-------- src/printer/relay_text_printer.cc | 2 + src/printer/text_printer.cc | 4 +- src/runtime/graph/graph_runtime.h | 1 + tests/python/relay/test_ir_parser2.py | 12 +-- 10 files changed, 145 insertions(+), 81 deletions(-) diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 749274acbb96..7981d58b0ead 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -353,7 +353,8 @@ struct AttrInitEntry { ~AttrInitEntry() DMLC_THROW_EXCEPTION { if (value_missing_) { std::ostringstream os; - os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization"; + os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization." + << "If the key is defined check that its type matches the declared type."; throw AttrError(os.str()); } } diff --git a/src/parser/meta_ref.cc b/src/parser/meta_ref.cc index e98ecada72fd..8d65a6f5248f 100644 --- a/src/parser/meta_ref.cc +++ b/src/parser/meta_ref.cc @@ -36,27 +36,12 @@ TVM_REGISTER_NODE_TYPE(MetaRefAttrs); bool MetaRefRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { LOG(FATAL) << "need to expand before type checking"; -// CHECK_EQ(types.size(), 3u); -// auto size_type = types[0]; -// auto tensor_type = size_type.as(); -// CHECK(tensor_type != nullptr); -// CHECK_EQ(tensor_type->dtype, DataType::Int(64)); -// CHECK_EQ(tensor_type->shape.size(), 0); -// auto align_type = types[1]; -// auto align_ttype = align_type.as(); -// CHECK(align_ttype != nullptr); -// CHECK_EQ(align_ttype->dtype, DataType::Int(64)); -// CHECK_EQ(align_ttype->shape.size(), 0); -// auto mod = reporter->GetModule(); -// CHECK(mod.defined()); -// auto storage_name = mod->GetGlobalTypeVar("Storage"); -// auto storage = TypeCall(storage_name, {}); -// reporter->Assign(types[2], storage); -// return true; + return true; } RELAY_REGISTER_OP("parser.MetaRef") .describe(R"code(A reference into the meta table.)code" TVM_ADD_FILELINE) + .set_attrs_type() .set_num_inputs(0) .set_support_level(10) .add_type_rel("MetaRef", MetaRefRel) @@ -66,7 +51,7 @@ RELAY_REGISTER_OP("parser.MetaRef") Expr MetaRef(std::string type_key, uint64_t node_index) { static const Op& op = Op::Get("parser.MetaRef"); auto attrs = make_object(); - attrs->type_key = type_key; + attrs->node_type_key = tvm::String(type_key); attrs->node_index = node_index; return Call(op, {}, Attrs(attrs), {}); } diff --git a/src/parser/meta_ref.h b/src/parser/meta_ref.h index 2b4bf647a1e6..6a273ff9d3d4 100644 --- a/src/parser/meta_ref.h +++ b/src/parser/meta_ref.h @@ -40,11 +40,11 @@ using namespace relay; * \brief Options for allocating storage. */ struct MetaRefAttrs : public tvm::AttrsNode { - std::string type_key; + tvm::String node_type_key; uint64_t node_index; TVM_DECLARE_ATTRS(MetaRefAttrs, "relay.attrs.MetaRefAttrs") { - TVM_ATTR_FIELD(type_key) + TVM_ATTR_FIELD(node_type_key) .describe("The type_key representing the type of the node referenced."); TVM_ATTR_FIELD(node_index).describe("The index into the type specific node array."); } diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 8a157f325dcd..c0b28cac83db 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -490,7 +490,20 @@ class Parser { template Array ParseSequence(TokenType start, TokenType sep, TokenType stop, std::function parse, std::function before_stop = nullptr) { + DLOG(INFO) << "Parser::ParseSequence: start=" << start << "sep=" << sep << "stop=" << stop; Match(start); + + // This is for the empty arguments list case, if we have token stream + // we must parse leftovers, then match a stop token. + if (before_stop) { + auto did_parse = before_stop(); + if (did_parse) { + Match(stop); + return {}; + } + } + + // This is the case in which we find an empty arguments lists and no leftovers. if (WhenMatch(stop)) { return Array(); } else { @@ -562,24 +575,23 @@ class Parser { } /*! \brief Parse the semantic versioning header. */ - SemVer ParseSemVer() { - // TODO(@jroesch): convert semver to module level attribute. - auto id = Peek(); - if (id->token_type == TokenType::Identifier && id.ToString() == "v0") { - auto id = Match(TokenType::Identifier); - Consume(TokenType::Period); - Consume(TokenType::Float); + SemVer ParseSemVer(bool required=true) { + if (Peek()->token_type == TokenType::Version) { + auto version = Match(TokenType::Version); + // TODO(@jroesch): we currently only support 0.0.5. + if (version.ToString() != "\"0.0.5\"") { + std::stringstream msg; + msg << "invalid semantic version `"; + msg << version.ToString() << "`"; + this->diag_ctx.Emit({version->line, version->column, msg.str() }); + } + } else if (required) { + std::stringstream msg; + msg << "expected text format semantic version "; + msg << "you can annotate it as #[version = \"0.0.5\"]"; + this->diag_ctx.Emit({Peek()->line, Peek()->column, msg.str() }); } - // TODO(@jroesch): the current lexing makes it hard to parse this - // in a way that doesnt feel like a hack. - // - // We should move to module level attributes instead - // so we can tag modules with top-level data. - // - // #[text_version = "0.0.4"] - // - // For now we only support current version. - return SemVer(0, 0, 4); + return SemVer(0, 0, 5); } /*! \brief Parse zero or more Relay definitions. */ @@ -701,10 +713,12 @@ class Parser { /*! \brief Parse a single Relay expression. */ Expr ParseExpr() { + DLOG(INFO) << "Parser::ParseExpr"; return ConsumeWhitespace([this] { std::vector exprs; while (true) { + DLOG(INFO) << "Parser::ParseExpr: parsing a single expression"; auto next = Peek(); switch (next->token_type) { // For graph or let, match first rhs, then invoke ParseBindingExpr @@ -799,6 +813,7 @@ class Parser { // This ensures for n sequential bindings // the call depth will be the same before // and after parsing the n bindings. + DLOG(INFO) << "Parser::ParseBindingExpr"; std::vector> bindings; int scopes = 0; @@ -869,6 +884,7 @@ class Parser { * Handles things of the form [T1, ..., TN](arg1: U1, ..., argN, UN) -> Ret { body }. */ Function ParseFunctionDef() { + DLOG(INFO) << "Parser::ParseFunctionDef"; PushScope(); PushTypeScope(); @@ -910,6 +926,7 @@ class Parser { /*! \brief Parse an if-expression. */ Expr ParseIf() { + DLOG(INFO) << "Parser::ParseIf"; Consume(TokenType::If); auto guard = Parens([&] { return ParseExpr(); }); @@ -936,6 +953,7 @@ class Parser { * This function recursively parses a pattern. */ Pattern ParsePattern() { + DLOG(INFO) << "Parser::ParsePattern"; auto next = Peek(); switch (next->token_type) { case TokenType::Underscore: { @@ -987,6 +1005,7 @@ class Parser { } Expr ParseExprBinOp() { + DLOG(INFO) << "Parser::ParseExprBinOp"; return ConsumeWhitespace([this] { // We must parse at least one expression, the default // case is that there is no operator and we will fall @@ -1060,11 +1079,13 @@ class Parser { } ObjectRef ParseAttributeValue() { + DLOG(INFO) << "Parser::ParseAttributeValue"; auto next = Peek(); switch (next->token_type) { case TokenType::Float: case TokenType::Integer: case TokenType::Boolean: + case TokenType::StringLiteral: return Match(next->token_type)->data; case TokenType::LSquare: { return ParseSequence(TokenType::LSquare, TokenType::Comma, TokenType::RSquare, [&]() { @@ -1077,6 +1098,7 @@ class Parser { } Map ParseAttrs() { + DLOG(INFO) << "Parser::ParseAttrs"; Map kwargs; while (Peek()->token_type == TokenType::Identifier) { auto key = Match(TokenType::Identifier).ToString(); @@ -1086,10 +1108,12 @@ class Parser { kwargs.Set(key, value); WhenMatch(TokenType::Comma); } + DLOG(INFO) << "Parser::ParseAttrs: kwargs=" << kwargs; return kwargs; } Expr ParseCallArgs(Expr op) { + DLOG(INFO) << "Parser::ParseCallArgs"; Map raw_attrs; std::string op_key; bool is_op = false; @@ -1118,6 +1142,7 @@ class Parser { Attrs attrs; if (is_op && op_key.size()) { + // raw_attrs.Set("type_key", tvm::String("hello")); auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs); CHECK(attr_obj.defined()); attrs = Downcast(attr_obj); @@ -1130,13 +1155,14 @@ class Parser { } Expr ParseCallExpr() { + DLOG(INFO) << "Parser::ParseCallExpr"; return ConsumeWhitespace([this] { Expr expr = ParseAtomicExpr(); // Parse as many call args as possible, building up expression // // NB(@jroesch): this seems like a hack but in order to parse curried functions // and avoid complex grammar we will parse multiple call lists in a row. - while (true) { + while (Peek()->token_type == TokenType::OpenParen) { auto new_expr = ParseCallArgs(expr); if (new_expr.defined()) { expr = new_expr; @@ -1166,6 +1192,7 @@ class Parser { } Expr ParseAtomicExpr() { + DLOG(INFO) << "Parser::ParseAtomicExpr"; auto expr = ConsumeWhitespace([this] { auto next = Peek(); switch (next->token_type) { @@ -1442,8 +1469,10 @@ IRModule ParseModule(std::string file_name, std::string file_content) { } Expr ParseExpr(std::string file_name, std::string file_content) { + DLOG(INFO) << "ParseExpr"; auto tokens = Tokenize(file_content); Parser parser(tokens, DefaultOpTable(), Source(file_content)); + parser.ParseSemVer(false); parser.PushScope(); auto expr = parser.ParseExpr(); parser.Match(TokenType::EndOfFile); diff --git a/src/parser/token.h b/src/parser/token.h index 8222389fa7a6..60a936852b89 100644 --- a/src/parser/token.h +++ b/src/parser/token.h @@ -87,6 +87,7 @@ enum TokenType { PartialMatch, Metadata, MetaReference, + Version, Unknown, EndOfFile, Null, @@ -192,6 +193,8 @@ std::string ToString(const TokenType& token_type) { return "Metadata"; case TokenType::MetaReference: return "MetaReference"; + case TokenType::Version: + return "Version"; case TokenType::Unknown: return "Unknown"; case TokenType::EndOfFile: @@ -311,6 +314,8 @@ std::string Pretty(const TokenType& token_type) { return "end of file"; case TokenType::Null: return "null"; + case TokenType::Version: + return "version attribute"; // Older compilers warn even though the above code is exhaustive. default: LOG(FATAL) << "unreachable code"; diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index 5994b8697f68..5d55ba2bb879 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -41,6 +41,20 @@ namespace parser { using namespace runtime; +// trim from start (in place) +static inline void ltrim(std::string &s) { + s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) { + return !std::isspace(ch); + })); +} + +// trim from end (in place) +static inline void rtrim(std::string &s) { + s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) { + return !std::isspace(ch); + }).base(), s.end()); +} + bool IsDigit(char c) { return '0' <= c && c <= '9'; } bool IsWhitespace(char c) { return ' ' == c || c == '\t' || c == '\n'; } @@ -106,7 +120,10 @@ struct Tokenizer { CommentParserState state = CommentParserState::Proceed; int nesting = 1; - while (true) { + while (More()) { + std::cout << "In comment state machine" << std::endl; + std::cout << "Buffer: " << *buffer << std::endl; + std::cout << "State: " << state << std::endl; switch (state) { case CommentParserState::Proceed: { if (Peek() == '/') { @@ -132,11 +149,11 @@ struct Tokenizer { Next(); buffer->pop_back(); return; - } else { - buffer->operator+=(Next()); - state = CommentParserState::Proceed; } } + + buffer->operator+=(Next()); + state = CommentParserState::Proceed; continue; } } @@ -215,8 +232,51 @@ struct Tokenizer { return Token(line, column, TokenType::MetaReference, MetaRef(type_key.str(), index)); } + Token TokenizeAttr() { + int line = this->line; + int column = this->col; + Next(); + if (Peek() == '[') { + Next(); + std::stringstream raw_attribute; + + while (More() && Peek() != ']') { + raw_attribute << Next(); + } + + CHECK_EQ(Next(), ']'); + + auto attribute = raw_attribute.str(); + // Clean up the white-space on both sides. + ltrim(attribute); + rtrim(attribute); + + // Metadata can only appear at the bottom of a file and goes to EOF. + if (attribute == "metadata") { + std::stringstream metadata; + while (More()) { + metadata << Next(); + } + ObjectRef metadata_map = tvm::LoadJSON(metadata.str()); + return Token(line, column, TokenType::Metadata, metadata_map); + } if (attribute.rfind("version", 0) == 0) { + std::string version = attribute.substr(attribute.find("=") + 1); + ltrim(version); + rtrim(version); + return Token(line, column, TokenType::Version, tvm::String(version)); + } else { + LOG(FATAL) << "unsupported " << attribute; + return Token(); + } + } else { + LOG(FATAL) << "lex error"; + return Token(); + } + } + inline Token TokenizeOnce() { auto next = Peek(); + DLOG(INFO) << "tvm::parser::TokenizeOnce: next=" << next; if (next == '\n') { auto token = NewToken(TokenType::Newline); Next(); @@ -228,12 +288,20 @@ struct Tokenizer { return token; } else { // TODO(@jroesch): have lexer use diagnostic context too. + // see https://github.com/apache/incubator-tvm/issues/6153. LOG(FATAL) << "lexer error"; return Token(); } } else if (next == '"') { - LOG(FATAL) << "string not working yet"; - return NewToken(TokenType::Unknown); + // TODO(@jroesch): Properly tokenize escape sequences in strings. + // see https://github.com/apache/incubator-tvm/issues/6153. + Next(); + std::stringstream string_content; + while (More() && Peek() != '"') { + string_content << Next(); + } + Next(); + return NewToken(TokenType::StringLiteral, tvm::String(string_content.str())); } else if (IsWhitespace(next)) { auto token = NewToken(TokenType::Whitespace); Next(); @@ -345,32 +413,7 @@ struct Tokenizer { } else if (MatchString("meta")) { return TokenizeMetaRef(); } else if (next == '#') { - Next(); - int line = this->line; - int column = this->col; - if (Peek() == '[') { - Next(); - std::stringstream attribute; - while (More() && Peek() != ']') { - attribute << Next(); - } - CHECK_EQ(Next(), ']'); - // Metadata can only appear at the bottom of a file and goes to EOF. - if (attribute.str() == "metadata") { - std::stringstream metadata; - while (More()) { - metadata << Next(); - } - ObjectRef metadata_map = tvm::LoadJSON(metadata.str()); - return Token(line, column, TokenType::Metadata, metadata_map); - } else { - LOG(FATAL) << "unsupported " << attribute.str(); - return Token(); - } - } else { - LOG(FATAL) << "lex error"; - return Token(); - } + return TokenizeAttr(); } else if (next == '%') { auto token = NewToken(TokenType::Percent); Next(); @@ -451,6 +494,7 @@ struct Tokenizer { } void Tokenize() { + DLOG(INFO) << "tvm::parser::Tokenize"; while (this->More()) { auto token = TokenizeOnce(); CHECK(token.defined()); @@ -531,6 +575,7 @@ std::vector Condense(const std::vector& tokens) { std::vector Tokenize(std::string source) { auto tokenizer = Tokenizer(source); tokenizer.Tokenize(); + std::cout << "Done tokenization" << std::endl; auto tokens = Condense(tokenizer.tokens); for (auto token : tokens) { CHECK(token.defined()); diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 70c4453bfff5..25f027c3aa5c 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -724,6 +724,8 @@ Doc RelayTextPrinter::PrintAttr(const ObjectRef& value, bool meta) { Doc printed_attr; if (value.as()) { printed_attr << "?"; + } else if (auto str_obj = value.as()) { + printed_attr << Doc::StrLiteral(GetRef(str_obj)); } else if (meta) { printed_attr = meta_->GetMetaNode(Downcast(value)); } else { diff --git a/src/printer/text_printer.cc b/src/printer/text_printer.cc index b887d174618f..302c68c20c2a 100644 --- a/src/printer/text_printer.cc +++ b/src/printer/text_printer.cc @@ -31,7 +31,7 @@ namespace tvm { -static const char* kSemVer = "v0.0.4"; +static const char* kSemVer = "0.0.5"; Doc TextPrinter::PrintMod(const IRModule& mod) { Doc doc; @@ -74,7 +74,7 @@ String PrettyPrint(const ObjectRef& node) { String AsText(const ObjectRef& node, bool show_meta_data, runtime::TypedPackedFunc annotate) { Doc doc; - doc << kSemVer << Doc::NewLine(); + doc << "#[version = \"" << kSemVer << "\"]" << Doc::NewLine(); runtime::TypedPackedFunc ftyped = nullptr; if (annotate != nullptr) { ftyped = runtime::TypedPackedFunc( diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h index dcef1e4071be..617ff3e25662 100644 --- a/src/runtime/graph/graph_runtime.h +++ b/src/runtime/graph/graph_runtime.h @@ -216,6 +216,7 @@ class TVM_DLL GraphRuntime : public ModuleNode { std::vector inputs; // control deps std::vector control_deps; + // JSON Loader void LoadAttrs(dmlc::JSONReader* reader, TVMOpParam* param) { int bitmask = 0; diff --git a/tests/python/relay/test_ir_parser2.py b/tests/python/relay/test_ir_parser2.py index 28b5660424d2..c3038771fa57 100644 --- a/tests/python/relay/test_ir_parser2.py +++ b/tests/python/relay/test_ir_parser2.py @@ -76,7 +76,6 @@ def graph_equal(lhs, rhs): def roundtrip_expr(expr): text = tvm.relay.Expr.astext(expr, show_meta_data=False) - import pdb; pdb.set_trace() x = tvm.parser.parse_expr(str(text)) assert_graph_equal(x, expr) @@ -233,9 +232,10 @@ def test_vars(): assert op.name == "nn.global_avg_pool2d" def test_meta_ref(): - # var = parse_text("meta[type_key][index]") - var = parse_text("meta[type_key][0]") - import pdb; pdb.set_trace() + meta_op = parse_text("meta[type_key][1337]") + assert meta_op.attrs.node_type_key == "type_key" + assert meta_op.attrs.node_index == "1337" + def test_let(): assert_parses_as( @@ -935,8 +935,4 @@ def test_resnet_inlined_params(): tvm.ir.assert_structural_equal(mod, parsed_mod) if __name__ == "__main__": - # import sys - # pytest.main(sys.argv) - # test_hierarchical_identifiers() - # test_resnet_inlined_params() test_meta_ref() From a4cfb3f065f184b9deaed6d197ab587cf3288144 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 27 Jul 2020 20:34:53 -0700 Subject: [PATCH 07/48] MetaReference expansion now works --- src/parser/meta_ref.cc | 45 +++++++++++++++++++++++++++ src/parser/meta_ref.h | 5 +++ src/parser/parser.cc | 12 +++++-- src/parser/tokenizer.h | 6 ---- tests/python/relay/test_ir_parser2.py | 8 +++-- 5 files changed, 66 insertions(+), 10 deletions(-) diff --git a/src/parser/meta_ref.cc b/src/parser/meta_ref.cc index 8d65a6f5248f..7c47d8635864 100644 --- a/src/parser/meta_ref.cc +++ b/src/parser/meta_ref.cc @@ -23,6 +23,8 @@ */ #include +#include +#include #include #include @@ -31,6 +33,9 @@ namespace tvm { namespace parser { +using tvm::transform::PassContext; +using tvm::relay::transform::CreateFunctionPass; + TVM_REGISTER_NODE_TYPE(MetaRefAttrs); bool MetaRefRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -56,5 +61,45 @@ Expr MetaRef(std::string type_key, uint64_t node_index) { return Call(op, {}, Attrs(attrs), {}); } + +// class MetaRefAttrExpander : AttrFunctor { +// ObjectRef VisitAttrDefault_(const Object* node) final { + +// } +// } + +struct MetaRefExpander : public ExprMutator { + MetaTable table; + + MetaRefExpander(const MetaTable& table) : table(table) {} + + Expr VisitExpr_(const CallNode* call) final { + if (auto op_node = call->op.as()) { + if (op_node->name == "parser.MetaRef") { + auto meta_attrs = call->attrs.as(); + CHECK(meta_attrs) << "an internal error has occurred"; + auto nodes = table.at(meta_attrs->node_type_key); + CHECK_LT(meta_attrs->node_index, nodes.size()); + return Downcast(nodes[meta_attrs->node_index]); + } + } + + return ExprMutator::VisitExpr_(call); + } +}; + +Function ExpandMetaRefs(const MetaTable& meta_table, const relay::Function& func) { + MetaRefExpander expander(meta_table); + return Downcast(expander.VisitExpr(func)); +} + +IRModule ExpandMetaRefs(const MetaTable& meta_table, const IRModule& mod) { + auto pass = CreateFunctionPass([&](Function func, IRModule module, PassContext ctx) { + return ExpandMetaRefs(meta_table, func); + }, 1337, "ExpandMetaRefs", {}); + + return pass(mod, PassContext::Create()); +} + } // namespace parser } // namespace tvm diff --git a/src/parser/meta_ref.h b/src/parser/meta_ref.h index 6a273ff9d3d4..c3bccce355e5 100644 --- a/src/parser/meta_ref.h +++ b/src/parser/meta_ref.h @@ -36,6 +36,8 @@ namespace parser { using namespace relay; +using MetaTable = Map>; + /*! * \brief Options for allocating storage. */ @@ -74,6 +76,9 @@ struct MetaRefAttrs : public tvm::AttrsNode { */ Expr MetaRef(std::string type_key, uint64_t node_index); +Function ExpandMetaRefs(const MetaTable& meta_table, const Function& mod); +IRModule ExpandMetaRefs(const MetaTable& meta_table, const IRModule& mod); + } // namespace parser } // namespace tvm diff --git a/src/parser/parser.cc b/src/parser/parser.cc index c0b28cac83db..ce7f3fb408c4 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -557,6 +557,7 @@ class Parser { auto defs = ParseDefinitions(); // Parse the metadata section at the end. auto metadata = ParseMetadata(); + Match(TokenType::EndOfFile); Map funcs; Map types; @@ -568,7 +569,8 @@ class Parser { auto mod = IRModule({}, types); for (auto func : defs.funcs) { - mod->Add(func.global, func.function); + auto function = ExpandMetaRefs(metadata.value(), func.function); + mod->Add(func.global, function); } return mod; @@ -1434,7 +1436,13 @@ class Parser { } // TODO(@jroesch): this is the final remaining feature. - ObjectRef ParseMetadata() { return ObjectRef(); } + Optional>> ParseMetadata() { + if (Peek()->token_type == TokenType::Metadata) { + return Match(TokenType::Metadata).ToMetadata(); + } else { + return Optional>>(); + } + } /*! \brief A helper for debugging the parser, displays the next N tokens in the token stream. */ void DisplayNextN(int n) { diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index 5d55ba2bb879..3e3049c05f0b 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -121,9 +121,6 @@ struct Tokenizer { int nesting = 1; while (More()) { - std::cout << "In comment state machine" << std::endl; - std::cout << "Buffer: " << *buffer << std::endl; - std::cout << "State: " << state << std::endl; switch (state) { case CommentParserState::Proceed: { if (Peek() == '/') { @@ -226,7 +223,6 @@ struct Tokenizer { } CHECK_EQ(Peek(), ']'); Next(); - std::cout << "NUmber: " << str_index.str() << std::endl; // todo: add error handling around bad indices auto index = ParseNumber(true, false, str_index.str()).ToNumber(); return Token(line, column, TokenType::MetaReference, MetaRef(type_key.str(), index)); @@ -575,11 +571,9 @@ std::vector Condense(const std::vector& tokens) { std::vector Tokenize(std::string source) { auto tokenizer = Tokenizer(source); tokenizer.Tokenize(); - std::cout << "Done tokenization" << std::endl; auto tokens = Condense(tokenizer.tokens); for (auto token : tokens) { CHECK(token.defined()); - std::cout << token << std::endl; } return tokens; } diff --git a/tests/python/relay/test_ir_parser2.py b/tests/python/relay/test_ir_parser2.py index c3038771fa57..7ffe85c5d5df 100644 --- a/tests/python/relay/test_ir_parser2.py +++ b/tests/python/relay/test_ir_parser2.py @@ -928,11 +928,15 @@ def inline_params(mod, params): def test_resnet_inlined_params(): mod, params = relay.testing.resnet.get_workload() + print("here") mod = inline_params(mod, params) + print("here") text = str(mod.astext()) + print("here") parsed_mod = parse_module(text) - import pdb; pdb.set_trace() + print("here") tvm.ir.assert_structural_equal(mod, parsed_mod) + print("here") if __name__ == "__main__": - test_meta_ref() + test_resnet_inlined_params() From 0cb6c859ee01dd4db8679f5c2b05e74427ebe3c1 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 27 Jul 2020 21:03:21 -0700 Subject: [PATCH 08/48] Start working on source map and move diagnostic context --- include/tvm/parser/source_map.h | 73 +++++++++++++++++++++++++++++ python/tvm/parser/__init__.py | 1 + src/parser/parser.cc | 20 ++++---- src/parser/source_map.cc | 83 +++++++++++++++++++++++++++++++++ to_json.py | 59 +++++++++++++++++++++++ 5 files changed, 226 insertions(+), 10 deletions(-) create mode 100644 include/tvm/parser/source_map.h create mode 100644 src/parser/source_map.cc create mode 100644 to_json.py diff --git a/include/tvm/parser/source_map.h b/include/tvm/parser/source_map.h new file mode 100644 index 000000000000..41b0d385b939 --- /dev/null +++ b/include/tvm/parser/source_map.h @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_PARSER_SOURCE_MAP_H_ +#define TVM_PARSER_SOURCE_MAP_H_ +/*! + * \file source_map.h + * \brief A map from source names to source code. + */ +#include +#include +#include + +#include +#include + +namespace tvm { +namespace parser { + +/*! + * \brief A mapping from a unique source name to source fragment. + */ +class SourceMap; +/*! + * \brief Stores locations in frontend source that generated a node. + */ +class SourceMapNode : public Object { + public: + /*! \brief The source mapping. */ + Map source_map; + + // override attr visitor + void VisitAttrs(AttrVisitor* v) { + v->Visit("source_map", &source_map); + } + + bool SEqualReduce(const SourceMapNode* other, SEqualReducer equal) const { + return equal(source_map, other->source_map); + } + + static constexpr const char* _type_key = "Span"; + TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object); +}; + +class SourceMap : public ObjectRef { + public: + TVM_DLL SourceMap(Map source_map); + + TVM_DLL static SourceMap* Get(); + + TVM_DEFINE_OBJECT_REF_METHODS(SourceMap, ObjectRef, SourceMapNode); +}; + +} // namespace parser +} // namespace tvm + +#endif // TVM_PARSER_SOURCE_MAP_H_ diff --git a/python/tvm/parser/__init__.py b/python/tvm/parser/__init__.py index 071c464dae51..696af362e03f 100644 --- a/python/tvm/parser/__init__.py +++ b/python/tvm/parser/__init__.py @@ -24,4 +24,5 @@ def parse_expr(source): return _ffi_api.ParseExpr("string", source) def fromtext(source, source_name="from_string"): + # TODO(@tqchen): currently we have to invoke `str` which dramatically reduces performance. return parse(str(source), str(source_name)) diff --git a/src/parser/parser.cc b/src/parser/parser.cc index ce7f3fb408c4..69f7e527ec8c 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -211,7 +211,7 @@ class Parser { SemVer version; /*! \brief The diagnostic context used for error reporting. */ - DiagnosticContext diag_ctx; + DiagnosticContext* diag_ctx; /*! \brief The current position in the token stream. */ int pos; @@ -243,8 +243,8 @@ class Parser { /*! \brief The set of expression scopes used for lexical scope. */ ScopeStack expr_scopes; - Parser(std::vector tokens, OperatorTable op_table, Source source) - : diag_ctx(source), pos(0), tokens(tokens), op_table(op_table), ignore_whitespace(true) {} + Parser(DiagnosticContext* ctx, std::vector tokens, OperatorTable op_table, Source source) + : diag_ctx(ctx), pos(0), tokens(tokens), op_table(op_table), ignore_whitespace(true) {} /*! \brief Examine the next token in the stream, the current parser is configured to be * whitespace insensitive so we will skip all whitespace or comment tokens. */ @@ -294,8 +294,8 @@ class Parser { if (tokens[pos]->token_type != token_type) { std::string message = "expected a " + Pretty(token_type) + " found " + Pretty(Peek()->token_type); - this->diag_ctx.Emit({tokens[pos]->line, tokens[pos]->column, message}); - this->diag_ctx.Render(std::cout); + this->diag_ctx->Emit({tokens[pos]->line, tokens[pos]->column, message}); + this->diag_ctx->Render(std::cout); } pos++; } @@ -374,7 +374,7 @@ class Parser { Var LookupLocal(const Token& local) { auto var = this->expr_scopes.Lookup(local.ToString()); if (!var.defined()) { - diag_ctx.Emit( + diag_ctx->Emit( {local->line, local->column, "this local variable has not been previously declared"}); } return var; @@ -387,7 +387,7 @@ class Parser { TypeVar LookupTypeVar(const Token& ident) { auto var = this->type_scopes.Lookup(ident.ToString()); if (!var.defined()) { - diag_ctx.Emit( + diag_ctx->Emit( {ident->line, ident->column, "this type variable has not been previously declared anywhere, perhaps a typo?"}); } @@ -585,13 +585,13 @@ class Parser { std::stringstream msg; msg << "invalid semantic version `"; msg << version.ToString() << "`"; - this->diag_ctx.Emit({version->line, version->column, msg.str() }); + this->diag_ctx->Emit({version->line, version->column, msg.str() }); } } else if (required) { std::stringstream msg; msg << "expected text format semantic version "; msg << "you can annotate it as #[version = \"0.0.5\"]"; - this->diag_ctx.Emit({Peek()->line, Peek()->column, msg.str() }); + this->diag_ctx->Emit({Peek()->line, Peek()->column, msg.str() }); } return SemVer(0, 0, 5); } @@ -620,7 +620,7 @@ class Parser { Consume(TokenType::Extern); auto type_def = ParseTypeDef(); if (type_def->constructors.size()) { - diag_ctx.Emit( + diag_ctx->Emit( {next->line, next->column, "an external type may not have any constructors"}); } defs.types.push_back(type_def); diff --git a/src/parser/source_map.cc b/src/parser/source_map.cc new file mode 100644 index 000000000000..b57278518f46 --- /dev/null +++ b/src/parser/source_map.cc @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file source_map.cc + * \brief The implementation of the source map data structure. + */ +#include +#include + +namespace tvm { + +ObjectPtr GetSourceNameNode(const String& name) { + // always return pointer as the reference can change as map re-allocate. + // or use another level of indirection by creating a unique_ptr + static std::unordered_map > source_map; + + auto sn = source_map.find(name); + if (sn == source_map.end()) { + ObjectPtr n = make_object(); + source_map[name] = n; + n->name = std::move(name); + return n; + } else { + return sn->second; + } +} + +ObjectPtr GetSourceNameNodeByStr(const std::string& name) { + return GetSourceNameNode(name); +} + +SourceName SourceName::Get(const String& name) { return SourceName(GetSourceNameNode(name)); } + +TVM_REGISTER_GLOBAL("ir.SourceName").set_body_typed(SourceName::Get); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "SourceName(" << node->name << ", " << node << ")"; + }); + +TVM_REGISTER_NODE_TYPE(SourceNameNode) + .set_creator(GetSourceNameNodeByStr) + .set_repr_bytes([](const Object* n) -> std::string { + return static_cast(n)->name; + }); + +Span::Span(SourceName source, int lineno, int col_offset) { + auto n = make_object(); + n->source = std::move(source); + n->line = lineno; + n->column = col_offset; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(SpanNode); + +TVM_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source, int lineno, int col_offset) { + return Span(source, lineno, col_offset); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "Span(" << node->source << ", " << node->line << ", " << node->column << ")"; + }); +} // namespace tvm diff --git a/to_json.py b/to_json.py new file mode 100644 index 000000000000..17516bead554 --- /dev/null +++ b/to_json.py @@ -0,0 +1,59 @@ +import tvm +from tvm import relay +from tvm.relay.expr_functor import ExprFunctor + +import json + +class ToTypeScriptJSON(ExprFunctor): + def visit_function(self, func): + json_params = [self.visit(param) for param in func.params] + fn_json = { "type": "Function", "params": json_params, + "body": self.visit(func.body) + } + return fn_json + + + def visit_var(self, var): + # TODO(@jroesch): we need to add a type field + var_json = { "type": "Var", "name_hint": var.name_hint } + return var_json + + def visit_call(self, call): + return {} + +x = relay.var('x', shape=(10, 5)) +y = relay.var('x', shape=(10, 5)) +f: relay.Function = relay.Function([x, y], x + y) + +to_json = ToTypeScriptJSON() +program_ser = to_json.visit(f) +program_json = json.dumps(program_ser) + +from tvm.relay.build_module import build + +mod = tvm.IRModule.from_expr(f) + +graph_json, rt_mod, params = build(mod, target='llvm') + +import pdb; pdb.set_trace() + +span_to_perf_data = { ... } +span_to_source_code = { ... } +span_to_graph = { ... } + +a = g + h + j +if (x + y) { + z + w + ......... +} else { + a + b + c + d .. +} + + +x y + \ / + + + | + if +| ....| | ....| + \ / + | From 4502adce050898b34a2327a54c974fbf659ab1f7 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Jul 2020 14:39:49 -0700 Subject: [PATCH 09/48] Convert tokenizer and parser to use new machinery --- include/tvm/parser/source_map.h | 33 ++++++++ src/parser/diagnostic.h | 142 ++++++++++++++------------------ src/parser/parser.cc | 71 ++++++++-------- src/parser/source_map.cc | 117 +++++++++++++++----------- src/parser/tokenizer.h | 25 ++++-- 5 files changed, 221 insertions(+), 167 deletions(-) diff --git a/include/tvm/parser/source_map.h b/include/tvm/parser/source_map.h index 41b0d385b939..b121b19df012 100644 --- a/include/tvm/parser/source_map.h +++ b/include/tvm/parser/source_map.h @@ -33,6 +33,39 @@ namespace tvm { namespace parser { + +/*! \brief A program source in any language. + * + * Could represent the source from an ML framework or the internal + * source of a TVM program. + */ +struct Source { + /*! \brief The raw source. */ + std::string source; + /*! \brief A mapping of line breaks into the raw source. */ + std::vector> line_map; + + /*! \brief An empty source. */ + Source() : source(), line_map() {} + + /*! \brief Construct a source from a string. */ + TVM_DLL explicit Source(const std::string& source); + + TVM_DLL Source(const Source& source) : source(source.source), line_map(source.line_map) {} + + /*! \brief Generate an error message at a specific line and column with the + * annotated message. + * + * The error is written directly to the `out` std::ostream. + * + * \param out The output ostream. + * \param line The line at which to report a diagnostic. + * \param line The column at which to report a diagnostic. + * \param msg The message to attach. + */ + TVM_DLL void ReportAt(std::ostream& out, int line, int column, const std::string& msg) const; +}; + /*! * \brief A mapping from a unique source name to source fragment. */ diff --git a/src/parser/diagnostic.h b/src/parser/diagnostic.h index 19f5d205126a..b1b4e5590a1f 100644 --- a/src/parser/diagnostic.h +++ b/src/parser/diagnostic.h @@ -31,9 +31,11 @@ #define TVM_PARSER_DIAGNOSTIC_H_ #include +#include #include #include + #include #include #include @@ -42,84 +44,6 @@ namespace tvm { namespace parser { -/*! \brief A program source in any language. - * - * Could represent the source from an ML framework or the internal - * source of a TVM program. - */ -struct Source { - /*! \brief The raw source. */ - std::string source; - /*! \brief A mapping of line breaks into the raw source. */ - std::vector> line_map; - - /*! \brief An empty source. */ - Source() : source(), line_map() {} - - /*! \brief Construct a source from a string. */ - explicit Source(const std::string& source) : source(source) { - int index = 0; - int length = 0; - line_map.push_back({index, length}); - for (auto c : source) { - if (c == '\n') { - // Record the length of the line. - line_map.back().second = length; - // Bump past the newline. - index += 1; - // Record the start of the next line, and put placeholder for length. - line_map.push_back({index, 0}); - // Reset length to zero. - length = 0; - } else { - length += 1; - index += 1; - } - } - line_map.back().second = length; - } - - Source(const Source& source) : source(source.source), line_map(source.line_map) {} - - /*! \brief Generate an error message at a specific line and column with the - * annotated message. - * - * The error is written directly to the `out` std::ostream. - * - * \param out The output ostream. - * \param line The line at which to report a diagnostic. - * \param line The column at which to report a diagnostic. - * \param msg The message to attach. - */ - void ReportAt(std::ostream& out, int line, int column, const std::string& msg) const { - CHECK(line - 1 <= static_cast(line_map.size())) - << "requested line: " << (line - 1) << "line_map size: " << line_map.size() - << "source: " << source; - - // Adjust for zero indexing, now have (line_start, line_length); - auto range = line_map.at(line - 1); - int line_start = range.first; - int line_length = range.second; - out << "file:" << line << ":" << column << ": parse error: " << msg << std::endl; - out << " " << source.substr(line_start, line_length) << std::endl; - out << " "; - std::stringstream marker; - for (int i = 1; i <= line_length; i++) { - if (i == column) { - marker << "^"; - } else if ((column - i) < 3) { - marker << "~"; - } else if ((i - column) < 3) { - marker << "~"; - } else { - marker << " "; - } - } - out << marker.str(); - out << std::endl; - } -}; - /*! \brief The diagnostic level, controls the printing of the message. */ enum DiagnosticLevel { Bug, @@ -140,6 +64,58 @@ struct Diagnostic { Diagnostic(int line, int column, const std::string& message) : level(DiagnosticLevel::Error), span(SourceName(), line, column), message(message) {} + + Diagnostic(DiagnosticLevel level, Span span, const std::string& message) + : level(level), span(span), message(message) {} +}; + +/*! + * \brief A wrapper around std::stringstream to build a diagnostic. + * + * \code + * + * void ReportError(const Error& err); + * + * void Test(int number) { + * // Use error reporter to construct an error. + * ReportError(ErrorBuilder() << "This is an error number=" << number); + * } + * + * \endcode + */ +struct DiagnosticBuilder { + public: + /*! \brief The level. */ + DiagnosticLevel level; + + /*! \brief The source name. */ + SourceName source_name; + + /*! \brief The line number. */ + int line; + + /*! \brief The column number. */ + int column; + + template + DiagnosticBuilder& operator<<(const T& val) { // NOLINT(*) + stream_ << val; + return *this; + } + + DiagnosticBuilder() : level(DiagnosticLevel::Error), source_name(), line(0), column(0) {} + DiagnosticBuilder(const DiagnosticBuilder& builder) + : level(builder.level), source_name(builder.source_name), line(builder.line), column(builder.column) {} + DiagnosticBuilder(DiagnosticLevel level, SourceName source_name, int line, int column) + : level(level), source_name(source_name), line(line), column(column) {} + + operator Diagnostic() { + return Diagnostic(this->level, Span(this->source_name, this->line, this->column), this->stream_.str()); + } + + private: + std::stringstream stream_; + friend struct Diagnostic; }; /*! \brief A diagnostic context for recording errors against a source file. @@ -158,6 +134,14 @@ struct DiagnosticContext { /*! \brief Emit a diagnostic. */ void Emit(const Diagnostic& diagnostic) { diagnostics.push_back(diagnostic); } + /*! \brief Emit a diagnostic. */ + void EmitFatal(const Diagnostic& diagnostic) { + diagnostics.push_back(diagnostic); + Render(std::cout); + // TODO(@jroesch): throw exception which is caught at the pass boundary and then rendered. + LOG(FATAL) << "error occurred"; + } + // TODO(@jroesch): eventually modularize the rendering interface to provide control of how to // format errors. void Render(std::ostream& ostream) { @@ -166,7 +150,7 @@ struct DiagnosticContext { } if (diagnostics.size()) { - LOG(FATAL) << "parse error occured"; + LOG(FATAL) << "parse error occurred"; } } }; diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 69f7e527ec8c..2104084ec7fb 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -213,6 +213,8 @@ class Parser { /*! \brief The diagnostic context used for error reporting. */ DiagnosticContext* diag_ctx; + const SourceName& source_name; + /*! \brief The current position in the token stream. */ int pos; @@ -243,8 +245,8 @@ class Parser { /*! \brief The set of expression scopes used for lexical scope. */ ScopeStack expr_scopes; - Parser(DiagnosticContext* ctx, std::vector tokens, OperatorTable op_table, Source source) - : diag_ctx(ctx), pos(0), tokens(tokens), op_table(op_table), ignore_whitespace(true) {} + Parser(DiagnosticContext* ctx, const SourceName& source_name, std::vector tokens, OperatorTable op_table, Source source) + : diag_ctx(ctx), source_name(source_name), pos(0), tokens(tokens), op_table(op_table), ignore_whitespace(true) {} /*! \brief Examine the next token in the stream, the current parser is configured to be * whitespace insensitive so we will skip all whitespace or comment tokens. */ @@ -292,10 +294,9 @@ class Parser { */ void Consume(const TokenType& token_type) { if (tokens[pos]->token_type != token_type) { - std::string message = - "expected a " + Pretty(token_type) + " found " + Pretty(Peek()->token_type); - this->diag_ctx->Emit({tokens[pos]->line, tokens[pos]->column, message}); - this->diag_ctx->Render(std::cout); + this->diag_ctx->EmitFatal( + DiagnosticBuilder(DiagnosticLevel::Error, SourceName(), tokens[pos]->line, tokens[pos]->column) + << "expected a " << Pretty(token_type) << " found " << Pretty(Peek()->token_type)); } pos++; } @@ -540,10 +541,9 @@ class Parser { return elements; } else { auto next = Peek(); - std::stringstream msg; - msg << "expected a " << Pretty(stop) << " found " << Pretty(next->token_type); - diag_ctx.Emit({next->line, next->column, msg.str()}); - diag_ctx.Render(std::cout); + this->diag_ctx->EmitFatal( + DiagnosticBuilder(DiagnosticLevel::Error, SourceName(), next->line, next->column) + << "expected a " << Pretty(stop) << " found " << Pretty(next->token_type)); return Array(nullptr); } } @@ -582,16 +582,15 @@ class Parser { auto version = Match(TokenType::Version); // TODO(@jroesch): we currently only support 0.0.5. if (version.ToString() != "\"0.0.5\"") { - std::stringstream msg; - msg << "invalid semantic version `"; - msg << version.ToString() << "`"; - this->diag_ctx->Emit({version->line, version->column, msg.str() }); + this->diag_ctx->Emit( + DiagnosticBuilder(DiagnosticLevel::Error, SourceName(), version->line, version->column) + << "invalid semantic version `" << version.ToString() << "`"); } } else if (required) { - std::stringstream msg; - msg << "expected text format semantic version "; - msg << "you can annotate it as #[version = \"0.0.5\"]"; - this->diag_ctx->Emit({Peek()->line, Peek()->column, msg.str() }); + this->diag_ctx->Emit( + DiagnosticBuilder(DiagnosticLevel::Error, SourceName(), Peek()->line, Peek()->column) + << "expected text format semantic version " + << "you can annotate it as #[version = \"0.0.5\"]"); } return SemVer(0, 0, 5); } @@ -1186,9 +1185,9 @@ class Parser { try { return Op::Get(op_name); } catch (dmlc::Error e) { - std::stringstream msg; - msg << "operator `" << op_name << "` not found, perhaps you forgot to register it?"; - this->diag_ctx.Emit({ tok->line, tok->column, msg.str() }); + this->diag_ctx->Emit( + DiagnosticBuilder(DiagnosticLevel::Error, SourceName(), tok->line, tok->column) + << "operator `" << op_name << "` not found, perhaps you forgot to register it?"); return Expr(); } } @@ -1291,10 +1290,9 @@ class Parser { } } default: { - std::stringstream msg; - msg << "expected an expression found " << Pretty(next->token_type); - diag_ctx.Emit({next->line, next->column, msg.str()}); - diag_ctx.Render(std::cout); + this->diag_ctx->EmitFatal( + DiagnosticBuilder(DiagnosticLevel::Error, SourceName(), next->line, next->column) + << "expected an expression found " << Pretty(next->token_type)); return Expr(); } } @@ -1414,11 +1412,9 @@ class Parser { if (WhenMatch(TokenType::Underscore)) { return IncompleteType(); } else { - std::stringstream msg; - msg << "failed to parse type found "; - msg << tok; - diag_ctx.Emit({tok->line, tok->column, msg.str()}); - diag_ctx.Render(std::cout); + this->diag_ctx->EmitFatal( + DiagnosticBuilder(DiagnosticLevel::Error, SourceName(), tok->line, tok->column) + << "failed to parse type found " << tok); return Type(); } } @@ -1471,15 +1467,22 @@ class Parser { }; IRModule ParseModule(std::string file_name, std::string file_content) { - auto tokens = Tokenize(file_content); - Parser parser(tokens, DefaultOpTable(), Source(file_content)); + DLOG(INFO) << "ParseModule"; + SourceName src_name = SourceName::Get(file_name); + Source src(file_content); + DiagnosticContext ctx(src); + auto tokens = Tokenize(&ctx, src_name, file_content); + Parser parser(&ctx, src_name, tokens, DefaultOpTable(), Source(file_content)); return parser.ParseModule(); } Expr ParseExpr(std::string file_name, std::string file_content) { DLOG(INFO) << "ParseExpr"; - auto tokens = Tokenize(file_content); - Parser parser(tokens, DefaultOpTable(), Source(file_content)); + SourceName src_name = SourceName::Get(file_name); + Source src(file_content); + DiagnosticContext ctx(src); + auto tokens = Tokenize(&ctx, src_name, file_content); + Parser parser(&ctx, src_name, tokens, DefaultOpTable(), Source(file_content)); parser.ParseSemVer(false); parser.PushScope(); auto expr = parser.ParseExpr(); diff --git a/src/parser/source_map.cc b/src/parser/source_map.cc index b57278518f46..f70df420ad9f 100644 --- a/src/parser/source_map.cc +++ b/src/parser/source_map.cc @@ -20,64 +20,89 @@ * \file source_map.cc * \brief The implementation of the source map data structure. */ -#include +#include #include namespace tvm { +namespace parser { -ObjectPtr GetSourceNameNode(const String& name) { - // always return pointer as the reference can change as map re-allocate. - // or use another level of indirection by creating a unique_ptr - static std::unordered_map > source_map; - - auto sn = source_map.find(name); - if (sn == source_map.end()) { - ObjectPtr n = make_object(); - source_map[name] = n; - n->name = std::move(name); - return n; - } else { - return sn->second; +/*! \brief Construct a source from a string. */ +Source::Source(const std::string& source) : source(source) { + int index = 0; + int length = 0; + line_map.push_back({index, length}); + for (auto c : source) { + if (c == '\n') { + // Record the length of the line. + line_map.back().second = length; + // Bump past the newline. + index += 1; + // Record the start of the next line, and put placeholder for length. + line_map.push_back({index, 0}); + // Reset length to zero. + length = 0; + } else { + length += 1; + index += 1; + } } + line_map.back().second = length; } -ObjectPtr GetSourceNameNodeByStr(const std::string& name) { - return GetSourceNameNode(name); -} -SourceName SourceName::Get(const String& name) { return SourceName(GetSourceNameNode(name)); } +/*! \brief Generate an error message at a specific line and column with the + * annotated message. + * + * The error is written directly to the `out` std::ostream. + * + * \param out The output ostream. + * \param line The line at which to report a diagnostic. + * \param line The column at which to report a diagnostic. + * \param msg The message to attach. + */ +void Source::ReportAt(std::ostream& out, int line, int column, const std::string& msg) const { + CHECK(line - 1 <= static_cast(line_map.size())) + << "requested line: " << (line - 1) << "line_map size: " << line_map.size() + << "source: " << source; -TVM_REGISTER_GLOBAL("ir.SourceName").set_body_typed(SourceName::Get); + // Adjust for zero indexing, now have (line_start, line_length); + auto range = line_map.at(line - 1); + int line_start = range.first; + int line_length = range.second; + out << "file:" << line << ":" << column << ": parse error: " << msg << std::endl; + out << " " << source.substr(line_start, line_length) << std::endl; + out << " "; + std::stringstream marker; + for (int i = 1; i <= line_length; i++) { + if (i == column) { + marker << "^"; + } else if ((column - i) < 3) { + marker << "~"; + } else if ((i - column) < 3) { + marker << "~"; + } else { + marker << " "; + } + } + out << marker.str(); + out << std::endl; +} -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "SourceName(" << node->name << ", " << node << ")"; - }); +// TVM_REGISTER_GLOBAL("ir.SourceName").set_body_typed(SourceName::Get); -TVM_REGISTER_NODE_TYPE(SourceNameNode) - .set_creator(GetSourceNameNodeByStr) - .set_repr_bytes([](const Object* n) -> std::string { - return static_cast(n)->name; - }); +// TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +// .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { +// auto* node = static_cast(ref.get()); +// p->stream << "SourceName(" << node->name << ", " << node << ")"; +// }); -Span::Span(SourceName source, int lineno, int col_offset) { - auto n = make_object(); - n->source = std::move(source); - n->line = lineno; - n->column = col_offset; +TVM_REGISTER_NODE_TYPE(SourceMapNode); + +SourceMap::SourceMap(Map source_map) { + auto n = make_object(); + n->source_map = std::move(source_map); data_ = std::move(n); } -TVM_REGISTER_NODE_TYPE(SpanNode); - -TVM_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source, int lineno, int col_offset) { - return Span(source, lineno, col_offset); -}); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "Span(" << node->source << ", " << node->line << ", " << node->column << ")"; - }); -} // namespace tvm +} // namespace parser +} // namespace tvm diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index 3e3049c05f0b..7a68c458de08 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -74,6 +74,9 @@ static std::unordered_map KEYWORD_TABLE = { {"match", TokenType::Match}, {"extern", TokenType::Extern}}; struct Tokenizer { + DiagnosticContext *diag_ctx; + const SourceName& source_name; + size_t pos; int col; int line; @@ -261,11 +264,17 @@ struct Tokenizer { rtrim(version); return Token(line, column, TokenType::Version, tvm::String(version)); } else { - LOG(FATAL) << "unsupported " << attribute; + // TOOD(@jroesch): maybe make this a warning an continue parsing? + this->diag_ctx->EmitFatal( + DiagnosticBuilder(DiagnosticLevel::Error, SourceName(), line, column) + << "unsupported attribute " << attribute); return Token(); } } else { - LOG(FATAL) << "lex error"; + this->diag_ctx->EmitFatal( + DiagnosticBuilder(DiagnosticLevel::Error, SourceName(), line, column) + << "`#` denotes the start of an attribute can only be followed by `[`" + << " found `" << Peek() << "`"); return Token(); } } @@ -283,9 +292,9 @@ struct Tokenizer { auto token = NewToken(TokenType::Newline); return token; } else { - // TODO(@jroesch): have lexer use diagnostic context too. - // see https://github.com/apache/incubator-tvm/issues/6153. - LOG(FATAL) << "lexer error"; + this->diag_ctx->EmitFatal( + DiagnosticBuilder(DiagnosticLevel::Error, SourceName(), this->line, this->col) + << "\\r carriage returns must be followed by a \\n in the TVM text format"); return Token(); } } else if (next == '"') { @@ -499,7 +508,7 @@ struct Tokenizer { this->tokens.push_back(NewToken(TokenType::EndOfFile)); } - explicit Tokenizer(std::string& source) : pos(0), col(1), line(1), source(source), tokens() {} + explicit Tokenizer(DiagnosticContext *ctx, const SourceName& source_name, const std::string& source) : diag_ctx(ctx), source_name(source_name), pos(0), col(1), line(1), source(source), tokens() {} }; std::vector Condense(const std::vector& tokens) { @@ -568,8 +577,8 @@ std::vector Condense(const std::vector& tokens) { return out; } -std::vector Tokenize(std::string source) { - auto tokenizer = Tokenizer(source); +std::vector Tokenize(DiagnosticContext *ctx, const SourceName& source_name, const std::string& source) { + auto tokenizer = Tokenizer(ctx, source_name, source); tokenizer.Tokenize(); auto tokens = Condense(tokenizer.tokens); for (auto token : tokens) { From 39df873d5d263e1e374d6983695b8db3fb5d119f Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Jul 2020 14:46:42 -0700 Subject: [PATCH 10/48] Kill to_json --- to_json.py | 59 ------------------------------------------------------ 1 file changed, 59 deletions(-) delete mode 100644 to_json.py diff --git a/to_json.py b/to_json.py deleted file mode 100644 index 17516bead554..000000000000 --- a/to_json.py +++ /dev/null @@ -1,59 +0,0 @@ -import tvm -from tvm import relay -from tvm.relay.expr_functor import ExprFunctor - -import json - -class ToTypeScriptJSON(ExprFunctor): - def visit_function(self, func): - json_params = [self.visit(param) for param in func.params] - fn_json = { "type": "Function", "params": json_params, - "body": self.visit(func.body) - } - return fn_json - - - def visit_var(self, var): - # TODO(@jroesch): we need to add a type field - var_json = { "type": "Var", "name_hint": var.name_hint } - return var_json - - def visit_call(self, call): - return {} - -x = relay.var('x', shape=(10, 5)) -y = relay.var('x', shape=(10, 5)) -f: relay.Function = relay.Function([x, y], x + y) - -to_json = ToTypeScriptJSON() -program_ser = to_json.visit(f) -program_json = json.dumps(program_ser) - -from tvm.relay.build_module import build - -mod = tvm.IRModule.from_expr(f) - -graph_json, rt_mod, params = build(mod, target='llvm') - -import pdb; pdb.set_trace() - -span_to_perf_data = { ... } -span_to_source_code = { ... } -span_to_graph = { ... } - -a = g + h + j -if (x + y) { - z + w + ......... -} else { - a + b + c + d .. -} - - -x y - \ / - + - | - if -| ....| | ....| - \ / - | From b7dd570ded64dd1a0d8c5ca388694aed56f6cf5e Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Jul 2020 21:30:02 -0700 Subject: [PATCH 11/48] Fix comment in type_infer.cc --- src/relay/transforms/type_infer.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 45e1af1c960f..7182f0e96f0f 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -26,16 +26,16 @@ * most efficient code we need to obtain type information for the * IR. * - * Like computation graphs the IR leaves most type information - * implicit and relies performing analysis of the program to - * generate this information. + * Similar to previous computation graph based IRs, the Relay IR leaves + * type information implicit and computes types by performing program + * analysis. * - * This pass given an expression `e` will infer a type `t` for - * the expression simultaneous checking the property `e : t` - * (i.e we can show e has type t). + * Given an expression `e` this pass infers a type `t` for + * the expression as well as simultaneously checking the property `e : t` + * (i.e., we can show e has type t). * - * If we can not infer a type or there are conflicting typing - * constraints we will trigger an error. + * If we can not infer a type or there is a conflicting + * constraint it will emit errors. */ #include #include From 9968fdf96de8bbf312f6483583af1adde6afad02 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Jul 2020 21:31:10 -0700 Subject: [PATCH 12/48] Remove old parser --- CMakeLists.txt | 3 - cmake/modules/ANTLR.cmake | 40 - cmake/util/FindANTLR.cmake | 65 - docker/Dockerfile.ci_cpu | 3 - docker/Dockerfile.ci_gpu | 3 - docker/Dockerfile.ci_wasm | 3 - docker/install/ubuntu_install_antlr.sh | 25 - .../install/ubuntu_install_python_package.sh | 2 +- docs/README.txt | 4 +- docs/install/from_source.rst | 7 - python/setup.py | 3 +- python/tvm/relay/_parser.py | 771 ---- python/tvm/relay/grammar/.gitignore | 1 - python/tvm/relay/grammar/Relay.g4 | 199 - python/tvm/relay/grammar/__init__.py | 16 - python/tvm/relay/grammar/py3/.gitattributes | 3 - python/tvm/relay/grammar/py3/RelayLexer.py | 256 -- python/tvm/relay/grammar/py3/RelayParser.py | 3732 ----------------- python/tvm/relay/grammar/py3/RelayVisitor.py | 343 -- python/tvm/relay/grammar/py3/__init__.py | 16 - python/tvm/relay/parser.py | 30 - tests/lint/rat-excludes | 5 - 22 files changed, 4 insertions(+), 5526 deletions(-) delete mode 100644 cmake/modules/ANTLR.cmake delete mode 100644 cmake/util/FindANTLR.cmake delete mode 100755 docker/install/ubuntu_install_antlr.sh delete mode 100644 python/tvm/relay/_parser.py delete mode 100644 python/tvm/relay/grammar/.gitignore delete mode 100644 python/tvm/relay/grammar/Relay.g4 delete mode 100644 python/tvm/relay/grammar/__init__.py delete mode 100644 python/tvm/relay/grammar/py3/.gitattributes delete mode 100644 python/tvm/relay/grammar/py3/RelayLexer.py delete mode 100644 python/tvm/relay/grammar/py3/RelayParser.py delete mode 100644 python/tvm/relay/grammar/py3/RelayVisitor.py delete mode 100644 python/tvm/relay/grammar/py3/__init__.py delete mode 100644 python/tvm/relay/parser.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 0ef216336af2..19ab38b4c464 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,7 +7,6 @@ include(cmake/util/FindCUDA.cmake) include(cmake/util/FindVulkan.cmake) include(cmake/util/FindLLVM.cmake) include(cmake/util/FindROCM.cmake) -include(cmake/util/FindANTLR.cmake) if(EXISTS ${CMAKE_CURRENT_BINARY_DIR}/config.cmake) include(${CMAKE_CURRENT_BINARY_DIR}/config.cmake) @@ -66,7 +65,6 @@ tvm_option(USE_SORT "Build with sort support" OFF) tvm_option(USE_NNPACK "Build with nnpack support" OFF) tvm_option(USE_RANDOM "Build with random support" OFF) tvm_option(USE_MICRO_STANDALONE_RUNTIME "Build with micro.standalone_runtime support" OFF) -tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF) tvm_option(USE_CPP_RPC "Build CPP RPC" OFF) tvm_option(USE_TFLITE "Build with tflite support" OFF) tvm_option(USE_TENSORFLOW_PATH "TensorFlow root path when use TFLite" none) @@ -311,7 +309,6 @@ include(cmake/modules/Metal.cmake) include(cmake/modules/ROCM.cmake) include(cmake/modules/LLVM.cmake) include(cmake/modules/Micro.cmake) -include(cmake/modules/ANTLR.cmake) include(cmake/modules/contrib/BLAS.cmake) include(cmake/modules/contrib/CODEGENC.cmake) include(cmake/modules/contrib/DNNL.cmake) diff --git a/cmake/modules/ANTLR.cmake b/cmake/modules/ANTLR.cmake deleted file mode 100644 index d3c1b4218253..000000000000 --- a/cmake/modules/ANTLR.cmake +++ /dev/null @@ -1,40 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -if(USE_ANTLR) - find_antlr(${USE_ANTLR}) - if(ANTLR4) - - set(RELAY_PARSER_DIR - ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm/relay/grammar) - - set(RELAY_PARSER - ${RELAY_PARSER_DIR}/py3/RelayVisitor.py - ${RELAY_PARSER_DIR}/py3/RelayParser.py - ${RELAY_PARSER_DIR}/py3/RelayLexer.py) - - - # Generate ANTLR grammar for parsing. - add_custom_command(OUTPUT ${RELAY_PARSER} - COMMAND ${ANTLR4} -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py3 - DEPENDS ${RELAY_PARSER_DIR}/Relay.g4 - WORKING_DIRECTORY ${RELAY_PARSER_DIR}) - - add_custom_target(relay_parser ALL DEPENDS ${RELAY_PARSER}) - else() - message(FATAL_ERROR "Can't find ANTLR4") - endif() -endif(USE_ANTLR) diff --git a/cmake/util/FindANTLR.cmake b/cmake/util/FindANTLR.cmake deleted file mode 100644 index 3e490187083e..000000000000 --- a/cmake/util/FindANTLR.cmake +++ /dev/null @@ -1,65 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -####################################################### -# Enhanced version of find ANTLR. -# -# Usage: -# find_antlr(${USE_ANTLR}) -# -# - When USE_ANTLR=ON, use auto search by first trying to find antlr4 program, -# then trying to find antlr-*-complete.jar -# - When USE_ANTLR=/path/to/antlr-*-complete.jar, use provided jar -# -# Provide variables: -# - ANTLR4 -# -macro(find_antlr use_antlr) - set(JAVA_HOME $ENV{JAVA_HOME}) - if (NOT DEFINED JAVA_HOME) - # Hack to get system to search for Java itself. - message(STATUS "JAVA_HOME is not defined. Set it to ensure proper use") - set(JAVA_HOME "/usr") - endif() - if(MSVC) - set(JAVA_PROGRAM ${JAVA_HOME}/java.exe) - else() - set(JAVA_PROGRAM ${JAVA_HOME}/bin/java) - endif() - message(STATUS "Using Java at " ${JAVA_PROGRAM}) - - if (${use_antlr} STREQUAL "ON") - find_program(ANTLR4 antlr4) - if (NOT ANTLR4) - file(GLOB_RECURSE ANTLR4JAR - /usr/local/lib/antlr-*-complete.jar - /usr/local/Cellar/*antlr-*-complete.jar) - - # Get the first element of the list of antlr jars. - # Sort and reverse the list so the item selected is the highest - # version in lib or else in Cellar if no lib installation exists. - list(SORT ANTLR4JAR) - list(REVERSE ANTLR4JAR) - list(GET ANTLR4JAR 0 ANTLR4JAR) - - set(ANTLR4 ${JAVA_PROGRAM} -jar ${ANTLR4JAR}) - endif() - elseif(NOT ${use_antlr} STREQUAL "OFF") - set(ANTLR4 ${JAVA_PROGRAM} -jar ${use_antlr}) - endif() - message(STATUS "ANTLR4=${ANTLR4}") -endmacro(find_antlr) diff --git a/docker/Dockerfile.ci_cpu b/docker/Dockerfile.ci_cpu index 828fff4e5fc6..df416d48ce09 100644 --- a/docker/Dockerfile.ci_cpu +++ b/docker/Dockerfile.ci_cpu @@ -60,9 +60,6 @@ ENV PATH $PATH:$CARGO_HOME/bin:/usr/lib/go-1.10/bin COPY install/ubuntu_install_java.sh /install/ubuntu_install_java.sh RUN bash /install/ubuntu_install_java.sh -COPY install/ubuntu_install_antlr.sh /install/ubuntu_install_antlr.sh -RUN bash /install/ubuntu_install_antlr.sh - # Chisel deps for TSIM COPY install/ubuntu_install_chisel.sh /install/ubuntu_install_chisel.sh RUN bash /install/ubuntu_install_chisel.sh diff --git a/docker/Dockerfile.ci_gpu b/docker/Dockerfile.ci_gpu index 6233cc53211f..7b8468ef7346 100644 --- a/docker/Dockerfile.ci_gpu +++ b/docker/Dockerfile.ci_gpu @@ -87,9 +87,6 @@ RUN bash /install/ubuntu_install_vulkan.sh COPY install/ubuntu_install_redis.sh /install/ubuntu_install_redis.sh RUN bash /install/ubuntu_install_redis.sh -COPY install/ubuntu_install_antlr.sh /install/ubuntu_install_antlr.sh -RUN bash /install/ubuntu_install_antlr.sh - # NNPACK deps COPY install/ubuntu_install_nnpack.sh /install/ubuntu_install_nnpack.sh RUN bash /install/ubuntu_install_nnpack.sh diff --git a/docker/Dockerfile.ci_wasm b/docker/Dockerfile.ci_wasm index 965bc01d22d8..85f942d57ca3 100644 --- a/docker/Dockerfile.ci_wasm +++ b/docker/Dockerfile.ci_wasm @@ -33,9 +33,6 @@ RUN bash /install/ubuntu1804_install_llvm.sh COPY install/ubuntu_install_java.sh /install/ubuntu_install_java.sh RUN bash /install/ubuntu_install_java.sh -COPY install/ubuntu_install_antlr.sh /install/ubuntu_install_antlr.sh -RUN bash /install/ubuntu_install_antlr.sh - COPY install/ubuntu_install_nodejs.sh /install/ubuntu_install_nodejs.sh RUN bash /install/ubuntu_install_nodejs.sh diff --git a/docker/install/ubuntu_install_antlr.sh b/docker/install/ubuntu_install_antlr.sh deleted file mode 100755 index de713a6f6a32..000000000000 --- a/docker/install/ubuntu_install_antlr.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/bash -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -set -e -set -u -set -o pipefail - -cd /usr/local/lib -wget -q https://www.antlr.org/download/antlr-4.7.1-complete.jar -cd - diff --git a/docker/install/ubuntu_install_python_package.sh b/docker/install/ubuntu_install_python_package.sh index 2eaf00e8fdd0..2ad55c0e521e 100755 --- a/docker/install/ubuntu_install_python_package.sh +++ b/docker/install/ubuntu_install_python_package.sh @@ -21,4 +21,4 @@ set -u set -o pipefail # install libraries for python package on ubuntu -pip3 install pylint==1.9.4 six numpy pytest cython decorator scipy tornado typed_ast pytest mypy orderedset antlr4-python3-runtime attrs requests Pillow packaging +pip3 install pylint==1.9.4 six numpy pytest cython decorator scipy tornado typed_ast pytest mypy orderedset attrs requests Pillow packaging diff --git a/docs/README.txt b/docs/README.txt index 281cafaeee89..87acd306b6b2 100644 --- a/docs/README.txt +++ b/docs/README.txt @@ -42,12 +42,12 @@ You can run the following script to reproduce the CI sphinx pre-check stage. This script skips the tutorial executions and is useful for quickly check the content. ```bash -./tests/scrpts/task_sphinx_precheck.sh +./tests/scripts/task_sphinx_precheck.sh ``` The following script runs the full build which includes tutorial executions. You will need a gpu CI environment. ```bash -./tests/scrpts/task_python_docs.sh +./tests/scripts/task_python_docs.sh ``` diff --git a/docs/install/from_source.rst b/docs/install/from_source.rst index 26aec77e09e2..9fe6c5e0eb23 100644 --- a/docs/install/from_source.rst +++ b/docs/install/from_source.rst @@ -213,13 +213,6 @@ like ``virtualenv``. pip3 install --user tornado psutil xgboost - * If you want to build tvm to compile a model, you must use Python 3 and run the following - - .. code:: bash - - sudo apt install antlr4 - pip3 install --user mypy orderedset antlr4-python3-runtime - Install Contrib Libraries ------------------------- diff --git a/python/setup.py b/python/setup.py index 682589ef5e6f..3205a7cfa525 100644 --- a/python/setup.py +++ b/python/setup.py @@ -167,8 +167,7 @@ def get_package_data_files(): 'psutil', 'xgboost>=1.1.0', 'mypy', - 'orderedset', - 'antlr4-python3-runtime']}, + 'orderedset']}, packages=find_packages(), package_dir={'tvm': 'tvm'}, diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py deleted file mode 100644 index 0d3f86f6262d..000000000000 --- a/python/tvm/relay/_parser.py +++ /dev/null @@ -1,771 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# pylint: disable=invalid-name, unused-argument -"""A parser for Relay's text format.""" -from __future__ import absolute_import - -import sys -from ast import literal_eval -from collections import deque - -try: - # no typing.Deque in Python 3.5 - # https://bugs.python.org/issue29011 - from typing import Any, Dict, List, Optional, TypeVar, Tuple, Union, MutableSequence, T, Deque -except ImportError: - class Deque(deque, MutableSequence[T], extra=deque): - - def __new__(cls, *args, **kwds): - if _geqv(cls, Deque): - raise TypeError("Type Deque cannot be instantiated; " - "use deque() instead") - return deque.__new__(cls, *args, **kwds) - -import tvm -import tvm.ir._ffi_api -from tvm.ir import IRModule - -from .base import Span, SourceName -from . import adt -from . import expr -from . import function -from . import ty -from . import op - -PYTHON_VERSION = sys.version_info.major -try: - from antlr4 import InputStream, CommonTokenStream - from antlr4.error.ErrorListener import ErrorListener -except ImportError: - raise Exception("Couldn't find ANTLR runtime." + - "Try running `pip{version} install antlr4-python{version}-runtime`." - .format(version=PYTHON_VERSION)) - -try: - from .grammar.py3.RelayVisitor import RelayVisitor - from .grammar.py3.RelayParser import RelayParser - from .grammar.py3.RelayLexer import RelayLexer -except ImportError: - raise Exception("Couldn't find ANTLR parser. Try building with USE_ANTLR=ON.") - - -sys.setrecursionlimit(10000) - -class ParseError(Exception): - """Exception type for parse errors.""" - - def __init__(self, message: str) -> None: - super(ParseError, self).__init__() - self.message = message - - def __repr__(self): - return "ParseError({})".format(self.message) - - def __str__(self): - return repr(self) - -class OpWrapper: - """Overload the __call__ for op.""" - - -class ExprOp(OpWrapper): - """Call an expr. The default, but does not handle attrs well.""" - def __init__(self, operator): - self.operator = operator - - def __call__(self, args, attrs, type_args): - try: - return expr.Call(self.operator, args, attrs, type_args) - except Exception: - raise Exception("Operator {} is not registered. It's attributes are {}" - .format(self.operator, attrs)) - -class FuncOp(OpWrapper): - """Convert the attrs, call the python function with the attrs passed in as keyword arguments. - Tvm should provide this in the future, as this is pretty similar to what op.get is providing. - """ - def __init__(self, operator): - self.operator = operator - - def convert(self, v): - if isinstance(v, tuple): - return tuple([self.convert(x) for x in v]) - if isinstance(v, expr.Constant): - return v.data.asnumpy().item() - if isinstance(v, str): - return v - raise Exception(v) - - def __call__(self, args, attrs, type_args): - if attrs is None: - attrs = {} - if self.operator in (op.strided_slice,): - x = self.operator(*args) - else: - x = self.operator(*args, **{k: self.convert(v) for k, v in attrs.items()}) - if isinstance(x, expr.TupleWrapper): - x = x.astuple() - return x - -BINARY_OPS = { - RelayParser.MUL: op.multiply, - RelayParser.DIV: op.divide, - RelayParser.ADD: op.add, - RelayParser.SUB: op.subtract, - RelayParser.LT: op.less, - RelayParser.GT: op.greater, - RelayParser.LE: op.less_equal, - RelayParser.GE: op.greater_equal, - RelayParser.EQ: op.equal, - RelayParser.NE: op.not_equal, -} - -FUNC_OPS = { - "nn.conv2d": op.nn.conv2d, - "nn.batch_norm": op.nn.batch_norm, - "nn.dense": op.nn.dense, - "nn.bias_add": op.nn.bias_add, - "nn.max_pool2d": op.nn.max_pool2d, - "nn.max_pool3d": op.nn.max_pool3d, - "nn.global_max_pool2d": op.nn.global_max_pool2d, - "nn.avg_pool2d": op.nn.avg_pool2d, - "nn.avg_pool3d": op.nn.avg_pool3d, - "nn.global_avg_pool2d": op.nn.global_avg_pool2d, - "nn.softmax": op.nn.softmax, - "reshape": op.reshape, - "nn.conv2d_transpose": op.nn.conv2d_transpose, - "nn.conv1d_transpose": op.nn.conv1d_transpose, - "concatenate": op.concatenate, - "nn.dropout": op.nn.dropout_raw, - "zeros": op.zeros, - "split": op.split, - "cast": op.cast, - "clip": op.clip, - "right_shift": op.right_shift, -} - -TYPE_PREFIXES = [ - "int", - "uint", - "float", - "bool", -] - -T = TypeVar("T") -Scope = Deque[Tuple[str, T]] -Scopes = Deque[Scope[T]] - -def lookup(scopes: Scopes[T], name: str) -> Optional[T]: - """Look up `name` in `scopes`.""" - - for scope in scopes: - for key, val in scope: - if key == name: - return val - return None - -def spanify(f): - """A decorator which attaches span information - to the value returned by calling `f`. - - Intended for use with the below AST visiting - methods. The idea is that after we do the work - of constructing the AST we attach Span information. - """ - - def _wrapper(*args, **kwargs): - # Assumes 0th arg is self and gets source_name from object. - sn = args[0].source_name - # Assumes 1st arg is an ANTLR parser context. - ctx = args[1] - ast = f(*args, **kwargs) - line, col = ctx.getSourceInterval() - sp = Span(sn, line, col) - if isinstance(ast, tvm.relay.expr.TupleWrapper): - ast = ast.astuple() - tvm.ir._ffi_api.NodeSetSpan(ast, sp) - return ast - return _wrapper - -# TODO(@jmp): Use https://stackoverflow.com/q/13889941 -# to figure out how to get ANTLR4 to be more unhappy about syntax errors -class ParseTreeToRelayIR(RelayVisitor): - """Parse Relay text format into Relay IR.""" - - def __init__(self, source_name: str) -> None: - self.source_name = source_name - self.module = IRModule({}) # type: IRModule - - # Adding an empty scope allows naked lets without pain. - self.var_scopes = deque([deque()]) # type: Scopes[expr.Var] - self.global_vars = {} # type: Scope[expr.GlobalVar] - self.type_var_scopes = deque([deque()]) # type: Scopes[ty.TypeVar] - self.global_type_vars = {} # type: Scope[expr.GlobalVar] - self.graph_expr = [] # type: List[expr.Expr] - - super(ParseTreeToRelayIR, self).__init__() - - - def enter_var_scope(self) -> None: - """Enter a new Var scope so it can be popped off later.""" - self.var_scopes.appendleft(deque()) - - def exit_var_scope(self) -> Scope[expr.Var]: - """Pop off the current Var scope and return it.""" - return self.var_scopes.popleft() - - def mk_var(self, name: str, typ: ty.Type = None): - """Create a new Var and add it to the Var scope.""" - var = expr.Var(name, typ) - self.var_scopes[0].appendleft((name, var)) - return var - - def mk_global_var(self, name: str) -> expr.GlobalVar: - """Create a new GlobalVar and add it to the GlobalVar scope.""" - if name in self.global_vars: - raise ParseError("duplicate global var \"{0}\"".format(name)) - var = expr.GlobalVar(name) - self.global_vars[name] = var - return var - - def enter_type_param_scope(self) -> None: - """Enter a new TypeVar scope so it can be popped off later.""" - self.type_var_scopes.appendleft(deque()) - - def exit_type_param_scope(self) -> Scope[ty.TypeVar]: - """Pop off the current TypeVar scope and return it.""" - return self.type_var_scopes.popleft() - - def mk_typ(self, name: str, kind: ty.TypeKind) -> ty.TypeVar: - """Create a new TypeVar and add it to the TypeVar scope.""" - typ = ty.TypeVar(name, kind) - self.type_var_scopes[0].append((name, typ)) - return typ - - def mk_global_typ_var(self, name, kind): - # (str, ty.Kind) -> ty.GlobalTypeVar - """Create a new TypeVar and add it to the TypeVar scope.""" - typ = ty.GlobalTypeVar(name, kind) - self._check_existing_typ_expr(name, typ) - self.global_type_vars[name] = typ - return typ - - # TODO(weberlo): rethink whether we should have type constructors mixed with type vars. - def mk_global_typ_cons(self, name, cons): - self._check_existing_typ_expr(name, cons) - self.global_type_vars[name] = cons - - def _check_existing_typ_expr(self, name, new_expr): - if name in self.global_type_vars: - new_typ_name = self._type_expr_name(new_expr) - existing_typ_name = self._type_expr_name(self.global_type_vars[name]) - raise ParseError( - "{0} `{1}` conflicts with existing {2}".format(new_typ_name,\ - name, existing_typ_name)) - - def _type_expr_name(self, e): - if isinstance(e, adt.Constructor): - return "`{0}` ADT constructor".format(e.belong_to.name_hint) - if isinstance(e, ty.GlobalTypeVar): - if e.kind == ty.TypeKind.AdtHandle: - return "ADT definition" - return "function definition" - - def visitProjection(self, ctx): - return expr.TupleGetItem(self.visit(ctx.expr()), self.visit(ctx.NAT())) - - def visitTerminal(self, node) -> Union[expr.Expr, int, float]: - """Visit lexer tokens that aren't ignored or visited by other functions.""" - node_type = node.getSymbol().type - node_text = node.getText() - - if node_type == RelayLexer.NAT: - return int(node_text) - if node_type == RelayLexer.FLOAT: - return float(node_text[:-1]) - if node_type == RelayLexer.BOOL_LIT: - if node_text == "True": - return True - if node_text == "False": - return False - raise ParseError("unrecognized BOOL_LIT: `{}`".format(node_text)) - if node_type == RelayLexer.QUOTED_STRING: - return literal_eval(node_text) - raise ParseError("unhandled terminal \"{0}\" of type `{1}`".format(node_text, node_type)) - - def visitGeneralIdent(self, ctx): - name = ctx.getText() - # Look through all type prefixes for a match. - for type_prefix in TYPE_PREFIXES: - if name.startswith(type_prefix): - return ty.scalar_type(name) - # Next, look it up in the local then global type params. - type_expr = lookup(self.type_var_scopes, name) - if type_expr is None: - type_expr = self.global_type_vars.get(name, None) - if type_expr is not None: - # Zero-arity constructor calls fall into the general ident case, so in that case, - # we construct a constructor call with no args. - if isinstance(type_expr, adt.Constructor) and not type_expr.inputs: - type_expr = expr.Call(type_expr, []) - return type_expr - # Check if it's an operator. - op_name = ".".join([name.getText() for name in ctx.CNAME()]) - if op_name in FUNC_OPS: - return FuncOp(FUNC_OPS[op_name]) - return ExprOp(op.get(op_name)) - - def visitGlobalVar(self, ctx): - var_name = ctx.CNAME().getText() - global_var = self.global_vars.get(var_name, None) - if global_var is None: - raise ParseError("unbound global var `{0}`".format(var_name)) - return global_var - - def visitLocalVar(self, ctx): - var_name = ctx.CNAME().getText() - local_var = lookup(self.var_scopes, var_name) - if local_var is None: - raise ParseError("unbound local var `{0}`".format(var_name)) - return local_var - - def visitGraphVar(self, ctx): - return self.graph_expr[int(ctx.NAT().getText())] - - def visit_list(self, ctx_list) -> List[Any]: - """"Visit a list of contexts.""" - assert isinstance(ctx_list, list) - - return [self.visit(ctx) for ctx in ctx_list] - - def getTypeExpr(self, ctx: Optional[RelayParser.TypeExprContext]) -> Optional[ty.Type]: - """Return a (possibly None) Relay type.""" - if ctx is None: - return None - - return self.visit(ctx) - - def visitProg(self, ctx: RelayParser.ProgContext) -> Union[expr.Expr, IRModule]: - self.meta = None - if ctx.METADATA(): - header, data = str(ctx.METADATA()).split("\n", 1) - assert header == "METADATA:" - self.meta = tvm.ir.load_json(data) - if ctx.defn(): - self.visit_list(ctx.defn()) - return self.module - - if ctx.expr(): - return self.visit(ctx.expr()) - - return self.module - - # Exprs - def visitOpIdent(self, ctx) -> tvm.ir.Op: - op_name = ".".join([name.getText() for name in ctx.CNAME()]) - if op_name in FUNC_OPS: - return FuncOp(FUNC_OPS[op_name]) - return ExprOp(op.get(op_name)) - - # pass through - def visitParen(self, ctx: RelayParser.ParenContext) -> expr.Expr: - return self.visit(ctx.expr()) - - # pass through - def visitTypeParen(self, ctx: RelayParser.TypeParenContext) -> expr.Expr: - return self.visit(ctx.typeExpr()) - - # pass through - def visitBody(self, ctx: RelayParser.BodyContext) -> expr.Expr: - return self.visit(ctx.expr()) - - def visitScalarFloat(self, ctx: RelayParser.ScalarFloatContext) -> expr.Constant: - return expr.const(self.visit(ctx.FLOAT())) - - def visitScalarInt(self, ctx: RelayParser.ScalarIntContext) -> expr.Constant: - return expr.const(self.visit(ctx.NAT())) - - def visitScalarBool(self, ctx: RelayParser.ScalarBoolContext) -> expr.Constant: - return expr.const(self.visit(ctx.BOOL_LIT())) - - def visitNeg(self, ctx: RelayParser.NegContext) -> Union[expr.Constant, expr.Call]: - val = self.visit(ctx.expr()) - if isinstance(val, expr.Constant) and val.data.asnumpy().ndim == 0: - # fold Neg in for scalars - return expr.const(-val.data.asnumpy().item()) - - return op.negative(val) - - def visitTuple(self, ctx: RelayParser.TupleContext) -> expr.Tuple: - tup = self.visit_list(ctx.expr()) - return expr.Tuple(tup) - - def visitLet(self, ctx: RelayParser.LetContext) -> expr.Let: - """Desugar various sequence constructs to Relay Let nodes.""" - - if ctx.var() is None: - # anonymous identity - ident = "_" - typ = None - var = self.mk_var(ident, typ) - else: - var = self.visitVar(ctx.var()) - - self.enter_var_scope() - value = self.visit(ctx.expr(0)) - self.exit_var_scope() - - body = self.visit(ctx.expr(1)) - - return expr.Let(var, value, body) - - def visitBinOp(self, ctx: RelayParser.BinOpContext) -> expr.Call: - """Desugar binary operators.""" - arg0, arg1 = self.visit_list(ctx.expr()) - relay_op = BINARY_OPS.get(ctx.op.type) - - if relay_op is None: - raise ParseError("unimplemented binary op.") - - return relay_op(arg0, arg1) - - @spanify - def visitVar(self, ctx: RelayParser.VarContext) -> expr.Var: - """Visit a single variable.""" - ident = ctx.localVar() - - if ident is None: - raise ParseError("only local ids may be used in vars.") - - typeExpr = self.getTypeExpr(ctx.typeExpr()) - - return self.mk_var(ident.getText()[1:], typeExpr) - - def visitVarList(self, ctx: RelayParser.VarListContext) -> List[expr.Var]: - return self.visit_list(ctx.var()) - - # TODO: support a larger class of values than just Relay exprs - def visitAttr(self, ctx: RelayParser.AttrContext) -> Tuple[str, expr.Expr]: - return (ctx.CNAME().getText(), self.visit(ctx.expr())) - - def visitArgNoAttr(self, ctx: RelayParser.ArgNoAttrContext): - return (self.visit_list(ctx.varList().var()), None) - - def visitAttrSeq(self, ctx: RelayParser.AttrSeqContext) -> Dict[str, expr.Expr]: - return dict(self.visit_list(ctx.attr())) - - def visitArgWithAttr(self, ctx: RelayParser.AttrSeqContext) \ - -> Tuple[List[expr.Var], Dict[str, expr.Expr]]: - return (self.visit_list(ctx.var()), self.visitAttrSeq(ctx.attrSeq())) - - def visitArgList(self, ctx: RelayParser.ArgListContext) \ - -> Tuple[Optional[List[expr.Var]], Optional[Dict[str, expr.Expr]]]: - var_list = self.visit(ctx.varList()) if ctx.varList() else None - attr_list = self.visit(ctx.attrList()) if ctx.attrList() else None - return (var_list, attr_list) - - def visitMeta(self, ctx: RelayParser.MetaContext): - type_key = str(ctx.CNAME()) - index = int(self.visit(ctx.NAT())) - return self.meta[type_key][index] - - def mk_func( - self, - ctx: Union[RelayParser.FuncContext, RelayParser.DefnContext]) \ - -> function.Function: - """Construct a function from either a Func or Defn.""" - # Enter var scope early to put params in scope. - self.enter_var_scope() - # Capture type params in params. - self.enter_type_param_scope() - type_params = ctx.typeParamList() - - if type_params is not None: - type_params = type_params.typeExpr() - assert type_params - for ty_param in type_params: - name = ty_param.getText() - self.mk_typ(name, ty.TypeKind.Type) - - var_list, attr_list = self.visit(ctx.argList()) - if var_list is None: - var_list = [] - ret_type = self.getTypeExpr(ctx.typeExpr()) - - body = self.visit(ctx.body()) - # NB(@jroesch): you must stay in the type parameter scope until - # after you exit the body, you can reference the type parameters - # of your parent scopes. - type_params = list(self.exit_type_param_scope()) - if type_params: - _, type_params = zip(*type_params) - self.exit_var_scope() - - attrs = tvm.ir.make_node("DictAttrs", **attr_list) if attr_list is not None else None - return function.Function(var_list, body, ret_type, type_params, attrs) - - @spanify - def visitFunc(self, ctx: RelayParser.FuncContext) -> function.Function: - return self.mk_func(ctx) - - # TODO: how to set spans for definitions? - # @spanify - def visitFuncDefn(self, ctx: RelayParser.DefnContext) -> None: - ident_name = ctx.globalVar().getText()[1:] - ident = self.mk_global_var(ident_name) - func = self.mk_func(ctx) - self.module[ident] = func - - def handle_adt_header( - self, - ctx: Union[RelayParser.ExternAdtDefnContext, RelayParser.AdtDefnContext]): - """Handles parsing of the name and type params of an ADT definition.""" - adt_name = ctx.generalIdent().getText() - adt_var = self.mk_global_typ_var(adt_name, ty.TypeKind.AdtHandle) - # parse type params - type_params = ctx.typeParamList() - if type_params is None: - type_params = [] - else: - type_params = [self.mk_typ(type_ident.getText(), ty.TypeKind.Type) - for type_ident in type_params.typeExpr()] - return adt_var, type_params - - def visitExternAdtDefn(self, ctx: RelayParser.ExternAdtDefnContext): - # TODO(weberlo): update this handler once extern is implemented - self.enter_type_param_scope() - adt_var, type_params = self.handle_adt_header(ctx) - # update module being built - self.module[adt_var] = adt.TypeData(adt_var, type_params, []) - self.exit_type_param_scope() - - def visitAdtDefn(self, ctx: RelayParser.AdtDefnContext): - self.enter_type_param_scope() - adt_var, type_params = self.handle_adt_header(ctx) - # parse constructors - adt_cons_defns = ctx.adtConsDefnList() - if adt_cons_defns is None: - adt_cons_defns = [] - else: - adt_cons_defns = adt_cons_defns.adtConsDefn() - parsed_constructors = [] - for cons_defn in adt_cons_defns: - inputs = [self.visit(inp) for inp in cons_defn.typeExpr()] - cons_defn_name = cons_defn.constructorName().getText() - cons_defn = adt.Constructor(cons_defn_name, inputs, adt_var) - self.mk_global_typ_cons(cons_defn_name, cons_defn) - parsed_constructors.append(cons_defn) - # update module being built - self.module[adt_var] = adt.TypeData(adt_var, type_params, parsed_constructors) - self.exit_type_param_scope() - - def visitMatch(self, ctx: RelayParser.MatchContext): - match_type = ctx.matchType().getText() - if match_type == "match": - complete_match = True - elif match_type == "match?": - complete_match = False - else: - raise RuntimeError("unknown match type {0}".format(match_type)) - - match_data = self.visit(ctx.expr()) - match_clauses = ctx.matchClauseList() - if match_clauses is None: - match_clauses = [] - else: - match_clauses = match_clauses.matchClause() - parsed_clauses = [] - for clause in match_clauses: - self.enter_var_scope() - pattern = self.visit(clause.pattern()) - clause_body = self.visit(clause.expr()) - self.exit_var_scope() - parsed_clauses.append(adt.Clause(pattern, clause_body)) - return adt.Match(match_data, parsed_clauses, complete=complete_match) - - def visitWildcardPattern(self, ctx: RelayParser.WildcardPatternContext): - return adt.PatternWildcard() - - def visitVarPattern(self, ctx: RelayParser.VarPatternContext): - text = ctx.localVar().getText() - typ = ctx.typeExpr() - if typ is not None: - typ = self.visit(typ) - var = self.mk_var(text[1:], typ=typ) - return adt.PatternVar(var) - - def visitConstructorPattern(self, ctx: RelayParser.ConstructorPatternContext): - constructor_name = ctx.constructorName().getText() - constructor = self.global_type_vars[constructor_name] - pattern_list = ctx.patternList() - if pattern_list is None: - patterns = [] - else: - patterns = [self.visit(pattern) for pattern in pattern_list.pattern()] - return adt.PatternConstructor(constructor, patterns) - - def visitTuplePattern(self, ctx: RelayParser.TuplePatternContext): - return adt.PatternTuple([self.visit(pattern) for pattern in ctx.patternList().pattern()]) - - def visitCallNoAttr(self, ctx: RelayParser.CallNoAttrContext): - return (self.visit_list(ctx.exprList().expr()), None) - - def visitCallWithAttr(self, ctx: RelayParser.CallWithAttrContext): - return (self.visit_list(ctx.expr()), self.visit(ctx.attrSeq())) - - def call(self, func, args, attrs, type_args): - if isinstance(func, OpWrapper): - return func(args, attrs, type_args) - if isinstance(func, adt.Constructor): - return func(*args) - return expr.Call(func, args, attrs, type_args) - - @spanify - def visitCall(self, ctx: RelayParser.CallContext) -> expr.Call: - func = self.visit(ctx.expr()) - args, attrs = self.visit(ctx.callList()) - res = self.call(func, args, attrs, []) - return res - - @spanify - def visitIfElse(self, ctx: RelayParser.IfElseContext) -> expr.If: - """Construct a Relay If node. Creates a new scope for each branch.""" - cond = self.visit(ctx.expr()) - - self.enter_var_scope() - true_branch = self.visit(ctx.body(0)) - self.exit_var_scope() - - self.enter_var_scope() - false_branch = self.visit(ctx.body(1)) - self.exit_var_scope() - - return expr.If(cond, true_branch, false_branch) - - @spanify - def visitGraph(self, ctx: RelayParser.GraphContext) -> expr.Expr: - """Visit a graph variable assignment.""" - graph_nid = int(ctx.graphVar().getText()[1:]) - - self.enter_var_scope() - value = self.visit(ctx.expr(0)) - self.exit_var_scope() - - if graph_nid != len(self.graph_expr): - raise ParseError( - "expected new graph variable to be `%{}`,".format(len(self.graph_expr)) + \ - "but got `%{}`".format(graph_nid)) - self.graph_expr.append(value) - - kont = self.visit(ctx.expr(1)) - return kont - - # Types - - # pylint: disable=unused-argument - def visitIncompleteType(self, ctx: RelayParser.IncompleteTypeContext) -> None: - return None - - def visitTypeCallType(self, ctx: RelayParser.TypeCallTypeContext): - func = self.visit(ctx.generalIdent()) - args = [self.visit(arg) for arg in ctx.typeParamList().typeExpr()] - return ty.TypeCall(func, args) - - def visitParensShape(self, ctx: RelayParser.ParensShapeContext) -> int: - return self.visit(ctx.shape()) - - def visitShapeList(self, ctx: RelayParser.ShapeListContext) -> List[int]: - return self.visit_list(ctx.shape()) - - def visitTensor(self, ctx: RelayParser.TensorContext): - return tuple(self.visit_list(ctx.expr())) - - def visitTensorType(self, ctx: RelayParser.TensorTypeContext) -> ty.TensorType: - """Create a simple tensor type. No generics.""" - - shape = self.visit(ctx.shapeList()) - dtype = self.visit(ctx.typeExpr()) - - if not isinstance(dtype, ty.TensorType): - raise ParseError("expected dtype to be a Relay base type.") - - dtype = dtype.dtype - - return ty.TensorType(shape, dtype) - - def visitTupleType(self, ctx: RelayParser.TupleTypeContext) -> ty.TupleType: - return ty.TupleType(self.visit_list(ctx.typeExpr())) - - def visitFuncType(self, ctx: RelayParser.FuncTypeContext) -> ty.FuncType: - types = self.visit_list(ctx.typeExpr()) - - arg_types = types[:-1] - ret_type = types[-1] - - return ty.FuncType(arg_types, ret_type, [], None) - -def make_parser(data: str) -> RelayParser: - """Construct a RelayParser a given data stream.""" - input_stream = InputStream(data) - lexer = RelayLexer(input_stream) - lexer.addErrorListener(StrictErrorListener(data)) - token_stream = CommonTokenStream(lexer) - p = RelayParser(token_stream) - p.addErrorListener(StrictErrorListener(data)) - return p - -__source_name_counter__ = 0 - -class StrictErrorListener(ErrorListener): - """This ErrorListener fail eagerly on all error, and report the program.""" - def __init__(self, text): - self.text = text - - def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e): - raise Exception("Syntax Error in:\n" + self.text) - - def reportAmbiguity(self, recognizer, dfa, startIndex, stopIndex, exact, ambigAlts, configs): - raise Exception("Ambiguity Error in:\n" + self.text) - - def reportAttemptingFullContext(self, - recognizer, - dfa, - startIndex, - stopIndex, - conflictingAlts, - configs): - raise Exception("Attempting Full Context in:\n" + self.text) - - def reportContextSensitivity(self, recognizer, dfa, startIndex, stopIndex, prediction, configs): - raise Exception("Context Sensitivity in:\n" + self.text) - -def fromtext(data: str, source_name: str = None) -> Union[expr.Expr, IRModule]: - """Parse a Relay program.""" - if data == "": - raise ParseError("cannot parse the empty string.") - - global __source_name_counter__ - - if source_name is None: - source_name = "source_file{0}".format(__source_name_counter__) - - if isinstance(source_name, str): - source_name = SourceName(source_name) - - tree = make_parser(data).prog() - return ParseTreeToRelayIR(source_name).visit(tree) diff --git a/python/tvm/relay/grammar/.gitignore b/python/tvm/relay/grammar/.gitignore deleted file mode 100644 index cffe35e1a41a..000000000000 --- a/python/tvm/relay/grammar/.gitignore +++ /dev/null @@ -1 +0,0 @@ -/.antlr/ diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 deleted file mode 100644 index bfcd18ffc98f..000000000000 --- a/python/tvm/relay/grammar/Relay.g4 +++ /dev/null @@ -1,199 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* - * NOTE: The `USE_ANTLR` option in `config.cmake` must be enabled in order for - * changes in this file to be reflected by the parser. - * NOTE: All upper-case rules are *lexer* rules and all camel-case rules are *parser* rules. - */ - -grammar Relay; - -SEMVER: 'v0.0.4' ; - -// Lexing -// comments -COMMENT : '/*' (COMMENT|.)*? '*/' -> skip; -WS : [ \t\n\r]+ -> skip; -LINE_COMMENT : '//' .*? '\n' -> skip; - -fragment ESCAPED_QUOTE : '\\"'; -QUOTED_STRING : '"' ( ESCAPED_QUOTE | ~('\n'|'\r') )*? '"'; - -// operators -MUL: '*' ; -DIV: '/' ; -ADD: '+' ; -SUB: '-' ; -LT: '<' ; -GT: '>' ; -LE: '<=' ; -GE: '>=' ; -EQ: '==' ; -NE: '!=' ; - -BOOL_LIT - : 'True' - | 'False' - ; - -CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ('.' CNAME)* ; - -// non-negative floats -fragment PREFLOAT : NAT ('.' NAT)? EXP?; // 1.35, 1.35E-9, 0.3, 4.5, 1, 1e10 3e4 - -FLOAT : PREFLOAT 'f'; - -// non-negative ints -NAT: DIGIT+ ; -fragment EXP: [eE] [+\-]? NAT ; // \- since - means "range" inside [...] - -fragment LETTER: [a-zA-Z]; -fragment DIGIT: [0-9]; - -METADATA: 'METADATA:' .*; -// Parsing - -// A Relay program is a list of global definitions or an expression. -prog: SEMVER (defn* | expr) METADATA? EOF ; - -// Covers both operator and type idents -generalIdent: CNAME ('.' CNAME)*; -globalVar: '@' CNAME ; -localVar: '%' ('_' | CNAME) ; -graphVar: '%' NAT ; - -exprList: (expr (',' expr)*)?; -callList - : exprList # callNoAttr - | (expr ',')* attrSeq # callWithAttr - ; - -expr - // operators - : '(' expr ')' # paren - // function application - | expr '(' callList ')' # call - | '-' expr # neg - | expr op=('*'|'/') expr # binOp - | expr op=('+'|'-') expr # binOp - | expr op=('<'|'>'|'<='|'>=') expr # binOp - | expr op=('=='|'!=') expr # binOp - // function definition - | func # funcExpr - // tuples and tensors - | '(' ')' # tuple - | '(' expr ',' ')' # tuple - | '(' expr (',' expr)+ ')' # tuple - | '[' (expr (',' expr)*)? ']' # tensor - | 'if' '(' expr ')' body 'else' body # ifElse - | matchType expr '{' matchClauseList? '}' # match - | expr '.' NAT # projection - // sequencing - | 'let' var '=' expr ';' expr # let - // sugar for let %_ = expr; expr - | expr ';;' expr # let - | graphVar '=' expr ';' expr # graph - | ident # identExpr - | scalar # scalarExpr - | meta # metaExpr - | QUOTED_STRING # stringExpr - ; - -func: 'fn' typeParamList? '(' argList ')' ('->' typeExpr)? body ; -defn - : 'def' globalVar typeParamList? '(' argList ')' ('->' typeExpr)? body # funcDefn - | 'extern' 'type' generalIdent typeParamList? # externAdtDefn - | 'type' generalIdent typeParamList? '{' adtConsDefnList? '}' # adtDefn - ; - -constructorName: CNAME ; - -adtConsDefnList: adtConsDefn (',' adtConsDefn)* ','? ; -adtConsDefn: constructorName ('(' typeExpr (',' typeExpr)* ')')? ; -matchClauseList: matchClause (',' matchClause)* ','? ; -matchClause: pattern '=>' ('{' expr '}' | expr) ; -// complete or incomplete match, respectively -matchType : 'match' | 'match?' ; - -patternList: '(' pattern (',' pattern)* ')'; -pattern - : '_' # wildcardPattern - | localVar (':' typeExpr)? # varPattern - | constructorName patternList? # constructorPattern - | patternList # tuplePattern - ; - -adtCons: constructorName adtConsParamList? ; -adtConsParamList: '(' adtConsParam (',' adtConsParam)* ')' ; -adtConsParam: localVar | constructorName ; - -argList - : varList # argNoAttr - | (var ',')* attrSeq # argWithAttr - ; - -varList: (var (',' var)*)? ; -var: localVar (':' typeExpr)? ; - -attrSeq: attr (',' attr)* ; -attr: CNAME '=' expr ; - -typeExpr - : '(' ')' # tupleType - | '(' typeExpr ')' # typeParen - | '(' typeExpr ',' ')' # tupleType - | '(' typeExpr (',' typeExpr)+ ')' # tupleType - | generalIdent typeParamList # typeCallType - | generalIdent # typeIdentType - | 'Tensor' '[' shapeList ',' typeExpr ']' # tensorType - | 'fn' typeParamList? '(' (typeExpr (',' typeExpr)*)? ')' '->' typeExpr # funcType - | '_' # incompleteType - ; - -typeParamList: '[' typeExpr (',' typeExpr)* ']' ; - -shapeList - : '(' ')' - | '(' shape (',' shape)+ ')' - | shape - ; - -meta : 'meta' '[' CNAME ']' '[' NAT ']'; - -shape - : meta # metaShape - | '(' shape ')' # parensShape - | NAT # intShape - ; - -body: '{' expr '}' ; - -scalar - : FLOAT # scalarFloat - | NAT # scalarInt - | BOOL_LIT # scalarBool - ; - -ident - : generalIdent - | globalVar - | localVar - | graphVar - ; diff --git a/python/tvm/relay/grammar/__init__.py b/python/tvm/relay/grammar/__init__.py deleted file mode 100644 index 13a83393a912..000000000000 --- a/python/tvm/relay/grammar/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/python/tvm/relay/grammar/py3/.gitattributes b/python/tvm/relay/grammar/py3/.gitattributes deleted file mode 100644 index 0eaf9078bc4f..000000000000 --- a/python/tvm/relay/grammar/py3/.gitattributes +++ /dev/null @@ -1,3 +0,0 @@ -Relay* binary -Relay* linguist-generated=true -Relay* linguist-detectable=false diff --git a/python/tvm/relay/grammar/py3/RelayLexer.py b/python/tvm/relay/grammar/py3/RelayLexer.py deleted file mode 100644 index 76e988b45418..000000000000 --- a/python/tvm/relay/grammar/py3/RelayLexer.py +++ /dev/null @@ -1,256 +0,0 @@ -# Generated from /Users/doobs/Code/repo/sampl/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2 -from antlr4 import * -from io import StringIO -from typing.io import TextIO -import sys - - - -def serializedATN(): - with StringIO() as buf: - buf.write("\3\u608b\ua72a\u8133\ub9ed\u417c\u3be7\u7786\u5964\2\62") - buf.write("\u0161\b\1\4\2\t\2\4\3\t\3\4\4\t\4\4\5\t\5\4\6\t\6\4\7") - buf.write("\t\7\4\b\t\b\4\t\t\t\4\n\t\n\4\13\t\13\4\f\t\f\4\r\t\r") - buf.write("\4\16\t\16\4\17\t\17\4\20\t\20\4\21\t\21\4\22\t\22\4\23") - buf.write("\t\23\4\24\t\24\4\25\t\25\4\26\t\26\4\27\t\27\4\30\t\30") - buf.write("\4\31\t\31\4\32\t\32\4\33\t\33\4\34\t\34\4\35\t\35\4\36") - buf.write("\t\36\4\37\t\37\4 \t \4!\t!\4\"\t\"\4#\t#\4$\t$\4%\t%") - buf.write("\4&\t&\4\'\t\'\4(\t(\4)\t)\4*\t*\4+\t+\4,\t,\4-\t-\4.") - buf.write("\t.\4/\t/\4\60\t\60\4\61\t\61\4\62\t\62\4\63\t\63\4\64") - buf.write("\t\64\4\65\t\65\4\66\t\66\3\2\3\2\3\3\3\3\3\4\3\4\3\5") - buf.write("\3\5\3\6\3\6\3\7\3\7\3\b\3\b\3\t\3\t\3\n\3\n\3\13\3\13") - buf.write("\3\13\3\f\3\f\3\f\3\f\3\f\3\r\3\r\3\16\3\16\3\17\3\17") - buf.write("\3\17\3\17\3\20\3\20\3\21\3\21\3\22\3\22\3\22\3\23\3\23") - buf.write("\3\23\3\24\3\24\3\24\3\25\3\25\3\25\3\25\3\26\3\26\3\26") - buf.write("\3\26\3\26\3\26\3\26\3\27\3\27\3\27\3\27\3\27\3\30\3\30") - buf.write("\3\30\3\31\3\31\3\31\3\31\3\31\3\31\3\32\3\32\3\32\3\32") - buf.write("\3\32\3\32\3\32\3\33\3\33\3\34\3\34\3\34\3\34\3\34\3\34") - buf.write("\3\34\3\35\3\35\3\35\3\35\3\35\3\36\3\36\3\36\3\36\3\36") - buf.write("\3\36\3\36\3\37\3\37\3\37\3\37\3\37\7\37\u00d7\n\37\f") - buf.write("\37\16\37\u00da\13\37\3\37\3\37\3\37\3\37\3\37\3 \6 \u00e2") - buf.write("\n \r \16 \u00e3\3 \3 \3!\3!\3!\3!\7!\u00ec\n!\f!\16!") - buf.write("\u00ef\13!\3!\3!\3!\3!\3\"\3\"\3\"\3#\3#\3#\7#\u00fb\n") - buf.write("#\f#\16#\u00fe\13#\3#\3#\3$\3$\3%\3%\3&\3&\3\'\3\'\3(") - buf.write("\3(\3)\3)\3*\3*\3*\3+\3+\3+\3,\3,\3,\3-\3-\3-\3.\3.\3") - buf.write(".\3.\3.\3.\3.\3.\3.\5.\u0123\n.\3/\3/\5/\u0127\n/\3/\3") - buf.write("/\3/\7/\u012c\n/\f/\16/\u012f\13/\3/\3/\7/\u0133\n/\f") - buf.write("/\16/\u0136\13/\3\60\3\60\3\60\5\60\u013b\n\60\3\60\5") - buf.write("\60\u013e\n\60\3\61\3\61\3\61\3\62\6\62\u0144\n\62\r\62") - buf.write("\16\62\u0145\3\63\3\63\5\63\u014a\n\63\3\63\3\63\3\64") - buf.write("\3\64\3\65\3\65\3\66\3\66\3\66\3\66\3\66\3\66\3\66\3\66") - buf.write("\3\66\3\66\3\66\7\66\u015d\n\66\f\66\16\66\u0160\13\66") - buf.write("\5\u00d8\u00ed\u00fc\2\67\3\3\5\4\7\5\t\6\13\7\r\b\17") - buf.write("\t\21\n\23\13\25\f\27\r\31\16\33\17\35\20\37\21!\22#\23") - buf.write("%\24\'\25)\26+\27-\30/\31\61\32\63\33\65\34\67\359\36") - buf.write(";\37= ?!A\"C\2E#G$I%K&M\'O(Q)S*U+W,Y-[.]/_\2a\60c\61e") - buf.write("\2g\2i\2k\62\3\2\b\5\2\13\f\17\17\"\"\4\2\f\f\17\17\4") - buf.write("\2GGgg\4\2--//\4\2C\\c|\3\2\62;\2\u016c\2\3\3\2\2\2\2") - buf.write("\5\3\2\2\2\2\7\3\2\2\2\2\t\3\2\2\2\2\13\3\2\2\2\2\r\3") - buf.write("\2\2\2\2\17\3\2\2\2\2\21\3\2\2\2\2\23\3\2\2\2\2\25\3\2") - buf.write("\2\2\2\27\3\2\2\2\2\31\3\2\2\2\2\33\3\2\2\2\2\35\3\2\2") - buf.write("\2\2\37\3\2\2\2\2!\3\2\2\2\2#\3\2\2\2\2%\3\2\2\2\2\'\3") - buf.write("\2\2\2\2)\3\2\2\2\2+\3\2\2\2\2-\3\2\2\2\2/\3\2\2\2\2\61") - buf.write("\3\2\2\2\2\63\3\2\2\2\2\65\3\2\2\2\2\67\3\2\2\2\29\3\2") - buf.write("\2\2\2;\3\2\2\2\2=\3\2\2\2\2?\3\2\2\2\2A\3\2\2\2\2E\3") - buf.write("\2\2\2\2G\3\2\2\2\2I\3\2\2\2\2K\3\2\2\2\2M\3\2\2\2\2O") - buf.write("\3\2\2\2\2Q\3\2\2\2\2S\3\2\2\2\2U\3\2\2\2\2W\3\2\2\2\2") - buf.write("Y\3\2\2\2\2[\3\2\2\2\2]\3\2\2\2\2a\3\2\2\2\2c\3\2\2\2") - buf.write("\2k\3\2\2\2\3m\3\2\2\2\5o\3\2\2\2\7q\3\2\2\2\ts\3\2\2") - buf.write("\2\13u\3\2\2\2\rw\3\2\2\2\17y\3\2\2\2\21{\3\2\2\2\23}") - buf.write("\3\2\2\2\25\177\3\2\2\2\27\u0082\3\2\2\2\31\u0087\3\2") - buf.write("\2\2\33\u0089\3\2\2\2\35\u008b\3\2\2\2\37\u008f\3\2\2") - buf.write("\2!\u0091\3\2\2\2#\u0093\3\2\2\2%\u0096\3\2\2\2\'\u0099") - buf.write("\3\2\2\2)\u009c\3\2\2\2+\u00a0\3\2\2\2-\u00a7\3\2\2\2") - buf.write("/\u00ac\3\2\2\2\61\u00af\3\2\2\2\63\u00b5\3\2\2\2\65\u00bc") - buf.write("\3\2\2\2\67\u00be\3\2\2\29\u00c5\3\2\2\2;\u00ca\3\2\2") - buf.write("\2=\u00d1\3\2\2\2?\u00e1\3\2\2\2A\u00e7\3\2\2\2C\u00f4") - buf.write("\3\2\2\2E\u00f7\3\2\2\2G\u0101\3\2\2\2I\u0103\3\2\2\2") - buf.write("K\u0105\3\2\2\2M\u0107\3\2\2\2O\u0109\3\2\2\2Q\u010b\3") - buf.write("\2\2\2S\u010d\3\2\2\2U\u0110\3\2\2\2W\u0113\3\2\2\2Y\u0116") - buf.write("\3\2\2\2[\u0122\3\2\2\2]\u0126\3\2\2\2_\u0137\3\2\2\2") - buf.write("a\u013f\3\2\2\2c\u0143\3\2\2\2e\u0147\3\2\2\2g\u014d\3") - buf.write("\2\2\2i\u014f\3\2\2\2k\u0151\3\2\2\2mn\7\60\2\2n\4\3\2") - buf.write("\2\2op\7B\2\2p\6\3\2\2\2qr\7\'\2\2r\b\3\2\2\2st\7a\2\2") - buf.write("t\n\3\2\2\2uv\7.\2\2v\f\3\2\2\2wx\7*\2\2x\16\3\2\2\2y") - buf.write("z\7+\2\2z\20\3\2\2\2{|\7]\2\2|\22\3\2\2\2}~\7_\2\2~\24") - buf.write("\3\2\2\2\177\u0080\7k\2\2\u0080\u0081\7h\2\2\u0081\26") - buf.write("\3\2\2\2\u0082\u0083\7g\2\2\u0083\u0084\7n\2\2\u0084\u0085") - buf.write("\7u\2\2\u0085\u0086\7g\2\2\u0086\30\3\2\2\2\u0087\u0088") - buf.write("\7}\2\2\u0088\32\3\2\2\2\u0089\u008a\7\177\2\2\u008a\34") - buf.write("\3\2\2\2\u008b\u008c\7n\2\2\u008c\u008d\7g\2\2\u008d\u008e") - buf.write("\7v\2\2\u008e\36\3\2\2\2\u008f\u0090\7?\2\2\u0090 \3\2") - buf.write("\2\2\u0091\u0092\7=\2\2\u0092\"\3\2\2\2\u0093\u0094\7") - buf.write("=\2\2\u0094\u0095\7=\2\2\u0095$\3\2\2\2\u0096\u0097\7") - buf.write("h\2\2\u0097\u0098\7p\2\2\u0098&\3\2\2\2\u0099\u009a\7") - buf.write("/\2\2\u009a\u009b\7@\2\2\u009b(\3\2\2\2\u009c\u009d\7") - buf.write("f\2\2\u009d\u009e\7g\2\2\u009e\u009f\7h\2\2\u009f*\3\2") - buf.write("\2\2\u00a0\u00a1\7g\2\2\u00a1\u00a2\7z\2\2\u00a2\u00a3") - buf.write("\7v\2\2\u00a3\u00a4\7g\2\2\u00a4\u00a5\7t\2\2\u00a5\u00a6") - buf.write("\7p\2\2\u00a6,\3\2\2\2\u00a7\u00a8\7v\2\2\u00a8\u00a9") - buf.write("\7{\2\2\u00a9\u00aa\7r\2\2\u00aa\u00ab\7g\2\2\u00ab.\3") - buf.write("\2\2\2\u00ac\u00ad\7?\2\2\u00ad\u00ae\7@\2\2\u00ae\60") - buf.write("\3\2\2\2\u00af\u00b0\7o\2\2\u00b0\u00b1\7c\2\2\u00b1\u00b2") - buf.write("\7v\2\2\u00b2\u00b3\7e\2\2\u00b3\u00b4\7j\2\2\u00b4\62") - buf.write("\3\2\2\2\u00b5\u00b6\7o\2\2\u00b6\u00b7\7c\2\2\u00b7\u00b8") - buf.write("\7v\2\2\u00b8\u00b9\7e\2\2\u00b9\u00ba\7j\2\2\u00ba\u00bb") - buf.write("\7A\2\2\u00bb\64\3\2\2\2\u00bc\u00bd\7<\2\2\u00bd\66\3") - buf.write("\2\2\2\u00be\u00bf\7V\2\2\u00bf\u00c0\7g\2\2\u00c0\u00c1") - buf.write("\7p\2\2\u00c1\u00c2\7u\2\2\u00c2\u00c3\7q\2\2\u00c3\u00c4") - buf.write("\7t\2\2\u00c48\3\2\2\2\u00c5\u00c6\7o\2\2\u00c6\u00c7") - buf.write("\7g\2\2\u00c7\u00c8\7v\2\2\u00c8\u00c9\7c\2\2\u00c9:\3") - buf.write("\2\2\2\u00ca\u00cb\7x\2\2\u00cb\u00cc\7\62\2\2\u00cc\u00cd") - buf.write("\7\60\2\2\u00cd\u00ce\7\62\2\2\u00ce\u00cf\7\60\2\2\u00cf") - buf.write("\u00d0\7\66\2\2\u00d0<\3\2\2\2\u00d1\u00d2\7\61\2\2\u00d2") - buf.write("\u00d3\7,\2\2\u00d3\u00d8\3\2\2\2\u00d4\u00d7\5=\37\2") - buf.write("\u00d5\u00d7\13\2\2\2\u00d6\u00d4\3\2\2\2\u00d6\u00d5") - buf.write("\3\2\2\2\u00d7\u00da\3\2\2\2\u00d8\u00d9\3\2\2\2\u00d8") - buf.write("\u00d6\3\2\2\2\u00d9\u00db\3\2\2\2\u00da\u00d8\3\2\2\2") - buf.write("\u00db\u00dc\7,\2\2\u00dc\u00dd\7\61\2\2\u00dd\u00de\3") - buf.write("\2\2\2\u00de\u00df\b\37\2\2\u00df>\3\2\2\2\u00e0\u00e2") - buf.write("\t\2\2\2\u00e1\u00e0\3\2\2\2\u00e2\u00e3\3\2\2\2\u00e3") - buf.write("\u00e1\3\2\2\2\u00e3\u00e4\3\2\2\2\u00e4\u00e5\3\2\2\2") - buf.write("\u00e5\u00e6\b \2\2\u00e6@\3\2\2\2\u00e7\u00e8\7\61\2") - buf.write("\2\u00e8\u00e9\7\61\2\2\u00e9\u00ed\3\2\2\2\u00ea\u00ec") - buf.write("\13\2\2\2\u00eb\u00ea\3\2\2\2\u00ec\u00ef\3\2\2\2\u00ed") - buf.write("\u00ee\3\2\2\2\u00ed\u00eb\3\2\2\2\u00ee\u00f0\3\2\2\2") - buf.write("\u00ef\u00ed\3\2\2\2\u00f0\u00f1\7\f\2\2\u00f1\u00f2\3") - buf.write("\2\2\2\u00f2\u00f3\b!\2\2\u00f3B\3\2\2\2\u00f4\u00f5\7") - buf.write("^\2\2\u00f5\u00f6\7$\2\2\u00f6D\3\2\2\2\u00f7\u00fc\7") - buf.write("$\2\2\u00f8\u00fb\5C\"\2\u00f9\u00fb\n\3\2\2\u00fa\u00f8") - buf.write("\3\2\2\2\u00fa\u00f9\3\2\2\2\u00fb\u00fe\3\2\2\2\u00fc") - buf.write("\u00fd\3\2\2\2\u00fc\u00fa\3\2\2\2\u00fd\u00ff\3\2\2\2") - buf.write("\u00fe\u00fc\3\2\2\2\u00ff\u0100\7$\2\2\u0100F\3\2\2\2") - buf.write("\u0101\u0102\7,\2\2\u0102H\3\2\2\2\u0103\u0104\7\61\2") - buf.write("\2\u0104J\3\2\2\2\u0105\u0106\7-\2\2\u0106L\3\2\2\2\u0107") - buf.write("\u0108\7/\2\2\u0108N\3\2\2\2\u0109\u010a\7>\2\2\u010a") - buf.write("P\3\2\2\2\u010b\u010c\7@\2\2\u010cR\3\2\2\2\u010d\u010e") - buf.write("\7>\2\2\u010e\u010f\7?\2\2\u010fT\3\2\2\2\u0110\u0111") - buf.write("\7@\2\2\u0111\u0112\7?\2\2\u0112V\3\2\2\2\u0113\u0114") - buf.write("\7?\2\2\u0114\u0115\7?\2\2\u0115X\3\2\2\2\u0116\u0117") - buf.write("\7#\2\2\u0117\u0118\7?\2\2\u0118Z\3\2\2\2\u0119\u011a") - buf.write("\7V\2\2\u011a\u011b\7t\2\2\u011b\u011c\7w\2\2\u011c\u0123") - buf.write("\7g\2\2\u011d\u011e\7H\2\2\u011e\u011f\7c\2\2\u011f\u0120") - buf.write("\7n\2\2\u0120\u0121\7u\2\2\u0121\u0123\7g\2\2\u0122\u0119") - buf.write("\3\2\2\2\u0122\u011d\3\2\2\2\u0123\\\3\2\2\2\u0124\u0127") - buf.write("\7a\2\2\u0125\u0127\5g\64\2\u0126\u0124\3\2\2\2\u0126") - buf.write("\u0125\3\2\2\2\u0127\u012d\3\2\2\2\u0128\u012c\7a\2\2") - buf.write("\u0129\u012c\5g\64\2\u012a\u012c\5i\65\2\u012b\u0128\3") - buf.write("\2\2\2\u012b\u0129\3\2\2\2\u012b\u012a\3\2\2\2\u012c\u012f") - buf.write("\3\2\2\2\u012d\u012b\3\2\2\2\u012d\u012e\3\2\2\2\u012e") - buf.write("\u0134\3\2\2\2\u012f\u012d\3\2\2\2\u0130\u0131\7\60\2") - buf.write("\2\u0131\u0133\5]/\2\u0132\u0130\3\2\2\2\u0133\u0136\3") - buf.write("\2\2\2\u0134\u0132\3\2\2\2\u0134\u0135\3\2\2\2\u0135^") - buf.write("\3\2\2\2\u0136\u0134\3\2\2\2\u0137\u013a\5c\62\2\u0138") - buf.write("\u0139\7\60\2\2\u0139\u013b\5c\62\2\u013a\u0138\3\2\2") - buf.write("\2\u013a\u013b\3\2\2\2\u013b\u013d\3\2\2\2\u013c\u013e") - buf.write("\5e\63\2\u013d\u013c\3\2\2\2\u013d\u013e\3\2\2\2\u013e") - buf.write("`\3\2\2\2\u013f\u0140\5_\60\2\u0140\u0141\7h\2\2\u0141") - buf.write("b\3\2\2\2\u0142\u0144\5i\65\2\u0143\u0142\3\2\2\2\u0144") - buf.write("\u0145\3\2\2\2\u0145\u0143\3\2\2\2\u0145\u0146\3\2\2\2") - buf.write("\u0146d\3\2\2\2\u0147\u0149\t\4\2\2\u0148\u014a\t\5\2") - buf.write("\2\u0149\u0148\3\2\2\2\u0149\u014a\3\2\2\2\u014a\u014b") - buf.write("\3\2\2\2\u014b\u014c\5c\62\2\u014cf\3\2\2\2\u014d\u014e") - buf.write("\t\6\2\2\u014eh\3\2\2\2\u014f\u0150\t\7\2\2\u0150j\3\2") - buf.write("\2\2\u0151\u0152\7O\2\2\u0152\u0153\7G\2\2\u0153\u0154") - buf.write("\7V\2\2\u0154\u0155\7C\2\2\u0155\u0156\7F\2\2\u0156\u0157") - buf.write("\7C\2\2\u0157\u0158\7V\2\2\u0158\u0159\7C\2\2\u0159\u015a") - buf.write("\7<\2\2\u015a\u015e\3\2\2\2\u015b\u015d\13\2\2\2\u015c") - buf.write("\u015b\3\2\2\2\u015d\u0160\3\2\2\2\u015e\u015c\3\2\2\2") - buf.write("\u015e\u015f\3\2\2\2\u015fl\3\2\2\2\u0160\u015e\3\2\2") - buf.write("\2\23\2\u00d6\u00d8\u00e3\u00ed\u00fa\u00fc\u0122\u0126") - buf.write("\u012b\u012d\u0134\u013a\u013d\u0145\u0149\u015e\3\b\2") - buf.write("\2") - return buf.getvalue() - - -class RelayLexer(Lexer): - - atn = ATNDeserializer().deserialize(serializedATN()) - - decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] - - T__0 = 1 - T__1 = 2 - T__2 = 3 - T__3 = 4 - T__4 = 5 - T__5 = 6 - T__6 = 7 - T__7 = 8 - T__8 = 9 - T__9 = 10 - T__10 = 11 - T__11 = 12 - T__12 = 13 - T__13 = 14 - T__14 = 15 - T__15 = 16 - T__16 = 17 - T__17 = 18 - T__18 = 19 - T__19 = 20 - T__20 = 21 - T__21 = 22 - T__22 = 23 - T__23 = 24 - T__24 = 25 - T__25 = 26 - T__26 = 27 - T__27 = 28 - SEMVER = 29 - COMMENT = 30 - WS = 31 - LINE_COMMENT = 32 - QUOTED_STRING = 33 - MUL = 34 - DIV = 35 - ADD = 36 - SUB = 37 - LT = 38 - GT = 39 - LE = 40 - GE = 41 - EQ = 42 - NE = 43 - BOOL_LIT = 44 - CNAME = 45 - FLOAT = 46 - NAT = 47 - METADATA = 48 - - channelNames = [ u"DEFAULT_TOKEN_CHANNEL", u"HIDDEN" ] - - modeNames = [ "DEFAULT_MODE" ] - - literalNames = [ "", - "'.'", "'@'", "'%'", "'_'", "','", "'('", "')'", "'['", "']'", - "'if'", "'else'", "'{'", "'}'", "'let'", "'='", "';'", "';;'", - "'fn'", "'->'", "'def'", "'extern'", "'type'", "'=>'", "'match'", - "'match?'", "':'", "'Tensor'", "'meta'", "'v0.0.4'", "'*'", - "'/'", "'+'", "'-'", "'<'", "'>'", "'<='", "'>='", "'=='", "'!='" ] - - symbolicNames = [ "", - "SEMVER", "COMMENT", "WS", "LINE_COMMENT", "QUOTED_STRING", - "MUL", "DIV", "ADD", "SUB", "LT", "GT", "LE", "GE", "EQ", "NE", - "BOOL_LIT", "CNAME", "FLOAT", "NAT", "METADATA" ] - - ruleNames = [ "T__0", "T__1", "T__2", "T__3", "T__4", "T__5", "T__6", - "T__7", "T__8", "T__9", "T__10", "T__11", "T__12", "T__13", - "T__14", "T__15", "T__16", "T__17", "T__18", "T__19", - "T__20", "T__21", "T__22", "T__23", "T__24", "T__25", - "T__26", "T__27", "SEMVER", "COMMENT", "WS", "LINE_COMMENT", - "ESCAPED_QUOTE", "QUOTED_STRING", "MUL", "DIV", "ADD", - "SUB", "LT", "GT", "LE", "GE", "EQ", "NE", "BOOL_LIT", - "CNAME", "PREFLOAT", "FLOAT", "NAT", "EXP", "LETTER", - "DIGIT", "METADATA" ] - - grammarFileName = "Relay.g4" - - def __init__(self, input=None, output:TextIO = sys.stdout): - super().__init__(input, output) - self.checkVersion("4.7.2") - self._interp = LexerATNSimulator(self, self.atn, self.decisionsToDFA, PredictionContextCache()) - self._actions = None - self._predicates = None - - diff --git a/python/tvm/relay/grammar/py3/RelayParser.py b/python/tvm/relay/grammar/py3/RelayParser.py deleted file mode 100644 index f24eed4be92f..000000000000 --- a/python/tvm/relay/grammar/py3/RelayParser.py +++ /dev/null @@ -1,3732 +0,0 @@ -# Generated from /Users/doobs/Code/repo/sampl/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2 -# encoding: utf-8 -from antlr4 import * -from io import StringIO -from typing.io import TextIO -import sys - - -def serializedATN(): - with StringIO() as buf: - buf.write("\3\u608b\ua72a\u8133\ub9ed\u417c\u3be7\u7786\u5964\3\62") - buf.write("\u0200\4\2\t\2\4\3\t\3\4\4\t\4\4\5\t\5\4\6\t\6\4\7\t\7") - buf.write("\4\b\t\b\4\t\t\t\4\n\t\n\4\13\t\13\4\f\t\f\4\r\t\r\4\16") - buf.write("\t\16\4\17\t\17\4\20\t\20\4\21\t\21\4\22\t\22\4\23\t\23") - buf.write("\4\24\t\24\4\25\t\25\4\26\t\26\4\27\t\27\4\30\t\30\4\31") - buf.write("\t\31\4\32\t\32\4\33\t\33\4\34\t\34\4\35\t\35\4\36\t\36") - buf.write("\4\37\t\37\4 \t \4!\t!\4\"\t\"\4#\t#\3\2\3\2\7\2I\n\2") - buf.write("\f\2\16\2L\13\2\3\2\5\2O\n\2\3\2\5\2R\n\2\3\2\3\2\3\3") - buf.write("\3\3\3\3\7\3Y\n\3\f\3\16\3\\\13\3\3\4\3\4\3\4\3\5\3\5") - buf.write("\3\5\3\6\3\6\3\6\3\7\3\7\3\7\7\7j\n\7\f\7\16\7m\13\7\5") - buf.write("\7o\n\7\3\b\3\b\3\b\3\b\7\bu\n\b\f\b\16\bx\13\b\3\b\5") - buf.write("\b{\n\b\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3") - buf.write("\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\6\t\u0090\n\t\r\t\16\t") - buf.write("\u0091\3\t\3\t\3\t\3\t\3\t\3\t\7\t\u009a\n\t\f\t\16\t") - buf.write("\u009d\13\t\5\t\u009f\n\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t") - buf.write("\3\t\3\t\3\t\3\t\3\t\3\t\5\t\u00ae\n\t\3\t\3\t\3\t\3\t") - buf.write("\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3") - buf.write("\t\3\t\5\t\u00c3\n\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3") - buf.write("\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t") - buf.write("\3\t\7\t\u00dc\n\t\f\t\16\t\u00df\13\t\3\n\3\n\5\n\u00e3") - buf.write("\n\n\3\n\3\n\3\n\3\n\3\n\5\n\u00ea\n\n\3\n\3\n\3\13\3") - buf.write("\13\3\13\5\13\u00f1\n\13\3\13\3\13\3\13\3\13\3\13\5\13") - buf.write("\u00f8\n\13\3\13\3\13\3\13\3\13\3\13\3\13\5\13\u0100\n") - buf.write("\13\3\13\3\13\3\13\5\13\u0105\n\13\3\13\3\13\5\13\u0109") - buf.write("\n\13\3\13\3\13\5\13\u010d\n\13\3\f\3\f\3\r\3\r\3\r\7") - buf.write("\r\u0114\n\r\f\r\16\r\u0117\13\r\3\r\5\r\u011a\n\r\3\16") - buf.write("\3\16\3\16\3\16\3\16\7\16\u0121\n\16\f\16\16\16\u0124") - buf.write("\13\16\3\16\3\16\5\16\u0128\n\16\3\17\3\17\3\17\7\17\u012d") - buf.write("\n\17\f\17\16\17\u0130\13\17\3\17\5\17\u0133\n\17\3\20") - buf.write("\3\20\3\20\3\20\3\20\3\20\3\20\5\20\u013c\n\20\3\21\3") - buf.write("\21\3\22\3\22\3\22\3\22\7\22\u0144\n\22\f\22\16\22\u0147") - buf.write("\13\22\3\22\3\22\3\23\3\23\3\23\3\23\5\23\u014f\n\23\3") - buf.write("\23\3\23\5\23\u0153\n\23\3\23\5\23\u0156\n\23\3\24\3\24") - buf.write("\5\24\u015a\n\24\3\25\3\25\3\25\3\25\7\25\u0160\n\25\f") - buf.write("\25\16\25\u0163\13\25\3\25\3\25\3\26\3\26\5\26\u0169\n") - buf.write("\26\3\27\3\27\3\27\3\27\7\27\u016f\n\27\f\27\16\27\u0172") - buf.write("\13\27\3\27\5\27\u0175\n\27\3\30\3\30\3\30\7\30\u017a") - buf.write("\n\30\f\30\16\30\u017d\13\30\5\30\u017f\n\30\3\31\3\31") - buf.write("\3\31\5\31\u0184\n\31\3\32\3\32\3\32\7\32\u0189\n\32\f") - buf.write("\32\16\32\u018c\13\32\3\33\3\33\3\33\3\33\3\34\3\34\3") - buf.write("\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34") - buf.write("\3\34\3\34\6\34\u01a1\n\34\r\34\16\34\u01a2\3\34\3\34") - buf.write("\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34") - buf.write("\3\34\3\34\5\34\u01b4\n\34\3\34\3\34\3\34\3\34\7\34\u01ba") - buf.write("\n\34\f\34\16\34\u01bd\13\34\5\34\u01bf\n\34\3\34\3\34") - buf.write("\3\34\3\34\5\34\u01c5\n\34\3\35\3\35\3\35\3\35\7\35\u01cb") - buf.write("\n\35\f\35\16\35\u01ce\13\35\3\35\3\35\3\36\3\36\3\36") - buf.write("\3\36\3\36\3\36\6\36\u01d8\n\36\r\36\16\36\u01d9\3\36") - buf.write("\3\36\3\36\5\36\u01df\n\36\3\37\3\37\3\37\3\37\3\37\3") - buf.write("\37\3\37\3\37\3 \3 \3 \3 \3 \3 \5 \u01ef\n \3!\3!\3!\3") - buf.write("!\3\"\3\"\3\"\5\"\u01f8\n\"\3#\3#\3#\3#\5#\u01fe\n#\3") - buf.write("#\2\3\20$\2\4\6\b\n\f\16\20\22\24\26\30\32\34\36 \"$&") - buf.write("(*,.\60\62\64\668:<>@BD\2\b\4\2\6\6//\3\2$%\3\2&\'\3\2") - buf.write("(+\3\2,-\3\2\32\33\2\u0234\2F\3\2\2\2\4U\3\2\2\2\6]\3") - buf.write("\2\2\2\b`\3\2\2\2\nc\3\2\2\2\fn\3\2\2\2\16z\3\2\2\2\20") - buf.write("\u00c2\3\2\2\2\22\u00e0\3\2\2\2\24\u010c\3\2\2\2\26\u010e") - buf.write("\3\2\2\2\30\u0110\3\2\2\2\32\u011b\3\2\2\2\34\u0129\3") - buf.write("\2\2\2\36\u0134\3\2\2\2 \u013d\3\2\2\2\"\u013f\3\2\2\2") - buf.write("$\u0155\3\2\2\2&\u0157\3\2\2\2(\u015b\3\2\2\2*\u0168\3") - buf.write("\2\2\2,\u0174\3\2\2\2.\u017e\3\2\2\2\60\u0180\3\2\2\2") - buf.write("\62\u0185\3\2\2\2\64\u018d\3\2\2\2\66\u01c4\3\2\2\28\u01c6") - buf.write("\3\2\2\2:\u01de\3\2\2\2<\u01e0\3\2\2\2>\u01ee\3\2\2\2") - buf.write("@\u01f0\3\2\2\2B\u01f7\3\2\2\2D\u01fd\3\2\2\2FN\7\37\2") - buf.write("\2GI\5\24\13\2HG\3\2\2\2IL\3\2\2\2JH\3\2\2\2JK\3\2\2\2") - buf.write("KO\3\2\2\2LJ\3\2\2\2MO\5\20\t\2NJ\3\2\2\2NM\3\2\2\2OQ") - buf.write("\3\2\2\2PR\7\62\2\2QP\3\2\2\2QR\3\2\2\2RS\3\2\2\2ST\7") - buf.write("\2\2\3T\3\3\2\2\2UZ\7/\2\2VW\7\3\2\2WY\7/\2\2XV\3\2\2") - buf.write("\2Y\\\3\2\2\2ZX\3\2\2\2Z[\3\2\2\2[\5\3\2\2\2\\Z\3\2\2") - buf.write("\2]^\7\4\2\2^_\7/\2\2_\7\3\2\2\2`a\7\5\2\2ab\t\2\2\2b") - buf.write("\t\3\2\2\2cd\7\5\2\2de\7\61\2\2e\13\3\2\2\2fk\5\20\t\2") - buf.write("gh\7\7\2\2hj\5\20\t\2ig\3\2\2\2jm\3\2\2\2ki\3\2\2\2kl") - buf.write("\3\2\2\2lo\3\2\2\2mk\3\2\2\2nf\3\2\2\2no\3\2\2\2o\r\3") - buf.write("\2\2\2p{\5\f\7\2qr\5\20\t\2rs\7\7\2\2su\3\2\2\2tq\3\2") - buf.write("\2\2ux\3\2\2\2vt\3\2\2\2vw\3\2\2\2wy\3\2\2\2xv\3\2\2\2") - buf.write("y{\5\62\32\2zp\3\2\2\2zv\3\2\2\2{\17\3\2\2\2|}\b\t\1\2") - buf.write("}~\7\b\2\2~\177\5\20\t\2\177\u0080\7\t\2\2\u0080\u00c3") - buf.write("\3\2\2\2\u0081\u0082\7\'\2\2\u0082\u00c3\5\20\t\26\u0083") - buf.write("\u00c3\5\22\n\2\u0084\u0085\7\b\2\2\u0085\u00c3\7\t\2") - buf.write("\2\u0086\u0087\7\b\2\2\u0087\u0088\5\20\t\2\u0088\u0089") - buf.write("\7\7\2\2\u0089\u008a\7\t\2\2\u008a\u00c3\3\2\2\2\u008b") - buf.write("\u008c\7\b\2\2\u008c\u008f\5\20\t\2\u008d\u008e\7\7\2") - buf.write("\2\u008e\u0090\5\20\t\2\u008f\u008d\3\2\2\2\u0090\u0091") - buf.write("\3\2\2\2\u0091\u008f\3\2\2\2\u0091\u0092\3\2\2\2\u0092") - buf.write("\u0093\3\2\2\2\u0093\u0094\7\t\2\2\u0094\u00c3\3\2\2\2") - buf.write("\u0095\u009e\7\n\2\2\u0096\u009b\5\20\t\2\u0097\u0098") - buf.write("\7\7\2\2\u0098\u009a\5\20\t\2\u0099\u0097\3\2\2\2\u009a") - buf.write("\u009d\3\2\2\2\u009b\u0099\3\2\2\2\u009b\u009c\3\2\2\2") - buf.write("\u009c\u009f\3\2\2\2\u009d\u009b\3\2\2\2\u009e\u0096\3") - buf.write("\2\2\2\u009e\u009f\3\2\2\2\u009f\u00a0\3\2\2\2\u00a0\u00c3") - buf.write("\7\13\2\2\u00a1\u00a2\7\f\2\2\u00a2\u00a3\7\b\2\2\u00a3") - buf.write("\u00a4\5\20\t\2\u00a4\u00a5\7\t\2\2\u00a5\u00a6\5@!\2") - buf.write("\u00a6\u00a7\7\r\2\2\u00a7\u00a8\5@!\2\u00a8\u00c3\3\2") - buf.write("\2\2\u00a9\u00aa\5 \21\2\u00aa\u00ab\5\20\t\2\u00ab\u00ad") - buf.write("\7\16\2\2\u00ac\u00ae\5\34\17\2\u00ad\u00ac\3\2\2\2\u00ad") - buf.write("\u00ae\3\2\2\2\u00ae\u00af\3\2\2\2\u00af\u00b0\7\17\2") - buf.write("\2\u00b0\u00c3\3\2\2\2\u00b1\u00b2\7\20\2\2\u00b2\u00b3") - buf.write("\5\60\31\2\u00b3\u00b4\7\21\2\2\u00b4\u00b5\5\20\t\2\u00b5") - buf.write("\u00b6\7\22\2\2\u00b6\u00b7\5\20\t\t\u00b7\u00c3\3\2\2") - buf.write("\2\u00b8\u00b9\5\n\6\2\u00b9\u00ba\7\21\2\2\u00ba\u00bb") - buf.write("\5\20\t\2\u00bb\u00bc\7\22\2\2\u00bc\u00bd\5\20\t\7\u00bd") - buf.write("\u00c3\3\2\2\2\u00be\u00c3\5D#\2\u00bf\u00c3\5B\"\2\u00c0") - buf.write("\u00c3\5<\37\2\u00c1\u00c3\7#\2\2\u00c2|\3\2\2\2\u00c2") - buf.write("\u0081\3\2\2\2\u00c2\u0083\3\2\2\2\u00c2\u0084\3\2\2\2") - buf.write("\u00c2\u0086\3\2\2\2\u00c2\u008b\3\2\2\2\u00c2\u0095\3") - buf.write("\2\2\2\u00c2\u00a1\3\2\2\2\u00c2\u00a9\3\2\2\2\u00c2\u00b1") - buf.write("\3\2\2\2\u00c2\u00b8\3\2\2\2\u00c2\u00be\3\2\2\2\u00c2") - buf.write("\u00bf\3\2\2\2\u00c2\u00c0\3\2\2\2\u00c2\u00c1\3\2\2\2") - buf.write("\u00c3\u00dd\3\2\2\2\u00c4\u00c5\f\25\2\2\u00c5\u00c6") - buf.write("\t\3\2\2\u00c6\u00dc\5\20\t\26\u00c7\u00c8\f\24\2\2\u00c8") - buf.write("\u00c9\t\4\2\2\u00c9\u00dc\5\20\t\25\u00ca\u00cb\f\23") - buf.write("\2\2\u00cb\u00cc\t\5\2\2\u00cc\u00dc\5\20\t\24\u00cd\u00ce") - buf.write("\f\22\2\2\u00ce\u00cf\t\6\2\2\u00cf\u00dc\5\20\t\23\u00d0") - buf.write("\u00d1\f\b\2\2\u00d1\u00d2\7\23\2\2\u00d2\u00dc\5\20\t") - buf.write("\t\u00d3\u00d4\f\27\2\2\u00d4\u00d5\7\b\2\2\u00d5\u00d6") - buf.write("\5\16\b\2\u00d6\u00d7\7\t\2\2\u00d7\u00dc\3\2\2\2\u00d8") - buf.write("\u00d9\f\n\2\2\u00d9\u00da\7\3\2\2\u00da\u00dc\7\61\2") - buf.write("\2\u00db\u00c4\3\2\2\2\u00db\u00c7\3\2\2\2\u00db\u00ca") - buf.write("\3\2\2\2\u00db\u00cd\3\2\2\2\u00db\u00d0\3\2\2\2\u00db") - buf.write("\u00d3\3\2\2\2\u00db\u00d8\3\2\2\2\u00dc\u00df\3\2\2\2") - buf.write("\u00dd\u00db\3\2\2\2\u00dd\u00de\3\2\2\2\u00de\21\3\2") - buf.write("\2\2\u00df\u00dd\3\2\2\2\u00e0\u00e2\7\24\2\2\u00e1\u00e3") - buf.write("\58\35\2\u00e2\u00e1\3\2\2\2\u00e2\u00e3\3\2\2\2\u00e3") - buf.write("\u00e4\3\2\2\2\u00e4\u00e5\7\b\2\2\u00e5\u00e6\5,\27\2") - buf.write("\u00e6\u00e9\7\t\2\2\u00e7\u00e8\7\25\2\2\u00e8\u00ea") - buf.write("\5\66\34\2\u00e9\u00e7\3\2\2\2\u00e9\u00ea\3\2\2\2\u00ea") - buf.write("\u00eb\3\2\2\2\u00eb\u00ec\5@!\2\u00ec\23\3\2\2\2\u00ed") - buf.write("\u00ee\7\26\2\2\u00ee\u00f0\5\6\4\2\u00ef\u00f1\58\35") - buf.write("\2\u00f0\u00ef\3\2\2\2\u00f0\u00f1\3\2\2\2\u00f1\u00f2") - buf.write("\3\2\2\2\u00f2\u00f3\7\b\2\2\u00f3\u00f4\5,\27\2\u00f4") - buf.write("\u00f7\7\t\2\2\u00f5\u00f6\7\25\2\2\u00f6\u00f8\5\66\34") - buf.write("\2\u00f7\u00f5\3\2\2\2\u00f7\u00f8\3\2\2\2\u00f8\u00f9") - buf.write("\3\2\2\2\u00f9\u00fa\5@!\2\u00fa\u010d\3\2\2\2\u00fb\u00fc") - buf.write("\7\27\2\2\u00fc\u00fd\7\30\2\2\u00fd\u00ff\5\4\3\2\u00fe") - buf.write("\u0100\58\35\2\u00ff\u00fe\3\2\2\2\u00ff\u0100\3\2\2\2") - buf.write("\u0100\u010d\3\2\2\2\u0101\u0102\7\30\2\2\u0102\u0104") - buf.write("\5\4\3\2\u0103\u0105\58\35\2\u0104\u0103\3\2\2\2\u0104") - buf.write("\u0105\3\2\2\2\u0105\u0106\3\2\2\2\u0106\u0108\7\16\2") - buf.write("\2\u0107\u0109\5\30\r\2\u0108\u0107\3\2\2\2\u0108\u0109") - buf.write("\3\2\2\2\u0109\u010a\3\2\2\2\u010a\u010b\7\17\2\2\u010b") - buf.write("\u010d\3\2\2\2\u010c\u00ed\3\2\2\2\u010c\u00fb\3\2\2\2") - buf.write("\u010c\u0101\3\2\2\2\u010d\25\3\2\2\2\u010e\u010f\7/\2") - buf.write("\2\u010f\27\3\2\2\2\u0110\u0115\5\32\16\2\u0111\u0112") - buf.write("\7\7\2\2\u0112\u0114\5\32\16\2\u0113\u0111\3\2\2\2\u0114") - buf.write("\u0117\3\2\2\2\u0115\u0113\3\2\2\2\u0115\u0116\3\2\2\2") - buf.write("\u0116\u0119\3\2\2\2\u0117\u0115\3\2\2\2\u0118\u011a\7") - buf.write("\7\2\2\u0119\u0118\3\2\2\2\u0119\u011a\3\2\2\2\u011a\31") - buf.write("\3\2\2\2\u011b\u0127\5\26\f\2\u011c\u011d\7\b\2\2\u011d") - buf.write("\u0122\5\66\34\2\u011e\u011f\7\7\2\2\u011f\u0121\5\66") - buf.write("\34\2\u0120\u011e\3\2\2\2\u0121\u0124\3\2\2\2\u0122\u0120") - buf.write("\3\2\2\2\u0122\u0123\3\2\2\2\u0123\u0125\3\2\2\2\u0124") - buf.write("\u0122\3\2\2\2\u0125\u0126\7\t\2\2\u0126\u0128\3\2\2\2") - buf.write("\u0127\u011c\3\2\2\2\u0127\u0128\3\2\2\2\u0128\33\3\2") - buf.write("\2\2\u0129\u012e\5\36\20\2\u012a\u012b\7\7\2\2\u012b\u012d") - buf.write("\5\36\20\2\u012c\u012a\3\2\2\2\u012d\u0130\3\2\2\2\u012e") - buf.write("\u012c\3\2\2\2\u012e\u012f\3\2\2\2\u012f\u0132\3\2\2\2") - buf.write("\u0130\u012e\3\2\2\2\u0131\u0133\7\7\2\2\u0132\u0131\3") - buf.write("\2\2\2\u0132\u0133\3\2\2\2\u0133\35\3\2\2\2\u0134\u0135") - buf.write("\5$\23\2\u0135\u013b\7\31\2\2\u0136\u0137\7\16\2\2\u0137") - buf.write("\u0138\5\20\t\2\u0138\u0139\7\17\2\2\u0139\u013c\3\2\2") - buf.write("\2\u013a\u013c\5\20\t\2\u013b\u0136\3\2\2\2\u013b\u013a") - buf.write("\3\2\2\2\u013c\37\3\2\2\2\u013d\u013e\t\7\2\2\u013e!\3") - buf.write("\2\2\2\u013f\u0140\7\b\2\2\u0140\u0145\5$\23\2\u0141\u0142") - buf.write("\7\7\2\2\u0142\u0144\5$\23\2\u0143\u0141\3\2\2\2\u0144") - buf.write("\u0147\3\2\2\2\u0145\u0143\3\2\2\2\u0145\u0146\3\2\2\2") - buf.write("\u0146\u0148\3\2\2\2\u0147\u0145\3\2\2\2\u0148\u0149\7") - buf.write("\t\2\2\u0149#\3\2\2\2\u014a\u0156\7\6\2\2\u014b\u014e") - buf.write("\5\b\5\2\u014c\u014d\7\34\2\2\u014d\u014f\5\66\34\2\u014e") - buf.write("\u014c\3\2\2\2\u014e\u014f\3\2\2\2\u014f\u0156\3\2\2\2") - buf.write("\u0150\u0152\5\26\f\2\u0151\u0153\5\"\22\2\u0152\u0151") - buf.write("\3\2\2\2\u0152\u0153\3\2\2\2\u0153\u0156\3\2\2\2\u0154") - buf.write("\u0156\5\"\22\2\u0155\u014a\3\2\2\2\u0155\u014b\3\2\2") - buf.write("\2\u0155\u0150\3\2\2\2\u0155\u0154\3\2\2\2\u0156%\3\2") - buf.write("\2\2\u0157\u0159\5\26\f\2\u0158\u015a\5(\25\2\u0159\u0158") - buf.write("\3\2\2\2\u0159\u015a\3\2\2\2\u015a\'\3\2\2\2\u015b\u015c") - buf.write("\7\b\2\2\u015c\u0161\5*\26\2\u015d\u015e\7\7\2\2\u015e") - buf.write("\u0160\5*\26\2\u015f\u015d\3\2\2\2\u0160\u0163\3\2\2\2") - buf.write("\u0161\u015f\3\2\2\2\u0161\u0162\3\2\2\2\u0162\u0164\3") - buf.write("\2\2\2\u0163\u0161\3\2\2\2\u0164\u0165\7\t\2\2\u0165)") - buf.write("\3\2\2\2\u0166\u0169\5\b\5\2\u0167\u0169\5\26\f\2\u0168") - buf.write("\u0166\3\2\2\2\u0168\u0167\3\2\2\2\u0169+\3\2\2\2\u016a") - buf.write("\u0175\5.\30\2\u016b\u016c\5\60\31\2\u016c\u016d\7\7\2") - buf.write("\2\u016d\u016f\3\2\2\2\u016e\u016b\3\2\2\2\u016f\u0172") - buf.write("\3\2\2\2\u0170\u016e\3\2\2\2\u0170\u0171\3\2\2\2\u0171") - buf.write("\u0173\3\2\2\2\u0172\u0170\3\2\2\2\u0173\u0175\5\62\32") - buf.write("\2\u0174\u016a\3\2\2\2\u0174\u0170\3\2\2\2\u0175-\3\2") - buf.write("\2\2\u0176\u017b\5\60\31\2\u0177\u0178\7\7\2\2\u0178\u017a") - buf.write("\5\60\31\2\u0179\u0177\3\2\2\2\u017a\u017d\3\2\2\2\u017b") - buf.write("\u0179\3\2\2\2\u017b\u017c\3\2\2\2\u017c\u017f\3\2\2\2") - buf.write("\u017d\u017b\3\2\2\2\u017e\u0176\3\2\2\2\u017e\u017f\3") - buf.write("\2\2\2\u017f/\3\2\2\2\u0180\u0183\5\b\5\2\u0181\u0182") - buf.write("\7\34\2\2\u0182\u0184\5\66\34\2\u0183\u0181\3\2\2\2\u0183") - buf.write("\u0184\3\2\2\2\u0184\61\3\2\2\2\u0185\u018a\5\64\33\2") - buf.write("\u0186\u0187\7\7\2\2\u0187\u0189\5\64\33\2\u0188\u0186") - buf.write("\3\2\2\2\u0189\u018c\3\2\2\2\u018a\u0188\3\2\2\2\u018a") - buf.write("\u018b\3\2\2\2\u018b\63\3\2\2\2\u018c\u018a\3\2\2\2\u018d") - buf.write("\u018e\7/\2\2\u018e\u018f\7\21\2\2\u018f\u0190\5\20\t") - buf.write("\2\u0190\65\3\2\2\2\u0191\u0192\7\b\2\2\u0192\u01c5\7") - buf.write("\t\2\2\u0193\u0194\7\b\2\2\u0194\u0195\5\66\34\2\u0195") - buf.write("\u0196\7\t\2\2\u0196\u01c5\3\2\2\2\u0197\u0198\7\b\2\2") - buf.write("\u0198\u0199\5\66\34\2\u0199\u019a\7\7\2\2\u019a\u019b") - buf.write("\7\t\2\2\u019b\u01c5\3\2\2\2\u019c\u019d\7\b\2\2\u019d") - buf.write("\u01a0\5\66\34\2\u019e\u019f\7\7\2\2\u019f\u01a1\5\66") - buf.write("\34\2\u01a0\u019e\3\2\2\2\u01a1\u01a2\3\2\2\2\u01a2\u01a0") - buf.write("\3\2\2\2\u01a2\u01a3\3\2\2\2\u01a3\u01a4\3\2\2\2\u01a4") - buf.write("\u01a5\7\t\2\2\u01a5\u01c5\3\2\2\2\u01a6\u01a7\5\4\3\2") - buf.write("\u01a7\u01a8\58\35\2\u01a8\u01c5\3\2\2\2\u01a9\u01c5\5") - buf.write("\4\3\2\u01aa\u01ab\7\35\2\2\u01ab\u01ac\7\n\2\2\u01ac") - buf.write("\u01ad\5:\36\2\u01ad\u01ae\7\7\2\2\u01ae\u01af\5\66\34") - buf.write("\2\u01af\u01b0\7\13\2\2\u01b0\u01c5\3\2\2\2\u01b1\u01b3") - buf.write("\7\24\2\2\u01b2\u01b4\58\35\2\u01b3\u01b2\3\2\2\2\u01b3") - buf.write("\u01b4\3\2\2\2\u01b4\u01b5\3\2\2\2\u01b5\u01be\7\b\2\2") - buf.write("\u01b6\u01bb\5\66\34\2\u01b7\u01b8\7\7\2\2\u01b8\u01ba") - buf.write("\5\66\34\2\u01b9\u01b7\3\2\2\2\u01ba\u01bd\3\2\2\2\u01bb") - buf.write("\u01b9\3\2\2\2\u01bb\u01bc\3\2\2\2\u01bc\u01bf\3\2\2\2") - buf.write("\u01bd\u01bb\3\2\2\2\u01be\u01b6\3\2\2\2\u01be\u01bf\3") - buf.write("\2\2\2\u01bf\u01c0\3\2\2\2\u01c0\u01c1\7\t\2\2\u01c1\u01c2") - buf.write("\7\25\2\2\u01c2\u01c5\5\66\34\2\u01c3\u01c5\7\6\2\2\u01c4") - buf.write("\u0191\3\2\2\2\u01c4\u0193\3\2\2\2\u01c4\u0197\3\2\2\2") - buf.write("\u01c4\u019c\3\2\2\2\u01c4\u01a6\3\2\2\2\u01c4\u01a9\3") - buf.write("\2\2\2\u01c4\u01aa\3\2\2\2\u01c4\u01b1\3\2\2\2\u01c4\u01c3") - buf.write("\3\2\2\2\u01c5\67\3\2\2\2\u01c6\u01c7\7\n\2\2\u01c7\u01cc") - buf.write("\5\66\34\2\u01c8\u01c9\7\7\2\2\u01c9\u01cb\5\66\34\2\u01ca") - buf.write("\u01c8\3\2\2\2\u01cb\u01ce\3\2\2\2\u01cc\u01ca\3\2\2\2") - buf.write("\u01cc\u01cd\3\2\2\2\u01cd\u01cf\3\2\2\2\u01ce\u01cc\3") - buf.write("\2\2\2\u01cf\u01d0\7\13\2\2\u01d09\3\2\2\2\u01d1\u01d2") - buf.write("\7\b\2\2\u01d2\u01df\7\t\2\2\u01d3\u01d4\7\b\2\2\u01d4") - buf.write("\u01d7\5> \2\u01d5\u01d6\7\7\2\2\u01d6\u01d8\5> \2\u01d7") - buf.write("\u01d5\3\2\2\2\u01d8\u01d9\3\2\2\2\u01d9\u01d7\3\2\2\2") - buf.write("\u01d9\u01da\3\2\2\2\u01da\u01db\3\2\2\2\u01db\u01dc\7") - buf.write("\t\2\2\u01dc\u01df\3\2\2\2\u01dd\u01df\5> \2\u01de\u01d1") - buf.write("\3\2\2\2\u01de\u01d3\3\2\2\2\u01de\u01dd\3\2\2\2\u01df") - buf.write(";\3\2\2\2\u01e0\u01e1\7\36\2\2\u01e1\u01e2\7\n\2\2\u01e2") - buf.write("\u01e3\7/\2\2\u01e3\u01e4\7\13\2\2\u01e4\u01e5\7\n\2\2") - buf.write("\u01e5\u01e6\7\61\2\2\u01e6\u01e7\7\13\2\2\u01e7=\3\2") - buf.write("\2\2\u01e8\u01ef\5<\37\2\u01e9\u01ea\7\b\2\2\u01ea\u01eb") - buf.write("\5> \2\u01eb\u01ec\7\t\2\2\u01ec\u01ef\3\2\2\2\u01ed\u01ef") - buf.write("\7\61\2\2\u01ee\u01e8\3\2\2\2\u01ee\u01e9\3\2\2\2\u01ee") - buf.write("\u01ed\3\2\2\2\u01ef?\3\2\2\2\u01f0\u01f1\7\16\2\2\u01f1") - buf.write("\u01f2\5\20\t\2\u01f2\u01f3\7\17\2\2\u01f3A\3\2\2\2\u01f4") - buf.write("\u01f8\7\60\2\2\u01f5\u01f8\7\61\2\2\u01f6\u01f8\7.\2") - buf.write("\2\u01f7\u01f4\3\2\2\2\u01f7\u01f5\3\2\2\2\u01f7\u01f6") - buf.write("\3\2\2\2\u01f8C\3\2\2\2\u01f9\u01fe\5\4\3\2\u01fa\u01fe") - buf.write("\5\6\4\2\u01fb\u01fe\5\b\5\2\u01fc\u01fe\5\n\6\2\u01fd") - buf.write("\u01f9\3\2\2\2\u01fd\u01fa\3\2\2\2\u01fd\u01fb\3\2\2\2") - buf.write("\u01fd\u01fc\3\2\2\2\u01feE\3\2\2\28JNQZknvz\u0091\u009b") - buf.write("\u009e\u00ad\u00c2\u00db\u00dd\u00e2\u00e9\u00f0\u00f7") - buf.write("\u00ff\u0104\u0108\u010c\u0115\u0119\u0122\u0127\u012e") - buf.write("\u0132\u013b\u0145\u014e\u0152\u0155\u0159\u0161\u0168") - buf.write("\u0170\u0174\u017b\u017e\u0183\u018a\u01a2\u01b3\u01bb") - buf.write("\u01be\u01c4\u01cc\u01d9\u01de\u01ee\u01f7\u01fd") - return buf.getvalue() - - -class RelayParser ( Parser ): - - grammarFileName = "Relay.g4" - - atn = ATNDeserializer().deserialize(serializedATN()) - - decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] - - sharedContextCache = PredictionContextCache() - - literalNames = [ "", "'.'", "'@'", "'%'", "'_'", "','", "'('", - "')'", "'['", "']'", "'if'", "'else'", "'{'", "'}'", - "'let'", "'='", "';'", "';;'", "'fn'", "'->'", "'def'", - "'extern'", "'type'", "'=>'", "'match'", "'match?'", - "':'", "'Tensor'", "'meta'", "'v0.0.4'", "", - "", "", "", "'*'", "'/'", - "'+'", "'-'", "'<'", "'>'", "'<='", "'>='", "'=='", - "'!='" ] - - symbolicNames = [ "", "", "", "", - "", "", "", "", - "", "", "", "", - "", "", "", "", - "", "", "", "", - "", "", "", "", - "", "", "", "", - "", "SEMVER", "COMMENT", "WS", "LINE_COMMENT", - "QUOTED_STRING", "MUL", "DIV", "ADD", "SUB", "LT", - "GT", "LE", "GE", "EQ", "NE", "BOOL_LIT", "CNAME", - "FLOAT", "NAT", "METADATA" ] - - RULE_prog = 0 - RULE_generalIdent = 1 - RULE_globalVar = 2 - RULE_localVar = 3 - RULE_graphVar = 4 - RULE_exprList = 5 - RULE_callList = 6 - RULE_expr = 7 - RULE_func = 8 - RULE_defn = 9 - RULE_constructorName = 10 - RULE_adtConsDefnList = 11 - RULE_adtConsDefn = 12 - RULE_matchClauseList = 13 - RULE_matchClause = 14 - RULE_matchType = 15 - RULE_patternList = 16 - RULE_pattern = 17 - RULE_adtCons = 18 - RULE_adtConsParamList = 19 - RULE_adtConsParam = 20 - RULE_argList = 21 - RULE_varList = 22 - RULE_var = 23 - RULE_attrSeq = 24 - RULE_attr = 25 - RULE_typeExpr = 26 - RULE_typeParamList = 27 - RULE_shapeList = 28 - RULE_meta = 29 - RULE_shape = 30 - RULE_body = 31 - RULE_scalar = 32 - RULE_ident = 33 - - ruleNames = [ "prog", "generalIdent", "globalVar", "localVar", "graphVar", - "exprList", "callList", "expr", "func", "defn", "constructorName", - "adtConsDefnList", "adtConsDefn", "matchClauseList", - "matchClause", "matchType", "patternList", "pattern", - "adtCons", "adtConsParamList", "adtConsParam", "argList", - "varList", "var", "attrSeq", "attr", "typeExpr", "typeParamList", - "shapeList", "meta", "shape", "body", "scalar", "ident" ] - - EOF = Token.EOF - T__0=1 - T__1=2 - T__2=3 - T__3=4 - T__4=5 - T__5=6 - T__6=7 - T__7=8 - T__8=9 - T__9=10 - T__10=11 - T__11=12 - T__12=13 - T__13=14 - T__14=15 - T__15=16 - T__16=17 - T__17=18 - T__18=19 - T__19=20 - T__20=21 - T__21=22 - T__22=23 - T__23=24 - T__24=25 - T__25=26 - T__26=27 - T__27=28 - SEMVER=29 - COMMENT=30 - WS=31 - LINE_COMMENT=32 - QUOTED_STRING=33 - MUL=34 - DIV=35 - ADD=36 - SUB=37 - LT=38 - GT=39 - LE=40 - GE=41 - EQ=42 - NE=43 - BOOL_LIT=44 - CNAME=45 - FLOAT=46 - NAT=47 - METADATA=48 - - def __init__(self, input:TokenStream, output:TextIO = sys.stdout): - super().__init__(input, output) - self.checkVersion("4.7.2") - self._interp = ParserATNSimulator(self, self.atn, self.decisionsToDFA, self.sharedContextCache) - self._predicates = None - - - - - class ProgContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def SEMVER(self): - return self.getToken(RelayParser.SEMVER, 0) - - def EOF(self): - return self.getToken(RelayParser.EOF, 0) - - def expr(self): - return self.getTypedRuleContext(RelayParser.ExprContext,0) - - - def METADATA(self): - return self.getToken(RelayParser.METADATA, 0) - - def defn(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.DefnContext) - else: - return self.getTypedRuleContext(RelayParser.DefnContext,i) - - - def getRuleIndex(self): - return RelayParser.RULE_prog - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitProg" ): - return visitor.visitProg(self) - else: - return visitor.visitChildren(self) - - - - - def prog(self): - - localctx = RelayParser.ProgContext(self, self._ctx, self.state) - self.enterRule(localctx, 0, self.RULE_prog) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 68 - self.match(RelayParser.SEMVER) - self.state = 76 - self._errHandler.sync(self) - token = self._input.LA(1) - if token in [RelayParser.EOF, RelayParser.T__19, RelayParser.T__20, RelayParser.T__21, RelayParser.METADATA]: - self.state = 72 - self._errHandler.sync(self) - _la = self._input.LA(1) - while (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__19) | (1 << RelayParser.T__20) | (1 << RelayParser.T__21))) != 0): - self.state = 69 - self.defn() - self.state = 74 - self._errHandler.sync(self) - _la = self._input.LA(1) - - pass - elif token in [RelayParser.T__1, RelayParser.T__2, RelayParser.T__5, RelayParser.T__7, RelayParser.T__9, RelayParser.T__13, RelayParser.T__17, RelayParser.T__23, RelayParser.T__24, RelayParser.T__27, RelayParser.QUOTED_STRING, RelayParser.SUB, RelayParser.BOOL_LIT, RelayParser.CNAME, RelayParser.FLOAT, RelayParser.NAT]: - self.state = 75 - self.expr(0) - pass - else: - raise NoViableAltException(self) - - self.state = 79 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.METADATA: - self.state = 78 - self.match(RelayParser.METADATA) - - - self.state = 81 - self.match(RelayParser.EOF) - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class GeneralIdentContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def CNAME(self, i:int=None): - if i is None: - return self.getTokens(RelayParser.CNAME) - else: - return self.getToken(RelayParser.CNAME, i) - - def getRuleIndex(self): - return RelayParser.RULE_generalIdent - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitGeneralIdent" ): - return visitor.visitGeneralIdent(self) - else: - return visitor.visitChildren(self) - - - - - def generalIdent(self): - - localctx = RelayParser.GeneralIdentContext(self, self._ctx, self.state) - self.enterRule(localctx, 2, self.RULE_generalIdent) - try: - self.enterOuterAlt(localctx, 1) - self.state = 83 - self.match(RelayParser.CNAME) - self.state = 88 - self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,3,self._ctx) - while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: - if _alt==1: - self.state = 84 - self.match(RelayParser.T__0) - self.state = 85 - self.match(RelayParser.CNAME) - self.state = 90 - self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,3,self._ctx) - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class GlobalVarContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def CNAME(self): - return self.getToken(RelayParser.CNAME, 0) - - def getRuleIndex(self): - return RelayParser.RULE_globalVar - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitGlobalVar" ): - return visitor.visitGlobalVar(self) - else: - return visitor.visitChildren(self) - - - - - def globalVar(self): - - localctx = RelayParser.GlobalVarContext(self, self._ctx, self.state) - self.enterRule(localctx, 4, self.RULE_globalVar) - try: - self.enterOuterAlt(localctx, 1) - self.state = 91 - self.match(RelayParser.T__1) - self.state = 92 - self.match(RelayParser.CNAME) - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class LocalVarContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def CNAME(self): - return self.getToken(RelayParser.CNAME, 0) - - def getRuleIndex(self): - return RelayParser.RULE_localVar - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitLocalVar" ): - return visitor.visitLocalVar(self) - else: - return visitor.visitChildren(self) - - - - - def localVar(self): - - localctx = RelayParser.LocalVarContext(self, self._ctx, self.state) - self.enterRule(localctx, 6, self.RULE_localVar) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 94 - self.match(RelayParser.T__2) - self.state = 95 - _la = self._input.LA(1) - if not(_la==RelayParser.T__3 or _la==RelayParser.CNAME): - self._errHandler.recoverInline(self) - else: - self._errHandler.reportMatch(self) - self.consume() - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class GraphVarContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def NAT(self): - return self.getToken(RelayParser.NAT, 0) - - def getRuleIndex(self): - return RelayParser.RULE_graphVar - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitGraphVar" ): - return visitor.visitGraphVar(self) - else: - return visitor.visitChildren(self) - - - - - def graphVar(self): - - localctx = RelayParser.GraphVarContext(self, self._ctx, self.state) - self.enterRule(localctx, 8, self.RULE_graphVar) - try: - self.enterOuterAlt(localctx, 1) - self.state = 97 - self.match(RelayParser.T__2) - self.state = 98 - self.match(RelayParser.NAT) - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class ExprListContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def expr(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.ExprContext) - else: - return self.getTypedRuleContext(RelayParser.ExprContext,i) - - - def getRuleIndex(self): - return RelayParser.RULE_exprList - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitExprList" ): - return visitor.visitExprList(self) - else: - return visitor.visitChildren(self) - - - - - def exprList(self): - - localctx = RelayParser.ExprListContext(self, self._ctx, self.state) - self.enterRule(localctx, 10, self.RULE_exprList) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 108 - self._errHandler.sync(self) - _la = self._input.LA(1) - if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__1) | (1 << RelayParser.T__2) | (1 << RelayParser.T__5) | (1 << RelayParser.T__7) | (1 << RelayParser.T__9) | (1 << RelayParser.T__13) | (1 << RelayParser.T__17) | (1 << RelayParser.T__23) | (1 << RelayParser.T__24) | (1 << RelayParser.T__27) | (1 << RelayParser.QUOTED_STRING) | (1 << RelayParser.SUB) | (1 << RelayParser.BOOL_LIT) | (1 << RelayParser.CNAME) | (1 << RelayParser.FLOAT) | (1 << RelayParser.NAT))) != 0): - self.state = 100 - self.expr(0) - self.state = 105 - self._errHandler.sync(self) - _la = self._input.LA(1) - while _la==RelayParser.T__4: - self.state = 101 - self.match(RelayParser.T__4) - self.state = 102 - self.expr(0) - self.state = 107 - self._errHandler.sync(self) - _la = self._input.LA(1) - - - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class CallListContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - - def getRuleIndex(self): - return RelayParser.RULE_callList - - - def copyFrom(self, ctx:ParserRuleContext): - super().copyFrom(ctx) - - - - class CallWithAttrContext(CallListContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.CallListContext - super().__init__(parser) - self.copyFrom(ctx) - - def attrSeq(self): - return self.getTypedRuleContext(RelayParser.AttrSeqContext,0) - - def expr(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.ExprContext) - else: - return self.getTypedRuleContext(RelayParser.ExprContext,i) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitCallWithAttr" ): - return visitor.visitCallWithAttr(self) - else: - return visitor.visitChildren(self) - - - class CallNoAttrContext(CallListContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.CallListContext - super().__init__(parser) - self.copyFrom(ctx) - - def exprList(self): - return self.getTypedRuleContext(RelayParser.ExprListContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitCallNoAttr" ): - return visitor.visitCallNoAttr(self) - else: - return visitor.visitChildren(self) - - - - def callList(self): - - localctx = RelayParser.CallListContext(self, self._ctx, self.state) - self.enterRule(localctx, 12, self.RULE_callList) - try: - self.state = 120 - self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,7,self._ctx) - if la_ == 1: - localctx = RelayParser.CallNoAttrContext(self, localctx) - self.enterOuterAlt(localctx, 1) - self.state = 110 - self.exprList() - pass - - elif la_ == 2: - localctx = RelayParser.CallWithAttrContext(self, localctx) - self.enterOuterAlt(localctx, 2) - self.state = 116 - self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,6,self._ctx) - while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: - if _alt==1: - self.state = 111 - self.expr(0) - self.state = 112 - self.match(RelayParser.T__4) - self.state = 118 - self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,6,self._ctx) - - self.state = 119 - self.attrSeq() - pass - - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class ExprContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - - def getRuleIndex(self): - return RelayParser.RULE_expr - - - def copyFrom(self, ctx:ParserRuleContext): - super().copyFrom(ctx) - - - class FuncExprContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def func(self): - return self.getTypedRuleContext(RelayParser.FuncContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitFuncExpr" ): - return visitor.visitFuncExpr(self) - else: - return visitor.visitChildren(self) - - - class MetaExprContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def meta(self): - return self.getTypedRuleContext(RelayParser.MetaContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitMetaExpr" ): - return visitor.visitMetaExpr(self) - else: - return visitor.visitChildren(self) - - - class MatchContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def matchType(self): - return self.getTypedRuleContext(RelayParser.MatchTypeContext,0) - - def expr(self): - return self.getTypedRuleContext(RelayParser.ExprContext,0) - - def matchClauseList(self): - return self.getTypedRuleContext(RelayParser.MatchClauseListContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitMatch" ): - return visitor.visitMatch(self) - else: - return visitor.visitChildren(self) - - - class TensorContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def expr(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.ExprContext) - else: - return self.getTypedRuleContext(RelayParser.ExprContext,i) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitTensor" ): - return visitor.visitTensor(self) - else: - return visitor.visitChildren(self) - - - class GraphContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def graphVar(self): - return self.getTypedRuleContext(RelayParser.GraphVarContext,0) - - def expr(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.ExprContext) - else: - return self.getTypedRuleContext(RelayParser.ExprContext,i) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitGraph" ): - return visitor.visitGraph(self) - else: - return visitor.visitChildren(self) - - - class IdentExprContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def ident(self): - return self.getTypedRuleContext(RelayParser.IdentContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitIdentExpr" ): - return visitor.visitIdentExpr(self) - else: - return visitor.visitChildren(self) - - - class StringExprContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def QUOTED_STRING(self): - return self.getToken(RelayParser.QUOTED_STRING, 0) - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitStringExpr" ): - return visitor.visitStringExpr(self) - else: - return visitor.visitChildren(self) - - - class CallContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def expr(self): - return self.getTypedRuleContext(RelayParser.ExprContext,0) - - def callList(self): - return self.getTypedRuleContext(RelayParser.CallListContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitCall" ): - return visitor.visitCall(self) - else: - return visitor.visitChildren(self) - - - class NegContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def SUB(self): - return self.getToken(RelayParser.SUB, 0) - def expr(self): - return self.getTypedRuleContext(RelayParser.ExprContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitNeg" ): - return visitor.visitNeg(self) - else: - return visitor.visitChildren(self) - - - class TupleContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def expr(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.ExprContext) - else: - return self.getTypedRuleContext(RelayParser.ExprContext,i) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitTuple" ): - return visitor.visitTuple(self) - else: - return visitor.visitChildren(self) - - - class ParenContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def expr(self): - return self.getTypedRuleContext(RelayParser.ExprContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitParen" ): - return visitor.visitParen(self) - else: - return visitor.visitChildren(self) - - - class ScalarExprContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def scalar(self): - return self.getTypedRuleContext(RelayParser.ScalarContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitScalarExpr" ): - return visitor.visitScalarExpr(self) - else: - return visitor.visitChildren(self) - - - class LetContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def var(self): - return self.getTypedRuleContext(RelayParser.VarContext,0) - - def expr(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.ExprContext) - else: - return self.getTypedRuleContext(RelayParser.ExprContext,i) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitLet" ): - return visitor.visitLet(self) - else: - return visitor.visitChildren(self) - - - class ProjectionContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def expr(self): - return self.getTypedRuleContext(RelayParser.ExprContext,0) - - def NAT(self): - return self.getToken(RelayParser.NAT, 0) - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitProjection" ): - return visitor.visitProjection(self) - else: - return visitor.visitChildren(self) - - - class IfElseContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def expr(self): - return self.getTypedRuleContext(RelayParser.ExprContext,0) - - def body(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.BodyContext) - else: - return self.getTypedRuleContext(RelayParser.BodyContext,i) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitIfElse" ): - return visitor.visitIfElse(self) - else: - return visitor.visitChildren(self) - - - class BinOpContext(ExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext - super().__init__(parser) - self.op = None # Token - self.copyFrom(ctx) - - def expr(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.ExprContext) - else: - return self.getTypedRuleContext(RelayParser.ExprContext,i) - - def MUL(self): - return self.getToken(RelayParser.MUL, 0) - def DIV(self): - return self.getToken(RelayParser.DIV, 0) - def ADD(self): - return self.getToken(RelayParser.ADD, 0) - def SUB(self): - return self.getToken(RelayParser.SUB, 0) - def LT(self): - return self.getToken(RelayParser.LT, 0) - def GT(self): - return self.getToken(RelayParser.GT, 0) - def LE(self): - return self.getToken(RelayParser.LE, 0) - def GE(self): - return self.getToken(RelayParser.GE, 0) - def EQ(self): - return self.getToken(RelayParser.EQ, 0) - def NE(self): - return self.getToken(RelayParser.NE, 0) - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitBinOp" ): - return visitor.visitBinOp(self) - else: - return visitor.visitChildren(self) - - - - def expr(self, _p:int=0): - _parentctx = self._ctx - _parentState = self.state - localctx = RelayParser.ExprContext(self, self._ctx, _parentState) - _prevctx = localctx - _startState = 14 - self.enterRecursionRule(localctx, 14, self.RULE_expr, _p) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 192 - self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,12,self._ctx) - if la_ == 1: - localctx = RelayParser.ParenContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - - self.state = 123 - self.match(RelayParser.T__5) - self.state = 124 - self.expr(0) - self.state = 125 - self.match(RelayParser.T__6) - pass - - elif la_ == 2: - localctx = RelayParser.NegContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 127 - self.match(RelayParser.SUB) - self.state = 128 - self.expr(20) - pass - - elif la_ == 3: - localctx = RelayParser.FuncExprContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 129 - self.func() - pass - - elif la_ == 4: - localctx = RelayParser.TupleContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 130 - self.match(RelayParser.T__5) - self.state = 131 - self.match(RelayParser.T__6) - pass - - elif la_ == 5: - localctx = RelayParser.TupleContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 132 - self.match(RelayParser.T__5) - self.state = 133 - self.expr(0) - self.state = 134 - self.match(RelayParser.T__4) - self.state = 135 - self.match(RelayParser.T__6) - pass - - elif la_ == 6: - localctx = RelayParser.TupleContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 137 - self.match(RelayParser.T__5) - self.state = 138 - self.expr(0) - self.state = 141 - self._errHandler.sync(self) - _la = self._input.LA(1) - while True: - self.state = 139 - self.match(RelayParser.T__4) - self.state = 140 - self.expr(0) - self.state = 143 - self._errHandler.sync(self) - _la = self._input.LA(1) - if not (_la==RelayParser.T__4): - break - - self.state = 145 - self.match(RelayParser.T__6) - pass - - elif la_ == 7: - localctx = RelayParser.TensorContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 147 - self.match(RelayParser.T__7) - self.state = 156 - self._errHandler.sync(self) - _la = self._input.LA(1) - if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__1) | (1 << RelayParser.T__2) | (1 << RelayParser.T__5) | (1 << RelayParser.T__7) | (1 << RelayParser.T__9) | (1 << RelayParser.T__13) | (1 << RelayParser.T__17) | (1 << RelayParser.T__23) | (1 << RelayParser.T__24) | (1 << RelayParser.T__27) | (1 << RelayParser.QUOTED_STRING) | (1 << RelayParser.SUB) | (1 << RelayParser.BOOL_LIT) | (1 << RelayParser.CNAME) | (1 << RelayParser.FLOAT) | (1 << RelayParser.NAT))) != 0): - self.state = 148 - self.expr(0) - self.state = 153 - self._errHandler.sync(self) - _la = self._input.LA(1) - while _la==RelayParser.T__4: - self.state = 149 - self.match(RelayParser.T__4) - self.state = 150 - self.expr(0) - self.state = 155 - self._errHandler.sync(self) - _la = self._input.LA(1) - - - - self.state = 158 - self.match(RelayParser.T__8) - pass - - elif la_ == 8: - localctx = RelayParser.IfElseContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 159 - self.match(RelayParser.T__9) - self.state = 160 - self.match(RelayParser.T__5) - self.state = 161 - self.expr(0) - self.state = 162 - self.match(RelayParser.T__6) - self.state = 163 - self.body() - self.state = 164 - self.match(RelayParser.T__10) - self.state = 165 - self.body() - pass - - elif la_ == 9: - localctx = RelayParser.MatchContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 167 - self.matchType() - self.state = 168 - self.expr(0) - self.state = 169 - self.match(RelayParser.T__11) - self.state = 171 - self._errHandler.sync(self) - _la = self._input.LA(1) - if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__2) | (1 << RelayParser.T__3) | (1 << RelayParser.T__5) | (1 << RelayParser.CNAME))) != 0): - self.state = 170 - self.matchClauseList() - - - self.state = 173 - self.match(RelayParser.T__12) - pass - - elif la_ == 10: - localctx = RelayParser.LetContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 175 - self.match(RelayParser.T__13) - self.state = 176 - self.var() - self.state = 177 - self.match(RelayParser.T__14) - self.state = 178 - self.expr(0) - self.state = 179 - self.match(RelayParser.T__15) - self.state = 180 - self.expr(7) - pass - - elif la_ == 11: - localctx = RelayParser.GraphContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 182 - self.graphVar() - self.state = 183 - self.match(RelayParser.T__14) - self.state = 184 - self.expr(0) - self.state = 185 - self.match(RelayParser.T__15) - self.state = 186 - self.expr(5) - pass - - elif la_ == 12: - localctx = RelayParser.IdentExprContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 188 - self.ident() - pass - - elif la_ == 13: - localctx = RelayParser.ScalarExprContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 189 - self.scalar() - pass - - elif la_ == 14: - localctx = RelayParser.MetaExprContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 190 - self.meta() - pass - - elif la_ == 15: - localctx = RelayParser.StringExprContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 191 - self.match(RelayParser.QUOTED_STRING) - pass - - - self._ctx.stop = self._input.LT(-1) - self.state = 219 - self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,14,self._ctx) - while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: - if _alt==1: - if self._parseListeners is not None: - self.triggerExitRuleEvent() - _prevctx = localctx - self.state = 217 - self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,13,self._ctx) - if la_ == 1: - localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) - self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 194 - if not self.precpred(self._ctx, 19): - from antlr4.error.Errors import FailedPredicateException - raise FailedPredicateException(self, "self.precpred(self._ctx, 19)") - self.state = 195 - localctx.op = self._input.LT(1) - _la = self._input.LA(1) - if not(_la==RelayParser.MUL or _la==RelayParser.DIV): - localctx.op = self._errHandler.recoverInline(self) - else: - self._errHandler.reportMatch(self) - self.consume() - self.state = 196 - self.expr(20) - pass - - elif la_ == 2: - localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) - self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 197 - if not self.precpred(self._ctx, 18): - from antlr4.error.Errors import FailedPredicateException - raise FailedPredicateException(self, "self.precpred(self._ctx, 18)") - self.state = 198 - localctx.op = self._input.LT(1) - _la = self._input.LA(1) - if not(_la==RelayParser.ADD or _la==RelayParser.SUB): - localctx.op = self._errHandler.recoverInline(self) - else: - self._errHandler.reportMatch(self) - self.consume() - self.state = 199 - self.expr(19) - pass - - elif la_ == 3: - localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) - self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 200 - if not self.precpred(self._ctx, 17): - from antlr4.error.Errors import FailedPredicateException - raise FailedPredicateException(self, "self.precpred(self._ctx, 17)") - self.state = 201 - localctx.op = self._input.LT(1) - _la = self._input.LA(1) - if not((((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.LT) | (1 << RelayParser.GT) | (1 << RelayParser.LE) | (1 << RelayParser.GE))) != 0)): - localctx.op = self._errHandler.recoverInline(self) - else: - self._errHandler.reportMatch(self) - self.consume() - self.state = 202 - self.expr(18) - pass - - elif la_ == 4: - localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) - self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 203 - if not self.precpred(self._ctx, 16): - from antlr4.error.Errors import FailedPredicateException - raise FailedPredicateException(self, "self.precpred(self._ctx, 16)") - self.state = 204 - localctx.op = self._input.LT(1) - _la = self._input.LA(1) - if not(_la==RelayParser.EQ or _la==RelayParser.NE): - localctx.op = self._errHandler.recoverInline(self) - else: - self._errHandler.reportMatch(self) - self.consume() - self.state = 205 - self.expr(17) - pass - - elif la_ == 5: - localctx = RelayParser.LetContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) - self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 206 - if not self.precpred(self._ctx, 6): - from antlr4.error.Errors import FailedPredicateException - raise FailedPredicateException(self, "self.precpred(self._ctx, 6)") - self.state = 207 - self.match(RelayParser.T__16) - self.state = 208 - self.expr(7) - pass - - elif la_ == 6: - localctx = RelayParser.CallContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) - self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 209 - if not self.precpred(self._ctx, 21): - from antlr4.error.Errors import FailedPredicateException - raise FailedPredicateException(self, "self.precpred(self._ctx, 21)") - self.state = 210 - self.match(RelayParser.T__5) - self.state = 211 - self.callList() - self.state = 212 - self.match(RelayParser.T__6) - pass - - elif la_ == 7: - localctx = RelayParser.ProjectionContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) - self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 214 - if not self.precpred(self._ctx, 8): - from antlr4.error.Errors import FailedPredicateException - raise FailedPredicateException(self, "self.precpred(self._ctx, 8)") - self.state = 215 - self.match(RelayParser.T__0) - self.state = 216 - self.match(RelayParser.NAT) - pass - - - self.state = 221 - self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,14,self._ctx) - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.unrollRecursionContexts(_parentctx) - return localctx - - - class FuncContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def argList(self): - return self.getTypedRuleContext(RelayParser.ArgListContext,0) - - - def body(self): - return self.getTypedRuleContext(RelayParser.BodyContext,0) - - - def typeParamList(self): - return self.getTypedRuleContext(RelayParser.TypeParamListContext,0) - - - def typeExpr(self): - return self.getTypedRuleContext(RelayParser.TypeExprContext,0) - - - def getRuleIndex(self): - return RelayParser.RULE_func - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitFunc" ): - return visitor.visitFunc(self) - else: - return visitor.visitChildren(self) - - - - - def func(self): - - localctx = RelayParser.FuncContext(self, self._ctx, self.state) - self.enterRule(localctx, 16, self.RULE_func) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 222 - self.match(RelayParser.T__17) - self.state = 224 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__7: - self.state = 223 - self.typeParamList() - - - self.state = 226 - self.match(RelayParser.T__5) - self.state = 227 - self.argList() - self.state = 228 - self.match(RelayParser.T__6) - self.state = 231 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__18: - self.state = 229 - self.match(RelayParser.T__18) - self.state = 230 - self.typeExpr() - - - self.state = 233 - self.body() - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class DefnContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - - def getRuleIndex(self): - return RelayParser.RULE_defn - - - def copyFrom(self, ctx:ParserRuleContext): - super().copyFrom(ctx) - - - - class ExternAdtDefnContext(DefnContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.DefnContext - super().__init__(parser) - self.copyFrom(ctx) - - def generalIdent(self): - return self.getTypedRuleContext(RelayParser.GeneralIdentContext,0) - - def typeParamList(self): - return self.getTypedRuleContext(RelayParser.TypeParamListContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitExternAdtDefn" ): - return visitor.visitExternAdtDefn(self) - else: - return visitor.visitChildren(self) - - - class FuncDefnContext(DefnContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.DefnContext - super().__init__(parser) - self.copyFrom(ctx) - - def globalVar(self): - return self.getTypedRuleContext(RelayParser.GlobalVarContext,0) - - def argList(self): - return self.getTypedRuleContext(RelayParser.ArgListContext,0) - - def body(self): - return self.getTypedRuleContext(RelayParser.BodyContext,0) - - def typeParamList(self): - return self.getTypedRuleContext(RelayParser.TypeParamListContext,0) - - def typeExpr(self): - return self.getTypedRuleContext(RelayParser.TypeExprContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitFuncDefn" ): - return visitor.visitFuncDefn(self) - else: - return visitor.visitChildren(self) - - - class AdtDefnContext(DefnContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.DefnContext - super().__init__(parser) - self.copyFrom(ctx) - - def generalIdent(self): - return self.getTypedRuleContext(RelayParser.GeneralIdentContext,0) - - def typeParamList(self): - return self.getTypedRuleContext(RelayParser.TypeParamListContext,0) - - def adtConsDefnList(self): - return self.getTypedRuleContext(RelayParser.AdtConsDefnListContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitAdtDefn" ): - return visitor.visitAdtDefn(self) - else: - return visitor.visitChildren(self) - - - - def defn(self): - - localctx = RelayParser.DefnContext(self, self._ctx, self.state) - self.enterRule(localctx, 18, self.RULE_defn) - self._la = 0 # Token type - try: - self.state = 266 - self._errHandler.sync(self) - token = self._input.LA(1) - if token in [RelayParser.T__19]: - localctx = RelayParser.FuncDefnContext(self, localctx) - self.enterOuterAlt(localctx, 1) - self.state = 235 - self.match(RelayParser.T__19) - self.state = 236 - self.globalVar() - self.state = 238 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__7: - self.state = 237 - self.typeParamList() - - - self.state = 240 - self.match(RelayParser.T__5) - self.state = 241 - self.argList() - self.state = 242 - self.match(RelayParser.T__6) - self.state = 245 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__18: - self.state = 243 - self.match(RelayParser.T__18) - self.state = 244 - self.typeExpr() - - - self.state = 247 - self.body() - pass - elif token in [RelayParser.T__20]: - localctx = RelayParser.ExternAdtDefnContext(self, localctx) - self.enterOuterAlt(localctx, 2) - self.state = 249 - self.match(RelayParser.T__20) - self.state = 250 - self.match(RelayParser.T__21) - self.state = 251 - self.generalIdent() - self.state = 253 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__7: - self.state = 252 - self.typeParamList() - - - pass - elif token in [RelayParser.T__21]: - localctx = RelayParser.AdtDefnContext(self, localctx) - self.enterOuterAlt(localctx, 3) - self.state = 255 - self.match(RelayParser.T__21) - self.state = 256 - self.generalIdent() - self.state = 258 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__7: - self.state = 257 - self.typeParamList() - - - self.state = 260 - self.match(RelayParser.T__11) - self.state = 262 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.CNAME: - self.state = 261 - self.adtConsDefnList() - - - self.state = 264 - self.match(RelayParser.T__12) - pass - else: - raise NoViableAltException(self) - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class ConstructorNameContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def CNAME(self): - return self.getToken(RelayParser.CNAME, 0) - - def getRuleIndex(self): - return RelayParser.RULE_constructorName - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitConstructorName" ): - return visitor.visitConstructorName(self) - else: - return visitor.visitChildren(self) - - - - - def constructorName(self): - - localctx = RelayParser.ConstructorNameContext(self, self._ctx, self.state) - self.enterRule(localctx, 20, self.RULE_constructorName) - try: - self.enterOuterAlt(localctx, 1) - self.state = 268 - self.match(RelayParser.CNAME) - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class AdtConsDefnListContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def adtConsDefn(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.AdtConsDefnContext) - else: - return self.getTypedRuleContext(RelayParser.AdtConsDefnContext,i) - - - def getRuleIndex(self): - return RelayParser.RULE_adtConsDefnList - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitAdtConsDefnList" ): - return visitor.visitAdtConsDefnList(self) - else: - return visitor.visitChildren(self) - - - - - def adtConsDefnList(self): - - localctx = RelayParser.AdtConsDefnListContext(self, self._ctx, self.state) - self.enterRule(localctx, 22, self.RULE_adtConsDefnList) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 270 - self.adtConsDefn() - self.state = 275 - self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,23,self._ctx) - while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: - if _alt==1: - self.state = 271 - self.match(RelayParser.T__4) - self.state = 272 - self.adtConsDefn() - self.state = 277 - self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,23,self._ctx) - - self.state = 279 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__4: - self.state = 278 - self.match(RelayParser.T__4) - - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class AdtConsDefnContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def constructorName(self): - return self.getTypedRuleContext(RelayParser.ConstructorNameContext,0) - - - def typeExpr(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.TypeExprContext) - else: - return self.getTypedRuleContext(RelayParser.TypeExprContext,i) - - - def getRuleIndex(self): - return RelayParser.RULE_adtConsDefn - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitAdtConsDefn" ): - return visitor.visitAdtConsDefn(self) - else: - return visitor.visitChildren(self) - - - - - def adtConsDefn(self): - - localctx = RelayParser.AdtConsDefnContext(self, self._ctx, self.state) - self.enterRule(localctx, 24, self.RULE_adtConsDefn) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 281 - self.constructorName() - self.state = 293 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__5: - self.state = 282 - self.match(RelayParser.T__5) - self.state = 283 - self.typeExpr() - self.state = 288 - self._errHandler.sync(self) - _la = self._input.LA(1) - while _la==RelayParser.T__4: - self.state = 284 - self.match(RelayParser.T__4) - self.state = 285 - self.typeExpr() - self.state = 290 - self._errHandler.sync(self) - _la = self._input.LA(1) - - self.state = 291 - self.match(RelayParser.T__6) - - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class MatchClauseListContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def matchClause(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.MatchClauseContext) - else: - return self.getTypedRuleContext(RelayParser.MatchClauseContext,i) - - - def getRuleIndex(self): - return RelayParser.RULE_matchClauseList - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitMatchClauseList" ): - return visitor.visitMatchClauseList(self) - else: - return visitor.visitChildren(self) - - - - - def matchClauseList(self): - - localctx = RelayParser.MatchClauseListContext(self, self._ctx, self.state) - self.enterRule(localctx, 26, self.RULE_matchClauseList) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 295 - self.matchClause() - self.state = 300 - self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,27,self._ctx) - while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: - if _alt==1: - self.state = 296 - self.match(RelayParser.T__4) - self.state = 297 - self.matchClause() - self.state = 302 - self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,27,self._ctx) - - self.state = 304 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__4: - self.state = 303 - self.match(RelayParser.T__4) - - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class MatchClauseContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def pattern(self): - return self.getTypedRuleContext(RelayParser.PatternContext,0) - - - def expr(self): - return self.getTypedRuleContext(RelayParser.ExprContext,0) - - - def getRuleIndex(self): - return RelayParser.RULE_matchClause - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitMatchClause" ): - return visitor.visitMatchClause(self) - else: - return visitor.visitChildren(self) - - - - - def matchClause(self): - - localctx = RelayParser.MatchClauseContext(self, self._ctx, self.state) - self.enterRule(localctx, 28, self.RULE_matchClause) - try: - self.enterOuterAlt(localctx, 1) - self.state = 306 - self.pattern() - self.state = 307 - self.match(RelayParser.T__22) - self.state = 313 - self._errHandler.sync(self) - token = self._input.LA(1) - if token in [RelayParser.T__11]: - self.state = 308 - self.match(RelayParser.T__11) - self.state = 309 - self.expr(0) - self.state = 310 - self.match(RelayParser.T__12) - pass - elif token in [RelayParser.T__1, RelayParser.T__2, RelayParser.T__5, RelayParser.T__7, RelayParser.T__9, RelayParser.T__13, RelayParser.T__17, RelayParser.T__23, RelayParser.T__24, RelayParser.T__27, RelayParser.QUOTED_STRING, RelayParser.SUB, RelayParser.BOOL_LIT, RelayParser.CNAME, RelayParser.FLOAT, RelayParser.NAT]: - self.state = 312 - self.expr(0) - pass - else: - raise NoViableAltException(self) - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class MatchTypeContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - - def getRuleIndex(self): - return RelayParser.RULE_matchType - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitMatchType" ): - return visitor.visitMatchType(self) - else: - return visitor.visitChildren(self) - - - - - def matchType(self): - - localctx = RelayParser.MatchTypeContext(self, self._ctx, self.state) - self.enterRule(localctx, 30, self.RULE_matchType) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 315 - _la = self._input.LA(1) - if not(_la==RelayParser.T__23 or _la==RelayParser.T__24): - self._errHandler.recoverInline(self) - else: - self._errHandler.reportMatch(self) - self.consume() - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class PatternListContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def pattern(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.PatternContext) - else: - return self.getTypedRuleContext(RelayParser.PatternContext,i) - - - def getRuleIndex(self): - return RelayParser.RULE_patternList - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitPatternList" ): - return visitor.visitPatternList(self) - else: - return visitor.visitChildren(self) - - - - - def patternList(self): - - localctx = RelayParser.PatternListContext(self, self._ctx, self.state) - self.enterRule(localctx, 32, self.RULE_patternList) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 317 - self.match(RelayParser.T__5) - self.state = 318 - self.pattern() - self.state = 323 - self._errHandler.sync(self) - _la = self._input.LA(1) - while _la==RelayParser.T__4: - self.state = 319 - self.match(RelayParser.T__4) - self.state = 320 - self.pattern() - self.state = 325 - self._errHandler.sync(self) - _la = self._input.LA(1) - - self.state = 326 - self.match(RelayParser.T__6) - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class PatternContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - - def getRuleIndex(self): - return RelayParser.RULE_pattern - - - def copyFrom(self, ctx:ParserRuleContext): - super().copyFrom(ctx) - - - - class WildcardPatternContext(PatternContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.PatternContext - super().__init__(parser) - self.copyFrom(ctx) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitWildcardPattern" ): - return visitor.visitWildcardPattern(self) - else: - return visitor.visitChildren(self) - - - class ConstructorPatternContext(PatternContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.PatternContext - super().__init__(parser) - self.copyFrom(ctx) - - def constructorName(self): - return self.getTypedRuleContext(RelayParser.ConstructorNameContext,0) - - def patternList(self): - return self.getTypedRuleContext(RelayParser.PatternListContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitConstructorPattern" ): - return visitor.visitConstructorPattern(self) - else: - return visitor.visitChildren(self) - - - class TuplePatternContext(PatternContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.PatternContext - super().__init__(parser) - self.copyFrom(ctx) - - def patternList(self): - return self.getTypedRuleContext(RelayParser.PatternListContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitTuplePattern" ): - return visitor.visitTuplePattern(self) - else: - return visitor.visitChildren(self) - - - class VarPatternContext(PatternContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.PatternContext - super().__init__(parser) - self.copyFrom(ctx) - - def localVar(self): - return self.getTypedRuleContext(RelayParser.LocalVarContext,0) - - def typeExpr(self): - return self.getTypedRuleContext(RelayParser.TypeExprContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitVarPattern" ): - return visitor.visitVarPattern(self) - else: - return visitor.visitChildren(self) - - - - def pattern(self): - - localctx = RelayParser.PatternContext(self, self._ctx, self.state) - self.enterRule(localctx, 34, self.RULE_pattern) - self._la = 0 # Token type - try: - self.state = 339 - self._errHandler.sync(self) - token = self._input.LA(1) - if token in [RelayParser.T__3]: - localctx = RelayParser.WildcardPatternContext(self, localctx) - self.enterOuterAlt(localctx, 1) - self.state = 328 - self.match(RelayParser.T__3) - pass - elif token in [RelayParser.T__2]: - localctx = RelayParser.VarPatternContext(self, localctx) - self.enterOuterAlt(localctx, 2) - self.state = 329 - self.localVar() - self.state = 332 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__25: - self.state = 330 - self.match(RelayParser.T__25) - self.state = 331 - self.typeExpr() - - - pass - elif token in [RelayParser.CNAME]: - localctx = RelayParser.ConstructorPatternContext(self, localctx) - self.enterOuterAlt(localctx, 3) - self.state = 334 - self.constructorName() - self.state = 336 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__5: - self.state = 335 - self.patternList() - - - pass - elif token in [RelayParser.T__5]: - localctx = RelayParser.TuplePatternContext(self, localctx) - self.enterOuterAlt(localctx, 4) - self.state = 338 - self.patternList() - pass - else: - raise NoViableAltException(self) - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class AdtConsContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def constructorName(self): - return self.getTypedRuleContext(RelayParser.ConstructorNameContext,0) - - - def adtConsParamList(self): - return self.getTypedRuleContext(RelayParser.AdtConsParamListContext,0) - - - def getRuleIndex(self): - return RelayParser.RULE_adtCons - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitAdtCons" ): - return visitor.visitAdtCons(self) - else: - return visitor.visitChildren(self) - - - - - def adtCons(self): - - localctx = RelayParser.AdtConsContext(self, self._ctx, self.state) - self.enterRule(localctx, 36, self.RULE_adtCons) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 341 - self.constructorName() - self.state = 343 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__5: - self.state = 342 - self.adtConsParamList() - - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class AdtConsParamListContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def adtConsParam(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.AdtConsParamContext) - else: - return self.getTypedRuleContext(RelayParser.AdtConsParamContext,i) - - - def getRuleIndex(self): - return RelayParser.RULE_adtConsParamList - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitAdtConsParamList" ): - return visitor.visitAdtConsParamList(self) - else: - return visitor.visitChildren(self) - - - - - def adtConsParamList(self): - - localctx = RelayParser.AdtConsParamListContext(self, self._ctx, self.state) - self.enterRule(localctx, 38, self.RULE_adtConsParamList) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 345 - self.match(RelayParser.T__5) - self.state = 346 - self.adtConsParam() - self.state = 351 - self._errHandler.sync(self) - _la = self._input.LA(1) - while _la==RelayParser.T__4: - self.state = 347 - self.match(RelayParser.T__4) - self.state = 348 - self.adtConsParam() - self.state = 353 - self._errHandler.sync(self) - _la = self._input.LA(1) - - self.state = 354 - self.match(RelayParser.T__6) - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class AdtConsParamContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def localVar(self): - return self.getTypedRuleContext(RelayParser.LocalVarContext,0) - - - def constructorName(self): - return self.getTypedRuleContext(RelayParser.ConstructorNameContext,0) - - - def getRuleIndex(self): - return RelayParser.RULE_adtConsParam - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitAdtConsParam" ): - return visitor.visitAdtConsParam(self) - else: - return visitor.visitChildren(self) - - - - - def adtConsParam(self): - - localctx = RelayParser.AdtConsParamContext(self, self._ctx, self.state) - self.enterRule(localctx, 40, self.RULE_adtConsParam) - try: - self.state = 358 - self._errHandler.sync(self) - token = self._input.LA(1) - if token in [RelayParser.T__2]: - self.enterOuterAlt(localctx, 1) - self.state = 356 - self.localVar() - pass - elif token in [RelayParser.CNAME]: - self.enterOuterAlt(localctx, 2) - self.state = 357 - self.constructorName() - pass - else: - raise NoViableAltException(self) - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class ArgListContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - - def getRuleIndex(self): - return RelayParser.RULE_argList - - - def copyFrom(self, ctx:ParserRuleContext): - super().copyFrom(ctx) - - - - class ArgNoAttrContext(ArgListContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ArgListContext - super().__init__(parser) - self.copyFrom(ctx) - - def varList(self): - return self.getTypedRuleContext(RelayParser.VarListContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitArgNoAttr" ): - return visitor.visitArgNoAttr(self) - else: - return visitor.visitChildren(self) - - - class ArgWithAttrContext(ArgListContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ArgListContext - super().__init__(parser) - self.copyFrom(ctx) - - def attrSeq(self): - return self.getTypedRuleContext(RelayParser.AttrSeqContext,0) - - def var(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.VarContext) - else: - return self.getTypedRuleContext(RelayParser.VarContext,i) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitArgWithAttr" ): - return visitor.visitArgWithAttr(self) - else: - return visitor.visitChildren(self) - - - - def argList(self): - - localctx = RelayParser.ArgListContext(self, self._ctx, self.state) - self.enterRule(localctx, 42, self.RULE_argList) - self._la = 0 # Token type - try: - self.state = 370 - self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,38,self._ctx) - if la_ == 1: - localctx = RelayParser.ArgNoAttrContext(self, localctx) - self.enterOuterAlt(localctx, 1) - self.state = 360 - self.varList() - pass - - elif la_ == 2: - localctx = RelayParser.ArgWithAttrContext(self, localctx) - self.enterOuterAlt(localctx, 2) - self.state = 366 - self._errHandler.sync(self) - _la = self._input.LA(1) - while _la==RelayParser.T__2: - self.state = 361 - self.var() - self.state = 362 - self.match(RelayParser.T__4) - self.state = 368 - self._errHandler.sync(self) - _la = self._input.LA(1) - - self.state = 369 - self.attrSeq() - pass - - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class VarListContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def var(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.VarContext) - else: - return self.getTypedRuleContext(RelayParser.VarContext,i) - - - def getRuleIndex(self): - return RelayParser.RULE_varList - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitVarList" ): - return visitor.visitVarList(self) - else: - return visitor.visitChildren(self) - - - - - def varList(self): - - localctx = RelayParser.VarListContext(self, self._ctx, self.state) - self.enterRule(localctx, 44, self.RULE_varList) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 380 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__2: - self.state = 372 - self.var() - self.state = 377 - self._errHandler.sync(self) - _la = self._input.LA(1) - while _la==RelayParser.T__4: - self.state = 373 - self.match(RelayParser.T__4) - self.state = 374 - self.var() - self.state = 379 - self._errHandler.sync(self) - _la = self._input.LA(1) - - - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class VarContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def localVar(self): - return self.getTypedRuleContext(RelayParser.LocalVarContext,0) - - - def typeExpr(self): - return self.getTypedRuleContext(RelayParser.TypeExprContext,0) - - - def getRuleIndex(self): - return RelayParser.RULE_var - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitVar" ): - return visitor.visitVar(self) - else: - return visitor.visitChildren(self) - - - - - def var(self): - - localctx = RelayParser.VarContext(self, self._ctx, self.state) - self.enterRule(localctx, 46, self.RULE_var) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 382 - self.localVar() - self.state = 385 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__25: - self.state = 383 - self.match(RelayParser.T__25) - self.state = 384 - self.typeExpr() - - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class AttrSeqContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def attr(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.AttrContext) - else: - return self.getTypedRuleContext(RelayParser.AttrContext,i) - - - def getRuleIndex(self): - return RelayParser.RULE_attrSeq - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitAttrSeq" ): - return visitor.visitAttrSeq(self) - else: - return visitor.visitChildren(self) - - - - - def attrSeq(self): - - localctx = RelayParser.AttrSeqContext(self, self._ctx, self.state) - self.enterRule(localctx, 48, self.RULE_attrSeq) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 387 - self.attr() - self.state = 392 - self._errHandler.sync(self) - _la = self._input.LA(1) - while _la==RelayParser.T__4: - self.state = 388 - self.match(RelayParser.T__4) - self.state = 389 - self.attr() - self.state = 394 - self._errHandler.sync(self) - _la = self._input.LA(1) - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class AttrContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def CNAME(self): - return self.getToken(RelayParser.CNAME, 0) - - def expr(self): - return self.getTypedRuleContext(RelayParser.ExprContext,0) - - - def getRuleIndex(self): - return RelayParser.RULE_attr - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitAttr" ): - return visitor.visitAttr(self) - else: - return visitor.visitChildren(self) - - - - - def attr(self): - - localctx = RelayParser.AttrContext(self, self._ctx, self.state) - self.enterRule(localctx, 50, self.RULE_attr) - try: - self.enterOuterAlt(localctx, 1) - self.state = 395 - self.match(RelayParser.CNAME) - self.state = 396 - self.match(RelayParser.T__14) - self.state = 397 - self.expr(0) - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class TypeExprContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - - def getRuleIndex(self): - return RelayParser.RULE_typeExpr - - - def copyFrom(self, ctx:ParserRuleContext): - super().copyFrom(ctx) - - - - class TypeParenContext(TypeExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.TypeExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def typeExpr(self): - return self.getTypedRuleContext(RelayParser.TypeExprContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitTypeParen" ): - return visitor.visitTypeParen(self) - else: - return visitor.visitChildren(self) - - - class TupleTypeContext(TypeExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.TypeExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def typeExpr(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.TypeExprContext) - else: - return self.getTypedRuleContext(RelayParser.TypeExprContext,i) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitTupleType" ): - return visitor.visitTupleType(self) - else: - return visitor.visitChildren(self) - - - class TypeCallTypeContext(TypeExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.TypeExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def generalIdent(self): - return self.getTypedRuleContext(RelayParser.GeneralIdentContext,0) - - def typeParamList(self): - return self.getTypedRuleContext(RelayParser.TypeParamListContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitTypeCallType" ): - return visitor.visitTypeCallType(self) - else: - return visitor.visitChildren(self) - - - class TypeIdentTypeContext(TypeExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.TypeExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def generalIdent(self): - return self.getTypedRuleContext(RelayParser.GeneralIdentContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitTypeIdentType" ): - return visitor.visitTypeIdentType(self) - else: - return visitor.visitChildren(self) - - - class IncompleteTypeContext(TypeExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.TypeExprContext - super().__init__(parser) - self.copyFrom(ctx) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitIncompleteType" ): - return visitor.visitIncompleteType(self) - else: - return visitor.visitChildren(self) - - - class TensorTypeContext(TypeExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.TypeExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def shapeList(self): - return self.getTypedRuleContext(RelayParser.ShapeListContext,0) - - def typeExpr(self): - return self.getTypedRuleContext(RelayParser.TypeExprContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitTensorType" ): - return visitor.visitTensorType(self) - else: - return visitor.visitChildren(self) - - - class FuncTypeContext(TypeExprContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.TypeExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def typeExpr(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.TypeExprContext) - else: - return self.getTypedRuleContext(RelayParser.TypeExprContext,i) - - def typeParamList(self): - return self.getTypedRuleContext(RelayParser.TypeParamListContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitFuncType" ): - return visitor.visitFuncType(self) - else: - return visitor.visitChildren(self) - - - - def typeExpr(self): - - localctx = RelayParser.TypeExprContext(self, self._ctx, self.state) - self.enterRule(localctx, 52, self.RULE_typeExpr) - self._la = 0 # Token type - try: - self.state = 450 - self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,47,self._ctx) - if la_ == 1: - localctx = RelayParser.TupleTypeContext(self, localctx) - self.enterOuterAlt(localctx, 1) - self.state = 399 - self.match(RelayParser.T__5) - self.state = 400 - self.match(RelayParser.T__6) - pass - - elif la_ == 2: - localctx = RelayParser.TypeParenContext(self, localctx) - self.enterOuterAlt(localctx, 2) - self.state = 401 - self.match(RelayParser.T__5) - self.state = 402 - self.typeExpr() - self.state = 403 - self.match(RelayParser.T__6) - pass - - elif la_ == 3: - localctx = RelayParser.TupleTypeContext(self, localctx) - self.enterOuterAlt(localctx, 3) - self.state = 405 - self.match(RelayParser.T__5) - self.state = 406 - self.typeExpr() - self.state = 407 - self.match(RelayParser.T__4) - self.state = 408 - self.match(RelayParser.T__6) - pass - - elif la_ == 4: - localctx = RelayParser.TupleTypeContext(self, localctx) - self.enterOuterAlt(localctx, 4) - self.state = 410 - self.match(RelayParser.T__5) - self.state = 411 - self.typeExpr() - self.state = 414 - self._errHandler.sync(self) - _la = self._input.LA(1) - while True: - self.state = 412 - self.match(RelayParser.T__4) - self.state = 413 - self.typeExpr() - self.state = 416 - self._errHandler.sync(self) - _la = self._input.LA(1) - if not (_la==RelayParser.T__4): - break - - self.state = 418 - self.match(RelayParser.T__6) - pass - - elif la_ == 5: - localctx = RelayParser.TypeCallTypeContext(self, localctx) - self.enterOuterAlt(localctx, 5) - self.state = 420 - self.generalIdent() - self.state = 421 - self.typeParamList() - pass - - elif la_ == 6: - localctx = RelayParser.TypeIdentTypeContext(self, localctx) - self.enterOuterAlt(localctx, 6) - self.state = 423 - self.generalIdent() - pass - - elif la_ == 7: - localctx = RelayParser.TensorTypeContext(self, localctx) - self.enterOuterAlt(localctx, 7) - self.state = 424 - self.match(RelayParser.T__26) - self.state = 425 - self.match(RelayParser.T__7) - self.state = 426 - self.shapeList() - self.state = 427 - self.match(RelayParser.T__4) - self.state = 428 - self.typeExpr() - self.state = 429 - self.match(RelayParser.T__8) - pass - - elif la_ == 8: - localctx = RelayParser.FuncTypeContext(self, localctx) - self.enterOuterAlt(localctx, 8) - self.state = 431 - self.match(RelayParser.T__17) - self.state = 433 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__7: - self.state = 432 - self.typeParamList() - - - self.state = 435 - self.match(RelayParser.T__5) - self.state = 444 - self._errHandler.sync(self) - _la = self._input.LA(1) - if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__3) | (1 << RelayParser.T__5) | (1 << RelayParser.T__17) | (1 << RelayParser.T__26) | (1 << RelayParser.CNAME))) != 0): - self.state = 436 - self.typeExpr() - self.state = 441 - self._errHandler.sync(self) - _la = self._input.LA(1) - while _la==RelayParser.T__4: - self.state = 437 - self.match(RelayParser.T__4) - self.state = 438 - self.typeExpr() - self.state = 443 - self._errHandler.sync(self) - _la = self._input.LA(1) - - - - self.state = 446 - self.match(RelayParser.T__6) - self.state = 447 - self.match(RelayParser.T__18) - self.state = 448 - self.typeExpr() - pass - - elif la_ == 9: - localctx = RelayParser.IncompleteTypeContext(self, localctx) - self.enterOuterAlt(localctx, 9) - self.state = 449 - self.match(RelayParser.T__3) - pass - - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class TypeParamListContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def typeExpr(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.TypeExprContext) - else: - return self.getTypedRuleContext(RelayParser.TypeExprContext,i) - - - def getRuleIndex(self): - return RelayParser.RULE_typeParamList - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitTypeParamList" ): - return visitor.visitTypeParamList(self) - else: - return visitor.visitChildren(self) - - - - - def typeParamList(self): - - localctx = RelayParser.TypeParamListContext(self, self._ctx, self.state) - self.enterRule(localctx, 54, self.RULE_typeParamList) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 452 - self.match(RelayParser.T__7) - self.state = 453 - self.typeExpr() - self.state = 458 - self._errHandler.sync(self) - _la = self._input.LA(1) - while _la==RelayParser.T__4: - self.state = 454 - self.match(RelayParser.T__4) - self.state = 455 - self.typeExpr() - self.state = 460 - self._errHandler.sync(self) - _la = self._input.LA(1) - - self.state = 461 - self.match(RelayParser.T__8) - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class ShapeListContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def shape(self, i:int=None): - if i is None: - return self.getTypedRuleContexts(RelayParser.ShapeContext) - else: - return self.getTypedRuleContext(RelayParser.ShapeContext,i) - - - def getRuleIndex(self): - return RelayParser.RULE_shapeList - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitShapeList" ): - return visitor.visitShapeList(self) - else: - return visitor.visitChildren(self) - - - - - def shapeList(self): - - localctx = RelayParser.ShapeListContext(self, self._ctx, self.state) - self.enterRule(localctx, 56, self.RULE_shapeList) - self._la = 0 # Token type - try: - self.state = 476 - self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,50,self._ctx) - if la_ == 1: - self.enterOuterAlt(localctx, 1) - self.state = 463 - self.match(RelayParser.T__5) - self.state = 464 - self.match(RelayParser.T__6) - pass - - elif la_ == 2: - self.enterOuterAlt(localctx, 2) - self.state = 465 - self.match(RelayParser.T__5) - self.state = 466 - self.shape() - self.state = 469 - self._errHandler.sync(self) - _la = self._input.LA(1) - while True: - self.state = 467 - self.match(RelayParser.T__4) - self.state = 468 - self.shape() - self.state = 471 - self._errHandler.sync(self) - _la = self._input.LA(1) - if not (_la==RelayParser.T__4): - break - - self.state = 473 - self.match(RelayParser.T__6) - pass - - elif la_ == 3: - self.enterOuterAlt(localctx, 3) - self.state = 475 - self.shape() - pass - - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class MetaContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def CNAME(self): - return self.getToken(RelayParser.CNAME, 0) - - def NAT(self): - return self.getToken(RelayParser.NAT, 0) - - def getRuleIndex(self): - return RelayParser.RULE_meta - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitMeta" ): - return visitor.visitMeta(self) - else: - return visitor.visitChildren(self) - - - - - def meta(self): - - localctx = RelayParser.MetaContext(self, self._ctx, self.state) - self.enterRule(localctx, 58, self.RULE_meta) - try: - self.enterOuterAlt(localctx, 1) - self.state = 478 - self.match(RelayParser.T__27) - self.state = 479 - self.match(RelayParser.T__7) - self.state = 480 - self.match(RelayParser.CNAME) - self.state = 481 - self.match(RelayParser.T__8) - self.state = 482 - self.match(RelayParser.T__7) - self.state = 483 - self.match(RelayParser.NAT) - self.state = 484 - self.match(RelayParser.T__8) - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class ShapeContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - - def getRuleIndex(self): - return RelayParser.RULE_shape - - - def copyFrom(self, ctx:ParserRuleContext): - super().copyFrom(ctx) - - - - class ParensShapeContext(ShapeContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ShapeContext - super().__init__(parser) - self.copyFrom(ctx) - - def shape(self): - return self.getTypedRuleContext(RelayParser.ShapeContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitParensShape" ): - return visitor.visitParensShape(self) - else: - return visitor.visitChildren(self) - - - class MetaShapeContext(ShapeContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ShapeContext - super().__init__(parser) - self.copyFrom(ctx) - - def meta(self): - return self.getTypedRuleContext(RelayParser.MetaContext,0) - - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitMetaShape" ): - return visitor.visitMetaShape(self) - else: - return visitor.visitChildren(self) - - - class IntShapeContext(ShapeContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ShapeContext - super().__init__(parser) - self.copyFrom(ctx) - - def NAT(self): - return self.getToken(RelayParser.NAT, 0) - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitIntShape" ): - return visitor.visitIntShape(self) - else: - return visitor.visitChildren(self) - - - - def shape(self): - - localctx = RelayParser.ShapeContext(self, self._ctx, self.state) - self.enterRule(localctx, 60, self.RULE_shape) - try: - self.state = 492 - self._errHandler.sync(self) - token = self._input.LA(1) - if token in [RelayParser.T__27]: - localctx = RelayParser.MetaShapeContext(self, localctx) - self.enterOuterAlt(localctx, 1) - self.state = 486 - self.meta() - pass - elif token in [RelayParser.T__5]: - localctx = RelayParser.ParensShapeContext(self, localctx) - self.enterOuterAlt(localctx, 2) - self.state = 487 - self.match(RelayParser.T__5) - self.state = 488 - self.shape() - self.state = 489 - self.match(RelayParser.T__6) - pass - elif token in [RelayParser.NAT]: - localctx = RelayParser.IntShapeContext(self, localctx) - self.enterOuterAlt(localctx, 3) - self.state = 491 - self.match(RelayParser.NAT) - pass - else: - raise NoViableAltException(self) - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class BodyContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def expr(self): - return self.getTypedRuleContext(RelayParser.ExprContext,0) - - - def getRuleIndex(self): - return RelayParser.RULE_body - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitBody" ): - return visitor.visitBody(self) - else: - return visitor.visitChildren(self) - - - - - def body(self): - - localctx = RelayParser.BodyContext(self, self._ctx, self.state) - self.enterRule(localctx, 62, self.RULE_body) - try: - self.enterOuterAlt(localctx, 1) - self.state = 494 - self.match(RelayParser.T__11) - self.state = 495 - self.expr(0) - self.state = 496 - self.match(RelayParser.T__12) - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class ScalarContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - - def getRuleIndex(self): - return RelayParser.RULE_scalar - - - def copyFrom(self, ctx:ParserRuleContext): - super().copyFrom(ctx) - - - - class ScalarFloatContext(ScalarContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ScalarContext - super().__init__(parser) - self.copyFrom(ctx) - - def FLOAT(self): - return self.getToken(RelayParser.FLOAT, 0) - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitScalarFloat" ): - return visitor.visitScalarFloat(self) - else: - return visitor.visitChildren(self) - - - class ScalarBoolContext(ScalarContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ScalarContext - super().__init__(parser) - self.copyFrom(ctx) - - def BOOL_LIT(self): - return self.getToken(RelayParser.BOOL_LIT, 0) - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitScalarBool" ): - return visitor.visitScalarBool(self) - else: - return visitor.visitChildren(self) - - - class ScalarIntContext(ScalarContext): - - def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ScalarContext - super().__init__(parser) - self.copyFrom(ctx) - - def NAT(self): - return self.getToken(RelayParser.NAT, 0) - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitScalarInt" ): - return visitor.visitScalarInt(self) - else: - return visitor.visitChildren(self) - - - - def scalar(self): - - localctx = RelayParser.ScalarContext(self, self._ctx, self.state) - self.enterRule(localctx, 64, self.RULE_scalar) - try: - self.state = 501 - self._errHandler.sync(self) - token = self._input.LA(1) - if token in [RelayParser.FLOAT]: - localctx = RelayParser.ScalarFloatContext(self, localctx) - self.enterOuterAlt(localctx, 1) - self.state = 498 - self.match(RelayParser.FLOAT) - pass - elif token in [RelayParser.NAT]: - localctx = RelayParser.ScalarIntContext(self, localctx) - self.enterOuterAlt(localctx, 2) - self.state = 499 - self.match(RelayParser.NAT) - pass - elif token in [RelayParser.BOOL_LIT]: - localctx = RelayParser.ScalarBoolContext(self, localctx) - self.enterOuterAlt(localctx, 3) - self.state = 500 - self.match(RelayParser.BOOL_LIT) - pass - else: - raise NoViableAltException(self) - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - class IdentContext(ParserRuleContext): - - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): - super().__init__(parent, invokingState) - self.parser = parser - - def generalIdent(self): - return self.getTypedRuleContext(RelayParser.GeneralIdentContext,0) - - - def globalVar(self): - return self.getTypedRuleContext(RelayParser.GlobalVarContext,0) - - - def localVar(self): - return self.getTypedRuleContext(RelayParser.LocalVarContext,0) - - - def graphVar(self): - return self.getTypedRuleContext(RelayParser.GraphVarContext,0) - - - def getRuleIndex(self): - return RelayParser.RULE_ident - - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitIdent" ): - return visitor.visitIdent(self) - else: - return visitor.visitChildren(self) - - - - - def ident(self): - - localctx = RelayParser.IdentContext(self, self._ctx, self.state) - self.enterRule(localctx, 66, self.RULE_ident) - try: - self.state = 507 - self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,53,self._ctx) - if la_ == 1: - self.enterOuterAlt(localctx, 1) - self.state = 503 - self.generalIdent() - pass - - elif la_ == 2: - self.enterOuterAlt(localctx, 2) - self.state = 504 - self.globalVar() - pass - - elif la_ == 3: - self.enterOuterAlt(localctx, 3) - self.state = 505 - self.localVar() - pass - - elif la_ == 4: - self.enterOuterAlt(localctx, 4) - self.state = 506 - self.graphVar() - pass - - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx - - - - def sempred(self, localctx:RuleContext, ruleIndex:int, predIndex:int): - if self._predicates == None: - self._predicates = dict() - self._predicates[7] = self.expr_sempred - pred = self._predicates.get(ruleIndex, None) - if pred is None: - raise Exception("No predicate with index:" + str(ruleIndex)) - else: - return pred(localctx, predIndex) - - def expr_sempred(self, localctx:ExprContext, predIndex:int): - if predIndex == 0: - return self.precpred(self._ctx, 19) - - - if predIndex == 1: - return self.precpred(self._ctx, 18) - - - if predIndex == 2: - return self.precpred(self._ctx, 17) - - - if predIndex == 3: - return self.precpred(self._ctx, 16) - - - if predIndex == 4: - return self.precpred(self._ctx, 6) - - - if predIndex == 5: - return self.precpred(self._ctx, 21) - - - if predIndex == 6: - return self.precpred(self._ctx, 8) - - - - - diff --git a/python/tvm/relay/grammar/py3/RelayVisitor.py b/python/tvm/relay/grammar/py3/RelayVisitor.py deleted file mode 100644 index c6a7b7a0558c..000000000000 --- a/python/tvm/relay/grammar/py3/RelayVisitor.py +++ /dev/null @@ -1,343 +0,0 @@ -# Generated from /Users/doobs/Code/repo/sampl/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2 -from antlr4 import * -if __name__ is not None and "." in __name__: - from .RelayParser import RelayParser -else: - from RelayParser import RelayParser - -# This class defines a complete generic visitor for a parse tree produced by RelayParser. - -class RelayVisitor(ParseTreeVisitor): - - # Visit a parse tree produced by RelayParser#prog. - def visitProg(self, ctx:RelayParser.ProgContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#generalIdent. - def visitGeneralIdent(self, ctx:RelayParser.GeneralIdentContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#globalVar. - def visitGlobalVar(self, ctx:RelayParser.GlobalVarContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#localVar. - def visitLocalVar(self, ctx:RelayParser.LocalVarContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#graphVar. - def visitGraphVar(self, ctx:RelayParser.GraphVarContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#exprList. - def visitExprList(self, ctx:RelayParser.ExprListContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#callNoAttr. - def visitCallNoAttr(self, ctx:RelayParser.CallNoAttrContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#callWithAttr. - def visitCallWithAttr(self, ctx:RelayParser.CallWithAttrContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#funcExpr. - def visitFuncExpr(self, ctx:RelayParser.FuncExprContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#metaExpr. - def visitMetaExpr(self, ctx:RelayParser.MetaExprContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#match. - def visitMatch(self, ctx:RelayParser.MatchContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#tensor. - def visitTensor(self, ctx:RelayParser.TensorContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#graph. - def visitGraph(self, ctx:RelayParser.GraphContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#identExpr. - def visitIdentExpr(self, ctx:RelayParser.IdentExprContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#stringExpr. - def visitStringExpr(self, ctx:RelayParser.StringExprContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#call. - def visitCall(self, ctx:RelayParser.CallContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#neg. - def visitNeg(self, ctx:RelayParser.NegContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#tuple. - def visitTuple(self, ctx:RelayParser.TupleContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#paren. - def visitParen(self, ctx:RelayParser.ParenContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#scalarExpr. - def visitScalarExpr(self, ctx:RelayParser.ScalarExprContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#let. - def visitLet(self, ctx:RelayParser.LetContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#projection. - def visitProjection(self, ctx:RelayParser.ProjectionContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#ifElse. - def visitIfElse(self, ctx:RelayParser.IfElseContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#binOp. - def visitBinOp(self, ctx:RelayParser.BinOpContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#func. - def visitFunc(self, ctx:RelayParser.FuncContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#funcDefn. - def visitFuncDefn(self, ctx:RelayParser.FuncDefnContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#externAdtDefn. - def visitExternAdtDefn(self, ctx:RelayParser.ExternAdtDefnContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#adtDefn. - def visitAdtDefn(self, ctx:RelayParser.AdtDefnContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#constructorName. - def visitConstructorName(self, ctx:RelayParser.ConstructorNameContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#adtConsDefnList. - def visitAdtConsDefnList(self, ctx:RelayParser.AdtConsDefnListContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#adtConsDefn. - def visitAdtConsDefn(self, ctx:RelayParser.AdtConsDefnContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#matchClauseList. - def visitMatchClauseList(self, ctx:RelayParser.MatchClauseListContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#matchClause. - def visitMatchClause(self, ctx:RelayParser.MatchClauseContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#matchType. - def visitMatchType(self, ctx:RelayParser.MatchTypeContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#patternList. - def visitPatternList(self, ctx:RelayParser.PatternListContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#wildcardPattern. - def visitWildcardPattern(self, ctx:RelayParser.WildcardPatternContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#varPattern. - def visitVarPattern(self, ctx:RelayParser.VarPatternContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#constructorPattern. - def visitConstructorPattern(self, ctx:RelayParser.ConstructorPatternContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#tuplePattern. - def visitTuplePattern(self, ctx:RelayParser.TuplePatternContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#adtCons. - def visitAdtCons(self, ctx:RelayParser.AdtConsContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#adtConsParamList. - def visitAdtConsParamList(self, ctx:RelayParser.AdtConsParamListContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#adtConsParam. - def visitAdtConsParam(self, ctx:RelayParser.AdtConsParamContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#argNoAttr. - def visitArgNoAttr(self, ctx:RelayParser.ArgNoAttrContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#argWithAttr. - def visitArgWithAttr(self, ctx:RelayParser.ArgWithAttrContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#varList. - def visitVarList(self, ctx:RelayParser.VarListContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#var. - def visitVar(self, ctx:RelayParser.VarContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#attrSeq. - def visitAttrSeq(self, ctx:RelayParser.AttrSeqContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#attr. - def visitAttr(self, ctx:RelayParser.AttrContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#tupleType. - def visitTupleType(self, ctx:RelayParser.TupleTypeContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#typeParen. - def visitTypeParen(self, ctx:RelayParser.TypeParenContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#typeCallType. - def visitTypeCallType(self, ctx:RelayParser.TypeCallTypeContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#typeIdentType. - def visitTypeIdentType(self, ctx:RelayParser.TypeIdentTypeContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#tensorType. - def visitTensorType(self, ctx:RelayParser.TensorTypeContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#funcType. - def visitFuncType(self, ctx:RelayParser.FuncTypeContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#incompleteType. - def visitIncompleteType(self, ctx:RelayParser.IncompleteTypeContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#typeParamList. - def visitTypeParamList(self, ctx:RelayParser.TypeParamListContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#shapeList. - def visitShapeList(self, ctx:RelayParser.ShapeListContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#meta. - def visitMeta(self, ctx:RelayParser.MetaContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#metaShape. - def visitMetaShape(self, ctx:RelayParser.MetaShapeContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#parensShape. - def visitParensShape(self, ctx:RelayParser.ParensShapeContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#intShape. - def visitIntShape(self, ctx:RelayParser.IntShapeContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#body. - def visitBody(self, ctx:RelayParser.BodyContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#scalarFloat. - def visitScalarFloat(self, ctx:RelayParser.ScalarFloatContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#scalarInt. - def visitScalarInt(self, ctx:RelayParser.ScalarIntContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#scalarBool. - def visitScalarBool(self, ctx:RelayParser.ScalarBoolContext): - return self.visitChildren(ctx) - - - # Visit a parse tree produced by RelayParser#ident. - def visitIdent(self, ctx:RelayParser.IdentContext): - return self.visitChildren(ctx) - - - -del RelayParser \ No newline at end of file diff --git a/python/tvm/relay/grammar/py3/__init__.py b/python/tvm/relay/grammar/py3/__init__.py deleted file mode 100644 index 13a83393a912..000000000000 --- a/python/tvm/relay/grammar/py3/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py deleted file mode 100644 index 6c4e3131e3c2..000000000000 --- a/python/tvm/relay/parser.py +++ /dev/null @@ -1,30 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""A parser for Relay's text format.""" -from __future__ import absolute_import -from .. import register_func - - -@register_func("relay.fromtext") -def fromtext(data, source_name=None): - """Parse a Relay program.""" - # pylint: disable=import-outside-toplevel - from tvm.relay import _parser - x = _parser.fromtext(data + "\n", source_name) - if x is None: - raise Exception("cannot parse: ", data) - return x diff --git a/tests/lint/rat-excludes b/tests/lint/rat-excludes index 0c3ab601e04a..5f0445134dea 100644 --- a/tests/lint/rat-excludes +++ b/tests/lint/rat-excludes @@ -37,11 +37,6 @@ dist .node_repl_history node_modules -# Relay parser: they are generated by ANTLR. -RelayLexer.py -RelayParser.py -RelayVisitor.py - # Specific files package-list MANIFEST From 11139536fb4df1518ff2fd39b8deacfc6405ee5a Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Jul 2020 21:32:57 -0700 Subject: [PATCH 13/48] Rename parser tests and remove old ones --- tests/python/relay/test_ir_parser.py | 233 ++++--- tests/python/relay/test_ir_parser2.py | 942 -------------------------- 2 files changed, 130 insertions(+), 1045 deletions(-) delete mode 100644 tests/python/relay/test_ir_parser2.py diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 1e4fe6b66830..7ffe85c5d5df 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -17,6 +17,7 @@ import tvm from tvm import te from tvm import relay +import tvm.relay.testing import pytest from numpy import isclose from typing import Union @@ -73,15 +74,18 @@ def assert_graph_equal(lhs, rhs): def graph_equal(lhs, rhs): return tvm.ir.structural_equal(lhs, rhs, map_free_vars=True) +def roundtrip_expr(expr): + text = tvm.relay.Expr.astext(expr, show_meta_data=False) + x = tvm.parser.parse_expr(str(text)) + assert_graph_equal(x, expr) def roundtrip(expr): - x = relay.fromtext(expr.astext()) + x = tvm.parser.fromtext(expr.astext()) assert_graph_equal(x, expr) - def parse_text(code): - expr = relay.fromtext(SEMVER + "\n" + code) - roundtrip(expr) + expr = tvm.parser.parse_expr(code) + roundtrip_expr(expr) return expr @@ -91,11 +95,18 @@ def parses_as(code, expr): result = graph_equal(parsed, expr) return result +def parse_module(code): + mod = tvm.parser.parse(code) + roundtrip(mod) + return mod def assert_parses_as(code, expr): parsed = parse_text(code) assert_graph_equal(parsed, expr) +def assert_parse_module_as(code, mod): + parsed = parse_module(code) + assert_graph_equal(parsed, mod) def get_scalar(x): # type: (relay.Constant) -> (Union[float, int, bool]) @@ -176,7 +187,8 @@ def test_bool_literal(): def test_negative(): - assert isinstance(parse_text("let %x = 1; -%x").body, relay.Call) + # need to handle parsing non-literal operations + # assert isinstance(parse_text("let %x = 1; -%x").body, relay.Call) assert get_scalar(parse_text("--10")) == 10 assert get_scalar(parse_text("---10")) == -10 @@ -198,15 +210,7 @@ def test_op_assoc(): assert graph_equal(parse_text("1 * 1 + 1 < 1 == 1"), parse_text("(((1 * 1) + 1) < 1) == 1")) assert graph_equal(parse_text("1 == 1 < 1 + 1 * 1"), parse_text("1 == (1 < (1 + (1 * 1)))")) - -@pytest.mark.skip def test_vars(): - # temp vars won't work b/c they start with a digit - # # temp var - # temp_var = parse_text("%1") - # assert isinstance(temp_var, relay.Var) - # assert temp_var.name == "1" - # var var = parse_text("let %foo = (); %foo") assert isinstance(var.body, relay.Var) @@ -218,9 +222,19 @@ def test_vars(): assert global_var.name_hint == "foo" # operator id - op = parse_text("foo") + op = parse_text("add") + assert isinstance(op, tvm.ir.Op) + assert op.name == "add" + + # operator id with prefix + op = parse_text("nn.global_avg_pool2d") assert isinstance(op, tvm.ir.Op) - assert op.name == "foo" + assert op.name == "nn.global_avg_pool2d" + +def test_meta_ref(): + meta_op = parse_text("meta[type_key][1337]") + assert meta_op.attrs.node_type_key == "type_key" + assert meta_op.attrs.node_index == "1337" def test_let(): @@ -253,7 +267,7 @@ def test_let(): def test_seq(): assert_parses_as( - "();; ()", + "(); ()", relay.Let( _, UNIT, @@ -348,16 +362,18 @@ def test_func(): ) ) - # attributes - assert_parses_as( - "fn (n=5) { () }", - relay.Function([], UNIT, None, None, tvm.ir.make_node("DictAttrs", n=relay.const(5))) - ) + # Refactor the attribute syntax and printing. + # + # # attributes + # assert_parses_as( + # "fn (n=5) { () }", + # relay.Function([], UNIT, None, None, tvm.ir.make_node("DictAttrs", n=relay.const(5))) + # ) # TODO(@jmp): Crashes if %x isn't annnotated. def test_defn(): - id_defn = parse_text( + id_defn = parse_module( """ def @id(%x: int32) -> int32 { %x @@ -367,7 +383,7 @@ def @id(%x: int32) -> int32 { def test_recursive_call(): - id_defn = parse_text( + id_defn = parse_module( """ def @id(%x: int32) -> int32 { @id(%x) @@ -487,40 +503,39 @@ def test_call(): ) ) - # TODO(@jmp): re-enable after sequence parsing improvements # curried function - # curried_mult = relay.Var("curried_mult") - # assert_parses_as( - # """ - # let %curried_mult = - # fn (%x) { - # fn (%y) { - # %x * %y - # } - # }; - # %curried_mult(0); - # %curried_mult(0)(0) - # """, - # relay.Let( - # curried_mult, - # relay.Function( - # [X], - # relay.Function( - # [Y], - # relay.multiply(X, Y), - # None, - # [] - # ), - # None, - # [] - # ), - # relay.Let( - # _, - # relay.Call(curried_mult, [relay.const(0)], None, None), - # relay.Call(relay.Call(curried_mult, [relay.const(0)], None, None), [relay.const(0)], None, None) - # ) - # ) - # ) + curried_mult = relay.Var("curried_mult") + assert_parses_as( + """ + let %curried_mult = + fn (%x) { + fn (%y) { + %x * %y + } + }; + %curried_mult(0); + %curried_mult(0)(0) + """, + relay.Let( + curried_mult, + relay.Function( + [X], + relay.Function( + [Y], + relay.multiply(X, Y), + None, + [] + ), + None, + [] + ), + relay.Let( + _, + relay.Call(curried_mult, [relay.const(0)], None, None), + relay.Call(relay.Call(curried_mult, [relay.const(0)], None, None), [relay.const(0)], None, None) + ) + ) + ) # op assert_parses_as( @@ -655,7 +670,7 @@ def test_adt_defn(): [], [relay.Constructor("Nil", [], glob_typ_var)]) mod[glob_typ_var] = prog - assert_parses_as( + assert_parse_module_as( """ type Ayy { Nil } """, @@ -669,7 +684,7 @@ def test_empty_adt_defn(): glob_typ_var = relay.GlobalTypeVar("Ayy") prog = relay.TypeData(glob_typ_var, [], []) mod[glob_typ_var] = prog - assert_parses_as( + assert_parse_module_as( """ type Ayy { } """, @@ -690,7 +705,7 @@ def test_multiple_cons_defn(): relay.Constructor("Nil", [], list_var), ]) mod[list_var] = prog - assert_parses_as(LIST_DEFN, mod) + assert_parse_module_as(LIST_DEFN, mod) def test_multiple_type_param_defn(): @@ -706,7 +721,7 @@ def test_multiple_type_param_defn(): ]) mod = tvm.IRModule() mod[glob_typ_var] = prog - assert_parses_as( + assert_parse_module_as( """ type Either[A, B] { Left(A), @@ -740,7 +755,7 @@ def test_match(): input_var = relay.Var("xs", input_type) rest_var = relay.Var("rest") cons_case = relay.Let( - _, + relay.var("", type_annotation=None), UNIT, relay.add(relay.const(1), relay.Call(length_var, [rest_var]))) body = relay.Match(input_var, @@ -762,14 +777,14 @@ def test_match(): ) mod[length_var] = length_func - assert_parses_as( + assert_parse_module_as( """ %s def @length[A](%%xs: List[A]) -> int32 { %s (%%xs) { - Cons(_, %%rest) => { - ();; + Cons(_, %%rest : List[A]) => { + (); 1 + @length(%%rest) }, Nil => 0, @@ -803,7 +818,7 @@ def test_adt_cons_expr(): ) mod[make_singleton_var] = make_singleton_func - assert_parses_as( + assert_parse_module_as( """ %s @@ -817,7 +832,7 @@ def @make_singleton(%%x: int32) -> List[int32] { @raises_parse_error def test_duplicate_adt_defn(): - parse_text( + parse_module( """ %s @@ -853,8 +868,8 @@ def test_duplicate_adt_cons_defn(): def test_duplicate_global_var(): parse_text( """ - def @id[A](%%x: A) -> A { x } - def @id[A](%%x: A) -> A { x } + def @id[A](%x: A) -> A { x } + def @id[A](%x: A) -> A { x } """ ) @@ -868,48 +883,60 @@ def test_extern_adt_defn(): extern_def = relay.TypeData(extern_var, [typ_var], []) mod[extern_var] = extern_def - assert_parses_as( + assert_parse_module_as( """ extern type T[A] """, mod ) + +@pytest.mark.skip("not yet tested on parser 2.0") def test_import_grad(): mod = tvm.IRModule() mod.import_from_std("gradient.rly") +# hiearchy id, i.e parse nn.conv2d +# do with multiple levels +# +# call attributes not correctly parsing +# convert error from attribute construction to real error message +# lexing issue with projection of graph variables + +# def test_hierarchical_identifiers(): +# assert False + +def test_resnet(): + mod, params = relay.testing.resnet.get_workload() + text = str(mod.astext()) + parsed_mod = parse_module(text) + tvm.ir.assert_structural_equal(mod, parsed_mod) + +def inline_params(mod, params): + main_fn = mod["main"] + str_to_var = {} + for param in main_fn.params: + str_to_var[param.name_hint] = param + + bind_map = {} + for param in params: + bind_map[str_to_var[param]] = relay.const(params[param]) + + body = relay.bind(main_fn.body, bind_map) + main_fn = relay.Function(relay.analysis.free_vars(body), body) + mod["main_fn"] = main_fn + return mod + +def test_resnet_inlined_params(): + mod, params = relay.testing.resnet.get_workload() + print("here") + mod = inline_params(mod, params) + print("here") + text = str(mod.astext()) + print("here") + parsed_mod = parse_module(text) + print("here") + tvm.ir.assert_structural_equal(mod, parsed_mod) + print("here") + if __name__ == "__main__": - test_graph() - test_comments() - test_int_literal() - test_float_literal() - test_bool_literal() - test_negative() - test_bin_op() - test_parens() - test_op_assoc() - test_let() - test_seq() - test_tuple() - test_func() - test_defn() - test_recursive_call() - test_ifelse() - test_call() - test_incomplete_type() - test_builtin_types() - test_tensor_type() - test_function_type() - test_tuple_type() - test_adt_defn() - test_empty_adt_defn() - test_multiple_cons_defn() - test_multiple_type_param_defn() - test_match() - test_adt_cons_expr() - test_duplicate_adt_defn() - test_duplicate_adt_cons() - test_duplicate_adt_cons_defn() - test_duplicate_global_var() - test_extern_adt_defn() - test_import_grad() + test_resnet_inlined_params() diff --git a/tests/python/relay/test_ir_parser2.py b/tests/python/relay/test_ir_parser2.py deleted file mode 100644 index 7ffe85c5d5df..000000000000 --- a/tests/python/relay/test_ir_parser2.py +++ /dev/null @@ -1,942 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import tvm -from tvm import te -from tvm import relay -import tvm.relay.testing -import pytest -from numpy import isclose -from typing import Union -from functools import wraps -raises_parse_error = pytest.mark.xfail(raises=tvm._ffi.base.TVMError) - -SEMVER = "v0.0.4" - -BINARY_OPS = { - "*": relay.multiply, - "/": relay.divide, - "+": relay.add, - "-": relay.subtract, - "<": relay.less, - ">": relay.greater, - "<=": relay.less_equal, - ">=": relay.greater_equal, - "==": relay.equal, - "!=": relay.not_equal, -} - -TYPES = { - "int8", - "int16", - "int32", - "int64", - - "uint8", - "uint16", - "uint32", - "uint64", - - "float16", - "float32", - "float64", - - "bool", - - "int8x4", - "uint1x4", - "float16x4", -} - -LIST_DEFN = """ -type List[A] { - Cons(A, List[A]), - Nil, -} -""" - -def assert_graph_equal(lhs, rhs): - tvm.ir.assert_structural_equal(lhs, rhs, map_free_vars=True) - -def graph_equal(lhs, rhs): - return tvm.ir.structural_equal(lhs, rhs, map_free_vars=True) - -def roundtrip_expr(expr): - text = tvm.relay.Expr.astext(expr, show_meta_data=False) - x = tvm.parser.parse_expr(str(text)) - assert_graph_equal(x, expr) - -def roundtrip(expr): - x = tvm.parser.fromtext(expr.astext()) - assert_graph_equal(x, expr) - -def parse_text(code): - expr = tvm.parser.parse_expr(code) - roundtrip_expr(expr) - return expr - - -def parses_as(code, expr): - # type: (str, relay.Expr) -> bool - parsed = parse_text(code) - result = graph_equal(parsed, expr) - return result - -def parse_module(code): - mod = tvm.parser.parse(code) - roundtrip(mod) - return mod - -def assert_parses_as(code, expr): - parsed = parse_text(code) - assert_graph_equal(parsed, expr) - -def assert_parse_module_as(code, mod): - parsed = parse_module(code) - assert_graph_equal(parsed, mod) - -def get_scalar(x): - # type: (relay.Constant) -> (Union[float, int, bool]) - return x.data.asnumpy().item() - -int32 = relay.scalar_type("int32") - -_ = relay.Var("_") -X = relay.Var("x") -Y = relay.Var("y") -X_ANNO = relay.Var("x", int32) -Y_ANNO = relay.Var("y", int32) - -UNIT = relay.Tuple([]) - - -def test_comments(): - assert_parses_as( - """ - // This is a line comment! - () - """, - UNIT - ) - - assert_parses_as( - """ - /* This is a block comment! - This is still a block comment! - */ - () - """, - UNIT - ) - - assert_parses_as( - """ - /* This is a block comment! - /*Block comment is recursive!*/ - */ - () - """, - UNIT - ) - - -def test_int_literal(): - assert isinstance(parse_text("1"), relay.Constant) - assert isinstance(parse_text("1").data, tvm.nd.NDArray) - - assert get_scalar(parse_text("1")) == 1 - assert get_scalar(parse_text("10")) == 10 - assert get_scalar(parse_text("0")) == 0 - assert get_scalar(parse_text("-100")) == -100 - assert get_scalar(parse_text("-05")) == -5 - - -def test_float_literal(): - assert get_scalar(parse_text("1.0f")) == 1.0 - assert isclose(get_scalar(parse_text("1.56667f")), 1.56667) - assert get_scalar(parse_text("0.0f")) == 0.0 - assert get_scalar(parse_text("-10.0f")) == -10.0 - - # scientific notation - assert isclose(get_scalar(parse_text("1e-1f")), 1e-1) - assert get_scalar(parse_text("1e+1f")) == 1e+1 - assert isclose(get_scalar(parse_text("1E-1f")), 1E-1) - assert get_scalar(parse_text("1E+1f")) == 1E+1 - assert isclose(get_scalar(parse_text("1.0e-1f")), 1.0e-1) - assert get_scalar(parse_text("1.0e+1f")) == 1.0e+1 - assert isclose(get_scalar(parse_text("1.0E-1f")), 1.0E-1) - assert get_scalar(parse_text("1.0E+1f")) == 1.0E+1 - - -def test_bool_literal(): - assert get_scalar(parse_text("True")) == True - assert get_scalar(parse_text("False")) == False - - -def test_negative(): - # need to handle parsing non-literal operations - # assert isinstance(parse_text("let %x = 1; -%x").body, relay.Call) - assert get_scalar(parse_text("--10")) == 10 - assert get_scalar(parse_text("---10")) == -10 - - -def test_bin_op(): - for bin_op in BINARY_OPS.keys(): - assert_parses_as( - "1 {} 1".format(bin_op), - BINARY_OPS.get(bin_op)(relay.const(1), relay.const(1)) - ) - - -def test_parens(): - assert graph_equal(parse_text("1 * 1 + 1"), parse_text("(1 * 1) + 1")) - assert not graph_equal(parse_text("1 * 1 + 1"), parse_text("1 * (1 + 1)")) - - -def test_op_assoc(): - assert graph_equal(parse_text("1 * 1 + 1 < 1 == 1"), parse_text("(((1 * 1) + 1) < 1) == 1")) - assert graph_equal(parse_text("1 == 1 < 1 + 1 * 1"), parse_text("1 == (1 < (1 + (1 * 1)))")) - -def test_vars(): - # var - var = parse_text("let %foo = (); %foo") - assert isinstance(var.body, relay.Var) - assert var.body.name_hint == "foo" - - # global var - global_var = parse_text("@foo") - assert isinstance(global_var, relay.GlobalVar) - assert global_var.name_hint == "foo" - - # operator id - op = parse_text("add") - assert isinstance(op, tvm.ir.Op) - assert op.name == "add" - - # operator id with prefix - op = parse_text("nn.global_avg_pool2d") - assert isinstance(op, tvm.ir.Op) - assert op.name == "nn.global_avg_pool2d" - -def test_meta_ref(): - meta_op = parse_text("meta[type_key][1337]") - assert meta_op.attrs.node_type_key == "type_key" - assert meta_op.attrs.node_index == "1337" - - -def test_let(): - assert_parses_as( - "let %x = 1; ()", - relay.Let( - X, - relay.const(1), - UNIT - ) - ) - - assert_parses_as( - """ - let %x = 1; - let %y = 2; - () - """, - relay.Let( - X, - relay.const(1), - relay.Let( - Y, - relay.const(2), - UNIT - ) - ) - ) - - -def test_seq(): - assert_parses_as( - "(); ()", - relay.Let( - _, - UNIT, - UNIT) - ) - - assert_parses_as( - "let %_ = 1; ()", - relay.Let( - X, - relay.const(1), - UNIT - ) - ) - - -def test_graph(): - code = "%0 = (); %1 = 1; (%0, %0, %1)" - assert_parses_as( - code, - relay.Tuple([UNIT, UNIT, relay.const(1)]) - ) - - -@raises_parse_error -def test_graph_wrong_order(): - parse_text("%1 = (); %1") - - -@raises_parse_error -def test_let_global_var(): - parse_text("let @x = 1; ()") - - -@raises_parse_error -def test_let_op(): - parse_text("let x = 1; ()") - - -def test_tuple(): - assert_parses_as("()", relay.Tuple([])) - - assert_parses_as("(0,)", relay.Tuple([relay.const(0)])) - - assert_parses_as("(0, 1)", relay.Tuple([relay.const(0), relay.const(1)])) - - assert_parses_as("(0, 1, 2)", relay.Tuple([relay.const(0), relay.const(1), relay.const(2)])) - - -def test_func(): - # 0 args - assert_parses_as( - "fn () { 0 }", - relay.Function( - [], - relay.const(0), - None, - [] - ) - ) - - # 1 arg - assert_parses_as( - "fn (%x) { %x }", - relay.Function( - [X], - X, - None, - [] - ) - ) - - # 2 args - assert_parses_as( - "fn (%x, %y) { %x + %y }", - relay.Function( - [X, Y], - relay.add(X, Y), - None, - [] - ) - ) - - # annotations - assert_parses_as( - "fn (%x: int32) -> int32 { %x }", - relay.Function( - [X_ANNO], - X_ANNO, - int32, - [] - ) - ) - - # Refactor the attribute syntax and printing. - # - # # attributes - # assert_parses_as( - # "fn (n=5) { () }", - # relay.Function([], UNIT, None, None, tvm.ir.make_node("DictAttrs", n=relay.const(5))) - # ) - - -# TODO(@jmp): Crashes if %x isn't annnotated. -def test_defn(): - id_defn = parse_module( - """ - def @id(%x: int32) -> int32 { - %x - } - """) - assert isinstance(id_defn, tvm.IRModule) - - -def test_recursive_call(): - id_defn = parse_module( - """ - def @id(%x: int32) -> int32 { - @id(%x) - } - """) - assert isinstance(id_defn, tvm.IRModule) - - -def test_ifelse(): - assert_parses_as( - """ - if (True) { - 0 - } else { - 1 - } - """, - relay.If( - relay.const(True), - relay.const(0), - relay.const(1) - ) - ) - - -@raises_parse_error -def test_ifelse_scope(): - parse_text( - """ - if (True) { - let %x = (); - () - } else { - %x - } - """ - ) - - -def test_call(): - # select right function to call: simple ident case - id_func = relay.Var("id") - assert_parses_as( - """ - let %id = fn (%x) { %x }; - 10 * %id(10) - """, - relay.Let( - id_func, - relay.Function([X], X, None, []), - relay.multiply(relay.const(10), relay.Call(id_func, [relay.const(10)])) - ) - ) - - # 0 args - constant = relay.Var("constant") - assert_parses_as( - """ - let %constant = fn () { 0 }; - %constant() - """, - relay.Let( - constant, - relay.Function([], relay.const(0), None, []), - relay.Call(constant, [], None, None) - ) - ) - - # 1 arg - id_var = relay.Var("id") - assert_parses_as( - """ - let %id = fn (%x) { %x }; - %id(1) - """, - relay.Let( - id_var, - relay.Function([X], X, None, []), - relay.Call(id_var, [relay.const(1)], None, None) - ) - ) - - # 2 args - multiply = relay.Var("multiply") - assert_parses_as( - """ - let %multiply = fn (%x, %y) { %x * %y }; - %multiply(0, 0) - """, - relay.Let( - multiply, - relay.Function( - [X, Y], - relay.multiply(X, Y), - None, - [] - ), - relay.Call(multiply, [relay.const(0), relay.const(0)], None, None) - ) - ) - - # anonymous function - assert_parses_as( - """ - (fn (%x) { %x })(0) - """, - relay.Call( - relay.Function( - [X], - X, - None, - [] - ), - [relay.const(0)], - None, - None - ) - ) - - # curried function - curried_mult = relay.Var("curried_mult") - assert_parses_as( - """ - let %curried_mult = - fn (%x) { - fn (%y) { - %x * %y - } - }; - %curried_mult(0); - %curried_mult(0)(0) - """, - relay.Let( - curried_mult, - relay.Function( - [X], - relay.Function( - [Y], - relay.multiply(X, Y), - None, - [] - ), - None, - [] - ), - relay.Let( - _, - relay.Call(curried_mult, [relay.const(0)], None, None), - relay.Call(relay.Call(curried_mult, [relay.const(0)], None, None), [relay.const(0)], None, None) - ) - ) - ) - - # op - assert_parses_as( - "abs(1)", - relay.Call(relay.op.get("abs"), [relay.const(1)], None, None) - ) - -# Types - - -def test_incomplete_type(): - assert_parses_as( - "let %_ : _ = (); ()", - relay.Let( - _, - UNIT, - UNIT - ) - ) - - -def test_builtin_types(): - for builtin_type in TYPES: - parse_text("let %_ : {} = (); ()".format(builtin_type)) - - -def test_tensor_type(): - assert_parses_as( - "let %_ : Tensor[(), float32] = (); ()", - relay.Let( - relay.Var("_", relay.TensorType((), "float32")), - UNIT, - UNIT - ) - ) - - assert_parses_as( - "let %_ : Tensor[(1), float32] = (); ()", - relay.Let( - relay.Var("_", relay.TensorType((1,), "float32")), - UNIT, - UNIT - ) - ) - - assert_parses_as( - "let %_ : Tensor[(1, 1), float32] = (); ()", - relay.Let( - relay.Var("_", relay.TensorType((1, 1), "float32")), - UNIT, - UNIT - ) - ) - - -def test_function_type(): - assert_parses_as( - """ - let %_: fn () -> int32 = fn () -> int32 { 0 }; () - """, - relay.Let( - relay.Var("_", relay.FuncType([], int32, [], [])), - relay.Function([], relay.const(0), int32, []), - UNIT - ) - ) - - assert_parses_as( - """ - let %_: fn (int32) -> int32 = fn (%x: int32) -> int32 { 0 }; () - """, - relay.Let( - relay.Var("_", relay.FuncType([int32], int32, [], [])), - relay.Function([relay.Var("x", int32)], relay.const(0), int32, []), - UNIT - ) - ) - - assert_parses_as( - """ - let %_: fn (int32, int32) -> int32 = fn (%x: int32, %y: int32) -> int32 { 0 }; () - """, - relay.Let( - relay.Var("_", relay.FuncType([int32, int32], int32, [], [])), - relay.Function([relay.Var("x", int32), relay.Var("y", int32)], relay.const(0), int32, []), - UNIT - ) - ) - - -def test_tuple_type(): - assert_parses_as( - """ - let %_: () = (); () - """, - relay.Let( - relay.Var("_", relay.TupleType([])), - UNIT, - UNIT - ) - ) - - assert_parses_as( - """ - let %_: (int32,) = (0,); () - """, - relay.Let( - relay.Var("_", relay.TupleType([int32])), - relay.Tuple([relay.const(0)]), - UNIT - ) - ) - - assert_parses_as( - """ - let %_: (int32, int32) = (0, 1); () - """, - relay.Let( - relay.Var("_", relay.TupleType([int32, int32])), - relay.Tuple([relay.const(0), relay.const(1)]), - UNIT - ) - ) - - -def test_adt_defn(): - mod = tvm.IRModule() - - glob_typ_var = relay.GlobalTypeVar("Ayy") - prog = relay.TypeData( - glob_typ_var, - [], - [relay.Constructor("Nil", [], glob_typ_var)]) - mod[glob_typ_var] = prog - assert_parse_module_as( - """ - type Ayy { Nil } - """, - mod - ) - - -def test_empty_adt_defn(): - mod = tvm.IRModule() - - glob_typ_var = relay.GlobalTypeVar("Ayy") - prog = relay.TypeData(glob_typ_var, [], []) - mod[glob_typ_var] = prog - assert_parse_module_as( - """ - type Ayy { } - """, - mod - ) - - -def test_multiple_cons_defn(): - mod = tvm.IRModule() - - list_var = relay.GlobalTypeVar("List") - typ_var = relay.TypeVar("A") - prog = relay.TypeData( - list_var, - [typ_var], - [ - relay.Constructor("Cons", [typ_var, list_var(typ_var)], list_var), - relay.Constructor("Nil", [], list_var), - ]) - mod[list_var] = prog - assert_parse_module_as(LIST_DEFN, mod) - - -def test_multiple_type_param_defn(): - glob_typ_var = relay.GlobalTypeVar("Either") - typ_var_a = relay.TypeVar("A") - typ_var_b = relay.TypeVar("B") - prog = relay.TypeData( - glob_typ_var, - [typ_var_a, typ_var_b], - [ - relay.Constructor("Left", [typ_var_a], glob_typ_var), - relay.Constructor("Right", [typ_var_b], glob_typ_var), - ]) - mod = tvm.IRModule() - mod[glob_typ_var] = prog - assert_parse_module_as( - """ - type Either[A, B] { - Left(A), - Right(B), - } - """, - mod - ) - - -def test_match(): - # pair each match keyword with whether it specifies a complete match or not - match_keywords = [("match", True), ("match?", False)] - for (match_keyword, is_complete) in match_keywords: - mod = tvm.IRModule() - - list_var = relay.GlobalTypeVar("List") - typ_var = relay.TypeVar("A") - cons_constructor = relay.Constructor( - "Cons", [typ_var, list_var(typ_var)], list_var) - nil_constructor = relay.Constructor("Nil", [], list_var) - list_def = relay.TypeData( - list_var, - [typ_var], - [cons_constructor, nil_constructor]) - mod[list_var] = list_def - - length_var = relay.GlobalVar("length") - typ_var = relay.TypeVar("A") - input_type = list_var(typ_var) - input_var = relay.Var("xs", input_type) - rest_var = relay.Var("rest") - cons_case = relay.Let( - relay.var("", type_annotation=None), - UNIT, - relay.add(relay.const(1), relay.Call(length_var, [rest_var]))) - body = relay.Match(input_var, - [relay.Clause( - relay.PatternConstructor( - cons_constructor, - [relay.PatternWildcard(), relay.PatternVar(rest_var)]), - cons_case), - relay.Clause( - relay.PatternConstructor(nil_constructor, []), - relay.const(0))], - complete=is_complete - ) - length_func = relay.Function( - [input_var], - body, - int32, - [typ_var] - ) - mod[length_var] = length_func - - assert_parse_module_as( - """ - %s - - def @length[A](%%xs: List[A]) -> int32 { - %s (%%xs) { - Cons(_, %%rest : List[A]) => { - (); - 1 + @length(%%rest) - }, - Nil => 0, - } - } - """ % (LIST_DEFN, match_keyword), - mod - ) - - -def test_adt_cons_expr(): - mod = tvm.IRModule() - - list_var = relay.GlobalTypeVar("List") - typ_var = relay.TypeVar("A") - cons_constructor = relay.Constructor( - "Cons", [typ_var, list_var(typ_var)], list_var) - nil_constructor = relay.Constructor("Nil", [], list_var) - list_def = relay.TypeData( - list_var, - [typ_var], - [cons_constructor, nil_constructor]) - mod[list_var] = list_def - - make_singleton_var = relay.GlobalVar("make_singleton") - input_var = relay.Var("x", int32) - make_singleton_func = relay.Function( - [input_var], - cons_constructor(input_var, nil_constructor()), - list_var(int32) - ) - mod[make_singleton_var] = make_singleton_func - - assert_parse_module_as( - """ - %s - - def @make_singleton(%%x: int32) -> List[int32] { - Cons(%%x, Nil) - } - """ % LIST_DEFN, - mod - ) - - -@raises_parse_error -def test_duplicate_adt_defn(): - parse_module( - """ - %s - - type List[A] { - Cons(A, List[A]), - Nil, - } - """ % LIST_DEFN - ) - - -@raises_parse_error -def test_duplicate_adt_cons(): - parse_text( - """ - type Ayy { Lmao } - type Haha { Lmao } - """ - ) - - -@raises_parse_error -def test_duplicate_adt_cons_defn(): - parse_text( - """ - type Ayy { Lmao } - type Lmao { Ayy } - """ - ) - - -@raises_parse_error -def test_duplicate_global_var(): - parse_text( - """ - def @id[A](%x: A) -> A { x } - def @id[A](%x: A) -> A { x } - """ - ) - - -def test_extern_adt_defn(): - # TODO(weberlo): update this test once extern is implemented - mod = tvm.IRModule() - - extern_var = relay.GlobalTypeVar("T") - typ_var = relay.TypeVar("A") - extern_def = relay.TypeData(extern_var, [typ_var], []) - mod[extern_var] = extern_def - - assert_parse_module_as( - """ - extern type T[A] - """, - mod - ) - -@pytest.mark.skip("not yet tested on parser 2.0") -def test_import_grad(): - mod = tvm.IRModule() - mod.import_from_std("gradient.rly") - -# hiearchy id, i.e parse nn.conv2d -# do with multiple levels -# -# call attributes not correctly parsing -# convert error from attribute construction to real error message -# lexing issue with projection of graph variables - -# def test_hierarchical_identifiers(): -# assert False - -def test_resnet(): - mod, params = relay.testing.resnet.get_workload() - text = str(mod.astext()) - parsed_mod = parse_module(text) - tvm.ir.assert_structural_equal(mod, parsed_mod) - -def inline_params(mod, params): - main_fn = mod["main"] - str_to_var = {} - for param in main_fn.params: - str_to_var[param.name_hint] = param - - bind_map = {} - for param in params: - bind_map[str_to_var[param]] = relay.const(params[param]) - - body = relay.bind(main_fn.body, bind_map) - main_fn = relay.Function(relay.analysis.free_vars(body), body) - mod["main_fn"] = main_fn - return mod - -def test_resnet_inlined_params(): - mod, params = relay.testing.resnet.get_workload() - print("here") - mod = inline_params(mod, params) - print("here") - text = str(mod.astext()) - print("here") - parsed_mod = parse_module(text) - print("here") - tvm.ir.assert_structural_equal(mod, parsed_mod) - print("here") - -if __name__ == "__main__": - test_resnet_inlined_params() From 669aafc58fe54ba326ad6dc97729237c23bfb89c Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Jul 2020 21:57:54 -0700 Subject: [PATCH 14/48] Record span end information --- include/tvm/ir/span.h | 14 ++++++++++++-- src/ir/span.cc | 11 ++++++----- src/parser/diagnostic.h | 12 ++++++++++-- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/include/tvm/ir/span.h b/include/tvm/ir/span.h index 40f854b027c6..49b139cb0f46 100644 --- a/include/tvm/ir/span.h +++ b/include/tvm/ir/span.h @@ -85,16 +85,26 @@ class SpanNode : public Object { int line; /*! \brief The column offset. */ int column; + /*! \brief The end line number. */ + int end_line; + /*! \brief The end column number. */ + int end_column; // override attr visitor void VisitAttrs(AttrVisitor* v) { v->Visit("source", &source); v->Visit("line", &line); v->Visit("column", &column); + v->Visit("end_line", &line); + v->Visit("end_column", &column); } bool SEqualReduce(const SpanNode* other, SEqualReducer equal) const { - return equal(source, other->source) && equal(line, other->line) && equal(column, other->column); + return equal(source, other->source) && + equal(line, other->line) && + equal(column, other->column) && + equal(end_line, other->end_line) && + equal(end_column, other->end_column); } static constexpr const char* _type_key = "Span"; @@ -103,7 +113,7 @@ class SpanNode : public Object { class Span : public ObjectRef { public: - TVM_DLL Span(SourceName source, int lineno, int col_offset); + TVM_DLL Span(SourceName source, int line, int column, int end_line, int end_column); TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode); }; diff --git a/src/ir/span.cc b/src/ir/span.cc index 64b42ab4dc14..80c15d8a7676 100644 --- a/src/ir/span.cc +++ b/src/ir/span.cc @@ -61,18 +61,19 @@ TVM_REGISTER_NODE_TYPE(SourceNameNode) return static_cast(n)->name; }); -Span::Span(SourceName source, int lineno, int col_offset) { +Span::Span(SourceName source, int line, int column, int end_line, int end_column) { auto n = make_object(); n->source = std::move(source); - n->line = lineno; - n->column = col_offset; + n->line = end_line; + n->column = end_column; data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(SpanNode); -TVM_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source, int lineno, int col_offset) { - return Span(source, lineno, col_offset); +TVM_REGISTER_GLOBAL("ir.Span") +.set_body_typed([](SourceName source, int line, int column, int end_line, int end_column) { + return Span(source, line, column, end_line, end_column); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) diff --git a/src/parser/diagnostic.h b/src/parser/diagnostic.h index b1b4e5590a1f..2e40ad43d7d9 100644 --- a/src/parser/diagnostic.h +++ b/src/parser/diagnostic.h @@ -62,8 +62,9 @@ struct Diagnostic { /*! \brief The diagnostic message. */ std::string message; + /*! \brief A diagnostic for a single character token. */ Diagnostic(int line, int column, const std::string& message) - : level(DiagnosticLevel::Error), span(SourceName(), line, column), message(message) {} + : level(DiagnosticLevel::Error), span(SourceName(), line, column, line, column + 1), message(message) {} Diagnostic(DiagnosticLevel level, Span span, const std::string& message) : level(level), span(span), message(message) {} @@ -97,6 +98,12 @@ struct DiagnosticBuilder { /*! \brief The column number. */ int column; + /*! \brief The line number. */ + int end_line; + + /*! \brief The column number. */ + int end_column; + template DiagnosticBuilder& operator<<(const T& val) { // NOLINT(*) stream_ << val; @@ -110,7 +117,8 @@ struct DiagnosticBuilder { : level(level), source_name(source_name), line(line), column(column) {} operator Diagnostic() { - return Diagnostic(this->level, Span(this->source_name, this->line, this->column), this->stream_.str()); + auto span = Span(this->source_name, this->line, this->column, this->end_line, this->end_column); + return Diagnostic(this->level, span, this->stream_.str()); } private: From 9be4e7b248f714bef2d97a8d61d9b884b42fb39b Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Jul 2020 22:17:27 -0700 Subject: [PATCH 15/48] Convert to using spans everywhere --- src/parser/diagnostic.h | 5 +++++ src/parser/parser.cc | 20 +++++++++---------- src/parser/token.h | 15 +++++++------- src/parser/tokenizer.h | 43 ++++++++++++++++++++++++++++------------- 4 files changed, 52 insertions(+), 31 deletions(-) diff --git a/src/parser/diagnostic.h b/src/parser/diagnostic.h index 2e40ad43d7d9..7dea00ec2948 100644 --- a/src/parser/diagnostic.h +++ b/src/parser/diagnostic.h @@ -92,6 +92,9 @@ struct DiagnosticBuilder { /*! \brief The source name. */ SourceName source_name; + /*! \brief The span of the diagnostic. */ + Span span; + /*! \brief The line number. */ int line; @@ -115,6 +118,8 @@ struct DiagnosticBuilder { : level(builder.level), source_name(builder.source_name), line(builder.line), column(builder.column) {} DiagnosticBuilder(DiagnosticLevel level, SourceName source_name, int line, int column) : level(level), source_name(source_name), line(line), column(column) {} + DiagnosticBuilder(DiagnosticLevel level, Span span) + : level(level), span(span) {} operator Diagnostic() { auto span = Span(this->source_name, this->line, this->column, this->end_line, this->end_column); diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 2104084ec7fb..8a793c02fb6c 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -295,7 +295,7 @@ class Parser { void Consume(const TokenType& token_type) { if (tokens[pos]->token_type != token_type) { this->diag_ctx->EmitFatal( - DiagnosticBuilder(DiagnosticLevel::Error, SourceName(), tokens[pos]->line, tokens[pos]->column) + DiagnosticBuilder(DiagnosticLevel::Error, tokens[pos]->span) << "expected a " << Pretty(token_type) << " found " << Pretty(Peek()->token_type)); } pos++; @@ -376,7 +376,7 @@ class Parser { auto var = this->expr_scopes.Lookup(local.ToString()); if (!var.defined()) { diag_ctx->Emit( - {local->line, local->column, "this local variable has not been previously declared"}); + {DiagnosticLevel::Error, local->span, "this local variable has not been previously declared"}); } return var; } @@ -389,7 +389,7 @@ class Parser { auto var = this->type_scopes.Lookup(ident.ToString()); if (!var.defined()) { diag_ctx->Emit( - {ident->line, ident->column, + {DiagnosticLevel::Error, ident->span, "this type variable has not been previously declared anywhere, perhaps a typo?"}); } return var; @@ -542,7 +542,7 @@ class Parser { } else { auto next = Peek(); this->diag_ctx->EmitFatal( - DiagnosticBuilder(DiagnosticLevel::Error, SourceName(), next->line, next->column) + DiagnosticBuilder(DiagnosticLevel::Error, next->span) << "expected a " << Pretty(stop) << " found " << Pretty(next->token_type)); return Array(nullptr); } @@ -583,12 +583,12 @@ class Parser { // TODO(@jroesch): we currently only support 0.0.5. if (version.ToString() != "\"0.0.5\"") { this->diag_ctx->Emit( - DiagnosticBuilder(DiagnosticLevel::Error, SourceName(), version->line, version->column) + DiagnosticBuilder(DiagnosticLevel::Error, version->span) << "invalid semantic version `" << version.ToString() << "`"); } } else if (required) { this->diag_ctx->Emit( - DiagnosticBuilder(DiagnosticLevel::Error, SourceName(), Peek()->line, Peek()->column) + DiagnosticBuilder(DiagnosticLevel::Error, Peek()->span) << "expected text format semantic version " << "you can annotate it as #[version = \"0.0.5\"]"); } @@ -620,7 +620,7 @@ class Parser { auto type_def = ParseTypeDef(); if (type_def->constructors.size()) { diag_ctx->Emit( - {next->line, next->column, "an external type may not have any constructors"}); + {DiagnosticLevel::Error, next->span, "an external type may not have any constructors"}); } defs.types.push_back(type_def); } @@ -1186,7 +1186,7 @@ class Parser { return Op::Get(op_name); } catch (dmlc::Error e) { this->diag_ctx->Emit( - DiagnosticBuilder(DiagnosticLevel::Error, SourceName(), tok->line, tok->column) + DiagnosticBuilder(DiagnosticLevel::Error, tok->span) << "operator `" << op_name << "` not found, perhaps you forgot to register it?"); return Expr(); } @@ -1291,7 +1291,7 @@ class Parser { } default: { this->diag_ctx->EmitFatal( - DiagnosticBuilder(DiagnosticLevel::Error, SourceName(), next->line, next->column) + DiagnosticBuilder(DiagnosticLevel::Error, next->span) << "expected an expression found " << Pretty(next->token_type)); return Expr(); } @@ -1413,7 +1413,7 @@ class Parser { return IncompleteType(); } else { this->diag_ctx->EmitFatal( - DiagnosticBuilder(DiagnosticLevel::Error, SourceName(), tok->line, tok->column) + DiagnosticBuilder(DiagnosticLevel::Error, tok->span) << "failed to parse type found " << tok); return Type(); } diff --git a/src/parser/token.h b/src/parser/token.h index 60a936852b89..4970667eb8bd 100644 --- a/src/parser/token.h +++ b/src/parser/token.h @@ -25,6 +25,7 @@ #ifndef TVM_PARSER_TOKEN_H_ #define TVM_PARSER_TOKEN_H_ +#include #include #include @@ -327,8 +328,7 @@ class Token; class TokenNode : public Object { public: - int line; - int column; + Span span; TokenType token_type; mutable runtime::ObjectRef data; @@ -341,7 +341,7 @@ class TokenNode : public Object { TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); - p->stream << "Token(line=" << node->line << ", column=" << node->column + p->stream << "Token(span=" << node->span << ", token_type=" << ToString(node->token_type) << ", data=" << node->data << ")"; }); @@ -349,7 +349,7 @@ TVM_REGISTER_NODE_TYPE(TokenNode); class Token : public ObjectRef { public: - TVM_DLL explicit Token(int line, int column, TokenType token_type, ObjectRef data = ObjectRef()); + TVM_DLL explicit Token(Span span, TokenType token_type, ObjectRef data = ObjectRef()); static Token Null(); int64_t ToNumber() const; @@ -358,16 +358,15 @@ class Token : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(Token, ObjectRef, TokenNode); }; -Token::Token(int line, int column, TokenType token_type, ObjectRef data) { +Token::Token(Span span, TokenType token_type, ObjectRef data) { ObjectPtr n = make_object(); - n->line = line; - n->column = column; + n->span = span; n->token_type = token_type; n->data = data; data_ = std::move(n); } -Token Token::Null() { return Token(0, 0, TokenType::Null); } +Token Token::Null() { return Token(Span(SourceName(), 0, 0, 0, 0), TokenType::Null); } int64_t Token::ToNumber() const { return Downcast(this->operator->()->data); } diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index 7a68c458de08..2ab80807405a 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -103,8 +103,15 @@ struct Tokenizer { return this->source.at(this->pos); } - Token NewToken(TokenType token_type, ObjectRef data = ObjectRef()) { - return Token(this->line, this->col, token_type, data); + Token NewToken(TokenType token_type, ObjectRef data = ObjectRef(), int lines = 0, int cols = 1) { + auto span = Span(this->source_name, this->line, this->col, this->line + lines, this->col + cols); + return Token(span, token_type, data); + } + + Span SpanFrom(int line, int column) { + int end_line = this->line; + int end_column = this->col; + return Span(this->source_name, line, column, end_line, end_column); } enum CommentParserState { @@ -228,7 +235,10 @@ struct Tokenizer { Next(); // todo: add error handling around bad indices auto index = ParseNumber(true, false, str_index.str()).ToNumber(); - return Token(line, column, TokenType::MetaReference, MetaRef(type_key.str(), index)); + int end_line = this->line; + int end_column = this->col; + auto span = Span(this->source_name, line, column, end_line, end_column); + return Token(span, TokenType::MetaReference, MetaRef(type_key.str(), index)); } Token TokenizeAttr() { @@ -257,12 +267,14 @@ struct Tokenizer { metadata << Next(); } ObjectRef metadata_map = tvm::LoadJSON(metadata.str()); - return Token(line, column, TokenType::Metadata, metadata_map); + auto span = SpanFrom(line, column); + return Token(span, TokenType::Metadata, metadata_map); } if (attribute.rfind("version", 0) == 0) { std::string version = attribute.substr(attribute.find("=") + 1); ltrim(version); rtrim(version); - return Token(line, column, TokenType::Version, tvm::String(version)); + auto span = SpanFrom(line, column); + return Token(span, TokenType::Version, tvm::String(version)); } else { // TOOD(@jroesch): maybe make this a warning an continue parsing? this->diag_ctx->EmitFatal( @@ -431,7 +443,8 @@ struct Tokenizer { auto number_str = number.str(); if (number_str.size()) { auto num_tok = ParseNumber(true, false, number_str); - token = Token(token->line, token->column, TokenType::Graph, num_tok->data); + auto span = SpanFrom(token->span->line, token->span->column); + token = Token(span, TokenType::Graph, num_tok->data); } return token; @@ -486,7 +499,8 @@ struct Tokenizer { token_type = TokenType::Identifier; } - return Token(line, col, token_type, tvm::String(ss.str())); + auto span = SpanFrom(line, col); + return Token(span, token_type, tvm::String(ss.str())); } else { std::stringstream ss; while (More() && !IsWhitespace(Peek())) { @@ -522,12 +536,13 @@ std::vector Condense(const std::vector& tokens) { if (next->token_type == TokenType::Identifier) { // Match this token. i += 1; - auto tok = Token(current->line, current->column, TokenType::Local, next->data); + // TODO(@jroesch): merge spans + auto tok = Token(current->span, TokenType::Local, next->data); CHECK(tok.defined()); out.push_back(tok); } else if (next->token_type == TokenType::Integer) { i += 1; - auto tok = Token(current->line, current->column, TokenType::Graph, next->data); + auto tok = Token(current->span, TokenType::Graph, next->data); CHECK(tok.defined()); out.push_back(tok); } else { @@ -541,7 +556,8 @@ std::vector Condense(const std::vector& tokens) { if (next->token_type == TokenType::Identifier) { // Match this token. i += 1; - auto tok = Token(current->line, current->column, TokenType::Global, next->data); + // TODO(@jroesch): merge spans + auto tok = Token(current->span, TokenType::Global, next->data); CHECK(tok.defined()); out.push_back(tok); } else { @@ -553,14 +569,15 @@ std::vector Condense(const std::vector& tokens) { case TokenType::Identifier: { std::string str = Downcast(current->data); Token tok; + // TODO(@jroesch): merge spans if (str == "True") { auto data = tvm::Integer(1); - tok = Token(current->line, current->column, TokenType::Boolean, data); + tok = Token(current->span, TokenType::Boolean, data); } else if (str == "False") { auto data = tvm::Integer(0); - tok = Token(current->line, current->column, TokenType::Boolean, data); + tok = Token(current->span, TokenType::Boolean, data); } else if (str == "_") { - tok = Token(current->line, current->column, TokenType::Underscore); + tok = Token(current->span, TokenType::Underscore); } else { tok = current; } From ab782e6750c4434a0f2693984984ee809ef9b374 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 29 Jul 2020 02:34:48 -0700 Subject: [PATCH 16/48] Add span fields back to all Relay constructors --- include/tvm/ir/span.h | 3 +++ include/tvm/relay/expr.h | 23 ++++++++++++----------- src/ir/span.cc | 15 +++++++++++++-- src/relay/ir/expr.cc | 30 ++++++++++++++++++++---------- 4 files changed, 48 insertions(+), 23 deletions(-) diff --git a/include/tvm/ir/span.h b/include/tvm/ir/span.h index 49b139cb0f46..494cdbf3a956 100644 --- a/include/tvm/ir/span.h +++ b/include/tvm/ir/span.h @@ -115,6 +115,9 @@ class Span : public ObjectRef { public: TVM_DLL Span(SourceName source, int line, int column, int end_line, int end_column); + /*! \brief Merge two spans into one which captures the combined regions. */ + TVM_DLL Span Merge(const Span& other); + TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode); }; diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 3c156dfd7481..2b5edfb248f5 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -89,7 +89,7 @@ class Constant : public Expr { * \brief The constructor * \param data The data of the constant tensor. */ - TVM_DLL explicit Constant(runtime::NDArray data); + TVM_DLL explicit Constant(runtime::NDArray data, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode); }; @@ -135,7 +135,7 @@ class Tuple : public Expr { * \brief The constructor * \param fields The fields of a tuple. */ - TVM_DLL explicit Tuple(tvm::Array fields); + TVM_DLL explicit Tuple(tvm::Array fields, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Tuple, RelayExpr, TupleNode); }; @@ -203,14 +203,15 @@ class Var : public Expr { * \param name_hint The name hint of a variable. * \param type_annotation The type annotation of a variable. */ - TVM_DLL Var(String name_hint, Type type_annotation) : Var(Id(name_hint), type_annotation) {} + TVM_DLL Var(String name_hint, Type type_annotation, Span span = Span()) + : Var(Id(name_hint), type_annotation, span) {} /*! * \brief The constructor * \param vid The unique id of a variable. * \param type_annotation The type annotation of a variable. */ - TVM_DLL Var(Id vid, Type type_annotation); + TVM_DLL Var(Id vid, Type type_annotation, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Var, RelayExpr, VarNode); }; @@ -297,7 +298,7 @@ class Call : public Expr { * \param type_args The type arguments passed to a polymorphic function. */ TVM_DLL Call(Expr op, Array args, Attrs attrs = Attrs(), - Array type_args = Array()); + Array type_args = Array(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Call, RelayExpr, CallNode); }; @@ -357,7 +358,7 @@ class Let : public Expr { * \param value The value used to bind to the variable. * \param body The body of the let binding. */ - TVM_DLL Let(Var var, Expr value, Expr body); + TVM_DLL Let(Var var, Expr value, Expr body, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Let, RelayExpr, LetNode); }; @@ -417,7 +418,7 @@ class If : public Expr { * \param true_branch The fall through branch * \param false_branch The branch for execution when condition is false. */ - TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch); + TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(If, RelayExpr, IfNode); }; @@ -458,7 +459,7 @@ class TupleGetItem : public Expr { * \param tuple The tuple to get an element from. * \param index The index for extracting a value in the tuple. */ - TVM_DLL TupleGetItem(Expr tuple, int index); + TVM_DLL TupleGetItem(Expr tuple, int index, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItem, RelayExpr, TupleGetItemNode); }; @@ -496,7 +497,7 @@ class RefCreate : public Expr { * \brief The constructor * \param value The initial value of the reference. */ - TVM_DLL explicit RefCreate(Expr value); + TVM_DLL explicit RefCreate(Expr value, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(RefCreate, RelayExpr, RefCreateNode); }; @@ -534,7 +535,7 @@ class RefRead : public Expr { * \brief The constructor * \param ref The reference where to read data. */ - TVM_DLL explicit RefRead(Expr ref); + TVM_DLL explicit RefRead(Expr ref, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(RefRead, RelayExpr, RefReadNode); }; @@ -578,7 +579,7 @@ class RefWrite : public Expr { * \param ref The reference where data is write to. * \param value The value to write. */ - TVM_DLL RefWrite(Expr ref, Expr value); + TVM_DLL RefWrite(Expr ref, Expr value, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(RefWrite, RelayExpr, RefWriteNode); }; diff --git a/src/ir/span.cc b/src/ir/span.cc index 80c15d8a7676..29732c476714 100644 --- a/src/ir/span.cc +++ b/src/ir/span.cc @@ -64,11 +64,22 @@ TVM_REGISTER_NODE_TYPE(SourceNameNode) Span::Span(SourceName source, int line, int column, int end_line, int end_column) { auto n = make_object(); n->source = std::move(source); - n->line = end_line; - n->column = end_column; + n->line = line; + n->column = column; + n->end_line = end_line; + n->end_column = end_column; data_ = std::move(n); } +Span Span::Merge(const Span& other) { + CHECK((*this)->source == other->source); + return Span((*this)->source, + std::min((*this)->line, other->line), + std::min((*this)->column, other->column), + std::max((*this)->end_line, other->end_line), + std::max((*this)->end_column, other->end_column)); +} + TVM_REGISTER_NODE_TYPE(SpanNode); TVM_REGISTER_GLOBAL("ir.Span") diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 1d9e3cef12b7..237cb35d8455 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -30,9 +30,10 @@ namespace relay { using tvm::ReprPrinter; using namespace tvm::runtime; -Constant::Constant(runtime::NDArray data) { +Constant::Constant(runtime::NDArray data, Span span) { ObjectPtr n = make_object(); n->data = std::move(data); + n->span = std::move(span); data_ = std::move(n); } @@ -63,9 +64,10 @@ TensorType ConstantNode::tensor_type() const { return TensorType(shape, dtype); } -Tuple::Tuple(tvm::Array fields) { +Tuple::Tuple(tvm::Array fields, Span span) { ObjectPtr n = make_object(); n->fields = std::move(fields); + n->span = std::move(span); data_ = std::move(n); } @@ -81,10 +83,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "Tuple(" << node->fields << ")"; }); -Var::Var(Id vid, Type type_annotation) { +Var::Var(Id vid, Type type_annotation, Span span) { ObjectPtr n = make_object(); n->vid = std::move(vid); n->type_annotation = std::move(type_annotation); + n->span = std::move(span); data_ = std::move(n); } @@ -105,12 +108,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ")"; }); -Call::Call(Expr op, Array args, Attrs attrs, Array type_args) { +Call::Call(Expr op, Array args, Attrs attrs, Array type_args, Span span) { ObjectPtr n = make_object(); n->op = std::move(op); n->args = std::move(args); n->attrs = std::move(attrs); n->type_args = std::move(type_args); + n->span = std::move(span); data_ = std::move(n); } @@ -128,11 +132,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << node->type_args << ")"; }); -Let::Let(Var var, Expr value, Expr body) { +Let::Let(Var var, Expr value, Expr body, Span span) { ObjectPtr n = make_object(); n->var = std::move(var); n->value = std::move(value); n->body = std::move(body); + n->span = std::move(span); data_ = std::move(n); } @@ -148,11 +153,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "LetNode(" << node->var << ", " << node->value << ", " << node->body << ")"; }); -If::If(Expr cond, Expr true_branch, Expr false_branch) { +If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) { ObjectPtr n = make_object(); n->cond = std::move(cond); n->true_branch = std::move(true_branch); n->false_branch = std::move(false_branch); + n->span = std::move(span); data_ = std::move(n); } @@ -170,10 +176,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << node->false_branch << ")"; }); -TupleGetItem::TupleGetItem(Expr tuple, int index) { +TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) { ObjectPtr n = make_object(); n->tuple = std::move(tuple); n->index = index; + n->span = std::move(span); data_ = std::move(n); } @@ -189,9 +196,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")"; }); -RefCreate::RefCreate(Expr value) { +RefCreate::RefCreate(Expr value, Span span) { ObjectPtr n = make_object(); n->value = std::move(value); + n->span = std::move(span); data_ = std::move(n); } @@ -207,9 +215,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "RefCreateNode(" << node->value << ")"; }); -RefRead::RefRead(Expr ref) { +RefRead::RefRead(Expr ref, Span span) { ObjectPtr n = make_object(); n->ref = std::move(ref); + n->span = std::move(span); data_ = std::move(n); } @@ -223,10 +232,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "RefReadNode(" << node->ref << ")"; }); -RefWrite::RefWrite(Expr ref, Expr value) { +RefWrite::RefWrite(Expr ref, Expr value, Span span) { ObjectPtr n = make_object(); n->ref = std::move(ref); n->value = std::move(value); + n->span = std::move(span); data_ = std::move(n); } From b7e7fc5c3f6fd094a085ba716d941076291d23c1 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 29 Jul 2020 02:41:13 -0700 Subject: [PATCH 17/48] Start passing spans --- src/parser/parser.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 8a793c02fb6c..e1a404a28664 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -1201,14 +1201,14 @@ class Parser { case TokenType::Float: { Consume(next->token_type); auto number = NumberToNDArray(next); - Expr e = Constant(number); + Expr e = Constant(number, next->span); return e; } case TokenType::Boolean: { Consume(TokenType::Boolean); int value = Downcast(next->data); auto boolean = BooleanToNDarray(value); - Expr e = Constant(boolean); + Expr e = Constant(boolean, next->span); return e; } // Parse a local of the form `%x`. @@ -1222,6 +1222,7 @@ class Parser { Consume(TokenType::Global); auto global = global_names.Get(string); if (!global) { + // TODO(@jroesch): fix global's needing span information auto global_var = GlobalVar(string); global_names.Add(string, global_var); return Expr(global_var); From 486f779d814d28c3bebecae7e87a9f2674d7832b Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 29 Jul 2020 21:48:15 -0700 Subject: [PATCH 18/48] Pass spans around visitors --- include/tvm/relay/adt.h | 2 +- include/tvm/relay/expr_functor.h | 2 +- include/tvm/relay/function.h | 2 +- python/tvm/parser/__init__.py | 3 +- python/tvm/relay/__init__.py | 6 +-- src/parser/parser.cc | 4 +- src/relay/ir/adt.cc | 3 +- src/relay/ir/expr_functor.cc | 69 +++++++++++++++++++--------- src/relay/ir/function.cc | 3 +- tests/python/relay/test_ir_parser.py | 4 +- 10 files changed, 61 insertions(+), 37 deletions(-) diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index b2164ba8c1f7..4f58e957aaae 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -310,7 +310,7 @@ class Match : public Expr { * \param clauses The clauses for matching. * \param complete Indicate if this match is complete. */ - TVM_DLL Match(Expr data, tvm::Array clauses, bool complete = true); + TVM_DLL Match(Expr data, tvm::Array clauses, bool complete = true, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Match, RelayExpr, MatchNode); }; diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 1189643c8181..c6eee15882fc 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -164,7 +164,7 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor { virtual void VisitType(const Type& t); virtual void VisitClause(const Clause& c); virtual void VisitPattern(const Pattern& c); - + virtual void VisitSpan(const Span& span); protected: // Internal visiting counter std::unordered_map visit_counter_; diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index d52a66cdadeb..1ae1cbd6ceae 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -112,7 +112,7 @@ class Function : public BaseFunc { * \param attrs Additional function attributes. */ TVM_DLL Function(tvm::Array params, Expr body, Type ret_type, tvm::Array ty_params, - tvm::DictAttrs attrs = NullValue()); + tvm::DictAttrs attrs = NullValue(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode); diff --git a/python/tvm/parser/__init__.py b/python/tvm/parser/__init__.py index 696af362e03f..8001cd416781 100644 --- a/python/tvm/parser/__init__.py +++ b/python/tvm/parser/__init__.py @@ -24,5 +24,4 @@ def parse_expr(source): return _ffi_api.ParseExpr("string", source) def fromtext(source, source_name="from_string"): - # TODO(@tqchen): currently we have to invoke `str` which dramatically reduces performance. - return parse(str(source), str(source_name)) + return parse(source, source_name) diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index e3909d9d6378..cd96ecc7ee33 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -29,7 +29,6 @@ from . import prelude from . import loops from . import scope_builder -from . import parser from . import transform from . import analysis @@ -132,12 +131,9 @@ # Prelude Prelude = prelude.Prelude -# Scope builder +# Scope Builder ScopeBuilder = scope_builder.ScopeBuilder -# Parser -fromtext = parser.fromtext - # Param Serialization save_param_dict = param_dict.save_param_dict load_param_dict = param_dict.load_param_dict diff --git a/src/parser/parser.cc b/src/parser/parser.cc index e1a404a28664..b5f8e722c366 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -1492,12 +1492,12 @@ Expr ParseExpr(std::string file_name, std::string file_content) { } TVM_REGISTER_GLOBAL("parser.ParseModule") - .set_body_typed([](std::string file_name, std::string file_content) { + .set_body_typed([](tvm::String file_name, tvm::String file_content) { return ParseModule(file_name, file_content); }); TVM_REGISTER_GLOBAL("parser.ParseExpr") - .set_body_typed([](std::string file_name, std::string file_content) { + .set_body_typed([](tvm::String file_name, tvm::String file_content) { return ParseExpr(file_name, file_content); }); diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc index d808351e841c..ba9743cc35bf 100644 --- a/src/relay/ir/adt.cc +++ b/src/relay/ir/adt.cc @@ -116,11 +116,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "ClauseNode(" << node->lhs << ", " << node->rhs << ")"; }); -Match::Match(Expr data, tvm::Array clauses, bool complete) { +Match::Match(Expr data, tvm::Array clauses, bool complete, Span span) { ObjectPtr n = make_object(); n->data = std::move(data); n->clauses = std::move(clauses); n->complete = complete; + n->span = std::move(span); data_ = std::move(n); } diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index ad15453e3058..c39b7f76818a 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -199,7 +199,7 @@ Expr ExprMutator::VisitExpr_(const VarNode* op) { if (op->type_annotation.defined()) { auto type = this->VisitType(op->type_annotation); if (!op->type_annotation.same_as(type)) { - return Var(op->vid, type); + return Var(op->vid, type, op->span); } } // default case return self. @@ -224,7 +224,7 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op) { if (all_fields_unchanged) { return GetRef(op); } else { - return Tuple(fields); + return Tuple(fields, op->span); } } @@ -253,7 +253,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) { body.same_as(op->body)) { return GetRef(op); } else { - return Function(params, body, ret_type, ty_params, op->attrs); + return Function(params, body, ret_type, ty_params, op->attrs, op->span); } } @@ -278,7 +278,7 @@ Expr ExprMutator::VisitExpr_(const CallNode* call_node) { if (unchanged) { return GetRef(call_node); } else { - return Call(new_op, call_args, call_node->attrs, ty_args); + return Call(new_op, call_args, call_node->attrs, ty_args, call_node->span); } } @@ -290,7 +290,7 @@ Expr ExprMutator::VisitExpr_(const LetNode* op) { if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { - return Let(var, value, body); + return Let(var, value, body, op->span); } } @@ -302,16 +302,16 @@ Expr ExprMutator::VisitExpr_(const IfNode* op) { op->false_branch.same_as(false_b)) { return GetRef(op); } else { - return If(guard, true_b, false_b); + return If(guard, true_b, false_b, op->span); } } -Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) { - auto t = this->Mutate(g->tuple); - if (g->tuple == t) { - return GetRef(g); +Expr ExprMutator::VisitExpr_(const TupleGetItemNode* get_item) { + auto t = this->Mutate(get_item->tuple); + if (get_item->tuple == t) { + return GetRef(get_item); } else { - return TupleGetItem(t, g->index); + return TupleGetItem(t, get_item->index, get_item->span); } } @@ -320,7 +320,7 @@ Expr ExprMutator::VisitExpr_(const RefCreateNode* op) { if (value.same_as(op->value)) { return GetRef(op); } else { - return RefCreate(value); + return RefCreate(value, op->span); } } @@ -329,7 +329,7 @@ Expr ExprMutator::VisitExpr_(const RefReadNode* op) { if (ref.same_as(op->ref)) { return GetRef(op); } else { - return RefRead(ref); + return RefRead(ref, op->span); } } @@ -339,7 +339,7 @@ Expr ExprMutator::VisitExpr_(const RefWriteNode* op) { if (ref.same_as(op->ref) && value.same_as(op->value)) { return GetRef(op); } else { - return RefWrite(ref, value); + return RefWrite(ref, value, op->span); } } @@ -355,10 +355,11 @@ Expr ExprMutator::VisitExpr_(const MatchNode* m) { } Expr data = Mutate(m->data); unchanged &= data.same_as(m->data); + if (unchanged) { return GetRef(m); } - return Match(data, clauses, m->complete); + return Match(data, clauses, m->complete, m->span); } Clause ExprMutator::VisitClause(const Clause& c) { @@ -386,22 +387,29 @@ void ExprVisitor::VisitExpr(const Expr& expr) { } void ExprVisitor::VisitExpr_(const VarNode* op) { + this->VisitSpan(op->span); if (op->type_annotation.defined()) { this->VisitType(op->type_annotation); } } -void ExprVisitor::VisitExpr_(const GlobalVarNode* op) {} +void ExprVisitor::VisitExpr_(const GlobalVarNode* op) { + this->VisitSpan(op->span); +} -void ExprVisitor::VisitExpr_(const ConstantNode* op) {} +void ExprVisitor::VisitExpr_(const ConstantNode* op) { + this->VisitSpan(op->span); +} void ExprVisitor::VisitExpr_(const TupleNode* op) { + this->VisitSpan(op->span); for (auto field : op->fields) { this->VisitExpr(field); } } void ExprVisitor::VisitExpr_(const FunctionNode* op) { + this->VisitSpan(op->span); for (auto param : op->params) { this->VisitExpr(param); } @@ -410,6 +418,7 @@ void ExprVisitor::VisitExpr_(const FunctionNode* op) { } void ExprVisitor::VisitExpr_(const CallNode* op) { + this->VisitSpan(op->span); this->VisitExpr(op->op); for (auto ty_arg : op->type_args) { @@ -422,31 +431,45 @@ void ExprVisitor::VisitExpr_(const CallNode* op) { } void ExprVisitor::VisitExpr_(const LetNode* op) { + this->VisitSpan(op->span); this->VisitExpr(op->value); this->VisitExpr(op->var); this->VisitExpr(op->body); } void ExprVisitor::VisitExpr_(const IfNode* op) { + this->VisitSpan(op->span); this->VisitExpr(op->cond); this->VisitExpr(op->true_branch); this->VisitExpr(op->false_branch); } -void ExprVisitor::VisitExpr_(const OpNode* op) { return; } +void ExprVisitor::VisitExpr_(const OpNode* op) { + return; +} -void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { this->VisitExpr(op->tuple); } +void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { + this->VisitSpan(op->span); + this->VisitExpr(op->tuple); +} -void ExprVisitor::VisitExpr_(const RefCreateNode* op) { this->VisitExpr(op->value); } +void ExprVisitor::VisitExpr_(const RefCreateNode* op) { + this->VisitSpan(op->span); + this->VisitExpr(op->value); } -void ExprVisitor::VisitExpr_(const RefReadNode* op) { this->VisitExpr(op->ref); } +void ExprVisitor::VisitExpr_(const RefReadNode* op) { + this->VisitSpan(op->span); + this->VisitExpr(op->ref); +} void ExprVisitor::VisitExpr_(const RefWriteNode* op) { + this->VisitSpan(op->span); this->VisitExpr(op->ref); this->VisitExpr(op->value); } void ExprVisitor::VisitExpr_(const ConstructorNode* op) { + // TODO(@jroesch): visit spans for (const Type& t : op->inputs) { this->VisitType(t); } @@ -454,6 +477,7 @@ void ExprVisitor::VisitExpr_(const ConstructorNode* op) { } void ExprVisitor::VisitExpr_(const MatchNode* op) { + this->VisitSpan(op->span); this->VisitExpr(op->data); for (const Clause& c : op->clauses) { this->VisitClause(c); @@ -461,6 +485,7 @@ void ExprVisitor::VisitExpr_(const MatchNode* op) { } void ExprVisitor::VisitClause(const Clause& op) { + // TODO(@jroesch): visit spans this->VisitPattern(op->lhs); this->VisitExpr(op->rhs); } @@ -469,6 +494,8 @@ void ExprVisitor::VisitPattern(const Pattern& p) { return; } void ExprVisitor::VisitType(const Type& t) { return; } +void ExprVisitor::VisitSpan(const Span& span) { return; } + // visitor to implement apply class ExprApplyVisit : public ExprVisitor { public: diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index 5312e6d48447..1439e8b59cf0 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -27,7 +27,7 @@ namespace tvm { namespace relay { Function::Function(tvm::Array params, Expr body, Type ret_type, - tvm::Array type_params, DictAttrs attrs) { + tvm::Array type_params, DictAttrs attrs, Span span) { ObjectPtr n = make_object(); CHECK(params.defined()); CHECK(type_params.defined()); @@ -36,6 +36,7 @@ Function::Function(tvm::Array params, Expr body, Type ret_type, n->ret_type = std::move(ret_type); n->type_params = std::move(type_params); n->attrs = std::move(attrs); + n->span = std::move(span); data_ = std::move(n); } diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 7ffe85c5d5df..b3adbf5d9e56 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -76,7 +76,7 @@ def graph_equal(lhs, rhs): def roundtrip_expr(expr): text = tvm.relay.Expr.astext(expr, show_meta_data=False) - x = tvm.parser.parse_expr(str(text)) + x = tvm.parser.parse_expr(text) assert_graph_equal(x, expr) def roundtrip(expr): @@ -931,7 +931,7 @@ def test_resnet_inlined_params(): print("here") mod = inline_params(mod, params) print("here") - text = str(mod.astext()) + text = mod.astext() print("here") parsed_mod = parse_module(text) print("here") From 4daa1fec050fba90fc564329e4b81df33e7d147c Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 29 Jul 2020 22:05:48 -0700 Subject: [PATCH 19/48] Format --- include/tvm/ir/span.h | 6 +-- include/tvm/parser/source_map.h | 7 +-- include/tvm/relay/expr.h | 2 +- include/tvm/relay/expr_functor.h | 1 + src/ir/span.cc | 13 +++--- src/parser/diagnostic.h | 19 ++++---- src/parser/meta_ref.cc | 73 +++++++++++++++---------------- src/parser/meta_ref.h | 3 +- src/parser/parser.cc | 68 ++++++++++++++-------------- src/parser/source_map.cc | 23 +++++----- src/parser/token.h | 10 ++--- src/parser/tokenizer.h | 46 +++++++++++-------- src/printer/relay_text_printer.cc | 2 +- src/printer/text_printer.h | 5 ++- src/relay/ir/expr_functor.cc | 15 +++---- 15 files changed, 147 insertions(+), 146 deletions(-) diff --git a/include/tvm/ir/span.h b/include/tvm/ir/span.h index 494cdbf3a956..4f1006ebcb8a 100644 --- a/include/tvm/ir/span.h +++ b/include/tvm/ir/span.h @@ -100,10 +100,8 @@ class SpanNode : public Object { } bool SEqualReduce(const SpanNode* other, SEqualReducer equal) const { - return equal(source, other->source) && - equal(line, other->line) && - equal(column, other->column) && - equal(end_line, other->end_line) && + return equal(source, other->source) && equal(line, other->line) && + equal(column, other->column) && equal(end_line, other->end_line) && equal(end_column, other->end_column); } diff --git a/include/tvm/parser/source_map.h b/include/tvm/parser/source_map.h index b121b19df012..00a6c09ebaf0 100644 --- a/include/tvm/parser/source_map.h +++ b/include/tvm/parser/source_map.h @@ -23,9 +23,9 @@ * \file source_map.h * \brief A map from source names to source code. */ +#include #include #include -#include #include #include @@ -33,7 +33,6 @@ namespace tvm { namespace parser { - /*! \brief A program source in any language. * * Could represent the source from an ML framework or the internal @@ -79,9 +78,7 @@ class SourceMapNode : public Object { Map source_map; // override attr visitor - void VisitAttrs(AttrVisitor* v) { - v->Visit("source_map", &source_map); - } + void VisitAttrs(AttrVisitor* v) { v->Visit("source_map", &source_map); } bool SEqualReduce(const SourceMapNode* other, SEqualReducer equal) const { return equal(source_map, other->source_map); diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 2b5edfb248f5..aee2016d669c 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -204,7 +204,7 @@ class Var : public Expr { * \param type_annotation The type annotation of a variable. */ TVM_DLL Var(String name_hint, Type type_annotation, Span span = Span()) - : Var(Id(name_hint), type_annotation, span) {} + : Var(Id(name_hint), type_annotation, span) {} /*! * \brief The constructor diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index c6eee15882fc..c3d2f724b736 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -165,6 +165,7 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor { virtual void VisitClause(const Clause& c); virtual void VisitPattern(const Pattern& c); virtual void VisitSpan(const Span& span); + protected: // Internal visiting counter std::unordered_map visit_counter_; diff --git a/src/ir/span.cc b/src/ir/span.cc index 29732c476714..2a2601c3f3df 100644 --- a/src/ir/span.cc +++ b/src/ir/span.cc @@ -73,17 +73,16 @@ Span::Span(SourceName source, int line, int column, int end_line, int end_column Span Span::Merge(const Span& other) { CHECK((*this)->source == other->source); - return Span((*this)->source, - std::min((*this)->line, other->line), - std::min((*this)->column, other->column), - std::max((*this)->end_line, other->end_line), - std::max((*this)->end_column, other->end_column)); + return Span((*this)->source, std::min((*this)->line, other->line), + std::min((*this)->column, other->column), + std::max((*this)->end_line, other->end_line), + std::max((*this)->end_column, other->end_column)); } TVM_REGISTER_NODE_TYPE(SpanNode); -TVM_REGISTER_GLOBAL("ir.Span") -.set_body_typed([](SourceName source, int line, int column, int end_line, int end_column) { +TVM_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source, int line, int column, + int end_line, int end_column) { return Span(source, line, column, end_line, end_column); }); diff --git a/src/parser/diagnostic.h b/src/parser/diagnostic.h index 7dea00ec2948..881a1903481b 100644 --- a/src/parser/diagnostic.h +++ b/src/parser/diagnostic.h @@ -35,7 +35,6 @@ #include #include - #include #include #include @@ -64,10 +63,12 @@ struct Diagnostic { /*! \brief A diagnostic for a single character token. */ Diagnostic(int line, int column, const std::string& message) - : level(DiagnosticLevel::Error), span(SourceName(), line, column, line, column + 1), message(message) {} + : level(DiagnosticLevel::Error), + span(SourceName(), line, column, line, column + 1), + message(message) {} Diagnostic(DiagnosticLevel level, Span span, const std::string& message) - : level(level), span(span), message(message) {} + : level(level), span(span), message(message) {} }; /*! @@ -101,7 +102,7 @@ struct DiagnosticBuilder { /*! \brief The column number. */ int column; - /*! \brief The line number. */ + /*! \brief The line number. */ int end_line; /*! \brief The column number. */ @@ -115,11 +116,13 @@ struct DiagnosticBuilder { DiagnosticBuilder() : level(DiagnosticLevel::Error), source_name(), line(0), column(0) {} DiagnosticBuilder(const DiagnosticBuilder& builder) - : level(builder.level), source_name(builder.source_name), line(builder.line), column(builder.column) {} + : level(builder.level), + source_name(builder.source_name), + line(builder.line), + column(builder.column) {} DiagnosticBuilder(DiagnosticLevel level, SourceName source_name, int line, int column) - : level(level), source_name(source_name), line(line), column(column) {} - DiagnosticBuilder(DiagnosticLevel level, Span span) - : level(level), span(span) {} + : level(level), source_name(source_name), line(line), column(column) {} + DiagnosticBuilder(DiagnosticLevel level, Span span) : level(level), span(span) {} operator Diagnostic() { auto span = Span(this->source_name, this->line, this->column, this->end_line, this->end_column); diff --git a/src/parser/meta_ref.cc b/src/parser/meta_ref.cc index 7c47d8635864..0bedf8a353fa 100644 --- a/src/parser/meta_ref.cc +++ b/src/parser/meta_ref.cc @@ -22,26 +22,26 @@ * \brief An operator which allows forward referencing a yet-to-be parsed meta table reference. */ -#include +#include "./meta_ref.h" + +#include #include -#include +#include #include -#include - -#include "./meta_ref.h" +#include namespace tvm { namespace parser { -using tvm::transform::PassContext; using tvm::relay::transform::CreateFunctionPass; +using tvm::transform::PassContext; TVM_REGISTER_NODE_TYPE(MetaRefAttrs); bool MetaRefRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { - LOG(FATAL) << "need to expand before type checking"; - return true; + const TypeReporter& reporter) { + LOG(FATAL) << "need to expand before type checking"; + return true; } RELAY_REGISTER_OP("parser.MetaRef") @@ -54,14 +54,13 @@ RELAY_REGISTER_OP("parser.MetaRef") .set_attr("TNonComputational", true); Expr MetaRef(std::string type_key, uint64_t node_index) { - static const Op& op = Op::Get("parser.MetaRef"); - auto attrs = make_object(); - attrs->node_type_key = tvm::String(type_key); - attrs->node_index = node_index; - return Call(op, {}, Attrs(attrs), {}); + static const Op& op = Op::Get("parser.MetaRef"); + auto attrs = make_object(); + attrs->node_type_key = tvm::String(type_key); + attrs->node_index = node_index; + return Call(op, {}, Attrs(attrs), {}); } - // class MetaRefAttrExpander : AttrFunctor { // ObjectRef VisitAttrDefault_(const Object* node) final { @@ -69,36 +68,36 @@ Expr MetaRef(std::string type_key, uint64_t node_index) { // } struct MetaRefExpander : public ExprMutator { - MetaTable table; - - MetaRefExpander(const MetaTable& table) : table(table) {} - - Expr VisitExpr_(const CallNode* call) final { - if (auto op_node = call->op.as()) { - if (op_node->name == "parser.MetaRef") { - auto meta_attrs = call->attrs.as(); - CHECK(meta_attrs) << "an internal error has occurred"; - auto nodes = table.at(meta_attrs->node_type_key); - CHECK_LT(meta_attrs->node_index, nodes.size()); - return Downcast(nodes[meta_attrs->node_index]); - } - } - - return ExprMutator::VisitExpr_(call); + MetaTable table; + + MetaRefExpander(const MetaTable& table) : table(table) {} + + Expr VisitExpr_(const CallNode* call) final { + if (auto op_node = call->op.as()) { + if (op_node->name == "parser.MetaRef") { + auto meta_attrs = call->attrs.as(); + CHECK(meta_attrs) << "an internal error has occurred"; + auto nodes = table.at(meta_attrs->node_type_key); + CHECK_LT(meta_attrs->node_index, nodes.size()); + return Downcast(nodes[meta_attrs->node_index]); + } } + + return ExprMutator::VisitExpr_(call); + } }; Function ExpandMetaRefs(const MetaTable& meta_table, const relay::Function& func) { - MetaRefExpander expander(meta_table); - return Downcast(expander.VisitExpr(func)); + MetaRefExpander expander(meta_table); + return Downcast(expander.VisitExpr(func)); } IRModule ExpandMetaRefs(const MetaTable& meta_table, const IRModule& mod) { - auto pass = CreateFunctionPass([&](Function func, IRModule module, PassContext ctx) { - return ExpandMetaRefs(meta_table, func); - }, 1337, "ExpandMetaRefs", {}); + auto pass = CreateFunctionPass([&](Function func, IRModule module, + PassContext ctx) { return ExpandMetaRefs(meta_table, func); }, + 1337, "ExpandMetaRefs", {}); - return pass(mod, PassContext::Create()); + return pass(mod, PassContext::Create()); } } // namespace parser diff --git a/src/parser/meta_ref.h b/src/parser/meta_ref.h index c3bccce355e5..a51985ceef28 100644 --- a/src/parser/meta_ref.h +++ b/src/parser/meta_ref.h @@ -25,7 +25,6 @@ #ifndef TVM_PARSER_META_REF_H_ #define TVM_PARSER_META_REF_H_ - #include #include @@ -36,7 +35,7 @@ namespace parser { using namespace relay; -using MetaTable = Map>; +using MetaTable = Map>; /*! * \brief Options for allocating storage. diff --git a/src/parser/parser.cc b/src/parser/parser.cc index b5f8e722c366..05e518226441 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -31,8 +31,8 @@ #include -#include "./meta_ref.h" #include "./diagnostic.h" +#include "./meta_ref.h" #include "./op_table.h" #include "./tokenizer.h" @@ -245,8 +245,14 @@ class Parser { /*! \brief The set of expression scopes used for lexical scope. */ ScopeStack expr_scopes; - Parser(DiagnosticContext* ctx, const SourceName& source_name, std::vector tokens, OperatorTable op_table, Source source) - : diag_ctx(ctx), source_name(source_name), pos(0), tokens(tokens), op_table(op_table), ignore_whitespace(true) {} + Parser(DiagnosticContext* ctx, const SourceName& source_name, std::vector tokens, + OperatorTable op_table, Source source) + : diag_ctx(ctx), + source_name(source_name), + pos(0), + tokens(tokens), + op_table(op_table), + ignore_whitespace(true) {} /*! \brief Examine the next token in the stream, the current parser is configured to be * whitespace insensitive so we will skip all whitespace or comment tokens. */ @@ -294,9 +300,9 @@ class Parser { */ void Consume(const TokenType& token_type) { if (tokens[pos]->token_type != token_type) { - this->diag_ctx->EmitFatal( - DiagnosticBuilder(DiagnosticLevel::Error, tokens[pos]->span) - << "expected a " << Pretty(token_type) << " found " << Pretty(Peek()->token_type)); + this->diag_ctx->EmitFatal(DiagnosticBuilder(DiagnosticLevel::Error, tokens[pos]->span) + << "expected a " << Pretty(token_type) << " found " + << Pretty(Peek()->token_type)); } pos++; } @@ -375,8 +381,8 @@ class Parser { Var LookupLocal(const Token& local) { auto var = this->expr_scopes.Lookup(local.ToString()); if (!var.defined()) { - diag_ctx->Emit( - {DiagnosticLevel::Error, local->span, "this local variable has not been previously declared"}); + diag_ctx->Emit({DiagnosticLevel::Error, local->span, + "this local variable has not been previously declared"}); } return var; } @@ -541,9 +547,9 @@ class Parser { return elements; } else { auto next = Peek(); - this->diag_ctx->EmitFatal( - DiagnosticBuilder(DiagnosticLevel::Error, next->span) - << "expected a " << Pretty(stop) << " found " << Pretty(next->token_type)); + this->diag_ctx->EmitFatal(DiagnosticBuilder(DiagnosticLevel::Error, next->span) + << "expected a " << Pretty(stop) << " found " + << Pretty(next->token_type)); return Array(nullptr); } } @@ -577,20 +583,18 @@ class Parser { } /*! \brief Parse the semantic versioning header. */ - SemVer ParseSemVer(bool required=true) { + SemVer ParseSemVer(bool required = true) { if (Peek()->token_type == TokenType::Version) { auto version = Match(TokenType::Version); // TODO(@jroesch): we currently only support 0.0.5. if (version.ToString() != "\"0.0.5\"") { - this->diag_ctx->Emit( - DiagnosticBuilder(DiagnosticLevel::Error, version->span) - << "invalid semantic version `" << version.ToString() << "`"); + this->diag_ctx->Emit(DiagnosticBuilder(DiagnosticLevel::Error, version->span) + << "invalid semantic version `" << version.ToString() << "`"); } } else if (required) { - this->diag_ctx->Emit( - DiagnosticBuilder(DiagnosticLevel::Error, Peek()->span) - << "expected text format semantic version " - << "you can annotate it as #[version = \"0.0.5\"]"); + this->diag_ctx->Emit(DiagnosticBuilder(DiagnosticLevel::Error, Peek()->span) + << "expected text format semantic version " + << "you can annotate it as #[version = \"0.0.5\"]"); } return SemVer(0, 0, 5); } @@ -619,8 +623,8 @@ class Parser { Consume(TokenType::Extern); auto type_def = ParseTypeDef(); if (type_def->constructors.size()) { - diag_ctx->Emit( - {DiagnosticLevel::Error, next->span, "an external type may not have any constructors"}); + diag_ctx->Emit({DiagnosticLevel::Error, next->span, + "an external type may not have any constructors"}); } defs.types.push_back(type_def); } @@ -1089,9 +1093,8 @@ class Parser { case TokenType::StringLiteral: return Match(next->token_type)->data; case TokenType::LSquare: { - return ParseSequence(TokenType::LSquare, TokenType::Comma, TokenType::RSquare, [&]() { - return ParseAttributeValue(); - }); + return ParseSequence(TokenType::LSquare, TokenType::Comma, TokenType::RSquare, + [&]() { return ParseAttributeValue(); }); } default: return ParseAtomicExpr(); @@ -1185,9 +1188,9 @@ class Parser { try { return Op::Get(op_name); } catch (dmlc::Error e) { - this->diag_ctx->Emit( - DiagnosticBuilder(DiagnosticLevel::Error, tok->span) - << "operator `" << op_name << "` not found, perhaps you forgot to register it?"); + this->diag_ctx->Emit(DiagnosticBuilder(DiagnosticLevel::Error, tok->span) + << "operator `" << op_name + << "` not found, perhaps you forgot to register it?"); return Expr(); } } @@ -1291,9 +1294,9 @@ class Parser { } } default: { - this->diag_ctx->EmitFatal( - DiagnosticBuilder(DiagnosticLevel::Error, next->span) - << "expected an expression found " << Pretty(next->token_type)); + this->diag_ctx->EmitFatal(DiagnosticBuilder(DiagnosticLevel::Error, next->span) + << "expected an expression found " + << Pretty(next->token_type)); return Expr(); } } @@ -1413,9 +1416,8 @@ class Parser { if (WhenMatch(TokenType::Underscore)) { return IncompleteType(); } else { - this->diag_ctx->EmitFatal( - DiagnosticBuilder(DiagnosticLevel::Error, tok->span) - << "failed to parse type found " << tok); + this->diag_ctx->EmitFatal(DiagnosticBuilder(DiagnosticLevel::Error, tok->span) + << "failed to parse type found " << tok); return Type(); } } diff --git a/src/parser/source_map.cc b/src/parser/source_map.cc index f70df420ad9f..fe9587cd4ed7 100644 --- a/src/parser/source_map.cc +++ b/src/parser/source_map.cc @@ -49,17 +49,16 @@ Source::Source(const std::string& source) : source(source) { line_map.back().second = length; } - /*! \brief Generate an error message at a specific line and column with the - * annotated message. - * - * The error is written directly to the `out` std::ostream. - * - * \param out The output ostream. - * \param line The line at which to report a diagnostic. - * \param line The column at which to report a diagnostic. - * \param msg The message to attach. - */ + * annotated message. + * + * The error is written directly to the `out` std::ostream. + * + * \param out The output ostream. + * \param line The line at which to report a diagnostic. + * \param line The column at which to report a diagnostic. + * \param msg The message to attach. + */ void Source::ReportAt(std::ostream& out, int line, int column, const std::string& msg) const { CHECK(line - 1 <= static_cast(line_map.size())) << "requested line: " << (line - 1) << "line_map size: " << line_map.size() @@ -104,5 +103,5 @@ SourceMap::SourceMap(Map source_map) { data_ = std::move(n); } -} // namespace parser -} // namespace tvm +} // namespace parser +} // namespace tvm diff --git a/src/parser/token.h b/src/parser/token.h index 4970667eb8bd..480872956b68 100644 --- a/src/parser/token.h +++ b/src/parser/token.h @@ -341,8 +341,8 @@ class TokenNode : public Object { TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); - p->stream << "Token(span=" << node->span - << ", token_type=" << ToString(node->token_type) << ", data=" << node->data << ")"; + p->stream << "Token(span=" << node->span << ", token_type=" << ToString(node->token_type) + << ", data=" << node->data << ")"; }); TVM_REGISTER_NODE_TYPE(TokenNode); @@ -372,9 +372,9 @@ int64_t Token::ToNumber() const { return Downcast(this->operator-> std::string Token::ToString() const { return Downcast(this->operator->()->data); } - Map> Token::ToMetadata() const { - return Downcast>>(this->operator->()->data); - } +Map> Token::ToMetadata() const { + return Downcast>>(this->operator->()->data); +} } // namespace parser } // namespace tvm diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index 2ab80807405a..43f23d231077 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -24,17 +24,17 @@ #ifndef TVM_PARSER_TOKENIZER_H_ #define TVM_PARSER_TOKENIZER_H_ +#include #include #include -#include #include #include #include #include -#include "./token.h" #include "./meta_ref.h" +#include "./token.h" namespace tvm { namespace parser { @@ -42,17 +42,14 @@ namespace parser { using namespace runtime; // trim from start (in place) -static inline void ltrim(std::string &s) { - s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) { - return !std::isspace(ch); - })); +static inline void ltrim(std::string& s) { + s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) { return !std::isspace(ch); })); } // trim from end (in place) -static inline void rtrim(std::string &s) { - s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) { - return !std::isspace(ch); - }).base(), s.end()); +static inline void rtrim(std::string& s) { + s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) { return !std::isspace(ch); }).base(), + s.end()); } bool IsDigit(char c) { return '0' <= c && c <= '9'; } @@ -74,7 +71,7 @@ static std::unordered_map KEYWORD_TABLE = { {"match", TokenType::Match}, {"extern", TokenType::Extern}}; struct Tokenizer { - DiagnosticContext *diag_ctx; + DiagnosticContext* diag_ctx; const SourceName& source_name; size_t pos; @@ -104,7 +101,8 @@ struct Tokenizer { } Token NewToken(TokenType token_type, ObjectRef data = ObjectRef(), int lines = 0, int cols = 1) { - auto span = Span(this->source_name, this->line, this->col, this->line + lines, this->col + cols); + auto span = + Span(this->source_name, this->line, this->col, this->line + lines, this->col + cols); return Token(span, token_type, data); } @@ -269,7 +267,8 @@ struct Tokenizer { ObjectRef metadata_map = tvm::LoadJSON(metadata.str()); auto span = SpanFrom(line, column); return Token(span, TokenType::Metadata, metadata_map); - } if (attribute.rfind("version", 0) == 0) { + } + if (attribute.rfind("version", 0) == 0) { std::string version = attribute.substr(attribute.find("=") + 1); ltrim(version); rtrim(version); @@ -278,15 +277,15 @@ struct Tokenizer { } else { // TOOD(@jroesch): maybe make this a warning an continue parsing? this->diag_ctx->EmitFatal( - DiagnosticBuilder(DiagnosticLevel::Error, SourceName(), line, column) + DiagnosticBuilder(DiagnosticLevel::Error, SourceName(), line, column) << "unsupported attribute " << attribute); return Token(); } } else { this->diag_ctx->EmitFatal( DiagnosticBuilder(DiagnosticLevel::Error, SourceName(), line, column) - << "`#` denotes the start of an attribute can only be followed by `[`" - << " found `" << Peek() << "`"); + << "`#` denotes the start of an attribute can only be followed by `[`" + << " found `" << Peek() << "`"); return Token(); } } @@ -305,7 +304,7 @@ struct Tokenizer { return token; } else { this->diag_ctx->EmitFatal( - DiagnosticBuilder(DiagnosticLevel::Error, SourceName(), this->line, this->col) + DiagnosticBuilder(DiagnosticLevel::Error, SourceName(), this->line, this->col) << "\\r carriage returns must be followed by a \\n in the TVM text format"); return Token(); } @@ -522,7 +521,15 @@ struct Tokenizer { this->tokens.push_back(NewToken(TokenType::EndOfFile)); } - explicit Tokenizer(DiagnosticContext *ctx, const SourceName& source_name, const std::string& source) : diag_ctx(ctx), source_name(source_name), pos(0), col(1), line(1), source(source), tokens() {} + explicit Tokenizer(DiagnosticContext* ctx, const SourceName& source_name, + const std::string& source) + : diag_ctx(ctx), + source_name(source_name), + pos(0), + col(1), + line(1), + source(source), + tokens() {} }; std::vector Condense(const std::vector& tokens) { @@ -594,7 +601,8 @@ std::vector Condense(const std::vector& tokens) { return out; } -std::vector Tokenize(DiagnosticContext *ctx, const SourceName& source_name, const std::string& source) { +std::vector Tokenize(DiagnosticContext* ctx, const SourceName& source_name, + const std::string& source) { auto tokenizer = Tokenizer(ctx, source_name, source); tokenizer.Tokenize(); auto tokens = Condense(tokenizer.tokens); diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 25f027c3aa5c..90cf428f1ca1 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -39,8 +39,8 @@ #include #include "../ir/attr_functor.h" -#include "../relay/analysis/dependency_graph.h" #include "../parser/meta_ref.h" +#include "../relay/analysis/dependency_graph.h" #include "doc.h" #include "meta_data.h" #include "text_printer.h" diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 867d48b0ef7b..716989fbbf75 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -388,8 +388,9 @@ class TextPrinter { doc << "#[metadata]" << Doc::NewLine() << meta_.GetMetaSection(); } else { doc << "/* For debugging purposes the metadata section has been omitted." << Doc::NewLine() - << " * If you would like to see the full metadata section you can set the `show_meta_data`" << Doc::NewLine() - << " * option to `True` when invoking `astext`. " << Doc::NewLine() + << " * If you would like to see the full metadata section you can set the " + "`show_meta_data`" + << Doc::NewLine() << " * option to `True` when invoking `astext`. " << Doc::NewLine() << " */"; } } diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index c39b7f76818a..cbc41d225d4b 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -393,13 +393,9 @@ void ExprVisitor::VisitExpr_(const VarNode* op) { } } -void ExprVisitor::VisitExpr_(const GlobalVarNode* op) { - this->VisitSpan(op->span); -} +void ExprVisitor::VisitExpr_(const GlobalVarNode* op) { this->VisitSpan(op->span); } -void ExprVisitor::VisitExpr_(const ConstantNode* op) { - this->VisitSpan(op->span); -} +void ExprVisitor::VisitExpr_(const ConstantNode* op) { this->VisitSpan(op->span); } void ExprVisitor::VisitExpr_(const TupleNode* op) { this->VisitSpan(op->span); @@ -444,9 +440,7 @@ void ExprVisitor::VisitExpr_(const IfNode* op) { this->VisitExpr(op->false_branch); } -void ExprVisitor::VisitExpr_(const OpNode* op) { - return; -} +void ExprVisitor::VisitExpr_(const OpNode* op) { return; } void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { this->VisitSpan(op->span); @@ -455,7 +449,8 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { void ExprVisitor::VisitExpr_(const RefCreateNode* op) { this->VisitSpan(op->span); - this->VisitExpr(op->value); } + this->VisitExpr(op->value); +} void ExprVisitor::VisitExpr_(const RefReadNode* op) { this->VisitSpan(op->span); From 07f87899c6d9265136e1a85204a6d730a7dc24fc Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 29 Jul 2020 22:16:24 -0700 Subject: [PATCH 20/48] Fix --- include/tvm/parser/source_map.h | 2 ++ src/parser/meta_ref.cc | 2 +- src/parser/meta_ref.h | 3 ++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/include/tvm/parser/source_map.h b/include/tvm/parser/source_map.h index 00a6c09ebaf0..467616a83da5 100644 --- a/include/tvm/parser/source_map.h +++ b/include/tvm/parser/source_map.h @@ -29,6 +29,8 @@ #include #include +#include +#include namespace tvm { namespace parser { diff --git a/src/parser/meta_ref.cc b/src/parser/meta_ref.cc index 0bedf8a353fa..f763e757c311 100644 --- a/src/parser/meta_ref.cc +++ b/src/parser/meta_ref.cc @@ -70,7 +70,7 @@ Expr MetaRef(std::string type_key, uint64_t node_index) { struct MetaRefExpander : public ExprMutator { MetaTable table; - MetaRefExpander(const MetaTable& table) : table(table) {} + explicit MetaRefExpander(const MetaTable& table) : table(table) {} Expr VisitExpr_(const CallNode* call) final { if (auto op_node = call->op.as()) { diff --git a/src/parser/meta_ref.h b/src/parser/meta_ref.h index a51985ceef28..a8960872e1eb 100644 --- a/src/parser/meta_ref.h +++ b/src/parser/meta_ref.h @@ -26,6 +26,7 @@ #define TVM_PARSER_META_REF_H_ #include +#include #include #include @@ -75,7 +76,7 @@ struct MetaRefAttrs : public tvm::AttrsNode { */ Expr MetaRef(std::string type_key, uint64_t node_index); -Function ExpandMetaRefs(const MetaTable& meta_table, const Function& mod); +relay::Function ExpandMetaRefs(const MetaTable& meta_table, const relay::Function& mod); IRModule ExpandMetaRefs(const MetaTable& meta_table, const IRModule& mod); } // namespace parser From 1fcde6313d6fc2f605a3dc6768f4026be944be1e Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 29 Jul 2020 22:17:08 -0700 Subject: [PATCH 21/48] Fix --- src/parser/meta_ref.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/parser/meta_ref.h b/src/parser/meta_ref.h index a8960872e1eb..40e3fdbb7a8b 100644 --- a/src/parser/meta_ref.h +++ b/src/parser/meta_ref.h @@ -26,8 +26,8 @@ #define TVM_PARSER_META_REF_H_ #include -#include #include +#include #include From 13a4e9642c2b4eff7934228f5fa52c61d07fb2f2 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 29 Jul 2020 22:28:13 -0700 Subject: [PATCH 22/48] disable reference lint from string helpers --- src/parser/tokenizer.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index 43f23d231077..945c269a20eb 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -42,12 +42,12 @@ namespace parser { using namespace runtime; // trim from start (in place) -static inline void ltrim(std::string& s) { +static inline void ltrim(std::string& s) { // NOLINT(*) s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) { return !std::isspace(ch); })); } // trim from end (in place) -static inline void rtrim(std::string& s) { +static inline void rtrim(std::string& s) { // NOLINT(*) s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) { return !std::isspace(ch); }).base(), s.end()); } From 4314584eea49cdad81b5469776d7930a3959f8ec Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 29 Jul 2020 22:31:52 -0700 Subject: [PATCH 23/48] Fix tokenizer --- src/parser/tokenizer.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index 945c269a20eb..2a2cf83cd4dd 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -42,12 +42,12 @@ namespace parser { using namespace runtime; // trim from start (in place) -static inline void ltrim(std::string& s) { // NOLINT(*) +static inline void ltrim(std::string& s) { // NOLINT(*) s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) { return !std::isspace(ch); })); } // trim from end (in place) -static inline void rtrim(std::string& s) { // NOLINT(*) +static inline void rtrim(std::string& s) { // NOLINT(*) s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) { return !std::isspace(ch); }).base(), s.end()); } From 46dbc418c198a6950c9e413b47ec9caf2ddc6779 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 30 Jul 2020 11:24:52 -0700 Subject: [PATCH 24/48] Fix issue with empty metadata section --- src/parser/parser.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 05e518226441..470c2299b1cb 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -575,7 +575,7 @@ class Parser { auto mod = IRModule({}, types); for (auto func : defs.funcs) { - auto function = ExpandMetaRefs(metadata.value(), func.function); + auto function = ExpandMetaRefs(metadata, func.function); mod->Add(func.global, function); } @@ -1434,12 +1434,11 @@ class Parser { return res; } - // TODO(@jroesch): this is the final remaining feature. - Optional>> ParseMetadata() { + Map> ParseMetadata() { if (Peek()->token_type == TokenType::Metadata) { return Match(TokenType::Metadata).ToMetadata(); } else { - return Optional>>(); + return Map>(); } } From 9932b9d9c7dd573e37b666c85b6e40eac05e2966 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 30 Jul 2020 11:57:28 -0700 Subject: [PATCH 25/48] Document new span fields and small tweaks --- include/tvm/parser/source_map.h | 7 ++++--- include/tvm/relay/expr.h | 13 +++++++++++-- src/parser/diagnostic.h | 2 +- src/parser/source_map.cc | 5 ++++- 4 files changed, 20 insertions(+), 7 deletions(-) diff --git a/include/tvm/parser/source_map.h b/include/tvm/parser/source_map.h index 467616a83da5..9ca8a739fd81 100644 --- a/include/tvm/parser/source_map.h +++ b/include/tvm/parser/source_map.h @@ -60,11 +60,12 @@ struct Source { * The error is written directly to the `out` std::ostream. * * \param out The output ostream. - * \param line The line at which to report a diagnostic. - * \param line The column at which to report a diagnostic. + * \param span The span to report the error at. * \param msg The message to attach. + * + * TODO(@jroesch): replace the ostream with an interface for rendering errors. */ - TVM_DLL void ReportAt(std::ostream& out, int line, int column, const std::string& msg) const; + TVM_DLL void ReportAt(std::ostream& out, const Span& span, const std::string& msg) const; }; /*! diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index aee2016d669c..2b714a388cac 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -88,6 +88,7 @@ class Constant : public Expr { /*! * \brief The constructor * \param data The data of the constant tensor. + * \param span The source span of the expression. */ TVM_DLL explicit Constant(runtime::NDArray data, Span span = Span()); @@ -134,6 +135,7 @@ class Tuple : public Expr { /*! * \brief The constructor * \param fields The fields of a tuple. + * \param span The source span of the expression. */ TVM_DLL explicit Tuple(tvm::Array fields, Span span = Span()); @@ -202,6 +204,7 @@ class Var : public Expr { * \brief The constructor * \param name_hint The name hint of a variable. * \param type_annotation The type annotation of a variable. + * \param span The source span of the expression. */ TVM_DLL Var(String name_hint, Type type_annotation, Span span = Span()) : Var(Id(name_hint), type_annotation, span) {} @@ -210,6 +213,7 @@ class Var : public Expr { * \brief The constructor * \param vid The unique id of a variable. * \param type_annotation The type annotation of a variable. + * \param span The source span of the expression. */ TVM_DLL Var(Id vid, Type type_annotation, Span span = Span()); @@ -296,6 +300,7 @@ class Call : public Expr { * \param args The arguments of the call. * \param attrs The attributes of the call node. * \param type_args The type arguments passed to a polymorphic function. + * \param span The source span of the expression. */ TVM_DLL Call(Expr op, Array args, Attrs attrs = Attrs(), Array type_args = Array(), Span span = Span()); @@ -357,6 +362,7 @@ class Let : public Expr { * \param var The variable that is bound to. * \param value The value used to bind to the variable. * \param body The body of the let binding. + * \param span The source span of the expression. */ TVM_DLL Let(Var var, Expr value, Expr body, Span span = Span()); @@ -417,6 +423,7 @@ class If : public Expr { * \param cond The condition of a if node. * \param true_branch The fall through branch * \param false_branch The branch for execution when condition is false. + * \param span The source span of the expression. */ TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch, Span span = Span()); @@ -458,6 +465,7 @@ class TupleGetItem : public Expr { * \brief The constructor * \param tuple The tuple to get an element from. * \param index The index for extracting a value in the tuple. + * \param span The source span of the expression. */ TVM_DLL TupleGetItem(Expr tuple, int index, Span span = Span()); @@ -496,6 +504,7 @@ class RefCreate : public Expr { /*! * \brief The constructor * \param value The initial value of the reference. + * \param span The source span of the expression. */ TVM_DLL explicit RefCreate(Expr value, Span span = Span()); @@ -534,6 +543,7 @@ class RefRead : public Expr { /*! * \brief The constructor * \param ref The reference where to read data. + * \param span The source span of the expression. */ TVM_DLL explicit RefRead(Expr ref, Span span = Span()); @@ -566,8 +576,6 @@ class RefWriteNode : public ExprNode { hash_reduce(value); } - TVM_DLL static RefWrite make(Expr ref, Expr value); - static constexpr const char* _type_key = "relay.RefWrite"; TVM_DECLARE_FINAL_OBJECT_INFO(RefWriteNode, ExprNode); }; @@ -578,6 +586,7 @@ class RefWrite : public Expr { * \brief The constructor * \param ref The reference where data is write to. * \param value The value to write. + * \param span The source span of the expression. */ TVM_DLL RefWrite(Expr ref, Expr value, Span span = Span()); diff --git a/src/parser/diagnostic.h b/src/parser/diagnostic.h index 881a1903481b..4a0bf6936ef1 100644 --- a/src/parser/diagnostic.h +++ b/src/parser/diagnostic.h @@ -162,7 +162,7 @@ struct DiagnosticContext { // format errors. void Render(std::ostream& ostream) { for (auto diagnostic : diagnostics) { - source.ReportAt(ostream, diagnostic.span->line, diagnostic.span->column, diagnostic.message); + source.ReportAt(ostream, diagnostic.span, diagnostic.message); } if (diagnostics.size()) { diff --git a/src/parser/source_map.cc b/src/parser/source_map.cc index fe9587cd4ed7..cddf0c79a0e4 100644 --- a/src/parser/source_map.cc +++ b/src/parser/source_map.cc @@ -59,7 +59,10 @@ Source::Source(const std::string& source) : source(source) { * \param line The column at which to report a diagnostic. * \param msg The message to attach. */ -void Source::ReportAt(std::ostream& out, int line, int column, const std::string& msg) const { +void Source::ReportAt(std::ostream& out, const Span& span, const std::string& msg) const { + int line = span->line; + int column= span->column; + CHECK(line - 1 <= static_cast(line_map.size())) << "requested line: " << (line - 1) << "line_map size: " << line_map.size() << "source: " << source; From bb8171cbdd8eeae0aa0538276c2d3ed1eb9483ff Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 30 Jul 2020 12:47:47 -0700 Subject: [PATCH 26/48] Formatting --- include/tvm/relay/expr.h | 4 ---- src/parser/source_map.cc | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 2b714a388cac..d0c9217958f0 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -190,10 +190,6 @@ class VarNode : public ExprNode { hash_reduce.FreeVarHashImpl(this); } - TVM_DLL static Var make(String name_hint, Type type_annotation); - - TVM_DLL static Var make(Id vid, Type type_annotation); - static constexpr const char* _type_key = "relay.Var"; TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, ExprNode); }; diff --git a/src/parser/source_map.cc b/src/parser/source_map.cc index cddf0c79a0e4..e77e53ecc1b4 100644 --- a/src/parser/source_map.cc +++ b/src/parser/source_map.cc @@ -61,7 +61,7 @@ Source::Source(const std::string& source) : source(source) { */ void Source::ReportAt(std::ostream& out, const Span& span, const std::string& msg) const { int line = span->line; - int column= span->column; + int column = span->column; CHECK(line - 1 <= static_cast(line_map.size())) << "requested line: " << (line - 1) << "line_map size: " << line_map.size() From 597bf253336e1476e6a5c72bd4d5510d214613ef Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 30 Jul 2020 13:02:45 -0700 Subject: [PATCH 27/48] Add span doc fields --- include/tvm/parser/source_map.h | 2 +- include/tvm/relay/adt.h | 1 + include/tvm/relay/function.h | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/include/tvm/parser/source_map.h b/include/tvm/parser/source_map.h index 9ca8a739fd81..a5ec3ef5e188 100644 --- a/include/tvm/parser/source_map.h +++ b/include/tvm/parser/source_map.h @@ -63,8 +63,8 @@ struct Source { * \param span The span to report the error at. * \param msg The message to attach. * - * TODO(@jroesch): replace the ostream with an interface for rendering errors. */ + // TODO(@jroesch): replace the ostream with an interface for rendering errors. TVM_DLL void ReportAt(std::ostream& out, const Span& span, const std::string& msg) const; }; diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index 4f58e957aaae..37182abb2681 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -309,6 +309,7 @@ class Match : public Expr { * \param data the input being deconstructed. * \param clauses The clauses for matching. * \param complete Indicate if this match is complete. + * \param span The span of the expression. */ TVM_DLL Match(Expr data, tvm::Array clauses, bool complete = true, Span span = Span()); diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index 1ae1cbd6ceae..db973b91f92a 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -110,6 +110,7 @@ class Function : public BaseFunc { * \param ret_type The return type of the function. * \param ty_params The type parameters. * \param attrs Additional function attributes. + * \param span The span of the function. */ TVM_DLL Function(tvm::Array params, Expr body, Type ret_type, tvm::Array ty_params, tvm::DictAttrs attrs = NullValue(), Span span = Span()); From dc158c1927d1ec8af89aee389f178d912ec2697e Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 30 Jul 2020 13:07:36 -0700 Subject: [PATCH 28/48] Add format tweak --- include/tvm/parser/source_map.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/parser/source_map.h b/include/tvm/parser/source_map.h index a5ec3ef5e188..98583ec549ba 100644 --- a/include/tvm/parser/source_map.h +++ b/include/tvm/parser/source_map.h @@ -64,7 +64,7 @@ struct Source { * \param msg The message to attach. * */ - // TODO(@jroesch): replace the ostream with an interface for rendering errors. + // TODO(@jroesch): replace the ostream with an interface for rendering errors. TVM_DLL void ReportAt(std::ostream& out, const Span& span, const std::string& msg) const; }; From 0e54e6f85bcbd2d6944cc1807130981c62a8d6c2 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 30 Jul 2020 16:42:10 -0700 Subject: [PATCH 29/48] Improve errors and fix the semantic version tags in Prelude --- .gitignore | 4 ---- include/tvm/parser/parser.h | 2 +- python/tvm/relay/std/core.rly | 3 ++- python/tvm/relay/std/prelude.rly | 2 +- src/ir/module.cc | 6 ++---- src/parser/parser.cc | 2 +- src/parser/source_map.cc | 1 + 7 files changed, 8 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index 506e54d93067..66eb0cb4f866 100644 --- a/.gitignore +++ b/.gitignore @@ -230,7 +230,3 @@ conda/pkg # nix files .envrc *.nix - -# antlr files -*.tokens -*.interp diff --git a/include/tvm/parser/parser.h b/include/tvm/parser/parser.h index 93803588383a..5c1239b1f59e 100644 --- a/include/tvm/parser/parser.h +++ b/include/tvm/parser/parser.h @@ -32,7 +32,7 @@ namespace tvm { namespace parser { -IRModule Parse(std::string file_name, std::string file_content); +IRModule ParseModule(std::string file_name, std::string file_content); } // namespace parser } // namespace tvm diff --git a/python/tvm/relay/std/core.rly b/python/tvm/relay/std/core.rly index 6a3facc3424c..f469491a56f1 100644 --- a/python/tvm/relay/std/core.rly +++ b/python/tvm/relay/std/core.rly @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -v0.0.4 + +#[version = "0.0.5"] extern type Storage diff --git a/python/tvm/relay/std/prelude.rly b/python/tvm/relay/std/prelude.rly index fa05d1a7bd98..a22f46c6d18b 100644 --- a/python/tvm/relay/std/prelude.rly +++ b/python/tvm/relay/std/prelude.rly @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -v0.0.4 +#[version = "0.0.4"] // TODO(weberlo): should we add sugar for scalar types (e.g., `int32` => `Tensor[(), int32]`)? diff --git a/src/ir/module.cc b/src/ir/module.cc index 25ecab2455cb..1bcf7d928c51 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -31,6 +31,7 @@ // Rationale: We calls into relay's analysis module to verify correctness. #include #include +#include #include #include @@ -371,10 +372,7 @@ void IRModuleNode::ImportFromStd(const String& path) { std::unordered_set IRModuleNode::Imports() const { return this->import_set_; } IRModule IRModule::FromText(const String& text, const String& source_path) { - auto* f = tvm::runtime::Registry::Get("relay.fromtext"); - CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path."; - IRModule mod = (*f)(text, source_path); - return mod; + return tvm::parser::ParseModule(source_path, text); } TVM_REGISTER_NODE_TYPE(IRModuleNode); diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 470c2299b1cb..375df1abc818 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -593,7 +593,7 @@ class Parser { } } else if (required) { this->diag_ctx->Emit(DiagnosticBuilder(DiagnosticLevel::Error, Peek()->span) - << "expected text format semantic version " + << "expected text format semantic version, found a " << PrettyPrint(Peek()) << "you can annotate it as #[version = \"0.0.5\"]"); } return SemVer(0, 0, 5); diff --git a/src/parser/source_map.cc b/src/parser/source_map.cc index e77e53ecc1b4..11789a5bf4d3 100644 --- a/src/parser/source_map.cc +++ b/src/parser/source_map.cc @@ -60,6 +60,7 @@ Source::Source(const std::string& source) : source(source) { * \param msg The message to attach. */ void Source::ReportAt(std::ostream& out, const Span& span, const std::string& msg) const { + DLOG(INFO) << "Source::ReportAt" << "span = " << span << "msg = " << msg; int line = span->line; int column = span->column; From 4557fede16e6534e0632277b521427f1cd8f7963 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 30 Jul 2020 16:44:35 -0700 Subject: [PATCH 30/48] Update gradient.rly --- python/tvm/relay/std/gradient.rly | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/std/gradient.rly b/python/tvm/relay/std/gradient.rly index ed81e4b2d454..7594f4ebc5f4 100644 --- a/python/tvm/relay/std/gradient.rly +++ b/python/tvm/relay/std/gradient.rly @@ -16,7 +16,8 @@ * specific language governing permissions and limitations * under the License. */ -v0.0.4 + +#[version = "0.0.5"] /* * Store the Gradient Value of a Tensor of type T. From 7ef4e76c299d5d5113be50e8aa87181b75be86b3 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 30 Jul 2020 17:05:46 -0700 Subject: [PATCH 31/48] Clean up broken spans --- src/parser/diagnostic.h | 52 ++++++++++++++++++---------- src/parser/tokenizer.h | 14 +++++--- tests/python/relay/test_ir_parser.py | 8 ++--- 3 files changed, 44 insertions(+), 30 deletions(-) diff --git a/src/parser/diagnostic.h b/src/parser/diagnostic.h index 4a0bf6936ef1..4208da52f000 100644 --- a/src/parser/diagnostic.h +++ b/src/parser/diagnostic.h @@ -52,6 +52,8 @@ enum DiagnosticLevel { Help, }; +struct DiagnosticBuilder; + /*! \brief A diagnostic message. */ struct Diagnostic { /*! \brief The level. */ @@ -69,6 +71,12 @@ struct Diagnostic { Diagnostic(DiagnosticLevel level, Span span, const std::string& message) : level(level), span(span), message(message) {} + + static DiagnosticBuilder Bug(Span span); + static DiagnosticBuilder Error(Span span); + static DiagnosticBuilder Warning(Span span); + static DiagnosticBuilder Note(Span span); + static DiagnosticBuilder Help(Span span); }; /*! @@ -96,37 +104,23 @@ struct DiagnosticBuilder { /*! \brief The span of the diagnostic. */ Span span; - /*! \brief The line number. */ - int line; - - /*! \brief The column number. */ - int column; - - /*! \brief The line number. */ - int end_line; - - /*! \brief The column number. */ - int end_column; - template DiagnosticBuilder& operator<<(const T& val) { // NOLINT(*) stream_ << val; return *this; } - DiagnosticBuilder() : level(DiagnosticLevel::Error), source_name(), line(0), column(0) {} + DiagnosticBuilder() : level(DiagnosticLevel::Error), source_name(), span(Span()) {} + DiagnosticBuilder(const DiagnosticBuilder& builder) : level(builder.level), source_name(builder.source_name), - line(builder.line), - column(builder.column) {} - DiagnosticBuilder(DiagnosticLevel level, SourceName source_name, int line, int column) - : level(level), source_name(source_name), line(line), column(column) {} + span(builder.span) {} + DiagnosticBuilder(DiagnosticLevel level, Span span) : level(level), span(span) {} operator Diagnostic() { - auto span = Span(this->source_name, this->line, this->column, this->end_line, this->end_column); - return Diagnostic(this->level, span, this->stream_.str()); + return Diagnostic(this->level, this->span, this->stream_.str()); } private: @@ -134,6 +128,26 @@ struct DiagnosticBuilder { friend struct Diagnostic; }; + DiagnosticBuilder Diagnostic::Bug(Span span) { + return DiagnosticBuilder(DiagnosticLevel::Bug, span); + } + + DiagnosticBuilder Diagnostic::Error(Span span) { + return DiagnosticBuilder(DiagnosticLevel::Error, span); + } + + DiagnosticBuilder Diagnostic::Warning(Span span) { + return DiagnosticBuilder(DiagnosticLevel::Warning, span); + } + + DiagnosticBuilder Diagnostic::Note(Span span) { + return DiagnosticBuilder(DiagnosticLevel::Note, span); + } + + DiagnosticBuilder Diagnostic::Help(Span span) { + return DiagnosticBuilder(DiagnosticLevel::Note, span); + } + /*! \brief A diagnostic context for recording errors against a source file. * TODO(@jroesch): convert source map and improve in follow up PR, the parser * assumes a single global file for now. diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index 2a2cf83cd4dd..d9ed4f308c96 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -276,14 +276,15 @@ struct Tokenizer { return Token(span, TokenType::Version, tvm::String(version)); } else { // TOOD(@jroesch): maybe make this a warning an continue parsing? - this->diag_ctx->EmitFatal( - DiagnosticBuilder(DiagnosticLevel::Error, SourceName(), line, column) - << "unsupported attribute " << attribute); + auto span = SpanFrom(line, column); + this->diag_ctx->EmitFatal(Diagnostic::Error(span) + << "unsupported attribute " << attribute); return Token(); } } else { + auto span = SpanFrom(line, column); this->diag_ctx->EmitFatal( - DiagnosticBuilder(DiagnosticLevel::Error, SourceName(), line, column) + Diagnostic::Error(span) << "`#` denotes the start of an attribute can only be followed by `[`" << " found `" << Peek() << "`"); return Token(); @@ -291,6 +292,8 @@ struct Tokenizer { } inline Token TokenizeOnce() { + int line = this->line; + int col = this->col; auto next = Peek(); DLOG(INFO) << "tvm::parser::TokenizeOnce: next=" << next; if (next == '\n') { @@ -303,8 +306,9 @@ struct Tokenizer { auto token = NewToken(TokenType::Newline); return token; } else { + auto span = SpanFrom(line, col); this->diag_ctx->EmitFatal( - DiagnosticBuilder(DiagnosticLevel::Error, SourceName(), this->line, this->col) + Diagnostic::Error(span) << "\\r carriage returns must be followed by a \\n in the TVM text format"); return Token(); } diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index b3adbf5d9e56..b3cedb4d623e 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -928,15 +928,11 @@ def inline_params(mod, params): def test_resnet_inlined_params(): mod, params = relay.testing.resnet.get_workload() - print("here") mod = inline_params(mod, params) - print("here") text = mod.astext() - print("here") parsed_mod = parse_module(text) - print("here") tvm.ir.assert_structural_equal(mod, parsed_mod) - print("here") if __name__ == "__main__": - test_resnet_inlined_params() + import sys + pytest.main(sys.argv) From 10f2531b665a866697cf5805b64c5edcefee32c7 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 30 Jul 2020 18:03:43 -0700 Subject: [PATCH 32/48] Clean up parser tests and turn on previously skipped tests --- python/tvm/error.py | 4 + src/parser/diagnostic.h | 4 +- src/parser/parser.cc | 26 +++++- tests/python/relay/test_ir_parser.py | 122 ++++++++++++--------------- 4 files changed, 83 insertions(+), 73 deletions(-) diff --git a/python/tvm/error.py b/python/tvm/error.py index b3502f6b0ead..398d22ada25c 100644 --- a/python/tvm/error.py +++ b/python/tvm/error.py @@ -121,3 +121,7 @@ class OpAttributeUnImplemented(OpError, NotImplementedError): "Attribute {} is not supported in operator {}".format( attr_name, op_name)) """ + +@register_error +class DiagnosticError(TVMError): + pass diff --git a/src/parser/diagnostic.h b/src/parser/diagnostic.h index 4208da52f000..61543a7cb2f1 100644 --- a/src/parser/diagnostic.h +++ b/src/parser/diagnostic.h @@ -168,8 +168,6 @@ struct DiagnosticContext { void EmitFatal(const Diagnostic& diagnostic) { diagnostics.push_back(diagnostic); Render(std::cout); - // TODO(@jroesch): throw exception which is caught at the pass boundary and then rendered. - LOG(FATAL) << "error occurred"; } // TODO(@jroesch): eventually modularize the rendering interface to provide control of how to @@ -180,7 +178,7 @@ struct DiagnosticContext { } if (diagnostics.size()) { - LOG(FATAL) << "parse error occurred"; + LOG(FATAL) << "DiagnosticError: one or more error diagnostics were emitted, please check diagnostic render for output."; } } }; diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 375df1abc818..21c69b3afe60 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -145,6 +145,7 @@ template struct InternTable { /*! \brief The internal table mapping strings to a unique allocation. */ std::unordered_map table; + DiagnosticContext *ctx; /*! \brief Add the unique allocation. */ void Add(const std::string& name, const T& t) { @@ -935,11 +936,21 @@ class Parser { Consume(TokenType::If); auto guard = Parens([&] { return ParseExpr(); }); - auto true_branch = Block([&] { return ParseExpr(); }); + auto true_branch = Block([&] { + this->PushScope(); + auto expr = ParseExpr(); + this->PopScopes(1); + return expr; + }); Match(TokenType::Else); - auto false_branch = Block([&] { return ParseExpr(); }); + auto false_branch = Block([&] { + this->PushScope(); + auto expr = ParseExpr(); + this->PopScopes(1); + return expr; + }); return relay::If(guard, true_branch, false_branch); } @@ -1475,7 +1486,12 @@ IRModule ParseModule(std::string file_name, std::string file_content) { DiagnosticContext ctx(src); auto tokens = Tokenize(&ctx, src_name, file_content); Parser parser(&ctx, src_name, tokens, DefaultOpTable(), Source(file_content)); - return parser.ParseModule(); + auto mod = parser.ParseModule(); + // NB(@jroesch): it is very important that we render any errors before we procede + // if there were any errors which allow the parser to procede we must render them + // here. + parser.diag_ctx->Render(std::cout); + return mod; } Expr ParseExpr(std::string file_name, std::string file_content) { @@ -1489,6 +1505,10 @@ Expr ParseExpr(std::string file_name, std::string file_content) { parser.PushScope(); auto expr = parser.ParseExpr(); parser.Match(TokenType::EndOfFile); + // NB(@jroesch): it is very important that we render any errors before we procede + // if there were any errors which allow the parser to procede we must render them + // here. + parser.diag_ctx->Render(std::cout); return expr; } diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index b3cedb4d623e..2ec24cc50c25 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -22,9 +22,10 @@ from numpy import isclose from typing import Union from functools import wraps -raises_parse_error = pytest.mark.xfail(raises=tvm._ffi.base.TVMError) -SEMVER = "v0.0.4" + + +SEMVER = "#[version = \"0.0.5\"]\n" BINARY_OPS = { "*": relay.multiply, @@ -79,6 +80,7 @@ def roundtrip_expr(expr): x = tvm.parser.parse_expr(text) assert_graph_equal(x, expr) +# Testing Utilities for expressions. def roundtrip(expr): x = tvm.parser.fromtext(expr.astext()) assert_graph_equal(x, expr) @@ -88,15 +90,15 @@ def parse_text(code): roundtrip_expr(expr) return expr - def parses_as(code, expr): # type: (str, relay.Expr) -> bool parsed = parse_text(code) result = graph_equal(parsed, expr) return result +# Testing Utilities for full modules. def parse_module(code): - mod = tvm.parser.parse(code) + mod = tvm.parser.parse(SEMVER + code) roundtrip(mod) return mod @@ -234,7 +236,7 @@ def test_vars(): def test_meta_ref(): meta_op = parse_text("meta[type_key][1337]") assert meta_op.attrs.node_type_key == "type_key" - assert meta_op.attrs.node_index == "1337" + assert meta_op.attrs.node_index == 1337 def test_let(): @@ -292,19 +294,17 @@ def test_graph(): ) -@raises_parse_error -def test_graph_wrong_order(): - parse_text("%1 = (); %1") +def test_graph_single(): + assert_parses_as("%1 = (); %1", relay.Tuple([])) - -@raises_parse_error def test_let_global_var(): - parse_text("let @x = 1; ()") + with pytest.raises(tvm.error.DiagnosticError): + parse_text("let @x = 1; ()") -@raises_parse_error def test_let_op(): - parse_text("let x = 1; ()") + with pytest.raises(tvm.error.DiagnosticError): + parse_text("let x = 1; ()") def test_tuple(): @@ -409,18 +409,18 @@ def test_ifelse(): ) -@raises_parse_error def test_ifelse_scope(): - parse_text( - """ - if (True) { - let %x = (); - () - } else { - %x - } - """ - ) + with pytest.raises(tvm.error.DiagnosticError): + parse_text( + """ + if (True) { + let %x = (); + () + } else { + %x + } + """ + ) def test_call(): @@ -830,52 +830,51 @@ def @make_singleton(%%x: int32) -> List[int32] { ) -@raises_parse_error def test_duplicate_adt_defn(): - parse_module( - """ - %s + with pytest.raises(tvm.error.DiagnosticError): + parse_module( + """ + %s - type List[A] { - Cons(A, List[A]), - Nil, - } - """ % LIST_DEFN - ) + type List[A] { + Cons(A, List[A]), + Nil, + } + """ % LIST_DEFN + ) -@raises_parse_error def test_duplicate_adt_cons(): - parse_text( - """ - type Ayy { Lmao } - type Haha { Lmao } - """ - ) + with pytest.raises(tvm.error.DiagnosticError): + parse_text( + """ + type Ayy { Lmao } + type Haha { Lmao } + """ + ) -@raises_parse_error def test_duplicate_adt_cons_defn(): - parse_text( - """ - type Ayy { Lmao } - type Lmao { Ayy } - """ - ) + with pytest.raises(tvm.error.DiagnosticError): + parse_text( + """ + type Ayy { Lmao } + type Lmao { Ayy } + """ + ) -@raises_parse_error def test_duplicate_global_var(): - parse_text( - """ - def @id[A](%x: A) -> A { x } - def @id[A](%x: A) -> A { x } - """ - ) + with pytest.raises(tvm.error.DiagnosticError): + parse_text( + """ + def @id[A](%x: A) -> A { x } + def @id[A](%x: A) -> A { x } + """ + ) def test_extern_adt_defn(): - # TODO(weberlo): update this test once extern is implemented mod = tvm.IRModule() extern_var = relay.GlobalTypeVar("T") @@ -890,21 +889,10 @@ def test_extern_adt_defn(): mod ) -@pytest.mark.skip("not yet tested on parser 2.0") def test_import_grad(): mod = tvm.IRModule() mod.import_from_std("gradient.rly") -# hiearchy id, i.e parse nn.conv2d -# do with multiple levels -# -# call attributes not correctly parsing -# convert error from attribute construction to real error message -# lexing issue with projection of graph variables - -# def test_hierarchical_identifiers(): -# assert False - def test_resnet(): mod, params = relay.testing.resnet.get_workload() text = str(mod.astext()) From 16c92760065a937c449ba1837350a7e45e1cd937 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 30 Jul 2020 22:06:42 -0700 Subject: [PATCH 33/48] Update errors to handle skipped cases --- src/parser/parser.cc | 46 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 39 insertions(+), 7 deletions(-) diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 21c69b3afe60..fa90d4e00f96 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -140,6 +140,10 @@ class ScopeStack { void PopStack() { this->scope_stack.pop_back(); } }; +struct DuplicateKeyError : public dmlc::Error { + DuplicateKeyError(const std::string& msg) : dmlc::Error(msg) {} +}; + /*! \brief A table of interning strings as global function and type names. */ template struct InternTable { @@ -151,7 +155,7 @@ struct InternTable { void Add(const std::string& name, const T& t) { auto it = table.find(name); if (it != table.end()) { - LOG(FATAL) << "duplicate name"; + throw DuplicateKeyError("duplicate key name in intern table"); } else { table.insert({name, t}); } @@ -609,9 +613,18 @@ class Parser { switch (next->token_type) { case TokenType::Defn: { Consume(TokenType::Defn); - auto global_name = Match(TokenType::Global).ToString(); + auto global_tok = Match(TokenType::Global); + auto global_name = global_tok.ToString(); auto global = GlobalVar(global_name); - global_names.Add(global_name, global); + try { + global_names.Add(global_name, global); + } catch (DuplicateKeyError e) { + this->diag_ctx->Emit( + Diagnostic::Error(global_tok->span) + << "a function with the name " + << "`@" << global_name << "` " + << "was previously defined"); + } auto func = ParseFunctionDef(); defs.funcs.push_back(GlobalFunc(global, func)); continue; @@ -640,9 +653,19 @@ class Parser { // Match the `type` keyword. Match(TokenType::TypeDef); // Parse the type's identifier. - auto type_id = Match(TokenType::Identifier).ToString(); + auto type_tok = Match(TokenType::Identifier); + auto type_id = type_tok.ToString(); auto type_global = tvm::GlobalTypeVar(type_id, TypeKind::kAdtHandle); - type_names.Add(type_id, type_global); + + try { + type_names.Add(type_id, type_global); + } catch (DuplicateKeyError e) { + this->diag_ctx->Emit( + Diagnostic::Error(type_tok->span) + << "a type definition with the name " + << "`" << type_id << "` " + << "was previously defined"); + } Array generics; @@ -664,7 +687,8 @@ class Parser { ctors = ParseSequence( TokenType::LCurly, TokenType::Comma, TokenType::RCurly, [&]() { // First match the name of the constructor. - auto ctor_name = Match(TokenType::Identifier).ToString(); + auto ctor_tok = Match(TokenType::Identifier); + auto ctor_name = ctor_tok.ToString(); Constructor ctor; // Match the optional field list. @@ -679,7 +703,15 @@ class Parser { CHECK(ctor.defined()); - this->ctors.Add(ctor_name, ctor); + try { + this->ctors.Add(ctor_name, ctor); + } catch (DuplicateKeyError e) { + this->diag_ctx->EmitFatal( + Diagnostic::Error(ctor_tok->span) + << "a constructor with the name " + << "`" << ctor_name << "` " + << "was previously defined"); + } return ctor; }); From d521757ead3f2f0eae3695ac2952c2af5629193e Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 30 Jul 2020 22:09:48 -0700 Subject: [PATCH 34/48] Tweak --- src/parser/parser.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/parser/parser.cc b/src/parser/parser.cc index fa90d4e00f96..4fc98562d5a6 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -141,7 +141,7 @@ class ScopeStack { }; struct DuplicateKeyError : public dmlc::Error { - DuplicateKeyError(const std::string& msg) : dmlc::Error(msg) {} + explicit DuplicateKeyError(const std::string& msg) : dmlc::Error(msg) {} }; /*! \brief A table of interning strings as global function and type names. */ @@ -597,9 +597,10 @@ class Parser { << "invalid semantic version `" << version.ToString() << "`"); } } else if (required) { - this->diag_ctx->Emit(DiagnosticBuilder(DiagnosticLevel::Error, Peek()->span) - << "expected text format semantic version, found a " << PrettyPrint(Peek()) - << "you can annotate it as #[version = \"0.0.5\"]"); + this->diag_ctx->Emit( + Diagnostic::Error(Peek()->span) + << "expected text format semantic version, found a " << PrettyPrint(Peek()) + << "you can annotate it as #[version = \"0.0.5\"]"); } return SemVer(0, 0, 5); } From aaea54a9d601ccbe85c5af7861c1c1d6509ce13d Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 30 Jul 2020 22:12:43 -0700 Subject: [PATCH 35/48] Tweak --- src/parser/diagnostic.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/parser/diagnostic.h b/src/parser/diagnostic.h index 61543a7cb2f1..6d0df6e88e6e 100644 --- a/src/parser/diagnostic.h +++ b/src/parser/diagnostic.h @@ -178,7 +178,8 @@ struct DiagnosticContext { } if (diagnostics.size()) { - LOG(FATAL) << "DiagnosticError: one or more error diagnostics were emitted, please check diagnostic render for output."; + LOG(FATAL) << "DiagnosticError: one or more error diagnostics were " + << "emitted, please check diagnostic render for output."; } } }; From 99bee2a3258db3a9987c2b638fa3ef46d2000f54 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 30 Jul 2020 22:14:54 -0700 Subject: [PATCH 36/48] Format --- src/ir/module.cc | 2 +- src/parser/diagnostic.h | 38 +++++++++++++++++--------------------- src/parser/parser.cc | 35 +++++++++++++++-------------------- src/parser/source_map.cc | 3 ++- src/parser/tokenizer.h | 11 +++++------ 5 files changed, 40 insertions(+), 49 deletions(-) diff --git a/src/ir/module.cc b/src/ir/module.cc index 1bcf7d928c51..b34740865fc6 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -29,9 +29,9 @@ // and are only used in minimum cases where they are clearly marked. // // Rationale: We calls into relay's analysis module to verify correctness. +#include #include #include -#include #include #include diff --git a/src/parser/diagnostic.h b/src/parser/diagnostic.h index 6d0df6e88e6e..e43c4b295501 100644 --- a/src/parser/diagnostic.h +++ b/src/parser/diagnostic.h @@ -113,40 +113,36 @@ struct DiagnosticBuilder { DiagnosticBuilder() : level(DiagnosticLevel::Error), source_name(), span(Span()) {} DiagnosticBuilder(const DiagnosticBuilder& builder) - : level(builder.level), - source_name(builder.source_name), - span(builder.span) {} + : level(builder.level), source_name(builder.source_name), span(builder.span) {} DiagnosticBuilder(DiagnosticLevel level, Span span) : level(level), span(span) {} - operator Diagnostic() { - return Diagnostic(this->level, this->span, this->stream_.str()); - } + operator Diagnostic() { return Diagnostic(this->level, this->span, this->stream_.str()); } private: std::stringstream stream_; friend struct Diagnostic; }; - DiagnosticBuilder Diagnostic::Bug(Span span) { - return DiagnosticBuilder(DiagnosticLevel::Bug, span); - } +DiagnosticBuilder Diagnostic::Bug(Span span) { + return DiagnosticBuilder(DiagnosticLevel::Bug, span); +} - DiagnosticBuilder Diagnostic::Error(Span span) { - return DiagnosticBuilder(DiagnosticLevel::Error, span); - } +DiagnosticBuilder Diagnostic::Error(Span span) { + return DiagnosticBuilder(DiagnosticLevel::Error, span); +} - DiagnosticBuilder Diagnostic::Warning(Span span) { - return DiagnosticBuilder(DiagnosticLevel::Warning, span); - } +DiagnosticBuilder Diagnostic::Warning(Span span) { + return DiagnosticBuilder(DiagnosticLevel::Warning, span); +} - DiagnosticBuilder Diagnostic::Note(Span span) { - return DiagnosticBuilder(DiagnosticLevel::Note, span); - } +DiagnosticBuilder Diagnostic::Note(Span span) { + return DiagnosticBuilder(DiagnosticLevel::Note, span); +} - DiagnosticBuilder Diagnostic::Help(Span span) { - return DiagnosticBuilder(DiagnosticLevel::Note, span); - } +DiagnosticBuilder Diagnostic::Help(Span span) { + return DiagnosticBuilder(DiagnosticLevel::Note, span); +} /*! \brief A diagnostic context for recording errors against a source file. * TODO(@jroesch): convert source map and improve in follow up PR, the parser diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 4fc98562d5a6..dcf5abc0ddbd 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -149,7 +149,7 @@ template struct InternTable { /*! \brief The internal table mapping strings to a unique allocation. */ std::unordered_map table; - DiagnosticContext *ctx; + DiagnosticContext* ctx; /*! \brief Add the unique allocation. */ void Add(const std::string& name, const T& t) { @@ -597,10 +597,10 @@ class Parser { << "invalid semantic version `" << version.ToString() << "`"); } } else if (required) { - this->diag_ctx->Emit( - Diagnostic::Error(Peek()->span) - << "expected text format semantic version, found a " << PrettyPrint(Peek()) - << "you can annotate it as #[version = \"0.0.5\"]"); + this->diag_ctx->Emit(Diagnostic::Error(Peek()->span) + << "expected text format semantic version, found a " + << PrettyPrint(Peek()) + << "you can annotate it as #[version = \"0.0.5\"]"); } return SemVer(0, 0, 5); } @@ -620,11 +620,9 @@ class Parser { try { global_names.Add(global_name, global); } catch (DuplicateKeyError e) { - this->diag_ctx->Emit( - Diagnostic::Error(global_tok->span) - << "a function with the name " - << "`@" << global_name << "` " - << "was previously defined"); + this->diag_ctx->Emit(Diagnostic::Error(global_tok->span) << "a function with the name " + << "`@" << global_name << "` " + << "was previously defined"); } auto func = ParseFunctionDef(); defs.funcs.push_back(GlobalFunc(global, func)); @@ -661,11 +659,9 @@ class Parser { try { type_names.Add(type_id, type_global); } catch (DuplicateKeyError e) { - this->diag_ctx->Emit( - Diagnostic::Error(type_tok->span) - << "a type definition with the name " - << "`" << type_id << "` " - << "was previously defined"); + this->diag_ctx->Emit(Diagnostic::Error(type_tok->span) << "a type definition with the name " + << "`" << type_id << "` " + << "was previously defined"); } Array generics; @@ -707,11 +703,10 @@ class Parser { try { this->ctors.Add(ctor_name, ctor); } catch (DuplicateKeyError e) { - this->diag_ctx->EmitFatal( - Diagnostic::Error(ctor_tok->span) - << "a constructor with the name " - << "`" << ctor_name << "` " - << "was previously defined"); + this->diag_ctx->EmitFatal(Diagnostic::Error(ctor_tok->span) + << "a constructor with the name " + << "`" << ctor_name << "` " + << "was previously defined"); } return ctor; diff --git a/src/parser/source_map.cc b/src/parser/source_map.cc index 11789a5bf4d3..beb32da7126c 100644 --- a/src/parser/source_map.cc +++ b/src/parser/source_map.cc @@ -60,7 +60,8 @@ Source::Source(const std::string& source) : source(source) { * \param msg The message to attach. */ void Source::ReportAt(std::ostream& out, const Span& span, const std::string& msg) const { - DLOG(INFO) << "Source::ReportAt" << "span = " << span << "msg = " << msg; + DLOG(INFO) << "Source::ReportAt" + << "span = " << span << "msg = " << msg; int line = span->line; int column = span->column; diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index d9ed4f308c96..0456ece4e293 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -277,16 +277,15 @@ struct Tokenizer { } else { // TOOD(@jroesch): maybe make this a warning an continue parsing? auto span = SpanFrom(line, column); - this->diag_ctx->EmitFatal(Diagnostic::Error(span) - << "unsupported attribute " << attribute); + this->diag_ctx->EmitFatal(Diagnostic::Error(span) << "unsupported attribute " << attribute); return Token(); } } else { auto span = SpanFrom(line, column); - this->diag_ctx->EmitFatal( - Diagnostic::Error(span) - << "`#` denotes the start of an attribute can only be followed by `[`" - << " found `" << Peek() << "`"); + this->diag_ctx + ->EmitFatal(Diagnostic::Error(span) + << "`#` denotes the start of an attribute can only be followed by `[`" + << " found `" << Peek() << "`"); return Token(); } } From 4510899b9c9c142cc40dfddcd1f2cf9128454d10 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 31 Jul 2020 14:26:15 -0700 Subject: [PATCH 37/48] Fix some minor issues with ADT tests --- python/tvm/relay/std/prelude.rly | 6 ++---- src/parser/parser.cc | 10 ++++++---- src/printer/text_printer.cc | 2 +- src/printer/text_printer.h | 16 ++++++++++------ 4 files changed, 19 insertions(+), 15 deletions(-) diff --git a/python/tvm/relay/std/prelude.rly b/python/tvm/relay/std/prelude.rly index a22f46c6d18b..17c91283f4d2 100644 --- a/python/tvm/relay/std/prelude.rly +++ b/python/tvm/relay/std/prelude.rly @@ -16,9 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -#[version = "0.0.4"] - -// TODO(weberlo): should we add sugar for scalar types (e.g., `int32` => `Tensor[(), int32]`)? +#[version = "0.0.5"] def @id[A](%x: A) -> A { %x @@ -298,7 +296,7 @@ def @size[A](%t: Tree[A]) -> Tensor[(), int32] { * Takes a number n and a function f; returns a closure that takes an argument * and applies f n times to its argument. */ -def @iterate[A](%f: fn(A) -> A, %n: Tensor[(), int32]) -> (fn(A) -> A) { +def @iterate[A](%f: fn(A) -> A, %n: Tensor[(), int32]) -> fn(A) -> A { if (%n == 0) { @id } else { diff --git a/src/parser/parser.cc b/src/parser/parser.cc index dcf5abc0ddbd..368cd0f1afd2 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -1215,11 +1215,13 @@ class Parser { } // We need a zero-arity case for constructors. - if (expr.as()) { - return Expr(Call(expr, {})); - } else { - return expr; + if (auto ctor_node = expr.as()) { + if (ctor_node->inputs.size() == 0) { + return Expr(Call(expr, {})); + } } + + return expr; }); } diff --git a/src/printer/text_printer.cc b/src/printer/text_printer.cc index 302c68c20c2a..1e882db1fd61 100644 --- a/src/printer/text_printer.cc +++ b/src/printer/text_printer.cc @@ -67,7 +67,7 @@ Doc TextPrinter::PrintMod(const IRModule& mod) { String PrettyPrint(const ObjectRef& node) { Doc doc; - doc << TextPrinter(false, nullptr).PrintFinal(node); + doc << TextPrinter(false, nullptr, false).PrintFinal(node); return doc.str(); } diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 716989fbbf75..19a66b996040 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -355,14 +355,19 @@ namespace tvm { class TextPrinter { public: explicit TextPrinter(bool show_meta_data, - const runtime::TypedPackedFunc& annotate) + const runtime::TypedPackedFunc& annotate, + bool show_warning = true) : show_meta_data_(show_meta_data), + show_warning_(show_warning), annotate_(annotate), relay_text_printer_(show_meta_data, &meta_, annotate), tir_text_printer_(show_meta_data, &meta_) {} /*! \brief whether show meta data */ bool show_meta_data_; + /*! \brief whether show meta data */ + bool show_warning_; + /*! \brief meta data context */ TextMetaDataContext meta_; /*! \brief additional comment function */ @@ -386,13 +391,12 @@ class TextPrinter { doc << Doc::NewLine(); if (show_meta_data_) { doc << "#[metadata]" << Doc::NewLine() << meta_.GetMetaSection(); - } else { + } else if (show_warning_) { doc << "/* For debugging purposes the metadata section has been omitted." << Doc::NewLine() - << " * If you would like to see the full metadata section you can set the " - "`show_meta_data`" - << Doc::NewLine() << " * option to `True` when invoking `astext`. " << Doc::NewLine() + << " * If you would like to see the full metadata section you can set the " << Doc::NewLine() + << " * option to `True` when invoking `astext`. " << Doc::NewLine() << " */"; - } + } } return doc; } From c9914409631dd9d0a191ae897087bce2bcd39880 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 31 Jul 2020 14:26:25 -0700 Subject: [PATCH 38/48] Format --- src/printer/text_printer.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 19a66b996040..c01457745e97 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -393,10 +393,10 @@ class TextPrinter { doc << "#[metadata]" << Doc::NewLine() << meta_.GetMetaSection(); } else if (show_warning_) { doc << "/* For debugging purposes the metadata section has been omitted." << Doc::NewLine() - << " * If you would like to see the full metadata section you can set the " << Doc::NewLine() - << " * option to `True` when invoking `astext`. " << Doc::NewLine() + << " * If you would like to see the full metadata section you can set the " + << Doc::NewLine() << " * option to `True` when invoking `astext`. " << Doc::NewLine() << " */"; - } + } } return doc; } From 50b517cc00c46150ee4224ce3f0f9cbe68f2324e Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 3 Aug 2020 12:26:00 -0700 Subject: [PATCH 39/48] Fix path --- src/parser/meta_ref.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/parser/meta_ref.cc b/src/parser/meta_ref.cc index f763e757c311..2a81423b898b 100644 --- a/src/parser/meta_ref.cc +++ b/src/parser/meta_ref.cc @@ -24,7 +24,6 @@ #include "./meta_ref.h" -#include #include #include #include From 6d1a38d5a53f719f70558416091fd377af644872 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 5 Aug 2020 22:55:12 -0700 Subject: [PATCH 40/48] WIP --- include/tvm/ir/span.h | 2 +- include/tvm/parser/source_map.h | 6 +- src/ir/span.cc | 12 +- src/parser/meta_ref.cc | 11 +- src/parser/meta_ref.h | 4 +- src/parser/op_table.h | 20 +- src/parser/parser.cc | 405 ++++++++++++--------- src/parser/token.h | 327 +++++++++-------- src/parser/tokenizer.h | 116 +++--- src/printer/relay_text_printer.cc | 2 +- src/printer/text_printer.h | 3 +- tests/python/relay/test_ir_text_printer.py | 64 ++-- 12 files changed, 508 insertions(+), 464 deletions(-) diff --git a/include/tvm/ir/span.h b/include/tvm/ir/span.h index 4f1006ebcb8a..be8799eaca19 100644 --- a/include/tvm/ir/span.h +++ b/include/tvm/ir/span.h @@ -111,7 +111,7 @@ class SpanNode : public Object { class Span : public ObjectRef { public: - TVM_DLL Span(SourceName source, int line, int column, int end_line, int end_column); + TVM_DLL Span(SourceName source, int line, int end_line, int column, int end_column); /*! \brief Merge two spans into one which captures the combined regions. */ TVM_DLL Span Merge(const Span& other); diff --git a/include/tvm/parser/source_map.h b/include/tvm/parser/source_map.h index 98583ec549ba..ca0b7163aa66 100644 --- a/include/tvm/parser/source_map.h +++ b/include/tvm/parser/source_map.h @@ -16,13 +16,13 @@ * specific language governing permissions and limitations * under the License. */ - -#ifndef TVM_PARSER_SOURCE_MAP_H_ -#define TVM_PARSER_SOURCE_MAP_H_ /*! * \file source_map.h * \brief A map from source names to source code. */ +#ifndef TVM_PARSER_SOURCE_MAP_H_ +#define TVM_PARSER_SOURCE_MAP_H_ + #include #include #include diff --git a/src/ir/span.cc b/src/ir/span.cc index 2a2601c3f3df..e936feae1723 100644 --- a/src/ir/span.cc +++ b/src/ir/span.cc @@ -65,8 +65,8 @@ Span::Span(SourceName source, int line, int column, int end_line, int end_column auto n = make_object(); n->source = std::move(source); n->line = line; - n->column = column; n->end_line = end_line; + n->column = column; n->end_column = end_column; data_ = std::move(n); } @@ -74,21 +74,21 @@ Span::Span(SourceName source, int line, int column, int end_line, int end_column Span Span::Merge(const Span& other) { CHECK((*this)->source == other->source); return Span((*this)->source, std::min((*this)->line, other->line), - std::min((*this)->column, other->column), std::max((*this)->end_line, other->end_line), + std::min((*this)->column, other->column), std::max((*this)->end_column, other->end_column)); } TVM_REGISTER_NODE_TYPE(SpanNode); -TVM_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source, int line, int column, - int end_line, int end_column) { - return Span(source, line, column, end_line, end_column); +TVM_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source, int line, int end_line, int column, + int end_column) { + return Span(source, line, end_line, column, end_column); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); - p->stream << "Span(" << node->source << ", " << node->line << ", " << node->column << ")"; + p->stream << "Span(" << node->source << ", " << node->line << ", " << node->end_line << ", " << node->column << ", " << node->end_column << ")"; }); } // namespace tvm diff --git a/src/parser/meta_ref.cc b/src/parser/meta_ref.cc index 2a81423b898b..d23892753c5f 100644 --- a/src/parser/meta_ref.cc +++ b/src/parser/meta_ref.cc @@ -35,6 +35,9 @@ namespace parser { using tvm::relay::transform::CreateFunctionPass; using tvm::transform::PassContext; +/* Set to arbitrary high number, since we should never schedule in normal pass manager flow. */ +static int kMetaExpandOptLevel = 1337; + TVM_REGISTER_NODE_TYPE(MetaRefAttrs); bool MetaRefRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -60,12 +63,6 @@ Expr MetaRef(std::string type_key, uint64_t node_index) { return Call(op, {}, Attrs(attrs), {}); } -// class MetaRefAttrExpander : AttrFunctor { -// ObjectRef VisitAttrDefault_(const Object* node) final { - -// } -// } - struct MetaRefExpander : public ExprMutator { MetaTable table; @@ -94,7 +91,7 @@ Function ExpandMetaRefs(const MetaTable& meta_table, const relay::Function& func IRModule ExpandMetaRefs(const MetaTable& meta_table, const IRModule& mod) { auto pass = CreateFunctionPass([&](Function func, IRModule module, PassContext ctx) { return ExpandMetaRefs(meta_table, func); }, - 1337, "ExpandMetaRefs", {}); + kMetaExpandOptLevel, "ExpandMetaRefs", {}); return pass(mod, PassContext::Create()); } diff --git a/src/parser/meta_ref.h b/src/parser/meta_ref.h index 40e3fdbb7a8b..481f334cb0fe 100644 --- a/src/parser/meta_ref.h +++ b/src/parser/meta_ref.h @@ -71,12 +71,12 @@ struct MetaRefAttrs : public tvm::AttrsNode { * of the program. * * \param type_key The type key of the object in the meta section. - * \param kind The index into that subfield. + * \param node_index The index into that subfield. * \returns The meta table reference. */ Expr MetaRef(std::string type_key, uint64_t node_index); -relay::Function ExpandMetaRefs(const MetaTable& meta_table, const relay::Function& mod); +relay::Function ExpandMetaRefs(const MetaTable& meta_table, const relay::Function& func); IRModule ExpandMetaRefs(const MetaTable& meta_table, const IRModule& mod); } // namespace parser diff --git a/src/parser/op_table.h b/src/parser/op_table.h index 5af10a0590b8..050904f23280 100644 --- a/src/parser/op_table.h +++ b/src/parser/op_table.h @@ -80,16 +80,16 @@ struct OperatorTable { OperatorTable DefaultOpTable() { return OperatorTable( - {Rule({TokenType::Star}, Op::Get("multiply"), 12, 2, true), - Rule({TokenType::Division}, Op::Get("divide"), 12, 2, true), - Rule({TokenType::Plus}, Op::Get("add"), 10, 2, true), - Rule({TokenType::Minus}, Op::Get("subtract"), 10, 2, true), - Rule({TokenType::LAngle}, Op::Get("less"), 8, 2, true), - Rule({TokenType::LAngle, TokenType::Equal}, Op::Get("less_equal"), 8, 2, true), - Rule({TokenType::RAngle}, Op::Get("greater"), 8, 2, true), - Rule({TokenType::RAngle, TokenType::Equal}, Op::Get("greater_equal"), 8, 2, true), - Rule({TokenType::Equal, TokenType::Equal}, Op::Get("equal"), 7, 2, true), - Rule({TokenType::Bang, TokenType::Equal}, Op::Get("not_equal"), 7, 2, true)}); + {Rule({TokenType::kStar}, Op::Get("multiply"), 12, 2, true), + Rule({TokenType::kDivision}, Op::Get("divide"), 12, 2, true), + Rule({TokenType::kPlus}, Op::Get("add"), 10, 2, true), + Rule({TokenType::kMinus}, Op::Get("subtract"), 10, 2, true), + Rule({TokenType::kLAngle}, Op::Get("less"), 8, 2, true), + Rule({TokenType::kLAngle, TokenType::kEqual}, Op::Get("less_equal"), 8, 2, true), + Rule({TokenType::kRAngle}, Op::Get("greater"), 8, 2, true), + Rule({TokenType::kRAngle, TokenType::kEqual}, Op::Get("greater_equal"), 8, 2, true), + Rule({TokenType::kEqual, TokenType::kEqual}, Op::Get("equal"), 7, 2, true), + Rule({TokenType::kBang, TokenType::kEqual}, Op::Get("not_equal"), 7, 2, true)}); } } // namespace parser diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 368cd0f1afd2..002ff8da5ffe 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -111,6 +111,7 @@ template class ScopeStack { private: std::vector> scope_stack; + std::unordered_map free_vars; public: /*! \brief Adds a variable binding to the current scope. */ @@ -121,6 +122,10 @@ class ScopeStack { this->scope_stack.back().name_map.insert({name, value}); } + void AddFreeVar(const std::string& name, const T& value) { + free_vars.insert({name, value}); + } + /*! \brief Looks up a variable name in the scope stack returning the matching variable * in most recent scope. */ T Lookup(const std::string& name) { @@ -130,6 +135,13 @@ class ScopeStack { return it->second; } } + + // Check if we bound a free variable declaration. + auto it = free_vars.find(name); + if (it != free_vars.end()) { + return it->second; + } + return T(); } @@ -265,10 +277,10 @@ class Parser { // For now we ignore all whitespace tokens and comments. // We can tweak this behavior later to enable white space sensitivity in the parser. while (pos < static_cast(tokens.size()) && ignore_whitespace && - (tokens.at(pos)->token_type == TokenType::Whitespace || - tokens.at(pos)->token_type == TokenType::Newline || - tokens.at(pos)->token_type == TokenType::LineComment || - tokens.at(pos)->token_type == TokenType::Comment)) { + (tokens.at(pos)->token_type == TokenType::kWhitespace || + tokens.at(pos)->token_type == TokenType::kNewline || + tokens.at(pos)->token_type == TokenType::kLineComment || + tokens.at(pos)->token_type == TokenType::kComment)) { pos++; } @@ -368,6 +380,17 @@ class Parser { return var; } + /*! \brief Bind a local variable in the expression scope. + * + * "x" -> Var("x"), these are needed to map from the raw string names + * to unique variable nodes. + */ + Var BindFreeVar(const std::string& name, const relay::Type& type_annotation) { + auto var = Var(name, type_annotation); + this->expr_scopes.AddFreeVar(name, var); + return var; + } + /*! \brief Bind a type variable in the type scope. * * "A" -> TypeVar("A", ...), these are needed to map from raw string names @@ -386,8 +409,8 @@ class Parser { Var LookupLocal(const Token& local) { auto var = this->expr_scopes.Lookup(local.ToString()); if (!var.defined()) { - diag_ctx->Emit({DiagnosticLevel::Error, local->span, - "this local variable has not been previously declared"}); + diag_ctx->Emit(Diagnostic::Error(local->span) << + "this local variable has not been previously declared"); } return var; } @@ -399,9 +422,8 @@ class Parser { TypeVar LookupTypeVar(const Token& ident) { auto var = this->type_scopes.Lookup(ident.ToString()); if (!var.defined()) { - diag_ctx->Emit( - {DiagnosticLevel::Error, ident->span, - "this type variable has not been previously declared anywhere, perhaps a typo?"}); + diag_ctx->Emit(Diagnostic::Error(ident->span) + << "this type variable has not been previously declared anywhere, perhaps a typo?"); } return var; } @@ -428,7 +450,7 @@ class Parser { /*! \brief Convert a numeric token to an NDArray for embedding into the Relay program. */ NDArray NumberToNDArray(const Token& token) { - if (token->token_type == TokenType::Integer) { + if (token->token_type == TokenType::kInteger) { DLContext ctx = {DLDeviceType::kDLCPU, 0}; auto dtype = String2DLDataType("int32"); auto data = NDArray::Empty({}, dtype, ctx); @@ -437,7 +459,7 @@ class Parser { int64_t value = Downcast(token->data); array[0] = (int32_t)value; return data; - } else if (token->token_type == TokenType::Float) { + } else if (token->token_type == TokenType::kFloat) { DLContext ctx = {DLDeviceType::kDLCPU, 0}; auto dtype = String2DLDataType("float32"); auto data = NDArray::Empty({}, dtype, ctx); @@ -479,13 +501,13 @@ class Parser { /*! \brief Parse `(` parser() `)`. */ template R Parens(std::function parser) { - return Bracket(TokenType::OpenParen, TokenType::CloseParen, parser); + return Bracket(TokenType::kOpenParen, TokenType::kCloseParen, parser); } /*! \brief Parse `{` parser() `}`. */ template R Block(std::function parser) { - return Bracket(TokenType::LCurly, TokenType::RCurly, parser); + return Bracket(TokenType::kLCurly, TokenType::kRCurly, parser); } /*! \brief Parses a sequence beginning with a start token, seperated by a seperator token, and @@ -502,7 +524,7 @@ class Parser { template Array ParseSequence(TokenType start, TokenType sep, TokenType stop, std::function parse, std::function before_stop = nullptr) { - DLOG(INFO) << "Parser::ParseSequence: start=" << start << "sep=" << sep << "stop=" << stop; + DLOG(INFO) << "Parser::ParseSequence: start=" << ToString(start) << "sep=" << ToString(sep) << "stop=" << ToString(stop); Match(start); // This is for the empty arguments list case, if we have token stream @@ -522,12 +544,6 @@ class Parser { auto data = parse(); Array elements = {data}; - // parse '(' expr ','? ')' - // if we are at the end invoke leftover parser - // if (Peek()->token_type == sep && before_stop) { - // before_stop(); - // } - if (WhenMatch(stop)) { return elements; // parse '( expr ',' * ')' @@ -569,7 +585,7 @@ class Parser { // Parse the metadata section at the end. auto metadata = ParseMetadata(); - Match(TokenType::EndOfFile); + Match(TokenType::kEndOfFile); Map funcs; Map types; @@ -589,8 +605,8 @@ class Parser { /*! \brief Parse the semantic versioning header. */ SemVer ParseSemVer(bool required = true) { - if (Peek()->token_type == TokenType::Version) { - auto version = Match(TokenType::Version); + if (Peek()->token_type == TokenType::kVersion) { + auto version = Match(TokenType::kVersion); // TODO(@jroesch): we currently only support 0.0.5. if (version.ToString() != "\"0.0.5\"") { this->diag_ctx->Emit(DiagnosticBuilder(DiagnosticLevel::Error, version->span) @@ -612,9 +628,9 @@ class Parser { while (true) { auto next = Peek(); switch (next->token_type) { - case TokenType::Defn: { - Consume(TokenType::Defn); - auto global_tok = Match(TokenType::Global); + case TokenType::kDefn: { + Consume(TokenType::kDefn); + auto global_tok = Match(TokenType::kGlobal); auto global_name = global_tok.ToString(); auto global = GlobalVar(global_name); try { @@ -628,12 +644,12 @@ class Parser { defs.funcs.push_back(GlobalFunc(global, func)); continue; } - case TokenType::TypeDef: { + case TokenType::kTypeDef: { defs.types.push_back(ParseTypeDef()); continue; } - case TokenType::Extern: { - Consume(TokenType::Extern); + case TokenType::kExtern: { + Consume(TokenType::kExtern); auto type_def = ParseTypeDef(); if (type_def->constructors.size()) { diag_ctx->Emit({DiagnosticLevel::Error, next->span, @@ -650,9 +666,9 @@ class Parser { /*! \brief Parse zero or more Relay type definitions. */ TypeData ParseTypeDef() { // Match the `type` keyword. - Match(TokenType::TypeDef); + Match(TokenType::kTypeDef); // Parse the type's identifier. - auto type_tok = Match(TokenType::Identifier); + auto type_tok = Match(TokenType::kIdentifier); auto type_id = type_tok.ToString(); auto type_global = tvm::GlobalTypeVar(type_id, TypeKind::kAdtHandle); @@ -667,33 +683,33 @@ class Parser { Array generics; bool should_pop = false; - if (Peek()->token_type == TokenType::LSquare) { + if (Peek()->token_type == TokenType::kLSquare) { // If we have generics we need to add a type scope. PushTypeScope(); should_pop = true; generics = - ParseSequence(TokenType::LSquare, TokenType::Comma, TokenType::RSquare, [&]() { - auto type_var_name = Match(TokenType::Identifier).ToString(); + ParseSequence(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { + auto type_var_name = Match(TokenType::kIdentifier).ToString(); return BindTypeVar(type_var_name, TypeKind::kType); }); } Array ctors; - if (Peek()->token_type == TokenType::LCurly) { + if (Peek()->token_type == TokenType::kLCurly) { // Parse the list of constructors. ctors = ParseSequence( - TokenType::LCurly, TokenType::Comma, TokenType::RCurly, [&]() { + TokenType::kLCurly, TokenType::kComma, TokenType::kRCurly, [&]() { // First match the name of the constructor. - auto ctor_tok = Match(TokenType::Identifier); + auto ctor_tok = Match(TokenType::kIdentifier); auto ctor_name = ctor_tok.ToString(); Constructor ctor; // Match the optional field list. - if (Peek()->token_type != TokenType::OpenParen) { + if (Peek()->token_type != TokenType::kOpenParen) { ctor = tvm::Constructor(ctor_name, {}, type_global); } else { auto arg_types = - ParseSequence(TokenType::OpenParen, TokenType::Comma, TokenType::CloseParen, + ParseSequence(TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, [&]() { return ParseType(); }); ctor = tvm::Constructor(ctor_name, arg_types, type_global); } @@ -757,12 +773,12 @@ class Parser { switch (next->token_type) { // For graph or let, match first rhs, then invoke ParseBindingExpr // ParseBindingExpression then parse_lhs() parse_rhs() ';' continue - case TokenType::LCurly: { + case TokenType::kLCurly: { // NB: Might need to optimize to remove deep recursion. // Stack should only grow proportionally to the number of // nested scopes. // Parses `{` expression `}`. - auto block = Bracket(TokenType::LCurly, TokenType::RCurly, [&]() { + auto block = Bracket(TokenType::kLCurly, TokenType::kRCurly, [&]() { PushScope(); auto expr = ParseExpr(); PopScopes(1); @@ -771,24 +787,32 @@ class Parser { exprs.push_back(block); break; } + case TokenType::kFreeVar: { + Consume(TokenType::kFreeVar); + auto var_token = Match(TokenType::kLocal); + Match(TokenType::kColon); + auto type = ParseType(); + BindFreeVar(var_token.ToString(), type); + break; + } // Parses `let ...`; - case TokenType::Let: + case TokenType::kLet: exprs.push_back(ParseBindingExpr()); break; - case TokenType::Match: - case TokenType::PartialMatch: { - bool is_total = next->token_type == TokenType::Match; + case TokenType::kMatch: + case TokenType::kPartialMatch: { + bool is_total = next->token_type == TokenType::kMatch; Consume(next->token_type); exprs.push_back(ParseMatch(is_total)); break; } - case TokenType::If: { + case TokenType::kIf: { exprs.push_back(ParseIf()); break; } // %x ... - case TokenType::Graph: - if (Lookahead(2)->token_type == TokenType::Equal) { + case TokenType::kGraph: + if (Lookahead(2)->token_type == TokenType::kEqual) { exprs.push_back(ParseBindingExpr()); break; } @@ -799,7 +823,7 @@ class Parser { } } - if (!WhenMatch(TokenType::Semicolon)) { + if (!WhenMatch(TokenType::kSemicolon)) { break; } } @@ -853,34 +877,34 @@ class Parser { while (true) { auto next = Peek(); - if (next->token_type == TokenType::Graph && Lookahead(2)->token_type == TokenType::Equal) { - Match(TokenType::Graph); - Match(TokenType::Equal); + if (next->token_type == TokenType::kGraph && Lookahead(2)->token_type == TokenType::kEqual) { + Match(TokenType::kGraph); + Match(TokenType::kEqual); auto val = this->ParseExprBinOp(); - Match(TokenType::Semicolon); + Match(TokenType::kSemicolon); AddGraphBinding(next, val); - } else if (next->token_type == TokenType::Let) { + } else if (next->token_type == TokenType::kLet) { // Parse the 'let'. - Consume(TokenType::Let); + Consume(TokenType::kLet); // Parse the local '%'. - auto local_tok = Match(TokenType::Local); + auto local_tok = Match(TokenType::kLocal); auto string = local_tok.ToString(); // Parse the optional type annotation (':' ). Type type; - if (WhenMatch(TokenType::Colon)) { + if (WhenMatch(TokenType::kColon)) { type = ParseType(); } auto var = BindVar(string, type); // Parse the '='; - Match(TokenType::Equal); + Match(TokenType::kEqual); // Parse the body, and the ';'. auto val = this->ParseExprBinOp(); - Consume(TokenType::Semicolon); + Consume(TokenType::kSemicolon); // Add the bindings to the local data structure. bindings.push_back({var, val}); @@ -923,30 +947,30 @@ class Parser { PushTypeScope(); Array generics; - if (Peek()->token_type == TokenType::LSquare) { + if (Peek()->token_type == TokenType::kLSquare) { // If we have generics we need to add a type scope. PushTypeScope(); generics = - ParseSequence(TokenType::LSquare, TokenType::Comma, TokenType::RSquare, [&]() { - auto type_var_name = Match(TokenType::Identifier).ToString(); + ParseSequence(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { + auto type_var_name = Match(TokenType::kIdentifier).ToString(); return BindTypeVar(type_var_name, TypeKind::kType); }); } auto params = - ParseSequence(TokenType::OpenParen, TokenType::Comma, TokenType::CloseParen, [&]() { - auto token = Match(TokenType::Local); + ParseSequence(TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, [&]() { + auto token = Match(TokenType::kLocal); auto string = token.ToString(); Type type; - if (WhenMatch(TokenType::Colon)) { + if (WhenMatch(TokenType::kColon)) { type = ParseType(); } return BindVar(string, type); }); Type ret_type; - if (WhenMatch(TokenType::Minus)) { - Match(TokenType::RAngle); + if (WhenMatch(TokenType::kMinus)) { + Match(TokenType::kRAngle); ret_type = ParseType(); } @@ -961,7 +985,7 @@ class Parser { /*! \brief Parse an if-expression. */ Expr ParseIf() { DLOG(INFO) << "Parser::ParseIf"; - Consume(TokenType::If); + Consume(TokenType::kIf); auto guard = Parens([&] { return ParseExpr(); }); auto true_branch = Block([&] { @@ -971,7 +995,7 @@ class Parser { return expr; }); - Match(TokenType::Else); + Match(TokenType::kElse); auto false_branch = Block([&] { this->PushScope(); @@ -985,7 +1009,7 @@ class Parser { /* This factors parsing a list of patterns for both tuples, and constructors. */ Array ParsePatternList() { - return ParseSequence(TokenType::OpenParen, TokenType::Comma, TokenType::CloseParen, + return ParseSequence(TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, [&] { return ParsePattern(); }); } @@ -1000,24 +1024,24 @@ class Parser { DLOG(INFO) << "Parser::ParsePattern"; auto next = Peek(); switch (next->token_type) { - case TokenType::Underscore: { - Match(TokenType::Underscore); + case TokenType::kUnderscore: { + Match(TokenType::kUnderscore); return PatternWildcard(); } - case TokenType::Local: { - auto id = Match(TokenType::Local); + case TokenType::kLocal: { + auto id = Match(TokenType::kLocal); Type type_annotation; - if (WhenMatch(TokenType::Colon)) { + if (WhenMatch(TokenType::kColon)) { type_annotation = ParseType(); } auto var = BindVar(id.ToString(), type_annotation); return PatternVar(var); } - case TokenType::Identifier: { - auto id = Match(TokenType::Identifier); + case TokenType::kIdentifier: { + auto id = Match(TokenType::kIdentifier); auto ctor = ctors.Get(id.ToString()); CHECK(ctor) << "undefined identifier"; - if (Peek()->token_type == TokenType::OpenParen) { + if (Peek()->token_type == TokenType::kOpenParen) { auto fields = ParsePatternList(); return PatternConstructor(ctor.value(), fields); } else { @@ -1032,8 +1056,8 @@ class Parser { Clause ParseMatchArm() { PushScope(); auto pattern = ParsePattern(); - Match(TokenType::Equal); - Consume(TokenType::RAngle); + Match(TokenType::kEqual); + Consume(TokenType::kRAngle); auto expr = ParseExpr(); PopScopes(1); return Clause(pattern, expr); @@ -1043,7 +1067,7 @@ class Parser { Expr scrutinee = ParseExpr(); Array clauses = ParseSequence( - TokenType::LCurly, TokenType::Comma, TokenType::RCurly, [&] { return ParseMatchArm(); }); + TokenType::kLCurly, TokenType::kComma, TokenType::kRCurly, [&] { return ParseMatchArm(); }); return relay::Match(scrutinee, clauses, is_total); } @@ -1126,13 +1150,13 @@ class Parser { DLOG(INFO) << "Parser::ParseAttributeValue"; auto next = Peek(); switch (next->token_type) { - case TokenType::Float: - case TokenType::Integer: - case TokenType::Boolean: - case TokenType::StringLiteral: + case TokenType::kFloat: + case TokenType::kInteger: + case TokenType::kBoolean: + case TokenType::kStringLiteral: return Match(next->token_type)->data; - case TokenType::LSquare: { - return ParseSequence(TokenType::LSquare, TokenType::Comma, TokenType::RSquare, + case TokenType::kLSquare: { + return ParseSequence(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { return ParseAttributeValue(); }); } default: @@ -1143,58 +1167,67 @@ class Parser { Map ParseAttrs() { DLOG(INFO) << "Parser::ParseAttrs"; Map kwargs; - while (Peek()->token_type == TokenType::Identifier) { - auto key = Match(TokenType::Identifier).ToString(); - Match(TokenType::Equal); + while (Peek()->token_type == TokenType::kIdentifier) { + auto key = Match(TokenType::kIdentifier).ToString(); + Match(TokenType::kEqual); // TOOD(@jroesch): syntactically what do we allow to appear in attribute right hand side. auto value = ParseAttributeValue(); + // TODO(@jroesch): we need a robust way to handle this writing dtypes as strings in text format is bad. kwargs.Set(key, value); - WhenMatch(TokenType::Comma); + WhenMatch(TokenType::kComma); } DLOG(INFO) << "Parser::ParseAttrs: kwargs=" << kwargs; return kwargs; } Expr ParseCallArgs(Expr op) { - DLOG(INFO) << "Parser::ParseCallArgs"; - Map raw_attrs; - std::string op_key; - bool is_op = false; - - if (auto op_node = op.as()) { - is_op = true; - op_key = op_node->attrs_type_key; - } + try { + DLOG(INFO) << "Parser::ParseCallArgs"; + Map raw_attrs; + std::string op_key; + bool is_op = false; + + if (auto op_node = op.as()) { + is_op = true; + op_key = op_node->attrs_type_key; + } - if (Peek()->token_type == TokenType::OpenParen) { - Array args = ParseSequence( - TokenType::OpenParen, TokenType::Comma, TokenType::CloseParen, - [&] { return ParseExpr(); }, - [&] { - auto is_ident = Lookahead(1)->token_type == TokenType::Identifier; - auto next_is_equal = Lookahead(2)->token_type == TokenType::Equal; - - if (is_op && is_ident && next_is_equal) { - raw_attrs = ParseAttrs(); - return true; - } + if (Peek()->token_type == TokenType::kOpenParen) { + Array args = ParseSequence( + TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, + [&] { return ParseExpr(); }, + [&] { + auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier; + auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual; + + if (is_op && is_ident && next_is_equal) { + raw_attrs = ParseAttrs(); + return true; + } - return false; - }); + return false; + }); - Attrs attrs; + Attrs attrs; - if (is_op && op_key.size()) { - // raw_attrs.Set("type_key", tvm::String("hello")); - auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs); - CHECK(attr_obj.defined()); - attrs = Downcast(attr_obj); - } + if (is_op && op_key.size()) { + auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs); + CHECK(attr_obj.defined()); + attrs = Downcast(attr_obj); + } - return Expr(Call(op, args, attrs, {})); - } else { - return Expr(); + return Expr(Call(op, args, attrs, {})); + } else { + return Expr(); + } + } catch (...) { + // TODO(@jroesch): AttrErrors should have fields + this->diag_ctx->Emit( + Diagnostic::Error(Peek()->span)); + // << err.what()); } + + return Expr(); } Expr ParseCallExpr() { @@ -1205,12 +1238,20 @@ class Parser { // // NB(@jroesch): this seems like a hack but in order to parse curried functions // and avoid complex grammar we will parse multiple call lists in a row. - while (Peek()->token_type == TokenType::OpenParen) { - auto new_expr = ParseCallArgs(expr); - if (new_expr.defined()) { - expr = new_expr; - } else { - break; + while (Peek()->token_type == TokenType::kOpenParen) { + try { + auto new_expr = ParseCallArgs(expr); + + if (new_expr.defined()) { + expr = new_expr; + } else { + break; + } + } catch (...) { + // TODO(@jroesch): AttrErrors should have fields + this->diag_ctx->EmitFatal( + Diagnostic::Error(Peek()->span)); + // << err.what()); } } @@ -1241,29 +1282,29 @@ class Parser { auto expr = ConsumeWhitespace([this] { auto next = Peek(); switch (next->token_type) { - case TokenType::Integer: - case TokenType::Float: { + case TokenType::kInteger: + case TokenType::kFloat: { Consume(next->token_type); auto number = NumberToNDArray(next); Expr e = Constant(number, next->span); return e; } - case TokenType::Boolean: { - Consume(TokenType::Boolean); + case TokenType::kBoolean: { + Consume(TokenType::kBoolean); int value = Downcast(next->data); auto boolean = BooleanToNDarray(value); Expr e = Constant(boolean, next->span); return e; } // Parse a local of the form `%x`. - case TokenType::Local: { - Consume(TokenType::Local); + case TokenType::kLocal: { + Consume(TokenType::kLocal); return Expr(LookupLocal(next)); } // Parse a local of the form `@x`. - case TokenType::Global: { + case TokenType::kGlobal: { auto string = next.ToString(); - Consume(TokenType::Global); + Consume(TokenType::kGlobal); auto global = global_names.Get(string); if (!global) { // TODO(@jroesch): fix global's needing span information @@ -1276,10 +1317,10 @@ class Parser { } // Parse a local of the form `x`. // Right now we fail to parse `x.y`. - case TokenType::Identifier: { + case TokenType::kIdentifier: { auto ctor = ctors.Get(next.ToString()); if (ctor) { - Consume(TokenType::Identifier); + Consume(TokenType::kIdentifier); return Expr(ctor.value()); } else { auto idents = ParseHierName(); @@ -1296,37 +1337,37 @@ class Parser { return GetOp(op_name.str(), next); } } - case TokenType::Graph: { - Consume(TokenType::Graph); + case TokenType::kGraph: { + Consume(TokenType::kGraph); return LookupGraphBinding(next); } - case TokenType::MetaReference: { - Consume(TokenType::MetaReference); + case TokenType::kMetaReference: { + Consume(TokenType::kMetaReference); return Downcast(next->data); } - case TokenType::Fn: { - Consume(TokenType::Fn); + case TokenType::kFn: { + Consume(TokenType::kFn); return Expr(ParseFunctionDef()); } - case TokenType::OpenParen: { - Consume(TokenType::OpenParen); + case TokenType::kOpenParen: { + Consume(TokenType::kOpenParen); // parse '(' ')' - if (WhenMatch(TokenType::CloseParen)) { + if (WhenMatch(TokenType::kCloseParen)) { return Expr(Tuple(Array())); } else { auto expr = ParseExpr(); // parse '(' expr ')' - if (WhenMatch(TokenType::CloseParen)) { + if (WhenMatch(TokenType::kCloseParen)) { return expr; // parse '( expr ',' * ')' - } else if (WhenMatch(TokenType::Comma)) { + } else if (WhenMatch(TokenType::kComma)) { Array exprs = {expr}; while (true) { - if (WhenMatch(TokenType::CloseParen)) { + if (WhenMatch(TokenType::kCloseParen)) { break; } else { auto expr = ParseExpr(); - WhenMatch(TokenType::Comma); + WhenMatch(TokenType::kComma); exprs.push_back(expr); } } @@ -1343,8 +1384,8 @@ class Parser { } }); - if (WhenMatch(TokenType::Period)) { - auto index = Match(TokenType::Integer).ToNumber(); + if (WhenMatch(TokenType::kPeriod)) { + auto index = Match(TokenType::kInteger).ToNumber(); expr = relay::TupleGetItem(expr, index); } @@ -1354,12 +1395,12 @@ class Parser { /*! \brief Parse a hierarchical name. */ Array ParseHierName() { Array idents; - while (Peek()->token_type == TokenType::Identifier) { + while (Peek()->token_type == TokenType::kIdentifier) { idents.push_back(Peek().ToString()); - Consume(TokenType::Identifier); + Consume(TokenType::kIdentifier); - if (Peek()->token_type == TokenType::Period) { - Consume(TokenType::Period); + if (Peek()->token_type == TokenType::kPeriod) { + Consume(TokenType::kPeriod); continue; } else { break; @@ -1371,9 +1412,9 @@ class Parser { /*! \brief Parse a shape. */ Array ParseShape() { - auto dims = ParseSequence(TokenType::OpenParen, TokenType::Comma, - TokenType::CloseParen, [&]() { - auto tok = Match(TokenType::Integer); + auto dims = ParseSequence(TokenType::kOpenParen, TokenType::kComma, + TokenType::kCloseParen, [&]() { + auto tok = Match(TokenType::kInteger); return Downcast(tok->data); }); return dims; @@ -1381,11 +1422,11 @@ class Parser { /*! \brief Parse a function type. */ Type ParseFunctionType() { - auto ty_params = ParseSequence(TokenType::OpenParen, TokenType::Comma, - TokenType::CloseParen, [&]() { return ParseType(); }); + auto ty_params = ParseSequence(TokenType::kOpenParen, TokenType::kComma, + TokenType::kCloseParen, [&]() { return ParseType(); }); - Match(TokenType::Minus); - Match(TokenType::RAngle); + Match(TokenType::kMinus); + Match(TokenType::kRAngle); auto ret_type = ParseType(); return relay::FuncType(ty_params, ret_type, {}, {}); @@ -1406,8 +1447,8 @@ class Parser { CHECK(head_type.defined()) << "internal error: head type must be defined"; Array arg_types; - if (Peek()->token_type == TokenType::LSquare) { - arg_types = ParseSequence(TokenType::LSquare, TokenType::Comma, TokenType::RSquare, + if (Peek()->token_type == TokenType::kLSquare) { + arg_types = ParseSequence(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { return ParseType(); }); } @@ -1426,21 +1467,21 @@ class Parser { Type ParseType() { auto tok = Peek(); - if (tok->token_type == TokenType::OpenParen) { - auto tys = ParseSequence(TokenType::OpenParen, TokenType::Comma, - TokenType::CloseParen, [&]() { return ParseType(); }); + if (tok->token_type == TokenType::kOpenParen) { + auto tys = ParseSequence(TokenType::kOpenParen, TokenType::kComma, + TokenType::kCloseParen, [&]() { return ParseType(); }); return relay::TupleType(tys); - } else if (WhenMatch(TokenType::Fn)) { + } else if (WhenMatch(TokenType::kFn)) { return ParseFunctionType(); - } else if (WhenMatch(TokenType::Identifier)) { + } else if (WhenMatch(TokenType::kIdentifier)) { auto id = tok.ToString(); if (id == "Tensor") { - Match(TokenType::LSquare); + Match(TokenType::kLSquare); auto shape = ParseShape(); - Match(TokenType::Comma); - auto dtype_tok = Match(TokenType::Identifier); + Match(TokenType::kComma); + auto dtype_tok = Match(TokenType::kIdentifier); auto dtype = DataType(String2DLDataType(dtype_tok.ToString())); - Match(TokenType::RSquare); + Match(TokenType::kRSquare); return TensorType(shape, dtype); } else { auto ty = tok.ToString(); @@ -1454,7 +1495,7 @@ class Parser { } } } - if (WhenMatch(TokenType::Underscore)) { + if (WhenMatch(TokenType::kUnderscore)) { return IncompleteType(); } else { this->diag_ctx->EmitFatal(DiagnosticBuilder(DiagnosticLevel::Error, tok->span) @@ -1467,7 +1508,7 @@ class Parser { R ConsumeWhitespace(std::function func) { auto old = this->ignore_whitespace; this->ignore_whitespace = true; - while (tokens[pos]->token_type == TokenType::Whitespace) { + while (tokens[pos]->token_type == TokenType::kWhitespace) { pos++; } auto res = func(); @@ -1476,8 +1517,8 @@ class Parser { } Map> ParseMetadata() { - if (Peek()->token_type == TokenType::Metadata) { - return Match(TokenType::Metadata).ToMetadata(); + if (Peek()->token_type == TokenType::kMetadata) { + return Match(TokenType::kMetadata).ToMetadata(); } else { return Map>(); } @@ -1534,7 +1575,7 @@ Expr ParseExpr(std::string file_name, std::string file_content) { parser.ParseSemVer(false); parser.PushScope(); auto expr = parser.ParseExpr(); - parser.Match(TokenType::EndOfFile); + parser.Match(TokenType::kEndOfFile); // NB(@jroesch): it is very important that we render any errors before we procede // if there were any errors which allow the parser to procede we must render them // here. diff --git a/src/parser/token.h b/src/parser/token.h index 480872956b68..86a26cbada52 100644 --- a/src/parser/token.h +++ b/src/parser/token.h @@ -38,169 +38,172 @@ namespace parser { using namespace runtime; -enum TokenType { - CommentStart, - CommentEnd, - LineComment, - Comment, - Whitespace, - Newline, - StringLiteral, - Identifier, - Local, - Global, - Op, - Graph, - OpenParen, - CloseParen, - AtSymbol, - Percent, - Comma, - Period, - Equal, - Semicolon, - Colon, - Integer, - Float, - Division, - Boolean, - Plus, - Star, - Minus, - RAngle, - LAngle, - RCurly, - LCurly, - RSquare, - LSquare, - Bang, - At, - Question, - If, - Else, - Underscore, - Let, - Fn, - Defn, - TypeDef, - Extern, - Match, - PartialMatch, - Metadata, - MetaReference, - Version, - Unknown, - EndOfFile, - Null, +enum class TokenType { + kCommentStart, + kCommentEnd, + kLineComment, + kComment, + kWhitespace, + kNewline, + kStringLiteral, + kIdentifier, + kLocal, + kGlobal, + kOp, + kGraph, + kOpenParen, + kCloseParen, + kAtSymbol, + kPercent, + kComma, + kPeriod, + kEqual, + kSemicolon, + kColon, + kInteger, + kFloat, + kDivision, + kBoolean, + kPlus, + kStar, + kMinus, + kRAngle, + kLAngle, + kRCurly, + kLCurly, + kRSquare, + kLSquare, + kBang, + kAt, + kQuestion, + kIf, + kElse, + kUnderscore, + kLet, + kFn, + kDefn, + kTypeDef, + kExtern, + kMatch, + kPartialMatch, + kMetadata, + kMetaReference, + kFreeVar, + kVersion, + kUnknown, + kEndOfFile, + kNull, }; std::string ToString(const TokenType& token_type) { switch (token_type) { - case TokenType::CommentStart: + case TokenType::kCommentStart: return "CommentStart"; - case TokenType::CommentEnd: + case TokenType::kCommentEnd: return "CommentEnd"; - case TokenType::LineComment: + case TokenType::kLineComment: return "LineComment"; - case TokenType::Comment: + case TokenType::kComment: return "Comment"; - case TokenType::Whitespace: + case TokenType::kWhitespace: return "WhiteSpace"; - case TokenType::Newline: + case TokenType::kNewline: return "Newline"; - case TokenType::StringLiteral: + case TokenType::kStringLiteral: return "StringLiteral"; - case TokenType::Identifier: + case TokenType::kIdentifier: return "Identifier"; - case TokenType::Local: + case TokenType::kLocal: return "Local"; - case TokenType::Global: + case TokenType::kGlobal: return "Global"; - case TokenType::Graph: + case TokenType::kGraph: return "Graph"; - case TokenType::Op: + case TokenType::kOp: return "Op"; - case TokenType::OpenParen: + case TokenType::kOpenParen: return "OpenParen"; - case TokenType::CloseParen: + case TokenType::kCloseParen: return "CloseParen"; - case TokenType::AtSymbol: + case TokenType::kAtSymbol: return "AtSymbol"; - case TokenType::Percent: + case TokenType::kPercent: return "Percent"; - case TokenType::Comma: + case TokenType::kComma: return "Comma"; - case TokenType::Colon: + case TokenType::kColon: return "Colon"; - case TokenType::Semicolon: + case TokenType::kSemicolon: return "Semicolon"; - case TokenType::Period: + case TokenType::kPeriod: return "Period"; - case TokenType::Equal: + case TokenType::kEqual: return "Equal"; - case TokenType::Integer: + case TokenType::kInteger: return "Integer"; - case TokenType::Float: + case TokenType::kFloat: return "Float"; - case TokenType::Plus: + case TokenType::kPlus: return "Plus"; - case TokenType::Star: + case TokenType::kStar: return "Star"; - case TokenType::Minus: + case TokenType::kMinus: return "Minus"; - case TokenType::Division: + case TokenType::kDivision: return "Division"; - case TokenType::RAngle: + case TokenType::kRAngle: return "RAngle"; - case TokenType::LAngle: + case TokenType::kLAngle: return "LAngle"; - case TokenType::RCurly: + case TokenType::kRCurly: return "RCurly"; - case TokenType::LCurly: + case TokenType::kLCurly: return "LCurly"; - case TokenType::RSquare: + case TokenType::kRSquare: return "RSquare"; - case TokenType::LSquare: + case TokenType::kLSquare: return "LSquare"; - case TokenType::Bang: + case TokenType::kBang: return "Bang"; - case TokenType::Underscore: + case TokenType::kUnderscore: return "Underscore"; - case TokenType::At: + case TokenType::kAt: return "At"; - case TokenType::Let: + case TokenType::kLet: return "Let"; - case TokenType::If: + case TokenType::kIf: return "If"; - case TokenType::Else: + case TokenType::kElse: return "Else"; - case TokenType::Fn: + case TokenType::kFn: return "Fn"; - case TokenType::Defn: + case TokenType::kDefn: return "Defn"; - case TokenType::TypeDef: + case TokenType::kTypeDef: return "TypeDef"; - case TokenType::Extern: + case TokenType::kExtern: return "Extern"; - case TokenType::Match: + case TokenType::kMatch: return "Match"; - case TokenType::PartialMatch: + case TokenType::kPartialMatch: return "PartialMatch"; - case TokenType::Question: + case TokenType::kQuestion: return "Question"; - case TokenType::Boolean: + case TokenType::kBoolean: return "Boolean"; - case TokenType::Metadata: + case TokenType::kMetadata: return "Metadata"; - case TokenType::MetaReference: + case TokenType::kMetaReference: return "MetaReference"; - case TokenType::Version: + case TokenType::kFreeVar: + return "FreeVar"; + case TokenType::kVersion: return "Version"; - case TokenType::Unknown: + case TokenType::kUnknown: return "Unknown"; - case TokenType::EndOfFile: + case TokenType::kEndOfFile: return "EndOfFile"; - case TokenType::Null: + case TokenType::kNull: return "Null"; // Older compilers warn even though the above code is exhaustive. default: @@ -211,111 +214,113 @@ std::string ToString(const TokenType& token_type) { std::string Pretty(const TokenType& token_type) { switch (token_type) { - case TokenType::CommentStart: + case TokenType::kCommentStart: return "`/*`"; - case TokenType::CommentEnd: + case TokenType::kCommentEnd: return "`*/`"; - case TokenType::LineComment: + case TokenType::kLineComment: return "`//`"; - case TokenType::Comment: + case TokenType::kComment: return "comment"; - case TokenType::Whitespace: + case TokenType::kWhitespace: return "whitespace"; - case TokenType::Newline: + case TokenType::kNewline: return "newline"; - case TokenType::StringLiteral: + case TokenType::kStringLiteral: return "string literal"; - case TokenType::Identifier: + case TokenType::kIdentifier: return "identifier"; - case TokenType::Local: + case TokenType::kLocal: return "local variable"; - case TokenType::Global: + case TokenType::kGlobal: return "global variable"; - case TokenType::Graph: + case TokenType::kGraph: return "graph variable"; - case TokenType::Op: + case TokenType::kOp: return "operator"; - case TokenType::OpenParen: + case TokenType::kOpenParen: return "`(`"; - case TokenType::CloseParen: + case TokenType::kCloseParen: return "`)`"; - case TokenType::AtSymbol: + case TokenType::kAtSymbol: return "`@`"; - case TokenType::Percent: + case TokenType::kPercent: return "`%`"; - case TokenType::Comma: + case TokenType::kComma: return "`,`"; - case TokenType::Colon: + case TokenType::kColon: return "`:`"; - case TokenType::Semicolon: + case TokenType::kSemicolon: return "`;`"; - case TokenType::Period: + case TokenType::kPeriod: return "`.`"; - case TokenType::Equal: + case TokenType::kEqual: return "`=`"; - case TokenType::Integer: + case TokenType::kInteger: return "integer"; - case TokenType::Float: + case TokenType::kFloat: return "float"; - case TokenType::Plus: + case TokenType::kPlus: return "`+`"; - case TokenType::Star: + case TokenType::kStar: return "`*`"; - case TokenType::Minus: + case TokenType::kMinus: return "`-`"; - case TokenType::Division: + case TokenType::kDivision: return "`/`"; - case TokenType::RAngle: + case TokenType::kRAngle: return "`<`"; - case TokenType::LAngle: + case TokenType::kLAngle: return "`>`"; - case TokenType::RCurly: + case TokenType::kRCurly: return "`}`"; - case TokenType::LCurly: + case TokenType::kLCurly: return "`{`"; - case TokenType::RSquare: + case TokenType::kRSquare: return "`]`"; - case TokenType::LSquare: + case TokenType::kLSquare: return "`[`"; - case TokenType::Bang: + case TokenType::kBang: return "`!`"; - case TokenType::Underscore: + case TokenType::kUnderscore: return "`_`"; - case TokenType::At: + case TokenType::kAt: return "`@`"; - case TokenType::Let: + case TokenType::kLet: return "`let`"; - case TokenType::If: + case TokenType::kIf: return "`if`"; - case TokenType::Else: + case TokenType::kElse: return "`else`"; - case TokenType::Fn: + case TokenType::kFn: return "`fn`"; - case TokenType::Defn: + case TokenType::kDefn: return "`def`"; - case TokenType::TypeDef: + case TokenType::kTypeDef: return "`type`"; - case TokenType::Extern: + case TokenType::kExtern: return "`extern`"; - case TokenType::Boolean: + case TokenType::kBoolean: return "boolean"; - case TokenType::Metadata: + case TokenType::kMetadata: return "metadata section"; - case TokenType::MetaReference: + case TokenType::kMetaReference: return "`meta`"; - case TokenType::Match: + case TokenType::kFreeVar: + return "`free_var`"; + case TokenType::kMatch: return "`match`"; - case TokenType::PartialMatch: + case TokenType::kPartialMatch: return "`match?`"; - case TokenType::Question: + case TokenType::kQuestion: return "`?`"; - case TokenType::Unknown: + case TokenType::kUnknown: return "unknown"; - case TokenType::EndOfFile: + case TokenType::kEndOfFile: return "end of file"; - case TokenType::Null: + case TokenType::kNull: return "null"; - case TokenType::Version: + case TokenType::kVersion: return "version attribute"; // Older compilers warn even though the above code is exhaustive. default: @@ -366,7 +371,7 @@ Token::Token(Span span, TokenType token_type, ObjectRef data) { data_ = std::move(n); } -Token Token::Null() { return Token(Span(SourceName(), 0, 0, 0, 0), TokenType::Null); } +Token Token::Null() { return Token(Span(SourceName(), 0, 0, 0, 0), TokenType::kNull); } int64_t Token::ToNumber() const { return Downcast(this->operator->()->data); } diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index 0456ece4e293..f500d4ac6d58 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -66,9 +66,9 @@ bool IsIdentLetter(char c) { return '_' == c || ('a' <= c && c <= 'z') || ('A' < bool IsIdent(char c) { return IsIdentLetter(c) || IsDigit(c); } static std::unordered_map KEYWORD_TABLE = { - {"let", TokenType::Let}, {"fn", TokenType::Fn}, {"def", TokenType::Defn}, - {"if", TokenType::If}, {"else", TokenType::Else}, {"type", TokenType::TypeDef}, - {"match", TokenType::Match}, {"extern", TokenType::Extern}}; + {"let", TokenType::kLet}, {"fn", TokenType::kFn}, {"def", TokenType::kDefn}, + {"if", TokenType::kIf}, {"else", TokenType::kElse}, {"type", TokenType::kTypeDef}, + {"match", TokenType::kMatch}, {"extern", TokenType::kExtern}, {"free_var", TokenType::kFreeVar}}; struct Tokenizer { DiagnosticContext* diag_ctx; @@ -102,14 +102,14 @@ struct Tokenizer { Token NewToken(TokenType token_type, ObjectRef data = ObjectRef(), int lines = 0, int cols = 1) { auto span = - Span(this->source_name, this->line, this->col, this->line + lines, this->col + cols); + Span(this->source_name, this->line, this->line + lines, this->col, this->col + cols); return Token(span, token_type, data); } Span SpanFrom(int line, int column) { int end_line = this->line; int end_column = this->col; - return Span(this->source_name, line, column, end_line, end_column); + return Span(this->source_name, line, end_line, column, end_column); } enum CommentParserState { @@ -172,7 +172,7 @@ struct Tokenizer { if (is_float) { throw std::invalid_argument("is_float"); } - auto token = NewToken(TokenType::Integer); + auto token = NewToken(TokenType::kInteger); size_t index = 0; int value = std::stoi(number, &index); if (number.size() > index) { @@ -182,7 +182,7 @@ struct Tokenizer { token->data = tvm::Integer(value); return token; } catch (const std::invalid_argument& ia) { - auto token = NewToken(TokenType::Float); + auto token = NewToken(TokenType::kFloat); if (number.back() == 'f') { number.pop_back(); @@ -233,10 +233,8 @@ struct Tokenizer { Next(); // todo: add error handling around bad indices auto index = ParseNumber(true, false, str_index.str()).ToNumber(); - int end_line = this->line; - int end_column = this->col; - auto span = Span(this->source_name, line, column, end_line, end_column); - return Token(span, TokenType::MetaReference, MetaRef(type_key.str(), index)); + auto span = SpanFrom(line, column); + return Token(span, TokenType::kMetaReference, MetaRef(type_key.str(), index)); } Token TokenizeAttr() { @@ -266,14 +264,14 @@ struct Tokenizer { } ObjectRef metadata_map = tvm::LoadJSON(metadata.str()); auto span = SpanFrom(line, column); - return Token(span, TokenType::Metadata, metadata_map); + return Token(span, TokenType::kMetadata, metadata_map); } if (attribute.rfind("version", 0) == 0) { std::string version = attribute.substr(attribute.find("=") + 1); ltrim(version); rtrim(version); auto span = SpanFrom(line, column); - return Token(span, TokenType::Version, tvm::String(version)); + return Token(span, TokenType::kVersion, tvm::String(version)); } else { // TOOD(@jroesch): maybe make this a warning an continue parsing? auto span = SpanFrom(line, column); @@ -296,13 +294,13 @@ struct Tokenizer { auto next = Peek(); DLOG(INFO) << "tvm::parser::TokenizeOnce: next=" << next; if (next == '\n') { - auto token = NewToken(TokenType::Newline); + auto token = NewToken(TokenType::kNewline); Next(); return token; } else if (next == '\r') { Next(); if (More() && Peek() == '\n') { - auto token = NewToken(TokenType::Newline); + auto token = NewToken(TokenType::kNewline); return token; } else { auto span = SpanFrom(line, col); @@ -320,9 +318,9 @@ struct Tokenizer { string_content << Next(); } Next(); - return NewToken(TokenType::StringLiteral, tvm::String(string_content.str())); + return NewToken(TokenType::kStringLiteral, tvm::String(string_content.str())); } else if (IsWhitespace(next)) { - auto token = NewToken(TokenType::Whitespace); + auto token = NewToken(TokenType::kWhitespace); Next(); return token; } else if (IsDigit(next) || next == '-') { @@ -336,7 +334,7 @@ struct Tokenizer { // with multi-token return or something. if (negs && !IsDigit(Peek())) { pos = pos - (negs - 1); - return NewToken(TokenType::Minus); + return NewToken(TokenType::kMinus); } bool is_neg = negs % 2 == 1; @@ -354,79 +352,79 @@ struct Tokenizer { return ParseNumber(!is_neg, is_float, ss.str()); } else if (next == '.') { - auto token = NewToken(TokenType::Period); + auto token = NewToken(TokenType::kPeriod); Next(); return token; } else if (next == ',') { - auto token = NewToken(TokenType::Comma); + auto token = NewToken(TokenType::kComma); Next(); return token; } else if (next == '=') { - auto token = NewToken(TokenType::Equal); + auto token = NewToken(TokenType::kEqual); Next(); return token; } else if (next == ';') { - auto token = NewToken(TokenType::Semicolon); + auto token = NewToken(TokenType::kSemicolon); Next(); return token; } else if (next == ':') { - auto token = NewToken(TokenType::Colon); + auto token = NewToken(TokenType::kColon); Next(); return token; } else if (next == '(') { - auto token = NewToken(TokenType::OpenParen); + auto token = NewToken(TokenType::kOpenParen); Next(); return token; } else if (next == ')') { - auto token = NewToken(TokenType::CloseParen); + auto token = NewToken(TokenType::kCloseParen); Next(); return token; } else if (next == '+') { - auto token = NewToken(TokenType::Plus); + auto token = NewToken(TokenType::kPlus); Next(); return token; } else if (next == '-') { - auto token = NewToken(TokenType::Minus); + auto token = NewToken(TokenType::kMinus); Next(); return token; } else if (next == '*') { - auto token = NewToken(TokenType::Star); + auto token = NewToken(TokenType::kStar); Next(); return token; } else if (next == '<') { - auto token = NewToken(TokenType::LAngle); + auto token = NewToken(TokenType::kLAngle); Next(); return token; } else if (next == '>') { - auto token = NewToken(TokenType::RAngle); + auto token = NewToken(TokenType::kRAngle); Next(); return token; } else if (next == '{') { - auto token = NewToken(TokenType::LCurly); + auto token = NewToken(TokenType::kLCurly); Next(); return token; } else if (next == '}') { - auto token = NewToken(TokenType::RCurly); + auto token = NewToken(TokenType::kRCurly); Next(); return token; } else if (next == '[') { - auto token = NewToken(TokenType::LSquare); + auto token = NewToken(TokenType::kLSquare); Next(); return token; } else if (next == ']') { - auto token = NewToken(TokenType::RSquare); + auto token = NewToken(TokenType::kRSquare); Next(); return token; } else if (next == '!') { - auto token = NewToken(TokenType::Bang); + auto token = NewToken(TokenType::kBang); Next(); return token; } else if (next == '@') { - auto token = NewToken(TokenType::At); + auto token = NewToken(TokenType::kAt); Next(); return token; } else if (next == '?') { - auto token = NewToken(TokenType::Question); + auto token = NewToken(TokenType::kQuestion); Next(); return token; } else if (MatchString("meta")) { @@ -434,7 +432,7 @@ struct Tokenizer { } else if (next == '#') { return TokenizeAttr(); } else if (next == '%') { - auto token = NewToken(TokenType::Percent); + auto token = NewToken(TokenType::kPercent); Next(); std::stringstream number; @@ -446,14 +444,14 @@ struct Tokenizer { if (number_str.size()) { auto num_tok = ParseNumber(true, false, number_str); auto span = SpanFrom(token->span->line, token->span->column); - token = Token(span, TokenType::Graph, num_tok->data); + token = Token(span, TokenType::kGraph, num_tok->data); } return token; } else if (next == '/') { Next(); if (Peek() == '/') { - auto token = NewToken(TokenType::LineComment); + auto token = NewToken(TokenType::kLineComment); // Consume the / Next(); std::stringstream comment; @@ -467,10 +465,10 @@ struct Tokenizer { Next(); std::string comment; MatchComment(&comment); - auto token = NewToken(TokenType::Comment, tvm::String(comment)); + auto token = NewToken(TokenType::kComment, tvm::String(comment)); return token; } else { - return NewToken(TokenType::Division); + return NewToken(TokenType::kDivision); } } else if (IsIdentLetter(next)) { std::stringstream ss; @@ -491,14 +489,14 @@ struct Tokenizer { if (it != KEYWORD_TABLE.end()) { token_type = it->second; - if (token_type == TokenType::Match) { + if (token_type == TokenType::kMatch) { if (More() && Peek() == '?') { Next(); - token_type = TokenType::PartialMatch; + token_type = TokenType::kPartialMatch; } } } else { - token_type = TokenType::Identifier; + token_type = TokenType::kIdentifier; } auto span = SpanFrom(line, col); @@ -508,7 +506,7 @@ struct Tokenizer { while (More() && !IsWhitespace(Peek())) { ss << Next(); } - auto token = NewToken(TokenType::Unknown); + auto token = NewToken(TokenType::kUnknown); token->data = tvm::String(ss.str()); return token; } @@ -521,7 +519,7 @@ struct Tokenizer { CHECK(token.defined()); this->tokens.push_back(token); } - this->tokens.push_back(NewToken(TokenType::EndOfFile)); + this->tokens.push_back(NewToken(TokenType::kEndOfFile)); } explicit Tokenizer(DiagnosticContext* ctx, const SourceName& source_name, @@ -541,18 +539,18 @@ std::vector Condense(const std::vector& tokens) { for (size_t i = 0; i < tokens.size(); i++) { auto current = tokens.at(i); switch (current->token_type) { - case TokenType::Percent: { + case TokenType::kPercent: { auto next = tokens.at(i + 1); - if (next->token_type == TokenType::Identifier) { + if (next->token_type == TokenType::kIdentifier) { // Match this token. i += 1; // TODO(@jroesch): merge spans - auto tok = Token(current->span, TokenType::Local, next->data); + auto tok = Token(current->span, TokenType::kLocal, next->data); CHECK(tok.defined()); out.push_back(tok); - } else if (next->token_type == TokenType::Integer) { + } else if (next->token_type == TokenType::kInteger) { i += 1; - auto tok = Token(current->span, TokenType::Graph, next->data); + auto tok = Token(current->span, TokenType::kGraph, next->data); CHECK(tok.defined()); out.push_back(tok); } else { @@ -561,13 +559,13 @@ std::vector Condense(const std::vector& tokens) { } continue; } - case TokenType::At: { + case TokenType::kAt: { auto next = tokens.at(i + 1); - if (next->token_type == TokenType::Identifier) { + if (next->token_type == TokenType::kIdentifier) { // Match this token. i += 1; // TODO(@jroesch): merge spans - auto tok = Token(current->span, TokenType::Global, next->data); + auto tok = Token(current->span, TokenType::kGlobal, next->data); CHECK(tok.defined()); out.push_back(tok); } else { @@ -576,18 +574,18 @@ std::vector Condense(const std::vector& tokens) { } continue; } - case TokenType::Identifier: { + case TokenType::kIdentifier: { std::string str = Downcast(current->data); Token tok; // TODO(@jroesch): merge spans if (str == "True") { auto data = tvm::Integer(1); - tok = Token(current->span, TokenType::Boolean, data); + tok = Token(current->span, TokenType::kBoolean, data); } else if (str == "False") { auto data = tvm::Integer(0); - tok = Token(current->span, TokenType::Boolean, data); + tok = Token(current->span, TokenType::kBoolean, data); } else if (str == "_") { - tok = Token(current->span, TokenType::Underscore); + tok = Token(current->span, TokenType::kUnderscore); } else { tok = current; } diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 90cf428f1ca1..1b09052a63d8 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -275,7 +275,7 @@ Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline) { if (expr.as()) { // This is our first time visiting the var and we hit the VarNode case // in the visitor. Thus the variable is free. - doc_stack_.back() << "free_var " << printed_expr << Doc::NewLine(); + doc_stack_.back() << "free_var " << printed_expr << ";" << Doc::NewLine(); // Memoization is done in AllocVar. return memo_[expr]; } else if (inline_expr) { diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index c01457745e97..b65b03c38063 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -365,7 +365,8 @@ class TextPrinter { /*! \brief whether show meta data */ bool show_meta_data_; - /*! \brief whether show meta data */ + + /*! \brief whether show the meta data warning message */ bool show_warning_; /*! \brief meta data context */ diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 2a88c0c99ae7..cd677096b2e9 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -22,20 +22,22 @@ from tvm.relay import Expr from tvm.relay.analysis import free_vars -do_print = [False] +do_print = [True] -SEMVER = "v0.0.4\n" +SEMVER = "#[version = \"0.0.4\"]\n" -def astext(p, unify_free_vars=False): - txt = p.astext() - if isinstance(p, Expr) and free_vars(p): - return txt - x = relay.fromtext(txt) - if unify_free_vars: - tvm.ir.assert_structural_equal(x, p, map_free_vars=True) +def astext(program, unify_free_vars=False): + text = program.astext() + print(text) + + if isinstance(program, Expr): + roundtrip_program = tvm.parser.parse_expr(text) else: - tvm.ir.assert_structural_equal(x, p) - return txt + roundtrip_program = tvm.parser.fromtext(text) + + tvm.ir.assert_structural_equal(roundtrip_program, program, map_free_vars=True) + + return text def show(text): if do_print[0]: @@ -252,23 +254,23 @@ def test_null_attribute(): if __name__ == "__main__": do_print[0] = True test_lstm() - test_zeros() - test_meta_data() - test_let_inlining() - test_resnet() - test_mobilenet() - test_mlp() - test_dqn() - test_dcgan() - test_squeezenet() - test_inception_v3() - test_vgg() - test_densenet() - test_func() - test_env() - test_call_attrs() - test_let_if_scope() - test_variable_name() - test_call_node_order() - test_unapplied_constructor() - test_null_attribute() + # test_zeros() + # test_meta_data() + # test_let_inlining() + # test_resnet() + # test_mobilenet() + # test_mlp() + # test_dqn() + # test_dcgan() + # test_squeezenet() + # test_inception_v3() + # test_vgg() + # test_densenet() + # test_func() + # test_env() + # test_call_attrs() + # test_let_if_scope() + # test_variable_name() + # test_call_node_order() + # test_unapplied_constructor() + # test_null_attribute() From 2e7bdbd56bd9644274e06507c1a1ca91574ff306 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 7 Aug 2020 14:59:27 -0700 Subject: [PATCH 41/48] WIP --- src/parser/parser.cc | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 002ff8da5ffe..dcc98169c191 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -42,6 +42,9 @@ namespace parser { using namespace relay; using Expr = relay::Expr; +/*! \brief The meta table maps from type key to a sequence of objects. */ +using MetaTable = Map>; + /*! \brief A wrapper structure for capturing the result of parsing * a global definition *before* we add it to the IRModule. * @@ -262,14 +265,18 @@ class Parser { /*! \brief The set of expression scopes used for lexical scope. */ ScopeStack expr_scopes; + /*! \brief The metadata section. */ + MetaTable meta_table; + Parser(DiagnosticContext* ctx, const SourceName& source_name, std::vector tokens, - OperatorTable op_table, Source source) + OperatorTable op_table, Source source, MetaTable table) : diag_ctx(ctx), source_name(source_name), pos(0), tokens(tokens), op_table(op_table), - ignore_whitespace(true) {} + ignore_whitespace(true), + meta_table(table) {} /*! \brief Examine the next token in the stream, the current parser is configured to be * whitespace insensitive so we will skip all whitespace or comment tokens. */ @@ -510,6 +517,10 @@ class Parser { return Bracket(TokenType::kLCurly, TokenType::kRCurly, parser); } + Object ParseMetaRef() { + Consume(TokenType::kMetaReference); + LOG(FATAL) << "implement me"; + } /*! \brief Parses a sequence beginning with a start token, seperated by a seperator token, and * ending with a stop token. * @@ -1342,6 +1353,7 @@ class Parser { return LookupGraphBinding(next); } case TokenType::kMetaReference: { + return Downcast(ParseMetaRef()); Consume(TokenType::kMetaReference); return Downcast(next->data); } From e3e5ef880b4504017b89d296c5f257eaff29cd70 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 7 Aug 2020 17:46:21 -0700 Subject: [PATCH 42/48] Fix ir_text_printer --- include/tvm/parser/source_map.h | 9 +- src/parser/parser.cc | 109 +++++++++++++++++---- src/parser/source_map.cc | 2 +- src/parser/token.h | 7 +- src/parser/tokenizer.h | 19 +++- tests/python/relay/test_ir_parser.py | 7 +- tests/python/relay/test_ir_text_printer.py | 57 ++++------- 7 files changed, 139 insertions(+), 71 deletions(-) diff --git a/include/tvm/parser/source_map.h b/include/tvm/parser/source_map.h index ca0b7163aa66..1e897b5444c9 100644 --- a/include/tvm/parser/source_map.h +++ b/include/tvm/parser/source_map.h @@ -41,18 +41,21 @@ namespace parser { * source of a TVM program. */ struct Source { + /*! \brief The source name. */ + SourceName source_name; + /*! \brief The raw source. */ std::string source; /*! \brief A mapping of line breaks into the raw source. */ std::vector> line_map; /*! \brief An empty source. */ - Source() : source(), line_map() {} + Source() : source_name(), source(), line_map() {} /*! \brief Construct a source from a string. */ - TVM_DLL explicit Source(const std::string& source); + TVM_DLL explicit Source(const SourceName& src_name, const std::string& source); - TVM_DLL Source(const Source& source) : source(source.source), line_map(source.line_map) {} + TVM_DLL Source(const Source& source) : source_name(source.source_name), source(source.source), line_map(source.line_map) {} /*! \brief Generate an error message at a specific line and column with the * annotated message. diff --git a/src/parser/parser.cc b/src/parser/parser.cc index dcc98169c191..a68d69cd5a0d 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -517,9 +517,29 @@ class Parser { return Bracket(TokenType::kLCurly, TokenType::kRCurly, parser); } - Object ParseMetaRef() { - Consume(TokenType::kMetaReference); - LOG(FATAL) << "implement me"; + ObjectRef ParseMetaRef() { + auto meta_ref = Match(TokenType::kMetaReference); + Call ref = Downcast(meta_ref->data); + auto attrs = ref->attrs.as(); + auto type_key = attrs->node_type_key; + auto index = attrs->node_index; + auto it = this->meta_table.find(type_key); + if (it != this->meta_table.end()) { + auto nodes = (*it).second; + if (index < nodes.size()) { + return nodes[index]; + } else { + this->diag_ctx->Emit( + Diagnostic::Error(meta_ref->span) + << "the node index `" << index << "` is out of bounds for `" << type_key << "`"); + return ObjectRef(); + } + } else { + this->diag_ctx->Emit( + Diagnostic::Error(meta_ref->span) + << "no entry in the meta table for `" << type_key << "`"); + return ObjectRef(); + } } /*! \brief Parses a sequence beginning with a start token, seperated by a seperator token, and * ending with a stop token. @@ -607,8 +627,7 @@ class Parser { auto mod = IRModule({}, types); for (auto func : defs.funcs) { - auto function = ExpandMetaRefs(metadata, func.function); - mod->Add(func.global, function); + mod->Add(func.global, func.function); } return mod; @@ -801,8 +820,14 @@ class Parser { case TokenType::kFreeVar: { Consume(TokenType::kFreeVar); auto var_token = Match(TokenType::kLocal); - Match(TokenType::kColon); - auto type = ParseType(); + + Type type; + if (WhenMatch(TokenType::kColon)) { + type = ParseType(); + } else { + type = IncompleteType(); + } + BindFreeVar(var_token.ToString(), type); break; } @@ -950,7 +975,7 @@ class Parser { /*! Parse a function definition without a leading keyword or identifier. * - * Handles things of the form [T1, ..., TN](arg1: U1, ..., argN, UN) -> Ret { body }. + * Handles things of the form [T1, ..., TN](arg1: U1, ..., argN : UN) -> Ret { body }. */ Function ParseFunctionDef() { DLOG(INFO) << "Parser::ParseFunctionDef"; @@ -968,6 +993,8 @@ class Parser { }); } + Map raw_attrs; + auto params = ParseSequence(TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, [&]() { auto token = Match(TokenType::kLocal); @@ -977,6 +1004,16 @@ class Parser { type = ParseType(); } return BindVar(string, type); + }, [&] { + auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier; + auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual; + + if (is_ident && next_is_equal) { + raw_attrs = ParseAttrs(); + return true; + } + + return false; }); Type ret_type; @@ -990,7 +1027,12 @@ class Parser { PopTypeScopes(1); PopScopes(1); - return relay::Function(params, body, ret_type, generics); + // TODO(@jroesch): attributes should never be null, they should always be empty. + if (raw_attrs.size()) { + return relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs)); + } else { + return relay::Function(params, body, ret_type, generics); + } } /*! \brief Parse an if-expression. */ @@ -1170,6 +1212,22 @@ class Parser { return ParseSequence(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { return ParseAttributeValue(); }); } + case TokenType::kOpenParen: { + // TODO(@jroesch: need to figure out bracket vs. sequence) + // return ParseSequence(TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, + // [&]() { return ParseAttributeValue(); }); + return Bracket(TokenType::kOpenParen, TokenType::kCloseParen, [&]() { return ParseAttributeValue(); }); + } + // TODO(@jroesch): not sure about this being the right way to handle nulls. + case TokenType::kIdentifier: { + if (auto text = next->data.as()) { + std::string id = GetRef(text); + if (id == "nullptr") { + Match(TokenType::kIdentifier); + return ObjectRef(); + } + } + } default: return ParseAtomicExpr(); } @@ -1278,6 +1336,7 @@ class Parser { } Expr GetOp(const std::string& op_name, const Token& tok) { + DLOG(INFO) << "op_name=" << op_name << " token=" << tok; try { return Op::Get(op_name); } catch (dmlc::Error e) { @@ -1335,6 +1394,7 @@ class Parser { return Expr(ctor.value()); } else { auto idents = ParseHierName(); + CHECK_NE(idents.size(), 0); std::stringstream op_name; int i = 0; int periods = idents.size() - 1; @@ -1354,8 +1414,6 @@ class Parser { } case TokenType::kMetaReference: { return Downcast(ParseMetaRef()); - Consume(TokenType::kMetaReference); - return Downcast(next->data); } case TokenType::kFn: { Consume(TokenType::kFn); @@ -1408,7 +1466,8 @@ class Parser { Array ParseHierName() { Array idents; while (Peek()->token_type == TokenType::kIdentifier) { - idents.push_back(Peek().ToString()); + auto name = Peek().ToString(); + idents.push_back(name); Consume(TokenType::kIdentifier); if (Peek()->token_type == TokenType::kPeriod) { @@ -1426,8 +1485,14 @@ class Parser { Array ParseShape() { auto dims = ParseSequence(TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, [&]() { - auto tok = Match(TokenType::kInteger); - return Downcast(tok->data); + tvm::PrimExpr dim; + if (Peek()->token_type == TokenType::kMetaReference) { + dim = Downcast(ParseMetaRef()); + } else { + dim = Downcast(Match(TokenType::kInteger)->data); + } + + return dim; }); return dims; } @@ -1565,10 +1630,12 @@ class Parser { IRModule ParseModule(std::string file_name, std::string file_content) { DLOG(INFO) << "ParseModule"; SourceName src_name = SourceName::Get(file_name); - Source src(file_content); + Source src(src_name, file_content); DiagnosticContext ctx(src); - auto tokens = Tokenize(&ctx, src_name, file_content); - Parser parser(&ctx, src_name, tokens, DefaultOpTable(), Source(file_content)); + auto tokens_and_table = Tokenize(&ctx, src_name, file_content); + auto tokens = tokens_and_table.first; + auto meta_data_table = tokens_and_table.second; + Parser parser(&ctx, src_name, tokens, DefaultOpTable(), src, meta_data_table.ToMetadata()); auto mod = parser.ParseModule(); // NB(@jroesch): it is very important that we render any errors before we procede // if there were any errors which allow the parser to procede we must render them @@ -1580,10 +1647,12 @@ IRModule ParseModule(std::string file_name, std::string file_content) { Expr ParseExpr(std::string file_name, std::string file_content) { DLOG(INFO) << "ParseExpr"; SourceName src_name = SourceName::Get(file_name); - Source src(file_content); + Source src(src_name, file_content); DiagnosticContext ctx(src); - auto tokens = Tokenize(&ctx, src_name, file_content); - Parser parser(&ctx, src_name, tokens, DefaultOpTable(), Source(file_content)); + auto tokens_and_table = Tokenize(&ctx, src_name, file_content); + auto tokens = tokens_and_table.first; + auto meta_data_table = tokens_and_table.second; + Parser parser(&ctx, src_name, tokens, DefaultOpTable(), src, meta_data_table.ToMetadata()); parser.ParseSemVer(false); parser.PushScope(); auto expr = parser.ParseExpr(); diff --git a/src/parser/source_map.cc b/src/parser/source_map.cc index beb32da7126c..549f7d33738e 100644 --- a/src/parser/source_map.cc +++ b/src/parser/source_map.cc @@ -27,7 +27,7 @@ namespace tvm { namespace parser { /*! \brief Construct a source from a string. */ -Source::Source(const std::string& source) : source(source) { +Source::Source(const SourceName& src_name, const std::string& source) : source_name(src_name), source(source) { int index = 0; int length = 0; line_map.push_back({index, length}); diff --git a/src/parser/token.h b/src/parser/token.h index 86a26cbada52..3750ec568cc8 100644 --- a/src/parser/token.h +++ b/src/parser/token.h @@ -378,7 +378,12 @@ int64_t Token::ToNumber() const { return Downcast(this->operator-> std::string Token::ToString() const { return Downcast(this->operator->()->data); } Map> Token::ToMetadata() const { - return Downcast>>(this->operator->()->data); + ObjectRef data = this->operator->()->data; + if (data.defined()) { + return Downcast>>(data); + } else { + return Map>({}); + } } } // namespace parser diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index f500d4ac6d58..7357106da41c 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -533,12 +533,22 @@ struct Tokenizer { tokens() {} }; -std::vector Condense(const std::vector& tokens) { +std::vector Condense(const std::vector& tokens, Token* table) { std::vector out; + bool found_metadata = false; for (size_t i = 0; i < tokens.size(); i++) { auto current = tokens.at(i); switch (current->token_type) { + case TokenType::kMetadata: { + if (!found_metadata) { + found_metadata = true; + *table = current; + } else { + LOG(FATAL) << "duplicate metadata section"; + } + continue; + } case TokenType::kPercent: { auto next = tokens.at(i + 1); if (next->token_type == TokenType::kIdentifier) { @@ -602,15 +612,16 @@ std::vector Condense(const std::vector& tokens) { return out; } -std::vector Tokenize(DiagnosticContext* ctx, const SourceName& source_name, +std::pair, Token> Tokenize(DiagnosticContext* ctx, const SourceName& source_name, const std::string& source) { auto tokenizer = Tokenizer(ctx, source_name, source); tokenizer.Tokenize(); - auto tokens = Condense(tokenizer.tokens); + Token meta_table(Span(), TokenType::kUnknown, ObjectRef()); + auto tokens = Condense(tokenizer.tokens, &meta_table); for (auto token : tokens) { CHECK(token.defined()); } - return tokens; + return { tokens, meta_table }; } } // namespace parser diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 2ec24cc50c25..50b87d2b94b0 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -234,9 +234,10 @@ def test_vars(): assert op.name == "nn.global_avg_pool2d" def test_meta_ref(): - meta_op = parse_text("meta[type_key][1337]") - assert meta_op.attrs.node_type_key == "type_key" - assert meta_op.attrs.node_index == 1337 + with pytest.raises(tvm.error.DiagnosticError): + meta_op = parse_text("meta[type_key][1337]") + assert meta_op.attrs.node_type_key == "type_key" + assert meta_op.attrs.node_index == 1337 def test_let(): diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index cd677096b2e9..52551bf68e77 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -17,19 +17,18 @@ import tvm from tvm import te from tvm import relay -import tvm.relay.testing +from tvm.relay import testing import numpy as np from tvm.relay import Expr from tvm.relay.analysis import free_vars -do_print = [True] +DEBUG_PRINT = False -SEMVER = "#[version = \"0.0.4\"]\n" +SEMVER = "#[version = \"0.0.5\"]\n" def astext(program, unify_free_vars=False): text = program.astext() print(text) - if isinstance(program, Expr): roundtrip_program = tvm.parser.parse_expr(text) else: @@ -40,7 +39,7 @@ def astext(program, unify_free_vars=False): return text def show(text): - if do_print[0]: + if DEBUG_PRINT: print("---------------------------") print(text) @@ -137,55 +136,55 @@ def test_variable_name(): def test_mlp(): - net, params = tvm.relay.testing.mlp.get_workload(batch_size=1) + net, _ = tvm.relay.testing.mlp.get_workload(batch_size=1) astext(net) def test_resnet(): - net, params = tvm.relay.testing.resnet.get_workload(batch_size=1) + net, _ = tvm.relay.testing.resnet.get_workload(batch_size=1) astext(net) def test_mobilenet(): - net, params = tvm.relay.testing.mobilenet.get_workload(batch_size=1) + net, _ = tvm.relay.testing.mobilenet.get_workload(batch_size=1) astext(net) def test_dqn(): - net, params = tvm.relay.testing.dqn.get_workload(batch_size=1) + net, _ = tvm.relay.testing.dqn.get_workload(batch_size=1) astext(net) def test_dcgan(): - net, params = tvm.relay.testing.dcgan.get_workload(batch_size=1) + net, _ = tvm.relay.testing.dcgan.get_workload(batch_size=1) astext(net) def test_lstm(): - net, params = tvm.relay.testing.lstm.get_workload(1, 1) + net, _ = tvm.relay.testing.lstm.get_workload(1, 1) astext(net) - net, params = tvm.relay.testing.lstm.get_workload(4, 4) + net, _ = tvm.relay.testing.lstm.get_workload(4, 4) astext(net) def test_inception_v3(): - net, params = tvm.relay.testing.inception_v3.get_workload(batch_size=1) + net, _ = tvm.relay.testing.inception_v3.get_workload(batch_size=1) astext(net) def test_squeezenet(): for version in ['1.0', '1.1']: - net, params = tvm.relay.testing.squeezenet.get_workload(batch_size=1, version=version) + net, _ = tvm.relay.testing.squeezenet.get_workload(batch_size=1, version=version) astext(net) def test_vgg(): - net, params = tvm.relay.testing.vgg.get_workload(batch_size=1) + net, _ = tvm.relay.testing.vgg.get_workload(batch_size=1) astext(net) def test_densenet(): - net, params = tvm.relay.testing.densenet.get_workload(batch_size=1) + net, _ = tvm.relay.testing.densenet.get_workload(batch_size=1) astext(net) @@ -234,7 +233,7 @@ def @main[A]() -> fn (A, List[A]) -> List[A] { Cons } """ - mod = relay.fromtext(SEMVER + type_def_str + main_def_str) + mod = tvm.parser.parse(SEMVER + type_def_str + main_def_str) mod_str = str(mod) # ensure constructors are printed correctly in type definitions (with their # signature) and as exprs (without their signature) @@ -252,25 +251,5 @@ def test_null_attribute(): if __name__ == "__main__": - do_print[0] = True - test_lstm() - # test_zeros() - # test_meta_data() - # test_let_inlining() - # test_resnet() - # test_mobilenet() - # test_mlp() - # test_dqn() - # test_dcgan() - # test_squeezenet() - # test_inception_v3() - # test_vgg() - # test_densenet() - # test_func() - # test_env() - # test_call_attrs() - # test_let_if_scope() - # test_variable_name() - # test_call_node_order() - # test_unapplied_constructor() - # test_null_attribute() + import sys + pytext.argv(sys.argv) From 2a8d67a0be0003f9d968f8cb2a5dbd2468a52b1e Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 7 Aug 2020 17:47:02 -0700 Subject: [PATCH 43/48] format --- include/tvm/parser/source_map.h | 3 +- src/ir/span.cc | 7 +- src/parser/parser.cc | 116 ++++++++++++++++---------------- src/parser/source_map.cc | 3 +- src/parser/tokenizer.h | 12 ++-- 5 files changed, 74 insertions(+), 67 deletions(-) diff --git a/include/tvm/parser/source_map.h b/include/tvm/parser/source_map.h index 1e897b5444c9..339417eeb84e 100644 --- a/include/tvm/parser/source_map.h +++ b/include/tvm/parser/source_map.h @@ -55,7 +55,8 @@ struct Source { /*! \brief Construct a source from a string. */ TVM_DLL explicit Source(const SourceName& src_name, const std::string& source); - TVM_DLL Source(const Source& source) : source_name(source.source_name), source(source.source), line_map(source.line_map) {} + TVM_DLL Source(const Source& source) + : source_name(source.source_name), source(source.source), line_map(source.line_map) {} /*! \brief Generate an error message at a specific line and column with the * annotated message. diff --git a/src/ir/span.cc b/src/ir/span.cc index e936feae1723..e67e7046efb7 100644 --- a/src/ir/span.cc +++ b/src/ir/span.cc @@ -81,14 +81,15 @@ Span Span::Merge(const Span& other) { TVM_REGISTER_NODE_TYPE(SpanNode); -TVM_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source, int line, int end_line, int column, - int end_column) { +TVM_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source, int line, int end_line, + int column, int end_column) { return Span(source, line, end_line, column, end_column); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); - p->stream << "Span(" << node->source << ", " << node->line << ", " << node->end_line << ", " << node->column << ", " << node->end_column << ")"; + p->stream << "Span(" << node->source << ", " << node->line << ", " << node->end_line << ", " + << node->column << ", " << node->end_column << ")"; }); } // namespace tvm diff --git a/src/parser/parser.cc b/src/parser/parser.cc index a68d69cd5a0d..19774c66dd65 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -125,9 +125,7 @@ class ScopeStack { this->scope_stack.back().name_map.insert({name, value}); } - void AddFreeVar(const std::string& name, const T& value) { - free_vars.insert({name, value}); - } + void AddFreeVar(const std::string& name, const T& value) { free_vars.insert({name, value}); } /*! \brief Looks up a variable name in the scope stack returning the matching variable * in most recent scope. */ @@ -416,8 +414,8 @@ class Parser { Var LookupLocal(const Token& local) { auto var = this->expr_scopes.Lookup(local.ToString()); if (!var.defined()) { - diag_ctx->Emit(Diagnostic::Error(local->span) << - "this local variable has not been previously declared"); + diag_ctx->Emit(Diagnostic::Error(local->span) + << "this local variable has not been previously declared"); } return var; } @@ -429,8 +427,9 @@ class Parser { TypeVar LookupTypeVar(const Token& ident) { auto var = this->type_scopes.Lookup(ident.ToString()); if (!var.defined()) { - diag_ctx->Emit(Diagnostic::Error(ident->span) - << "this type variable has not been previously declared anywhere, perhaps a typo?"); + diag_ctx->Emit( + Diagnostic::Error(ident->span) + << "this type variable has not been previously declared anywhere, perhaps a typo?"); } return var; } @@ -529,16 +528,15 @@ class Parser { if (index < nodes.size()) { return nodes[index]; } else { - this->diag_ctx->Emit( - Diagnostic::Error(meta_ref->span) - << "the node index `" << index << "` is out of bounds for `" << type_key << "`"); + this->diag_ctx->Emit(Diagnostic::Error(meta_ref->span) + << "the node index `" << index << "` is out of bounds for `" + << type_key << "`"); return ObjectRef(); } } else { - this->diag_ctx->Emit( - Diagnostic::Error(meta_ref->span) - << "no entry in the meta table for `" << type_key << "`"); - return ObjectRef(); + this->diag_ctx->Emit(Diagnostic::Error(meta_ref->span) + << "no entry in the meta table for `" << type_key << "`"); + return ObjectRef(); } } /*! \brief Parses a sequence beginning with a start token, seperated by a seperator token, and @@ -555,7 +553,8 @@ class Parser { template Array ParseSequence(TokenType start, TokenType sep, TokenType stop, std::function parse, std::function before_stop = nullptr) { - DLOG(INFO) << "Parser::ParseSequence: start=" << ToString(start) << "sep=" << ToString(sep) << "stop=" << ToString(stop); + DLOG(INFO) << "Parser::ParseSequence: start=" << ToString(start) << "sep=" << ToString(sep) + << "stop=" << ToString(stop); Match(start); // This is for the empty arguments list case, if we have token stream @@ -717,8 +716,8 @@ class Parser { // If we have generics we need to add a type scope. PushTypeScope(); should_pop = true; - generics = - ParseSequence(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { + generics = ParseSequence( + TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { auto type_var_name = Match(TokenType::kIdentifier).ToString(); return BindTypeVar(type_var_name, TypeKind::kType); }); @@ -739,8 +738,8 @@ class Parser { ctor = tvm::Constructor(ctor_name, {}, type_global); } else { auto arg_types = - ParseSequence(TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, - [&]() { return ParseType(); }); + ParseSequence(TokenType::kOpenParen, TokenType::kComma, + TokenType::kCloseParen, [&]() { return ParseType(); }); ctor = tvm::Constructor(ctor_name, arg_types, type_global); } @@ -986,8 +985,8 @@ class Parser { if (Peek()->token_type == TokenType::kLSquare) { // If we have generics we need to add a type scope. PushTypeScope(); - generics = - ParseSequence(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { + generics = ParseSequence( + TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { auto type_var_name = Match(TokenType::kIdentifier).ToString(); return BindTypeVar(type_var_name, TypeKind::kType); }); @@ -995,8 +994,9 @@ class Parser { Map raw_attrs; - auto params = - ParseSequence(TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, [&]() { + auto params = ParseSequence( + TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, + [&]() { auto token = Match(TokenType::kLocal); auto string = token.ToString(); Type type; @@ -1004,16 +1004,17 @@ class Parser { type = ParseType(); } return BindVar(string, type); - }, [&] { - auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier; - auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual; - - if (is_ident && next_is_equal) { - raw_attrs = ParseAttrs(); - return true; - } + }, + [&] { + auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier; + auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual; + + if (is_ident && next_is_equal) { + raw_attrs = ParseAttrs(); + return true; + } - return false; + return false; }); Type ret_type; @@ -1214,9 +1215,11 @@ class Parser { } case TokenType::kOpenParen: { // TODO(@jroesch: need to figure out bracket vs. sequence) - // return ParseSequence(TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, + // return ParseSequence(TokenType::kOpenParen, TokenType::kComma, + // TokenType::kCloseParen, // [&]() { return ParseAttributeValue(); }); - return Bracket(TokenType::kOpenParen, TokenType::kCloseParen, [&]() { return ParseAttributeValue(); }); + return Bracket(TokenType::kOpenParen, TokenType::kCloseParen, + [&]() { return ParseAttributeValue(); }); } // TODO(@jroesch): not sure about this being the right way to handle nulls. case TokenType::kIdentifier: { @@ -1241,7 +1244,8 @@ class Parser { Match(TokenType::kEqual); // TOOD(@jroesch): syntactically what do we allow to appear in attribute right hand side. auto value = ParseAttributeValue(); - // TODO(@jroesch): we need a robust way to handle this writing dtypes as strings in text format is bad. + // TODO(@jroesch): we need a robust way to handle this writing dtypes as strings in text + // format is bad. kwargs.Set(key, value); WhenMatch(TokenType::kComma); } @@ -1280,9 +1284,9 @@ class Parser { Attrs attrs; if (is_op && op_key.size()) { - auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs); - CHECK(attr_obj.defined()); - attrs = Downcast(attr_obj); + auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs); + CHECK(attr_obj.defined()); + attrs = Downcast(attr_obj); } return Expr(Call(op, args, attrs, {})); @@ -1290,10 +1294,9 @@ class Parser { return Expr(); } } catch (...) { - // TODO(@jroesch): AttrErrors should have fields - this->diag_ctx->Emit( - Diagnostic::Error(Peek()->span)); - // << err.what()); + // TODO(@jroesch): AttrErrors should have fields + this->diag_ctx->Emit(Diagnostic::Error(Peek()->span)); + // << err.what()); } return Expr(); @@ -1317,10 +1320,9 @@ class Parser { break; } } catch (...) { - // TODO(@jroesch): AttrErrors should have fields - this->diag_ctx->EmitFatal( - Diagnostic::Error(Peek()->span)); - // << err.what()); + // TODO(@jroesch): AttrErrors should have fields + this->diag_ctx->EmitFatal(Diagnostic::Error(Peek()->span)); + // << err.what()); } } @@ -1483,17 +1485,17 @@ class Parser { /*! \brief Parse a shape. */ Array ParseShape() { - auto dims = ParseSequence(TokenType::kOpenParen, TokenType::kComma, - TokenType::kCloseParen, [&]() { - tvm::PrimExpr dim; - if (Peek()->token_type == TokenType::kMetaReference) { - dim = Downcast(ParseMetaRef()); - } else { - dim = Downcast(Match(TokenType::kInteger)->data); - } - - return dim; - }); + auto dims = ParseSequence( + TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, [&]() { + tvm::PrimExpr dim; + if (Peek()->token_type == TokenType::kMetaReference) { + dim = Downcast(ParseMetaRef()); + } else { + dim = Downcast(Match(TokenType::kInteger)->data); + } + + return dim; + }); return dims; } diff --git a/src/parser/source_map.cc b/src/parser/source_map.cc index 549f7d33738e..a2efdb5a88fd 100644 --- a/src/parser/source_map.cc +++ b/src/parser/source_map.cc @@ -27,7 +27,8 @@ namespace tvm { namespace parser { /*! \brief Construct a source from a string. */ -Source::Source(const SourceName& src_name, const std::string& source) : source_name(src_name), source(source) { +Source::Source(const SourceName& src_name, const std::string& source) + : source_name(src_name), source(source) { int index = 0; int length = 0; line_map.push_back({index, length}); diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index 7357106da41c..6c9349d66817 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -66,9 +66,11 @@ bool IsIdentLetter(char c) { return '_' == c || ('a' <= c && c <= 'z') || ('A' < bool IsIdent(char c) { return IsIdentLetter(c) || IsDigit(c); } static std::unordered_map KEYWORD_TABLE = { - {"let", TokenType::kLet}, {"fn", TokenType::kFn}, {"def", TokenType::kDefn}, - {"if", TokenType::kIf}, {"else", TokenType::kElse}, {"type", TokenType::kTypeDef}, - {"match", TokenType::kMatch}, {"extern", TokenType::kExtern}, {"free_var", TokenType::kFreeVar}}; + {"let", TokenType::kLet}, {"fn", TokenType::kFn}, + {"def", TokenType::kDefn}, {"if", TokenType::kIf}, + {"else", TokenType::kElse}, {"type", TokenType::kTypeDef}, + {"match", TokenType::kMatch}, {"extern", TokenType::kExtern}, + {"free_var", TokenType::kFreeVar}}; struct Tokenizer { DiagnosticContext* diag_ctx; @@ -613,7 +615,7 @@ std::vector Condense(const std::vector& tokens, Token* table) { } std::pair, Token> Tokenize(DiagnosticContext* ctx, const SourceName& source_name, - const std::string& source) { + const std::string& source) { auto tokenizer = Tokenizer(ctx, source_name, source); tokenizer.Tokenize(); Token meta_table(Span(), TokenType::kUnknown, ObjectRef()); @@ -621,7 +623,7 @@ std::pair, Token> Tokenize(DiagnosticContext* ctx, const Sour for (auto token : tokens) { CHECK(token.defined()); } - return { tokens, meta_table }; + return {tokens, meta_table}; } } // namespace parser From 470bd8687d85170de4d2c4e340939663d770eb71 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 7 Aug 2020 18:36:02 -0700 Subject: [PATCH 44/48] Formatted --- python/tvm/error.py | 5 ++++- src/parser/diagnostic.h | 30 ++++++++++++------------------ src/parser/parser.cc | 34 +++++++++++++++++++++++----------- src/parser/tokenizer.h | 1 + 4 files changed, 40 insertions(+), 30 deletions(-) diff --git a/python/tvm/error.py b/python/tvm/error.py index 398d22ada25c..9125448e30c9 100644 --- a/python/tvm/error.py +++ b/python/tvm/error.py @@ -124,4 +124,7 @@ class OpAttributeUnImplemented(OpError, NotImplementedError): @register_error class DiagnosticError(TVMError): - pass + """Error diagnostics were reported during the execution of a pass. + + See the configured diagnostic renderer for detailed error information. + """ diff --git a/src/parser/diagnostic.h b/src/parser/diagnostic.h index e43c4b295501..2eb38b312242 100644 --- a/src/parser/diagnostic.h +++ b/src/parser/diagnostic.h @@ -44,12 +44,12 @@ namespace tvm { namespace parser { /*! \brief The diagnostic level, controls the printing of the message. */ -enum DiagnosticLevel { - Bug, - Error, - Warning, - Note, - Help, +enum class DiagnosticLevel { + kBug, + kError, + kWarning, + kNote, + kHelp, }; struct DiagnosticBuilder; @@ -63,12 +63,6 @@ struct Diagnostic { /*! \brief The diagnostic message. */ std::string message; - /*! \brief A diagnostic for a single character token. */ - Diagnostic(int line, int column, const std::string& message) - : level(DiagnosticLevel::Error), - span(SourceName(), line, column, line, column + 1), - message(message) {} - Diagnostic(DiagnosticLevel level, Span span, const std::string& message) : level(level), span(span), message(message) {} @@ -110,7 +104,7 @@ struct DiagnosticBuilder { return *this; } - DiagnosticBuilder() : level(DiagnosticLevel::Error), source_name(), span(Span()) {} + DiagnosticBuilder() : level(DiagnosticLevel::kError), source_name(), span(Span()) {} DiagnosticBuilder(const DiagnosticBuilder& builder) : level(builder.level), source_name(builder.source_name), span(builder.span) {} @@ -125,23 +119,23 @@ struct DiagnosticBuilder { }; DiagnosticBuilder Diagnostic::Bug(Span span) { - return DiagnosticBuilder(DiagnosticLevel::Bug, span); + return DiagnosticBuilder(DiagnosticLevel::kBug, span); } DiagnosticBuilder Diagnostic::Error(Span span) { - return DiagnosticBuilder(DiagnosticLevel::Error, span); + return DiagnosticBuilder(DiagnosticLevel::kError, span); } DiagnosticBuilder Diagnostic::Warning(Span span) { - return DiagnosticBuilder(DiagnosticLevel::Warning, span); + return DiagnosticBuilder(DiagnosticLevel::kWarning, span); } DiagnosticBuilder Diagnostic::Note(Span span) { - return DiagnosticBuilder(DiagnosticLevel::Note, span); + return DiagnosticBuilder(DiagnosticLevel::kNote, span); } DiagnosticBuilder Diagnostic::Help(Span span) { - return DiagnosticBuilder(DiagnosticLevel::Note, span); + return DiagnosticBuilder(DiagnosticLevel::kHelp, span); } /*! \brief A diagnostic context for recording errors against a source file. diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 19774c66dd65..468dff2eba3b 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -322,7 +322,7 @@ class Parser { */ void Consume(const TokenType& token_type) { if (tokens[pos]->token_type != token_type) { - this->diag_ctx->EmitFatal(DiagnosticBuilder(DiagnosticLevel::Error, tokens[pos]->span) + this->diag_ctx->EmitFatal(Diagnostic::Error(tokens[pos]->span) << "expected a " << Pretty(token_type) << " found " << Pretty(Peek()->token_type)); } @@ -598,7 +598,7 @@ class Parser { return elements; } else { auto next = Peek(); - this->diag_ctx->EmitFatal(DiagnosticBuilder(DiagnosticLevel::Error, next->span) + this->diag_ctx->EmitFatal(Diagnostic::Error(next->span) << "expected a " << Pretty(stop) << " found " << Pretty(next->token_type)); return Array(nullptr); @@ -638,7 +638,7 @@ class Parser { auto version = Match(TokenType::kVersion); // TODO(@jroesch): we currently only support 0.0.5. if (version.ToString() != "\"0.0.5\"") { - this->diag_ctx->Emit(DiagnosticBuilder(DiagnosticLevel::Error, version->span) + this->diag_ctx->Emit(Diagnostic::Error(version->span) << "invalid semantic version `" << version.ToString() << "`"); } } else if (required) { @@ -681,8 +681,8 @@ class Parser { Consume(TokenType::kExtern); auto type_def = ParseTypeDef(); if (type_def->constructors.size()) { - diag_ctx->Emit({DiagnosticLevel::Error, next->span, - "an external type may not have any constructors"}); + diag_ctx->Emit(Diagnostic::Error(next->span) + << "an external type may not have any constructors"); } defs.types.push_back(type_def); } @@ -1342,7 +1342,7 @@ class Parser { try { return Op::Get(op_name); } catch (dmlc::Error e) { - this->diag_ctx->Emit(DiagnosticBuilder(DiagnosticLevel::Error, tok->span) + this->diag_ctx->Emit(Diagnostic::Error(tok->span) << "operator `" << op_name << "` not found, perhaps you forgot to register it?"); return Expr(); @@ -1395,7 +1395,7 @@ class Parser { Consume(TokenType::kIdentifier); return Expr(ctor.value()); } else { - auto idents = ParseHierName(); + auto idents = ParseHierarchicalName(); CHECK_NE(idents.size(), 0); std::stringstream op_name; int i = 0; @@ -1448,7 +1448,7 @@ class Parser { } } default: { - this->diag_ctx->EmitFatal(DiagnosticBuilder(DiagnosticLevel::Error, next->span) + this->diag_ctx->EmitFatal(Diagnostic::Error(next->span) << "expected an expression found " << Pretty(next->token_type)); return Expr(); @@ -1464,18 +1464,30 @@ class Parser { return expr; } - /*! \brief Parse a hierarchical name. */ - Array ParseHierName() { + /*! \brief Parse a hierarchical name. + * + * The tokenizer produces a token stream of . + * and so on for names of the form `nn.conv2d`. + * Currently we only use string names everywhere instead + * of a notion of a hierarchical name. + * + * The below utility reassembles a token stream into a + * single stream inserting the required periods needed + * to look up registered names. + */ + Array ParseHierarchicalName() { Array idents; while (Peek()->token_type == TokenType::kIdentifier) { auto name = Peek().ToString(); idents.push_back(name); Consume(TokenType::kIdentifier); + // Keep parsing while we see a trailing period. if (Peek()->token_type == TokenType::kPeriod) { Consume(TokenType::kPeriod); continue; } else { + // No more periods means we are done! break; } } @@ -1577,7 +1589,7 @@ class Parser { if (WhenMatch(TokenType::kUnderscore)) { return IncompleteType(); } else { - this->diag_ctx->EmitFatal(DiagnosticBuilder(DiagnosticLevel::Error, tok->span) + this->diag_ctx->EmitFatal(Diagnostic::Error(tok->span) << "failed to parse type found " << tok); return Type(); } diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index 6c9349d66817..8fc0e85836fc 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -32,6 +32,7 @@ #include #include #include +#include #include "./meta_ref.h" #include "./token.h" From 43dfe425aa3ad420eebc7f21667a6f8036302117 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sat, 8 Aug 2020 22:24:46 -0700 Subject: [PATCH 45/48] More formatting --- src/parser/parser.cc | 2 +- src/parser/tokenizer.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 468dff2eba3b..7877245725ea 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -682,7 +682,7 @@ class Parser { auto type_def = ParseTypeDef(); if (type_def->constructors.size()) { diag_ctx->Emit(Diagnostic::Error(next->span) - << "an external type may not have any constructors"); + << "an external type may not have any constructors"); } defs.types.push_back(type_def); } diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index 8fc0e85836fc..88a49290dc3d 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -31,8 +31,8 @@ #include #include #include -#include #include +#include #include "./meta_ref.h" #include "./token.h" From 7fa4c3f819daa2237de185e507bb314e131045a8 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 10 Aug 2020 23:26:40 -0700 Subject: [PATCH 46/48] Repair test cases --- include/tvm/ir/span.h | 12 +++++----- include/tvm/parser/source_map.h | 4 ++-- python/tvm/ir/base.py | 4 ++-- src/ir/span.cc | 14 +++++------ tests/python/relay/test_ir_nodes.py | 12 ++++++---- tests/python/relay/test_ir_parser.py | 8 +++---- tests/python/relay/test_op_level10.py | 8 +++---- tests/python/relay/test_pass_eta_expand.py | 24 +++++++++---------- .../python/relay/test_pass_unmatched_cases.py | 4 ++-- 9 files changed, 47 insertions(+), 43 deletions(-) diff --git a/include/tvm/ir/span.h b/include/tvm/ir/span.h index be8799eaca19..95a1acb9412d 100644 --- a/include/tvm/ir/span.h +++ b/include/tvm/ir/span.h @@ -80,7 +80,7 @@ class Span; class SpanNode : public Object { public: /*! \brief The source name. */ - SourceName source; + SourceName source_name; /*! \brief The line number. */ int line; /*! \brief The column offset. */ @@ -92,15 +92,15 @@ class SpanNode : public Object { // override attr visitor void VisitAttrs(AttrVisitor* v) { - v->Visit("source", &source); + v->Visit("source_name", &source_name); v->Visit("line", &line); v->Visit("column", &column); - v->Visit("end_line", &line); - v->Visit("end_column", &column); + v->Visit("end_line", &end_line); + v->Visit("end_column", &end_column); } bool SEqualReduce(const SpanNode* other, SEqualReducer equal) const { - return equal(source, other->source) && equal(line, other->line) && + return equal(source_name, other->source_name) && equal(line, other->line) && equal(column, other->column) && equal(end_line, other->end_line) && equal(end_column, other->end_column); } @@ -111,7 +111,7 @@ class SpanNode : public Object { class Span : public ObjectRef { public: - TVM_DLL Span(SourceName source, int line, int end_line, int column, int end_column); + TVM_DLL Span(SourceName source_name, int line, int end_line, int column, int end_column); /*! \brief Merge two spans into one which captures the combined regions. */ TVM_DLL Span Merge(const Span& other); diff --git a/include/tvm/parser/source_map.h b/include/tvm/parser/source_map.h index 339417eeb84e..cf926665e218 100644 --- a/include/tvm/parser/source_map.h +++ b/include/tvm/parser/source_map.h @@ -91,8 +91,8 @@ class SourceMapNode : public Object { return equal(source_map, other->source_map); } - static constexpr const char* _type_key = "Span"; - TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object); + static constexpr const char* _type_key = "SourceMap"; + TVM_DECLARE_FINAL_OBJECT_INFO(SourceMapNode, Object); }; class SourceMap : public ObjectRef { diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index bab98382e713..b505a2ee00bb 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -84,9 +84,9 @@ class Span(Object): col_offset : int The column offset of the location. """ - def __init__(self, source, lineno, col_offset): + def __init__(self, source_name, line, end_line, column, end_column): self.__init_handle_by_constructor__( - _ffi_api.Span, source, lineno, col_offset) + _ffi_api.Span, source_name, line, end_line, column, end_column) @tvm._ffi.register_object diff --git a/src/ir/span.cc b/src/ir/span.cc index e67e7046efb7..ffb44c4c7af9 100644 --- a/src/ir/span.cc +++ b/src/ir/span.cc @@ -61,9 +61,9 @@ TVM_REGISTER_NODE_TYPE(SourceNameNode) return static_cast(n)->name; }); -Span::Span(SourceName source, int line, int column, int end_line, int end_column) { +Span::Span(SourceName source_name, int line, int end_line, int column, int end_column) { auto n = make_object(); - n->source = std::move(source); + n->source_name = std::move(source_name); n->line = line; n->end_line = end_line; n->column = column; @@ -72,8 +72,8 @@ Span::Span(SourceName source, int line, int column, int end_line, int end_column } Span Span::Merge(const Span& other) { - CHECK((*this)->source == other->source); - return Span((*this)->source, std::min((*this)->line, other->line), + CHECK((*this)->source_name == other->source_name); + return Span((*this)->source_name, std::min((*this)->line, other->line), std::max((*this)->end_line, other->end_line), std::min((*this)->column, other->column), std::max((*this)->end_column, other->end_column)); @@ -81,15 +81,15 @@ Span Span::Merge(const Span& other) { TVM_REGISTER_NODE_TYPE(SpanNode); -TVM_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source, int line, int end_line, +TVM_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source_name, int line, int end_line, int column, int end_column) { - return Span(source, line, end_line, column, end_column); + return Span(source_name, line, end_line, column, end_column); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); - p->stream << "Span(" << node->source << ", " << node->line << ", " << node->end_line << ", " + p->stream << "Span(" << node->source_name << ", " << node->line << ", " << node->end_line << ", " << node->column << ", " << node->end_column << ")"; }); } // namespace tvm diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index fed257fafd4a..b53423a2f4c1 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -31,10 +31,12 @@ def check_json_roundtrip(node): # Span def test_span(): - span = relay.Span(None, 1, 1) - assert span.source == None + span = relay.Span(None, 1, 2, 3, 4) + assert span.source_name == None assert span.line == 1 - assert span.column == 1 + assert span.end_line == 2 + assert span.column == 3 + assert span.end_column == 4 assert span.same_as(span) assert span == span assert isinstance(span, relay.base.Span) @@ -43,9 +45,11 @@ def test_span(): # span is not a node so we can't use graph_equal # to test the round trip back = tvm.ir.load_json(tvm.ir.save_json(span)) - assert back.source == span.source + assert back.source_name == span.source_name assert back.line == span.line + assert back.end_line == span.end_line assert back.column == span.column + assert back.end_column == span.end_column def test_constant(): diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 50b87d2b94b0..3fcc7dab5bcd 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -895,9 +895,9 @@ def test_import_grad(): mod.import_from_std("gradient.rly") def test_resnet(): - mod, params = relay.testing.resnet.get_workload() - text = str(mod.astext()) - parsed_mod = parse_module(text) + mod, _ = relay.testing.resnet.get_workload() + text = mod.astext() + parsed_mod = tvm.parser.parse(text) tvm.ir.assert_structural_equal(mod, parsed_mod) def inline_params(mod, params): @@ -919,7 +919,7 @@ def test_resnet_inlined_params(): mod, params = relay.testing.resnet.get_workload() mod = inline_params(mod, params) text = mod.astext() - parsed_mod = parse_module(text) + parsed_mod = tvm.parser.parse(text) tvm.ir.assert_structural_equal(mod, parsed_mod) if __name__ == "__main__": diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index f65407acbcc9..c0a990ba9d2e 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -59,9 +59,9 @@ def test_checkpoint_alpha_equal(): mod = tvm.transform.Sequential(passes)(tvm.IRModule.from_expr(df)) df = mod["main"] - df_parsed = relay.parser.fromtext( + df_parsed = tvm.parser.parse_expr( """ - v0.0.4 + #[version = "0.0.5"] fn (%x: Tensor[(1), float32], %y: Tensor[(1), float32], %z: Tensor[(1), float32], %w: Tensor[(1), float32]) -> (Tensor[(1), float32], @@ -115,9 +115,9 @@ def test_checkpoint_alpha_equal_tuple(): mod = tvm.transform.Sequential(passes)(tvm.IRModule.from_expr(df)) df = mod["main"] - df_parsed = relay.parser.fromtext( + df_parsed = tvm.parser.parse_expr( """ - v0.0.4 + #[version = "0.0.5"] fn (%x: Tensor[(1), float32], %y: Tensor[(1), float32], %z: Tensor[(1), float32], %w: Tensor[(1), float32]) -> ((Tensor[(1), float32], Tensor[(1), float32]), diff --git a/tests/python/relay/test_pass_eta_expand.py b/tests/python/relay/test_pass_eta_expand.py index e0a189b5c2ee..05c5f0328e22 100644 --- a/tests/python/relay/test_pass_eta_expand.py +++ b/tests/python/relay/test_pass_eta_expand.py @@ -24,24 +24,24 @@ import tvm.relay.transform as _transform def test_eta_expand_global_var(): - mod = relay.fromtext(r""" - v0.0.4 + mod = tvm.parser.fromtext(r""" + #[version = "0.0.5"] def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] { %x } - def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) { + def @main() -> fn(Tensor[(), int32]) -> Tensor[(), int32] { @aux } """) seq = tvm.transform.Sequential([_transform.EtaExpand(expand_global_var=True)]) with tvm.transform.PassContext(opt_level=3): mod = seq(mod) - expected = relay.fromtext(r""" - v0.0.4 + expected = tvm.parser.fromtext(r""" + #[version = "0.0.5"] def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] { %x } - def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) { + def @main() -> fn(Tensor[(), int32]) -> Tensor[(), int32] { fn (%x: Tensor[(), int32]) -> Tensor[(), int32] { @aux(%x) } @@ -52,26 +52,26 @@ def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) { def test_eta_expand_constructor(): - mod = relay.fromtext(r""" - v0.0.4 + mod = tvm.parser.fromtext(r""" + #[version = "0.0.5"] type List[A] { Cons(A, List[A]), Nil, } - def @main[A]() -> (fn(A, List[A]) -> List[A]) { + def @main[A]() -> fn(A, List[A]) -> List[A] { Cons } """) seq = tvm.transform.Sequential([_transform.EtaExpand(expand_constructor=True)]) with tvm.transform.PassContext(opt_level=3): mod = seq(mod) - expected = relay.fromtext(r""" - v0.0.4 + expected = tvm.parser.fromtext(r""" + #[version = "0.0.5"] type List[A] { Cons(A, List[A]), Nil, } - def @main[A]() -> (fn(A, List[A]) -> List[A]) { + def @main[A]() -> fn(A, List[A]) -> List[A] { fn [A](%x: A, %xs: List[A]) -> List[A] { Cons(%x, %xs) } diff --git a/tests/python/relay/test_pass_unmatched_cases.py b/tests/python/relay/test_pass_unmatched_cases.py index 42344bccabaa..07193e104a7c 100644 --- a/tests/python/relay/test_pass_unmatched_cases.py +++ b/tests/python/relay/test_pass_unmatched_cases.py @@ -279,7 +279,7 @@ def test_tuple_match(): def test_inf_loop_case(): code = """ -v0.0.4 +#[version = "0.0.5"] type Arith[A] { Zero, Const(A), @@ -294,7 +294,7 @@ def @shallow_opt[A](%a: Arith[A]) -> Arith[A] { } } """ - relay.fromtext(code) + tvm.parser.fromtext(code) # fromtext parse the module, then checked it (which include strictness checking). if __name__ == "__main__": From dd6108326260559402b4ea919923c1f0a591f041 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 10 Aug 2020 23:27:40 -0700 Subject: [PATCH 47/48] Fix CI --- src/ir/span.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ir/span.cc b/src/ir/span.cc index ffb44c4c7af9..d9c9bbc47c34 100644 --- a/src/ir/span.cc +++ b/src/ir/span.cc @@ -89,7 +89,7 @@ TVM_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source_name, int lin TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); - p->stream << "Span(" << node->source_name << ", " << node->line << ", " << node->end_line << ", " - << node->column << ", " << node->end_column << ")"; + p->stream << "Span(" << node->source_name << ", " << node->line << ", " << node->end_line + << ", " << node->column << ", " << node->end_column << ")"; }); } // namespace tvm From 77397badf8124e80334d41abdf2827ea47283a76 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 11 Aug 2020 11:23:00 -0700 Subject: [PATCH 48/48] Retrigger CI