From c4f71416a3173ecfd8117cd05aef799d74c38dc3 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 10 Dec 2018 12:52:10 -0800 Subject: [PATCH 01/12] Add error reporting start WIP --- include/tvm/relay/source_map.h | 44 ++ src/relay/ir/expr.cc | 8 + src/relay/ir/source_map.cc | 70 +++ src/relay/util/rang.h | 503 +++++++++++++++++++++ tests/python/relay/test_error_reporting.py | 4 + 5 files changed, 629 insertions(+) create mode 100644 include/tvm/relay/source_map.h create mode 100644 src/relay/ir/source_map.cc create mode 100644 src/relay/util/rang.h create mode 100644 tests/python/relay/test_error_reporting.py diff --git a/include/tvm/relay/source_map.h b/include/tvm/relay/source_map.h new file mode 100644 index 000000000000..2eabdfd28422 --- /dev/null +++ b/include/tvm/relay/source_map.h @@ -0,0 +1,44 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file source_map.h + * \brief A representation of source files and a data structure for + * storing them. + */ +#ifndef TVM_RELAY_SOURCE_MAP_H_ +#define TVM_RELAY_SOURCE_MAP_H_ + +#include +#include +#include + +namespace tvm { +namespace relay { + +struct SourceFragment { + std::string file_name; + std::vector source_lines; + + SourceFragment(std::string file_name, std::string source); + + SourceFragment(const SourceFragment& sf) { + this->file_name = sf.file_name; + this->source_lines = sf.source_lines; + } + + std::string SourceAt(Span sp, int lines); +}; + +/*! \brief Maps from FileId's to a SourceFragment. + */ +class SourceMap { + /*! \brief Map from unique token to a fragment of a source file. */ + std::unordered_map map_; + public: + SourceMap() : map_() {} + SourceName AddSource(std::string file_name, std::string source); + const SourceFragment & GetSource(SourceName id) const; +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_SOURCE_MAP_H_ diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index cdb2a32a0009..d740bab595e6 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -279,5 +279,13 @@ TVM_REGISTER_API("relay._expr.TempExprRealize") }); +TVM_REGISTER_API("relay._expr.set_span") +.set_body([](TVMArgs args, TVMRetValue* ret) { + Expr e = args[0]; + Span sp = args[1]; + *ret = temp->Realize(); +}); + + } // namespace relay } // namespace tvm diff --git a/src/relay/ir/source_map.cc b/src/relay/ir/source_map.cc new file mode 100644 index 000000000000..9df4e4cb831f --- /dev/null +++ b/src/relay/ir/source_map.cc @@ -0,0 +1,70 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file source_map.cc + * \brief Source maps for Relay. + */ + +#include +#include +#include + +namespace tvm { +namespace relay { + +using tvm::IRPrinter; +using namespace tvm::runtime; + +SourceFragment::SourceFragment(std::string file_name, std::string source) + : file_name(file_name), source_lines({}) { + RELAY_LOG(INFO)<< "SourceFragment::SourceFragment source=" << source << std::endl; + std::stringstream source_stream; + source_stream.str(source.c_str()); + std::string line; + + while (std::getline(source_stream, line)) { + RELAY_LOG(INFO) << "SourceFragment::SourceFragment: line=" << line << std::endl; + std::string copy(line); + source_lines.push_back(copy); + } +} + +std::string SourceFragment::SourceAt(Span sp, int max_lines) { + std::stringstream out; + + // We need to move from 1 based indexing to zero based indexing. + int starting_line = sp->lineno; + + if (starting_line >= static_cast(this->source_lines.size())) { + throw dmlc::Error("SourceFragment: index out of bounds"); + } + + auto lines = std::max(static_cast(max_lines), source_lines.size() - starting_line); + + for (size_t i = 0; i < lines; i++) { + out << std::endl << this->source_lines.at(starting_line + i); + } + + auto source_slice = out.str(); + + RELAY_LOG(INFO) << "SourceFragment::SourceAt: source_slice=" << source_slice << std::endl; + return source_slice; +} + +SourceName SourceMap::AddSource(std::string file_name, std::string source) { + auto new_id = SourceNameNode::make(file_name); + SourceFragment sfile(file_name, source); + this->map_.insert({new_id, sfile}); + return new_id; +} + +const SourceFragment& SourceMap::GetSource(SourceName id) const { + auto item = map_.find(id); + if (item != map_.end()) { + return (*item).second; + } else { + throw dmlc::Error("could not find requested source fragment"); + } +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/util/rang.h b/src/relay/util/rang.h new file mode 100644 index 000000000000..2da32a0a3ccf --- /dev/null +++ b/src/relay/util/rang.h @@ -0,0 +1,503 @@ +// This code is a header only library, which can be found here: https://github.com/agauniyal/rang. +#ifndef RANG_DOT_HPP +#define RANG_DOT_HPP + +#if defined(__unix__) || defined(__unix) || defined(__linux__) +#define OS_LINUX +#elif defined(WIN32) || defined(_WIN32) || defined(_WIN64) +#define OS_WIN +#elif defined(__APPLE__) || defined(__MACH__) +#define OS_MAC +#else +#error Unknown Platform +#endif + +#if defined(OS_LINUX) || defined(OS_MAC) +#include + +#elif defined(OS_WIN) + +#if defined(_WIN32_WINNT) && (_WIN32_WINNT < 0x0600) +#error \ + "Please include rang.hpp before any windows system headers or set _WIN32_WINNT at least to _WIN32_WINNT_VISTA" +#elif !defined(_WIN32_WINNT) +#define _WIN32_WINNT _WIN32_WINNT_VISTA +#endif + +#include +#include +#include + +// Only defined in windows 10 onwards, redefining in lower windows since it +// doesn't gets used in lower versions +// https://docs.microsoft.com/en-us/windows/console/getconsolemode +#ifndef ENABLE_VIRTUAL_TERMINAL_PROCESSING +#define ENABLE_VIRTUAL_TERMINAL_PROCESSING 0x0004 +#endif + +#endif + +#include +#include +#include +#include +#include + +namespace rang { + +/* For better compability with most of terminals do not use any style settings + * except of reset, bold and reversed. + * Note that on Windows terminals bold style is same as fgB color. + */ +enum class style { + reset = 0, + bold = 1, + dim = 2, + italic = 3, + underline = 4, + blink = 5, + rblink = 6, + reversed = 7, + conceal = 8, + crossed = 9 +}; + +enum class fg { + black = 30, + red = 31, + green = 32, + yellow = 33, + blue = 34, + magenta = 35, + cyan = 36, + gray = 37, + reset = 39 +}; + +enum class bg { + black = 40, + red = 41, + green = 42, + yellow = 43, + blue = 44, + magenta = 45, + cyan = 46, + gray = 47, + reset = 49 +}; + +enum class fgB { + black = 90, + red = 91, + green = 92, + yellow = 93, + blue = 94, + magenta = 95, + cyan = 96, + gray = 97 +}; + +enum class bgB { + black = 100, + red = 101, + green = 102, + yellow = 103, + blue = 104, + magenta = 105, + cyan = 106, + gray = 107 +}; + +enum class control { // Behaviour of rang function calls + Off = 0, // toggle off rang style/color calls + Auto = 1, // (Default) autodect terminal and colorize if needed + Force = 2 // force ansi color output to non terminal streams +}; +// Use rang::setControlMode to set rang control mode + +enum class winTerm { // Windows Terminal Mode + Auto = 0, // (Default) automatically detects wheter Ansi or Native API + Ansi = 1, // Force use Ansi API + Native = 2 // Force use Native API +}; +// Use rang::setWinTermMode to explicitly set terminal API for Windows +// Calling rang::setWinTermMode have no effect on other OS + +namespace rang_implementation { + + inline std::atomic &controlMode() noexcept + { + static std::atomic value(control::Auto); + return value; + } + + inline std::atomic &winTermMode() noexcept + { + static std::atomic termMode(winTerm::Auto); + return termMode; + } + + inline bool supportsColor() noexcept + { +#if defined(OS_LINUX) || defined(OS_MAC) + + static const bool result = [] { + const char *Terms[] + = { "ansi", "color", "console", "cygwin", "gnome", + "konsole", "kterm", "linux", "msys", "putty", + "rxvt", "screen", "vt100", "xterm" }; + + const char *env_p = std::getenv("TERM"); + if (env_p == nullptr) { + return false; + } + return std::any_of(std::begin(Terms), std::end(Terms), + [&](const char *term) { + return std::strstr(env_p, term) != nullptr; + }); + }(); + +#elif defined(OS_WIN) + // All windows versions support colors through native console methods + static constexpr bool result = true; +#endif + return result; + } + +#ifdef OS_WIN + + + inline bool isMsysPty(int fd) noexcept + { + // Dynamic load for binary compability with old Windows + const auto ptrGetFileInformationByHandleEx + = reinterpret_cast( + GetProcAddress(GetModuleHandle(TEXT("kernel32.dll")), + "GetFileInformationByHandleEx")); + if (!ptrGetFileInformationByHandleEx) { + return false; + } + + HANDLE h = reinterpret_cast(_get_osfhandle(fd)); + if (h == INVALID_HANDLE_VALUE) { + return false; + } + + // Check that it's a pipe: + if (GetFileType(h) != FILE_TYPE_PIPE) { + return false; + } + + // POD type is binary compatible with FILE_NAME_INFO from WinBase.h + // It have the same alignment and used to avoid UB in caller code + struct MY_FILE_NAME_INFO { + DWORD FileNameLength; + WCHAR FileName[MAX_PATH]; + }; + + auto pNameInfo = std::unique_ptr( + new (std::nothrow) MY_FILE_NAME_INFO()); + if (!pNameInfo) { + return false; + } + + // Check pipe name is template of + // {"cygwin-","msys-"}XXXXXXXXXXXXXXX-ptyX-XX + if (!ptrGetFileInformationByHandleEx(h, FileNameInfo, pNameInfo.get(), + sizeof(MY_FILE_NAME_INFO))) { + return false; + } + std::wstring name(pNameInfo->FileName, pNameInfo->FileNameLength / sizeof(WCHAR)); + if ((name.find(L"msys-") == std::wstring::npos + && name.find(L"cygwin-") == std::wstring::npos) + || name.find(L"-pty") == std::wstring::npos) { + return false; + } + + return true; + } + +#endif + + inline bool isTerminal(const std::streambuf *osbuf) noexcept + { + using std::cerr; + using std::clog; + using std::cout; +#if defined(OS_LINUX) || defined(OS_MAC) + if (osbuf == cout.rdbuf()) { + static const bool cout_term = isatty(fileno(stdout)) != 0; + return cout_term; + } else if (osbuf == cerr.rdbuf() || osbuf == clog.rdbuf()) { + static const bool cerr_term = isatty(fileno(stderr)) != 0; + return cerr_term; + } +#elif defined(OS_WIN) + if (osbuf == cout.rdbuf()) { + static const bool cout_term + = (_isatty(_fileno(stdout)) || isMsysPty(_fileno(stdout))); + return cout_term; + } else if (osbuf == cerr.rdbuf() || osbuf == clog.rdbuf()) { + static const bool cerr_term + = (_isatty(_fileno(stderr)) || isMsysPty(_fileno(stderr))); + return cerr_term; + } +#endif + return false; + } + + template + using enableStd = typename std::enable_if< + std::is_same::value || std::is_same::value + || std::is_same::value || std::is_same::value + || std::is_same::value, + std::ostream &>::type; + + +#ifdef OS_WIN + + struct SGR { // Select Graphic Rendition parameters for Windows console + BYTE fgColor; // foreground color (0-15) lower 3 rgb bits + intense bit + BYTE bgColor; // background color (0-15) lower 3 rgb bits + intense bit + BYTE bold; // emulated as FOREGROUND_INTENSITY bit + BYTE underline; // emulated as BACKGROUND_INTENSITY bit + BOOLEAN inverse; // swap foreground/bold & background/underline + BOOLEAN conceal; // set foreground/bold to background/underline + }; + + enum class AttrColor : BYTE { // Color attributes for console screen buffer + black = 0, + red = 4, + green = 2, + yellow = 6, + blue = 1, + magenta = 5, + cyan = 3, + gray = 7 + }; + + inline HANDLE getConsoleHandle(const std::streambuf *osbuf) noexcept + { + if (osbuf == std::cout.rdbuf()) { + static const HANDLE hStdout = GetStdHandle(STD_OUTPUT_HANDLE); + return hStdout; + } else if (osbuf == std::cerr.rdbuf() || osbuf == std::clog.rdbuf()) { + static const HANDLE hStderr = GetStdHandle(STD_ERROR_HANDLE); + return hStderr; + } + return INVALID_HANDLE_VALUE; + } + + inline bool setWinTermAnsiColors(const std::streambuf *osbuf) noexcept + { + HANDLE h = getConsoleHandle(osbuf); + if (h == INVALID_HANDLE_VALUE) { + return false; + } + DWORD dwMode = 0; + if (!GetConsoleMode(h, &dwMode)) { + return false; + } + dwMode |= ENABLE_VIRTUAL_TERMINAL_PROCESSING; + if (!SetConsoleMode(h, dwMode)) { + return false; + } + return true; + } + + inline bool supportsAnsi(const std::streambuf *osbuf) noexcept + { + using std::cerr; + using std::clog; + using std::cout; + if (osbuf == cout.rdbuf()) { + static const bool cout_ansi + = (isMsysPty(_fileno(stdout)) || setWinTermAnsiColors(osbuf)); + return cout_ansi; + } else if (osbuf == cerr.rdbuf() || osbuf == clog.rdbuf()) { + static const bool cerr_ansi + = (isMsysPty(_fileno(stderr)) || setWinTermAnsiColors(osbuf)); + return cerr_ansi; + } + return false; + } + + inline const SGR &defaultState() noexcept + { + static const SGR defaultSgr = []() -> SGR { + CONSOLE_SCREEN_BUFFER_INFO info; + WORD attrib = FOREGROUND_RED | FOREGROUND_GREEN | FOREGROUND_BLUE; + if (GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), + &info) + || GetConsoleScreenBufferInfo(GetStdHandle(STD_ERROR_HANDLE), + &info)) { + attrib = info.wAttributes; + } + SGR sgr = { 0, 0, 0, 0, FALSE, FALSE }; + sgr.fgColor = attrib & 0x0F; + sgr.bgColor = (attrib & 0xF0) >> 4; + return sgr; + }(); + return defaultSgr; + } + + inline BYTE ansi2attr(BYTE rgb) noexcept + { + static const AttrColor rev[8] + = { AttrColor::black, AttrColor::red, AttrColor::green, + AttrColor::yellow, AttrColor::blue, AttrColor::magenta, + AttrColor::cyan, AttrColor::gray }; + return static_cast(rev[rgb]); + } + + inline void setWinSGR(rang::bg col, SGR &state) noexcept + { + if (col != rang::bg::reset) { + state.bgColor = ansi2attr(static_cast(col) - 40); + } else { + state.bgColor = defaultState().bgColor; + } + } + + inline void setWinSGR(rang::fg col, SGR &state) noexcept + { + if (col != rang::fg::reset) { + state.fgColor = ansi2attr(static_cast(col) - 30); + } else { + state.fgColor = defaultState().fgColor; + } + } + + inline void setWinSGR(rang::bgB col, SGR &state) noexcept + { + state.bgColor = (BACKGROUND_INTENSITY >> 4) + | ansi2attr(static_cast(col) - 100); + } + + inline void setWinSGR(rang::fgB col, SGR &state) noexcept + { + state.fgColor + = FOREGROUND_INTENSITY | ansi2attr(static_cast(col) - 90); + } + + inline void setWinSGR(rang::style style, SGR &state) noexcept + { + switch (style) { + case rang::style::reset: state = defaultState(); break; + case rang::style::bold: state.bold = FOREGROUND_INTENSITY; break; + case rang::style::underline: + case rang::style::blink: + state.underline = BACKGROUND_INTENSITY; + break; + case rang::style::reversed: state.inverse = TRUE; break; + case rang::style::conceal: state.conceal = TRUE; break; + default: break; + } + } + + inline SGR ¤t_state() noexcept + { + static SGR state = defaultState(); + return state; + } + + inline WORD SGR2Attr(const SGR &state) noexcept + { + WORD attrib = 0; + if (state.conceal) { + if (state.inverse) { + attrib = (state.fgColor << 4) | state.fgColor; + if (state.bold) + attrib |= FOREGROUND_INTENSITY | BACKGROUND_INTENSITY; + } else { + attrib = (state.bgColor << 4) | state.bgColor; + if (state.underline) + attrib |= FOREGROUND_INTENSITY | BACKGROUND_INTENSITY; + } + } else if (state.inverse) { + attrib = (state.fgColor << 4) | state.bgColor; + if (state.bold) attrib |= BACKGROUND_INTENSITY; + if (state.underline) attrib |= FOREGROUND_INTENSITY; + } else { + attrib = state.fgColor | (state.bgColor << 4) | state.bold + | state.underline; + } + return attrib; + } + + template + inline void setWinColorAnsi(std::ostream &os, T const value) + { + os << "\033[" << static_cast(value) << "m"; + } + + template + inline void setWinColorNative(std::ostream &os, T const value) + { + const HANDLE h = getConsoleHandle(os.rdbuf()); + if (h != INVALID_HANDLE_VALUE) { + setWinSGR(value, current_state()); + // Out all buffered text to console with previous settings: + os.flush(); + SetConsoleTextAttribute(h, SGR2Attr(current_state())); + } + } + + template + inline enableStd setColor(std::ostream &os, T const value) + { + if (winTermMode() == winTerm::Auto) { + if (supportsAnsi(os.rdbuf())) { + setWinColorAnsi(os, value); + } else { + setWinColorNative(os, value); + } + } else if (winTermMode() == winTerm::Ansi) { + setWinColorAnsi(os, value); + } else { + setWinColorNative(os, value); + } + return os; + } +#else + template + inline enableStd setColor(std::ostream &os, T const value) + { + return os << "\033[" << static_cast(value) << "m"; + } +#endif +} // namespace rang_implementation + +template +inline rang_implementation::enableStd operator<<(std::ostream &os, + const T value) +{ + const control option = rang_implementation::controlMode(); + switch (option) { + case control::Auto: + return rang_implementation::supportsColor() + && rang_implementation::isTerminal(os.rdbuf()) + ? rang_implementation::setColor(os, value) + : os; + case control::Force: return rang_implementation::setColor(os, value); + default: return os; + } +} + +inline void setWinTermMode(const rang::winTerm value) noexcept +{ + rang_implementation::winTermMode() = value; +} + +inline void setControlMode(const control value) noexcept +{ + rang_implementation::controlMode() = value; +} + +} // namespace rang + +#undef OS_LINUX +#undef OS_WIN +#undef OS_MAC + +#endif /* ifndef RANG_DOT_HPP */ diff --git a/tests/python/relay/test_error_reporting.py b/tests/python/relay/test_error_reporting.py new file mode 100644 index 000000000000..d1ebf188e9d9 --- /dev/null +++ b/tests/python/relay/test_error_reporting.py @@ -0,0 +1,4 @@ +from tvm.relay.parser import fromtext + +def annotate_spans(expr): + return fromtext(expr.astext()) From 7c77432a6483f2befff1b60638831fb2a1353a0a Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 13 Dec 2018 22:12:59 -0800 Subject: [PATCH 02/12] Errors --- cmake/modules/ANTLR.cmake | 4 ++-- include/tvm/relay/source_map.h | 1 + python/tvm/relay/_base.py | 5 +++++ python/tvm/relay/_parser.py | 14 ++++++++++++++ python/tvm/relay/base.py | 10 ++++++++++ src/relay/ir/base.cc | 14 ++++++++++++++ src/relay/ir/expr.cc | 7 +------ src/relay/ir/source_map.cc | 4 ++-- tests/python/relay/test_error_reporting.py | 10 +++++++++- 9 files changed, 58 insertions(+), 11 deletions(-) create mode 100644 python/tvm/relay/_base.py diff --git a/cmake/modules/ANTLR.cmake b/cmake/modules/ANTLR.cmake index 72eb5925bda0..d8747d399802 100644 --- a/cmake/modules/ANTLR.cmake +++ b/cmake/modules/ANTLR.cmake @@ -1,6 +1,6 @@ if(USE_ANTLR) - if(EXISTS /usr/local/lib/antlr-4.7.1-complete.jar) - set(ANTLR4 "/usr/local/lib/antlr-4.7.1-complete.jar") + if(EXISTS /usr/local/Cellar/antlr/4.7.1_1/antlr-4.7.1-complete.jar) + set(ANTLR4 "/usr/local/Cellar/antlr/4.7.1_1/antlr-4.7.1-complete.jar") set(RELAY_PARSER_DIR ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm/relay/grammar) diff --git a/include/tvm/relay/source_map.h b/include/tvm/relay/source_map.h index 2eabdfd28422..99bcb269b3a0 100644 --- a/include/tvm/relay/source_map.h +++ b/include/tvm/relay/source_map.h @@ -7,6 +7,7 @@ #ifndef TVM_RELAY_SOURCE_MAP_H_ #define TVM_RELAY_SOURCE_MAP_H_ +#include #include #include #include diff --git a/python/tvm/relay/_base.py b/python/tvm/relay/_base.py new file mode 100644 index 000000000000..b23655a0406a --- /dev/null +++ b/python/tvm/relay/_base.py @@ -0,0 +1,5 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable +"""The interface of expr function exposed from C++.""" +from tvm._ffi.function import _init_api + +_init_api("relay._base", __name__) diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index f64c635dd4ff..ffb41da80997 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -9,10 +9,12 @@ from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable, Any from . import module +from .base import Span, SourceName from . import expr from . import ty from . import op + class ParseError(Exception): """Exception type for parse errors.""" @@ -76,6 +78,17 @@ def lookup(scopes, name): return val return None +def spanify(f): + def _wrapper(*args, **kwargs): + ctx = args[1] + ast = f(*args, **kwargs) + line, col = ctx.getSourceInterval() + sn = SourceName("foo") + sp = Span(sn, line, col) + ast.set_span(sp) + return ast + return _wrapper + # TODO(@jmp): Use https://stackoverflow.com/q/13889941 # to figure out how to get ANTLR4 to be more unhappy about syntax errors class ParseTreeToRelayIR(RelayVisitor): @@ -269,6 +282,7 @@ def visitBinOp(self, ctx): return relay_op(arg0, arg1) + @spanify def visitVar(self, ctx): # type: (RelayParser.VarContext) -> expr.Var ident = ctx.ident().LOCAL_VAR() diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index c50013b199ac..780d52863079 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -4,6 +4,7 @@ from .._ffi.node import NodeBase, register_node as _register_tvm_node from . import _make from . import _expr +from . import _base NodeBase = NodeBase @@ -63,6 +64,9 @@ def astext(self, show_meta_data=True, annotate=None): """ return _expr.RelayPrint(self, show_meta_data, annotate) + def set_span(self, span): + _base.set_span(self, span) + @register_relay_node class Span(RelayNode): @@ -71,6 +75,12 @@ class Span(RelayNode): def __init__(self, source, lineno, col_offset): self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset) +@register_relay_node +class SourceName(RelayNode): + """A identifier for a source location""" + + def __init__(self, name): + self.__init_handle_by_constructor__(_make.SourceName, name) @register_relay_node class Id(NodeBase): diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index 06593b6420f5..8df54883616a 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -32,6 +32,11 @@ SourceName SourceName::Get(const std::string& name) { return SourceName(GetSourceNameNode(name)); } +TVM_REGISTER_API("relay._make.SourceName") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = SourceName::Get(args[0]); + }); + TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const SourceNameNode* node, tvm::IRPrinter* p) { p->stream << "SourceName(" << node->name << ", " << node << ")"; @@ -66,5 +71,14 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(IdNode); +TVM_REGISTER_API("relay._base.set_span") +.set_body([](TVMArgs args, TVMRetValue* ret) { + NodeRef node_ref = args[0]; + auto rn = node_ref.as_derived(); + CHECK(rn); + Span sp = args[1]; + rn->span = sp; +}); + } // namespace relay } // namespace tvm diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index d740bab595e6..77956701429c 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -279,12 +279,7 @@ TVM_REGISTER_API("relay._expr.TempExprRealize") }); -TVM_REGISTER_API("relay._expr.set_span") -.set_body([](TVMArgs args, TVMRetValue* ret) { - Expr e = args[0]; - Span sp = args[1]; - *ret = temp->Realize(); -}); + } // namespace relay diff --git a/src/relay/ir/source_map.cc b/src/relay/ir/source_map.cc index 9df4e4cb831f..330a48ec9895 100644 --- a/src/relay/ir/source_map.cc +++ b/src/relay/ir/source_map.cc @@ -4,8 +4,8 @@ * \brief Source maps for Relay. */ -#include #include +#include #include namespace tvm { @@ -51,7 +51,7 @@ std::string SourceFragment::SourceAt(Span sp, int max_lines) { } SourceName SourceMap::AddSource(std::string file_name, std::string source) { - auto new_id = SourceNameNode::make(file_name); + auto new_id = SourceName::Get(file_name); SourceFragment sfile(file_name, source); this->map_.insert({new_id, sfile}); return new_id; diff --git a/tests/python/relay/test_error_reporting.py b/tests/python/relay/test_error_reporting.py index d1ebf188e9d9..bddc29f86096 100644 --- a/tests/python/relay/test_error_reporting.py +++ b/tests/python/relay/test_error_reporting.py @@ -1,4 +1,12 @@ from tvm.relay.parser import fromtext +from tvm import relay def annotate_spans(expr): - return fromtext(expr.astext()) + # sn = SourceName("my_expr.relay") + return fromtext(expr.astext(), sn) + +def test_var(): + x = relay.var('x') + func = relay.Function([x], x) + func = annotate_spans(func) + import pdb; pdb.set_trace() From 867e3165805bd33edd2b236f99eb5ae087106840 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 13 Dec 2018 22:39:20 -0800 Subject: [PATCH 03/12] hacking --- python/tvm/relay/_parser.py | 22 +++++++++++++++++----- tests/python/relay/test_error_reporting.py | 2 +- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index ffb41da80997..7b61383c8150 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -80,10 +80,10 @@ def lookup(scopes, name): def spanify(f): def _wrapper(*args, **kwargs): + sn = args[0].source_name ctx = args[1] ast = f(*args, **kwargs) line, col = ctx.getSourceInterval() - sn = SourceName("foo") sp = Span(sn, line, col) ast.set_span(sp) return ast @@ -94,8 +94,9 @@ def _wrapper(*args, **kwargs): class ParseTreeToRelayIR(RelayVisitor): """Parse Relay text format into Relay IR.""" - def __init__(self): + def __init__(self, source_name): # type: () -> None + self.source_name = source_name self.module = module.Module({}) # type: module.Module # Adding an empty scope allows naked lets without pain. @@ -104,6 +105,7 @@ def __init__(self): super(ParseTreeToRelayIR, self).__init__() + def enter_var_scope(self): # type: () -> None """Enter a new Var scope so it can be popped off later.""" @@ -318,10 +320,12 @@ def mk_func(self, ctx): return expr.Function(var_list, body, ret_type, type_params) # type: ignore + @spanify def visitFunc(self, ctx): # type: (RelayParser.FuncContext) -> expr.Function return self.mk_func(ctx) + @spanify def visitDefn(self, ctx): # type: (RelayParser.DefnContext) -> None ident = ctx.ident().GLOBAL_VAR() @@ -331,6 +335,7 @@ def visitDefn(self, ctx): self.module[ident] = self.mk_func(ctx) + @spanify def visitCall(self, ctx): # type: (RelayParser.CallContext) -> expr.Call visited_exprs = self.visit_list(ctx.expr()) @@ -340,6 +345,7 @@ def visitCall(self, ctx): return expr.Call(func, args, None, None) + @spanify def visitIfElse(self, ctx): # type: (RelayParser.IfElseContext) -> expr.If """Construct a Relay If node. Creates a new scope for each branch.""" @@ -432,8 +438,14 @@ def make_parser(data): token_stream = CommonTokenStream(lexer) return RelayParser(token_stream) -def fromtext(data): - # type: (str) -> Union[expr.Expr, env.Environment] +__source_name_counter__ = 0 + +def fromtext(data, source_name=None): + # type: (str, str) -> Union[expr.Expr, env.Environment] """Parse a Relay program.""" + global __source_name_counter__ + if source_name is None: + source_name = "source_file{0}".format(__source_name_counter__) + source_name = SourceName(source_name) tree = make_parser(data).prog() - return ParseTreeToRelayIR().visit(tree) + return ParseTreeToRelayIR(source_name).visit(tree) diff --git a/tests/python/relay/test_error_reporting.py b/tests/python/relay/test_error_reporting.py index bddc29f86096..e19c08f21922 100644 --- a/tests/python/relay/test_error_reporting.py +++ b/tests/python/relay/test_error_reporting.py @@ -3,7 +3,7 @@ def annotate_spans(expr): # sn = SourceName("my_expr.relay") - return fromtext(expr.astext(), sn) + return fromtext(expr.astext()) def test_var(): x = relay.var('x') From 273fddf7c0d26a41c8dab9046cfbe8f98f05f90b Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 13 Dec 2018 23:04:14 -0800 Subject: [PATCH 04/12] Add initial tests --- tests/python/relay/test_error_reporting.py | 29 +++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/tests/python/relay/test_error_reporting.py b/tests/python/relay/test_error_reporting.py index e19c08f21922..ef382f012756 100644 --- a/tests/python/relay/test_error_reporting.py +++ b/tests/python/relay/test_error_reporting.py @@ -1,6 +1,33 @@ from tvm.relay.parser import fromtext +from tvm.relay.expr import ExprFunctor +from tvm.relay.op import Op from tvm import relay +class SpanChecker(ExprFunctor): + def visit(self, expr): + if isinstance(expr, Op): + return self.visit_op(expr) + else: + return super().visit(expr) + + def visit_var(self, var): + assert var.span + + def visit_op(self, op): + pass + + def visit_function(self, func): + for param in func.params: + self.visit(param) + + self.visit(func.body) + + assert func.span + +def check_spans(expr): + sp_ck = SpanChecker() + sp_ck.visit(expr) + def annotate_spans(expr): # sn = SourceName("my_expr.relay") return fromtext(expr.astext()) @@ -9,4 +36,4 @@ def test_var(): x = relay.var('x') func = relay.Function([x], x) func = annotate_spans(func) - import pdb; pdb.set_trace() + check_spans(func) From b0abcf6079792f79457e09c7967a9b3303c35c86 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 13 Dec 2018 23:29:04 -0800 Subject: [PATCH 05/12] Hacking --- python/tvm/relay/_parser.py | 3 +++ python/tvm/relay/grammar/Relay.g4 | 3 ++- tests/python/relay/test_error_reporting.py | 17 ++++++++++++++++- 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index 7b61383c8150..0400a771ffae 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -430,6 +430,9 @@ def visitFuncType(self, ctx): return ty.FuncType(arg_types, ret_type, [], None) + def visitGraphExpr(self, ctx): + import pdb; pdb.set_trace() + def make_parser(data): # type: (str) -> RelayParser """Construct a RelayParser a given data stream.""" diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 index c74a42c97e77..59d3ab0cebd2 100644 --- a/python/tvm/relay/grammar/Relay.g4 +++ b/python/tvm/relay/grammar/Relay.g4 @@ -20,7 +20,7 @@ NE: '!=' ; opIdent: CNAME ; GLOBAL_VAR: '@' CNAME ; -LOCAL_VAR: '%' CNAME ; +LOCAL_VAR: '%' (CNAME | INT); MUT: 'mut' ; @@ -83,6 +83,7 @@ expr | ident # identExpr | scalar # scalarExpr + | LOCAL_VAR '=' expr # graphExpr // | expr '.' INT # project // | 'debug' # debug ; diff --git a/tests/python/relay/test_error_reporting.py b/tests/python/relay/test_error_reporting.py index ef382f012756..01d862ec6424 100644 --- a/tests/python/relay/test_error_reporting.py +++ b/tests/python/relay/test_error_reporting.py @@ -32,8 +32,23 @@ def annotate_spans(expr): # sn = SourceName("my_expr.relay") return fromtext(expr.astext()) -def test_var(): +def test_var_span(): x = relay.var('x') func = relay.Function([x], x) func = annotate_spans(func) check_spans(func) + +def test_type_check_call(): + x = relay.var('x', shape=(10, 10)) + func = relay.Function([x], x) + y = relay.var('x', shape=(10, 11)) + call = relay.Call(func, [y]) + func2 = relay.Function([y], call) + print(func2.astext()) + call = annotate_spans(func2) + check_spans(func2) + relay.ir_pass.infer_type(func2) + +if __name__ == "__main__": + test_var_span() + test_type_check_call() From 23b3e54790906e76610a1e721712beede4bbbaa9 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 25 Dec 2018 02:21:25 -0800 Subject: [PATCH 06/12] Parser hacking --- python/tvm/relay/_parser.py | 31 ++++++++++++++++------ python/tvm/relay/grammar/Relay.g4 | 8 +++--- src/relay/ir/text_printer.cc | 14 +++++----- tests/python/relay/test_error_reporting.py | 16 +++++------ 4 files changed, 41 insertions(+), 28 deletions(-) diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index 0400a771ffae..e2dac3a9eac7 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -102,6 +102,7 @@ def __init__(self, source_name): # Adding an empty scope allows naked lets without pain. self.var_scopes = deque([deque()]) # type: Scopes[expr.Var] self.type_param_scopes = deque([deque()]) # type: Scopes[ty.TypeVar] + self.graph_expr = [] super(ParseTreeToRelayIR, self).__init__() @@ -146,6 +147,20 @@ def mk_typ(self, name, kind): self.type_param_scopes[0].appendleft((name, typ)) return typ + def local_lookup(self, name): + try: + graph_nid = int(name) + return self.graph_expr[graph_nid] + except ValueError: + var = lookup(self.var_scopes, name) + + if var is None: + raise ParseError("Couldn't resolve `{}`.".format(name)) + + return var + except IndexError: + raise ParseError("Graph Expr error {}".format(name)) + def visitTerminal(self, node): # type: (TerminalNode) -> Union[expr.Expr, int, float] """Visit lexer tokens that aren't ignored or visited by other functions.""" @@ -157,13 +172,8 @@ def visitTerminal(self, node): if node_type == RelayLexer.GLOBAL_VAR: return expr.GlobalVar(node_text[1:]) elif node_type == RelayLexer.LOCAL_VAR: - name = node_text[1:] - var = lookup(self.var_scopes, name) - if var is None: - raise ParseError("Couldn't resolve `{}`.".format(name)) - - return var - + # Remove the leading '%' and lookup the name. + return self.local_lookup(node_text[1:]) # data types elif node_type == RelayLexer.INT: return int(node_text) @@ -431,7 +441,12 @@ def visitFuncType(self, ctx): return ty.FuncType(arg_types, ret_type, [], None) def visitGraphExpr(self, ctx): - import pdb; pdb.set_trace() + graph_nid = int(ctx.LOCAL_VAR().getText()[1:]) + value = self.visit(ctx.expr(0)) + assert graph_nid == len(self.graph_expr) + self.graph_expr.append(value) + kont = self.visit(ctx.expr(1)) + return kont def make_parser(data): # type: (str) -> RelayParser diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 index 59d3ab0cebd2..7bc463e76f11 100644 --- a/python/tvm/relay/grammar/Relay.g4 +++ b/python/tvm/relay/grammar/Relay.g4 @@ -83,9 +83,9 @@ expr | ident # identExpr | scalar # scalarExpr - | LOCAL_VAR '=' expr # graphExpr - // | expr '.' INT # project - // | 'debug' # debug + | LOCAL_VAR '=' expr ';' expr # graphExpr + // | expr '.' INT # project + // | 'debug' # debug ; func: 'fn' varList ('->' type_)? body ; @@ -132,7 +132,7 @@ identType: CNAME ; // Float16, Float32, Float64 // Bool -body: '{' expr '}' ; +body: '{' expr ';' '}' ; scalar : FLOAT # scalarFloat diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc index 46b0d25b3d7d..5ecd220bda8a 100644 --- a/src/relay/ir/text_printer.cc +++ b/src/relay/ir/text_printer.cc @@ -287,7 +287,7 @@ class TextPrinter : std::ostringstream os; os << id << " = fn"; this->PrintFuncInternal(os.str(), GetRef(op)); - this->PrintEndInst("\n"); + this->PrintEndInst(";\n"); return id; } @@ -325,7 +325,7 @@ class TextPrinter : } this->PrintCallAttrs(op->op, op->attrs, stream_); stream_ << ")"; - this->PrintEndInst(""); + this->PrintEndInst(";"); this->PrintOptionalInfo(GetRef(op)); stream_ << '\n'; return id; @@ -336,7 +336,7 @@ class TextPrinter : this->PrintIndent(); stream_ << id << " = "; this->PrintScope(GetRef(op)); - this->PrintEndInst("\n"); + this->PrintEndInst(";\n"); return id; } @@ -345,7 +345,7 @@ class TextPrinter : this->PrintIndent(); stream_ << id << " = "; this->PrintScope(GetRef(op)); - this->PrintEndInst("\n"); + this->PrintEndInst(";\n"); return id; } @@ -358,7 +358,7 @@ class TextPrinter : TextValue id = this->AllocTempVar(); this->PrintIndent(); stream_ << id << " = " << tuple << "." << op->index << ""; - this->PrintEndInst("\n"); + this->PrintEndInst(";\n"); return id; } @@ -492,7 +492,7 @@ class TextPrinter : stream_ << "let "; this->PrintVarDecl(let->var, stream_); stream_ << " = " << value; - this->PrintEndInst("\n"); + this->PrintEndInst(";\n"); this->PrintScopeBody(let->body); } else if (const IfNode* ifnode = body.as()) { TextValue cond = GetValue(ifnode->cond); @@ -507,7 +507,7 @@ class TextPrinter : TextValue value = GetValue(body); this->PrintIndent(); stream_ << value; - this->PrintEndInst("\n"); + this->PrintEndInst(";\n"); } } diff --git a/tests/python/relay/test_error_reporting.py b/tests/python/relay/test_error_reporting.py index 01d862ec6424..98ffb18bb250 100644 --- a/tests/python/relay/test_error_reporting.py +++ b/tests/python/relay/test_error_reporting.py @@ -1,15 +1,9 @@ from tvm.relay.parser import fromtext -from tvm.relay.expr import ExprFunctor +from tvm.relay.expr_functor import ExprFunctor from tvm.relay.op import Op from tvm import relay class SpanChecker(ExprFunctor): - def visit(self, expr): - if isinstance(expr, Op): - return self.visit_op(expr) - else: - return super().visit(expr) - def visit_var(self, var): assert var.span @@ -24,6 +18,9 @@ def visit_function(self, func): assert func.span + def visit_call(self, call): + assert call.span + def check_spans(expr): sp_ck = SpanChecker() sp_ck.visit(expr) @@ -46,8 +43,9 @@ def test_type_check_call(): func2 = relay.Function([y], call) print(func2.astext()) call = annotate_spans(func2) - check_spans(func2) - relay.ir_pass.infer_type(func2) + check_spans(call) + import pdb; pdb.set_trace() + relay.ir_pass.infer_type(call) if __name__ == "__main__": test_var_span() From eb18473abe262084beee5297edf3f036be687b37 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 25 Dec 2018 03:31:27 -0800 Subject: [PATCH 07/12] Hacking on errors --- include/tvm/relay/error.h | 14 ++++++++++++++ src/relay/pass/type_infer.cc | 28 +++++++++++++++++++++------- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h index 1c2b90611bbd..6d07a145d405 100644 --- a/include/tvm/relay/error.h +++ b/include/tvm/relay/error.h @@ -8,11 +8,13 @@ #include #include "./base.h" +#include "./source_map.h" namespace tvm { namespace relay { struct Error : public dmlc::Error { + Span sp; explicit Error(const std::string &msg) : dmlc::Error(msg) {} }; @@ -28,6 +30,18 @@ struct TypecheckerError : public Error { explicit TypecheckerError(const std::string &msg) : Error(msg) {} }; +class ErrorReporter { +public: + SourceMap src_map; + std::vector errors; + + void Report(const Error& err) { + this->errors.push_back(err); + } + + dmlc::Error Render(); +}; + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 13da159e99a8..8b1f98164aee 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -104,6 +104,7 @@ class TypeInferencer : private ExprFunctor { // constructors TypeInferencer() { } + explicit TypeInferencer(Module mod) : mod_(mod) { } @@ -114,8 +115,12 @@ class TypeInferencer : private ExprFunctor { private: // type resolver that maps back to type class Resolver; + // internal environment Module mod_; + + ErrorReporter err_reporter; + // map from expression to checked type // type inferencer will populate it up std::unordered_map type_map_; @@ -125,19 +130,28 @@ class TypeInferencer : private ExprFunctor { // relation function TypeRelationFn tuple_getitem_rel_; TypeRelationFn make_tuple_rel_; + + [[noreturn]] void FatalError(const Error& err) { + this->err_reporter.Report(err); + throw this->err_reporter.Render(); + } + // Unify two types Type Unify(const Type& t1, const Type& t2, const Span& span) { // TODO(tqchen, jroesch): propagate span to solver try { return solver_.Unify(t1, t2); } catch (const dmlc::Error &e) { - LOG(FATAL) - << "Error unifying `" - << t1 - << "` and `" - << t2 - << "`: " << e.what(); - return Type(); + // LOG(FATAL) + Error err("failed to unify"); + err.sp = span; + FatalError(err); + // << "Error unifying `" + // << t1 + // << "` and `" + // << t2 + // << "`: " << e.what(); + // return Type(); } } // Lazily get type for expr From bf6509233bbf98f40b8336f1512122d0958944d3 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 25 Dec 2018 21:24:13 -0800 Subject: [PATCH 08/12] First version of type inference errors working --- include/tvm/relay/base.h | 4 +- include/tvm/relay/error.h | 3 ++ include/tvm/relay/module.h | 11 +++++ include/tvm/relay/pass.h | 1 + include/tvm/relay/source_map.h | 9 ++-- python/tvm/relay/_parser.py | 7 ++- python/tvm/relay/parser.py | 6 ++- src/relay/ir/module.cc | 4 ++ src/relay/ir/source_map.cc | 20 ++++---- src/relay/pass/type_infer.cc | 53 +++++++++++++++++++--- tests/python/relay/test_error_reporting.py | 19 ++++---- 11 files changed, 104 insertions(+), 33 deletions(-) diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index f72f557a9765..f90acdc9400b 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -108,7 +108,9 @@ class SourceName : public NodeRef { * \brief access the internal node container * \return the pointer to the internal node container */ - inline const SourceNameNode* operator->() const; + inline const SourceNameNode* operator->() const { + return static_cast(this->node_.get()); + } /*! * \brief Get an SourceName for a given operator name. diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h index 6d07a145d405..ec9f3628d830 100644 --- a/include/tvm/relay/error.h +++ b/include/tvm/relay/error.h @@ -35,6 +35,9 @@ class ErrorReporter { SourceMap src_map; std::vector errors; + ErrorReporter() : src_map(), errors() {} + ErrorReporter(SourceMap src_map) : errors() {} + void Report(const Error& err) { this->errors.push_back(err); } diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index b04d6fec20c5..243d63f05cda 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -43,11 +43,16 @@ class ModuleNode : public RelayNode { /*! \brief A map from ids to all global functions. */ tvm::Map functions; + mutable ErrorReporter err_reporter; + + mutable GlobalVar entry_func; + ModuleNode() {} void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("functions", &functions); v->Visit("global_var_map_", &global_var_map_); + v->Visit("entry_func", &entry_func); } TVM_DLL static Module make(tvm::Map global_funcs); @@ -102,6 +107,12 @@ class ModuleNode : public RelayNode { */ void Update(const Module& other); + /*! + * \brief Get the entry point of the module. + * + */ + Expr EntryPoint(); + static constexpr const char* _type_key = "relay.Module"; TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node); diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 8fff7016a827..b4b587991810 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -28,6 +28,7 @@ namespace relay { * \return A type checked expression with its checked_type field populated. */ Expr InferType(const Expr& expr, const Module& mod); + /*! * \brief Infer the type of a function as if it is mapped to var in the mod. * diff --git a/include/tvm/relay/source_map.h b/include/tvm/relay/source_map.h index 99bcb269b3a0..d8a4f99022e5 100644 --- a/include/tvm/relay/source_map.h +++ b/include/tvm/relay/source_map.h @@ -16,13 +16,13 @@ namespace tvm { namespace relay { struct SourceFragment { - std::string file_name; + SourceName name; std::vector source_lines; - SourceFragment(std::string file_name, std::string source); + SourceFragment(const SourceName& file_name, const std::string& source); SourceFragment(const SourceFragment& sf) { - this->file_name = sf.file_name; + this->name = sf.name; this->source_lines = sf.source_lines; } @@ -36,7 +36,8 @@ class SourceMap { std::unordered_map map_; public: SourceMap() : map_() {} - SourceName AddSource(std::string file_name, std::string source); + SourceName AddSource(const std::string& file_name, const std::string& source); + SourceName AddSource(const SourceName& source_name, const std::string& source); const SourceFragment & GetSource(SourceName id) const; }; diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index e2dac3a9eac7..32299931f05e 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -462,8 +462,13 @@ def fromtext(data, source_name=None): # type: (str, str) -> Union[expr.Expr, env.Environment] """Parse a Relay program.""" global __source_name_counter__ + if source_name is None: source_name = "source_file{0}".format(__source_name_counter__) - source_name = SourceName(source_name) + + if isinstance(source_name, str): + source_name = SourceName(source_name) + tree = make_parser(data).prog() + import pdb; pdb.set_trace() return ParseTreeToRelayIR(source_name).visit(tree) diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index 51200343f147..b3b502edc84b 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -1,5 +1,6 @@ """A parser for Relay's text format.""" from __future__ import absolute_import +from .. import register_func def enabled(): """Is the parser enabled/Can we import the parser?""" @@ -11,7 +12,8 @@ def enabled(): except Exception: return False -def fromtext(data): +@register_func("relay.fromtext") +def fromtext(data, source_name=None): """Parse a Relay program.""" from tvm.relay import _parser - return _parser.fromtext(data) + return _parser.fromtext(data, source_name) diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 4443ed50783e..ad4d5fa9748a 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -91,6 +91,10 @@ void ModuleNode::Update(const Module& mod) { } } +Expr ModuleNode::EntryPoint() { + return this->Lookup(this->entry_func); +} + TVM_REGISTER_NODE_TYPE(ModuleNode); TVM_REGISTER_API("relay._make.Module") diff --git a/src/relay/ir/source_map.cc b/src/relay/ir/source_map.cc index 330a48ec9895..f59ee2e854f2 100644 --- a/src/relay/ir/source_map.cc +++ b/src/relay/ir/source_map.cc @@ -14,8 +14,8 @@ namespace relay { using tvm::IRPrinter; using namespace tvm::runtime; -SourceFragment::SourceFragment(std::string file_name, std::string source) - : file_name(file_name), source_lines({}) { +SourceFragment::SourceFragment(const SourceName& name, const std::string& source) + : name(name), source_lines({}) { RELAY_LOG(INFO)<< "SourceFragment::SourceFragment source=" << source << std::endl; std::stringstream source_stream; source_stream.str(source.c_str()); @@ -50,11 +50,15 @@ std::string SourceFragment::SourceAt(Span sp, int max_lines) { return source_slice; } -SourceName SourceMap::AddSource(std::string file_name, std::string source) { - auto new_id = SourceName::Get(file_name); - SourceFragment sfile(file_name, source); - this->map_.insert({new_id, sfile}); - return new_id; +SourceName SourceMap::AddSource(const SourceName& source_name, const std::string& source) { + SourceFragment sfile(source_name, source); + this->map_.insert({source_name, sfile}); + return source_name; +} + +SourceName SourceMap::AddSource(const std::string& file_name, const std::string& source) { + auto source_name = SourceName::Get(file_name); + return this->AddSource(source_name, source); } const SourceFragment& SourceMap::GetSource(SourceName id) const { @@ -62,7 +66,7 @@ const SourceFragment& SourceMap::GetSource(SourceName id) const { if (item != map_.end()) { return (*item).second; } else { - throw dmlc::Error("could not find requested source fragment"); + LOG(FATAL) << "could not find requested source fragment" << id; } } diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 8b1f98164aee..803364c99edd 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -112,6 +112,10 @@ class TypeInferencer : private ExprFunctor { // inference the type of expr. Expr Infer(Expr expr); + inline ErrorReporter& ErrReporter() { + return this->mod_->err_reporter; + } + private: // type resolver that maps back to type class Resolver; @@ -119,7 +123,6 @@ class TypeInferencer : private ExprFunctor { // internal environment Module mod_; - ErrorReporter err_reporter; // map from expression to checked type // type inferencer will populate it up @@ -132,8 +135,8 @@ class TypeInferencer : private ExprFunctor { TypeRelationFn make_tuple_rel_; [[noreturn]] void FatalError(const Error& err) { - this->err_reporter.Report(err); - throw this->err_reporter.Render(); + this->ErrReporter().Report(err); + throw this->ErrReporter().Render(); } // Unify two types @@ -545,19 +548,55 @@ Expr TypeInferencer::Infer(Expr expr) { return resolved_expr; } +std::pair AnnotateSpans(const Expr& expr, const SourceName& source_name) { + auto text = RelayPrint(expr); + auto fromtext = runtime::Registry::Get("relay.fromtext"); + CHECK(fromtext != nullptr); + auto annotated_expr = (*fromtext)(text, source_name); + return { annotated_expr, text }; +} -Expr InferType(const Expr& expr, const Module& mod) { - auto e = TypeInferencer(mod).Infer(expr); +Expr InferType(const Expr& expr, const Module& m) { + Module mod = m; + if (!mod.defined()) + mod = ModuleNode::make({}); + CHECK(mod.defined()); + auto main = GlobalVarNode::make("main"); + auto source_name = SourceName::Get("main"); + mod->entry_func = main; + auto annotated = AnnotateSpans(expr, source_name); + auto spanned_expr = annotated.first; + auto text = annotated.second; + mod->Add(main, Downcast(spanned_expr)); + mod->err_reporter.src_map.AddSource(source_name, text); + auto e = TypeInferencer(mod).Infer(main); CHECK(WellFormed(e)); - return e; + return mod->Lookup(main); +} + +Expr InferType(const Expr& expr) { + auto mod = ModuleNode::make({}); + return InferType(expr, mod); } Function InferType(const Function& func, - const Module& mod, + const Module& m, const GlobalVar& var) { Function func_copy = Function(make_node(*func.operator->())); func_copy->checked_type_ = func_copy->func_type_annotation(); + + // Module mod + Module mod = m; + if (!mod.defined()) + mod = ModuleNode::make({}); + mod->functions.Set(var, func_copy); + CHECK(mod.defined()); + auto source_name = SourceName::Get(var->name_hint); + auto annotated = AnnotateSpans(func, source_name); + auto spanned_expr = annotated.first; + auto text = annotated.second; + mod->err_reporter.src_map.AddSource(source_name, text); Expr func_ret = TypeInferencer(mod).Infer(func_copy); auto map_node = mod->functions.CopyOnWrite(); map_node->data.erase(var.node_); diff --git a/tests/python/relay/test_error_reporting.py b/tests/python/relay/test_error_reporting.py index 98ffb18bb250..8a78dfc14450 100644 --- a/tests/python/relay/test_error_reporting.py +++ b/tests/python/relay/test_error_reporting.py @@ -1,3 +1,4 @@ +import tvm from tvm.relay.parser import fromtext from tvm.relay.expr_functor import ExprFunctor from tvm.relay.op import Op @@ -25,14 +26,16 @@ def check_spans(expr): sp_ck = SpanChecker() sp_ck.visit(expr) -def annotate_spans(expr): - # sn = SourceName("my_expr.relay") - return fromtext(expr.astext()) +@tvm.register_func("annotate_spans") +def annotate_spans(expr, source_name): + text = expr.astext() + expr = fromtext(text, source_name) + check_spans(expr) + return expr, text def test_var_span(): x = relay.var('x') func = relay.Function([x], x) - func = annotate_spans(func) check_spans(func) def test_type_check_call(): @@ -41,12 +44,8 @@ def test_type_check_call(): y = relay.var('x', shape=(10, 11)) call = relay.Call(func, [y]) func2 = relay.Function([y], call) - print(func2.astext()) - call = annotate_spans(func2) - check_spans(call) - import pdb; pdb.set_trace() - relay.ir_pass.infer_type(call) + relay.ir_pass.infer_type(func2) if __name__ == "__main__": - test_var_span() + # test_var_span() test_type_check_call() From 1d02a8ccf51439ae45f93c28ec0089330d194cc5 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 25 Dec 2018 21:28:52 -0800 Subject: [PATCH 09/12] First version of type inference errors working --- include/tvm/relay/source_map.h | 2 +- src/relay/ir/source_map.cc | 19 +++++++++---------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/include/tvm/relay/source_map.h b/include/tvm/relay/source_map.h index d8a4f99022e5..04cddb17bfad 100644 --- a/include/tvm/relay/source_map.h +++ b/include/tvm/relay/source_map.h @@ -26,7 +26,7 @@ struct SourceFragment { this->source_lines = sf.source_lines; } - std::string SourceAt(Span sp, int lines); + std::vector LinesAt(Span sp, int lines); }; /*! \brief Maps from FileId's to a SourceFragment. diff --git a/src/relay/ir/source_map.cc b/src/relay/ir/source_map.cc index f59ee2e854f2..262f8820df48 100644 --- a/src/relay/ir/source_map.cc +++ b/src/relay/ir/source_map.cc @@ -28,9 +28,7 @@ SourceFragment::SourceFragment(const SourceName& name, const std::string& source } } -std::string SourceFragment::SourceAt(Span sp, int max_lines) { - std::stringstream out; - +std::vector SourceFragment::LinesAt(Span sp, int max_lines) { // We need to move from 1 based indexing to zero based indexing. int starting_line = sp->lineno; @@ -38,16 +36,17 @@ std::string SourceFragment::SourceAt(Span sp, int max_lines) { throw dmlc::Error("SourceFragment: index out of bounds"); } - auto lines = std::max(static_cast(max_lines), source_lines.size() - starting_line); + auto num_of_lines = + std::max(static_cast(max_lines), + source_lines.size() - starting_line); - for (size_t i = 0; i < lines; i++) { - out << std::endl << this->source_lines.at(starting_line + i); + std::vector lines; + for (size_t i = 0; i < num_of_lines; i++) { + lines.push_back(this->source_lines.at(starting_line + i)); } - auto source_slice = out.str(); - - RELAY_LOG(INFO) << "SourceFragment::SourceAt: source_slice=" << source_slice << std::endl; - return source_slice; + // RELAY_LOG(INFO) << "SourceFragment::SourceAt: source_slice=" << source_slice << std::endl; + return lines; } SourceName SourceMap::AddSource(const SourceName& source_name, const std::string& source) { From 2af982b783b819f0efacd72f28ad32ae0fc5e394 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 25 Dec 2018 22:10:18 -0800 Subject: [PATCH 10/12] Add more --- python/tvm/relay/_parser.py | 3 ++- src/relay/pass/type_infer.cc | 12 ++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index 32299931f05e..1aec2e2a8107 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -469,6 +469,7 @@ def fromtext(data, source_name=None): if isinstance(source_name, str): source_name = SourceName(source_name) - tree = make_parser(data).prog() import pdb; pdb.set_trace() + print(data) + tree = make_parser(data).prog() return ParseTreeToRelayIR(source_name).visit(tree) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 803364c99edd..338cdddefb2d 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -225,14 +225,14 @@ class TypeInferencer : private ExprFunctor { return op->op_type; } - Type VisitExpr_(const LetNode* op) final { - Type vtype = GetType(op->value); - if (op->var->type_annotation.defined()) { - vtype = Unify(vtype, op->var->type_annotation, op->span); + Type VisitExpr_(const LetNode* let) final { + Type vtype = GetType(let->value); + if (let->var->type_annotation.defined()) { + vtype = Unify(vtype, let->var->type_annotation, let->span); } - CHECK(!type_map_.count(op->var)); + CHECK(!type_map_.count(let->var)); // NOTE: no scoping is necessary because var are unique in program - type_map_[op->var].checked_type = vtype; + type_map_[let->var].checked_type = vtype; return GetType(op->body); } From 90391541260f5b623ef14f945185e42f1ade1b48 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 26 Dec 2018 15:04:04 -0800 Subject: [PATCH 11/12] Add error.cc --- src/relay/ir/error.cc | 48 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 src/relay/ir/error.cc diff --git a/src/relay/ir/error.cc b/src/relay/ir/error.cc new file mode 100644 index 000000000000..9e6d2dc4a99f --- /dev/null +++ b/src/relay/ir/error.cc @@ -0,0 +1,48 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file error.cc + * \brief Relay type inference and checking. + * + */ + +#include +#include "../util/rang.h" + +namespace tvm { +namespace relay { + +dmlc::Error ErrorReporter::Render() { + for (auto err : this->errors) { + auto sp = err.sp; + CHECK(sp.defined()) << "while attempting to report an error its span was null"; + auto source_file = this->src_map.GetSource(err.sp->source); + auto file_name = source_file.name->name; + auto lines = source_file.LinesAt(err.sp, 1); + std::string error_marker = "error:"; + auto line_info = + std::to_string(sp->lineno) + ":" + std::to_string(sp->col_offset); + + std::cout << rang::style::bold << rang::fg::red << error_marker + << rang::fg::reset << file_name << ":" << line_info + << rang::style::reset << " " << lines[0] << std::endl; + + // Build the cursor. + + // Fix this code, hardwired to compute alignment of pointer. + size_t spaces = error_marker.size() + line_info.size() + file_name.size() + + sp->col_offset - 3; + + std::string cursor = "~~~~^~~~~"; + for (size_t i = 0; i < spaces; i++) { + std::cout << " "; + } + + std::cout << rang::fg::red << cursor << " " << err.what() << rang::style::reset + << std::endl; + } + return dmlc::Error("print me"); + } + + +} // relay +} // tvm From 7546c4837a0b9ea965a3d61934f3da3d0737dff7 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 26 Dec 2018 18:41:15 -0800 Subject: [PATCH 12/12] Remove uneeded function --- tests/python/relay/test_error_reporting.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/python/relay/test_error_reporting.py b/tests/python/relay/test_error_reporting.py index 8a78dfc14450..afd8071bc65a 100644 --- a/tests/python/relay/test_error_reporting.py +++ b/tests/python/relay/test_error_reporting.py @@ -26,13 +26,6 @@ def check_spans(expr): sp_ck = SpanChecker() sp_ck.visit(expr) -@tvm.register_func("annotate_spans") -def annotate_spans(expr, source_name): - text = expr.astext() - expr = fromtext(text, source_name) - check_spans(expr) - return expr, text - def test_var_span(): x = relay.var('x') func = relay.Function([x], x)