diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index a2178167b2e3..316d59631782 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -379,6 +379,9 @@ class TIRTextPrinter : public StmtFunctor, String AsTVMScript(const ObjectRef& mod, const String& tir_prefix = "tir", bool show_meta = false); +String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta, + runtime::TypedPackedFunc annotate); + } // namespace tir } // namespace tvm diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 13e4cfcd30ba..8ac745f675d9 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -91,7 +91,7 @@ class TVMScriptPrinter : public StmtFunctor, */ TVM_DLL Doc Print(const ObjectRef& node); - private: + protected: /*! \brief The tir prefix */ String tir_prefix_; /*! \brief whether show meta data */ @@ -208,6 +208,7 @@ class TVMScriptPrinter : public StmtFunctor, Doc PrintBlockVars(const BlockRealizeNode* op); Doc PrintBlockAttr(const BlockRealizeNode* op); Doc PrintBlockBody(const BlockNode* op); + virtual Doc PrintBlockName(const BlockNode* block_op); Doc PrintBufferRegion(const BufferRegionNode* op); Doc PrintMatchBufferRegion(const MatchBufferRegionNode* op); Doc PrintAnnotations(const Map& annotations); @@ -217,15 +218,24 @@ class TVMScriptPrinter : public StmtFunctor, Doc AllocVar(const Var& var); Doc AllocBuf(const Buffer& buffer); void TryDeallocVar(const Var& var); + bool ContainsOptionalInfo(const Stmt& stmt); /*! Helper functions for loop printing. */ /*! * \brief Print a single for loop * \param loop The for loop to be printed */ - Doc PrintLoop(const For& loop); + virtual Doc PrintLoop(const For& loop); /*! \brief Print all simple loops in stack into one line using tir_prefix_.grid(). */ Doc PrintLoopStack(); + /*! + * \brief Print all simple loops in stack into one line using tir_prefix_.grid(). + * \param for_op the for node to be checked + */ + bool IsSimpleLoop(const ForNode* for_op) { + return for_op->kind == ForKind::kSerial && for_op->annotations.empty() && + is_zero(for_op->min) && !ContainsOptionalInfo(GetRef(for_op)); + } /*! * \brief Print additional info about expr in comment. @@ -234,11 +244,9 @@ class TVMScriptPrinter : public StmtFunctor, Doc PrintOptionalInfo(const Stmt& stmt) { Doc doc; // default annotations - if (annotate_ != nullptr) { + if (ContainsOptionalInfo(stmt)) { std::string annotated_stmt = annotate_(stmt); - if (!annotated_stmt.empty()) { - doc << "# " << annotated_stmt << Doc::NewLine(); - } + doc << "# " << annotated_stmt << Doc::NewLine(); } return doc; } @@ -391,6 +399,16 @@ Doc TVMScriptPrinter::AllocBuf(const Buffer& buffer) { return val; } +/*! + * \brief Check if any optional information exists in annotate_ for + * a given Stmt. + * \param stmt The statement. + */ +bool TVMScriptPrinter::ContainsOptionalInfo(const Stmt& stmt) { + if (annotate_ == nullptr) return false; + return !annotate_(stmt).empty(); +} + /*! * \brief Try to dealloc vars out of space and leave the index to coming vars. * \note It is not a necessary step. @@ -835,14 +853,14 @@ Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) { var_not_in_headers_.insert(op->loop_var.get()); loop_var_map_[op->loop_var.get()] = GetRef(op); const auto* body = op->body.as(); - bool simple_loop = op->kind == ForKind::kSerial && op->annotations.empty() && is_zero(op->min); + bool simple_loop = IsSimpleLoop(op); if (simple_loop) simple_loop_stack_.push_back(GetRef(op)); // It is a loop that can be compressed, let the loops below print it out - if (simple_loop && body != nullptr) { - Doc result = Print(GetRef(body)); + if (simple_loop && body != nullptr && IsSimpleLoop(body)) { + doc << Print(GetRef(body)); TryDeallocVar(op->loop_var); loop_var_map_.erase(op->loop_var.get()); - return result; + return doc; } // It is a loop that can not be compressed bool print_above = !simple_loop_stack_.empty(); @@ -916,6 +934,7 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferStoreNode* op) { return doc; } +/*! Helper functions for block printing. */ Doc TVMScriptPrinter::PrintBlockVar(const IterVar& iter_var, const PrimExpr& value) { Doc doc; doc << Print(iter_var->var) << " = " << tir_prefix_ << ".axis."; @@ -1049,15 +1068,25 @@ Doc TVMScriptPrinter::PrintBlockBody(const BlockNode* op) { return body; } -Doc TVMScriptPrinter::VisitStmt_(const BlockRealizeNode* op) { - const auto* block_op = op->block.as(); - // print block name and block vars +/*! + * \brief Print the name of a block + * \param block_op The block node to be printed + */ +Doc TVMScriptPrinter::PrintBlockName(const BlockNode* block_op) { Doc doc; doc << "with " << tir_prefix_ << ".block("; if (!block_op->name_hint.empty()) { doc << Doc::StrLiteral(block_op->name_hint); } doc << "):"; + return doc; +} + +Doc TVMScriptPrinter::VisitStmt_(const BlockRealizeNode* op) { + const auto* block_op = op->block.as(); + Doc doc = PrintOptionalInfo(GetRef(block_op)); + // print block name and block vars + doc << PrintBlockName(block_op); Doc block_var = PrintBlockVars(op); // print predicate, binding, read/write tensor region, annotations Doc block_attr_doc = PrintBlockAttr(op); @@ -1343,6 +1372,45 @@ Doc TVMScriptPrinter::PrintLoopStack() { return res; } +/*! + * \brief The printer for TVMScript with diagnostic + * \details The printer obtain the precedence of the top-level operation when printing each + * subexpression to decide whether or not parentheses is needed. + */ +class TVMScriptPrinterWithDiagnostic : public TVMScriptPrinter { + public: + explicit TVMScriptPrinterWithDiagnostic(const String& tir_prefix, bool show_meta, + runtime::TypedPackedFunc annotate) + : TVMScriptPrinter(tir_prefix, show_meta, annotate) {} + + protected: + Doc PrintBlockName(const BlockNode* block_op) override; + Doc PrintUnderline(const Stmt& stmt, int length); + Doc PrintLoop(const For& loop) override; +}; + +Doc TVMScriptPrinterWithDiagnostic::PrintBlockName(const BlockNode* block_op) { + Doc doc = TVMScriptPrinter::PrintBlockName(block_op); + doc << PrintUnderline(GetRef(block_op), doc.str().size()); + return doc; +} + +Doc TVMScriptPrinterWithDiagnostic::PrintUnderline(const Stmt& stmt, int length) { + Doc doc; + // annotation + if (ContainsOptionalInfo(stmt)) { + String underline = std::string(length, '^'); + doc << Doc::NewLine() << underline; + } + return doc; +} + +Doc TVMScriptPrinterWithDiagnostic::PrintLoop(const For& loop) { + Doc res = TVMScriptPrinter::PrintLoop(loop); + res << PrintUnderline(loop, res.str().size()); + return res; +} + String AsTVMScript(const ObjectRef& mod, const String& tir_prefix, bool show_meta) { ICHECK(mod->IsInstance() || mod->IsInstance()); return TVMScriptPrinter(tir_prefix, show_meta).Print(mod).str() + "\n"; @@ -1350,5 +1418,13 @@ String AsTVMScript(const ObjectRef& mod, const String& tir_prefix, bool show_met TVM_REGISTER_GLOBAL("script.AsTVMScript").set_body_typed(AsTVMScript); +String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta, + runtime::TypedPackedFunc annotate) { + ICHECK(mod->IsInstance() || mod->IsInstance()); + return TVMScriptPrinterWithDiagnostic(tir_prefix, show_meta, annotate).Print(mod).str() + "\n"; +} + +TVM_REGISTER_GLOBAL("script.AsTVMScriptWithDiagnostic").set_body_typed(AsTVMScriptWithDiagnostic); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/error.cc b/src/tir/schedule/error.cc index d8dcf57b91e4..eb72773ffedb 100644 --- a/src/tir/schedule/error.cc +++ b/src/tir/schedule/error.cc @@ -24,29 +24,37 @@ namespace tir { String ScheduleError::RenderReport(const String& primitive) const { IRModule mod = this->mod(); std::ostringstream os; - os << "ScheduleError: An error occurred in the schedule primitive '" << primitive - << "'.\n\nThe IR is:\n" - << AsTVMScript(mod); + + // get locations of interest Array locs = LocationsOfInterest(); + std::unordered_map loc_obj_to_name; int n_locs = locs.size(); - std::vector roi_names; - roi_names.reserve(n_locs); - if (n_locs > 0) { - os << "Regions of interest:\n"; - for (const ObjectRef& obj : locs) { - String name = obj->GetTypeKey() + '#' + std::to_string(roi_names.size()); - os << name << "\n" << obj; - roi_names.emplace_back(std::move(name)); - } - os << "\n"; - } std::string msg = DetailRenderTemplate(); - for (int i = 0; i < n_locs; ++i) { - std::string src = "{" + std::to_string(i) + "}"; - for (size_t pos; (pos = msg.find(src)) != std::string::npos;) { - msg.replace(pos, src.length(), roi_names[i]); + if (n_locs > 0) { + for (int i = 0; i < n_locs; ++i) { + std::string name = locs[i]->GetTypeKey() + '#' + std::to_string(i); + std::string src = "{" + std::to_string(i) + "}"; + for (size_t pos; (pos = msg.find(src)) != std::string::npos;) { + msg.replace(pos, src.length(), name); + } + loc_obj_to_name.emplace(locs[i], std::move(name)); } } + + // print IR module + runtime::TypedPackedFunc annotate = + runtime::TypedPackedFunc( + [&loc_obj_to_name](const Stmt& expr) -> std::string { + auto it = loc_obj_to_name.find(Downcast(expr)); + if (it == loc_obj_to_name.end()) return ""; + return it->second; + }); + + os << "ScheduleError: An error occurred in the schedule primitive '" << primitive + << "'.\n\nThe IR with diagnostic is:\n" + << AsTVMScriptWithDiagnostic(mod, "tir", false, annotate); + + // print error message os << "Error message: " << msg; return os.str(); } diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 80c37229f519..3098c86a7c2e 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -18,6 +18,7 @@ import pytest import sys import tvm +from tvm import tir from tvm.script import tir as T from tvm.ir.diagnostics import override_renderer import inspect @@ -511,5 +512,77 @@ def render(e): # TODO(Siyuan): block iter errors. + +@T.prim_func +def elementwise_not_affine(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128, 128)) + for i, j, k, l in T.grid(128, 128, 128, 8): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + vl = T.axis.S(128, l * 16) + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + + +@T.prim_func +def elementwise_non_single_branch(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128)) + C = T.alloc_buffer((128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128)) + for i, j in T.grid(128, 128): + for k in T.serial(0, 128): + with T.block("C"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + C[vi, vj, vk] = A[vi, vj, vk] * 2.0 + for k in T.serial(0, 128): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + B[vi, vj, vk] = C[vi, vj, vk] * 2.0 + + +def test_reorder_fail_block(): + sch = tir.Schedule(elementwise_not_affine, debug_mask="all") + block_b = sch.get_block("B") + i, j, k, l = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError) as execinfo: + sch.reorder(l, i) + expected_sub_error_message = ( + " # tir.Block#0\n" + ' with tir.block("B"):\n' + " ^^^^^^^^^^^^^^^^^^^^\n" + ) + assert expected_sub_error_message in str(execinfo.value) + + +def test_reorder_fail_nested_loop_inner(): + sch = tir.Schedule(elementwise_non_single_branch, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError) as execinfo: + sch.reorder(k, i) + expected_sub_error_message = ( + " for i in tir.serial(0, 128):\n" + " # tir.For#0\n" + " for j in tir.serial(0, 128):\n" + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n" + ) + assert expected_sub_error_message in str(execinfo.value) + + +def test_fuse_fail_nested_loop_outer(): + sch = tir.Schedule(elementwise_non_single_branch, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError) as execinfo: + sch.fuse(k, i) + expected_sub_error_message = ( + " # tir.For#1\n" + " for i in tir.serial(0, 128):\n" + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n" + " for j in tir.serial(0, 128):\n" + ) + assert expected_sub_error_message in str(execinfo.value) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))