Skip to content

Do some more simplifications specific to extents#3891

Merged
naoyam merged 8 commits intomainfrom
simplify_resize_extents
Feb 20, 2025
Merged

Do some more simplifications specific to extents#3891
naoyam merged 8 commits intomainfrom
simplify_resize_extents

Conversation

@naoyam
Copy link
Collaborator

@naoyam naoyam commented Feb 14, 2025

While working on #3848, I noticed test_unpadded_catop_issue2275_repro2 took an extremely long time (> 30 min). That seems largely due to index hoisting and expression simplification. It took just a few seconds when they were disabled. That's likely due to a lot of min and max due to slicing of symbolic extents, as shown below:

T10_l_float[iS152{( fmax(0, ( fmin(( fmax(0, ( fmin(i0, 2) )) ), 2) )) )}, iS153{( fmax(0, ( fmin(( fmax(0, ( fmin(i1, 32) )) ), 32) )) )}, iS154{( fmax(0, ( fmin(( fmax(0, ( fmin(i2, 4096) )) ), 4096) )) )}, iS155{( ( ( fmax(64, ( fmin(( fmax(0, ( fmin(i3, 128) )) ), 128) )) ) - 64 ) + ( fmax(0, ( fmin(( fmax(0, ( fmin(i3, 128) )) ), 64) )) ) )}]
   = __bfloat2float(T9_l___bfloat[iS148{( fmax(0, ( fmin(( fmax(0, ( fmin(i0, 2) )) ), 2) )) )}, iS149{( fmax(0, ( fmin(( fmax(0, ( fmin(i1, 32) )) ), 32) )) )}, iS150{( fmax(0, ( fmin(( fmax(0, ( fmin(i2, 4096) )) ), 4096) )) )}, iS151{( ( ( fmax(64, ( fmin(( fmax(0, ( fmin(i3, 128) )) ), 128) )) ) - 64 ) + ( fmax(0, ( fmin(( fmax(0, ( fmin(i3, 128) )) ), 64) )) ) )}]);

This PR tries to simplifies these extents a little further, which results in:

T10_l_float[?S54{( fmin(i0, 2) )}, ?S55{( fmin(i1, 32) )}, ?S56{( fmin(i2, 4096) )}, ?S57{( ( ( fmax(64, ( fmin(i3, 128) )) ) - 64 ) + ( fmin(i3, 64) ) )}]
      = __bfloat2float(T9_l___bfloat[?S50{( fmin(i0, 2) )}, ?S51{( fmin(i1, 32) )}, ?S52{( fmin(i2, 4096) )}, ?S53{( ( ( fmax(64, ( fmin(i3, 128) )) ) - 64 ) + ( fmin(i3, 64) ) )}]);

The test time is reduced to several seconds by these simplifications.

Confirmed no failure with manual_ci.sh on an H100 machine.

@github-actions
Copy link

github-actions bot commented Feb 14, 2025

Review updated until commit 885d52f

Description

  • Added simplifications for expressions involving extents in slice function.

  • Enhanced addExpr to handle specific simplifications for negation.

  • Updated test cases to reflect new simplifications.


Changes walkthrough 📝

Relevant files
Enhancement
builder.cpp
Add negation simplifications in addExpr                                   

