Skip to content

Prepare matmul schedulers for 2d grid traversal pattern#4242

Merged
rdspring1 merged 3 commits intomainfrom
grid_traversal_2d
Apr 15, 2025
Merged

Prepare matmul schedulers for 2d grid traversal pattern#4242
rdspring1 merged 3 commits intomainfrom
grid_traversal_2d

Conversation

@rdspring1
Copy link
Collaborator

@rdspring1 rdspring1 commented Apr 11, 2025

This PR prepares hopper matmul scheduler to use 2d grid traversal pattern.

  • Rename grid_swizzle_factor to grid_traversal_factor in matmul schedulers.
  • Changes grid_swizzle_factor from an int to std::pair<int, int>
  • If grid_traversal_factor.second == 1, thengrid_traversal_factor.first == grid_swizzle_factor.
  • Rename swizzleBlockTiles function to reorderBlockTileTraversal.

@github-actions
Copy link

Description

  • Rename grid_swizzle_factor to grid_traversal_factor

  • Change grid_traversal_factor type to std::pair<int, int>

  • Rename swizzleBlockTiles to reorderBlockTileTraversal


Changes walkthrough 📝

Relevant files
Enhancement
9 files
python_bindings.cpp
Rename `grid_swizzle_factor` to `grid_traversal_factor`   
+1/-1     
ampere_multi_matmul.cpp
Rename swizzleBlockTiles to reorderBlockTileTraversal and update
grid_traversal_factor
+8/-4     
hopper_multi_matmul.cpp
Rename swizzleBlockTiles to reorderBlockTileTraversal and update
grid_traversal_factor
+8/-4     
matmul_heuristic_plugin.cpp
Update `grid_traversal_factor` in config and params           
+2/-2     
matmul_utils.cpp
Update `grid_traversal_factor` calculation                             
+5/-6     
test_matmul.cpp
Update `grid_traversal_factor` in tests                                   
+2/-2     
ampere_multi_matmul.h
Rename `swizzleBlockTiles` to `reorderBlockTileTraversal`
+3/-2     
hopper_multi_matmul.h
Rename `swizzleBlockTiles` to `reorderBlockTileTraversal`
+3/-2     
matmul_heuristic.h
Change grid_swizzle_factor to std::pair and update related
methods
+7/-6     

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review

Grid Traversal Factor

The code sets grid_traversal_factor to {grid_traversal_factor, 1} and disables 2D grid traversal. Ensure that this change does not inadvertently disable a feature that was previously enabled with grid_swizzle_factor.

mparams->grid_traversal_factor = {grid_traversal_factor, 1};
Grid Swizzle Factor

The code copies grid_traversal_factor.first to grid_swizzle_factor. Verify that this change is intentional and that the second element of grid_traversal_factor is not needed in the KernelConfig.

config->grid_swizzle_factor = mparams->grid_traversal_factor.first;
config->cta_order =
Grid Swizzle Factor

The code copies config->grid_swizzle_factor to grid_traversal_factor.first. Ensure that this change is intentional and that the second element of grid_traversal_factor is correctly handled.

mparams->grid_traversal_factor.first = config->grid_swizzle_factor;
switch (config->cta_order) {

@rdspring1 rdspring1 changed the title Rename grid_swizzle_factor to grid_traversal_factor. Prepare matmul schedulers for 2d grid traversal pattern Apr 11, 2025
@rdspring1
Copy link
Collaborator Author

!test

@rdspring1 rdspring1 marked this pull request as ready for review April 11, 2025 20:39
Copy link
Collaborator

@jacobhinkle jacobhinkle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me and is a step in the right direction. Maybe we could generalize it a bit further in the future by defining the operation recursively and holding a vector of ints instead of a tuple.

@rdspring1 rdspring1 merged commit 0b59727 into main Apr 15, 2025
53 checks passed
@rdspring1 rdspring1 deleted the grid_traversal_2d branch April 15, 2025 15:33
jacobhinkle added a commit that referenced this pull request Apr 25, 2025
#4242 turned on "grid traversal factor" which is a good thing. However,
it exposed a bug in how we limit that factor to prevent overrun in case
the swizzled axis has fewer tiles than the factor. This led to a
regression from 58% to 35% geomean perf compared to eager on H200.

This PR swaps the axes used to compute the number of swizzled tiles and
takes us from a geomean of 35% to 65% on
`benchmarks/python/test_matmul.py` on H200.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants