From cd3448f57c4b1701358d019b7a5ec3eafd6b9d21 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 30 Mar 2023 13:04:19 -0700 Subject: [PATCH 01/11] Better test --- csrc/utils.h | 13 +++- test/test_expr_simplifier.cpp | 111 ++++++++++++++++++++++++++++------ 2 files changed, 102 insertions(+), 22 deletions(-) diff --git a/csrc/utils.h b/csrc/utils.h index cc6c17b44fb..db922adc6c2 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -12,6 +12,7 @@ #include #include +#include #include #include #include @@ -267,9 +268,8 @@ std::vector getSortedKeys( // Based on https://stackoverflow.com/a/9154394 template -static auto hasToStringHelper(int) -> decltype( - std::declval::type>().toString(), - std::true_type{}); +static auto hasToStringHelper(int) + -> decltype(std::declval::type>().toString(), std::true_type{}); template static auto hasToStringHelper(long) -> std::false_type; @@ -374,6 +374,13 @@ std::string toDelimitedString( return toDelimitedString(vec.begin(), vec.end(), delim); } +template +std::string toDelimitedString( + const std::deque& dq, + std::string delim = ", ") { + return toDelimitedString(dq.begin(), dq.end(), delim); +} + template void unrolled_for(func_t fun) { if constexpr (index < stop) { diff --git a/test/test_expr_simplifier.cpp b/test/test_expr_simplifier.cpp index b289dbbcbd9..9971b58cac3 100644 --- a/test/test_expr_simplifier.cpp +++ b/test/test_expr_simplifier.cpp @@ -13,6 +13,7 @@ #include #include +#include #include #include #include @@ -80,13 +81,24 @@ namespace stupid_simple_compiler { using fun1_t = Val* (*)(Val*); using fun2_t = Val* (*)(Val*, Val*); -struct LeftParenthesis {}; +struct LeftParenthesis { + int64_t prev_lparen_pos; +}; +struct FunctionCall { + int64_t prev_lparen_pos; + std::string_view name; +}; +struct Comma {}; +struct LowestPrecedence {}; using token_t = std::variant< Val*, // variable or constant fun1_t, // unary op fun2_t, // binary op - LeftParenthesis>; + LeftParenthesis, + FunctionCall, + Comma, + LowestPrecedence>; Val* parseIdentifier(std::string_view token_str) { if (token_str == "true") { @@ -154,6 +166,23 @@ Val* parseNumber(std::string_view token_str) { } } +Val* functionCall(std::string_view name, std::deque args) { + if (name == "max") { + TORCH_CHECK( + args.size() == 2, "Invalid argument: ", toDelimitedString(args)); + return IrBuilder::maxExpr(args.at(0), args.at(1)); + } else if (name == "min") { + TORCH_CHECK( + args.size() == 2, "Invalid argument: ", toDelimitedString(args)); + return IrBuilder::minExpr(args.at(0), args.at(1)); + } else if (name == "ceilDiv") { + TORCH_CHECK( + args.size() == 2, "Invalid argument: ", toDelimitedString(args)); + return IrBuilder::ceilDivExpr(args.at(0), args.at(1)); + } + TORCH_CHECK(false, "Unknown function: ", name); +} + token_t parseToken(std::string_view token_str, bool& expect_val) { if (std::isalpha(token_str.at(0))) { TORCH_CHECK( @@ -219,6 +248,12 @@ token_t parseToken(std::string_view token_str, bool& expect_val) { // https://en.cppreference.com/w/cpp/language/operator_precedence int getOpPrecedence(token_t op) { + if (std::holds_alternative(op)) { + return std::numeric_limits::max(); + } + if (std::holds_alternative(op)) { + return 17; + } if (std::holds_alternative(op)) { auto uop = std::get(op); if (uop == fun1_t(neg) || uop == fun1_t(notOp)) { @@ -279,7 +314,19 @@ Val* parse(const char* str) { op); }; + auto eval_all_top = [&](token_t token) { + TORCH_CHECK(current != nullptr, "Expect value to evaluate top"); + while (!op_stack.empty() && + (std::holds_alternative(op_stack.back()) || + std::holds_alternative(op_stack.back())) && + getOpPrecedence(op_stack.back()) <= getOpPrecedence(token)) { + eval_top(); + } + }; + bool expect_val = true; + int64_t last_lparen_pos = -1; + while (!remaining.empty()) { const auto end_pos = remaining.find_first_of(' '); const auto token_str = remaining.substr(0, end_pos); @@ -287,19 +334,48 @@ Val* parse(const char* str) { if (token_str == "(") { TORCH_CHECK( expect_val, "Syntax error: not expecting ( but get ", token_str); - op_stack.push_back(LeftParenthesis{}); + op_stack.push_back(LeftParenthesis{last_lparen_pos}); + last_lparen_pos = op_stack.size() - 1; + } else if (token_str.back() == '(') { + TORCH_CHECK( + expect_val, + "Syntax error: not expecting function call but get ", + token_str); + op_stack.push_back(FunctionCall{ + last_lparen_pos, token_str.substr(0, token_str.size() - 1)}); + last_lparen_pos = op_stack.size() - 1; + } else if (token_str == ",") { + TORCH_CHECK(!expect_val, "Syntax error: not expecting comma"); + expect_val = true; + auto comma = Comma{}; + eval_all_top(comma); + value_stack.emplace_back(current); + op_stack.emplace_back(comma); + current = nullptr; } else if (token_str == ")") { TORCH_CHECK( !expect_val, "Syntax error: not expecting ) but get ", token_str); - // pop stack until meets matching ( - TORCH_CHECK(current != nullptr, "Expect value before )"); - while (true) { - TORCH_CHECK(!op_stack.empty(), "Unmatched )"); - if (std::holds_alternative(op_stack.back())) { + eval_all_top(LowestPrecedence{}); + auto last_lparen = op_stack.at(last_lparen_pos); + TORCH_CHECK(!op_stack.empty(), "Unmatched )"); + if (std::holds_alternative(last_lparen)) { + TORCH_INTERNAL_ASSERT(last_lparen_pos == (int64_t)op_stack.size() - 1); + auto lparen = std::get(op_stack.back()); + last_lparen_pos = lparen.prev_lparen_pos; + op_stack.pop_back(); + } else if (std::holds_alternative(last_lparen)) { + std::deque args{current}; + while (std::holds_alternative(op_stack.back())) { op_stack.pop_back(); - break; + args.push_front(value_stack.back()); + value_stack.pop_back(); } - eval_top(); + auto fc = std::get(op_stack.back()); + last_lparen_pos = fc.prev_lparen_pos; + op_stack.pop_back(); + current = functionCall(fc.name, std::move(args)); + } else { + TORCH_CHECK(false, "Unknown left parenthesis type"); } } else { token_t token = parseToken(token_str, expect_val); @@ -307,15 +383,9 @@ Val* parse(const char* str) { TORCH_CHECK(current == nullptr, "Don't expect value"); current = std::get(token); } else if (std::holds_alternative(token)) { - TORCH_CHECK(current == nullptr, "Don't expect value"); op_stack.push_back(token); } else if (std::holds_alternative(token)) { - TORCH_CHECK(current != nullptr, "Expect value before binary op"); - while (!op_stack.empty() && - !std::holds_alternative(op_stack.back()) && - getOpPrecedence(op_stack.back()) <= getOpPrecedence(token)) { - eval_top(); - } + eval_all_top(token); value_stack.push_back(current); op_stack.push_back(token); current = nullptr; @@ -416,6 +486,9 @@ TEST_F(ExprSimplifierTest, EliminateTrivialComputation_CUDA) { Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); + // constant folding + TORCH_CHECK(simplifyExpr("ceilDiv( 5 , 3 ) * 5"_)->sameAs("10"_)); + TORCH_CHECK(simplifyExpr("1 * i"_)->sameAs("i"_)); TORCH_CHECK(simplifyExpr("1.0 * d"_)->sameAs("d"_)); TORCH_CHECK(simplifyExpr("i * 1"_)->sameAs("i"_)); @@ -440,8 +513,8 @@ TEST_F(ExprSimplifierTest, EliminateTrivialComputation_CUDA) { TORCH_CHECK(simplifyExpr("b && b"_)->sameAs("b"_)); TORCH_CHECK(simplifyExpr("b || b"_)->sameAs("b"_)); - TORCH_CHECK(simplifyExpr(IrBuilder::maxExpr("i"_, "i"_))->sameAs("i"_)); - TORCH_CHECK(simplifyExpr(IrBuilder::minExpr("i"_, "i"_))->sameAs("i"_)); + TORCH_CHECK(simplifyExpr("max( i , i )"_)->sameAs("i"_)); + TORCH_CHECK(simplifyExpr("min( i , i )"_)->sameAs("i"_)); TORCH_CHECK(simplifyExpr("i / 1"_)->sameAs("i"_)); TORCH_CHECK(simplifyExpr("d / 1.0"_)->sameAs("d"_)); From 81dcd4877122d33dec21601aa36bb702a75931d6 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 30 Mar 2023 16:00:44 -0700 Subject: [PATCH 02/11] min-max simplifier --- csrc/expr_simplifier.cpp | 27 +++++++++++++++ test/test_expr_simplifier.cpp | 62 +++++++++++++++++++++-------------- 2 files changed, 64 insertions(+), 25 deletions(-) diff --git a/csrc/expr_simplifier.cpp b/csrc/expr_simplifier.cpp index 576b6c8930c..465fc0f7a4f 100644 --- a/csrc/expr_simplifier.cpp +++ b/csrc/expr_simplifier.cpp @@ -1611,6 +1611,33 @@ Val* eliminateTrivialComputation(Val* value, const Context& context) { } } } + { // max(a, b) -> a if a >= b, min(a, b) -> b if a >= b + if (op == BinaryOpType::Max || op == BinaryOpType::Min) { + std::vector simplified_input; + for (auto v : fop->inputs()) { + bool found_redundant = false; + for (auto& v2 : simplified_input) { + if ((op == BinaryOpType::Max && prove::lessEqual(v, v2, context)) || + (op == BinaryOpType::Min && prove::lessEqual(v2, v, context))) { + found_redundant = true; + break; + } else if ( + (op == BinaryOpType::Max && prove::lessEqual(v2, v, context)) || + (op == BinaryOpType::Min && prove::lessEqual(v, v2, context))) { + found_redundant = true; + v2 = v; + break; + } + } + if (!found_redundant) { + simplified_input.emplace_back(v); + } + } + if (simplified_input.size() < fop->inputs().size()) { + return maybeFlattenedOpOf(op, std::move(simplified_input)); + } + } + } } else if (auto bop = dynamic_cast(value->definition())) { auto lhs = foldConstants(bop->lhs()); auto rhs = foldConstants(bop->rhs()); diff --git a/test/test_expr_simplifier.cpp b/test/test_expr_simplifier.cpp index 9971b58cac3..96f71657895 100644 --- a/test/test_expr_simplifier.cpp +++ b/test/test_expr_simplifier.cpp @@ -486,53 +486,61 @@ TEST_F(ExprSimplifierTest, EliminateTrivialComputation_CUDA) { Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); + auto simplify = [](Val* x, Val* assumption) { + return simplifyExpr(x, {}, {assumption->as()}); + }; + // constant folding - TORCH_CHECK(simplifyExpr("ceilDiv( 5 , 3 ) * 5"_)->sameAs("10"_)); + ASSERT_TRUE(simplifyExpr("ceilDiv( 5 , 3 ) * 5"_)->sameAs("10"_)); - TORCH_CHECK(simplifyExpr("1 * i"_)->sameAs("i"_)); - TORCH_CHECK(simplifyExpr("1.0 * d"_)->sameAs("d"_)); - TORCH_CHECK(simplifyExpr("i * 1"_)->sameAs("i"_)); - TORCH_CHECK(simplifyExpr("d * 1.0"_)->sameAs("d"_)); + ASSERT_TRUE(simplifyExpr("1 * i"_)->sameAs("i"_)); + ASSERT_TRUE(simplifyExpr("1.0 * d"_)->sameAs("d"_)); + ASSERT_TRUE(simplifyExpr("i * 1"_)->sameAs("i"_)); + ASSERT_TRUE(simplifyExpr("d * 1.0"_)->sameAs("d"_)); ASSERT_EQ(simplifyExpr("0 * i"_)->getInt(), 0); ASSERT_EQ(simplifyExpr("i * 0"_)->getInt(), 0); - TORCH_CHECK(simplifyExpr("0 + i"_)->sameAs("i"_)); - TORCH_CHECK(simplifyExpr("0.0 + d"_)->sameAs("d"_)); - TORCH_CHECK(simplifyExpr("i + 0"_)->sameAs("i"_)); - TORCH_CHECK(simplifyExpr("d + 0.0"_)->sameAs("d"_)); + ASSERT_TRUE(simplifyExpr("0 + i"_)->sameAs("i"_)); + ASSERT_TRUE(simplifyExpr("0.0 + d"_)->sameAs("d"_)); + ASSERT_TRUE(simplifyExpr("i + 0"_)->sameAs("i"_)); + ASSERT_TRUE(simplifyExpr("d + 0.0"_)->sameAs("d"_)); - TORCH_CHECK(simplifyExpr("true && b"_)->sameAs("b"_)); - TORCH_CHECK(simplifyExpr("b && true"_)->sameAs("b"_)); + ASSERT_TRUE(simplifyExpr("true && b"_)->sameAs("b"_)); + ASSERT_TRUE(simplifyExpr("b && true"_)->sameAs("b"_)); ASSERT_EQ(simplifyExpr("false && b"_)->getBool(), false); ASSERT_EQ(simplifyExpr("b && false"_)->getBool(), false); ASSERT_EQ(simplifyExpr("true || b"_)->getBool(), true); ASSERT_EQ(simplifyExpr("b || true"_)->getBool(), true); - TORCH_CHECK(simplifyExpr("false || b"_)->sameAs("b"_)); - TORCH_CHECK(simplifyExpr("b || false"_)->sameAs("b"_)); + ASSERT_TRUE(simplifyExpr("false || b"_)->sameAs("b"_)); + ASSERT_TRUE(simplifyExpr("b || false"_)->sameAs("b"_)); - TORCH_CHECK(simplifyExpr("b && b"_)->sameAs("b"_)); - TORCH_CHECK(simplifyExpr("b || b"_)->sameAs("b"_)); - TORCH_CHECK(simplifyExpr("max( i , i )"_)->sameAs("i"_)); - TORCH_CHECK(simplifyExpr("min( i , i )"_)->sameAs("i"_)); + ASSERT_TRUE(simplifyExpr("b && b"_)->sameAs("b"_)); + ASSERT_TRUE(simplifyExpr("b || b"_)->sameAs("b"_)); + ASSERT_TRUE(simplifyExpr("max( i , i )"_)->sameAs("i"_)); + ASSERT_TRUE(simplifyExpr("min( i , i )"_)->sameAs("i"_)); + ASSERT_TRUE(simplify("max( i1 , i2 )"_, "i1 <= i2"_)->sameAs("i2"_)); + ASSERT_TRUE(simplify("max( i2 , i1 )"_, "i1 <= i2"_)->sameAs("i2"_)); + ASSERT_TRUE(simplify("min( i1 , i2 )"_, "i1 <= i2"_)->sameAs("i1"_)); + ASSERT_TRUE(simplify("min( i2 , i1 )"_, "i1 <= i2"_)->sameAs("i1"_)); - TORCH_CHECK(simplifyExpr("i / 1"_)->sameAs("i"_)); - TORCH_CHECK(simplifyExpr("d / 1.0"_)->sameAs("d"_)); + ASSERT_TRUE(simplifyExpr("i / 1"_)->sameAs("i"_)); + ASSERT_TRUE(simplifyExpr("d / 1.0"_)->sameAs("d"_)); ASSERT_EQ(simplifyExpr("0 / i"_)->getInt(), 0); ASSERT_EQ(simplifyExpr("i % 1"_)->getInt(), 0); // -(-a) -> a - TORCH_CHECK(simplifyExpr("- - i"_)->sameAs("i"_)); - TORCH_CHECK(simplifyExpr("~ ~ i"_)->sameAs("i"_)); - TORCH_CHECK(simplifyExpr("! ! b"_)->sameAs("b"_)); + ASSERT_TRUE(simplifyExpr("- - i"_)->sameAs("i"_)); + ASSERT_TRUE(simplifyExpr("~ ~ i"_)->sameAs("i"_)); + ASSERT_TRUE(simplifyExpr("! ! b"_)->sameAs("b"_)); // Test constant folding - TORCH_CHECK(simplifyExpr("1 + i + 1"_)->sameAs("i + 2"_)); - TORCH_CHECK(simplifyExpr("1.0 + d + 1.0"_)->sameAs("d + 2.0"_)); + ASSERT_TRUE(simplifyExpr("1 + i + 1"_)->sameAs("i + 2"_)); + ASSERT_TRUE(simplifyExpr("1.0 + d + 1.0"_)->sameAs("d + 2.0"_)); // Test that FlattenedAssocCommOp::sameAs ignores order - TORCH_CHECK(simplifyExpr("( i1 * i2 ) - ( i2 * i1 )"_)->isZeroInt()); + ASSERT_TRUE(simplifyExpr("( i1 * i2 ) - ( i2 * i1 )"_)->isZeroInt()); } TEST_F(ExprSimplifierTest, SimplifyDivisibleDivMod_CUDA) { @@ -846,6 +854,10 @@ TEST_F(ExprSimplifierTest, Compare_CUDA) { ASSERT_TRUE(*simplify("i1 >= i1 * i2"_, "i1 <= 0 && i2 > 0"_)); ASSERT_TRUE(*simplify("d1 <= d1 * d2"_, "d1 >= 0.0 && d2 >= 1.0"_)); ASSERT_TRUE(*simplify("d1 >= d1 * d2"_, "d1 <= 0.0 && d2 >= 1.0"_)); + ASSERT_TRUE( + *simplifyExpr( + "ceilDiv( T0.size[0] , 128 ) * 4 >= ceilDiv( T0.size[0] , 128 )"_) + ->getBool()); } TEST_F(ExprSimplifierTest, FundamentalDivisionWithRemainderProperty_CUDA) { From da109f2f2c561adaaa266837380c46e4b694790c Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 30 Mar 2023 16:34:59 -0700 Subject: [PATCH 03/11] assume --- CMakeLists.txt | 1 + csrc/assume.cpp | 35 +++++++++++++++++++++++++++++++++ csrc/assume.h | 15 ++++++++++++++ csrc/expr_simplifier.cpp | 7 +++++++ csrc/parallel_dimension_map.cpp | 3 ++- test/test_expr_simplifier.cpp | 34 ++++++++++++++++++++++++++++++++ 6 files changed, 94 insertions(+), 1 deletion(-) create mode 100644 csrc/assume.cpp create mode 100644 csrc/assume.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 7a491e33035..5044a890bcb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -60,6 +60,7 @@ endif() # nvfuser codegen sources set(NVFUSER_SRCS) list(APPEND NVFUSER_SRCS + ${NVFUSER_SRCS_DIR}/assume.cpp ${NVFUSER_SRCS_DIR}/compute_at.cpp ${NVFUSER_SRCS_DIR}/inlining.cpp ${NVFUSER_SRCS_DIR}/compute_at_map.cpp diff --git a/csrc/assume.cpp b/csrc/assume.cpp new file mode 100644 index 00000000000..b70da84c742 --- /dev/null +++ b/csrc/assume.cpp @@ -0,0 +1,35 @@ +#include +#include +#include + +#include + +namespace nvfuser::assume { + +Bool* tensorsAreNotEmpty(Val* value) { + std::vector todo{value}; + std::vector tensor_sizes; + while (!todo.empty()) { + auto v = todo.back(); + todo.pop_back(); + if (auto ns = dynamic_cast(v)) { + if (ns->isTensorSize()) { + tensor_sizes.emplace_back(v); + continue; + } + } + if (auto def = v->definition()) { + for (auto inp : def->inputs()) { + todo.emplace_back(inp); + } + } + } + Bool* result = nullptr; + for (auto ts : tensor_sizes) { + result = SimplifyingIrBuilder::andExpr( + result, SimplifyingIrBuilder::gtExpr(ts, ts->container()->zeroVal())); + } + return result; +} + +} // namespace nvfuser::assume \ No newline at end of file diff --git a/csrc/assume.h b/csrc/assume.h new file mode 100644 index 00000000000..a0d5399614c --- /dev/null +++ b/csrc/assume.h @@ -0,0 +1,15 @@ +#include + +// Return boolean values representing the conditional you want to assume + +namespace nvfuser::assume { + +// Assume that all tensor sizes appearing in `value` are positive. Return +// nullptr if not applicable. For example: +// tensorsAreNotEmpty(ceilDiv(T0.size[0], 5) * T0.size[1]) +// -> T0.size[0] > 0 && T0.size[1] > 0 +// tensorsAreNotEmpty(ceilDiv(i1, 5) * i2) +// -> nullptr +Bool* tensorsAreNotEmpty(Val* value); + +} // namespace nvfuser::assume diff --git a/csrc/expr_simplifier.cpp b/csrc/expr_simplifier.cpp index 465fc0f7a4f..176021fa462 100644 --- a/csrc/expr_simplifier.cpp +++ b/csrc/expr_simplifier.cpp @@ -1296,6 +1296,13 @@ bool isPositiveHelper(Val* value, const Context& context) { } return true; } + } else if (auto bop = dynamic_cast(value->definition())) { + auto op = bop->getBinaryOpType(); + if (op == BinaryOpType::CeilDiv) { + return isPositive(bop->lhs(), context) && + isValidDenominator(bop->rhs(), context) && + isNonNegative(bop->rhs(), context); + } } for (const auto& [a, b] : context.getKnownLessThan()) { if (a->isZero() && b->sameAs(value)) { diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index 9888f2fc848..37ff328894b 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -69,7 +70,7 @@ void ParallelDimensionMap::build(Fusion* fusion) { // Simplify dim_map_ for (auto& [k, v] : dim_map_) { - v = simplifyExpr(v); + v = simplifyExpr(v, {}, {assume::tensorsAreNotEmpty(v)}); } // Compute exact_types_ diff --git a/test/test_expr_simplifier.cpp b/test/test_expr_simplifier.cpp index 96f71657895..d4b80a2689c 100644 --- a/test/test_expr_simplifier.cpp +++ b/test/test_expr_simplifier.cpp @@ -7,6 +7,7 @@ // clang-format on #include +#include #include #include #include @@ -858,6 +859,9 @@ TEST_F(ExprSimplifierTest, Compare_CUDA) { *simplifyExpr( "ceilDiv( T0.size[0] , 128 ) * 4 >= ceilDiv( T0.size[0] , 128 )"_) ->getBool()); + + ASSERT_TRUE(*simplify("ceilDiv( i1 , i2 ) > 0"_, "i1 > 0 && i2 > 0"_)); + ASSERT_TRUE(*simplify("ceilDiv( i1 , i2 ) >= 1"_, "i1 > 0 && i2 > 0"_)); } TEST_F(ExprSimplifierTest, FundamentalDivisionWithRemainderProperty_CUDA) { @@ -1041,4 +1045,34 @@ TEST_F(ExprSimplifierTest, ReducePredicateRegisterUsage_CUDA) { } } +TEST_F(ExprSimplifierTest, MinMax_CUDA) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + auto simplify = [](Val* x, Val* assumption) { + return simplifyExpr(x, {}, {assumption->as()}); + }; + + auto expr = + "max( max( ceilDiv( T0.size[0] , 128 ) * 4 , ceilDiv( T0.size[0] , 128 ) ) , 4 )"_; + ASSERT_TRUE(simplify(expr, assume::tensorsAreNotEmpty(expr)) + ->sameAs("ceilDiv( T0.size[0] , 128 ) * 4"_)); +} + +TEST_F(ExprSimplifierTest, Assume_CUDA) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + auto expr = + "max( max( ceilDiv( T0.size[0] , 128 ) * 4 , ceilDiv( T0.size[1] , 128 ) ) , 4 )"_; + ASSERT_EQ( + simplifyExpr(IrBuilder::eqExpr( + assume::tensorsAreNotEmpty(expr), + "T0.size[0] > 0 && T0.size[1] > 0"_)) + ->getBool(), + true); +} + } // namespace nvfuser From 85d846af0cc24dcb78ad01140f024fb0d58d69fc Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 30 Mar 2023 16:37:58 -0700 Subject: [PATCH 04/11] fix --- csrc/parallel_dimension_map.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index 37ff328894b..56092455361 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -70,7 +70,12 @@ void ParallelDimensionMap::build(Fusion* fusion) { // Simplify dim_map_ for (auto& [k, v] : dim_map_) { - v = simplifyExpr(v, {}, {assume::tensorsAreNotEmpty(v)}); + auto assume = assume::tensorsAreNotEmpty(v); + if (assume != nullptr) { + v = simplifyExpr(v, {}, {assume}); + } else { + v = simplifyExpr(v); + } } // Compute exact_types_ From 3e8befec706bbf1f32d292ab2b788b035b6bb783 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 30 Mar 2023 16:40:03 -0700 Subject: [PATCH 05/11] comment --- csrc/parallel_dimension_map.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index 56092455361..634ca31e656 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -70,6 +70,10 @@ void ParallelDimensionMap::build(Fusion* fusion) { // Simplify dim_map_ for (auto& [k, v] : dim_map_) { + // Well, this isn't really correct, but we need this assumption to better + // handle non-empty cases. If this turn out to be an issue, I believe we + // then need to find a more systematic way to handle empty tensor, rather + // than just disable this assumption. auto assume = assume::tensorsAreNotEmpty(v); if (assume != nullptr) { v = simplifyExpr(v, {}, {assume}); From 8804619c75135ef8b16189e8f54c046f7c1676f4 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 30 Mar 2023 16:41:47 -0700 Subject: [PATCH 06/11] newline --- csrc/assume.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/assume.cpp b/csrc/assume.cpp index b70da84c742..8da6ba6df70 100644 --- a/csrc/assume.cpp +++ b/csrc/assume.cpp @@ -32,4 +32,4 @@ Bool* tensorsAreNotEmpty(Val* value) { return result; } -} // namespace nvfuser::assume \ No newline at end of file +} // namespace nvfuser::assume From 5cc53e94352a9e0d83d95632d6152bc7d61291bc Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 30 Mar 2023 16:44:49 -0700 Subject: [PATCH 07/11] unformat --- csrc/utils.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/csrc/utils.h b/csrc/utils.h index db922adc6c2..8f95ab6243c 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -268,8 +268,9 @@ std::vector getSortedKeys( // Based on https://stackoverflow.com/a/9154394 template -static auto hasToStringHelper(int) - -> decltype(std::declval::type>().toString(), std::true_type{}); +static auto hasToStringHelper(int) -> decltype( + std::declval::type>().toString(), + std::true_type{}); template static auto hasToStringHelper(long) -> std::false_type; From 15f097fd2d411dd132708c51c736ffdeb4379d92 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 30 Mar 2023 16:52:11 -0700 Subject: [PATCH 08/11] fix tests --- test/test_gpu1.cpp | 12 ++++++------ test/test_gpu2.cpp | 32 ++++++++++++++++---------------- test/test_gpu3.cpp | 16 ++++++++-------- 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/test/test_gpu1.cpp b/test/test_gpu1.cpp index 43bf5f410c5..e8a40e3d54b 100644 --- a/test/test_gpu1.cpp +++ b/test/test_gpu1.cpp @@ -1201,17 +1201,17 @@ TEST_F(NVFuserTest, FusionParser_CUDA) { // 2. use a fuzzy compare (ignore non-significant whitespaces for example) const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Tensor T3) { - int64_t i241; - i241 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); - if ((i241 < T0.size[0])) { + int64_t i244; + i244 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); + if ((i244 < T0.size[0])) { float T5[1]; T5[0] = 0; T5[0] - = T1[i241]; + = T1[i244]; float T4[1]; T4[0] = 0; T4[0] - = T0[i241]; + = T0[i244]; float T2[1]; T2[0] = T4[0] @@ -1220,7 +1220,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Te T6[0] = T2[0] * T4[0]; - T3[i241] + T3[i244] = T6[0]; } } diff --git a/test/test_gpu2.cpp b/test/test_gpu2.cpp index 93194c781cb..45ee2722023 100644 --- a/test/test_gpu2.cpp +++ b/test/test_gpu2.cpp @@ -9029,27 +9029,27 @@ TEST_F(NVFuserTest, FusionChannelsLastParser_CUDA) { // 2. use a fuzzy compare (ignore non-significant whitespaces for example) const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, Tensor<__half, 4> T7) { - int64_t i1409; - i1409 = T0.size[2] * T0.size[1]; - int64_t i1412; - i1412 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); - int64_t i1414; - i1414 = (T0.size[1] * T0.size[2]) * T0.size[3]; - int64_t i1446; - i1446 = i1412 % i1414; - int64_t i1423; - i1423 = T0.size[2] * T0.size[3]; - int64_t i1447; - i1447 = i1446 % i1423; - if ((i1412 < (((T0.size[0] * T0.size[1]) * T0.size[2]) * T0.size[3]))) { + int64_t i1419; + i1419 = T0.size[2] * T0.size[1]; + int64_t i1422; + i1422 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); + int64_t i1424; + i1424 = (T0.size[1] * T0.size[2]) * T0.size[3]; + int64_t i1456; + i1456 = i1422 % i1424; + int64_t i1433; + i1433 = T0.size[2] * T0.size[3]; + int64_t i1457; + i1457 = i1456 % i1433; + if ((i1422 < (((T0.size[0] * T0.size[1]) * T0.size[2]) * T0.size[3]))) { __half T9[1]; T9[0] = 0; T9[0] - = T2[(((((i1409 * T0.size[3]) * (i1412 / i1414)) + (i1409 * (i1447 % T0.size[3]))) + (T0.size[2] * (i1446 / i1423))) + (i1447 / T0.size[3]))]; + = T2[(((((i1419 * T0.size[3]) * (i1422 / i1424)) + (i1419 * (i1457 % T0.size[3]))) + (T0.size[2] * (i1456 / i1433))) + (i1457 / T0.size[3]))]; __half T8[1]; T8[0] = 0; T8[0] - = T0[i1412]; + = T0[i1422]; float T3[1]; T3[0] = __half2float(T9[0]); @@ -9069,7 +9069,7 @@ __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, __half T10[1]; T10[0] = __float2half(T6[0]); - T7[i1412] + T7[i1422] = T10[0]; } } diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index da88ba59b9d..64845d82971 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -1749,21 +1749,21 @@ TEST_F(NVFuserTest, FusionIndexHoist3_CUDA) { const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T2) { - int64_t i194; - i194 = ((nvfuser_index_t)threadIdx.x) + (256 * ((nvfuser_index_t)blockIdx.x)); + int64_t i197; + i197 = ((nvfuser_index_t)threadIdx.x) + (256 * ((nvfuser_index_t)blockIdx.x)); int64_t i7; i7 = T0.size[0] * T0.size[1]; - bool b324; - b324 = i194 < i7; + bool b327; + b327 = i197 < i7; float f8; f8 = (float)(i7); float T1[1]; - if (b324) { + if (b327) { T1[0] - = sinf(T0[i194]); + = sinf(T0[i197]); } - if (b324) { - T2[i194] + if (b327) { + T2[i197] = T1[0] + f8; } From b3266e8da95c59de65ec8a435109031258ffa8b3 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 30 Mar 2023 18:28:41 -0700 Subject: [PATCH 09/11] Remove trivial predicate for matmul --- csrc/expr_simplifier.cpp | 16 ++++++++++++---- csrc/kernel_ir.h | 4 ++++ csrc/lower_scalar_hoist.cpp | 7 +++++++ csrc/predicate_compute.cpp | 8 ++++++++ test/test_expr_simplifier.cpp | 4 ++++ test/test_gpu_tensorcore.cpp | 24 ++++++++++++++++++++++++ 6 files changed, 59 insertions(+), 4 deletions(-) diff --git a/csrc/expr_simplifier.cpp b/csrc/expr_simplifier.cpp index 176021fa462..79d32ac4aff 100644 --- a/csrc/expr_simplifier.cpp +++ b/csrc/expr_simplifier.cpp @@ -138,6 +138,10 @@ std::unique_ptr createLogger(Val* value) { } // namespace debug_print +namespace assoc_comm { +Val* flatten(Val* value); +} // namespace assoc_comm + namespace { std::vector getAxioms() { @@ -222,16 +226,20 @@ class Context { if (auto bop = dynamic_cast(def)) { switch (bop->getBinaryOpType()) { case BinaryOpType::LT: - less_than_.emplace_back(bop->lhs(), bop->rhs()); + less_than_.emplace_back( + assoc_comm::flatten(bop->lhs()), assoc_comm::flatten(bop->rhs())); break; case BinaryOpType::LE: - less_equal_.emplace_back(bop->lhs(), bop->rhs()); + less_equal_.emplace_back( + assoc_comm::flatten(bop->lhs()), assoc_comm::flatten(bop->rhs())); break; case BinaryOpType::GT: - less_than_.emplace_back(bop->rhs(), bop->lhs()); + less_than_.emplace_back( + assoc_comm::flatten(bop->rhs()), assoc_comm::flatten(bop->lhs())); break; case BinaryOpType::GE: - less_equal_.emplace_back(bop->rhs(), bop->lhs()); + less_equal_.emplace_back( + assoc_comm::flatten(bop->rhs()), assoc_comm::flatten(bop->lhs())); break; default: TORCH_INTERNAL_ASSERT( diff --git a/csrc/kernel_ir.h b/csrc/kernel_ir.h index 81510c211a8..6c452255fc6 100644 --- a/csrc/kernel_ir.h +++ b/csrc/kernel_ir.h @@ -113,6 +113,10 @@ class TORCH_CUDA_CU_API Predicate final : public Val { return hasValue() && value_->isConst(); } + bool isTrivial() const { + return isConst() && value_->getBool() == true; + } + private: PredicateType ptype_ = PredicateType::Manual; diff --git a/csrc/lower_scalar_hoist.cpp b/csrc/lower_scalar_hoist.cpp index 5af1290a415..855605c2cb2 100644 --- a/csrc/lower_scalar_hoist.cpp +++ b/csrc/lower_scalar_hoist.cpp @@ -247,6 +247,13 @@ std::list getVariableInfo( std::vector getAssumptions(const std::vector& loops) { std::vector assumptions; + // assumptions from parallel dimension + for (auto [p, extent] : + GpuLower::current()->parallelDimensionMap().getMap()) { + auto a = IrBuilder::ltExpr(NamedScalar::getParallelIndex(p), extent); + assumptions.emplace_back(a); + } + // assumptions from loop nesting for (auto loop : loops) { // Trivial loop is not generated, so there is no `if` or `for` in C++ to // guard its scope. So we should not assume index < stop. One real example diff --git a/csrc/predicate_compute.cpp b/csrc/predicate_compute.cpp index 7ab87a99f6c..f710a5bd54a 100644 --- a/csrc/predicate_compute.cpp +++ b/csrc/predicate_compute.cpp @@ -61,6 +61,14 @@ Bool* ParallelizedDomainPredicate::PredicateInfo::getPredicate() const { GpuLower::current()->caMap()->getConcreteMappedID( pred_id, IdMappingMode::EXACT)); auto new_pred = SimplifyingIrBuilder::ltExpr(index, pred_id->extent()); + // pt_ not being exact does not mean the predicate is not trivial. For + // example, if I have T1[blockIdx.x{3}] and T2[blockIdx.x{5}], then + // blockIdx.x will not be exact. However, the predicate blockIdx.x < 5 is + // still trivial. + if (pred_id->extent()->sameAs( + GpuLower::current()->parallelDimensionMap().getRaw(pt_))) { + continue; + } pred = SimplifyingIrBuilder::andExpr(pred, new_pred)->as(); } diff --git a/test/test_expr_simplifier.cpp b/test/test_expr_simplifier.cpp index d4b80a2689c..c7ed0951844 100644 --- a/test/test_expr_simplifier.cpp +++ b/test/test_expr_simplifier.cpp @@ -862,6 +862,10 @@ TEST_F(ExprSimplifierTest, Compare_CUDA) { ASSERT_TRUE(*simplify("ceilDiv( i1 , i2 ) > 0"_, "i1 > 0 && i2 > 0"_)); ASSERT_TRUE(*simplify("ceilDiv( i1 , i2 ) >= 1"_, "i1 > 0 && i2 > 0"_)); + + ASSERT_TRUE(*simplify( + "blockIdx.x < ceilDiv( T0.size[0] , 128 ) * 4"_, + "blockIdx.x < ceilDiv( T0.size[0] , 128 ) * 4"_)); } TEST_F(ExprSimplifierTest, FundamentalDivisionWithRemainderProperty_CUDA) { diff --git a/test/test_gpu_tensorcore.cpp b/test/test_gpu_tensorcore.cpp index 94da9900b67..44b6a4c931e 100644 --- a/test/test_gpu_tensorcore.cpp +++ b/test/test_gpu_tensorcore.cpp @@ -806,6 +806,30 @@ TEST_F(NVFuserTest, FusionAmpereSwizzle_CUDA) { TORCH_CHECK(gdimy == expected_gdimy); runtime = fe.kernelTimeMs(); + + // Check that mma op is not predicated + class PredicateChecker : public kir::IrVisitor { + public: + using kir::IrVisitor::handle; + bool found_mma = false; + + private: + void handle(MmaOp* uop) final { + found_mma = true; + for (auto expr : scope_exprs_) { + TORCH_CHECK( + !expr->isA() || + expr->as()->predicate()->isTrivial(), + "MmaOp should't be predicated!", + " Get predicate ", + expr->as()->predicate()->toInlineString()); + } + } + } pred_checker; + + GpuLower gpulw(&fusion); + pred_checker.handle(gpulw.kernel()->topLevelExprs()); + ASSERT_TRUE(pred_checker.found_mma); }; // Checking only a single layout to keep runtime short (compilation overhead) From 8c0e32cb62efa10444f0f0967e6a6cba9f6f58d3 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 30 Mar 2023 18:33:32 -0700 Subject: [PATCH 10/11] fix tests --- test/test_gpu1.cpp | 12 ++++++------ test/test_gpu2.cpp | 32 ++++++++++++++++---------------- test/test_gpu3.cpp | 16 ++++++++-------- test/test_loop_rotation.cpp | 30 +++++++++++++++--------------- 4 files changed, 45 insertions(+), 45 deletions(-) diff --git a/test/test_gpu1.cpp b/test/test_gpu1.cpp index e8a40e3d54b..3b8a85eb2d0 100644 --- a/test/test_gpu1.cpp +++ b/test/test_gpu1.cpp @@ -1201,17 +1201,17 @@ TEST_F(NVFuserTest, FusionParser_CUDA) { // 2. use a fuzzy compare (ignore non-significant whitespaces for example) const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Tensor T3) { - int64_t i244; - i244 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); - if ((i244 < T0.size[0])) { + int64_t i248; + i248 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); + if ((i248 < T0.size[0])) { float T5[1]; T5[0] = 0; T5[0] - = T1[i244]; + = T1[i248]; float T4[1]; T4[0] = 0; T4[0] - = T0[i244]; + = T0[i248]; float T2[1]; T2[0] = T4[0] @@ -1220,7 +1220,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Te T6[0] = T2[0] * T4[0]; - T3[i244] + T3[i248] = T6[0]; } } diff --git a/test/test_gpu2.cpp b/test/test_gpu2.cpp index 45ee2722023..a1e542f0b77 100644 --- a/test/test_gpu2.cpp +++ b/test/test_gpu2.cpp @@ -9029,27 +9029,27 @@ TEST_F(NVFuserTest, FusionChannelsLastParser_CUDA) { // 2. use a fuzzy compare (ignore non-significant whitespaces for example) const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, Tensor<__half, 4> T7) { - int64_t i1419; - i1419 = T0.size[2] * T0.size[1]; - int64_t i1422; - i1422 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); - int64_t i1424; - i1424 = (T0.size[1] * T0.size[2]) * T0.size[3]; - int64_t i1456; - i1456 = i1422 % i1424; - int64_t i1433; - i1433 = T0.size[2] * T0.size[3]; - int64_t i1457; - i1457 = i1456 % i1433; - if ((i1422 < (((T0.size[0] * T0.size[1]) * T0.size[2]) * T0.size[3]))) { + int64_t i1435; + i1435 = T0.size[2] * T0.size[1]; + int64_t i1438; + i1438 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); + int64_t i1440; + i1440 = (T0.size[1] * T0.size[2]) * T0.size[3]; + int64_t i1472; + i1472 = i1438 % i1440; + int64_t i1449; + i1449 = T0.size[2] * T0.size[3]; + int64_t i1473; + i1473 = i1472 % i1449; + if ((i1438 < (((T0.size[0] * T0.size[1]) * T0.size[2]) * T0.size[3]))) { __half T9[1]; T9[0] = 0; T9[0] - = T2[(((((i1419 * T0.size[3]) * (i1422 / i1424)) + (i1419 * (i1457 % T0.size[3]))) + (T0.size[2] * (i1456 / i1433))) + (i1457 / T0.size[3]))]; + = T2[(((((i1435 * T0.size[3]) * (i1438 / i1440)) + (i1435 * (i1473 % T0.size[3]))) + (T0.size[2] * (i1472 / i1449))) + (i1473 / T0.size[3]))]; __half T8[1]; T8[0] = 0; T8[0] - = T0[i1422]; + = T0[i1438]; float T3[1]; T3[0] = __half2float(T9[0]); @@ -9069,7 +9069,7 @@ __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, __half T10[1]; T10[0] = __float2half(T6[0]); - T7[i1422] + T7[i1438] = T10[0]; } } diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 64845d82971..8c02d4c9064 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -1749,21 +1749,21 @@ TEST_F(NVFuserTest, FusionIndexHoist3_CUDA) { const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T2) { - int64_t i197; - i197 = ((nvfuser_index_t)threadIdx.x) + (256 * ((nvfuser_index_t)blockIdx.x)); + int64_t i201; + i201 = ((nvfuser_index_t)threadIdx.x) + (256 * ((nvfuser_index_t)blockIdx.x)); int64_t i7; i7 = T0.size[0] * T0.size[1]; - bool b327; - b327 = i197 < i7; + bool b347; + b347 = i201 < i7; float f8; f8 = (float)(i7); float T1[1]; - if (b327) { + if (b347) { T1[0] - = sinf(T0[i197]); + = sinf(T0[i201]); } - if (b327) { - T2[i197] + if (b347) { + T2[i201] = T1[0] + f8; } diff --git a/test/test_loop_rotation.cpp b/test/test_loop_rotation.cpp index 29264b19427..cb3430e0370 100644 --- a/test/test_loop_rotation.cpp +++ b/test/test_loop_rotation.cpp @@ -206,8 +206,8 @@ TEST_F(LoopRotationTest, NonDivisibleSplit_CUDA) { const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { NVFUSER_DEFINE_MAGIC_ZERO - int64_t i1511; - i1511 = T0.size[0] * T0.size[1]; + int64_t i1529; + i1529 = T0.size[0] * T0.size[1]; float T1[5]; float T2[5]; #pragma unroll @@ -219,7 +219,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { for(nvfuser_index_t i36 = 0; i36 < 5; ++i36) { int64_t i154; i154 = i36 + nvfuser_zero; - if ((i154 < i1511)) { + if ((i154 < i1529)) { T1[i36] = T0[((T0.stride[0] * (i154 / T0.size[1])) + (T0.stride[1] * (i154 % T0.size[1])))]; } @@ -233,10 +233,10 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { NVFUSER_UPDATE_MAGIC_ZERO #pragma unroll 1 for(nvfuser_index_t i39 = 0; i39 < (ceilDiv((T0.size[0] * T0.size[1]), 5)); ++i39) { - int64_t i628; - i628 = 5 * i39; - int64_t i1218; - i1218 = 5 + i628; + int64_t i636; + i636 = 5 * i39; + int64_t i1230; + i1230 = 5 + i636; // Alias Allocation - register auto& T3 = T1; #pragma unroll @@ -247,10 +247,10 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { NVFUSER_UPDATE_MAGIC_ZERO #pragma unroll for(nvfuser_index_t i40 = 0; i40 < 5; ++i40) { - int64_t i629; - i629 = i628 + (i40 + nvfuser_zero); - if ((i629 < i1511)) { - T4[i629] + int64_t i637; + i637 = i636 + (i40 + nvfuser_zero); + if ((i637 < i1529)) { + T4[i637] = T3[i40]; } } @@ -262,11 +262,11 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { NVFUSER_UPDATE_MAGIC_ZERO #pragma unroll for(nvfuser_index_t i36 = 0; i36 < 5; ++i36) { - int64_t i1219; - i1219 = i1218 + (i36 + nvfuser_zero); - if ((i1219 < i1511)) { + int64_t i1231; + i1231 = i1230 + (i36 + nvfuser_zero); + if ((i1231 < i1529)) { T1[i36] - = T0[((T0.stride[0] * (i1219 / T0.size[1])) + (T0.stride[1] * (i1219 % T0.size[1])))]; + = T0[((T0.stride[0] * (i1231 / T0.size[1])) + (T0.stride[1] * (i1231 % T0.size[1])))]; } } NVFUSER_UPDATE_MAGIC_ZERO From b113fa622b400777cbab4caa4153104dd7657c4b Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Fri, 31 Mar 2023 14:25:48 -0700 Subject: [PATCH 11/11] save --- csrc/predicate_compute.cpp | 17 +++++++++-------- test/test_gpu_tensorcore.cpp | 3 ++- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/csrc/predicate_compute.cpp b/csrc/predicate_compute.cpp index f710a5bd54a..86862848454 100644 --- a/csrc/predicate_compute.cpp +++ b/csrc/predicate_compute.cpp @@ -61,14 +61,6 @@ Bool* ParallelizedDomainPredicate::PredicateInfo::getPredicate() const { GpuLower::current()->caMap()->getConcreteMappedID( pred_id, IdMappingMode::EXACT)); auto new_pred = SimplifyingIrBuilder::ltExpr(index, pred_id->extent()); - // pt_ not being exact does not mean the predicate is not trivial. For - // example, if I have T1[blockIdx.x{3}] and T2[blockIdx.x{5}], then - // blockIdx.x will not be exact. However, the predicate blockIdx.x < 5 is - // still trivial. - if (pred_id->extent()->sameAs( - GpuLower::current()->parallelDimensionMap().getRaw(pt_))) { - continue; - } pred = SimplifyingIrBuilder::andExpr(pred, new_pred)->as(); } @@ -173,6 +165,7 @@ ParallelizedDomainPredicate::getPredicateMap( gpu_lower->parallelDimensionMap().isExact(loop_ptype)) { continue; } + auto parallel_dim = gpu_lower->parallelDimensionMap().getRaw(loop_ptype); // Parallel dimensions need not be predicated if fully unswitched. if (within_unswitch && @@ -209,6 +202,14 @@ ParallelizedDomainPredicate::getPredicateMap( continue; } + // loop_ptype not being exact does not mean the predicate is not trivial. + // For example, if I have T1[blockIdx.x{3}] and T2[blockIdx.x{5}], then + // blockIdx.x will not be exact. However, the predicate blockIdx.x < 5 is + // still trivial. + if (tv_id->extent()->sameAs(parallel_dim)) { + continue; + } + // tv_id needs to be predicated. Adds it to the PredicateInfo map. auto& info = map.at(loop_ptype); info.addDomain(tv_id); diff --git a/test/test_gpu_tensorcore.cpp b/test/test_gpu_tensorcore.cpp index 44b6a4c931e..7f5e7a7e87b 100644 --- a/test/test_gpu_tensorcore.cpp +++ b/test/test_gpu_tensorcore.cpp @@ -807,7 +807,8 @@ TEST_F(NVFuserTest, FusionAmpereSwizzle_CUDA) { runtime = fe.kernelTimeMs(); - // Check that mma op is not predicated + // Check that mma op is not predicated. This is a regression test for + // https://github.com/NVIDIA/Fuser/issues/95 class PredicateChecker : public kir::IrVisitor { public: using kir::IrVisitor::handle;