Skip to content

replay loop domain transforms to allocation domain#4795

Merged
liqiangxl merged 45 commits intomainfrom
llu/selfReplayLoopToAllocation
Aug 8, 2025
Merged

replay loop domain transforms to allocation domain#4795
liqiangxl merged 45 commits intomainfrom
llu/selfReplayLoopToAllocation

Conversation

@liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Jul 17, 2025

Add selfReplayLoopToAllocation
Assume allocation domain is a permutation of logical domain, then we can use ReplaySelf to replay the loop domain transformations to allocation domain.
The replay happens within ReplaySelf in the following steps:

  1. Given a loop domain, find the logical domain that poduces it.
  2. Map the logical domain to the allocation domain.
  3. Apply the same transformation on the allocation domain.

After replay, reset allocation domain to the transformed version.

Add scheduler_utils::replayLoopToAllocation(fusion)

Two PRs are split from this PR:
(1) refactor getAllocationDomainsAndContiguity #4792
(2) Schedule allocation domain manually and use IdModel to detect mapping between scheduled allocation domain and loop domain. #4791

@github-actions
Copy link

github-actions bot commented Jul 17, 2025

Review updated until commit 1579436

Description

  • Fix allocation domain mismatch for shared memory tensors

  • Replay loop domain transforms to allocation domain

  • Ensure allocation domain consistency with compute-at

  • Add tests for allocation domain transformations


Changes walkthrough 📝

Relevant files
Bug fix
normalization_inner_outer_tma_ws.cpp
Schedule allocation domain for shared memory                         

