From 6c880b69f28121dd3c0fb7c0304ff4ea693a4eb7 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 14 Feb 2024 00:02:20 -0500 Subject: [PATCH 01/58] Check well-formedness in the parser --- python/tvm/script/parser/core/entry.py | 31 ++- tests/python/relax/test_analysis.py | 14 +- .../test_analysis_estimate_memory_usage.py | 2 +- .../test_transform_normalize_global_var.py | 53 ++-- ...ansform_operator_specific_normalization.py | 255 ++++++++++++------ .../relax/test_vm_alloc_storage_with_scope.py | 2 +- tests/python/relax/test_vm_codegen_only.py | 19 +- tests/python/relax/test_vm_codegen_tir.py | 8 +- tests/python/relax/test_vm_cuda_graph.py | 4 +- 9 files changed, 249 insertions(+), 139 deletions(-) diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index 9a7430643cd8..3b4333029388 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -18,6 +18,8 @@ import inspect from typing import Any, Dict, Union +from ....relax.analysis import well_formed +from ....ir.module import IRModule from ...ir_builder import IRBuilder from . import doc from .diagnostics import Source @@ -43,6 +45,22 @@ def scan_macro(program: Union[Any, str], extra_vars: Dict[str, Any] = None) -> A return source, closure_vars +def find_decorator_annotation(node: doc.Module, annotation: str, default: bool = True) -> bool: + """ + Check the value of given annotation (argument name) in the function decorator. + Returns the value of the annotation if present, otherwise giving the default value. + """ + # Note: A Module body is always a list containing a single ClassDef. + # The ClassDef has a decorator list + for dec in node.body[0].decorator_list: + if not isinstance(dec, doc.Call) or dec.func.attr != "ir_module": + continue + for keyword in dec.keywords: + if keyword.arg == annotation: + return keyword.value.value + return default + + def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) -> Any: """Register a method for a operand type, AST operator node and operand index. @@ -77,4 +95,15 @@ def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) parser.parse(extra_vars=extra_vars) except ParserError as err: parser.report_error(err.node, err.args[0]) - return builder.get() + ret = builder.get() + # well-formedness check will ignore any non-Relax functions + if isinstance(ret, IRModule): + # note: use the walrus operator (:=) once the project Python version + # supports it, would be more concise + source_ast = source.as_ast() + if find_decorator_annotation(source_ast, "check_well_formed") and not well_formed(ret): + parser.report_error( + source_ast, + err="Program containing Relax functions is not well-formed", + ) + return ret diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py index abbe380d4839..178869442249 100644 --- a/tests/python/relax/test_analysis.py +++ b/tests/python/relax/test_analysis.py @@ -97,7 +97,7 @@ def test_binding_block_remove_all_unused(): @tvm.script.ir_module class IdentityUnused: - @R.function + @R.function(pure=False) def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: with R.dataflow(): lv0 = x @@ -113,7 +113,7 @@ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: @tvm.script.ir_module class GroundTruth: - @R.function + @R.function(pure=False) def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: with R.dataflow(): lv0 = x @@ -202,7 +202,7 @@ def before(x: R.Tensor((32, 32), "int32")): def test_binding_block_remove_all_unused_func_without_dataflow(): @tvm.script.ir_module class IdentityUnused: - @R.function + @R.function(pure=False) def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: lv0 = x @@ -217,7 +217,7 @@ def internal_unused_func(A: R.Tensor((32, 32), "float32")) -> R.Tensor: @tvm.script.ir_module class GroundTruth: - @R.function + @R.function(pure=False) def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: lv0 = x z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) @@ -229,7 +229,7 @@ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: def test_binding_block_fake_unused_remove_all_unused(): @tvm.script.ir_module class IdentityUnused: - @R.function + @R.function(pure=False) def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: with R.dataflow(): lv0 = x @@ -241,7 +241,7 @@ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: @tvm.script.ir_module class GroundTruth: - @R.function + @R.function(pure=False) def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: with R.dataflow(): lv0 = x @@ -256,7 +256,7 @@ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: def test_edge_binding_block_fake_unused_remove_all_unused(): @tvm.script.ir_module class IdentityUnused: - @R.function + @R.function(pure=False) def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): z = R.call_packed("vm.builtin.copy", x, sinfo_args=(R.Tensor((32, 32), "float32"))) return x diff --git a/tests/python/relax/test_analysis_estimate_memory_usage.py b/tests/python/relax/test_analysis_estimate_memory_usage.py index 31419b544d23..ab036aab6141 100644 --- a/tests/python/relax/test_analysis_estimate_memory_usage.py +++ b/tests/python/relax/test_analysis_estimate_memory_usage.py @@ -66,7 +66,7 @@ def pad( ): T.evaluate(0) - @R.function + @R.function(pure=False) def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32"): cls = Module storage: R.Object = R.memory.alloc_storage( diff --git a/tests/python/relax/test_transform_normalize_global_var.py b/tests/python/relax/test_transform_normalize_global_var.py index 0a26ffc8e6f6..53ce6a5a8b19 100644 --- a/tests/python/relax/test_transform_normalize_global_var.py +++ b/tests/python/relax/test_transform_normalize_global_var.py @@ -15,31 +15,31 @@ # specific language governing permissions and limitations # under the License. import pytest +import numpy as np import tvm import tvm.testing from tvm import relax -from tvm import tir from tvm.ir.base import assert_structural_equal import tvm.script from tvm.script import tir as T, relax as R, ir as I -@pytest.mark.skip_well_formed_check_before_transform def test_normalize_relax_function(): - @I.ir_module - class Before: - @R.function(private=True) - def f(): - return R.const(1, "int32") - - @R.function - def f1(): - R.func_attr({"global_symbol": "f"}) - cls = Before - gv: R.Tensor((), dtype="int32") = cls.f() - return gv + # parser will check well-formedness so we can't use it to construct this example + bb = relax.BlockBuilder() + f = relax.Function( + [], + relax.SeqExpr([], relax.Constant(tvm.nd.array(np.int32(1)), R.Tensor((), "int32"))), + R.Tensor((), "int32"), + ) + f_gv = bb.add_func(bb.normalize(f).without_attr("global_symbol"), "f") + with bb.function("f1", []): + gv = bb.emit(f_gv(), "gv") + bb.emit_func_output(gv) + Before = bb.get() + Before.update_func(Before.get_global_var("f1"), Before["f1"].with_attr("global_symbol", "f")) @I.ir_module class Expected: @@ -62,18 +62,19 @@ def f1(): @pytest.mark.skip_well_formed_check_before_transform def test_normalize_tir_function(): - @I.ir_module - class Before: - @T.prim_func(private=True) - def f(x: T.Buffer((1,), "int32")): - x[0] = T.int32(0) - - @R.function - def f1(): - R.func_attr({"global_symbol": "f"}) - cls = Before - gv: R.Tensor((), dtype="int32") = R.call_tir(cls.f, (), R.Tensor((1,), dtype="int32")) - return gv + # parser will check well-formedness so we can't use it to construct this example + bb = relax.BlockBuilder() + + @T.prim_func(private=True) + def f(x: T.Buffer((1,), "int32")): + x[0] = T.int32(0) + + f_gv = bb.add_func(f, "f") + with bb.function("f1", []): + gv = bb.emit(R.call_tir(f_gv, (), R.Tensor((1,), dtype="int32"))) + bb.emit_func_output(gv) + Before = bb.get() + Before.update_func(Before.get_global_var("f1"), Before["f1"].with_attr("global_symbol", "f")) @I.ir_module class Expected: diff --git a/tests/python/relax/test_transform_operator_specific_normalization.py b/tests/python/relax/test_transform_operator_specific_normalization.py index 4ee17166452f..8209db40f458 100644 --- a/tests/python/relax/test_transform_operator_specific_normalization.py +++ b/tests/python/relax/test_transform_operator_specific_normalization.py @@ -95,11 +95,14 @@ def func(A: R.Tensor): def test_normalization_applied_during_cpp_mutator(custom_op): """FNormalize is applied by relax::ExprMutator subclasses""" - @I.ir_module - class Before: - @R.function - def main(A: R.Tensor): - return relax.Call(custom_op, [A]) + # can't use parser because it will check well-formedness proactively + @R.function + def main(A: R.Tensor): + return relax.Call(custom_op, [A]) + + bb = relax.BlockBuilder() + bb.add_func(main, "main") + Before = bb.get() @I.ir_module class Expected: @@ -155,11 +158,13 @@ def test_un_normalized_call_node_is_ill_formed(custom_op, define_normalization): FNormalize has no corresponding check applied. """ - @I.ir_module - class Module: - @R.function - def main(A: R.Tensor): - return relax.Call(custom_op, [A]) + @R.function + def main(A: R.Tensor): + return relax.Call(custom_op, [A]) + + bb = relax.BlockBuilder() + bb.add_func(main, "main") + Module = bb.get() if define_normalization: assert not relax.analysis.well_formed(Module) @@ -171,22 +176,41 @@ def main(A: R.Tensor): def test_normalize_to_inline_tuple_for_call_tir(custom_op): """FNormalize in-lines the argument tuple for R.call_tir""" - @I.ir_module - class Before: - @R.function - def main(A: R.Tensor([16], "float32")): - cls = Before - args = (A,) - return relax.Call( - tvm.ir.Op.get("relax.call_tir"), - [cls.multiply_by_two, args], - sinfo_args=[A.struct_info], - ) - - @T.prim_func(private=True) - def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): - for i in range(16): - B[i] = A[i] * 2.0 + @T.prim_func(private=True) + def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + for i in range(16): + B[i] = A[i] * 2.0 + + bb = relax.BlockBuilder() + primfunc_gv = bb.add_func(multiply_by_two, "multiply_by_two") + a = relax.Var("a", R.Tensor([16], "float32")) + tup = relax.Var("tup", R.Tuple(R.Tensor([16], "float32"))) + ret = relax.Var("ret", R.Tensor([16], "float32")) + # can't even use the block builder to assemble because it will normalize! + main_func = relax.Function( + [a], + relax.SeqExpr( + [ + relax.BindingBlock( + [ + relax.VarBinding(tup, relax.Tuple([a])), + relax.VarBinding( + ret, + relax.Call( + tvm.ir.Op.get("relax.call_tir"), + [primfunc_gv, tup], + sinfo_args=[a.struct_info], + ), + ), + ] + ) + ], + ret, + ), + R.Tensor([16], "float32"), + ).with_attr("global_symbol", "main") + bb.add_func(main_func, "main") + Before = bb.get() @I.ir_module class Expected: @@ -219,21 +243,39 @@ def test_normalize_argument_to_inline_tuple_for_call_tir(custom_op): argument tuple is provided as a relax function argument. """ - @I.ir_module - class Before: - @R.function - def main(args: R.Tuple([R.Tensor([16], "float32")])): - cls = Before - return relax.Call( - tvm.ir.Op.get("relax.call_tir"), - [cls.multiply_by_two, args], - sinfo_args=[args[0].struct_info], - ) - - @T.prim_func(private=True) - def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): - for i in range(16): - B[i] = A[i] * 2.0 + @T.prim_func(private=True) + def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + for i in range(16): + B[i] = A[i] * 2.0 + + bb = relax.BlockBuilder() + primfunc_gv = bb.add_func(multiply_by_two, "multiply_by_two") + + args = relax.Var("args", R.Tuple(R.Tensor([16], "float32"))) + ret = relax.Var("ret", R.Tensor([16], "float32")) + main_func = relax.Function( + [args], + relax.SeqExpr( + [ + relax.BindingBlock( + [ + relax.VarBinding( + ret, + relax.Call( + tvm.ir.Op.get("relax.call_tir"), + [primfunc_gv, args], + sinfo_args=[R.Tensor([16], "float32")], + ), + ) + ] + ), + ], + ret, + ), + args[0].struct_info, + ).with_attr("global_symbol", "main") + bb.add_func(main_func, "main") + Before = bb.get() @I.ir_module class Expected: @@ -261,9 +303,9 @@ def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): def test_normalize_to_inline_tuple_for_call_tir_inplace(custom_op): """FNormalize in-lines the argument tuple for R.call_tir_inplace""" - # The CallTIRInplaceAttrs cannot be constructed from the Python - # API. Therefore, declaring the Expected output first, so that - # the attributes can be used for the non-normalized Before. + # The CallTIRInplaceAttrs is difficult to construct in the Python + # API, so it is more convenient to declare the expected one first + # and reuse its attributes @I.ir_module class Expected: @R.function @@ -284,23 +326,41 @@ def multiply_by_two(A: T.Buffer(16, "float32")): inplace_attrs = Expected["main"].body.blocks[0].bindings[1].value.attrs - @I.ir_module - class Before: - @R.function - def main(A: R.Tensor([16], "float32")): - cls = Before - args = (A,) - return relax.Call( - tvm.ir.Op.get("relax.call_tir_inplace"), - [cls.multiply_by_two, args], - attrs=inplace_attrs, - sinfo_args=[A.struct_info], - ) - - @T.prim_func(private=True) - def multiply_by_two(A: T.Buffer(16, "float32")): - for i in range(16): - A[i] = A[i] * 2.0 + @T.prim_func(private=True) + def multiply_by_two(A: T.Buffer(16, "float32")): + for i in range(16): + A[i] = A[i] * 2.0 + + bb = relax.BlockBuilder() + primfunc_gv = bb.add_func(multiply_by_two, "multiply_by_two") + a = relax.Var("a", R.Tensor([16], "float32")) + tup = relax.Var("tup", R.Tuple(R.Tensor([16], "float32"))) + ret = relax.Var("ret", R.Tensor([16], "float32")) + main_func = relax.Function( + [a], + relax.SeqExpr( + [ + relax.BindingBlock( + [ + relax.VarBinding(tup, relax.Tuple([a])), + relax.VarBinding( + ret, + relax.Call( + tvm.ir.Op.get("relax.call_tir_inplace"), + [primfunc_gv, tup], + attrs=inplace_attrs, + sinfo_args=[a.struct_info], + ), + ), + ] + ) + ], + ret, + ), + R.Tensor([16], "float32"), + ).with_attr("global_symbol", "main") + bb.add_func(main_func, "main") + Before = bb.get() After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before) @@ -312,9 +372,9 @@ def multiply_by_two(A: T.Buffer(16, "float32")): def test_normalize_to_inline_tuple_for_call_tir_with_grad(custom_op): """FNormalize in-lines the argument tuple for R.call_tir_with_grad""" - # The CallTIRWithGradAttrs cannot be constructed from the Python - # API. Therefore, declaring the Expected output first, so that - # the attributes can be used for the non-normalized Before. + # The CallTIRWithGradAttrs is difficult to construct in the Python + # API, so it is more convenient to declare the expected one first + # and reuse its attributes @I.ir_module class Expected: @R.function @@ -342,30 +402,49 @@ def f_grad( with_grad_attrs = Expected["main"].body.blocks[0].bindings[1].value.attrs - @I.ir_module - class Before: - @R.function - def main(A: R.Tensor([16], "float32")): - cls = Before - args = (A,) - return relax.Call( - tvm.ir.Op.get("relax.call_tir_with_grad"), - [cls.multiply_by_two, args], - attrs=with_grad_attrs, - sinfo_args=[A.struct_info], - ) - - @T.prim_func(private=True) - def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): - for i in range(16): - B[i] = A[i] * 2.0 - - @T.prim_func(private=True) - def f_grad( - A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32"), Grad: T.Buffer(16, "float32") - ): - for i in range(16): - Grad[i] = 2.0 + @T.prim_func(private=True) + def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + for i in range(16): + B[i] = A[i] * 2.0 + + @T.prim_func(private=True) + def f_grad( + A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32"), Grad: T.Buffer(16, "float32") + ): + for i in range(16): + Grad[i] = 2.0 + + bb = relax.BlockBuilder() + multiply_gv = bb.add_func(multiply_by_two, "multiply_by_two") + bb.add_func(f_grad, "f_grad") + a = relax.Var("a", R.Tensor([16], "float32")) + tup = relax.Var("tup", R.Tuple(R.Tensor([16], "float32"))) + ret = relax.Var("ret", R.Tensor([16], "float32")) + main_func = relax.Function( + [a], + relax.SeqExpr( + [ + relax.BindingBlock( + [ + relax.VarBinding(tup, relax.Tuple([a])), + relax.VarBinding( + ret, + relax.Call( + tvm.ir.Op.get("relax.call_tir_with_grad"), + [multiply_gv, tup], + attrs=with_grad_attrs, + sinfo_args=[a.struct_info], + ), + ), + ] + ) + ], + ret, + ), + R.Tensor([16], "float32"), + ).with_attr("global_symbol", "main") + bb.add_func(main_func, "main") + Before = bb.get() After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before) diff --git a/tests/python/relax/test_vm_alloc_storage_with_scope.py b/tests/python/relax/test_vm_alloc_storage_with_scope.py index ca1802b1f527..17ae449a5d6a 100644 --- a/tests/python/relax/test_vm_alloc_storage_with_scope.py +++ b/tests/python/relax/test_vm_alloc_storage_with_scope.py @@ -44,7 +44,7 @@ def add( T.writes(output[v_ax0, v_ax1]) output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1] - @R.function + @R.function(pure=False) def main(x: R.Tensor((2, 2), dtype="float32")): cls = Module storage = R.vm.alloc_storage( diff --git a/tests/python/relax/test_vm_codegen_only.py b/tests/python/relax/test_vm_codegen_only.py index 0d461f0713c2..a93eb8350ce2 100644 --- a/tests/python/relax/test_vm_codegen_only.py +++ b/tests/python/relax/test_vm_codegen_only.py @@ -42,7 +42,7 @@ def codegen(mod, target, exec_mode="bytecode"): def test_vm_copy(exec_mode): @tvm.script.ir_module class TestVMMove: - @R.function + @R.function(pure=False) def foo(x: R.Tensor((3, 4), "float32")): R.func_attr({"global_symbol": "foo"}) z = R.call_packed("vm.builtin.copy", x, sinfo_args=(R.Tensor((3, 4), dtype="float32"))) @@ -61,7 +61,7 @@ def foo(x: R.Tensor((3, 4), "float32")): def test_vm_to_device(exec_mode): @tvm.script.ir_module class TestVMToDevice: - @R.function + @R.function(pure=False) def foo(x: R.Tensor((3, 4), "float32")): R.func_attr({"global_symbol": "foo"}) # Copy x to the first cpu: device_type=1 and device_id=0. @@ -110,7 +110,7 @@ def main(x: R.Tensor(ndim=2, dtype="float32")) -> R.Tensor(ndim=2, dtype="float3 def test_vm_exec_serialize_export_library(exec_mode): @tvm.script.ir_module class TestVMMove: - @R.function + @R.function(pure=False) def foo(x: R.Tensor((3, 4), "float32")): R.func_attr({"global_symbol": "foo"}) z = R.call_packed("vm.builtin.copy", x, sinfo_args=(R.Tensor((3, 4), dtype="float32"))) @@ -133,7 +133,7 @@ def foo(x: R.Tensor((3, 4), "float32")): def test_if_cond(exec_mode): @tvm.script.ir_module class TestVMCompileIf: - @R.function + @R.function(pure=False) def ife(cond: R.Tensor((), "bool"), x: R.Tensor((3, 4), "float32")) -> R.Tensor: R.func_attr({"global_symbol": "ife"}) if cond: @@ -183,7 +183,7 @@ def main(x: R.Tensor(ndim=2, dtype="float32")): def test_vm_const_as_call_arg(exec_mode): @tvm.script.ir_module class TestVMConstAsCallArg: - @R.function + @R.function(pure=False) def main(x: R.Tensor(ndim=2, dtype="float32")): R.func_attr({"global_symbol": "main"}) a = R.call_packed( @@ -219,7 +219,7 @@ def test_shape_check_builtin(exec_mode): @tvm.script.ir_module class TestVMShapeCheck: - @R.function + @R.function(pure=False) def main(x: R.Tensor(["n", "m"], "float32")) -> R.Shape(ndim=3): R.func_attr({"global_symbol": "main"}) n = T.int64() @@ -338,7 +338,7 @@ def main(): def test_vm_builtin_reshape(exec_mode): @tvm.script.ir_module class TestVMBuiltinReshape: - @R.function + @R.function(pure=False) def main(x: R.Tensor((3, 4), "float32")): R.func_attr({"global_symbol": "main"}) y = R.call_packed( @@ -383,7 +383,8 @@ def full1(T_full: T.Buffer((T.int64(4),), "float32")): T.writes(T_full[v_ax0]) T_full[v_ax0] = T.float32(1) - @R.function + # PrimFuncs called directly are treated as impure + @R.function(pure=False) def main() -> R.Tensor((4,), dtype="float32"): R.func_attr({"global_symbol": "main"}) cls = TestKillObject @@ -425,7 +426,7 @@ def main() -> R.Tensor((4,), dtype="float32"): def test_preserve_trivial_bindings(exec_mode): @I.ir_module class mod: - @R.function + @R.function(pure=False) def main(): callback = R.ExternFunc("test.vm.check_if_defined") diff --git a/tests/python/relax/test_vm_codegen_tir.py b/tests/python/relax/test_vm_codegen_tir.py index d82715a3946f..21e192955b93 100644 --- a/tests/python/relax/test_vm_codegen_tir.py +++ b/tests/python/relax/test_vm_codegen_tir.py @@ -34,7 +34,7 @@ def get_tir_mod(mod): def test_add(): @tvm.script.ir_module class Before: - @R.function + @R.function(pure=False) def foo(x: R.Tensor): R.func_attr({"global_symbol": "foo"}) z = R.call_packed("test.vm.add", x, x, sinfo_args=(R.Tensor)) @@ -71,7 +71,7 @@ def shape_func(H: T.Buffer(T.int64(4), "int64")): # generated compute function H[T.int64(0)] = H[T.int64(0)] + T.int64(1) - @R.function + @R.function(pure=False) def foo(x: R.Tensor): R.func_attr({"global_symbol": "foo"}) _ = Before.shape_func(x) @@ -104,7 +104,7 @@ def __vmtir__foo(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle): def test_if_cond(): @tvm.script.ir_module class Before: - @R.function + @R.function(pure=False) def ife(cond: R.Tensor((), "bool"), x: R.Tensor) -> R.Tensor: R.func_attr({"global_symbol": "ife"}) if cond: @@ -191,7 +191,7 @@ def __vmtir__main(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle): def test_const_call(): @tvm.script.ir_module class Before: - @R.function + @R.function(pure=False) def main(x: R.Tensor): R.func_attr({"global_symbol": "main"}) y = R.const([1, 2]) diff --git a/tests/python/relax/test_vm_cuda_graph.py b/tests/python/relax/test_vm_cuda_graph.py index 8406b9df15d3..6a20b6b1f892 100644 --- a/tests/python/relax/test_vm_cuda_graph.py +++ b/tests/python/relax/test_vm_cuda_graph.py @@ -27,7 +27,7 @@ @I.ir_module class Module: - @R.function + @R.function(pure=False) def main(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="float32"): cls = Module R.func_attr({"global_symbol": "main"}) @@ -63,7 +63,7 @@ def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object): gv: R.Tuple(R.Object, R.Object) = (storage, storage1) return gv - @R.function + @R.function(pure=False) def cuda_graph_capture(alloc: R.Tensor((16, 16), dtype="float32"), storage1: R.Object, storage: R.Object) -> R.Tuple(R.Tensor((16, 16), dtype="float32")): cls = Module R.func_attr({"global_symbol": "cuda_graph_capture"}) From dbf64d0225a8067f7792a9e4ca1e80a3dba0d222 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 14 Feb 2024 16:28:07 -0500 Subject: [PATCH 02/58] Correct packed funcs in NN frontend --- python/tvm/relax/frontend/nn/modules.py | 42 ++++++++----------- .../python/relax/test_frontend_nn_modules.py | 37 +++++++++------- 2 files changed, 39 insertions(+), 40 deletions(-) diff --git a/python/tvm/relax/frontend/nn/modules.py b/python/tvm/relax/frontend/nn/modules.py index 1579c5b512c5..477e0fe882f9 100644 --- a/python/tvm/relax/frontend/nn/modules.py +++ b/python/tvm/relax/frontend/nn/modules.py @@ -599,15 +599,12 @@ def emit_init(self, name_hint: str, bb: rx.BlockBuilder): # pylint: disable=arg init_shape = rx.ShapeExpr([self.init_seq_len] + self.unit_shape) return [ bb.emit( - rx.Call( - ir.Op.get("relax.call_pure_packed"), - args=[ - rx.extern("vm.builtin.attention_kv_cache_create"), - rx.op.zeros(init_shape, self.dtype), - init_shape, - rx.PrimValue(0), - ], - sinfo_args=[rx.ObjectStructInfo()], + rx.op.call_pure_packed( + "vm.builtin.attention_kv_cache_create", + rx.op.zeros(init_shape, self.dtype), + init_shape, + rx.PrimValue(0), + sinfo_args=rx.ObjectStructInfo(), ), name_hint=name_hint, ) @@ -675,14 +672,11 @@ def view(self, seq_len: tir.Var) -> Tensor: shape = rx.ShapeExpr([seq_len] + self.unit_shape) return Tensor( _expr=rx.BlockBuilder.current().emit( - rx.Call( - ir.Op.get("relax.call_pure_packed"), - args=[ - rx.extern("vm.builtin.attention_kv_cache_view"), - self.cache, - shape, - ], - sinfo_args=[rx.TensorStructInfo(shape, self.dtype)], + rx.op.call_pure_packed( + "vm.builtin.attention_kv_cache_view", + self.cache, + shape, + sinfo_args=rx.TensorStructInfo(shape, self.dtype), ) ) ) @@ -702,14 +696,12 @@ def append(self, new_element: Tensor) -> None: f'but got "{new_element.dtype}"' ) self.cache = rx.BlockBuilder.current().emit( - rx.Call( - ir.Op.get("relax.call_pure_packed"), - args=[ - rx.extern("vm.builtin.attention_kv_cache_append"), - self.cache, - new_element._expr, - ], - sinfo_args=[rx.ObjectStructInfo()], + rx.op.call_inplace_packed( + "vm.builtin.attention_kv_cache_append", + self.cache, + new_element._expr, + inplace_indices=[0], + sinfo_args=rx.ObjectStructInfo(), ) ) diff --git a/tests/python/relax/test_frontend_nn_modules.py b/tests/python/relax/test_frontend_nn_modules.py index 9b357114d351..5c8c9edd8369 100644 --- a/tests/python/relax/test_frontend_nn_modules.py +++ b/tests/python/relax/test_frontend_nn_modules.py @@ -297,9 +297,10 @@ def forward( lv1: R.Tensor((n, 32, h - 2, w - 2), dtype="float32") = R.nn.conv2d(x, weight) lv2: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(bias, R.shape([1, 32, 1, 1])) conv2d: R.Tensor((n, 32, h - 2, w - 2), dtype="float32") = R.add(lv1, lv2) - gv1: R.Tuple( - R.Tensor((n, 32, h - 2, w - 2), dtype="float32"), R.Tuple(R.Object) - ) = conv2d, (_io,) + gv1: R.Tuple(R.Tensor((n, 32, h - 2, w - 2), dtype="float32"), R.Tuple(R.Object)) = ( + conv2d, + (_io,), + ) R.output(gv1) return gv1 @@ -463,9 +464,10 @@ def forward( get_timestep_embedding: R.Tensor((3, 10), dtype="float32") = R.astype( lv11, dtype="float32" ) - gv1: R.Tuple( - R.Tensor((3, 10), dtype="float32"), R.Tuple(R.Object) - ) = get_timestep_embedding, (_io,) + gv1: R.Tuple(R.Tensor((3, 10), dtype="float32"), R.Tuple(R.Object)) = ( + get_timestep_embedding, + (_io,), + ) R.output(gv1) return gv1 @@ -489,7 +491,7 @@ def _initialize_effect() -> R.Tuple(R.Object, R.Object): lv, R.shape([8, 2, 4]), R.prim_value(0), - sinfo_args=(R.Object,), + sinfo_args=[R.Object()], ) lv1 = _io, cache gv = lv1 @@ -502,8 +504,12 @@ def forward( ) -> R.Tuple(R.Tensor((4, 2, 4), dtype="float32"), R.Tuple(R.Object, R.Object)): R.func_attr({"num_input": 3}) with R.dataflow(): - lv2: R.Object = R.call_pure_packed( - "vm.builtin.attention_kv_cache_append", cache, x, sinfo_args=(R.Object,) + lv2: R.Object = R.call_inplace_packed( + "vm.builtin.attention_kv_cache_append", + cache, + x, + inplace_indices=[0], + sinfo_args=[R.Object()], ) lv3: R.Tensor((4, 2, 4), dtype="float32") = R.call_pure_packed( "vm.builtin.attention_kv_cache_view", @@ -511,9 +517,10 @@ def forward( R.shape([4, 2, 4]), sinfo_args=(R.Tensor((4, 2, 4), dtype="float32"),), ) - gv1: R.Tuple( - R.Tensor((4, 2, 4), dtype="float32"), R.Tuple(R.Object, R.Object) - ) = lv3, (_io, lv2) + gv1: R.Tuple(R.Tensor((4, 2, 4), dtype="float32"), R.Tuple(R.Object, R.Object)) = ( + lv3, + (_io, lv2), + ) R.output(gv1) return gv1 @@ -585,9 +592,9 @@ def forward( reshape2: R.Tensor((2, 77, 10, 64), dtype="float32") = R.reshape( matmul2, R.shape([2, 77, 10, 64]) ) - scaled_dot_product_attention: R.Tensor( - (2, 4096, 10, 64), dtype="float32" - ) = R.nn.attention(reshape, reshape1, reshape2, scale=None, causal_mask=None) + scaled_dot_product_attention: R.Tensor((2, 4096, 10, 64), dtype="float32") = ( + R.nn.attention(reshape, reshape1, reshape2, scale=None, causal_mask=None) + ) reshape3: R.Tensor((2, 4096, 640), dtype="float32") = R.reshape( scaled_dot_product_attention, R.shape([2, 4096, 640]) ) From 471148c0dfe444ebff8dbe8da23051d2885f5500 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 14 Feb 2024 17:03:51 -0500 Subject: [PATCH 03/58] Support the check_well_formed optional argument to I.ir_module --- python/tvm/script/parser/core/entry.py | 39 +-- python/tvm/script/parser/ir/entry.py | 31 ++- ...test_distributed_transform_lower_distir.py | 4 +- .../test_transform_normalize_global_var.py | 53 ++-- ...ansform_operator_specific_normalization.py | 243 ++++++------------ 5 files changed, 146 insertions(+), 224 deletions(-) diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index 3b4333029388..a5d9e43676ec 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -45,23 +45,11 @@ def scan_macro(program: Union[Any, str], extra_vars: Dict[str, Any] = None) -> A return source, closure_vars -def find_decorator_annotation(node: doc.Module, annotation: str, default: bool = True) -> bool: - """ - Check the value of given annotation (argument name) in the function decorator. - Returns the value of the annotation if present, otherwise giving the default value. - """ - # Note: A Module body is always a list containing a single ClassDef. - # The ClassDef has a decorator list - for dec in node.body[0].decorator_list: - if not isinstance(dec, doc.Call) or dec.func.attr != "ir_module": - continue - for keyword in dec.keywords: - if keyword.arg == annotation: - return keyword.value.value - return default - - -def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) -> Any: +def parse( + program: Union[doc.AST, Any, str], + extra_vars: Dict[str, Any] = None, + check_well_formed: bool = True, +) -> Any: """Register a method for a operand type, AST operator node and operand index. Parameters @@ -72,6 +60,9 @@ def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) extra_vars : Dict[str, Any] The extra variable table for parsing. + check_well_formed : bool + Whether to check well-formedness after parsing. + Returns ------- func : Any @@ -97,13 +88,9 @@ def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) parser.report_error(err.node, err.args[0]) ret = builder.get() # well-formedness check will ignore any non-Relax functions - if isinstance(ret, IRModule): - # note: use the walrus operator (:=) once the project Python version - # supports it, would be more concise - source_ast = source.as_ast() - if find_decorator_annotation(source_ast, "check_well_formed") and not well_formed(ret): - parser.report_error( - source_ast, - err="Program containing Relax functions is not well-formed", - ) + if check_well_formed and isinstance(ret, IRModule) and not well_formed(ret): + parser.report_error( + source.as_ast(), + err="Program containing Relax functions is not well-formed", + ) return ret diff --git a/python/tvm/script/parser/ir/entry.py b/python/tvm/script/parser/ir/entry.py index 5878a1ce55cc..0a75f846a5f1 100644 --- a/python/tvm/script/parser/ir/entry.py +++ b/python/tvm/script/parser/ir/entry.py @@ -17,14 +17,17 @@ """The entry point of TVM parser for ir module.""" import inspect -from typing import Type +from typing import Optional, Type from tvm.ir import IRModule from .._core import parse, utils -def ir_module(mod: Type) -> IRModule: +# this formulation allows us to support having @I.ir_module +# appear as a decorator by itself or to have optional arguments +# like @I.ir_module(check_well_formed=False) +def ir_module(mod: Optional[Type] = None, check_well_formed: bool = True) -> IRModule: """The parsing method for ir module, by using `@ir_module` as decorator. Parameters @@ -32,17 +35,29 @@ def ir_module(mod: Type) -> IRModule: mod : Type The class to be parsed as ir module. + check_well_formed : bool + Whether to check well-formedness during parsing. + Returns ------- ir_module : IRModule The parsed ir module. """ - if not inspect.isclass(mod): - raise TypeError(f"Expect a class, but got: {mod}") - - m = parse(mod, utils.inspect_class_capture(mod)) - setattr(m, "__name__", mod.__name__) - return m + def decorator_wrapper(mod): + if not inspect.isclass(mod): + raise TypeError(f"Expect a class, but got: {mod}") + m = parse(mod, utils.inspect_class_capture(mod), check_well_formed=check_well_formed) + setattr(m, "__name__", mod.__name__) + return m + + if mod is not None: + # if there are no optional args given, this will directly invoke the wrapper + return decorator_wrapper(mod) + else: + # if there is a optional arg given, it returns the wrapper function + # as a new decorator and applies it + setattr(decorator_wrapper, "dispatch_token", "ir") + return decorator_wrapper setattr(ir_module, "dispatch_token", "ir") diff --git a/tests/python/relax/distributed/test_distributed_transform_lower_distir.py b/tests/python/relax/distributed/test_distributed_transform_lower_distir.py index 3df65b3ea6ff..54f7fa3c613a 100644 --- a/tests/python/relax/distributed/test_distributed_transform_lower_distir.py +++ b/tests/python/relax/distributed/test_distributed_transform_lower_distir.py @@ -136,7 +136,7 @@ def foo( ) return lv3 - @I.ir_module + @I.ir_module(check_well_formed=False) class LoweredMLP: I.module_attrs({"device_num": 10}) I.module_global_infos( @@ -331,7 +331,7 @@ def foo( ) return lv4 - @I.ir_module + @I.ir_module(check_well_formed=False) class LoweredMLPWithTuple: I.module_attrs({"device_num": 10}) I.module_global_infos( diff --git a/tests/python/relax/test_transform_normalize_global_var.py b/tests/python/relax/test_transform_normalize_global_var.py index 53ce6a5a8b19..0dddab02edcf 100644 --- a/tests/python/relax/test_transform_normalize_global_var.py +++ b/tests/python/relax/test_transform_normalize_global_var.py @@ -15,31 +15,31 @@ # specific language governing permissions and limitations # under the License. import pytest -import numpy as np import tvm import tvm.testing from tvm import relax +from tvm import tir from tvm.ir.base import assert_structural_equal import tvm.script from tvm.script import tir as T, relax as R, ir as I +@pytest.mark.skip_well_formed_check_before_transform def test_normalize_relax_function(): - # parser will check well-formedness so we can't use it to construct this example - bb = relax.BlockBuilder() - f = relax.Function( - [], - relax.SeqExpr([], relax.Constant(tvm.nd.array(np.int32(1)), R.Tensor((), "int32"))), - R.Tensor((), "int32"), - ) - f_gv = bb.add_func(bb.normalize(f).without_attr("global_symbol"), "f") - with bb.function("f1", []): - gv = bb.emit(f_gv(), "gv") - bb.emit_func_output(gv) - Before = bb.get() - Before.update_func(Before.get_global_var("f1"), Before["f1"].with_attr("global_symbol", "f")) + @I.ir_module(check_well_formed=False) + class Before: + @R.function(private=True) + def f(): + return R.const(1, "int32") + + @R.function + def f1(): + R.func_attr({"global_symbol": "f"}) + cls = Before + gv: R.Tensor((), dtype="int32") = cls.f() + return gv @I.ir_module class Expected: @@ -62,19 +62,18 @@ def f1(): @pytest.mark.skip_well_formed_check_before_transform def test_normalize_tir_function(): - # parser will check well-formedness so we can't use it to construct this example - bb = relax.BlockBuilder() - - @T.prim_func(private=True) - def f(x: T.Buffer((1,), "int32")): - x[0] = T.int32(0) - - f_gv = bb.add_func(f, "f") - with bb.function("f1", []): - gv = bb.emit(R.call_tir(f_gv, (), R.Tensor((1,), dtype="int32"))) - bb.emit_func_output(gv) - Before = bb.get() - Before.update_func(Before.get_global_var("f1"), Before["f1"].with_attr("global_symbol", "f")) + @I.ir_module(check_well_formed=False) + class Before: + @T.prim_func(private=True) + def f(x: T.Buffer((1,), "int32")): + x[0] = T.int32(0) + + @R.function + def f1(): + R.func_attr({"global_symbol": "f"}) + cls = Before + gv: R.Tensor((), dtype="int32") = R.call_tir(cls.f, (), R.Tensor((1,), dtype="int32")) + return gv @I.ir_module class Expected: diff --git a/tests/python/relax/test_transform_operator_specific_normalization.py b/tests/python/relax/test_transform_operator_specific_normalization.py index 8209db40f458..cef2afd11a80 100644 --- a/tests/python/relax/test_transform_operator_specific_normalization.py +++ b/tests/python/relax/test_transform_operator_specific_normalization.py @@ -95,14 +95,11 @@ def func(A: R.Tensor): def test_normalization_applied_during_cpp_mutator(custom_op): """FNormalize is applied by relax::ExprMutator subclasses""" - # can't use parser because it will check well-formedness proactively - @R.function - def main(A: R.Tensor): - return relax.Call(custom_op, [A]) - - bb = relax.BlockBuilder() - bb.add_func(main, "main") - Before = bb.get() + @I.ir_module(check_well_formed=False) + class Before: + @R.function + def main(A: R.Tensor): + return relax.Call(custom_op, [A]) @I.ir_module class Expected: @@ -158,13 +155,11 @@ def test_un_normalized_call_node_is_ill_formed(custom_op, define_normalization): FNormalize has no corresponding check applied. """ - @R.function - def main(A: R.Tensor): - return relax.Call(custom_op, [A]) - - bb = relax.BlockBuilder() - bb.add_func(main, "main") - Module = bb.get() + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor): + return relax.Call(custom_op, [A]) if define_normalization: assert not relax.analysis.well_formed(Module) @@ -176,41 +171,22 @@ def main(A: R.Tensor): def test_normalize_to_inline_tuple_for_call_tir(custom_op): """FNormalize in-lines the argument tuple for R.call_tir""" - @T.prim_func(private=True) - def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): - for i in range(16): - B[i] = A[i] * 2.0 - - bb = relax.BlockBuilder() - primfunc_gv = bb.add_func(multiply_by_two, "multiply_by_two") - a = relax.Var("a", R.Tensor([16], "float32")) - tup = relax.Var("tup", R.Tuple(R.Tensor([16], "float32"))) - ret = relax.Var("ret", R.Tensor([16], "float32")) - # can't even use the block builder to assemble because it will normalize! - main_func = relax.Function( - [a], - relax.SeqExpr( - [ - relax.BindingBlock( - [ - relax.VarBinding(tup, relax.Tuple([a])), - relax.VarBinding( - ret, - relax.Call( - tvm.ir.Op.get("relax.call_tir"), - [primfunc_gv, tup], - sinfo_args=[a.struct_info], - ), - ), - ] - ) - ], - ret, - ), - R.Tensor([16], "float32"), - ).with_attr("global_symbol", "main") - bb.add_func(main_func, "main") - Before = bb.get() + @I.ir_module(check_well_formed=False) + class Before: + @R.function + def main(A: R.Tensor([16], "float32")): + cls = Before + args = (A,) + return relax.Call( + tvm.ir.Op.get("relax.call_tir"), + [cls.multiply_by_two, args], + sinfo_args=[A.struct_info], + ) + + @T.prim_func(private=True) + def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + for i in range(16): + B[i] = A[i] * 2.0 @I.ir_module class Expected: @@ -243,39 +219,21 @@ def test_normalize_argument_to_inline_tuple_for_call_tir(custom_op): argument tuple is provided as a relax function argument. """ - @T.prim_func(private=True) - def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): - for i in range(16): - B[i] = A[i] * 2.0 - - bb = relax.BlockBuilder() - primfunc_gv = bb.add_func(multiply_by_two, "multiply_by_two") - - args = relax.Var("args", R.Tuple(R.Tensor([16], "float32"))) - ret = relax.Var("ret", R.Tensor([16], "float32")) - main_func = relax.Function( - [args], - relax.SeqExpr( - [ - relax.BindingBlock( - [ - relax.VarBinding( - ret, - relax.Call( - tvm.ir.Op.get("relax.call_tir"), - [primfunc_gv, args], - sinfo_args=[R.Tensor([16], "float32")], - ), - ) - ] - ), - ], - ret, - ), - args[0].struct_info, - ).with_attr("global_symbol", "main") - bb.add_func(main_func, "main") - Before = bb.get() + @I.ir_module(check_well_formed=False) + class Before: + @R.function + def main(args: R.Tuple([R.Tensor([16], "float32")])): + cls = Before + return relax.Call( + tvm.ir.Op.get("relax.call_tir"), + [cls.multiply_by_two, args], + sinfo_args=[args[0].struct_info], + ) + + @T.prim_func(private=True) + def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + for i in range(16): + B[i] = A[i] * 2.0 @I.ir_module class Expected: @@ -326,41 +284,23 @@ def multiply_by_two(A: T.Buffer(16, "float32")): inplace_attrs = Expected["main"].body.blocks[0].bindings[1].value.attrs - @T.prim_func(private=True) - def multiply_by_two(A: T.Buffer(16, "float32")): - for i in range(16): - A[i] = A[i] * 2.0 - - bb = relax.BlockBuilder() - primfunc_gv = bb.add_func(multiply_by_two, "multiply_by_two") - a = relax.Var("a", R.Tensor([16], "float32")) - tup = relax.Var("tup", R.Tuple(R.Tensor([16], "float32"))) - ret = relax.Var("ret", R.Tensor([16], "float32")) - main_func = relax.Function( - [a], - relax.SeqExpr( - [ - relax.BindingBlock( - [ - relax.VarBinding(tup, relax.Tuple([a])), - relax.VarBinding( - ret, - relax.Call( - tvm.ir.Op.get("relax.call_tir_inplace"), - [primfunc_gv, tup], - attrs=inplace_attrs, - sinfo_args=[a.struct_info], - ), - ), - ] - ) - ], - ret, - ), - R.Tensor([16], "float32"), - ).with_attr("global_symbol", "main") - bb.add_func(main_func, "main") - Before = bb.get() + @I.ir_module(check_well_formed=False) + class Before: + @R.function + def main(A: R.Tensor([16], "float32")): + cls = Before + args = (A,) + return relax.Call( + tvm.ir.Op.get("relax.call_tir_inplace"), + [cls.multiply_by_two, args], + attrs=inplace_attrs, + sinfo_args=[A.struct_info], + ) + + @T.prim_func(private=True) + def multiply_by_two(A: T.Buffer(16, "float32")): + for i in range(16): + A[i] = A[i] * 2.0 After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before) @@ -402,49 +342,30 @@ def f_grad( with_grad_attrs = Expected["main"].body.blocks[0].bindings[1].value.attrs - @T.prim_func(private=True) - def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): - for i in range(16): - B[i] = A[i] * 2.0 - - @T.prim_func(private=True) - def f_grad( - A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32"), Grad: T.Buffer(16, "float32") - ): - for i in range(16): - Grad[i] = 2.0 - - bb = relax.BlockBuilder() - multiply_gv = bb.add_func(multiply_by_two, "multiply_by_two") - bb.add_func(f_grad, "f_grad") - a = relax.Var("a", R.Tensor([16], "float32")) - tup = relax.Var("tup", R.Tuple(R.Tensor([16], "float32"))) - ret = relax.Var("ret", R.Tensor([16], "float32")) - main_func = relax.Function( - [a], - relax.SeqExpr( - [ - relax.BindingBlock( - [ - relax.VarBinding(tup, relax.Tuple([a])), - relax.VarBinding( - ret, - relax.Call( - tvm.ir.Op.get("relax.call_tir_with_grad"), - [multiply_gv, tup], - attrs=with_grad_attrs, - sinfo_args=[a.struct_info], - ), - ), - ] - ) - ], - ret, - ), - R.Tensor([16], "float32"), - ).with_attr("global_symbol", "main") - bb.add_func(main_func, "main") - Before = bb.get() + @I.ir_module(check_well_formed=False) + class Before: + @R.function + def main(A: R.Tensor([16], "float32")): + cls = Before + args = (A,) + return relax.Call( + tvm.ir.Op.get("relax.call_tir_with_grad"), + [cls.multiply_by_two, args], + attrs=with_grad_attrs, + sinfo_args=[A.struct_info], + ) + + @T.prim_func(private=True) + def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + for i in range(16): + B[i] = A[i] * 2.0 + + @T.prim_func(private=True) + def f_grad( + A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32"), Grad: T.Buffer(16, "float32") + ): + for i in range(16): + Grad[i] = 2.0 After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before) From ca0ce7a82497c53b9166890e69d4b0a648e09e02 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 14 Feb 2024 17:08:56 -0500 Subject: [PATCH 04/58] Also check well-formedness in TIR --- python/tvm/script/parser/core/entry.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index a5d9e43676ec..1cbd949e3e9c 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -18,7 +18,8 @@ import inspect from typing import Any, Dict, Union -from ....relax.analysis import well_formed +from ....relax.analysis import well_formed as relax_well_formed +from ....tir.analysis import verify_well_formed as tir_well_formed from ....ir.module import IRModule from ...ir_builder import IRBuilder from . import doc @@ -87,8 +88,12 @@ def parse( except ParserError as err: parser.report_error(err.node, err.args[0]) ret = builder.get() - # well-formedness check will ignore any non-Relax functions - if check_well_formed and isinstance(ret, IRModule) and not well_formed(ret): + # check well-formedness in both Relax and TIR + if ( + check_well_formed + and isinstance(ret, IRModule) + and not (relax_well_formed(ret) and tir_well_formed(ret, assert_mode=False)) + ): parser.report_error( source.as_ast(), err="Program containing Relax functions is not well-formed", From 76d79a6db9bd742560d126245fc527450da6aec6 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 14 Feb 2024 17:34:01 -0500 Subject: [PATCH 05/58] Enable normalization for individual Relax functions and PrimFuncs --- python/tvm/relax/block_builder.py | 11 ++++++++--- python/tvm/script/parser/core/entry.py | 18 +++++++++--------- python/tvm/script/parser/relax/entry.py | 4 ++-- python/tvm/script/parser/tir/entry.py | 6 ++++-- tests/python/relax/test_analysis.py | 8 ++++---- ...ransform_operator_specific_normalization.py | 9 +++++---- tests/python/relax/test_tvmscript_parser.py | 16 ++++++++-------- tests/python/tir-base/test_tir_specialize.py | 2 +- 8 files changed, 41 insertions(+), 33 deletions(-) diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index 330585599d08..37866840bd68 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -35,11 +35,12 @@ class FunctionScope(object): """Auxiliary scope for function""" - def __init__(self, block_builder, name, params, attrs): + def __init__(self, block_builder, name, params, attrs, is_pure): self._bb = block_builder self._name = name self._params = params self._attrs = attrs + self._is_pure = is_pure # Blocks that have been collected within the function self._blocks = [] @@ -208,6 +209,7 @@ def function( name: str, params: Optional[Union[Var, Tuple, List[Var]]] = None, attrs: Optional[Dict[str, Object]] = None, + pure: bool = True, private: bool = False, ) -> FunctionScope: """Annotate a Relax function. @@ -225,6 +227,9 @@ def function( attrs : Dict[str, Object], optional The function attrs + pure : bool, optional + Whether the function is annotated as pure. + private : bool, optional Whether the function is annotated as private. If the function is private, it will not have a global symbol attribute. @@ -254,7 +259,7 @@ def function( if not private: attrs["global_symbol"] = name - return FunctionScope(self, name, params, attrs) + return FunctionScope(self, name, params, attrs, is_pure=pure) def testing_scope(self, def_vars: List[tir.Var]) -> TestingScope: """Start a scope for unit-testing purposes. @@ -640,7 +645,7 @@ def emit_func_output( # do not specify ret_struct_info and let constructor deduce # from seqe.struct_info - func = rx.Function(self._func._params, seqe) + func = rx.Function(self._func._params, seqe, is_pure=self._func._is_pure) for key, value in self._func._attrs.items(): func = func.with_attr(key, value) self.end_scope() diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index 1cbd949e3e9c..0017730f1289 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -89,13 +89,13 @@ def parse( parser.report_error(err.node, err.args[0]) ret = builder.get() # check well-formedness in both Relax and TIR - if ( - check_well_formed - and isinstance(ret, IRModule) - and not (relax_well_formed(ret) and tir_well_formed(ret, assert_mode=False)) - ): - parser.report_error( - source.as_ast(), - err="Program containing Relax functions is not well-formed", - ) + if check_well_formed: + check_ret = ret + if not isinstance(check_ret, IRModule): + check_ret = IRModule.from_expr(ret) + if not relax_well_formed(check_ret) or not tir_well_formed(check_ret, assert_mode=False): + parser.report_error( + source.as_ast(), + err="Program is not well-formed", + ) return ret diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py index a82cbeb16349..a3b391637cb4 100644 --- a/python/tvm/script/parser/relax/entry.py +++ b/python/tvm/script/parser/relax/entry.py @@ -52,7 +52,7 @@ # appear as a decorator by itself or to have optional arguments # like @R.function(pure=False) def function( - f: Optional[FType] = None, pure: bool = True, private: bool = False + f: Optional[FType] = None, pure: bool = True, private: bool = False, check_well_formed=True ) -> Union[Function, FType]: # pylint: disable=unused-argument # (pure and private aren't used here, but are used later in parsing) @@ -66,7 +66,7 @@ def decorator_wrapper(f): raise TypeError(f"Expect a function, but got: {f}") if utils.is_defined_in_class(orig_stack, f): return f - return parse(f, utils.inspect_function_capture(f)) + return parse(f, utils.inspect_function_capture(f), check_well_formed=check_well_formed) if f is not None: # if there are no optional args given, this will directly invoke the wrapper diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index d2fb070aaab1..79eb88dfc102 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -26,7 +26,9 @@ from ..core.parser import Parser, ScriptMacro -def prim_func(func: Optional[Callable] = None, private: bool = False) -> Union[PrimFunc, Callable]: +def prim_func( + func: Optional[Callable] = None, private: bool = False, check_well_formed=True +) -> Union[PrimFunc, Callable]: """The parsing method for tir prim func, by using `@prim_func` as decorator. Parameters @@ -60,7 +62,7 @@ def decorator_wrapper(func): raise TypeError(f"Expect a function, but got: {func}") if utils.is_defined_in_class(outer_stack, func): return func - f = parse(func, utils.inspect_function_capture(func)) + f = parse(func, utils.inspect_function_capture(func), check_well_formed=check_well_formed) setattr(f, "__name__", func.__name__) return f diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py index 178869442249..28ca13ad8991 100644 --- a/tests/python/relax/test_analysis.py +++ b/tests/python/relax/test_analysis.py @@ -157,7 +157,7 @@ def test_binding_block_keep_impure_without_dataflow(): contain side effects. """ - @R.function(private=True) + @R.function(private=True, pure=False) def before(x: R.Tensor((32, 32), "float32")) -> R.Tensor: lv0 = x y = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) @@ -185,7 +185,7 @@ def test_binding_block_keep_pure_func_used_only_for_impure(): it was required to evaluate the packed function. """ - @R.function(private=True) + @R.function(private=True, pure=False) def before(x: R.Tensor((32, 32), "int32")): y = x * R.const(2) z = R.call_packed( @@ -335,14 +335,14 @@ def expected(x: R.Tensor((32, 32), "float32")) -> R.Tensor: def test_retain_impure_calls_unused_in_binding_block(): """An impure call may have side effects, and must be kept""" - @R.function + @R.function(pure=False) def before(x: R.Tensor((32, 32), "float32")) -> R.Tensor: lv0 = x unused0 = R.call_packed("my_impure_call", x, sinfo_args=R.Tensor((32, 32), dtype="float32")) unused1 = R.call_dps_packed("my_unused_call", (lv0,), R.Tensor((32, 32), dtype="float32")) return lv0 - @R.function + @R.function(pure=False) def expected(x: R.Tensor((32, 32), "float32")) -> R.Tensor: lv0 = x unused0 = R.call_packed("my_impure_call", x, sinfo_args=R.Tensor((32, 32), dtype="float32")) diff --git a/tests/python/relax/test_transform_operator_specific_normalization.py b/tests/python/relax/test_transform_operator_specific_normalization.py index cef2afd11a80..beb1ee85946a 100644 --- a/tests/python/relax/test_transform_operator_specific_normalization.py +++ b/tests/python/relax/test_transform_operator_specific_normalization.py @@ -74,11 +74,12 @@ def test_normalization_suppressed_for_tvmscript(custom_op): """FNormalize isn't applied when parsing TVMScript TVMScript should be able to produce un-normalized Relax IR for - specifying test cases, and to ensure that no changes occur when - performing a round-trip through TVMScript. + specifying test cases if the well-formed check is disabled, + and to ensure that no changes occur when performing a round-trip + through TVMScript. """ - @R.function + @R.function(check_well_formed=False) def func(A: R.Tensor): return relax.Call(custom_op, [A]) @@ -116,7 +117,7 @@ def main(A: R.Tensor): def test_normalization_applied_during_python_mutator(custom_op): """FNormalize is applied by relax.ExprMutator subclasses""" - @R.function(private=True) + @R.function(private=True, check_well_formed=False) def before(A: R.Tensor): return relax.Call(custom_op, [A]) diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 48d087c18a20..3f806de28dbd 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -821,14 +821,14 @@ def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): def test_call_packed(): - @R.function + @R.function(pure=False) def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor: z = R.call_packed("vm.builtin.copy", x, sinfo_args=R.Tensor((32, 32), "float32")) return z x = relax.Var("x", R.Tensor((32, 32), "float32")) bb = relax.BlockBuilder() - with bb.function("foo", (x)): + with bb.function("foo", (x), pure=False): z = bb.emit( relax.Call( relax.ExternFunc("vm.builtin.copy"), @@ -843,14 +843,14 @@ def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor: def test_call_packed_without_sinfo_args(): - @R.function + @R.function(pure=False) def foo(x: R.Object) -> R.Object: z = R.call_packed("test", x) return z x = relax.Var("x", R.Object()) bb = relax.BlockBuilder() - with bb.function("foo", (x)): + with bb.function("foo", (x), pure=False): z = bb.emit( relax.Call( relax.ExternFunc("test"), @@ -865,7 +865,7 @@ def foo(x: R.Object) -> R.Object: def test_annotation(): - @R.function + @R.function(pure=False) def foo( x: R.Tensor((32, "m"), "float32"), y: R.Tensor(("m",), "float32"), @@ -1576,7 +1576,7 @@ def foo(x: R.Tensor(("m", "n"), dtype="float32")): def test_prim_value(): - @R.function + @R.function(pure=False) def foo(): gv = R.call_packed("test", 1, sinfo_args=R.Tensor((32, 32), "float32")) return gv @@ -1585,7 +1585,7 @@ def foo(): def test_string_imm(): - @R.function + @R.function(pure=False) def foo(): gv = R.call_packed("test", "hello", sinfo_args=R.Tensor((32, 32), "float32")) return gv @@ -1594,7 +1594,7 @@ def foo(): def test_datatype_imm(): - @R.function + @R.function(pure=False) def foo(): gv = R.call_packed("test", R.dtype("float32"), sinfo_args=R.Tensor((32, 32), "float32")) return gv diff --git a/tests/python/tir-base/test_tir_specialize.py b/tests/python/tir-base/test_tir_specialize.py index fd2843f743be..6e4007b08454 100644 --- a/tests/python/tir-base/test_tir_specialize.py +++ b/tests/python/tir-base/test_tir_specialize.py @@ -67,7 +67,7 @@ def matmul_m_128(a: T.handle, b: T.handle, c: T.handle) -> None: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] -@T.prim_func +@T.prim_func(check_well_formed=False) def matmul_m_8x(a: T.handle, b: T.handle, c: T.handle) -> None: x = T.int32() m = T.int32() From cefdb7e7e440ce29fe2f3149bc0c7621bffc72f2 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 14 Feb 2024 23:09:26 -0500 Subject: [PATCH 06/58] Use the error raised by the TIR well-formed checker for the message --- python/tvm/script/parser/core/entry.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index 0017730f1289..2f4e8d59e7ef 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -93,9 +93,11 @@ def parse( check_ret = ret if not isinstance(check_ret, IRModule): check_ret = IRModule.from_expr(ret) - if not relax_well_formed(check_ret) or not tir_well_formed(check_ret, assert_mode=False): - parser.report_error( - source.as_ast(), - err="Program is not well-formed", - ) + source_ast = source.as_ast() + if not relax_well_formed(check_ret): + parser.report_error(source_ast, err="Program is not well-formed") + try: + tir_well_formed(check_ret) + except Exception as err: + parser.report_error(source_ast, err=err) return ret From 3177880900f858f7ea00ba4a89e17f615a961b46 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 14 Feb 2024 23:29:24 -0500 Subject: [PATCH 07/58] Fix tvmscript test failures --- .../tvmscript/test_tvmscript_parser_tir.py | 4 +- .../tvmscript/test_tvmscript_roundtrip.py | 840 +++++++++--------- 2 files changed, 429 insertions(+), 415 deletions(-) diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py b/tests/python/tvmscript/test_tvmscript_parser_tir.py index b2b534064605..074603681f34 100644 --- a/tests/python/tvmscript/test_tvmscript_parser_tir.py +++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py @@ -272,7 +272,7 @@ def test_tir_starred_for_loop(): @T.prim_func(private=True) def starred(a: T.handle, b: T.handle): A = T.match_buffer(a, [*dims, 128], "int32") - B = T.match_buffer(a, dims, "int32") + B = T.match_buffer(b, dims, "int32") for *spatial, reduction in T.grid(*A.shape): with T.block("reduce"): with T.init(): @@ -282,7 +282,7 @@ def starred(a: T.handle, b: T.handle): @T.prim_func(private=True) def non_starred(a: T.handle, b: T.handle): A = T.match_buffer(a, [128, 128, 128], "int32") - B = T.match_buffer(a, [128, 128], "int32") + B = T.match_buffer(b, [128, 128], "int32") for i, j, k in T.grid(128, 128, 128): with T.block("reduce"): with T.init(): diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index 85526f871bf1..3b2c1aef86fb 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -27,8 +27,9 @@ def opt_gemm_normalize(): - @tvm.script.ir_module + @tvm.script.ir_module(check_well_formed=False) class Module: + # packedB is treated as undefined @T.prim_func def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: # function attr dict @@ -57,9 +58,9 @@ def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: ) for x_c_init in T.serial(0, 32): for y_c_init in T.vectorized(0, 32): - C_global[ - (x_c_init + (x_outer * 32)), (y_c_init + (y_outer * 32)) - ] = T.float32(0) + C_global[(x_c_init + (x_outer * 32)), (y_c_init + (y_outer * 32))] = ( + T.float32(0) + ) for k_outer in T.serial(0, 256): for x_c in T.serial(0, 32): for k_inner in T.unroll(0, 4): @@ -180,8 +181,9 @@ def main(inputs: T.Buffer((64, 2, 4), "float32")) -> None: def opt_gemm_mod_host(): - @tvm.script.ir_module + @tvm.script.ir_module(check_well_formed=False) class Module: + # packedB is treated as undefined @T.prim_func def mmult( args: T.handle, @@ -478,7 +480,7 @@ def mmult( def opt_conv_tensorcore_normalize(): - @T.prim_func + @T.prim_func(check_well_formed=False) def func(A: T.handle, W: T.handle, Conv: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) @@ -1018,583 +1020,623 @@ def func( for kh in T.serial(0, 3): for ax2 in T.serial(0, 3): with T.launch_thread(tx, 32): - Apad_shared[ - ((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) - ] = T.if_then_else( - ( + Apad_shared[((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx)] = ( + T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61440 - ), - ], - T.float16(0), - dtype="float16", + - 61440 + ), + ], + T.float16(0), + dtype="float16", + ) ) with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 32) - ] = T.if_then_else( - ( + Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 32)] = ( + T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61408 - ), - ], - T.float16(0), - dtype="float16", + - 61408 + ), + ], + T.float16(0), + dtype="float16", + ) ) with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 64) - ] = T.if_then_else( - ( + Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 64)] = ( + T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61376 - ), - ], - T.float16(0), - dtype="float16", + - 61376 + ), + ], + T.float16(0), + dtype="float16", + ) ) with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 96) - ] = T.if_then_else( - ( + Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 96)] = ( + T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61344 - ), - ], - T.float16(0), - dtype="float16", + - 61344 + ), + ], + T.float16(0), + dtype="float16", + ) ) with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 128) - ] = T.if_then_else( - ( + Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 128)] = ( + T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61312 - ), - ], - T.float16(0), - dtype="float16", + - 61312 + ), + ], + T.float16(0), + dtype="float16", + ) ) with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 160) - ] = T.if_then_else( - ( + Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 160)] = ( + T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61280 - ), - ], - T.float16(0), - dtype="float16", + - 61280 + ), + ], + T.float16(0), + dtype="float16", + ) ) with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 192) - ] = T.if_then_else( - ( + Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 192)] = ( + T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61248 - ), - ], - T.float16(0), - dtype="float16", + - 61248 + ), + ], + T.float16(0), + dtype="float16", + ) ) with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 224) - ] = T.if_then_else( - ( + Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 224)] = ( + T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61216 - ), - ], - T.float16(0), - dtype="float16", + - 61216 + ), + ], + T.float16(0), + dtype="float16", + ) ) with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 256) - ] = T.if_then_else( - ( + Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 256)] = ( + T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61184 - ), - ], - T.float16(0), - dtype="float16", + - 61184 + ), + ], + T.float16(0), + dtype="float16", + ) ) with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 288) - ] = T.if_then_else( - ( + Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 288)] = ( + T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61152 - ), - ], - T.float16(0), - dtype="float16", + - 61152 + ), + ], + T.float16(0), + dtype="float16", + ) ) with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 320) - ] = T.if_then_else( - ( + Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 320)] = ( + T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61120 - ), - ], - T.float16(0), - dtype="float16", + - 61120 + ), + ], + T.float16(0), + dtype="float16", + ) ) with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 352) - ] = T.if_then_else( - ( + Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 352)] = ( + T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61088 - ), - ], - T.float16(0), - dtype="float16", + - 61088 + ), + ], + T.float16(0), + dtype="float16", + ) ) with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 384) - ] = T.if_then_else( - ( + Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 384)] = ( + T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx - ) - - 61056 - ), - ], - T.float16(0), - dtype="float16", + - 61056 + ), + ], + T.float16(0), + dtype="float16", + ) ) with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 416) - ] = T.if_then_else( - ( + Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 416)] = ( + T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ - ( + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ ( ( ( ( ( ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) ) - + (kh * 57344) + + (bz * 4096) ) - + (bz * 4096) + + (ax2 * 4096) ) - + (ax2 * 4096) + + (ic_outer * 512) ) - + (ic_outer * 512) + + tx ) - + tx + - 61024 + ), + ], + T.float16(0), + dtype="float16", + ) + ) + with T.launch_thread(tx, 32): + Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 448)] = ( + T.if_then_else( + ( + ( + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - - 61024 + and ((ax2 + T.floormod(bz, 14)) < 15) ), - ], - T.float16(0), - dtype="float16", + A_1[ + ( + ( + ( + ( + ( + ( + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) + ) + + (bz * 4096) + ) + + (ax2 * 4096) + ) + + (ic_outer * 512) + ) + + tx + ) + - 60992 + ), + ], + T.float16(0), + dtype="float16", + ) ) - with T.launch_thread(tx, 32): - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 448) - ] = T.if_then_else( + T.launch_thread(tx, 32) + Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 480)] = ( + T.if_then_else( ( ( ( @@ -1626,52 +1668,12 @@ def func( ) + tx ) - - 60992 + - 60960 ), ], T.float16(0), dtype="float16", ) - T.launch_thread(tx, 32) - Apad_shared[ - (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 480) - ] = T.if_then_else( - ( - ( - ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) - ) - and (1 <= (ax2 + T.floormod(bz, 14))) - ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ - ( - ( - ( - ( - ( - ( - ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) - ) - + (kh * 57344) - ) - + (bz * 4096) - ) - + (ax2 * 4096) - ) - + (ic_outer * 512) - ) - + tx - ) - - 60960 - ), - ], - T.float16(0), - dtype="float16", ) with T.launch_thread(tx, 32): W_shared[T.ramp((((ty * 512) + (tz * 256)) + (tx * 8)), 1, 8)] = W_1[ @@ -2909,7 +2911,8 @@ def constant_folding(a: T.handle) -> None: def simplify_bracket(): - @T.prim_func + # uninitialized variables + @T.prim_func(check_well_formed=False) def simplify_bracket() -> None: a = T.int32() b = T.int32() @@ -3024,7 +3027,8 @@ def comm_reducer_multiple_reduce_groups(a: T.handle, b: T.handle) -> None: def multiple_commreducer(): - @T.prim_func + # normal_reduce_temp0 is treated as uninitialized value + @T.prim_func(check_well_formed=False) def multiple_commreducer() -> None: normal_reduce_temp0 = T.Buffer([1], dtype="float32", strides=[1], scope="local") normal_reduce_temp1 = T.Buffer([1], dtype="float32", strides=[1], scope="local") @@ -3044,7 +3048,8 @@ def multiple_commreducer() -> None: def func_div_mod(): - @T.prim_func + # not well-formed: free variables + @T.prim_func(check_well_formed=False) def func_div_mod(): a = T.int32() b = T.int32() @@ -3057,7 +3062,7 @@ def func_div_mod(): def test_div_mod(): func = func_div_mod() - rt_func = tvm.script.from_source(func.script()) + rt_func = tvm.script.from_source(func.script(), check_well_formed=False) tvm.ir.assert_structural_equal(func, rt_func, True) assert isinstance(func.body[0].value, tvm.tir.FloorDiv) @@ -3220,7 +3225,8 @@ def ctpop(A: T.Buffer((16,), "uint8"), B: T.Buffer((16,), "uint8")) -> None: def parse_bufferslice_as_range_bound(): - @T.prim_func + # apparently the use of i in the "outer" block when it is defined outside of a block is wrong + @T.prim_func(check_well_formed=False) def segment_sum( A_ptr: T.handle, B_ptr: T.handle, indptr_ptr: T.handle, n: T.int32, m: T.int32 ) -> None: @@ -3485,7 +3491,8 @@ def func() -> None: def bool_cast(): - @T.prim_func + # uninitialized var + @T.prim_func(check_well_formed=False) def func() -> None: a = T.bool() T.evaluate(T.bool(T.int32(0))) @@ -3608,7 +3615,8 @@ def func(): def let_stmt_value(): - @T.prim_func + # uninitialized var + @T.prim_func(check_well_formed=False) def func(): y = T.int32() with T.LetStmt(y) as x: @@ -3654,7 +3662,8 @@ def main(a: T.handle, b: T.handle): def merge_shape_var_def(): - @T.prim_func + # uninitialized vars + @T.prim_func(check_well_formed=False) def main(A: T.handle, B: T.handle): T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) m, n = T.int32(), T.int32() @@ -3872,8 +3881,8 @@ def undefined_data_ptr_in_decl_buffer(): Allocate/DeclBuffer pair, performing a round-trip through TVMScript should not introduce an Allocate node. """ - - @T.prim_func + # uninitialized var + @T.prim_func(check_well_formed=False) def func(): data_ptr = T.handle("float32") buf = T.decl_buffer(shape=[1], dtype="float32", data=data_ptr) @@ -3883,7 +3892,8 @@ def func(): def undefined_shape_in_decl_buffer(): - @T.prim_func + # uninitialized var + @T.prim_func(check_well_formed=False) def func(): size = T.int32() buf = T.decl_buffer(shape=[size], dtype="float32") @@ -3893,7 +3903,8 @@ def func(): def undefined_stride_in_decl_buffer(): - @T.prim_func + # uninitialized var + @T.prim_func(check_well_formed=False) def func(): stride = T.int32() buf = T.decl_buffer(shape=[1], dtype="float32", strides=[stride]) @@ -3903,7 +3914,8 @@ def func(): def undefined_elem_offset_in_decl_buffer(): - @T.prim_func + # uninitialized var + @T.prim_func(check_well_formed=False) def func(): elem_offset = T.int32() buf = T.decl_buffer(shape=[1], dtype="float32", elem_offset=elem_offset) @@ -4162,7 +4174,9 @@ def func(A: R.Object): def test_roundtrip(ir_generator): original = ir_generator() - after_roundtrip = tvm.script.from_source(original.script(show_meta=True)) + after_roundtrip = tvm.script.from_source( + original.script(show_meta=True), check_well_formed=False + ) tvm.ir.assert_structural_equal(original, after_roundtrip, True) From 4f9f97a7c73c6fd4c29961b0552f6624193e246d Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 14 Feb 2024 23:31:48 -0500 Subject: [PATCH 08/58] Whitespace --- python/tvm/script/parser/ir/entry.py | 1 + tests/python/relax/test_frontend_nn_modules.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/tvm/script/parser/ir/entry.py b/python/tvm/script/parser/ir/entry.py index 0a75f846a5f1..f91c7701a2eb 100644 --- a/python/tvm/script/parser/ir/entry.py +++ b/python/tvm/script/parser/ir/entry.py @@ -43,6 +43,7 @@ def ir_module(mod: Optional[Type] = None, check_well_formed: bool = True) -> IRM ir_module : IRModule The parsed ir module. """ + def decorator_wrapper(mod): if not inspect.isclass(mod): raise TypeError(f"Expect a class, but got: {mod}") diff --git a/tests/python/relax/test_frontend_nn_modules.py b/tests/python/relax/test_frontend_nn_modules.py index 5c8c9edd8369..5ddc10505591 100644 --- a/tests/python/relax/test_frontend_nn_modules.py +++ b/tests/python/relax/test_frontend_nn_modules.py @@ -592,9 +592,9 @@ def forward( reshape2: R.Tensor((2, 77, 10, 64), dtype="float32") = R.reshape( matmul2, R.shape([2, 77, 10, 64]) ) - scaled_dot_product_attention: R.Tensor((2, 4096, 10, 64), dtype="float32") = ( - R.nn.attention(reshape, reshape1, reshape2, scale=None, causal_mask=None) - ) + scaled_dot_product_attention: R.Tensor( + (2, 4096, 10, 64), dtype="float32" + ) = R.nn.attention(reshape, reshape1, reshape2, scale=None, causal_mask=None) reshape3: R.Tensor((2, 4096, 640), dtype="float32") = R.reshape( scaled_dot_product_attention, R.shape([2, 4096, 640]) ) From 019a85b251b62b8f73b3155940fc210a35c91bb0 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 14 Feb 2024 23:34:46 -0500 Subject: [PATCH 09/58] Fix errors in verify_well_formed test --- .../test_tir_analysis_verify_well_formed.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py index a1b3bee1b282..629549721e4a 100644 --- a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py +++ b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py @@ -43,7 +43,7 @@ def element_wise( def test_fail_use_out_loop_var(): - @T.prim_func + @T.prim_func(check_well_formed=False) def element_wise( A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), @@ -60,7 +60,7 @@ def element_wise( def test_error_for_out_of_scope_usage(): """A variable may not be used after its scope ends""" - @T.prim_func + @T.prim_func(check_well_formed=False) def func(): i = T.int32() with T.LetStmt(42, var=i): @@ -76,7 +76,7 @@ def func(): def test_error_for_nested_rebind_usage(): """A variable may not be re-defined within the initial scope""" - @T.prim_func + @T.prim_func(check_well_formed=False) def func(): i = T.int32() with T.LetStmt(42, var=i): @@ -92,7 +92,7 @@ def func(): def test_error_for_repeated_binding(): """A variable may not be re-defined after the scope ends""" - @T.prim_func + @T.prim_func(check_well_formed=False) def func(): i = T.int32() with T.LetStmt(42, var=i): @@ -109,7 +109,7 @@ def test_error_for_cross_function_reuse(): i = tvm.tir.Var("i", "int32") - @I.ir_module + @I.ir_module(check_well_formed=False) class mod: @T.prim_func def func1(): @@ -175,7 +175,7 @@ def test_reuse_of_env_thread_across_functions_is_ill_formed(): threadIdx_x = tvm.tir.Var("threadIdx_x", "int32") - @I.ir_module + @I.ir_module(check_well_formed=False) class mod: @T.prim_func def kernel_1(A: T.Buffer([256], "float32")): From ab2f2bd7cf5d257e99d977d610a7401d87424e90 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 15 Feb 2024 15:42:13 -0500 Subject: [PATCH 10/58] Include a more helpful error message --- python/tvm/script/parser/core/entry.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index 2f4e8d59e7ef..b63efcfb65e2 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -27,6 +27,13 @@ from .error import ParserError from .parser import Parser +WELL_FORMED_ERROR_MESSAGE = ( + "Program is not well-formed. If this is deliberate, consider " + "setting check_well_formed in the top-level decorator to False " + "(e.g., @I.ir_module(check_well_formed=False) or " + "@R.function(check_well_formed=False))." +) + def _default_globals() -> Dict[str, Any]: import tvm # pylint: disable=import-outside-toplevel @@ -95,9 +102,12 @@ def parse( check_ret = IRModule.from_expr(ret) source_ast = source.as_ast() if not relax_well_formed(check_ret): - parser.report_error(source_ast, err="Program is not well-formed") + parser.report_error(source_ast, err=WELL_FORMED_ERROR_MESSAGE) try: tir_well_formed(check_ret) except Exception as err: - parser.report_error(source_ast, err=err) + parser.report_error( + source_ast, + err=f"{WELL_FORMED_ERROR_MESSAGE}\n\nTraceback: {str(err)}", + ) return ret From a6db5fa61f50e29fb19363fcd2cf73fd0c60264b Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 15 Feb 2024 16:14:58 -0500 Subject: [PATCH 11/58] Fix TIR test failures --- tests/python/arith/test_arith_domain_touched.py | 3 ++- tests/python/tir-base/test_tir_renew_defs.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/python/arith/test_arith_domain_touched.py b/tests/python/arith/test_arith_domain_touched.py index 1553aabd4e4c..46effde3420c 100644 --- a/tests/python/arith/test_arith_domain_touched.py +++ b/tests/python/arith/test_arith_domain_touched.py @@ -71,7 +71,8 @@ def test_domain_touched(): def test_domain_touched_vector(): m = tvm.runtime.convert(128) - @T.prim_func + # n is undefined + @T.prim_func(check_well_formed=False) def func(a: T.handle, b: T.handle): n = T.int32() A = T.match_buffer(a, (n * m,)) diff --git a/tests/python/tir-base/test_tir_renew_defs.py b/tests/python/tir-base/test_tir_renew_defs.py index 22f7b65ca17b..5a11725c64b5 100644 --- a/tests/python/tir-base/test_tir_renew_defs.py +++ b/tests/python/tir-base/test_tir_renew_defs.py @@ -82,7 +82,8 @@ def _get_block(f): def test_match_buffer(): - @T.prim_func + # well-formed checker complains about multiple definitions for a variable A0_s1>? + @T.prim_func(check_well_formed=False) # A and B should be remapped def func_match_buffer(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32")): with T.block("root"): From 2048e60194336e29add2a6f7145c15fa043610bb Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 15 Feb 2024 16:17:37 -0500 Subject: [PATCH 12/58] Address well-formed failures in test_tir_specialize --- tests/python/tir-base/test_tir_specialize.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/tir-base/test_tir_specialize.py b/tests/python/tir-base/test_tir_specialize.py index 6e4007b08454..327884eaa866 100644 --- a/tests/python/tir-base/test_tir_specialize.py +++ b/tests/python/tir-base/test_tir_specialize.py @@ -277,7 +277,8 @@ def before(A: T.Buffer([16, 16], "float32"), B: T.Buffer([16, 16], "float32")): for i in range(256): B_flat[i] = A_flat[i] * 2.0 - @T.prim_func(private=True) + # well-formed checker complains about multiple nested definitions of B_flat? + @T.prim_func(private=True, check_well_formed=False) def expected(A: T.Buffer([16, 16], "float32"), B_handle: T.handle): B = T.match_buffer(B_handle, [16, 16], "float32", data=A.data) A_flat = T.decl_buffer([256], "float32", data=A.data) From 831f8f2a84b69fe1a2a5f86130fd58f89f69c77a Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 15 Feb 2024 16:24:41 -0500 Subject: [PATCH 13/58] Correct well-formedness error in test_tir_analysis_oob --- tests/python/tir-analysis/test_tir_analysis_oob.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/tir-analysis/test_tir_analysis_oob.py b/tests/python/tir-analysis/test_tir_analysis_oob.py index 7c8ceed36e10..754334b5658d 100644 --- a/tests/python/tir-analysis/test_tir_analysis_oob.py +++ b/tests/python/tir-analysis/test_tir_analysis_oob.py @@ -42,7 +42,8 @@ def bad_store_loop(A: T.Buffer((2, 3), "float32"), B: T.Buffer((3, 2), "float32" B[0, i] = A[1, i] -@T.prim_func +# N is undefined +@T.prim_func(check_well_formed=False) def unknown_bounds(A: T.Buffer((2, 3), "float32"), B: T.Buffer((3, 2), "float32")): N = T.int32() for i in range(3): From e6d5f005a24688374d0c0a2f2b6f696da64b70d5 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 15 Feb 2024 16:57:04 -0500 Subject: [PATCH 14/58] Correct further well-formedness failures --- python/tvm/testing/utils.py | 9 ++++++--- .../test_tir_transform_common_subexpr_elim.py | 9 +++++---- .../test_tir_transform_lower_match_buffer.py | 6 ++++-- ...m_merge_dynamic_shared_memory_allocations.py | 4 ++-- .../test_tir_transform_simplify.py | 17 ++++++++++++----- 5 files changed, 29 insertions(+), 16 deletions(-) diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index e1b1c654570a..d0ceee4aa2a0 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -527,7 +527,6 @@ def enabled_targets(): class Feature: - """A feature that may be required to run a test. Parameters @@ -1952,6 +1951,8 @@ def expected(A: T.Buffer(1, "int32")): """ + check_well_formed: bool = True + def __init_subclass__(cls): assert len([getattr(cls, name) for name in ["before", "Before"] if hasattr(cls, name)]) <= 1 assert ( @@ -1995,7 +1996,9 @@ def inner(self): func_dict[name] = method.with_attr("global_symbol", name) else: source_code = "@T.prim_func\n" + textwrap.dedent(inspect.getsource(method)) - prim_func = tvm.script.from_source(source_code) + prim_func = tvm.script.from_source( + source_code, check_well_formed=self.check_well_formed + ) func_dict[name] = prim_func.with_attr("global_symbol", name) return tvm.IRModule(func_dict) @@ -2004,7 +2007,7 @@ def inner(self): def inner(self): # pylint: disable=unused-argument source_code = "@T.prim_func\n" + textwrap.dedent(inspect.getsource(func)) - return tvm.script.from_source(source_code) + return tvm.script.from_source(source_code, check_well_formed=self.check_well_formed) return pytest.fixture(inner) diff --git a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py index 7be1038ce5d4..62a3cbcbcd8c 100644 --- a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py @@ -348,14 +348,15 @@ def test_no_normalization_without_commoning(): # ------------------------------------------------- # Part for testing the commoning with equivalences # ------------------------------------------------- -@T.prim_func +# B is treated as uninitialized +@T.prim_func(check_well_formed=False) def func_distributivity(i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32) -> None: B = T.Buffer((50,), "int32") B[i1] = x * (y + z) B[i2] = x * y + x * z -@T.prim_func +@T.prim_func(check_well_formed=False) def func_distributivity_expected( i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 ) -> None: @@ -365,14 +366,14 @@ def func_distributivity_expected( B[i2] = cse_var_1 -@T.prim_func +@T.prim_func(check_well_formed=False) def func_associativity(i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32) -> None: B = T.Buffer((50,), "int32") B[i1] = (x + y) + z B[i2] = x + (y + z) -@T.prim_func +@T.prim_func(check_well_formed=False) def func_associativity_expected( i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 ) -> None: diff --git a/tests/python/tir-transform/test_tir_transform_lower_match_buffer.py b/tests/python/tir-transform/test_tir_transform_lower_match_buffer.py index 7dc164496501..410269ffae5c 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_match_buffer.py +++ b/tests/python/tir-transform/test_tir_transform_lower_match_buffer.py @@ -466,7 +466,8 @@ def fail_match_store(a: T.handle) -> None: sub_A[()] = 1 -@T.prim_func +# well-formed checker complains about redefinition of a stride variable +@T.prim_func(check_well_formed=False) def fail_buffer_bind(a: T.handle) -> None: A = T.match_buffer(a, (8, 8)) for i, j in T.grid(8, 2): @@ -479,7 +480,8 @@ def fail_buffer_bind(a: T.handle) -> None: sub_A[i, j * 4 + jj] = 1 -@T.prim_func +# well-formed checker complains about redefinition of a stride variable +@T.prim_func(check_well_formed=False) def fail_match_func_param(a: T.handle, m: T.handle, n: T.handle) -> None: A = T.match_buffer(a, (8, 8)) for i, j in T.grid(8, 2): diff --git a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py index efe2944aaa48..8661843d39c1 100644 --- a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -456,7 +456,7 @@ def func( class TestSimpleAllocNoReuse(tvm.testing.CompareBeforeAfter): """Test alloc and free within the same scope.""" - transform = tvm.tir.transform.MergeDynamicSharedMemoryAllocations() + transform = tvm.tir.transform.MergeSharedMemoryAllocations() def before(self): @T.prim_func @@ -485,7 +485,7 @@ def func(): class TestSimpleAllocReuse(tvm.testing.CompareBeforeAfter): """Test alloc and free within the same scope with a reuse chance.""" - transform = tvm.tir.transform.MergeDynamicSharedMemoryAllocations() + transform = tvm.tir.transform.MergeSharedMemoryAllocations() def before(self): @T.prim_func diff --git a/tests/python/tir-transform/test_tir_transform_simplify.py b/tests/python/tir-transform/test_tir_transform_simplify.py index 6bad817c4955..f7887bc61137 100644 --- a/tests/python/tir-transform/test_tir_transform_simplify.py +++ b/tests/python/tir-transform/test_tir_transform_simplify.py @@ -142,6 +142,8 @@ class BaseBeforeAfter(tvm.testing.CompareBeforeAfter): apply_constraints_to_boolean_branches = False propagate_knowns_to_prove_conditional = False propagate_knowns_to_simplify_expressions = False + # from base class + check_well_formed = False def transform(self): def inner(mod): @@ -650,7 +652,8 @@ class TestRemoveTransitivelyProvableCondition(BaseBeforeAfter): def before(self, test_case): priors, postulate, _ = test_case - @T.prim_func + # well formed checker complains of undefined variables in condition + @T.prim_func(check_well_formed=False) def func(A: T.Buffer(1, "bool")): if priors: A[0] = postulate @@ -666,7 +669,8 @@ def expected(self, test_case): if provable: - @T.prim_func + # well formed checker complains of undefined variables in condition + @T.prim_func(check_well_formed=False) def func(A: T.Buffer(1, "bool")): if priors: A[0] = True @@ -676,7 +680,8 @@ def func(A: T.Buffer(1, "bool")): else: postulate = analyzer.canonical_simplify(postulate) - @T.prim_func + # well formed checker complains of undefined variables in condition + @T.prim_func(check_well_formed=False) def func(A: T.Buffer(1, "bool")): if priors: A[0] = postulate @@ -1034,7 +1039,8 @@ class TestMostRestrictiveConditional(BaseBeforeAfter): def before(self, test_case): priors, expr_before, _ = test_case - @T.prim_func + # well formed checker complains of undefined variables in condition + @T.prim_func(check_well_formed=False) def func(A: T.Buffer(1, "bool")): if priors: A[0] = expr_before @@ -1045,7 +1051,8 @@ def func(A: T.Buffer(1, "bool")): def expected(self, test_case): priors, _, expr_after = test_case - @T.prim_func + # well formed checker complains of undefined variables in condition + @T.prim_func(check_well_formed=False) def func(A: T.Buffer(1, "bool")): if priors: A[0] = expr_after From 3070b18cfff2c15fc2206c64880c6c7019410b57 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Fri, 23 Feb 2024 17:05:53 -0500 Subject: [PATCH 15/58] Remove __tvm_meta__ from test case to avoid parsing error --- ...est_tir_transform_inject_rolling_buffer.py | 167 ++++++++++++++---- 1 file changed, 136 insertions(+), 31 deletions(-) diff --git a/tests/python/tir-transform/test_tir_transform_inject_rolling_buffer.py b/tests/python/tir-transform/test_tir_transform_inject_rolling_buffer.py index b7bd6cb46fd6..6aa4e96cb207 100644 --- a/tests/python/tir-transform/test_tir_transform_inject_rolling_buffer.py +++ b/tests/python/tir-transform/test_tir_transform_inject_rolling_buffer.py @@ -199,49 +199,124 @@ def test_mixed_buffers(make_rolling): _verify_schedule(sch, [A], pool_c) -# fmt: off -@tvm.script.ir_module +@tvm.script.ir_module(check_well_formed=False) class PreRollingBuffer: @T.prim_func - def main(A: T.handle, tensor: T.handle) -> None: + def main( + A: T.handle, + tensor: T.handle, + tensor_2: T.Buffer( + [1, 10, 12, 16], + dtype="int8", + elem_offset=0, + align=64, + offset_factor=1, + ), + ) -> None: # function attr dict - T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - # buffer definition - tensor_2 = T.Buffer([1, 10, 12, 16], dtype="int8", elem_offset=0, align=64, offset_factor=1) - A_1 = T.match_buffer(A, [1, 12, 14, 16], dtype="int8", elem_offset=0, align=64, offset_factor=1) - tensor_1 = T.match_buffer(tensor, [1, 8, 8, 16], dtype="int8", elem_offset=0, align=64, offset_factor=1) + T.func_attr( + { + "from_legacy_te_schedule": True, + "global_symbol": "main", + "tir.noalias": True, + } + ) + A_1 = T.match_buffer( + A, [1, 12, 14, 16], dtype="int8", elem_offset=0, align=64, offset_factor=1 + ) + tensor_1 = T.match_buffer( + tensor, [1, 8, 8, 16], dtype="int8", elem_offset=0, align=64, offset_factor=1 + ) # body T.realize(tensor_1[0:1, 0:8, 0:8, 0:16], "") for ax1_outer in T.serial(0, 2): - T.realize(tensor_2[0:1, (ax1_outer*4):((ax1_outer*4) + 6), 0:12, 0:16], "") + T.realize(tensor_2[0:1, (ax1_outer * 4) : ((ax1_outer * 4) + 6), 0:12, 0:16], "") T.attr(tensor_2, "rolling_buffer_scope", True) for ax1 in T.serial(0, 6): for ax2 in T.serial(0, 12): for ax3 in T.serial(0, 16): - tensor_2[0, (ax1 + (ax1_outer*4)), ax2, ax3] = T.int8(0) + tensor_2[0, (ax1 + (ax1_outer * 4)), ax2, ax3] = T.int8(0) for dh in T.serial(0, 3): for dw in T.serial(0, 3): - tensor_2[0, (ax1 + (ax1_outer*4)), ax2, ax3] = T.max(tensor_2[0, (ax1 + (ax1_outer*4)), ax2, ax3], A_1[0, ((ax1 + (ax1_outer*4)) + dh), (ax2 + dw), ax3]) + tensor_2[ + 0, + (ax1 + (ax1_outer * 4)), + ax2, + ax3, + ] = T.max( + tensor_2[ + 0, + (ax1 + (ax1_outer * 4)), + ax2, + ax3, + ], + A_1[ + 0, + ((ax1 + (ax1_outer * 4)) + dh), + (ax2 + dw), + ax3, + ], + ) for ax1_inner in T.serial(0, 4): for ax2_inner in T.serial(0, 8): for ax3_inner in T.serial(0, 16): - tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner] = T.int8(0) + tensor_1[ + 0, + (ax1_inner + (ax1_outer * 4)), + ax2_inner, + ax3_inner, + ] = T.int8(0) for dh_1 in T.serial(0, 3): for dw_1 in T.serial(0, 5): - tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner] = T.max(tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner], tensor_2[0, ((ax1_inner + (ax1_outer*4)) + dh_1), (ax2_inner + dw_1), ax3_inner]) - __tvm_meta__ = None - - -@tvm.script.ir_module + tensor_1[ + 0, + (ax1_inner + (ax1_outer * 4)), + ax2_inner, + ax3_inner, + ] = T.max( + tensor_1[ + 0, + (ax1_inner + (ax1_outer * 4)), + ax2_inner, + ax3_inner, + ], + tensor_2[ + 0, + ((ax1_inner + (ax1_outer * 4)) + dh_1), + (ax2_inner + dw_1), + ax3_inner, + ], + ) + + +@tvm.script.ir_module(check_well_formed=False) class PostRollingBuffer: @T.prim_func - def main(A: T.handle, tensor: T.handle) -> None: + def main( + A: T.handle, + tensor: T.handle, + tensor_2: T.Buffer( + [1, 10, 12, 16], + dtype="int8", + elem_offset=0, + align=64, + offset_factor=1, + ), + ) -> None: # function attr dict - T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - # buffer definition - tensor_2 = T.Buffer([1, 10, 12, 16], dtype="int8", elem_offset=0, align=64, offset_factor=1) - A_1 = T.match_buffer(A, [1, 12, 14, 16], dtype="int8", elem_offset=0, align=64, offset_factor=1) - tensor_1 = T.match_buffer(tensor, [1, 8, 8, 16], dtype="int8", elem_offset=0, align=64, offset_factor=1) + T.func_attr( + { + "from_legacy_te_schedule": True, + "global_symbol": "main", + "tir.noalias": True, + } + ) + A_1 = T.match_buffer( + A, [1, 12, 14, 16], dtype="int8", elem_offset=0, align=64, offset_factor=1 + ) + tensor_1 = T.match_buffer( + tensor, [1, 8, 8, 16], dtype="int8", elem_offset=0, align=64, offset_factor=1 + ) # body T.realize(tensor_1[0:1, 0:8, 0:8, 0:16], "") T.realize(tensor_2[0:1, 0:6, 0:12, 0:16], "") @@ -249,21 +324,51 @@ def main(A: T.handle, tensor: T.handle) -> None: for ax1 in T.serial(0, 6): for ax2 in T.serial(0, 12): for ax3 in T.serial(0, 16): - if T.likely(((ax1_outer < 1) or (ax1 >= 2)), dtype='bool') : - tensor_2[0, T.floormod((ax1 + (ax1_outer*4)), 6), ax2, ax3] = T.int8(0) + if T.likely(((ax1_outer < 1) or (ax1 >= 2)), dtype="bool"): + tensor_2[ + 0, + T.floormod((ax1 + (ax1_outer * 4)), 6), + ax2, + ax3, + ] = T.int8(0) for dh in T.serial(0, 3): for dw in T.serial(0, 3): - if T.likely(((ax1_outer < 1) or (ax1 >= 2)), dtype='bool'): - tensor_2[0, T.floormod((ax1 + (ax1_outer*4)), 6), ax2, ax3] = T.max(tensor_2[0, T.floormod((ax1 + (ax1_outer*4)), 6), ax2, ax3], A_1[0, ((ax1 + (ax1_outer*4)) + dh), (ax2 + dw), ax3]) + if T.likely(((ax1_outer < 1) or (ax1 >= 2)), dtype="bool"): + tensor_2[ + 0, T.floormod((ax1 + (ax1_outer * 4)), 6), ax2, ax3 + ] = T.max( + tensor_2[ + 0, T.floormod((ax1 + (ax1_outer * 4)), 6), ax2, ax3 + ], + A_1[0, ((ax1 + (ax1_outer * 4)) + dh), (ax2 + dw), ax3], + ) for ax1_inner in T.serial(0, 4): for ax2_inner in T.serial(0, 8): for ax3_inner in T.serial(0, 16): - tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner] = T.int8(0) + tensor_1[ + 0, + (ax1_inner + (ax1_outer * 4)), + ax2_inner, + ax3_inner, + ] = T.int8(0) for dh_1 in T.serial(0, 3): for dw_1 in T.serial(0, 5): - tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner] = T.max(tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner], tensor_2[0, T.floormod(((ax1_inner + (ax1_outer*4)) + dh_1), 6), (ax2_inner + dw_1), ax3_inner]) - __tvm_meta__ = None -# fmt: on + tensor_1[ + 0, + (ax1_inner + (ax1_outer * 4)), + ax2_inner, + ax3_inner, + ] = T.max( + tensor_1[ + 0, (ax1_inner + (ax1_outer * 4)), ax2_inner, ax3_inner + ], + tensor_2[ + 0, + T.floormod(((ax1_inner + (ax1_outer * 4)) + dh_1), 6), + (ax2_inner + dw_1), + ax3_inner, + ], + ) def test_rolling_buffer_ir_transform(): From 5c7aec3628708bf5b650b20ab0eb7c2856ee6f5f Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 26 Feb 2024 18:26:33 -0500 Subject: [PATCH 16/58] Avoid circular import in entryy.py --- python/tvm/script/parser/core/entry.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index b63efcfb65e2..49c919b9dc0f 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -18,8 +18,6 @@ import inspect from typing import Any, Dict, Union -from ....relax.analysis import well_formed as relax_well_formed -from ....tir.analysis import verify_well_formed as tir_well_formed from ....ir.module import IRModule from ...ir_builder import IRBuilder from . import doc @@ -97,6 +95,11 @@ def parse( ret = builder.get() # check well-formedness in both Relax and TIR if check_well_formed: + # do the imports here to avoid a circular import at the start + # (since importing Relax will import a dependenc on the parser) + from ....relax.analysis import well_formed as relax_well_formed + from ....tir.analysis import verify_well_formed as tir_well_formed + check_ret = ret if not isinstance(check_ret, IRModule): check_ret = IRModule.from_expr(ret) From 87426d3316dd8792fe05e4f2cd7d2fbbe09a1d3d Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 26 Feb 2024 18:26:56 -0500 Subject: [PATCH 17/58] Formatting fixes --- ...est_tir_transform_inject_rolling_buffer.py | 21 +- .../tvmscript/test_tvmscript_roundtrip.py | 795 +++++++++--------- 2 files changed, 400 insertions(+), 416 deletions(-) diff --git a/tests/python/tir-transform/test_tir_transform_inject_rolling_buffer.py b/tests/python/tir-transform/test_tir_transform_inject_rolling_buffer.py index 6aa4e96cb207..046df6c4cb36 100644 --- a/tests/python/tir-transform/test_tir_transform_inject_rolling_buffer.py +++ b/tests/python/tir-transform/test_tir_transform_inject_rolling_buffer.py @@ -238,24 +238,9 @@ def main( tensor_2[0, (ax1 + (ax1_outer * 4)), ax2, ax3] = T.int8(0) for dh in T.serial(0, 3): for dw in T.serial(0, 3): - tensor_2[ - 0, - (ax1 + (ax1_outer * 4)), - ax2, - ax3, - ] = T.max( - tensor_2[ - 0, - (ax1 + (ax1_outer * 4)), - ax2, - ax3, - ], - A_1[ - 0, - ((ax1 + (ax1_outer * 4)) + dh), - (ax2 + dw), - ax3, - ], + tensor_2[0, (ax1 + (ax1_outer * 4)), ax2, ax3] = T.max( + tensor_2[0, (ax1 + (ax1_outer * 4)), ax2, ax3], + A_1[0, ((ax1 + (ax1_outer * 4)) + dh), (ax2 + dw), ax3], ) for ax1_inner in T.serial(0, 4): for ax2_inner in T.serial(0, 8): diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index 3b2c1aef86fb..6c21ffb13fa9 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -58,9 +58,9 @@ def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: ) for x_c_init in T.serial(0, 32): for y_c_init in T.vectorized(0, 32): - C_global[(x_c_init + (x_outer * 32)), (y_c_init + (y_outer * 32))] = ( - T.float32(0) - ) + C_global[ + (x_c_init + (x_outer * 32)), (y_c_init + (y_outer * 32)) + ] = T.float32(0) for k_outer in T.serial(0, 256): for x_c in T.serial(0, 32): for k_inner in T.unroll(0, 4): @@ -1020,627 +1020,586 @@ def func( for kh in T.serial(0, 3): for ax2 in T.serial(0, 3): with T.launch_thread(tx, 32): - Apad_shared[((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx)] = ( - T.if_then_else( + Apad_shared[ + ((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + ] = T.if_then_else( + ( ( ( - ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) - ) - and (1 <= (ax2 + T.floormod(bz, 14))) + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ + and (1 <= (ax2 + T.floormod(bz, 14))) + ) + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ + ( ( ( ( ( ( ( - ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) - ) - + (kh * 57344) + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) ) - + (bz * 4096) + + (kh * 57344) ) - + (ax2 * 4096) + + (bz * 4096) ) - + (ic_outer * 512) + + (ax2 * 4096) ) - + tx - ) - - 61440 - ), - ], - T.float16(0), - dtype="float16", - ) - ) - with T.launch_thread(tx, 32): - Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 32)] = ( - T.if_then_else( - ( - ( - ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) + + (ic_outer * 512) ) - and (1 <= (ax2 + T.floormod(bz, 14))) + + tx ) - and ((ax2 + T.floormod(bz, 14)) < 15) + - 61440 ), - A_1[ - ( - ( - ( - ( - ( - ( - ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) - ) - + (kh * 57344) - ) - + (bz * 4096) - ) - + (ax2 * 4096) - ) - + (ic_outer * 512) - ) - + tx - ) - - 61408 - ), - ], - T.float16(0), - dtype="float16", - ) + ], + T.float16(0), + dtype="float16", ) with T.launch_thread(tx, 32): - Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 64)] = ( - T.if_then_else( + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 32) + ] = T.if_then_else( + ( ( ( - ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) - ) - and (1 <= (ax2 + T.floormod(bz, 14))) + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ + and (1 <= (ax2 + T.floormod(bz, 14))) + ) + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ + ( ( ( ( ( ( ( - ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) - ) - + (kh * 57344) + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) ) - + (bz * 4096) + + (kh * 57344) ) - + (ax2 * 4096) + + (bz * 4096) ) - + (ic_outer * 512) + + (ax2 * 4096) ) - + tx + + (ic_outer * 512) ) - - 61376 - ), - ], - T.float16(0), - dtype="float16", - ) + + tx + ) + - 61408 + ), + ], + T.float16(0), + dtype="float16", ) with T.launch_thread(tx, 32): - Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 96)] = ( - T.if_then_else( + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 64) + ] = T.if_then_else( + ( ( ( - ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) - ) - and (1 <= (ax2 + T.floormod(bz, 14))) + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ + and (1 <= (ax2 + T.floormod(bz, 14))) + ) + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ + ( ( ( ( ( ( ( - ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) - ) - + (kh * 57344) + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) ) - + (bz * 4096) + + (kh * 57344) ) - + (ax2 * 4096) + + (bz * 4096) ) - + (ic_outer * 512) + + (ax2 * 4096) ) - + tx + + (ic_outer * 512) ) - - 61344 - ), - ], - T.float16(0), - dtype="float16", - ) + + tx + ) + - 61376 + ), + ], + T.float16(0), + dtype="float16", ) with T.launch_thread(tx, 32): - Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 128)] = ( - T.if_then_else( + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 96) + ] = T.if_then_else( + ( ( ( - ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) - ) - and (1 <= (ax2 + T.floormod(bz, 14))) + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ + and (1 <= (ax2 + T.floormod(bz, 14))) + ) + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ + ( ( ( ( ( ( ( - ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) - ) - + (kh * 57344) + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) ) - + (bz * 4096) + + (kh * 57344) ) - + (ax2 * 4096) + + (bz * 4096) ) - + (ic_outer * 512) + + (ax2 * 4096) ) - + tx + + (ic_outer * 512) ) - - 61312 - ), - ], - T.float16(0), - dtype="float16", - ) + + tx + ) + - 61344 + ), + ], + T.float16(0), + dtype="float16", ) with T.launch_thread(tx, 32): - Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 160)] = ( - T.if_then_else( + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 128) + ] = T.if_then_else( + ( ( ( - ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) - ) - and (1 <= (ax2 + T.floormod(bz, 14))) + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ + and (1 <= (ax2 + T.floormod(bz, 14))) + ) + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ + ( ( ( ( ( ( ( - ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) - ) - + (kh * 57344) + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) ) - + (bz * 4096) + + (kh * 57344) ) - + (ax2 * 4096) + + (bz * 4096) ) - + (ic_outer * 512) + + (ax2 * 4096) ) - + tx + + (ic_outer * 512) ) - - 61280 - ), - ], - T.float16(0), - dtype="float16", - ) + + tx + ) + - 61312 + ), + ], + T.float16(0), + dtype="float16", ) with T.launch_thread(tx, 32): - Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 192)] = ( - T.if_then_else( + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 160) + ] = T.if_then_else( + ( ( ( - ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) - ) - and (1 <= (ax2 + T.floormod(bz, 14))) + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ + and (1 <= (ax2 + T.floormod(bz, 14))) + ) + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ + ( ( ( ( ( ( ( - ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) - ) - + (kh * 57344) + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) ) - + (bz * 4096) + + (kh * 57344) ) - + (ax2 * 4096) + + (bz * 4096) ) - + (ic_outer * 512) + + (ax2 * 4096) ) - + tx + + (ic_outer * 512) ) - - 61248 - ), - ], - T.float16(0), - dtype="float16", - ) + + tx + ) + - 61280 + ), + ], + T.float16(0), + dtype="float16", ) with T.launch_thread(tx, 32): - Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 224)] = ( - T.if_then_else( + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 192) + ] = T.if_then_else( + ( ( ( - ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) - ) - and (1 <= (ax2 + T.floormod(bz, 14))) + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ + and (1 <= (ax2 + T.floormod(bz, 14))) + ) + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ + ( ( ( ( ( ( ( - ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) - ) - + (kh * 57344) + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) ) - + (bz * 4096) + + (kh * 57344) ) - + (ax2 * 4096) + + (bz * 4096) ) - + (ic_outer * 512) + + (ax2 * 4096) ) - + tx + + (ic_outer * 512) ) - - 61216 - ), - ], - T.float16(0), - dtype="float16", - ) + + tx + ) + - 61248 + ), + ], + T.float16(0), + dtype="float16", ) with T.launch_thread(tx, 32): - Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 256)] = ( - T.if_then_else( + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 224) + ] = T.if_then_else( + ( ( ( - ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) - ) - and (1 <= (ax2 + T.floormod(bz, 14))) + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ + and (1 <= (ax2 + T.floormod(bz, 14))) + ) + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ + ( ( ( ( ( ( ( - ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) - ) - + (kh * 57344) + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) ) - + (bz * 4096) + + (kh * 57344) ) - + (ax2 * 4096) + + (bz * 4096) ) - + (ic_outer * 512) + + (ax2 * 4096) ) - + tx + + (ic_outer * 512) ) - - 61184 - ), - ], - T.float16(0), - dtype="float16", - ) + + tx + ) + - 61216 + ), + ], + T.float16(0), + dtype="float16", ) with T.launch_thread(tx, 32): - Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 288)] = ( - T.if_then_else( + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 256) + ] = T.if_then_else( + ( ( ( - ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) - ) - and (1 <= (ax2 + T.floormod(bz, 14))) + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ + and (1 <= (ax2 + T.floormod(bz, 14))) + ) + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ + ( ( ( ( ( ( ( - ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) - ) - + (kh * 57344) + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) ) - + (bz * 4096) + + (kh * 57344) ) - + (ax2 * 4096) + + (bz * 4096) ) - + (ic_outer * 512) + + (ax2 * 4096) ) - + tx + + (ic_outer * 512) ) - - 61152 - ), - ], - T.float16(0), - dtype="float16", - ) + + tx + ) + - 61184 + ), + ], + T.float16(0), + dtype="float16", ) with T.launch_thread(tx, 32): - Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 320)] = ( - T.if_then_else( + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 288) + ] = T.if_then_else( + ( ( ( - ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) - ) - and (1 <= (ax2 + T.floormod(bz, 14))) + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ + and (1 <= (ax2 + T.floormod(bz, 14))) + ) + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ + ( ( ( ( ( ( ( - ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) - ) - + (kh * 57344) + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) ) - + (bz * 4096) + + (kh * 57344) ) - + (ax2 * 4096) + + (bz * 4096) ) - + (ic_outer * 512) + + (ax2 * 4096) ) - + tx + + (ic_outer * 512) ) - - 61120 - ), - ], - T.float16(0), - dtype="float16", - ) + + tx + ) + - 61152 + ), + ], + T.float16(0), + dtype="float16", ) with T.launch_thread(tx, 32): - Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 352)] = ( - T.if_then_else( + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 320) + ] = T.if_then_else( + ( ( ( - ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) - ) - and (1 <= (ax2 + T.floormod(bz, 14))) + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ + and (1 <= (ax2 + T.floormod(bz, 14))) + ) + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ + ( ( ( ( ( ( ( - ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) - ) - + (kh * 57344) + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) ) - + (bz * 4096) + + (kh * 57344) ) - + (ax2 * 4096) + + (bz * 4096) ) - + (ic_outer * 512) + + (ax2 * 4096) ) - + tx + + (ic_outer * 512) ) - - 61088 - ), - ], - T.float16(0), - dtype="float16", - ) + + tx + ) + - 61120 + ), + ], + T.float16(0), + dtype="float16", ) with T.launch_thread(tx, 32): - Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 384)] = ( - T.if_then_else( + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 352) + ] = T.if_then_else( + ( ( ( - ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) - ) - and (1 <= (ax2 + T.floormod(bz, 14))) + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ + and (1 <= (ax2 + T.floormod(bz, 14))) + ) + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ + ( ( ( ( ( ( ( - ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) - ) - + (kh * 57344) + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) ) - + (bz * 4096) + + (kh * 57344) ) - + (ax2 * 4096) + + (bz * 4096) ) - + (ic_outer * 512) + + (ax2 * 4096) ) - + tx + + (ic_outer * 512) ) - - 61056 - ), - ], - T.float16(0), - dtype="float16", - ) + + tx + ) + - 61088 + ), + ], + T.float16(0), + dtype="float16", ) with T.launch_thread(tx, 32): - Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 416)] = ( - T.if_then_else( + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 384) + ] = T.if_then_else( + ( ( ( - ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) - ) - and (1 <= (ax2 + T.floormod(bz, 14))) + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ + and (1 <= (ax2 + T.floormod(bz, 14))) + ) + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ + ( ( ( ( ( ( ( - ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) - ) - + (kh * 57344) + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) ) - + (bz * 4096) + + (kh * 57344) ) - + (ax2 * 4096) + + (bz * 4096) ) - + (ic_outer * 512) + + (ax2 * 4096) ) - + tx + + (ic_outer * 512) ) - - 61024 - ), - ], - T.float16(0), - dtype="float16", - ) + + tx + ) + - 61056 + ), + ], + T.float16(0), + dtype="float16", ) with T.launch_thread(tx, 32): - Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 448)] = ( - T.if_then_else( + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 416) + ] = T.if_then_else( + ( ( ( - ( - (1 <= (T.floordiv(bz, 14) + kh)) - and ((T.floordiv(bz, 14) + kh) < 15) - ) - and (1 <= (ax2 + T.floormod(bz, 14))) + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) ) - and ((ax2 + T.floormod(bz, 14)) < 15) - ), - A_1[ + and (1 <= (ax2 + T.floormod(bz, 14))) + ) + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ + ( ( ( ( ( ( ( - ( - ((bx * 6422528) + (ty * 1605632)) - + (tz * 802816) - ) - + (kh * 57344) + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) ) - + (bz * 4096) + + (kh * 57344) ) - + (ax2 * 4096) + + (bz * 4096) ) - + (ic_outer * 512) + + (ax2 * 4096) ) - + tx + + (ic_outer * 512) ) - - 60992 - ), - ], - T.float16(0), - dtype="float16", - ) + + tx + ) + - 61024 + ), + ], + T.float16(0), + dtype="float16", ) - T.launch_thread(tx, 32) - Apad_shared[(((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 480)] = ( - T.if_then_else( - ( + with T.launch_thread(tx, 32): + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 448) + ] = T.if_then_else( ( ( - (1 <= (T.floordiv(bz, 14) + kh)) + (1 <= (T.floordiv(bz, 14) + kh) and ((T.floordiv(bz, 14) + kh) < 15) ) and (1 <= (ax2 + T.floormod(bz, 14))) @@ -1668,12 +1627,52 @@ def func( ) + tx ) - - 60960 + - 60992 ), ], T.float16(0), dtype="float16", ) + T.launch_thread(tx, 32) + Apad_shared[ + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 480) + ] = T.if_then_else( + ( + ( + ( + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) + ) + and (1 <= (ax2 + T.floormod(bz, 14))) + ) + and ((ax2 + T.floormod(bz, 14)) < 15) + ), + A_1[ + ( + ( + ( + ( + ( + ( + ( + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) + ) + + (kh * 57344) + ) + + (bz * 4096) + ) + + (ax2 * 4096) + ) + + (ic_outer * 512) + ) + + tx + ) + - 60960 + ), + ], + T.float16(0), + dtype="float16", ) with T.launch_thread(tx, 32): W_shared[T.ramp((((ty * 512) + (tz * 256)) + (tx * 8)), 1, 8)] = W_1[ From 7e299e9e44abc128adbad4cd82d4c41e61dd5bae Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 26 Feb 2024 19:04:39 -0500 Subject: [PATCH 18/58] lint fix --- tests/python/tvmscript/test_tvmscript_roundtrip.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index 6c21ffb13fa9..73bf200bb22a 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -1597,9 +1597,10 @@ def func( Apad_shared[ (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 448) ] = T.if_then_else( + ( ( ( - (1 <= (T.floordiv(bz, 14) + kh) + 1 <= (T.floordiv(bz, 14) + kh) and ((T.floordiv(bz, 14) + kh) < 15) ) and (1 <= (ax2 + T.floormod(bz, 14))) From b8616f2c4643ce7393297e33a89f2d96da51e513 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 26 Feb 2024 19:07:17 -0500 Subject: [PATCH 19/58] Add pylint exceptions --- python/tvm/script/parser/core/entry.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index 49c919b9dc0f..7fd0dba634fd 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -95,10 +95,10 @@ def parse( ret = builder.get() # check well-formedness in both Relax and TIR if check_well_formed: - # do the imports here to avoid a circular import at the start - # (since importing Relax will import a dependenc on the parser) - from ....relax.analysis import well_formed as relax_well_formed - from ....tir.analysis import verify_well_formed as tir_well_formed + # (C0415 = import-outside-toplevel. It is necessary here to avoid a circular dependency, + # since importing Relax imports a dependency on the parser) + from ....relax.analysis import well_formed as relax_well_formed # pylint: disable=C0415 + from ....tir.analysis import verify_well_formed as tir_well_formed # pylint: disable=C0415 check_ret = ret if not isinstance(check_ret, IRModule): @@ -108,7 +108,7 @@ def parse( parser.report_error(source_ast, err=WELL_FORMED_ERROR_MESSAGE) try: tir_well_formed(check_ret) - except Exception as err: + except Exception as err: # pylint: disable=broad-exception-caught parser.report_error( source_ast, err=f"{WELL_FORMED_ERROR_MESSAGE}\n\nTraceback: {str(err)}", From cb114e4376f7f145d8ae2664d801192dc9002552 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 26 Feb 2024 19:07:58 -0500 Subject: [PATCH 20/58] Fix whitespace --- python/tvm/script/parser/core/entry.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index 7fd0dba634fd..0c88cacf8a62 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -97,8 +97,8 @@ def parse( if check_well_formed: # (C0415 = import-outside-toplevel. It is necessary here to avoid a circular dependency, # since importing Relax imports a dependency on the parser) - from ....relax.analysis import well_formed as relax_well_formed # pylint: disable=C0415 - from ....tir.analysis import verify_well_formed as tir_well_formed # pylint: disable=C0415 + from ....relax.analysis import well_formed as relax_well_formed # pylint: disable=C0415 + from ....tir.analysis import verify_well_formed as tir_well_formed # pylint: disable=C0415 check_ret = ret if not isinstance(check_ret, IRModule): @@ -108,7 +108,7 @@ def parse( parser.report_error(source_ast, err=WELL_FORMED_ERROR_MESSAGE) try: tir_well_formed(check_ret) - except Exception as err: # pylint: disable=broad-exception-caught + except Exception as err: # pylint: disable=broad-exception-caught parser.report_error( source_ast, err=f"{WELL_FORMED_ERROR_MESSAGE}\n\nTraceback: {str(err)}", From ea357bbc59c6ab311034a1a16c8b7e977092034e Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 26 Feb 2024 21:25:07 -0500 Subject: [PATCH 21/58] Fix more failed test cases --- tests/python/codegen/test_inject_ptx_ldg32.py | 3 ++- .../micro/test_aot_legalize_packed_call.py | 5 +++-- .../test_tir_analysis_identify_memcpy.py | 1 + .../tir-schedule/test_tir_schedule_rfactor.py | 3 ++- .../test_tir_transform_convert_ssa.py | 14 +++++++++----- ...r_transform_lower_cross_thread_reduction.py | 18 ++++++++++++------ 6 files changed, 29 insertions(+), 15 deletions(-) diff --git a/tests/python/codegen/test_inject_ptx_ldg32.py b/tests/python/codegen/test_inject_ptx_ldg32.py index 8e8547c572d0..d7a92802598b 100644 --- a/tests/python/codegen/test_inject_ptx_ldg32.py +++ b/tests/python/codegen/test_inject_ptx_ldg32.py @@ -20,7 +20,8 @@ import tvm.testing -@T.prim_func +# A_local is undefined +@T.prim_func(check_well_formed=False) def vector_add(A: T.Buffer((16), "float32"), B: T.Buffer((32), "float32")) -> None: T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) bx = T.env_thread("blockIdx.x") diff --git a/tests/python/micro/test_aot_legalize_packed_call.py b/tests/python/micro/test_aot_legalize_packed_call.py index 6f66f3a43283..3e66a96dfb43 100644 --- a/tests/python/micro/test_aot_legalize_packed_call.py +++ b/tests/python/micro/test_aot_legalize_packed_call.py @@ -22,7 +22,8 @@ from tvm.script import tir as T -@tvm.script.ir_module +# complains of an undefined var being used +@tvm.script.ir_module(check_well_formed=False) class Module: @T.prim_func def tvm_test_cpacked( @@ -52,7 +53,7 @@ def tir_packed_call() -> None: ) -@tvm.script.ir_module +@tvm.script.ir_module(check_well_formed=False) class Expected: @T.prim_func def tvm_test_cpacked( diff --git a/tests/python/tir-analysis/test_tir_analysis_identify_memcpy.py b/tests/python/tir-analysis/test_tir_analysis_identify_memcpy.py index b69d3aea3ea3..b3c6a489c9f7 100644 --- a/tests/python/tir-analysis/test_tir_analysis_identify_memcpy.py +++ b/tests/python/tir-analysis/test_tir_analysis_identify_memcpy.py @@ -32,6 +32,7 @@ class BaseTest: """Utility class for defining unit tests for memcpy""" def __init_subclass__(cls): + cls.check_well_formed = False # CompareBeforeAfter has a member var cls.func = tvm.testing.CompareBeforeAfter._normalize_before(cls.func) cls.expected = pytest.fixture(cls.expected) diff --git a/tests/python/tir-schedule/test_tir_schedule_rfactor.py b/tests/python/tir-schedule/test_tir_schedule_rfactor.py index 37e68fa21a0e..a15bd3d9137b 100644 --- a/tests/python/tir-schedule/test_tir_schedule_rfactor.py +++ b/tests/python/tir-schedule/test_tir_schedule_rfactor.py @@ -951,7 +951,8 @@ def argmax_split_body_bufferstore_value_not_var( argmax_v1[i] = v_argmax_v1 -@T.prim_func +# v_unbound is unbound +@T.prim_func(check_well_formed=False) def argmax_split_body_bufferstore_value_unbound_var( idx: T.Buffer((128, 128), "int32"), val: T.Buffer((128, 128), "float32"), diff --git a/tests/python/tir-transform/test_tir_transform_convert_ssa.py b/tests/python/tir-transform/test_tir_transform_convert_ssa.py index 644ab3b624ef..371fa0b4e71c 100644 --- a/tests/python/tir-transform/test_tir_transform_convert_ssa.py +++ b/tests/python/tir-transform/test_tir_transform_convert_ssa.py @@ -327,7 +327,8 @@ class TestDeDuplicateThreadIdxAcrossMultipleFunctions(BaseBeforeAfter): def before(self): threadIdx_x = tvm.tir.Var("threadIdx_x", "int32") - @I.ir_module + # complaints of duplicate definitions of threadIdx_x + @I.ir_module(check_well_formed=False) class mod: @T.prim_func def kernel_1(A: T.Buffer([256], "float32")): @@ -350,7 +351,8 @@ def kernel_2(A: T.Buffer([256], "float32")): return mod def expected(self): - @I.ir_module + # complaints of duplicate definitions of threadIdx_x + @I.ir_module(check_well_formed=False) class mod: @T.prim_func def kernel_1(A: T.Buffer([256], "float32")): @@ -389,7 +391,8 @@ def before(self): tvm.ir.Range(0, 256), threadIdx_x, tvm.tir.IterVar.ThreadIndex, "threadIdx.x" ) - @I.ir_module + # complaints of multiple definitions for threadIdx_x + @I.ir_module(check_well_formed=False) class mod: @T.prim_func def kernel_1(A: T.Buffer([256], "float32")): @@ -404,7 +407,7 @@ def kernel_2(A: T.Buffer([256], "float32")): return mod def expected(self): - @I.ir_module + @I.ir_module(check_well_formed=False) class mod: @T.prim_func def kernel_1(A: T.Buffer([256], "float32")): @@ -445,7 +448,8 @@ def before(self): tvm.ir.Range(0, 256), threadIdx_x, tvm.tir.IterVar.ThreadIndex, "threadIdx.x" ) - @I.ir_module + # complaints of multiple definitions of threadIdx_x + @I.ir_module(check_well_formed=False) class mod: @T.prim_func def kernel_1(A: T.Buffer([256], "float32")): diff --git a/tests/python/tir-transform/test_tir_transform_lower_cross_thread_reduction.py b/tests/python/tir-transform/test_tir_transform_lower_cross_thread_reduction.py index aa55b25f1668..35b4d55ea51d 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_cross_thread_reduction.py +++ b/tests/python/tir-transform/test_tir_transform_lower_cross_thread_reduction.py @@ -116,7 +116,8 @@ def no_normal_reduction(a: T.handle, b: T.handle) -> None: B[vi] = B[vi] + A[vi, vk] -@T.prim_func +# complains that k is defined outside of a block +@T.prim_func(check_well_formed=False) def lowered_no_normal_reduction(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") @@ -162,7 +163,8 @@ def two_bound_loops(a: T.handle, b: T.handle) -> None: B[vi] = B[vi] + A[vi, vk] -@T.prim_func +# complains that ko is defined outside of a block +@T.prim_func(check_well_formed=False) def lowered_two_bound_loops(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") @@ -899,7 +901,8 @@ def reducer_max(a: T.handle, b: T.handle) -> None: B[vi] = T.max(B[vi], A[vi, vk]) -@T.prim_func +# complains that k is defined outside of a block +@T.prim_func(check_well_formed=False) def lowered_reducer_max(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") @@ -942,7 +945,8 @@ def zero_rank_buffer(a: T.handle, b: T.handle) -> None: B[()] = B[()] + A[vk] -@T.prim_func +# complains that k is defined outside of a block +@T.prim_func(check_well_formed=False) def lowered_zero_rank_buffer(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128], dtype="float32") B = T.match_buffer(b, [], dtype="float32") @@ -1572,7 +1576,8 @@ def thread_broadcast_1(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256,), " B[vi] = temp_local[vi] + T.float32(1) -@T.prim_func +# complains that k is defined outside of a block +@T.prim_func(check_well_formed=False) def lowered_thread_broadcast_1(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256,), "float32")): temp_local = T.alloc_buffer((256,), scope="local") cross_thread_temp_local = T.alloc_buffer((1,), strides=(1,), scope="local") @@ -1745,7 +1750,8 @@ def no_thread_broadcast(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256, 25 B[vi, vj] = A[vi, vj] + temp_2_local[0] -@T.prim_func +# complains that k is defined outside of a block +@T.prim_func(check_well_formed=False) def lowered_no_thread_broadcast( A: T.Buffer((256, 256), "float32"), B: T.Buffer((256, 256), "float32") ): From 5d6de7db9d09cda3d108d354a7a4f84c07564914 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 26 Feb 2024 21:29:57 -0500 Subject: [PATCH 22/58] Catch inappropriate use of decl_function instead of segfaulting --- python/tvm/script/ir_builder/ir/ir.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/tvm/script/ir_builder/ir/ir.py b/python/tvm/script/ir_builder/ir/ir.py index 0d3523ec7dd7..d35d73678b47 100644 --- a/python/tvm/script/ir_builder/ir/ir.py +++ b/python/tvm/script/ir_builder/ir/ir.py @@ -43,7 +43,7 @@ def decl_function(func_name: str, func_signature: BaseFunc) -> GlobalVar: func_name : str The function unique name. - func_signature: Optional[BaseFunc] + func_signature: BaseFunc A Function w/o body, which used to specify the function signature (i.e. func params and func return type/shape). @@ -55,7 +55,11 @@ def decl_function(func_name: str, func_signature: BaseFunc) -> GlobalVar: gv : GlobalVar The corresponding GlobalVar. """ - + if not isinstance(func_signature, BaseFunc): + raise ValueError( + "decl_function expects an instance of BaseFunc, " + f"but {func_signature} is of type {type(func_signature)}" + ) return _ffi_api.DeclFunction( # type: ignore[attr-defined] # pylint: disable=no-member func_name, func_signature ) From 574677a536ea0a1a4f1049126b836b6635108528 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 27 Feb 2024 13:14:13 -0500 Subject: [PATCH 23/58] Fix test_lower.py --- tests/python/integration/test_lower.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/integration/test_lower.py b/tests/python/integration/test_lower.py index 965ab80bebb2..1d042610ac07 100644 --- a/tests/python/integration/test_lower.py +++ b/tests/python/integration/test_lower.py @@ -22,7 +22,8 @@ from tvm.script import tir as T -@T.prim_func +# complains that index_i is defined outside of a block +@T.prim_func(check_well_formed=False) def tensorcore_gemm(handle_a: T.handle, handle_b: T.handle, handle_c: T.handle) -> None: # pylint: disable=missing-function-docstring # match buffer From 08819aa2ff042d833ec9dd1190afc5e707fea4a3 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 27 Feb 2024 13:28:39 -0500 Subject: [PATCH 24/58] Mark purity in test_relax_2d_buffer_allocation.py --- .../contrib/test_hexagon/test_relax_2d_buffer_allocation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py b/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py index 40de28cca0a8..ae459dc770d7 100644 --- a/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py +++ b/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py @@ -46,7 +46,7 @@ def add( T.writes(output[v_ax0, v_ax1]) output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1] - @R.function + @R.function(pure=False) def main(x: R.Tensor((2, 2), dtype="float32")): cls = Module # Try allocating 2d storage (2,2) in global.vtcm scope with nd allocator From f93c406c6135a556faef155f0eb42061d9361a21 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 27 Feb 2024 13:32:35 -0500 Subject: [PATCH 25/58] Mark purity in test_dma_builtin.py --- tests/python/contrib/test_hexagon/test_dma_builtin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/contrib/test_hexagon/test_dma_builtin.py b/tests/python/contrib/test_hexagon/test_dma_builtin.py index af82c2b55afd..e1c98ac35650 100644 --- a/tests/python/contrib/test_hexagon/test_dma_builtin.py +++ b/tests/python/contrib/test_hexagon/test_dma_builtin.py @@ -49,7 +49,7 @@ def compute_add_in_vtcm(a: T.handle, b: T.handle, c: T.handle) -> None: T.writes(C[v_ax0]) C[v_ax0] = A[v_ax0] + B[v_ax0] - @R.function + @R.function(pure=False) def main( x: R.Tensor((12800,), data_type), y: R.Tensor((12800,), data_type), From 299abbc3a0f345ab31058b45bf15535b9d80f3f3 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 27 Feb 2024 13:35:10 -0500 Subject: [PATCH 26/58] Remove __tvm_meta___ from test_tir_usmp_analysis_extract_bufferinfo.py --- .../tir-usmp/test_tir_usmp_analysis_extract_bufferinfo.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/python/tir-usmp/test_tir_usmp_analysis_extract_bufferinfo.py b/tests/python/tir-usmp/test_tir_usmp_analysis_extract_bufferinfo.py index 662f86479c09..4a55c6c84bfb 100644 --- a/tests/python/tir-usmp/test_tir_usmp_analysis_extract_bufferinfo.py +++ b/tests/python/tir-usmp/test_tir_usmp_analysis_extract_bufferinfo.py @@ -18,6 +18,7 @@ import sys import tvm +import tvm.testing from tvm import tir, script from tvm.ir import Range from tvm.script import tir as T @@ -171,7 +172,6 @@ def run_model(input: T.handle, output: T.handle) -> None: T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32")) - __tvm_meta__ = None # fmt: on @@ -653,7 +653,6 @@ def run_model(input: T.handle, output: T.handle) -> None: T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast_1", sid_4, sid_32, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__2", sid_32, T.lookup_param("p17", dtype="handle"), T.lookup_param("p18", dtype="handle"), sid_31, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_concatenate", sid_2, sid_19, sid_25, sid_31, output, dtype="int32")) - __tvm_meta__ = None # fmt: on From fda20a363bc245d9f9c0223fff4c59071bee965d Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 27 Feb 2024 13:38:45 -0500 Subject: [PATCH 27/58] Suppress well-formed check in test_tir_transform_convert_blocks_to_opaque.py --- .../test_tir_transform_convert_blocks_to_opaque.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/tir-transform/test_tir_transform_convert_blocks_to_opaque.py b/tests/python/tir-transform/test_tir_transform_convert_blocks_to_opaque.py index 8fbbaf59bb58..f920a46ba57e 100644 --- a/tests/python/tir-transform/test_tir_transform_convert_blocks_to_opaque.py +++ b/tests/python/tir-transform/test_tir_transform_convert_blocks_to_opaque.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm +import tvm.testing from tvm import tir, te from tvm.script import tir as T @@ -84,6 +85,7 @@ def test_lower_te(): class TestErrorIfPredicateUsesBlockVariables(tvm.testing.CompareBeforeAfter): transform = tvm.tir.transform.ConvertBlocksToOpaque() + check_well_formed = False def before(A: T.Buffer(8, "int32")): for i in T.serial(8): From 8b9506d7715161fc9422f714996f1cca6e508466 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 27 Feb 2024 13:41:18 -0500 Subject: [PATCH 28/58] Remove __tvm_meta__ in test_tir_usmp_algo.py --- tests/python/tir-usmp/test_tir_usmp_algo.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/python/tir-usmp/test_tir_usmp_algo.py b/tests/python/tir-usmp/test_tir_usmp_algo.py index 265e6fe5d5d5..b9cfde485633 100644 --- a/tests/python/tir-usmp/test_tir_usmp_algo.py +++ b/tests/python/tir-usmp/test_tir_usmp_algo.py @@ -359,7 +359,6 @@ def run_model(input: T.handle, output: T.handle) -> None: T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32")) - __tvm_meta__ = None # fmt: on @@ -530,7 +529,6 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place Conv2dOutput[ff] = Conv2dOutput[ff] + T.cast(PaddedInput[ax0_ax1_fused_ax2_fused * 64 + rc], "int32") * T.cast(placeholder_8[rc * 64 + ff], "int32") for ax3_inner_1 in T.serial(0, 64): T_cast_3[ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput[ax3_inner_1] + placeholder_9[ax3_inner_1], 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16") - __tvm_meta__ = None # fmt: on From 4944c3a5c1eaafd0aca3b849fd789887204e6d0f Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 27 Feb 2024 13:44:02 -0500 Subject: [PATCH 29/58] Remove __tvm_meta__ from more USMP tests --- .../tir-usmp/test_tir_usmp_analysis_extract_bufferinfo.py | 2 -- ...st_tir_usmp_transform_convert_pool_allocations_to_offsets.py | 1 - tests/python/tir-usmp/test_tir_usmp_utils.py | 1 - 3 files changed, 4 deletions(-) diff --git a/tests/python/tir-usmp/test_tir_usmp_analysis_extract_bufferinfo.py b/tests/python/tir-usmp/test_tir_usmp_analysis_extract_bufferinfo.py index 4a55c6c84bfb..f8da0ef9f42d 100644 --- a/tests/python/tir-usmp/test_tir_usmp_analysis_extract_bufferinfo.py +++ b/tests/python/tir-usmp/test_tir_usmp_analysis_extract_bufferinfo.py @@ -245,7 +245,6 @@ def run_model(input: T.handle, output: T.handle) -> None: T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", input, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), output, dtype="int32")) -__tvm_meta__ = None # fmt: on @@ -286,7 +285,6 @@ def run_model(input: T.handle, output: T.handle) -> None: T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", input, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), output, dtype="int32")) -__tvm_meta__ = None # fmt: on diff --git a/tests/python/tir-usmp/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/tir-usmp/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index 03929c5436be..9e9fea7c8152 100644 --- a/tests/python/tir-usmp/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/tir-usmp/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -509,7 +509,6 @@ def __tvm_main__(input: T.handle, global_workspace_0_var: T.handle("uint8"), out T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8_let, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7_let, global_workspace_0_buffer_var.data, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7_let, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6_let, global_workspace_0_buffer_var.data, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2_let, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6_let, output, global_workspace_0_buffer_var.data, dtype="int32")) - __tvm_meta__ = None # fmt: on diff --git a/tests/python/tir-usmp/test_tir_usmp_utils.py b/tests/python/tir-usmp/test_tir_usmp_utils.py index 0fece9dcd263..635c9a760f87 100644 --- a/tests/python/tir-usmp/test_tir_usmp_utils.py +++ b/tests/python/tir-usmp/test_tir_usmp_utils.py @@ -91,7 +91,6 @@ def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32")) - __tvm_meta__ = None # fmt: on From 7889711aa7706f561d5dc23e31e634eea48bd46e Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 27 Feb 2024 13:45:08 -0500 Subject: [PATCH 30/58] Fix incorrect var in test_tir_transform_storage_flatten.py --- .../python/tir-transform/test_tir_transform_storage_flatten.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/tir-transform/test_tir_transform_storage_flatten.py b/tests/python/tir-transform/test_tir_transform_storage_flatten.py index f09645462366..8ddfbb5adfd3 100644 --- a/tests/python/tir-transform/test_tir_transform_storage_flatten.py +++ b/tests/python/tir-transform/test_tir_transform_storage_flatten.py @@ -153,7 +153,7 @@ def main(): @T.prim_func def tir_func(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [2, 2]) - B = T.match_buffer(a, [2, 2]) + B = T.match_buffer(b, [2, 2]) A[0, 1] = B[1, 1] From e11cf8cc5c7d5897896f7ad3cbdf80aa331d2c40 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 27 Feb 2024 13:55:28 -0500 Subject: [PATCH 31/58] Remove all remaining instances of __tvm_meta__ --- .../contrib/test_ethosu/test_create_tiles.py | 8 -------- .../test_ethosu/test_encode_constants.py | 9 --------- .../test_ethosu/test_merge_constants.py | 8 -------- .../test_ethosu/test_remove_concatenates.py | 1 - .../test_ethosu/test_replace_conv2d.py | 12 ------------ .../contrib/test_ethosu/test_replace_copy.py | 2 -- .../contrib/test_ethosu/test_scheduler.py | 1 - .../test_ethosu/test_tir_to_cs_translator.py | 19 ------------------- .../contrib/test_ethosu/test_vela_api.py | 6 ------ 9 files changed, 66 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_create_tiles.py b/tests/python/contrib/test_ethosu/test_create_tiles.py index e4b4067a2977..ac90e3c27839 100644 --- a/tests/python/contrib/test_ethosu/test_create_tiles.py +++ b/tests/python/contrib/test_ethosu/test_create_tiles.py @@ -56,8 +56,6 @@ def main(placeholder1: T.Buffer((100,), "int8"), placeholder2: T.Buffer((100,), for i3 in T.serial(0, 1): for i4 in T.serial(0, 16): placeholder1[((i1*16) + i4)] = placeholder2[((T.floormod((i1 + 4), 6)*16) + i4)] - - __tvm_meta__ = None # fmt: on stmt = Module["main"].body @@ -87,8 +85,6 @@ def main(placeholder1: T.Buffer((100,), "int8"), placeholder2: T.Buffer((100,), for i3 in T.serial(0, 6): for i4 in T.serial(0, 16): placeholder1[((i3*16) + i4)] = placeholder2[((T.floormod((i3 + 4), 6)*16) + i4)] - - __tvm_meta__ = None # fmt: on stmt = Module["main"].body @@ -118,8 +114,6 @@ def main(placeholder1: T.Buffer((100,), "int8"), placeholder2: T.Buffer((100,), for i3 in T.serial(0, 1): for i4 in T.serial(0, 16): placeholder1[((i1*16) + i4)] = placeholder2[((T.floormod((i1 + 4), 6)*8) + i4)] - - __tvm_meta__ = None # fmt: on stmt = Module["main"].body @@ -148,8 +142,6 @@ def main(placeholder1: T.Buffer((100,), "int8"), placeholder2: T.Buffer((100,), for i2 in T.serial(0, 6): for i3 in T.serial(0, 4): placeholder1[(((i1*24) + (i2*4)) + i3)] = placeholder2[(((((T.floordiv((i1 - 1), 2)*48) + (T.floormod((i1 + 1), 2)*24)) + (i2*4)) + i3) + 96)] - - __tvm_meta__ = None # fmt: on stmt = Module["main"].body diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 4341f367f0e1..b16fe2a85a34 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -60,7 +60,6 @@ def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_writ T.evaluate(T.call_extern("ethosu_copy", buffer7[0], 144, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, buffer9[0], 112, T.int8(-1), T.int8(-1), 12, buffer9[112], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 112, T.int8(-1), T.int8(-1), 12, p2[112], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None @tvm.script.ir_module @@ -89,7 +88,6 @@ def main(ifm: T.Buffer((1, 16, 16, 32), "int8"), ethosu_write: T.Buffer((1, 16, T.call_extern("handle", "ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, ifm_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write_1[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2_global_4_1[0], 80, p2_global_4_1[80], 80, 12, p2_global_4_1[160], 16, p2_global_4_1[176], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0) T.call_extern("handle", "ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, ifm_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write_1[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2_global_5_1[0], 96, p2_global_5_1[96], 80, 12, p2_global_5_1[176], 16, p2_global_5_1[192], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0) T.call_extern("handle", "ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, ifm_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write_1[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2_global_6_1[0], 80, p2_global_6_1[80], 80, 12, p2_global_6_1[160], 16, p2_global_6_1[176], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0) - __tvm_meta__ = None # fmt: on @@ -168,7 +166,6 @@ def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_writ T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 384, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 304, T.int8(-1), T.int8(-1), 12, p1[304], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[64], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 304, T.int8(-1), T.int8(-1), 12, p2[304], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None @tvm.script.ir_module @@ -190,8 +187,6 @@ def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_writ T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1[0], 464, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[64], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 192, p2[192], 176, 12, p2[368], 48, p2[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - - __tvm_meta__ = None # fmt: on @@ -269,7 +264,6 @@ def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_writ ethosu_write_1 = T.Buffer([4096], "int8", data=ethosu_write_1_data) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer[0], 592, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, buffer_2[0], 160, T.int8(-1), T.int8(-1), 12, buffer_3[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None @tvm.script.ir_module @@ -290,7 +284,6 @@ def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_writ ethosu_write_2 = T.Buffer([4096], "int8", data=ethosu_write_2_data) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_encoded[0], 304, placeholder_encoded[304], 304, 12, placeholder_encoded_1[0], 80, placeholder_encoded_1[80], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_encoded_2[0], 112, placeholder_encoded_2[112], 96, 12, placeholder_encoded_3[0], 48, placeholder_encoded_3[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None # fmt: on @@ -380,7 +373,6 @@ def main(input_ifm: T.Buffer((1,16,16,32), "int8"), input_ethosu_write: T.Buffer T.evaluate(T.call_extern("ethosu_copy", buffer7[0], 112, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 80, T.int8(-1), T.int8(-1), 12, p1[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, T.int8(-1), T.int8(-1), 12, p2[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None @tvm.script.ir_module @@ -414,7 +406,6 @@ def main(ifm: T.Buffer((1, 16, 16, 32), "int8"), ethosu_write: T.Buffer((1, 16, T.call_extern("handle", "ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write_3[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p5_global_4[0], 48, p5_global_4[48], 48, 12, p5_global_4[96], 16, p5_global_4[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0) T.call_extern("handle", "ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write_3[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p5_global_5[0], 48, p5_global_5[48], 48, 12, p5_global_5[96], 16, p5_global_5[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0) T.call_extern("handle", "ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write_3[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p5_global_6[0], 48, p5_global_6[48], 48, 12, p5_global_6[96], 16, p5_global_6[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0) - __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_merge_constants.py b/tests/python/contrib/test_ethosu/test_merge_constants.py index 624bef00c7f8..1c22f0b5c8b6 100644 --- a/tests/python/contrib/test_ethosu/test_merge_constants.py +++ b/tests/python/contrib/test_ethosu/test_merge_constants.py @@ -413,7 +413,6 @@ def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), buffer1: T.Buffer T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 368, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 96, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p2[0], 48, p2[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None @tvm.script.ir_module @@ -430,7 +429,6 @@ def main(input_placeholder: T.Buffer((1,16,16,32), "int8"), buffer1: T.Buffer((4 p1 = T.Buffer([464], "uint8", data=p1_data) T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 464, p1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None # fmt: on const_dict = { @@ -470,7 +468,6 @@ def main(input_placeholder: T.Buffer((1,16,16,32), "int8"), buffer1: T.Buffer((3 T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 368, p3[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 96, p4[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 192, p3[192], 176, 12, p4[0], 48, p4[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None @tvm.script.ir_module @@ -491,7 +488,6 @@ def main(input_placeholder: T.Buffer((1,16,16,32), "int8"), buffer1: T.Buffer((4 T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 464, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 192, p2[192], 176, 12, p2[368], 48, p2[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None # fmt: on const_dict = { @@ -536,7 +532,6 @@ def main(input_placeholder: T.Buffer((1,16,16,32), "int8"), buffer1: T.Buffer((3 T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 368, p3[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 96, p4[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 192, p3[192], 176, 12, p4[0], 48, p4[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None @tvm.script.ir_module @@ -557,7 +552,6 @@ def main(input_placeholder: T.Buffer((1,16,16,32), "int8"), buffer1: T.Buffer((4 T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 464, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 192, p2[192], 176, 12, p2[368], 48, p2[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None # fmt: on const_dict = { @@ -602,7 +596,6 @@ def main(input_placeholder: T.Buffer((1,16,16,32), "int8"), buffer1: T.Buffer((3 T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 368, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 96, p4[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 192, p2[192], 176, 12, p4[0], 48, p4[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None @tvm.script.ir_module @@ -623,7 +616,6 @@ def main(input_placeholder: T.Buffer((1,16,16,32), "int8"), buffer1: T.Buffer((4 T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 464, p2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 192, p2[192], 176, 12, p2[368], 48, p2[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None # fmt: on const_dict = { diff --git a/tests/python/contrib/test_ethosu/test_remove_concatenates.py b/tests/python/contrib/test_ethosu/test_remove_concatenates.py index ef034930d7bc..56205bcd73d4 100644 --- a/tests/python/contrib/test_ethosu/test_remove_concatenates.py +++ b/tests/python/contrib/test_ethosu/test_remove_concatenates.py @@ -54,7 +54,6 @@ def main(input_placeholder: T.Buffer((1,8,12,16), "int8"), input_placeholder_1: T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, T_concat_1[192], 0, 0, 0, T.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T_concat[352], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, buffer_2[0], 2992, T.int8(-1), T.int8(-1), 12, buffer_3[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 12, 16, 8, 0, 12, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 192, 16, 1, "int8", 8, 12, 16, 8, 0, 12, T_concat_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, buffer_4[0], 2992, T.int8(-1), T.int8(-1), 12, buffer_5[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 22, 16, 8, 0, 22, T_concat_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 22, 16, 8, 0, 22, T_concat[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, buffer_6[0], 2992, T.int8(-1), T.int8(-1), 12, buffer_7[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index 32d1303e124e..ff47343811f5 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -383,7 +383,6 @@ def main(input_placeholder_5: T.Buffer((1, 8, 8, 3), "int8"), input_ethosu_write T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, buffer[0], 304, T.int8(-1), T.int8(-1), 12, buffer_1[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, placeholder_5[12], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, buffer_3[0], 160, T.int8(-1), T.int8(-1), 12, buffer_2[0], 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, ethosu_write_1[32], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, buffer[0], 304, T.int8(-1), T.int8(-1), 12, buffer_1[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None @tvm.script.ir_module @@ -405,7 +404,6 @@ def main(input_placeholder_5: T.Buffer((1, 8, 8, 3), "int8"), input_ethosu_write T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, buffer_3[0], 2608, T.int8(-1), T.int8(-1), 12, buffer[0], 80, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[48], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, buffer_2[0], 1312, T.int8(-1), T.int8(-1), 12, buffer_1[0], 320, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, ethosu_write_1[256], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, buffer_3[0], 2608, T.int8(-1), T.int8(-1), 12, buffer[0], 80, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None @tvm.script.ir_module @@ -430,7 +428,6 @@ def main(input_placeholder_5: T.Buffer((1, 16, 16, 3), "int8"), input_ethosu_wri T.evaluate(T.call_extern("ethosu_conv2d", "int8", 10, 8, 32, 10, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, ethosu_write_1[256], 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, buffer[0], 1744, T.int8(-1), T.int8(-1), 12, buffer_1[0], 80, T.int8(-1), T.int8(-1), 0, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 16, 3, 4, 0, 16, placeholder_5[576], 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 4, 8, 32, 4, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, buffer_3[0], 880, T.int8(-1), T.int8(-1), 12, buffer_2[0], 320, T.int8(-1), T.int8(-1), 0, 1, 2, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 32, 4, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 4, 8, 4, 0, 4, ethosu_write_1[512], 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, buffer[0], 1744, T.int8(-1), T.int8(-1), 12, buffer_1[0], 80, T.int8(-1), T.int8(-1), 0, 1, 2, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None @tvm.script.ir_module @@ -452,7 +449,6 @@ def main(input_placeholder_5: T.Buffer((1, 8, 1, 8, 16), "int8"), input_ethosu_w T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[384], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, buffer_3[0], 11040, T.int8(-1), T.int8(-1), 12, buffer_2[0], 272, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[256], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, buffer[0], 1456, T.int8(-1), T.int8(-1), 12, buffer_1[0], 352, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, ethosu_write_1[1024], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, buffer_3[0], 11040, T.int8(-1), T.int8(-1), 12, buffer_2[0], 272, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None @tvm.script.ir_module @@ -474,7 +470,6 @@ def main(input_placeholder: T.Buffer((1, 8, 8, 3), "int8"), input_ethosu_write: T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 32, 8, 16, 0, 32, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 8, 1, 1, 1, 1, 1, 1, 1, buffer_2[0], 304, T.int8(-1), T.int8(-1), 12, buffer_3[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 3, 4, 0, 8, placeholder[96], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 32, 1, 1, 1, 1, 1, 1, 1, buffer[0], 160, T.int8(-1), T.int8(-1), 12, buffer_1[0], 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 32, 8, 16, 0, 32, ethosu_write[4096], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 8, 1, 1, 1, 1, 1, 1, 1, buffer_2[0], 304, T.int8(-1), T.int8(-1), 12, buffer_3[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None @tvm.script.ir_module @@ -494,7 +489,6 @@ def main(input_placeholder: T.Buffer((1, 8, 1, 8, 16), "int8"), input_ethosu_wri ethosu_write_1 = T.Buffer([12288], "int8", data=ethosu_write_1_data) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 3, 8, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 16, 16, 35, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 768, 16, 256, 3, 3, 1, 1, 1, 1, buffer[0], 1456, T.int8(-1), T.int8(-1), 12, buffer_1[0], 352, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NEAREST", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 35, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 768, 16, 256, "int8", 32, 32, 26, 32, 0, 32, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 1024, 16, 512, 3, 3, 1, 1, 1, 1, buffer_2[0], 11040, T.int8(-1), T.int8(-1), 12, buffer_3[0], 272, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NEAREST", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None # fmt: on @@ -657,7 +651,6 @@ def main(input_placeholder_3: T.Buffer((1, 10, 12, 8), "int8"), input_ethosu_wri ethosu_write_1 = T.Buffer([1024], 'int8', data=input_ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8, 0, 8, placeholder_3[120], 0, 0, 0, T.float32(0.5), 10, "NHWC", 96, 8, 1, "int8", 8, 8, 16, 8, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 3, 3, 1, 1, 1, 1, buffer[0], 848, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None @tvm.script.ir_module @@ -672,7 +665,6 @@ def main(input_placeholder_3: T.Buffer((1, 7, 9, 5), "int8"), input_ethosu_write ethosu_write_1 = T.Buffer([240], 'int8', data=input_ethosu_write_1.data) # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 3, 5, 3, 3, 0, 5, placeholder_3[146], 0, 0, 0, T.float32(0.5), 10, "NHWC", 45, 5, 1, "int8", 3, 5, 16, 3, 0, 5, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 80, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 656, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None # fmt: on @@ -717,7 +709,6 @@ def main(input_placeholder_3: T.Buffer((4, 6, 8, 1), "int8"), input_ethosu_write # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None @tvm.script.ir_module @@ -733,7 +724,6 @@ def main(input_placeholder_3: T.Buffer((1, 24, 8), "int8"), input_ethosu_write_1 # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None @tvm.script.ir_module @@ -749,7 +739,6 @@ def main(input_placeholder_3: T.Buffer((192, 1), "int8"), input_ethosu_write_1: # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None @tvm.script.ir_module @@ -764,7 +753,6 @@ def main(placeholder_3: T.Buffer((192,), "int8"), input_ethosu_write_1: T.Buffer # body T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index 94763c5d3fbf..7329b40d5851 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -45,7 +45,6 @@ def main(input_placeholder_3: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_wr placeholder_global = T.Buffer([384], "uint8", data=placeholder_global_data) T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 384, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 304, T.int8(-1), T.int8(-1), 12, placeholder_global[304], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None # fmt: on @@ -94,7 +93,6 @@ def main(input_placeholder_5: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_wr T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 336, placeholder_d_global_1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 10, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_d_global[0], 416, T.int8(-1), T.int8(-1), 12, placeholder_d_global[416], 112, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 6, 16, 0, 16, ethosu_write_1[10], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_d_global_1[0], 272, T.int8(-1), T.int8(-1), 12, placeholder_d_global_1[272], 64, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py index e7abb707a69c..74394fbd46bf 100644 --- a/tests/python/contrib/test_ethosu/test_scheduler.py +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -234,7 +234,6 @@ def main(input_placeholder: T.Buffer((1, 56, 56, 96), "int8"), input_ethosu_writ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 96, 56, 0, 56, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 5376, 96, 1, "int8", 56, 56, 24, 56, 0, 56, p5[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, p1[0], 2608, T.int8(-1), T.int8(-1), 12, p1[2608], 240, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 24, 56, 0, 56, p5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, p6[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, p2[0], 736, T.int8(-1), T.int8(-1), 12, p2[736], 240, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 56, 56, 24, 56, 0, 56, p5[0], 0, 0, 0,T.float32(1), 0, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, p6[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, ethosu_write[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "ADD", 0, "NONE", 0, 0, "TFL", 0, 0, 0, 0, 0, 0, dtype="handle")) - __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py index 05d6f71037fa..9a61be6b975d 100644 --- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py +++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py @@ -146,7 +146,6 @@ def main(placeholder: T.Buffer((8192,), "int8"), ethosu_write: T.Buffer((2048,), T.evaluate(T.call_extern("ethosu_copy", buffer_6[0], 112, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_7[0], 32, placeholder_d_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 112, T.int8(-1), T.int8(-1), 12, placeholder_d_global[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None # fmt: on @@ -199,7 +198,6 @@ def main(placeholder: T.Buffer((8192,), "int8"), ethosu_write: T.Buffer((2048,), T.evaluate(T.call_extern("ethosu_copy", buffer_8[0], 80, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_9[0], 32, placeholder_d_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, T.int8(-1), T.int8(-1), 12, placeholder_d_global[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - __tvm_meta__ = None # fmt: on @@ -558,7 +556,6 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle ethosu_depthwise_conv2d_1 = T.match_buffer(ethosu_depthwise_conv2d, [126], dtype="int8", elem_offset=0, align=64, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 8, 8, 3, 8, 0, 8, placeholder_3[0], 0, 0, 0, T.float32(0.6), 11, "NHWC", 24, 3, 1, "int8", 6, 7, 3, 6, 0, 7, ethosu_depthwise_conv2d_1[0], 0, 0, 0, T.float32(0.26), 15, "NHWC", 21, 3, 1, 2, 3, 1, 1, 1, 1, placeholder_4[0], 18, 13, placeholder_5[0], 30, 0, 0, 0, 0, "CLIP", 15, 105, "TFL", "NONE", 0, 0, 0, dtype="int8")) - __tvm_meta__ = None # fmt: on @@ -1039,7 +1036,6 @@ def main(placeholder: T.handle, placeholder_3: T.handle, ethosu_write: T.handle) ethosu_write_2 = T.match_buffer(ethosu_write, [75], dtype="int8", elem_offset=0, align=64, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_pooling", "int8", 5, 9, 3, 5, 0, 9, placeholder_4[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 5, 3, 5, 0, 5, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 15, 3, 1, "AVG", 2, 3, 2, 1, 1, 1, 1, 1, 1, 0, "CLIP", 10, 100, "TFL", "NONE", 0, 0, 0, dtype="int8")) - __tvm_meta__ = None # fmt: on @@ -1116,8 +1112,6 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: ) # body T.evaluate(T.call_extern( "ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, placeholder_2[135], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "ADD", 0, "CLIP", 10, 100, "TFL", 0, 0, 0, 0, 0, 0, dtype="int8")) - - __tvm_meta__ = None # fmt: on # fmt: off @@ -1132,7 +1126,6 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: ethosu_write_2 = T.match_buffer(ethosu_write, [135], dtype="int8", elem_offset=0, align=64, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, placeholder_2[135], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "SUB", 0, "CLIP", 10, 100, "TFL", 0, 0, 0, 0, 0, 0, dtype="int8")) - __tvm_meta__ = None # fmt: on # fmt: off @@ -1147,7 +1140,6 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: ethosu_write_2 = T.match_buffer(ethosu_write, [135], dtype="int8", elem_offset=0, align=64, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, placeholder_2[135], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "MUL", 0, "CLIP", 10, 100, "TFL", 0, 0, 0, 0, 0, 0, dtype="int8")) - __tvm_meta__ = None # fmt: on @@ -1163,7 +1155,6 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: ethosu_write_2 = T.match_buffer(ethosu_write, [135], dtype="int8", elem_offset=0, align=64, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, placeholder_2[135], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "MIN", 0, "CLIP", 10, 100, "TFL", 0, 0, 0, 0, 0, 0, dtype="int8")) - __tvm_meta__ = None # fmt: on @@ -1179,7 +1170,6 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: ethosu_write_2 = T.match_buffer(ethosu_write, [135], dtype="int8", elem_offset=0, align=64, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, placeholder_2[135], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "MAX", 0, "CLIP", 10, 100, "TFL", 0, 0, 0, 0, 0, 0, dtype="int8")) - __tvm_meta__ = None # fmt: on @@ -1195,7 +1185,6 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: ethosu_write_2 = T.match_buffer(ethosu_write, [135], dtype="int32", elem_offset=0, align=64, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 5, 9, 3, 5, 0, 9, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, placeholder_2[135], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "SHR", 0, "NONE", 0, 0, "TFL", 0, 0, 0, 0, 0, 0, dtype="int32")) - __tvm_meta__ = None # fmt: on @@ -1211,7 +1200,6 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: ethosu_write_2 = T.match_buffer(ethosu_write, [135], dtype="int32", elem_offset=0, align=64, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 5, 9, 3, 5, 0, 9, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, placeholder_2[135], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "SHL", 0, "CLIP", 10, 100, "TFL", 0, 0, 0, 0, 0, 0, dtype="int32")) - __tvm_meta__ = None # fmt: on @@ -1332,7 +1320,6 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: ethosu_write_2 = T.match_buffer(ethosu_write, [24], dtype="int8", elem_offset=0, align=64, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "ADD", 1, "CLIP", 10, 100, "TFL", 0, 0, 0, 0, 0, 0, dtype="int8")) - __tvm_meta__ = None # fmt: on # fmt: off @@ -1347,7 +1334,6 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: ethosu_write_2 = T.match_buffer(ethosu_write, [24], dtype="int8", elem_offset=0, align=64, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "SUB", 1, "CLIP", 10, 100, "TFL", 0, 0, 0, 0, 0, 0, dtype="int8")) - __tvm_meta__ = None # fmt: on # fmt: off @@ -1362,7 +1348,6 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: ethosu_write_2 = T.match_buffer(ethosu_write, [24], dtype="int8", elem_offset=0, align=64, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "MUL", 1, "CLIP", 10, 100, "TFL", 0, 0, 0, 0, 0, 0, dtype="int8")) - __tvm_meta__ = None # fmt: on @@ -1378,7 +1363,6 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: ethosu_write_2 = T.match_buffer(ethosu_write, [24], dtype="int8", elem_offset=0, align=64, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "MIN", 1, "CLIP", 10, 100, "TFL", 0, 0, 0, 0, 0, 0, dtype="int8")) - __tvm_meta__ = None # fmt: on @@ -1394,7 +1378,6 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: ethosu_write_2 = T.match_buffer(ethosu_write, [24], dtype="int8", elem_offset=0, align=64, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "MAX", 1, "CLIP", 10, 100, "TFL", 0, 0, 0, 0, 0, 0, dtype="int8")) - __tvm_meta__ = None # fmt: on @@ -1410,7 +1393,6 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: ethosu_write_2 = T.match_buffer(ethosu_write, [24], dtype="int32", elem_offset=0, align=64, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 2, 3, 4, 2, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int32", 1, 3, 1, 1, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int32", 2, 3, 4, 2, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "SHR", 1, "NONE", 0, 0, "TFL", 0, 0, 0, 0, 0, 0, dtype="int32")) - __tvm_meta__ = None # fmt: on @@ -1426,7 +1408,6 @@ def main(placeholder: T.handle, ethosu_write: T.handle) -> None: ethosu_write_2 = T.match_buffer(ethosu_write, [24], dtype="int32", elem_offset=0, align=64, offset_factor=1) # body T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 2, 3, 4, 2, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int32", 1, 3, 1, 1, 0, 3, placeholder_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int32", 2, 3, 4, 2, 0, 3, ethosu_write_2[0], 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "SHL", 1, "CLIP", 10, 100, "TFL", 0, 0, 0, 0, 0, 0, dtype="int32")) - __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_vela_api.py b/tests/python/contrib/test_ethosu/test_vela_api.py index 16785e182a49..c70746964baf 100644 --- a/tests/python/contrib/test_ethosu/test_vela_api.py +++ b/tests/python/contrib/test_ethosu/test_vela_api.py @@ -123,8 +123,6 @@ def main( ) ) - __tvm_meta__ = None - """Test case 2 with per-channel quantization""" @@ -219,8 +217,6 @@ def main( ) ) - __tvm_meta__ = None - # fmt: off @tvm.script.ir_module @@ -239,8 +235,6 @@ def main(ethos_u_0_i0: T.Buffer((1, 299, 299, 2), "int8"), ethosu_write: T.Buffe ethos_u_0_i0_1 = T.Buffer((178802,), "int8", data=ethos_u_0_i0.data) ethosu_write_1 = T.Buffer((268203,), "int8", data=ethosu_write.data) T.call_extern("handle", "ethosu_conv2d", "int8", 299, 299, 2, 299, 0, 299, ethos_u_0_i0_1[0], 0, 0, 0, T.float32(0.0039215683937072754), -128, "NHWC", 598, 2, 1, "int8", 299, 299, 3, 299, 0, 299, ethosu_write_1[0], 0, 0, 0, T.float32(0.025585981085896492), -128, "NHWC", 897, 3, 1, 2, 3, 1, 1, 1, 2, p2_global_1[0], 96, T.int8(-1), T.int8(-1), 0, p2_global_1[96], 32, T.int8(-1), T.int8(-1), 2, 0, 2, 1, "NONE", 0, 0, "TFL", "NONE", 32, 12, 8) - - __tvm_meta__ = None # fmt: on From e35499dd05b4d707a742dbc034447916cf8b2a3b Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 27 Feb 2024 16:16:30 -0500 Subject: [PATCH 32/58] Fix purity error in test_dataflow_pattern.py --- tests/python/relax/test_dataflow_pattern.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index a717e3da043f..583e2a8d0822 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -314,7 +314,7 @@ def test_is_call_tir(): assert is_call_tir("tir_zeros", wildcard(), wildcard()).match(lv2_val, var2val=var2val) -@R.function +@R.function(pure=False) def simple_call_packed( x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32") ) -> R.Tensor: From 261e3eb255fdc597d15cc40c6a2e2507650c84a0 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 27 Feb 2024 16:25:27 -0500 Subject: [PATCH 33/58] Fix purity error in test_ast_printer --- tests/python/relax/test_ast_printer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index 2a554f16e23f..97ad9f5dd034 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -564,7 +564,7 @@ def foo(x: R.Tensor): # axis is -1 assert "PrimExpr(value=`T.int64(-1)`)" in foo_str - @R.function + @R.function(pure=False) def bar(x: R.Tensor): return R.print(x, format="{}") From 6cc354c9341592e8cd013ba4def0b1d3b1e78fff Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 27 Feb 2024 22:21:43 -0500 Subject: [PATCH 34/58] Fix test_arith_domain_touched example --- tests/python/arith/test_arith_domain_touched.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/python/arith/test_arith_domain_touched.py b/tests/python/arith/test_arith_domain_touched.py index 46effde3420c..e8d49316bdd6 100644 --- a/tests/python/arith/test_arith_domain_touched.py +++ b/tests/python/arith/test_arith_domain_touched.py @@ -71,17 +71,15 @@ def test_domain_touched(): def test_domain_touched_vector(): m = tvm.runtime.convert(128) - # n is undefined - @T.prim_func(check_well_formed=False) - def func(a: T.handle, b: T.handle): - n = T.int32() + @T.prim_func + def func(a: T.handle, b: T.handle, n: T.int32): A = T.match_buffer(a, (n * m,)) B = T.match_buffer(b, (n * m,)) for i in T.serial(n): A[i * m : (i + 1) * m : 1] = A[i * m : (i + 1) * m : 1] + B[i * m : (i + 1) * m : 1] - a, b = [func.buffer_map[var] for var in func.params] + a, b = [func.buffer_map[var] for var in func.params[:2]] assert tvm.arith._ffi_api.DomainTouched(func.body, a, True, False)[0].extent.value == 128 assert tvm.arith._ffi_api.DomainTouched(func.body, a, True, False)[0].extent.value == 128 From 355aedf9b5e812d09354d72e6b0a0376a4df9baf Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 27 Feb 2024 22:37:23 -0500 Subject: [PATCH 35/58] Okay to set check_well_formed to True in test_tir_analysis_identify_mcmcpy --- tests/python/tir-analysis/test_tir_analysis_identify_memcpy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/tir-analysis/test_tir_analysis_identify_memcpy.py b/tests/python/tir-analysis/test_tir_analysis_identify_memcpy.py index b3c6a489c9f7..8510a66d308d 100644 --- a/tests/python/tir-analysis/test_tir_analysis_identify_memcpy.py +++ b/tests/python/tir-analysis/test_tir_analysis_identify_memcpy.py @@ -32,7 +32,7 @@ class BaseTest: """Utility class for defining unit tests for memcpy""" def __init_subclass__(cls): - cls.check_well_formed = False # CompareBeforeAfter has a member var + cls.check_well_formed = True # CompareBeforeAfter has a member var cls.func = tvm.testing.CompareBeforeAfter._normalize_before(cls.func) cls.expected = pytest.fixture(cls.expected) From f1b6dcbe21ebfba5b65e08c94de7ebcda6799342 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 27 Feb 2024 22:38:19 -0500 Subject: [PATCH 36/58] Define variable in test_tir_analysis_oob --- tests/python/tir-analysis/test_tir_analysis_oob.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/python/tir-analysis/test_tir_analysis_oob.py b/tests/python/tir-analysis/test_tir_analysis_oob.py index 754334b5658d..c4d520881797 100644 --- a/tests/python/tir-analysis/test_tir_analysis_oob.py +++ b/tests/python/tir-analysis/test_tir_analysis_oob.py @@ -42,10 +42,8 @@ def bad_store_loop(A: T.Buffer((2, 3), "float32"), B: T.Buffer((3, 2), "float32" B[0, i] = A[1, i] -# N is undefined -@T.prim_func(check_well_formed=False) -def unknown_bounds(A: T.Buffer((2, 3), "float32"), B: T.Buffer((3, 2), "float32")): - N = T.int32() +@T.prim_func +def unknown_bounds(A: T.Buffer((2, 3), "float32"), B: T.Buffer((3, 2), "float32"), N: T.int32): for i in range(3): B[0, N] = A[1, i] From d9dbeb5803b2efe59856932638e0288c83829488 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 27 Feb 2024 22:42:20 -0500 Subject: [PATCH 37/58] Typo fix --- tests/python/tir-base/test_tir_renew_defs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/tir-base/test_tir_renew_defs.py b/tests/python/tir-base/test_tir_renew_defs.py index 5a11725c64b5..7fe8d7c679fa 100644 --- a/tests/python/tir-base/test_tir_renew_defs.py +++ b/tests/python/tir-base/test_tir_renew_defs.py @@ -82,7 +82,8 @@ def _get_block(f): def test_match_buffer(): - # well-formed checker complains about multiple definitions for a variable A0_s1>? + # well-formed checker complains about multiple definitions for variable A0_s1, + # likely stemming from strides=[s, s] @T.prim_func(check_well_formed=False) # A and B should be remapped def func_match_buffer(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32")): From 0ca4698b486798e17217a89b123c7a681c724bc1 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 27 Feb 2024 22:44:29 -0500 Subject: [PATCH 38/58] Add explanatory comment to test case --- tests/python/tir-base/test_tir_specialize.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/python/tir-base/test_tir_specialize.py b/tests/python/tir-base/test_tir_specialize.py index 327884eaa866..042288723376 100644 --- a/tests/python/tir-base/test_tir_specialize.py +++ b/tests/python/tir-base/test_tir_specialize.py @@ -67,6 +67,8 @@ def matmul_m_128(a: T.handle, b: T.handle, c: T.handle) -> None: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] +# x is considered undefined because it appears as part of x*8, +# but not on its own @T.prim_func(check_well_formed=False) def matmul_m_8x(a: T.handle, b: T.handle, c: T.handle) -> None: x = T.int32() @@ -277,7 +279,8 @@ def before(A: T.Buffer([16, 16], "float32"), B: T.Buffer([16, 16], "float32")): for i in range(256): B_flat[i] = A_flat[i] * 2.0 - # well-formed checker complains about multiple nested definitions of B_flat? + # well-formed checker complains about multiple nested definitions of B_flat + # since it appears in the buffer map twice @T.prim_func(private=True, check_well_formed=False) def expected(A: T.Buffer([16, 16], "float32"), B_handle: T.handle): B = T.match_buffer(B_handle, [16, 16], "float32", data=A.data) From 8af392dd8d0c2ef6b5b3d03c2371dbf608957f12 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 27 Feb 2024 22:51:35 -0500 Subject: [PATCH 39/58] Define the undefined vars in test_tir_transform_common_subexpr_elim --- .../test_tir_transform_common_subexpr_elim.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py index 62a3cbcbcd8c..e64d3c74932b 100644 --- a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py @@ -348,36 +348,35 @@ def test_no_normalization_without_commoning(): # ------------------------------------------------- # Part for testing the commoning with equivalences # ------------------------------------------------- -# B is treated as uninitialized -@T.prim_func(check_well_formed=False) -def func_distributivity(i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32) -> None: - B = T.Buffer((50,), "int32") +@T.prim_func +def func_distributivity( + B: T.Buffer((50,), "int32"), i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 +) -> None: B[i1] = x * (y + z) B[i2] = x * y + x * z -@T.prim_func(check_well_formed=False) +@T.prim_func def func_distributivity_expected( - i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 + B: T.Buffer((50,), "int32"), i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 ) -> None: - B = T.Buffer((50,), "int32") with T.LetStmt(x * y + x * z) as cse_var_1: B[i1] = cse_var_1 B[i2] = cse_var_1 -@T.prim_func(check_well_formed=False) -def func_associativity(i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32) -> None: - B = T.Buffer((50,), "int32") +@T.prim_func +def func_associativity( + B: T.Buffer((50,), "int32"), i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 +) -> None: B[i1] = (x + y) + z B[i2] = x + (y + z) -@T.prim_func(check_well_formed=False) +@T.prim_func def func_associativity_expected( - i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 + B: T.Buffer((50,), "int32"), i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 ) -> None: - B = T.Buffer((50,), "int32") with T.LetStmt((x + y) + z) as cse_var_1: B[i1] = cse_var_1 B[i2] = cse_var_1 @@ -461,6 +460,7 @@ def test_deterministic_cse(): ["PR", 3, 0, "auto_unroll_max_step$512"], ["AN", 1, 3, 2], ["AN", 3, 21, 2], \ ["AN", 6, 6, 2]]]], "r": [[0.0331129], 0, 0.900362, 1647464342], "v": "v0.6"}\n' + # The workload associated with the log @auto_scheduler.register_workload def conv2d_layer(N, H, W, CO, CI, KH, KW, stride, padding): From 07c8819e07e5e37cce62704b74376e107e7648b1 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 27 Feb 2024 22:54:47 -0500 Subject: [PATCH 40/58] Exception no longer necessary in test_tir_transform_inject_rolling_buffer --- .../tir-transform/test_tir_transform_inject_rolling_buffer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/tir-transform/test_tir_transform_inject_rolling_buffer.py b/tests/python/tir-transform/test_tir_transform_inject_rolling_buffer.py index 046df6c4cb36..c1c8141f70a7 100644 --- a/tests/python/tir-transform/test_tir_transform_inject_rolling_buffer.py +++ b/tests/python/tir-transform/test_tir_transform_inject_rolling_buffer.py @@ -199,7 +199,7 @@ def test_mixed_buffers(make_rolling): _verify_schedule(sch, [A], pool_c) -@tvm.script.ir_module(check_well_formed=False) +@tvm.script.ir_module class PreRollingBuffer: @T.prim_func def main( @@ -274,7 +274,7 @@ def main( ) -@tvm.script.ir_module(check_well_formed=False) +@tvm.script.ir_module class PostRollingBuffer: @T.prim_func def main( From ae8001c35f80d0cb600f4e4706a9e2203047905e Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 28 Feb 2024 17:49:22 -0500 Subject: [PATCH 41/58] Remove unnecessary check exemption in test_tir_transform_convert_ssa --- tests/python/tir-transform/test_tir_transform_convert_ssa.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/python/tir-transform/test_tir_transform_convert_ssa.py b/tests/python/tir-transform/test_tir_transform_convert_ssa.py index 371fa0b4e71c..ec768ba74f7b 100644 --- a/tests/python/tir-transform/test_tir_transform_convert_ssa.py +++ b/tests/python/tir-transform/test_tir_transform_convert_ssa.py @@ -327,7 +327,7 @@ class TestDeDuplicateThreadIdxAcrossMultipleFunctions(BaseBeforeAfter): def before(self): threadIdx_x = tvm.tir.Var("threadIdx_x", "int32") - # complaints of duplicate definitions of threadIdx_x + # threadIdx_x is defined outside @I.ir_module(check_well_formed=False) class mod: @T.prim_func @@ -351,8 +351,7 @@ def kernel_2(A: T.Buffer([256], "float32")): return mod def expected(self): - # complaints of duplicate definitions of threadIdx_x - @I.ir_module(check_well_formed=False) + @I.ir_module class mod: @T.prim_func def kernel_1(A: T.Buffer([256], "float32")): From b373a4f8495a6e759a798e515fe6a8e44d4a7511 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 28 Feb 2024 17:58:09 -0500 Subject: [PATCH 42/58] Avoid checking exemption in test_inject_ptx_ldg32 --- tests/python/codegen/test_inject_ptx_ldg32.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/python/codegen/test_inject_ptx_ldg32.py b/tests/python/codegen/test_inject_ptx_ldg32.py index d7a92802598b..4a6d4c366a61 100644 --- a/tests/python/codegen/test_inject_ptx_ldg32.py +++ b/tests/python/codegen/test_inject_ptx_ldg32.py @@ -20,21 +20,21 @@ import tvm.testing -# A_local is undefined -@T.prim_func(check_well_formed=False) +@T.prim_func def vector_add(A: T.Buffer((16), "float32"), B: T.Buffer((32), "float32")) -> None: T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) bx = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(bx, 1) T.launch_thread(tx, 32) - A_local = T.Buffer((32), "float32", scope="local") - with T.block(): - T.reads(A[0:16]) - T.writes(A_local[0:32]) - A_local[tx] = T.if_then_else(tx % 2 == 0, A[tx // 2], T.float32(0), dtype="float32") - B[tx] = A_local[tx] + 1.0 + A_local = T.alloc_buffer((32), "float32", scope="local") + + with T.block(): + T.reads(A[0:16]) + T.writes(A_local[0:32]) + A_local[tx] = T.if_then_else(tx % 2 == 0, A[tx // 2], T.float32(0), dtype="float32") + B[tx] = A_local[tx] + 1.0 @tvm.testing.requires_cuda From 785394d0de57d48a8f5f4f3692d5d8484f0b955a Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 28 Feb 2024 22:25:54 -0500 Subject: [PATCH 43/58] Note special case in test_distributed_transform_propagate_sharding --- .../test_distributed_transform_propagate_sharding.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py b/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py index 990a7b1557e5..e1f45d278d6c 100644 --- a/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py +++ b/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py @@ -1370,7 +1370,9 @@ def foo( gv: R.Tensor((1, 256, 4096), dtype="float16") = lv44 return gv - @I.ir_module + # the below uses global vars that are not yet defined but the definitions + # will be added later + @I.ir_module(check_well_formed=False) class ShardedLlamaAttentionLayerTIR: I.module_attrs({"device_num": 10}) I.module_global_infos( From 571cf47d3da1913229790385e0c7faaaeebc1a05 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 29 Feb 2024 21:58:43 -0500 Subject: [PATCH 44/58] Exempt well-formed error in dlight/test_benchmark --- tests/python/dlight/test_benchmark.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/dlight/test_benchmark.py b/tests/python/dlight/test_benchmark.py index 3153be2cc9b0..5f8bdf49551e 100644 --- a/tests/python/dlight/test_benchmark.py +++ b/tests/python/dlight/test_benchmark.py @@ -36,9 +36,11 @@ ) import tvm.testing +# The test function uses an undefined symbolic var in Relax. +# In principle, this should be attached to an argument. # pylint: disable=no-self-argument,invalid-name,line-too-long,no-method-argument # fmt: off -@I.ir_module +@I.ir_module(check_well_formed=False) class Module: @T.prim_func def full1(var_T_full: T.handle): From aff00dec2e9b59a6b5e5fec6a616faff96457312 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 29 Feb 2024 22:12:17 -0500 Subject: [PATCH 45/58] Exempt well-formedness errors in test_ethosu/, mostly uninitialized vars --- .../test_copy_compute_reordering.py | 45 ++++++++++++------- .../test_ethosu/test_encode_constants.py | 32 ++++++++----- .../test_ethosu/test_merge_constants.py | 5 ++- .../test_ethosu/test_remove_concatenates.py | 5 ++- .../test_ethosu/test_replace_conv2d.py | 24 ++++++---- .../contrib/test_ethosu/test_replace_copy.py | 10 +++-- .../contrib/test_ethosu/test_scheduler.py | 3 +- .../contrib/test_ethosu/test_vela_api.py | 3 +- 8 files changed, 81 insertions(+), 46 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_copy_compute_reordering.py b/tests/python/contrib/test_ethosu/test_copy_compute_reordering.py index 1a00e01b6031..6b9702f012ca 100644 --- a/tests/python/contrib/test_ethosu/test_copy_compute_reordering.py +++ b/tests/python/contrib/test_ethosu/test_copy_compute_reordering.py @@ -22,8 +22,9 @@ from tvm.script import tir as T from tvm.relay.backend.contrib.ethosu.tir.passes import CopyComputeReordering +# Uninitialized vars used # fmt: off -@tvm.script.ir_module +@tvm.script.ir_module(check_well_formed=False) class AllOperatorsWithWeights: @T.prim_func def main() -> None: @@ -70,8 +71,9 @@ def test_all_operators_with_weights_max_copy_movements_0(): def test_all_operators_with_weights_max_copy_movements_1(): + # Uninitialized vars used # fmt: off - @tvm.script.ir_module + @tvm.script.ir_module(check_well_formed=False) class ReferenceModule: @T.prim_func def main() -> None: @@ -116,8 +118,9 @@ def main() -> None: def test_all_operators_with_weights_max_copy_movements_2(): + # Uninitialized vars used # fmt: off - @tvm.script.ir_module + @tvm.script.ir_module(check_well_formed=False) class ReferenceModule: @T.prim_func def main() -> None: @@ -161,8 +164,9 @@ def main() -> None: tvm.ir.assert_structural_equal(test_mod, reference_mod, True) +# Uninitialized vars used # fmt: off -@tvm.script.ir_module +@tvm.script.ir_module(check_well_formed=False) class AllOperatorsWithoutWeights: @T.prim_func def main() -> None: @@ -183,8 +187,9 @@ def test_all_operators_without_weights(max_copy_movements): tvm.ir.assert_structural_equal(test_mod, reference_mod, True) +# Uninitialized vars used # fmt: off -@tvm.script.ir_module +@tvm.script.ir_module(check_well_formed=False) class OperatorsWithAndWithoutWeights: @T.prim_func def main() -> None: @@ -218,8 +223,9 @@ def test_operators_with_and_without_weights_max_copy_movements_0(): def test_operators_with_and_without_weights_max_copy_movements_1(): + # Uninitialized vars used # fmt: off - @tvm.script.ir_module + @tvm.script.ir_module(check_well_formed=False) class ReferenceModule: @T.prim_func def main() -> None: @@ -251,8 +257,9 @@ def main() -> None: def test_operators_with_and_without_weights_max_copy_movements_2(): + # Uninitialized vars used # fmt: off - @tvm.script.ir_module + @tvm.script.ir_module(check_well_formed=False) class ReferenceModule: @T.prim_func def main() -> None: @@ -283,8 +290,9 @@ def main() -> None: tvm.ir.assert_structural_equal(test_mod, reference_mod, True) +# Uninitialized vars used # fmt: off -@tvm.script.ir_module +@tvm.script.ir_module(check_well_formed=False) class CopyToBufferWithLocalScope: @T.prim_func def main() -> None: @@ -324,8 +332,9 @@ def test_copy_to_buffer_with_local_scope_max_copy_movements_0(): @pytest.mark.parametrize("max_copy_movements", [1, 2]) def test_copy_to_buffer_with_local_scope_max_copy_movements_n(max_copy_movements): + # Uninitialized vars used # fmt: off - @tvm.script.ir_module + @tvm.script.ir_module(check_well_formed=False) class ReferenceModule: @T.prim_func def main() -> None: @@ -400,8 +409,9 @@ def abs(): def test_default_max_copy_movements(): + # Uninitialized vars used # fmt: off - @tvm.script.ir_module + @tvm.script.ir_module(check_well_formed=False) class ReferenceModule: @T.prim_func def main() -> None: @@ -433,8 +443,9 @@ def main() -> None: def test_pass_context_option_max_copy_movements(): + # Uninitialized vars used # fmt: off - @tvm.script.ir_module + @tvm.script.ir_module(check_well_formed=False) class ReferenceModule: @T.prim_func def main() -> None: @@ -469,8 +480,9 @@ def main() -> None: def test_reordering_based_on_cycles(): + # Uninitialized vars used # fmt: off - @tvm.script.ir_module + @tvm.script.ir_module(check_well_formed=False) class ModuleBefore: @T.prim_func def main(placeholder: T.Buffer(97156, "int8"), placeholder_encoded: T.Buffer(208, "uint8"), placeholder_encoded_1: T.Buffer(112, "uint8"), placeholder_encoded_2: T.Buffer(96, "uint8"), placeholder_encoded_3: T.Buffer(112, "uint8"), ethosu_write: T.Buffer(43672, "int8")) -> None: @@ -518,7 +530,8 @@ def main(placeholder: T.Buffer(97156, "int8"), placeholder_encoded: T.Buffer(208 T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 103, 106, 4, 103, 0, 106, ethosu_write_5[0], 0, 0, 0, T.float32(0.0057637207210063934), -128, "NHCWB16", 1696, 16, 1, "int8", 103, 106, 4, 103, 0, 106, ethosu_write[0], 0, 0, 0, T.float32(0.0057619437575340271), -128, "NHWC", 424, 4, 1, 3, 2, 1, 1, 2, 2, placeholder_d_global_3[0], 64, 0, placeholder_d_global_3[64], 48, 1, 2, 1, 2, "NONE", 0, 0, "TFL", "NONE", 14, 18, 8, dtype="handle")) - @tvm.script.ir_module + # Uninitialized vars used + @tvm.script.ir_module(check_well_formed=False) class ModuleAfter: @T.prim_func def main(placeholder: T.Buffer(97156, "int8"), placeholder_encoded: T.Buffer(208, "uint8"), placeholder_encoded_1: T.Buffer(112, "uint8"), placeholder_encoded_2: T.Buffer(96, "uint8"), placeholder_encoded_3: T.Buffer(112, "uint8"), ethosu_write: T.Buffer(43672, "int8")) -> None: @@ -572,8 +585,9 @@ def main(placeholder: T.Buffer(97156, "int8"), placeholder_encoded: T.Buffer(208 def test_reordering_based_on_cycles_luts_present(): + # Uninitialized vars used # fmt: off - @tvm.script.ir_module + @tvm.script.ir_module(check_well_formed=False) class ModuleBefore: @T.prim_func def main(placeholder: T.Buffer(97156, "int8"), placeholder_encoded: T.Buffer(208, "uint8"), placeholder_encoded_1: T.Buffer(112, "uint8"), placeholder_1: T.Buffer(256, "int8"), placeholder_encoded_2: T.Buffer(96, "uint8"), placeholder_2: T.Buffer(256, "int8"), placeholder_3: T.Buffer(256, "int8"), ethosu_write: T.Buffer(46200, "int8")) -> None: @@ -623,7 +637,8 @@ def main(placeholder: T.Buffer(97156, "int8"), placeholder_encoded: T.Buffer(208 T.evaluate(T.call_extern("ethosu_pooling", "int8", 105, 110, 4, 105, 0, 110, ethosu_write_5[0], 0, 0, 0, T.float32(1), 0, "NHCWB16", 1760, 16, 1, "int8", 105, 110, 4, 105, 0, 110, ethosu_write[0], 0, 0, 0, T.float32(1), 0, "NHWC", 440, 4, 1, "MAX", 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, "TANH", 0, 0, "TFL", "NONE", 4, 64, 8, dtype="handle")) - @tvm.script.ir_module + # Uninitialized vars used + @tvm.script.ir_module(check_well_formed=False) class ModuleAfter: @T.prim_func def main(placeholder: T.Buffer(97156, "int8"), placeholder_encoded: T.Buffer(208, "uint8"), placeholder_encoded_1: T.Buffer(112, "uint8"), placeholder_1: T.Buffer(256, "int8"), placeholder_encoded_2: T.Buffer(96, "uint8"), placeholder_2: T.Buffer(256, "int8"), placeholder_3: T.Buffer(256, "int8"), ethosu_write: T.Buffer(46200, "int8")) -> None: diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index b16fe2a85a34..8c35a43e47e9 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -32,8 +32,9 @@ from .infra import make_ethosu_binary_elementwise, make_ethosu_conv2d +# Uninitialized variables used # fmt: off -@tvm.script.ir_module +@tvm.script.ir_module(check_well_formed=False) class WeightStreamOnlyU55: @T.prim_func def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_write: T.Buffer((1, 16, 16, 8), "int8")) -> None: @@ -62,7 +63,8 @@ def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_writ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 112, T.int8(-1), T.int8(-1), 12, p2[112], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) -@tvm.script.ir_module +# Uninitialized variables used +@tvm.script.ir_module(check_well_formed=False) class WeightStreamOnlyU65: @T.prim_func def main(ifm: T.Buffer((1, 16, 16, 32), "int8"), ethosu_write: T.Buffer((1, 16, 16, 8), "int8")): @@ -140,15 +142,16 @@ def _get_func(): func = _get_func() mod, consts = _lower_to_tir(func, cascader=_planner) script = mod.script() - test_mod = tvm.script.from_source(script) + test_mod = tvm.script.from_source(script, check_well_formed=False) tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) test_const_size = [value.size for value in list(consts.values())] assert reference_const_sizes.sort() == test_const_size.sort() +# Uninitialized variables used # fmt: off -@tvm.script.ir_module +@tvm.script.ir_module(check_well_formed=False) class RereadWeightsU55: @T.prim_func def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_write: T.Buffer((1, 16, 16, 8), "int8")) -> None: @@ -168,7 +171,8 @@ def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_writ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[64], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 304, T.int8(-1), T.int8(-1), 12, p2[304], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) -@tvm.script.ir_module +# Uninitialized variables used +@tvm.script.ir_module(check_well_formed=False) class RereadWeightsU65: @T.prim_func def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_write: T.Buffer((1, 16, 16, 8), "int8")) -> None: @@ -239,15 +243,16 @@ def _get_func(): func = _get_func() mod, consts = _lower_to_tir(func, cascader=_cascader) script = mod.script() - test_mod = tvm.script.from_source(script) + test_mod = tvm.script.from_source(script, check_well_formed=False) tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) test_const_size = [value.size for value in list(consts.values())] assert reference_const_sizes.sort() == test_const_size.sort() +# Uninitialized variables used # fmt: off -@tvm.script.ir_module +@tvm.script.ir_module(check_well_formed=False) class DirectReadOnlyU55: @T.prim_func def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_write: T.Buffer((1, 16, 16, 8), "int8")) -> None: @@ -266,7 +271,8 @@ def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_writ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, buffer_2[0], 160, T.int8(-1), T.int8(-1), 12, buffer_3[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) -@tvm.script.ir_module +# Uninitialized variables used +@tvm.script.ir_module(check_well_formed=False) class DirectReadOnlyU65: @T.prim_func def main(input_placeholder: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_write: T.Buffer((1, 16, 16, 8), "int8")) -> None: @@ -335,15 +341,16 @@ def _get_func(): mod, consts = _lower_to_tir(func) script = mod.script() - test_mod = tvm.script.from_source(script) + test_mod = tvm.script.from_source(script, check_well_formed=False) tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) test_const_size = [value.size for value in list(consts.values())] assert reference_const_sizes.sort() == test_const_size.sort() +# Uninitialized variables used # fmt: off -@tvm.script.ir_module +@tvm.script.ir_module(check_well_formed=False) class MixedReadU55: @T.prim_func def main(input_ifm: T.Buffer((1,16,16,32), "int8"), input_ethosu_write: T.Buffer((1,16,16,8), "int8")) -> None: @@ -375,7 +382,8 @@ def main(input_ifm: T.Buffer((1,16,16,32), "int8"), input_ethosu_write: T.Buffer T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, T.int8(-1), T.int8(-1), 12, p2[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) -@tvm.script.ir_module +# Uninitialized variables used +@tvm.script.ir_module(check_well_formed=False) class MixedReadU65: @T.prim_func def main(ifm: T.Buffer((1, 16, 16, 32), "int8"), ethosu_write: T.Buffer((1, 16, 16, 8), "int8")): @@ -468,7 +476,7 @@ def _get_func(): mod, consts = _lower_to_tir(func, cascader=_planner) script = mod.script() - test_mod = tvm.script.from_source(script) + test_mod = tvm.script.from_source(script, check_well_formed=False) tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) test_const_size = [value.size for value in list(consts.values())] diff --git a/tests/python/contrib/test_ethosu/test_merge_constants.py b/tests/python/contrib/test_ethosu/test_merge_constants.py index 1c22f0b5c8b6..b989c15e065e 100644 --- a/tests/python/contrib/test_ethosu/test_merge_constants.py +++ b/tests/python/contrib/test_ethosu/test_merge_constants.py @@ -305,8 +305,9 @@ def main(buffer1: T.Buffer((64,), "uint8"), def test_no_copies(): + # the vars placeholder and ethosu_write are undefined # fmt: off - @tvm.script.ir_module + @tvm.script.ir_module(check_well_formed=False) class InputModule: @T.prim_func def main() -> None: @@ -320,7 +321,7 @@ def main() -> None: T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 1, 4, 4, 1, 0, 4, placeholder[0], 0, 0, 0, T.float32(0.00783747), -128, "NHWC", 1, 4, 1, "int8", 1, 4, 1, 1, 0, 4, placeholder[16], 0, 0, 0, T.float32(0.00783747), -128, "NHWC", 1, 1, 1, "int8", 1, 4, 4, 1, 0, 4, ethosu_write_4[0], 0, 0, 0, T.float32(0.00783747), -128, "NHWC", 1, 4, 1, "MAX", 0, "CLIP", -128, 127, "TFL", 1, 4, 4, dtype="handle")) T.evaluate(T.call_extern("ethosu_identity", "int8", 1, 4, 4, 1, 0, 4, ethosu_write_4[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 4, 1, "int8", 1, 4, 4, 1, 0, 4, ethosu_write[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 4, 1, "AVG", 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - @tvm.script.ir_module + @tvm.script.ir_module(check_well_formed=False) class ReferenceModule: @T.prim_func def main() -> None: diff --git a/tests/python/contrib/test_ethosu/test_remove_concatenates.py b/tests/python/contrib/test_ethosu/test_remove_concatenates.py index 56205bcd73d4..58cf5f72d7c0 100644 --- a/tests/python/contrib/test_ethosu/test_remove_concatenates.py +++ b/tests/python/contrib/test_ethosu/test_remove_concatenates.py @@ -28,7 +28,8 @@ # fmt: off -@tvm.script.ir_module +# complains of an undefined buffer +@tvm.script.ir_module(check_well_formed=False) class ReferenceModule: @T.prim_func def main(input_placeholder: T.Buffer((1,8,12,16), "int8"), input_placeholder_1: T.Buffer((1,8,10,16), "int8"), input_T_concat: T.Buffer((1,8,32,16), "int8")) -> None: @@ -74,7 +75,7 @@ def _get_func(): func = _get_func() mod, _ = _lower_to_tir(func) script = mod.script() - test_mod = tvm.script.from_source(script) + test_mod = tvm.script.from_source(script, check_well_formed=False) reference_mod = ReferenceModule tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index ff47343811f5..be529cdb32fa 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -363,8 +363,9 @@ def _visit(stmt): assert data[0] == answer, data[0] +# Undefined variables used # fmt: off -@tvm.script.ir_module +@tvm.script.ir_module(check_well_formed=False) class Conv2dDoubleCascade1: @T.prim_func def main(input_placeholder_5: T.Buffer((1, 8, 8, 3), "int8"), input_ethosu_write_1: T.Buffer((1, 8, 8, 8), "int8")) -> None: @@ -385,7 +386,8 @@ def main(input_placeholder_5: T.Buffer((1, 8, 8, 3), "int8"), input_ethosu_write T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, ethosu_write_1[32], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, buffer[0], 304, T.int8(-1), T.int8(-1), 12, buffer_1[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) -@tvm.script.ir_module +# undefined variables used +@tvm.script.ir_module(check_well_formed=False) class Conv2dDoubleCascade2: @T.prim_func def main(input_placeholder_5: T.Buffer((1, 8, 8, 3), "int8"), input_ethosu_write_1: T.Buffer((1, 8, 8, 8), "int8")) -> None: @@ -406,7 +408,8 @@ def main(input_placeholder_5: T.Buffer((1, 8, 8, 3), "int8"), input_ethosu_write T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, ethosu_write_1[256], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, buffer_3[0], 2608, T.int8(-1), T.int8(-1), 12, buffer[0], 80, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) -@tvm.script.ir_module +# undefined variables used +@tvm.script.ir_module(check_well_formed=False) class Conv2dDoubleCascade3: @T.prim_func def main(input_placeholder_5: T.Buffer((1, 16, 16, 3), "int8"), input_ethosu_write_1: T.Buffer((1, 20, 4, 8), "int8")) -> None: @@ -430,7 +433,8 @@ def main(input_placeholder_5: T.Buffer((1, 16, 16, 3), "int8"), input_ethosu_wri T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 32, 4, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 4, 8, 4, 0, 4, ethosu_write_1[512], 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, buffer[0], 1744, T.int8(-1), T.int8(-1), 12, buffer_1[0], 80, T.int8(-1), T.int8(-1), 0, 1, 2, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) -@tvm.script.ir_module +# undefined variables used +@tvm.script.ir_module(check_well_formed=False) class Conv2dDoubleCascade4: @T.prim_func def main(input_placeholder_5: T.Buffer((1, 8, 1, 8, 16), "int8"), input_ethosu_write_1: T.Buffer((1, 8, 2, 8, 16), "int8")) -> None: @@ -451,7 +455,8 @@ def main(input_placeholder_5: T.Buffer((1, 8, 1, 8, 16), "int8"), input_ethosu_w T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, ethosu_write_1[1024], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, buffer_3[0], 11040, T.int8(-1), T.int8(-1), 12, buffer_2[0], 272, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) -@tvm.script.ir_module +# undefined variables used +@tvm.script.ir_module(check_well_formed=False) class Conv2dDoubleCascade5: @T.prim_func def main(input_placeholder: T.Buffer((1, 8, 8, 3), "int8"), input_ethosu_write: T.Buffer((1, 32, 32, 8), "int8")) -> None: @@ -472,7 +477,8 @@ def main(input_placeholder: T.Buffer((1, 8, 8, 3), "int8"), input_ethosu_write: T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 32, 8, 16, 0, 32, ethosu_write[4096], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 8, 1, 1, 1, 1, 1, 1, 1, buffer_2[0], 304, T.int8(-1), T.int8(-1), 12, buffer_3[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", 0, 0, 0, dtype="handle")) -@tvm.script.ir_module +# undefined variables used +@tvm.script.ir_module(check_well_formed=False) class Conv2dDoubleCascade6: @T.prim_func def main(input_placeholder: T.Buffer((1, 8, 1, 8, 16), "int8"), input_ethosu_write: T.Buffer((1, 32, 2, 32, 16), "int8")) -> None: @@ -634,7 +640,7 @@ def _get_func( func = _get_func(*params[:-1]) mod, _ = _lower_to_tir(func, cascader=total_cascader(params[-1])) script = mod.script() - mod = tvm.script.from_source(script) + mod = tvm.script.from_source(script, check_well_formed=False) tvm.ir.assert_structural_equal(mod["main"], reference_mod["main"], True) @@ -691,7 +697,7 @@ def _get_func(ifm_shape, lower, upper, ofm_channels=16): func = _get_func(*params) mod, _ = _lower_to_tir(func) script = mod.script() - mod = tvm.script.from_source(script) + mod = tvm.script.from_source(script, check_well_formed=False) tvm.ir.assert_structural_equal(mod["main"], reference_mod["main"], True) @@ -789,7 +795,7 @@ def _get_func(ifm_shape, reshaped, ifm_layout): func = _get_func(*params) mod, _ = _lower_to_tir(func, cascader=total_cascader((1, 4, 6, 16))) script = mod.script() - mod = tvm.script.from_source(script) + mod = tvm.script.from_source(script, check_well_formed=False) tvm.ir.assert_structural_equal(mod["main"], reference_mod["main"], True) diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index 7329b40d5851..ff343517352d 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -30,8 +30,9 @@ from .infra import make_ethosu_conv2d +# uninitialized varaibles used # fmt: off -@tvm.script.ir_module +@tvm.script.ir_module(check_well_formed=False) class ReferenceModule: @T.prim_func def main(input_placeholder_3: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_write_1: T.Buffer((1, 16, 16, 8), "int8")) -> None: @@ -68,13 +69,14 @@ def _get_func(): mod, _ = _lower_to_tir(func, cascader=copy_constants()) script = mod.script() - test_mod = tvm.script.from_source(script) + test_mod = tvm.script.from_source(script, check_well_formed=False) reference_mod = ReferenceModule tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) +# Uninitialized variables used # fmt: off -@tvm.script.ir_module +@tvm.script.ir_module(check_well_formed=False) class WeightStream: @T.prim_func def main(input_placeholder_5: T.Buffer((1, 16, 16, 32), "int8"), input_ethosu_write_1: T.Buffer((1, 16, 16, 16), "int8")) -> None: @@ -127,7 +129,7 @@ def _get_func(): mod, _ = _lower_to_tir(func, cascader=_cascader) script = mod.script() - test_mod = tvm.script.from_source(script) + test_mod = tvm.script.from_source(script, check_well_formed=False) reference_mod = WeightStream tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py index 74394fbd46bf..0b6f4a2629b7 100644 --- a/tests/python/contrib/test_ethosu/test_scheduler.py +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -211,8 +211,9 @@ def test_schedule_cache_reads(): assert list(sch[cr].iter_var_attrs[iv].pragma_values) == ["ethosu_copy"] +# uninitialized variables used # fmt: off -@tvm.script.ir_module +@tvm.script.ir_module(check_well_formed=False) class DiamondGraphTir: @T.prim_func def main(input_placeholder: T.Buffer((1, 56, 56, 96), "int8"), input_ethosu_write: T.Buffer((1, 56, 56, 24), "int8")) -> None: diff --git a/tests/python/contrib/test_ethosu/test_vela_api.py b/tests/python/contrib/test_ethosu/test_vela_api.py index c70746964baf..7f4b5b8c7052 100644 --- a/tests/python/contrib/test_ethosu/test_vela_api.py +++ b/tests/python/contrib/test_ethosu/test_vela_api.py @@ -218,8 +218,9 @@ def main( ) +# Complains of the use of undefined vars # fmt: off -@tvm.script.ir_module +@tvm.script.ir_module(check_well_formed=False) class Module3: @T.prim_func def main(ethos_u_0_i0: T.Buffer((1, 299, 299, 2), "int8"), ethosu_write: T.Buffer((1, 299, 299, 3), "int8")): From a8e74155748599e70bfef6ebc265607bc9b1e458 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Fri, 1 Mar 2024 15:19:26 -0500 Subject: [PATCH 46/58] Whitespace --- tests/python/dlight/test_benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/dlight/test_benchmark.py b/tests/python/dlight/test_benchmark.py index 5f8bdf49551e..695a0e90263d 100644 --- a/tests/python/dlight/test_benchmark.py +++ b/tests/python/dlight/test_benchmark.py @@ -36,7 +36,7 @@ ) import tvm.testing -# The test function uses an undefined symbolic var in Relax. +# The test function uses an undefined symbolic var in Relax. # In principle, this should be attached to an argument. # pylint: disable=no-self-argument,invalid-name,line-too-long,no-method-argument # fmt: off From f7721907f7edc2c27b2618727f0f1c3037b86462 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 5 Mar 2024 18:55:54 -0500 Subject: [PATCH 47/58] Include non-CUDA GPUs in IsScheduledOnGPU --- src/tir/transforms/default_gpu_schedule.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/default_gpu_schedule.cc b/src/tir/transforms/default_gpu_schedule.cc index 6cf7f6e06743..6d0542257309 100644 --- a/src/tir/transforms/default_gpu_schedule.cc +++ b/src/tir/transforms/default_gpu_schedule.cc @@ -113,7 +113,8 @@ bool IsScheduledOnGPU(const BaseFunc& func) { if (target.defined()) { int dev_type = target->GetTargetDeviceType(); - if (dev_type != kDLCUDA) { + if (!(dev_type == kDLCUDA || dev_type == kDLMetal || dev_type == kDLROCM || + dev_type == kDLWebGPU)) { return false; } } From f82ca07cea703da09ece34509e8337fb20ff3512 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 5 Mar 2024 19:25:04 -0500 Subject: [PATCH 48/58] Fix thread binding bug by changing thread binding var dtype --- src/tir/ir/data_type_rewriter.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index 2d2c097be494..3461597b8e0f 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -532,6 +532,12 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const ForNode* op) { n->loop_var = new_loop_var; n->min = cast(new_loop_var.dtype(), min); n->extent = cast(new_loop_var.dtype(), extent); + if (op->thread_binding.defined()) { + auto old_thread_binding = op->thread_binding.value(); + auto* ptr = old_thread_binding.CopyOnWrite(); + ptr->var = old_thread_binding->var.copy_with_dtype(new_loop_var.dtype()); + n->thread_binding = std::move(Optional(std::move(old_thread_binding))); + } n->body = new_body; return std::move(new_for); } else { From 2b7c4caa6ccbe4cfd758922a3e5640221e2fd09f Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 5 Mar 2024 22:50:22 -0500 Subject: [PATCH 49/58] Include overrides in test_runtime_builtin_paged_attention_kv_cache.py --- ...me_builtin_paged_attention_kv_cache_tir.py | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index 64887ca5b653..c33686d16e77 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -566,7 +566,8 @@ def test_paged_attention_kv_cache_sliding_window(kv_cache_and_config): def kv_cache_transpose_append(head_dim, dtype): - @T.prim_func + # undefined vars used + @T.prim_func(check_well_formed=False) def _kv_cache_transpose_append( var_pages: T.handle, var_k_data: T.handle, @@ -604,7 +605,8 @@ def _kv_cache_transpose_append( def copy_cache(head_dim, dtype): - @T.prim_func + # undefined vars used + @T.prim_func(check_well_formed=False) def _copy_cache( var_pages: T.handle, var_position_map: T.handle, @@ -677,7 +679,8 @@ def _rope( # pylint: disable=too-many-arguments ) return cos + sin - @T.prim_func(private=True) + # undefined vars used + @T.prim_func(private=True, check_well_formed=False) def fused_rope( # pylint: disable=too-many-locals var_qkv: T.handle, var_position_map: T.handle, @@ -852,9 +855,10 @@ def _attention_prefill( tile_z = 8 num_warps = 2 + # undefined vars used # pylint: disable=line-too-long,too-many-arguments,too-many-branches # fmt: off - @T.prim_func + @T.prim_func(check_well_formed=False) def batch_prefill_paged_kv( _0: T.int32, # pylint: disable=unused-argument var_q: T.handle, # [total_len, h_q, d] @@ -1214,9 +1218,10 @@ def _attention_decode( tile_size_per_bdx = TILE_SIZE_PER_BDX if GROUP_SIZE == 1 else 1 log2e = math.log2(math.exp(1)) + # undefined vars used # pylint: disable=line-too-long,too-many-arguments,too-many-branches # fmt: off - @T.prim_func + @T.prim_func(check_well_formed=False) def batch_decode_paged_kv( _0: T.int32, # pylint: disable=unused-argument Q_handle: T.handle, @@ -1457,9 +1462,10 @@ def _attention_prefill_ragged( tile_z = 8 num_warps = 2 + # undefined vars used # fmt: off - @T.prim_func - def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-branches + @T.prim_func(check_well_formed=False) + def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-branches var_q: T.handle, # [total_len, h_q, d] var_q_indptr: T.handle, # [batch_size + 1] var_k: T.handle, # [total_len, h_kv, d] @@ -1775,7 +1781,8 @@ def _merge_state_inplace( bdy //= 2 gdy = num_heads // bdy - @T.prim_func + # undefined vars used + @T.prim_func(check_well_formed=False) def merge_state_inplace( v: T.handle, s: T.handle, From d2541775617b1ef1decb941b0db51b585041a63e Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 6 Mar 2024 17:51:06 -0500 Subject: [PATCH 50/58] add exemptions in test_ethosu/test_replace_conv2d --- .../contrib/test_ethosu/test_replace_conv2d.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index be529cdb32fa..a8aa4043293f 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -644,8 +644,9 @@ def _get_func( tvm.ir.assert_structural_equal(mod["main"], reference_mod["main"], True) +# Undefined vars used # fmt: off -@tvm.script.ir_module +@tvm.script.ir_module(check_well_formed=False) class Conv2dInlineCopy1: @T.prim_func def main(input_placeholder_3: T.Buffer((1, 10, 12, 8), "int8"), input_ethosu_write_1: T.Buffer((1, 8, 8, 16), "int8")) -> None: @@ -659,7 +660,8 @@ def main(input_placeholder_3: T.Buffer((1, 10, 12, 8), "int8"), input_ethosu_wri T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8, 0, 8, placeholder_3[120], 0, 0, 0, T.float32(0.5), 10, "NHWC", 96, 8, 1, "int8", 8, 8, 16, 8, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 3, 3, 1, 1, 1, 1, buffer[0], 848, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) -@tvm.script.ir_module +# Undefined vars used +@tvm.script.ir_module(check_well_formed=False) class Conv2dInlineCopy2: @T.prim_func def main(input_placeholder_3: T.Buffer((1, 7, 9, 5), "int8"), input_ethosu_write_1: T.Buffer((1, 3, 5, 16), "int8")) -> None: @@ -701,8 +703,9 @@ def _get_func(ifm_shape, lower, upper, ofm_channels=16): tvm.ir.assert_structural_equal(mod["main"], reference_mod["main"], True) +# Undefined vars used # fmt: off -@tvm.script.ir_module +@tvm.script.ir_module(check_well_formed=False) class Conv2dInlineReshape1: @T.prim_func def main(input_placeholder_3: T.Buffer((4, 6, 8, 1), "int8"), input_ethosu_write_1: T.Buffer((1, 8, 6, 16), "int8")) -> None: @@ -717,7 +720,8 @@ def main(input_placeholder_3: T.Buffer((4, 6, 8, 1), "int8"), input_ethosu_write T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) -@tvm.script.ir_module +# undefined vars used +@tvm.script.ir_module(check_well_formed=False) class Conv2dInlineReshape2: @T.prim_func def main(input_placeholder_3: T.Buffer((1, 24, 8), "int8"), input_ethosu_write_1: T.Buffer((1, 8, 6, 16), "int8")) -> None: @@ -732,7 +736,8 @@ def main(input_placeholder_3: T.Buffer((1, 24, 8), "int8"), input_ethosu_write_1 T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) -@tvm.script.ir_module +# undefined vars used +@tvm.script.ir_module(check_well_formed=False) class Conv2dInlineReshape3: @T.prim_func def main(input_placeholder_3: T.Buffer((192, 1), "int8"), input_ethosu_write_1: T.Buffer((1, 8, 6, 16), "int8")) -> None: @@ -747,7 +752,8 @@ def main(input_placeholder_3: T.Buffer((192, 1), "int8"), input_ethosu_write_1: T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) -@tvm.script.ir_module +# undefined vars used +@tvm.script.ir_module(check_well_formed=False) class Conv2dInlineReshape4: @T.prim_func def main(placeholder_3: T.Buffer((192,), "int8"), input_ethosu_write_1: T.Buffer((1, 8, 6, 16), "int8")) -> None: From 3247cb5711230861d1dd90f2296640d4219894cf Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 6 Mar 2024 17:56:03 -0500 Subject: [PATCH 51/58] Add more ethosu exemptions --- .../test_ethosu/test_merge_constants.py | 9 ++++---- .../test_ethosu/test_tir_to_cs_translator.py | 21 ++++++++++++------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_merge_constants.py b/tests/python/contrib/test_ethosu/test_merge_constants.py index b989c15e065e..8bb3cd70eec2 100644 --- a/tests/python/contrib/test_ethosu/test_merge_constants.py +++ b/tests/python/contrib/test_ethosu/test_merge_constants.py @@ -16,11 +16,11 @@ # under the License. import pytest -pytest.importorskip("ethosu.vela") +#pytest.importorskip("ethosu.vela") import tvm from tvm.script import tir as T -from tvm.relay.backend.contrib.ethosu.tir.passes import MergeConstants +#from tvm.relay.backend.contrib.ethosu.tir.passes import MergeConstants import numpy as np @@ -637,7 +637,8 @@ def main(input_placeholder: T.Buffer((1,16,16,32), "int8"), buffer1: T.Buffer((4 def test_cycle_count(): # fmt: off - @tvm.script.ir_module + # undefined vars used + @tvm.script.ir_module(check_well_formed=False) class InputModule: @T.prim_func def main(buffer2: T.Buffer((128,), "uint8"), buffer3: T.Buffer((32,), "uint8"), buffer4: T.Buffer((112,), "uint8"), buffer5: T.Buffer((32,), "uint8"), buffer6: T.Buffer((112,), "uint8"), buffer7: T.Buffer((32,), "uint8"), buffer8: T.Buffer((112,), "uint8"), buffer9: T.Buffer((32,), "uint8")) -> None: @@ -700,7 +701,7 @@ def main(buffer2: T.Buffer((128,), "uint8"), buffer3: T.Buffer((32,), "uint8"), T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p7[0], 112, 12, p8[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - @tvm.script.ir_module + @tvm.script.ir_module(check_well_formed=False) class ReferenceModule: @T.prim_func def main(buffer2: T.Buffer((160,), "uint8"), buffer4: T.Buffer((144,), "uint8"), buffer6: T.Buffer((144,), "uint8"), buffer8: T.Buffer((144,), "uint8")) -> None: diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py index 9a61be6b975d..69076f5337c8 100644 --- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py +++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py @@ -29,8 +29,9 @@ # fmt: off +# Undefined vars used """A sample tir test case for translator""" -@tvm.script.ir_module +@tvm.script.ir_module(check_well_formed=False) class SingleEthosUConv2D: @T.prim_func def main(placeholder_3: T.Buffer((8192,), "int8"), ethosu_conv2d_1: T.Buffer((1024,), "int8")) -> None: @@ -44,8 +45,9 @@ def main(placeholder_3: T.Buffer((8192,), "int8"), ethosu_conv2d_1: T.Buffer((10 # fmt: off +# undefined vars used """A sample tir test case with multiple convolutions for translator""" -@tvm.script.ir_module +@tvm.script.ir_module(check_well_formed=False) class MultiEthosUConv2D: @T.prim_func def main(placeholder_6: T.Buffer((192,), "int8"), ethosu_conv2d_1: T.Buffer((512,), "int8")) -> None: @@ -66,8 +68,9 @@ def main(placeholder_6: T.Buffer((192,), "int8"), ethosu_conv2d_1: T.Buffer((512 # fmt: off +# undefined vars used """A sample tir test case with copy operations for translator""" -@tvm.script.ir_module +@tvm.script.ir_module(check_well_formed=False) class MultiEthosUCopy: @T.prim_func def main(placeholder_3: T.Buffer((8192,), "int8"), ethosu_conv2d_1: T.Buffer((2048,), "int8")) -> None: @@ -85,8 +88,9 @@ def main(placeholder_3: T.Buffer((8192,), "int8"), ethosu_conv2d_1: T.Buffer((20 # fmt: off +# undefined vars used """A tir test case with copy operation having a buffer size less than the minimum for a DMA operation""" -@tvm.script.ir_module +@tvm.script.ir_module(check_well_formed=False) class CopyLessMinimal: @T.prim_func def main(ethos_u_0_i0: T.Buffer((1, 4), "int8"), ethosu_write: T.Buffer((1, 4), "int8")): @@ -105,8 +109,9 @@ def main(ethos_u_0_i0: T.Buffer((1, 4), "int8"), ethosu_write: T.Buffer((1, 4), # fmt: off +# undefined vars used """A TIR test module of weight streaming""" -@tvm.script.ir_module +@tvm.script.ir_module(check_well_formed=False) class WeightStreamOnly: @T.prim_func def main(placeholder: T.Buffer((8192,), "int8"), ethosu_write: T.Buffer((2048,), "int8")) -> None: @@ -150,8 +155,9 @@ def main(placeholder: T.Buffer((8192,), "int8"), ethosu_write: T.Buffer((2048,), # fmt: off +# undefined vars used """A TIR test module of weight streaming and direct reading""" -@tvm.script.ir_module +@tvm.script.ir_module(check_well_formed=False) class MixedRead: @T.prim_func def main(placeholder: T.Buffer((8192,), "int8"), ethosu_write: T.Buffer((2048,), "int8")) -> None: @@ -703,7 +709,8 @@ def populate_ethosu_copy_calls(stmt): # fmt: off -@tvm.script.ir_module +# undefined vars used +@tvm.script.ir_module(check_well_formed=False) class MixedConstantDatatypes: @T.prim_func def main(placeholder_4: T.Buffer((2048,), "int8"), ethosu_write_1: T.Buffer((16,), "int8")) -> None: From 9cd556dcc65781a6ff9763cc50bb195dddab79a6 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 6 Mar 2024 17:59:57 -0500 Subject: [PATCH 52/58] More exemptions for ethosu tests --- .../test_ethosu/test_merge_constants.py | 34 ++++++++++++------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_merge_constants.py b/tests/python/contrib/test_ethosu/test_merge_constants.py index 8bb3cd70eec2..5c5cd960e5d3 100644 --- a/tests/python/contrib/test_ethosu/test_merge_constants.py +++ b/tests/python/contrib/test_ethosu/test_merge_constants.py @@ -16,11 +16,11 @@ # under the License. import pytest -#pytest.importorskip("ethosu.vela") +pytest.importorskip("ethosu.vela") import tvm from tvm.script import tir as T -#from tvm.relay.backend.contrib.ethosu.tir.passes import MergeConstants +from tvm.relay.backend.contrib.ethosu.tir.passes import MergeConstants import numpy as np @@ -35,7 +35,8 @@ def check_const_dictionaries(const_dict, new_const_dict): def test_only_one_operator(): # fmt: off - @tvm.script.ir_module + # undefined vars used + @tvm.script.ir_module(check_well_formed=False) class InputModule: @T.prim_func def main(buffer2: T.Buffer((128,), "uint8"), buffer3: T.Buffer((32,), "uint8")) -> None: @@ -53,7 +54,8 @@ def main(buffer2: T.Buffer((128,), "uint8"), buffer3: T.Buffer((32,), "uint8")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 128, 12, p4[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - @tvm.script.ir_module + # undefined vars used + @tvm.script.ir_module(check_well_formed=False) class ReferenceModule: @T.prim_func def main(buffer2: T.Buffer((160,), "uint8")) -> None: @@ -80,7 +82,8 @@ def main(buffer2: T.Buffer((160,), "uint8")) -> None: def test_all_operators_with_weights(): # fmt: off - @tvm.script.ir_module + # undefined vars used + @tvm.script.ir_module(check_well_formed=False) class InputModule: @T.prim_func def main(buffer2: T.Buffer((128,), "uint8"), buffer3: T.Buffer((32,), "uint8"), buffer4: T.Buffer((112,), "uint8"), buffer5: T.Buffer((32,), "uint8"), buffer6: T.Buffer((112,), "uint8"), buffer7: T.Buffer((32,), "uint8"), buffer8: T.Buffer((112,), "uint8"), buffer9: T.Buffer((32,), "uint8")) -> None: @@ -119,7 +122,8 @@ def main(buffer2: T.Buffer((128,), "uint8"), buffer3: T.Buffer((32,), "uint8"), T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p7[0], 112, 12, p8[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - @tvm.script.ir_module + # undefined vars used + @tvm.script.ir_module(check_well_formed=False) class ReferenceModule: @T.prim_func def main(buffer2: T.Buffer((160,), "uint8"), buffer4: T.Buffer((144,), "uint8"), buffer6: T.Buffer((144,), "uint8"), buffer8: T.Buffer((144,), "uint8")) -> None: @@ -170,7 +174,8 @@ def main(buffer2: T.Buffer((160,), "uint8"), buffer4: T.Buffer((144,), "uint8"), def test_operators_with_and_without_weights(): # fmt: off - @tvm.script.ir_module + # undefined vars used + @tvm.script.ir_module(check_well_formed=False) class InputModule: @T.prim_func def main(buffer2: T.Buffer((80,), "uint8"), buffer3: T.Buffer((64,), "uint8")) -> None: @@ -189,7 +194,8 @@ def main(buffer2: T.Buffer((80,), "uint8"), buffer3: T.Buffer((64,), "uint8")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 214, 114, 2, 214, 0, 114, buffer0[0], 0, 0, 0, T.float32(0.00392157), -128, "NHCWB16", 1824, 16, 1, "int8", 214, 114, 5, 214, 0, 114, buffer6[0], 0, 0, 0, T.float32(0.0174839), -128, "NHCWB16", 1824, 16, 1, 3, 1, 1, 1, 1, 2, p2[0], 80, 0, p3[0], 64, 0, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - @tvm.script.ir_module + # undefined vars used + @tvm.script.ir_module(check_well_formed=False) class ReferenceModule: @T.prim_func def main(buffer2: T.Buffer((144,), "uint8")) -> None: @@ -218,7 +224,8 @@ def main(buffer2: T.Buffer((144,), "uint8")) -> None: def test_copy_to_buffer_with_local_scope(): # fmt: off - @tvm.script.ir_module + # undefined vars used + @tvm.script.ir_module(check_well_formed=False) class InputModule: @T.prim_func def main(buffer1: T.Buffer((64,), "uint8"), @@ -255,7 +262,8 @@ def main(buffer1: T.Buffer((64,), "uint8"), T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 4, 4, 4, 4, 0, 4, buffer9[0], 0, 0, 0, T.float32(0.0078125), 0, "NHCWB16", 64, 16, 1, "int8", 4, 4, 4, 4, 0, 4, buffer8[0], 0, 0, 0, T.float32(0.00372155), -128, "NHWC", 16, 4, 1, 1, 1, 1, 1, 1, 1, p5[0], 16, 0, p6[0], 48, 0, 0, 0, 0, "TANH", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - @tvm.script.ir_module + # undefined vars used + @tvm.script.ir_module(check_well_formed=False) class ReferenceModule: @T.prim_func def main(buffer1: T.Buffer((64,), "uint8"), @@ -346,7 +354,8 @@ def main() -> None: def test_copies_to_the_same_buffer(): # fmt: off - @tvm.script.ir_module + # undefined vars used + @tvm.script.ir_module(check_well_formed=False) class InputModule: @T.prim_func def main(buffer2: T.Buffer((128,), "uint8"), buffer3: T.Buffer((32,), "uint8")) -> None: @@ -367,7 +376,8 @@ def main(buffer2: T.Buffer((128,), "uint8"), buffer3: T.Buffer((32,), "uint8")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 128, 12, p4[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - @tvm.script.ir_module + # undefined vars used + @tvm.script.ir_module(check_well_formed=False) class ReferenceModule: @T.prim_func def main(buffer2: T.Buffer((160,), "uint8")) -> None: From 180018879bf34769ca951bf2486859b717d3b469 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 13 Mar 2024 20:33:27 -0400 Subject: [PATCH 53/58] Remove unused reference --- python/tvm/relax/frontend/nn/modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/nn/modules.py b/python/tvm/relax/frontend/nn/modules.py index 477e0fe882f9..b61656a2e6bd 100644 --- a/python/tvm/relax/frontend/nn/modules.py +++ b/python/tvm/relax/frontend/nn/modules.py @@ -19,7 +19,7 @@ from typing import List, Optional, Sequence, Union from tvm import relax as rx -from tvm import tir, ir +from tvm import tir from . import op from .core import Effect, Module, ModuleList, Parameter, Tensor, get_default_dtype From c54267beecdbd2244127d36979bab5d0f35f629c Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 14 Mar 2024 19:21:45 -0400 Subject: [PATCH 54/58] Indicate purity in test_transform_rewrite_cuda_graph --- tests/python/relax/test_transform_rewrite_cuda_graph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py b/tests/python/relax/test_transform_rewrite_cuda_graph.py index dc115939a7e4..91b3fce2640a 100644 --- a/tests/python/relax/test_transform_rewrite_cuda_graph.py +++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py @@ -709,7 +709,7 @@ def main(): def test_static_args(): @I.ir_module class Before: - @R.function + @R.function(pure=False) def main(): storage0 = R.memory.alloc_storage(R.shape([8]), 0, "global", "float32") alloc0 = R.memory.alloc_tensor(storage0, 0, R.shape([8]), "float32") @@ -734,7 +734,7 @@ def cuda_graph_capture(alloc0: R.Tensor((8,), dtype="float32")) -> R.Tuple: gv: R.Tuple = R.tuple() return gv - @R.function + @R.function(pure=False) def main() -> R.Tuple: cls = Expected gv: R.Tuple(R.Object) = R.call_builtin_with_ctx( From 3684fd3d2ac9f3b97022374495ddc7c328692f5c Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 14 Mar 2024 19:22:55 -0400 Subject: [PATCH 55/58] Indicate purity in test_transform_normalize --- tests/python/relax/test_transform_normalize.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/relax/test_transform_normalize.py b/tests/python/relax/test_transform_normalize.py index f37df4d07969..335ca7c70a12 100644 --- a/tests/python/relax/test_transform_normalize.py +++ b/tests/python/relax/test_transform_normalize.py @@ -571,11 +571,11 @@ def test_remove_usage_of_void_type_variables(): relax.VarBinding(x, R.assert_op(R.const(True, "bool"))), ] seq = relax.SeqExpr([relax.BindingBlock(bindings)], x) - before = relax.Function([], seq, ret_struct_info=R.Tuple([])) + before = relax.Function([], seq, ret_struct_info=R.Tuple([]), is_pure=False) after = relax.transform.Normalize()(tvm.IRModule({"main": before}))["main"] - @R.function(private=True) + @R.function(private=True, pure=False) def expected(): x = R.assert_op(R.const(True, "bool")) return R.tuple() From e77ef6e77847a7990de922b0320999af24fa15f2 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 18 Mar 2024 19:52:49 -0400 Subject: [PATCH 56/58] Reorder MergeSharedMemoryAllocations in GPU codegen --- src/driver/driver_api.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 33b4514e6b29..e3b4a5a6517c 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -590,6 +590,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::ThreadSync("shared")); mixed_pass_list.push_back(tir::transform::ThreadSync("shared.dyn")); + mixed_pass_list.push_back(tir::transform::MergeSharedMemoryAllocations()); mixed_pass_list.push_back(tir::transform::ThreadSync("warp")); mixed_pass_list.push_back(tir::transform::InferFragment()); mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce()); @@ -607,9 +608,6 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions()); mixed_pass_list.push_back(tir::transform::SplitHostDevice()); - // MergeSharedMemoryAllocations must be applied after SplitHostDevice - // because the merged allocation site is at the beginning of each device function - mixed_pass_list.push_back(tir::transform::MergeSharedMemoryAllocations()); bool unpacked_api = mixed_mod->GetAttr(tvm::attr::kExecutor) .value_or(relay::Executor::Create("graph", {})) From c8fb78cdad98ab71cfccf631635afab2c35174f4 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 19 Mar 2024 19:47:52 -0400 Subject: [PATCH 57/58] Add target parameter for FP8StorageLegalize and FP8ComputeLegalize --- python/tvm/tir/transform/transform.py | 17 +++++++++++++---- .../test_tir_transform_fp8_legalize.py | 9 ++++++--- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index c2022b918643..66472010ed50 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -23,6 +23,7 @@ from . import _ffi_api from . import function_pass as _fpass +from ...target import Target def Apply(ftransform): @@ -323,7 +324,7 @@ def BF16ComputeLegalize(): return _ffi_api.BF16ComputeLegalize() # type: ignore -def FP8ComputeLegalize(promote_dtype_str: str = "float32"): +def FP8ComputeLegalize(target: Target, promote_dtype_str: str = "float32"): """Legalize fp8 compute Ops. Parameters @@ -331,12 +332,15 @@ def FP8ComputeLegalize(promote_dtype_str: str = "float32"): promote_dtype : str The data type we promote fp8 to, options: float16/float32. + target : tvm.target.Target + The legalization target + Returns ------- fpass : tvm.transform.Pass The result pass """ - return _ffi_api.FP8ComputeLegalize(promote_dtype_str) # type: ignore + return _ffi_api.FP8ComputeLegalize(target, promote_dtype_str) # type: ignore def BF16StorageLegalize(): @@ -350,15 +354,20 @@ def BF16StorageLegalize(): return _ffi_api.BF16StorageLegalize() # type: ignore -def FP8StorageLegalize(): +def FP8StorageLegalize(target: Target): """Legalize fp8 storage types to u8. + Parameters + ---------- + target : tvm.target.Target + The legalization target + Returns ------- fpass : tvm.transform.Pass The result pass """ - return _ffi_api.FP8StorageLegalize() # type: ignore + return _ffi_api.FP8StorageLegalize(target) # type: ignore def CommonSubexprElimTIR(enable_cse_tir: bool = True, identify_equiv_terms: bool = False): diff --git a/tests/python/tir-transform/test_tir_transform_fp8_legalize.py b/tests/python/tir-transform/test_tir_transform_fp8_legalize.py index f5786808a6f3..6e44b53d0cae 100644 --- a/tests/python/tir-transform/test_tir_transform_fp8_legalize.py +++ b/tests/python/tir-transform/test_tir_transform_fp8_legalize.py @@ -17,6 +17,7 @@ import tvm import tvm.script import tvm.testing +from tvm.target import Target from tvm.script import tir as T # pylint: disable=no-member,invalid-name,unused-variable @@ -204,18 +205,20 @@ def main(Aptr: T.handle("uint8"), Bptr: T.handle("uint8"), Dptr: T.handle("uint8 def test_fp8_compute_legalize(dtype, promote_dtype): + target = Target("cuda") before = get_before(dtype) expected = get_after_compute_legalize(dtype, promote_dtype) # run the transform twice to ensure we can afford to deal # with this repeative optimizations - after = tvm.tir.transform.FP8ComputeLegalize(promote_dtype)(before) - after = tvm.tir.transform.FP8ComputeLegalize(promote_dtype)(after) + after = tvm.tir.transform.FP8ComputeLegalize(target, promote_dtype)(before) + after = tvm.tir.transform.FP8ComputeLegalize(target, promote_dtype)(after) tvm.ir.assert_structural_equal(after, expected) def test_fp8_storage_legalize(dtype, promote_dtype): + target = Target("cuda") before = get_after_compute_legalize(dtype, promote_dtype) - after = tvm.tir.transform.FP8StorageLegalize()(before) + after = tvm.tir.transform.FP8StorageLegalize(target)(before) expected = get_after_storage_legalize(dtype, promote_dtype) tvm.ir.assert_structural_equal(after, expected) From 33958089c199a80392a33ac8d6de9f682371f658 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 20 Mar 2024 20:38:46 -0400 Subject: [PATCH 58/58] Don't re-import Target in tvm/tir/transform/transform.py --- python/tvm/tir/transform/transform.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 66472010ed50..9f7f92dbed74 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -19,11 +19,10 @@ import enum -from typing import Callable, Optional +from typing import Any, Callable, Optional from . import _ffi_api from . import function_pass as _fpass -from ...target import Target def Apply(ftransform): @@ -324,7 +323,7 @@ def BF16ComputeLegalize(): return _ffi_api.BF16ComputeLegalize() # type: ignore -def FP8ComputeLegalize(target: Target, promote_dtype_str: str = "float32"): +def FP8ComputeLegalize(target: Any, promote_dtype_str: str = "float32"): """Legalize fp8 compute Ops. Parameters @@ -354,7 +353,7 @@ def BF16StorageLegalize(): return _ffi_api.BF16StorageLegalize() # type: ignore -def FP8StorageLegalize(target: Target): +def FP8StorageLegalize(target: Any): """Legalize fp8 storage types to u8. Parameters