csrc/ir/builder.cpp

  • Added simplification for (-x) + x and x + (-x) to 0 in addExpr.
  • +11/-0   
    alias.cpp
    Enhance extent simplifications in slice                                   

    csrc/ops/alias.cpp

  • Included expr_evaluator.h for expression evaluation.
  • Added get_int function to evaluate integer values.
  • Implemented min_extents function for specialized min simplifications.
  • Updated normalize_slice_range to use min_extents and get_int.
  • +66/-12 
    Tests
    test_resize.cpp
    Update test for extent simplification                                       

    tests/cpp/test_resize.cpp

  • Updated test case SliceExtentSimplification to expect fmin instead of
    fmax.
  • +2/-2     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Simplification Logic

    The added simplification logic for expressions like (-x) + x and x + (-x) to 0 should be validated to ensure it does not introduce any unintended side effects or incorrect simplifications.

    // Simplify (-x) + x to 0
    if (auto uop = dynamic_cast<UnaryOp*>(lhs->definition()); uop != nullptr &&
        uop->getUnaryOpType() == UnaryOpType::Neg && uop->in()->sameAs(rhs)) {
      return lhs->fusion()->zeroVal(lhs->dtype());
    }
    // Simplify x + (-x) to 0
    if (auto uop = dynamic_cast<UnaryOp*>(rhs->definition()); uop != nullptr &&
        uop->getUnaryOpType() == UnaryOpType::Neg && uop->in()->sameAs(lhs)) {
      return lhs->fusion()->zeroVal(lhs->dtype());
    }
    Extent Simplification

    The new extent simplification logic, especially the min_extents function, should be thoroughly tested to ensure it correctly handles all edge cases and does not introduce regressions.

    // Specialized min for extents. Do some more simplification beyond
    // SimplifyingIrBuilder that are only valid for extents.
    const auto min_extents = [&](Val* x, Val* y) -> Val* {
      auto x_int = get_int(x);
      auto y_int = get_int(y);
      // Since extents are never negative, if one is 0, that must be the mininum.
      if (x_int == 0) {
        return x;
      } else if (y_int == 0) {
        return y;
      }
      // Simplify patterns like min(min(x, 32), 32) to min(x, 32) as it
      // isn't uncommon.
      auto bop = dynamic_cast<BinaryOp*>(x->definition());
      if (y_int != std::nullopt && bop != nullptr &&
          bop->getBinaryOpType() == BinaryOpType::Min) {
        if (auto lhs_int = get_int(bop->lhs()); lhs_int != std::nullopt) {
          return SimplifyingIrBuilder::minExpr(
              bop->rhs(), IrBuilder::create<Val>(std::min(*lhs_int, *y_int)));
        } else if (auto rhs_int = get_int(bop->rhs()); rhs_int != std::nullopt) {
          return SimplifyingIrBuilder::minExpr(
              bop->lhs(), IrBuilder::create<Val>(std::min(*rhs_int, *y_int)));
        }
      }
    
      return SimplifyingIrBuilder::minExpr(x, y);
    };
    
    const auto normalize_slice_range =
        [&manual_normalization, &min_extents, &get_int](
            Slice range, Val* extent) -> Slice {
      auto cast_extent =
          SimplifyingIrBuilder::maybeCastExpr(DataType::Index, extent);
    
      auto zero = FusionGuard::getCurFusion()->zeroVal(DataType::Index);
    
      auto start_int = get_int(range.start);
      auto stop_int = get_int(range.stop);
    
      // norm_start = max(0, start < 0 ? start + extent : start)
      if (range.start == nullptr) {
        range.start = zero;
        start_int = 0;
      } else if (start_int != 0) {
        range.start =
            SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.start);
        if (!manual_normalization) {
          range.start = SimplifyingIrBuilder::maxExpr(
              zero,
              SimplifyingIrBuilder::whereExpr(
                  SimplifyingIrBuilder::ltExpr(range.start, zero),
                  SimplifyingIrBuilder::addExpr(range.start, cast_extent),
                  range.start));
        }
        start_int = get_int(range.start);
      }
    
      // norm_stop = max(norm_start, min(extent, stop < 0 ? stop + extent : stop)
      if (range.stop == nullptr) {
        range.stop = cast_extent;
      } else if (!range.stop->sameAs(extent)) {
        range.stop =
            SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.stop);
        // Commonly, range.start is zero and stop is non negative
        if (start_int == 0 && stop_int >= 0) {
          range.stop = min_extents(cast_extent, range.stop);
        } else {
          if (!manual_normalization) {
            range.stop = SimplifyingIrBuilder::maxExpr(
                range.start,
                min_extents(
                    cast_extent,
                    SimplifyingIrBuilder::whereExpr(
                        SimplifyingIrBuilder::ltExpr(range.stop, zero),
                        SimplifyingIrBuilder::addExpr(range.stop, cast_extent),
                        range.stop)));
          }
        }
    Test Expectation

    The test expectation for the simplified extent in SliceExtentSimplification should be verified to ensure it accurately reflects the expected behavior after the simplification changes.

      //   fmin(i0, 1)
    
      fusion.addOutput(tv1);
    
      auto resize_extent = tv1->axis(0)->extent();
      auto bop = dynamic_cast<BinaryOp*>(resize_extent->definition());
      ASSERT_TRUE(bop != nullptr)
          << "Unexpected resize output extent: " << resize_extent->toInlineString();
      EXPECT_EQ(bop->getBinaryOpType(), BinaryOpType::Min)
          << "Unexpected resize output extent: " << resize_extent->toInlineString();
    }

    @naoyam naoyam marked this pull request as ready for review February 14, 2025 06:42
    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Feb 14, 2025

    !test

    @naoyam naoyam requested a review from jacobhinkle February 14, 2025 06:43
    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Feb 16, 2025

    !test

    @jacobhinkle
    Copy link
    Collaborator

    Maybe it's time for me to revive #511. With runtime info we should be able to fully simplify these expressions at concretization.

    Copy link
    Collaborator

    @jacobhinkle jacobhinkle left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM. Comments are minor.

    Comment on lines +383 to +397
    // Simplify x + (-x) to 0
    Val* x = nullptr;
    auto uop = dynamic_cast<UnaryOp*>(lhs->definition());
    if (uop != nullptr) {
    // lhs may be (-x). Pick rhs as x
    x = rhs;
    } else {
    uop = dynamic_cast<UnaryOp*>(rhs->definition());
    // rhs may be (-x). Pick lhs as x
    x = lhs;
    }
    if (uop != nullptr && uop->getUnaryOpType() == UnaryOpType::Neg &&
    uop->in()->sameAs(x)) {
    return lhs->fusion()->zeroVal(lhs->dtype());
    }
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I think if you had abs(y) + (-abs(y)) this might not catch it because lhs=abs(y) is not a Neg. Instead, what about something like this?

    Suggested change
    // Simplify x + (-x) to 0
    Val* x = nullptr;
    auto uop = dynamic_cast<UnaryOp*>(lhs->definition());
    if (uop != nullptr) {
    // lhs may be (-x). Pick rhs as x
    x = rhs;
    } else {
    uop = dynamic_cast<UnaryOp*>(rhs->definition());
    // rhs may be (-x). Pick lhs as x
    x = lhs;
    }
    if (uop != nullptr && uop->getUnaryOpType() == UnaryOpType::Neg &&
    uop->in()->sameAs(x)) {
    return lhs->fusion()->zeroVal(lhs->dtype());
    }
    // Simplify x + (-x) to 0
    Val* x = nullptr;
    Val* neg_x_in = nullptr;
    if (auto uop = dynamic_cast<UnaryOp*>(lhs->definition()); uop && uop->getUnaryOpType() == UnaryOpType::Neg) {
    // lhs is -x. Pick rhs as x
    neg_x_in = uop->in();
    x = rhs;
    } else if (auto uop = dynamic_cast<UnaryOp*>(rhs->definition()); uop && uop->getUnaryOpType() == UnaryOpType::Neg) {
    // rhs is -x. Pick lhs as x
    neg_x_in = uop->in();
    x = lhs;
    }
    if (x != nullptr && neg_x_in->sameAs(x)) {
    return lhs->fusion()->zeroVal(lhs->dtype());
    }

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    That's right, thanks. My initial version was actually exactly like your suggested code, but clang-tidy didn't like it because of the repetition of the condition expressions. Let me try rewriting the code again.

    const auto normalize_slice_range = [&manual_normalization](
    Slice range, Val* extent) -> Slice {
    const auto get_int = [](Val* x) -> std::optional<int64_t> {
    if (x != nullptr && x->isConstInt()) {
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Minor nit: if you initialize an ExpressionEvaluator in this function I think you could use expr_eval.evaluate(x) to return a PolymorphicValue and check pv.hasValue() instead of using optional<int64_t>. Then you could avoid the isConstInt() condition here which will redundantly evaluate the val and replace it with isIntegralScalar().

    Comment on lines +812 to +814
    if (range.start->isConstInt()) {
    start_int = range.start->evaluate().as<int64_t>();
    }
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    See comment above if we have an expr_eval. Not important but will avoid re-computing these vals.

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Feb 20, 2025

    !test

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Feb 20, 2025

    !test

    @naoyam naoyam merged commit 93fdc32 into main Feb 20, 2025
    54 checks passed
    @naoyam naoyam deleted the simplify_resize_extents branch February 20, 2025 06:17
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants