Skip to content

Support Split between logical domain to allocation domain to represent padding#5184

Draft
jjsjann123 wants to merge 89 commits intojj/skip_vectorization_allocation_validationfrom
jj/allocation_PR_0
Draft

Support Split between logical domain to allocation domain to represent padding#5184
jjsjann123 wants to merge 89 commits intojj/skip_vectorization_allocation_validationfrom
jj/allocation_PR_0

Conversation

@jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Sep 18, 2025

Stacked PR

PR0: #5622 skip aggressive validation check on allocation domain for vectorization
PR1: #5184 Support Split between logical domain to allocation domain to represent padding <-- this one

This PR

Allows split of ID on the path logical->allocation to represent padding logic on allocation. Notably, we no longer require allocation domain on the path between logical->loop.

Motivation

Split on allocation domain allows a clean representation for padding. i.e.

  // `out` is a 2d TensorView with logical domain as [i0, i1]
  auto&& [io, ii] = IterDomain::split(
      out->axis(1), IrBuilder::create<Val>(16L, DataType::Index), true);
  // out now has
  //   logical [i0, i1]
  //     io(i1/16), ii(16) = split(i1, 16)
  //   alloc    [i0, io(i1/16), ii(16)]
  out->setAllocationDomain({out->axis(0), io, ii}, true);

The example above is just specifying that dimension i1 would be padded to a multiple of 16.

Main Code Change

In order to support this, we have to update TensorView::cacheBefore. CacheBefore changes the graph from this to producer -> set -> consumer:

  • The old cacheBefore logic keeps this->domain() on producer and replays from logical to loop on consumer;
  • This was arguably not correct, since we shouldn't dictate the layout of cache from the output tensor consumer;
  • A split that sits only between logical to allocation wouldn't work neither, since it isn't on the replay path.

Hence this PR changes the cacheBefore logic such that:

  • We replay the transformation from root to loop on producer;
  • this->domain() is now preserved on consumer after reduction IDs were removed.

Technical Challenges

  1. In theory, we shouldn't need allocation domain on cache at all. One exception where allocation domain is preserved on cache is when the cache is sharded. This is because our shape inference done via ExpressionEvaluator relies on allocation domain. Without proper allocation domain, the reshape call would be called on global tensor instead of local tensor;
  2. shape inference and indexing correctness is compromised with non-divisible split. See the added example in LogicalAndAllocationSizes. Since this PR is growing in size, I'll fix it in follow up PRs;
  3. There's a separate codegen tests when a modified allocation domain on cache leads to incorrect codegen on vectorized store. See comment. I think this is more of a scheduler issue, which I'll continue investigating separately.

@github-actions
Copy link

github-actions bot commented Sep 18, 2025

Review updated until commit 4d240a4

Description

  • Enable split operations between logical and allocation domains for padding representation

  • Refactor TensorView::cacheBefore to properly handle domain transformations and preserve allocation domains

  • Update transform replay logic to maintain parallelization types and rfactor product information

  • Add support for scatter operations in cacheBefore with proper domain handling

  • Improve allocation domain preservation during caching operations

Changes walkthrough

Relevant files
Enhancement
tensor_view.cpp
Refactor cacheBefore with improved domain handling             

