From 0673e8aef33fedc28941998a880c6822b5b10157 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 18 May 2023 09:55:13 -0400 Subject: [PATCH 1/4] Remove extent>0 check in IterDomain::split --- csrc/ir_nodes.cpp | 4 --- test/test_resize.cpp | 66 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 4 deletions(-) diff --git a/csrc/ir_nodes.cpp b/csrc/ir_nodes.cpp index 4c4634841a4..75c84576d32 100644 --- a/csrc/ir_nodes.cpp +++ b/csrc/ir_nodes.cpp @@ -2330,10 +2330,6 @@ std::pair IterDomain::split( bool inner_split, Val* start_offset, Val* stop_offset) { - TORCH_CHECK( - !in->extent()->isZeroInt(), - "Splitting IterDomains with ending values that are 0 is not supported at this time."); - TORCH_CHECK( factor->isIntegralScalar(), "Cannot split by non-integer value ", factor); diff --git a/test/test_resize.cpp b/test/test_resize.cpp index d8fe4cffa76..19a7c97f91a 100644 --- a/test/test_resize.cpp +++ b/test/test_resize.cpp @@ -2053,4 +2053,70 @@ TEST_F(NVFuserTest, ResizePermuteAndSlice_CUDA) { __FILE__); } +// When scheduling this test, the pointwise scheduler attempt to replay a Split +// transform on a size-0 dimension, which is not allowed. +TEST_F(NVFuserTest, FusionSizeZeroSliceSplit_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + std::vector shape({8}); + + // concrete shapes to avoid dynamic Fusion + auto tv0 = makeContigConcreteTensor(shape); + fusion->addInput(tv0); + + auto tv1 = slice( + tv0, + {{IrBuilder::create(0), + IrBuilder::create(2), + IrBuilder::create(1)}}); + auto tv2 = slice( + tv0, + {{IrBuilder::create(2), + IrBuilder::create(4), + IrBuilder::create(1)}}); + auto tv3 = slice( + tv0, + {{IrBuilder::create(4), + IrBuilder::create(6), + IrBuilder::create(1)}}); + auto tv4 = slice( + tv0, + {{IrBuilder::create(6), + IrBuilder::create(6), + IrBuilder::create(1)}}); + auto tv5 = slice( + tv0, + {{IrBuilder::create(6), + IrBuilder::create(6), + IrBuilder::create(1)}}); + auto tv6 = slice( + tv0, + {{IrBuilder::create(6), + IrBuilder::create(8), + IrBuilder::create(1)}}); + fusion->addOutput(tv1); + fusion->addOutput(tv2); + fusion->addOutput(tv3); + fusion->addOutput(tv4); + fusion->addOutput(tv5); + fusion->addOutput(tv6); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + FusionExecutor fe; + + auto ref0 = t0.index({at::indexing::Slice(0, 2)}); + auto ref1 = t0.index({at::indexing::Slice(0, 4)}); + + TORCH_CHECK(ref0.equal(cg_outputs[0])); + TORCH_CHECK(ref1.equal(cg_outputs[1])); +} + } // namespace nvfuser From df36a61ccec225225fa5c41d1031d39e0b5e2d73 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 18 May 2023 10:05:16 -0400 Subject: [PATCH 2/4] Add simpler merge/split test Kept the old one which exercises the pointwise scheduler too --- test/test_resize.cpp | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/test/test_resize.cpp b/test/test_resize.cpp index 19a7c97f91a..aec0f055e80 100644 --- a/test/test_resize.cpp +++ b/test/test_resize.cpp @@ -2055,7 +2055,7 @@ TEST_F(NVFuserTest, ResizePermuteAndSlice_CUDA) { // When scheduling this test, the pointwise scheduler attempt to replay a Split // transform on a size-0 dimension, which is not allowed. -TEST_F(NVFuserTest, FusionSizeZeroSliceSplit_CUDA) { +TEST_F(NVFuserTest, FusionSizeZeroSliceSplitSchedule_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -2119,4 +2119,30 @@ TEST_F(NVFuserTest, FusionSizeZeroSliceSplit_CUDA) { TORCH_CHECK(ref1.equal(cg_outputs[1])); } +// In this test, we split and merge with size-zero dimensions directly. +TEST_F(NVFuserTest, FusionSizeZeroSliceSplit_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + std::vector shape({4, 5}); + + // concrete shapes to avoid dynamic Fusion + auto tv0 = makeContigConcreteTensor(shape); + fusion->addInput(tv0); + + auto tv1 = slice( + tv0, + {{IrBuilder::create(2), + IrBuilder::create(2), + IrBuilder::create(1)}, + {IrBuilder::create(0), + IrBuilder::create(5), + IrBuilder::create(1)}}); + // tv1 is of shape {0, 5} + fusion->addOutput(tv1); + + tv1->merge(0, 1); // size 0*5 = 0 + tv1->split(0, 4); // sizes (0, 4) +} + } // namespace nvfuser From b3fdafe660b7b965abfaa4073c6f7aaf881cb115 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 18 May 2023 10:25:47 -0400 Subject: [PATCH 3/4] Remove size-zero merge assert --- csrc/ir_nodes.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/csrc/ir_nodes.cpp b/csrc/ir_nodes.cpp index 75c84576d32..e42998b00d9 100644 --- a/csrc/ir_nodes.cpp +++ b/csrc/ir_nodes.cpp @@ -2250,9 +2250,6 @@ std::vector IterDomain::clone( // domains have valid start and stop, it's not possible to contiguous // predication. IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) { - TORCH_CHECK( - !outer->extent()->isZeroInt() && !inner->extent()->isZeroInt(), - "Merging IterDomains with ending values that are 0 is not supported at this time."); TORCH_CHECK( outer->isReduction() == inner->isReduction(), "Merging IterDomains requires that their iteration types match. ", From e43ac6f25260ab3d8b80df4ca1702617f200d20b Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 18 May 2023 10:34:20 -0400 Subject: [PATCH 4/4] Complete merge/split test --- test/test_resize.cpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/test/test_resize.cpp b/test/test_resize.cpp index aec0f055e80..77ea76f34ac 100644 --- a/test/test_resize.cpp +++ b/test/test_resize.cpp @@ -2143,6 +2143,21 @@ TEST_F(NVFuserTest, FusionSizeZeroSliceSplit_CUDA) { tv1->merge(0, 1); // size 0*5 = 0 tv1->split(0, 4); // sizes (0, 4) + + FusionExecutor fe; + fe.compileFusion(fusion.get()); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); + + auto cg_outputs = fe.runFusion(aten_inputs); + + auto ref0 = t0.index({at::indexing::Slice(2, 2), at::indexing::Slice(0, 5)}); + + TORCH_CHECK(ref0.equal(cg_outputs[0])); } } // namespace nvfuser