-
Notifications
You must be signed in to change notification settings - Fork 79
CTA Swizzles #87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CTA Swizzles #87
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -729,6 +729,100 @@ TEST_F(NVFuserTest, FusionAmpereMatmulPipelineGmem_CUDA) { | |
| } | ||
| } | ||
|
|
||
| // Matmul test for Ampere MMA: checking CTA Swizzles | ||
| TEST_F(NVFuserTest, FusionAmpereSwizzle_CUDA) { | ||
| // Keep multiples of 8 to keep vectorizable. | ||
| int dim = 8192; | ||
| int M = dim, N = dim, K = dim; | ||
| const auto all_orders = { | ||
| MatmulParam::TileRasterizationOrder::RowMajor, | ||
| MatmulParam::TileRasterizationOrder::ColumnMajor}; | ||
|
|
||
| REQUIRE_DEVICE_SMEM_SIZE(70 << 10, 0); | ||
|
|
||
| auto test = [&](MatmulLayout layout, | ||
| MatmulParam::TileRasterizationOrder order, | ||
| int swizzle, | ||
| float& runtime) { | ||
| Fusion fusion; | ||
| FusionGuard fg(&fusion); | ||
| auto tv0 = makeContigTensor(2, DataType::Half); | ||
| auto tv1 = makeContigTensor(2, DataType::Half); | ||
|
|
||
| fusion.addInput(tv0); | ||
| fusion.addInput(tv1); | ||
|
|
||
| auto tv2 = matmul(tv0, tv1, layout); | ||
|
|
||
| fusion.addOutput(tv2); | ||
|
|
||
| MatMulTileOptions gemm_tile; | ||
| gemm_tile.cta_tile = GemmTile(128, 128, 32); | ||
| gemm_tile.warp_tile = GemmTile(64, 64, 32); | ||
| gemm_tile.instruction_tile = GemmTile(16, 8, 16); | ||
|
|
||
| auto mma_builder = | ||
| MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) | ||
| .layout(layout); | ||
|
|
||
| MatmulParam params(mma_builder); | ||
| params.tile_sizes = gemm_tile; | ||
| params.async_gmem_load_operands = true; | ||
| params.double_buffer_options.double_buffer_smem_write = true; | ||
| params.double_buffer_options.double_buffer_smem_read = true; | ||
| params.double_buffer_options.smem_double_buffer_stage = 3; | ||
|
|
||
| params.rasterization_order = order; | ||
| params.grid_swizzle_factor = swizzle; | ||
|
|
||
| scheduleMatmul(tv2, tv0, tv1, params); | ||
|
|
||
| at::manual_seed(0); | ||
| auto inputs = fp16MatmulAtInput(M, N, K, layout); | ||
|
|
||
| FusionExecutor fe; | ||
| fe.setMeasureKernelTimeFlag(true); | ||
| NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( | ||
| 8, | ||
| 0, | ||
| fe.compileFusion( | ||
| &fusion, | ||
| {inputs.first, inputs.second}, | ||
| LaunchParams(), | ||
| matmul_cparams)); | ||
| auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); | ||
| auto tref = atMatmul( | ||
| inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); | ||
| TORCH_CHECK(cg_outputs[0].allclose(tref, 0.01, 0.01)); | ||
|
|
||
| int gdimx = fe.lastLaunchParams().gdimx(); | ||
| int gdimy = fe.lastLaunchParams().gdimy(); | ||
|
|
||
| int expected_gdim_unswizzled = (dim + 128 - 1) / 128; | ||
| int expected_gdimx = expected_gdim_unswizzled * swizzle; | ||
| int expected_gdimy = (expected_gdim_unswizzled + swizzle - 1) / swizzle; | ||
|
|
||
| TORCH_CHECK(gdimx == expected_gdimx); | ||
| TORCH_CHECK(gdimy == expected_gdimy); | ||
|
|
||
| runtime = fe.kernelTimeMs(); | ||
| }; | ||
|
|
||
| // Gmem pipeline stage | ||
|
|
||
| for (auto layout : {MatmulLayout::TT}) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just removed it to keep a short runtime as the test checks four configs per layout already
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Current test takes 15s, would jump to 45s.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, then let's just test one layout |
||
| for (auto order : all_orders) { | ||
| float runtime1 = 0; | ||
| test(layout, order, 1, runtime1); | ||
|
|
||
| float runtime4 = 0; | ||
| test(layout, order, 4, runtime4); | ||
|
|
||
| TORCH_CHECK(runtime4 < runtime1); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| TEST_F(NVFuserTest, FusionAmpereMatmulRegDoubleBuffer_CUDA) { | ||
| // Keep multiples of 8 to keep vectorizable. | ||
| int M = 504, N = 136, K = 248; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this will add some debug printting when running this test? Could you comment this line out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No debug prints. The only effect is to create cudaEvents and return the runtime through
fe.kernelTimeMs()(otherwise it's just zero).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahhh, I thought it would automatically
std::cout << fe.kernelTimeMs(). If no debug prints, then we can keep it.