diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h index deda5d666e40..19aba6e06933 100644 --- a/include/tvm/relax/struct_info.h +++ b/include/tvm/relax/struct_info.h @@ -60,19 +60,26 @@ class ObjectStructInfo : public StructInfo { */ class PrimStructInfoNode : public StructInfoNode { public: + /*! \brief Underlying primitive value, if known */ + Optional value; + /*! \brief Underlying data type of the primitive value */ DataType dtype; void VisitAttrs(AttrVisitor* v) { + v->Visit("value", &value); v->Visit("dtype", &dtype); v->Visit("span", &span); } bool SEqualReduce(const PrimStructInfoNode* other, SEqualReducer equal) const { - return equal(dtype, other->dtype); + return equal(value, other->value) && equal(dtype, other->dtype); } - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dtype); } + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(value); + hash_reduce(dtype); + } static constexpr const char* _type_key = "relax.PrimStructInfo"; TVM_DECLARE_FINAL_OBJECT_INFO(PrimStructInfoNode, StructInfoNode); @@ -84,8 +91,12 @@ class PrimStructInfoNode : public StructInfoNode { */ class PrimStructInfo : public StructInfo { public: + /* Construct a PrimStructInfo with a known dtype, but unknown value */ TVM_DLL PrimStructInfo(DataType dtype, Span span = Span()); + /* Construct a PrimStructInfo with a known value */ + TVM_DLL PrimStructInfo(PrimExpr value, Span span = Span()); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PrimStructInfo, StructInfo, PrimStructInfoNode); }; diff --git a/python/tvm/relax/struct_info.py b/python/tvm/relax/struct_info.py index fe30c01d3a04..8389b6ec1c27 100644 --- a/python/tvm/relax/struct_info.py +++ b/python/tvm/relax/struct_info.py @@ -23,6 +23,7 @@ from tvm.ir import Span, EnvFunc, Array, VDevice from tvm.tir import PrimExpr +from tvm.runtime import DataType from .expr import StructInfo, Expr, ShapeExpr from . import _ffi_api, ty, expr @@ -42,14 +43,68 @@ class PrimStructInfo(StructInfo): Parameters ---------- - dtype : str - The data type of the prim value. + dtype_or_expr : Union[str, DataType, PrimExpr] + + The data type of the prim value, or a known expression for the prim + value. """ + value: Optional[PrimExpr] dtype: str - def __init__(self, dtype: str, span: Span = None) -> None: - self.__init_handle_by_constructor__(_ffi_api.PrimStructInfo, dtype, span) # type: ignore + def __init__( + self, + dtype: Optional[Union[str, DataType]] = None, + value: Optional[Union[int, float, PrimExpr]] = None, + span: Span = None, + ) -> None: + # Guard against incorrect usage. For backwards compatibility, + # the dtype and value are in the opposite order from most + # usages. While PrimStructInfo could take a single positional + # argument and check the type, this would require an API + # difference from TVMScript's PrimProxy, which cannot. + # (PrimProxy uses string arguments for datatype, and also for + # inline variable definitions when used in a function + # signature, and requires separate arguments to distinguish + # the two cases.) + if isinstance(dtype, (PrimExpr, int, float)): + raise TypeError( + f"The first positional argument of PrimStructInfo must be the datatype, " + f", but received {type(dtype)}. " + f"The value can be specified as a keyword argument " + f"without needing specifying the dtype: " + f"PrimStructInfo(value=arg)." + ) + + if dtype is None and value is None: + raise TypeError( + "PrimStructInfo.__init__ missing required argument. " + "Must provide either 'dtype' or 'value'" + ) + + if dtype is not None: + if isinstance(value, PrimExpr): + assert value.dtype == dtype, ( + "When providing both 'value' and 'dtype' to PrimStructInfo.__init__, " + "they must be consistent with each other. " + "However, the value {value} has dtype {value.dtype}, " + "but the specified dtype was {dtype}." + ) + elif isinstance(value, (int, float)): + value = tvm.tir.const(value, dtype) + + # Use relax's default integer type if not otherwise specified. + if isinstance(value, int): + value = tvm.tir.IntImm("int64", value) + + if value is None: + self.__init_handle_by_constructor__( + _ffi_api.PrimStructInfoFromDtype, dtype, span + ) # type: ignore + else: + self.__init_handle_by_constructor__( + _ffi_api.PrimStructInfoFromValue, value, span + ) # type: ignore @tvm._ffi.register_object("relax.ShapeStructInfo") diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py index 1c18d75be45e..a847778b90a6 100644 --- a/python/tvm/script/parser/relax/entry.py +++ b/python/tvm/script/parser/relax/entry.py @@ -296,7 +296,7 @@ def __init__( def get_symbolic_vars(self) -> Set[str]: if self.values is None: - return {} + return set() else: return {v for v in self.values if isinstance(v, str) and v.isidentifier()} @@ -342,27 +342,52 @@ def Object() -> ObjectProxy: class PrimProxy(StructInfoProxy): - dtype: str - """The type of shape values. + dtype: Optional[str] + value: Optional[Union[int, float, str, PrimExpr]] + + """The type of TIR-representable values. Parameters ---------- - dtype : str + dtype : Optional[str] The data type. + + value: Optional[Union[int, float, str, PrimExpr]] + The known value """ - def __init__(self, dtype: str) -> None: + def __init__( + self, + dtype: Optional[str] = None, + value: Optional[Union[int, float, str, PrimExpr]] = None, + ) -> None: + if dtype is None and value is None: + raise TypeError( + "R.Prim missing required argument. " "Must provide either 'dtype' or 'value'" + ) + self.dtype = dtype + self.value = value def get_symbolic_vars(self) -> Set[str]: - return set() + if isinstance(self.value, str) and self.value.isidentifier(): + return {self.value} + else: + return set() def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> ShapeStructInfo: - return PrimStructInfo(self.dtype) + if self.value is None: + return PrimStructInfo(dtype=self.dtype) + else: + value = _eval_shape(self.value, dict_globals) + return PrimStructInfo(dtype=self.dtype, value=value) -def Prim(dtype: str) -> PrimProxy: - return PrimProxy(dtype) +def Prim( + dtype: Optional[str] = None, + value: Optional[Union[int, float, str, PrimExpr]] = None, +) -> PrimProxy: + return PrimProxy(dtype, value) ############################ R.match_cast ############################# diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc index 9fae7762790c..e1401ac63842 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/struct_info_analysis.cc @@ -1053,6 +1053,9 @@ class SymbolicVarCollector : public relax::ExprVisitor, this->VisitStructInfoExprField(val); } } + if (auto prim_value = expr.as()) { + this->VisitStructInfoExprField(prim_value.value()->value); + } } void VisitStructInfoExprField(const PrimExpr& expr) final { diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 3808325670aa..447d3935f5ca 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -312,7 +312,7 @@ TVM_REGISTER_GLOBAL("relax.Constant") PrimValue::PrimValue(PrimExpr value, Span span) { ObjectPtr n = make_object(); n->checked_type_ = PrimType(value.dtype()); - n->struct_info_ = PrimStructInfo(value.dtype()); + n->struct_info_ = PrimStructInfo(value); n->value = std::move(value); n->span = std::move(span); data_ = std::move(n); diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index f0f0d29b517a..01743088028c 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -225,6 +225,9 @@ void ExprVisitor::VisitExpr_(const SeqExprNode* op) { void ExprVisitor::VisitExpr_(const PrimValueNode* op) { this->VisitPrimExpr(op->value); + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } this->VisitSpan(op->span); } diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc index 31784af00041..9b635bb479a9 100644 --- a/src/relax/ir/struct_info.cc +++ b/src/relax/ir/struct_info.cc @@ -42,19 +42,32 @@ TVM_REGISTER_GLOBAL("relax.ObjectStructInfo").set_body_typed([](Span span) { }); // Prim +PrimStructInfo::PrimStructInfo(PrimExpr value, Span span) { + ObjectPtr n = make_object(); + n->dtype = value->dtype; + n->value = std::move(value); + n->span = span; + data_ = std::move(n); +} + PrimStructInfo::PrimStructInfo(DataType dtype, Span span) { ObjectPtr n = make_object(); n->dtype = dtype; + n->value = NullOpt; n->span = span; data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(PrimStructInfoNode); -TVM_REGISTER_GLOBAL("relax.PrimStructInfo").set_body_typed([](DataType dtype, Span span) { +TVM_REGISTER_GLOBAL("relax.PrimStructInfoFromDtype").set_body_typed([](DataType dtype, Span span) { return PrimStructInfo(dtype, span); }); +TVM_REGISTER_GLOBAL("relax.PrimStructInfoFromValue").set_body_typed([](PrimExpr value, Span span) { + return PrimStructInfo(value, span); +}); + // Shape ShapeStructInfo::ShapeStructInfo(Array values, Span span) { ObjectPtr n = make_object(); diff --git a/src/relax/ir/struct_info_functor.cc b/src/relax/ir/struct_info_functor.cc index 10babe4b066e..bb7c7d654dc7 100644 --- a/src/relax/ir/struct_info_functor.cc +++ b/src/relax/ir/struct_info_functor.cc @@ -28,7 +28,11 @@ namespace relax { void StructInfoVisitor::VisitStructInfo_(const ObjectStructInfoNode* op) {} -void StructInfoVisitor::VisitStructInfo_(const PrimStructInfoNode* op) {} +void StructInfoVisitor::VisitStructInfo_(const PrimStructInfoNode* op) { + if (op->value.defined()) { + this->VisitStructInfoExprField(op->value.value()); + } +} void StructInfoVisitor::VisitStructInfo_(const ShapeStructInfoNode* op) { if (op->values.defined()) { @@ -68,7 +72,16 @@ StructInfo StructInfoMutator::VisitStructInfo_(const ObjectStructInfoNode* op) { } StructInfo StructInfoMutator::VisitStructInfo_(const PrimStructInfoNode* op) { - return GetRef(op); + if (!op->value.defined()) { + return GetRef(op); + } + + auto new_expr = VisitStructInfoExprField(op->value.value()); + if (new_expr.same_as(op->value)) { + return GetRef(op); + } else { + return PrimStructInfo(new_expr); + } } StructInfo StructInfoMutator::VisitStructInfo_(const ShapeStructInfoNode* op) { diff --git a/src/script/printer/relax/struct_info.cc b/src/script/printer/relax/struct_info.cc index cccf9ed08b2c..11f987c368fa 100644 --- a/src/script/printer/relax/struct_info.cc +++ b/src/script/printer/relax/struct_info.cc @@ -30,12 +30,6 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return Relax(d, "Object"); }); -TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( - "", [](relax::PrimStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc { - return Relax(d, "Prim")->Call({LiteralDoc::DataType(n->dtype, n_p->Attr("dtype"))}); - }); - ExprDoc PrintShapeVar(const PrimExpr& e, const ObjectPath& e_p, const IRDocsifier& d) { ExprDoc expr_doc = d->AsDoc(e, e_p); // Step 1. Find if `func_vars` are being collected @@ -66,6 +60,23 @@ ExprDoc PrintShapeVar(const PrimExpr& e, const ObjectPath& e_p, const IRDocsifie return expr_doc; } +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "", [](relax::PrimStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc { + Array args; + Array kwargs_keys; + Array kwargs_values; + + if (n->value.defined()) { + kwargs_keys.push_back("value"); + kwargs_values.push_back(PrintShapeVar(n->value.value(), n_p->Attr("value"), d)); + } else { + args.push_back(LiteralDoc::DataType(n->dtype, n_p->Attr("dtype"))); + } + + return Relax(d, "Prim")->Call(args, kwargs_keys, kwargs_values); + }); + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( "", [](relax::ShapeStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc { diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py b/tests/python/relax/test_analysis_struct_info_analysis.py index d279b60b541c..e0844829d340 100644 --- a/tests/python/relax/test_analysis_struct_info_analysis.py +++ b/tests/python/relax/test_analysis_struct_info_analysis.py @@ -578,7 +578,7 @@ def test_tir_vars_in_struct_info(): tvm.ir.assert_structural_equal(rx.analysis.tir_vars_in_struct_info(func), [n, m]) -def test_symbolic_var_collector(): +def test_collect_symbolic_var_from_tensor_shape(): n, m, k, q, p = ( tir.Var("n", "int64"), tir.Var("m", "int64"), @@ -600,5 +600,46 @@ def test_symbolic_var_collector(): assert free_vars == {n, p, q} +param_type = tvm.testing.parameter("shape_expr", "prim_value") +param_order = tvm.testing.parameter("definition_first", "usage_first") + + +def test_collect_symbolic_var_from_non_tensor_params(param_type, param_order): + tir_n = tir.Var("n", "int64") + tir_m = tir.Var("m", "int64") + + bb = rx.BlockBuilder() + arg = rx.Var("arg", rx.TensorStructInfo([tir_n * tir_m])) + + if param_type == "shape_expr": + extra_params = [ + rx.Var("shape_expr", rx.ShapeStructInfo([tir_n, tir_m])), + ] + elif param_type == "prim_value": + extra_params = [ + rx.Var("n", rx.PrimStructInfo(value=tir_n)), + rx.Var("m", rx.PrimStructInfo(value=tir_m)), + ] + else: + raise ValueError(f"Unknown param_type: {param_type}") + + if param_order == "definition_first": + params = [*extra_params, arg] + elif param_order == "usage_first": + params = [arg, *extra_params] + else: + raise ValueError(f"Unknown param_order: {param_order}") + + with bb.function("main", params=params): + out = rx.op.reshape(arg, [tir_n, tir_m]) + bb.emit_func_output(out) + func = bb.get()["main"] + + defined_vars = set(rx.analysis.defined_symbolic_vars(func)) + free_vars = set(rx.analysis.free_symbolic_vars(func)) + assert defined_vars == {tir_n, tir_m} + assert free_vars == set() + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_bind_symbolic_vars.py b/tests/python/relax/test_bind_symbolic_vars.py index 1dc1189a67f3..82798c56dfff 100644 --- a/tests/python/relax/test_bind_symbolic_vars.py +++ b/tests/python/relax/test_bind_symbolic_vars.py @@ -163,7 +163,7 @@ def expected(A: R.Tensor(["outside_var * 2", "outside_var"])): tvm.ir.assert_structural_equal(expected, after) -def test_bind_symbolic_vars_in_shape(): +def test_bind_symbolic_vars_in_tensor_shape(): """The bound variable should be replaced when appearing in struct info""" @R.function(private=True) @@ -183,6 +183,91 @@ def expected(A: R.Tensor(["M", 16])): tvm.ir.assert_structural_equal(expected, after) +def test_bind_symbolic_vars_in_shape_expr(): + """The bound variable should be replaced when appearing in R.Shape""" + + @R.function(private=True) + def before(A: R.Tensor(["M * N"]), x: R.Shape(["M", "N"])): + M = T.int64() + N = T.int64() + B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([2 * M * N])) + return B + + @R.function(private=True) + def expected(A: R.Tensor(["M * 16"]), x: R.Shape(["M", 16])): + B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([M * 32])) + return B + + after = before.bind_symbolic_vars({"N": 16}) + tvm.ir.assert_structural_equal(expected, after) + + +def test_bind_defining_of_symbolic_vars_in_prim_value(): + """R.Prim may define symbolic variables + + This case is a bit odd, because it always results in a + fully-constrained parameter at the relax level. After binding in + this test case, we have a function that accepts three parameters, + and the third parameter must always be the number 16. + + However, this provides the most consistent behavior with other + uses of `relax.Function.bind_symbolic_vars`, which restricts the + allowed values for each parameter, but does not alter the number + of parameters. This is in contrast to the `BindParams` pass, + which provides a known value for relax parameters, removing them + from the function signature. + + This convention also prevents surprise changes to the function + signature, such as shown in + `test_bind_symbolic_vars_with_expr_in_prim_value`. + """ + + @R.function(private=True) + def before(A: R.Tensor(["M * N"]), x: R.Prim(value="M"), y: R.Prim(value="N")): + M = T.int64() + N = T.int64() + B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([2 * M * N])) + return B + + @R.function(private=True) + def expected(A: R.Tensor(["M * 16"]), x: R.Prim(value="M"), y: R.Prim(value=16)): + B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([M * 32])) + return B + + after = before.bind_symbolic_vars({"N": 16}) + tvm.ir.assert_structural_equal(expected, after) + + +def test_bind_usage_of_symbolic_vars_in_prim_value(): + """R.Prim may use symbolic variables defined by other parameters + + Like test_bind_defining_of_symbolic_vars_in_prim_value, but with + R.Prim using a symbolic variable rather than defining it. + + This also demonstrates why we should not remove fully-constrained + R.Prim function parameters. In this case, we have a function that + accepts two parameters, and we have specialized the shape of the + first parameter. It would be unexpected for specialization of the + first parameter to result in removal of a different parameter + altogether. + """ + + @R.function(private=True) + def before(A: R.Tensor(["M", "N"]), x: R.Prim(value="M*N")): + M = T.int64() + N = T.int64() + B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([2 * M * N])) + return B + + @R.function(private=True) + def expected(A: R.Tensor([16, 16]), x: R.Prim(value=256)): + B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([512])) + return B + + after = before.bind_symbolic_vars({"M": 16, "N": 16}) + tvm.ir.assert_structural_equal(expected, after) + + def test_bind_strided_slice(): """relax.op.strided_slice stores PrimExpr attributes""" diff --git a/tests/python/relax/test_expr.py b/tests/python/relax/test_expr.py index 902c4785610f..fbd37b307eba 100644 --- a/tests/python/relax/test_expr.py +++ b/tests/python/relax/test_expr.py @@ -238,6 +238,23 @@ def test_prim_value(): _check_json_roundtrip(pv) +def test_prim_value_with_var(): + n = tir.Var("n", "int64") + pv = rx.PrimValue(n) + assert pv.value.same_as(n) + tvm.ir.assert_structural_equal(pv.struct_info, rx.PrimStructInfo(value=n)) + _check_equal(pv, rx.PrimValue(n)) + _check_json_roundtrip(pv) + + +def test_prim_value_with_expr(): + n = tir.Var("n", "int64") + pv = rx.PrimValue(n + 1) + tvm.ir.assert_structural_equal(pv.struct_info, rx.PrimStructInfo(value=n + 1)) + _check_equal(pv, rx.PrimValue(n + 1)) + _check_json_roundtrip(pv) + + def test_string_imm(): s0 = rx.StringImm("hello") s1 = rx.StringImm("hello") diff --git a/tests/python/relax/test_struct_info.py b/tests/python/relax/test_struct_info.py index 80ebc3cb182a..33dcd7e9d77e 100644 --- a/tests/python/relax/test_struct_info.py +++ b/tests/python/relax/test_struct_info.py @@ -86,7 +86,23 @@ def test_prim_struct_info(): # wrong API constructors with pytest.raises(TVMError): - rx.PrimStructInfo(1) + rx.PrimStructInfo([1]) + + +def test_prim_struct_info_with_expr(): + n = tir.Var("n", "int64") + sinfo = rx.PrimStructInfo(value=n + 1) + + _check_equal(sinfo, rx.PrimStructInfo(value=n + 1)) + assert not tvm.ir.structural_equal(sinfo, rx.PrimStructInfo(dtype=n.dtype)) + + # can turn into str + str(sinfo) + + assert isinstance(sinfo, rx.PrimStructInfo) + _check_json_roundtrip(sinfo) + + assert sinfo.dtype == "int64" def test_shape_struct_info(): diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 39a4d33ca6db..abb086be858c 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -1192,8 +1192,9 @@ def foo(x: R.Tuple()): _check(foo, bb.get()["foo"]) -def test_symbolic_shape_computing(): - # Tensor Case 1 +def test_symbolic_vars_in_tensor_shape_with_usage_first(): + """First param may use symbolic variable defined in second param""" + @R.function def foo(x: R.Tensor(("m + 1",), "float32"), y: R.Tensor(("m", 1), "float32")): z = R.add(x, y) @@ -1209,7 +1210,10 @@ def foo(x: R.Tensor(("m + 1",), "float32"), y: R.Tensor(("m", 1), "float32")): _check(foo, bb.get()["foo"]) - # Tensor Case 2 + +def test_symbolic_vars_in_tensor_shape_with_definition_first(): + """Second param may use symbolic variable defined in first param""" + @R.function def bar( x: R.Tensor(("m",), "float32"), y: R.Tensor(("T.max(m, 20)",), "float32") @@ -1232,7 +1236,10 @@ def bar( _check(bar, bb.get()["bar"]) - # Shape Case + +def test_symbolic_vars_in_shape(): + """Symbolic variable may be defined in R.Shape""" + @R.function def baz(x: R.Shape(("m",)), y: R.Tensor(("m * 2",), "float32")): m = T.int64() @@ -1249,7 +1256,36 @@ def baz(x: R.Shape(("m",)), y: R.Tensor(("m * 2",), "float32")): _check(baz, bb.get()["baz"]) - # Error Case + +def test_symbolic_vars_in_prim_value(): + """Symbolic variable may be defined in R.Prim""" + + @R.function + def baz(x: R.Prim(value="m"), y: R.Tensor(("m * 2",), "float32")): + m = T.int64() + z = R.call_dps_packed("test_intrin", y, R.Tensor((m * 2,), dtype="float32")) + return z + + m = tir.Var("m", "int64") + x = relax.Var("x", relax.PrimStructInfo(value=m)) + y = relax.Var("y", relax.TensorStructInfo([m * 2], "float32")) + bb = relax.BlockBuilder() + with bb.function("baz", (x, y)): + z = bb.emit(relax.call_dps_packed("test_intrin", (y), R.Tensor((m * 2,), dtype="float32"))) + bb.emit_func_output(z) + + _check(baz, bb.get()["baz"]) + + +def test_undefined_symbolic_var_raises_error(): + """An undefined symbolic variable in an error + + A symbolic variables is defined at the first site where it appears + as a shape parameter without any modification. TVMScript does not + support solving for a symbolic variable in terms of the argument + shape. That is, this test case raises an error, and will not + attempt to define `m` as either `x.shape[0]-1` or `x.shape[1]//2`. + """ with pytest.raises(tvm.error.DiagnosticError): @R.function diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index 9f4ffd9acdfe..2e4218b2abef 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -205,6 +205,7 @@ def test_func_struct_info(): relax.PrimStructInfo("float32"), relax.ObjectStructInfo(), relax.ShapeStructInfo([1, tir.Var("a", "int64"), 3]), + relax.PrimStructInfo(value=tir.Var("b", "int64")), ], ret=relax.TensorStructInfo( shape=relax.ShapeExpr([1, 2, 3]), @@ -214,7 +215,8 @@ def test_func_struct_info(): _assert_print( obj, "a = T.int64()\n" - 'R.Callable((R.Prim("float32"), R.Object, R.Shape([1, a, 3])), ' + "b = T.int64()\n" + 'R.Callable((R.Prim("float32"), R.Object, R.Shape([1, a, 3]), R.Prim(value=b)), ' 'R.Tensor((1, 2, 3), dtype="float32"), True)', )