Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions python/tvm/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,5 @@
from .popen_pool import call_py_ffi, call_cpp_py_ffi, fast_summation, slow_summation
from .popen_pool import timeout_job

from .tir import check_error

from . import auto_scheduler
from . import autotvm
45 changes: 2 additions & 43 deletions python/tvm/testing/tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,49 +16,6 @@
# under the License.
# pylint: disable=invalid-name, import-outside-toplevel, unused-variable
"""Common utility functions in TVM tir"""
import inspect
import re
import tvm
from tvm.ir.diagnostics import override_renderer


CHECK_ERROR_RE = re.compile(r"^.*# check_error: (.+)$")


def check_error(func, rel_lineno):
"""check if TIR script throws error"""
# Override the default renderer to accumulate errors
errors = []

def render(e):
for d in e.diagnostics:
errors.append(d)

override_renderer(render)
# The diagnostic context throws an exception when it gets an error
try:
source_code = inspect.getsource(func)
source_code = "@T.prim_func\n" + source_code
from tvm.script import from_source

# to avoid cyclic import
from_source(source_code)
except tvm.error.DiagnosticError as e:
pass
assert len(errors) == 1, errors
for d in errors:
assert (
d.span.line - 1 == rel_lineno
), f"Expected error to be on line {rel_lineno}, but it was on {d.span.line - 1}"

error_line = source_code.split("\n")[rel_lineno]
m = CHECK_ERROR_RE.match(error_line)
if m:
expected_error_text = m.group(1)
errors = [e.message for e in errors]
assert (
expected_error_text in errors
), f'check_error expects "{expected_error_text} in str(errors): {errors}'


def mma_schedule(
Expand All @@ -80,6 +37,8 @@ def mma_schedule(
shared_scope="shared",
):
"""Create a tensorized schedule for GEMM with MMA intrinsics."""
import tvm # pylint: disable=import-outside-toplevel

ir_module = tvm.IRModule({"main": workload})
sch = tvm.tir.Schedule(ir_module)

Expand Down
Loading