Conversation
|
!test |
Greptile SummaryThis PR extends Key changes:
The inverse swizzle formula is mathematically verified against the evaluator's forward formula, and the Confidence Score: 4/5
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["haveDifferentShardings(producer, consumer, parallel_types)"] --> B["Build pt_to_index from\nproducer DeviceMesh\n(for each DID in kParallelTypeDIDs)"]
B --> C["For each parallel_type in parallel_types"]
C --> D["computeLoopIndex(p_id, producer logical, id_to_index, pt_to_index)"]
C --> E["computeLoopIndex(c_id, consumer root, id_to_index, pt_to_index)"]
D --> F{"Transform type?"}
F -->|Split| G["outer = in / inner_extent\ninner = in % inner_extent"]
F -->|Merge| H["out = outer * inner_extent + inner"]
F -->|Swizzle1D| I["out_idx = (in_idx - pt_to_index[pt] + extent) % extent\nInverse of: in_idx = (out_idx + device_idx) % extent"]
E --> F
G --> J["simplifyExpr(p_index == c_index, assumptions)"]
H --> J
I --> J
J -->|"Cannot prove equal"| K["return true (resharding)"]
J -->|"Proven equal"| L["return false (non-resharding)"]
Last reviewed commit: 89885c0 |
| std::unordered_map<ParallelType, Val*> pt_to_index; | ||
| const DeviceMesh& mesh = producer->getDeviceMesh(); | ||
| for (ParallelType pt : kParallelTypeDIDs) { | ||
| if (!mesh.hasParallelType(pt)) { | ||
| continue; | ||
| } | ||
| Val* device_idx = IrBuilder::create<Val>(DataType::Index); | ||
| pt_to_index[pt] = device_idx; | ||
| Val* team_size = IrBuilder::create<Val>(mesh.size(pt), DataType::Index); | ||
| assumptions.push_back( | ||
| SimplifyingIrBuilder::leExpr(fusion->zeroVal(), device_idx)); | ||
| assumptions.push_back(SimplifyingIrBuilder::ltExpr(device_idx, team_size)); | ||
| } |
There was a problem hiding this comment.
pt_to_index is populated exclusively from the producer's mesh. It is then passed verbatim into computeLoopIndex for the consumer side as well (lines 358–362). This means that if haveDifferentShardings is called with parallel_types = {Stream} and the consumer holds a Swizzle1D referencing a ParallelType (e.g. DIDy) that is not present in the producer's mesh, pt_to_index.at(swizzle->parallelType()) (line 110) will throw std::out_of_range.
The early-exit guard at lines 152–158 only fires when parallel_types contains a DID type, so it does not protect this path when only ParallelType::Stream is being checked.
A minimal fix is to also fold the consumer's mesh into pt_to_index:
| std::unordered_map<ParallelType, Val*> pt_to_index; | |
| const DeviceMesh& mesh = producer->getDeviceMesh(); | |
| for (ParallelType pt : kParallelTypeDIDs) { | |
| if (!mesh.hasParallelType(pt)) { | |
| continue; | |
| } | |
| Val* device_idx = IrBuilder::create<Val>(DataType::Index); | |
| pt_to_index[pt] = device_idx; | |
| Val* team_size = IrBuilder::create<Val>(mesh.size(pt), DataType::Index); | |
| assumptions.push_back( | |
| SimplifyingIrBuilder::leExpr(fusion->zeroVal(), device_idx)); | |
| assumptions.push_back(SimplifyingIrBuilder::ltExpr(device_idx, team_size)); | |
| } | |
| // Collect device-parallel symbolic indices from both meshes. | |
| for (const DeviceMesh* m : {&producer->getDeviceMesh(), | |
| &consumer->getDeviceMesh()}) { | |
| for (ParallelType pt : kParallelTypeDIDs) { | |
| if (!m->hasParallelType(pt) || pt_to_index.count(pt)) { | |
| continue; | |
| } | |
| Val* device_idx = IrBuilder::create<Val>(DataType::Index); | |
| pt_to_index[pt] = device_idx; | |
| Val* team_size = IrBuilder::create<Val>(m->size(pt), DataType::Index); | |
| assumptions.push_back( | |
| SimplifyingIrBuilder::leExpr(fusion->zeroVal(), device_idx)); | |
| assumptions.push_back( | |
| SimplifyingIrBuilder::ltExpr(device_idx, team_size)); | |
| } | |
| } |
| EXPECT_FALSE(haveDifferentShardings( | ||
| in, DomainType::kLoop, out, DomainType::kLoop, {ParallelType::Stream})); | ||
| } |
There was a problem hiding this comment.
Swizzle1D_ConsistentSwizzle only asserts that {ParallelType::Stream} is non-resharding. Because neither tensor's loop domain has a DIDx-parallelized axis in this configuration, it would be useful to also verify that checking {ParallelType::DIDx} does not regress (i.e. still returns false). This mirrors the two-assertion pattern used in Swizzle1D_DIDToStream and guards against future regressions where a swizzle-internal DID type is accidentally surfaced.
| EXPECT_FALSE(haveDifferentShardings( | |
| in, DomainType::kLoop, out, DomainType::kLoop, {ParallelType::Stream})); | |
| } | |
| EXPECT_FALSE(haveDifferentShardings( | |
| in, DomainType::kLoop, out, DomainType::kLoop, {ParallelType::Stream})); | |
| EXPECT_FALSE(haveDifferentShardings( | |
| in, DomainType::kLoop, out, DomainType::kLoop, {ParallelType::DIDx})); |
|
!test |
No description provided.