diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py index a3b391637cb4..f8c13d30b172 100644 --- a/python/tvm/script/parser/relax/entry.py +++ b/python/tvm/script/parser/relax/entry.py @@ -150,6 +150,9 @@ def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> Struc def get_symbolic_vars(self) -> Set[str]: return {} + def get_symbolic_size_vars(self) -> Set[str]: + return self.get_symbolic_vars() + def asobject(self): return self.as_struct_info(None) @@ -172,9 +175,6 @@ class ObjectProxy(StructInfoProxy): def __init__(self) -> None: pass - def get_symbolic_vars(self) -> Set[str]: - return set() - def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> ShapeStructInfo: return ObjectStructInfo() @@ -327,6 +327,12 @@ def get_symbolic_vars(self) -> Set[str]: else: return set().union(*[p.get_symbolic_vars() for p in self.params]) + def get_symbolic_size_vars(self) -> Set[str]: + if self.params is None: + return set() + else: + return set().union(*[p.get_symbolic_size_vars() for p in self.params]) + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> FuncStructInfo: if self.ret is None: ret = None @@ -377,6 +383,9 @@ def __init__( def get_symbolic_vars(self) -> Set[str]: return set().union(*[f.get_symbolic_vars() for f in self.fields]) + def get_symbolic_size_vars(self) -> Set[str]: + return set().union(*[f.get_symbolic_size_vars() for f in self.fields]) + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> TupleStructInfo: fields = [field.as_struct_info(dict_globals) for field in self.fields] return TupleStructInfo(fields) @@ -463,6 +472,13 @@ def get_symbolic_vars(self) -> Set[str]: else: return set() + def get_symbolic_size_vars(self) -> Set[str]: + # While variables defined by R.Shape and R.Tensor arguments + # are known to be non-negative, R.Prim arguments may be + # negative. Overriding the default implementation of + # `get_symbolic_size_vars()` + return set() + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> ShapeStructInfo: if self.value is None: return PrimStructInfo(dtype=self.dtype) diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index 9d73749b0aa4..aec5a0965978 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -147,15 +147,23 @@ def is_recursive(node: doc.FunctionDef) -> bool: def collect_symbolic_var_from_params(self: Parser, node: doc.FunctionDef) -> None: # Collect symbolic vars from parameters symbolic_vars = set() + symbolic_size_vars = set() for arg in node.args.args: if arg.annotation is None: self.report_error(arg, "Type annotation is required for function parameters.") param_sinfo_proxy = eval_struct_info_proxy(self, arg.annotation) symbolic_vars.update(param_sinfo_proxy.get_symbolic_vars()) + symbolic_size_vars.update(param_sinfo_proxy.get_symbolic_size_vars()) + + assert len(symbolic_size_vars - symbolic_vars) == 0, ( + "Internal error: " + "All collected tir.SizeVar names must also appear in the list of tir.Var names" + ) # Define symbolic vars to the current var_table frame for var_name in symbolic_vars: - self.var_table.add(var_name, tir.Var(var_name, "int64"), allow_shadowing=False) + var_cls = tir.SizeVar if var_name in symbolic_size_vars else tir.Var + self.var_table.add(var_name, var_cls(var_name, "int64"), allow_shadowing=False) @dispatch.register(token="relax", type_name="FunctionDef") diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 64014d1c49be..768a379b7061 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -2293,7 +2293,6 @@ def subroutine(x: R.Tensor, _: R.Shape(["m", "n"])) -> R.Tensor(["m", "n"]): assert func.attrs is not None -@pytest.mark.xfail(reason="Bug: Implicit bounds not provided when parsing") def test_function_symbolic_variables_are_annotated(): """Symbolic variables must be exposed for struct inference @@ -2317,5 +2316,155 @@ def expected(A: R.Tensor(["extent"])) -> R.Tensor(["extent-1"]): tvm.ir.assert_structural_equal(inferred_sinfo, expected) +def test_symbolic_shape_variables_are_size_var(): + """Symbolic variables inferred from shapes are SizeVar + + The indices in `R.strided_slice` follow Python's conventions for + negative indices. Absent any additional information, a slice + `arr[0:i]` would either have length `i` when `i >= 0`, or length + `len(arr) + i` when `i < 0`. + + In this case, though, the dynamic `extent` variable is known to be + non-negative, because negative values may not be used as the + dimensions of `R.Tensor` or `R.Shape`. Because Relax struct + inference is performed while TVMScript is being parsed, this + constraint must be exposed during TVMScript parsing in order to + correctly infer the resulting StructInfo. + + """ + + @R.function(private=True) + def inferred_sinfo(A: R.Tensor(["extent"])): + extent = T.int64() + output = R.strided_slice(A, [0], [0], [extent]) + return output + + @R.function(private=True) + def expected(A: R.Tensor(["extent"])) -> R.Tensor(["extent"]): + extent = T.int64() + output: R.Tensor([extent]) = R.strided_slice(A, [0], [0], [extent]) + return output + + tvm.ir.assert_structural_equal(inferred_sinfo, expected) + + assert isinstance(inferred_sinfo.params[0].struct_info.shape[0], tir.SizeVar) + + +def test_symbolic_variables_from_prim_value_may_be_negative(): + """Symbolic variables inferred from R.Prim are Var + + Not all symbolic variables represent shapes. While a + `relax::PrimValue` can be the source of definition for a TIR + variable, a `relax::PrimValue` may not represent a shape, and may + be negative. + + This test is similar to + `test_symbolic_shape_variables_are_size_var`, except that the + `extent` variable is defined by a `R.Prim` argument, and not by a + `R.Tensor` argument. As a result, we do not know whether `extent` + is negative, and cannot simplify expressions that depend on + `extent<0`. + + """ + + @R.function(private=True) + def inferred_sinfo(A: R.Tensor([16]), _: R.Prim(value="extent")): + extent = T.int64() + output = R.strided_slice(A, [0], [0], [extent]) + return output + + @R.function(private=True) + def expected(A: R.Tensor([16]), _: R.Prim(value="extent")): + extent = T.int64() + output: R.Tensor( + [T.min(T.max(T.if_then_else(extent < 0, extent + 16, extent), 0), 16)] + ) = R.strided_slice(A, [0], [0], [extent]) + return output + + tvm.ir.assert_structural_equal(inferred_sinfo, expected) + + assert not isinstance(inferred_sinfo.params[1].struct_info.value, tir.SizeVar) + + +def test_other_arguments_may_cause_prim_value_to_define_size_var(): + """Other arguments may cause R.Prim to hold SizeVar + + This test is similar to + `test_symbolic_variables_from_prim_value_may_be_negative`, except + that `extent` also appears in a `R.Shape`. While the + `R.Prim(value="extent")` occurs first in the parameter list, and + is the source of definition, the presence of `extent` in `R.Shape` + parameter shows that it is a `SizeVar`. + + """ + + @R.function(private=True) + def inferred_sinfo( + A: R.Tensor([16]), + _prim: R.Prim(value="extent"), + _shape: R.Shape( + ["extent"], + ), + ): + extent = T.int64() + output = R.strided_slice(A, [0], [0], [extent]) + return output + + @R.function(private=True) + def expected( + A: R.Tensor([16]), + _prim: R.Prim(value="extent"), + _shape: R.Shape(["extent"]), + ): + extent = T.int64() + output: R.Tensor([T.min(extent, 16)]) = R.strided_slice(A, [0], [0], [extent]) + return output + + tvm.ir.assert_structural_equal(inferred_sinfo, expected) + + assert isinstance(inferred_sinfo.params[1].struct_info.value, tir.SizeVar) + + +@pytest.mark.xfail(reason="Bug: Implicit bounds not provided when parsing") +def test_known_positive_expressions(): + """Expressions may be known as non-negative + + The variable `N` is not defined as a shape variable, and may be + either positive or negative. However, the expression `N+16` is + used as the shape of a tensor, and is therefore known not to be + negative. Later use of the expression `N+16 < 0` may therefore be + simplified. + + This test is currently marked as failing. When using + `relax::BlockBuilder::VisitWithNewScope` is provided with + parameters, it can mark shape expressions as non-negative, in + addition to individual variables. However, this is not currently + used for TVMScript parsing. + + """ + + @R.function(private=True) + def inferred_sinfo( + A: R.Tensor(["N + 16"]), + _: R.Prim(value="N"), + ): + N = T.int64() + output = R.strided_slice(A, [0], [0], [N + 16]) + return output + + @R.function(private=True) + def expected( + A: R.Tensor(["N + 16"]), + _: R.Prim(value="N"), + ): + N = T.int64() + output: R.Tensor([N + 16]) = R.strided_slice(A, [0], [0], [N + 16]) + return output + + tvm.ir.assert_structural_equal(inferred_sinfo, expected) + + assert not isinstance(inferred_sinfo.params[1].struct_info.value, tir.SizeVar) + + if __name__ == "__main__": tvm.testing.main()