csrc/tensor_view.cpp

  • Major refactor of TensorView::cacheBefore() method with new two-step
    approach
  • Add scatter operation support with custom domain handling
  • Implement proper cleanup of consumer domains removing root and
    reduction IDs
  • Preserve allocation domains and parallelization information during
    caching
  • Add device mesh handling with allocation domain mapping
  • +139/-38
    transform_replay.cpp
    Enhance transform replay with parallelization preservation

    csrc/transform_replay.cpp

  • Preserve parallelization types during split operations
  • Update merge operations to handle rfactor products correctly
  • Refactor fullSelfReplay to return replay mapping for allocation domain
    updates
  • Add new applyFullSelfReplay helper function
  • +42/-10 
    transform_replay.h
    Extend fullSelfReplay API with replay mapping                       

    csrc/transform_replay.h

  • Add new fullSelfReplay overload accepting replay_map parameter
  • Update documentation to clarify replay transformation behavior
  • +9/-1     
    internal_base_nodes.h
    Add resetRFactorProduct utility method                                     

    csrc/ir/internal_base_nodes.h

  • Add resetRFactorProduct method to IterDomain for clearing rfactor
    domain flag
  • +5/-0     
    Bug_fix
    matmul.cpp
    Fix matmul scheduler ID model updates                                       

    csrc/scheduler/matmul.cpp

  • Update updateIdModel to handle eliminated reduction IDs from
    cacheBefore
  • Fix cacheBefore to properly map logical domains between consumer and
    producer
  • Add ValGroup traversal logic to find remaining IDs in new id_model
  • +20/-6   
    Tests
    test_layout_op.cpp
    Add allocation domain padding and vectorization tests       

    tests/cpp/test_layout_op.cpp

  • Add test for logical and allocation domain sizes with padding
  • Add test for allocation domain split vectorization factor
  • Validate padding behavior and vectorization with allocation domain
    splits
  • +65/-0   
    test_allocation_domain.cpp
    Update allocation domain test expectations                             

    tests/cpp/test_allocation_domain.cpp

  • Remove assertions about allocation domain preservation after
    cacheBefore
  • Update test expectations to match new cacheBefore behavior
  • +0/-2     

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Memory Management

    The new cacheBefore implementation creates multiple new IterDomains using IrBuilder::createInContainer but doesn't explicitly clean up the old domain objects. While the old domain is stored in old_domain pointer, there's no clear deletion strategy which could lead to memory leaks, especially in long-running applications or when cacheBefore is called multiple times.

    TensorDomain* old_domain = domain();
    ScatterOp Edge Case

    The special handling for ScatterOp creates logical and loop domains separately, but the comment suggests this is a workaround for limitations in replay. The logic for handling scatter dimensions and creating new IDs might have edge cases where the mapping isn't correct, particularly when the scatter has complex indexing patterns.

    if (definition()->isA<ScatterOp>()) {
      // scatter output's loop is not connected to its root, we cannot support it
      // in replay
      NVF_ERROR(
          !domain()->hasRoot(),
          "scatter output's with root domain is not supported in cacheBefore");
      std::vector<IterDomain*> logical;
      std::vector<IterDomain*> loop;
    
      std::ranges::transform(
          domain()->logical(), std::back_inserter(logical), [&](IterDomain* id) {
            IterDomain* cloned_id =
                IrBuilder::createInContainer<IterDomain>(container(), id);
            producer_map[id] = cloned_id;
            return cloned_id;
          });
      std::ranges::transform(
          domain()->loop(), std::back_inserter(loop), [&](IterDomain* id) {
            if (auto it = producer_map.find(id); it != producer_map.end()) {
              // reuse cloned_ids
              return it->second;
            }
            // for scatter dimension, create new ID
            return IrBuilder::createInContainer<IterDomain>(container(), id);
          });
      producer = IrBuilder::createInContainer<TensorView>(
          container(),
          IrBuilder::createInContainer<TensorDomain>(
              container(),
              logical,
              loop,
              TensorDomain::getContiguityFilledWith(logical, true),
              /*skip_loop_validation=*/true),
          getDataType().value());
    } else {
    Allocation Domain Mapping

    The new allocation domain mapping logic (lines 1272-1279) assumes that all IDs in the old allocation domain exist in the producer_map. This might not hold true in all scenarios, particularly with complex transformations or when reduction IDs are involved, potentially causing runtime crashes or incorrect memory layouts.

    if (consumer->domain()->hasAllocation()) {
      std::vector<IterDomain*> mapped_alloc;
      mapped_alloc.reserve(old_domain->allocation().size());
      for (auto* c_id : old_domain->allocation()) {
        mapped_alloc.push_back(producer_map.at(c_id));
      }
      producer->setAllocationDomain(mapped_alloc, true);
    }

    Test failures

    • (Medium, 1) Tensor numerical mismatches in nvFuser matmul tests (H100 runner)

      Test Name H100 Source
      HopperMatmulTest.HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile Link

    out->split(1, 16);
    out->setAllocationDomain(out->getLoopDomain(), true);
    // restore loop domain
    out->merge(1);
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This doesn't restore. Is this necessary?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Touche. It unsplit the loop domain so that it has the same size as logical domain.
    You are right that the extent is no longer the same, so it's not a restoration.

    Schedulers expects un-scheduled fusion. Without this merge, I'm hitting the assert here:

    NVF_ERROR(broadcast_bit_multiples.size() == ref_loop.size());

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Hmm, not sure that's good enough WAR, though this is just a test.

    I thought the schedulers can work with some scheduled loop domains (for DID parallelization), not?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    // We always cacheBefore output at the beginning of the scheduling. And after
    // cacheBefore, the reference tensor will have all reduction IDs removed.
    ref_loop = TensorDomain::noDevices(TensorDomain::noReductions(ref_loop));

    DID related IDs are just ignored by scheduler. So that's just too specific for multi-device.

    I'm not a fan of this neither. Let me see if I can skip messing with loop and play transformation on allocation directly.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I suppose you can just modify the allocation domain with AbstractTensor. I remember there are some tests.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I can also directly using IterDomain::split for that.

    Anyway, looks like if the transformation is not on logical to loop, our replay wouldn't pick it up. Felt similar to the allocation domain replay that rfactor was missing. fyi @Priya2698

    #0  nvfuser::nvfCheckFail (func=0xaaaaac218080 "validateDomainEquivalence",
        file=0xaaaaac216938 "/opt/pytorch/nvfuser/csrc/ir/utils.cpp", line=1162,
        msg=" INTERNAL ASSERT FAILED at /opt/pytorch/nvfuser/csrc/ir/utils.cpp:1162, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. \nExpected !compare_result.dom0_has_u"...) at /opt/pytorch/nvfuser/csrc/exceptions.cpp:267
    #1  0x0000aaaaab1bbe68 in nvfuser::nvfErrorFail (func=0xaaaaac218080 "validateDomainEquivalence",
        file=0xaaaaac216938 "/opt/pytorch/nvfuser/csrc/ir/utils.cpp", line=1162,
        condMsg=0xaaaaac217fd8 " INTERNAL ASSERT FAILED at /opt/pytorch/nvfuser/csrc/ir/utils.cpp:1162, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. ",
        userMsg="Expected !compare_result.dom0_has_unreachable_ids . dom0 has unreachable IDs. dom0: iS10{i0}, iS11{i2}. dom1: iS10{i0}") at /opt/pytorch/nvfuser/csrc/exceptions.cpp:277
    #2  0x0000aaaaab60a3e8 in nvfuser::ir_utils::validateDomainEquivalence (
        dom0=std::vector of length 2, capacity 2 = {...}, dom1=std::vector of length 1, capacity 3 = {...},
        additional_ids=std::vector of length 0, capacity 0) at /opt/pytorch/nvfuser/csrc/ir/utils.cpp:1162
    #3  0x0000aaaaab4aac30 in nvfuser::TensorDomain::setAllocationDomain (this=0xaaaab20918b0,
        new_allocation_domain=std::vector of length 1, capacity 3 = {...},
        new_contiguity=std::vector of length 1, capacity 3 = {...})
        at /opt/pytorch/nvfuser/csrc/ir/nodes.cpp:4055
    #4  0x0000aaaaabc7b368 in nvfuser::TransformReplay::replayCasP (consumer=0xaaaab2088c00,
        producer=0xaaaab2091200, producer_pos=2, logical_map=..., opt=...)
        at /opt/pytorch/nvfuser/csrc/transform_replay.cpp:917
    #5  0x0000aaaaabc7b7fc in nvfuser::TransformReplay::replayCasP (consumer=0xaaaab2088c00,
        producer=0xaaaab2091200, compute_at_axis=-1, opt=...)
        at /opt/pytorch/nvfuser/csrc/transform_replay.cpp:945
    #6  0x0000aaaaabc44ccc in nvfuser::TensorView::cacheBefore (this=0xaaaab2088c00,
        op_type=nvfuser::LoadStoreOpType::Set) at /opt/pytorch/nvfuser/csrc/tensor_view.cpp:1160
    #7  0x0000aaaaabbdb250 in nvfuser::scheduler_utils::cacheAndForkOutputs (fusion=0xaaaab2084910,
        unroll=true) at /opt/pytorch/nvfuser/csrc/scheduler/utils.cpp:1357
    #8  0x0000aaaaabb067dc in nvfuser::schedulePointwise (fusion=0xaaaab2084910, pparams=0xaaaab207f880)
        at /opt/pytorch/nvfuser/csrc/scheduler/pointwise.cpp:822
    #9  0x0000aaaaabb0898c in nvfuser::PointWiseScheduler::schedule (this=0xaaaab2083460,
        fusion=0xaaaab2084910, params=0xaaaab207f880)
        at /opt/pytorch/nvfuser/csrc/scheduler/pointwise.cpp:1304
    

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    So, what did you decide to do? Nothing seems to have changed?

    I can also directly using IterDomain::split for that.

    Of course, but you'd need to maintain the proper ordering of the ID vector yourself.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I can also directly using IterDomain::split for that.

    Anyway, looks like if the transformation is not on logical to loop, our replay wouldn't pick it up. Felt similar to the allocation domain replay that rfactor was missing. fyi @Priya2698

    Yes rfactor replay for allocation will also complain similarly if allocation transforms are disjoint from root-to-loop.
    replayPasC also uses the loop domain as the target so if you intend to use IterDomain::split, we will have to update that, among other things.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    yep. switched to selfReplay instead of replayCasP for TensorView::cacheBefore

    }
    };

    TEST_F(LayoutOpTest, LogicalAndAllocationSizes) {
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    What is being tested here?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Without the relaxation in vectorization analysis, this test will trigger an assert.

    So the test just verifies that we do allow allocation domain split now.
    In the follow up PR, we added more validation to this test to check the produce tensor matches the logical sizes.

    Copy link
    Collaborator

    @Priya2698 Priya2698 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    The changes look good for the multidevice support part. I am not familiar enough with the requirements for LayoutOp, so I will defer to Naoya to approve the PR.
    Is there an existing issue or doc detailing the LayoutOp design?

    @jjsjann123
    Copy link
    Collaborator Author

    !test

    @jjsjann123
    Copy link
    Collaborator Author

    !test

    @jjsjann123
    Copy link
    Collaborator Author

    Is there an existing issue or doc detailing the LayoutOp design?

    Sorry I don't have anything on that yet. I'll try to write up one when I have the end-2-end example working at least in a prototype. Mostly trying to wing it at this moment.

    @jjsjann123
    Copy link
    Collaborator Author

    !test


    // Replay loop.
    if (self_loop != self->logical()) {
    ReplaySelf replay(self_loop, axis_map);
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Just FYI: #4585 reversed this. I expect some tests to break.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Thanks a ton. Let me sweep through failing tests and see if there's anything easy to patch. 🧑‍💼

    @jjsjann123
    Copy link
    Collaborator Author

    !test

    @jjsjann123
    Copy link
    Collaborator Author

    !test

    fusion.addOutput(out);
    // padding output to multiple of 16 on allocation domain
    auto&& [io, ii] = IterDomain::split(
    out->axis(1), IrBuilder::create<Val>(16L, DataType::Index), true);
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    tagging @naoyam changed the test to only apply split on logical -> allocation.

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    12 files reviewed, 3 comments

    Edit Code Review Agent Settings | Greptile

    Comment on lines +1186 to +1189
    domain()->logical() | std::views::transform([](IterDomain* id) {
    id->setDefinition(nullptr);
    return id->resetRFactorProduct();
    }),
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    logic: mutating IterDomain objects by clearing their definitions with setDefinition(nullptr) affects the original objects that may still be referenced elsewhere in the codebase, potentially causing issues if other code paths expect these definitions to be intact

    Comment on lines +105 to +108
    // FIXME: this doesn't feel right, we have to mark contiguity on axis(0) as
    // `false` to avoid accidntal indexing collapsing, this should be figured out
    // by indexing from the ceilDiv.
    out->setAllocationDomain({out->axis(0), io, ii}, {false, true, true});
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    logic: the FIXME indicates the contiguity workaround (false for axis(0)) is hiding an indexing bug - the indexing logic should automatically handle ceilDiv from splits without manual contiguity adjustments

    Comment on lines +117 to +127
    // FIXME: output shape inference is not correct yet.
    // output should remain the correct logical size
    // EXPECT_EQ(
    // cg_outputs[0].as<at::Tensor>().sizes(), std::vector<int64_t>({512,
    // 9}));
    // padding on the inner dimension is represented as stride on the outer
    // dimension
    EXPECT_EQ(
    cg_outputs[0].as<at::Tensor>().strides(), std::vector<int64_t>({16, 1}));
    // We need to slice because output buffer shape is not right
    EXPECT_TRUE(t0.equal(cg_outputs[0].as<at::Tensor>().slice(1, 0, k)));
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    logic: shape inference is broken for non-divisible splits - the output tensor has wrong size requiring manual slicing to verify correctness, which means production code using this feature will produce incorrectly-shaped outputs

    @jjsjann123 jjsjann123 changed the base branch from main to jj/refactor_allocation_domain_lowering November 25, 2025 19:46
    jjsjann123 and others added 2 commits November 25, 2025 12:08
    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    8 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Comment on lines +1186 to +1188
    domain()->logical() | std::views::transform([](IterDomain* id) {
    id->setDefinition(nullptr);
    return id->resetRFactorProduct();
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    logic: in-place mutation of IterDomain objects via setDefinition(nullptr) modifies objects that may still be referenced elsewhere (e.g., in producer_map, graph structures, or other tensor views). The comment mentions this is needed for BestEffortReplay, but this mutation happens BEFORE the consumer domain is created and could affect code that expects these IDs to retain their definitions

    @jjsjann123
    Copy link
    Collaborator Author

    !test

    @jjsjann123
    Copy link
    Collaborator Author

    !test

    @jjsjann123
    Copy link
    Collaborator Author

    Looks like it's hitting a vectorization validation error now. I'm not sure if it's coming from changes in #5594.

    Taking a look at that now.

    @jjsjann123 jjsjann123 marked this pull request as draft December 2, 2025 01:13
    @jjsjann123 jjsjann123 changed the base branch from jj/refactor_allocation_domain_lowering to main December 2, 2025 17:12
    @jjsjann123
    Copy link
    Collaborator Author

    !test

    @jjsjann123 jjsjann123 changed the base branch from main to jj/skip_vectorization_allocation_validation December 4, 2025 00:22
    @jjsjann123
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    7 files reviewed, 3 comments

    Edit Code Review Agent Settings | Greptile

    Comment on lines +1197 to +1199
    domain()->logical() | std::views::transform([](IterDomain* id) {
    id->setDefinition(nullptr);
    return id->resetRFactorProduct();
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    logic: calling setDefinition(nullptr) mutates the original IterDomain objects from domain()->logical(). These IDs are shared and may be referenced by:

    • old_domain (stored at line 1237)
    • producer_map entries (populated at lines 1140, 1180)
    • Other tensors in the fusion graph
    • Downstream transform replay operations

    The mutation happens before the consumer domain is created, so any code expecting these IDs to retain their definitions will break. The comment mentions this is needed for BestEffortReplay, but the timing and scope of this mutation need validation.

    Comment on lines +105 to +108
    // FIXME: this doesn't feel right, we have to mark contiguity on axis(0) as
    // `false` to avoid accidntal indexing collapsing, this should be figured out
    // by indexing from the ceilDiv.
    out->setAllocationDomain({out->axis(0), io, ii}, {false, true, true});
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    logic: FIXME indicates a workaround for broken indexing logic. Setting axis(0) contiguity to false prevents indexing from collapsing dimensions, but this should be automatically inferred from the ceilDiv in the split. This manual workaround masks an underlying indexing bug that could affect other use cases.

    Suggested change
    // FIXME: this doesn't feel right, we have to mark contiguity on axis(0) as
    // `false` to avoid accidntal indexing collapsing, this should be figured out
    // by indexing from the ceilDiv.
    out->setAllocationDomain({out->axis(0), io, ii}, {false, true, true});
    // TODO: Remove this workaround once indexing properly handles ceilDiv splits
    out->setAllocationDomain({out->axis(0), io, ii}, {false, true, true});

    Comment on lines +117 to +127
    // FIXME: output shape inference is not correct yet.
    // output should remain the correct logical size
    // EXPECT_EQ(
    // cg_outputs[0].as<at::Tensor>().sizes(), std::vector<int64_t>({512,
    // 9}));
    // padding on the inner dimension is represented as stride on the outer
    // dimension
    EXPECT_EQ(
    cg_outputs[0].as<at::Tensor>().strides(), std::vector<int64_t>({16, 1}));
    // We need to slice because output buffer shape is not right
    EXPECT_TRUE(t0.equal(cg_outputs[0].as<at::Tensor>().slice(1, 0, k)));
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    logic: shape inference is broken - the output tensor has the wrong size and requires manual slicing to verify correctness. This means production code using non-divisible splits will produce incorrectly-shaped outputs, breaking downstream operations that depend on tensor shapes.

    The commented-out assertion at lines 119-121 shows the expected behavior isn't working. This is a critical correctness issue that needs resolution before merge.

    @jjsjann123 jjsjann123 marked this pull request as draft December 4, 2025 19:42
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    6 participants