From f9def574ad119065244756079c558ed001d2cc6e Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 6 Jun 2023 15:34:52 -0400 Subject: [PATCH 01/24] Add clipping to slice output extent expressions --- csrc/ops/alias.cpp | 10 ++++++-- test/test_resize.cpp | 57 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 49ed3b9912d..93679d7df22 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -691,12 +691,18 @@ TensorView* slice(TensorView* inp, const std::vector& ranges) { out_root_id = inp_root_id->cloneWithoutRFactor(); out_rf_id = out_root_id; } else { + // Clip the start and stop values to the extent of the input + auto clipped_start = + SimplifyingIrBuilder::minExpr(inp_root_id->extent(), range.start); + auto clipped_stop = + SimplifyingIrBuilder::minExpr(inp_root_id->extent(), range.stop); + out_root_id = IterDomainBuilder(inp_root_id).is_rfactor_domain(true).build(); out_rf_id = IterDomain::resize( out_root_id, - SimplifyingIrBuilder::negExpr(range.start), - sub(range.stop, inp_root_id->extent()), + SimplifyingIrBuilder::negExpr(clipped_start), + sub(clipped_stop, inp_root_id->extent()), true); needs_real_slicing = true; } diff --git a/test/test_resize.cpp b/test/test_resize.cpp index 29b16d6ed74..a014c0e7416 100644 --- a/test/test_resize.cpp +++ b/test/test_resize.cpp @@ -1122,6 +1122,63 @@ TEST_F(NVFuserTest, FusionResizeSlice5_CUDA) { testValidate(&fusion, cg_outputs, aten_inputs, {t2, t4}, __LINE__, __FILE__); } +// Slice with end beyond size of input. This should clip to input, not pad. +TEST_F(NVFuserTest, FusionResizeSlice6_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape({9}); + + // concrete shapes to avoid dynamic Fusion + auto tv0 = makeConcreteTensor(shape); + fusion.addInput(tv0); + + auto tv1 = slice(tv0, {{fusion.zeroVal(), IrBuilder::create(11)}}); + fusion.addOutput(tv1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + auto t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto ref = t0.index({at::indexing::Slice(0, 11)}); + + testValidate(&fusion, cg_outputs, aten_inputs, {ref}, __LINE__, __FILE__); +} + +// Slice with start beyond size of input. This should produce zero-size tensor. +TEST_F(NVFuserTest, FusionResizeSlice7_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape({9}); + + // concrete shapes to avoid dynamic Fusion + auto tv0 = makeConcreteTensor(shape); + fusion.addInput(tv0); + + auto tv1 = + slice(tv0, {{IrBuilder::create(11), IrBuilder::create(13)}}); + fusion.addOutput(tv1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + auto t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto ref = t0.index({at::indexing::Slice(11, 13)}); + + testValidate(&fusion, cg_outputs, aten_inputs, {ref}, __LINE__, __FILE__); +} + // Auto scheduled version of Slice1 TEST_F(NVFuserTest, FusionResizeSliceScheduler1_CUDA) { auto fusion_ptr = std::make_unique(); From 2c4af30628b74f9329c32f2f108cb02598d65e0c Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 6 Jun 2023 21:45:36 -0400 Subject: [PATCH 02/24] Add where to ExpressionEvaluator, handle negative in slice This currently fails at lowering due to infinite recursion in nvfuser::prove::lessEqual when trying to simplify index expressions for index hoisting. --- csrc/evaluator_common.cpp | 69 +++++++++++++++++++++++++++++++++++++++ csrc/evaluator_common.h | 23 ++++++++++--- csrc/ops/alias.cpp | 12 +++++++ test/test_resize.cpp | 49 +++++++++++++++++---------- 4 files changed, 130 insertions(+), 23 deletions(-) diff --git a/csrc/evaluator_common.cpp b/csrc/evaluator_common.cpp index 5a3933f4363..02653fffc1c 100644 --- a/csrc/evaluator_common.cpp +++ b/csrc/evaluator_common.cpp @@ -367,6 +367,8 @@ NaiveValueMachine::NaiveValueMachine(PrecomputedValues& precomputed_values) makeUnaryOp(uop); } else if (auto bop = dynamic_cast(def)) { makeBinaryOp(bop); + } else if (auto top = dynamic_cast(def)) { + makeTernaryOp(top); } else { TORCH_INTERNAL_ASSERT(false, "Unsupported expr"); } @@ -393,12 +395,19 @@ void NaiveValueMachine::copyFrom(const NaiveValueMachine& other) { bop_type_.insert( bop_type_.end(), other.bop_type_.begin(), other.bop_type_.end()); + top_type_.clear(); + top_type_.insert( + top_type_.end(), other.top_type_.begin(), other.top_type_.end()); + src0_.clear(); src0_.insert(src0_.end(), other.src0_.begin(), other.src0_.end()); src1_.clear(); src1_.insert(src1_.end(), other.src1_.begin(), other.src1_.end()); + src2_.clear(); + src2_.insert(src2_.end(), other.src2_.begin(), other.src2_.end()); + dest_.clear(); dest_.insert(dest_.end(), other.dest_.begin(), other.dest_.end()); } @@ -448,14 +457,36 @@ void NaiveValueMachine::makeBinaryOp(BinaryOp* bop) { dest_[index] = out; } +void NaiveValueMachine::makeTernaryOp(TernaryOp* top) { + int in0 = top->inputs()[0]->evaluatorIndex(); + int in1 = top->inputs()[1]->evaluatorIndex(); + int in2 = top->inputs()[2]->evaluatorIndex(); + int out = top->outputs()[0]->evaluatorIndex(); + + TORCH_INTERNAL_ASSERT(in0 >= 0, "Integer Machine: unknown input 0: ", top); + TORCH_INTERNAL_ASSERT(in1 >= 0, "Integer Machine: unknown input 1: ", top); + TORCH_INTERNAL_ASSERT(in2 >= 0, "Integer Machine: unknown input 2: ", top); + TORCH_INTERNAL_ASSERT(out >= 0, "Integer Machine: unknown out: ", top); + + int index = makeInstructionEntry(); + inst_type_[index] = InstructionType::TERNARY_OP; + top_type_[index] = top->getTernaryOpType(); + src0_[index] = in0; + src1_[index] = in1; + src2_[index] = in2; + dest_[index] = out; +} + int NaiveValueMachine::makeInstructionEntry() { int index = num_of_instructions_++; inst_type_.emplace_back(InstructionType::UNARY_OP); uop_type_.emplace_back(UnaryOpType::Abs); bop_type_.emplace_back(BinaryOpType::Add); + top_type_.emplace_back(TernaryOpType::Where); data_type_.emplace_back(DataType::Null); src0_.emplace_back(-1); src1_.emplace_back(-1); + src2_.emplace_back(-1); dest_.emplace_back(-1); return index; } @@ -472,6 +503,9 @@ void NaiveValueMachine::runInstruction(int index) { case InstructionType::BINARY_OP: runBinaryOp(index); break; + case InstructionType::TERNARY_OP: + runTernaryOp(index); + break; } } @@ -575,4 +609,39 @@ void NaiveValueMachine::runBinaryOp(int index) { precomputed_values_.defined_[dest_index] = true; } +void NaiveValueMachine::runTernaryOp(int index) { + using namespace EvaluatorValue_functions; + int src0_index = src0_[index]; + int src1_index = src1_[index]; + int src2_index = src2_[index]; + bool src0_is_const = precomputed_values_.is_constant_[src0_index]; + bool src1_is_const = precomputed_values_.is_constant_[src1_index]; + bool src2_is_const = precomputed_values_.is_constant_[src2_index]; + + bool src_defined = + (precomputed_values_.defined_[src0_index] || src0_is_const) && + (precomputed_values_.defined_[src1_index] || src1_is_const) && + (precomputed_values_.defined_[src2_index] || src2_is_const); + + if (!src_defined) { + return; + } + int dest_index = dest_[index]; + + auto& in1 = precomputed_values_.values_[src0_index]; + auto& in2 = precomputed_values_.values_[src1_index]; + auto& in3 = precomputed_values_.values_[src1_index]; + auto& dest = precomputed_values_.values_[dest_index]; + + switch (top_type_[index]) { + case TernaryOpType::Where: + dest = in1 ? in2 : in3; + break; + default: + TORCH_CHECK(!"Unexpected operator type"); + } + + precomputed_values_.defined_[dest_index] = true; +} + } // namespace nvfuser diff --git a/csrc/evaluator_common.h b/csrc/evaluator_common.h index 112d770a27e..7e872ce2fc8 100644 --- a/csrc/evaluator_common.h +++ b/csrc/evaluator_common.h @@ -29,9 +29,8 @@ struct TensorArgAbstract; //! PrecomputedValues that will provide the workspace //! containing the concrete values for the values. class NaiveValueMachine { - //! The generic types of instructions supported for this - //! machine, currently only binary and unary. - enum class InstructionType { UNARY_OP, BINARY_OP, SET_OP }; + //! The generic types of instructions supported for this machine. + enum class InstructionType { UNARY_OP, BINARY_OP, TERNARY_OP, SET_OP }; public: //! Constructor lowers all the expr IR nodes stored in precomputed_values @@ -55,6 +54,9 @@ class NaiveValueMachine { //! Convert an binary IR expr to an instruction void makeBinaryOp(BinaryOp* bop); + //! Convert an ternary IR expr to an instruction + void makeTernaryOp(TernaryOp* bop); + //! Create an empty instruction with all default values //! and place it at the end of the instruction buffer. int makeInstructionEntry(); @@ -70,6 +72,9 @@ class NaiveValueMachine { //! Runs a binary operation at given index of instruction buffer void runBinaryOp(int index); + //! Runs a ternary operation at given index of instruction buffer + void runTernaryOp(int index); + private: friend PrecomputedValues; @@ -97,10 +102,14 @@ class NaiveValueMachine { //! value at each index corresponding other ops. std::vector data_type_; - //! Unary operator type if applicable, contains a default - //! value at each index corresponding to a unary op. + //! Binary operator type if applicable, contains a default + //! value at each index corresponding to a binary op. std::vector bop_type_; + //! Ternary operator type if applicable, contains a default + //! value at each index corresponding to a ternary op. + std::vector top_type_; + //! Indexes of operands and destination of each instruction. //! The indexes corresponds to positions in the workspace //! where concrete values are hosted. @@ -112,6 +121,10 @@ class NaiveValueMachine { //! each index corresponding to a unary op. std::vector src1_; + //! Operand 2 of each instruction, a default value at + //! each index corresponding to a unary or binary op. + std::vector src2_; + //! Destination of each instruction. std::vector dest_; }; diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 93679d7df22..4b945f8be99 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -656,9 +656,21 @@ TensorView* slice(TensorView* inp, const std::vector& ranges) { auto normalize_slice_range = [](Slice range, Val* extent) -> Slice { if (range.start == nullptr) { range.start = FusionGuard::getCurFusion()->zeroVal(); + } else { + // Negative start and stop values are relative to end of axis + range.start = where( + lt(range.start, FusionGuard::getCurFusion()->zeroVal()), + add(range.start, extent), + range.start); } if (range.stop == nullptr) { range.stop = extent; + } else { + // Negative start and stop values are relative to end of axis + range.stop = where( + lt(range.stop, FusionGuard::getCurFusion()->zeroVal()), + add(range.stop, extent), + range.stop); } if (range.step == nullptr) { range.step = FusionGuard::getCurFusion()->oneVal(); diff --git a/test/test_resize.cpp b/test/test_resize.cpp index a014c0e7416..7b8958ca233 100644 --- a/test/test_resize.cpp +++ b/test/test_resize.cpp @@ -1122,32 +1122,45 @@ TEST_F(NVFuserTest, FusionResizeSlice5_CUDA) { testValidate(&fusion, cg_outputs, aten_inputs, {t2, t4}, __LINE__, __FILE__); } -// Slice with end beyond size of input. This should clip to input, not pad. -TEST_F(NVFuserTest, FusionResizeSlice6_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); +// Test slice with a variety of (constant) inputs +TEST_F(NVFuserTest, FusionResizeSliceShmoo_CUDA) { + for (auto [start, stop] : std::vector>( + {// Slice with end beyond size of input. This should clip to input, + // not pad. + {0, 11}, + {11, 13}, + {-3, 8}, + {-3, -1}, + {13, -1}})) { + std::cout << "start=" << start << " stop=" << stop << std::endl; + Fusion fusion; + FusionGuard fg(&fusion); - std::vector shape({9}); + std::vector shape({9}); - // concrete shapes to avoid dynamic Fusion - auto tv0 = makeConcreteTensor(shape); - fusion.addInput(tv0); + // concrete shapes to avoid dynamic Fusion + auto tv0 = makeConcreteTensor(shape); + fusion.addInput(tv0); - auto tv1 = slice(tv0, {{fusion.zeroVal(), IrBuilder::create(11)}}); - fusion.addOutput(tv1); + auto tv1 = slice( + tv0, {{IrBuilder::create(start), IrBuilder::create(stop)}}); + fusion.addOutput(tv1); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + fusion.printMath(); - auto t0 = at::randn(shape, options); - std::vector aten_inputs({t0}); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); + auto t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); - auto ref = t0.index({at::indexing::Slice(0, 11)}); + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); - testValidate(&fusion, cg_outputs, aten_inputs, {ref}, __LINE__, __FILE__); + auto ref = t0.index({at::indexing::Slice(start, stop)}); + + testValidate(&fusion, cg_outputs, aten_inputs, {ref}, __LINE__, __FILE__); + } } // Slice with start beyond size of input. This should produce zero-size tensor. From 5dcf2c88d41af5f51b41961ff58c1c65054533ab Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 11 Sep 2023 12:17:12 -0400 Subject: [PATCH 03/24] Support Set,Where, bool ops in NaiveValueMachine --- csrc/evaluator_common.cpp | 63 ++++++++++++++++++++++++++++----------- csrc/evaluator_common.h | 4 +++ csrc/ir/nodes.cpp | 6 ++++ csrc/ops/alias.cpp | 50 ++++++++++++++----------------- test/test_resize.cpp | 7 ++--- 5 files changed, 79 insertions(+), 51 deletions(-) diff --git a/csrc/evaluator_common.cpp b/csrc/evaluator_common.cpp index dbde96ca54e..155bbe01462 100644 --- a/csrc/evaluator_common.cpp +++ b/csrc/evaluator_common.cpp @@ -355,6 +355,12 @@ NaiveValueMachine::NaiveValueMachine(PrecomputedValues& precomputed_values) makeBinaryOp(bop); } else if (auto top = dynamic_cast(def)) { makeTernaryOp(top); + } else if (auto lsop = dynamic_cast(def)) { + NVF_ERROR( + lsop->opType() == LoadStoreOpType::Set, + "NaiveValueMachine: unsupported LoadStoreOpType: ", + lsop->opType()); + makeSetOp(lsop); } else { // There could be some ops not supported yet. For these ops, we will // bind their outputs. So ignoring them here. @@ -382,19 +388,12 @@ void NaiveValueMachine::copyFrom(const NaiveValueMachine& other) { bop_type_.insert( bop_type_.end(), other.bop_type_.begin(), other.bop_type_.end()); - top_type_.clear(); - top_type_.insert( - top_type_.end(), other.top_type_.begin(), other.top_type_.end()); - src0_.clear(); src0_.insert(src0_.end(), other.src0_.begin(), other.src0_.end()); src1_.clear(); src1_.insert(src1_.end(), other.src1_.begin(), other.src1_.end()); - src2_.clear(); - src2_.insert(src2_.end(), other.src2_.begin(), other.src2_.end()); - dest_.clear(); dest_.insert(dest_.end(), other.dest_.begin(), other.dest_.end()); } @@ -450,10 +449,10 @@ void NaiveValueMachine::makeTernaryOp(TernaryOp* top) { int in2 = top->inputs()[2]->evaluatorIndex(); int out = top->outputs()[0]->evaluatorIndex(); - TORCH_INTERNAL_ASSERT(in0 >= 0, "Integer Machine: unknown input 0: ", top); - TORCH_INTERNAL_ASSERT(in1 >= 0, "Integer Machine: unknown input 1: ", top); - TORCH_INTERNAL_ASSERT(in2 >= 0, "Integer Machine: unknown input 2: ", top); - TORCH_INTERNAL_ASSERT(out >= 0, "Integer Machine: unknown out: ", top); + NVF_ERROR(in0 >= 0, "Integer Machine: unknown first input: ", top); + NVF_ERROR(in1 >= 0, "Integer Machine: unknown second input: ", top); + NVF_ERROR(in2 >= 0, "Integer Machine: unknown third input: ", top); + NVF_ERROR(out >= 0, "Integer Machine: unknown out: ", top); int index = makeInstructionEntry(); inst_type_[index] = InstructionType::TERNARY_OP; @@ -464,6 +463,19 @@ void NaiveValueMachine::makeTernaryOp(TernaryOp* top) { dest_[index] = out; } +void NaiveValueMachine::makeSetOp(LoadStoreOp* lsop) { + int in = lsop->in()->evaluatorIndex(); + int out = lsop->out()->evaluatorIndex(); + + NVF_ERROR(in >= 0, "Integer Machine: unknown input: ", lsop); + NVF_ERROR(out >= 0, "Integer Machine: unknown out: ", lsop); + + int index = makeInstructionEntry(); + inst_type_[index] = InstructionType::SET_OP; + src0_[index] = in; + dest_[index] = out; +} + int NaiveValueMachine::makeInstructionEntry() { int index = num_of_instructions_++; inst_type_.emplace_back(InstructionType::UNARY_OP); @@ -608,15 +620,30 @@ void NaiveValueMachine::runBinaryOp(int index) { case BinaryOpType::Gcd: dest = gcd(lhs, rhs); break; + case BinaryOpType::LT: + dest = lhs < rhs; + break; + case BinaryOpType::LE: + dest = lhs <= rhs; + break; + case BinaryOpType::Eq: + dest = lhs == rhs; + break; + case BinaryOpType::GE: + dest = lhs >= rhs; + break; + case BinaryOpType::GT: + dest = lhs > rhs; + break; default: - NVF_CHECK(!"Unexpected operator type"); + NVF_CHECK(false, "Unexpected operator type ", bop_type_[index]); } precomputed_values_.defined_[dest_index] = true; } void NaiveValueMachine::runTernaryOp(int index) { - using namespace EvaluatorValue_functions; + using namespace PolymorphicValue_functions; int src0_index = src0_[index]; int src1_index = src1_[index]; int src2_index = src2_[index]; @@ -634,17 +661,17 @@ void NaiveValueMachine::runTernaryOp(int index) { } int dest_index = dest_[index]; - auto& in1 = precomputed_values_.values_[src0_index]; - auto& in2 = precomputed_values_.values_[src1_index]; - auto& in3 = precomputed_values_.values_[src1_index]; + auto& a = precomputed_values_.values_[src0_index]; + auto& b = precomputed_values_.values_[src1_index]; + auto& c = precomputed_values_.values_[src2_index]; auto& dest = precomputed_values_.values_[dest_index]; switch (top_type_[index]) { case TernaryOpType::Where: - dest = in1 ? in2 : in3; + dest = a ? b : c; break; default: - TORCH_CHECK(!"Unexpected operator type"); + NVF_CHECK(!"Unexpected operator type"); } precomputed_values_.defined_[dest_index] = true; diff --git a/csrc/evaluator_common.h b/csrc/evaluator_common.h index 18dd88112bd..28b4a1e1d9f 100644 --- a/csrc/evaluator_common.h +++ b/csrc/evaluator_common.h @@ -58,6 +58,10 @@ class NaiveValueMachine { //! Convert an ternary IR expr to an instruction void makeTernaryOp(TernaryOp* bop); + //! Convert a LoadStoreOp expr to an instruction. This assumes lsop->opType() + //! is equal to LoadStoreOpType::Set. + void makeSetOp(LoadStoreOp* lsop); + //! Create an empty instruction with all default values //! and place it at the end of the instruction buffer. int makeInstructionEntry(); diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 98b4f94ad1b..145836e23eb 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -2844,6 +2844,12 @@ IterDomain* IterDomain::resize( } } + if (resized_id_size->dtype() != DataType::Index) { + std::cout << "Casting resized extent " << resized_id_size->toInlineString() + << " to Index" << std::endl; + resized_id_size = castOp(DataType::Index, resized_id_size); + } + auto resized_id = IterDomainBuilder(in->container()->zeroVal(), resized_id_size) .is_rfactor_domain(mark_as_rfactor) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 63aa735d1da..3b2ceb2d681 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -690,9 +690,6 @@ TensorView* cat( return out; } -// Currently there's no error check about the actual values of the -// Slice parameters. For example, the start parameter of a range of a -// domain is assumed to be >= 0 and < the extent of the domain. TensorView* slice(TensorView* inp, const std::vector& ranges) { const auto inp_dom = TensorDomain::noReductions(inp->getMaybeRFactorDomain()); const int ndims = static_cast(inp_dom.size()); @@ -704,36 +701,30 @@ TensorView* slice(TensorView* inp, const std::vector& ranges) { ", Expected: ", ndims); - auto normalize_slice_range = [](Slice range, Val* extent) -> Slice { - if (range.start == nullptr) { - range.start = FusionGuard::getCurFusion()->zeroVal(); + const auto normalize_arg = [](Val* a, Val* extent, Val* def) -> Val* { + if (a == nullptr) { + return def; + } else if (a->isZeroInt() || a->sameAs(extent)) { + // These do not need normalization } else { // Negative start and stop values are relative to end of axis - range.start = where( - lt(range.start, FusionGuard::getCurFusion()->zeroVal()), - add(range.start, extent), - range.start); + a = where( + lt(a, FusionGuard::getCurFusion()->zeroVal()), add(a, extent), a); } - if (range.stop == nullptr) { - range.stop = extent; - } else { - // Negative start and stop values are relative to end of axis - range.stop = where( - lt(range.stop, FusionGuard::getCurFusion()->zeroVal()), - add(range.stop, extent), - range.stop); + if (a->dtype() != DataType::Index) { + a = SimplifyingIrBuilder::maybeCastExpr(DataType::Index, a); } + return a; + }; + + const auto normalize_slice_range = [&normalize_arg]( + Slice range, Val* extent) -> Slice { + range.start = normalize_arg( + range.start, extent, FusionGuard::getCurFusion()->zeroVal()); + range.stop = normalize_arg(range.stop, extent, extent); if (range.step == nullptr) { range.step = FusionGuard::getCurFusion()->oneVal(); } - if (range.start->dtype() != DataType::Index) { - range.start = - SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.start); - } - if (range.stop->dtype() != DataType::Index) { - range.stop = - SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.stop); - } if (range.step->dtype() != DataType::Index) { range.step = SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.step); @@ -769,8 +760,11 @@ TensorView* slice(TensorView* inp, const std::vector& ranges) { // Clip the start and stop values to the extent of the input auto clipped_start = SimplifyingIrBuilder::minExpr(inp_root_id->extent(), range.start); - auto clipped_stop = - SimplifyingIrBuilder::minExpr(inp_root_id->extent(), range.stop); + // stop is clipped the same as start, then we additionally clip so that + // stop >= start + auto clipped_stop = SimplifyingIrBuilder::maxExpr( + SimplifyingIrBuilder::minExpr(inp_root_id->extent(), range.stop), + clipped_start); out_root_id = IterDomainBuilder(inp_root_id).is_rfactor_domain(true).build(); diff --git a/test/test_resize.cpp b/test/test_resize.cpp index 29a3a92f00d..95b52942777 100644 --- a/test/test_resize.cpp +++ b/test/test_resize.cpp @@ -1151,7 +1151,6 @@ TEST_F(NVFuserTest, FusionResizeSliceShmoo_CUDA) { {-3, 8}, {-3, -1}, {13, -1}})) { - std::cout << "start=" << start << " stop=" << stop << std::endl; Fusion fusion; FusionGuard fg(&fusion); @@ -1162,11 +1161,9 @@ TEST_F(NVFuserTest, FusionResizeSliceShmoo_CUDA) { fusion.addInput(tv0); auto tv1 = slice( - tv0, {{IrBuilder::create(start), IrBuilder::create(stop)}}); + tv0, {{IrBuilder::create(start), IrBuilder::create(stop)}}); fusion.addOutput(tv1); - fusion.printMath(); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto t0 = at::randn(shape, options); @@ -1194,7 +1191,7 @@ TEST_F(NVFuserTest, FusionResizeSlice7_CUDA) { fusion.addInput(tv0); auto tv1 = - slice(tv0, {{IrBuilder::create(11), IrBuilder::create(13)}}); + slice(tv0, {{IrBuilder::create(11), IrBuilder::create(13)}}); fusion.addOutput(tv1); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); From 7432adba430faddaf9ee0e890ae18890da1145fb Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 11 Sep 2023 12:27:29 -0400 Subject: [PATCH 04/24] Silence clang-tidy in test_resize.cpp --- test/test_resize.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_resize.cpp b/test/test_resize.cpp index 95b52942777..b28c5c9669e 100644 --- a/test/test_resize.cpp +++ b/test/test_resize.cpp @@ -2427,7 +2427,7 @@ TEST_F(ResizeTest, Slice1DVectorizeManual1) { FusionGuard fg(fusion_ptr.get()); const int64_t slice_offset = 4; - const std::vector shape({1024 * 1024}); + const std::vector shape({1024L * 1024L}); // Using a concrete tensor to avoid dynamic reshape auto tv0 = makeContigConcreteTensor(shape); @@ -2466,7 +2466,7 @@ TEST_F(ResizeTest, Slice1DVectorizeManual2) { FusionGuard fg(fusion_ptr.get()); const int64_t slice_offset = 4; - const std::vector shape({1024 * 1024}); + const std::vector shape({1024L * 1024L}); auto tv0 = makeContigConcreteTensor(shape); fusion.addInput(tv0); @@ -2522,7 +2522,7 @@ TEST_F(ResizeTest, Slice1DVectorizeManual3) { FusionGuard fg(fusion_ptr.get()); const int64_t slice_offset = 4; - const std::vector shape({1024 * 1024}); + const std::vector shape({1024L * 1024L}); auto tv0 = makeContigConcreteTensor(shape); fusion.addInput(tv0); @@ -2571,7 +2571,7 @@ TEST_F(ResizeTest, Slice1DVectorizeManual4) { auto& fusion = *fusion_ptr; FusionGuard fg(fusion_ptr.get()); - const std::vector shape({1024 * 1024}); + const std::vector shape({1024L * 1024L}); auto tv0 = makeContigConcreteTensor({shape[0] - 4}); fusion.addInput(tv0); @@ -2613,7 +2613,7 @@ TEST_F(ResizeTest, Slice2DVectorizeManual1) { // The extent of the innermost domain is just 2, and the outer // domain is sliced. This slicing should be vectorizable by a // factor of 4 as the two domains can be merged and vectorized. - const std::vector shape({1024 * 1024, 2}); + const std::vector shape({1024L * 1024L, 2}); auto tv0 = makeContigConcreteTensor(shape); fusion.addInput(tv0); From 9856d93ce134626b414ff564fee27c02ac607971 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 14 Sep 2023 08:55:39 -0400 Subject: [PATCH 05/24] Add simplifying comparison operators --- csrc/ir/builder.cpp | 90 +++++++++++++++++++++++++++++++++++++++++++++ csrc/ir/builder.h | 7 ++++ 2 files changed, 97 insertions(+) diff --git a/csrc/ir/builder.cpp b/csrc/ir/builder.cpp index a7aebf763e9..5d58f7c9af2 100644 --- a/csrc/ir/builder.cpp +++ b/csrc/ir/builder.cpp @@ -330,6 +330,9 @@ Val* SimplifyingIrBuilder::addExpr(Val* lhs, Val* rhs) { } Val* SimplifyingIrBuilder::subExpr(Val* lhs, Val* rhs) { + if (lhs->sameAs(rhs)) { + return lhs->fusion()->zeroVal(lhs->dtype()); + } return addExpr(lhs, negExpr(rhs)); } @@ -373,6 +376,9 @@ Val* SimplifyingIrBuilder::divExpr(Val* lhs, Val* rhs) { if (rhs->isOneInt()) { return lhs; } + if (lhs->sameAs(rhs)) { + return lhs->fusion()->oneVal(lhs->dtype()); + } return IrBuilder::divExpr(lhs, rhs); } @@ -606,6 +612,90 @@ Val* SimplifyingIrBuilder::gcdExpr(Val* lhs, Val* rhs) { return IrBuilder::gcdExpr(lhs, rhs); } +Val* SimplifyingIrBuilder::ltExpr(Val* lhs, Val* rhs) { + NVF_ERROR( + lhs->dtype() == rhs->dtype(), + "Comparison expressions require same dtype for inputs"); + + if (lhs->isConstScalar() && rhs->isConstScalar()) { + return (lhs->evaluateBool() < rhs->evaluateBool()) + ? lhs->fusion()->trueVal() + : lhs->fusion()->falseVal(); + } + + return IrBuilder::ltExpr(lhs, rhs); +} + +Val* SimplifyingIrBuilder::leExpr(Val* lhs, Val* rhs) { + NVF_ERROR( + lhs->dtype() == rhs->dtype(), + "Comparison expressions require same dtype for inputs"); + + if (lhs->isConstScalar() && rhs->isConstScalar()) { + return (lhs->evaluateBool() <= rhs->evaluateBool()) + ? lhs->fusion()->trueVal() + : lhs->fusion()->falseVal(); + } + + return IrBuilder::leExpr(lhs, rhs); +} + +Val* SimplifyingIrBuilder::eqExpr(Val* lhs, Val* rhs) { + NVF_ERROR( + lhs->dtype() == rhs->dtype(), + "Comparison expressions require same dtype for inputs"); + + if (lhs->isConstScalar() && rhs->isConstScalar()) { + return (lhs->evaluateBool() == rhs->evaluateBool()) + ? lhs->fusion()->trueVal() + : lhs->fusion()->falseVal(); + } + + return IrBuilder::eqExpr(lhs, rhs); +} + +Val* SimplifyingIrBuilder::neExpr(Val* lhs, Val* rhs) { + NVF_ERROR( + lhs->dtype() == rhs->dtype(), + "Comparison expressions require same dtype for inputs"); + + if (lhs->isConstScalar() && rhs->isConstScalar()) { + return (lhs->evaluateBool() != rhs->evaluateBool()) + ? lhs->fusion()->trueVal() + : lhs->fusion()->falseVal(); + } + + return IrBuilder::neExpr(lhs, rhs); +} + +Val* SimplifyingIrBuilder::geExpr(Val* lhs, Val* rhs) { + NVF_ERROR( + lhs->dtype() == rhs->dtype(), + "Comparison expressions require same dtype for inputs"); + + if (lhs->isConstScalar() && rhs->isConstScalar()) { + return (lhs->evaluateBool() >= rhs->evaluateBool()) + ? lhs->fusion()->trueVal() + : lhs->fusion()->falseVal(); + } + + return IrBuilder::geExpr(lhs, rhs); +} + +Val* SimplifyingIrBuilder::gtExpr(Val* lhs, Val* rhs) { + NVF_ERROR( + lhs->dtype() == rhs->dtype(), + "Comparison expressions require same dtype for inputs"); + + if (lhs->isConstScalar() && rhs->isConstScalar()) { + return (lhs->evaluateBool() > rhs->evaluateBool()) + ? lhs->fusion()->trueVal() + : lhs->fusion()->falseVal(); + } + + return IrBuilder::gtExpr(lhs, rhs); +} + Val* SimplifyingIrBuilder::whereExpr(Val* pred, Val* lhs, Val* rhs) { NVF_ERROR( pred->dtype() == DataType::Bool, diff --git a/csrc/ir/builder.h b/csrc/ir/builder.h index 9d8578a1507..faf0564a347 100644 --- a/csrc/ir/builder.h +++ b/csrc/ir/builder.h @@ -198,6 +198,13 @@ class TORCH_CUDA_CU_API SimplifyingIrBuilder : public IrBuilder { static Val* minExpr(Val* lhs, Val* rhs); static Val* gcdExpr(Val* lhs, Val* rhs); + static Val* ltExpr(Val* lhs, Val* rhs); + static Val* leExpr(Val* lhs, Val* rhs); + static Val* eqExpr(Val* lhs, Val* rhs); + static Val* neExpr(Val* lhs, Val* rhs); + static Val* geExpr(Val* lhs, Val* rhs); + static Val* gtExpr(Val* lhs, Val* rhs); + static Val* whereExpr(Val* pred, Val* lhs, Val* rhs); }; From 85446508efb51a6724140550d562b7824366c63f Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 14 Sep 2023 08:55:53 -0400 Subject: [PATCH 06/24] Clean up normalized slice start/stop expressions --- csrc/ops/alias.cpp | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 3b2ceb2d681..7bd0d33e22e 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -758,20 +758,27 @@ TensorView* slice(TensorView* inp, const std::vector& ranges) { out_rf_id = out_root_id; } else { // Clip the start and stop values to the extent of the input - auto clipped_start = - SimplifyingIrBuilder::minExpr(inp_root_id->extent(), range.start); + auto zero = FusionGuard::getCurFusion()->zeroVal(DataType::Index); + const auto clip_to_extent = [&zero](Val* a, Val* ext) { + return SimplifyingIrBuilder::minExpr( + ext, + SimplifyingIrBuilder::whereExpr( + SimplifyingIrBuilder::ltExpr(a, zero), + SimplifyingIrBuilder::maxExpr(a, zero), + a)); + }; + auto clipped_start = clip_to_extent(range.start, inp_root_id->extent()); // stop is clipped the same as start, then we additionally clip so that // stop >= start auto clipped_stop = SimplifyingIrBuilder::maxExpr( - SimplifyingIrBuilder::minExpr(inp_root_id->extent(), range.stop), - clipped_start); + clip_to_extent(range.stop, inp_root_id->extent()), clipped_start); out_root_id = IterDomainBuilder(inp_root_id).is_rfactor_domain(true).build(); out_rf_id = IterDomain::resize( out_root_id, SimplifyingIrBuilder::negExpr(clipped_start), - sub(clipped_stop, inp_root_id->extent()), + SimplifyingIrBuilder::subExpr(clipped_stop, inp_root_id->extent()), true); needs_real_slicing = true; } From eee9ff20989f9031b4fd901d5bc465ecbeed1cda Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 14 Sep 2023 09:02:00 -0400 Subject: [PATCH 07/24] Fix wrong ostream in preseg ir dump --- csrc/kernel_cache.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/kernel_cache.cpp b/csrc/kernel_cache.cpp index f8dc8574dab..8b3322993b8 100644 --- a/csrc/kernel_cache.cpp +++ b/csrc/kernel_cache.cpp @@ -895,8 +895,8 @@ FusionKernelRuntime::FusionKernelRuntime( fusion.get()); if (isDebugDumpEnabled(DebugDumpOption::FusionIrPreseg)) { - std::cout << "Fusion IR after pre-segmenter optimization passes:" - << std::endl; + debug() << "Fusion IR after pre-segmenter optimization passes:" + << std::endl; fusion->printMath(); } From efcb2033d435085b4ed944319fd71017646c7984 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 14 Sep 2023 09:02:14 -0400 Subject: [PATCH 08/24] Remove debug print --- csrc/ir/nodes.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 23a9c416e0c..f858c186362 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -2847,8 +2847,6 @@ IterDomain* IterDomain::resize( } if (resized_id_size->dtype() != DataType::Index) { - std::cout << "Casting resized extent " << resized_id_size->toInlineString() - << " to Index" << std::endl; resized_id_size = castOp(DataType::Index, resized_id_size); } From bf0a4b6e08bbf0d858fac02f69c7e69a2014e661 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 14 Sep 2023 09:04:47 -0400 Subject: [PATCH 09/24] Handle NE in runBinaryOp --- csrc/evaluator_common.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/csrc/evaluator_common.cpp b/csrc/evaluator_common.cpp index 155bbe01462..58e9314512c 100644 --- a/csrc/evaluator_common.cpp +++ b/csrc/evaluator_common.cpp @@ -629,6 +629,9 @@ void NaiveValueMachine::runBinaryOp(int index) { case BinaryOpType::Eq: dest = lhs == rhs; break; + case BinaryOpType::NE: + dest = lhs != rhs; + break; case BinaryOpType::GE: dest = lhs >= rhs; break; From 121a3b9a800dfea6d539fd9289b21c2db89c4a40 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 14 Sep 2023 09:31:49 -0400 Subject: [PATCH 10/24] Fix simplified comparison ops --- csrc/ir/builder.cpp | 148 +++++++++++++++++++++++++------------------- 1 file changed, 84 insertions(+), 64 deletions(-) diff --git a/csrc/ir/builder.cpp b/csrc/ir/builder.cpp index 5d58f7c9af2..c51f9a04e7f 100644 --- a/csrc/ir/builder.cpp +++ b/csrc/ir/builder.cpp @@ -612,95 +612,116 @@ Val* SimplifyingIrBuilder::gcdExpr(Val* lhs, Val* rhs) { return IrBuilder::gcdExpr(lhs, rhs); } -Val* SimplifyingIrBuilder::ltExpr(Val* lhs, Val* rhs) { - NVF_ERROR( - lhs->dtype() == rhs->dtype(), - "Comparison expressions require same dtype for inputs"); +namespace { +enum class ScalarComparisonResult { LT, EQ, GT, NONCONST }; - if (lhs->isConstScalar() && rhs->isConstScalar()) { - return (lhs->evaluateBool() < rhs->evaluateBool()) - ? lhs->fusion()->trueVal() - : lhs->fusion()->falseVal(); +double evaluateAsDouble(Val* a) { + NVF_ERROR( + a->isScalar(), + "evaluateAsDouble expects scalar but found ", + a->toString()); + if (a->isIntegralScalar()) { + return static_cast(a->evaluateInt()); + } else if (a->isABool()) { + return static_cast(a->evaluateBool()); + } else if (a->isFloatingPointScalar()) { + return a->evaluateDouble(); + } else { + NVF_ERROR( + false, + "Unhandled dtype ", + a->dtype(), + " in evaluateAsDouble for input ", + a->toInlineString()); + } +} + +//! Compares a to b after evaluation and conversion to double +ScalarComparisonResult compareConstScalars(Val* a, Val* b) { + if (!a->isConstScalar() || !b->isConstScalar()) { + return ScalarComparisonResult::NONCONST; + } + auto ad = evaluateAsDouble(a); + auto bd = evaluateAsDouble(b); + if (ad < bd) { + return ScalarComparisonResult::LT; + } else if (ad == bd) { + return ScalarComparisonResult::EQ; + } else { + return ScalarComparisonResult::GT; } +} +} // namespace - return IrBuilder::ltExpr(lhs, rhs); +Val* SimplifyingIrBuilder::ltExpr(Val* lhs, Val* rhs) { + switch (compareConstScalars(lhs, rhs)) { + case ScalarComparisonResult::NONCONST: + return IrBuilder::ltExpr(lhs, rhs); + case ScalarComparisonResult::LT: + return lhs->fusion()->trueVal(); + default: + return lhs->fusion()->falseVal(); + } } Val* SimplifyingIrBuilder::leExpr(Val* lhs, Val* rhs) { - NVF_ERROR( - lhs->dtype() == rhs->dtype(), - "Comparison expressions require same dtype for inputs"); - - if (lhs->isConstScalar() && rhs->isConstScalar()) { - return (lhs->evaluateBool() <= rhs->evaluateBool()) - ? lhs->fusion()->trueVal() - : lhs->fusion()->falseVal(); + switch (compareConstScalars(lhs, rhs)) { + case ScalarComparisonResult::NONCONST: + return IrBuilder::leExpr(lhs, rhs); + case ScalarComparisonResult::LT: + case ScalarComparisonResult::EQ: + return lhs->fusion()->trueVal(); + default: + return lhs->fusion()->falseVal(); } - - return IrBuilder::leExpr(lhs, rhs); } Val* SimplifyingIrBuilder::eqExpr(Val* lhs, Val* rhs) { - NVF_ERROR( - lhs->dtype() == rhs->dtype(), - "Comparison expressions require same dtype for inputs"); - - if (lhs->isConstScalar() && rhs->isConstScalar()) { - return (lhs->evaluateBool() == rhs->evaluateBool()) - ? lhs->fusion()->trueVal() - : lhs->fusion()->falseVal(); + switch (compareConstScalars(lhs, rhs)) { + case ScalarComparisonResult::NONCONST: + return IrBuilder::eqExpr(lhs, rhs); + case ScalarComparisonResult::EQ: + return lhs->fusion()->trueVal(); + default: + return lhs->fusion()->falseVal(); } - - return IrBuilder::eqExpr(lhs, rhs); } Val* SimplifyingIrBuilder::neExpr(Val* lhs, Val* rhs) { - NVF_ERROR( - lhs->dtype() == rhs->dtype(), - "Comparison expressions require same dtype for inputs"); - - if (lhs->isConstScalar() && rhs->isConstScalar()) { - return (lhs->evaluateBool() != rhs->evaluateBool()) - ? lhs->fusion()->trueVal() - : lhs->fusion()->falseVal(); + switch (compareConstScalars(lhs, rhs)) { + case ScalarComparisonResult::NONCONST: + return IrBuilder::neExpr(lhs, rhs); + case ScalarComparisonResult::EQ: + return lhs->fusion()->falseVal(); + default: + return lhs->fusion()->trueVal(); } - - return IrBuilder::neExpr(lhs, rhs); } Val* SimplifyingIrBuilder::geExpr(Val* lhs, Val* rhs) { - NVF_ERROR( - lhs->dtype() == rhs->dtype(), - "Comparison expressions require same dtype for inputs"); - - if (lhs->isConstScalar() && rhs->isConstScalar()) { - return (lhs->evaluateBool() >= rhs->evaluateBool()) - ? lhs->fusion()->trueVal() - : lhs->fusion()->falseVal(); + switch (compareConstScalars(lhs, rhs)) { + case ScalarComparisonResult::NONCONST: + return IrBuilder::geExpr(lhs, rhs); + case ScalarComparisonResult::GT: + case ScalarComparisonResult::EQ: + return lhs->fusion()->trueVal(); + default: + return lhs->fusion()->falseVal(); } - - return IrBuilder::geExpr(lhs, rhs); } Val* SimplifyingIrBuilder::gtExpr(Val* lhs, Val* rhs) { - NVF_ERROR( - lhs->dtype() == rhs->dtype(), - "Comparison expressions require same dtype for inputs"); - - if (lhs->isConstScalar() && rhs->isConstScalar()) { - return (lhs->evaluateBool() > rhs->evaluateBool()) - ? lhs->fusion()->trueVal() - : lhs->fusion()->falseVal(); + switch (compareConstScalars(lhs, rhs)) { + case ScalarComparisonResult::NONCONST: + return IrBuilder::gtExpr(lhs, rhs); + case ScalarComparisonResult::GT: + return lhs->fusion()->trueVal(); + default: + return lhs->fusion()->falseVal(); } - - return IrBuilder::gtExpr(lhs, rhs); } Val* SimplifyingIrBuilder::whereExpr(Val* pred, Val* lhs, Val* rhs) { - NVF_ERROR( - pred->dtype() == DataType::Bool, - "Where requires a predicate as an input, but received"); - if (pred->isConstScalar() && pred->isABool()) { if (pred->evaluateBool()) { return lhs; @@ -708,7 +729,6 @@ Val* SimplifyingIrBuilder::whereExpr(Val* pred, Val* lhs, Val* rhs) { return rhs; } } - return IrBuilder::whereExpr(pred, lhs, rhs); } From 93f3737eefd938d8496e08e5dcce963d73ebccd2 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 14 Sep 2023 09:50:45 -0400 Subject: [PATCH 11/24] Use maybeCastExpr in resize --- csrc/ir/nodes.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index f858c186362..6469301d3c2 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -2846,9 +2846,8 @@ IterDomain* IterDomain::resize( } } - if (resized_id_size->dtype() != DataType::Index) { - resized_id_size = castOp(DataType::Index, resized_id_size); - } + resized_id_size = + SimplifyingIrBuilder::maybeCastExpr(DataType::Index, resized_id_size); auto resized_id = IterDomainBuilder(in->container()->zeroVal(), resized_id_size) From ec315a1f0f0b242d5b1cc47022516000f7c8ab23 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 14 Sep 2023 10:00:01 -0400 Subject: [PATCH 12/24] Update doc comment on slice op --- csrc/ops/alias.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/csrc/ops/alias.h b/csrc/ops/alias.h index 374d2d04329..8cfb3e75cb1 100644 --- a/csrc/ops/alias.h +++ b/csrc/ops/alias.h @@ -108,7 +108,9 @@ TORCH_CUDA_CU_API TensorView* cat( std::optional iter_type_opt = std::nullopt); //! Return a tensor where each dimension is sliced as specified by the -//! ranges parameter. Stepping must be one at this moment. +//! ranges parameter. Stepping must be one at this moment. The semantics of +//! slicing with negative values and values >= extent follow those of numpy and +//! PyTorch. TORCH_CUDA_CU_API TensorView* slice( TensorView* inp, const std::vector& ranges); From 98f3519c14902742a5fe837d9da4e3fc2647f4ef Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 14 Sep 2023 10:18:42 -0400 Subject: [PATCH 13/24] Add input range test --- test/test_resize.cpp | 73 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 61 insertions(+), 12 deletions(-) diff --git a/test/test_resize.cpp b/test/test_resize.cpp index b28c5c9669e..5fd27989e6a 100644 --- a/test/test_resize.cpp +++ b/test/test_resize.cpp @@ -1141,8 +1141,8 @@ TEST_F(ResizeTest, FusionResizeSlice5) { testValidate(&fusion, cg_outputs, aten_inputs, {t2, t4}, __LINE__, __FILE__); } -// Test slice with a variety of (constant) inputs -TEST_F(NVFuserTest, FusionResizeSliceShmoo_CUDA) { +// Test slice with a variety of constant ranges +TEST_F(NVFuserTest, FusionResizeSliceConstantShmoo_CUDA) { for (auto [start, stop] : std::vector>( {// Slice with end beyond size of input. This should clip to input, // not pad. @@ -1179,8 +1179,8 @@ TEST_F(NVFuserTest, FusionResizeSliceShmoo_CUDA) { } } -// Slice with start beyond size of input. This should produce zero-size tensor. -TEST_F(NVFuserTest, FusionResizeSlice7_CUDA) { +// Test slice with a variety of non-constant input ranges +TEST_F(NVFuserTest, FusionResizeSliceInputShmoo_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1188,24 +1188,72 @@ TEST_F(NVFuserTest, FusionResizeSlice7_CUDA) { // concrete shapes to avoid dynamic Fusion auto tv0 = makeConcreteTensor(shape); + auto s0 = IrBuilder::create(DataType::Index); + auto s1 = IrBuilder::create(DataType::Index); fusion.addInput(tv0); + fusion.addInput(s0); + fusion.addInput(s1); - auto tv1 = - slice(tv0, {{IrBuilder::create(11), IrBuilder::create(13)}}); + auto tv1 = slice(tv0, {{s0, s1}}); fusion.addOutput(tv1); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn(shape, options); - std::vector aten_inputs({t0}); + { + // Concretize so that we set output IterType as Iteration. We should now + // have expressions that work with any input range. + ExpressionEvaluator expr_eval; + + expr_eval.bind(tv0->axis(0)->extent(), 9); + expr_eval.bind(s0, 0); + expr_eval.bind(s1, 9); + + auto initial_info = DynamicTransform::getInitialInfo(&fusion); + auto info = DynamicTransformConcretizationInfo(&initial_info, &expr_eval); + + DynamicTransform::concretizeFusion(&fusion, &info); + NVF_CHECK( + !fusion.hasDynamicTransform(), "Expected to have no dynamic transform"); + } FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); + fe.compileFusion(&fusion); - auto ref = t0.index({at::indexing::Slice(11, 13)}); + auto t0 = at::randn(shape, options); + for (auto [start, stop] : std::vector>({ + // Slice with end beyond size of input. This should clip to input, + // not pad. + {0, 5}, + 3, 9}, + + , 4}, + + 5}, + + 11}, + + 13}, + + 8}, + + -1}, + + -5}, + + -1}, + + , 9}, + + , 0}, + + { + std::vector aten_inputs({t0, start, stop}); + auto cg_outputs = fe.runFusion(aten_inputs); - testValidate(&fusion, cg_outputs, aten_inputs, {ref}, __LINE__, __FILE__); + auto ref = t0.index({at::indexing::Slice(start, stop)}); + + testValidate(&fusion, cg_outputs, aten_inputs, {ref}, __LINE__, __FILE__); + } } // Auto scheduled version of Slice1 @@ -2896,3 +2944,4 @@ TEST_F(ResizeTest, CatOfExpandedBroadcast) { } } // namespace nvfuser + \ No newline at end of file From b557e380d381e857fbf48db383dae795c6ce98b8 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 14 Sep 2023 10:21:05 -0400 Subject: [PATCH 14/24] Reformat test --- test/test_resize.cpp | 43 +++++++++++++++---------------------------- 1 file changed, 15 insertions(+), 28 deletions(-) diff --git a/test/test_resize.cpp b/test/test_resize.cpp index 5fd27989e6a..1f93643f3ea 100644 --- a/test/test_resize.cpp +++ b/test/test_resize.cpp @@ -1220,33 +1220,21 @@ TEST_F(NVFuserTest, FusionResizeSliceInputShmoo_CUDA) { fe.compileFusion(&fusion); auto t0 = at::randn(shape, options); - for (auto [start, stop] : std::vector>({ - // Slice with end beyond size of input. This should clip to input, - // not pad. - {0, 5}, - 3, 9}, - - , 4}, - - 5}, - - 11}, - - 13}, - - 8}, - - -1}, - - -5}, - - -1}, - - , 9}, - - , 0}, - - { + for (auto [start, stop] : std::vector>( + {// Slice with end beyond size of input. This should clip to input, + // not pad. + {0, 5}, + {3, 9}, + {3, 4}, + {7, 5}, + {0, 11}, + {11, 13}, + {-3, 8}, + {-3, -1}, + {-3, -5}, + {13, -1}, + {-11, 9}, + {-11, 0}})) { std::vector aten_inputs({t0, start, stop}); auto cg_outputs = fe.runFusion(aten_inputs); @@ -2944,4 +2932,3 @@ TEST_F(ResizeTest, CatOfExpandedBroadcast) { } } // namespace nvfuser - \ No newline at end of file From 72493d20e19b9df2f99aabcf4500ecd1d87a212c Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 26 Sep 2023 08:42:48 -0400 Subject: [PATCH 15/24] Simplify clipping exprs, clean up op --- csrc/ops/alias.cpp | 78 ++++++++++++++++++++++++---------------------- 1 file changed, 40 insertions(+), 38 deletions(-) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 7bd0d33e22e..909a96234ef 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -701,34 +701,51 @@ TensorView* slice(TensorView* inp, const std::vector& ranges) { ", Expected: ", ndims); - const auto normalize_arg = [](Val* a, Val* extent, Val* def) -> Val* { - if (a == nullptr) { - return def; - } else if (a->isZeroInt() || a->sameAs(extent)) { - // These do not need normalization + const auto normalize_slice_range = [](Slice range, Val* extent) -> Slice { + // Cast inputs to Index first + if (extent->dtype() != DataType::Index) { + extent = SimplifyingIrBuilder::maybeCastExpr(DataType::Index, extent); + } + + auto zero = FusionGuard::getCurFusion()->zeroVal(DataType::Index); + + // norm_start = max(0, start < 0 ? start + extent : start) + if (range.start == nullptr) { + range.start = zero; } else { - // Negative start and stop values are relative to end of axis - a = where( - lt(a, FusionGuard::getCurFusion()->zeroVal()), add(a, extent), a); + if (range.start->dtype() != DataType::Index) { + range.start = + SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.start); + } + range.start = SimplifyingIrBuilder::maxExpr( + zero, + where(lt(range.start, zero), add(range.start, extent), range.start)); } - if (a->dtype() != DataType::Index) { - a = SimplifyingIrBuilder::maybeCastExpr(DataType::Index, a); + + // norm_stop = max(norm_start, min(extent, stop < 0 ? stop + extent : stop) + if (range.stop == nullptr) { + range.stop = extent; + } else { + if (range.stop->dtype() != DataType::Index) { + range.stop = + SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.stop); + } + range.stop = SimplifyingIrBuilder::maxExpr( + range.start, + SimplifyingIrBuilder::minExpr( + extent, + where( + lt(range.stop, zero), add(range.stop, extent), range.stop))); } - return a; - }; - const auto normalize_slice_range = [&normalize_arg]( - Slice range, Val* extent) -> Slice { - range.start = normalize_arg( - range.start, extent, FusionGuard::getCurFusion()->zeroVal()); - range.stop = normalize_arg(range.stop, extent, extent); + // Ensure step is of type Index if (range.step == nullptr) { - range.step = FusionGuard::getCurFusion()->oneVal(); - } - if (range.step->dtype() != DataType::Index) { + range.step = FusionGuard::getCurFusion()->oneVal(DataType::Index); + } else if (range.step->dtype() != DataType::Index) { range.step = SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.step); } + return range; }; @@ -736,7 +753,7 @@ TensorView* slice(TensorView* inp, const std::vector& ranges) { // Step not supported yet NVF_CHECK( range.step == nullptr || range.step->isOneInt(), - "Unsupported step: ", + "Unsupported step (must be 1 or null): ", range.step->toString()); } @@ -758,27 +775,12 @@ TensorView* slice(TensorView* inp, const std::vector& ranges) { out_rf_id = out_root_id; } else { // Clip the start and stop values to the extent of the input - auto zero = FusionGuard::getCurFusion()->zeroVal(DataType::Index); - const auto clip_to_extent = [&zero](Val* a, Val* ext) { - return SimplifyingIrBuilder::minExpr( - ext, - SimplifyingIrBuilder::whereExpr( - SimplifyingIrBuilder::ltExpr(a, zero), - SimplifyingIrBuilder::maxExpr(a, zero), - a)); - }; - auto clipped_start = clip_to_extent(range.start, inp_root_id->extent()); - // stop is clipped the same as start, then we additionally clip so that - // stop >= start - auto clipped_stop = SimplifyingIrBuilder::maxExpr( - clip_to_extent(range.stop, inp_root_id->extent()), clipped_start); - out_root_id = IterDomainBuilder(inp_root_id).is_rfactor_domain(true).build(); out_rf_id = IterDomain::resize( out_root_id, - SimplifyingIrBuilder::negExpr(clipped_start), - SimplifyingIrBuilder::subExpr(clipped_stop, inp_root_id->extent()), + SimplifyingIrBuilder::negExpr(range.start), + SimplifyingIrBuilder::subExpr(range.stop, inp_root_id->extent()), true); needs_real_slicing = true; } From a9d9e5878a759113e9c42e672d22bd6c3bc030fc Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 26 Sep 2023 09:02:21 -0400 Subject: [PATCH 16/24] Simplify maybe cast exprs --- csrc/ops/alias.cpp | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 909a96234ef..208f34e0e81 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -703,9 +703,7 @@ TensorView* slice(TensorView* inp, const std::vector& ranges) { const auto normalize_slice_range = [](Slice range, Val* extent) -> Slice { // Cast inputs to Index first - if (extent->dtype() != DataType::Index) { - extent = SimplifyingIrBuilder::maybeCastExpr(DataType::Index, extent); - } + extent = SimplifyingIrBuilder::maybeCastExpr(DataType::Index, extent); auto zero = FusionGuard::getCurFusion()->zeroVal(DataType::Index); @@ -713,10 +711,8 @@ TensorView* slice(TensorView* inp, const std::vector& ranges) { if (range.start == nullptr) { range.start = zero; } else { - if (range.start->dtype() != DataType::Index) { - range.start = - SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.start); - } + range.start = + SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.start); range.start = SimplifyingIrBuilder::maxExpr( zero, where(lt(range.start, zero), add(range.start, extent), range.start)); @@ -726,10 +722,8 @@ TensorView* slice(TensorView* inp, const std::vector& ranges) { if (range.stop == nullptr) { range.stop = extent; } else { - if (range.stop->dtype() != DataType::Index) { - range.stop = - SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.stop); - } + range.stop = + SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.stop); range.stop = SimplifyingIrBuilder::maxExpr( range.start, SimplifyingIrBuilder::minExpr( @@ -741,7 +735,7 @@ TensorView* slice(TensorView* inp, const std::vector& ranges) { // Ensure step is of type Index if (range.step == nullptr) { range.step = FusionGuard::getCurFusion()->oneVal(DataType::Index); - } else if (range.step->dtype() != DataType::Index) { + } else { range.step = SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.step); } From 90f96165644364a0c273f2e07298a032429c7188 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 26 Sep 2023 09:02:30 -0400 Subject: [PATCH 17/24] Remove unneeded change to nodes.cpp --- csrc/ir/nodes.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 69d1e53c03e..02c9eab0d67 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -2912,9 +2912,6 @@ IterDomain* IterDomain::resize( } } - resized_id_size = - SimplifyingIrBuilder::maybeCastExpr(DataType::Index, resized_id_size); - auto resized_id = IterDomainBuilder(in->container()->zeroVal(), resized_id_size) .is_rfactor_domain(mark_as_rfactor) From ca3a674bca33396ce5a35d43a204467942e5d8c1 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 26 Sep 2023 10:47:03 -0400 Subject: [PATCH 18/24] Restore check for trivial slice --- csrc/ops/alias.cpp | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 208f34e0e81..f27ac49dba8 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -703,33 +703,39 @@ TensorView* slice(TensorView* inp, const std::vector& ranges) { const auto normalize_slice_range = [](Slice range, Val* extent) -> Slice { // Cast inputs to Index first - extent = SimplifyingIrBuilder::maybeCastExpr(DataType::Index, extent); + auto cast_extent = + SimplifyingIrBuilder::maybeCastExpr(DataType::Index, extent); auto zero = FusionGuard::getCurFusion()->zeroVal(DataType::Index); // norm_start = max(0, start < 0 ? start + extent : start) if (range.start == nullptr) { range.start = zero; - } else { + } else if (!range.start->isZeroInt()) { range.start = SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.start); range.start = SimplifyingIrBuilder::maxExpr( zero, - where(lt(range.start, zero), add(range.start, extent), range.start)); + where( + lt(range.start, zero), + add(range.start, cast_extent), + range.start)); } // norm_stop = max(norm_start, min(extent, stop < 0 ? stop + extent : stop) if (range.stop == nullptr) { - range.stop = extent; - } else { + range.stop = cast_extent; + } else if (!range.stop->sameAs(extent)) { range.stop = SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.stop); range.stop = SimplifyingIrBuilder::maxExpr( range.start, SimplifyingIrBuilder::minExpr( - extent, + cast_extent, where( - lt(range.stop, zero), add(range.stop, extent), range.stop))); + lt(range.stop, zero), + add(range.stop, cast_extent), + range.stop))); } // Ensure step is of type Index From 2fd5c23dcd1900ed50a4452091738c7066b96f1f Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 26 Sep 2023 11:12:12 -0400 Subject: [PATCH 19/24] Use SimplifyingIrBuilder fixes bcast error --- csrc/ops/alias.cpp | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index f27ac49dba8..dbd0814c828 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -703,39 +703,36 @@ TensorView* slice(TensorView* inp, const std::vector& ranges) { const auto normalize_slice_range = [](Slice range, Val* extent) -> Slice { // Cast inputs to Index first - auto cast_extent = - SimplifyingIrBuilder::maybeCastExpr(DataType::Index, extent); - auto zero = FusionGuard::getCurFusion()->zeroVal(DataType::Index); // norm_start = max(0, start < 0 ? start + extent : start) if (range.start == nullptr) { range.start = zero; } else if (!range.start->isZeroInt()) { - range.start = - SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.start); range.start = SimplifyingIrBuilder::maxExpr( zero, - where( - lt(range.start, zero), - add(range.start, cast_extent), + SimplifyingIrBuilder::whereExpr( + SimplifyingIrBuilder::ltExpr(range.start, zero), + SimplifyingIrBuilder::addExpr(range.start, extent), range.start)); + range.start = + SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.start); } // norm_stop = max(norm_start, min(extent, stop < 0 ? stop + extent : stop) if (range.stop == nullptr) { - range.stop = cast_extent; + range.stop = extent; } else if (!range.stop->sameAs(extent)) { - range.stop = - SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.stop); range.stop = SimplifyingIrBuilder::maxExpr( range.start, SimplifyingIrBuilder::minExpr( - cast_extent, - where( - lt(range.stop, zero), - add(range.stop, cast_extent), + extent, + SimplifyingIrBuilder::whereExpr( + SimplifyingIrBuilder::ltExpr(range.stop, zero), + SimplifyingIrBuilder::addExpr(range.stop, extent), range.stop))); + range.stop = + SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.stop); } // Ensure step is of type Index From 2bc47485739dcd6e82860f69612eff396f6352e7 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 26 Sep 2023 11:15:47 -0400 Subject: [PATCH 20/24] Cast extent first --- csrc/ops/alias.cpp | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index dbd0814c828..bfd831e0d8b 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -702,37 +702,39 @@ TensorView* slice(TensorView* inp, const std::vector& ranges) { ndims); const auto normalize_slice_range = [](Slice range, Val* extent) -> Slice { - // Cast inputs to Index first + auto cast_extent = + SimplifyingIrBuilder::maybeCastExpr(DataType::Index, extent); + auto zero = FusionGuard::getCurFusion()->zeroVal(DataType::Index); // norm_start = max(0, start < 0 ? start + extent : start) if (range.start == nullptr) { range.start = zero; } else if (!range.start->isZeroInt()) { + range.start = + SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.start); range.start = SimplifyingIrBuilder::maxExpr( zero, SimplifyingIrBuilder::whereExpr( SimplifyingIrBuilder::ltExpr(range.start, zero), - SimplifyingIrBuilder::addExpr(range.start, extent), + SimplifyingIrBuilder::addExpr(range.start, cast_extent), range.start)); - range.start = - SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.start); } // norm_stop = max(norm_start, min(extent, stop < 0 ? stop + extent : stop) if (range.stop == nullptr) { - range.stop = extent; + range.stop = cast_extent; } else if (!range.stop->sameAs(extent)) { + range.stop = + SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.stop); range.stop = SimplifyingIrBuilder::maxExpr( range.start, SimplifyingIrBuilder::minExpr( - extent, + cast_extent, SimplifyingIrBuilder::whereExpr( SimplifyingIrBuilder::ltExpr(range.stop, zero), - SimplifyingIrBuilder::addExpr(range.stop, extent), + SimplifyingIrBuilder::addExpr(range.stop, cast_extent), range.stop))); - range.stop = - SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.stop); } // Ensure step is of type Index From 233198993bd3dab87684ee4fad39da6a6ed9b879 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 26 Sep 2023 10:00:00 -0700 Subject: [PATCH 21/24] Adding a reshape example (#944) --- test/test_tutorial.cpp | 184 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 184 insertions(+) diff --git a/test/test_tutorial.cpp b/test/test_tutorial.cpp index d591062e046..757ecadc2af 100644 --- a/test/test_tutorial.cpp +++ b/test/test_tutorial.cpp @@ -446,4 +446,188 @@ TEST_F(Tutorial, ReductionRFactor) { } } +TEST_F(Tutorial, Reshape) { + { + // Simple reshape example + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + // Shape of tv0 is assumed to be [4, 8], which is then reshaped to [32] + auto tv1 = reshape(tv0, {4, 8}, {32}); + fusion.addOutput(tv1); + + if (verbose_) { + // Notice that tv1 has root and rfactor domains. The root domain + // should consist of two IterDomains, whreas the rfactor domain + // consists of a single IterDomain that is an output of a merge + // operation of the two root IterDomains + fusion.print(); + } + + // Check if the tv1 domains are generated as expected + ASSERT_TRUE(tv1->hasRFactor()); + ASSERT_EQ(tv1->getRFactorDomain().size(), 1); + ASSERT_TRUE(tv1->getRFactorDomain().at(0)->definition()->isA()); + Merge* tv1_merge = tv1->getRFactorDomain().at(0)->definition()->as(); + ASSERT_EQ(tv1_merge->inner(), tv1->getRootDomain().at(1)); + ASSERT_EQ(tv1_merge->outer(), tv1->getRootDomain().at(0)); + } + + { + // Reshape example with broadcast domains + Fusion fusion; + FusionGuard fg(&fusion); + + // Create a 3D tensor with a broadcast domain + auto tv0 = makeConcreteTensor({1, -1, -1}); + fusion.addInput(tv0); + + // tv0 is first squeezed and then reshaped and unsqueezed + auto tv1 = reshape(tv0, {1, 2, 3}, {3, 2, 1}); + fusion.addOutput(tv1); + + if (verbose_) { + fusion.print(); + } + + // The fusion should look like: + // + // tv1 = unsqueeze(reshape(squeeze(tv0))); + ASSERT_TRUE(tv1->definition()->isA()); + auto reshape_output = tv1->definition()->input(0)->as(); + ASSERT_TRUE(reshape_output->definition()->isA()); + auto squeeze_output = + reshape_output->definition()->input(0)->as(); + ASSERT_TRUE(squeeze_output->definition()->isA()); + + ASSERT_TRUE(reshape_output->hasRFactor()); + ASSERT_EQ(reshape_output->getRFactorDomain().size(), 2); + ASSERT_TRUE( + reshape_output->getRFactorDomain().at(0)->definition()->isA()); + auto reshape_output_split = + reshape_output->getRFactorDomain().at(0)->definition()->as(); + ASSERT_EQ( + reshape_output_split->outer(), + reshape_output->getRFactorDomain().at(0)); + ASSERT_EQ( + reshape_output_split->inner(), + reshape_output->getRFactorDomain().at(1)); + ASSERT_TRUE(reshape_output_split->in()->definition()->isA()); + auto reshape_output_merge = + reshape_output_split->in()->definition()->as(); + ASSERT_EQ( + reshape_output_merge->outer(), reshape_output->getRootDomain().at(0)); + ASSERT_EQ( + reshape_output_merge->inner(), reshape_output->getRootDomain().at(1)); + + // So far, the fusion has transformations as part of its + // definition. It can be further extended with scheduling transformations. + reshape_output->merge(0, 1); + reshape_output->split(0, 128); + + ASSERT_TRUE( + reshape_output->getLeafDomain().at(0)->definition()->isA()); + ASSERT_EQ( + reshape_output->getLeafDomain() + .at(0) + ->definition() + ->as() + ->inner(), + reshape_output->getLeafDomain().at(1)); + ASSERT_TRUE(reshape_output->getLeafDomain() + .at(0) + ->definition() + ->as() + ->in() + ->definition() + ->isA()); + ASSERT_EQ( + reshape_output->getLeafDomain() + .at(0) + ->definition() + ->as() + ->in() + ->definition() + ->as() + ->outer(), + reshape_output->getRFactorDomain().at(0)); + ASSERT_EQ( + reshape_output->getLeafDomain() + .at(0) + ->definition() + ->as() + ->in() + ->definition() + ->as() + ->inner(), + reshape_output->getRFactorDomain().at(1)); + + // Here's how we propagate the transformations of reshape_output + // to all other tensors in the fusion + TransformPropagatorWithCheck propagator(reshape_output); + MaxRootDomainInfoSpanningTree(reshape_output).traverse(&propagator); + + // Now, all tensors, including those before the reshape op, should + // be transformed to 2D tensors with an inner domain of extent + // 128. + if (verbose_) { + fusion.print(); + } + + // Notice that all transformations of the reshape tensor, + // including both the reshape and scheduling transformations, are + // propagated. For example, squeeze_output should have the merge and split + // for the reshape, followed by another merge and split for + // scheduling. Specifically: + // + // Root domain: [b0, i1, i2] + // merge(1, 2) -> [b0, i1*i2] + // outer split(1, 3) -> [b0, 3, i1*i2/3] + // merge(1, 2) -> [b0, 3*i1*i2/3] + // split(1, 128) -> [b0, 3*i1*i2/3/128, 128] + ASSERT_TRUE( + squeeze_output->getLeafDomain().at(0)->definition()->isA()); + auto squeeze_output_second_split = + squeeze_output->getLeafDomain().at(0)->definition()->as(); + ASSERT_EQ( + squeeze_output_second_split->outer(), + squeeze_output->getLeafDomain().at(0)); + ASSERT_EQ( + squeeze_output_second_split->inner(), + squeeze_output->getLeafDomain().at(1)); + + ASSERT_TRUE(squeeze_output_second_split->in()->definition()->isA()); + auto squeeze_output_second_merge = + squeeze_output_second_split->in()->definition()->as(); + + ASSERT_TRUE( + squeeze_output_second_merge->outer()->definition()->isA()); + auto squeeze_output_first_split = + squeeze_output_second_merge->outer()->definition()->as(); + ASSERT_EQ( + squeeze_output_first_split->outer(), + squeeze_output_second_merge->outer()); + ASSERT_EQ( + squeeze_output_first_split->inner(), + squeeze_output_second_merge->inner()); + + ASSERT_TRUE(squeeze_output_first_split->in()->definition()->isA()); + auto squeeze_output_first_merge = + squeeze_output_first_split->in()->definition()->as(); + ASSERT_EQ( + squeeze_output_first_merge->outer(), + squeeze_output->getRootDomain().at(0)); + ASSERT_EQ( + squeeze_output_first_merge->inner(), + squeeze_output->getRootDomain().at(1)); + + // Note that all the transformations of squeeze_output are scheduling + // transformations, thus it should not have a rfactor domain + ASSERT_FALSE(squeeze_output->hasRFactor()); + } +} + } // namespace nvfuser From 74eb074ee7f472f21bc29dfa01092e6538100345 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 26 Sep 2023 13:42:22 -0400 Subject: [PATCH 22/24] Remove manual refs and add FEC test --- test/test_resize.cpp | 57 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 51 insertions(+), 6 deletions(-) diff --git a/test/test_resize.cpp b/test/test_resize.cpp index 0e0154c673c..f5f92f002fc 100644 --- a/test/test_resize.cpp +++ b/test/test_resize.cpp @@ -1156,9 +1156,7 @@ TEST_F(NVFuserTest, FusionResizeSliceConstantShmoo_CUDA) { fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); - auto ref = t0.index({at::indexing::Slice(start, stop)}); - - testValidate(&fusion, cg_outputs, aten_inputs, {ref}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } } @@ -1217,13 +1215,60 @@ TEST_F(NVFuserTest, FusionResizeSliceInputShmoo_CUDA) { {-3, -5}, {13, -1}, {-11, 9}, - {-11, 0}})) { + {-11, 0}, + {-13, -11}})) { std::vector aten_inputs({t0, start, stop}); auto cg_outputs = fe.runFusion(aten_inputs); - auto ref = t0.index({at::indexing::Slice(start, stop)}); + testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); + } +} + +// Same as FusionResizeSliceInputShmoo_CUDA but use FusionExecutorCache, which +// might re-concretize when output sizes change +TEST_F(NVFuserTest, FusionResizeSliceInputShmooFusionExecutorCache_CUDA) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + std::vector shape({9}); + + // concrete shapes to avoid dynamic Fusion + auto tv0 = makeConcreteTensor(shape); + auto s0 = IrBuilder::create(DataType::Index); + auto s1 = IrBuilder::create(DataType::Index); + fusion->addInput(tv0); + fusion->addInput(s0); + fusion->addInput(s1); + + auto tv1 = slice(tv0, {{s0, s1}}); + fusion->addOutput(tv1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + FusionExecutorCache fec(std::move(fusion_ptr)); + + auto t0 = at::randn(shape, options); + for (auto [start, stop] : std::vector>( + {// Slice with end beyond size of input. This should clip to input, + // not pad. + {0, 5}, + {3, 9}, + {3, 4}, + {7, 5}, + {0, 11}, + {11, 13}, + {-3, 8}, + {-3, -1}, + {-3, -5}, + {13, -1}, + {-11, 9}, + {-11, 0}, + {-13, -11}})) { + std::vector aten_inputs({t0, start, stop}); + auto cg_outputs = fec.runFusionWithInputs(aten_inputs); - testValidate(&fusion, cg_outputs, aten_inputs, {ref}, __LINE__, __FILE__); + testValidate(fec.fusion(), cg_outputs, aten_inputs, __LINE__, __FILE__); } } From d7a4b5698ead6bc017e5116cdb05b769e3c9e6ae Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 26 Sep 2023 13:42:33 -0400 Subject: [PATCH 23/24] Change check for invalid extents to be >= 0 --- csrc/dynamic_transform.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index b4f205ca487..60cc4d28b7a 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -316,7 +316,7 @@ void DynamicTransformConcretizationInfo::analyzeResizes( out_id->toString()); auto extent_int = extent_val.as(); NVF_ERROR( - extent_int > 0, + extent_int >= 0, "Invalid resized domain extent ", extent_int, " for domain ", From 87dae14442d44dc302987e5ea1dae7e4a91153a8 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 26 Sep 2023 14:00:49 -0400 Subject: [PATCH 24/24] Use same set of slice cases for all three tests --- test/test_resize.cpp | 58 ++++++++++++++------------------------------ 1 file changed, 18 insertions(+), 40 deletions(-) diff --git a/test/test_resize.cpp b/test/test_resize.cpp index f5f92f002fc..7d04e555040 100644 --- a/test/test_resize.cpp +++ b/test/test_resize.cpp @@ -1124,16 +1124,24 @@ TEST_F(ResizeTest, FusionResizeSlice5) { testValidate(&fusion, cg_outputs, aten_inputs, {t2, t4}, __LINE__, __FILE__); } +std::vector> slice_cases( + {{0, 5}, + {3, 9}, + {3, 4}, + {7, 5}, + {0, 11}, + {11, 13}, + {-3, 8}, + {-3, -1}, + {-3, -5}, + {13, -1}, + {-11, 9}, + {-11, 0}, + {-13, -11}}); + // Test slice with a variety of constant ranges TEST_F(NVFuserTest, FusionResizeSliceConstantShmoo_CUDA) { - for (auto [start, stop] : std::vector>( - {// Slice with end beyond size of input. This should clip to input, - // not pad. - {0, 11}, - {11, 13}, - {-3, 8}, - {-3, -1}, - {13, -1}})) { + for (auto [start, stop] : slice_cases) { Fusion fusion; FusionGuard fg(&fusion); @@ -1201,22 +1209,7 @@ TEST_F(NVFuserTest, FusionResizeSliceInputShmoo_CUDA) { fe.compileFusion(&fusion); auto t0 = at::randn(shape, options); - for (auto [start, stop] : std::vector>( - {// Slice with end beyond size of input. This should clip to input, - // not pad. - {0, 5}, - {3, 9}, - {3, 4}, - {7, 5}, - {0, 11}, - {11, 13}, - {-3, 8}, - {-3, -1}, - {-3, -5}, - {13, -1}, - {-11, 9}, - {-11, 0}, - {-13, -11}})) { + for (auto [start, stop] : slice_cases) { std::vector aten_inputs({t0, start, stop}); auto cg_outputs = fe.runFusion(aten_inputs); @@ -1249,22 +1242,7 @@ TEST_F(NVFuserTest, FusionResizeSliceInputShmooFusionExecutorCache_CUDA) { FusionExecutorCache fec(std::move(fusion_ptr)); auto t0 = at::randn(shape, options); - for (auto [start, stop] : std::vector>( - {// Slice with end beyond size of input. This should clip to input, - // not pad. - {0, 5}, - {3, 9}, - {3, 4}, - {7, 5}, - {0, 11}, - {11, 13}, - {-3, 8}, - {-3, -1}, - {-3, -5}, - {13, -1}, - {-11, 9}, - {-11, 0}, - {-13, -11}})) { + for (auto [start, stop] : slice_cases) { std::vector aten_inputs({t0, start, stop}); auto cg_outputs = fec.runFusionWithInputs(aten_inputs);