Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cmake/modules/ANTLR.cmake
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
if(USE_ANTLR)
if(EXISTS /usr/local/lib/antlr-4.7.1-complete.jar)
set(ANTLR4 "/usr/local/lib/antlr-4.7.1-complete.jar")
if(EXISTS /usr/local/Cellar/antlr/4.7.1_1/antlr-4.7.1-complete.jar)
set(ANTLR4 "/usr/local/Cellar/antlr/4.7.1_1/antlr-4.7.1-complete.jar")

set(RELAY_PARSER_DIR
${CMAKE_CURRENT_SOURCE_DIR}/python/tvm/relay/grammar)
Expand Down
4 changes: 3 additions & 1 deletion include/tvm/relay/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ class SourceName : public NodeRef {
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const SourceNameNode* operator->() const;
inline const SourceNameNode* operator->() const {
return static_cast<SourceNameNode*>(this->node_.get());
}

/*!
* \brief Get an SourceName for a given operator name.
Expand Down
17 changes: 17 additions & 0 deletions include/tvm/relay/error.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@

#include <string>
#include "./base.h"
#include "./source_map.h"

namespace tvm {
namespace relay {

struct Error : public dmlc::Error {
Span sp;
explicit Error(const std::string &msg) : dmlc::Error(msg) {}
};

Expand All @@ -28,6 +30,21 @@ struct TypecheckerError : public Error {
explicit TypecheckerError(const std::string &msg) : Error(msg) {}
};

class ErrorReporter {
public:
SourceMap src_map;
std::vector<Error> errors;

ErrorReporter() : src_map(), errors() {}
ErrorReporter(SourceMap src_map) : errors() {}

void Report(const Error& err) {
this->errors.push_back(err);
}

dmlc::Error Render();
};

} // namespace relay
} // namespace tvm

Expand Down
11 changes: 11 additions & 0 deletions include/tvm/relay/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,16 @@ class ModuleNode : public RelayNode {
/*! \brief A map from ids to all global functions. */
tvm::Map<GlobalVar, Function> functions;

mutable ErrorReporter err_reporter;

mutable GlobalVar entry_func;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra space


ModuleNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("functions", &functions);
v->Visit("global_var_map_", &global_var_map_);
v->Visit("entry_func", &entry_func);
}

TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs);
Expand Down Expand Up @@ -102,6 +107,12 @@ class ModuleNode : public RelayNode {
*/
void Update(const Module& other);

/*!
* \brief Get the entry point of the module.
*
*/
Expr EntryPoint();

static constexpr const char* _type_key = "relay.Module";
TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node);

Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ namespace relay {
* \return A type checked expression with its checked_type field populated.
*/
Expr InferType(const Expr& expr, const Module& mod);

/*!
* \brief Infer the type of a function as if it is mapped to var in the mod.
*
Expand Down
46 changes: 46 additions & 0 deletions include/tvm/relay/source_map.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*!
* Copyright (c) 2018 by Contributors
* \file source_map.h
* \brief A representation of source files and a data structure for
* storing them.
*/
#ifndef TVM_RELAY_SOURCE_MAP_H_
#define TVM_RELAY_SOURCE_MAP_H_

#include <tvm/relay/base.h>
#include <tvm/relay/expr.h>
#include <string>
#include <vector>

namespace tvm {
namespace relay {

struct SourceFragment {
SourceName name;
std::vector<std::string> source_lines;

SourceFragment(const SourceName& file_name, const std::string& source);

SourceFragment(const SourceFragment& sf) {
this->name = sf.name;
this->source_lines = sf.source_lines;
}

std::vector<std::string> LinesAt(Span sp, int lines);
};

/*! \brief Maps from FileId's to a SourceFragment.
*/
class SourceMap {
/*! \brief Map from unique token to a fragment of a source file. */
std::unordered_map<SourceName, SourceFragment, NodeHash> map_;
public:
SourceMap() : map_() {}
SourceName AddSource(const std::string& file_name, const std::string& source);
SourceName AddSource(const SourceName& source_name, const std::string& source);
const SourceFragment & GetSource(SourceName id) const;
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_SOURCE_MAP_H_
5 changes: 5 additions & 0 deletions python/tvm/relay/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface of expr function exposed from C++."""
from tvm._ffi.function import _init_api

_init_api("relay._base", __name__)
72 changes: 61 additions & 11 deletions python/tvm/relay/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable, Any

from . import module
from .base import Span, SourceName
from . import expr
from . import ty
from . import op


class ParseError(Exception):
"""Exception type for parse errors."""

Expand Down Expand Up @@ -76,21 +78,35 @@ def lookup(scopes, name):
return val
return None

def spanify(f):
def _wrapper(*args, **kwargs):
sn = args[0].source_name
ctx = args[1]
ast = f(*args, **kwargs)
line, col = ctx.getSourceInterval()
sp = Span(sn, line, col)
ast.set_span(sp)
return ast
return _wrapper

# TODO(@jmp): Use https://stackoverflow.com/q/13889941
# to figure out how to get ANTLR4 to be more unhappy about syntax errors
class ParseTreeToRelayIR(RelayVisitor):
"""Parse Relay text format into Relay IR."""

def __init__(self):
def __init__(self, source_name):
# type: () -> None
self.source_name = source_name
self.module = module.Module({}) # type: module.Module

# Adding an empty scope allows naked lets without pain.
self.var_scopes = deque([deque()]) # type: Scopes[expr.Var]
self.type_param_scopes = deque([deque()]) # type: Scopes[ty.TypeVar]
self.graph_expr = []

super(ParseTreeToRelayIR, self).__init__()


def enter_var_scope(self):
# type: () -> None
"""Enter a new Var scope so it can be popped off later."""
Expand Down Expand Up @@ -131,6 +147,20 @@ def mk_typ(self, name, kind):
self.type_param_scopes[0].appendleft((name, typ))
return typ

def local_lookup(self, name):
try:
graph_nid = int(name)
return self.graph_expr[graph_nid]
except ValueError:
var = lookup(self.var_scopes, name)

if var is None:
raise ParseError("Couldn't resolve `{}`.".format(name))

return var
except IndexError:
raise ParseError("Graph Expr error {}".format(name))

def visitTerminal(self, node):
# type: (TerminalNode) -> Union[expr.Expr, int, float]
"""Visit lexer tokens that aren't ignored or visited by other functions."""
Expand All @@ -142,13 +172,8 @@ def visitTerminal(self, node):
if node_type == RelayLexer.GLOBAL_VAR:
return expr.GlobalVar(node_text[1:])
elif node_type == RelayLexer.LOCAL_VAR:
name = node_text[1:]
var = lookup(self.var_scopes, name)
if var is None:
raise ParseError("Couldn't resolve `{}`.".format(name))

return var

# Remove the leading '%' and lookup the name.
return self.local_lookup(node_text[1:])
# data types
elif node_type == RelayLexer.INT:
return int(node_text)
Expand Down Expand Up @@ -269,6 +294,7 @@ def visitBinOp(self, ctx):

return relay_op(arg0, arg1)

@spanify
def visitVar(self, ctx):
# type: (RelayParser.VarContext) -> expr.Var
ident = ctx.ident().LOCAL_VAR()
Expand Down Expand Up @@ -304,10 +330,12 @@ def mk_func(self, ctx):

return expr.Function(var_list, body, ret_type, type_params) # type: ignore

@spanify
def visitFunc(self, ctx):
# type: (RelayParser.FuncContext) -> expr.Function
return self.mk_func(ctx)

@spanify
def visitDefn(self, ctx):
# type: (RelayParser.DefnContext) -> None
ident = ctx.ident().GLOBAL_VAR()
Expand All @@ -317,6 +345,7 @@ def visitDefn(self, ctx):

self.module[ident] = self.mk_func(ctx)

@spanify
def visitCall(self, ctx):
# type: (RelayParser.CallContext) -> expr.Call
visited_exprs = self.visit_list(ctx.expr())
Expand All @@ -326,6 +355,7 @@ def visitCall(self, ctx):

return expr.Call(func, args, None, None)

@spanify
def visitIfElse(self, ctx):
# type: (RelayParser.IfElseContext) -> expr.If
"""Construct a Relay If node. Creates a new scope for each branch."""
Expand Down Expand Up @@ -410,6 +440,14 @@ def visitFuncType(self, ctx):

return ty.FuncType(arg_types, ret_type, [], None)

def visitGraphExpr(self, ctx):
graph_nid = int(ctx.LOCAL_VAR().getText()[1:])
value = self.visit(ctx.expr(0))
assert graph_nid == len(self.graph_expr)
self.graph_expr.append(value)
kont = self.visit(ctx.expr(1))
return kont

def make_parser(data):
# type: (str) -> RelayParser
"""Construct a RelayParser a given data stream."""
Expand All @@ -418,8 +456,20 @@ def make_parser(data):
token_stream = CommonTokenStream(lexer)
return RelayParser(token_stream)

def fromtext(data):
# type: (str) -> Union[expr.Expr, env.Environment]
__source_name_counter__ = 0

def fromtext(data, source_name=None):
# type: (str, str) -> Union[expr.Expr, env.Environment]
"""Parse a Relay program."""
global __source_name_counter__

if source_name is None:
source_name = "source_file{0}".format(__source_name_counter__)

if isinstance(source_name, str):
source_name = SourceName(source_name)

import pdb; pdb.set_trace()
print(data)
tree = make_parser(data).prog()
return ParseTreeToRelayIR().visit(tree)
return ParseTreeToRelayIR(source_name).visit(tree)
10 changes: 10 additions & 0 deletions python/tvm/relay/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .._ffi.node import NodeBase, register_node as _register_tvm_node
from . import _make
from . import _expr
from . import _base

NodeBase = NodeBase

Expand Down Expand Up @@ -63,6 +64,9 @@ def astext(self, show_meta_data=True, annotate=None):
"""
return _expr.RelayPrint(self, show_meta_data, annotate)

def set_span(self, span):
_base.set_span(self, span)


@register_relay_node
class Span(RelayNode):
Expand All @@ -71,6 +75,12 @@ class Span(RelayNode):
def __init__(self, source, lineno, col_offset):
self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset)

@register_relay_node
class SourceName(RelayNode):
"""A identifier for a source location"""

def __init__(self, name):
self.__init_handle_by_constructor__(_make.SourceName, name)

@register_relay_node
class Id(NodeBase):
Expand Down
9 changes: 5 additions & 4 deletions python/tvm/relay/grammar/Relay.g4
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ NE: '!=' ;

opIdent: CNAME ;
GLOBAL_VAR: '@' CNAME ;
LOCAL_VAR: '%' CNAME ;
LOCAL_VAR: '%' (CNAME | INT);

MUT: 'mut' ;

Expand Down Expand Up @@ -83,8 +83,9 @@ expr

| ident # identExpr
| scalar # scalarExpr
// | expr '.' INT # project
// | 'debug' # debug
| LOCAL_VAR '=' expr ';' expr # graphExpr
// | expr '.' INT # project
// | 'debug' # debug
;

func: 'fn' varList ('->' type_)? body ;
Expand Down Expand Up @@ -131,7 +132,7 @@ identType: CNAME ;
// Float16, Float32, Float64
// Bool

body: '{' expr '}' ;
body: '{' expr ';' '}' ;

scalar
: FLOAT # scalarFloat
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""A parser for Relay's text format."""
from __future__ import absolute_import
from .. import register_func

def enabled():
"""Is the parser enabled/Can we import the parser?"""
Expand All @@ -11,7 +12,8 @@ def enabled():
except Exception:
return False

def fromtext(data):
@register_func("relay.fromtext")
def fromtext(data, source_name=None):
"""Parse a Relay program."""
from tvm.relay import _parser
return _parser.fromtext(data)
return _parser.fromtext(data, source_name)
Loading