diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index b4f205ca487..60cc4d28b7a 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -316,7 +316,7 @@ void DynamicTransformConcretizationInfo::analyzeResizes( out_id->toString()); auto extent_int = extent_val.as(); NVF_ERROR( - extent_int > 0, + extent_int >= 0, "Invalid resized domain extent ", extent_int, " for domain ", diff --git a/csrc/kernel_cache.cpp b/csrc/kernel_cache.cpp index 085c8c19636..c44fb3b5fcd 100644 --- a/csrc/kernel_cache.cpp +++ b/csrc/kernel_cache.cpp @@ -897,8 +897,8 @@ FusionKernelRuntime::FusionKernelRuntime( fusion.get()); if (isDebugDumpEnabled(DebugDumpOption::FusionIrPreseg)) { - std::cout << "Fusion IR after pre-segmenter optimization passes:" - << std::endl; + debug() << "Fusion IR after pre-segmenter optimization passes:" + << std::endl; fusion->printMath(); } diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 2f28f406c95..bfd831e0d8b 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -690,9 +690,6 @@ TensorView* cat( return out; } -// Currently there's no error check about the actual values of the -// Slice parameters. For example, the start parameter of a range of a -// domain is assumed to be >= 0 and < the extent of the domain. TensorView* slice(TensorView* inp, const std::vector& ranges) { const auto inp_dom = TensorDomain::noReductions(inp->getMaybeRFactorDomain()); const int ndims = static_cast(inp_dom.size()); @@ -704,28 +701,50 @@ TensorView* slice(TensorView* inp, const std::vector& ranges) { ", Expected: ", ndims); - auto normalize_slice_range = [](Slice range, Val* extent) -> Slice { + const auto normalize_slice_range = [](Slice range, Val* extent) -> Slice { + auto cast_extent = + SimplifyingIrBuilder::maybeCastExpr(DataType::Index, extent); + + auto zero = FusionGuard::getCurFusion()->zeroVal(DataType::Index); + + // norm_start = max(0, start < 0 ? start + extent : start) if (range.start == nullptr) { - range.start = FusionGuard::getCurFusion()->zeroVal(); - } - if (range.stop == nullptr) { - range.stop = extent; - } - if (range.step == nullptr) { - range.step = FusionGuard::getCurFusion()->oneVal(); - } - if (range.start->dtype() != DataType::Index) { + range.start = zero; + } else if (!range.start->isZeroInt()) { range.start = SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.start); + range.start = SimplifyingIrBuilder::maxExpr( + zero, + SimplifyingIrBuilder::whereExpr( + SimplifyingIrBuilder::ltExpr(range.start, zero), + SimplifyingIrBuilder::addExpr(range.start, cast_extent), + range.start)); } - if (range.stop->dtype() != DataType::Index) { + + // norm_stop = max(norm_start, min(extent, stop < 0 ? stop + extent : stop) + if (range.stop == nullptr) { + range.stop = cast_extent; + } else if (!range.stop->sameAs(extent)) { range.stop = SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.stop); + range.stop = SimplifyingIrBuilder::maxExpr( + range.start, + SimplifyingIrBuilder::minExpr( + cast_extent, + SimplifyingIrBuilder::whereExpr( + SimplifyingIrBuilder::ltExpr(range.stop, zero), + SimplifyingIrBuilder::addExpr(range.stop, cast_extent), + range.stop))); } - if (range.step->dtype() != DataType::Index) { + + // Ensure step is of type Index + if (range.step == nullptr) { + range.step = FusionGuard::getCurFusion()->oneVal(DataType::Index); + } else { range.step = SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.step); } + return range; }; @@ -733,7 +752,7 @@ TensorView* slice(TensorView* inp, const std::vector& ranges) { // Step not supported yet NVF_CHECK( range.step == nullptr || range.step->isOneInt(), - "Unsupported step: ", + "Unsupported step (must be 1 or null): ", range.step->toString()); } @@ -754,12 +773,13 @@ TensorView* slice(TensorView* inp, const std::vector& ranges) { out_root_id = inp_root_id->cloneWithoutRFactor(); out_rf_id = out_root_id; } else { + // Clip the start and stop values to the extent of the input out_root_id = IterDomainBuilder(inp_root_id).is_rfactor_domain(true).build(); out_rf_id = IterDomain::resize( out_root_id, SimplifyingIrBuilder::negExpr(range.start), - sub(range.stop, inp_root_id->extent()), + SimplifyingIrBuilder::subExpr(range.stop, inp_root_id->extent()), true); needs_real_slicing = true; } diff --git a/csrc/ops/alias.h b/csrc/ops/alias.h index 3f540ab3373..25e3e80293d 100644 --- a/csrc/ops/alias.h +++ b/csrc/ops/alias.h @@ -91,7 +91,9 @@ TensorView* cat( std::optional iter_type_opt = std::nullopt); //! Return a tensor where each dimension is sliced as specified by the -//! ranges parameter. Stepping must be one at this moment. +//! ranges parameter. Stepping must be one at this moment. The semantics of +//! slicing with negative values and values >= extent follow those of numpy and +//! PyTorch. TensorView* slice(TensorView* inp, const std::vector& ranges); } // namespace nvfuser diff --git a/test/test_resize.cpp b/test/test_resize.cpp index 84a519cb688..7d04e555040 100644 --- a/test/test_resize.cpp +++ b/test/test_resize.cpp @@ -1124,6 +1124,132 @@ TEST_F(ResizeTest, FusionResizeSlice5) { testValidate(&fusion, cg_outputs, aten_inputs, {t2, t4}, __LINE__, __FILE__); } +std::vector> slice_cases( + {{0, 5}, + {3, 9}, + {3, 4}, + {7, 5}, + {0, 11}, + {11, 13}, + {-3, 8}, + {-3, -1}, + {-3, -5}, + {13, -1}, + {-11, 9}, + {-11, 0}, + {-13, -11}}); + +// Test slice with a variety of constant ranges +TEST_F(NVFuserTest, FusionResizeSliceConstantShmoo_CUDA) { + for (auto [start, stop] : slice_cases) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape({9}); + + // concrete shapes to avoid dynamic Fusion + auto tv0 = makeConcreteTensor(shape); + fusion.addInput(tv0); + + auto tv1 = slice( + tv0, {{IrBuilder::create(start), IrBuilder::create(stop)}}); + fusion.addOutput(tv1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + auto t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); + } +} + +// Test slice with a variety of non-constant input ranges +TEST_F(NVFuserTest, FusionResizeSliceInputShmoo_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape({9}); + + // concrete shapes to avoid dynamic Fusion + auto tv0 = makeConcreteTensor(shape); + auto s0 = IrBuilder::create(DataType::Index); + auto s1 = IrBuilder::create(DataType::Index); + fusion.addInput(tv0); + fusion.addInput(s0); + fusion.addInput(s1); + + auto tv1 = slice(tv0, {{s0, s1}}); + fusion.addOutput(tv1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + { + // Concretize so that we set output IterType as Iteration. We should now + // have expressions that work with any input range. + ExpressionEvaluator expr_eval; + + expr_eval.bind(tv0->axis(0)->extent(), 9); + expr_eval.bind(s0, 0); + expr_eval.bind(s1, 9); + + auto initial_info = DynamicTransform::getInitialInfo(&fusion); + auto info = DynamicTransformConcretizationInfo(&initial_info, &expr_eval); + + DynamicTransform::concretizeFusion(&fusion, &info); + NVF_CHECK( + !fusion.hasDynamicTransform(), "Expected to have no dynamic transform"); + } + + FusionExecutor fe; + fe.compileFusion(&fusion); + + auto t0 = at::randn(shape, options); + for (auto [start, stop] : slice_cases) { + std::vector aten_inputs({t0, start, stop}); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); + } +} + +// Same as FusionResizeSliceInputShmoo_CUDA but use FusionExecutorCache, which +// might re-concretize when output sizes change +TEST_F(NVFuserTest, FusionResizeSliceInputShmooFusionExecutorCache_CUDA) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + std::vector shape({9}); + + // concrete shapes to avoid dynamic Fusion + auto tv0 = makeConcreteTensor(shape); + auto s0 = IrBuilder::create(DataType::Index); + auto s1 = IrBuilder::create(DataType::Index); + fusion->addInput(tv0); + fusion->addInput(s0); + fusion->addInput(s1); + + auto tv1 = slice(tv0, {{s0, s1}}); + fusion->addOutput(tv1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + FusionExecutorCache fec(std::move(fusion_ptr)); + + auto t0 = at::randn(shape, options); + for (auto [start, stop] : slice_cases) { + std::vector aten_inputs({t0, start, stop}); + auto cg_outputs = fec.runFusionWithInputs(aten_inputs); + + testValidate(fec.fusion(), cg_outputs, aten_inputs, __LINE__, __FILE__); + } +} + // Auto scheduled version of Slice1 TEST_F(ResizeTest, FusionResizeSliceScheduler1) { auto fusion_ptr = std::make_unique(); @@ -2319,7 +2445,7 @@ TEST_F(ResizeTest, Slice1DVectorizeManual1) { FusionGuard fg(fusion_ptr.get()); const int64_t slice_offset = 4; - const std::vector shape({1024 * 1024}); + const std::vector shape({1024L * 1024L}); // Using a concrete tensor to avoid dynamic reshape auto tv0 = makeContigConcreteTensor(shape); @@ -2358,7 +2484,7 @@ TEST_F(ResizeTest, Slice1DVectorizeManual2) { FusionGuard fg(fusion_ptr.get()); const int64_t slice_offset = 4; - const std::vector shape({1024 * 1024}); + const std::vector shape({1024L * 1024L}); auto tv0 = makeContigConcreteTensor(shape); fusion.addInput(tv0); @@ -2414,7 +2540,7 @@ TEST_F(ResizeTest, Slice1DVectorizeManual3) { FusionGuard fg(fusion_ptr.get()); const int64_t slice_offset = 4; - const std::vector shape({1024 * 1024}); + const std::vector shape({1024L * 1024L}); auto tv0 = makeContigConcreteTensor(shape); fusion.addInput(tv0); @@ -2463,7 +2589,7 @@ TEST_F(ResizeTest, Slice1DVectorizeManual4) { auto& fusion = *fusion_ptr; FusionGuard fg(fusion_ptr.get()); - const std::vector shape({1024 * 1024}); + const std::vector shape({1024L * 1024L}); auto tv0 = makeContigConcreteTensor({shape[0] - 4}); fusion.addInput(tv0); @@ -2505,7 +2631,7 @@ TEST_F(ResizeTest, Slice2DVectorizeManual1) { // The extent of the innermost domain is just 2, and the outer // domain is sliced. This slicing should be vectorizable by a // factor of 4 as the two domains can be merged and vectorized. - const std::vector shape({1024 * 1024, 2}); + const std::vector shape({1024L * 1024L, 2}); auto tv0 = makeContigConcreteTensor(shape); fusion.addInput(tv0); diff --git a/test/test_tutorial.cpp b/test/test_tutorial.cpp index d591062e046..757ecadc2af 100644 --- a/test/test_tutorial.cpp +++ b/test/test_tutorial.cpp @@ -446,4 +446,188 @@ TEST_F(Tutorial, ReductionRFactor) { } } +TEST_F(Tutorial, Reshape) { + { + // Simple reshape example + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + // Shape of tv0 is assumed to be [4, 8], which is then reshaped to [32] + auto tv1 = reshape(tv0, {4, 8}, {32}); + fusion.addOutput(tv1); + + if (verbose_) { + // Notice that tv1 has root and rfactor domains. The root domain + // should consist of two IterDomains, whreas the rfactor domain + // consists of a single IterDomain that is an output of a merge + // operation of the two root IterDomains + fusion.print(); + } + + // Check if the tv1 domains are generated as expected + ASSERT_TRUE(tv1->hasRFactor()); + ASSERT_EQ(tv1->getRFactorDomain().size(), 1); + ASSERT_TRUE(tv1->getRFactorDomain().at(0)->definition()->isA()); + Merge* tv1_merge = tv1->getRFactorDomain().at(0)->definition()->as(); + ASSERT_EQ(tv1_merge->inner(), tv1->getRootDomain().at(1)); + ASSERT_EQ(tv1_merge->outer(), tv1->getRootDomain().at(0)); + } + + { + // Reshape example with broadcast domains + Fusion fusion; + FusionGuard fg(&fusion); + + // Create a 3D tensor with a broadcast domain + auto tv0 = makeConcreteTensor({1, -1, -1}); + fusion.addInput(tv0); + + // tv0 is first squeezed and then reshaped and unsqueezed + auto tv1 = reshape(tv0, {1, 2, 3}, {3, 2, 1}); + fusion.addOutput(tv1); + + if (verbose_) { + fusion.print(); + } + + // The fusion should look like: + // + // tv1 = unsqueeze(reshape(squeeze(tv0))); + ASSERT_TRUE(tv1->definition()->isA()); + auto reshape_output = tv1->definition()->input(0)->as(); + ASSERT_TRUE(reshape_output->definition()->isA()); + auto squeeze_output = + reshape_output->definition()->input(0)->as(); + ASSERT_TRUE(squeeze_output->definition()->isA()); + + ASSERT_TRUE(reshape_output->hasRFactor()); + ASSERT_EQ(reshape_output->getRFactorDomain().size(), 2); + ASSERT_TRUE( + reshape_output->getRFactorDomain().at(0)->definition()->isA()); + auto reshape_output_split = + reshape_output->getRFactorDomain().at(0)->definition()->as(); + ASSERT_EQ( + reshape_output_split->outer(), + reshape_output->getRFactorDomain().at(0)); + ASSERT_EQ( + reshape_output_split->inner(), + reshape_output->getRFactorDomain().at(1)); + ASSERT_TRUE(reshape_output_split->in()->definition()->isA()); + auto reshape_output_merge = + reshape_output_split->in()->definition()->as(); + ASSERT_EQ( + reshape_output_merge->outer(), reshape_output->getRootDomain().at(0)); + ASSERT_EQ( + reshape_output_merge->inner(), reshape_output->getRootDomain().at(1)); + + // So far, the fusion has transformations as part of its + // definition. It can be further extended with scheduling transformations. + reshape_output->merge(0, 1); + reshape_output->split(0, 128); + + ASSERT_TRUE( + reshape_output->getLeafDomain().at(0)->definition()->isA()); + ASSERT_EQ( + reshape_output->getLeafDomain() + .at(0) + ->definition() + ->as() + ->inner(), + reshape_output->getLeafDomain().at(1)); + ASSERT_TRUE(reshape_output->getLeafDomain() + .at(0) + ->definition() + ->as() + ->in() + ->definition() + ->isA()); + ASSERT_EQ( + reshape_output->getLeafDomain() + .at(0) + ->definition() + ->as() + ->in() + ->definition() + ->as() + ->outer(), + reshape_output->getRFactorDomain().at(0)); + ASSERT_EQ( + reshape_output->getLeafDomain() + .at(0) + ->definition() + ->as() + ->in() + ->definition() + ->as() + ->inner(), + reshape_output->getRFactorDomain().at(1)); + + // Here's how we propagate the transformations of reshape_output + // to all other tensors in the fusion + TransformPropagatorWithCheck propagator(reshape_output); + MaxRootDomainInfoSpanningTree(reshape_output).traverse(&propagator); + + // Now, all tensors, including those before the reshape op, should + // be transformed to 2D tensors with an inner domain of extent + // 128. + if (verbose_) { + fusion.print(); + } + + // Notice that all transformations of the reshape tensor, + // including both the reshape and scheduling transformations, are + // propagated. For example, squeeze_output should have the merge and split + // for the reshape, followed by another merge and split for + // scheduling. Specifically: + // + // Root domain: [b0, i1, i2] + // merge(1, 2) -> [b0, i1*i2] + // outer split(1, 3) -> [b0, 3, i1*i2/3] + // merge(1, 2) -> [b0, 3*i1*i2/3] + // split(1, 128) -> [b0, 3*i1*i2/3/128, 128] + ASSERT_TRUE( + squeeze_output->getLeafDomain().at(0)->definition()->isA()); + auto squeeze_output_second_split = + squeeze_output->getLeafDomain().at(0)->definition()->as(); + ASSERT_EQ( + squeeze_output_second_split->outer(), + squeeze_output->getLeafDomain().at(0)); + ASSERT_EQ( + squeeze_output_second_split->inner(), + squeeze_output->getLeafDomain().at(1)); + + ASSERT_TRUE(squeeze_output_second_split->in()->definition()->isA()); + auto squeeze_output_second_merge = + squeeze_output_second_split->in()->definition()->as(); + + ASSERT_TRUE( + squeeze_output_second_merge->outer()->definition()->isA()); + auto squeeze_output_first_split = + squeeze_output_second_merge->outer()->definition()->as(); + ASSERT_EQ( + squeeze_output_first_split->outer(), + squeeze_output_second_merge->outer()); + ASSERT_EQ( + squeeze_output_first_split->inner(), + squeeze_output_second_merge->inner()); + + ASSERT_TRUE(squeeze_output_first_split->in()->definition()->isA()); + auto squeeze_output_first_merge = + squeeze_output_first_split->in()->definition()->as(); + ASSERT_EQ( + squeeze_output_first_merge->outer(), + squeeze_output->getRootDomain().at(0)); + ASSERT_EQ( + squeeze_output_first_merge->inner(), + squeeze_output->getRootDomain().at(1)); + + // Note that all the transformations of squeeze_output are scheduling + // transformations, thus it should not have a rfactor domain + ASSERT_FALSE(squeeze_output->hasRFactor()); + } +} + } // namespace nvfuser