Skip to content

refactor getAllocationDomainsAndContiguity#4792

Merged
liqiangxl merged 1 commit intomainfrom
llu/refactor_getAllocationDomainsAndContiguity
Jul 22, 2025
Merged

refactor getAllocationDomainsAndContiguity#4792
liqiangxl merged 1 commit intomainfrom
llu/refactor_getAllocationDomainsAndContiguity

Conversation

@liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Jul 17, 2025

Changes:
Refactor getAllocationDomainsAndContiguity

  • Extracted function canUsePresetAllocationDomain
  • Extracted function usePresetAllocationDomain

Motivations:
getAllocationDomainsAndContiguitymainly does 3 things:

  • Determine whether the preset allocation domain can be reused
  • Check how to reuse it if possible
  • set a new allocation domain if can't use the preset domain
    Each section contains many deeply nested if-else branches, making it hard to follow and maintain.
    This refactor abstracts the first two tasks into two helper functions.

After refactor, the logic of getAllocationDomainsAndContiguity is:

if(canUsePresetAllocationDomain()){
    return usePresetAllocationDomain()
}
Derive domains that should be allocated from loop domain & allocation position.

@github-actions
Copy link

github-actions bot commented Jul 17, 2025

Review updated until commit 3a06741

Description

  • Refactored getAllocationDomainsAndContiguity to use helper functions.

  • Added canUsePresetAllocationDomain to determine if preset allocation domain can be used.

  • Added usePresetAllocationDomain to select domains from allocation domain that should be allocated.

  • Simplified logic by extracting conditions and helper functions.


Changes walkthrough 📝

Relevant files
Enhancement
allocation.cpp
Refactor `getAllocationDomainsAndContiguity` with helper functions

