From 77533199db9708e190be826e473c6e76b0815110 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 26 Jul 2024 11:38:24 -0500 Subject: [PATCH 1/2] [Relax][Analysis] Validate global_symbol on non-Relax functions Prior to this commit, the well-formed checker verified that the `"global_symbol"` attribute, if present, matches the name of the `GlobalVar`. However, this check was only applied for Relax functions. As a result, discrepencies between the `"global_symbol"` and the `gvar->name_hint` could result in unexpected bugs. (For example, https://github.com/apache/tvm/issues/17176.) This commit updates the well-formed checker to verify `"global_symbol"` on all functions in an `IRModule`. --- src/relax/analysis/well_formed.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 7688c4a64291..6c0394b9783c 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -90,11 +90,11 @@ class WellFormedChecker : public relax::ExprVisitor, WellFormedChecker(obj.as(), check_struct_info); if (const auto* mod = obj.as()) { - for (const auto& it : mod->functions) { + for (const auto& [gvar, base_func] : mod->functions) { + well_formed_checker.CheckGlobalVarAndGsymbolConsistency(gvar, base_func); // visit relax.Function - if (auto* n = it.second.as()) { - Function func = GetRef(n); - well_formed_checker.CheckGlobalVarAndGsymbolConsistency(it.first, func); + if (auto opt = base_func.as()) { + Function func = opt.value(); well_formed_checker.VisitExpr(func); } } @@ -133,7 +133,7 @@ class WellFormedChecker : public relax::ExprVisitor, LOG(WARNING) << "This IR is not well formed: " << diag->message; } - void CheckGlobalVarAndGsymbolConsistency(GlobalVar var, Function func) { + void CheckGlobalVarAndGsymbolConsistency(GlobalVar var, BaseFunc func) { // the uniqueness of all global vars are ensured by IRModule->global_var_map_, so do not need // to check again From ed94359576eab728615e55d67441bcb4b49199a6 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 16 Sep 2024 15:30:11 -0500 Subject: [PATCH 2/2] fix unit tests --- tests/python/tir-base/test_debug_info.py | 5 +- tests/python/tir-base/test_tir_host_func.py | 4 +- tests/python/tir-base/test_tir_intrin.py | 107 +++++++++--------- ...test_tir_transform_device_kernel_launch.py | 8 +- ...form_manifest_shared_memory_local_stage.py | 4 +- .../test_transform_default_gpu_schedule.py | 14 +-- tests/python/tir-usmp/test_tir_usmp_algo.py | 2 +- ...st_tir_usmp_analysis_extract_bufferinfo.py | 11 +- 8 files changed, 76 insertions(+), 79 deletions(-) diff --git a/tests/python/tir-base/test_debug_info.py b/tests/python/tir-base/test_debug_info.py index ecd25b3a6749..da92cc5a40dd 100644 --- a/tests/python/tir-base/test_debug_info.py +++ b/tests/python/tir-base/test_debug_info.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Test line-level debug info for TIR""" + import tvm import tvm.testing from tvm import tir @@ -104,7 +105,7 @@ def find_span(m): class module_before: @T.prim_func def main(a: T.handle, b: T.handle): - T.func_attr({"global_symbol": "main", "tir.noalias": True, "target": T.target("llvm")}) + T.func_attr({"tir.noalias": True, "target": T.target("llvm")}) A = T.match_buffer(a, (8,), dtype="float32") B = T.match_buffer(b, (8,), dtype="float32") for i in range(8): @@ -114,7 +115,7 @@ def main(a: T.handle, b: T.handle): @T.prim_func def subroutine(a_ptr: T.handle("float32"), b_ptr: T.handle("float32")): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"tir.noalias": True}) A = T.decl_buffer(1, "float32", data=a_ptr) B = T.decl_buffer(1, "float32", data=b_ptr) B[0] = A[1] + 1.0 diff --git a/tests/python/tir-base/test_tir_host_func.py b/tests/python/tir-base/test_tir_host_func.py index ed04985bdda1..7cce4db962ae 100644 --- a/tests/python/tir-base/test_tir_host_func.py +++ b/tests/python/tir-base/test_tir_host_func.py @@ -33,7 +33,6 @@ def main( ): T.func_attr( { - "global_symbol": "test", "target": tvm.target.Target("llvm", host="llvm"), "tir.noalias": True, } @@ -59,12 +58,11 @@ def test_host_func(): func = tvm.te.create_prim_func( te_workload.matmul(729, 729, 729, in_dtype="float32", out_dtype="float32") ) - mod = tvm.ir.IRModule({"main": func}) + mod = tvm.ir.IRModule({"main": func.with_attr("global_symbol", "main")}) target = tvm.target.Target("cuda") mod = tvm.tir.transform.Apply( lambda f: f.with_attr( { - "global_symbol": "test", "tir.is_host_func": 1, } ) diff --git a/tests/python/tir-base/test_tir_intrin.py b/tests/python/tir-base/test_tir_intrin.py index 1ee709191c41..ef01889fafc1 100644 --- a/tests/python/tir-base/test_tir_intrin.py +++ b/tests/python/tir-base/test_tir_intrin.py @@ -19,7 +19,7 @@ from tvm import te, tir from tvm import topi from tvm.contrib import utils, clang -from tvm.script import tir as T +from tvm.script import ir as I, tir as T import numpy as np import ctypes import math @@ -187,59 +187,60 @@ def clz_np(x, dtype): np.testing.assert_equal(b.numpy(), ref) -@tvm.script.ir_module -class Module: - @T.prim_func - def test_tir_fma(A: T.handle, B: T.handle, C: T.handle, d: T.handle) -> None: - # function attr dict - T.func_attr({"global_symbol": "test_fma", "tir.noalias": True}) - n = T.int32() - stride = T.int32() - stride_1 = T.int32() - stride_2 = T.int32() - stride_3 = T.int32() - A_1 = T.match_buffer( - A, - [n], - strides=[stride], - elem_offset=0, - align=64, - offset_factor=1, - buffer_type="auto", - ) - B_1 = T.match_buffer( - B, - [n], - strides=[stride_1], - elem_offset=0, - align=64, - offset_factor=1, - buffer_type="auto", - ) - C_1 = T.match_buffer( - C, - [n], - strides=[stride_2], - elem_offset=0, - align=64, - offset_factor=1, - buffer_type="auto", - ) - d_1 = T.match_buffer( - d, - [n], - strides=[stride_3], - elem_offset=0, - align=64, - offset_factor=1, - buffer_type="auto", - ) - # body - for i in T.serial(0, n): - d_1[(i * stride_3)] = (A_1[(i * stride)] * B_1[(i * stride_1)]) + C_1[(i * stride_2)] - - def test_fma(): + @I.ir_module + class Module: + @T.prim_func + def test_tir_fma(A: T.handle, B: T.handle, C: T.handle, d: T.handle) -> None: + # function attr dict + T.func_attr({"tir.noalias": True}) + n = T.int32() + stride = T.int32() + stride_1 = T.int32() + stride_2 = T.int32() + stride_3 = T.int32() + A_1 = T.match_buffer( + A, + [n], + strides=[stride], + elem_offset=0, + align=64, + offset_factor=1, + buffer_type="auto", + ) + B_1 = T.match_buffer( + B, + [n], + strides=[stride_1], + elem_offset=0, + align=64, + offset_factor=1, + buffer_type="auto", + ) + C_1 = T.match_buffer( + C, + [n], + strides=[stride_2], + elem_offset=0, + align=64, + offset_factor=1, + buffer_type="auto", + ) + d_1 = T.match_buffer( + d, + [n], + strides=[stride_3], + elem_offset=0, + align=64, + offset_factor=1, + buffer_type="auto", + ) + # body + for i in T.serial(0, n): + d_1[(i * stride_3)] = (A_1[(i * stride)] * B_1[(i * stride_1)]) + C_1[ + (i * stride_2) + ] + opt = tvm.transform.Sequential( [ tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm"))), diff --git a/tests/python/tir-transform/test_tir_transform_device_kernel_launch.py b/tests/python/tir-transform/test_tir_transform_device_kernel_launch.py index 34cde4e4b6ce..adb920d11fec 100644 --- a/tests/python/tir-transform/test_tir_transform_device_kernel_launch.py +++ b/tests/python/tir-transform/test_tir_transform_device_kernel_launch.py @@ -43,7 +43,7 @@ def main(A: T.Buffer(1, "float32")): T.func_attr({"target": T.target("llvm")}) mod.kernel(A.data) - @T.prim_func + @T.prim_func(private=True) def kernel(A_data: T.handle("float32")): T.func_attr({"target": T.target("cuda")}) A = T.decl_buffer(1, dtype="float32", data=A_data) @@ -66,7 +66,6 @@ def kernel(A_data: T.handle("float32")): "target": T.target("cuda"), "calling_conv": 2, "tir.kernel_launch_params": [], - "global_symbol": "kernel", "tir.is_global_func": True, } ) @@ -99,7 +98,7 @@ def main(A: T.Buffer(1, "float32")): @T.prim_func def kernel(A_data: T.handle("float32")): - T.func_attr({"target": T.target("cuda"), "global_symbol": "kernel_by_another_name"}) + T.func_attr({"target": T.target("cuda")}) A = T.decl_buffer(1, dtype="float32", data=A_data) A[0] = 0.0 @@ -111,7 +110,7 @@ class mod: @T.prim_func def main(A: T.Buffer(1, "float32")): T.func_attr({"target": T.target("llvm")}) - T.call_packed("kernel_by_another_name", A.data) + T.call_packed("kernel", A.data) @T.prim_func def kernel(A_data: T.handle("float32")): @@ -120,7 +119,6 @@ def kernel(A_data: T.handle("float32")): "target": T.target("cuda"), "calling_conv": 2, "tir.kernel_launch_params": [], - "global_symbol": "kernel_by_another_name", "tir.is_global_func": True, } ) diff --git a/tests/python/tir-transform/test_tir_transform_manifest_shared_memory_local_stage.py b/tests/python/tir-transform/test_tir_transform_manifest_shared_memory_local_stage.py index 15d7118fb8a9..ef8615393f7a 100644 --- a/tests/python/tir-transform/test_tir_transform_manifest_shared_memory_local_stage.py +++ b/tests/python/tir-transform/test_tir_transform_manifest_shared_memory_local_stage.py @@ -28,7 +28,7 @@ class MatmulBefore: @T.prim_func def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")) -> None: # function attr dict - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"tir.noalias": True}) # body # with T.block("root") for blockIdx_y in T.thread_binding(32, thread="blockIdx.y"): @@ -69,7 +69,7 @@ class MatmulAfter: @T.prim_func def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")) -> None: # function attr dict - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"tir.noalias": True}) # body # with T.block("root") for blockIdx_y in T.thread_binding(32, thread="blockIdx.y"): diff --git a/tests/python/tir-transform/test_transform_default_gpu_schedule.py b/tests/python/tir-transform/test_transform_default_gpu_schedule.py index 0a648338490c..6b31e8068021 100644 --- a/tests/python/tir-transform/test_transform_default_gpu_schedule.py +++ b/tests/python/tir-transform/test_transform_default_gpu_schedule.py @@ -77,7 +77,7 @@ def matmul( B: T.Buffer((32, 32), "float16"), C: T.Buffer((32, 32), "float16"), ): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"tir.noalias": True}) # with T.block("root"): for i, j, k in T.grid(32, 32, 32): with T.block("C"): @@ -94,8 +94,7 @@ def matmul_gpu( B: T.Buffer((32, 32), "float16"), C: T.Buffer((32, 32), "float16"), ): - T.func_attr({"global_symbol": "main", - "target": T.target({"arch": "sm_86", + T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, @@ -118,8 +117,7 @@ def matmul_cpu( B: T.Buffer((32, 32), "float16"), C: T.Buffer((32, 32), "float16"), ): - T.func_attr({"global_symbol": "main", - "target": T.target({"keys": ["cpu"], "kind": "llvm", "tag": ""}), + T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "llvm", "tag": ""}), "tir.noalias": True}) # with T.block("root"): for i, j, k in T.grid(32, 32, 32): @@ -139,7 +137,7 @@ def matmul( B: T.Buffer((32, 32), "float16"), C: T.Buffer((32, 32), "float16"), ): - T.func_attr({"tir.is_scheduled": True, "global_symbol": "main", "tir.noalias": True}) + T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) # with T.block("root"): for i_j_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for i_j_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): @@ -160,7 +158,7 @@ def matmul( @T.prim_func def matmul_cpu(A: T.Buffer((32, 32), "float16"), B: T.Buffer((32, 32), "float16"), C: T.Buffer((32, 32), "float16")): - T.func_attr({"global_symbol": "main", "target": T.target({"keys": ["cpu"], "kind": "llvm", "tag": ""}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)}) + T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "llvm", "tag": ""}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)}) # with T.block("root"): for i, j, k in T.grid(32, 32, 32): with T.block("C"): @@ -173,7 +171,7 @@ def matmul_cpu(A: T.Buffer((32, 32), "float16"), B: T.Buffer((32, 32), "float16" @T.prim_func def matmul_gpu(A: T.Buffer((32, 32), "float16"), B: T.Buffer((32, 32), "float16"), C: T.Buffer((32, 32), "float16")): - T.func_attr({"global_symbol": "main", "target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)}) + T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)}) # with T.block("root"): for i_j_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for i_j_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): diff --git a/tests/python/tir-usmp/test_tir_usmp_algo.py b/tests/python/tir-usmp/test_tir_usmp_algo.py index b9cfde485633..4105730803fb 100644 --- a/tests/python/tir-usmp/test_tir_usmp_algo.py +++ b/tests/python/tir-usmp/test_tir_usmp_algo.py @@ -350,7 +350,7 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: @T.prim_func def run_model(input: T.handle, output: T.handle) -> None: # function attr dict - T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) + T.func_attr({"runner_function": True}) # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) diff --git a/tests/python/tir-usmp/test_tir_usmp_analysis_extract_bufferinfo.py b/tests/python/tir-usmp/test_tir_usmp_analysis_extract_bufferinfo.py index f8da0ef9f42d..8aa1ed3f2abb 100644 --- a/tests/python/tir-usmp/test_tir_usmp_analysis_extract_bufferinfo.py +++ b/tests/python/tir-usmp/test_tir_usmp_analysis_extract_bufferinfo.py @@ -105,6 +105,7 @@ def _assign_targets_to_primfuncs_irmodule(mod, target): # These are test IRModules that contains varied topologies of operator graphs # that includes a main TIR function that includes call to such operators. + # fmt: off @tvm.script.ir_module class LinearStructure: @@ -163,7 +164,7 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: @T.prim_func def run_model(input: T.handle, output: T.handle) -> None: # function attr dict - T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) + T.func_attr({"runner_function": True}) # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) @@ -238,7 +239,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placehol @T.prim_func def run_model(input: T.handle, output: T.handle) -> None: # function attr dict - T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) + T.func_attr({"runner_function": True}) # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) @@ -278,7 +279,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placehol @T.prim_func def run_model(input: T.handle, output: T.handle) -> None: # function attr dict - T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) + T.func_attr({"runner_function": True}) # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) @@ -618,7 +619,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placehol @T.prim_func def run_model(input: T.handle, output: T.handle) -> None: # function attr dict - T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) + T.func_attr({"runner_function": True}) # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) @@ -1334,7 +1335,7 @@ def tvmgen_default_fused_nn_softmax_add(placeholder_26: T.handle, placeholder_27 @T.prim_func def run_model(data: T.handle, output: T.handle) -> None: # function attr dict - T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) + T.func_attr({"runner_function": True}) data_buffer = T.match_buffer(data, [864], dtype="float32", align=16) output_buffer = T.match_buffer(output, [864], dtype="float32", align=16) # body