Skip to content

Move MarkAliasAnalysisPreparePass before propagateShardingsPass#4274

Merged
Priya2698 merged 3 commits intomainfrom
pm/alias_analysis
Apr 23, 2025
Merged

Move MarkAliasAnalysisPreparePass before propagateShardingsPass#4274
Priya2698 merged 3 commits intomainfrom
pm/alias_analysis

Conversation

@Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Apr 18, 2025

This makes #3838 performance neutral. PR #3838 sets the allocation domain for multidevice tensorviews in makeReshardingContiguous pass. Aliasing is not done if allocation domain has already been set for a tensorview. This PR moves the multidevice preseg passes after markAliasAnalysisPreparePass to avoid performance regression.

Update: The codegen changes are due to missed aliasing opportunities for the final permute operation inserted in reorderShardedAxisPass (here and here), however this does not have a significant performance impact (see benchmark results). Since, only the input/output of the communication need to have the allocation domain specified, new_output can have the same allocation as output.

Once we fix markAliasPreparePass to propagate DID transforms and shardings for copied tensorviews, the presegmentation passes will be ordered as [propagateShardingsPass, insertResharding, reorderShardedAxis] -> [markAliasPreparePass, AllocationDomainPass] -> [makeReshardingContiguous]. This avoids missed aliasing opportunites for operators added in the insertResharding and reorderShardedAxis pass.

Benchmarking results on GH200 nodes:

On main:

Name (time in ms)               Min     Max    Mean  StdDev  Median     IQR  Outliers       OPS  Rounds  Iterations

test_transformer_forward     6.2744  7.0567  6.4946  0.3369  6.2961  0.4077       1;0  153.9732       5           1
test_transformer_forward     6.2781  7.0573  6.4949  0.3368  6.2962  0.4076       1;0  153.9664       5           1
-------------------------------------------------------------------------------------------------------------------

test_transformer_backward     12.5244  13.7777  13.0152  0.6278  12.5900  1.1082       1;0  76.8331       5           1
test_transformer_backward     12.5348  13.7620  13.0204  0.6094  12.6391  1.0909       1;0  76.8024       5           1
-----------------------------------------------------------------------------------------------------------------------

This branch:


Name (time in ms)               Min     Max    Mean  StdDev  Median     IQR  Outliers       OPS  Rounds  Iterations
test_transformer_forward     6.2889  7.0885  6.5132  0.3481  6.2960  0.4302       1;0  153.5349       5           1
test_transformer_forward     6.2895  7.0262  6.5010  0.3231  6.2963  0.4195       1;0  153.8221       5           1

Name (time in ms)                 Min      Max     Mean  StdDev   Median     IQR  Outliers      OPS  Rounds  Iterations
test_transformer_backward     12.4542  13.6518  12.9532  0.5625  12.6231  0.9795       1;0  77.2012       5           1
test_transformer_backward     12.4778  13.6544  12.9510  0.5641  12.5828  0.9724       1;0  77.2139       5           1
-----------------------------------------------------------------------------------------------------------------------

@Priya2698
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Apr 18, 2025

Review updated until commit afacb99

Description

  • Reordered MarkAliasAnalysisPreparePass before propagateShardingsPass

  • Updated test cases for alias analysis and resharding expressions

  • Adjusted expected segment count in multidevice matmul test


Changes walkthrough 📝

Relevant files
Enhancement
alias_analysis.cpp
Remove resharding check in alias analysis                               

csrc/alias_analysis.cpp

  • Removed resharding check in AliasFinder::handle
+0/-4     
pre_segmenter.cpp
Reorder multidevice passes                                                             

csrc/preseg_passes/pre_segmenter.cpp

  • Moved multidevice passes after allocation-related passes
+10/-6   
Tests
test_alias_analysis.cpp
Update resharding alias test                                                         

tests/cpp/test_alias_analysis.cpp

  • Updated test case for resharding expressions
+4/-2     
test_multidevice_matmul.cpp
Update matmul segment count test                                                 

tests/cpp/test_multidevice_matmul.cpp

  • Adjusted expected segment count in matmul test
+1/-1     

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review

Pass Order

The order of passes has been changed, which may affect the performance and correctness of the fusion. Ensure that the new order does not introduce any unintended side effects.

// open an issue for this and see if we want to have a more aggressive
// approach inside MovePadPass instead. removes extra cast added from pushing
// pad out OptimizationPass<ConsecutiveCastPass>::runPass(fusion);
OptimizationPass<MarkAliasesPreparePass>::runPass(fusion);
OptimizationPass<ExactMappedExtentSubstitutionPass>::runPass(fusion);
OptimizationPass<AllocationDomainPass>::runPass(fusion);

