Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 124 additions & 1 deletion csrc/ir/builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,132 @@ Val* SimplifyingIrBuilder::addExpr(Val* lhs, Val* rhs) {
return addExpr(rhs, lhs->value());
} else if (rhs->isConst()) {
return addExpr(lhs, rhs->value(), rhs->dtype());
} else {
}
// Flatten nested additions. For example, suppose we have
// lhs = a + ( b + c )
// rhs = d + ( ( -b ) + e )
// Then we would flatten this to a sum of the terms
// a b c d ( -b ) e
// Now suppose we have
// lhs = a + b
// rhs = d + ( -( b + e ) )
// We can flatten this to the sum of
// a b d ( -b ) ( -e )
// We track a b d as positive terms and b e as negative terms, after which
// we can cancel b easily
std::vector<Val*> pos_terms;
std::vector<Val*> neg_terms;

// We'll do a stack-based search, using two stacks: one for positive and one
// for negative terms. Whenever the negative term stack is non-empty, we
// draw from it first.
std::vector<Val*> pos_term_stack{lhs, rhs};
std::vector<Val*> neg_term_stack;
while (!pos_term_stack.empty() || !neg_term_stack.empty()) {
Val* term = nullptr;
bool term_is_neg = false;
if (!neg_term_stack.empty()) {
term = neg_term_stack.back();
neg_term_stack.pop_back();
term_is_neg = true;
} else {
term = pos_term_stack.back();
pos_term_stack.pop_back();
}
if (auto* uop = dynamic_cast<UnaryOp*>(term->definition());
uop && uop->getUnaryOpType() == UnaryOpType::Neg) {
Val* unneg = uop->in();
if (term_is_neg) {
// This is a term like -(-a) or (-b) within -(a + (-b) ), so we treat
// it as a positive term
pos_term_stack.push_back(unneg);
} else {
neg_term_stack.push_back(unneg);
}
continue;
} else if (auto* bop = dynamic_cast<BinaryOp*>(term->definition())) {
if (bop->getBinaryOpType() == BinaryOpType::Add) {
if (term_is_neg) {
neg_term_stack.push_back(bop->lhs());
neg_term_stack.push_back(bop->rhs());
} else {
pos_term_stack.push_back(bop->lhs());
pos_term_stack.push_back(bop->rhs());
}
continue;
} else if (bop->getBinaryOpType() == BinaryOpType::Sub) {
if (term_is_neg) {
neg_term_stack.push_back(bop->lhs());
pos_term_stack.push_back(bop->rhs());
} else {
pos_term_stack.push_back(bop->lhs());
neg_term_stack.push_back(bop->rhs());
}
continue;
}
}
if (term_is_neg) {
neg_terms.push_back(term);
} else {
pos_terms.push_back(term);
}
}
// Now do cancellation
std::vector<bool> cancelled_pos(pos_terms.size(), false);
std::vector<bool> cancelled_neg(neg_terms.size(), false);
bool performed_cancellation = false;
for (size_t j : c10::irange(neg_terms.size())) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Iterate over negative terms first since they are likely more rare. If no negative terms were found, we'll quickly see that no cancellation is possible and return the simple sum.

if (cancelled_neg[j]) {
continue;
}
for (size_t i : c10::irange(pos_terms.size())) {
if (cancelled_pos[i]) {
continue;
}
if (neg_terms[j]->sameAs(pos_terms[i])) {
// Mark this pair of terms as cancelled
cancelled_pos[i] = true;
cancelled_neg[j] = true;
performed_cancellation = true;
break;
}
}
}
// If we didn't perform any cancellations, then perform the simple sum of
// lhs and rhs
if (!performed_cancellation) {
return IrBuilder::addExpr(lhs, rhs);
}

// If we did cancellations, then unflatten to produce simplified sum. We
// unflatten all the positives then all the negatives, then we negate the
// unflattened negative sum and add it to the positive sum.
Val* pos_sum = nullptr;
for (size_t i : c10::irange(pos_terms.size())) {
if (cancelled_pos[i]) {
continue;
}
pos_sum = pos_sum == nullptr ? pos_terms[i]
: IrBuilder::addExpr(pos_sum, pos_terms[i]);
}
Val* neg_sum = nullptr;
for (size_t j : c10::irange(neg_terms.size())) {
if (cancelled_neg[j]) {
continue;
}
neg_sum = neg_sum == nullptr ? neg_terms[j]
: IrBuilder::addExpr(neg_sum, neg_terms[j]);
}
if (neg_sum == nullptr) {
return pos_sum == nullptr
? lhs->fusion()->zeroVal(promoteType(lhs->dtype(), rhs->dtype()))
: pos_sum;
}
neg_sum = IrBuilder::negExpr(neg_sum);
if (pos_sum == nullptr) {
return neg_sum;
}
return IrBuilder::addExpr(pos_sum, neg_sum);
}

Val* SimplifyingIrBuilder::subExpr(Val* lhs, Val* rhs) {
Expand Down