Fix overzealous concretization root->rfactor propagation#397
Fix overzealous concretization root->rfactor propagation#397jacobhinkle merged 6 commits intomainfrom
Conversation
|
!build |
Why is this? Aren't we generating a broadcast domain when the extent is 1? |
|
I wonder why this needs to be symbolic: Its extent is known to be |
Yes. Sorry I should've been more clear. The root should get concretized to Iteration, and it does. This is correct. However, currently we also propagate that to rfactor, assuming that if we updated the root then all uses up until rfactor will need an update. In this case, the rfactor is already a concrete Broadcast, so we need to avoid overwriting it. |
One reason to keep it symbolic would be if we allow our |
Yes, but even if we overwrite it, shouldn't it be overwritten as a broadcast domain since the extent is 1? I agree we shouldn't overwrite it from the beginning, but I wonder if there's anything wrong in the propagation step of fusion concretization. |
So, the ID should really be something like: where |
Yes, I think we probably should address this in the ops. I don't think However, with |
Yes that is a bit of an issue I think. The way it works now is we work from root to rfactor and we do not infer the "proper" iter_type at each step, we just copy the input iter_type if the intermediate ID is |
Yeah, I kind of expected there would be some op-specific corner cases. |
Ah, I see. So, we should just skip non-symbolic IDs. For any symbolic ID, if it requires concretization with an expr eval, it should be taken care before the propagation by |
I think that's correct. One thing I've been trying to sort out is how far we need to traverse to find the IDs that need an expr_eval to decide, before concretization. Currently this problematic code is in root->rfactor that happens during forward propagation after we have run |
naoyam
left a comment
There was a problem hiding this comment.
LGTM. Thanks for the fix. Please create an issue for the slice problem.
Repro:
Running this we encounter an error at
runFusionWithInputs():The error occurs during concretization when creating a new TensorDomain for T2. We are replacing the old tensor domain:
with a new one:
The
IterTypeis being changed fromBroadcasttoIterationwhen concretizing this slice op. The root domain?S3{2}rfis concretized toiS5{2}rf, which is correct, but when it's propagated from root to rfactor insideDynamicTransformConcretizer::mutate(TensorView* tv)we should not overwrite each op's outputs unless it isSymbolic. Even then, it may be best to delegate this to the op instead. For now, this PR skips non-Symbolicoutputs when doing this propagation.