Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,9 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,

String AsTVMScript(const ObjectRef& mod, const String& tir_prefix = "tir", bool show_meta = false);

String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta,
runtime::TypedPackedFunc<std::string(Stmt)> annotate);

} // namespace tir
} // namespace tvm

Expand Down
102 changes: 89 additions & 13 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
*/
TVM_DLL Doc Print(const ObjectRef& node);

private:
protected:
/*! \brief The tir prefix */
String tir_prefix_;
/*! \brief whether show meta data */
Expand Down Expand Up @@ -208,6 +208,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc PrintBlockVars(const BlockRealizeNode* op);
Doc PrintBlockAttr(const BlockRealizeNode* op);
Doc PrintBlockBody(const BlockNode* op);
virtual Doc PrintBlockName(const BlockNode* block_op);
Doc PrintBufferRegion(const BufferRegionNode* op);
Doc PrintMatchBufferRegion(const MatchBufferRegionNode* op);
Doc PrintAnnotations(const Map<String, ObjectRef>& annotations);
Expand All @@ -217,15 +218,24 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc AllocVar(const Var& var);
Doc AllocBuf(const Buffer& buffer);
void TryDeallocVar(const Var& var);
bool ContainsOptionalInfo(const Stmt& stmt);

/*! Helper functions for loop printing. */
/*!
* \brief Print a single for loop
* \param loop The for loop to be printed
*/
Doc PrintLoop(const For& loop);
virtual Doc PrintLoop(const For& loop);
/*! \brief Print all simple loops in stack into one line using tir_prefix_.grid(). */
Doc PrintLoopStack();
/*!
* \brief Print all simple loops in stack into one line using tir_prefix_.grid().
* \param for_op the for node to be checked
*/
bool IsSimpleLoop(const ForNode* for_op) {
return for_op->kind == ForKind::kSerial && for_op->annotations.empty() &&
is_zero(for_op->min) && !ContainsOptionalInfo(GetRef<Stmt>(for_op));
}

/*!
* \brief Print additional info about expr in comment.
Expand All @@ -234,11 +244,9 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc PrintOptionalInfo(const Stmt& stmt) {
Doc doc;
// default annotations
if (annotate_ != nullptr) {
if (ContainsOptionalInfo(stmt)) {
std::string annotated_stmt = annotate_(stmt);
if (!annotated_stmt.empty()) {
doc << "# " << annotated_stmt << Doc::NewLine();
}
doc << "# " << annotated_stmt << Doc::NewLine();
}
return doc;
}
Expand Down Expand Up @@ -391,6 +399,16 @@ Doc TVMScriptPrinter::AllocBuf(const Buffer& buffer) {
return val;
}

/*!
* \brief Check if any optional information exists in annotate_ for
* a given Stmt.
* \param stmt The statement.
*/
bool TVMScriptPrinter::ContainsOptionalInfo(const Stmt& stmt) {
if (annotate_ == nullptr) return false;
return !annotate_(stmt).empty();
}

