Skip to content

Switch axis we use to compute swizzled_tiles#4311

Merged
jacobhinkle merged 1 commit intomainfrom
jh/transpose_grid_traversal_limit
Apr 25, 2025
Merged

Switch axis we use to compute swizzled_tiles#4311
jacobhinkle merged 1 commit intomainfrom
jh/transpose_grid_traversal_limit

Conversation

@jacobhinkle
Copy link
Collaborator

#4242 turned on "grid traversal factor" which is a good thing. However, it exposed a bug in how we limit that factor to prevent overrun in case the swizzled axis has fewer tiles than the factor. This led to a regression from 58% to 35% geomean perf compared to eager on H200.

This PR swaps the axes used to compute the number of swizzled tiles and takes us from a geomean of 35% to 65% on benchmarks/python/test_matmul.py on H200.

@jacobhinkle jacobhinkle requested a review from rdspring1 April 24, 2025 19:58
@jacobhinkle
Copy link
Collaborator Author

!build

@github-actions
Copy link

Description

  • Switched axis for computing swizzled tiles to improve performance

  • Added memory check to skip large test cases in benchmarks


Changes walkthrough 📝

Relevant files
Enhancement
matmul_utils.cpp
Switched swizzled_tiles axis logic                                             

csrc/scheduler/matmul_utils.cpp

  • Changed the logic to determine the swizzled_tiles axis
+1/-1     
test_matmul.py
Added memory check for large test cases                                   

benchmarks/python/test_matmul.py

  • Added memory check to skip large test cases
+6/-0     

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review

Swizzled Tiles Calculation

The change in the calculation of swizzled_tiles may have unintended consequences on performance or correctness. Ensure that this change does not introduce any edge cases or regressions.

int64_t swizzled_tiles = Mtiles >= Ntiles ? Ntiles : Mtiles;
Memory Check

The added memory check in the benchmark tests is a good practice to prevent OOM errors. Ensure that the threshold of 20GiB is appropriate and that no important test cases are being skipped unnecessarily.

if (m * k + n * k + m * n) * 2 > 20 * (2**30):
    pytest.skip("Case takes more than 20GiB. Skipping to avoid OOM")
Memory Check

The added memory check in the benchmark tests is a good practice to prevent OOM errors. Ensure that the threshold of 20GiB is appropriate and that no important test cases are being skipped unnecessarily.

if (m * k + n * k + m * n) * 2 > 20 * (2**30):
    pytest.skip("Case takes more than 20GiB. Skipping to avoid OOM")

Comment on lines +44 to +46
if (m * k + n * k + m * n) * 2 > 20 * (2**30):
pytest.skip("Case takes more than 20GiB. Skipping to avoid OOM")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Limiting mem use to 20GB. This is conservative but we don't expect problem sizes bigger than this for DL at this time, and it prevents OOM on most devices.

@jacobhinkle jacobhinkle merged commit fadfde5 into main Apr 25, 2025
16 checks passed
@jacobhinkle jacobhinkle deleted the jh/transpose_grid_traversal_limit branch April 25, 2025 01:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants