From cd3448f57c4b1701358d019b7a5ec3eafd6b9d21 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 30 Mar 2023 13:04:19 -0700 Subject: [PATCH 01/10] Better test --- csrc/utils.h | 13 +++- test/test_expr_simplifier.cpp | 111 ++++++++++++++++++++++++++++------ 2 files changed, 102 insertions(+), 22 deletions(-) diff --git a/csrc/utils.h b/csrc/utils.h index cc6c17b44fb..db922adc6c2 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -12,6 +12,7 @@ #include #include +#include #include #include #include @@ -267,9 +268,8 @@ std::vector getSortedKeys( // Based on https://stackoverflow.com/a/9154394 template -static auto hasToStringHelper(int) -> decltype( - std::declval::type>().toString(), - std::true_type{}); +static auto hasToStringHelper(int) + -> decltype(std::declval::type>().toString(), std::true_type{}); template static auto hasToStringHelper(long) -> std::false_type; @@ -374,6 +374,13 @@ std::string toDelimitedString( return toDelimitedString(vec.begin(), vec.end(), delim); } +template +std::string toDelimitedString( + const std::deque& dq, + std::string delim = ", ") { + return toDelimitedString(dq.begin(), dq.end(), delim); +} + template void unrolled_for(func_t fun) { if constexpr (index < stop) { diff --git a/test/test_expr_simplifier.cpp b/test/test_expr_simplifier.cpp index b289dbbcbd9..9971b58cac3 100644 --- a/test/test_expr_simplifier.cpp +++ b/test/test_expr_simplifier.cpp @@ -13,6 +13,7 @@ #include #include +#include #include #include #include @@ -80,13 +81,24 @@ namespace stupid_simple_compiler { using fun1_t = Val* (*)(Val*); using fun2_t = Val* (*)(Val*, Val*); -struct LeftParenthesis {}; +struct LeftParenthesis { + int64_t prev_lparen_pos; +}; +struct FunctionCall { + int64_t prev_lparen_pos; + std::string_view name; +}; +struct Comma {}; +struct LowestPrecedence {}; using token_t = std::variant< Val*, // variable or constant fun1_t, // unary op fun2_t, // binary op - LeftParenthesis>; + LeftParenthesis, + FunctionCall, + Comma, + LowestPrecedence>; Val* parseIdentifier(std::string_view token_str) { if (token_str == "true") { @@ -154,6 +166,23 @@ Val* parseNumber(std::string_view token_str) { } } +Val* functionCall(std::string_view name, std::deque args) { + if (name == "max") { + TORCH_CHECK( + args.size() == 2, "Invalid argument: ", toDelimitedString(args)); + return IrBuilder::maxExpr(args.at(0), args.at(1)); + } else if (name == "min") { + TORCH_CHECK( + args.size() == 2, "Invalid argument: ", toDelimitedString(args)); + return IrBuilder::minExpr(args.at(0), args.at(1)); + } else if (name == "ceilDiv") { + TORCH_CHECK( + args.size() == 2, "Invalid argument: ", toDelimitedString(args)); + return IrBuilder::ceilDivExpr(args.at(0), args.at(1)); + } + TORCH_CHECK(false, "Unknown function: ", name); +} + token_t parseToken(std::string_view token_str, bool& expect_val) { if (std::isalpha(token_str.at(0))) { TORCH_CHECK( @@ -219,6 +248,12 @@ token_t parseToken(std::string_view token_str, bool& expect_val) { // https://en.cppreference.com/w/cpp/language/operator_precedence int getOpPrecedence(token_t op) { + if (std::holds_alternative(op)) { + return std::numeric_limits::max(); + } + if (std::holds_alternative(op)) { + return 17; + } if (std::holds_alternative(op)) { auto uop = std::get(op); if (uop == fun1_t(neg) || uop == fun1_t(notOp)) { @@ -279,7 +314,19 @@ Val* parse(const char* str) { op); }; + auto eval_all_top = [&](token_t token) { + TORCH_CHECK(current != nullptr, "Expect value to evaluate top"); + while (!op_stack.empty() && + (std::holds_alternative(op_stack.back()) || + std::holds_alternative(op_stack.back())) && + getOpPrecedence(op_stack.back()) <= getOpPrecedence(token)) { + eval_top(); + } + }; + bool expect_val = true; + int64_t last_lparen_pos = -1; + while (!remaining.empty()) { const auto end_pos = remaining.find_first_of(' '); const auto token_str = remaining.substr(0, end_pos); @@ -287,19 +334,48 @@ Val* parse(const char* str) { if (token_str == "(") { TORCH_CHECK( expect_val, "Syntax error: not expecting ( but get ", token_str); - op_stack.push_back(LeftParenthesis{}); + op_stack.push_back(LeftParenthesis{last_lparen_pos}); + last_lparen_pos = op_stack.size() - 1; + } else if (token_str.back() == '(') { + TORCH_CHECK( + expect_val, + "Syntax error: not expecting function call but get ", + token_str); + op_stack.push_back(FunctionCall{ + last_lparen_pos, token_str.substr(0, token_str.size() - 1)}); + last_lparen_pos = op_stack.size() - 1; + } else if (token_str == ",") { + TORCH_CHECK(!expect_val, "Syntax error: not expecting comma"); + expect_val = true; + auto comma = Comma{}; + eval_all_top(comma); + value_stack.emplace_back(current); + op_stack.emplace_back(comma); + current = nullptr; } else if (token_str == ")") { TORCH_CHECK( !expect_val, "Syntax error: not expecting ) but get ", token_str); - // pop stack until meets matching ( - TORCH_CHECK(current != nullptr, "Expect value before )"); - while (true) { - TORCH_CHECK(!op_stack.empty(), "Unmatched )"); - if (std::holds_alternative(op_stack.back())) { + eval_all_top(LowestPrecedence{}); + auto last_lparen = op_stack.at(last_lparen_pos); + TORCH_CHECK(!op_stack.empty(), "Unmatched )"); + if (std::holds_alternative(last_lparen)) { + TORCH_INTERNAL_ASSERT(last_lparen_pos == (int64_t)op_stack.size() - 1); + auto lparen = std::get(op_stack.back()); + last_lparen_pos = lparen.prev_lparen_pos; + op_stack.pop_back(); + } else if (std::holds_alternative(last_lparen)) { + std::deque args{current}; + while (std::holds_alternative(op_stack.back())) { op_stack.pop_back(); - break; + args.push_front(value_stack.back()); + value_stack.pop_back(); } - eval_top(); + auto fc = std::get(op_stack.back()); + last_lparen_pos = fc.prev_lparen_pos; + op_stack.pop_back(); + current = functionCall(fc.name, std::move(args)); + } else { + TORCH_CHECK(false, "Unknown left parenthesis type"); } } else { token_t token = parseToken(token_str, expect_val); @@ -307,15 +383,9 @@ Val* parse(const char* str) { TORCH_CHECK(current == nullptr, "Don't expect value"); current = std::get(token); } else if (std::holds_alternative(token)) { - TORCH_CHECK(current == nullptr, "Don't expect value"); op_stack.push_back(token); } else if (std::holds_alternative(token)) { - TORCH_CHECK(current != nullptr, "Expect value before binary op"); - while (!op_stack.empty() && - !std::holds_alternative(op_stack.back()) && - getOpPrecedence(op_stack.back()) <= getOpPrecedence(token)) { - eval_top(); - } + eval_all_top(token); value_stack.push_back(current); op_stack.push_back(token); current = nullptr; @@ -416,6 +486,9 @@ TEST_F(ExprSimplifierTest, EliminateTrivialComputation_CUDA) { Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); + // constant folding + TORCH_CHECK(simplifyExpr("ceilDiv( 5 , 3 ) * 5"_)->sameAs("10"_)); + TORCH_CHECK(simplifyExpr("1 * i"_)->sameAs("i"_)); TORCH_CHECK(simplifyExpr("1.0 * d"_)->sameAs("d"_)); TORCH_CHECK(simplifyExpr("i * 1"_)->sameAs("i"_)); @@ -440,8 +513,8 @@ TEST_F(ExprSimplifierTest, EliminateTrivialComputation_CUDA) { TORCH_CHECK(simplifyExpr("b && b"_)->sameAs("b"_)); TORCH_CHECK(simplifyExpr("b || b"_)->sameAs("b"_)); - TORCH_CHECK(simplifyExpr(IrBuilder::maxExpr("i"_, "i"_))->sameAs("i"_)); - TORCH_CHECK(simplifyExpr(IrBuilder::minExpr("i"_, "i"_))->sameAs("i"_)); + TORCH_CHECK(simplifyExpr("max( i , i )"_)->sameAs("i"_)); + TORCH_CHECK(simplifyExpr("min( i , i )"_)->sameAs("i"_)); TORCH_CHECK(simplifyExpr("i / 1"_)->sameAs("i"_)); TORCH_CHECK(simplifyExpr("d / 1.0"_)->sameAs("d"_)); From 81dcd4877122d33dec21601aa36bb702a75931d6 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 30 Mar 2023 16:00:44 -0700 Subject: [PATCH 02/10] min-max simplifier --- csrc/expr_simplifier.cpp | 27 +++++++++++++++ test/test_expr_simplifier.cpp | 62 +++++++++++++++++++++-------------- 2 files changed, 64 insertions(+), 25 deletions(-) diff --git a/csrc/expr_simplifier.cpp b/csrc/expr_simplifier.cpp index 576b6c8930c..465fc0f7a4f 100644 --- a/csrc/expr_simplifier.cpp +++ b/csrc/expr_simplifier.cpp @@ -1611,6 +1611,33 @@ Val* eliminateTrivialComputation(Val* value, const Context& context) { } } } + { // max(a, b) -> a if a >= b, min(a, b) -> b if a >= b + if (op == BinaryOpType::Max || op == BinaryOpType::Min) { + std::vector simplified_input; + for (auto v : fop->inputs()) { + bool found_redundant = false; + for (auto& v2 : simplified_input) { + if ((op == BinaryOpType::Max && prove::lessEqual(v, v2, context)) || + (op == BinaryOpType::Min && prove::lessEqual(v2, v, context))) { + found_redundant = true; + break; + } else if ( + (op == BinaryOpType::Max && prove::lessEqual(v2, v, context)) || + (op == BinaryOpType::Min && prove::lessEqual(v, v2, context))) { + found_redundant = true; + v2 = v; + break; + } + } + if (!found_redundant) { + simplified_input.emplace_back(v); + } + } + if (simplified_input.size() < fop->inputs().size()) { + return maybeFlattenedOpOf(op, std::move(simplified_input)); + } + } + } } else if (auto bop = dynamic_cast(value->definition())) { auto lhs = foldConstants(bop->lhs()); auto rhs = foldConstants(bop->rhs()); diff --git a/test/test_expr_simplifier.cpp b/test/test_expr_simplifier.cpp index 9971b58cac3..96f71657895 100644 --- a/test/test_expr_simplifier.cpp +++ b/test/test_expr_simplifier.cpp @@ -486,53 +486,61 @@ TEST_F(ExprSimplifierTest, EliminateTrivialComputation_CUDA) { Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); + auto simplify = [](Val* x, Val* assumption) { + return simplifyExpr(x, {}, {assumption->as()}); + }; + // constant folding - TORCH_CHECK(simplifyExpr("ceilDiv( 5 , 3 ) * 5"_)->sameAs("10"_)); + ASSERT_TRUE(simplifyExpr("ceilDiv( 5 , 3 ) * 5"_)->sameAs("10"_)); - TORCH_CHECK(simplifyExpr("1 * i"_)->sameAs("i"_)); - TORCH_CHECK(simplifyExpr("1.0 * d"_)->sameAs("d"_)); - TORCH_CHECK(simplifyExpr("i * 1"_)->sameAs("i"_)); - TORCH_CHECK(simplifyExpr("d * 1.0"_)->sameAs("d"_)); + ASSERT_TRUE(simplifyExpr("1 * i"_)->sameAs("i"_)); + ASSERT_TRUE(simplifyExpr("1.0 * d"_)->sameAs("d"_)); + ASSERT_TRUE(simplifyExpr("i * 1"_)->sameAs("i"_)); + ASSERT_TRUE(simplifyExpr("d * 1.0"_)->sameAs("d"_)); ASSERT_EQ(simplifyExpr("0 * i"_)->getInt(), 0); ASSERT_EQ(simplifyExpr("i * 0"_)->getInt(), 0); - TORCH_CHECK(simplifyExpr("0 + i"_)->sameAs("i"_)); - TORCH_CHECK(simplifyExpr("0.0 + d"_)->sameAs("d"_)); - TORCH_CHECK(simplifyExpr("i + 0"_)->sameAs("i"_)); - TORCH_CHECK(simplifyExpr("d + 0.0"_)->sameAs("d"_)); + ASSERT_TRUE(simplifyExpr("0 + i"_)->sameAs("i"_)); + ASSERT_TRUE(simplifyExpr("0.0 + d"_)->sameAs("d"_)); + ASSERT_TRUE(simplifyExpr("i + 0"_)->sameAs("i"_)); + ASSERT_TRUE(simplifyExpr("d + 0.0"_)->sameAs("d"_)); - TORCH_CHECK(simplifyExpr("true && b"_)->sameAs("b"_)); - TORCH_CHECK(simplifyExpr("b && true"_)->sameAs("b"_)); + ASSERT_TRUE(simplifyExpr("true && b"_)->sameAs("b"_)); + ASSERT_TRUE(simplifyExpr("b && true"_)->sameAs("b"_)); ASSERT_EQ(simplifyExpr("false && b"_)->getBool(), false); ASSERT_EQ(simplifyExpr("b && false"_)->getBool(), false); ASSERT_EQ(simplifyExpr("true || b"_)->getBool(), true); ASSERT_EQ(simplifyExpr("b || true"_)->getBool(), true); - TORCH_CHECK(simplifyExpr("false || b"_)->sameAs("b"_)); - TORCH_CHECK(simplifyExpr("b || false"_)->sameAs("b"_)); + ASSERT_TRUE(simplifyExpr("false || b"_)->sameAs("b"_)); + ASSERT_TRUE(simplifyExpr("b || false"_)->sameAs("b"_)); - TORCH_CHECK(simplifyExpr("b && b"_)->sameAs("b"_)); - TORCH_CHECK(simplifyExpr("b || b"_)->sameAs("b"_)); - TORCH_CHECK(simplifyExpr("max( i , i )"_)->sameAs("i"_)); - TORCH_CHECK(simplifyExpr("min( i , i )"_)->sameAs("i"_)); + ASSERT_TRUE(simplifyExpr("b && b"_)->sameAs("b"_)); + ASSERT_TRUE(simplifyExpr("b || b"_)->sameAs("b"_)); + ASSERT_TRUE(simplifyExpr("max( i , i )"_)->sameAs("i"_)); + ASSERT_TRUE(simplifyExpr("min( i , i )"_)->sameAs("i"_)); + ASSERT_TRUE(simplify("max( i1 , i2 )"_, "i1 <= i2"_)->sameAs("i2"_)); + ASSERT_TRUE(simplify("max( i2 , i1 )"_, "i1 <= i2"_)->sameAs("i2"_)); + ASSERT_TRUE(simplify("min( i1 , i2 )"_, "i1 <= i2"_)->sameAs("i1"_)); + ASSERT_TRUE(simplify("min( i2 , i1 )"_, "i1 <= i2"_)->sameAs("i1"_)); - TORCH_CHECK(simplifyExpr("i / 1"_)->sameAs("i"_)); - TORCH_CHECK(simplifyExpr("d / 1.0"_)->sameAs("d"_)); + ASSERT_TRUE(simplifyExpr("i / 1"_)->sameAs("i"_)); + ASSERT_TRUE(simplifyExpr("d / 1.0"_)->sameAs("d"_)); ASSERT_EQ(simplifyExpr("0 / i"_)->getInt(), 0); ASSERT_EQ(simplifyExpr("i % 1"_)->getInt(), 0); // -(-a) -> a - TORCH_CHECK(simplifyExpr("- - i"_)->sameAs("i"_)); - TORCH_CHECK(simplifyExpr("~ ~ i"_)->sameAs("i"_)); - TORCH_CHECK(simplifyExpr("! ! b"_)->sameAs("b"_)); + ASSERT_TRUE(simplifyExpr("- - i"_)->sameAs("i"_)); + ASSERT_TRUE(simplifyExpr("~ ~ i"_)->sameAs("i"_)); + ASSERT_TRUE(simplifyExpr("! ! b"_)->sameAs("b"_)); // Test constant folding - TORCH_CHECK(simplifyExpr("1 + i + 1"_)->sameAs("i + 2"_)); - TORCH_CHECK(simplifyExpr("1.0 + d + 1.0"_)->sameAs("d + 2.0"_)); + ASSERT_TRUE(simplifyExpr("1 + i + 1"_)->sameAs("i + 2"_)); + ASSERT_TRUE(simplifyExpr("1.0 + d + 1.0"_)->sameAs("d + 2.0"_)); // Test that FlattenedAssocCommOp::sameAs ignores order - TORCH_CHECK(simplifyExpr("( i1 * i2 ) - ( i2 * i1 )"_)->isZeroInt()); + ASSERT_TRUE(simplifyExpr("( i1 * i2 ) - ( i2 * i1 )"_)->isZeroInt()); } TEST_F(ExprSimplifierTest, SimplifyDivisibleDivMod_CUDA) { @@ -846,6 +854,10 @@ TEST_F(ExprSimplifierTest, Compare_CUDA) { ASSERT_TRUE(*simplify("i1 >= i1 * i2"_, "i1 <= 0 && i2 > 0"_)); ASSERT_TRUE(*simplify("d1 <= d1 * d2"_, "d1 >= 0.0 && d2 >= 1.0"_)); ASSERT_TRUE(*simplify("d1 >= d1 * d2"_, "d1 <= 0.0 && d2 >= 1.0"_)); + ASSERT_TRUE( + *simplifyExpr( + "ceilDiv( T0.size[0] , 128 ) * 4 >= ceilDiv( T0.size[0] , 128 )"_) + ->getBool()); } TEST_F(ExprSimplifierTest, FundamentalDivisionWithRemainderProperty_CUDA) { From da109f2f2c561adaaa266837380c46e4b694790c Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 30 Mar 2023 16:34:59 -0700 Subject: [PATCH 03/10] assume --- CMakeLists.txt | 1 + csrc/assume.cpp | 35 +++++++++++++++++++++++++++++++++ csrc/assume.h | 15 ++++++++++++++ csrc/expr_simplifier.cpp | 7 +++++++ csrc/parallel_dimension_map.cpp | 3 ++- test/test_expr_simplifier.cpp | 34 ++++++++++++++++++++++++++++++++ 6 files changed, 94 insertions(+), 1 deletion(-) create mode 100644 csrc/assume.cpp create mode 100644 csrc/assume.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 7a491e33035..5044a890bcb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -60,6 +60,7 @@ endif() # nvfuser codegen sources set(NVFUSER_SRCS) list(APPEND NVFUSER_SRCS + ${NVFUSER_SRCS_DIR}/assume.cpp ${NVFUSER_SRCS_DIR}/compute_at.cpp ${NVFUSER_SRCS_DIR}/inlining.cpp ${NVFUSER_SRCS_DIR}/compute_at_map.cpp diff --git a/csrc/assume.cpp b/csrc/assume.cpp new file mode 100644 index 00000000000..b70da84c742 --- /dev/null +++ b/csrc/assume.cpp @@ -0,0 +1,35 @@ +#include +#include +#include + +#include + +namespace nvfuser::assume { + +Bool* tensorsAreNotEmpty(Val* value) { + std::vector todo{value}; + std::vector tensor_sizes; + while (!todo.empty()) { + auto v = todo.back(); + todo.pop_back(); + if (auto ns = dynamic_cast(v)) { + if (ns->isTensorSize()) { + tensor_sizes.emplace_back(v); + continue; + } + } + if (auto def = v->definition()) { + for (auto inp : def->inputs()) { + todo.emplace_back(inp); + } + } + } + Bool* result = nullptr; + for (auto ts : tensor_sizes) { + result = SimplifyingIrBuilder::andExpr( + result, SimplifyingIrBuilder::gtExpr(ts, ts->container()->zeroVal())); + } + return result; +} + +} // namespace nvfuser::assume \ No newline at end of file diff --git a/csrc/assume.h b/csrc/assume.h new file mode 100644 index 00000000000..a0d5399614c --- /dev/null +++ b/csrc/assume.h @@ -0,0 +1,15 @@ +#include + +// Return boolean values representing the conditional you want to assume + +namespace nvfuser::assume { + +// Assume that all tensor sizes appearing in `value` are positive. Return +// nullptr if not applicable. For example: +// tensorsAreNotEmpty(ceilDiv(T0.size[0], 5) * T0.size[1]) +// -> T0.size[0] > 0 && T0.size[1] > 0 +// tensorsAreNotEmpty(ceilDiv(i1, 5) * i2) +// -> nullptr +Bool* tensorsAreNotEmpty(Val* value); + +} // namespace nvfuser::assume diff --git a/csrc/expr_simplifier.cpp b/csrc/expr_simplifier.cpp index 465fc0f7a4f..176021fa462 100644 --- a/csrc/expr_simplifier.cpp +++ b/csrc/expr_simplifier.cpp @@ -1296,6 +1296,13 @@ bool isPositiveHelper(Val* value, const Context& context) { } return true; } + } else if (auto bop = dynamic_cast(value->definition())) { + auto op = bop->getBinaryOpType(); + if (op == BinaryOpType::CeilDiv) { + return isPositive(bop->lhs(), context) && + isValidDenominator(bop->rhs(), context) && + isNonNegative(bop->rhs(), context); + } } for (const auto& [a, b] : context.getKnownLessThan()) { if (a->isZero() && b->sameAs(value)) { diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index 9888f2fc848..37ff328894b 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -69,7 +70,7 @@ void ParallelDimensionMap::build(Fusion* fusion) { // Simplify dim_map_ for (auto& [k, v] : dim_map_) { - v = simplifyExpr(v); + v = simplifyExpr(v, {}, {assume::tensorsAreNotEmpty(v)}); } // Compute exact_types_ diff --git a/test/test_expr_simplifier.cpp b/test/test_expr_simplifier.cpp index 96f71657895..d4b80a2689c 100644 --- a/test/test_expr_simplifier.cpp +++ b/test/test_expr_simplifier.cpp @@ -7,6 +7,7 @@ // clang-format on #include +#include #include #include #include @@ -858,6 +859,9 @@ TEST_F(ExprSimplifierTest, Compare_CUDA) { *simplifyExpr( "ceilDiv( T0.size[0] , 128 ) * 4 >= ceilDiv( T0.size[0] , 128 )"_) ->getBool()); + + ASSERT_TRUE(*simplify("ceilDiv( i1 , i2 ) > 0"_, "i1 > 0 && i2 > 0"_)); + ASSERT_TRUE(*simplify("ceilDiv( i1 , i2 ) >= 1"_, "i1 > 0 && i2 > 0"_)); } TEST_F(ExprSimplifierTest, FundamentalDivisionWithRemainderProperty_CUDA) { @@ -1041,4 +1045,34 @@ TEST_F(ExprSimplifierTest, ReducePredicateRegisterUsage_CUDA) { } } +TEST_F(ExprSimplifierTest, MinMax_CUDA) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + auto simplify = [](Val* x, Val* assumption) { + return simplifyExpr(x, {}, {assumption->as()}); + }; + + auto expr = + "max( max( ceilDiv( T0.size[0] , 128 ) * 4 , ceilDiv( T0.size[0] , 128 ) ) , 4 )"_; + ASSERT_TRUE(simplify(expr, assume::tensorsAreNotEmpty(expr)) + ->sameAs("ceilDiv( T0.size[0] , 128 ) * 4"_)); +} + +TEST_F(ExprSimplifierTest, Assume_CUDA) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + auto expr = + "max( max( ceilDiv( T0.size[0] , 128 ) * 4 , ceilDiv( T0.size[1] , 128 ) ) , 4 )"_; + ASSERT_EQ( + simplifyExpr(IrBuilder::eqExpr( + assume::tensorsAreNotEmpty(expr), + "T0.size[0] > 0 && T0.size[1] > 0"_)) + ->getBool(), + true); +} + } // namespace nvfuser From 85d846af0cc24dcb78ad01140f024fb0d58d69fc Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 30 Mar 2023 16:37:58 -0700 Subject: [PATCH 04/10] fix --- csrc/parallel_dimension_map.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index 37ff328894b..56092455361 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -70,7 +70,12 @@ void ParallelDimensionMap::build(Fusion* fusion) { // Simplify dim_map_ for (auto& [k, v] : dim_map_) { - v = simplifyExpr(v, {}, {assume::tensorsAreNotEmpty(v)}); + auto assume = assume::tensorsAreNotEmpty(v); + if (assume != nullptr) { + v = simplifyExpr(v, {}, {assume}); + } else { + v = simplifyExpr(v); + } } // Compute exact_types_ From 3e8befec706bbf1f32d292ab2b788b035b6bb783 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 30 Mar 2023 16:40:03 -0700 Subject: [PATCH 05/10] comment --- csrc/parallel_dimension_map.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index 56092455361..634ca31e656 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -70,6 +70,10 @@ void ParallelDimensionMap::build(Fusion* fusion) { // Simplify dim_map_ for (auto& [k, v] : dim_map_) { + // Well, this isn't really correct, but we need this assumption to better + // handle non-empty cases. If this turn out to be an issue, I believe we + // then need to find a more systematic way to handle empty tensor, rather + // than just disable this assumption. auto assume = assume::tensorsAreNotEmpty(v); if (assume != nullptr) { v = simplifyExpr(v, {}, {assume}); From 8804619c75135ef8b16189e8f54c046f7c1676f4 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 30 Mar 2023 16:41:47 -0700 Subject: [PATCH 06/10] newline --- csrc/assume.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/assume.cpp b/csrc/assume.cpp index b70da84c742..8da6ba6df70 100644 --- a/csrc/assume.cpp +++ b/csrc/assume.cpp @@ -32,4 +32,4 @@ Bool* tensorsAreNotEmpty(Val* value) { return result; } -} // namespace nvfuser::assume \ No newline at end of file +} // namespace nvfuser::assume From 5cc53e94352a9e0d83d95632d6152bc7d61291bc Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 30 Mar 2023 16:44:49 -0700 Subject: [PATCH 07/10] unformat --- csrc/utils.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/csrc/utils.h b/csrc/utils.h index db922adc6c2..8f95ab6243c 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -268,8 +268,9 @@ std::vector getSortedKeys( // Based on https://stackoverflow.com/a/9154394 template -static auto hasToStringHelper(int) - -> decltype(std::declval::type>().toString(), std::true_type{}); +static auto hasToStringHelper(int) -> decltype( + std::declval::type>().toString(), + std::true_type{}); template static auto hasToStringHelper(long) -> std::false_type; From 15f097fd2d411dd132708c51c736ffdeb4379d92 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 30 Mar 2023 16:52:11 -0700 Subject: [PATCH 08/10] fix tests --- test/test_gpu1.cpp | 12 ++++++------ test/test_gpu2.cpp | 32 ++++++++++++++++---------------- test/test_gpu3.cpp | 16 ++++++++-------- 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/test/test_gpu1.cpp b/test/test_gpu1.cpp index 43bf5f410c5..e8a40e3d54b 100644 --- a/test/test_gpu1.cpp +++ b/test/test_gpu1.cpp @@ -1201,17 +1201,17 @@ TEST_F(NVFuserTest, FusionParser_CUDA) { // 2. use a fuzzy compare (ignore non-significant whitespaces for example) const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Tensor T3) { - int64_t i241; - i241 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); - if ((i241 < T0.size[0])) { + int64_t i244; + i244 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); + if ((i244 < T0.size[0])) { float T5[1]; T5[0] = 0; T5[0] - = T1[i241]; + = T1[i244]; float T4[1]; T4[0] = 0; T4[0] - = T0[i241]; + = T0[i244]; float T2[1]; T2[0] = T4[0] @@ -1220,7 +1220,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Te T6[0] = T2[0] * T4[0]; - T3[i241] + T3[i244] = T6[0]; } } diff --git a/test/test_gpu2.cpp b/test/test_gpu2.cpp index 93194c781cb..45ee2722023 100644 --- a/test/test_gpu2.cpp +++ b/test/test_gpu2.cpp @@ -9029,27 +9029,27 @@ TEST_F(NVFuserTest, FusionChannelsLastParser_CUDA) { // 2. use a fuzzy compare (ignore non-significant whitespaces for example) const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, Tensor<__half, 4> T7) { - int64_t i1409; - i1409 = T0.size[2] * T0.size[1]; - int64_t i1412; - i1412 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); - int64_t i1414; - i1414 = (T0.size[1] * T0.size[2]) * T0.size[3]; - int64_t i1446; - i1446 = i1412 % i1414; - int64_t i1423; - i1423 = T0.size[2] * T0.size[3]; - int64_t i1447; - i1447 = i1446 % i1423; - if ((i1412 < (((T0.size[0] * T0.size[1]) * T0.size[2]) * T0.size[3]))) { + int64_t i1419; + i1419 = T0.size[2] * T0.size[1]; + int64_t i1422; + i1422 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); + int64_t i1424; + i1424 = (T0.size[1] * T0.size[2]) * T0.size[3]; + int64_t i1456; + i1456 = i1422 % i1424; + int64_t i1433; + i1433 = T0.size[2] * T0.size[3]; + int64_t i1457; + i1457 = i1456 % i1433; + if ((i1422 < (((T0.size[0] * T0.size[1]) * T0.size[2]) * T0.size[3]))) { __half T9[1]; T9[0] = 0; T9[0] - = T2[(((((i1409 * T0.size[3]) * (i1412 / i1414)) + (i1409 * (i1447 % T0.size[3]))) + (T0.size[2] * (i1446 / i1423))) + (i1447 / T0.size[3]))]; + = T2[(((((i1419 * T0.size[3]) * (i1422 / i1424)) + (i1419 * (i1457 % T0.size[3]))) + (T0.size[2] * (i1456 / i1433))) + (i1457 / T0.size[3]))]; __half T8[1]; T8[0] = 0; T8[0] - = T0[i1412]; + = T0[i1422]; float T3[1]; T3[0] = __half2float(T9[0]); @@ -9069,7 +9069,7 @@ __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, __half T10[1]; T10[0] = __float2half(T6[0]); - T7[i1412] + T7[i1422] = T10[0]; } } diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index da88ba59b9d..64845d82971 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -1749,21 +1749,21 @@ TEST_F(NVFuserTest, FusionIndexHoist3_CUDA) { const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T2) { - int64_t i194; - i194 = ((nvfuser_index_t)threadIdx.x) + (256 * ((nvfuser_index_t)blockIdx.x)); + int64_t i197; + i197 = ((nvfuser_index_t)threadIdx.x) + (256 * ((nvfuser_index_t)blockIdx.x)); int64_t i7; i7 = T0.size[0] * T0.size[1]; - bool b324; - b324 = i194 < i7; + bool b327; + b327 = i197 < i7; float f8; f8 = (float)(i7); float T1[1]; - if (b324) { + if (b327) { T1[0] - = sinf(T0[i194]); + = sinf(T0[i197]); } - if (b324) { - T2[i194] + if (b327) { + T2[i197] = T1[0] + f8; } From d91025c447c2c7e4753f3076349c97e9a04f570a Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Fri, 31 Mar 2023 11:00:03 -0700 Subject: [PATCH 09/10] dic --- csrc/assume.h | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/csrc/assume.h b/csrc/assume.h index a0d5399614c..40b93d55d5d 100644 --- a/csrc/assume.h +++ b/csrc/assume.h @@ -1,11 +1,14 @@ #include -// Return boolean values representing the conditional you want to assume +// Return boolean predicates representing the conditional you want to assume. +// The return value is typically used as the `assumptions` argument of +// `simplifyExpr` namespace nvfuser::assume { -// Assume that all tensor sizes appearing in `value` are positive. Return -// nullptr if not applicable. For example: +// Return a boolean predicate stating that all tensor sizes appearing in `value` +// are positive. Return nullptr if `value` does not depend on any tensor size. +// For example: // tensorsAreNotEmpty(ceilDiv(T0.size[0], 5) * T0.size[1]) // -> T0.size[0] > 0 && T0.size[1] > 0 // tensorsAreNotEmpty(ceilDiv(i1, 5) * i2) From 5954219cb049f9be0ce61264fbce951ffd56fb98 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Fri, 31 Mar 2023 13:39:00 -0700 Subject: [PATCH 10/10] assume dedup --- csrc/assume.cpp | 16 ++++++++++++++-- test/test_expr_simplifier.cpp | 2 ++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/csrc/assume.cpp b/csrc/assume.cpp index 8da6ba6df70..08b146cc349 100644 --- a/csrc/assume.cpp +++ b/csrc/assume.cpp @@ -25,9 +25,21 @@ Bool* tensorsAreNotEmpty(Val* value) { } } Bool* result = nullptr; + // tensor_sizes might contain duplicate, and we should remove this duplication + std::vector tensor_sizes_applied; for (auto ts : tensor_sizes) { - result = SimplifyingIrBuilder::andExpr( - result, SimplifyingIrBuilder::gtExpr(ts, ts->container()->zeroVal())); + bool is_duplicate = false; + for (auto existing : tensor_sizes_applied) { + if (existing->sameAs(ts)) { + is_duplicate = true; + break; + } + } + if (!is_duplicate) { + tensor_sizes_applied.emplace_back(ts); + result = SimplifyingIrBuilder::andExpr( + result, SimplifyingIrBuilder::gtExpr(ts, ts->container()->zeroVal())); + } } return result; } diff --git a/test/test_expr_simplifier.cpp b/test/test_expr_simplifier.cpp index d4b80a2689c..775dcc95e33 100644 --- a/test/test_expr_simplifier.cpp +++ b/test/test_expr_simplifier.cpp @@ -1073,6 +1073,8 @@ TEST_F(ExprSimplifierTest, Assume_CUDA) { "T0.size[0] > 0 && T0.size[1] > 0"_)) ->getBool(), true); + expr = "ceilDiv( T0.size[0] , T0.size[0] ) * T0.size[0]"_; + ASSERT_TRUE(assume::tensorsAreNotEmpty(expr)->sameAs("T0.size[0] > 0"_)); } } // namespace nvfuser