diff --git a/CMakeLists.txt b/CMakeLists.txt index 7a491e33035..5044a890bcb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -60,6 +60,7 @@ endif() # nvfuser codegen sources set(NVFUSER_SRCS) list(APPEND NVFUSER_SRCS + ${NVFUSER_SRCS_DIR}/assume.cpp ${NVFUSER_SRCS_DIR}/compute_at.cpp ${NVFUSER_SRCS_DIR}/inlining.cpp ${NVFUSER_SRCS_DIR}/compute_at_map.cpp diff --git a/csrc/assume.cpp b/csrc/assume.cpp new file mode 100644 index 00000000000..08b146cc349 --- /dev/null +++ b/csrc/assume.cpp @@ -0,0 +1,47 @@ +#include +#include +#include + +#include + +namespace nvfuser::assume { + +Bool* tensorsAreNotEmpty(Val* value) { + std::vector todo{value}; + std::vector tensor_sizes; + while (!todo.empty()) { + auto v = todo.back(); + todo.pop_back(); + if (auto ns = dynamic_cast(v)) { + if (ns->isTensorSize()) { + tensor_sizes.emplace_back(v); + continue; + } + } + if (auto def = v->definition()) { + for (auto inp : def->inputs()) { + todo.emplace_back(inp); + } + } + } + Bool* result = nullptr; + // tensor_sizes might contain duplicate, and we should remove this duplication + std::vector tensor_sizes_applied; + for (auto ts : tensor_sizes) { + bool is_duplicate = false; + for (auto existing : tensor_sizes_applied) { + if (existing->sameAs(ts)) { + is_duplicate = true; + break; + } + } + if (!is_duplicate) { + tensor_sizes_applied.emplace_back(ts); + result = SimplifyingIrBuilder::andExpr( + result, SimplifyingIrBuilder::gtExpr(ts, ts->container()->zeroVal())); + } + } + return result; +} + +} // namespace nvfuser::assume diff --git a/csrc/assume.h b/csrc/assume.h new file mode 100644 index 00000000000..40b93d55d5d --- /dev/null +++ b/csrc/assume.h @@ -0,0 +1,18 @@ +#include + +// Return boolean predicates representing the conditional you want to assume. +// The return value is typically used as the `assumptions` argument of +// `simplifyExpr` + +namespace nvfuser::assume { + +// Return a boolean predicate stating that all tensor sizes appearing in `value` +// are positive. Return nullptr if `value` does not depend on any tensor size. +// For example: +// tensorsAreNotEmpty(ceilDiv(T0.size[0], 5) * T0.size[1]) +// -> T0.size[0] > 0 && T0.size[1] > 0 +// tensorsAreNotEmpty(ceilDiv(i1, 5) * i2) +// -> nullptr +Bool* tensorsAreNotEmpty(Val* value); + +} // namespace nvfuser::assume diff --git a/csrc/expr_simplifier.cpp b/csrc/expr_simplifier.cpp index 576b6c8930c..176021fa462 100644 --- a/csrc/expr_simplifier.cpp +++ b/csrc/expr_simplifier.cpp @@ -1296,6 +1296,13 @@ bool isPositiveHelper(Val* value, const Context& context) { } return true; } + } else if (auto bop = dynamic_cast(value->definition())) { + auto op = bop->getBinaryOpType(); + if (op == BinaryOpType::CeilDiv) { + return isPositive(bop->lhs(), context) && + isValidDenominator(bop->rhs(), context) && + isNonNegative(bop->rhs(), context); + } } for (const auto& [a, b] : context.getKnownLessThan()) { if (a->isZero() && b->sameAs(value)) { @@ -1611,6 +1618,33 @@ Val* eliminateTrivialComputation(Val* value, const Context& context) { } } } + { // max(a, b) -> a if a >= b, min(a, b) -> b if a >= b + if (op == BinaryOpType::Max || op == BinaryOpType::Min) { + std::vector simplified_input; + for (auto v : fop->inputs()) { + bool found_redundant = false; + for (auto& v2 : simplified_input) { + if ((op == BinaryOpType::Max && prove::lessEqual(v, v2, context)) || + (op == BinaryOpType::Min && prove::lessEqual(v2, v, context))) { + found_redundant = true; + break; + } else if ( + (op == BinaryOpType::Max && prove::lessEqual(v2, v, context)) || + (op == BinaryOpType::Min && prove::lessEqual(v, v2, context))) { + found_redundant = true; + v2 = v; + break; + } + } + if (!found_redundant) { + simplified_input.emplace_back(v); + } + } + if (simplified_input.size() < fop->inputs().size()) { + return maybeFlattenedOpOf(op, std::move(simplified_input)); + } + } + } } else if (auto bop = dynamic_cast(value->definition())) { auto lhs = foldConstants(bop->lhs()); auto rhs = foldConstants(bop->rhs()); diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index 9888f2fc848..634ca31e656 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -69,7 +70,16 @@ void ParallelDimensionMap::build(Fusion* fusion) { // Simplify dim_map_ for (auto& [k, v] : dim_map_) { - v = simplifyExpr(v); + // Well, this isn't really correct, but we need this assumption to better + // handle non-empty cases. If this turn out to be an issue, I believe we + // then need to find a more systematic way to handle empty tensor, rather + // than just disable this assumption. + auto assume = assume::tensorsAreNotEmpty(v); + if (assume != nullptr) { + v = simplifyExpr(v, {}, {assume}); + } else { + v = simplifyExpr(v); + } } // Compute exact_types_ diff --git a/csrc/utils.h b/csrc/utils.h index cc6c17b44fb..8f95ab6243c 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -12,6 +12,7 @@ #include #include +#include #include #include #include @@ -374,6 +375,13 @@ std::string toDelimitedString( return toDelimitedString(vec.begin(), vec.end(), delim); } +template +std::string toDelimitedString( + const std::deque& dq, + std::string delim = ", ") { + return toDelimitedString(dq.begin(), dq.end(), delim); +} + template void unrolled_for(func_t fun) { if constexpr (index < stop) { diff --git a/test/test_expr_simplifier.cpp b/test/test_expr_simplifier.cpp index b289dbbcbd9..775dcc95e33 100644 --- a/test/test_expr_simplifier.cpp +++ b/test/test_expr_simplifier.cpp @@ -7,12 +7,14 @@ // clang-format on #include +#include #include #include #include #include #include +#include #include #include #include @@ -80,13 +82,24 @@ namespace stupid_simple_compiler { using fun1_t = Val* (*)(Val*); using fun2_t = Val* (*)(Val*, Val*); -struct LeftParenthesis {}; +struct LeftParenthesis { + int64_t prev_lparen_pos; +}; +struct FunctionCall { + int64_t prev_lparen_pos; + std::string_view name; +}; +struct Comma {}; +struct LowestPrecedence {}; using token_t = std::variant< Val*, // variable or constant fun1_t, // unary op fun2_t, // binary op - LeftParenthesis>; + LeftParenthesis, + FunctionCall, + Comma, + LowestPrecedence>; Val* parseIdentifier(std::string_view token_str) { if (token_str == "true") { @@ -154,6 +167,23 @@ Val* parseNumber(std::string_view token_str) { } } +Val* functionCall(std::string_view name, std::deque args) { + if (name == "max") { + TORCH_CHECK( + args.size() == 2, "Invalid argument: ", toDelimitedString(args)); + return IrBuilder::maxExpr(args.at(0), args.at(1)); + } else if (name == "min") { + TORCH_CHECK( + args.size() == 2, "Invalid argument: ", toDelimitedString(args)); + return IrBuilder::minExpr(args.at(0), args.at(1)); + } else if (name == "ceilDiv") { + TORCH_CHECK( + args.size() == 2, "Invalid argument: ", toDelimitedString(args)); + return IrBuilder::ceilDivExpr(args.at(0), args.at(1)); + } + TORCH_CHECK(false, "Unknown function: ", name); +} + token_t parseToken(std::string_view token_str, bool& expect_val) { if (std::isalpha(token_str.at(0))) { TORCH_CHECK( @@ -219,6 +249,12 @@ token_t parseToken(std::string_view token_str, bool& expect_val) { // https://en.cppreference.com/w/cpp/language/operator_precedence int getOpPrecedence(token_t op) { + if (std::holds_alternative(op)) { + return std::numeric_limits::max(); + } + if (std::holds_alternative(op)) { + return 17; + } if (std::holds_alternative(op)) { auto uop = std::get(op); if (uop == fun1_t(neg) || uop == fun1_t(notOp)) { @@ -279,7 +315,19 @@ Val* parse(const char* str) { op); }; + auto eval_all_top = [&](token_t token) { + TORCH_CHECK(current != nullptr, "Expect value to evaluate top"); + while (!op_stack.empty() && + (std::holds_alternative(op_stack.back()) || + std::holds_alternative(op_stack.back())) && + getOpPrecedence(op_stack.back()) <= getOpPrecedence(token)) { + eval_top(); + } + }; + bool expect_val = true; + int64_t last_lparen_pos = -1; + while (!remaining.empty()) { const auto end_pos = remaining.find_first_of(' '); const auto token_str = remaining.substr(0, end_pos); @@ -287,19 +335,48 @@ Val* parse(const char* str) { if (token_str == "(") { TORCH_CHECK( expect_val, "Syntax error: not expecting ( but get ", token_str); - op_stack.push_back(LeftParenthesis{}); + op_stack.push_back(LeftParenthesis{last_lparen_pos}); + last_lparen_pos = op_stack.size() - 1; + } else if (token_str.back() == '(') { + TORCH_CHECK( + expect_val, + "Syntax error: not expecting function call but get ", + token_str); + op_stack.push_back(FunctionCall{ + last_lparen_pos, token_str.substr(0, token_str.size() - 1)}); + last_lparen_pos = op_stack.size() - 1; + } else if (token_str == ",") { + TORCH_CHECK(!expect_val, "Syntax error: not expecting comma"); + expect_val = true; + auto comma = Comma{}; + eval_all_top(comma); + value_stack.emplace_back(current); + op_stack.emplace_back(comma); + current = nullptr; } else if (token_str == ")") { TORCH_CHECK( !expect_val, "Syntax error: not expecting ) but get ", token_str); - // pop stack until meets matching ( - TORCH_CHECK(current != nullptr, "Expect value before )"); - while (true) { - TORCH_CHECK(!op_stack.empty(), "Unmatched )"); - if (std::holds_alternative(op_stack.back())) { + eval_all_top(LowestPrecedence{}); + auto last_lparen = op_stack.at(last_lparen_pos); + TORCH_CHECK(!op_stack.empty(), "Unmatched )"); + if (std::holds_alternative(last_lparen)) { + TORCH_INTERNAL_ASSERT(last_lparen_pos == (int64_t)op_stack.size() - 1); + auto lparen = std::get(op_stack.back()); + last_lparen_pos = lparen.prev_lparen_pos; + op_stack.pop_back(); + } else if (std::holds_alternative(last_lparen)) { + std::deque args{current}; + while (std::holds_alternative(op_stack.back())) { op_stack.pop_back(); - break; + args.push_front(value_stack.back()); + value_stack.pop_back(); } - eval_top(); + auto fc = std::get(op_stack.back()); + last_lparen_pos = fc.prev_lparen_pos; + op_stack.pop_back(); + current = functionCall(fc.name, std::move(args)); + } else { + TORCH_CHECK(false, "Unknown left parenthesis type"); } } else { token_t token = parseToken(token_str, expect_val); @@ -307,15 +384,9 @@ Val* parse(const char* str) { TORCH_CHECK(current == nullptr, "Don't expect value"); current = std::get(token); } else if (std::holds_alternative(token)) { - TORCH_CHECK(current == nullptr, "Don't expect value"); op_stack.push_back(token); } else if (std::holds_alternative(token)) { - TORCH_CHECK(current != nullptr, "Expect value before binary op"); - while (!op_stack.empty() && - !std::holds_alternative(op_stack.back()) && - getOpPrecedence(op_stack.back()) <= getOpPrecedence(token)) { - eval_top(); - } + eval_all_top(token); value_stack.push_back(current); op_stack.push_back(token); current = nullptr; @@ -416,50 +487,61 @@ TEST_F(ExprSimplifierTest, EliminateTrivialComputation_CUDA) { Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); - TORCH_CHECK(simplifyExpr("1 * i"_)->sameAs("i"_)); - TORCH_CHECK(simplifyExpr("1.0 * d"_)->sameAs("d"_)); - TORCH_CHECK(simplifyExpr("i * 1"_)->sameAs("i"_)); - TORCH_CHECK(simplifyExpr("d * 1.0"_)->sameAs("d"_)); + auto simplify = [](Val* x, Val* assumption) { + return simplifyExpr(x, {}, {assumption->as()}); + }; + + // constant folding + ASSERT_TRUE(simplifyExpr("ceilDiv( 5 , 3 ) * 5"_)->sameAs("10"_)); + + ASSERT_TRUE(simplifyExpr("1 * i"_)->sameAs("i"_)); + ASSERT_TRUE(simplifyExpr("1.0 * d"_)->sameAs("d"_)); + ASSERT_TRUE(simplifyExpr("i * 1"_)->sameAs("i"_)); + ASSERT_TRUE(simplifyExpr("d * 1.0"_)->sameAs("d"_)); ASSERT_EQ(simplifyExpr("0 * i"_)->getInt(), 0); ASSERT_EQ(simplifyExpr("i * 0"_)->getInt(), 0); - TORCH_CHECK(simplifyExpr("0 + i"_)->sameAs("i"_)); - TORCH_CHECK(simplifyExpr("0.0 + d"_)->sameAs("d"_)); - TORCH_CHECK(simplifyExpr("i + 0"_)->sameAs("i"_)); - TORCH_CHECK(simplifyExpr("d + 0.0"_)->sameAs("d"_)); + ASSERT_TRUE(simplifyExpr("0 + i"_)->sameAs("i"_)); + ASSERT_TRUE(simplifyExpr("0.0 + d"_)->sameAs("d"_)); + ASSERT_TRUE(simplifyExpr("i + 0"_)->sameAs("i"_)); + ASSERT_TRUE(simplifyExpr("d + 0.0"_)->sameAs("d"_)); - TORCH_CHECK(simplifyExpr("true && b"_)->sameAs("b"_)); - TORCH_CHECK(simplifyExpr("b && true"_)->sameAs("b"_)); + ASSERT_TRUE(simplifyExpr("true && b"_)->sameAs("b"_)); + ASSERT_TRUE(simplifyExpr("b && true"_)->sameAs("b"_)); ASSERT_EQ(simplifyExpr("false && b"_)->getBool(), false); ASSERT_EQ(simplifyExpr("b && false"_)->getBool(), false); ASSERT_EQ(simplifyExpr("true || b"_)->getBool(), true); ASSERT_EQ(simplifyExpr("b || true"_)->getBool(), true); - TORCH_CHECK(simplifyExpr("false || b"_)->sameAs("b"_)); - TORCH_CHECK(simplifyExpr("b || false"_)->sameAs("b"_)); + ASSERT_TRUE(simplifyExpr("false || b"_)->sameAs("b"_)); + ASSERT_TRUE(simplifyExpr("b || false"_)->sameAs("b"_)); - TORCH_CHECK(simplifyExpr("b && b"_)->sameAs("b"_)); - TORCH_CHECK(simplifyExpr("b || b"_)->sameAs("b"_)); - TORCH_CHECK(simplifyExpr(IrBuilder::maxExpr("i"_, "i"_))->sameAs("i"_)); - TORCH_CHECK(simplifyExpr(IrBuilder::minExpr("i"_, "i"_))->sameAs("i"_)); + ASSERT_TRUE(simplifyExpr("b && b"_)->sameAs("b"_)); + ASSERT_TRUE(simplifyExpr("b || b"_)->sameAs("b"_)); + ASSERT_TRUE(simplifyExpr("max( i , i )"_)->sameAs("i"_)); + ASSERT_TRUE(simplifyExpr("min( i , i )"_)->sameAs("i"_)); + ASSERT_TRUE(simplify("max( i1 , i2 )"_, "i1 <= i2"_)->sameAs("i2"_)); + ASSERT_TRUE(simplify("max( i2 , i1 )"_, "i1 <= i2"_)->sameAs("i2"_)); + ASSERT_TRUE(simplify("min( i1 , i2 )"_, "i1 <= i2"_)->sameAs("i1"_)); + ASSERT_TRUE(simplify("min( i2 , i1 )"_, "i1 <= i2"_)->sameAs("i1"_)); - TORCH_CHECK(simplifyExpr("i / 1"_)->sameAs("i"_)); - TORCH_CHECK(simplifyExpr("d / 1.0"_)->sameAs("d"_)); + ASSERT_TRUE(simplifyExpr("i / 1"_)->sameAs("i"_)); + ASSERT_TRUE(simplifyExpr("d / 1.0"_)->sameAs("d"_)); ASSERT_EQ(simplifyExpr("0 / i"_)->getInt(), 0); ASSERT_EQ(simplifyExpr("i % 1"_)->getInt(), 0); // -(-a) -> a - TORCH_CHECK(simplifyExpr("- - i"_)->sameAs("i"_)); - TORCH_CHECK(simplifyExpr("~ ~ i"_)->sameAs("i"_)); - TORCH_CHECK(simplifyExpr("! ! b"_)->sameAs("b"_)); + ASSERT_TRUE(simplifyExpr("- - i"_)->sameAs("i"_)); + ASSERT_TRUE(simplifyExpr("~ ~ i"_)->sameAs("i"_)); + ASSERT_TRUE(simplifyExpr("! ! b"_)->sameAs("b"_)); // Test constant folding - TORCH_CHECK(simplifyExpr("1 + i + 1"_)->sameAs("i + 2"_)); - TORCH_CHECK(simplifyExpr("1.0 + d + 1.0"_)->sameAs("d + 2.0"_)); + ASSERT_TRUE(simplifyExpr("1 + i + 1"_)->sameAs("i + 2"_)); + ASSERT_TRUE(simplifyExpr("1.0 + d + 1.0"_)->sameAs("d + 2.0"_)); // Test that FlattenedAssocCommOp::sameAs ignores order - TORCH_CHECK(simplifyExpr("( i1 * i2 ) - ( i2 * i1 )"_)->isZeroInt()); + ASSERT_TRUE(simplifyExpr("( i1 * i2 ) - ( i2 * i1 )"_)->isZeroInt()); } TEST_F(ExprSimplifierTest, SimplifyDivisibleDivMod_CUDA) { @@ -773,6 +855,13 @@ TEST_F(ExprSimplifierTest, Compare_CUDA) { ASSERT_TRUE(*simplify("i1 >= i1 * i2"_, "i1 <= 0 && i2 > 0"_)); ASSERT_TRUE(*simplify("d1 <= d1 * d2"_, "d1 >= 0.0 && d2 >= 1.0"_)); ASSERT_TRUE(*simplify("d1 >= d1 * d2"_, "d1 <= 0.0 && d2 >= 1.0"_)); + ASSERT_TRUE( + *simplifyExpr( + "ceilDiv( T0.size[0] , 128 ) * 4 >= ceilDiv( T0.size[0] , 128 )"_) + ->getBool()); + + ASSERT_TRUE(*simplify("ceilDiv( i1 , i2 ) > 0"_, "i1 > 0 && i2 > 0"_)); + ASSERT_TRUE(*simplify("ceilDiv( i1 , i2 ) >= 1"_, "i1 > 0 && i2 > 0"_)); } TEST_F(ExprSimplifierTest, FundamentalDivisionWithRemainderProperty_CUDA) { @@ -956,4 +1045,36 @@ TEST_F(ExprSimplifierTest, ReducePredicateRegisterUsage_CUDA) { } } +TEST_F(ExprSimplifierTest, MinMax_CUDA) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + auto simplify = [](Val* x, Val* assumption) { + return simplifyExpr(x, {}, {assumption->as()}); + }; + + auto expr = + "max( max( ceilDiv( T0.size[0] , 128 ) * 4 , ceilDiv( T0.size[0] , 128 ) ) , 4 )"_; + ASSERT_TRUE(simplify(expr, assume::tensorsAreNotEmpty(expr)) + ->sameAs("ceilDiv( T0.size[0] , 128 ) * 4"_)); +} + +TEST_F(ExprSimplifierTest, Assume_CUDA) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + auto expr = + "max( max( ceilDiv( T0.size[0] , 128 ) * 4 , ceilDiv( T0.size[1] , 128 ) ) , 4 )"_; + ASSERT_EQ( + simplifyExpr(IrBuilder::eqExpr( + assume::tensorsAreNotEmpty(expr), + "T0.size[0] > 0 && T0.size[1] > 0"_)) + ->getBool(), + true); + expr = "ceilDiv( T0.size[0] , T0.size[0] ) * T0.size[0]"_; + ASSERT_TRUE(assume::tensorsAreNotEmpty(expr)->sameAs("T0.size[0] > 0"_)); +} + } // namespace nvfuser diff --git a/test/test_gpu1.cpp b/test/test_gpu1.cpp index 43bf5f410c5..e8a40e3d54b 100644 --- a/test/test_gpu1.cpp +++ b/test/test_gpu1.cpp @@ -1201,17 +1201,17 @@ TEST_F(NVFuserTest, FusionParser_CUDA) { // 2. use a fuzzy compare (ignore non-significant whitespaces for example) const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Tensor T3) { - int64_t i241; - i241 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); - if ((i241 < T0.size[0])) { + int64_t i244; + i244 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); + if ((i244 < T0.size[0])) { float T5[1]; T5[0] = 0; T5[0] - = T1[i241]; + = T1[i244]; float T4[1]; T4[0] = 0; T4[0] - = T0[i241]; + = T0[i244]; float T2[1]; T2[0] = T4[0] @@ -1220,7 +1220,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Te T6[0] = T2[0] * T4[0]; - T3[i241] + T3[i244] = T6[0]; } } diff --git a/test/test_gpu2.cpp b/test/test_gpu2.cpp index 93194c781cb..45ee2722023 100644 --- a/test/test_gpu2.cpp +++ b/test/test_gpu2.cpp @@ -9029,27 +9029,27 @@ TEST_F(NVFuserTest, FusionChannelsLastParser_CUDA) { // 2. use a fuzzy compare (ignore non-significant whitespaces for example) const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, Tensor<__half, 4> T7) { - int64_t i1409; - i1409 = T0.size[2] * T0.size[1]; - int64_t i1412; - i1412 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); - int64_t i1414; - i1414 = (T0.size[1] * T0.size[2]) * T0.size[3]; - int64_t i1446; - i1446 = i1412 % i1414; - int64_t i1423; - i1423 = T0.size[2] * T0.size[3]; - int64_t i1447; - i1447 = i1446 % i1423; - if ((i1412 < (((T0.size[0] * T0.size[1]) * T0.size[2]) * T0.size[3]))) { + int64_t i1419; + i1419 = T0.size[2] * T0.size[1]; + int64_t i1422; + i1422 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); + int64_t i1424; + i1424 = (T0.size[1] * T0.size[2]) * T0.size[3]; + int64_t i1456; + i1456 = i1422 % i1424; + int64_t i1433; + i1433 = T0.size[2] * T0.size[3]; + int64_t i1457; + i1457 = i1456 % i1433; + if ((i1422 < (((T0.size[0] * T0.size[1]) * T0.size[2]) * T0.size[3]))) { __half T9[1]; T9[0] = 0; T9[0] - = T2[(((((i1409 * T0.size[3]) * (i1412 / i1414)) + (i1409 * (i1447 % T0.size[3]))) + (T0.size[2] * (i1446 / i1423))) + (i1447 / T0.size[3]))]; + = T2[(((((i1419 * T0.size[3]) * (i1422 / i1424)) + (i1419 * (i1457 % T0.size[3]))) + (T0.size[2] * (i1456 / i1433))) + (i1457 / T0.size[3]))]; __half T8[1]; T8[0] = 0; T8[0] - = T0[i1412]; + = T0[i1422]; float T3[1]; T3[0] = __half2float(T9[0]); @@ -9069,7 +9069,7 @@ __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, __half T10[1]; T10[0] = __float2half(T6[0]); - T7[i1412] + T7[i1422] = T10[0]; } } diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index da88ba59b9d..64845d82971 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -1749,21 +1749,21 @@ TEST_F(NVFuserTest, FusionIndexHoist3_CUDA) { const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T2) { - int64_t i194; - i194 = ((nvfuser_index_t)threadIdx.x) + (256 * ((nvfuser_index_t)blockIdx.x)); + int64_t i197; + i197 = ((nvfuser_index_t)threadIdx.x) + (256 * ((nvfuser_index_t)blockIdx.x)); int64_t i7; i7 = T0.size[0] * T0.size[1]; - bool b324; - b324 = i194 < i7; + bool b327; + b327 = i197 < i7; float f8; f8 = (float)(i7); float T1[1]; - if (b324) { + if (b327) { T1[0] - = sinf(T0[i194]); + = sinf(T0[i197]); } - if (b324) { - T2[i194] + if (b327) { + T2[i197] = T1[0] + f8; }