csrc/device_lower/pass/allocation.cpp

  • Added canUsePresetAllocationDomain function.
  • Added usePresetAllocationDomain function.
  • Refactored getAllocationDomainsAndContiguity to use new helper
    functions.
  • Simplified and cleaned up existing logic.
  • +229/-169

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The logic for determining whether to use the preset allocation domain (canUsePresetAllocationDomain) seems complex and may have edge cases that are not covered. It is important to ensure that all conditions are correctly evaluated and that no valid allocation domains are incorrectly excluded.

    bool canUsePresetAllocationDomain(TensorView* tv) {
      if (!tv->hasAllocation()) {
        return false;
      }
      // Honor the allocation domain if the tensor is global or Hopper MMA's
      // output
      if (tv->getMemoryType() == MemoryType::Global ||
          (tv->definition()->isA<MmaOp>() &&
           isHopper(tv->definition()->as<MmaOp>()->macro()))) {
        return true;
      }
      // If it's a shared memory tensor, the set domain is likely
      // valid if Swizzle or Bulk is used. Also, if the allocation
      // domain is just a permutation of the loop domain, use the
      // set allocation domain. This seems to happen only with
      // AllocationDomainTest.TransposedIntermediate.
      if (tv->getMemoryType() == MemoryType::Shared) {
        if (std::any_of(
                tv->getAllocationDomain().begin(),
                tv->getAllocationDomain().end(),
                [](IterDomain* allocation_domain) {
                  return dynamic_cast<Swizzle*>(
                             allocation_domain->definition()) != nullptr ||
                      allocation_domain->getParallelType() == ParallelType::Bulk;
                }) ||
            std::is_permutation(
                tv->getLoopDomain().begin(),
                tv->getLoopDomain().end(),
                tv->getAllocationDomain().begin(),
                tv->getAllocationDomain().end())) {
          return true;
        }
    
        // Honor the set allocation domain if the tensor is used by a
        // TMA store or MmaOp
        if (std::ranges::any_of(tv->uses(), [](Expr* expr) {
              return ir_utils::isCpAsyncBulkStore(expr) || expr->isA<MmaOp>();
            })) {
          return true;
        }
      }
      return false;
    Code Duplication

    The logic for handling allocation domains when the preset domain is not used (usePresetAllocationDomain and the else block in getAllocationDomainsAndContiguity) seems to overlap with the old logic. This could lead to inconsistencies and should be reviewed to ensure that the refactoring does not introduce bugs.

       if (tv->getMemoryType() == MemoryType::Global) {
         auto allocation_domains = tv->getAllocationDomain();
         auto contiguity = tv->domain()->contiguity();
         NVF_ERROR(allocation_domains.size() == contiguity.size());
         return {allocation_domains, contiguity};
       }
    
       // Get allocation position and collect excluded loop domains
       // For example:
       // T2_s_float[iS6{2}, iS11{3}, iS12{4}, iB8{16}] ca_pos( 2 )
       // iS6{2}, iS11{3} are excluded.
       int64_t allocation_pos =
           lower_utils::getAllocPosInfo(tv, for_loops).alloc_pos;
       auto exclude_ca_ids = collectExcludedLoopDomains(tv, allocation_pos);
    
       // Process allocation domains
       std::vector<IterDomain*> allocation_domains;
       std::vector<std::optional<bool>> contiguity;
    
       for (auto i : arange(tv->getAllocationDomain().size())) {
         auto id = tv->getAllocationDomain()[i];
    
         // Excluded based on allocation position
         IterDomain* excluded_id = getExcludedAllocationDomain(id, exclude_ca_ids);
         if (excluded_id != nullptr) {
           exclude_ca_ids.erase(excluded_id);
           continue;
         }
         // Excluded based on memory partitioning
         if (ir_utils::isMemoryPartitionedAcross(
                 tv->getMemoryType(), id->getParallelType())) {
           continue;
         }
         allocation_domains.push_back(id);
         contiguity.push_back(tv->domain()->contiguity()[i]);
       }
    
       // Verify all excluded domains were found
       NVF_ERROR(
           exclude_ca_ids.empty(),
           "The non-allocating compute-at IDs are not found in the allocation "
           "domain. ",
           "It is unclear how to allocate the tensor: ",
           tv->toString(),
           " allocation domain: ",
           ir_utils::toString(tv->getAllocationDomain()));
    
       NVF_ERROR(allocation_domains.size() == contiguity.size());
       return {allocation_domains, contiguity};
     }
    
    public:
     using IrVisitor::dispatch;
    
     // Set allocation domain info for all tensors
     void setup(const std::vector<Expr*>& exprs) {
       // Find out correct allocation domains for all consumer
       // tensors. Input tensors are handled after this
       for (auto expr : exprs) {
         dispatch(expr);
       }
    
       // Make sure all tensors have allocation domains
       for (TensorView* producer_tv : used_as_producer) {
         auto it = tv_alloc_info_map.find(producer_tv);
         if (it != tv_alloc_info_map.end()) {
           continue;
         }
    
         // Not yet set. This must be an input tensor or it must be aliased via
         // aliasTensorProducer, in which case it will not be allocated.
         NVF_ERROR(
             producer_tv->isFusionInput() ||
                 GpuLower::current()->getTensorProducerAlias(producer_tv) !=
                     nullptr,
             "Expected a fusion input or aliased tensor but found: ",
             producer_tv->toString());
    
         // For fusion input, we can just use getMaybeAllocationDomain.
    
         auto alloc_info = getAllocationDomainInfo(
             producer_tv,
             producer_tv->getMaybeAllocationDomain(),
             producer_tv->domain()->contiguity());
    
         tv_alloc_info_map.emplace(producer_tv, alloc_info);
       }
     }
    
     void dispatch(Expr* expr) override {
       if (ir_utils::isTvOp(expr)) {
         for (auto out_tv : ir_utils::filterByType<TensorView>(expr->outputs())) {
           // Note that since we are dealing with a Kernel IR, a single
           // tensor may show up as consumers multiple times, e.g.,
           // zero initialization and actual definition. Using the last
           // expr should give us correct allocation info. See
           // IndexingTest.InlinedUnroll for a concrete
           // example. Specifically, the initization expression of t2
           // doesn't have an unrolling loop, so the allocation info
           // obtained from that expression would fail to give the
           // correct allocation domains.
           auto [alloc_domains, contiguity] =
               getAllocationDomainsAndContiguity(out_tv, for_loops_);
           auto alloc_info =
               getAllocationDomainInfo(out_tv, alloc_domains, contiguity);
           tv_alloc_info_map[out_tv] = alloc_info;
         }
         for (auto in_tv : ir_utils::filterByType<TensorView>(expr->inputs())) {
           used_as_producer.insert(in_tv);
         }
       } else {
         IrVisitor::dispatch(expr);
       }
     }
    
     // Get the allocation domains and contiguity of a given tensor
     //
     // TODO: Ideally, all tensors should have their correct allocation
     // domains, but that isn't always the case at this moment. The logic
     // here is duplicated in multiple locations and should be cleaned up.
     std::pair<std::vector<IterDomain*>, std::vector<std::optional<bool>>>
     getAllocationDomainsAndContiguity(
         TensorView* tv,
         const std::vector<ForLoop*>& for_loops) {
       if (canUsePresetAllocationDomain(tv)) {
         return usePresetAllocationDomain(tv, for_loops);
       }
       std::vector<IterDomain*> allocation_domains;
       std::vector<std::optional<bool>> contiguity;
    
       // If allocation domain is not set, assume that:
       // - Global: logical domains
       // - Local/Shared: loop domains to the right of the CA position
       if (tv->getMemoryType() == MemoryType::Global) {
         allocation_domains = tv->getLogicalDomain();
         contiguity = tv->domain()->contiguity();
       } else {
         int64_t allocation_pos =
             lower_utils::getAllocPosInfo(tv, for_loops).alloc_pos;
         for (const auto i : arange(tv->nDims())) {
           auto loop_id = tv->getLoopDomain().at(i);
           auto pt = loop_id->getParallelType();
    
           // If the position is left of the inlining position, no need to
           // allocate the domain unless it's shared. For example, if this
           // is a Shared tensor and the domain is parallelized with TID,
           // even if it's outside of the CA position, since the domain
           // is shared, it must be allocated.
           if (i < allocation_pos &&
               !ir_utils::isMemorySharedAcross(tv->getMemoryType(), pt)) {
             continue;
           }
    
           allocation_domains.push_back(loop_id);
         }
         // Assume Local and Shared are always fully contiguous
         contiguity =
             std::vector<std::optional<bool>>(allocation_domains.size(), true);
       }
    
       if (auto indexed_alloc_dom =
               patchAllocationOfIndexedProducerTensor(tv, allocation_domains);
           indexed_alloc_dom.has_value()) {
         allocation_domains = indexed_alloc_dom.value();
         // Make sure the original allocation domains are fully contiguous
         NVF_ERROR(std::all_of(contiguity.begin(), contiguity.end(), [](auto b) {
           return b.has_value() && b.value();
         }));
         // Set the new allocation domains fully contiguous
         contiguity =
             std::vector<std::optional<bool>>(allocation_domains.size(), true);
       }
    
    Performance Impact

    The refactoring introduces new functions and logic, which could potentially impact performance. It is important to measure the performance impact of these changes and ensure that they do not degrade the performance of the allocation process.

    // domains, but that isn't always the case at this moment. The logic
    // here is duplicated in multiple locations and should be cleaned up.
    std::pair<std::vector<IterDomain*>, std::vector<std::optional<bool>>>
    getAllocationDomainsAndContiguity(
        TensorView* tv,
        const std::vector<ForLoop*>& for_loops) {
      if (canUsePresetAllocationDomain(tv)) {
        return usePresetAllocationDomain(tv, for_loops);
      }
      std::vector<IterDomain*> allocation_domains;
      std::vector<std::optional<bool>> contiguity;
    
      // If allocation domain is not set, assume that:
      // - Global: logical domains
      // - Local/Shared: loop domains to the right of the CA position
      if (tv->getMemoryType() == MemoryType::Global) {
        allocation_domains = tv->getLogicalDomain();
        contiguity = tv->domain()->contiguity();
      } else {
        int64_t allocation_pos =
            lower_utils::getAllocPosInfo(tv, for_loops).alloc_pos;
        for (const auto i : arange(tv->nDims())) {
          auto loop_id = tv->getLoopDomain().at(i);
          auto pt = loop_id->getParallelType();
    
          // If the position is left of the inlining position, no need to
          // allocate the domain unless it's shared. For example, if this
          // is a Shared tensor and the domain is parallelized with TID,
          // even if it's outside of the CA position, since the domain
          // is shared, it must be allocated.
          if (i < allocation_pos &&
              !ir_utils::isMemorySharedAcross(tv->getMemoryType(), pt)) {
            continue;
          }
    
          allocation_domains.push_back(loop_id);
        }
        // Assume Local and Shared are always fully contiguous
        contiguity =
            std::vector<std::optional<bool>>(allocation_domains.size(), true);
      }
    
      if (auto indexed_alloc_dom =
              patchAllocationOfIndexedProducerTensor(tv, allocation_domains);
          indexed_alloc_dom.has_value()) {
        allocation_domains = indexed_alloc_dom.value();
        // Make sure the original allocation domains are fully contiguous
        NVF_ERROR(std::all_of(contiguity.begin(), contiguity.end(), [](auto b) {
          return b.has_value() && b.value();
        }));
        // Set the new allocation domains fully contiguous
        contiguity =
            std::vector<std::optional<bool>>(allocation_domains.size(), true);
      }
    
      // reorderAllocationDomains and
      // patchAllocationOfTransposedSmemTensor assume unallocated IDs
      // are removed
      std::vector<IterDomain*> actual_allocation_ids;
      std::vector<std::optional<bool>> actual_contiguity;
      for (auto [i, id] : enumerate(allocation_domains)) {
        if (mayRequireAllocation(tv, id)) {
          actual_allocation_ids.push_back(id);
          actual_contiguity.push_back(contiguity.at(i));
        }
      }
      std::swap(allocation_domains, actual_allocation_ids);
      std::swap(contiguity, actual_contiguity);
    
      if (auto reordered_domains =
              reorderAllocationDomains(tv, allocation_domains);
          reordered_domains.has_value()) {
        allocation_domains = reordered_domains.value();
        NVF_ERROR(
            std::all_of(
                contiguity.begin(),
                contiguity.end(),
                [](auto b) { return b.has_value() && b.value(); }),
            tv->toString());
      }
    
      // WAR for transpose
      if (auto transposed_smem_alloc_dom = patchAllocationOfTransposedSmemTensor(
              tv,
              allocation_domains,
              GpuLower::current()->idModel().idGraph(IdMappingMode::EXACT));
          transposed_smem_alloc_dom.has_value()) {
        allocation_domains = transposed_smem_alloc_dom.value();
        // Make sure the original allocation domains are fully contiguous
        NVF_ERROR(std::all_of(contiguity.begin(), contiguity.end(), [](auto b) {
          return b.has_value() && b.value();
        }));
        // Set the new allocation domains fully contiguous
        contiguity =
            std::vector<std::optional<bool>>(allocation_domains.size(), true);
      }
    
      NVF_ERROR(allocation_domains.size() == contiguity.size());
    

    @liqiangxl
    Copy link
    Collaborator Author

    !test --diff

    @liqiangxl liqiangxl requested a review from Copilot July 17, 2025 13:59
    Copy link

    Copilot AI left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Pull Request Overview

    This PR refactors the getAllocationDomainsAndContiguity function by extracting helper functions into an anonymous namespace to improve code organization and readability. The refactoring breaks down the complex logic into smaller, more focused functions without changing the overall functionality.

    Key changes:

    • Extract helper functions into anonymous namespace for better organization
    • Improve code structure by separating concerns into dedicated functions
    • Minor formatting improvements for comments and code alignment
    Comments suppressed due to low confidence (2)

    csrc/device_lower/pass/allocation.cpp:166

    • [nitpick] Function name should follow camelCase convention. Consider renaming to 'canUsePresetAllocationDomain' or use snake_case if that's the project convention.
        bool canUsePresetAllocationDomain(TensorView * tv) {
    

    csrc/device_lower/pass/allocation.cpp:227

    • [nitpick] Function name suggests a boolean return but returns a pointer. Consider renaming to 'getExcludedAllocationDomain' or 'findExcludedAllocationDomain' to better reflect the actual return type.
        IterDomain* shouldExcludeAllocationDomain(
    

    @liqiangxl
    Copy link
    Collaborator Author

    !test --diff

    1 similar comment
    @liqiangxl
    Copy link
    Collaborator Author

    !test --diff

    @liqiangxl liqiangxl force-pushed the llu/refactor_getAllocationDomainsAndContiguity branch from 040e349 to 3a06741 Compare July 17, 2025 18:14
    @liqiangxl
    Copy link
    Collaborator Author

    !test --diff

    @liqiangxl
    Copy link
    Collaborator Author

    code diff is ptx register names
    MmaTest/HopperSS.MultipleTile/Hopper_64_80_16_TN_NoSwizzle_64B__bfloat
    MmaTest/HopperSS.SingleTileTransposed/Hopper_64_232_16_NN_NoSwizzle_128B__bfloat

    @liqiangxl liqiangxl requested review from jjsjann123 and naoyam July 17, 2025 21:42
    @liqiangxl liqiangxl marked this pull request as ready for review July 17, 2025 21:42
    usePresetAllocationDomain(
    TensorView* tv,
    const std::vector<ForLoop*>& for_loops) {
    if (tv->getMemoryType() == MemoryType::Global) {
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Is my understanding correct that only changes made to the logic is inside usePresetAllocationDomain?

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I can use a quick context / high level description on what's the motivation behind the refactor.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Is my understanding correct that only changes made to the logic is inside usePresetAllocationDomain?

    Yes, the logic in usePresetAllocationDomain will be changed in a following PR, see #4791.
    This current PR is just a refactor, it didn't change the logic.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I can use a quick context / high level description on what's the motivation behind the refactor.

    The previous function became too large and difficult to maintain. It includes logic to determine whether the preset allocation domain can be reused, how to reuse it if possible, and how to set a new allocation domain if not. The code contains many deeply nested if-else branches, making it hard to follow and maintain.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This current PR is just a refactor, it didn't change the logic.

    Thanks a lot for confirming that. I thought we have changed the logic in usePresetAllocationDomain, but looks like we are just moving some utils around.

    Copy link
    Collaborator

    @jjsjann123 jjsjann123 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    stamp for code restructure.

    🙇

    usePresetAllocationDomain(
    TensorView* tv,
    const std::vector<ForLoop*>& for_loops) {
    if (tv->getMemoryType() == MemoryType::Global) {
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This current PR is just a refactor, it didn't change the logic.

    Thanks a lot for confirming that. I thought we have changed the logic in usePresetAllocationDomain, but looks like we are just moving some utils around.

    @liqiangxl liqiangxl merged commit f533819 into main Jul 22, 2025
    56 of 59 checks passed
    @liqiangxl liqiangxl deleted the llu/refactor_getAllocationDomainsAndContiguity branch July 22, 2025 12:15
    nsarka pushed a commit to nsarka/Fuser that referenced this pull request Jul 28, 2025
    **Changes:**
    Refactor `getAllocationDomainsAndContiguity`
    - Extracted function `canUsePresetAllocationDomain`
    - Extracted function `usePresetAllocationDomain`
    
    
    **Motivations:**
    `getAllocationDomainsAndContiguity`mainly does 3 things:
    
    - Determine whether the preset allocation domain can be reused
    - Check how to reuse it if possible
    - set a new allocation domain if can't use the preset domain
    Each section contains many deeply nested if-else branches, making it
    hard to follow and maintain.
    This refactor abstracts the first two tasks into two helper functions.
    
    
    After refactor, the logic of `getAllocationDomainsAndContiguity` is:
    ```
    if(canUsePresetAllocationDomain()){
        return usePresetAllocationDomain()
    }
    Derive domains that should be allocated from loop domain & allocation position.
    ```
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    3 participants