/*!
* \brief Try to dealloc vars out of space and leave the index to coming vars.
* \note It is not a necessary step.
Expand Down Expand Up @@ -835,14 +853,14 @@ Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) {
var_not_in_headers_.insert(op->loop_var.get());
loop_var_map_[op->loop_var.get()] = GetRef<For>(op);
const auto* body = op->body.as<ForNode>();
bool simple_loop = op->kind == ForKind::kSerial && op->annotations.empty() && is_zero(op->min);
bool simple_loop = IsSimpleLoop(op);
if (simple_loop) simple_loop_stack_.push_back(GetRef<For>(op));
// It is a loop that can be compressed, let the loops below print it out
if (simple_loop && body != nullptr) {
Doc result = Print(GetRef<For>(body));
if (simple_loop && body != nullptr && IsSimpleLoop(body)) {
doc << Print(GetRef<For>(body));
TryDeallocVar(op->loop_var);
loop_var_map_.erase(op->loop_var.get());
return result;
return doc;
}
// It is a loop that can not be compressed
bool print_above = !simple_loop_stack_.empty();
Expand Down Expand Up @@ -916,6 +934,7 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferStoreNode* op) {
return doc;
}

/*! Helper functions for block printing. */
Doc TVMScriptPrinter::PrintBlockVar(const IterVar& iter_var, const PrimExpr& value) {
Doc doc;
doc << Print(iter_var->var) << " = " << tir_prefix_ << ".axis.";
Expand Down Expand Up @@ -1049,15 +1068,25 @@ Doc TVMScriptPrinter::PrintBlockBody(const BlockNode* op) {
return body;
}

Doc TVMScriptPrinter::VisitStmt_(const BlockRealizeNode* op) {
const auto* block_op = op->block.as<BlockNode>();
// print block name and block vars
/*!
* \brief Print the name of a block
* \param block_op The block node to be printed
*/
Doc TVMScriptPrinter::PrintBlockName(const BlockNode* block_op) {
Doc doc;
doc << "with " << tir_prefix_ << ".block(";
if (!block_op->name_hint.empty()) {
doc << Doc::StrLiteral(block_op->name_hint);
}
doc << "):";
return doc;
}

Doc TVMScriptPrinter::VisitStmt_(const BlockRealizeNode* op) {
const auto* block_op = op->block.as<BlockNode>();
Doc doc = PrintOptionalInfo(GetRef<Stmt>(block_op));
// print block name and block vars
doc << PrintBlockName(block_op);
Doc block_var = PrintBlockVars(op);
// print predicate, binding, read/write tensor region, annotations
Doc block_attr_doc = PrintBlockAttr(op);
Expand Down Expand Up @@ -1343,12 +1372,59 @@ Doc TVMScriptPrinter::PrintLoopStack() {
return res;
}

/*!
* \brief The printer for TVMScript with diagnostic
* \details The printer obtain the precedence of the top-level operation when printing each
* subexpression to decide whether or not parentheses is needed.
*/
class TVMScriptPrinterWithDiagnostic : public TVMScriptPrinter {
public:
explicit TVMScriptPrinterWithDiagnostic(const String& tir_prefix, bool show_meta,
runtime::TypedPackedFunc<std::string(Stmt)> annotate)
: TVMScriptPrinter(tir_prefix, show_meta, annotate) {}

protected:
Doc PrintBlockName(const BlockNode* block_op) override;
Doc PrintUnderline(const Stmt& stmt, int length);
Doc PrintLoop(const For& loop) override;
};

Doc TVMScriptPrinterWithDiagnostic::PrintBlockName(const BlockNode* block_op) {
Doc doc = TVMScriptPrinter::PrintBlockName(block_op);
doc << PrintUnderline(GetRef<Stmt>(block_op), doc.str().size());
return doc;
}

Doc TVMScriptPrinterWithDiagnostic::PrintUnderline(const Stmt& stmt, int length) {
Doc doc;
// annotation
if (ContainsOptionalInfo(stmt)) {
String underline = std::string(length, '^');
doc << Doc::NewLine() << underline;
}
return doc;
}

Doc TVMScriptPrinterWithDiagnostic::PrintLoop(const For& loop) {
Doc res = TVMScriptPrinter::PrintLoop(loop);
res << PrintUnderline(loop, res.str().size());
return res;
}

String AsTVMScript(const ObjectRef& mod, const String& tir_prefix, bool show_meta) {
ICHECK(mod->IsInstance<PrimFuncNode>() || mod->IsInstance<IRModuleNode>());
return TVMScriptPrinter(tir_prefix, show_meta).Print(mod).str() + "\n";
}

TVM_REGISTER_GLOBAL("script.AsTVMScript").set_body_typed(AsTVMScript);

String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta,
runtime::TypedPackedFunc<std::string(Stmt)> annotate) {
ICHECK(mod->IsInstance<PrimFuncNode>() || mod->IsInstance<IRModuleNode>());
return TVMScriptPrinterWithDiagnostic(tir_prefix, show_meta, annotate).Print(mod).str() + "\n";
}

TVM_REGISTER_GLOBAL("script.AsTVMScriptWithDiagnostic").set_body_typed(AsTVMScriptWithDiagnostic);

} // namespace tir
} // namespace tvm
44 changes: 26 additions & 18 deletions src/tir/schedule/error.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,37 @@ namespace tir {
String ScheduleError::RenderReport(const String& primitive) const {
IRModule mod = this->mod();
std::ostringstream os;
os << "ScheduleError: An error occurred in the schedule primitive '" << primitive
<< "'.\n\nThe IR is:\n"
<< AsTVMScript(mod);

// get locations of interest
Array<ObjectRef> locs = LocationsOfInterest();
std::unordered_map<ObjectRef, String, ObjectPtrHash, ObjectPtrEqual> loc_obj_to_name;
int n_locs = locs.size();
std::vector<String> roi_names;
roi_names.reserve(n_locs);
if (n_locs > 0) {
os << "Regions of interest:\n";
for (const ObjectRef& obj : locs) {
String name = obj->GetTypeKey() + '#' + std::to_string(roi_names.size());
os << name << "\n" << obj;
roi_names.emplace_back(std::move(name));
}
os << "\n";
}
std::string msg = DetailRenderTemplate();
for (int i = 0; i < n_locs; ++i) {
std::string src = "{" + std::to_string(i) + "}";
for (size_t pos; (pos = msg.find(src)) != std::string::npos;) {
msg.replace(pos, src.length(), roi_names[i]);
if (n_locs > 0) {
for (int i = 0; i < n_locs; ++i) {
std::string name = locs[i]->GetTypeKey() + '#' + std::to_string(i);
std::string src = "{" + std::to_string(i) + "}";
for (size_t pos; (pos = msg.find(src)) != std::string::npos;) {
msg.replace(pos, src.length(), name);
}
loc_obj_to_name.emplace(locs[i], std::move(name));
}
}

// print IR module
runtime::TypedPackedFunc<std::string(Stmt)> annotate =
runtime::TypedPackedFunc<std::string(Stmt)>(
[&loc_obj_to_name](const Stmt& expr) -> std::string {
auto it = loc_obj_to_name.find(Downcast<ObjectRef>(expr));
if (it == loc_obj_to_name.end()) return "";
return it->second;
});

os << "ScheduleError: An error occurred in the schedule primitive '" << primitive
<< "'.\n\nThe IR with diagnostic is:\n"
<< AsTVMScriptWithDiagnostic(mod, "tir", false, annotate);

// print error message
os << "Error message: " << msg;
return os.str();
}
Expand Down
73 changes: 73 additions & 0 deletions tests/python/unittest/test_tvmscript_error_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pytest
import sys
import tvm
from tvm import tir
from tvm.script import tir as T
from tvm.ir.diagnostics import override_renderer
import inspect
Expand Down Expand Up @@ -511,5 +512,77 @@ def render(e):

# TODO(Siyuan): block iter errors.


@T.prim_func
def elementwise_not_affine(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, (128, 128, 128, 128))
B = T.match_buffer(b, (128, 128, 128, 128))
for i, j, k, l in T.grid(128, 128, 128, 8):
with T.block("B"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
vl = T.axis.S(128, l * 16)
B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0


@T.prim_func
def elementwise_non_single_branch(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, (128, 128, 128))
C = T.alloc_buffer((128, 128, 128))
B = T.match_buffer(b, (128, 128, 128))
for i, j in T.grid(128, 128):
for k in T.serial(0, 128):
with T.block("C"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
C[vi, vj, vk] = A[vi, vj, vk] * 2.0
for k in T.serial(0, 128):
with T.block("B"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
B[vi, vj, vk] = C[vi, vj, vk] * 2.0


def test_reorder_fail_block():
sch = tir.Schedule(elementwise_not_affine, debug_mask="all")
block_b = sch.get_block("B")
i, j, k, l = sch.get_loops(block_b)
with pytest.raises(tvm.tir.ScheduleError) as execinfo:
sch.reorder(l, i)
expected_sub_error_message = (
" # tir.Block#0\n"
' with tir.block("B"):\n'
" ^^^^^^^^^^^^^^^^^^^^\n"
)
assert expected_sub_error_message in str(execinfo.value)


def test_reorder_fail_nested_loop_inner():
sch = tir.Schedule(elementwise_non_single_branch, debug_mask="all")
block_b = sch.get_block("B")
i, j, k = sch.get_loops(block_b)
with pytest.raises(tvm.tir.ScheduleError) as execinfo:
sch.reorder(k, i)
expected_sub_error_message = (
" for i in tir.serial(0, 128):\n"
" # tir.For#0\n"
" for j in tir.serial(0, 128):\n"
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n"
)
assert expected_sub_error_message in str(execinfo.value)


def test_fuse_fail_nested_loop_outer():
sch = tir.Schedule(elementwise_non_single_branch, debug_mask="all")
block_b = sch.get_block("B")
i, j, k = sch.get_loops(block_b)
with pytest.raises(tvm.tir.ScheduleError) as execinfo:
sch.fuse(k, i)
expected_sub_error_message = (
" # tir.For#1\n"
" for i in tir.serial(0, 128):\n"
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n"
" for j in tir.serial(0, 128):\n"
)
assert expected_sub_error_message in str(execinfo.value)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))