Update propagateSharding preseg pass for DID loop split#3838
Update propagateSharding preseg pass for DID loop split#3838
Conversation
|
Review updated until commit b09d926 Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
dcab5b7 to
c41b7ee
Compare
baad591 to
6d159ac
Compare
b50134f to
f136246
Compare
|
!test |
…backward transform propagation
|
!test |
|
!test |
|
!test |
For my reference:
|
|
|
||
| } // namespace | ||
|
|
||
| void MakeReshardingContiguousPass::runPass(Fusion* fusion) { |
There was a problem hiding this comment.
I'll rename this pass in a separate PR.
…4274) This makes #3838 performance neutral. Benchmarking results on GH200 nodes: On main: ``` Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations test_transformer_forward 6.2744 7.0567 6.4946 0.3369 6.2961 0.4077 1;0 153.9732 5 1 test_transformer_forward 6.2781 7.0573 6.4949 0.3368 6.2962 0.4076 1;0 153.9664 5 1 ------------------------------------------------------------------------------------------------------------------- test_transformer_backward 12.5244 13.7777 13.0152 0.6278 12.5900 1.1082 1;0 76.8331 5 1 test_transformer_backward 12.5348 13.7620 13.0204 0.6094 12.6391 1.0909 1;0 76.8024 5 1 ----------------------------------------------------------------------------------------------------------------------- ``` This branch: ``` Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations test_transformer_forward 6.2889 7.0885 6.5132 0.3481 6.2960 0.4302 1;0 153.5349 5 1 test_transformer_forward 6.2895 7.0262 6.5010 0.3231 6.2963 0.4195 1;0 153.8221 5 1 Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations test_transformer_backward 12.4542 13.6518 12.9532 0.5625 12.6231 0.9795 1;0 77.2012 5 1 test_transformer_backward 12.4778 13.6544 12.9510 0.5641 12.5828 0.9724 1;0 77.2139 5 1 -----------------------------------------------------------------------------------------------------------------------
|
!test |
wujingyue
left a comment
There was a problem hiding this comment.
I don't yet fully understand the code around selectiveReorderDIDToFront. LGTM otherwise!
|
!test |
Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
|
!test |
|
!test |
This PR extends the
propagateShardingpresegmentation pass for DID loop splits.Key changes:
ViewOpwhich is handled manually since TransformPropagator does not support it without first propagating the reshape to the producer.makeReshardingContiguoussets allocation domain for tvs with device mesh. Ideally, we need to set it only for global tensors but this is not known before segmentation, but should be set before segmentation.The following tests are modified: See discussion. PR MoveMarkAliasAnalysisPreparePassbeforepropagateShardingsPass#4274 resolved this.Follow-up PRs:
ViewOpwill be handled in a followup PR.shardAllLikecan be modified to specify which parallel type to propagate. SinceinsertReshardingandpropagateShardingrequire different behavior, I will handle it in a separate PR.TransformReplay::CasPin lieu of TransformPropagator.castOp: privatizeUpcast clones cast operations, which fails segmentation since the transforms are not replicated.Findings from experiments: #3838 (comment)