Skip to content

Patch allocation logic to produce outputs with correct logical size#5170

Closed
jjsjann123 wants to merge 15 commits intomainfrom
jj/layout_op_PR3_allocation_patch
Closed

Patch allocation logic to produce outputs with correct logical size#5170
jjsjann123 wants to merge 15 commits intomainfrom
jj/layout_op_PR3_allocation_patch

Conversation

@jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Sep 17, 2025

Motivation:

We want to use allocation domain to represent padding logic. This is/will be used by swizzle layout for block scaling factor in quantization, as well as more complex padding used in grouped matmul.

For a simple example, see the added cpp test LogicalAndAllocationSizes

Context

There are two cases where padding is represented at allocation:

Case 1: non-divisible split:

For a given 2d TensorView with logical domain [i0, i1]. If we split i1 to construct its allocation domain as [i0, ceilDiv(i1, 16), 16].
The implication is that we'll be allocating "extra" spaces to have the inner dimension to be padded to an extent of multiple of 16. This means the stride of the TensorView as [ceilDiv(i1, 16) * 16, 1].

However, since the given TensorView still has a logical domain of [i0, i1], which means we should produce a tensor of shape [i0, i1] to avoid feeding consumer a tensor with mismatch shapes.

Case 2: unconnected allocation & logical domain:

For PreprocessGroupedMatmulInputSf, because the allocation domain and logical domain are not connected via any ID operations, but rather we rely on arithmetic operations on their extent directly to compute the proper buffer size. We cannot project the allocation domain to logical domain to compute its size/stride. In which case, we simply slice a section of the allocated buffer with trivial strides (contiguous) to avoid causing any issue with shape validation. It is safe to do so, since the consumer of this TensorView will not index it through its stride.

What's in this PR:

