From 99e39e11edd74299bb9483a0a34580f7ddbf43c3 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 25 May 2021 16:42:32 +0000 Subject: [PATCH] [TensorIR][M2a] Structural Error Reporting This PR is part of the TensorIR upstreaming effort (#7527), stage M2a. In this PR, we implemented ScheduleError, an error reporting mechanism for schedule primitives to report user-face error messages, with the functionality of rendering the TIR out in the TVM script syntax. This set of APIs allows future improvement of error location rendering, e.g. more colorful rendering mechanisms like synr does. Co-authored-by: Siyuan Feng Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Tristan Konolige --- include/tvm/tir/schedule/schedule.h | 14 +++- python/tvm/tir/__init__.py | 2 +- python/tvm/tir/schedule/__init__.py | 2 +- python/tvm/tir/schedule/schedule.py | 22 ++++++ src/tir/schedule/concrete_schedule.cc | 71 +++++++++++++++++-- src/tir/schedule/concrete_schedule.h | 3 + src/tir/schedule/error.cc | 55 ++++++++++++++ src/tir/schedule/error.h | 60 ++++++++++++++++ src/tir/schedule/schedule.cc | 5 +- src/tir/schedule/utils.h | 1 + .../unittest/test_tir_schedule_error.py | 70 ++++++++++++++++++ 11 files changed, 296 insertions(+), 9 deletions(-) create mode 100644 src/tir/schedule/error.cc create mode 100644 src/tir/schedule/error.h create mode 100644 tests/python/unittest/test_tir_schedule_error.py diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index b85fdec8cba9..2aee2cb136b3 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -24,6 +24,16 @@ namespace tvm { namespace tir { +/*! \brief The level of detailed error message rendering */ +enum class ScheduleErrorRenderLevel : int32_t { + /*! \brief Render a detailed error message */ + kDetail = 0, + /*! \brief Render the error in fast mode */ + kFast = 1, + /*! \brief No error message at all */ + kNone = 2, +}; + /**************** Random variable: BlockRV ****************/ /*! \brief A random variable that evaluates to a TensorIR block */ @@ -209,13 +219,15 @@ class Schedule : public runtime::ObjectRef { * \param mod The IRModule to be scheduled * \param debug_mode Do extra correctness checking after the class creation * and each time after calling the Replace method. + * \param error_render_level The level of error rendering * \return The concrete schedule created * \sa ScheduleDebugMask * \note The checks performed includes: * 1) VerifySRefTree * 2) VerifyCachedFlags */ - TVM_DLL static Schedule Concrete(IRModule mod, int debug_mode); + TVM_DLL static Schedule Concrete(IRModule mod, int debug_mode, + ScheduleErrorRenderLevel error_render_level); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode); }; diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index afe521a74361..eb200df0c599 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -48,7 +48,7 @@ from .op import comm_reducer, min, max, sum from .op import q_multiply_shift -from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule +from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError from . import schedule from . import ir_builder diff --git a/python/tvm/tir/schedule/__init__.py b/python/tvm/tir/schedule/__init__.py index 5550a9e3c74f..ef1cab1fb663 100644 --- a/python/tvm/tir/schedule/__init__.py +++ b/python/tvm/tir/schedule/__init__.py @@ -19,4 +19,4 @@ from .block_scope import BlockScope, Dependency, DepKind, StmtSRef from .state import ScheduleDebugMask, ScheduleState -from .schedule import LoopRV, BlockRV, ExprRV, RAND_VAR_TYPE, Schedule +from .schedule import LoopRV, BlockRV, ExprRV, RAND_VAR_TYPE, Schedule, ScheduleError diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index f207fa274212..d420f7d32db0 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -19,6 +19,7 @@ from typing import List, Optional, Union from tvm._ffi import register_object as _register_object +from tvm.error import TVMError, register_error from tvm.ir import IRModule, PrimExpr from tvm.runtime import Object from tvm.tir import Block, For, IntImm, PrimFunc, Var @@ -27,6 +28,11 @@ from .state import ScheduleState, StmtSRef +@register_error +class ScheduleError(TVMError): + """Error that happens during TensorIR scheduling.""" + + @_register_object("tir.LoopRV") class LoopRV(Object): """A random variable that refers to a loop""" @@ -57,10 +63,14 @@ class Schedule(Object): Link to tutorial: https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html """ + ERROR_RENDER_LEVEL = {"detail": 0, "fast": 1, "none": 2} + def __init__( self, func_or_mod: Union[PrimFunc, IRModule], + *, debug_mode: Union[bool, int] = False, + error_render_level: str = "detail", ): """Construct a concrete TensorIR schedule from an IRModule or a PrimFunc @@ -71,6 +81,11 @@ def __init__( debug_mode : Union[bool, int] Do extra correctness checking after the class creation and each time scheduling primitive + error_render_level : str = "detail" + The level of error rendering. Choices: "detail", "fast", "none". + "detail": Render a detailed error message, with the TIR and error locations printed + "fast: Show a simple error message without rendering or string manipulation + "none": Do not show any error message. Note ---------- @@ -85,10 +100,17 @@ def __init__( debug_mode = 0 if not isinstance(debug_mode, int): raise TypeError(f"`debug_mode` should be integer or boolean, but gets: {debug_mode}") + if error_render_level not in Schedule.ERROR_RENDER_LEVEL: + raise ValueError( + 'error_render_level can be "detail", "fast", or "none", but got: ' + + f"{error_render_level}" + ) + error_render_level = Schedule.ERROR_RENDER_LEVEL.get(error_render_level) self.__init_handle_by_constructor__( _ffi_api_schedule.ConcreteSchedule, # pylint: disable=no-member func_or_mod, debug_mode, + error_render_level, ) ########## Utilities ########## diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index ef12f10fa924..60ab7920c37b 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -21,9 +21,11 @@ namespace tvm { namespace tir { -Schedule Schedule::Concrete(IRModule mod, int debug_mode) { +Schedule Schedule::Concrete(IRModule mod, int debug_mode, + ScheduleErrorRenderLevel error_render_level) { ObjectPtr n = make_object(); n->state_ = ScheduleState(mod, debug_mode); + n->error_render_level_ = error_render_level; n->symbol_table_ = {}; n->analyzer_ = std::make_unique(); return Schedule(std::move(n)); @@ -136,6 +138,7 @@ class ScheduleCopier { scope->src2deps = Copy(old_info.scope->src2deps); scope->dst2deps = Copy(old_info.scope->dst2deps); scope->buffer_writers = Copy(old_info.scope->buffer_writers); + scope->stage_pipeline = old_info.scope->stage_pipeline; new_info.scope = BlockScope(std::move(scope)); result[Copy(old_sref)] = std::move(new_info); } @@ -173,21 +176,81 @@ class ScheduleCopier { void ConcreteScheduleNode::Copy(ScheduleState* new_state, TSymbolTable* new_symbol_table) const { ScheduleCopier::Copy(this, new_state, new_symbol_table); + new_state->get()->DebugVerify(); } Schedule ConcreteScheduleNode::Copy() const { ObjectPtr n = make_object(); - Copy(&n->state_, &n->symbol_table_); + n->error_render_level_ = this->error_render_level_; + this->Copy(&n->state_, &n->symbol_table_); n->analyzer_ = std::make_unique(); return Schedule(std::move(n)); } +/*! \brief Macro that guards the beginning of each invocation of TensorIR schedule primitive */ +#define TVM_TIR_SCHEDULE_BEGIN() try { +/*! + * \brief Macro that pairs with `TVM_TIR_SCHEDULE_BEGIN`, handling potential errors and error + * message rendering + * \param level An ScheduleErrorRenderLevel enum, level of error rendering + * \sa ScheduleErrorRenderLevel + */ +#define TVM_TIR_SCHEDULE_END(level) \ + } \ + catch (const ScheduleError& error) { \ + if ((level) == ScheduleErrorRenderLevel::kDetail) { \ + throw tvm::runtime::Error(error.RenderReport()); \ + } else if ((level) == ScheduleErrorRenderLevel::kFast) { \ + throw tvm::runtime::Error(error.FastErrorString()); \ + } else if ((level) == ScheduleErrorRenderLevel::kNone) { \ + throw tvm::runtime::Error("ScheduleError: (not rendered)"); \ + } \ + } + /******** Block/Loop relation ********/ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& func_name) { + class NotSingleResult : public ScheduleError { + public: + explicit NotSingleResult(String name, IRModule mod, const Array& blocks) + : name_(name), mod_(mod), blocks_{} { + blocks_.reserve(blocks.size()); + for (const StmtSRef& block_sref : blocks) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + blocks_.push_back(GetRef(block)); + } + } + + String primitive() const final { return "get-block"; } + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {blocks_.begin(), blocks_.end()}; } + + String DetailRenderTemplate() const final { + if (blocks_.empty()) { + return "Cannot find a block with the name: " + name_; + } else { + return "Found " + std::to_string(blocks_.size()) + " blocks with the name: " + name_; + } + } + + String FastErrorString() const final { + if (blocks_.empty()) { + return "ScheduleError: Cannot find a block with the specified name"; + } else { + return "ScheduleError: Found multiple blocks with the specified name"; + } + } + + String name_; + IRModule mod_; + Array blocks_; + }; Array blocks = tir::GetBlocks(this->state_, name, func_name); - CHECK_EQ(blocks.size(), 1) << "ValueError: There are " << blocks.size() - << " blocks with the name: " << name; + if (blocks.size() != 1) { + TVM_TIR_SCHEDULE_BEGIN(); + throw NotSingleResult(name, this->state_->mod, blocks); + TVM_TIR_SCHEDULE_END(this->error_render_level_); + } return CreateRV(blocks[0]); } diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 39eab1159db9..ab467cec9ee3 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -37,6 +37,8 @@ class ConcreteScheduleNode : public ScheduleNode { protected: /*! \brief The internal state of scheduling */ ScheduleState state_; + /*! \brief The level of error rendering */ + ScheduleErrorRenderLevel error_render_level_; /*! \brief A symbol table that maps random variables to concrete StmtSRef/Integers */ TSymbolTable symbol_table_; /*! \brief A persistent stateless arithmetic analyzer. */ @@ -44,6 +46,7 @@ class ConcreteScheduleNode : public ScheduleNode { public: void VisitAttrs(tvm::AttrVisitor* v) { + // `error_render_level_` is not visited // `state_` is not visited // `symbol_table_` is not visited // `analyzer_` is not visitied diff --git a/src/tir/schedule/error.cc b/src/tir/schedule/error.cc new file mode 100644 index 000000000000..f64d4aeb984b --- /dev/null +++ b/src/tir/schedule/error.cc @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace tir { + +String ScheduleError::RenderReport() const { + IRModule mod = this->mod(); + std::ostringstream os; + os << "ScheduleError: An error occurred in the schedule primitive '" << this->primitive() + << "'.\n\nThe IR is:\n" + << AsTVMScript(mod); + Array locs = LocationsOfInterest(); + int n_locs = locs.size(); + std::vector roi_names; + roi_names.reserve(n_locs); + if (n_locs > 0) { + os << "Regions of interest:\n"; + for (const ObjectRef& obj : locs) { + String name = obj->GetTypeKey() + '#' + std::to_string(roi_names.size()); + os << name << "\n" << obj; + roi_names.emplace_back(std::move(name)); + } + os << "\n"; + } + std::string msg = DetailRenderTemplate(); + for (int i = 0; i < n_locs; ++i) { + std::string src = "{" + std::to_string(i) + "}"; + for (size_t pos; (pos = msg.find(src)) != std::string::npos;) { + msg.replace(pos, src.length(), roi_names[i]); + } + } + os << "Error message: " << msg; + return os.str(); +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/error.h b/src/tir/schedule/error.h new file mode 100644 index 000000000000..1031672f0010 --- /dev/null +++ b/src/tir/schedule/error.h @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TIR_SCHEDULE_ERROR_H_ +#define TVM_TIR_SCHEDULE_ERROR_H_ + +#include + +namespace tvm { +namespace tir { + +/*! \brief Error that happens during TensorIR scheduling */ +class ScheduleError : public tvm::runtime::Error { + public: + /*! \brief Base constructor */ + ScheduleError() : tvm::runtime::Error("") {} + /*! \brief The error occurred in this scheduling primitive */ + virtual String primitive() const = 0; + /*! \brief The error occurred in this IRModule */ + virtual IRModule mod() const = 0; + /*! \brief The locations of interest that we want to point out */ + virtual Array LocationsOfInterest() const = 0; + /*! + * \brief Returns an error string template for rendering, corresponds to the "detail" mode. + * \sa ScheduleErrorRenderLevel + * \note The template is a string, e.g. + * "Some error occurred on block {0} and loop {1} blah blah" + * And renderer will replace {0} and {1} according to the list provided LocationsOfInterest. Right + * now it only printed out all the locations in plain text, but in the future, we may want to mark + * the IR with underscores and attach names to each location of interest, like what synr does. + */ + virtual String DetailRenderTemplate() const = 0; + /*! + * \brief Returns an error string without needing to render, corresponds to the "fast" mode + * \sa ScheduleErrorRenderLevel + */ + virtual String FastErrorString() const = 0; + /*! \brief Render the ScheduleError with the template provided by `DetailRenderTemplate` */ + String RenderReport() const; +}; + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_SCHEDULE_ERROR_H_ diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index b407b07e5312..a1a4f09a7525 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -56,7 +56,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") // /**************** (FFI) Constructor ****************/ TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule") - .set_body_typed([](ObjectRef obj, int debug_mode) -> Schedule { + .set_body_typed([](ObjectRef obj, int debug_mode, int error_render_level) -> Schedule { IRModule mod{nullptr}; if (const auto* func = obj.as()) { mod = IRModule({{GlobalVar("main"), GetRef(func)}}); @@ -66,7 +66,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule") LOG(FATAL) << "TypeError: Expects `IRModule` or `PrimFunc`, but gets: " << obj->GetTypeKey(); } - return Schedule::Concrete(mod, debug_mode); + return Schedule::Concrete(mod, debug_mode, + static_cast(error_render_level)); }); /******** (FFI) Lookup random variables ********/ diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index b72fd8e05706..e7c73120c730 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -35,6 +35,7 @@ #include "../../printer/text_printer.h" #include "../../runtime/thread_storage_scope.h" #include "./analysis.h" +#include "./error.h" namespace tvm { namespace tir { diff --git a/tests/python/unittest/test_tir_schedule_error.py b/tests/python/unittest/test_tir_schedule_error.py new file mode 100644 index 000000000000..1fa658feabe3 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_error.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import pytest +import tvm +from tvm import tir +from tvm.script import ty + + +# pylint: disable=no-member,invalid-name,unused-variable + + +@tvm.script.tir +def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "init") as [vi, vj]: + C[vi, vj] = tir.float32(0) + for k in range(0, 128): + with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +# pylint: enable=no-member,invalid-name,unused-variable + + +def test_tir_schedule_error_detail(): + sch = tir.Schedule(matmul, debug_mode=True, error_render_level="detail") + with pytest.raises(tir.ScheduleError) as excinfo: + sch.get_block("wrong_name") + (msg,) = excinfo.value.args + assert "Cannot find a block with the name: wrong_name" in msg + + +def test_tir_schedule_error_fast(): + sch = tir.Schedule(matmul, debug_mode=True, error_render_level="fast") + with pytest.raises(tir.ScheduleError) as excinfo: + sch.get_block("wrong_name") + (msg,) = excinfo.value.args + assert "Cannot find a block with the specified name" in msg + + +def test_tir_schedule_error_none(): + sch = tir.Schedule(matmul, debug_mode=True, error_render_level="none") + with pytest.raises(tir.ScheduleError) as excinfo: + sch.get_block("wrong_name") + (msg,) = excinfo.value.args + assert "(not rendered)" in msg + + +if __name__ == "__main__": + test_tir_schedule_error_detail() + test_tir_schedule_error_fast() + test_tir_schedule_error_none()