From d7fa520f5f73f7ab7c57e22cbdaffc92fd77c3cd Mon Sep 17 00:00:00 2001 From: shingjan Date: Mon, 18 Oct 2021 00:06:04 -0700 Subject: [PATCH 01/16] add structural error printing --- src/printer/text_printer.h | 3 + src/printer/tvmscript_printer.cc | 96 +++++++++++++++++++++++++++++++- src/tir/schedule/error.cc | 23 +++++++- 3 files changed, 118 insertions(+), 4 deletions(-) diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index a2178167b2e3..afd31f9065b4 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -379,6 +379,9 @@ class TIRTextPrinter : public StmtFunctor, String AsTVMScript(const ObjectRef& mod, const String& tir_prefix = "tir", bool show_meta = false); +String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta, + runtime::TypedPackedFunc annotate); + } // namespace tir } // namespace tvm diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 13e4cfcd30ba..19925b53875f 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -91,7 +91,7 @@ class TVMScriptPrinter : public StmtFunctor, */ TVM_DLL Doc Print(const ObjectRef& node); - private: + protected: /*! \brief The tir prefix */ String tir_prefix_; /*! \brief whether show meta data */ @@ -1343,6 +1343,92 @@ Doc TVMScriptPrinter::PrintLoopStack() { return res; } +/*! + * \brief The printer for TVMScript with diagnostic + * \details The printer obtain the precedence of the top-level operation when printing each + * subexpression to decide whether or not parentheses is needed. + */ +class TVMScriptPrinterWithDiagnostic : public TVMScriptPrinter { + public: + explicit TVMScriptPrinterWithDiagnostic(const String& tir_prefix, bool show_meta, + runtime::TypedPackedFunc annotate) + : TVMScriptPrinter(tir_prefix, show_meta), annotate_(std::move(annotate)) {} + + protected: + /*! \brief additional comment function */ + runtime::TypedPackedFunc annotate_; + Doc VisitStmt_(const ForNode* op) override; + Doc VisitStmt_(const BlockRealizeNode* op) override; + Doc PrintAnnotation(const Stmt& stmt, int length); +}; + +Doc TVMScriptPrinterWithDiagnostic::VisitStmt_(const ForNode* op) { + Doc doc; + + // + var_not_in_headers_.insert(op->loop_var.get()); + const auto* body = op->body.as(); + bool simple_loop = op->kind == ForKind::kSerial && op->annotations.empty() && is_zero(op->min); + if (simple_loop) loop_stack_.push_back(GetRef(op)); + // It is a loop that can be compressed, let the loops below print it out + if (simple_loop && body != nullptr) { + Doc result = Print(GetRef(body)); + TryDeallocVar(op->loop_var); + return result; + } + // It is a loop that can not be compressed + bool print_above = !loop_stack_.empty(); + // print loops above if needed + if (print_above) { + doc << PrintLoopStack(); + loop_stack_.clear(); + } + if (!simple_loop) { + // print current loop if needed + Doc current_loop; + current_loop << PrintLoop(GetRef(op)) + << PrintAnnotation(GetRef(op), doc.str().size()) + << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); + doc << (print_above ? Doc::Indent(4, Doc::NewLine() << current_loop) : current_loop); + } else { + doc << PrintAnnotation(GetRef(op), doc.str().size()) + << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); + } + TryDeallocVar(op->loop_var); + return doc; +} + +Doc TVMScriptPrinterWithDiagnostic::VisitStmt_(const BlockRealizeNode* op) { + const auto* block_op = op->block.as(); + // print block name and block vars + Doc doc = PrintBlockVar(block_op); + + // annotation + doc << PrintAnnotation(GetRef(op), doc.str().size()); + + // print predicate, binding, read/write tensor region, annotations + Doc block_attr_doc = PrintBlockAttr(op); + // print body + Doc body = PrintBlockBody(block_op); + doc << Doc::Indent(4, block_attr_doc << Doc::NewLine() << body); + for (const auto& iter_var : block_op->iter_vars) { + TryDeallocVar(iter_var->var); + } + return doc; +} + +Doc TVMScriptPrinterWithDiagnostic::PrintAnnotation(const Stmt& stmt, int length) { + Doc doc; + // annotation + if (annotate_ != nullptr) { + String annotated_stmt = std::string(length, '^'); + if (!annotated_stmt.empty()) { + doc << Doc::NewLine() << annotated_stmt; + } + } + return doc; +} + String AsTVMScript(const ObjectRef& mod, const String& tir_prefix, bool show_meta) { ICHECK(mod->IsInstance() || mod->IsInstance()); return TVMScriptPrinter(tir_prefix, show_meta).Print(mod).str() + "\n"; @@ -1350,5 +1436,13 @@ String AsTVMScript(const ObjectRef& mod, const String& tir_prefix, bool show_met TVM_REGISTER_GLOBAL("script.AsTVMScript").set_body_typed(AsTVMScript); +String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta, + runtime::TypedPackedFunc annotate) { + ICHECK(mod->IsInstance() || mod->IsInstance()); + return TVMScriptPrinterWithDiagnostic(tir_prefix, show_meta, annotate).Print(mod).str() + "\n"; +} + +TVM_REGISTER_GLOBAL("script.AsTVMScriptWithDiagnostic").set_body_typed(AsTVMScriptWithDiagnostic); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/error.cc b/src/tir/schedule/error.cc index d8dcf57b91e4..5cc80e53863c 100644 --- a/src/tir/schedule/error.cc +++ b/src/tir/schedule/error.cc @@ -24,13 +24,18 @@ namespace tir { String ScheduleError::RenderReport(const String& primitive) const { IRModule mod = this->mod(); std::ostringstream os; - os << "ScheduleError: An error occurred in the schedule primitive '" << primitive - << "'.\n\nThe IR is:\n" - << AsTVMScript(mod); + + // print IRModule Array locs = LocationsOfInterest(); int n_locs = locs.size(); std::vector roi_names; roi_names.reserve(n_locs); + + os << "ScheduleError: An error occurred in the schedule primitive '" << primitive + << "'.\n\nThe IR is:\n" + << AsTVMScript(mod); + + // print region of interest if (n_locs > 0) { os << "Regions of interest:\n"; for (const ObjectRef& obj : locs) { @@ -40,6 +45,18 @@ String ScheduleError::RenderReport(const String& primitive) const { } os << "\n"; } + + runtime::TypedPackedFunc annotate = + runtime::TypedPackedFunc([](const ObjectRef& expr) -> String { + std::string annotations = std::string(10, '^'); + return annotations; + }); + + os << "ScheduleError: An error occurred in the schedule primitive '" << primitive + << "'.\n\nThe IR with diagnostic is:\n" + << AsTVMScriptWithDiagnostic(mod, "tir", false, annotate); + + // print error message std::string msg = DetailRenderTemplate(); for (int i = 0; i < n_locs; ++i) { std::string src = "{" + std::to_string(i) + "}"; From d0ce41d878a4e6c177caae8ee9d376211af967c8 Mon Sep 17 00:00:00 2001 From: shingjan Date: Mon, 18 Oct 2021 00:11:57 -0700 Subject: [PATCH 02/16] remove old code --- src/tir/schedule/error.cc | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/tir/schedule/error.cc b/src/tir/schedule/error.cc index 5cc80e53863c..38c7fcc98384 100644 --- a/src/tir/schedule/error.cc +++ b/src/tir/schedule/error.cc @@ -25,27 +25,19 @@ String ScheduleError::RenderReport(const String& primitive) const { IRModule mod = this->mod(); std::ostringstream os; - // print IRModule + // get locations of interest Array locs = LocationsOfInterest(); int n_locs = locs.size(); std::vector roi_names; roi_names.reserve(n_locs); - - os << "ScheduleError: An error occurred in the schedule primitive '" << primitive - << "'.\n\nThe IR is:\n" - << AsTVMScript(mod); - - // print region of interest if (n_locs > 0) { - os << "Regions of interest:\n"; for (const ObjectRef& obj : locs) { String name = obj->GetTypeKey() + '#' + std::to_string(roi_names.size()); - os << name << "\n" << obj; roi_names.emplace_back(std::move(name)); } - os << "\n"; } + // print IR module runtime::TypedPackedFunc annotate = runtime::TypedPackedFunc([](const ObjectRef& expr) -> String { std::string annotations = std::string(10, '^'); From 1070e1f4f81dffb14b18c1959b37ce1fdc5cf71d Mon Sep 17 00:00:00 2001 From: shingjan Date: Mon, 18 Oct 2021 18:09:46 -0700 Subject: [PATCH 03/16] address comments --- src/printer/text_printer.h | 2 +- src/printer/tvmscript_printer.cc | 19 +++++++++---------- src/tir/schedule/error.cc | 31 +++++++++++++++++-------------- 3 files changed, 27 insertions(+), 25 deletions(-) diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index afd31f9065b4..316d59631782 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -380,7 +380,7 @@ class TIRTextPrinter : public StmtFunctor, String AsTVMScript(const ObjectRef& mod, const String& tir_prefix = "tir", bool show_meta = false); String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta, - runtime::TypedPackedFunc annotate); + runtime::TypedPackedFunc annotate); } // namespace tir } // namespace tvm diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 19925b53875f..e5921a95a9dc 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -1351,15 +1351,13 @@ Doc TVMScriptPrinter::PrintLoopStack() { class TVMScriptPrinterWithDiagnostic : public TVMScriptPrinter { public: explicit TVMScriptPrinterWithDiagnostic(const String& tir_prefix, bool show_meta, - runtime::TypedPackedFunc annotate) - : TVMScriptPrinter(tir_prefix, show_meta), annotate_(std::move(annotate)) {} + runtime::TypedPackedFunc annotate) + : TVMScriptPrinter(tir_prefix, show_meta, annotate) {} protected: - /*! \brief additional comment function */ - runtime::TypedPackedFunc annotate_; Doc VisitStmt_(const ForNode* op) override; Doc VisitStmt_(const BlockRealizeNode* op) override; - Doc PrintAnnotation(const Stmt& stmt, int length); + Doc PrintAnnotation(const Stmt stmt, int length); }; Doc TVMScriptPrinterWithDiagnostic::VisitStmt_(const ForNode* op) { @@ -1404,7 +1402,7 @@ Doc TVMScriptPrinterWithDiagnostic::VisitStmt_(const BlockRealizeNode* op) { Doc doc = PrintBlockVar(block_op); // annotation - doc << PrintAnnotation(GetRef(op), doc.str().size()); + doc << PrintAnnotation(GetRef(block_op), doc.str().size()); // print predicate, binding, read/write tensor region, annotations Doc block_attr_doc = PrintBlockAttr(op); @@ -1417,13 +1415,14 @@ Doc TVMScriptPrinterWithDiagnostic::VisitStmt_(const BlockRealizeNode* op) { return doc; } -Doc TVMScriptPrinterWithDiagnostic::PrintAnnotation(const Stmt& stmt, int length) { +Doc TVMScriptPrinterWithDiagnostic::PrintAnnotation(const Stmt stmt, int length) { Doc doc; // annotation if (annotate_ != nullptr) { - String annotated_stmt = std::string(length, '^'); + String annotated_stmt = annotate_(stmt); if (!annotated_stmt.empty()) { - doc << Doc::NewLine() << annotated_stmt; + String underline = std::string(length, '^'); + doc << Doc::NewLine() << underline << Doc::NewLine() << annotated_stmt; } } return doc; @@ -1437,7 +1436,7 @@ String AsTVMScript(const ObjectRef& mod, const String& tir_prefix, bool show_met TVM_REGISTER_GLOBAL("script.AsTVMScript").set_body_typed(AsTVMScript); String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta, - runtime::TypedPackedFunc annotate) { + runtime::TypedPackedFunc annotate) { ICHECK(mod->IsInstance() || mod->IsInstance()); return TVMScriptPrinterWithDiagnostic(tir_prefix, show_meta, annotate).Print(mod).str() + "\n"; } diff --git a/src/tir/schedule/error.cc b/src/tir/schedule/error.cc index 38c7fcc98384..46d54c91713e 100644 --- a/src/tir/schedule/error.cc +++ b/src/tir/schedule/error.cc @@ -27,35 +27,38 @@ String ScheduleError::RenderReport(const String& primitive) const { // get locations of interest Array locs = LocationsOfInterest(); + std::unordered_set loc_set; int n_locs = locs.size(); std::vector roi_names; roi_names.reserve(n_locs); + std::string msg = DetailRenderTemplate(); if (n_locs > 0) { - for (const ObjectRef& obj : locs) { - String name = obj->GetTypeKey() + '#' + std::to_string(roi_names.size()); + for (int i = 0; i < n_locs; ++i) { + std::string name = locs[i]->GetTypeKey() + '#' + std::to_string(roi_names.size()); + std::string src = "{" + std::to_string(i) + "}"; + for (size_t pos; (pos = msg.find(src)) != std::string::npos;) { + msg.replace(pos, src.length(), name); + } roi_names.emplace_back(std::move(name)); + loc_set.emplace(locs[i]); } } + msg = "Error message: " + std::move(msg); // print IR module - runtime::TypedPackedFunc annotate = - runtime::TypedPackedFunc([](const ObjectRef& expr) -> String { - std::string annotations = std::string(10, '^'); - return annotations; - }); + runtime::TypedPackedFunc annotate = + runtime::TypedPackedFunc( + [&loc_set, &msg](const Stmt& expr) -> std::string { + auto search = loc_set.find(Downcast(expr)); + if (search == loc_set.end()) return ""; + return msg; + }); os << "ScheduleError: An error occurred in the schedule primitive '" << primitive << "'.\n\nThe IR with diagnostic is:\n" << AsTVMScriptWithDiagnostic(mod, "tir", false, annotate); // print error message - std::string msg = DetailRenderTemplate(); - for (int i = 0; i < n_locs; ++i) { - std::string src = "{" + std::to_string(i) + "}"; - for (size_t pos; (pos = msg.find(src)) != std::string::npos;) { - msg.replace(pos, src.length(), roi_names[i]); - } - } os << "Error message: " << msg; return os.str(); } From 2eb3f1f586ae42fe6d800a54d901dbfdd4993ac2 Mon Sep 17 00:00:00 2001 From: shingjan Date: Mon, 18 Oct 2021 20:14:29 -0700 Subject: [PATCH 04/16] address comments --- src/tir/schedule/error.cc | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/src/tir/schedule/error.cc b/src/tir/schedule/error.cc index 46d54c91713e..28b168889329 100644 --- a/src/tir/schedule/error.cc +++ b/src/tir/schedule/error.cc @@ -27,31 +27,27 @@ String ScheduleError::RenderReport(const String& primitive) const { // get locations of interest Array locs = LocationsOfInterest(); - std::unordered_set loc_set; + std::unordered_map loc_obj_to_name; int n_locs = locs.size(); - std::vector roi_names; - roi_names.reserve(n_locs); std::string msg = DetailRenderTemplate(); if (n_locs > 0) { for (int i = 0; i < n_locs; ++i) { - std::string name = locs[i]->GetTypeKey() + '#' + std::to_string(roi_names.size()); + std::string name = locs[i]->GetTypeKey() + '#' + std::to_string(i); std::string src = "{" + std::to_string(i) + "}"; for (size_t pos; (pos = msg.find(src)) != std::string::npos;) { msg.replace(pos, src.length(), name); } - roi_names.emplace_back(std::move(name)); - loc_set.emplace(locs[i]); + loc_obj_to_name.emplace(locs[i], std::move(name)); } } - msg = "Error message: " + std::move(msg); // print IR module runtime::TypedPackedFunc annotate = runtime::TypedPackedFunc( - [&loc_set, &msg](const Stmt& expr) -> std::string { - auto search = loc_set.find(Downcast(expr)); - if (search == loc_set.end()) return ""; - return msg; + [&loc_obj_to_name](const Stmt& expr) -> std::string { + auto search = loc_obj_to_name.find(Downcast(expr)); + if (search == loc_obj_to_name.end()) return ""; + return search->second; }); os << "ScheduleError: An error occurred in the schedule primitive '" << primitive From d115c3bfaebd97d6c31ff7fe669891842c06ceff Mon Sep 17 00:00:00 2001 From: shingjan Date: Mon, 18 Oct 2021 21:27:11 -0700 Subject: [PATCH 05/16] add test --- src/printer/tvmscript_printer.cc | 30 +++++++++++-------- .../unittest/test_tir_schedule_reorder.py | 13 ++++++++ 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index e5921a95a9dc..61ff8a3863fc 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -1362,53 +1362,59 @@ class TVMScriptPrinterWithDiagnostic : public TVMScriptPrinter { Doc TVMScriptPrinterWithDiagnostic::VisitStmt_(const ForNode* op) { Doc doc; - - // var_not_in_headers_.insert(op->loop_var.get()); + loop_var_map_[op->loop_var.get()] = GetRef(op); const auto* body = op->body.as(); bool simple_loop = op->kind == ForKind::kSerial && op->annotations.empty() && is_zero(op->min); - if (simple_loop) loop_stack_.push_back(GetRef(op)); + if (simple_loop) simple_loop_stack_.push_back(GetRef(op)); // It is a loop that can be compressed, let the loops below print it out if (simple_loop && body != nullptr) { Doc result = Print(GetRef(body)); TryDeallocVar(op->loop_var); + loop_var_map_.erase(op->loop_var.get()); return result; } // It is a loop that can not be compressed - bool print_above = !loop_stack_.empty(); + bool print_above = !simple_loop_stack_.empty(); // print loops above if needed if (print_above) { doc << PrintLoopStack(); - loop_stack_.clear(); + simple_loop_stack_.clear(); } if (!simple_loop) { // print current loop if needed Doc current_loop; current_loop << PrintLoop(GetRef(op)) - << PrintAnnotation(GetRef(op), doc.str().size()) - << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); + << PrintAnnotation(GetRef(op), doc.str().size()); + current_loop << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); doc << (print_above ? Doc::Indent(4, Doc::NewLine() << current_loop) : current_loop); } else { doc << PrintAnnotation(GetRef(op), doc.str().size()) << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); } TryDeallocVar(op->loop_var); + loop_var_map_.erase(op->loop_var.get()); return doc; } Doc TVMScriptPrinterWithDiagnostic::VisitStmt_(const BlockRealizeNode* op) { const auto* block_op = op->block.as(); - // print block name and block vars - Doc doc = PrintBlockVar(block_op); - + // print block name + Doc doc; + doc << "with " << tir_prefix_ << ".block("; + if (!block_op->name_hint.empty()) { + doc << Doc::StrLiteral(block_op->name_hint); + } + doc << "):"; // annotation doc << PrintAnnotation(GetRef(block_op), doc.str().size()); - + // print block vars + Doc block_var = PrintBlockVars(op); // print predicate, binding, read/write tensor region, annotations Doc block_attr_doc = PrintBlockAttr(op); // print body Doc body = PrintBlockBody(block_op); - doc << Doc::Indent(4, block_attr_doc << Doc::NewLine() << body); + doc << Doc::Indent(4, block_var << block_attr_doc << Doc::NewLine() << body); for (const auto& iter_var : block_op->iter_vars) { TryDeallocVar(iter_var->var); } diff --git a/tests/python/unittest/test_tir_schedule_reorder.py b/tests/python/unittest/test_tir_schedule_reorder.py index 8267a369cf5d..0f0f26863a9d 100644 --- a/tests/python/unittest/test_tir_schedule_reorder.py +++ b/tests/python/unittest/test_tir_schedule_reorder.py @@ -279,5 +279,18 @@ def test_reorder_fail_not_affine_bindings(): sch.reorder(l, i) +def test_reorder_fail_error_msg(): + sch = tir.Schedule(elementwise_not_affine, debug_mask="all") + block_b = sch.get_block("B") + i, j, k, l = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError) as execinfo: + sch.reorder(l, i) + expected_sub_error_message = ( + 'with tir.block("B"):\n' "\t\t\t^^^^^^^^^^^^^^^^^^^^\n" "\t\t\ttir.Block#0" + ) + print(str(execinfo.value)) + assert expected_sub_error_message in str(execinfo.value) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) From 6cc7a0357892c7ea59b8cbe4f2dedb90d8547f11 Mon Sep 17 00:00:00 2001 From: shingjan Date: Mon, 18 Oct 2021 21:30:36 -0700 Subject: [PATCH 06/16] fix test case --- tests/python/unittest/test_tir_schedule_reorder.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/python/unittest/test_tir_schedule_reorder.py b/tests/python/unittest/test_tir_schedule_reorder.py index 0f0f26863a9d..b83c056b9422 100644 --- a/tests/python/unittest/test_tir_schedule_reorder.py +++ b/tests/python/unittest/test_tir_schedule_reorder.py @@ -286,9 +286,8 @@ def test_reorder_fail_error_msg(): with pytest.raises(tvm.tir.ScheduleError) as execinfo: sch.reorder(l, i) expected_sub_error_message = ( - 'with tir.block("B"):\n' "\t\t\t^^^^^^^^^^^^^^^^^^^^\n" "\t\t\ttir.Block#0" + 'with tir.block("B"):\n' "\t\t\t^^^^^^^^^^^^^^^^^^^^\n" " tir.Block#0" ) - print(str(execinfo.value)) assert expected_sub_error_message in str(execinfo.value) From 47da01e130cbf76cc24f7c96fc7db3860b175179 Mon Sep 17 00:00:00 2001 From: shingjan Date: Thu, 21 Oct 2021 15:59:56 -0700 Subject: [PATCH 07/16] fix nested loop --- src/printer/tvmscript_printer.cc | 107 +++++++++++++----- .../unittest/test_tir_schedule_reorder.py | 22 +++- 2 files changed, 101 insertions(+), 28 deletions(-) diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 61ff8a3863fc..f8001390e8e5 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -217,6 +217,7 @@ class TVMScriptPrinter : public StmtFunctor, Doc AllocVar(const Var& var); Doc AllocBuf(const Buffer& buffer); void TryDeallocVar(const Var& var); + bool ContainsOptionalInfo(const Stmt& stmt); /*! Helper functions for loop printing. */ /*! @@ -226,6 +227,14 @@ class TVMScriptPrinter : public StmtFunctor, Doc PrintLoop(const For& loop); /*! \brief Print all simple loops in stack into one line using tir_prefix_.grid(). */ Doc PrintLoopStack(); + /*! + * \brief Print all simple loops in stack into one line using tir_prefix_.grid(). + * \param for_op the for node to be checked + */ + bool IsSimpleLoop(const ForNode* for_op) { + return for_op->kind == ForKind::kSerial && for_op->annotations.empty() && + is_zero(for_op->min) && !ContainsOptionalInfo(GetRef(for_op)); + } /*! * \brief Print additional info about expr in comment. @@ -234,11 +243,9 @@ class TVMScriptPrinter : public StmtFunctor, Doc PrintOptionalInfo(const Stmt& stmt) { Doc doc; // default annotations - if (annotate_ != nullptr) { + if (ContainsOptionalInfo(stmt)) { std::string annotated_stmt = annotate_(stmt); - if (!annotated_stmt.empty()) { - doc << "# " << annotated_stmt << Doc::NewLine(); - } + doc << "# " << annotated_stmt << Doc::NewLine(); } return doc; } @@ -391,6 +398,16 @@ Doc TVMScriptPrinter::AllocBuf(const Buffer& buffer) { return val; } +/*! + * \brief Check if any optional information exists in annotate_ for + * a given Stmt. + * \param stmt The statement. + */ +bool TVMScriptPrinter::ContainsOptionalInfo(const Stmt& stmt) { + if (annotate_ == nullptr) return false; + return !annotate_(stmt).empty(); +} + /*! * \brief Try to dealloc vars out of space and leave the index to coming vars. * \note It is not a necessary step. @@ -835,14 +852,14 @@ Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) { var_not_in_headers_.insert(op->loop_var.get()); loop_var_map_[op->loop_var.get()] = GetRef(op); const auto* body = op->body.as(); - bool simple_loop = op->kind == ForKind::kSerial && op->annotations.empty() && is_zero(op->min); + bool simple_loop = IsSimpleLoop(op) && body != nullptr && IsSimpleLoop(body); if (simple_loop) simple_loop_stack_.push_back(GetRef(op)); // It is a loop that can be compressed, let the loops below print it out if (simple_loop && body != nullptr) { - Doc result = Print(GetRef(body)); + doc << Print(GetRef(body)); TryDeallocVar(op->loop_var); loop_var_map_.erase(op->loop_var.get()); - return result; + return doc; } // It is a loop that can not be compressed bool print_above = !simple_loop_stack_.empty(); @@ -1357,7 +1374,9 @@ class TVMScriptPrinterWithDiagnostic : public TVMScriptPrinter { protected: Doc VisitStmt_(const ForNode* op) override; Doc VisitStmt_(const BlockRealizeNode* op) override; - Doc PrintAnnotation(const Stmt stmt, int length); + Doc PrintUnderline(const Stmt& stmt, int length); + Doc PrintLoop(const For& loop); + Doc PrintLoopStack(); }; Doc TVMScriptPrinterWithDiagnostic::VisitStmt_(const ForNode* op) { @@ -1365,32 +1384,32 @@ Doc TVMScriptPrinterWithDiagnostic::VisitStmt_(const ForNode* op) { var_not_in_headers_.insert(op->loop_var.get()); loop_var_map_[op->loop_var.get()] = GetRef(op); const auto* body = op->body.as(); - bool simple_loop = op->kind == ForKind::kSerial && op->annotations.empty() && is_zero(op->min); - if (simple_loop) simple_loop_stack_.push_back(GetRef(op)); + + bool simple_loop = IsSimpleLoop(op) && body != nullptr && IsSimpleLoop(body); + if (simple_loop) { + simple_loop_stack_.push_back(GetRef(op)); + } // It is a loop that can be compressed, let the loops below print it out if (simple_loop && body != nullptr) { - Doc result = Print(GetRef(body)); + doc << Print(GetRef(body)); TryDeallocVar(op->loop_var); loop_var_map_.erase(op->loop_var.get()); - return result; + return doc; } // It is a loop that can not be compressed bool print_above = !simple_loop_stack_.empty(); // print loops above if needed if (print_above) { doc << PrintLoopStack(); - simple_loop_stack_.clear(); } if (!simple_loop) { // print current loop if needed Doc current_loop; - current_loop << PrintLoop(GetRef(op)) - << PrintAnnotation(GetRef(op), doc.str().size()); + current_loop << PrintLoop(GetRef(op)); current_loop << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); doc << (print_above ? Doc::Indent(4, Doc::NewLine() << current_loop) : current_loop); } else { - doc << PrintAnnotation(GetRef(op), doc.str().size()) - << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); + doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); } TryDeallocVar(op->loop_var); loop_var_map_.erase(op->loop_var.get()); @@ -1400,14 +1419,17 @@ Doc TVMScriptPrinterWithDiagnostic::VisitStmt_(const ForNode* op) { Doc TVMScriptPrinterWithDiagnostic::VisitStmt_(const BlockRealizeNode* op) { const auto* block_op = op->block.as(); // print block name - Doc doc; + // we want to print the option info for its BlockNode body first + // as it wouldn't be picked up by PrintOptionalInfo(BlockRealizeNode& op) + Doc doc = PrintOptionalInfo(GetRef(block_op)); + std::size_t option_info_size = doc.str().size(); doc << "with " << tir_prefix_ << ".block("; if (!block_op->name_hint.empty()) { doc << Doc::StrLiteral(block_op->name_hint); } doc << "):"; // annotation - doc << PrintAnnotation(GetRef(block_op), doc.str().size()); + doc << PrintUnderline(GetRef(block_op), doc.str().size() - option_info_size); // print block vars Doc block_var = PrintBlockVars(op); // print predicate, binding, read/write tensor region, annotations @@ -1421,19 +1443,52 @@ Doc TVMScriptPrinterWithDiagnostic::VisitStmt_(const BlockRealizeNode* op) { return doc; } -Doc TVMScriptPrinterWithDiagnostic::PrintAnnotation(const Stmt stmt, int length) { +Doc TVMScriptPrinterWithDiagnostic::PrintUnderline(const Stmt& stmt, int length) { Doc doc; // annotation - if (annotate_ != nullptr) { - String annotated_stmt = annotate_(stmt); - if (!annotated_stmt.empty()) { - String underline = std::string(length, '^'); - doc << Doc::NewLine() << underline << Doc::NewLine() << annotated_stmt; - } + if (ContainsOptionalInfo(stmt)) { + String underline = std::string(length, '^'); + doc << Doc::NewLine() << underline; } return doc; } +Doc TVMScriptPrinterWithDiagnostic::PrintLoop(const For& loop) { + Doc res; + res << "for " << Print(loop->loop_var) << " in " << tir_prefix_ + << "." + std::string(ForKind2String(loop->kind)) + "(" << Print(loop->min) << ", " + << Print(loop->min + loop->extent); + if (loop->thread_binding.defined()) { + res << ", thread="; + res << Print(loop->thread_binding.value()->thread_tag); + } + if (!loop->annotations.empty()) { + res << ", annotations={"; + res << PrintAnnotations(loop->annotations); + res << "}"; + } + res << "):"; + res << PrintUnderline(loop, res.str().size()); + return res; +} + +Doc TVMScriptPrinterWithDiagnostic::PrintLoopStack() { + Doc res; + if (simple_loop_stack_.size() == 1) { + res << PrintLoop(simple_loop_stack_[0]); + } else if (simple_loop_stack_.size() > 1) { + std::vector vars, extents; + for (const auto& loop : simple_loop_stack_) { + vars.push_back(Print(loop->loop_var)); + extents.push_back(Print(loop->extent)); + } + res << "for " << PrintSep(vars, Doc::Text(", ")) << " in " << tir_prefix_ << ".grid(" + << PrintSep(extents, Doc::Text(", ")) << "):"; + } + simple_loop_stack_.clear(); + return res; +} + String AsTVMScript(const ObjectRef& mod, const String& tir_prefix, bool show_meta) { ICHECK(mod->IsInstance() || mod->IsInstance()); return TVMScriptPrinter(tir_prefix, show_meta).Print(mod).str() + "\n"; diff --git a/tests/python/unittest/test_tir_schedule_reorder.py b/tests/python/unittest/test_tir_schedule_reorder.py index b83c056b9422..8e0a900cf73a 100644 --- a/tests/python/unittest/test_tir_schedule_reorder.py +++ b/tests/python/unittest/test_tir_schedule_reorder.py @@ -275,8 +275,9 @@ def test_reorder_fail_not_affine_bindings(): sch = tir.Schedule(elementwise_not_affine, debug_mask="all") block_b = sch.get_block("B") i, j, k, l = sch.get_loops(block_b) - with pytest.raises(tvm.tir.ScheduleError): + with pytest.raises(tvm.tir.ScheduleError) as msg: sch.reorder(l, i) + print(str(msg.value)) def test_reorder_fail_error_msg(): @@ -286,7 +287,24 @@ def test_reorder_fail_error_msg(): with pytest.raises(tvm.tir.ScheduleError) as execinfo: sch.reorder(l, i) expected_sub_error_message = ( - 'with tir.block("B"):\n' "\t\t\t^^^^^^^^^^^^^^^^^^^^\n" " tir.Block#0" + " # tir.Block#0\n" + ' with tir.block("B"):\n' + " ^^^^^^^^^^^^^^^^^^^^\n" + ) + assert expected_sub_error_message in str(execinfo.value) + + +def test_reorder_fail_nested_loop(): + sch = tir.Schedule(elementwise_non_single_branch, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError) as execinfo: + sch.reorder(k, i) + expected_sub_error_message = ( + " for i in tir.serial(0, 128):\n" + " # tir.For#0\n" + " for j in tir.serial(0, 128):\n" + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n" ) assert expected_sub_error_message in str(execinfo.value) From edda6a3b760ab47a12facc13d014f18de3efb3f6 Mon Sep 17 00:00:00 2001 From: shingjan Date: Thu, 21 Oct 2021 16:02:34 -0700 Subject: [PATCH 08/16] rm print --- tests/python/unittest/test_tir_schedule_reorder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/unittest/test_tir_schedule_reorder.py b/tests/python/unittest/test_tir_schedule_reorder.py index 8e0a900cf73a..88b61de1bd04 100644 --- a/tests/python/unittest/test_tir_schedule_reorder.py +++ b/tests/python/unittest/test_tir_schedule_reorder.py @@ -277,7 +277,6 @@ def test_reorder_fail_not_affine_bindings(): i, j, k, l = sch.get_loops(block_b) with pytest.raises(tvm.tir.ScheduleError) as msg: sch.reorder(l, i) - print(str(msg.value)) def test_reorder_fail_error_msg(): From aa481299036bb2a48dd5c9ca9b4ddfc57a42ffb3 Mon Sep 17 00:00:00 2001 From: shingjan Date: Thu, 21 Oct 2021 16:22:32 -0700 Subject: [PATCH 09/16] change simple loop cond --- src/printer/tvmscript_printer.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index f8001390e8e5..4264b77845e4 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -1385,8 +1385,8 @@ Doc TVMScriptPrinterWithDiagnostic::VisitStmt_(const ForNode* op) { loop_var_map_[op->loop_var.get()] = GetRef(op); const auto* body = op->body.as(); - bool simple_loop = IsSimpleLoop(op) && body != nullptr && IsSimpleLoop(body); - if (simple_loop) { + bool simple_loop = IsSimpleLoop(op); + if (simple_loop && body != nullptr && IsSimpleLoop(body)) { simple_loop_stack_.push_back(GetRef(op)); } // It is a loop that can be compressed, let the loops below print it out From 6bf45cf4085a6e92a82c83f9f5753a8875fc83ed Mon Sep 17 00:00:00 2001 From: shingjan Date: Thu, 21 Oct 2021 17:23:32 -0700 Subject: [PATCH 10/16] address comments --- src/printer/tvmscript_printer.cc | 117 ++++-------------- .../unittest/test_tir_schedule_reorder.py | 29 ----- .../unittest/test_tvmscript_error_report.py | 71 +++++++++++ 3 files changed, 93 insertions(+), 124 deletions(-) diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 4264b77845e4..c0b3caa2465f 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -208,6 +208,7 @@ class TVMScriptPrinter : public StmtFunctor, Doc PrintBlockVars(const BlockRealizeNode* op); Doc PrintBlockAttr(const BlockRealizeNode* op); Doc PrintBlockBody(const BlockNode* op); + virtual Doc PrintBlockName(const BlockNode* block_op); Doc PrintBufferRegion(const BufferRegionNode* op); Doc PrintMatchBufferRegion(const MatchBufferRegionNode* op); Doc PrintAnnotations(const Map& annotations); @@ -224,7 +225,7 @@ class TVMScriptPrinter : public StmtFunctor, * \brief Print a single for loop * \param loop The for loop to be printed */ - Doc PrintLoop(const For& loop); + virtual Doc PrintLoop(const For& loop); /*! \brief Print all simple loops in stack into one line using tir_prefix_.grid(). */ Doc PrintLoopStack(); /*! @@ -852,10 +853,10 @@ Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) { var_not_in_headers_.insert(op->loop_var.get()); loop_var_map_[op->loop_var.get()] = GetRef(op); const auto* body = op->body.as(); - bool simple_loop = IsSimpleLoop(op) && body != nullptr && IsSimpleLoop(body); + bool simple_loop = IsSimpleLoop(op); if (simple_loop) simple_loop_stack_.push_back(GetRef(op)); // It is a loop that can be compressed, let the loops below print it out - if (simple_loop && body != nullptr) { + if (simple_loop && body != nullptr && IsSimpleLoop(body)) { doc << Print(GetRef(body)); TryDeallocVar(op->loop_var); loop_var_map_.erase(op->loop_var.get()); @@ -933,6 +934,7 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferStoreNode* op) { return doc; } +/*! Helper functions for block printing. */ Doc TVMScriptPrinter::PrintBlockVar(const IterVar& iter_var, const PrimExpr& value) { Doc doc; doc << Print(iter_var->var) << " = " << tir_prefix_ << ".axis."; @@ -1066,15 +1068,24 @@ Doc TVMScriptPrinter::PrintBlockBody(const BlockNode* op) { return body; } -Doc TVMScriptPrinter::VisitStmt_(const BlockRealizeNode* op) { - const auto* block_op = op->block.as(); - // print block name and block vars +/*! + * \brief Print the name of a block + * \param block_op The block node to be printed + */ +Doc TVMScriptPrinter::PrintBlockName(const BlockNode* block_op) { Doc doc; doc << "with " << tir_prefix_ << ".block("; if (!block_op->name_hint.empty()) { doc << Doc::StrLiteral(block_op->name_hint); } doc << "):"; + return doc; +} + +Doc TVMScriptPrinter::VisitStmt_(const BlockRealizeNode* op) { + const auto* block_op = op->block.as(); + // print block name and block vars + Doc doc = PrintBlockName(block_op); Doc block_var = PrintBlockVars(op); // print predicate, binding, read/write tensor region, annotations Doc block_attr_doc = PrintBlockAttr(op); @@ -1372,74 +1383,20 @@ class TVMScriptPrinterWithDiagnostic : public TVMScriptPrinter { : TVMScriptPrinter(tir_prefix, show_meta, annotate) {} protected: - Doc VisitStmt_(const ForNode* op) override; - Doc VisitStmt_(const BlockRealizeNode* op) override; + Doc PrintBlockName(const BlockNode* block_op); Doc PrintUnderline(const Stmt& stmt, int length); Doc PrintLoop(const For& loop); - Doc PrintLoopStack(); }; -Doc TVMScriptPrinterWithDiagnostic::VisitStmt_(const ForNode* op) { - Doc doc; - var_not_in_headers_.insert(op->loop_var.get()); - loop_var_map_[op->loop_var.get()] = GetRef(op); - const auto* body = op->body.as(); - - bool simple_loop = IsSimpleLoop(op); - if (simple_loop && body != nullptr && IsSimpleLoop(body)) { - simple_loop_stack_.push_back(GetRef(op)); - } - // It is a loop that can be compressed, let the loops below print it out - if (simple_loop && body != nullptr) { - doc << Print(GetRef(body)); - TryDeallocVar(op->loop_var); - loop_var_map_.erase(op->loop_var.get()); - return doc; - } - // It is a loop that can not be compressed - bool print_above = !simple_loop_stack_.empty(); - // print loops above if needed - if (print_above) { - doc << PrintLoopStack(); - } - if (!simple_loop) { - // print current loop if needed - Doc current_loop; - current_loop << PrintLoop(GetRef(op)); - current_loop << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); - doc << (print_above ? Doc::Indent(4, Doc::NewLine() << current_loop) : current_loop); - } else { - doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); - } - TryDeallocVar(op->loop_var); - loop_var_map_.erase(op->loop_var.get()); - return doc; -} - -Doc TVMScriptPrinterWithDiagnostic::VisitStmt_(const BlockRealizeNode* op) { - const auto* block_op = op->block.as(); - // print block name - // we want to print the option info for its BlockNode body first - // as it wouldn't be picked up by PrintOptionalInfo(BlockRealizeNode& op) +Doc TVMScriptPrinterWithDiagnostic::PrintBlockName(const BlockNode* block_op) { Doc doc = PrintOptionalInfo(GetRef(block_op)); - std::size_t option_info_size = doc.str().size(); + auto optional_info_size = doc.str().size(); doc << "with " << tir_prefix_ << ".block("; if (!block_op->name_hint.empty()) { doc << Doc::StrLiteral(block_op->name_hint); } doc << "):"; - // annotation - doc << PrintUnderline(GetRef(block_op), doc.str().size() - option_info_size); - // print block vars - Doc block_var = PrintBlockVars(op); - // print predicate, binding, read/write tensor region, annotations - Doc block_attr_doc = PrintBlockAttr(op); - // print body - Doc body = PrintBlockBody(block_op); - doc << Doc::Indent(4, block_var << block_attr_doc << Doc::NewLine() << body); - for (const auto& iter_var : block_op->iter_vars) { - TryDeallocVar(iter_var->var); - } + doc << PrintUnderline(GetRef(block_op), doc.str().size() - optional_info_size); return doc; } @@ -1454,41 +1411,11 @@ Doc TVMScriptPrinterWithDiagnostic::PrintUnderline(const Stmt& stmt, int length) } Doc TVMScriptPrinterWithDiagnostic::PrintLoop(const For& loop) { - Doc res; - res << "for " << Print(loop->loop_var) << " in " << tir_prefix_ - << "." + std::string(ForKind2String(loop->kind)) + "(" << Print(loop->min) << ", " - << Print(loop->min + loop->extent); - if (loop->thread_binding.defined()) { - res << ", thread="; - res << Print(loop->thread_binding.value()->thread_tag); - } - if (!loop->annotations.empty()) { - res << ", annotations={"; - res << PrintAnnotations(loop->annotations); - res << "}"; - } - res << "):"; + Doc res = TVMScriptPrinter::PrintLoop(loop); res << PrintUnderline(loop, res.str().size()); return res; } -Doc TVMScriptPrinterWithDiagnostic::PrintLoopStack() { - Doc res; - if (simple_loop_stack_.size() == 1) { - res << PrintLoop(simple_loop_stack_[0]); - } else if (simple_loop_stack_.size() > 1) { - std::vector vars, extents; - for (const auto& loop : simple_loop_stack_) { - vars.push_back(Print(loop->loop_var)); - extents.push_back(Print(loop->extent)); - } - res << "for " << PrintSep(vars, Doc::Text(", ")) << " in " << tir_prefix_ << ".grid(" - << PrintSep(extents, Doc::Text(", ")) << "):"; - } - simple_loop_stack_.clear(); - return res; -} - String AsTVMScript(const ObjectRef& mod, const String& tir_prefix, bool show_meta) { ICHECK(mod->IsInstance() || mod->IsInstance()); return TVMScriptPrinter(tir_prefix, show_meta).Print(mod).str() + "\n"; diff --git a/tests/python/unittest/test_tir_schedule_reorder.py b/tests/python/unittest/test_tir_schedule_reorder.py index 88b61de1bd04..362318e63ec2 100644 --- a/tests/python/unittest/test_tir_schedule_reorder.py +++ b/tests/python/unittest/test_tir_schedule_reorder.py @@ -279,34 +279,5 @@ def test_reorder_fail_not_affine_bindings(): sch.reorder(l, i) -def test_reorder_fail_error_msg(): - sch = tir.Schedule(elementwise_not_affine, debug_mask="all") - block_b = sch.get_block("B") - i, j, k, l = sch.get_loops(block_b) - with pytest.raises(tvm.tir.ScheduleError) as execinfo: - sch.reorder(l, i) - expected_sub_error_message = ( - " # tir.Block#0\n" - ' with tir.block("B"):\n' - " ^^^^^^^^^^^^^^^^^^^^\n" - ) - assert expected_sub_error_message in str(execinfo.value) - - -def test_reorder_fail_nested_loop(): - sch = tir.Schedule(elementwise_non_single_branch, debug_mask="all") - block_b = sch.get_block("B") - i, j, k = sch.get_loops(block_b) - with pytest.raises(tvm.tir.ScheduleError) as execinfo: - sch.reorder(k, i) - expected_sub_error_message = ( - " for i in tir.serial(0, 128):\n" - " # tir.For#0\n" - " for j in tir.serial(0, 128):\n" - " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n" - ) - assert expected_sub_error_message in str(execinfo.value) - - if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 80c37229f519..a1b6517b0c02 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -18,6 +18,7 @@ import pytest import sys import tvm +from tvm import tir from tvm.script import tir as T from tvm.ir.diagnostics import override_renderer import inspect @@ -511,5 +512,75 @@ def render(e): # TODO(Siyuan): block iter errors. + +@T.prim_func +def elementwise_not_affine(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128, 128)) + for i, j, k, l in T.grid(128, 128, 128, 8): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + vl = T.axis.S(128, l * 16) + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + + +@T.prim_func +def elementwise_non_single_branch(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128)) + C = T.alloc_buffer((128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128)) + for i, j in T.grid(128, 128): + for k in T.serial(0, 128): + with T.block("C"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + C[vi, vj, vk] = A[vi, vj, vk] * 2.0 + for k in T.serial(0, 128): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + B[vi, vj, vk] = C[vi, vj, vk] * 2.0 + + +def test_reorder_fail_block(): + sch = tir.Schedule(elementwise_not_affine, debug_mask="all") + block_b = sch.get_block("B") + i, j, k, l = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError) as execinfo: + sch.reorder(l, i) + expected_sub_error_message = ( + "# tir.Block#0\n" ' with tir.block("B"):\n' " ^^^^^^^^^^^^^^^^^^^^\n" + ) + assert expected_sub_error_message in str(execinfo.value) + + +def test_reorder_fail_nested_loop_inner(): + sch = tir.Schedule(elementwise_non_single_branch, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError) as execinfo: + sch.reorder(k, i) + expected_sub_error_message = ( + " for i in tir.serial(0, 128):\n" + " # tir.For#0\n" + " for j in tir.serial(0, 128):\n" + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n" + ) + assert expected_sub_error_message in str(execinfo.value) + + +def test_reorder_fail_nested_loop_outer(): + sch = tir.Schedule(elementwise_non_single_branch, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError) as execinfo: + sch.fuse(k, i) + expected_sub_error_message = ( + " # tir.For#1\n" + " for i in tir.serial(0, 128):\n" + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n" + " for j in tir.serial(0, 128):\n" + ) + assert expected_sub_error_message in str(execinfo.value) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) From 92a114e381e2a327c4650e3c76c3959418dbcede Mon Sep 17 00:00:00 2001 From: shingjan Date: Thu, 21 Oct 2021 17:36:31 -0700 Subject: [PATCH 11/16] fix test --- tests/python/unittest/test_tvmscript_error_report.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index a1b6517b0c02..032bc7b77812 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -547,7 +547,9 @@ def test_reorder_fail_block(): with pytest.raises(tvm.tir.ScheduleError) as execinfo: sch.reorder(l, i) expected_sub_error_message = ( - "# tir.Block#0\n" ' with tir.block("B"):\n' " ^^^^^^^^^^^^^^^^^^^^\n" + " # tir.Block#0\n" + ' with tir.block("B"):\n' + " ^^^^^^^^^^^^^^^^^^^^\n" ) assert expected_sub_error_message in str(execinfo.value) From 3ff431ab23dfeff755a6ce236e7322a6000b059f Mon Sep 17 00:00:00 2001 From: shingjan Date: Thu, 21 Oct 2021 17:41:39 -0700 Subject: [PATCH 12/16] address comments --- src/printer/tvmscript_printer.cc | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index c0b3caa2465f..554e5513b3e8 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -1390,13 +1390,8 @@ class TVMScriptPrinterWithDiagnostic : public TVMScriptPrinter { Doc TVMScriptPrinterWithDiagnostic::PrintBlockName(const BlockNode* block_op) { Doc doc = PrintOptionalInfo(GetRef(block_op)); - auto optional_info_size = doc.str().size(); - doc << "with " << tir_prefix_ << ".block("; - if (!block_op->name_hint.empty()) { - doc << Doc::StrLiteral(block_op->name_hint); - } - doc << "):"; - doc << PrintUnderline(GetRef(block_op), doc.str().size() - optional_info_size); + Doc block_name = TVMScriptPrinter::PrintBlockName(block_op); + doc << block_name << PrintUnderline(GetRef(block_op), block_name.str().size()); return doc; } From b3c232c7104abb72d78ed68506a797a77ea1b9c9 Mon Sep 17 00:00:00 2001 From: shingjan Date: Thu, 21 Oct 2021 17:43:00 -0700 Subject: [PATCH 13/16] remove msg --- tests/python/unittest/test_tir_schedule_reorder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_schedule_reorder.py b/tests/python/unittest/test_tir_schedule_reorder.py index 362318e63ec2..8267a369cf5d 100644 --- a/tests/python/unittest/test_tir_schedule_reorder.py +++ b/tests/python/unittest/test_tir_schedule_reorder.py @@ -275,7 +275,7 @@ def test_reorder_fail_not_affine_bindings(): sch = tir.Schedule(elementwise_not_affine, debug_mask="all") block_b = sch.get_block("B") i, j, k, l = sch.get_loops(block_b) - with pytest.raises(tvm.tir.ScheduleError) as msg: + with pytest.raises(tvm.tir.ScheduleError): sch.reorder(l, i) From 05bfc315ac3876045ac9651e07359b8ee46ad012 Mon Sep 17 00:00:00 2001 From: shingjan Date: Thu, 21 Oct 2021 18:01:44 -0700 Subject: [PATCH 14/16] add override --- src/printer/tvmscript_printer.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 554e5513b3e8..f22c5429b3ed 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -1383,9 +1383,9 @@ class TVMScriptPrinterWithDiagnostic : public TVMScriptPrinter { : TVMScriptPrinter(tir_prefix, show_meta, annotate) {} protected: - Doc PrintBlockName(const BlockNode* block_op); + Doc PrintBlockName(const BlockNode* block_op) override; Doc PrintUnderline(const Stmt& stmt, int length); - Doc PrintLoop(const For& loop); + Doc PrintLoop(const For& loop) override; }; Doc TVMScriptPrinterWithDiagnostic::PrintBlockName(const BlockNode* block_op) { From 6388858a63ab4568ae4461d86768a27d35271a11 Mon Sep 17 00:00:00 2001 From: shingjan Date: Fri, 22 Oct 2021 11:19:25 -0700 Subject: [PATCH 15/16] address comments --- src/printer/tvmscript_printer.cc | 8 ++++---- src/tir/schedule/error.cc | 6 +++--- tests/python/unittest/test_tvmscript_error_report.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index f22c5429b3ed..d8fb8fa3b237 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -1073,7 +1073,7 @@ Doc TVMScriptPrinter::PrintBlockBody(const BlockNode* op) { * \param block_op The block node to be printed */ Doc TVMScriptPrinter::PrintBlockName(const BlockNode* block_op) { - Doc doc; + Doc doc = PrintOptionalInfo(GetRef(block_op)); doc << "with " << tir_prefix_ << ".block("; if (!block_op->name_hint.empty()) { doc << Doc::StrLiteral(block_op->name_hint); @@ -1389,9 +1389,9 @@ class TVMScriptPrinterWithDiagnostic : public TVMScriptPrinter { }; Doc TVMScriptPrinterWithDiagnostic::PrintBlockName(const BlockNode* block_op) { - Doc doc = PrintOptionalInfo(GetRef(block_op)); - Doc block_name = TVMScriptPrinter::PrintBlockName(block_op); - doc << block_name << PrintUnderline(GetRef(block_op), block_name.str().size()); + Doc doc = TVMScriptPrinter::PrintBlockName(block_op); + Doc optional_info = PrintOptionalInfo(GetRef(block_op)); + doc << PrintUnderline(GetRef(block_op), doc.str().size() - optional_info.str().size()); return doc; } diff --git a/src/tir/schedule/error.cc b/src/tir/schedule/error.cc index 28b168889329..eb72773ffedb 100644 --- a/src/tir/schedule/error.cc +++ b/src/tir/schedule/error.cc @@ -45,9 +45,9 @@ String ScheduleError::RenderReport(const String& primitive) const { runtime::TypedPackedFunc annotate = runtime::TypedPackedFunc( [&loc_obj_to_name](const Stmt& expr) -> std::string { - auto search = loc_obj_to_name.find(Downcast(expr)); - if (search == loc_obj_to_name.end()) return ""; - return search->second; + auto it = loc_obj_to_name.find(Downcast(expr)); + if (it == loc_obj_to_name.end()) return ""; + return it->second; }); os << "ScheduleError: An error occurred in the schedule primitive '" << primitive diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 032bc7b77812..3098c86a7c2e 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -569,7 +569,7 @@ def test_reorder_fail_nested_loop_inner(): assert expected_sub_error_message in str(execinfo.value) -def test_reorder_fail_nested_loop_outer(): +def test_fuse_fail_nested_loop_outer(): sch = tir.Schedule(elementwise_non_single_branch, debug_mask="all") block_b = sch.get_block("B") i, j, k = sch.get_loops(block_b) From f690a33275b23f9c86f36c917bed17a6b3de70ef Mon Sep 17 00:00:00 2001 From: shingjan Date: Fri, 22 Oct 2021 11:34:01 -0700 Subject: [PATCH 16/16] address comments --- src/printer/tvmscript_printer.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index d8fb8fa3b237..8ac745f675d9 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -1073,7 +1073,7 @@ Doc TVMScriptPrinter::PrintBlockBody(const BlockNode* op) { * \param block_op The block node to be printed */ Doc TVMScriptPrinter::PrintBlockName(const BlockNode* block_op) { - Doc doc = PrintOptionalInfo(GetRef(block_op)); + Doc doc; doc << "with " << tir_prefix_ << ".block("; if (!block_op->name_hint.empty()) { doc << Doc::StrLiteral(block_op->name_hint); @@ -1084,8 +1084,9 @@ Doc TVMScriptPrinter::PrintBlockName(const BlockNode* block_op) { Doc TVMScriptPrinter::VisitStmt_(const BlockRealizeNode* op) { const auto* block_op = op->block.as(); + Doc doc = PrintOptionalInfo(GetRef(block_op)); // print block name and block vars - Doc doc = PrintBlockName(block_op); + doc << PrintBlockName(block_op); Doc block_var = PrintBlockVars(op); // print predicate, binding, read/write tensor region, annotations Doc block_attr_doc = PrintBlockAttr(op); @@ -1390,8 +1391,7 @@ class TVMScriptPrinterWithDiagnostic : public TVMScriptPrinter { Doc TVMScriptPrinterWithDiagnostic::PrintBlockName(const BlockNode* block_op) { Doc doc = TVMScriptPrinter::PrintBlockName(block_op); - Doc optional_info = PrintOptionalInfo(GetRef(block_op)); - doc << PrintUnderline(GetRef(block_op), doc.str().size() - optional_info.str().size()); + doc << PrintUnderline(GetRef(block_op), doc.str().size()); return doc; }