Fix itertype promotion for GatherScatter#5365
Conversation
|
!test |
|
Review updated until commit 12d6f71 Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
|
|
||
| namespace nvfuser { | ||
|
|
||
| TEST_F(NVFuserTest, GatherScatterIterType) { |
There was a problem hiding this comment.
This is same as ScatterLargeConstrainedIDs.
I have made this separate to keep it independent of the other test in case it is modified in the future according to its purpose.
|
!test |
## This PR
Stops propagation of IterType::GatherScatter to be propagated from
producer to consumer.
## Context
We are propagating IterType::GatherScatter, in the example below, the
output T2 from scatter op contains `nS2` in logical domain. That
IterType is propagated through to its consumer T3, which felt wrong.
```
%kernel {
T2_l_int64_t[iS3{1024}]
= scatter(in = T0_g_int64_t[iS0{128}], dim = 0, src = 1, idx = T1_g_int64_t[iS1{1024}], accumulate = add )
T3_g_int64_t[nS4{128}]
= T2_l_int64_t[iS3{1024}]
+ 1;
T2_l_int64_t[iS3{1024}]
logical domain : (nS2{128})
contiguity: t
loop domain : (iS3{1024})
T3_g_int64_t[nS4{128}]
logical domain : (nS4{128})
contiguity: t
loop domain : (nS4{128})
} // %kernel
```
This PR
Stops propagation of IterType::GatherScatter to be propagated from producer to consumer.
Context
We are propagating IterType::GatherScatter, in the example below, the output T2 from scatter op contains
nS2in logical domain. That IterType is propagated through to its consumer T3, which felt wrong.