csrc/scheduler/normalization_inner_outer_tma_ws.cpp

  • Added call to buildAllocationDomainForSharedMemoryTvs after scheduling
  • Ensures shared memory tensors have correct allocation domains
  • Fixes allocation based on transformed loop domains
  • +3/-0     
    Enhancement
    utils.cpp
    Implement allocation domain replay from loop domain           

    csrc/scheduler/utils.cpp

  • Added buildAllocationDomainFromLoopIds to map loop to allocation
    domains
  • Implemented transformation replay using splits and merges
  • Added buildAllocationDomainForSharedMemoryTvs to process all shared
    memory TVs
  • Uses dependency check to find transformation expressions
  • +80/-0   
    utils.h
    Declare allocation domain utility functions                           

    csrc/scheduler/utils.h

  • Declared buildAllocationDomainFromLoopIds function
  • Declared buildAllocationDomainForSharedMemoryTvs function
  • Added documentation for allocation domain utilities
  • Improved API for allocation domain management
  • +11/-0   
    Tests
    test_allocation_domain.cpp
    Add tests for allocation domain transformations                   

    tests/cpp/test_allocation_domain.cpp

  • Added test buildAllocationDomainFromLoopIdsSplit for split cases
  • Added test buildAllocationDomainFromLoopIdsMerge for merge cases
  • Added SmemAllocationDomainChanged test for bank conflict detection
  • Includes validation of allocation domain extents
  • +87/-41 
    test_combined_inner_outer_reduction.cpp
    Add test for allocation domain broadcast bug                         

    tests/cpp/test_combined_inner_outer_reduction.cpp

  • Added AllocationDomainBroadcast test for normalization scheduler
  • Reproduces bug 5374767 with broadcast and reduction
  • Validates fix for non-allocating compute-at IDs
  • Uses warp specialized normalization
  • +32/-0   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The function buildAllocationDomainFromLoopIds assumes the allocation domain is a permutation of the logical domain, but does not handle cases where transformations like broadcast are involved. This could lead to incorrect allocation domain construction when broadcasted dimensions are present.

    void buildAllocationDomainFromLoopIds(TensorView* tv) {
      const auto& logical = tv->getLogicalDomain();
      const auto& alloc = tv->getMaybeAllocationDomain();
      NVF_ERROR(
          std::is_permutation(
              logical.begin(), logical.end(), alloc.begin(), alloc.end()),
          "buildAllocationDomainFromLoopIds expects the allocation domain to be a "
          "permutation of the logical domain");
      const auto& loop = tv->getLoopDomain();
    
      // Get transformation expressions from allocation to loop domain
      auto transform_exprs = DependencyCheck::getAllExprsBetween(
          {alloc.begin(), alloc.end()}, {loop.begin(), loop.end()});
    
      // Track which allocation IDs each transformed ID came from
      std::unordered_map<IterDomain*, std::vector<IterDomain*>> id_to_alloc_sources;
      for (auto alloc_id : alloc) {
        id_to_alloc_sources[alloc_id] = {alloc_id};
      }
      for (auto expr : transform_exprs) {
        if (auto split = dynamic_cast<Split*>(expr)) {
          NVF_ERROR(id_to_alloc_sources.contains(split->in()));
          auto sources = id_to_alloc_sources[split->in()];
          id_to_alloc_sources[split->outer()] = sources;
          id_to_alloc_sources[split->inner()] = sources;
        } else if (auto merge = dynamic_cast<Merge*>(expr)) {
          NVF_ERROR(id_to_alloc_sources.contains(merge->outer()));
          NVF_ERROR(id_to_alloc_sources.contains(merge->inner()));
          auto outer_sources = id_to_alloc_sources[merge->outer()];
          auto inner_sources = id_to_alloc_sources[merge->inner()];
          outer_sources.insert(
              outer_sources.end(), inner_sources.begin(), inner_sources.end());
          id_to_alloc_sources[merge->out()] = std::move(outer_sources);
        } else {
          NVF_ERROR(false, "Unsupported expression type: ", expr->toString());
        }
      }
    
      // Build final allocation domain preserving allocation order
      std::vector<IterDomain*> new_alloc_domain;
      std::unordered_set<IterDomain*> used_loop_ids;
      for (auto alloc_id : alloc) {
        for (auto loop_id : loop) {
          // skip if the loop ID has already been used
          if (used_loop_ids.count(loop_id)) {
            continue;
          }
          // skip if the loop ID is not derived from any allocation ID
          if (!id_to_alloc_sources.contains(loop_id)) {
            continue;
          }
          // skip if the loop ID is not derived from the current allocation ID
          auto& sources = id_to_alloc_sources.at(loop_id);
          if (std::find(sources.begin(), sources.end(), alloc_id) ==
              sources.end()) {
            continue;
          }
          new_alloc_domain.push_back(loop_id);
          used_loop_ids.insert(loop_id);
        }
      }
    
      tv->setAllocationDomain(new_alloc_domain, true);
    Performance Concern

    The call to buildAllocationDomainForSharedMemoryTvs is added unconditionally for all shared memory tensors, which may introduce unnecessary overhead in cases where the allocation domain does not require transformation.

    // replay loop domain transformations to allocation domain for shared memory
    // tensors. Ensure we can allocate based on the allocation domain.
    scheduler_utils::buildAllocationDomainForSharedMemoryTvs(fusion);
    Test Coverage

    The test SmemAllocationDomainChanged verifies bank conflict behavior, but does not validate the correctness of the allocation domain transformation logic itself, leaving a gap in testing the core functionality.

    TEST_F(AllocationDomainTest, SmemAllocationDomainChanged) {
      auto fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
      auto tv0 = makeContigConcreteTensor({512, 32});
      fusion->addInput(tv0);
      std::vector<IterDomain*> tv0_dom = {tv0->axis(1), tv0->axis(0)};
      tv0->setAllocationDomain(tv0_dom, true);
      auto tv2 = add(tv0, tv0);
      fusion->addOutput(tv2);
    
      auto tv1 = tv0->cacheAfter();
      tv1->setMemoryType(MemoryType::Shared);
      for (auto tv : fusion->allTvs()) {
        tv->axis(0)->parallelize(ParallelType::TIDx);
      }
      // smem tensor has allocation domain (32, 512)
      // and loop domain (512(TIDx), 32(S))
      // there is no bank conflict since the index goes to allocation
      // domain where 512 is the inner-most dim.
      ASSERT_TRUE(fusion->bankConflictInfo().empty());
    
      // If we reset its allocation domain to (512, 32) and still keep loop
      // domain as (512(TIDx), 32(S)), then there are bank conflicts, e.g.
      // all threads in a warp access bank-0, then bank-1, then bank-2, etc.
      tv1->setAllocationDomain(tv1->getLoopDomain(), /*new_contiguity=*/true);
      ASSERT_FALSE(fusion->bankConflictInfo().empty());
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA);
      // shape: (x, y), alloc: (y, x), stride: (1, x)
      auto t0 = at::randn({512, 32}, options).as_strided({512, 32}, {1, 512});
      KernelExecutor ke;
      ke.compile(fusion.get(), {t0});
      auto outputs = ke.run({t0});
      testValidate(fusion.get(), outputs, {t0}, __LINE__, __FILE__);
    }

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Copy link
    Collaborator

    @jjsjann123 jjsjann123 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I wonder what's behind the refactor on the allocation lowering pass?

    Comment on lines +1469 to +1477
    IterDomainMap logical_to_alloc_map;
    for (auto logical_id : logical) {
    auto it = std::find(alloc.begin(), alloc.end(), logical_id);
    NVF_ERROR(
    it != alloc.end(),
    "Could not find matching allocation ID for logical ID: ",
    logical_id);
    logical_to_alloc_map[logical_id] = *it;
    }
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    nitpick, if we can compute the permutation, we can directly use that.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    nitpick, if we can compute the permutation, we can directly use that.

    This permutation is for logical domain -- allocation domain, logical domain is usually further transformed to get loop domain, then we may not direclty use the allocation domain as-is.
    For example, in the test added in #4791, we have

      // T2_s_float[iS6{2}, iS11{3}, iS12{4}, iB8{16}] ca_pos( 2 )
      // logical domain : (iS6{2}, iS7{12}, iB8{16})
      // allocation domain : (iS7{12}, iS6{2}, iB8{16})
      // contiguity: t t t
      //  Split: iS7{12} by factor 4 -> iS11{3}, iS12{4}
      // loop domain : (iS6{2}, iS11{3}, iS12{4}, iB8{16})
    
      // T2 is computed at pos 2, we don't need to allocate domains iS6{2} and
      // iS11{3} nvFuser tries to exclude these two domains from the allocation
      // domain, however, iS11{3} doesn't exist in the allocation domain, so it's
      // not excluded and this is considered a failed case.
    

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    sorry, I meant just to refactor how we produce logical_to_alloc_map. i.e. if we compute the permutation order with this

    Fuser/csrc/ir/utils.h

    Lines 661 to 679 in 1353ec7

    std::optional<std::vector<int64_t>> computePermutation(
    const std::vector<T>& in,
    const std::vector<T>& out) {
    // Both std::is_permutation and the rest of this function are O(n^2). This is
    // fine for the current use case of computing the root-to-rfactor
    // permutation. If needed, this can be improved by requiring T to be hashable
    // (leading to O(n)) and/or comparable (leading to O(nlogn)).
    if (!std::is_permutation(in.begin(), in.end(), out.begin(), out.end())) {
    return std::nullopt;
    }
    std::vector<int64_t> permutation;
    permutation.reserve(out.size());
    for (const T& out_element : out) {
    permutation.push_back(std::distance(
    in.begin(), std::find(in.begin(), in.end(), out_element)));
    }
    return permutation;
    }

    It's a nitpick, since the logic there is identical to what you used here. 😉

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Got you! I gave computePermutation a try it returns the permutation index, which is helpful in some contexts. However, in this case, we need the actual IterDomain mapping. I believe the current implementation is a bit more straightforward, as it allows us to directly find the IterDomain without needing to compute the index with std::distance and then retrieve the IterDomain from that. Let me know if I’m missing something!

    // Happens within ReplaySelf in the following steps:
    // 1. Given a loop domain, find the logical domain that poduces it.
    // 2. Map the logical domain to the allocation domain.
    // 3. Do the same transformation on the allocation domain.
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    do we actually need a replay like this? I'm wondering if there's anything blocking us from just setting current loop domain as allocation domain as-is.

    Copy link
    Collaborator Author

    @liqiangxl liqiangxl Jul 22, 2025

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    do we actually need a replay like this? I'm wondering if there's anything blocking us from just setting current loop domain as allocation domain as-is.

    Directly set loop domain as allocation domain may change the allocation domain. When it is a smem tensor, this change leads to bank conflicts. See newly added test SmemAllocationDomainChanged.
    In that case, input tensor has allocation domain 32, 512 and loop domain 512,32, to achieve coalesced load from gmem to smem, we want to parallelized 512 with TIDx. So the parallelized loop domain is 512(Tidx), 32(S). Both loop and allocation domains are passed to shared memory cached input. If we set loop as allocation, then allocation becomes 512(Tidx), 32(S), which has bank conflicts.

      // smem tensor has allocation domain (32, 512)
      // and loop domain (512(TIDx), 32(S))
      // there is no bank conflict since the index goes to allocation
      // domain where 512 is the inner-most dim.
      ASSERT_TRUE(fusion->bankConflictInfo().empty());
    
      // If we reset its allocation domain to (512, 32) and still keep loop
      // domain as (512(TIDx), 32(S)), then there are bank conflicts, e.g.
      // all threads in a warp access bank-0, then bank-1, then bank-2, etc.
      tv1->setAllocationDomain(tv1->getLoopDomain(), /*new_contiguity=*/true);
      ASSERT_FALSE(fusion->bankConflictInfo().empty());
    

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Are you saying it's just ordering? If so, why not just reorder the loop domain?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Are you saying it's just ordering? If so, why not just reorder the loop domain?

    If a tv's loop domain is re-ordered, the inlined position will be influenced and maybe other issues.

    Without reorder: T2_s_float[ithreadIdx.x4{512}, iS5{32}] ca_pos( 2 )

    Inputs:
      T0_g_float[ithreadIdx.x0{512}, iS1{32}]
    Outputs:
      T1_g_float[ithreadIdx.x2{512}, iS3{32}] ca_pos( 2 ) produce_pos( 2 )
    
    %kernel {
    T2_s_float[ithreadIdx.x4{512}, iS5{32}] ca_pos( 2 )
       = Set( T0_g_float[ithreadIdx.x0{512}, iS1{32}], cache_op=Streaming )
    T1_g_float[ithreadIdx.x2{512}, iS3{32}] ca_pos( 2 ) produce_pos( 2 )
       = T2_s_float[ithreadIdx.x4{512}, iS5{32}] ca_pos( 2 )
       + T2_s_float[ithreadIdx.x4{512}, iS5{32}] ca_pos( 2 );
    

    With reorder T2_s_float[iS5{32}, ithreadIdx.x4{512}]

    Inputs:
      T0_g_float[ithreadIdx.x0{512}, iS1{32}]
    Outputs:
      T1_g_float[ithreadIdx.x2{512}, iS3{32}] ca_pos( 2 )
    
    %kernel {
    T2_s_float[iS5{32}, ithreadIdx.x4{512}]
       = Set( T0_g_float[ithreadIdx.x0{512}, iS1{32}], cache_op=Streaming )
    T1_g_float[ithreadIdx.x2{512}, iS3{32}] ca_pos( 2 )
       = T2_s_float[iS5{32}, ithreadIdx.x4{512}]
       + T2_s_float[iS5{32}, ithreadIdx.x4{512}];
    

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Hmm, I'm not sure what you mean by that.

    Here's the example you gave for the IdModel usage:

    T2_s_float[iS6{2}, iS11{3}, iS12{4}, iB8{16}] ca_pos( 2 ) 
    logical domain : (iS6{2}, iS7{12}, iB8{16}) 
    allocation domain : (iS15{3}, iS16{4}, iS6{2}, iB8{16}) contiguity: t t t t 
    Split: iS7{12} by factor 4 -> iS15{3}, iS16{4} 
    Split: iS7{12} by factor 4 -> iS11{3}, iS12{4} 
    loop domain : (iS6{2}, iS11{3}, iS12{4}, iB8{16})
    

    What I'm suggesting is that since we want the allocation domain to have the same iter-domain expressions, we could use iS6{2}, iS11{3}, iS12{4}, iB8{16} to create the allocation domain as: iS11{3}, iS12{4}, iS6{2}, iB8{16}. Why do we need to create new iter-domains?

    Copy link
    Collaborator Author

    @liqiangxl liqiangxl Jul 25, 2025

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Got you. So what you suggested is: we don't change the loop domain, we create the allocation domain using the loop domain IDs by reordering. Then, we won't create new IDs during the schedule of the allocation domain. For example
    (1) current approach, direct split allocation domain, iS7{12}, iS6{2}, iB8{16} ---> iS15{3}, iS16{4}, iS6{2}, iB8{16}
    (2) new approach, use IDs from loop domain, iS7{12}, iS6{2}, iB8{16} ---> iS11{3}, iS12{4}, iS6{2}, iB8{16}
    The difference is we re-use iS11{3}, iS12{4} instead of creating new iS15{3}, iS16{4}. Is this understanding correct?

    @liqiangxl liqiangxl marked this pull request as ready for review July 22, 2025 15:52
    @liqiangxl liqiangxl requested a review from naoyam July 22, 2025 15:52
    os() << " contiguity: " << tv->domain()->getContiguityString() << "\n";

    for (const auto exp : tv->domain()->allExprs()) {
    const auto& loop_domain = tv->getLoopDomain();
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Why is this change? logical_to_loop does not capture all exprs returned by allExprs(). In some cases, the loop domain may not be dependents of the logical domain either.

    Copy link
    Collaborator Author

    @liqiangxl liqiangxl Jul 24, 2025

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I made this change to avoid printing transforms made on allocation domains. For example, there are two splits in the following tensor, one for logical domain and the other for allocation domain.
    I don't think we want to print out the transforms on allocation domain since the original allocation domain was replaced. Then, this additional split expr looks confusing to me.

      T2_s_float[iS6{2}, iS11{3}, iS12{4}, iB8{16}] ca_pos( 2 )
      logical domain : (iS6{2}, iS7{12}, iB8{16})
      allocation domain : (iS15{3}, iS16{4}, iS6{2}, iB8{16})
      contiguity: t t t t
       Split: iS7{12} by factor 4 -> iS15{3}, iS16{4}
       Split: iS7{12} by factor 4 -> iS11{3}, iS12{4}
      loop domain : (iS6{2}, iS11{3}, iS12{4}, iB8{16})
    

    I didn't realize the loop domain may not be dependents of the logical domain, then, we should revise to still use allExprs() but exclude exprs that generate allocation domains, that is Split: iS7{12} by factor 4 -> iS15{3}, iS16{4} in this case.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Yes, it is not the best way to show, but since these transformations are no longer just straight line transformations from root to logical, etc. Maybe we could have multiple sections for exprs like "root to logical", "logical to loop" and "logical to allocation", as those expr sequences are typically what matter most.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Or maybe we change to

      T2_s_float[iS6{2}, iS11{3}, iS12{4}, iB8{16}] ca_pos( 2 )
      logical domain : (iS6{2}, iS7{12}, iB8{16})
       Split: iS7{12} by factor 4 -> iS15{3}, iS16{4}
      allocation domain : (iS15{3}, iS16{4}, iS6{2}, iB8{16})
      contiguity: t t t t
       Split: iS7{12} by factor 4 -> iS11{3}, iS12{4}
      loop domain : (iS6{2}, iS11{3}, iS12{4}, iB8{16})
    

    then we know Split: iS7{12} by factor 4 -> iS15{3}, iS16{4} was used to generate allocation domain : (iS15{3}, iS16{4}, iS6{2}, iB8{16})

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Yeah, something like that would be more helpful than just dumping all expressions. See TensorDomain::allExprs() to see how to grab each set of expressions.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I extended allExprs() to allExprsToIds(alloc_domain)

    if (exclude_it != exclude_ca_ids.end()) {
    return *exclude_it;
    }
    // Fallback: use IdModel to check if any excluded ID is mapped
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Why is this necessary?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    For example, we have allocation domain (iS15{3}, iS16{4}, iS6{2}, iB8{16}) and loop domain (iS6{2}, iS11{3}, iS12{4}, iB8{16}) in

      T2_s_float[iS6{2}, iS11{3}, iS12{4}, iB8{16}] ca_pos( 2 )
      logical domain : (iS6{2}, iS7{12}, iB8{16})
      allocation domain : (iS15{3}, iS16{4}, iS6{2}, iB8{16})
      contiguity: t t t t
       Split: iS7{12} by factor 4 -> iS15{3}, iS16{4}
       Split: iS7{12} by factor 4 -> iS11{3}, iS12{4}
      loop domain : (iS6{2}, iS11{3}, iS12{4}, iB8{16})
    

    Based on loop domain and compute pos, we don't need to allocate iS6{2} and iS11{3}.
    Then the corresponding allocation domains iS6{2} and iS15{3} should be excluded.
    iS6{2} exists in both allocation and loop domains, it is found directly by pointer comparison.
    iS15{3} only exists in allocation domain, but it is mapped with iS11{3} in loop domain. Here we need IdModel to find this pair.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I see. I think that makes sense. Perhaps, we could simplify the code a bit:

    const auto excluded_ca_groups = GpuLower::current()->idModel().idGraph(IdMappingMode::EXACT).toGroups(exclude_ca_ids);
    rturn excluded_ca_groups.has(GpuLower::current()->idModel().idGraph(IdMappingMode::EXACT).toGroup(id));
    

    Then we can also remove lines 165 to 168.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    That's great, thanks for the suggestion!
    In addition to check for existence, we also need to remove the actual ID from exclude_ca_ids to ensure all intended exclusions are correctly applied. I’ve slightly extended this approach and submitted a refactor PR at #4843


    // For shared memory tensor, replay loop domain transformations to allocation
    // domain
    void replayLoopToAllocation(Fusion* fusion);
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    The name and the comment of the function seem vague. Nothing is mentioned that it doesn't silently ignore any tensor that already has an allocation domain. Ideally, it should be clear from the function name only.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    renamed to replayLoopToAllocationForSharedMemoryTvs

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl
    Copy link
    Collaborator Author

    Summary of buildAllocationDomainFromLoopIds Logic

    Purpose: Replace allocation domain IDs with their corresponding loop domain IDs while preserving allocation order.

    Algorithm:

    1. Validate allocation domain is a permutation of logical domain
    2. Get transformations from allocation (permutation of logical) → loop domain
    3. Track sources: Map each ID to its allocation origins
      • Initially: each allocation ID maps to itself
      • Split: both outputs inherit input's sources
      • Merge: output gets combined sources from both inputs
    4. Build result: For each allocation ID (in order), collect loop IDs that came from it

    Copy link
    Collaborator

    @naoyam naoyam left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM

    @liqiangxl liqiangxl merged commit bc5f462 into main Aug 8, 2025
    55 checks passed
    @liqiangxl liqiangxl deleted the llu/selfReplayLoopToAllocation branch August 8, 2025 13:33
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    3 participants