Fixes in allocation.cpp:

  1. Existing allocation logic projects allocation domain to logical domain doesn't remove the added extent from ceilDiv, leading to computing wrong logical sizes. e.g. with the above example, we'll be producing logical size of [i0, ceilDiv(i1, 16) * 16] instead of [i0, i1].
  2. allocateOutputs needs to allocate the buffer using allocation sizes and strides (so we'll get enough buffer space) and then slice out the correct extent.
  3. A minor bug fix on stride computation.

Changes in vectorize_helper.cpp and csrc/multidevice/utils.cpp[.h]:

Relax the assert on vectorization analysis to allow non-device split.

cpp test verify the updated behavior.

@jjsjann123
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Sep 17, 2025

Review updated until commit db304bd

Description

  • Fix device split validation logic in layout operations

  • Improve allocation and stride computation for padded tensors

  • Add proper handling of logical vs. allocation domains

  • Enhance tensor reshaping with correct extent slicing


Changes walkthrough 📝

Relevant files
Bug fix
utils.cpp
Update device split validation logic                                         

csrc/multidevice/utils.cpp

  • Replace validateDeviceSplit with isValidateDeviceSplit returning bool
  • Use NVF_ERROR with validation check in
    projectShardedAllocationToLogical
  • Apply same validation in projectLogicalToShardedAllocation
  • +18/-17 
    utils.h
    Update header for split validation                                             

    csrc/multidevice/utils.h

    • Update isValidateDeviceSplit declaration to return bool
    +1/-1     
    Enhancement
    allocations.cpp
    Improve allocation and stride handling                                     

    csrc/runtime/allocations.cpp

  • Allocate based on allocation sizes/strides if available
  • Apply as_strided_ to match logical shape/strides
  • Fix stride computation order in getShapeAndStrideAfterDimMerged
  • Slice padded dimensions using evaluated extent in backward traversal
  • +43/-9   
    vectorize_helper.cpp
    Fix vectorization with device splits                                         

    csrc/scheduler/vectorize_helper.cpp

  • Replace validateDeviceSplit with isValidateDeviceSplit check
  • Break early if non-device split detected
  • Prevent vectorization on padded dimensions
  • +10/-1   
    Tests
    test_layout_op.cpp
    Add test for allocation logic                                                       

    tests/cpp/test_layout_op.cpp

  • Add test for logical vs allocation sizes
  • Validate output shape matches input despite padding
  • Check strides reflect padding in allocation
  • +38/-0   
    Formatting
    test_low_precision_recipe.cpp
    Fix test name typo                                                                             

    tests/cpp/test_low_precision_recipe.cpp

    • Fix typo in test name: 'Ouput' → 'Output'
    +1/-1     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The logic for computing strides in getShapeAndStrideAfterDimMerged may produce incorrect results due to a change in the order of operations. The updated code multiplies prod after assigning it to tensor_new_strides[i], which differs from the original behavior where multiplication occurred before assignment. This could lead to incorrect stride values.

    std::vector<int64_t> tensor_new_strides(tensor_new_shape.size(), 1);
    int64_t prod = 1;
    for (int i = static_cast<int>(tensor_new_shape.size()) - 1; i >= 0; --i) {
      tensor_new_strides[i] = prod;
      prod *= tensor_new_shape[i];
    }
    Performance Concern

    The function isValidateDeviceSplit is called multiple times across different files to validate device splits. However, each call involves dynamic checks that could be optimized or cached if the validation results are reused, potentially impacting performance in hot paths.

    bool isValidateDeviceSplit(Expr* expr) {
      if (expr == nullptr || !expr->isA<Split>()) {
        return false;
      }
      auto* split = expr->as<Split>();
      if (split == nullptr || !split->outer()->isDeviceDim() ||
          split->innerSplit()) {
        return false;
      }
      return true;
    }
    Possible Issue

    In ContiguousInnerDimensionsMapper::getContigMergeOfInnerSize, the early exit condition based on only_valid_device_split may prevent further processing even when some splits are valid, potentially leading to missed optimization opportunities or incorrect vectorization decisions.

    // Get the logical ID corresponding to the allocation ID.
    auto exprs = DependencyCheck::getAllExprsBetween(
        {of_tv->getLogicalDomain().begin(), of_tv->getLogicalDomain().end()},
        {alloc_iid});
    IterDomain* logical_id = alloc_iid;
    Val* num_devices = of_tv->container()->oneVal();
    bool only_valid_device_split = true;
    for (Expr* expr : exprs | std::views::reverse) {
      if (!isValidateDeviceSplit(expr)) {
        only_valid_device_split = false;
        break;
      }
      auto* split = expr->as<Split>();
      logical_id = split->in();
      num_devices = SimplifyingIrBuilder::mulExpr(num_devices, split->factor());
    }
    
    // Non device split could lead to padding, which prevents vectorization
    if (!only_valid_device_split) {
      break;
    }
    
    // Mapping order isn't correct, cannot expand vectorization dimension.
    if (projected_dims[--projected_dims_i] != logical_id) {
      break;
    }

    if (!only_valid_device_split) {
    break;
    }

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    @Priya2698 would you mind double check the changes here, as well as the two files under csrc/multidevice/utils.h[.cpp].

    I'm trying to relax it to allow using allocation domain to represent padding.

    for (int i = static_cast<int>(tensor_new_shape.size()) - 1; i >= 0; --i) {
    prod *= tensor_new_shape[i];
    tensor_new_strides[i] = prod;
    prod *= tensor_new_shape[i];
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    @protonu this seems to be a bug in the old code. we are not computing the correct stride.

    The existing swizzle test doesn't check correctness. That's why it's not caught earlier. (hit an error in the added LogicalAndAllocationSizes test, because as_strided call with allocation size/strides was out of bound)

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Thanks for catching this!

    @jjsjann123 jjsjann123 changed the title Jj/layout op pr3 allocation patch Patch allocation logic to produce outputs with correct logical size Sep 17, 2025
    @jjsjann123
    Copy link
    Collaborator Author

    !test

    @jjsjann123 jjsjann123 marked this pull request as ready for review September 17, 2025 09:52
    @jjsjann123 jjsjann123 requested review from Priya2698, naoyam and protonu and removed request for Priya2698 September 17, 2025 09:52
    std::set<IterDomain*> logical_set(logical.begin(), logical.end());
    if (frontier_set != logical_set) {
    return tensor;
    std::vector<int64_t> logical_sizes(logical.size(), 0);
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Looks like the comment above needs to be updated.

    Can you please remind me when this condition hits and what we were doing previously? Was it just returning an incorrect tensor?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    The comment above still holds.

    This is the case we have for PreprocessGroupedMatmulInputSf output, where the logical domain to allocation domain is not connected and they are not mapped.

    logical domain [ i0, i1 ]
    allocation domain [ i0 + (i_group - 1) * 127, (i1 + 4 - 1) / 4 * 4 ]

    Note that allocation domain extent were directly operated on, instead of using resize op.

    Fuser/csrc/ops/indexing.cpp

    Lines 333 to 364 in f722efc

    // Note: output logical domain handles potential padding required for the
    // layout. Since the actual padding size is data-dependent, we allocate for
    // the maximum padding (reflected on logical/allocation domain).
    // NOTE: We could use resize operations for the padding logic, I think this
    // might simplify predication. Not doing that for now for simpler
    // implementation. We'll re-evaluate when we add scheduler support.
    // pad row size: num_groups * (row_multiple - 1) + row_size
    auto pad_to_max_extent = [&](IterDomain* id, int multiple) -> IterDomain* {
    auto* maximum_pad_value_per_group =
    IrBuilder::create<Val>(multiple - 1, DataType::Index);
    Val* padded_ext = SimplifyingIrBuilder::addExpr(
    id->extent(),
    SimplifyingIrBuilder::mulExpr(num_groups, maximum_pad_value_per_group));
    return IterDomainBuilder(id).extent(padded_ext).build();
    };
    out_alloc_dom.push_back(pad_to_max_extent(out_logical_dom[0], row_multiple));
    // pad col size: (col_size + col_multiple - 1) / col_multiple * col_multiple
    auto pad_to_multiple = [&](IterDomain* id, int multiple) -> IterDomain* {
    Val* ext = id->extent();
    auto* multiple_val = IrBuilder::create<Val>(multiple, DataType::Index);
    Val* padded_ext = SimplifyingIrBuilder::mulExpr(
    SimplifyingIrBuilder::divExpr(
    SimplifyingIrBuilder::subExpr(
    SimplifyingIrBuilder::addExpr(ext, multiple_val), one_val),
    multiple_val),
    multiple_val);
    return IterDomainBuilder(id).extent(padded_ext).build();
    };
    out_alloc_dom.push_back(pad_to_multiple(out_logical_dom[1], col_multiple));

    The implication here is that, frontier_set and projected logical_set are not equivalent.

    The behavior before the change is to return the allocated buffer directly, because we allocate use allocation domain, we are returning a tensor with padded size.
    The updated behavior is to return a slice into the allocated buffer so that the output aten tensor would have a shape that's consistent with its logical domain.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    The comment above still holds.

    I meant this part: give up on producing right shape/stride. I thought that's now fixed by this PR, no?

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I'm still trying to understand when this happens.

    Looks like the frontier is always updated with the split input ID no matter if it's divisible or not.

    // update frontier
        if (inner_dim < outer_dim) {
          *inner_it = in;
          frontier_.erase(outer_it);
        } else {
          *outer_it = in;
          frontier_.erase(inner_it);
        }
    

    Wouldn't that mean the frontier should eventually get to the same IDs as the logical IDs?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I meant this part: give up on producing right shape/stride. I thought that's now fixed by this PR, no?

    Ah, you are right. We are producing right shape, but the stride is still wrong. And there's simply no right stride to index that tensor.

    Wouldn't that mean the frontier should eventually get to the same IDs as the logical IDs?

    If there's ID operations between logical and allocation, then yes.
    The case with PreprocessGroupedMatmulInputSf is that, there's no ID ops between allocation and logical domain at all. We are directly operating on ID extent.
    So the two sets are not related at all.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    For some reason I totally forgot to mention the second case with unconnected logical & allocation domain, but only mentioned the non-divisible split case. 🥹 Apologies on that.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Hmm, so the change here is actually not related to the above change about the non-divisible split case?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Yep I'm hearing the feedback as split this into a separate PR. I'll do that.

    {alloc_iid});
    IterDomain* logical_id = alloc_iid;
    Val* num_devices = of_tv->container()->oneVal();
    bool only_valid_device_split = true;
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Why does device split have anything to do with the padding?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    no device split doesn't have anything to do with padding.

    This is the part I was tagging Priya on. The earlier fix we have on vectorization analysis assumed all Split between logical to allocation domain has to be a device split.

    We are using that for padding now, so I'm extending the analysis to allow non-device split. I'm staying on the conservative side to stop vectorization on split IDs.

    int64_t in_extent = ee_.evaluate(in->extent()).as<int64_t>();
    if (areDimsToBeMergedContiguous(tensor_, new_shape)) {
    tensor_ = tensor_.view(new_shape);
    if (in_extent != tensor_.size(left)) {
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I'd like to have more extensive unit tests for this case since this seems like depending on rather intricate behavior. For example, allocation domains with not just one split, but with multiple (both inner and outer) splits and merges.

    Indexing and predication with non-divisible splits is tricky, and I had a lot of issues. Some of the stuff may be just hidden with some implicit assumptions, and this case is certainly breaking some of the common assumptions like global tensors are always allocated based off logical IDs.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I totally agree with you that this is a delicate topic.

    I was trying to avoid pulling off that right now because I don't know how deep a rabbit hole this is going to be. 😉
    Let me try adding some more interesting examples to see if we can prefetch some issues.

    assumptions like global tensors are always allocated based off logical IDs.

    That's the part that doesn't sound right to me.

    I understand the old assumption that logical and allocation domains are consistent (exact coverage).

    But if we break that assumption. Conceptually, logical domain is used for predication and allocation domain is used for indexing. So the actually buffer should be allocated based on allocation domain instead, otherwise, there could be out-of-bound access.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    That's why I'm suggesting adding more tests. I'm not aware of any logic that is based on some implicit assumptions, but this pattern is not common and not well tested, which isn't ideal.

    !split->innerSplit(),
    "Inner split by device dimension is not supported: ",
    expr->toString());
    bool isValidateDeviceSplit(Expr* expr) {
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Why is this related?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This is used by vectorization analysis.

    The old behavior only asserts, I don't want to wrap a try/catch on that so decided to change the function to return boolean for the check result and have the call site assert on the return when necessary.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I guess my confusion is why we would need to change this for fixing the allocation?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This is an orthogonal issue.

    We don't need to change this for fixing the computed allocation sizes. But the cpp example added in this PR would trigger assert on ToT.

      auto inp = makeSymbolicTensor(2);
      fusion.addInput(inp);
      auto out = set(inp);
      fusion.addOutput(out);
      // padding output to multiple of 16
      out->split(1, 16);
      out->setAllocationDomain(out->getLoopDomain(), true);
      // restore loop domain
      out->merge(1);
    

    So related changes under multi-device and vectorization analysis is just to patch that.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I'd say the allocation fix itself is a significant change and deserves its own PR, which would be easier to review.

    @jjsjann123
    Copy link
    Collaborator Author

    closing this one, since it's now split up as #5184 #5185 #5186

    @jjsjann123 jjsjann123 closed this Sep 18, 2025
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    3 participants