diff --git a/src/tir/schedule/ir_comparator.cc b/src/tir/schedule/ir_comparator.cc index ea0ac0bc733d..9d89c641630b 100644 --- a/src/tir/schedule/ir_comparator.cc +++ b/src/tir/schedule/ir_comparator.cc @@ -85,16 +85,63 @@ bool TensorizeComparator::VisitExpr(const PrimExpr& n, const PrimExpr& other) { bool TensorizeComparator::VisitStmt_(const ForNode* op, const Stmt& other) { const auto* rhs = other.as(); - if (!DefEqual(op->loop_var, rhs->loop_var)) return false; - if (!VisitExpr(op->min, rhs->min)) return false; - if (!VisitExpr(op->extent, rhs->extent)) return false; - if (op->thread_binding.defined() != rhs->thread_binding.defined()) return false; + if (!DefEqual(op->loop_var, rhs->loop_var)) { + if (assert_mode_) { + std::ostringstream os; + os << "ForNode loop vars do not match: op->loop_var=" << op->loop_var + << " vs rhs->loop_var=" << rhs->loop_var; + EmitError(os.str()); + } + return false; + } + if (!VisitExpr(op->min, rhs->min)) { + if (assert_mode_) { + std::ostringstream os; + os << "ForNode min values do not match: op->min=" << op->min << " vs rhs->min=" << rhs->min; + EmitError(os.str()); + } + return false; + } + if (!VisitExpr(op->extent, rhs->extent)) { + if (assert_mode_) { + std::ostringstream os; + os << "ForNode extent values do not match: op->extent=" << op->extent + << " vs rhs->extent=" << rhs->extent; + EmitError(os.str()); + } + return false; + } + if (op->thread_binding.defined() != rhs->thread_binding.defined()) { + if (assert_mode_) { + std::ostringstream os; + os << "ForNode thread_bindings do not match: op->thread_binding.defined()=" + << op->thread_binding.defined() + << " vs rhs->thread_binding.defined()=" << rhs->thread_binding.defined(); + EmitError(os.str()); + } + return false; + } if (op->thread_binding.defined() && !VisitExpr(op->thread_binding.value(), rhs->thread_binding.value())) { return false; } - if (op->kind != rhs->kind) return false; - if (!CompareAnnotationMap(op->annotations, rhs->annotations)) return false; + if (op->kind != rhs->kind) { + if (assert_mode_) { + std::ostringstream os; + os << "ForNode kinds do not match: op->kind=" << op->kind << " vs rhs->kind=" << rhs->kind; + EmitError(os.str()); + } + return false; + } + if (!CompareAnnotationMap(op->annotations, rhs->annotations)) { + if (assert_mode_) { + std::ostringstream os; + os << "ForNode annotation maps do not match: op->annotations=" << op->annotations + << " vs rhs->annotations=" << rhs->annotations; + EmitError(os.str()); + } + return false; + } return VisitStmt(op->body, rhs->body); } @@ -112,6 +159,12 @@ bool TensorizeComparator::VisitStmt_(const BlockRealizeNode* op, const Stmt& oth const auto* rhs = other.as(); if (!is_scope_block) { if (!CompareArray(op->iter_values, rhs->iter_values, &TensorizeComparator::VisitExpr)) { + if (assert_mode_) { + std::ostringstream os; + os << "BlockRealizeNode iter_values do not match: op->iter_values=" << op->iter_values + << " vs rhs->iter_values=" << rhs->iter_values; + EmitError(os.str()); + } return false; } } @@ -125,16 +178,40 @@ bool TensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt& other) { // When checking iter vars, DefEqual is used to remap variables. if (!is_scope_block) { if (!CompareArray(op->iter_vars, rhs->iter_vars, &TensorizeComparator::CompareIterVar)) { + if (assert_mode_) { + std::ostringstream os; + os << "BlockNode iter_vars do not match: op->alloc_buffers=" << op->iter_vars + << " vs rhs->alloc_buffers=" << rhs->iter_vars; + EmitError(os.str()); + } return false; } if (!CompareArray(op->alloc_buffers, rhs->alloc_buffers, &TensorizeComparator::CompareBuffer)) { + if (assert_mode_) { + std::ostringstream os; + os << "BlockNode alloc_buffers do not match: op->alloc_buffers=" << op->alloc_buffers + << " vs rhs->alloc_buffers=" << rhs->alloc_buffers; + EmitError(os.str()); + } return false; } } if (!CompareArray(op->writes, rhs->writes, &TensorizeComparator::CompareBufferRegion)) { + if (assert_mode_) { + std::ostringstream os; + os << "BlockNode write buffers do not match: op->writes=" << op->writes + << " vs rhs->writes=" << rhs->writes; + EmitError(os.str()); + } return false; } if (!CompareArray(op->reads, rhs->reads, &TensorizeComparator::CompareBufferRegion)) { + if (assert_mode_) { + std::ostringstream os; + os << "BlockNode read buffers regions do not match: op->reads=" << op->reads + << " vs rhs->reads=" << rhs->reads; + EmitError(os.str()); + } return false; } is_scope_block = false; @@ -168,12 +245,30 @@ TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(FloorModNode); bool TensorizeComparator::VisitExpr_(const IntImmNode* op, const PrimExpr& other) { const auto* rhs = other.as(); - return op->value == rhs->value; + if (op->value != rhs->value) { + if (assert_mode_) { + std::ostringstream os; + os << "IntImmNode values do not match: op->value=" << op->value + << " vs rhs->value=" << rhs->value; + EmitError(os.str()); + } + return false; + } + return true; } bool TensorizeComparator::VisitExpr_(const FloatImmNode* op, const PrimExpr& other) { const auto* rhs = other.as(); - return op->value == rhs->value; + if (op->value != rhs->value) { + if (assert_mode_) { + std::ostringstream os; + os << "FloatImmNode values do not match: op->value=" << op->value + << " vs rhs->value=" << rhs->value; + EmitError(os.str()); + } + return false; + } + return true; } bool TensorizeComparator::VisitExpr_(const CastNode* op, const PrimExpr& other) { @@ -185,7 +280,15 @@ bool TensorizeComparator::VisitExpr_(const VarNode* op, const PrimExpr& other) { const auto* rhs = other.as(); auto lhs = GetRef(op); if (lhs.same_as(other)) return true; - if (op->dtype.code() != rhs->dtype.code()) return false; + if (op->dtype.code() != rhs->dtype.code()) { + if (assert_mode_) { + std::ostringstream os; + os << "VarNode data type codes do not match: op->dtype.code()=" << op->dtype.code() + << " vs rhs->dtype.code()=" << rhs->dtype.code(); + EmitError(os.str()); + } + return false; + } auto it = equal_map_.find(lhs); return it != equal_map_.end() && it->second.same_as(other); } @@ -216,14 +319,30 @@ bool TensorizeComparator::DefEqual(const Var& lhs, const Var& rhs) { bool TensorizeComparator::CompareAnnotation(const std::pair& lhs, const std::pair& rhs) { - if (lhs.first != rhs.first) return false; + if (lhs.first != rhs.first) { + if (assert_mode_) { + std::ostringstream os; + os << "CompareAnnotation key mismatch: lhs.first=" << lhs.first + << " vs rhs.first=" << rhs.first; + EmitError(os.str()); + } + return false; + } return VisitExpr(Downcast(lhs.second), Downcast(rhs.second)); } bool TensorizeComparator::CompareAnnotationMap(const Map& lhs, const Map& rhs) { if (lhs.same_as(rhs)) return true; - if (lhs.size() != rhs.size()) return false; + if (lhs.size() != rhs.size()) { + if (assert_mode_) { + std::ostringstream os; + os << "CompareAnnotationMap size mismatch: lhs.size()=" << lhs.size() + << " vs rhs.size()=" << rhs.size(); + EmitError(os.str()); + } + return false; + } auto sort_map = [](const Map& map) -> std::vector> { @@ -236,7 +355,14 @@ bool TensorizeComparator::CompareAnnotationMap(const Map& lhs std::vector> rhs_array = sort_map(rhs); for (size_t i = 0; i < lhs.size(); ++i) { - if (!CompareAnnotation(lhs_array[i], rhs_array[i])) return false; + if (!CompareAnnotation(lhs_array[i], rhs_array[i])) { + if (assert_mode_) { + std::ostringstream os; + os << "CompareAnnotationMap annotations mismatch within AnnotationMap."; + EmitError(os.str()); + } + return false; + } } return true; } @@ -253,6 +379,14 @@ bool TensorizeComparator::CompareBuffer(const Buffer& lhs, const Buffer& rhs) { DefEqual(lhs->data, rhs->data) && lhs->dtype == rhs->dtype && lhs.scope() == rhs.scope(); if (equal) { rhs_buffer_map_[rhs] = lhs; + } else { + if (assert_mode_) { + std::ostringstream os; + os << "CompareBuffer buffer mismatch. data: " << lhs->data << " vs " << rhs->data + << ", dtypes: " << lhs->dtype << " vs " << rhs->dtype << ", scope(): " << lhs.scope() + << " vs " << rhs.scope(); + EmitError(os.str()); + } } } return equal; @@ -262,14 +396,24 @@ bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const Buf if (!CompareBuffer(lhs->buffer, rhs->buffer)) { if (assert_mode_) { std::ostringstream os; - os << "Buffer mismatch: " << lhs->buffer << " vs " << rhs->buffer; + os << "CompareBufferRegion returning false due to buffer mismatch: lhs->buffer=" + << lhs->buffer << " vs rhs->buffer=" << rhs->buffer; EmitError(os.str()); } return false; } int offset = static_cast(lhs->region.size()) - static_cast(rhs->region.size()); // Number of indices in RHS (desc of the tensor intrinsic) must be smaller than it in LHS - if (offset < 0) return false; + if (offset < 0) { + if (assert_mode_) { + std::ostringstream os; + os << "CompareBufferRegion returning false because buffer region sizes do not match: " + "lhs->region.size()=" + << lhs->region.size() << " vs rhs->region.size()=" << rhs->region.size(); + EmitError(os.str()); + } + return false; + } auto it = buffer_indices_.find(lhs->buffer); if (it == buffer_indices_.end()) { @@ -279,7 +423,16 @@ bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const Buf indices_base.reserve(lhs->region.size()); for (int i = 0; i < offset; i++) { // High-dim region must be element-wise - if (!is_one(lhs->region[i]->extent)) return false; + if (!is_one(lhs->region[i]->extent)) { + if (assert_mode_) { + std::ostringstream os; + os << "CompareBufferRegion returning false because buffer extent high-dim region must be " + "element-wise. lhs->region[i]->extent=" + << lhs->region[i]->extent; + EmitError(os.str()); + } + return false; + } indices_base.emplace_back(lhs->region[i]->min); } for (size_t i = 0; i < rhs->region.size(); i++) { @@ -287,6 +440,12 @@ bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const Buf indices_base.emplace_back(lhs->region[i + offset]->min); // check extent match if (!analyzer_.CanProveEqual(lhs->region[i + offset]->extent, rhs->region[i]->extent)) { + if (assert_mode_) { + std::ostringstream os; + os << "CompareBufferRegion buffer extent mismatch: lhs->region[i + offset]=" + << lhs->region[i + offset] << " vs rhs->region[i]=" << rhs->region[i]; + EmitError(os.str()); + } return false; } } @@ -296,16 +455,46 @@ bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const Buf const std::vector& indices_base = it->second; for (int i = 0; i < offset; i++) { // High-dim region must be element-wise - if (!is_one(lhs->region[i]->extent)) return false; - if (!analyzer_.CanProveEqual(indices_base[i], lhs->region[i]->min)) return false; + if (!is_one(lhs->region[i]->extent)) { + if (assert_mode_) { + std::ostringstream os; + os << "CompareBufferRegion returning false because buffer extent high-dim region must be " + "element-wise. lhs->region[i]->extent=" + << lhs->region[i]->extent; + EmitError(os.str()); + } + return false; + } + if (!analyzer_.CanProveEqual(indices_base[i], lhs->region[i]->min)) { + if (assert_mode_) { + std::ostringstream os; + os << "Buffer base index consistency check failed due to unequal index base: " + "indices_base[i]=" + << indices_base[i] << " vs lhs->region[i]->min=" << lhs->region[i]->min; + EmitError(os.str()); + } + return false; + } } for (size_t i = 0; i < rhs->region.size(); i++) { // check extent match if (!analyzer_.CanProveEqual(lhs->region[i + offset]->extent, rhs->region[i]->extent)) { + if (assert_mode_) { + std::ostringstream os; + os << "CompareBufferRegion buffer region extent mismatch. lhs->region[i + offset]=" + << lhs->region[i + offset] << " vs rhs->region[i]=" << rhs->region[i]; + EmitError(os.str()); + } return false; } PrimExpr normalized_lhs_min = (lhs->region[i + offset]->min - indices_base[i + offset]); if (!analyzer_.CanProveEqual(normalized_lhs_min, rhs->region[i]->min)) { + if (assert_mode_) { + std::ostringstream os; + os << "CompareBufferRegion buffer region min mismatch. lhs->region[i + offset]=" + << lhs->region[i + offset] << " vs rhs->region[i]=" << rhs->region[i]; + EmitError(os.str()); + } return false; } } @@ -318,7 +507,16 @@ template bool TensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) { if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false; int offset = static_cast(lhs->indices.size()) - static_cast(rhs->indices.size()); - if (offset < 0) return false; + if (offset < 0) { + if (assert_mode_) { + std::ostringstream os; + os << "CompareBufferAccess returning false because buffer indices sizes do not match: " + "lhs->indices.size()=" + << lhs->indices.size() << " vs rhs->indices.size()=" << rhs->indices.size(); + EmitError(os.str()); + } + return false; + } auto it = buffer_indices_.find(lhs->buffer); ICHECK(it != buffer_indices_.end()); const std::vector& indices_base = (*it).second; @@ -328,7 +526,8 @@ bool TensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) { if (!analyzer_.CanProveEqual(normalized_lhs_index, rhs->indices[i])) { if (assert_mode_) { std::ostringstream os; - os << "Buffer indices mismatch: " << lhs->indices[i + offset] << " vs " << rhs->indices[i]; + os << "CompareBufferAccess buffer indices mismatch. lhs->indices[i + offset]=" + << lhs->indices[i + offset] << " vs rhs->indices[i]=" << rhs->indices[i]; EmitError(os.str()); } return false; @@ -340,7 +539,15 @@ bool TensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) { template bool TensorizeComparator::CompareArray(const Array& lhs, const Array& rhs, F Self::*cmp) { if (lhs.same_as(rhs)) return true; - if (lhs.size() != rhs.size()) return false; + if (lhs.size() != rhs.size()) { + if (assert_mode_) { + std::ostringstream os; + os << "CompareArray array size mismatch. lhs.size()=" << lhs.size() + << " vs rhs.size()=" << rhs.size(); + EmitError(os.str()); + } + return false; + } for (size_t i = 0; i < lhs.size(); ++i) { if (!(static_cast(this)->*cmp)(lhs[i], rhs[i])) return false; }