// All the multidevice passes are moved after allocation related passes:
// MarkAliasesPreparePass, and AllocationDomainPass Multidevice passes will
// try to set the allocation domain for tvs with device mesh which will
// conflict with these passes.
OptimizationPass<PropagateShardingsPass>::runPass(fusion);
OptimizationPass<InsertReshardingsPass>::runPass(fusion);
OptimizationPass<ReorderShardedAxisPass>::runPass(fusion);
OptimizationPass<MakeReshardingContiguousPass>::runPass(fusion);
Test Update

The test AliasForReshardingExprs has been added, and the test NoAliasForReshardingExprs has been modified. Verify that these changes accurately reflect the expected behavior of the alias analysis.

// for alias analysis for resharding exprs
TEST_F(AliasAnalysisTest, AliasForReshardingExprs) {
  Fusion fusion;
  FusionGuard fg(&fusion);

  constexpr int kNumDevices = 4;
  auto mesh = DeviceMesh::createForNumDevices(kNumDevices);

  TensorView* in = makeContigTensor(2);
  TensorView* out = set(in);

  in->setDeviceMesh(mesh);
  in->axis(0)->parallelize(ParallelType::DIDx);
  out->setDeviceMesh(mesh);

  fusion.addInput(in);
  fusion.addOutput(out);

  AliasAnalysisResult analysis = findAliases(&fusion);
  EXPECT_TRUE(analysis.getRoot(out) == in);
}
Test Assertion

The assertion in Matmul_LayoutTN_Allgather has been updated to expect 3 segments instead of 2. Validate that this change is correct and that it aligns with the expected behavior of the fusion.

kernel_runtime->fusionSegments()->groups(),
Contains(HeuristicIs(SchedulerType::ExprEval)).Times(3));

@Priya2698 Priya2698 requested a review from wujingyue April 18, 2025 20:18
@Priya2698 Priya2698 marked this pull request as ready for review April 18, 2025 20:18
@Priya2698 Priya2698 changed the title move alias analysis before propagateShardings Move MarkAliasAnalysisPreparePass before propagateShardingsPass Apr 18, 2025
@wujingyue wujingyue requested a review from jjsjann123 April 18, 2025 20:22
@wujingyue
Copy link
Collaborator

@jjsjann123 do you request to check any perf benchmarks? I remember allocation-related passes have been sensitive to the order.

@jjsjann123
Copy link
Collaborator

@jjsjann123 do you request to check any perf benchmarks? I remember allocation-related passes have been sensitive to the order.

We don't have a specific benchmark that's just testing the allocation domain inference on a end2end examples. So it's tricky for us trying to figure out how to examine the perf impact from this change.

Meanwhile, since we are only moving the pass involving sharding passes. This shouldn't introduce any codegen diff for single GPU tests here. So maybe we can do a diff to verify that?

@Priya2698
Copy link
Collaborator Author

!test --diff

@Priya2698
Copy link
Collaborator Author

!test --diff

Copy link
Collaborator

@wujingyue wujingyue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Results for 2 GPUs on main:

Can you specify the benchmark environment when you show any benchmark results?

Also, I'd try viking-prod-pjnl which has H100x8

@Priya2698
Copy link
Collaborator Author

Priya2698 commented Apr 22, 2025

