From 933f14532966d40912fa28b570fa4fc3b6b33b3f Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 14 Feb 2024 16:03:07 +0000 Subject: [PATCH 1/2] [Relax] Additional unit tests for RemoveUnusedParameters Verifying behavior for subroutines that receive `R.Prim` or `R.Shape` parameters, if the symbolic variables defined by those parameters are already defined by another parameter. --- ...test_transform_remove_unused_parameters.py | 109 +++++++++++++++++- 1 file changed, 106 insertions(+), 3 deletions(-) diff --git a/tests/python/relax/test_transform_remove_unused_parameters.py b/tests/python/relax/test_transform_remove_unused_parameters.py index 82c8d0bd1d29..72a32de89332 100644 --- a/tests/python/relax/test_transform_remove_unused_parameters.py +++ b/tests/python/relax/test_transform_remove_unused_parameters.py @@ -24,7 +24,14 @@ class BaseCompare(tvm.testing.CompareBeforeAfter): transform = tvm.relax.transform.RemoveUnusedParameters() -class TestSimple(BaseCompare): +class TestRemoveUnusedRelaxParameter(BaseCompare): + """A relax parameter may be removed + + This is only allowed for internal function calls, where all + callsites can be updated. For externally-exposed functions, the + signature may not be modified. + """ + @I.ir_module class Before: @R.function @@ -46,7 +53,15 @@ def func(A: R.Tensor) -> R.Tensor: return A -class TestSymbolicVariables(BaseCompare): +class TestReplaceSymbolicVariables(BaseCompare): + """If a parameter is only required for its symbolic variables, provide them directly + + The relax parameter `A` isn't used by the subroutine. However, + its shape defines the symbolic variables `m` and `n`. When + removing the `R.Tensor` argument, we may need to provide + additional parameters to define the symbolic variables. + """ + @I.ir_module class Before: @R.function @@ -78,7 +93,12 @@ def func( class TestNoExtraSymbolicVariables(BaseCompare): - """Don't add symbolic variables if they can be inferred.""" + """Don't add symbolic variables if they can be inferred. + + Even though some cases require adding new parameters to provide + symbolic variables, not every symbolic variable requires a + distinct parameter. + """ @I.ir_module class Before: @@ -97,5 +117,88 @@ def func(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"): Expected = Before +class TestRemoveExtraPrimVariables(BaseCompare): + """Remove parameters that only serve to define existing symbolic variables + + If a `R.Prim` parameter provies a definition of a symbolic + variable, but that symbolic variable can be determined from a + different parameter, then the `R.Prim` parameter can be removed. + """ + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"): + m = T.int64() + n = T.int64() + return Before.func(A, R.prim_value(m), R.prim_value(n)) + + @R.function(private=True) + def func( + A: R.Tensor(["m", "n"], "float32"), _m: R.Prim(value="m"), _n: R.Prim(value="n") + ) -> R.Tensor(["m", "n"], "float32"): + m = T.int64() + n = T.int64() + zeros = R.zeros(R.shape([m, n]), dtype="float32") + out = R.add(A, zeros) + return out + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"): + return Expected.func(A) + + @R.function(private=True) + def func(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"): + m = T.int64() + n = T.int64() + zeros = R.zeros(R.shape([m, n]), dtype="float32") + out = R.add(A, zeros) + return out + + +class TestRemoveExtraShapeVariables(BaseCompare): + """Remove parameters that only serve to define existing symbolic variables + + If a `R.Shape` parameter provies a definition of a symbolic + variable, but that symbolic variable can be determined from a + different parameter, then the `R.Shape` parameter can be removed. + """ + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"): + m = T.int64() + n = T.int64() + return Before.func(A, R.shape([m, n])) + + @R.function(private=True) + def func( + A: R.Tensor(["m", "n"], "float32"), + _: R.Shape(["m", "n"]), + ) -> R.Tensor(["m", "n"], "float32"): + m = T.int64() + n = T.int64() + zeros = R.zeros(R.shape([m, n]), dtype="float32") + out = R.add(A, zeros) + return out + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"): + return Expected.func(A) + + @R.function(private=True) + def func(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"): + m = T.int64() + n = T.int64() + zeros = R.zeros(R.shape([m, n]), dtype="float32") + out = R.add(A, zeros) + return out + + if __name__ == "__main__": tvm.testing.main() From d0fb692290fbee2e24d08fced6bfbbaacbd8ebf0 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 16 Feb 2024 02:33:28 +0000 Subject: [PATCH 2/2] Typo fix --- tests/python/relax/test_transform_remove_unused_parameters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_transform_remove_unused_parameters.py b/tests/python/relax/test_transform_remove_unused_parameters.py index 72a32de89332..ea905eb88283 100644 --- a/tests/python/relax/test_transform_remove_unused_parameters.py +++ b/tests/python/relax/test_transform_remove_unused_parameters.py @@ -161,7 +161,7 @@ def func(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"): class TestRemoveExtraShapeVariables(BaseCompare): """Remove parameters that only serve to define existing symbolic variables - If a `R.Shape` parameter provies a definition of a symbolic + If a `R.Shape` parameter provides a definition of a symbolic variable, but that symbolic variable can be determined from a different parameter, then the `R.Shape` parameter can be removed. """