From 04a97a604b2610cf8240417272815b75e31a2d65 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 9 Nov 2022 15:33:10 -0800 Subject: [PATCH] [Minor][Testing] Consolidate IRs into corresponding functions We moved most of the IR definition into the testing methods correspondingly. Co-authored-by: Yaxing Cai --- python/tvm/testing/__init__.py | 2 - python/tvm/testing/tir.py | 45 +- .../unittest/test_tvmscript_error_report.py | 710 ++++++++---------- .../unittest/test_tvmscript_syntax_sugar.py | 13 +- 4 files changed, 330 insertions(+), 440 deletions(-) diff --git a/python/tvm/testing/__init__.py b/python/tvm/testing/__init__.py index 9a18f1689100..d84846725ec4 100644 --- a/python/tvm/testing/__init__.py +++ b/python/tvm/testing/__init__.py @@ -28,7 +28,5 @@ from .popen_pool import call_py_ffi, call_cpp_py_ffi, fast_summation, slow_summation from .popen_pool import timeout_job -from .tir import check_error - from . import auto_scheduler from . import autotvm diff --git a/python/tvm/testing/tir.py b/python/tvm/testing/tir.py index 8dd482673829..57c1a85c5b9f 100644 --- a/python/tvm/testing/tir.py +++ b/python/tvm/testing/tir.py @@ -16,49 +16,6 @@ # under the License. # pylint: disable=invalid-name, import-outside-toplevel, unused-variable """Common utility functions in TVM tir""" -import inspect -import re -import tvm -from tvm.ir.diagnostics import override_renderer - - -CHECK_ERROR_RE = re.compile(r"^.*# check_error: (.+)$") - - -def check_error(func, rel_lineno): - """check if TIR script throws error""" - # Override the default renderer to accumulate errors - errors = [] - - def render(e): - for d in e.diagnostics: - errors.append(d) - - override_renderer(render) - # The diagnostic context throws an exception when it gets an error - try: - source_code = inspect.getsource(func) - source_code = "@T.prim_func\n" + source_code - from tvm.script import from_source - - # to avoid cyclic import - from_source(source_code) - except tvm.error.DiagnosticError as e: - pass - assert len(errors) == 1, errors - for d in errors: - assert ( - d.span.line - 1 == rel_lineno - ), f"Expected error to be on line {rel_lineno}, but it was on {d.span.line - 1}" - - error_line = source_code.split("\n")[rel_lineno] - m = CHECK_ERROR_RE.match(error_line) - if m: - expected_error_text = m.group(1) - errors = [e.message for e in errors] - assert ( - expected_error_text in errors - ), f'check_error expects "{expected_error_text} in str(errors): {errors}' def mma_schedule( @@ -80,6 +37,8 @@ def mma_schedule( shared_scope="shared", ): """Create a tensorized schedule for GEMM with MMA intrinsics.""" + import tvm # pylint: disable=import-outside-toplevel + ir_module = tvm.IRModule({"main": workload}) sch = tvm.tir.Schedule(ir_module) diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index acc68af065dd..36de35fa928b 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -14,310 +14,304 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import inspect +import re import pytest -import sys import tvm +import tvm.testing from tvm import tir -from tvm.testing import check_error +from tvm.ir.diagnostics import override_renderer +from tvm.script import from_source from tvm.script import tir as T -def buffer_bind_missing_args(a: T.handle) -> None: - A = T.match_buffer((16, 16), "float32") # error +def check_error(func, rel_lineno): + check_error_re = re.compile(r"^.*# check_error: (.+)$") + """check if TIR script throws error""" + # Override the default renderer to accumulate errors + errors = [] + + def render(e): + for d in e.diagnostics: + errors.append(d) + + override_renderer(render) + # The diagnostic context throws an exception when it gets an error + try: + source_code = inspect.getsource(func) + indent = len(re.match(r"^\s*", source_code).group(0)) + source_code = "@T.prim_func\n" + "\n".join( + line[indent:] for line in source_code.splitlines() + ) + from_source(source_code) + except tvm.error.DiagnosticError as e: + pass + assert len(errors) == 1, errors + if rel_lineno is None: + return + error = errors[0] + assert ( + error.span.line - 1 == rel_lineno + ), f"Expected error to be on line {rel_lineno}, but it was on {error.span.line - 1}" + + error_line = source_code.split("\n")[rel_lineno] + m = check_error_re.match(error_line) + if m: + expected_error_text = m.group(1) + error = error.message + assert ( + expected_error_text == error + ), f'check_error expects "{expected_error_text} in str(errors): {error}' def test_buffer_bind(): - check_error(buffer_bind_missing_args, 2) - + def buffer_bind_missing_args(a: T.handle) -> None: + A = T.match_buffer((16, 16), "float32") # error -def undefined_buffer(a: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - - T.attr(A, "realize_scope", "") - T.realize(C[0:16, 0:16], "") # error - for i in T.serial(16): - for j in T.serial(0, 16): - A[i, j] = 0.0 + check_error(buffer_bind_missing_args, 2) def test_undefined_buffer(): - check_error(undefined_buffer, 5) + def undefined_buffer(a: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + T.attr(A, "realize_scope", "") + T.realize(C[0:16, 0:16], "") # error + for i in T.serial(16): + for j in T.serial(0, 16): + A[i, j] = 0.0 -def unsupported_stmt(a: T.int32) -> None: - if a > 0: - print("I love tvm") # error + check_error(undefined_buffer, 5) def test_unsupported_stmt(): - check_error(unsupported_stmt, 3) - + def unsupported_stmt(a: T.int32) -> None: + if a > 0: + print("I love tvm") # error -def unsupported_function_call(a: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - - T.attr(A, "realize_scope", "") - T.realize(A[0:16, 0:16], "") - for i in T.const_range(16): # error - for j in T.serial(0, 16): - A[i, j] = 0.0 + check_error(unsupported_stmt, 3) def test_unsupported_function_call(): - check_error(unsupported_function_call, 6) + def unsupported_function_call(a: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + T.attr(A, "realize_scope", "") + T.realize(A[0:16, 0:16], "") + for i in T.const_range(16): # error + for j in T.serial(0, 16): + A[i, j] = 0.0 -def missing_type_annotation(a) -> None: # error - T.evaluate(0.0) + check_error(unsupported_function_call, 6) def test_missing_type_annotation(): - check_error(missing_type_annotation, 1) - - -def invalid_expr_stmt() -> None: - T.max(1, 2) # error - - -def test_invalid_expr_stmt(): - check_error(invalid_expr_stmt, 2) - - -def invalid_for_function(a: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") + def missing_type_annotation(a) -> None: # error + T.evaluate(0.0) - for i in T.evaluate(0.0): # error - for j in T.serial(0, 16): - A[i, j] = 0.0 + check_error(missing_type_annotation, 1) def test_invalid_for_function(): - check_error(invalid_for_function, 4) + def invalid_for_function(a: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + for i in T.evaluate(0.0): # error + for j in T.serial(0, 16): + A[i, j] = 0.0 -def invalid_block_function(a: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - - with T.evaluate(0.0): # error - T.evaluate(1.0) + check_error(invalid_for_function, 4) def test_invalid_block_function(): - check_error(invalid_block_function, 4) + def invalid_block_function(a: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + with T.evaluate(0.0): # error + T.evaluate(1.0) -def return_not_allowed(a: T.handle) -> None: - return T.evaluate(0) # error + check_error(invalid_block_function, 4) def test_return_not_allowed(): - check_error(return_not_allowed, 2) + def return_not_allowed(a: T.handle) -> None: + return T.evaluate(0) # error - -def tir_assert(a: T.handle) -> None: - T.Assert(0, "") # error - - -def test_tir_assert(): - check_error(tir_assert, 2) - - -def no_body(a: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - T.realize(A, "") # error + check_error(return_not_allowed, 2) def test_no_body(): - check_error(no_body, 3) + def no_body(a: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + T.realize(A, "") # error - -def allocate_with_buffers() -> None: - with T.allocate([1], "float32", "") as [A, B]: # error - T.evaluate(1.0) + check_error(no_body, 3) def test_allocate_with_buffers(): - check_error(allocate_with_buffers, 2) - + def allocate_with_buffers() -> None: + with T.allocate([1], "float32", "") as [A, B]: # error + T.evaluate(1.0) -def inconsistent_binding_value() -> None: - for i, j in T.grid(16, 16): - vi, vj = T.axis.remap("SS", [i]) # error - T.evaluate(1.0) + check_error(allocate_with_buffers, 2) -def inconsistent_binding_type() -> None: - for i, j in T.grid(16, 16): - vi, vj = T.axis.remap("S", [i, j]) # error - T.evaluate(1.0) +def test_inconsistent_binding(): + def inconsistent_binding_value() -> None: + for i, j in T.grid(16, 16): + vi, vj = T.axis.remap("SS", [i]) # error + T.evaluate(1.0) + def inconsistent_binding_type() -> None: + for i, j in T.grid(16, 16): + vi, vj = T.axis.remap("S", [i, j]) # error + T.evaluate(1.0) -def test_inconsistent_binding(): check_error(inconsistent_binding_value, 3) check_error(inconsistent_binding_type, 3) -def error_remap_type() -> None: - for i, j in T.grid(16, 16): - with T.block(): - vi, vj = T.axis.remap("TT", [i, j]) # error - T.evaluate(1.0) - - -def error_remap_value() -> None: - for i, j in T.grid(16, 16): - with T.block(): - vi, vj = T.axis.remap("SS", [i + j, j]) # error - T.evaluate(1.0) +def test_error_remap_args(): + def error_remap_type() -> None: + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("TT", [i, j]) # error + T.evaluate(1.0) + def error_remap_value() -> None: + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i + j, j]) # error + T.evaluate(1.0) -def test_error_remap_args(): check_error(error_remap_type, 4) check_error(error_remap_value, 4) -def invalid_block_axes(a: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - for i, j in T.grid(16, 16): - with T.block(): - vi = T.axis.S(i, A) # error - T.evaluate(1.0) - - def test_invalid_block_axes(): - check_error(invalid_block_axes, 5) - + def invalid_block_axes(a: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + for i, j in T.grid(16, 16): + with T.block(): + vi = T.axis.S(i, A) # error + T.evaluate(1.0) -def duplicate_block_axes() -> None: - for i, j in T.grid(16, 16): - with T.block(): - vi = T.axis.S(16, i) - vi = T.axis.S(16, j) # error - T.evaluate(1.0) + check_error(invalid_block_axes, 5) -def duplicate_block_axes_remap() -> None: - for i, j in T.grid(16, 16): - with T.block(): - vi, vi = T.axis.remap("SS", [i, j]) # error - T.evaluate(1.0) +def test_duplicate_block_axes(): + def duplicate_block_axes() -> None: + for i, j in T.grid(16, 16): + with T.block(): + vi = T.axis.S(16, i) + vi = T.axis.S(16, j) # error + T.evaluate(1.0) + def duplicate_block_axes_remap() -> None: + for i, j in T.grid(16, 16): + with T.block(): + vi, vi = T.axis.remap("SS", [i, j]) # error + T.evaluate(1.0) -def test_duplicate_block_axes(): check_error(duplicate_block_axes, 5) check_error(duplicate_block_axes_remap, 4) -def miss_block_bind_value() -> None: - for i, j in T.grid(128, 128): - with T.block(): - vi = T.axis.S(i) # error - T.evaluate(1.0) - - def test_miss_block_bind(): - check_error(miss_block_bind_value, 4) - + def miss_block_bind_value() -> None: + for i, j in T.grid(128, 128): + with T.block(): + vi = T.axis.S(i) # error + T.evaluate(1.0) -def invalid_loop_var() -> None: - for i, j in range(0, 16): # error - T.evaluate(1.0) + check_error(miss_block_bind_value, 4) def test_invalid_loop_var(): - check_error(invalid_loop_var, 2) - + def invalid_loop_var() -> None: + for i, j in range(0, 16): # error + T.evaluate(1.0) -def inconsistent_grid() -> None: - for i in T.grid(16, 16): # error - T.evaluate(1.0) + check_error(invalid_loop_var, 2) def test_inconsistent_grid(): - check_error(inconsistent_grid, 2) - - -def invalid_match_buffer_region() -> None: - for i, j in T.grid(128, 128): - with T.block(): - vi, vj = T.axis.remap("SS", [i, j]) - A = T.match_buffer(vi) # error + def inconsistent_grid() -> None: + for i in T.grid(16, 16): # error T.evaluate(1.0) + check_error(inconsistent_grid, 2) -def test_invalid_match_buffer_region(): - check_error(invalid_match_buffer_region, 5) +def test_invalid_match_buffer_region(): + def invalid_match_buffer_region() -> None: + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + A = T.match_buffer(vi) # error + T.evaluate(1.0) -def duplicate_buffer() -> None: - A = T.alloc_buffer((128, 128), "float32") - for i, j in T.grid(128, 128): - with T.block(): - vi, vj = T.axis.remap("SS", [i, j]) - A = T.alloc_buffer((128, 128), "float32") # error - T.evaluate(1.0) + check_error(invalid_match_buffer_region, 5) def test_duplicate_buffer(): - check_error(duplicate_buffer, 6) + def duplicate_buffer() -> None: + A = T.alloc_buffer((128, 128), "float32") + A = T.alloc_buffer((128, 128), "float32") # error - -def duplicate_reads() -> None: - A = T.alloc_buffer((128, 128), "float32") - for i, j in T.grid(128, 128): - with T.block(): - vi, vj = T.axis.remap("SS", [i, j]) - T.reads(A[0:8, 0:8]) - T.reads(A[0:16, 0:16]) # error - T.evaluate(1.0) - - -def duplicate_writes() -> None: - A = T.alloc_buffer((128, 128), "float32") - for i, j in T.grid(128, 128): - with T.block(): - vi, vj = T.axis.remap("SS", [i, j]) - T.writes(A[0:8, 0:8]) - T.writes(A[0:16, 0:16]) # error - T.evaluate(1.0) + check_error(duplicate_buffer, 3) -def duplicate_predicate() -> None: - for i, j in T.grid(16, 16): - with T.block(): - vi, vj = T.axis.remap("SS", [i, j]) - T.where(1) - T.where(0) # error - - -def duplicate_annotations() -> None: - for i, j in T.grid(16, 16): - with T.block(): - vi, vj = T.axis.remap("SS", [i, j]) - T.block_attr({}) - T.block_attr({}) # error - - -def duplicate_init() -> None: - for i, j in T.grid(16, 16): - with T.block(): - vi, vj = T.axis.remap("SS", [i, j]) - with T.init(): +def test_duplicate_block_signature(): + def duplicate_reads() -> None: + A = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[0:8, 0:8]) + T.reads(A[0:16, 0:16]) # error T.evaluate(1.0) - with T.init(): # error + + def duplicate_writes() -> None: + A = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + T.writes(A[0:8, 0:8]) + T.writes(A[0:16, 0:16]) # error T.evaluate(1.0) + def duplicate_predicate() -> None: + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + T.where(1) + T.where(0) # error -def duplicate_axes() -> None: - for i, j in T.grid(16, 16): - with T.block(): - vi, vj = T.axis.remap("SS", [i, j]) - vi = T.axis.S(i, 16) # error - T.evaluate(1.0) + def duplicate_annotations() -> None: + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + T.block_attr({}) + T.block_attr({}) # error + def duplicate_init() -> None: + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + with T.init(): + T.evaluate(1.0) + with T.init(): # error + T.evaluate(1.0) + + def duplicate_axes() -> None: + for i, j in T.grid(16, 16): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + vi = T.axis.S(i, 16) # error + T.evaluate(1.0) -def test_duplicate_block_signature(): check_error(duplicate_reads, 7) check_error(duplicate_writes, 7) check_error(duplicate_predicate, 6) @@ -326,143 +320,105 @@ def test_duplicate_block_signature(): check_error(duplicate_axes, 5) -def opaque_access_during_complete(a: T.handle) -> None: # error - A = T.match_buffer(a, (16, 16), "float32") - for i, j in T.grid(16, 16): - with T.block(): - T.evaluate(T.call_extern("dummy_extern_function", A.data, dtype="int32")) - - def test_opaque_access_during_complete(): - check_error(opaque_access_during_complete, 1) - + def opaque_access_during_complete(a: T.handle) -> None: # error + A = T.match_buffer(a, (16, 16), "float32") + for i, j in T.grid(16, 16): + with T.block(): + T.evaluate(T.call_extern("dummy_extern_function", A.data, dtype="int32")) -def convert_slice_to_bufferload() -> None: - A = T.alloc_buffer((128, 128), "float32") - for i, j in T.grid(128, 128): - with T.block(): - vi, vj = T.axis.remap("SS", [i, j]) - A[vi, vj] = A[vi : vi + 2, vj] + 1 # error + check_error(opaque_access_during_complete, None) def test_convert_slice_to_bufferload(): - check_error(convert_slice_to_bufferload, 6) - - -def error_index_type() -> None: - A = T.alloc_buffer((128, 128), "float32") - for i, j in T.grid(16, 16): - with T.block(): - vi, vj = T.axis.remap("SS", [i, j]) - A[vi, vj] = A[vi, 0.0] + 1 # error - - -def error_bufferslice_index_type() -> None: - A = T.alloc_buffer((1,), "float32") - B = T.alloc_buffer((16, 16), "float32") - C = T.alloc_buffer((16, 16), "float32") - for i, j in T.grid(16, 16): - with T.block(): - vi, vj = T.axis.remap("SS", [i, j]) - C[vi, vj] = B[vi, A[0]] # error - - -def test_error_index_type(): - check_error(error_index_type, 6) - check_error(error_bufferslice_index_type, 8) - - -def special_stmt_except() -> None: - A = T.alloc_buffer("(128, 128)", "float32") # error - T.evaluate(1.0) + def convert_slice_to_bufferload() -> None: + A = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + A[vi, vj] = A[vi : vi + 2, vj] + 1 # error - -def scope_handler_except() -> None: - for i in T.serial("1", "1"): # error - T.evaluate(1) + check_error(convert_slice_to_bufferload, 6) -def intrin_except_unassign(a: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - T.evaluate(A) # error +def test_tvm_exception_catch(): + def special_stmt_except() -> None: + A = T.alloc_buffer("(128, 128)", "float32") # error + T.evaluate(1.0) + def scope_handler_except() -> None: + for i in T.serial("1", "1"): # error + T.evaluate(1) -def intrin_except_assign(a: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - A[0, 0] = A[A] # error + def intrin_except_unassign(a: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + T.evaluate(A) # error + def intrin_except_assign(a: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + A[0, 0] = A[A] # error -def test_tvm_exception_catch(): - # test catching c++ side exception check_error(special_stmt_except, 2) check_error(scope_handler_except, 2) check_error(intrin_except_unassign, 3) check_error(intrin_except_assign, 3) -def buffer_shape_mismatch(a: T.handle) -> None: - A = T.match_buffer(a, (8, 8)) - for i, j in T.grid(8, 2): - with T.block(): - T.reads([]) - T.writes([A[i, j * 4 : j * 4 + 4]]) - sub_A = T.match_buffer( - A[i, j * 4 : j * 4 + 4], (5) - ) # error: shape mismatched between 4 and 5 - for jj in range(0, 4): - sub_A[i, j * 4 + jj] = 1 - - def test_match_buffer_shape_mismatch(): - check_error(buffer_shape_mismatch, 7) - + def buffer_shape_mismatch(a: T.handle) -> None: + A = T.match_buffer(a, (8, 8)) + for i, j in T.grid(8, 2): + with T.block(): + T.reads([]) + T.writes([A[i, j * 4 : j * 4 + 4]]) + sub_A = T.match_buffer( + A[i, j * 4 : j * 4 + 4], (5) + ) # error: shape mismatched between 4 and 5 + for jj in range(0, 4): + sub_A[i, j * 4 + jj] = 1 -def high_dim_store() -> None: - with T.block("root"): - B = T.allocate([256], "float32", "global") - for i, j in T.grid(16, 16): - B[i, j] = 1.0 # error: Store is only allowed with one index + check_error(buffer_shape_mismatch, 7) def test_high_dim_store(): - check_error(high_dim_store, 5) + def high_dim_store() -> None: + with T.block("root"): + B = T.allocate([256], "float32", "global") + for i, j in T.grid(16, 16): + B[i, j] = 1.0 # error: Store is only allowed with one index - -def block_has_option_vars() -> None: - with T.block("root") as x: # error: block does not support option_vars - T.evaluate(0.0) + check_error(high_dim_store, 5) def test_block_has_option_vars(): - check_error(block_has_option_vars, 2) - - -def implicit_root_has_read(): - T.reads([]) # error: implicit root does not support reads - T.evaluate(0.0) - - -def implicit_root_has_write(): - T.writes([]) # error: implicit root does not support writes - T.evaluate(0.0) + def block_has_option_vars() -> None: + with T.block("root") as x: # error: block does not support option_vars + T.evaluate(0.0) + check_error(block_has_option_vars, 2) -def implicit_root_has_attrs(): - T.block_attr({}) # error: implicit root does not support block_attr - T.evaluate(0.0) +def test_implicit_root_has_attrs(): + def implicit_root_has_read(): + T.reads([]) # error: implicit root does not support reads + T.evaluate(0.0) -def implicit_root_has_predicate(): - T.where(True) # error: implicit root does not support predicate - T.evaluate(0.0) + def implicit_root_has_write(): + T.writes([]) # error: implicit root does not support writes + T.evaluate(0.0) + def implicit_root_has_attrs(): + T.block_attr({}) # error: implicit root does not support block_attr + T.evaluate(0.0) -def implicit_root_has_axes(): - v = T.axis.S(0, 0) # error: implicit root does not support axis define - T.evaluate(0.0) + def implicit_root_has_predicate(): + T.where(True) # error: implicit root does not support predicate + T.evaluate(0.0) + def implicit_root_has_axes(): + v = T.axis.S(0, 0) # error: implicit root does not support axis define + T.evaluate(0.0) -def test_implicit_root_has_attrs(): check_error(implicit_root_has_read, 2) check_error(implicit_root_has_write, 2) check_error(implicit_root_has_attrs, 2) @@ -554,127 +510,115 @@ def test_report_error_root_block(): assert expected_sub_error_message in str(execinfo.value) -def load_var_multiple() -> None: - d = T.var("float32") - d[2] = d[2, 1] # error cannot provide two indices to load - - def test_load_var(): - check_error(load_var_multiple, 3) - + def load_var_multiple() -> None: + d = T.var("float32") + d[2] = d[2, 1] # error cannot provide two indices to load -def store_var_multiple() -> None: - d = T.var("float32") - d[2, 1] = d[1] # error cannot provide two indices to store + check_error(load_var_multiple, 3) def test_store_var(): - check_error(store_var_multiple, 3) - + def store_var_multiple() -> None: + d = T.var("float32") + d[2, 1] = d[1] # error cannot provide two indices to store -def load_handle(h: T.handle) -> None: - h_ = T.match_buffer(h, [1]) - h_[0] = h[0] # error cannot load from handle + check_error(store_var_multiple, 3) def test_load_handle(): - check_error(load_var_multiple, 3) + def load_handle(h: T.handle) -> None: + h_ = T.match_buffer(h, [1]) + h_[0] = h[0] # error cannot load from handle - -def store_handle(h: T.handle) -> None: - h_ = T.match_buffer(h, [1]) - h[0] = h_[0] # error cannot store to handle + check_error(load_handle, 3) def test_store_handle(): - check_error(store_var_multiple, 3) - + def store_handle(h: T.handle) -> None: + h_ = T.match_buffer(h, [1]) + h[0] = h_[0] # error cannot store to handle -def binop_bad_ast_type(h: T.handle): - h_ = T.match_buffer(h, [1]) - h_[0] = h + [2] # error rhs should be a primexpr + check_error(store_handle, 3) def test_binop_bad_ast_type(): - check_error(binop_bad_ast_type, 3) - + def binop_bad_ast_type(h: T.handle): + h_ = T.match_buffer(h, [1]) + h_[0] = h + [2] # error rhs should be a primexpr -def binop_bad_type(h: T.handle): - h_ = T.match_buffer(h, [1]) - h_[0] = h + 2 # error lhs and rhs should be the same type + check_error(binop_bad_ast_type, 3) def test_binop_bad_type(): - check_error(binop_bad_type, 3) - - -def floor_dtype(h: T.handle): - h_ = T.match_buffer(h, [1]) - h_[0] = T.floor(2) # error floor requires a dtype - + def binop_bad_type(h: T.handle): + h_ = T.match_buffer(h, [1]) + h_[0] = h + 2 # error lhs and rhs should be the same type -def test_floor_dtype(): - check_error(floor_dtype, 3) - - -def non_integer_typed_block_iter(): - with T.block(): - i = T.axis.S(0.1, 0.1) # error IterVar requires an integer dtype + check_error(binop_bad_type, 3) def test_non_integer_typed_block_iter(): - check_error(non_integer_typed_block_iter, 3) - + def non_integer_typed_block_iter(): + with T.block(): + i = T.axis.S(0.1, 0.1) # error IterVar requires an integer dtype -def preflattened_buffer_map_align_nonint(foo: T.handle): - foo_1 = T.match_buffer(foo, [1]) - T.preflattened_buffer( - foo_1, [1], align="bar" - ) # check_error: align: want int or IntImm, got 'bar' + check_error(non_integer_typed_block_iter, 3) def test_preflattened_buffer_map_align(): - check_error(preflattened_buffer_map_align_nonint, 3) - + def preflattened_buffer_map_align_nonint(foo: T.handle): + foo_1 = T.match_buffer(foo, [1]) + T.preflattened_buffer( + foo_1, [1], align="bar" + ) # check_error: align: want int or IntImm, got 'bar' -def preflattened_buffer_map_offset_factor_nonint(foo: T.handle): - foo_1 = T.match_buffer(foo, [1]) - T.preflattened_buffer( - foo_1, [1], offset_factor="bar" - ) # check_error: offset_factor: want int or IntImm, got 'bar' + check_error(preflattened_buffer_map_align_nonint, 3) def test_preflattened_buffer_map_offset_factor(): - check_error(preflattened_buffer_map_offset_factor_nonint, 3) - - -def strided_buffer_region(A: T.handle): - # do not allow stride in buffer region - A = T.match_buffer((128, 128), "int32") - with T.block(): - T.reads([]) - T.writes([A[0:128:2, 0:128:3]]) # error - T.evaluate(T.call_extern("strided_compute", dtype="")) + def preflattened_buffer_map_offset_factor_nonint(foo: T.handle): + foo_1 = T.match_buffer(foo, [1]) + T.preflattened_buffer( + foo_1, [1], offset_factor="bar" + ) # check_error: offset_factor: want int or IntImm, got 'bar' + check_error(preflattened_buffer_map_offset_factor_nonint, 3) -def access_reversed_slice(A: T.handle): - # do not allow reversed slice step - A = T.match_buffer((128,), "int32") - A[0:128:-1] = T.broadcast(1, 128) # error +def test_illegal_buffer_slice(): + def strided_buffer_region(A: T.handle): + # do not allow stride in buffer region + A = T.match_buffer((128, 128), "int32") + with T.block(): + T.reads([]) + T.writes([A[0:128:2, 0:128:3]]) # error + T.evaluate(T.call_extern("strided_compute", dtype="")) -def access_non_const_slice_length(A: T.handle): - # do not allow non-constant slice length - A = T.match_buffer((128,), "int32") - for i in range(4): - T.evaluate(A[0:i:1]) # error + def access_reversed_slice(A: T.handle): + # do not allow reversed slice step + A = T.match_buffer((128,), "int32") + A[0:128:-1] = T.broadcast(1, 128) # error + def access_non_const_slice_length(A: T.handle): + # do not allow non-constant slice length + A = T.match_buffer((128,), "int32") + for i in range(4): + T.evaluate(A[0:i:1]) # error -def test_illegal_buffer_slice(): check_error(strided_buffer_region, 3) check_error(access_reversed_slice, 3) check_error(access_non_const_slice_length, 3) +def test_syntax_sugar_fail(): + def loop_syntax_sugar_fail(a: T.handle) -> None: + A = T.match_buffer(a, (128,)) + for i in T.thread_binding(128, 128): + A[i] = A[i] * 2.0 + + check_error(loop_syntax_sugar_fail, 3) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py index 849b0fc03d92..32572d392c51 100644 --- a/tests/python/unittest/test_tvmscript_syntax_sugar.py +++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py @@ -20,9 +20,8 @@ import pytest import tvm.testing from tvm.ir import assert_structural_equal +from tvm.script import from_source from tvm.script import tir as T -from tvm.script.parser import from_source -from tvm.testing import check_error @T.prim_func @@ -89,20 +88,10 @@ def loop_syntax_sugar(a: T.handle) -> None: A[i, j, k, x] = A[i, j, k, x] * 2.0 -def loop_syntax_sugar_fail(a: T.handle) -> None: - A = T.match_buffer(a, (128,)) - for i in T.thread_binding(128, 128): - A[i] = A[i] * 2.0 - - def test_loop_syntax_sugar(): assert_structural_equal(loop_no_syntax_sugar, loop_syntax_sugar) -def test_syntax_sugar_fail(): - check_error(loop_syntax_sugar_fail, 3) - - # match buffer - use kwargs @T.prim_func def elementwise_handle(