None of the single GPU tests are affected.
I am looking at the MultiDevice tests codegen changes, although they are expected to be impacted (there may be extra kernels due to segment_set inserted by the MarkAliasAnalysisPass.

Update: The codegen changes are due to missed aliasing opportunities for the final permute operation inserted in reorderShardedAxisPass (here and here), however this does not have a significant performance impact (see benchmark results). Since, only the input/output of the communication need to have the allocation domain specified, new_output can have the same allocation as output.

@Priya2698
Copy link
Collaborator Author

Results for 2 GPUs on main:

Can you specify the benchmark environment when you show any benchmark results?

Also, I'd try viking-prod-pjnl which has H100x8

Updated the comment, the existing results are from GH200.
viking-prod-pjnl is currently down. I'll update the comment when I get access to any of those nodes.

@Priya2698
Copy link
Collaborator Author

Using GH200 nodes:

Overlap allgather benchmarks:

On main:

Name (time in us)                                                                                    Min                   Max                  Mean              StdDev                Median                 IQR            Outliers         OPS            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

test_overlap_allgather_matmul_stream_outermost[s=1-backend_type=CommunicatorBackend.ucc]        769.7269 (1.0)      1,153.0219 (1.27)       811.4853 (1.01)      73.5782 (2.43)       792.3190 (1.00)      24.0248 (1.0)           1;2  1,232.3081 (0.99)         25           1
test_overlap_allgather_matmul_stream_outermost[s=1-backend_type=CommunicatorBackend.nccl]       770.2380 (1.00)       908.0300 (1.0)        800.4824 (1.0)       30.2735 (1.0)        789.9511 (1.0)       26.1198 (1.09)          3;1  1,249.2467 (1.0)          25           1
test_overlap_allgather_matmul_stream_outermost[s=8-backend_type=CommunicatorBackend.ucc]      3,983.1620 (5.17)     4,350.4890 (4.79)     4,062.3394 (5.07)      69.8278 (2.31)     4,044.6970 (5.12)      43.9199 (1.83)          3;2    246.1636 (0.20)         25           1
test_overlap_allgather_matmul_stream_outermost[s=8-backend_type=CommunicatorBackend.nccl]     4,003.1620 (5.20)     5,627.7660 (6.20)     4,194.1767 (5.24)     311.8692 (10.30)    4,160.9530 (5.27)     122.3673 (5.09)          1;1    238.4258 (0.19)         25           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

test_overlap_allgather_matmul_stream_outermost[s=1-backend_type=CommunicatorBackend.ucc]        787.8711 (1.04)       984.7671 (1.05)       826.1004 (1.04)      40.1384 (1.15)       818.7820 (1.04)      37.2551 (1.39)          3;1  1,210.5066 (0.97)         25           1
test_overlap_allgather_matmul_stream_outermost[s=1-backend_type=CommunicatorBackend.nccl]       754.3659 (1.0)        938.1110 (1.0)        797.2863 (1.0)       34.9924 (1.0)        788.6700 (1.0)       26.8243 (1.0)           3;2  1,254.2546 (1.0)          25           1
test_overlap_allgather_matmul_stream_outermost[s=8-backend_type=CommunicatorBackend.ucc]      4,122.8090 (5.47)     4,379.5439 (4.67)     4,162.4954 (5.22)      49.9946 (1.43)     4,155.7370 (5.27)      29.1997 (1.09)          1;1    240.2405 (0.19)         25           1
test_overlap_allgather_matmul_stream_outermost[s=8-backend_type=CommunicatorBackend.nccl]     4,142.6170 (5.49)     5,407.6710 (5.76)     4,326.8906 (5.43)     260.0871 (7.43)     4,227.9931 (5.36)     274.3837 (10.23)         1;1    231.1128 (0.18)         25           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

On this branch:

------------------------------------------------------------------------------------------------------------------------------ benchmark: 4 tests ------------------------------------------------------------------------------------------------------------------------------
Name (time in us)                                                                                    Min                   Max                  Mean              StdDev                Median                 IQR            Outliers         OPS            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_overlap_allgather_matmul_stream_outermost[s=1-backend_type=CommunicatorBackend.ucc]        807.4871 (1.06)     1,213.7579 (1.28)       858.3986 (1.06)      77.7686 (2.11)       840.3820 (1.05)      23.3600 (1.0)           1;3  1,164.9599 (0.94)         25           1
test_overlap_allgather_matmul_stream_outermost[s=1-backend_type=CommunicatorBackend.nccl]       761.6311 (1.0)        945.8550 (1.0)        809.3273 (1.0)       36.8810 (1.0)        798.2701 (1.0)       29.0880 (1.25)          3;2  1,235.5940 (1.0)          25           1
test_overlap_allgather_matmul_stream_outermost[s=8-backend_type=CommunicatorBackend.ucc]      4,482.9370 (5.89)     4,701.5600 (4.97)     4,568.5709 (5.64)      47.9336 (1.30)     4,566.4880 (5.72)      41.7361 (1.79)          6;3    218.8868 (0.18)         25           1
test_overlap_allgather_matmul_stream_outermost[s=8-backend_type=CommunicatorBackend.nccl]     4,152.9849 (5.45)     4,794.5840 (5.07)     4,269.6082 (5.28)     167.4440 (4.54)     4,200.0250 (5.26)     162.3110 (6.95)          2;2    234.2135 (0.19)         25           1

test_overlap_allgather_matmul_stream_outermost[s=1-backend_type=CommunicatorBackend.ucc]        839.8379 (1.08)     1,049.3100 (1.15)       888.6295 (1.11)      37.1963 (1.42)       883.2950 (1.11)     17.6237 (1.46)          2;2  1,125.3284 (0.90)         25           1
test_overlap_allgather_matmul_stream_outermost[s=1-backend_type=CommunicatorBackend.nccl]       776.3189 (1.0)        913.7900 (1.0)        800.1471 (1.0)       26.1388 (1.0)        794.2379 (1.0)      12.0320 (1.0)           2;3  1,249.7702 (1.0)          25           1
test_overlap_allgather_matmul_stream_outermost[s=8-backend_type=CommunicatorBackend.ucc]      4,205.9129 (5.42)     4,563.7370 (4.99)     4,280.3422 (5.35)      67.7339 (2.59)     4,267.9600 (5.37)     32.9513 (2.74)          3;2    233.6262 (0.19)         25           1
test_overlap_allgather_matmul_stream_outermost[s=8-backend_type=CommunicatorBackend.nccl]     4,040.8251 (5.21)     4,666.0400 (5.11)     4,198.2393 (5.25)     115.1670 (4.41)     4,189.0170 (5.27)     43.8160 (3.64)          4;4    238.1951 (0.19)         25           1

CPP transformer benchmarks:
On main:

Sequence Parallel: False

1: Average forward time 6.5ms
0: Average forward time 6.5ms
0: Average backward time 13ms
1: Average backward time 13ms

Sequence Parallel: True

1: Average forward time 6.3ms
0: Average forward time 6.3ms
0: Average backward time 13ms
1: Average backward time 13ms

On this branch:

Sequence Parallel: False

0: Average forward time 6.5ms
1: Average forward time 6.5ms
1: Average backward time 13.1ms
0: Average backward time 13.1ms

Sequence Parallel: True

0: Average forward time 6.3ms
1: Average forward time 6.3ms
1: Average backward time 13ms
0: Average backward time 13ms

I do not see any major performance dips for other benchmarks. However, the overlapping benchmarks are less stable with very high standard deviations.

@Priya2698
Copy link
Collaborator Author

On 8xA100 40GB(luna_prod):

CPP benchmarks
On main:

Sequence parallel: False
Average forward time 6.9ms
Average backward time 11.2ms

Sequence parallel: True
Average forward time 6.4ms
Average backward time 10.3ms

This branch:

Sequence parallel: False
Average forward time 6.9ms
Average backward time 10.8-10.9ms

Sequence parallel: True
Average forward time 6.4-6.5ms
Average backward time 9.9ms

Python transformer benchmarks:
On main:

Name (time in ms)               Min      Max    Mean  StdDev  Median     IQR  Outliers       OPS  Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------
test_transformer_forward     7.2599  10.0077  8.0222  1.1733  7.3541  1.3432       1;0  124.6539       5           1
test_transformer_backward     14.3580  20.7857  15.6574  2.8668  14.3675  1.6381       1;1  63.8674       5           1

This branch:

Name (time in ms)               Min     Max    Mean  StdDev  Median     IQR  Outliers       OPS  Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------
test_transformer_forward     7.2479  9.7135  7.9143  1.0730  7.2500  1.2643       1;0  126.3539       5           1
test_transformer_backward     13.6592  16.2368  14.1995  1.1392  13.6857  0.6783       1;1  70.4252       5           1

@Priya2698
Copy link
Collaborator Author

Priya2698 commented Apr 23, 2025

On 8xA100 40GB(luna_prod):
Overlap allgather matmul benchmarks.
On main:

----------------------------------------------------------------------------------
Name (time in us)                                                                                    Min                   Max                  Mean              StdDev                Median                IQR            Outliers       OPS            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_overlap_allgather_matmul_stream_outermost[s=1-backend_type=CommunicatorBackend.nccl]       997.6337 (1.0)      3,330.4309 (1.96)     1,208.4252 (1.12)     497.3844 (2.93)     1,036.8880 (1.03)     62.8247 (6.08)          3;5  827.5233 (0.90)         25           1
test_overlap_allgather_matmul_stream_outermost[s=1-backend_type=CommunicatorBackend.ucc]      1,004.3662 (1.01)     1,701.0877 (1.0)      1,082.9736 (1.0)      169.6304 (1.0)      1,010.7378 (1.0)      10.3327 (1.0)           3;5  923.3835 (1.0)          25           1
test_overlap_allgather_matmul_stream_outermost[s=8-backend_type=CommunicatorBackend.nccl]     5,114.2760 (5.13)     6,057.7570 (3.56)     5,219.6112 (4.82)     249.6188 (1.47)     5,126.0488 (5.07)     21.3305 (2.06)          3;5  191.5852 (0.21)         25           1
test_overlap_allgather_matmul_stream_outermost[s=8-backend_type=CommunicatorBackend.ucc]      5,100.4710 (5.11)     9,172.3478 (5.39)     5,332.4179 (4.92)     817.5165 (4.82)     5,112.7030 (5.06)     40.7953 (3.95)          1;6  187.5322 (0.20)         25           1

This branch:

----------------------------------------------------------------------------------
Name (time in us)                                                                                    Min                   Max                  Mean              StdDev                Median                IQR            Outliers       OPS            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_overlap_allgather_matmul_stream_outermost[s=1-backend_type=CommunicatorBackend.nccl]       983.6568 (1.0)      2,261.9828 (1.25)     1,051.0509 (1.0)      255.5715 (1.28)       986.5817 (1.0)        7.6344 (1.0)           1;4  951.4287 (1.0)          25           1
test_overlap_allgather_matmul_stream_outermost[s=1-backend_type=CommunicatorBackend.ucc]      1,001.9219 (1.02)     1,802.8207 (1.0)      1,095.7827 (1.04)     199.5873 (1.0)      1,017.8220 (1.03)      44.0930 (5.78)          3;3  912.5897 (0.96)         25           1
test_overlap_allgather_matmul_stream_outermost[s=8-backend_type=CommunicatorBackend.nccl]     5,091.1531 (5.18)     6,288.4660 (3.49)     5,262.4144 (5.01)     246.2242 (1.23)     5,230.3867 (5.30)     160.5271 (21.03)         2;2  190.0268 (0.20)         25           1
test_overlap_allgather_matmul_stream_outermost[s=8-backend_type=CommunicatorBackend.ucc]      5,099.9811 (5.18)     8,211.1852 (4.55)     5,356.8633 (5.10)     613.3263 (3.07)     5,219.2560 (5.29)     137.0180 (17.95)         1;2  186.6764 (0.20)         25           1
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

@Priya2698
Copy link
Collaborator Author

Priya2698 commented Apr 23, 2025

@wujingyue viking-prod-pjnl maybe down for sometime. I ran on A100x8 and the results are in the above comments. We do not see a performance penalty due to missed aliasing opportunity at this time.
Note, that the earlier results for 2 nodes were using GH200 nodes.

Let me know if you would like to see any additional results. I will merge the PR after this.

@wujingyue
Copy link
Collaborator

As you said, codegen diff shows a potential performance hit but it doesn't seem to affect any real benchmarks and you have an idea how to fix this in the short future (e.g. move markAliasesPrepare between reorder and set-allocation). So LGTM!

@Priya2698 Priya2698 merged commit 096b681 into main Apr 23, 2025
58 of 60 checks passed
@Priya2698 Priya2698 deleted the pm/alias_analysis branch April 23, 2025 07:09
Priya2698 added a commit that referenced this pull request Apr 24, 2025
This PR extends the `propagateSharding` presegmentation pass for DID
loop splits.
Key changes:
1. We use TransformPropagator for all expressions except `ViewOp` which
is handled manually since TransformPropagator does not support it
without first propagating the reshape to the producer.
2. `makeReshardingContiguous` sets allocation domain for tvs with device
mesh. Ideally, we need to set it only for global tensors but this is not
known before segmentation, but should be set before segmentation.
3. ~The following tests are modified: See
[discussion](#3838 (comment).
PR #4274 resolved this.

Follow-up PRs:

- `ViewOp` will be handled in a followup PR.
- Currently, we only backpropagate sharding for a tv that does not
already have a device dimension. This can be extended to propagate for
all parallel types not present on the tv. This will be done in a
followup. Backpropagating shardings can incorrectly change DIDx to
serial or modify DIDx to be on another location. `shardAllLike` can be
modified to specify which parallel type to propagate. Since
`insertResharding` and `propagateSharding` require different behavior, I
will handle it in a separate PR.
- Use `TransformReplay::CasP` in lieu of TransformPropagator.
- Propagate DID transforms within `castOp`:
[privatizeUpcast](https://github.com/NVIDIA/Fuser/blob/ed687366cf717837c8ea3e40f56542fec48e1616/csrc/fusion_segmenter.cpp#L4235-L4238)
clones cast operations, which fails segmentation since the transforms
are not replicated.

Findings from experiments:
#3838 (comment)

---------

Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants