From 44acec2814bcf157d4d95d026ef7c9749e00dfb6 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 2 Apr 2024 14:18:32 +0000 Subject: [PATCH] Perform cancellation in SimplifyingIrBuilder::addExpr --- csrc/ir/builder.cpp | 125 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 124 insertions(+), 1 deletion(-) diff --git a/csrc/ir/builder.cpp b/csrc/ir/builder.cpp index 3833ed054d4..5a84a0cd276 100644 --- a/csrc/ir/builder.cpp +++ b/csrc/ir/builder.cpp @@ -355,9 +355,132 @@ Val* SimplifyingIrBuilder::addExpr(Val* lhs, Val* rhs) { return addExpr(rhs, lhs->value()); } else if (rhs->isConst()) { return addExpr(lhs, rhs->value(), rhs->dtype()); - } else { + } + // Flatten nested additions. For example, suppose we have + // lhs = a + ( b + c ) + // rhs = d + ( ( -b ) + e ) + // Then we would flatten this to a sum of the terms + // a b c d ( -b ) e + // Now suppose we have + // lhs = a + b + // rhs = d + ( -( b + e ) ) + // We can flatten this to the sum of + // a b d ( -b ) ( -e ) + // We track a b d as positive terms and b e as negative terms, after which + // we can cancel b easily + std::vector pos_terms; + std::vector neg_terms; + + // We'll do a stack-based search, using two stacks: one for positive and one + // for negative terms. Whenever the negative term stack is non-empty, we + // draw from it first. + std::vector pos_term_stack{lhs, rhs}; + std::vector neg_term_stack; + while (!pos_term_stack.empty() || !neg_term_stack.empty()) { + Val* term = nullptr; + bool term_is_neg = false; + if (!neg_term_stack.empty()) { + term = neg_term_stack.back(); + neg_term_stack.pop_back(); + term_is_neg = true; + } else { + term = pos_term_stack.back(); + pos_term_stack.pop_back(); + } + if (auto* uop = dynamic_cast(term->definition()); + uop && uop->getUnaryOpType() == UnaryOpType::Neg) { + Val* unneg = uop->in(); + if (term_is_neg) { + // This is a term like -(-a) or (-b) within -(a + (-b) ), so we treat + // it as a positive term + pos_term_stack.push_back(unneg); + } else { + neg_term_stack.push_back(unneg); + } + continue; + } else if (auto* bop = dynamic_cast(term->definition())) { + if (bop->getBinaryOpType() == BinaryOpType::Add) { + if (term_is_neg) { + neg_term_stack.push_back(bop->lhs()); + neg_term_stack.push_back(bop->rhs()); + } else { + pos_term_stack.push_back(bop->lhs()); + pos_term_stack.push_back(bop->rhs()); + } + continue; + } else if (bop->getBinaryOpType() == BinaryOpType::Sub) { + if (term_is_neg) { + neg_term_stack.push_back(bop->lhs()); + pos_term_stack.push_back(bop->rhs()); + } else { + pos_term_stack.push_back(bop->lhs()); + neg_term_stack.push_back(bop->rhs()); + } + continue; + } + } + if (term_is_neg) { + neg_terms.push_back(term); + } else { + pos_terms.push_back(term); + } + } + // Now do cancellation + std::vector cancelled_pos(pos_terms.size(), false); + std::vector cancelled_neg(neg_terms.size(), false); + bool performed_cancellation = false; + for (size_t j : c10::irange(neg_terms.size())) { + if (cancelled_neg[j]) { + continue; + } + for (size_t i : c10::irange(pos_terms.size())) { + if (cancelled_pos[i]) { + continue; + } + if (neg_terms[j]->sameAs(pos_terms[i])) { + // Mark this pair of terms as cancelled + cancelled_pos[i] = true; + cancelled_neg[j] = true; + performed_cancellation = true; + break; + } + } + } + // If we didn't perform any cancellations, then perform the simple sum of + // lhs and rhs + if (!performed_cancellation) { return IrBuilder::addExpr(lhs, rhs); } + + // If we did cancellations, then unflatten to produce simplified sum. We + // unflatten all the positives then all the negatives, then we negate the + // unflattened negative sum and add it to the positive sum. + Val* pos_sum = nullptr; + for (size_t i : c10::irange(pos_terms.size())) { + if (cancelled_pos[i]) { + continue; + } + pos_sum = pos_sum == nullptr ? pos_terms[i] + : IrBuilder::addExpr(pos_sum, pos_terms[i]); + } + Val* neg_sum = nullptr; + for (size_t j : c10::irange(neg_terms.size())) { + if (cancelled_neg[j]) { + continue; + } + neg_sum = neg_sum == nullptr ? neg_terms[j] + : IrBuilder::addExpr(neg_sum, neg_terms[j]); + } + if (neg_sum == nullptr) { + return pos_sum == nullptr + ? lhs->fusion()->zeroVal(promoteType(lhs->dtype(), rhs->dtype())) + : pos_sum; + } + neg_sum = IrBuilder::negExpr(neg_sum); + if (pos_sum == nullptr) { + return neg_sum; + } + return IrBuilder::addExpr(pos_sum, neg_sum); } Val* SimplifyingIrBuilder::subExpr(Val* lhs, Val* rhs) {