diff --git a/.gitignore b/.gitignore index 410a36aecdec..04dad2039860 100644 --- a/.gitignore +++ b/.gitignore @@ -209,3 +209,7 @@ tvm_t.* # patch sentinel patched.txt + +# Python type checking +.mypy_cache/ +.pyre/ diff --git a/CMakeLists.txt b/CMakeLists.txt index 98bbc5b650d3..363b2056a87a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -47,6 +47,7 @@ tvm_option(USE_ROCBLAS "Build with ROCM:RoCBLAS" OFF) tvm_option(USE_SORT "Build with sort support" OFF) tvm_option(USE_NNPACK "Build with nnpack support" OFF) tvm_option(USE_RANDOM "Build with random support" OFF) +tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF) # include directories include_directories("include") @@ -183,6 +184,7 @@ include(cmake/modules/Metal.cmake) include(cmake/modules/ROCM.cmake) include(cmake/modules/SGX.cmake) include(cmake/modules/LLVM.cmake) +include(cmake/modules/ANTLR.cmake) include(cmake/modules/contrib/BLAS.cmake) include(cmake/modules/contrib/Random.cmake) include(cmake/modules/contrib/Sort.cmake) diff --git a/Jenkinsfile b/Jenkinsfile index adc9e12ca74b..02f00e42e8fd 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -98,6 +98,7 @@ stage('Build') { echo set\\(USE_GRAPH_RUNTIME ON\\) >> config.cmake echo set\\(USE_STACKVM_RUNTIME ON\\) >> config.cmake echo set\\(USE_GRAPH_RUNTIME_DEBUG ON\\) >> config.cmake + echo set\\(USE_ANTLR ON\\) >> config.cmake echo set\\(USE_BLAS openblas\\) >> config.cmake echo set\\(CMAKE_CXX_COMPILER g++\\) >> config.cmake echo set\\(CMAKE_CXX_FLAGS -Werror\\) >> config.cmake @@ -133,6 +134,7 @@ stage('Build') { echo set\\(USE_LLVM llvm-config-4.0\\) >> config.cmake echo set\\(USE_NNPACK ON\\) >> config.cmake echo set\\(NNPACK_PATH /NNPACK/build/\\) >> config.cmake + echo set\\(USE_ANTLR ON\\) >> config.cmake echo set\\(CMAKE_CXX_COMPILER g++\\) >> config.cmake echo set\\(CMAKE_CXX_FLAGS -Werror\\) >> config.cmake """ diff --git a/cmake/config.cmake b/cmake/config.cmake index a92be7ce3008..a97def410ddd 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -128,3 +128,6 @@ set(USE_ROCBLAS OFF) # Whether use contrib sort set(USE_SORT OFF) + +# Build ANTLR parser for Relay text format +set(USE_ANTLR OFF) diff --git a/cmake/modules/ANTLR.cmake b/cmake/modules/ANTLR.cmake new file mode 100644 index 000000000000..72eb5925bda0 --- /dev/null +++ b/cmake/modules/ANTLR.cmake @@ -0,0 +1,28 @@ +if(USE_ANTLR) + if(EXISTS /usr/local/lib/antlr-4.7.1-complete.jar) + set(ANTLR4 "/usr/local/lib/antlr-4.7.1-complete.jar") + + set(RELAY_PARSER_DIR + ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm/relay/grammar) + + set(RELAY_PARSER + ${RELAY_PARSER_DIR}/py2/RelayVisitor.py + ${RELAY_PARSER_DIR}/py2/RelayParser.py + ${RELAY_PARSER_DIR}/py2/RelayLexer.py + + ${RELAY_PARSER_DIR}/py3/RelayVisitor.py + ${RELAY_PARSER_DIR}/py3/RelayParser.py + ${RELAY_PARSER_DIR}/py3/RelayLexer.py) + + # Generate ANTLR grammar for parsing. + add_custom_command(OUTPUT ${RELAY_PARSER} + COMMAND $ENV{JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python2 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py2 + COMMAND $ENV{JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py3 + DEPENDS ${RELAY_PARSER_DIR}/Relay.g4 + WORKING_DIRECTORY ${RELAY_PARSER_DIR}) + + add_custom_target(relay_parser ALL DEPENDS ${RELAY_PARSER}) + else() + message(FATAL_ERROR "Can't find ANTLR4!") + endif() +endif(USE_ANTLR) diff --git a/docker/Dockerfile.ci_cpu b/docker/Dockerfile.ci_cpu index 11a77adbfdde..e6e2dd7a37b0 100644 --- a/docker/Dockerfile.ci_cpu +++ b/docker/Dockerfile.ci_cpu @@ -40,10 +40,3 @@ COPY install/ubuntu_install_nnpack.sh /install/ubuntu_install_nnpack.sh RUN bash /install/ubuntu_install_nnpack.sh ENV PATH $PATH:$CARGO_HOME/bin:/usr/lib/go-1.10/bin - -# ANTLR deps -COPY install/ubuntu_install_java.sh /install/ubuntu_install_java.sh -RUN bash /install/ubuntu_install_java.sh - -COPY install/ubuntu_install_antlr.sh /install/ubuntu_install_antlr.sh -RUN bash /install/ubuntu_install_antlr.sh diff --git a/docker/install/ubuntu_install_antlr.sh b/docker/install/ubuntu_install_antlr.sh index f1066c4220d4..d2f2d6a8c48f 100644 --- a/docker/install/ubuntu_install_antlr.sh +++ b/docker/install/ubuntu_install_antlr.sh @@ -1,5 +1,3 @@ cd /usr/local/lib wget https://www.antlr.org/download/antlr-4.7.1-complete.jar cd - - -alias antlr4='java -jar /usr/local/lib/antlr-4.7.1-complete.jar' diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 6b071f65a794..b66132f27775 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -8,6 +8,7 @@ from . import module from . import ir_pass from .build_module import build, build_config, create_executor +from . import parser # Root operators from .op import Op @@ -52,7 +53,6 @@ If = expr.If TupleGetItem = expr.TupleGetItem - # helper functions var = expr.var const = expr.const @@ -63,3 +63,6 @@ def _debug(*args): import pdb pdb.set_trace() + +# Parser +fromtext = parser.fromtext diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py new file mode 100644 index 000000000000..f64c635dd4ff --- /dev/null +++ b/python/tvm/relay/_parser.py @@ -0,0 +1,425 @@ + +# pylint: disable=invalid-name, unused-import +"""A parser for Relay's text format.""" +from __future__ import absolute_import + +import sys + +from collections import deque +from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable, Any + +from . import module +from . import expr +from . import ty +from . import op + +class ParseError(Exception): + """Exception type for parse errors.""" + + def __init__(self, message): + # type: (str) -> None + super(ParseError, self).__init__() + self.message = message + +PYTHON_VERSION = sys.version_info.major +try: + if PYTHON_VERSION == 2: + from .grammar.py2.RelayVisitor import RelayVisitor + from .grammar.py2.RelayParser import RelayParser + from .grammar.py2.RelayLexer import RelayLexer + else: + from .grammar.py3.RelayVisitor import RelayVisitor + from .grammar.py3.RelayParser import RelayParser + from .grammar.py3.RelayLexer import RelayLexer +except ImportError: + raise ParseError("Couldn't find ANTLR parser. Try building with USE_ANTLR=ON.") + +try: + from antlr4 import ParserRuleContext, InputStream, CommonTokenStream + from antlr4.tree.Tree import TerminalNode +except ImportError: + raise ParseError("Couldn't find ANTLR runtime." + + "Try running `pip{} install antlr4-python{}-runtime`." + .format(PYTHON_VERSION, PYTHON_VERSION)) + +BINARY_OPS = { + RelayParser.MUL: op.multiply, + RelayParser.DIV: op.divide, + RelayParser.ADD: op.add, + RelayParser.SUB: op.subtract, + RelayParser.LT: op.less, + RelayParser.GT: op.greater, + RelayParser.LE: op.less_equal, + RelayParser.GE: op.greater_equal, + RelayParser.EQ: op.equal, + RelayParser.NE: op.not_equal, +} + +TYPE_PREFIXES = [ + "int", + "uint", + "float", + "bool", +] + +T = TypeVar("T") +Scope = Deque[Tuple[str, T]] +Scopes = Deque[Scope[T]] + +def lookup(scopes, name): + # type: (Scopes[T], str) -> Optional[T] + """Look up `name` in `scopes`.""" + + for scope in scopes: + for key, val in scope: + if key == name: + return val + return None + +# TODO(@jmp): Use https://stackoverflow.com/q/13889941 +# to figure out how to get ANTLR4 to be more unhappy about syntax errors +class ParseTreeToRelayIR(RelayVisitor): + """Parse Relay text format into Relay IR.""" + + def __init__(self): + # type: () -> None + self.module = module.Module({}) # type: module.Module + + # Adding an empty scope allows naked lets without pain. + self.var_scopes = deque([deque()]) # type: Scopes[expr.Var] + self.type_param_scopes = deque([deque()]) # type: Scopes[ty.TypeVar] + + super(ParseTreeToRelayIR, self).__init__() + + def enter_var_scope(self): + # type: () -> None + """Enter a new Var scope so it can be popped off later.""" + + self.var_scopes.appendleft(deque()) + + def exit_var_scope(self): + # type: () -> Scope[expr.Var] + """Pop off the current Var scope and return it.""" + + return self.var_scopes.popleft() + + def mk_var(self, name, type_): + # type: (str, ty.Type) -> expr.Var + """Create a new Var and add it to the Var scope.""" + + var = expr.Var(name, type_) + self.var_scopes[0].appendleft((name, var)) + return var + + def enter_type_param_scope(self): + # type: () -> None + """Enter a new TypeVar scope so it can be popped off later.""" + + self.type_param_scopes.appendleft(deque()) + + def exit_type_param_scope(self): + # type: () -> Scope[ty.TypeVar] + """Pop off the current TypeVar scope and return it.""" + + return self.type_param_scopes.popleft() + + def mk_typ(self, name, kind): + # (str, ty.Kind) -> ty.TypeVar + """Create a new TypeVar and add it to the TypeVar scope.""" + + typ = ty.TypeVar(name, kind) + self.type_param_scopes[0].appendleft((name, typ)) + return typ + + def visitTerminal(self, node): + # type: (TerminalNode) -> Union[expr.Expr, int, float] + """Visit lexer tokens that aren't ignored or visited by other functions.""" + + node_type = node.getSymbol().type + node_text = node.getText() + + # variables + if node_type == RelayLexer.GLOBAL_VAR: + return expr.GlobalVar(node_text[1:]) + elif node_type == RelayLexer.LOCAL_VAR: + name = node_text[1:] + var = lookup(self.var_scopes, name) + if var is None: + raise ParseError("Couldn't resolve `{}`.".format(name)) + + return var + + # data types + elif node_type == RelayLexer.INT: + return int(node_text) + elif node_type == RelayLexer.FLOAT: + return float(node_text) + elif node_type == RelayLexer.BOOL_LIT: + if node_text == "True": + return True + elif node_text == "False": + return False + else: + raise ParseError("Unrecognized BOOL_LIT: `{}`".format(node_text)) + + else: + raise ParseError("todo: {}".format(node_text)) + + def visit_list(self, ctx_list): + # type: (List[ParserRuleContext]) -> List[Any] + """"Visit a list of contexts.""" + + return [self.visit(ctx) for ctx in ctx_list] + + def getType_(self, ctx): + # type: (Optional[RelayParser.Type_Context]) -> Optional[ty.Type] + """Return a (possibly None) Relay type.""" + + if ctx is None: + return None + + return self.visit(ctx) + + def visitProg(self, ctx): + # type: (RelayParser.ProgContext) -> Union[expr.Expr, env.Environment] + if ctx.defn(): + self.visit_list(ctx.defn()) + return self.module + + return self.visit(ctx.expr()) + + # Exprs + + def visitOpIdent(self, ctx): + # type: (RelayParser.OpIdentContext) -> op.Op + return op.get(ctx.CNAME().getText()) + + # pass through + def visitParens(self, ctx): + # type: (RelayParser.ParensContext) -> expr.Expr + return self.visit(ctx.expr()) + + # pass through + def visitBody(self, ctx): + # type: (RelayParser.BodyContext) -> expr.Expr + return self.visit(ctx.expr()) + + def visitScalarFloat(self, ctx): + # type: (RelayParser.ScalarFloatContext) -> expr.Constant + return expr.const(self.visit(ctx.FLOAT())) + + def visitScalarInt(self, ctx): + # type: (RelayParser.ScalarIntContext) -> expr.Constant + return expr.const(self.visit(ctx.INT())) + + def visitScalarBool(self, ctx): + # type: (RelayParser.ScalarBoolContext) -> expr.Constant + return expr.const(self.visit(ctx.BOOL_LIT())) + + def visitNeg(self, ctx): + # type: (RelayParser.NegContext) -> Union[expr.Constant, expr.Call] + val = self.visit(ctx.expr()) + if isinstance(val, expr.Constant) and val.data.asnumpy().ndim == 0: + # fold Neg in for scalars + return expr.const(-val.data.asnumpy().item()) + + return op.negative(val) + + def visitTuple(self, ctx): + # type: (RelayParser.TupleContext) -> expr.Tuple + tup = self.visit_list(ctx.expr()) + return expr.Tuple(tup) + + # Currently doesn't support mutable sequencing. + def visitSeq(self, ctx): + # type: (RelayParser.SeqContext) -> expr.Let + """Desugar various sequence constructs to Relay Let nodes.""" + if ctx.MUT() is not None: + raise ParseError("Mutation is currently unsupported.") + + if ctx.var() is None or ctx.var().ident() is None: + # anonymous identity + ident = "_" + type_ = None + else: + local_var = ctx.var().ident().LOCAL_VAR() + if local_var is None: + raise ParseError('Only local ids may be used in `let`s.') + ident = local_var.getText()[1:] + type_ = self.getType_(ctx.var().type_()) + + var = self.mk_var(ident, type_) + + self.enter_var_scope() + value = self.visit(ctx.expr(0)) + self.exit_var_scope() + + body = self.visit(ctx.expr(1)) + + return expr.Let(var, value, body) + + def visitBinOp(self, ctx): + # type: (RelayParser.BinOpContext) -> expr.Call + """Desugar binary operators.""" + arg0, arg1 = self.visit_list(ctx.expr()) + relay_op = BINARY_OPS.get(ctx.op.type) + + if relay_op is None: + raise ParseError("Unimplemented binary op.") + + return relay_op(arg0, arg1) + + def visitVar(self, ctx): + # type: (RelayParser.VarContext) -> expr.Var + ident = ctx.ident().LOCAL_VAR() + + if ident is None: + raise ParseError('Only local ids may be used in params.') + + type_ = self.getType_(ctx.type_()) + + return self.mk_var(ident.getText()[1:], type_) + + def visitVarList(self, ctx): + # type: (RelayParser.VarListContext) -> List[expr.Var] + return self.visit_list(ctx.var()) + + def mk_func(self, ctx): + # type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> Function + """Construct a function from either a Func or Defn.""" + + # Enter var scope early to put params in scope. + self.enter_var_scope() + # Capture type params in params. + self.enter_type_param_scope() + var_list = self.visit(ctx.varList()) + ret_type = self.getType_(ctx.type_()) + + type_params = list(self.exit_type_param_scope()) + if type_params: + _, type_params = zip(*type_params) + + body = self.visit(ctx.body()) + self.exit_var_scope() + + return expr.Function(var_list, body, ret_type, type_params) # type: ignore + + def visitFunc(self, ctx): + # type: (RelayParser.FuncContext) -> expr.Function + return self.mk_func(ctx) + + def visitDefn(self, ctx): + # type: (RelayParser.DefnContext) -> None + ident = ctx.ident().GLOBAL_VAR() + if ident is None: + raise ParseError('Only global ids may be used in `def`s.') + ident = expr.GlobalVar(ident.getText()[1:]) + + self.module[ident] = self.mk_func(ctx) + + def visitCall(self, ctx): + # type: (RelayParser.CallContext) -> expr.Call + visited_exprs = self.visit_list(ctx.expr()) + + func = visited_exprs[0] + args = visited_exprs[1:] + + return expr.Call(func, args, None, None) + + def visitIfElse(self, ctx): + # type: (RelayParser.IfElseContext) -> expr.If + """Construct a Relay If node. Creates a new scope for each branch.""" + cond = self.visit(ctx.expr()) + + self.enter_var_scope() + true_branch = self.visit(ctx.body(0)) + self.exit_var_scope() + + self.enter_var_scope() + false_branch = self.visit(ctx.body(1)) + self.exit_var_scope() + + return expr.If(cond, true_branch, false_branch) + + # Types + + # pylint: disable=unused-argument + def visitIncompleteType(self, ctx): + # type (RelayParser.IncompleteTypeContext) -> None: + return None + + def visitIdentType(self, ctx): + # type: (RelayParser.IdentTypeContext) -> Union[ty.TensorType, str] + ident_type = ctx.CNAME().getText() + + # look through all type prefixes for a match + for type_prefix in TYPE_PREFIXES: + if ident_type.startswith(type_prefix): + return ty.scalar_type(ident_type) + + raise ParseError("Unknown builtin type: {}".format(ident_type)) + + # def visitCallType(self, ctx): + # # type: (RelayParser.CallTypeContext) -> Union[expr.Expr, ty.TensorType] + # ident_type = ctx.identType().CNAME().getText() + + # args = self.visit_list(ctx.type_()) + + # if not args: + # raise ParseError("Type-level functions must have arguments!") + + # func_type = TYPE_FUNCS.get(ident_type)(args) + + # if func_type is None: + # raise ParseError("Unknown type-level function: `{}`".format(ident_type)) + # else: + # return func_type + + def visitParensShape(self, ctx): + # type: (RelayParser.ParensShapeContext) -> int + return self.visit(ctx.shape()) + + def visitShapeSeq(self, ctx): + # type: (RelayParser.ShapeSeqContext) -> List[int] + return self.visit_list(ctx.shape()) + + def visitTensorType(self, ctx): + # type: (RelayParser.TensorTypeContext) -> ty.TensorType + """Create a simple tensor type. No generics.""" + + shape = self.visit(ctx.shapeSeq()) + dtype = self.visit(ctx.type_()) + + if not isinstance(dtype, ty.TensorType): + raise ParseError("Expected dtype to be a Relay base type.") + + dtype = dtype.dtype + + return ty.TensorType(shape, dtype) + + def visitTupleType(self, ctx): + # type: (RelayParser.TupleTypeContext) -> ty.TupleType + return ty.TupleType(self.visit_list(ctx.type_())) + + def visitFuncType(self, ctx): + # type: (RelayParser.FuncTypeContext) -> ty.FuncType + types = self.visit_list(ctx.type_()) + + arg_types = types[:-1] + ret_type = types[-1] + + return ty.FuncType(arg_types, ret_type, [], None) + +def make_parser(data): + # type: (str) -> RelayParser + """Construct a RelayParser a given data stream.""" + input_stream = InputStream(data) + lexer = RelayLexer(input_stream) + token_stream = CommonTokenStream(lexer) + return RelayParser(token_stream) + +def fromtext(data): + # type: (str) -> Union[expr.Expr, env.Environment] + """Parse a Relay program.""" + tree = make_parser(data).prog() + return ParseTreeToRelayIR().visit(tree) diff --git a/python/tvm/relay/expr.pyi b/python/tvm/relay/expr.pyi index e73a5963e5b1..bc2e5115df0d 100644 --- a/python/tvm/relay/expr.pyi +++ b/python/tvm/relay/expr.pyi @@ -22,7 +22,7 @@ class Constant(Expr): class Tuple(Expr): - fields = .. # type: List[Expr] + fields = ... # type: List[Expr] def __init__(self, fields): # type: (List[Expr]) -> None @@ -77,10 +77,10 @@ class Call(Expr): """A function call in Relay, see tvm/relay/expr.h for more details.""" op = ... # type: Expr args = ... # type: List[Expr] - # todo(@jroesch): add attrs + # todo(@jroesch): add attrs. revise attrs type in __init__ - def __init__(self, op, args, attrs, ty_args=None): - # type: (Expr, List[Expr], Optional[List[Type]]) -> None + def __init__(self, op, args, attrs=None, ty_args=None): + # type: (Expr, List[Expr], Optional[List[Any]], Optional[List[Type]]) -> None if not ty_args: ty_args = [] diff --git a/python/tvm/relay/grammar/.gitignore b/python/tvm/relay/grammar/.gitignore new file mode 100644 index 000000000000..cffe35e1a41a --- /dev/null +++ b/python/tvm/relay/grammar/.gitignore @@ -0,0 +1 @@ +/.antlr/ diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 new file mode 100644 index 000000000000..c74a42c97e77 --- /dev/null +++ b/python/tvm/relay/grammar/Relay.g4 @@ -0,0 +1,146 @@ +grammar Relay; + +// Lexing +// comments +WS : [ \t\n\r]+ -> skip ; +LINE_COMMENT : '//' .*? '\n' -> skip ; +COMMENT : '/*' .*? '*/' -> skip ; + +// operators +MUL: '*' ; +DIV: '/' ; +ADD: '+' ; +SUB: '-' ; +LT: '<' ; +GT: '>' ; +LE: '<=' ; +GE: '>=' ; +EQ: '==' ; +NE: '!=' ; + +opIdent: CNAME ; +GLOBAL_VAR: '@' CNAME ; +LOCAL_VAR: '%' CNAME ; + +MUT: 'mut' ; + +BOOL_LIT + : 'True' + | 'False' + ; + +// non-negative floats +FLOAT + : INT '.' INT EXP? // 1.35, 1.35E-9, 0.3, 4.5 + | INT EXP // 1e10 3e4 + ; + +// non-negative ints +INT: DIGIT+ ; +fragment EXP: [eE] [+\-]? INT ; // \- since - means "range" inside [...] + +CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ; +fragment LETTER: [a-zA-Z] ; +fragment DIGIT: [0-9] ; + +// Parsing + +// A Relay program is a list of global definitions or an expression. +prog: (defn* | expr) EOF ; + +// option: 'set' ident BOOL_LIT ; + +expr + // operators + : '(' expr ')' # parens + | '-' expr # neg + | expr op=('*'|'/') expr # binOp + | expr op=('+'|'-') expr # binOp + | expr op=('<'|'>'|'<='|'>=') expr # binOp + | expr op=('=='|'!=') expr # binOp + + // function definition and application + | expr '(' (expr (',' expr)*)? ')' # call + | func # funcExpr + + // tuples and tensors + | '(' ')' # tuple + | '(' expr ',' ')' # tuple + | '(' expr (',' expr)+ ')' # tuple + | '[' (expr (',' expr)*)? ']' # tensor + + | 'if' '(' expr ')' body 'else' body # ifElse + + // sequencing + | 'let' MUT? var '=' expr ';' expr # seq + | 'let' MUT? var '=' '{' expr '}' ';' expr # seq + // sugar for let %_ = expr; expr + | expr ';' expr # seq + + // mutable update + // | ident '=' expr # writeRef + // | expr '^' # readRef + + | ident # identExpr + | scalar # scalarExpr + // | expr '.' INT # project + // | 'debug' # debug + ; + +func: 'fn' varList ('->' type_)? body ; +defn: 'def' ident varList ('->' type_)? body ; + +varList: '(' (var (',' var)*)? ')' ; +var: ident (':' type_)? ; + +// TODO(@jmp): for improved type annotations +// returnAnno: (ident ':')? type_ ; + +// relations: 'where' relation (',' relation)* ; +// relation: ident '(' (type_ (',' type_)*)? ')' ; + +type_ + : '(' ')' # tupleType + | '(' type_ ',' ')' # tupleType + | '(' type_ (',' type_)+ ')' # tupleType + | identType # identTypeType + | 'Tensor' '[' shapeSeq ',' type_ ']' # tensorType + // currently unused + // | identType '[' (type_ (',' type_)*)? ']' # callType + | 'fn' '(' (type_ (',' type_)*)? ')' '->' type_ # funcType + | '_' # incompleteType + | INT # intType + ; + +shapeSeq + : '(' ')' + | '(' shape ',' ')' + | '(' shape (',' shape)+ ')' + ; + +shape + : '(' shape ')' # parensShape + // | type_ op=('*'|'/') type_ # binOpType + // | type_ op=('+'|'-') type_ # binOpType + | INT # intShape + ; + +identType: CNAME ; +// Int8, Int16, Int32, Int64 +// UInt8, UInt16, UInt32, UInt64 +// Float16, Float32, Float64 +// Bool + +body: '{' expr '}' ; + +scalar + : FLOAT # scalarFloat + | INT # scalarInt + | BOOL_LIT # scalarBool + ; + +ident + : opIdent + | GLOBAL_VAR + | LOCAL_VAR + ; diff --git a/python/tvm/relay/grammar/__init__.py b/python/tvm/relay/grammar/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/tvm/relay/grammar/py2/.gitignore b/python/tvm/relay/grammar/py2/.gitignore new file mode 100644 index 000000000000..d677ff551940 --- /dev/null +++ b/python/tvm/relay/grammar/py2/.gitignore @@ -0,0 +1 @@ +Relay* diff --git a/python/tvm/relay/grammar/py2/__init__.py b/python/tvm/relay/grammar/py2/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/tvm/relay/grammar/py3/.gitignore b/python/tvm/relay/grammar/py3/.gitignore new file mode 100644 index 000000000000..d677ff551940 --- /dev/null +++ b/python/tvm/relay/grammar/py3/.gitignore @@ -0,0 +1 @@ +Relay* diff --git a/python/tvm/relay/grammar/py3/__init__.py b/python/tvm/relay/grammar/py3/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py new file mode 100644 index 000000000000..51200343f147 --- /dev/null +++ b/python/tvm/relay/parser.py @@ -0,0 +1,17 @@ +"""A parser for Relay's text format.""" +from __future__ import absolute_import + +def enabled(): + """Is the parser enabled/Can we import the parser?""" + try: + # pylint: disable=unused-variable + from tvm.relay import _parser + return True + # pylint: disable=broad-except + except Exception: + return False + +def fromtext(data): + """Parse a Relay program.""" + from tvm.relay import _parser + return _parser.fromtext(data) diff --git a/python/tvm/relay/ty.pyi b/python/tvm/relay/ty.pyi index 221fc228081d..933814853f3e 100644 --- a/python/tvm/relay/ty.pyi +++ b/python/tvm/relay/ty.pyi @@ -156,7 +156,7 @@ class FuncType(Type): class IncompleteType(Type): """An incomplete type.""" - def __init__(self, kind): + def __init__(self, kind=Kind.Type): self.__init_handle_by_constructor__(_make.IncompleteType, kind) @register_relay_node