diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index fefc2857230d..88b84bbe7ebc 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -171,10 +171,28 @@ class Constant(ExprWithOp): ---------- data : tvm.nd.NDArray The data content of the constant expression. + + span: Optional[tvm.relay.Span] + Span that points to original source code. """ - def __init__(self, data): - self.__init_handle_by_constructor__(_ffi_api.Constant, data) + def __init__(self, data, span=None): + self.__init_handle_by_constructor__(_ffi_api.Constant, data, span) + + +@tvm._ffi.register_func("relay.ConstantWithFields") +def ConstantWithFields( + constant, + data=None, + virtual_device=None, + span=None, +): + """ + Returns constant with the given properties. A None property denotes 'no change'. + Returns constant if all properties are unchanged. Otherwise, returns a copy with the new + fields. + """ + return _ffi_api.ConstantWithFields(constant, data, virtual_device, span) @tvm._ffi.register_object("relay.Tuple") @@ -187,7 +205,7 @@ class Tuple(ExprWithOp): The fields in the tuple. span: Optional[tvm.relay.Span] - Span that points to original source code + Span that points to original source code. """ def __init__(self, fields, span=None): @@ -205,6 +223,16 @@ def astype(self, _): raise TypeError("astype cannot be used on tuple") +@tvm._ffi.register_func("relay.TupleWithFields") +def TupleWithFields(tup, fields=None, virtual_device=None, span=None): + """ + Returns tuple with the given properties. A None property denotes 'no change'. + Returns tuple if all properties are unchanged. Otherwise, returns a copy with the new + fields. + """ + return _ffi_api.TupleWithFields(tup, fields, virtual_device, span) + + @tvm._ffi.register_object("relay.Var") class Var(ExprWithOp): """A local variable in Relay. @@ -221,10 +249,13 @@ class Var(ExprWithOp): type_annotation: tvm.relay.Type, optional The type annotation on the variable. + + span: Optional[tvm.relay.Span] + Span that points to original source code. """ - def __init__(self, name_hint, type_annotation=None): - self.__init_handle_by_constructor__(_ffi_api.Var, name_hint, type_annotation) + def __init__(self, name_hint, type_annotation=None, span=None): + self.__init_handle_by_constructor__(_ffi_api.Var, name_hint, type_annotation, span) @property def name_hint(self): @@ -233,6 +264,16 @@ def name_hint(self): return name +@tvm._ffi.register_func("relay.VarWithFields") +def VarWithFields(variable, vid=None, type_annotation=None, virtual_device=None, span=None): + """ + Returns var with the given properties. A None property denotes 'no change'. + Returns var if all properties are unchanged. Otherwise, returns a copy with the new + fields. + """ + return _ffi_api.VarWithFields(variable, vid, type_annotation, virtual_device, span) + + @tvm._ffi.register_object("relay.Call") class Call(ExprWithOp): """Function call node in Relay. @@ -256,7 +297,7 @@ class Call(ExprWithOp): used in advanced usecase of template functions. span: Optional[tvm.relay.Span] - Span that points to original source code + Span that points to original source code. """ def __init__(self, op, args, attrs=None, type_args=None, span=None): @@ -265,6 +306,18 @@ def __init__(self, op, args, attrs=None, type_args=None, span=None): self.__init_handle_by_constructor__(_ffi_api.Call, op, args, attrs, type_args, span) +@tvm._ffi.register_func("relay.CallWithFields") +def CallWithFields( + call, op=None, args=None, attrs=None, type_args=None, virtual_device=None, span=None +): + """ + Returns call with the given properties. A None property denotes 'no change'. + Returns call if all properties are unchanged. Otherwise, returns a copy with the new + fields. + """ + return _ffi_api.CallWithFields(call, op, args, attrs, type_args, virtual_device, span) + + @tvm._ffi.register_object("relay.Let") class Let(ExprWithOp): """Let variable binding expression. @@ -279,10 +332,23 @@ class Let(ExprWithOp): body: tvm.relay.Expr The body of the let binding. + + span: Optional[tvm.relay.Span] + Span that points to original source code. """ - def __init__(self, variable, value, body): - self.__init_handle_by_constructor__(_ffi_api.Let, variable, value, body) + def __init__(self, variable, value, body, span=None): + self.__init_handle_by_constructor__(_ffi_api.Let, variable, value, body, span) + + +@tvm._ffi.register_func("relay.LetWithFields") +def LetWithFields(let, variable=None, value=None, body=None, virtual_device=None, span=None): + """ + Returns let with the given properties. A None property denotes 'no change'. + Returns let if all properties are unchanged. Otherwise, returns a copy with the new + fields. + """ + return _ffi_api.LetWithFields(let, variable, value, body, virtual_device, span) @tvm._ffi.register_object("relay.If") @@ -299,10 +365,25 @@ class If(ExprWithOp): false_branch: tvm.relay.Expr The expression evaluated when condition is false. + + span: Optional[tvm.relay.Span] + Span that points to original source code. """ - def __init__(self, cond, true_branch, false_branch): - self.__init_handle_by_constructor__(_ffi_api.If, cond, true_branch, false_branch) + def __init__(self, cond, true_branch, false_branch, span=None): + self.__init_handle_by_constructor__(_ffi_api.If, cond, true_branch, false_branch, span) + + +@tvm._ffi.register_func("relay.IfWithFields") +def IfWithFields( + if_expr, cond=None, true_branch=None, false_branch=None, virtual_device=None, span=None +): + """ + Returns if with the given properties. A None property denotes 'no change'. + Returns if if all properties are unchanged. Otherwise, returns a copy with the new + fields. + """ + return _ffi_api.IfWithFields(if_expr, cond, true_branch, false_branch, virtual_device, span) @tvm._ffi.register_object("relay.TupleGetItem") @@ -316,10 +397,25 @@ class TupleGetItem(ExprWithOp): index: int The index. + + span: Optional[tvm.relay.Span] + Span that points to original source code. """ - def __init__(self, tuple_value, index): - self.__init_handle_by_constructor__(_ffi_api.TupleGetItem, tuple_value, index) + def __init__(self, tuple_value, index, span=None): + self.__init_handle_by_constructor__(_ffi_api.TupleGetItem, tuple_value, index, span) + + +@tvm._ffi.register_func("relay.TupleGetItemWithFields") +def TupleGetItemWithFields( + tuple_get_item, tuple_value=None, index=None, virtual_device=None, span=None +): + """ + Returns tuple_get_item with the given properties. A None property denotes 'no change'. + Returns tuple_get_item if all properties are unchanged. Otherwise, returns a copy with the new + fields. + """ + return _ffi_api.TupleGetItemWithFields(tuple_get_item, tuple_value, index, virtual_device, span) @tvm._ffi.register_object("relay.RefCreate") @@ -329,10 +425,28 @@ class RefCreate(ExprWithOp): ---------- value: tvm.relay.Expr The initial value. + + span: Optional[tvm.relay.Span] + Span that points to original source code. """ - def __init__(self, value): - self.__init_handle_by_constructor__(_ffi_api.RefCreate, value) + def __init__(self, value, span=None): + self.__init_handle_by_constructor__(_ffi_api.RefCreate, value, span) + + +@tvm._ffi.register_func("relay.RefCreateWithFields") +def RefCreateWithFields( + ref_create, + value=None, + virtual_device=None, + span=None, +): + """ + Returns ref_create with the given properties. A None property denotes 'no change'. + Returns ref_create if all properties are unchanged. Otherwise, returns a copy with the new + fields. + """ + return _ffi_api.RefCreateWithFields(ref_create, value, virtual_device, span) @tvm._ffi.register_object("relay.RefRead") @@ -342,10 +456,28 @@ class RefRead(ExprWithOp): ---------- ref: tvm.relay.Expr The reference. + + span: Optional[tvm.relay.Span] + Span that points to original source code. """ - def __init__(self, ref): - self.__init_handle_by_constructor__(_ffi_api.RefRead, ref) + def __init__(self, ref, span=None): + self.__init_handle_by_constructor__(_ffi_api.RefRead, ref, span) + + +@tvm._ffi.register_func("relay.RefReadWithFields") +def RefReadWithFields( + ref_read, + ref=None, + virtual_device=None, + span=None, +): + """ + Returns ref_read with the given properties. A None property denotes 'no change'. + Returns ref_read if all properties are unchanged. Otherwise, returns a copy with the new + fields. + """ + return _ffi_api.RefReadWithFields(ref_read, ref, virtual_device, span) @tvm._ffi.register_object("relay.RefWrite") @@ -357,12 +489,32 @@ class RefWrite(ExprWithOp): ---------- ref: tvm.relay.Expr The reference. + value: tvm.relay.Expr The new value. + + span: Optional[tvm.relay.Span] + Span that points to original source code. """ - def __init__(self, ref, value): - self.__init_handle_by_constructor__(_ffi_api.RefWrite, ref, value) + def __init__(self, ref, value, span=None): + self.__init_handle_by_constructor__(_ffi_api.RefWrite, ref, value, span) + + +@tvm._ffi.register_func("relay.RefWriteWithFields") +def RefWriteWithFields( + ref_write, + ref=None, + value=None, + virtual_device=None, + span=None, +): + """ + Returns ref_write with the given properties. A None property denotes 'no change'. + Returns ref_write if all properties are unchanged. Otherwise, returns a copy with the new + fields. + """ + return _ffi_api.RefWriteWithFields(ref_write, ref, value, virtual_device, span) class TempExpr(ExprWithOp): @@ -433,7 +585,7 @@ def astype(self, _): raise TypeError("astype cannot be used on tuple") -def var(name_hint, type_annotation=None, shape=None, dtype="float32"): +def var(name_hint, type_annotation=None, shape=None, dtype="float32", span=None): """Create a new tvm.relay.Var. This is a simple wrapper function that allows specify @@ -456,6 +608,9 @@ def var(name_hint, type_annotation=None, shape=None, dtype="float32"): dtype: str, optional The data type of the tensor. + span: Optional[tvm.relay.Span] + Span that points to original source code. + Examples -------- .. code-block:: python @@ -476,10 +631,10 @@ def var(name_hint, type_annotation=None, shape=None, dtype="float32"): type_annotation = _ty.TensorType(shape, dtype) elif isinstance(type_annotation, str): type_annotation = _ty.TensorType((), type_annotation) - return Var(name_hint, type_annotation) + return Var(name_hint, type_annotation, span) -def const(value, dtype=None): +def const(value, dtype=None, span=None): """Create a constant value. Parameters @@ -490,6 +645,9 @@ def const(value, dtype=None): dtype: str, optional The data type of the resulting constant. + span: Optional[tvm.relay.Span] + Span that points to original source code. + Note ---- When dtype is None, we use the following rule: @@ -516,7 +674,7 @@ def const(value, dtype=None): if not isinstance(value, _nd.NDArray): raise ValueError("value has to be scalar or NDArray") - return Constant(value) + return Constant(value, span) def bind(expr, binds): diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 5f961f1ae0e8..925feb765ad0 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -24,6 +24,7 @@ from tvm.ir import IRModule from tvm.topi.utils import get_const_tuple +from ..expr_functor import ExprMutator from .. import expr as _expr from .. import function as _function from .. import transform as _transform @@ -304,13 +305,16 @@ def __init__(self): self.const_ctr = 1 self.in_padding = False - def new_const(self, value, shape=None, dtype="float32"): + def new_const(self, value, shape=None, dtype="float32", source_name=None): + """Construct a new var expr and add to exprs dictionary""" name = "_param_%d" % (self.const_ctr) if hasattr(value, "shape"): shape = value.shape self.const_ctr += 1 self.params[name] = value self.exprs[name] = _expr.var(name_hint=name, shape=shape, dtype=dtype) + if source_name: + self.exprs[name] = set_span(self.exprs[name], source_name) return self.exprs[name] def get_expr(self, name): @@ -997,3 +1001,162 @@ def try_resolve_var_to_const(x, graph_params): return _op.const(value, dtype) return x + + +class _SpanFiller(ExprMutator): + """SpanFiller""" + + def __init__(self, span): + ExprMutator.__init__(self) + if isinstance(span, tvm.relay.Span): + self._span = span + elif isinstance(span, str): + self._span = tvm.relay.Span(tvm.relay.SourceName(span), 0, 0, 0, 0) + elif isinstance(span, bytes): + self._span = tvm.relay.Span(tvm.relay.SourceName(span.decode("utf-8")), 0, 0, 0, 0) + else: + assert False, f"unsupported span type: {type(span)}" + + def visit(self, expr): + if hasattr(expr, "span") and expr.span: + return expr + + return super().visit(expr) + + def visit_function(self, fn): + new_params = [self.visit(x) for x in fn.params] + new_body = self.visit(fn.body) + return _function.FunctionWithFields( + fn, list(new_params), new_body, fn.ret_type, fn.type_params, fn.attrs, None, self._span + ) + + def visit_let(self, let): + new_variable = self.visit(let.var) + new_value = self.visit(let.value) + new_body = self.visit(let.body) + return _expr.LetWithFields(let, new_variable, new_value, new_body, None, self._span) + + def visit_call(self, call): + new_args = [self.visit(arg) for arg in call.args] + # call.op might be RelayExpr or Op type + # ExprMutator will return directly if subject belongs to Op type + new_op = self.visit(call.op) + return _expr.CallWithFields( + call, new_op, new_args, call.attrs, call.type_args, None, self._span + ) + + def visit_var(self, var): + return _expr.VarWithFields(var, var.vid, var.type_annotation, None, self._span) + + def visit_if(self, ite): + return _expr.IfWithFields( + ite, + self.visit(ite.cond), + self.visit(ite.true_branch), + self.visit(ite.false_branch), + None, + self._span, + ) + + def visit_tuple(self, tup): + return _expr.TupleWithFields( + tup, [self.visit(field) for field in tup.fields], None, self._span + ) + + def visit_tuple_getitem(self, op): + return _expr.TupleGetItemWithFields( + op, self.visit(op.tuple_value), op.index, None, self._span + ) + + def visit_constant(self, const): + return _expr.ConstantWithFields(const, const.data, None, self._span) + + # TODO: Frontend model translation could not use following relay expressions so far, + # enable them when new models/impls leverage these kinds of relay expressions. + def visit_ref_create(self, _): + raise NotImplementedError() + + def visit_ref_write(self, _): + raise NotImplementedError() + + def visit_ref_read(self, _): + raise NotImplementedError() + + def visit_match(self, _): + raise NotImplementedError() + + def fill(self, sym): + """Fill span to sym when it is an expr, or return it without change + + Parameters + ---------- + sym : + A symbol which is generated from the conversion of a frontend operator. + + Returns + ------- + sym: + A expr with span-filled or the original sym. + """ + if isinstance(sym, _expr.TupleWrapper): + return _expr.TupleWrapper(self.visit(sym.tuple_value), sym.size) + elif isinstance(sym, _expr.RelayExpr): + return self.visit(sym) + elif isinstance(sym, list): + assert all( + isinstance(expr, _expr.RelayExpr) for expr in sym + ), f"unexpected relay expressions in {sym}" + return [self.visit(expr) for expr in sym] + elif isinstance(sym, tuple): + # some op conversion may return dummy elements + # e.g. op in frontend/pytorch.py: min_max_common + assert all( + isinstance(expr, (_expr.RelayExpr, type(None))) for expr in sym + ), f"unexpected relay expressions in {sym}" + return tuple(self.visit(expr) if expr else None for expr in sym) + elif isinstance(sym, (float, int)): + return sym + elif isinstance(sym, np.ndarray): + return sym + + raise RuntimeError(f"unsupported type {type(sym)}") + + +def set_span(sym, span): + """ + Recursively tag the span to the symbol. Stop when it encounters a span-tagged expr. Disabled + when setting the "relay.frontend.fill_span" as False to the config of PassContext + + Parameters + ---------- + sym : + A symbol is generated from the conversion of a frontend operator. Raise an error when the + type of the symbol is not supported. + + span : String, Span, or bytes + The source information of the corresponding symbol. + + Returns + ------- + result : + The symbol tagged with span. + + Examples + -------- + .. code-block:: python + + x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var") + w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64") + y = set_span( + relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)), "conv2d" + ) + print(relay.Function([x], y)) + + #fn (%x: Tensor[(1, 64, 56, 56), float32] /* span=x_var:0:0 */) { + # nn.conv2d(%x, meta[relay.Constant][0] /* span=conv2d:0:0 */, ...) /* span=conv2d:0:0 */ + #} + """ + + if tvm.transform.PassContext.current().config.get("relay.frontend.fill_span", True): + return _SpanFiller(span).fill(sym) + return sym diff --git a/python/tvm/relay/function.py b/python/tvm/relay/function.py index 6b3513cb5e1a..68d8953900cf 100644 --- a/python/tvm/relay/function.py +++ b/python/tvm/relay/function.py @@ -44,14 +44,17 @@ class Function(BaseFunc): type_params: Optional[List[tvm.relay.TypeParam]] The additional type parameters, this is only used in advanced usecase of template functions. + + span: Optional[tvm.relay.Span] + Span that points to original source code. """ - def __init__(self, params, body, ret_type=None, type_params=None, attrs=None): + def __init__(self, params, body, ret_type=None, type_params=None, attrs=None, span=None): if type_params is None: type_params = convert([]) self.__init_handle_by_constructor__( - _ffi_api.Function, params, body, ret_type, type_params, attrs + _ffi_api.Function, params, body, ret_type, type_params, attrs, span ) def __call__(self, *args): diff --git a/python/tvm/relay/loops.py b/python/tvm/relay/loops.py index 6c2ab2e23d72..d46e34860f0b 100644 --- a/python/tvm/relay/loops.py +++ b/python/tvm/relay/loops.py @@ -54,7 +54,7 @@ def while_loop(cond, loop_vars, loop_bodies): for i, loop_var in enumerate(loop_vars): name = loop_var.name_hint if isinstance(loop_var, _expr.Var) else "arg{}".format(i) - new_var = _expr.var(name, type_annotation=sb.type_of(loop_var)) + new_var = _expr.var(name, type_annotation=sb.type_of(loop_var), span=loop_var.span) fresh_vars.append(new_var) with sb.if_scope(cond(*fresh_vars)): diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 74ca326bca7e..899b05440388 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -2081,3 +2081,25 @@ def pprint(name, obj): f"or an instance of `tvm.tir.PrimFunc`. " f"Instead, received {type(expected)}." ) + + +class _control_span_filling: + def __init__(self, on=True): + self._on = on + self._pass_ctx = tvm.transform.PassContext(config={"relay.frontend.fill_span": self._on}) + + def __enter__(self): + self._pass_ctx.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + self._pass_ctx.__exit__(exc_type, exc_val, exc_tb) + + +class enable_span_filling(_control_span_filling): + def __init__(self): + super().__init__() + + +class disable_span_filling(_control_span_filling): + def __init__(self): + super().__init__(on=False) diff --git a/src/ir/span.cc b/src/ir/span.cc index e19bef4cb864..39f0044d16d3 100644 --- a/src/ir/span.cc +++ b/src/ir/span.cc @@ -20,13 +20,17 @@ * \file span.cc * \brief The span data structure. */ +#include #include +#include #include #include namespace tvm { +TVM_REGISTER_PASS_CONFIG_OPTION("relay.frontend.fill_span", Bool); + ObjectPtr GetSourceNameNode(const String& name) { // always return pointer as the reference can change as map re-allocate. // or use another level of indirection by creating a unique_ptr diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 5c85b3b29df7..062d9206cf92 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -72,9 +72,14 @@ Constant::Constant(runtime::NDArray data, Span span) { TVM_REGISTER_NODE_TYPE(ConstantNode); -TVM_REGISTER_GLOBAL("relay.ir.Constant").set_body_typed([](runtime::NDArray data) { - return Constant(data); +TVM_REGISTER_GLOBAL("relay.ir.Constant").set_body_typed([](runtime::NDArray data, Span span) { + return Constant(data, span); }); +TVM_REGISTER_GLOBAL("relay.ir.ConstantWithFields") + .set_body_typed([](Constant constant, Optional opt_data, + Optional opt_virtual_device, Optional opt_span) { + return WithFields(constant, opt_data, opt_virtual_device, opt_span); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -129,6 +134,11 @@ TVM_REGISTER_NODE_TYPE(TupleNode); TVM_REGISTER_GLOBAL("relay.ir.Tuple").set_body_typed([](tvm::Array fields, Span span) { return Tuple(fields, span); }); +TVM_REGISTER_GLOBAL("relay.ir.TupleWithFields") + .set_body_typed([](Tuple tuple, Optional> opt_fields, + Optional opt_virtual_device, Optional opt_span) { + return WithFields(tuple, opt_fields, opt_virtual_device, opt_span); + }); Tuple WithFields(Tuple tuple, Optional> opt_fields, Optional opt_virtual_device, Optional opt_span) { @@ -200,9 +210,14 @@ Var WithFields(Var var, Optional opt_vid, Optional opt_type_annotation TVM_REGISTER_NODE_TYPE(VarNode); -TVM_REGISTER_GLOBAL("relay.ir.Var").set_body_typed([](String str, Type type_annotation) { - return Var(str, type_annotation); +TVM_REGISTER_GLOBAL("relay.ir.Var").set_body_typed([](String str, Type type_annotation, Span span) { + return Var(str, type_annotation, span); }); +TVM_REGISTER_GLOBAL("relay.ir.VarWithFields") + .set_body_typed([](Var var, Optional opt_vid, Optional opt_type_annotation, + Optional opt_virtual_device, Optional opt_span) { + return WithFields(var, opt_vid, opt_type_annotation, opt_virtual_device, opt_span); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -278,6 +293,13 @@ TVM_REGISTER_GLOBAL("relay.ir.Call") .set_body_typed([](Expr op, Array args, Attrs attrs, Array type_args, Span span) { return Call(op, args, attrs, type_args, span); }); +TVM_REGISTER_GLOBAL("relay.ir.CallWithFields") + .set_body_typed([](Call call, Optional opt_op, Optional> opt_args, + Optional opt_attrs, Optional> opt_type_args, + Optional opt_virtual_device, Optional opt_span) { + return WithFields(call, opt_op, opt_args, opt_attrs, opt_type_args, opt_virtual_device, + opt_span); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -320,9 +342,15 @@ Let WithFields(Let let, Optional opt_var, Optional opt_value, Optiona TVM_REGISTER_NODE_TYPE(LetNode); -TVM_REGISTER_GLOBAL("relay.ir.Let").set_body_typed([](Var var, Expr value, Expr body) { - return Let(var, value, body); +TVM_REGISTER_GLOBAL("relay.ir.Let").set_body_typed([](Var var, Expr value, Expr body, Span span) { + return Let(var, value, body, span); }); +TVM_REGISTER_GLOBAL("relay.ir.LetWithFields") + .set_body_typed([](Let let, Optional opt_var, Optional opt_value, + Optional opt_body, Optional opt_virtual_device, + Optional opt_span) { + return WithFields(let, opt_var, opt_value, opt_body, opt_virtual_device, opt_span); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -367,8 +395,15 @@ If WithFields(If if_expr, Optional opt_cond, Optional opt_true_branc TVM_REGISTER_NODE_TYPE(IfNode); TVM_REGISTER_GLOBAL("relay.ir.If") - .set_body_typed([](Expr cond, Expr true_branch, Expr false_branch) { - return If(cond, true_branch, false_branch); + .set_body_typed([](Expr cond, Expr true_branch, Expr false_branch, Span span) { + return If(cond, true_branch, false_branch, span); + }); +TVM_REGISTER_GLOBAL("relay.ir.IfWithFields") + .set_body_typed([](If if_expr, Optional opt_cond, Optional opt_true_branch, + Optional opt_false_branch, Optional opt_virtual_device, + Optional opt_span) { + return WithFields(if_expr, opt_cond, opt_true_branch, opt_false_branch, opt_virtual_device, + opt_span); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -410,9 +445,15 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, TVM_REGISTER_NODE_TYPE(TupleGetItemNode); -TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int index) { - return TupleGetItem(tuple, index); +TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int index, Span span) { + return TupleGetItem(tuple, index, span); }); +TVM_REGISTER_GLOBAL("relay.ir.TupleGetItemWithFields") + .set_body_typed([](TupleGetItem tuple_get_item, Optional opt_tuple, + Optional opt_index, Optional opt_virtual_device, + Optional opt_span) { + return WithFields(tuple_get_item, opt_tuple, opt_index, opt_virtual_device, opt_span); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -448,9 +489,14 @@ RefCreate WithFields(RefCreate ref_create, Optional opt_value, TVM_REGISTER_NODE_TYPE(RefCreateNode); -TVM_REGISTER_GLOBAL("relay.ir.RefCreate").set_body_typed([](Expr value) { - return RefCreate(value); +TVM_REGISTER_GLOBAL("relay.ir.RefCreate").set_body_typed([](Expr value, Span span) { + return RefCreate(value, span); }); +TVM_REGISTER_GLOBAL("relay.ir.RefCreateWithFields") + .set_body_typed([](RefCreate ref_create, Optional opt_value, + Optional opt_virtual_device, Optional opt_span) { + return WithFields(ref_create, opt_value, opt_virtual_device, opt_span); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -486,7 +532,14 @@ RefRead WithFields(RefRead ref_read, Optional opt_ref, TVM_REGISTER_NODE_TYPE(RefReadNode); -TVM_REGISTER_GLOBAL("relay.ir.RefRead").set_body_typed([](Expr ref) { return RefRead(ref); }); +TVM_REGISTER_GLOBAL("relay.ir.RefRead").set_body_typed([](Expr ref, Span span) { + return RefRead(ref, span); +}); +TVM_REGISTER_GLOBAL("relay.ir.RefReadWithFields") + .set_body_typed([](RefRead ref_read, Optional opt_ref, + Optional opt_virtual_device, Optional opt_span) { + return WithFields(ref_read, opt_ref, opt_virtual_device, opt_span); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -525,9 +578,14 @@ RefWrite WithFields(RefWrite ref_write, Optional opt_ref, Optional o TVM_REGISTER_NODE_TYPE(RefWriteNode); -TVM_REGISTER_GLOBAL("relay.ir.RefWrite").set_body_typed([](Expr ref, Expr value) { - return RefWrite(ref, value); +TVM_REGISTER_GLOBAL("relay.ir.RefWrite").set_body_typed([](Expr ref, Expr value, Span span) { + return RefWrite(ref, value, span); }); +TVM_REGISTER_GLOBAL("relay.ir.RefWriteWithFields") + .set_body_typed([](RefWrite ref_write, Optional opt_ref, Optional opt_value, + Optional opt_virtual_device, Optional opt_span) { + return WithFields(ref_write, opt_ref, opt_value, opt_virtual_device, opt_span); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index 1a3db9974f05..07cfb27b1d35 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -124,8 +124,8 @@ TVM_REGISTER_NODE_TYPE(FunctionNode); TVM_REGISTER_GLOBAL("relay.ir.Function") .set_body_typed([](tvm::Array params, Expr body, Type ret_type, - tvm::Array ty_params, tvm::DictAttrs attrs) { - return Function(params, body, ret_type, ty_params, attrs); + tvm::Array ty_params, tvm::DictAttrs attrs, Span span) { + return Function(params, body, ret_type, ty_params, attrs, span); }); TVM_REGISTER_GLOBAL("relay.ir.FunctionWithFields") .set_body_typed([](Function function, Optional> opt_params, Optional opt_body, diff --git a/tests/python/frontend/test_common.py b/tests/python/frontend/test_common.py index e706f2af304a..2b35ae71f2d6 100644 --- a/tests/python/frontend/test_common.py +++ b/tests/python/frontend/test_common.py @@ -14,7 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from tvm.relay.frontend.common import StrAttrsDict + +import numpy as np + +from tvm import relay, testing, transform +from tvm.relay.frontend.common import StrAttrsDict, set_span +from relay.utils.tag_span import _set_span, _create_span, _verify_structural_equal_with_span def test_key_is_present(): @@ -27,6 +32,189 @@ def test_key_is_not_present(): assert not attrs.has_attr("b") +class TestSetSpan: + def test_pass_ctx_switch(self): + def _res(should_fill): + if should_fill: + with testing.enable_span_filling(): + return set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var") + else: + with testing.disable_span_filling(): + return set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var") + + disable = relay.var("x", shape=(1, 64, 56, 56)) + enable = relay.var("x", shape=(1, 64, 56, 56), span=_create_span("x_var")) + + _verify_structural_equal_with_span(_res(False), disable) + _verify_structural_equal_with_span(_res(True), enable) + + # Should tag all exprs without span, and stop when expr is span-tagged + def test_builtin_tuple(self): + def _res(): + a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a")) + b = relay.const(np.zeros([1, 1, 1]), dtype="int64") + return set_span(tuple([a, b]), "tuple") + + def _golden(): + a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a")) + b = relay.const(np.zeros([1, 1, 1]), dtype="int64", span=_create_span("tuple")) + return tuple([a, b]) + + res_tuple, golden_tuple = _res(), _golden() + assert len(res_tuple) == len(golden_tuple) + for i in range(len(res_tuple)): + _verify_structural_equal_with_span(res_tuple[i], golden_tuple[i]) + + def test_builtin_list(self): + def _res(): + a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a")) + b = relay.const(np.zeros([1, 1, 1]), dtype="int64") + t = relay.Tuple([a, b]) + t_a = relay.TupleGetItem(t, 0) + t_b = relay.TupleGetItem(t, 1) + return set_span([t_a, t_b], "list") + + def _golden(): + a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a")) + b = relay.const(np.zeros([1, 1, 1]), dtype="int64", span=_create_span("list")) + t = relay.Tuple([a, b], span=_create_span("list")) + t_a = relay.TupleGetItem(t, 0, span=_create_span("list")) + t_b = relay.TupleGetItem(t, 1, span=_create_span("list")) + return [t_a, t_b] + + res_list, golden_list = _res(), _golden() + assert len(res_list) == len(golden_list) + for i in range(len(res_list)): + _verify_structural_equal_with_span(res_list[i], golden_list[i]) + + def test_var(self): + x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var") + x_expected = relay.var("x", shape=(1, 64, 56, 56), span=_create_span("x_var")) + _verify_structural_equal_with_span(x, x_expected) + + def test_constant(self): + c = set_span(relay.const(np.ones([64, 64, 3, 3]), dtype="int64"), "const_c") + c_expected = relay.const( + np.ones([64, 64, 3, 3]), dtype="int64", span=_create_span("const_c") + ) + _verify_structural_equal_with_span(c, c_expected) + + def test_call(self): + def _res(): + x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var") + w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64") + y = set_span( + relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)), "conv2d" + ) + return relay.Function([x], y) + + def _golden(): + x = relay.var("x", shape=(1, 64, 56, 56), span=_create_span("x_var")) + w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64", span=_create_span("conv2d")) + y = _set_span( + relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)), "conv2d" + ) + return relay.Function([x], y) + + _verify_structural_equal_with_span(_res(), _golden()) + + def test_tuple(self): + def _res(): + a = set_span(relay.const(np.ones([1, 1, 1]), dtype="int64"), "a") + b = relay.const(np.ones([1, 1, 1]), dtype="int64") + t = set_span(relay.Tuple([a, b]), "t") + return relay.Function([], t) + + def _golden(): + a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a")) + b = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("t")) + t = relay.Tuple([a, b], span=_create_span("t")) + return relay.Function([], t) + + _verify_structural_equal_with_span(_res(), _golden()) + + def test_tuple_getitem(self): + def _res(): + a = set_span(relay.const(np.ones([1, 1, 1]), dtype="int64"), "a") + b = relay.const(np.ones([1, 1, 1]), dtype="int64") + t = relay.Tuple([a, b]) + i = set_span(relay.TupleGetItem(t, 0), "i") + return relay.Function([], i) + + def _golden(): + a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a")) + b = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("i")) + t = relay.Tuple([a, b], span=_create_span("i")) + i = relay.TupleGetItem(t, 0, span=_create_span("i")) + return relay.Function([], i) + + _verify_structural_equal_with_span(_res(), _golden()) + + def test_let(self): + def _res(): + x = set_span(relay.Var("x"), "x_var") + c_1 = relay.const(np.ones(10)) + add = relay.add(x, x) + body = set_span(relay.Let(x, c_1, add), "let") + + c_2 = set_span(relay.const(np.zeros(10)), "zeros") + y = set_span(relay.add(body, c_2), "add_2") + return relay.Function([x], y) + + def _golden(): + x = relay.Var("x", span=_create_span("x_var")) + c_1 = relay.const(np.ones(10), span=_create_span("let")) + add = _set_span(relay.add(x, x), "let") + body = relay.Let(x, c_1, add, span=_create_span("let")) + + c_2 = relay.const(np.zeros(10), span=_create_span("zeros")) + y = _set_span(relay.add(body, c_2), "add_2") + return relay.Function([x], y) + + _verify_structural_equal_with_span(_res(), _golden()) + + def test_if(self): + def _res(): + x = set_span(relay.var("x", shape=[], dtype="float32"), "x_var") + y = set_span(relay.var("y", shape=[], dtype="float32"), "y_var") + eq = relay.equal(x, y) + + true_branch = set_span(relay.add(x, y), "true_branch") + false_branch = relay.subtract(x, y) + ife = set_span(relay.If(eq, true_branch, false_branch), "if") + return relay.Function([x, y], ife) + + def _golden(): + x = relay.var("x", shape=[], dtype="float32", span=_create_span("x_var")) + y = relay.var("y", shape=[], dtype="float32", span=_create_span("y_var")) + eq = _set_span(relay.equal(x, y), "if") + + true_branch = _set_span(relay.add(x, y), "true_branch") + false_branch = _set_span(relay.subtract(x, y), "if") + ife = relay.If(eq, true_branch, false_branch, span=_create_span("if")) + return relay.Function([x, y], ife) + + _verify_structural_equal_with_span(_res(), _golden()) + + def test_fn(self): + def _res(): + x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var") + w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64") + y = relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)) + f = set_span(relay.Function([x], y), "func") + return f + + def _golden(): + x = relay.var("x", shape=(1, 64, 56, 56), span=_create_span("x_var")) + w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64", span=_create_span("func")) + y = _set_span( + relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)), "func" + ) + f = relay.Function([x], y, span=_create_span("func")) + return f + + _verify_structural_equal_with_span(_res(), _golden()) + + if __name__ == "__main__": - test_key_is_present() - test_key_is_present() + testing.main() diff --git a/tests/python/relay/utils/tag_span.py b/tests/python/relay/utils/tag_span.py new file mode 100644 index 000000000000..77042be60285 --- /dev/null +++ b/tests/python/relay/utils/tag_span.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import relay, tir +from tvm.relay import expr as _expr +from tvm.relay.expr_functor import ExprVisitor + + +def _set_span(expr, src): + if isinstance(expr, _expr.Call): + return _expr.CallWithFields( + expr, expr.op, expr.args, expr.attrs, expr.type_args, None, _create_span(src) + ) + elif isinstance(expr, _expr.Var): + return _expr.VarWithFields(expr, expr.vid, expr.type_annotation, None, _create_span(src)) + elif isinstance(expr, _expr.TupleGetItem): + return _expr.TupleGetItemWithFields( + expr, expr.tuple_value, expr.index, None, _create_span(src) + ) + elif isinstance(expr, _expr.Constant): + return _expr.ConstantWithFields(expr, expr.data, None, _create_span(src)) + elif isinstance(expr, _expr.Tuple): + return _expr.TupleWithFields(expr, expr.fields, None, _create_span(src)) + elif isinstance(expr, _expr.TupleWrapper): + return _expr.TupleWrapper(_set_span(expr.tuple_value, src), expr.size) + + assert False, f"unsupported type {type(expr)}" + + +def _create_span(src): + if isinstance(src, list): + tmp_list = [] + for s in src: + if isinstance(s, str): + tmp_list.append(_create_span(s)) + elif isinstance(s, relay.Span): + tmp_list.append(s) + elif isinstance(s, relay.SequentialSpan): + tmp_list.extend(s.spans) + elif s is None: + tmp_list.append(s) + else: + assert False, f"unsupported type {type(s)}" + return relay.SequentialSpan(tmp_list) + return relay.Span(relay.SourceName(src), 0, 0, 0, 0) + + +def _collect_spans(objref): + class Collector: + def __init__(self): + self._spans = [] + + def collect(self, objref): + if hasattr(objref, "span"): + self._spans.append(objref.span) + + @property + def get_spans(self): + return self._spans + + pov = None + if isinstance(objref, relay.Expr): + pov = relay.analysis.post_order_visit + elif isinstance(objref, (tir.Stmt, tir.expr.PrimExprWithOp)): + pov = tir.stmt_functor.post_order_visit + else: + assert False, f"unsupported type {type(objref)}" + + c = Collector() + pov(objref, c.collect) + return c.get_spans + + +def _verify_span(lhs, rhs): + lhs_spans, rhs_spans = _collect_spans(lhs), _collect_spans(rhs) + + assert len(lhs_spans) == len(rhs_spans) + + for i in range(len(lhs_spans)): + assert tvm.ir.structural_equal(lhs_spans[i], rhs_spans[i]) + + +def _verify_structural_equal_with_span(lhs, rhs, assert_mode=False, map_free_vars=False): + if isinstance(lhs, relay.Var) and isinstance(rhs, relay.Var): + # SEqualReduce compares the vid of Var type. Threrfore we only compare span here. + _verify_span(lhs, rhs) + return + + if assert_mode: + tvm.ir.assert_structural_equal(lhs, rhs, map_free_vars) + else: + assert tvm.ir.structural_equal(lhs, rhs, map_free_vars) + + _verify_span(lhs, rhs)