Conversation
|
Review updated until commit 8649ff1 Description
|
| Relevant files | |||||
|---|---|---|---|---|---|
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Device Mesh Validation
|
95871f0 to
54bbceb
Compare
Greptile SummaryMigrates the Key Changes:
Implementation Notes:
Confidence Score: 4/5
Important Files Changed
Last reviewed commit: 8649ff1 |
chaserileyroberts
left a comment
There was a problem hiding this comment.
Small comments
| (z_out,) = fd.execute( | ||
| [ | ||
| z_in, | ||
| w_norm_in, | ||
| b_norm_in, | ||
| w_p_in, | ||
| w_g_in, | ||
| w_norm_out, | ||
| b_norm_out, | ||
| w_p_out, | ||
| w_g_out, | ||
| mask, | ||
| ] | ||
| ) |
There was a problem hiding this comment.
Do we want to include a numerics execution test as well? Although I know those are hard to make for more complex ops.
There was a problem hiding this comment.
Do we want to include a numerics execution test as well?
Yes, we should.
Although I know those are hard to make for more complex ops.
Yes, it's hard to figure out a reasonable comparison threshold. We may have to go with toy sizes as in
There was a problem hiding this comment.
Are you aiming to do this in the PR?
I do think we should have a numerics check to avoid silent errors within schedulers.
There was a problem hiding this comment.
Are you aiming to do this in the PR?
I do think we should have a numerics check to avoid silent errors within schedulers.
Yes
| case Direction.INCOMING: | ||
| matmul_out.axis(-2).parallelize(nvfuser.ParallelType.mesh_y) | ||
| matmul_out.outer_split(3, cp_size) | ||
| matmul_out.axis(3).parallelize(nvfuser.ParallelType.mesh_x) |
There was a problem hiding this comment.
Do the X, Y, and Z mesh axis represent real hardware communication links? I.e., Communicating over X would always be local NVLinks, Y being infiniband, etc, or is that arbitrary and left to the how the mesh gets constructed?
There was a problem hiding this comment.
is that arbitrary and left to the how the mesh gets constructed?
That
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
|
!test |
|
!test |
I got the "triangle updates" test passing with 3D sharding in this PR. Below are the key issues identified, workarounds and their current status:
1. Sharding Propagation Rework
2. Multi-Dimensional Sharding &
getCommunicationInfogetCommunicationInfoto support multi-dimensional sharding. It reuseshaveDifferentShardingsto identify inconsistencies between input and outputTensorViewobjects. The commit needs cleanup and further test verification to be merged.haveDifferentShardingsis currently bottlenecked by the expensiveExpressionSimplifier. We need to transition this to be IdModel-based in a future iteration.3. Misaligned Memory Access in Transpose Kernels
ReorderShardedAxisPassto ensure the scattered axis of theReduceScatteris allocated outermost.4. High memory usage
AllGatherpreceding the Einsum is functional but consumes too much memory for AlphaFold3 workloads due to long sequence lengths.cc @DejunL