Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 227 additions & 20 deletions src/tir/schedule/ir_comparator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,63 @@ bool TensorizeComparator::VisitExpr(const PrimExpr& n, const PrimExpr& other) {

bool TensorizeComparator::VisitStmt_(const ForNode* op, const Stmt& other) {
const auto* rhs = other.as<ForNode>();
if (!DefEqual(op->loop_var, rhs->loop_var)) return false;
if (!VisitExpr(op->min, rhs->min)) return false;
if (!VisitExpr(op->extent, rhs->extent)) return false;
if (op->thread_binding.defined() != rhs->thread_binding.defined()) return false;
if (!DefEqual(op->loop_var, rhs->loop_var)) {
if (assert_mode_) {
std::ostringstream os;
os << "ForNode loop vars do not match: op->loop_var=" << op->loop_var
<< " vs rhs->loop_var=" << rhs->loop_var;
EmitError(os.str());
}
return false;
}
if (!VisitExpr(op->min, rhs->min)) {
if (assert_mode_) {
std::ostringstream os;
os << "ForNode min values do not match: op->min=" << op->min << " vs rhs->min=" << rhs->min;
EmitError(os.str());
}
return false;
}
if (!VisitExpr(op->extent, rhs->extent)) {
if (assert_mode_) {
std::ostringstream os;
os << "ForNode extent values do not match: op->extent=" << op->extent
<< " vs rhs->extent=" << rhs->extent;
EmitError(os.str());
}
return false;
}
if (op->thread_binding.defined() != rhs->thread_binding.defined()) {
if (assert_mode_) {
std::ostringstream os;
os << "ForNode thread_bindings do not match: op->thread_binding.defined()="
<< op->thread_binding.defined()
<< " vs rhs->thread_binding.defined()=" << rhs->thread_binding.defined();
EmitError(os.str());
}
return false;
}
if (op->thread_binding.defined() &&
!VisitExpr(op->thread_binding.value(), rhs->thread_binding.value())) {
return false;
}
if (op->kind != rhs->kind) return false;
if (!CompareAnnotationMap(op->annotations, rhs->annotations)) return false;
if (op->kind != rhs->kind) {
if (assert_mode_) {
std::ostringstream os;
os << "ForNode kinds do not match: op->kind=" << op->kind << " vs rhs->kind=" << rhs->kind;
EmitError(os.str());
}
return false;
}
if (!CompareAnnotationMap(op->annotations, rhs->annotations)) {
if (assert_mode_) {
std::ostringstream os;
os << "ForNode annotation maps do not match: op->annotations=" << op->annotations
<< " vs rhs->annotations=" << rhs->annotations;
EmitError(os.str());
}
return false;
}
return VisitStmt(op->body, rhs->body);
}

Expand All @@ -112,6 +159,12 @@ bool TensorizeComparator::VisitStmt_(const BlockRealizeNode* op, const Stmt& oth
const auto* rhs = other.as<BlockRealizeNode>();
if (!is_scope_block) {
if (!CompareArray(op->iter_values, rhs->iter_values, &TensorizeComparator::VisitExpr)) {
if (assert_mode_) {
std::ostringstream os;
os << "BlockRealizeNode iter_values do not match: op->iter_values=" << op->iter_values
<< " vs rhs->iter_values=" << rhs->iter_values;
EmitError(os.str());
}
return false;
}
}
Expand All @@ -125,16 +178,40 @@ bool TensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt& other) {
// When checking iter vars, DefEqual is used to remap variables.
if (!is_scope_block) {
if (!CompareArray(op->iter_vars, rhs->iter_vars, &TensorizeComparator::CompareIterVar)) {
if (assert_mode_) {
std::ostringstream os;
os << "BlockNode iter_vars do not match: op->alloc_buffers=" << op->iter_vars
<< " vs rhs->alloc_buffers=" << rhs->iter_vars;
EmitError(os.str());
}
return false;
}
if (!CompareArray(op->alloc_buffers, rhs->alloc_buffers, &TensorizeComparator::CompareBuffer)) {
if (assert_mode_) {
std::ostringstream os;
os << "BlockNode alloc_buffers do not match: op->alloc_buffers=" << op->alloc_buffers
<< " vs rhs->alloc_buffers=" << rhs->alloc_buffers;
EmitError(os.str());
}
return false;
}
}
if (!CompareArray(op->writes, rhs->writes, &TensorizeComparator::CompareBufferRegion)) {
if (assert_mode_) {
std::ostringstream os;
os << "BlockNode write buffers do not match: op->writes=" << op->writes
<< " vs rhs->writes=" << rhs->writes;
EmitError(os.str());
}
return false;
}
if (!CompareArray(op->reads, rhs->reads, &TensorizeComparator::CompareBufferRegion)) {
if (assert_mode_) {
std::ostringstream os;
os << "BlockNode read buffers regions do not match: op->reads=" << op->reads
<< " vs rhs->reads=" << rhs->reads;
EmitError(os.str());
}
return false;
}
is_scope_block = false;
Expand Down Expand Up @@ -168,12 +245,30 @@ TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(FloorModNode);

bool TensorizeComparator::VisitExpr_(const IntImmNode* op, const PrimExpr& other) {
const auto* rhs = other.as<IntImmNode>();
return op->value == rhs->value;
if (op->value != rhs->value) {
if (assert_mode_) {
std::ostringstream os;
os << "IntImmNode values do not match: op->value=" << op->value
<< " vs rhs->value=" << rhs->value;
EmitError(os.str());
}
return false;
}
return true;
}

bool TensorizeComparator::VisitExpr_(const FloatImmNode* op, const PrimExpr& other) {
const auto* rhs = other.as<FloatImmNode>();
return op->value == rhs->value;
if (op->value != rhs->value) {
if (assert_mode_) {
std::ostringstream os;
os << "FloatImmNode values do not match: op->value=" << op->value
<< " vs rhs->value=" << rhs->value;
EmitError(os.str());
}
return false;
}
return true;
}

bool TensorizeComparator::VisitExpr_(const CastNode* op, const PrimExpr& other) {
Expand All @@ -185,7 +280,15 @@ bool TensorizeComparator::VisitExpr_(const VarNode* op, const PrimExpr& other) {
const auto* rhs = other.as<VarNode>();
auto lhs = GetRef<Var>(op);
if (lhs.same_as(other)) return true;
if (op->dtype.code() != rhs->dtype.code()) return false;
if (op->dtype.code() != rhs->dtype.code()) {
if (assert_mode_) {
std::ostringstream os;
os << "VarNode data type codes do not match: op->dtype.code()=" << op->dtype.code()
<< " vs rhs->dtype.code()=" << rhs->dtype.code();
EmitError(os.str());
}
return false;
}
auto it = equal_map_.find(lhs);
return it != equal_map_.end() && it->second.same_as(other);
}
Expand Down Expand Up @@ -216,14 +319,30 @@ bool TensorizeComparator::DefEqual(const Var& lhs, const Var& rhs) {

bool TensorizeComparator::CompareAnnotation(const std::pair<String, ObjectRef>& lhs,
const std::pair<String, ObjectRef>& rhs) {
if (lhs.first != rhs.first) return false;
if (lhs.first != rhs.first) {
if (assert_mode_) {
std::ostringstream os;
os << "CompareAnnotation key mismatch: lhs.first=" << lhs.first
<< " vs rhs.first=" << rhs.first;
EmitError(os.str());
}
return false;
}
return VisitExpr(Downcast<PrimExpr>(lhs.second), Downcast<PrimExpr>(rhs.second));
}

bool TensorizeComparator::CompareAnnotationMap(const Map<String, ObjectRef>& lhs,
const Map<String, ObjectRef>& rhs) {
if (lhs.same_as(rhs)) return true;
if (lhs.size() != rhs.size()) return false;
if (lhs.size() != rhs.size()) {
if (assert_mode_) {
std::ostringstream os;
os << "CompareAnnotationMap size mismatch: lhs.size()=" << lhs.size()
<< " vs rhs.size()=" << rhs.size();
EmitError(os.str());
}
return false;
}

auto sort_map =
[](const Map<String, ObjectRef>& map) -> std::vector<std::pair<String, ObjectRef>> {
Expand All @@ -236,7 +355,14 @@ bool TensorizeComparator::CompareAnnotationMap(const Map<String, ObjectRef>& lhs
std::vector<std::pair<String, ObjectRef>> rhs_array = sort_map(rhs);

for (size_t i = 0; i < lhs.size(); ++i) {
if (!CompareAnnotation(lhs_array[i], rhs_array[i])) return false;
if (!CompareAnnotation(lhs_array[i], rhs_array[i])) {
if (assert_mode_) {
std::ostringstream os;
os << "CompareAnnotationMap annotations mismatch within AnnotationMap.";
EmitError(os.str());
}
return false;
}
}
return true;
}
Expand All @@ -253,6 +379,14 @@ bool TensorizeComparator::CompareBuffer(const Buffer& lhs, const Buffer& rhs) {
DefEqual(lhs->data, rhs->data) && lhs->dtype == rhs->dtype && lhs.scope() == rhs.scope();
if (equal) {
rhs_buffer_map_[rhs] = lhs;
} else {
if (assert_mode_) {
std::ostringstream os;
os << "CompareBuffer buffer mismatch. data: " << lhs->data << " vs " << rhs->data
<< ", dtypes: " << lhs->dtype << " vs " << rhs->dtype << ", scope(): " << lhs.scope()
<< " vs " << rhs.scope();
EmitError(os.str());
}
}
}
return equal;
Expand All @@ -262,14 +396,24 @@ bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const Buf
if (!CompareBuffer(lhs->buffer, rhs->buffer)) {
if (assert_mode_) {
std::ostringstream os;
os << "Buffer mismatch: " << lhs->buffer << " vs " << rhs->buffer;
os << "CompareBufferRegion returning false due to buffer mismatch: lhs->buffer="
<< lhs->buffer << " vs rhs->buffer=" << rhs->buffer;
EmitError(os.str());
}
return false;
}
int offset = static_cast<int>(lhs->region.size()) - static_cast<int>(rhs->region.size());
// Number of indices in RHS (desc of the tensor intrinsic) must be smaller than it in LHS
if (offset < 0) return false;
if (offset < 0) {
if (assert_mode_) {
std::ostringstream os;
os << "CompareBufferRegion returning false because buffer region sizes do not match: "
"lhs->region.size()="
<< lhs->region.size() << " vs rhs->region.size()=" << rhs->region.size();
EmitError(os.str());
}
return false;
}

auto it = buffer_indices_.find(lhs->buffer);
if (it == buffer_indices_.end()) {
Expand All @@ -279,14 +423,29 @@ bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const Buf
indices_base.reserve(lhs->region.size());
for (int i = 0; i < offset; i++) {
// High-dim region must be element-wise
if (!is_one(lhs->region[i]->extent)) return false;
if (!is_one(lhs->region[i]->extent)) {
if (assert_mode_) {
std::ostringstream os;
os << "CompareBufferRegion returning false because buffer extent high-dim region must be "
"element-wise. lhs->region[i]->extent="
<< lhs->region[i]->extent;
EmitError(os.str());
}
return false;
}
indices_base.emplace_back(lhs->region[i]->min);
}
for (size_t i = 0; i < rhs->region.size(); i++) {
// save base index
indices_base.emplace_back(lhs->region[i + offset]->min);
// check extent match
if (!analyzer_.CanProveEqual(lhs->region[i + offset]->extent, rhs->region[i]->extent)) {
if (assert_mode_) {
std::ostringstream os;
os << "CompareBufferRegion buffer extent mismatch: lhs->region[i + offset]="
<< lhs->region[i + offset] << " vs rhs->region[i]=" << rhs->region[i];
EmitError(os.str());
}
return false;
}
}
Expand All @@ -296,16 +455,46 @@ bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const Buf
const std::vector<PrimExpr>& indices_base = it->second;
for (int i = 0; i < offset; i++) {
// High-dim region must be element-wise
if (!is_one(lhs->region[i]->extent)) return false;
if (!analyzer_.CanProveEqual(indices_base[i], lhs->region[i]->min)) return false;
if (!is_one(lhs->region[i]->extent)) {
if (assert_mode_) {
std::ostringstream os;
os << "CompareBufferRegion returning false because buffer extent high-dim region must be "
"element-wise. lhs->region[i]->extent="
<< lhs->region[i]->extent;
EmitError(os.str());
}
return false;
}
if (!analyzer_.CanProveEqual(indices_base[i], lhs->region[i]->min)) {
if (assert_mode_) {
std::ostringstream os;
os << "Buffer base index consistency check failed due to unequal index base: "
"indices_base[i]="
<< indices_base[i] << " vs lhs->region[i]->min=" << lhs->region[i]->min;
EmitError(os.str());
}
return false;
}
}
for (size_t i = 0; i < rhs->region.size(); i++) {
// check extent match
if (!analyzer_.CanProveEqual(lhs->region[i + offset]->extent, rhs->region[i]->extent)) {
if (assert_mode_) {
std::ostringstream os;
os << "CompareBufferRegion buffer region extent mismatch. lhs->region[i + offset]="
<< lhs->region[i + offset] << " vs rhs->region[i]=" << rhs->region[i];
EmitError(os.str());
}
return false;
}
PrimExpr normalized_lhs_min = (lhs->region[i + offset]->min - indices_base[i + offset]);
if (!analyzer_.CanProveEqual(normalized_lhs_min, rhs->region[i]->min)) {
if (assert_mode_) {
std::ostringstream os;
os << "CompareBufferRegion buffer region min mismatch. lhs->region[i + offset]="
<< lhs->region[i + offset] << " vs rhs->region[i]=" << rhs->region[i];
EmitError(os.str());
}
return false;
}
}
Expand All @@ -318,7 +507,16 @@ template <typename T>
bool TensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) {
if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false;
int offset = static_cast<int>(lhs->indices.size()) - static_cast<int>(rhs->indices.size());
if (offset < 0) return false;
if (offset < 0) {
if (assert_mode_) {
std::ostringstream os;
os << "CompareBufferAccess returning false because buffer indices sizes do not match: "
"lhs->indices.size()="
<< lhs->indices.size() << " vs rhs->indices.size()=" << rhs->indices.size();
EmitError(os.str());
}
return false;
}
auto it = buffer_indices_.find(lhs->buffer);
ICHECK(it != buffer_indices_.end());
const std::vector<PrimExpr>& indices_base = (*it).second;
Expand All @@ -328,7 +526,8 @@ bool TensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) {
if (!analyzer_.CanProveEqual(normalized_lhs_index, rhs->indices[i])) {
if (assert_mode_) {
std::ostringstream os;
os << "Buffer indices mismatch: " << lhs->indices[i + offset] << " vs " << rhs->indices[i];
os << "CompareBufferAccess buffer indices mismatch. lhs->indices[i + offset]="
<< lhs->indices[i + offset] << " vs rhs->indices[i]=" << rhs->indices[i];
EmitError(os.str());
}
return false;
Expand All @@ -340,7 +539,15 @@ bool TensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) {
template <typename T, typename Self, typename F>
bool TensorizeComparator::CompareArray(const Array<T>& lhs, const Array<T>& rhs, F Self::*cmp) {
if (lhs.same_as(rhs)) return true;
if (lhs.size() != rhs.size()) return false;
if (lhs.size() != rhs.size()) {
if (assert_mode_) {
std::ostringstream os;
os << "CompareArray array size mismatch. lhs.size()=" << lhs.size()
<< " vs rhs.size()=" << rhs.size();
EmitError(os.str());
}
return false;
}
for (size_t i = 0; i < lhs.size(); ++i) {
if (!(static_cast<Self*>(this)->*cmp)(lhs[i], rhs[i])) return false;
}
Expand Down