Skip to content

s/reshape/set when no transforms are applied#4056

Open
wujingyue wants to merge 3 commits intomainfrom
wjy/empty
Open

s/reshape/set when no transforms are applied#4056
wujingyue wants to merge 3 commits intomainfrom
wjy/empty

Conversation

@wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Mar 10, 2025

This PR fixes #1691 in a different way and therefore reverts #1692.

When working on #3950, I realize this might be a better fix because it simplifies the replay contract. We don't need to deal with a close match where the concrete logical shape contains more reduction dimensions than the symbolic shape.

@wujingyue
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Mar 10, 2025

Review updated until commit d438d83

Description

  • Simplify reshape logic when no transforms are applied

  • Replace reshape with set when input and output tensors are the same

  • Refactor reshape function for better readability


Changes walkthrough 📝

Relevant files
Enhancement
dynamic_transform.cpp
Simplify logical domain assignment                                             

csrc/dynamic_transform.cpp

  • Removed unnecessary TensorDomain::noReductions call
  • Simplified logical domain assignment
  • +1/-2     
    transform_view.cpp
    Refactor reshape function and use set                                       

    csrc/transform_view.cpp

  • Refactored reshape function for better readability
  • Used std::identity for checking axes
  • Added set call when input and output tensors are the same
  • +24/-19 

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review

    Code Clarity

    The new implementation of the reshape function introduces multiple conditional checks and operations. It would be beneficial to ensure that the logic is clear and that each step is necessary. Consider adding comments to explain the purpose of each conditional block.

    TensorView* out_tv = inp_tv;
    if (std::any_of(
            view_analysis.squeeze_axes.begin(),
            view_analysis.squeeze_axes.end(),
            std::identity())) {
      out_tv = squeeze(out_tv, view_analysis.squeeze_axes);
    }
    
    if (!view_analysis.transforms.empty()) {
      out_tv = applyViewTransforms(inp_tv, out_tv, view_analysis);
    }
    
    if (std::any_of(
            view_analysis.broadcast_axes.begin(),
            view_analysis.broadcast_axes.end(),
            std::identity())) {
      out_tv = broadcast(out_tv, view_analysis.broadcast_axes);
    }
    
    if (out_tv == inp_tv) {
      out_tv = set(inp_tv);
    Redundant Checks

    The check if (out_tv == inp_tv) might be redundant if the previous operations guarantee that out_tv will always be different from inp_tv when transformations are applied. Verify if this condition is necessary or if it can be simplified.

    if (out_tv == inp_tv) {
      out_tv = set(inp_tv);

    @wujingyue wujingyue requested a review from jacobhinkle March 11, 2025 03:21
    @wujingyue wujingyue marked this pull request as ready for review March 11, 2025 03:21
    @wujingyue
    Copy link
    Collaborator Author

    !test --diff

    Comment on lines +976 to +978
    if (out_tv == inp_tv) {
    out_tv = set(inp_tv);
    }
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Just so I understand fully: this part is the only functional change to this file right?

    auto old_logical = incomplete_out_tv->getLogicalDomain();
    auto new_logical =
    TensorDomain::noReductions(concrete_reshape_out_tv->getLogicalDomain());
    auto new_logical = concrete_reshape_out_tv->getLogicalDomain();
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Do we need to update the comment above here?

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    reshape fails with RuntimeError: old_rfactor.size() == new_rfactor.size() INTERNAL ASSERT FAILED at "/workspace/Fuser/csrc/dynamic_transform.cpp":652

    2 participants