Skip to content

Use proper extent expression for dynamic reshape extents#4479

Open
jacobhinkle wants to merge 8 commits intomainfrom
jh/dynamic_reshape_neg_one_extent
Open

Use proper extent expression for dynamic reshape extents#4479
jacobhinkle wants to merge 8 commits intomainfrom
jh/dynamic_reshape_neg_one_extent

Conversation

@jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented May 19, 2025

Previously in dynamic reshapes we used the given Vals as the output extents. However, these are not the actual extents since as of #590 it is valid to pass -1 as a dynamic reshape size. This PR replaces a val like i0 with an expression that should be correct no matter what the actual values wind up being. For example:

tv1 = reshape(tv0[ i0, i1 ], new_sizes=[i2, i3, i4])
tv1[
    (i2 == -1) ? (i0 * i1)/ (i3 * i4 == 0 ? 1 : (i3 * i4)) : i2,
    (i3 == -1) ? (i0 * i1)/ (i2 * i4 == 0 ? 1 : (i2 * i4)) : i3,
    (i4 == -1) ? (i0 * i1)/ (i2 * i3 == 0 ? 1 : (i2 * i3)) : i4,
]

This type of expression is indeed verbose but it handles the cases where any of those dynamic scalars is -1 or any of the input dimensions is 0, avoiding division by zero error.

During concretization, these extents should get replaced by the usual multiply/ceilDiv expressions coming from the particular reshape transforms.

Fixes #4476

Related to Issue #249 and PR #590

@jacobhinkle
Copy link
Collaborator Author

!test --diff

@github-actions
Copy link

github-actions bot commented May 19, 2025

Review updated until commit 88166f8

Description

  • Use proper extent expression for dynamic reshape extents

  • Compute numel only when necessary

  • Fix division by zero and other related issues

  • Add tests for dynamic reshape with -1 and other edge cases


Changes walkthrough 📝

Relevant files
Enhancement
alias.cpp
Improve dynamic reshape extent handling                                   

csrc/ops/alias.cpp

  • Added a lambda function neg_one_size to compute numel only when
    necessary
  • Updated reshape logic to use neg_one_size for -1 extent
  • Added explicit casting of new sizes to Index
  • +42/-14 
    Tests
    test_dynamic_transform.cpp
    Update dynamic reshape test cases                                               

    tests/cpp/test_dynamic_transform.cpp

  • Updated test cases to handle dynamic reshape with -1 extent
  • Fixed expected behavior for reshape with -1 extent
  • Added checks for reshape transforms
  • +17/-20 
    test_evaluator.cpp
    Add test for reshape with -1 input                                             

    tests/cpp/test_evaluator.cpp

    • Added a new test case for evaluating reshape with -1 as input
    +24/-0   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Performance Concern

    The computation of numel and other_new_numel is done inside a lambda function that could be called multiple times. Ensure that this computation is efficient and not redundant.

    const auto neg_one_size = [&numel, &inp_dom, &new_sizes](size_t pos) {
      if (numel == nullptr) {
        numel = FusionGuard::getCurFusion()->oneVal();
        for (const auto j : arange(inp_dom.size())) {
          numel = SimplifyingIrBuilder::mulExpr(numel, inp_dom.at(j)->extent());
        }
      }
    
      Val* other_new_numel = FusionGuard::getCurFusion()->oneVal();
      for (const auto j : arange(new_sizes.size())) {
        if (pos == j) {
          continue;
        }
        Val* new_size =
            SimplifyingIrBuilder::maybeCastExpr(DataType::Index, new_sizes.at(j));
        other_new_numel =
            SimplifyingIrBuilder::mulExpr(other_new_numel, new_size);
      }
      // In case numel is 0, other_new_numel might also be 0 and we would hit a
      // division by zero. In such cases, using 1 as the denominator will cause
      // us to properly compute 0 for this extent.
      other_new_numel = SimplifyingIrBuilder::whereExpr(
          eq(other_new_numel, FusionGuard::getCurFusion()->zeroVal()),
          FusionGuard::getCurFusion()->oneVal(),
          other_new_numel);
    
      Val* new_size = SimplifyingIrBuilder::divExpr(numel, other_new_numel);
      NVF_ERROR(new_size->dtype() == DataType::Index);
    Test Case

    The test case DynamicTransform1 now includes a binding for tv1 extents. Verify that these bindings are correct and necessary for the test.

    // input: 4, 3
    // output: 3, -1
    expr_eval.bind(tv0->axis(0)->extent(), 4L);
    expr_eval.bind(tv0->axis(1)->extent(), 3L);
    expr_eval.bind(reshape_shape0, 3L);
    expr_eval.bind(reshape_shape1, -1L);
    // We cannot infer the shape of tv1 from the above bound values, since
    // either axis of tv2 might be broadcast against one from tv1.
    expr_eval.bind(tv1->axis(0)->extent(), 3L);
    expr_eval.bind(tv1->axis(1)->extent(), 4L);
    Test Case

    The test case Issue249InputNegative1 now tests dynamic -1 as input. Ensure that this test case correctly validates the behavior of dynamic reshapes.

    FusionExecutorCache executor_cache(std::move(fusion_ptr));
    
    auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
    at::Tensor at_x = at::randn({2, 3, 4, 5}, options);
    
    // Test that running with dynamic -1 works as expected
    executor_cache.runFusionWithInputs({at_x, 2, 4, -1});
    

    @jacobhinkle
    Copy link
    Collaborator Author

    Looks like zero extents need special handling also. This could mean the extent expressions get complicated.

    @jacobhinkle jacobhinkle marked this pull request as ready for review May 22, 2025 17:10
    @jacobhinkle
    Copy link
    Collaborator Author

    !test

    @jacobhinkle
    Copy link
    Collaborator Author

    !test --diff

    @wujingyue
    Copy link
    Collaborator

    (i3 * i4 == 0 ? 1 : (i3 * i4))

    (Trust by verify)

    Is this necessary? Can we instead error out when i3 * i4 == 0? That's what torch does:

    >>> torch.randn(0, 4).reshape(-1, 0, 2)
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
    RuntimeError: cannot reshape tensor of 0 elements into shape [-1, 0, 2] because the unspecified dimension size -1 can be any value and is ambiguous
    

    @jacobhinkle
    Copy link
    Collaborator Author

    (i3 * i4 == 0 ? 1 : (i3 * i4))

    (Trust by verify)

    Is this necessary? Can we instead error out when i3 * i4 == 0? That's what torch does:

    >>> torch.randn(0, 4).reshape(-1, 0, 2)
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
    RuntimeError: cannot reshape tensor of 0 elements into shape [-1, 0, 2] because the unspecified dimension size -1 can be any value and is ambiguous
    

    Ah, thanks for pointing out the torch behavior. Let me try to guard against that properly. It will simplify these expressions considerably.

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    ViewOp::evaluate with dynamic input

    2 participants