From 8883a1adba2dc039606328dcd9d4863c26568d01 Mon Sep 17 00:00:00 2001 From: Joey Tsai Date: Wed, 16 Nov 2022 13:18:57 +0800 Subject: [PATCH 1/4] [Relay][Frontend] Span filling common API - Expose and add span attribute of Expr-derived types from C++ to Python - Add common API of span filling - Add test cases of span filling - Add function to control whether to fill span via environment variable - Modify the way of pretty-print to print span --- python/tvm/relay/expr.py | 75 +++++++--- python/tvm/relay/frontend/common.py | 142 ++++++++++++++++++- python/tvm/relay/function.py | 7 +- python/tvm/relay/loops.py | 2 +- python/tvm/testing/utils.py | 25 ++++ src/printer/relay_text_printer.cc | 28 ++-- src/printer/text_printer.h | 23 ++- src/relay/ir/expr.cc | 32 +++-- src/relay/ir/function.cc | 4 +- tests/python/frontend/test_common.py | 204 ++++++++++++++++++++++++++- tests/python/relay/utils/tag_span.py | 106 ++++++++++++++ 11 files changed, 584 insertions(+), 64 deletions(-) create mode 100644 tests/python/relay/utils/tag_span.py diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index fefc2857230d..f2b2b4ebbdc5 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -171,10 +171,13 @@ class Constant(ExprWithOp): ---------- data : tvm.nd.NDArray The data content of the constant expression. + + span: Optional[tvm.relay.Span] + Span that points to original source code. """ - def __init__(self, data): - self.__init_handle_by_constructor__(_ffi_api.Constant, data) + def __init__(self, data, span=None): + self.__init_handle_by_constructor__(_ffi_api.Constant, data, span) @tvm._ffi.register_object("relay.Tuple") @@ -187,7 +190,7 @@ class Tuple(ExprWithOp): The fields in the tuple. span: Optional[tvm.relay.Span] - Span that points to original source code + Span that points to original source code. """ def __init__(self, fields, span=None): @@ -221,10 +224,13 @@ class Var(ExprWithOp): type_annotation: tvm.relay.Type, optional The type annotation on the variable. + + span: Optional[tvm.relay.Span] + Span that points to original source code. """ - def __init__(self, name_hint, type_annotation=None): - self.__init_handle_by_constructor__(_ffi_api.Var, name_hint, type_annotation) + def __init__(self, name_hint, type_annotation=None, span=None): + self.__init_handle_by_constructor__(_ffi_api.Var, name_hint, type_annotation, span) @property def name_hint(self): @@ -256,7 +262,7 @@ class Call(ExprWithOp): used in advanced usecase of template functions. span: Optional[tvm.relay.Span] - Span that points to original source code + Span that points to original source code. """ def __init__(self, op, args, attrs=None, type_args=None, span=None): @@ -279,10 +285,13 @@ class Let(ExprWithOp): body: tvm.relay.Expr The body of the let binding. + + span: Optional[tvm.relay.Span] + Span that points to original source code. """ - def __init__(self, variable, value, body): - self.__init_handle_by_constructor__(_ffi_api.Let, variable, value, body) + def __init__(self, variable, value, body, span=None): + self.__init_handle_by_constructor__(_ffi_api.Let, variable, value, body, span) @tvm._ffi.register_object("relay.If") @@ -299,10 +308,13 @@ class If(ExprWithOp): false_branch: tvm.relay.Expr The expression evaluated when condition is false. + + span: Optional[tvm.relay.Span] + Span that points to original source code. """ - def __init__(self, cond, true_branch, false_branch): - self.__init_handle_by_constructor__(_ffi_api.If, cond, true_branch, false_branch) + def __init__(self, cond, true_branch, false_branch, span=None): + self.__init_handle_by_constructor__(_ffi_api.If, cond, true_branch, false_branch, span) @tvm._ffi.register_object("relay.TupleGetItem") @@ -316,10 +328,13 @@ class TupleGetItem(ExprWithOp): index: int The index. + + span: Optional[tvm.relay.Span] + Span that points to original source code. """ - def __init__(self, tuple_value, index): - self.__init_handle_by_constructor__(_ffi_api.TupleGetItem, tuple_value, index) + def __init__(self, tuple_value, index, span=None): + self.__init_handle_by_constructor__(_ffi_api.TupleGetItem, tuple_value, index, span) @tvm._ffi.register_object("relay.RefCreate") @@ -329,10 +344,13 @@ class RefCreate(ExprWithOp): ---------- value: tvm.relay.Expr The initial value. + + span: Optional[tvm.relay.Span] + Span that points to original source code. """ - def __init__(self, value): - self.__init_handle_by_constructor__(_ffi_api.RefCreate, value) + def __init__(self, value, span=None): + self.__init_handle_by_constructor__(_ffi_api.RefCreate, value, span) @tvm._ffi.register_object("relay.RefRead") @@ -342,10 +360,13 @@ class RefRead(ExprWithOp): ---------- ref: tvm.relay.Expr The reference. + + span: Optional[tvm.relay.Span] + Span that points to original source code. """ - def __init__(self, ref): - self.__init_handle_by_constructor__(_ffi_api.RefRead, ref) + def __init__(self, ref, span=None): + self.__init_handle_by_constructor__(_ffi_api.RefRead, ref, span) @tvm._ffi.register_object("relay.RefWrite") @@ -357,12 +378,16 @@ class RefWrite(ExprWithOp): ---------- ref: tvm.relay.Expr The reference. + value: tvm.relay.Expr The new value. + + span: Optional[tvm.relay.Span] + Span that points to original source code. """ - def __init__(self, ref, value): - self.__init_handle_by_constructor__(_ffi_api.RefWrite, ref, value) + def __init__(self, ref, value, span=None): + self.__init_handle_by_constructor__(_ffi_api.RefWrite, ref, value, span) class TempExpr(ExprWithOp): @@ -433,7 +458,7 @@ def astype(self, _): raise TypeError("astype cannot be used on tuple") -def var(name_hint, type_annotation=None, shape=None, dtype="float32"): +def var(name_hint, type_annotation=None, shape=None, dtype="float32", span=None): """Create a new tvm.relay.Var. This is a simple wrapper function that allows specify @@ -456,6 +481,9 @@ def var(name_hint, type_annotation=None, shape=None, dtype="float32"): dtype: str, optional The data type of the tensor. + span: Optional[tvm.relay.Span] + Span that points to original source code. + Examples -------- .. code-block:: python @@ -476,10 +504,10 @@ def var(name_hint, type_annotation=None, shape=None, dtype="float32"): type_annotation = _ty.TensorType(shape, dtype) elif isinstance(type_annotation, str): type_annotation = _ty.TensorType((), type_annotation) - return Var(name_hint, type_annotation) + return Var(name_hint, type_annotation, span) -def const(value, dtype=None): +def const(value, dtype=None, span=None): """Create a constant value. Parameters @@ -490,6 +518,9 @@ def const(value, dtype=None): dtype: str, optional The data type of the resulting constant. + span: Optional[tvm.relay.Span] + Span that points to original source code. + Note ---- When dtype is None, we use the following rule: @@ -516,7 +547,7 @@ def const(value, dtype=None): if not isinstance(value, _nd.NDArray): raise ValueError("value has to be scalar or NDArray") - return Constant(value) + return Constant(value, span) def bind(expr, binds): diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 5f961f1ae0e8..4b211431f37e 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -18,12 +18,14 @@ """Common utilities""" from __future__ import absolute_import as _abs import logging +import os import numpy as np import tvm from tvm.ir import IRModule from tvm.topi.utils import get_const_tuple +from ..expr_functor import ExprMutator from .. import expr as _expr from .. import function as _function from .. import transform as _transform @@ -304,13 +306,17 @@ def __init__(self): self.const_ctr = 1 self.in_padding = False - def new_const(self, value, shape=None, dtype="float32"): + def new_const(self, value, shape=None, dtype="float32", source_name=None): + """Construct a new var expr and add to exprs dictionary""" name = "_param_%d" % (self.const_ctr) if hasattr(value, "shape"): shape = value.shape self.const_ctr += 1 self.params[name] = value - self.exprs[name] = _expr.var(name_hint=name, shape=shape, dtype=dtype) + tmp_var = _expr.var(name_hint=name, shape=shape, dtype=dtype) + if source_name: + tmp_var = set_span(tmp_var, source_name) + self.exprs[name] = tmp_var return self.exprs[name] def get_expr(self, name): @@ -997,3 +1003,135 @@ def try_resolve_var_to_const(x, graph_params): return _op.const(value, dtype) return x + + +class _SpanFiller(ExprMutator): + """SpanFiller""" + + def __init__(self, span): + ExprMutator.__init__(self) + if isinstance(span, tvm.relay.Span): + self._span = span + elif isinstance(span, str): + self._span = tvm.relay.Span(tvm.relay.SourceName(span), 0, 0, 0, 0) + elif isinstance(span, bytes): + self._span = tvm.relay.Span(tvm.relay.SourceName(span.decode("utf-8")), 0, 0, 0, 0) + else: + assert False, f"unsupported span type: {type(span)}" + + def visit(self, expr): + if hasattr(expr, "span") and expr.span: + return expr + + return super().visit(expr) + + def visit_function(self, fn): + new_params = [self.visit(x) for x in fn.params] + new_body = self.visit(fn.body) + return _function.Function( + list(new_params), new_body, fn.ret_type, fn.type_params, fn.attrs, self._span + ) + + def visit_let(self, let): + new_variable = self.visit(let.var) + new_value = self.visit(let.value) + new_body = self.visit(let.body) + return _expr.Let(new_variable, new_value, new_body, self._span) + + def visit_call(self, call): + new_args = [self.visit(arg) for arg in call.args] + # call.op might be RelayExpr or Op type + # ExprMutator will return directly if subject belongs to Op type + new_op = self.visit(call.op) + return _expr.Call(new_op, new_args, call.attrs, call.type_args, self._span) + + def visit_var(self, var): + return _expr.Var(var.name_hint, var.type_annotation, self._span) + + def visit_if(self, ite): + return _expr.If( + self.visit(ite.cond), + self.visit(ite.true_branch), + self.visit(ite.false_branch), + self._span, + ) + + def visit_tuple(self, tup): + return _expr.Tuple([self.visit(field) for field in tup.fields], self._span) + + def visit_tuple_getitem(self, op): + return _expr.TupleGetItem(self.visit(op.tuple_value), op.index, self._span) + + def visit_constant(self, const): + return _expr.Constant(const.data, self._span) + + # TODO: Frontend model translation could not use following relay expressions so far, + # enable them when new models/impls leverage these kinds of relay expressions. + def visit_ref_create(self, _): + raise NotImplementedError() + + def visit_ref_write(self, _): + raise NotImplementedError() + + def visit_ref_read(self, _): + raise NotImplementedError() + + def visit_match(self, _): + raise NotImplementedError() + + def fill(self, sym): + """Fill span to sym when it is an expr, or return it without change + + Parameters + ---------- + sym : + A symbol which is generated from the conversion of a frontend operator. + + Returns + ------- + sym: + A expr with span-filled or the original sym. + """ + if isinstance(sym, _expr.TupleWrapper): + return _expr.TupleWrapper(self.visit(sym.tuple_value), sym.size) + elif isinstance(sym, _expr.RelayExpr): + return self.visit(sym) + elif isinstance(sym, list): + assert all( + isinstance(expr, _expr.RelayExpr) for expr in sym + ), f"unexpected relay expressions in {sym}" + return [self.visit(expr) for expr in sym] + elif isinstance(sym, tuple): + # some op conversion may return dummy elements + # e.g. op in frontend/pytorch.py: min_max_common + assert all( + isinstance(expr, (_expr.RelayExpr, type(None))) for expr in sym + ), f"unexpected relay expressions in {sym}" + return tuple(self.visit(expr) if expr else None for expr in sym) + elif isinstance(sym, (float, int)): + return sym + elif isinstance(sym, np.ndarray): + return sym + + raise RuntimeError(f"unsupported type {type(sym)}") + + +def _should_fill_span(): + should_fill_span = os.environ.get("TVM_SPANFILLING", "1") + + try: + should_fill_span = bool(int(should_fill_span)) + except ValueError: + raise ValueError( + f"invalid value for TVM_SPANFILLING {should_fill_span}, please set to 0 or 1." + ) + + return should_fill_span + + +def set_span(sym, span): + """Set up the sapn of relay expression(s) while converting OP""" + + if _should_fill_span(): + return _SpanFiller(span).fill(sym) + return sym diff --git a/python/tvm/relay/function.py b/python/tvm/relay/function.py index 6b3513cb5e1a..68d8953900cf 100644 --- a/python/tvm/relay/function.py +++ b/python/tvm/relay/function.py @@ -44,14 +44,17 @@ class Function(BaseFunc): type_params: Optional[List[tvm.relay.TypeParam]] The additional type parameters, this is only used in advanced usecase of template functions. + + span: Optional[tvm.relay.Span] + Span that points to original source code. """ - def __init__(self, params, body, ret_type=None, type_params=None, attrs=None): + def __init__(self, params, body, ret_type=None, type_params=None, attrs=None, span=None): if type_params is None: type_params = convert([]) self.__init_handle_by_constructor__( - _ffi_api.Function, params, body, ret_type, type_params, attrs + _ffi_api.Function, params, body, ret_type, type_params, attrs, span ) def __call__(self, *args): diff --git a/python/tvm/relay/loops.py b/python/tvm/relay/loops.py index 6c2ab2e23d72..d46e34860f0b 100644 --- a/python/tvm/relay/loops.py +++ b/python/tvm/relay/loops.py @@ -54,7 +54,7 @@ def while_loop(cond, loop_vars, loop_bodies): for i, loop_var in enumerate(loop_vars): name = loop_var.name_hint if isinstance(loop_var, _expr.Var) else "arg{}".format(i) - new_var = _expr.var(name, type_annotation=sb.type_of(loop_var)) + new_var = _expr.var(name, type_annotation=sb.type_of(loop_var), span=loop_var.span) fresh_vars.append(new_var) with sb.if_scope(cond(*fresh_vars)): diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 74ca326bca7e..9dd1f450d241 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -2081,3 +2081,28 @@ def pprint(name, obj): f"or an instance of `tvm.tir.PrimFunc`. " f"Instead, received {type(expected)}." ) + + +class _control_span_filling: + def __init__(self, on=True): + self._old_state = os.environ["TVM_SPANFILLING"] if "TVM_SPANFILLING" in os.environ else None + self._on = on + + def __enter__(self): + os.environ["TVM_SPANFILLING"] = str(int(self._on)) + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._old_state: + os.environ["TVM_SPANFILLING"] = self._old_state + else: + del os.environ["TVM_SPANFILLING"] + + +class enable_span_filling(_control_span_filling): + def __init__(self): + super().__init__() + + +class disable_span_filling(_control_span_filling): + def __init__(self): + super().__init__(on=False) diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 76cac28b07f7..27cb685985d1 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -65,15 +65,8 @@ Doc RelayTextPrinter::PrintOptionalInfo(const Expr& expr) { if (annotate_ == nullptr) { if ((expr.as() || expr.as() || expr.as() || expr.as() || expr.as() || expr.as()) && - (expr->checked_type_.defined() || expr->span.defined())) { - doc << " /*"; - if (expr->checked_type_.defined()) { - doc << " ty=" << Print(expr->checked_type()); - } - if (expr->span.defined()) { - doc << " span=" << PrintSpan(expr->span); - } - doc << " */"; + expr->checked_type_.defined()) { + doc << " /* ty=" << Print(expr->checked_type()) << " */"; } } else { std::string annotated_expr = annotate_(expr); @@ -81,6 +74,11 @@ Doc RelayTextPrinter::PrintOptionalInfo(const Expr& expr) { doc << annotated_expr; } } + + if (expr->span.defined()) { + doc << " /* si=" << Print(expr->span) << " */"; + } + return doc; } @@ -132,6 +130,10 @@ Doc RelayTextPrinter::Print(const ObjectRef& node, bool meta, bool try_inline) { return PrintPattern(Downcast(node), meta); } else if (node.as()) { return PrintMod(Downcast(node)); + } else if (node.as()) { + std::ostringstream os; + os << Downcast(node); + return Doc::RawText(os.str()); } else { // default module. std::ostringstream os; @@ -962,14 +964,6 @@ Doc RelayTextPrinter::PrintMapAsAttributeValue(const Map& return doc; } -Doc RelayTextPrinter::PrintSpan(const Span& span) { - Doc doc; - const auto* span_node = span.as(); - ICHECK(span_node); - doc << span_node->source_name->name << ":" << span_node->line << ":" << span_node->column; - return doc; -} - TVM_REGISTER_GLOBAL("ir.TextPrinter").set_body_typed([](ObjectRef node) { auto text = AsText(node, false, nullptr); return text; diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 2dc0997f82ec..5a47a8382e8c 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -113,8 +113,6 @@ class RelayTextPrinter : public ExprFunctor, */ Doc PrintMapAsAttributeValue(const Map& map); - Doc PrintSpan(const Span& span); - Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false); Doc TempVar(int n); @@ -472,4 +470,25 @@ class TextPrinter { }; } // namespace tvm +namespace tvm { +namespace runtime { + +inline std::ostream& operator<<(std::ostream& os, const SourceName& source_name) { // NOLINT(*) + ICHECK(source_name->name.defined()); + os << source_name->name; + return os; +} + +inline std::ostream& operator<<(std::ostream& os, const Span& span) { // NOLINT(*) + if (span.defined()) { + os << span->source_name; + } else { + os << "nullptr"; + } + return os; +} + +} // namespace runtime +} // namespace tvm + #endif // TVM_PRINTER_TEXT_PRINTER_H_ diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 5c85b3b29df7..38db974f3970 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -72,8 +72,8 @@ Constant::Constant(runtime::NDArray data, Span span) { TVM_REGISTER_NODE_TYPE(ConstantNode); -TVM_REGISTER_GLOBAL("relay.ir.Constant").set_body_typed([](runtime::NDArray data) { - return Constant(data); +TVM_REGISTER_GLOBAL("relay.ir.Constant").set_body_typed([](runtime::NDArray data, Span span) { + return Constant(data, span); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -200,8 +200,8 @@ Var WithFields(Var var, Optional opt_vid, Optional opt_type_annotation TVM_REGISTER_NODE_TYPE(VarNode); -TVM_REGISTER_GLOBAL("relay.ir.Var").set_body_typed([](String str, Type type_annotation) { - return Var(str, type_annotation); +TVM_REGISTER_GLOBAL("relay.ir.Var").set_body_typed([](String str, Type type_annotation, Span span) { + return Var(str, type_annotation, span); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -320,8 +320,8 @@ Let WithFields(Let let, Optional opt_var, Optional opt_value, Optiona TVM_REGISTER_NODE_TYPE(LetNode); -TVM_REGISTER_GLOBAL("relay.ir.Let").set_body_typed([](Var var, Expr value, Expr body) { - return Let(var, value, body); +TVM_REGISTER_GLOBAL("relay.ir.Let").set_body_typed([](Var var, Expr value, Expr body, Span span) { + return Let(var, value, body, span); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -367,8 +367,8 @@ If WithFields(If if_expr, Optional opt_cond, Optional opt_true_branc TVM_REGISTER_NODE_TYPE(IfNode); TVM_REGISTER_GLOBAL("relay.ir.If") - .set_body_typed([](Expr cond, Expr true_branch, Expr false_branch) { - return If(cond, true_branch, false_branch); + .set_body_typed([](Expr cond, Expr true_branch, Expr false_branch, Span span) { + return If(cond, true_branch, false_branch, span); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -410,8 +410,8 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, TVM_REGISTER_NODE_TYPE(TupleGetItemNode); -TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int index) { - return TupleGetItem(tuple, index); +TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int index, Span span) { + return TupleGetItem(tuple, index, span); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -448,8 +448,8 @@ RefCreate WithFields(RefCreate ref_create, Optional opt_value, TVM_REGISTER_NODE_TYPE(RefCreateNode); -TVM_REGISTER_GLOBAL("relay.ir.RefCreate").set_body_typed([](Expr value) { - return RefCreate(value); +TVM_REGISTER_GLOBAL("relay.ir.RefCreate").set_body_typed([](Expr value, Span span) { + return RefCreate(value, span); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -486,7 +486,9 @@ RefRead WithFields(RefRead ref_read, Optional opt_ref, TVM_REGISTER_NODE_TYPE(RefReadNode); -TVM_REGISTER_GLOBAL("relay.ir.RefRead").set_body_typed([](Expr ref) { return RefRead(ref); }); +TVM_REGISTER_GLOBAL("relay.ir.RefRead").set_body_typed([](Expr ref, Span span) { + return RefRead(ref, span); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -525,8 +527,8 @@ RefWrite WithFields(RefWrite ref_write, Optional opt_ref, Optional o TVM_REGISTER_NODE_TYPE(RefWriteNode); -TVM_REGISTER_GLOBAL("relay.ir.RefWrite").set_body_typed([](Expr ref, Expr value) { - return RefWrite(ref, value); +TVM_REGISTER_GLOBAL("relay.ir.RefWrite").set_body_typed([](Expr ref, Expr value, Span span) { + return RefWrite(ref, value, span); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index 1a3db9974f05..07cfb27b1d35 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -124,8 +124,8 @@ TVM_REGISTER_NODE_TYPE(FunctionNode); TVM_REGISTER_GLOBAL("relay.ir.Function") .set_body_typed([](tvm::Array params, Expr body, Type ret_type, - tvm::Array ty_params, tvm::DictAttrs attrs) { - return Function(params, body, ret_type, ty_params, attrs); + tvm::Array ty_params, tvm::DictAttrs attrs, Span span) { + return Function(params, body, ret_type, ty_params, attrs, span); }); TVM_REGISTER_GLOBAL("relay.ir.FunctionWithFields") .set_body_typed([](Function function, Optional> opt_params, Optional opt_body, diff --git a/tests/python/frontend/test_common.py b/tests/python/frontend/test_common.py index e706f2af304a..0e2e05b21ffd 100644 --- a/tests/python/frontend/test_common.py +++ b/tests/python/frontend/test_common.py @@ -14,7 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from tvm.relay.frontend.common import StrAttrsDict + +import numpy as np + +from tvm import relay, testing +from tvm.relay.frontend.common import StrAttrsDict, set_span +from relay.utils.tag_span import _set_span, _create_span, _verify_structural_equal_with_span def test_key_is_present(): @@ -27,6 +32,203 @@ def test_key_is_not_present(): assert not attrs.has_attr("b") +def test_set_span(): + def _verify_env_var_switch(): + def _res(should_fill): + if should_fill: + with testing.enable_span_filling(): + return set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var") + else: + with testing.disable_span_filling(): + return set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var") + + disable = relay.var("x", shape=(1, 64, 56, 56)) + enable = relay.var("x", shape=(1, 64, 56, 56), span=_create_span("x_var")) + + assert _verify_structural_equal_with_span(_res(False), disable) + assert _verify_structural_equal_with_span(_res(True), enable) + + # Should tag all exprs without span, and stop when expr is span-tagged + def _verify_builtin_tuple(): + def _res(): + a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a")) + b = relay.const(np.zeros([1, 1, 1]), dtype="int64") + return set_span(tuple([a, b]), "tuple") + + def _golden(): + a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a")) + b = relay.const(np.zeros([1, 1, 1]), dtype="int64", span=_create_span("tuple")) + return tuple([a, b]) + + res_tuple, golden_tuple = _res(), _golden() + assert len(res_tuple) == len(golden_tuple) + for i in range(len(res_tuple)): + assert _verify_structural_equal_with_span(res_tuple[i], golden_tuple[i]) + + def _verify_builtin_list(): + def _res(): + a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a")) + b = relay.const(np.zeros([1, 1, 1]), dtype="int64") + t = relay.Tuple([a, b]) + t_a = relay.TupleGetItem(t, 0) + t_b = relay.TupleGetItem(t, 1) + return set_span([t_a, t_b], "list") + + def _golden(): + a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a")) + b = relay.const(np.zeros([1, 1, 1]), dtype="int64", span=_create_span("list")) + t = relay.Tuple([a, b], span=_create_span("list")) + t_a = relay.TupleGetItem(t, 0, span=_create_span("list")) + t_b = relay.TupleGetItem(t, 1, span=_create_span("list")) + return [t_a, t_b] + + res_list, golden_list = _res(), _golden() + assert len(res_list) == len(golden_list) + for i in range(len(res_list)): + assert _verify_structural_equal_with_span(res_list[i], golden_list[i]) + + def _verify_var(): + x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var") + x_expected = relay.var("x", shape=(1, 64, 56, 56), span=_create_span("x_var")) + assert _verify_structural_equal_with_span(x, x_expected) + + def _verify_constant(): + c = set_span(relay.const(np.ones([64, 64, 3, 3]), dtype="int64"), "const_c") + c_expected = relay.const( + np.ones([64, 64, 3, 3]), dtype="int64", span=_create_span("const_c") + ) + assert _verify_structural_equal_with_span(c, c_expected) + + def _verify_call(): + def _res(): + x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var") + w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64") + y = set_span( + relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)), "conv2d" + ) + return relay.Function([x], y) + + def _golden(): + x = relay.var("x", shape=(1, 64, 56, 56), span=_create_span("x_var")) + w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64", span=_create_span("conv2d")) + y = _set_span( + relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)), "conv2d" + ) + return relay.Function([x], y) + + assert _verify_structural_equal_with_span(_res(), _golden()) + + def _verify_tuple(): + def _res(): + a = set_span(relay.const(np.ones([1, 1, 1]), dtype="int64"), "a") + b = relay.const(np.ones([1, 1, 1]), dtype="int64") + t = set_span(relay.Tuple([a, b]), "t") + return relay.Function([], t) + + def _golden(): + a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a")) + b = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("t")) + t = relay.Tuple([a, b], span=_create_span("t")) + return relay.Function([], t) + + assert _verify_structural_equal_with_span(_res(), _golden()) + + def _verify_tuple_getitem(): + def _res(): + a = set_span(relay.const(np.ones([1, 1, 1]), dtype="int64"), "a") + b = relay.const(np.ones([1, 1, 1]), dtype="int64") + t = relay.Tuple([a, b]) + i = set_span(relay.TupleGetItem(t, 0), "i") + return relay.Function([], i) + + def _golden(): + a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a")) + b = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("i")) + t = relay.Tuple([a, b], span=_create_span("i")) + i = relay.TupleGetItem(t, 0, span=_create_span("i")) + return relay.Function([], i) + + assert _verify_structural_equal_with_span(_res(), _golden()) + + def _verify_let(): + def _res(): + x = set_span(relay.Var("x"), "x_var") + c_1 = relay.const(np.ones(10)) + add = relay.add(x, x) + body = set_span(relay.Let(x, c_1, add), "let") + + c_2 = set_span(relay.const(np.zeros(10)), "zeros") + y = set_span(relay.add(body, c_2), "add_2") + return relay.Function([x], y) + + def _golden(): + x = relay.Var("x", span=_create_span("x_var")) + c_1 = relay.const(np.ones(10), span=_create_span("let")) + add = _set_span(relay.add(x, x), "let") + body = relay.Let(x, c_1, add, span=_create_span("let")) + + c_2 = relay.const(np.zeros(10), span=_create_span("zeros")) + y = _set_span(relay.add(body, c_2), "add_2") + return relay.Function([x], y) + + assert _verify_structural_equal_with_span(_res(), _golden()) + + def _verify_if(): + def _res(): + x = set_span(relay.var("x", shape=[], dtype="float32"), "x_var") + y = set_span(relay.var("y", shape=[], dtype="float32"), "y_var") + eq = relay.equal(x, y) + + true_branch = set_span(relay.add(x, y), "true_branch") + false_branch = relay.subtract(x, y) + ife = set_span(relay.If(eq, true_branch, false_branch), "if") + return relay.Function([x, y], ife) + + def _golden(): + x = relay.var("x", shape=[], dtype="float32", span=_create_span("x_var")) + y = relay.var("y", shape=[], dtype="float32", span=_create_span("y_var")) + eq = _set_span(relay.equal(x, y), "if") + + true_branch = _set_span(relay.add(x, y), "true_branch") + false_branch = _set_span(relay.subtract(x, y), "if") + ife = relay.If(eq, true_branch, false_branch, span=_create_span("if")) + return relay.Function([x, y], ife) + + assert _verify_structural_equal_with_span(_res(), _golden()) + + def _verify_fn(): + def _res(): + x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var") + w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64") + y = relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)) + f = set_span(relay.Function([x], y), "func") + return f + + def _golden(): + x = relay.var("x", shape=(1, 64, 56, 56), span=_create_span("x_var")) + w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64", span=_create_span("func")) + y = _set_span( + relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)), "func" + ) + f = relay.Function([x], y, span=_create_span("func")) + return f + + assert _verify_structural_equal_with_span(_res(), _golden()) + + _verify_env_var_switch() + _verify_builtin_tuple() + _verify_builtin_list() + _verify_var() + _verify_constant() + _verify_call() + _verify_tuple() + _verify_tuple_getitem() + _verify_let() + _verify_if() + _verify_fn() + + if __name__ == "__main__": test_key_is_present() test_key_is_present() + test_set_span() diff --git a/tests/python/relay/utils/tag_span.py b/tests/python/relay/utils/tag_span.py new file mode 100644 index 000000000000..30ec32752457 --- /dev/null +++ b/tests/python/relay/utils/tag_span.py @@ -0,0 +1,106 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import relay, tir +from tvm.relay.expr_functor import ExprVisitor + + +def _set_span(expr, src): + if isinstance(expr, relay.Call): + return relay.Call(expr.op, expr.args, expr.attrs, expr.type_args, _create_span(src)) + elif isinstance(expr, relay.Var): + return relay.var(expr.name_hint, expr.type_annotation, None, None, _create_span(src)) + elif isinstance(expr, relay.TupleGetItem): + return relay.TupleGetItem(expr.tuple_value, expr.index, _create_span(src)) + elif isinstance(expr, relay.Constant): + return relay.Constant(expr.data, _create_span(src)) + elif isinstance(expr, relay.TupleWrapper): + return relay.TupleWrapper(_set_span(expr.tuple_value, src), expr.size) + elif isinstance(expr, relay.Tuple): + return relay.Tuple(expr.fields, _create_span(src)) + elif isinstance(expr, tir.AttrStmt): + return tir.AttrStmt(expr.node, expr.attr_key, expr.value, expr.body, _create_span(src)) + + assert False, f"unsupported type {type(expr)}" + + +def _create_span(src): + if isinstance(src, list): + tmp_list = [] + for s in src: + if isinstance(s, str): + tmp_list.append(_create_span(s)) + elif isinstance(s, relay.Span): + tmp_list.append(s) + elif isinstance(s, relay.SequentialSpan): + tmp_list.extend(s.spans) + elif s is None: + tmp_list.append(s) + else: + assert False, f"unsupported type {type(s)}" + return relay.SequentialSpan(tmp_list) + return relay.Span(relay.SourceName(src), 0, 0, 0, 0) + + +def _collect_spans(objref): + class Collector: + def __init__(self): + self._spans = [] + + def collect(self, objref): + if hasattr(objref, "span"): + self._spans.append(objref.span) + + @property + def get_spans(self): + return self._spans + + pov = None + if isinstance(objref, relay.Expr): + pov = relay.analysis.post_order_visit + elif isinstance(objref, (tir.Stmt, tir.expr.PrimExprWithOp)): + pov = tir.stmt_functor.post_order_visit + else: + assert False, f"unsupported type {type(objref)}" + + c = Collector() + pov(objref, c.collect) + return c.get_spans + + +def _verify_span(lhs, rhs): + lhs_spans, rhs_spans = _collect_spans(lhs), _collect_spans(rhs) + + if len(lhs_spans) != len(rhs_spans): + return False + + for i in range(len(lhs_spans)): + if not tvm.ir.structural_equal(lhs_spans[i], rhs_spans[i]): + return False + return True + + +def _verify_structural_equal_with_span(lhs, rhs, assert_mode=False, map_free_vars=False): + if isinstance(lhs, relay.Var) and isinstance(rhs, relay.Var): + return _verify_span(lhs, rhs) + + if assert_mode: + tvm.ir.assert_structural_equal(lhs, rhs, map_free_vars) + elif not tvm.ir.structural_equal(lhs, rhs, map_free_vars): + return False + + return _verify_span(lhs, rhs) From 556ec6b38f922f9bcc600582706638172c8954bc Mon Sep 17 00:00:00 2001 From: Joey Tsai Date: Thu, 17 Nov 2022 11:32:57 +0800 Subject: [PATCH 2/4] [SpanFillingCommonAPI] - Change based on comment - Discard the change of pretty-print - Add document to set_span --- python/tvm/relay/frontend/common.py | 39 ++++++++++++++++++++++++++--- src/printer/relay_text_printer.cc | 28 +++++++++++++-------- src/printer/text_printer.h | 23 ++--------------- 3 files changed, 54 insertions(+), 36 deletions(-) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 4b211431f37e..18f9db72974e 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -313,10 +313,9 @@ def new_const(self, value, shape=None, dtype="float32", source_name=None): shape = value.shape self.const_ctr += 1 self.params[name] = value - tmp_var = _expr.var(name_hint=name, shape=shape, dtype=dtype) + self.exprs[name] = _expr.var(name_hint=name, shape=shape, dtype=dtype) if source_name: - tmp_var = set_span(tmp_var, source_name) - self.exprs[name] = tmp_var + self.exprs[name] = set_span(self.exprs[name], source_name) return self.exprs[name] def get_expr(self, name): @@ -1130,7 +1129,39 @@ def _should_fill_span(): def set_span(sym, span): - """Set up the sapn of relay expression(s) while converting OP""" + """ + Recursively tag the span to the symbol. Stop when it encounters a span-tagged expr. Disabled + when setting the environment variable "TVM_SPANFILLING" as 0. + + Parameters + ---------- + sym : + A symbol is generated from the conversion of a frontend operator. Raise an error when the + type of the symbol is not supported. + + span : String, Span, or bytes + The source information of the corresponding symbol. + + Returns + ------- + result : + The symbol tagged with span. + + Examples + -------- + .. code-block:: python + + x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var") + w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64") + y = set_span( + relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)), "conv2d" + ) + print(relay.Function([x], y)) + + #fn (%x: Tensor[(1, 64, 56, 56), float32] /* span=x_var:0:0 */) { + # nn.conv2d(%x, meta[relay.Constant][0] /* span=conv2d:0:0 */, ...) /* span=conv2d:0:0 */ + #} + """ if _should_fill_span(): return _SpanFiller(span).fill(sym) diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 27cb685985d1..76cac28b07f7 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -65,8 +65,15 @@ Doc RelayTextPrinter::PrintOptionalInfo(const Expr& expr) { if (annotate_ == nullptr) { if ((expr.as() || expr.as() || expr.as() || expr.as() || expr.as() || expr.as()) && - expr->checked_type_.defined()) { - doc << " /* ty=" << Print(expr->checked_type()) << " */"; + (expr->checked_type_.defined() || expr->span.defined())) { + doc << " /*"; + if (expr->checked_type_.defined()) { + doc << " ty=" << Print(expr->checked_type()); + } + if (expr->span.defined()) { + doc << " span=" << PrintSpan(expr->span); + } + doc << " */"; } } else { std::string annotated_expr = annotate_(expr); @@ -74,11 +81,6 @@ Doc RelayTextPrinter::PrintOptionalInfo(const Expr& expr) { doc << annotated_expr; } } - - if (expr->span.defined()) { - doc << " /* si=" << Print(expr->span) << " */"; - } - return doc; } @@ -130,10 +132,6 @@ Doc RelayTextPrinter::Print(const ObjectRef& node, bool meta, bool try_inline) { return PrintPattern(Downcast(node), meta); } else if (node.as()) { return PrintMod(Downcast(node)); - } else if (node.as()) { - std::ostringstream os; - os << Downcast(node); - return Doc::RawText(os.str()); } else { // default module. std::ostringstream os; @@ -964,6 +962,14 @@ Doc RelayTextPrinter::PrintMapAsAttributeValue(const Map& return doc; } +Doc RelayTextPrinter::PrintSpan(const Span& span) { + Doc doc; + const auto* span_node = span.as(); + ICHECK(span_node); + doc << span_node->source_name->name << ":" << span_node->line << ":" << span_node->column; + return doc; +} + TVM_REGISTER_GLOBAL("ir.TextPrinter").set_body_typed([](ObjectRef node) { auto text = AsText(node, false, nullptr); return text; diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 5a47a8382e8c..2dc0997f82ec 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -113,6 +113,8 @@ class RelayTextPrinter : public ExprFunctor, */ Doc PrintMapAsAttributeValue(const Map& map); + Doc PrintSpan(const Span& span); + Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false); Doc TempVar(int n); @@ -470,25 +472,4 @@ class TextPrinter { }; } // namespace tvm -namespace tvm { -namespace runtime { - -inline std::ostream& operator<<(std::ostream& os, const SourceName& source_name) { // NOLINT(*) - ICHECK(source_name->name.defined()); - os << source_name->name; - return os; -} - -inline std::ostream& operator<<(std::ostream& os, const Span& span) { // NOLINT(*) - if (span.defined()) { - os << span->source_name; - } else { - os << "nullptr"; - } - return os; -} - -} // namespace runtime -} // namespace tvm - #endif // TVM_PRINTER_TEXT_PRINTER_H_ From 3a873d91a143d54979ad57d13efc7f205eed5cf3 Mon Sep 17 00:00:00 2001 From: Joey Tsai Date: Wed, 23 Nov 2022 16:18:48 +0800 Subject: [PATCH 3/4] [SpanFillingCommonAPI] - Change the test cases to pytest style - Group the set_span test cases to a class --- tests/python/frontend/test_common.py | 64 +++++++++++----------------- tests/python/relay/utils/tag_span.py | 17 ++++---- 2 files changed, 33 insertions(+), 48 deletions(-) diff --git a/tests/python/frontend/test_common.py b/tests/python/frontend/test_common.py index 0e2e05b21ffd..6fd7a193f564 100644 --- a/tests/python/frontend/test_common.py +++ b/tests/python/frontend/test_common.py @@ -32,8 +32,8 @@ def test_key_is_not_present(): assert not attrs.has_attr("b") -def test_set_span(): - def _verify_env_var_switch(): +class TestSetSpan: + def test_env_var_switch(self): def _res(should_fill): if should_fill: with testing.enable_span_filling(): @@ -45,11 +45,11 @@ def _res(should_fill): disable = relay.var("x", shape=(1, 64, 56, 56)) enable = relay.var("x", shape=(1, 64, 56, 56), span=_create_span("x_var")) - assert _verify_structural_equal_with_span(_res(False), disable) - assert _verify_structural_equal_with_span(_res(True), enable) + _verify_structural_equal_with_span(_res(False), disable) + _verify_structural_equal_with_span(_res(True), enable) # Should tag all exprs without span, and stop when expr is span-tagged - def _verify_builtin_tuple(): + def test_builtin_tuple(self): def _res(): a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a")) b = relay.const(np.zeros([1, 1, 1]), dtype="int64") @@ -63,9 +63,9 @@ def _golden(): res_tuple, golden_tuple = _res(), _golden() assert len(res_tuple) == len(golden_tuple) for i in range(len(res_tuple)): - assert _verify_structural_equal_with_span(res_tuple[i], golden_tuple[i]) + _verify_structural_equal_with_span(res_tuple[i], golden_tuple[i]) - def _verify_builtin_list(): + def test_builtin_list(self): def _res(): a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a")) b = relay.const(np.zeros([1, 1, 1]), dtype="int64") @@ -85,21 +85,21 @@ def _golden(): res_list, golden_list = _res(), _golden() assert len(res_list) == len(golden_list) for i in range(len(res_list)): - assert _verify_structural_equal_with_span(res_list[i], golden_list[i]) + _verify_structural_equal_with_span(res_list[i], golden_list[i]) - def _verify_var(): + def test_var(self): x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var") x_expected = relay.var("x", shape=(1, 64, 56, 56), span=_create_span("x_var")) - assert _verify_structural_equal_with_span(x, x_expected) + _verify_structural_equal_with_span(x, x_expected) - def _verify_constant(): + def test_constant(self): c = set_span(relay.const(np.ones([64, 64, 3, 3]), dtype="int64"), "const_c") c_expected = relay.const( np.ones([64, 64, 3, 3]), dtype="int64", span=_create_span("const_c") ) - assert _verify_structural_equal_with_span(c, c_expected) + _verify_structural_equal_with_span(c, c_expected) - def _verify_call(): + def test_call(self): def _res(): x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var") w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64") @@ -116,9 +116,9 @@ def _golden(): ) return relay.Function([x], y) - assert _verify_structural_equal_with_span(_res(), _golden()) + _verify_structural_equal_with_span(_res(), _golden()) - def _verify_tuple(): + def test_tuple(self): def _res(): a = set_span(relay.const(np.ones([1, 1, 1]), dtype="int64"), "a") b = relay.const(np.ones([1, 1, 1]), dtype="int64") @@ -131,9 +131,9 @@ def _golden(): t = relay.Tuple([a, b], span=_create_span("t")) return relay.Function([], t) - assert _verify_structural_equal_with_span(_res(), _golden()) + _verify_structural_equal_with_span(_res(), _golden()) - def _verify_tuple_getitem(): + def test_tuple_getitem(self): def _res(): a = set_span(relay.const(np.ones([1, 1, 1]), dtype="int64"), "a") b = relay.const(np.ones([1, 1, 1]), dtype="int64") @@ -148,9 +148,9 @@ def _golden(): i = relay.TupleGetItem(t, 0, span=_create_span("i")) return relay.Function([], i) - assert _verify_structural_equal_with_span(_res(), _golden()) + _verify_structural_equal_with_span(_res(), _golden()) - def _verify_let(): + def test_let(self): def _res(): x = set_span(relay.Var("x"), "x_var") c_1 = relay.const(np.ones(10)) @@ -171,9 +171,9 @@ def _golden(): y = _set_span(relay.add(body, c_2), "add_2") return relay.Function([x], y) - assert _verify_structural_equal_with_span(_res(), _golden()) + _verify_structural_equal_with_span(_res(), _golden()) - def _verify_if(): + def test_if(self): def _res(): x = set_span(relay.var("x", shape=[], dtype="float32"), "x_var") y = set_span(relay.var("y", shape=[], dtype="float32"), "y_var") @@ -194,9 +194,9 @@ def _golden(): ife = relay.If(eq, true_branch, false_branch, span=_create_span("if")) return relay.Function([x, y], ife) - assert _verify_structural_equal_with_span(_res(), _golden()) + _verify_structural_equal_with_span(_res(), _golden()) - def _verify_fn(): + def test_fn(self): def _res(): x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var") w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64") @@ -213,22 +213,8 @@ def _golden(): f = relay.Function([x], y, span=_create_span("func")) return f - assert _verify_structural_equal_with_span(_res(), _golden()) - - _verify_env_var_switch() - _verify_builtin_tuple() - _verify_builtin_list() - _verify_var() - _verify_constant() - _verify_call() - _verify_tuple() - _verify_tuple_getitem() - _verify_let() - _verify_if() - _verify_fn() + _verify_structural_equal_with_span(_res(), _golden()) if __name__ == "__main__": - test_key_is_present() - test_key_is_present() - test_set_span() + testing.main() diff --git a/tests/python/relay/utils/tag_span.py b/tests/python/relay/utils/tag_span.py index 30ec32752457..d7511a809eaa 100644 --- a/tests/python/relay/utils/tag_span.py +++ b/tests/python/relay/utils/tag_span.py @@ -85,22 +85,21 @@ def get_spans(self): def _verify_span(lhs, rhs): lhs_spans, rhs_spans = _collect_spans(lhs), _collect_spans(rhs) - if len(lhs_spans) != len(rhs_spans): - return False + assert len(lhs_spans) == len(rhs_spans) for i in range(len(lhs_spans)): - if not tvm.ir.structural_equal(lhs_spans[i], rhs_spans[i]): - return False - return True + assert tvm.ir.structural_equal(lhs_spans[i], rhs_spans[i]) def _verify_structural_equal_with_span(lhs, rhs, assert_mode=False, map_free_vars=False): if isinstance(lhs, relay.Var) and isinstance(rhs, relay.Var): - return _verify_span(lhs, rhs) + # SEqualReduce compares the vid of Var type. Threrfore we only compare span here. + _verify_span(lhs, rhs) + return if assert_mode: tvm.ir.assert_structural_equal(lhs, rhs, map_free_vars) - elif not tvm.ir.structural_equal(lhs, rhs, map_free_vars): - return False + else: + assert tvm.ir.structural_equal(lhs, rhs, map_free_vars) - return _verify_span(lhs, rhs) + _verify_span(lhs, rhs) From 5a2cfd84f29db6e974a567cf4a1191c48202a10c Mon Sep 17 00:00:00 2001 From: Joey Tsai Date: Tue, 13 Dec 2022 16:47:13 +0800 Subject: [PATCH 4/4] [SpanFillingCommonAPI] - Expose Relay Expr WithFields APIs to python side - Change the APIs in _SpanFiller from creating new instance to WithFields. - Change the control of frontend span filler from env var to the passcontext config --- python/tvm/relay/expr.py | 127 +++++++++++++++++++++++++++ python/tvm/relay/frontend/common.py | 44 ++++------ python/tvm/testing/utils.py | 9 +- src/ir/span.cc | 4 + src/relay/ir/expr.cc | 56 ++++++++++++ tests/python/frontend/test_common.py | 4 +- tests/python/relay/utils/tag_span.py | 31 ++++--- 7 files changed, 228 insertions(+), 47 deletions(-) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index f2b2b4ebbdc5..88b84bbe7ebc 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -180,6 +180,21 @@ def __init__(self, data, span=None): self.__init_handle_by_constructor__(_ffi_api.Constant, data, span) +@tvm._ffi.register_func("relay.ConstantWithFields") +def ConstantWithFields( + constant, + data=None, + virtual_device=None, + span=None, +): + """ + Returns constant with the given properties. A None property denotes 'no change'. + Returns constant if all properties are unchanged. Otherwise, returns a copy with the new + fields. + """ + return _ffi_api.ConstantWithFields(constant, data, virtual_device, span) + + @tvm._ffi.register_object("relay.Tuple") class Tuple(ExprWithOp): """Tuple expression that groups several fields together. @@ -208,6 +223,16 @@ def astype(self, _): raise TypeError("astype cannot be used on tuple") +@tvm._ffi.register_func("relay.TupleWithFields") +def TupleWithFields(tup, fields=None, virtual_device=None, span=None): + """ + Returns tuple with the given properties. A None property denotes 'no change'. + Returns tuple if all properties are unchanged. Otherwise, returns a copy with the new + fields. + """ + return _ffi_api.TupleWithFields(tup, fields, virtual_device, span) + + @tvm._ffi.register_object("relay.Var") class Var(ExprWithOp): """A local variable in Relay. @@ -239,6 +264,16 @@ def name_hint(self): return name +@tvm._ffi.register_func("relay.VarWithFields") +def VarWithFields(variable, vid=None, type_annotation=None, virtual_device=None, span=None): + """ + Returns var with the given properties. A None property denotes 'no change'. + Returns var if all properties are unchanged. Otherwise, returns a copy with the new + fields. + """ + return _ffi_api.VarWithFields(variable, vid, type_annotation, virtual_device, span) + + @tvm._ffi.register_object("relay.Call") class Call(ExprWithOp): """Function call node in Relay. @@ -271,6 +306,18 @@ def __init__(self, op, args, attrs=None, type_args=None, span=None): self.__init_handle_by_constructor__(_ffi_api.Call, op, args, attrs, type_args, span) +@tvm._ffi.register_func("relay.CallWithFields") +def CallWithFields( + call, op=None, args=None, attrs=None, type_args=None, virtual_device=None, span=None +): + """ + Returns call with the given properties. A None property denotes 'no change'. + Returns call if all properties are unchanged. Otherwise, returns a copy with the new + fields. + """ + return _ffi_api.CallWithFields(call, op, args, attrs, type_args, virtual_device, span) + + @tvm._ffi.register_object("relay.Let") class Let(ExprWithOp): """Let variable binding expression. @@ -294,6 +341,16 @@ def __init__(self, variable, value, body, span=None): self.__init_handle_by_constructor__(_ffi_api.Let, variable, value, body, span) +@tvm._ffi.register_func("relay.LetWithFields") +def LetWithFields(let, variable=None, value=None, body=None, virtual_device=None, span=None): + """ + Returns let with the given properties. A None property denotes 'no change'. + Returns let if all properties are unchanged. Otherwise, returns a copy with the new + fields. + """ + return _ffi_api.LetWithFields(let, variable, value, body, virtual_device, span) + + @tvm._ffi.register_object("relay.If") class If(ExprWithOp): """A conditional expression in Relay. @@ -317,6 +374,18 @@ def __init__(self, cond, true_branch, false_branch, span=None): self.__init_handle_by_constructor__(_ffi_api.If, cond, true_branch, false_branch, span) +@tvm._ffi.register_func("relay.IfWithFields") +def IfWithFields( + if_expr, cond=None, true_branch=None, false_branch=None, virtual_device=None, span=None +): + """ + Returns if with the given properties. A None property denotes 'no change'. + Returns if if all properties are unchanged. Otherwise, returns a copy with the new + fields. + """ + return _ffi_api.IfWithFields(if_expr, cond, true_branch, false_branch, virtual_device, span) + + @tvm._ffi.register_object("relay.TupleGetItem") class TupleGetItem(ExprWithOp): """Get index-th item from a tuple. @@ -337,6 +406,18 @@ def __init__(self, tuple_value, index, span=None): self.__init_handle_by_constructor__(_ffi_api.TupleGetItem, tuple_value, index, span) +@tvm._ffi.register_func("relay.TupleGetItemWithFields") +def TupleGetItemWithFields( + tuple_get_item, tuple_value=None, index=None, virtual_device=None, span=None +): + """ + Returns tuple_get_item with the given properties. A None property denotes 'no change'. + Returns tuple_get_item if all properties are unchanged. Otherwise, returns a copy with the new + fields. + """ + return _ffi_api.TupleGetItemWithFields(tuple_get_item, tuple_value, index, virtual_device, span) + + @tvm._ffi.register_object("relay.RefCreate") class RefCreate(ExprWithOp): """Create a new reference from initial value. @@ -353,6 +434,21 @@ def __init__(self, value, span=None): self.__init_handle_by_constructor__(_ffi_api.RefCreate, value, span) +@tvm._ffi.register_func("relay.RefCreateWithFields") +def RefCreateWithFields( + ref_create, + value=None, + virtual_device=None, + span=None, +): + """ + Returns ref_create with the given properties. A None property denotes 'no change'. + Returns ref_create if all properties are unchanged. Otherwise, returns a copy with the new + fields. + """ + return _ffi_api.RefCreateWithFields(ref_create, value, virtual_device, span) + + @tvm._ffi.register_object("relay.RefRead") class RefRead(ExprWithOp): """Get the value inside the reference. @@ -369,6 +465,21 @@ def __init__(self, ref, span=None): self.__init_handle_by_constructor__(_ffi_api.RefRead, ref, span) +@tvm._ffi.register_func("relay.RefReadWithFields") +def RefReadWithFields( + ref_read, + ref=None, + virtual_device=None, + span=None, +): + """ + Returns ref_read with the given properties. A None property denotes 'no change'. + Returns ref_read if all properties are unchanged. Otherwise, returns a copy with the new + fields. + """ + return _ffi_api.RefReadWithFields(ref_read, ref, virtual_device, span) + + @tvm._ffi.register_object("relay.RefWrite") class RefWrite(ExprWithOp): """ @@ -390,6 +501,22 @@ def __init__(self, ref, value, span=None): self.__init_handle_by_constructor__(_ffi_api.RefWrite, ref, value, span) +@tvm._ffi.register_func("relay.RefWriteWithFields") +def RefWriteWithFields( + ref_write, + ref=None, + value=None, + virtual_device=None, + span=None, +): + """ + Returns ref_write with the given properties. A None property denotes 'no change'. + Returns ref_write if all properties are unchanged. Otherwise, returns a copy with the new + fields. + """ + return _ffi_api.RefWriteWithFields(ref_write, ref, value, virtual_device, span) + + class TempExpr(ExprWithOp): """Baseclass of all TempExpr. diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 18f9db72974e..925feb765ad0 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -18,7 +18,6 @@ """Common utilities""" from __future__ import absolute_import as _abs import logging -import os import numpy as np import tvm @@ -1027,42 +1026,50 @@ def visit(self, expr): def visit_function(self, fn): new_params = [self.visit(x) for x in fn.params] new_body = self.visit(fn.body) - return _function.Function( - list(new_params), new_body, fn.ret_type, fn.type_params, fn.attrs, self._span + return _function.FunctionWithFields( + fn, list(new_params), new_body, fn.ret_type, fn.type_params, fn.attrs, None, self._span ) def visit_let(self, let): new_variable = self.visit(let.var) new_value = self.visit(let.value) new_body = self.visit(let.body) - return _expr.Let(new_variable, new_value, new_body, self._span) + return _expr.LetWithFields(let, new_variable, new_value, new_body, None, self._span) def visit_call(self, call): new_args = [self.visit(arg) for arg in call.args] # call.op might be RelayExpr or Op type # ExprMutator will return directly if subject belongs to Op type new_op = self.visit(call.op) - return _expr.Call(new_op, new_args, call.attrs, call.type_args, self._span) + return _expr.CallWithFields( + call, new_op, new_args, call.attrs, call.type_args, None, self._span + ) def visit_var(self, var): - return _expr.Var(var.name_hint, var.type_annotation, self._span) + return _expr.VarWithFields(var, var.vid, var.type_annotation, None, self._span) def visit_if(self, ite): - return _expr.If( + return _expr.IfWithFields( + ite, self.visit(ite.cond), self.visit(ite.true_branch), self.visit(ite.false_branch), + None, self._span, ) def visit_tuple(self, tup): - return _expr.Tuple([self.visit(field) for field in tup.fields], self._span) + return _expr.TupleWithFields( + tup, [self.visit(field) for field in tup.fields], None, self._span + ) def visit_tuple_getitem(self, op): - return _expr.TupleGetItem(self.visit(op.tuple_value), op.index, self._span) + return _expr.TupleGetItemWithFields( + op, self.visit(op.tuple_value), op.index, None, self._span + ) def visit_constant(self, const): - return _expr.Constant(const.data, self._span) + return _expr.ConstantWithFields(const, const.data, None, self._span) # TODO: Frontend model translation could not use following relay expressions so far, # enable them when new models/impls leverage these kinds of relay expressions. @@ -1115,23 +1122,10 @@ def fill(self, sym): raise RuntimeError(f"unsupported type {type(sym)}") -def _should_fill_span(): - should_fill_span = os.environ.get("TVM_SPANFILLING", "1") - - try: - should_fill_span = bool(int(should_fill_span)) - except ValueError: - raise ValueError( - f"invalid value for TVM_SPANFILLING {should_fill_span}, please set to 0 or 1." - ) - - return should_fill_span - - def set_span(sym, span): """ Recursively tag the span to the symbol. Stop when it encounters a span-tagged expr. Disabled - when setting the environment variable "TVM_SPANFILLING" as 0. + when setting the "relay.frontend.fill_span" as False to the config of PassContext Parameters ---------- @@ -1163,6 +1157,6 @@ def set_span(sym, span): #} """ - if _should_fill_span(): + if tvm.transform.PassContext.current().config.get("relay.frontend.fill_span", True): return _SpanFiller(span).fill(sym) return sym diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 9dd1f450d241..899b05440388 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -2085,17 +2085,14 @@ def pprint(name, obj): class _control_span_filling: def __init__(self, on=True): - self._old_state = os.environ["TVM_SPANFILLING"] if "TVM_SPANFILLING" in os.environ else None self._on = on + self._pass_ctx = tvm.transform.PassContext(config={"relay.frontend.fill_span": self._on}) def __enter__(self): - os.environ["TVM_SPANFILLING"] = str(int(self._on)) + self._pass_ctx.__enter__() def __exit__(self, exc_type, exc_val, exc_tb): - if self._old_state: - os.environ["TVM_SPANFILLING"] = self._old_state - else: - del os.environ["TVM_SPANFILLING"] + self._pass_ctx.__exit__(exc_type, exc_val, exc_tb) class enable_span_filling(_control_span_filling): diff --git a/src/ir/span.cc b/src/ir/span.cc index e19bef4cb864..39f0044d16d3 100644 --- a/src/ir/span.cc +++ b/src/ir/span.cc @@ -20,13 +20,17 @@ * \file span.cc * \brief The span data structure. */ +#include #include +#include #include #include namespace tvm { +TVM_REGISTER_PASS_CONFIG_OPTION("relay.frontend.fill_span", Bool); + ObjectPtr GetSourceNameNode(const String& name) { // always return pointer as the reference can change as map re-allocate. // or use another level of indirection by creating a unique_ptr diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 38db974f3970..062d9206cf92 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -75,6 +75,11 @@ TVM_REGISTER_NODE_TYPE(ConstantNode); TVM_REGISTER_GLOBAL("relay.ir.Constant").set_body_typed([](runtime::NDArray data, Span span) { return Constant(data, span); }); +TVM_REGISTER_GLOBAL("relay.ir.ConstantWithFields") + .set_body_typed([](Constant constant, Optional opt_data, + Optional opt_virtual_device, Optional opt_span) { + return WithFields(constant, opt_data, opt_virtual_device, opt_span); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -129,6 +134,11 @@ TVM_REGISTER_NODE_TYPE(TupleNode); TVM_REGISTER_GLOBAL("relay.ir.Tuple").set_body_typed([](tvm::Array fields, Span span) { return Tuple(fields, span); }); +TVM_REGISTER_GLOBAL("relay.ir.TupleWithFields") + .set_body_typed([](Tuple tuple, Optional> opt_fields, + Optional opt_virtual_device, Optional opt_span) { + return WithFields(tuple, opt_fields, opt_virtual_device, opt_span); + }); Tuple WithFields(Tuple tuple, Optional> opt_fields, Optional opt_virtual_device, Optional opt_span) { @@ -203,6 +213,11 @@ TVM_REGISTER_NODE_TYPE(VarNode); TVM_REGISTER_GLOBAL("relay.ir.Var").set_body_typed([](String str, Type type_annotation, Span span) { return Var(str, type_annotation, span); }); +TVM_REGISTER_GLOBAL("relay.ir.VarWithFields") + .set_body_typed([](Var var, Optional opt_vid, Optional opt_type_annotation, + Optional opt_virtual_device, Optional opt_span) { + return WithFields(var, opt_vid, opt_type_annotation, opt_virtual_device, opt_span); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -278,6 +293,13 @@ TVM_REGISTER_GLOBAL("relay.ir.Call") .set_body_typed([](Expr op, Array args, Attrs attrs, Array type_args, Span span) { return Call(op, args, attrs, type_args, span); }); +TVM_REGISTER_GLOBAL("relay.ir.CallWithFields") + .set_body_typed([](Call call, Optional opt_op, Optional> opt_args, + Optional opt_attrs, Optional> opt_type_args, + Optional opt_virtual_device, Optional opt_span) { + return WithFields(call, opt_op, opt_args, opt_attrs, opt_type_args, opt_virtual_device, + opt_span); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -323,6 +345,12 @@ TVM_REGISTER_NODE_TYPE(LetNode); TVM_REGISTER_GLOBAL("relay.ir.Let").set_body_typed([](Var var, Expr value, Expr body, Span span) { return Let(var, value, body, span); }); +TVM_REGISTER_GLOBAL("relay.ir.LetWithFields") + .set_body_typed([](Let let, Optional opt_var, Optional opt_value, + Optional opt_body, Optional opt_virtual_device, + Optional opt_span) { + return WithFields(let, opt_var, opt_value, opt_body, opt_virtual_device, opt_span); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -370,6 +398,13 @@ TVM_REGISTER_GLOBAL("relay.ir.If") .set_body_typed([](Expr cond, Expr true_branch, Expr false_branch, Span span) { return If(cond, true_branch, false_branch, span); }); +TVM_REGISTER_GLOBAL("relay.ir.IfWithFields") + .set_body_typed([](If if_expr, Optional opt_cond, Optional opt_true_branch, + Optional opt_false_branch, Optional opt_virtual_device, + Optional opt_span) { + return WithFields(if_expr, opt_cond, opt_true_branch, opt_false_branch, opt_virtual_device, + opt_span); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -413,6 +448,12 @@ TVM_REGISTER_NODE_TYPE(TupleGetItemNode); TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int index, Span span) { return TupleGetItem(tuple, index, span); }); +TVM_REGISTER_GLOBAL("relay.ir.TupleGetItemWithFields") + .set_body_typed([](TupleGetItem tuple_get_item, Optional opt_tuple, + Optional opt_index, Optional opt_virtual_device, + Optional opt_span) { + return WithFields(tuple_get_item, opt_tuple, opt_index, opt_virtual_device, opt_span); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -451,6 +492,11 @@ TVM_REGISTER_NODE_TYPE(RefCreateNode); TVM_REGISTER_GLOBAL("relay.ir.RefCreate").set_body_typed([](Expr value, Span span) { return RefCreate(value, span); }); +TVM_REGISTER_GLOBAL("relay.ir.RefCreateWithFields") + .set_body_typed([](RefCreate ref_create, Optional opt_value, + Optional opt_virtual_device, Optional opt_span) { + return WithFields(ref_create, opt_value, opt_virtual_device, opt_span); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -489,6 +535,11 @@ TVM_REGISTER_NODE_TYPE(RefReadNode); TVM_REGISTER_GLOBAL("relay.ir.RefRead").set_body_typed([](Expr ref, Span span) { return RefRead(ref, span); }); +TVM_REGISTER_GLOBAL("relay.ir.RefReadWithFields") + .set_body_typed([](RefRead ref_read, Optional opt_ref, + Optional opt_virtual_device, Optional opt_span) { + return WithFields(ref_read, opt_ref, opt_virtual_device, opt_span); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -530,6 +581,11 @@ TVM_REGISTER_NODE_TYPE(RefWriteNode); TVM_REGISTER_GLOBAL("relay.ir.RefWrite").set_body_typed([](Expr ref, Expr value, Span span) { return RefWrite(ref, value, span); }); +TVM_REGISTER_GLOBAL("relay.ir.RefWriteWithFields") + .set_body_typed([](RefWrite ref_write, Optional opt_ref, Optional opt_value, + Optional opt_virtual_device, Optional opt_span) { + return WithFields(ref_write, opt_ref, opt_value, opt_virtual_device, opt_span); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { diff --git a/tests/python/frontend/test_common.py b/tests/python/frontend/test_common.py index 6fd7a193f564..2b35ae71f2d6 100644 --- a/tests/python/frontend/test_common.py +++ b/tests/python/frontend/test_common.py @@ -17,7 +17,7 @@ import numpy as np -from tvm import relay, testing +from tvm import relay, testing, transform from tvm.relay.frontend.common import StrAttrsDict, set_span from relay.utils.tag_span import _set_span, _create_span, _verify_structural_equal_with_span @@ -33,7 +33,7 @@ def test_key_is_not_present(): class TestSetSpan: - def test_env_var_switch(self): + def test_pass_ctx_switch(self): def _res(should_fill): if should_fill: with testing.enable_span_filling(): diff --git a/tests/python/relay/utils/tag_span.py b/tests/python/relay/utils/tag_span.py index d7511a809eaa..77042be60285 100644 --- a/tests/python/relay/utils/tag_span.py +++ b/tests/python/relay/utils/tag_span.py @@ -16,24 +16,27 @@ # under the License. import tvm from tvm import relay, tir +from tvm.relay import expr as _expr from tvm.relay.expr_functor import ExprVisitor def _set_span(expr, src): - if isinstance(expr, relay.Call): - return relay.Call(expr.op, expr.args, expr.attrs, expr.type_args, _create_span(src)) - elif isinstance(expr, relay.Var): - return relay.var(expr.name_hint, expr.type_annotation, None, None, _create_span(src)) - elif isinstance(expr, relay.TupleGetItem): - return relay.TupleGetItem(expr.tuple_value, expr.index, _create_span(src)) - elif isinstance(expr, relay.Constant): - return relay.Constant(expr.data, _create_span(src)) - elif isinstance(expr, relay.TupleWrapper): - return relay.TupleWrapper(_set_span(expr.tuple_value, src), expr.size) - elif isinstance(expr, relay.Tuple): - return relay.Tuple(expr.fields, _create_span(src)) - elif isinstance(expr, tir.AttrStmt): - return tir.AttrStmt(expr.node, expr.attr_key, expr.value, expr.body, _create_span(src)) + if isinstance(expr, _expr.Call): + return _expr.CallWithFields( + expr, expr.op, expr.args, expr.attrs, expr.type_args, None, _create_span(src) + ) + elif isinstance(expr, _expr.Var): + return _expr.VarWithFields(expr, expr.vid, expr.type_annotation, None, _create_span(src)) + elif isinstance(expr, _expr.TupleGetItem): + return _expr.TupleGetItemWithFields( + expr, expr.tuple_value, expr.index, None, _create_span(src) + ) + elif isinstance(expr, _expr.Constant): + return _expr.ConstantWithFields(expr, expr.data, None, _create_span(src)) + elif isinstance(expr, _expr.Tuple): + return _expr.TupleWithFields(expr, expr.fields, None, _create_span(src)) + elif isinstance(expr, _expr.TupleWrapper): + return _expr.TupleWrapper(_set_span(expr.tuple_value, src), expr.size) assert False, f"unsupported type {type(expr)}"