Skip to content

Handle swizzle1d in isResharding#6028

Open
Priya2698 wants to merge 2 commits intomainfrom
pm/swizzle1d_resharding
Open

Handle swizzle1d in isResharding#6028
Priya2698 wants to merge 2 commits intomainfrom
pm/swizzle1d_resharding

Conversation

@Priya2698
Copy link
Collaborator

No description provided.

@Priya2698
Copy link
Collaborator Author

!test

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 5, 2026

Greptile Summary

This PR extends haveDifferentShardings / isResharding to correctly handle Swizzle1D transforms in the loop domain by adding symbolic inverse-swizzle computation to computeLoopIndex and building a pt_to_index map from the producer's DeviceMesh.

Key changes:

  • device_mesh.cpp: hasParallelType now guards with size() > 0 to avoid a false positive on the default DeviceMesh (a 1-D tensor with 0 elements but rank() == 1, which would cause parallelTypeToAxis to incorrectly return 0 instead of -1).
  • resharding.cpp: computeLoopIndex gains a pt_to_index parameter and handles Swizzle1D transforms by applying the correct inverse formula out_idx = (in_idx − device_idx + extent) % extent, consistent with the evaluator's forward formula. haveDifferentShardings creates symbolic Val* entries for each DID present in the producer's mesh and passes the complete map to both producer and consumer computeLoopIndex calls.
  • tests/cpp/test_resharding.cpp: Two new unit tests validate the resharding (DID → Stream swizzle) and non-resharding (symmetric swizzle) cases with correct assertions.

The inverse swizzle formula is mathematically verified against the evaluator's forward formula, and the device_mesh.cpp guard is correct and necessary. Tests cover both resharding and non-resharding paths.

Confidence Score: 4/5

  • Safe to merge; the PR correctly implements Swizzle1D handling in resharding checks and the known edge case (consumer Swizzle1D referencing a DID absent from producer mesh) does not affect normal usage patterns and is tracked in existing threads.
  • The inverse swizzle formula is mathematically correct and verified against the evaluator. The hasParallelType guard is necessary and correct. Both the implementation and tests are sound. The one limitation (consumer-side DID mismatch in Swizzle1D) is an edge case outside the normal call path and is intentionally tracked for future work.
  • No files require special attention.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["haveDifferentShardings(producer, consumer, parallel_types)"] --> B["Build pt_to_index from\nproducer DeviceMesh\n(for each DID in kParallelTypeDIDs)"]
    B --> C["For each parallel_type in parallel_types"]
    C --> D["computeLoopIndex(p_id, producer logical, id_to_index, pt_to_index)"]
    C --> E["computeLoopIndex(c_id, consumer root, id_to_index, pt_to_index)"]

    D --> F{"Transform type?"}
    F -->|Split| G["outer = in / inner_extent\ninner = in % inner_extent"]
    F -->|Merge| H["out = outer * inner_extent + inner"]
    F -->|Swizzle1D| I["out_idx = (in_idx - pt_to_index[pt] + extent) % extent\nInverse of: in_idx = (out_idx + device_idx) % extent"]

    E --> F

    G --> J["simplifyExpr(p_index == c_index, assumptions)"]
    H --> J
    I --> J

    J -->|"Cannot prove equal"| K["return true (resharding)"]
    J -->|"Proven equal"| L["return false (non-resharding)"]
Loading

Last reviewed commit: 89885c0

Comment on lines +266 to +278
std::unordered_map<ParallelType, Val*> pt_to_index;
const DeviceMesh& mesh = producer->getDeviceMesh();
for (ParallelType pt : kParallelTypeDIDs) {
if (!mesh.hasParallelType(pt)) {
continue;
}
Val* device_idx = IrBuilder::create<Val>(DataType::Index);
pt_to_index[pt] = device_idx;
Val* team_size = IrBuilder::create<Val>(mesh.size(pt), DataType::Index);
assumptions.push_back(
SimplifyingIrBuilder::leExpr(fusion->zeroVal(), device_idx));
assumptions.push_back(SimplifyingIrBuilder::ltExpr(device_idx, team_size));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pt_to_index is populated exclusively from the producer's mesh. It is then passed verbatim into computeLoopIndex for the consumer side as well (lines 358–362). This means that if haveDifferentShardings is called with parallel_types = {Stream} and the consumer holds a Swizzle1D referencing a ParallelType (e.g. DIDy) that is not present in the producer's mesh, pt_to_index.at(swizzle->parallelType()) (line 110) will throw std::out_of_range.

The early-exit guard at lines 152–158 only fires when parallel_types contains a DID type, so it does not protect this path when only ParallelType::Stream is being checked.

A minimal fix is to also fold the consumer's mesh into pt_to_index:

Suggested change
std::unordered_map<ParallelType, Val*> pt_to_index;
const DeviceMesh& mesh = producer->getDeviceMesh();
for (ParallelType pt : kParallelTypeDIDs) {
if (!mesh.hasParallelType(pt)) {
continue;
}
Val* device_idx = IrBuilder::create<Val>(DataType::Index);
pt_to_index[pt] = device_idx;
Val* team_size = IrBuilder::create<Val>(mesh.size(pt), DataType::Index);
assumptions.push_back(
SimplifyingIrBuilder::leExpr(fusion->zeroVal(), device_idx));
assumptions.push_back(SimplifyingIrBuilder::ltExpr(device_idx, team_size));
}
// Collect device-parallel symbolic indices from both meshes.
for (const DeviceMesh* m : {&producer->getDeviceMesh(),
&consumer->getDeviceMesh()}) {
for (ParallelType pt : kParallelTypeDIDs) {
if (!m->hasParallelType(pt) || pt_to_index.count(pt)) {
continue;
}
Val* device_idx = IrBuilder::create<Val>(DataType::Index);
pt_to_index[pt] = device_idx;
Val* team_size = IrBuilder::create<Val>(m->size(pt), DataType::Index);
assumptions.push_back(
SimplifyingIrBuilder::leExpr(fusion->zeroVal(), device_idx));
assumptions.push_back(
SimplifyingIrBuilder::ltExpr(device_idx, team_size));
}
}

Comment on lines +676 to +678
EXPECT_FALSE(haveDifferentShardings(
in, DomainType::kLoop, out, DomainType::kLoop, {ParallelType::Stream}));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Swizzle1D_ConsistentSwizzle only asserts that {ParallelType::Stream} is non-resharding. Because neither tensor's loop domain has a DIDx-parallelized axis in this configuration, it would be useful to also verify that checking {ParallelType::DIDx} does not regress (i.e. still returns false). This mirrors the two-assertion pattern used in Swizzle1D_DIDToStream and guards against future regressions where a swizzle-internal DID type is accidentally surfaced.

Suggested change
EXPECT_FALSE(haveDifferentShardings(
in, DomainType::kLoop, out, DomainType::kLoop, {ParallelType::Stream}));
}
EXPECT_FALSE(haveDifferentShardings(
in, DomainType::kLoop, out, DomainType::kLoop, {ParallelType::Stream}));
EXPECT_FALSE(haveDifferentShardings(
in, DomainType::kLoop, out, DomainType::kLoop, {ParallelType::DIDx}));

@Priya2698
Copy link
Collaborator Author

!test

@Priya2698 Priya2698 requested a review from wujingyue March 5, 2026 19:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant