Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions csrc/expr_simplifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ std::unique_ptr<debug_print::NoOpLogger> createLogger(Val* value) {

} // namespace debug_print

namespace assoc_comm {
Val* flatten(Val* value);
} // namespace assoc_comm

namespace {

std::vector<Bool*> getAxioms() {
Expand Down Expand Up @@ -222,16 +226,20 @@ class Context {
if (auto bop = dynamic_cast<BinaryOp*>(def)) {
switch (bop->getBinaryOpType()) {
case BinaryOpType::LT:
less_than_.emplace_back(bop->lhs(), bop->rhs());
less_than_.emplace_back(
assoc_comm::flatten(bop->lhs()), assoc_comm::flatten(bop->rhs()));
break;
case BinaryOpType::LE:
less_equal_.emplace_back(bop->lhs(), bop->rhs());
less_equal_.emplace_back(
assoc_comm::flatten(bop->lhs()), assoc_comm::flatten(bop->rhs()));
break;
case BinaryOpType::GT:
less_than_.emplace_back(bop->rhs(), bop->lhs());
less_than_.emplace_back(
assoc_comm::flatten(bop->rhs()), assoc_comm::flatten(bop->lhs()));
break;
case BinaryOpType::GE:
less_equal_.emplace_back(bop->rhs(), bop->lhs());
less_equal_.emplace_back(
assoc_comm::flatten(bop->rhs()), assoc_comm::flatten(bop->lhs()));
break;
default:
TORCH_INTERNAL_ASSERT(
Expand Down
4 changes: 4 additions & 0 deletions csrc/kernel_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ class TORCH_CUDA_CU_API Predicate final : public Val {
return hasValue() && value_->isConst();
}

bool isTrivial() const {
return isConst() && value_->getBool() == true;
}

private:
PredicateType ptype_ = PredicateType::Manual;

Expand Down
7 changes: 7 additions & 0 deletions csrc/lower_scalar_hoist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,13 @@ std::list<VarInfo> getVariableInfo(

std::vector<Bool*> getAssumptions(const std::vector<kir::ForLoop*>& loops) {
std::vector<Bool*> assumptions;
// assumptions from parallel dimension
for (auto [p, extent] :
GpuLower::current()->parallelDimensionMap().getMap()) {
auto a = IrBuilder::ltExpr(NamedScalar::getParallelIndex(p), extent);
assumptions.emplace_back(a);
}
// assumptions from loop nesting
for (auto loop : loops) {
// Trivial loop is not generated, so there is no `if` or `for` in C++ to
// guard its scope. So we should not assume index < stop. One real example
Expand Down
9 changes: 9 additions & 0 deletions csrc/predicate_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ ParallelizedDomainPredicate::getPredicateMap(
gpu_lower->parallelDimensionMap().isExact(loop_ptype)) {
continue;
}
auto parallel_dim = gpu_lower->parallelDimensionMap().getRaw(loop_ptype);

// Parallel dimensions need not be predicated if fully unswitched.
if (within_unswitch &&
Expand Down Expand Up @@ -201,6 +202,14 @@ ParallelizedDomainPredicate::getPredicateMap(
continue;
}

// loop_ptype not being exact does not mean the predicate is not trivial.
// For example, if I have T1[blockIdx.x{3}] and T2[blockIdx.x{5}], then
// blockIdx.x will not be exact. However, the predicate blockIdx.x < 5 is
// still trivial.
if (tv_id->extent()->sameAs(parallel_dim)) {
continue;
}

// tv_id needs to be predicated. Adds it to the PredicateInfo map.
auto& info = map.at(loop_ptype);
info.addDomain(tv_id);
Expand Down
4 changes: 4 additions & 0 deletions test/test_expr_simplifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,10 @@ TEST_F(ExprSimplifierTest, Compare_CUDA) {

ASSERT_TRUE(*simplify("ceilDiv( i1 , i2 ) > 0"_, "i1 > 0 && i2 > 0"_));
ASSERT_TRUE(*simplify("ceilDiv( i1 , i2 ) >= 1"_, "i1 > 0 && i2 > 0"_));

ASSERT_TRUE(*simplify(
"blockIdx.x < ceilDiv( T0.size[0] , 128 ) * 4"_,
"blockIdx.x < ceilDiv( T0.size[0] , 128 ) * 4"_));
}

TEST_F(ExprSimplifierTest, FundamentalDivisionWithRemainderProperty_CUDA) {
Expand Down
12 changes: 6 additions & 6 deletions test/test_gpu1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1201,17 +1201,17 @@ TEST_F(NVFuserTest, FusionParser_CUDA) {
// 2. use a fuzzy compare (ignore non-significant whitespaces for example)
const std::string expected_kernel = R"(
__global__ void CUDAGeneratedKernel(Tensor<float, 1> T0, Tensor<float, 1> T1, Tensor<float, 1> T3) {
int64_t i244;
i244 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x));
if ((i244 < T0.size[0])) {
int64_t i248;
i248 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x));
if ((i248 < T0.size[0])) {
float T5[1];
T5[0] = 0;
T5[0]
= T1[i244];
= T1[i248];
float T4[1];
T4[0] = 0;
T4[0]
= T0[i244];
= T0[i248];
float T2[1];
T2[0]
= T4[0]
Expand All @@ -1220,7 +1220,7 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 1> T0, Tensor<float, 1> T1, Te
T6[0]
= T2[0]
* T4[0];
T3[i244]
T3[i248]
= T6[0];
}
}
Expand Down
32 changes: 16 additions & 16 deletions test/test_gpu2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9029,27 +9029,27 @@ TEST_F(NVFuserTest, FusionChannelsLastParser_CUDA) {
// 2. use a fuzzy compare (ignore non-significant whitespaces for example)
const std::string expected_kernel = R"(
__global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, Tensor<__half, 4> T7) {
int64_t i1419;
i1419 = T0.size[2] * T0.size[1];
int64_t i1422;
i1422 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x));
int64_t i1424;
i1424 = (T0.size[1] * T0.size[2]) * T0.size[3];
int64_t i1456;
i1456 = i1422 % i1424;
int64_t i1433;
i1433 = T0.size[2] * T0.size[3];
int64_t i1457;
i1457 = i1456 % i1433;
if ((i1422 < (((T0.size[0] * T0.size[1]) * T0.size[2]) * T0.size[3]))) {
int64_t i1435;
i1435 = T0.size[2] * T0.size[1];
int64_t i1438;
i1438 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x));
int64_t i1440;
i1440 = (T0.size[1] * T0.size[2]) * T0.size[3];
int64_t i1472;
i1472 = i1438 % i1440;
int64_t i1449;
i1449 = T0.size[2] * T0.size[3];
int64_t i1473;
i1473 = i1472 % i1449;
if ((i1438 < (((T0.size[0] * T0.size[1]) * T0.size[2]) * T0.size[3]))) {
__half T9[1];
T9[0] = 0;
T9[0]
= T2[(((((i1419 * T0.size[3]) * (i1422 / i1424)) + (i1419 * (i1457 % T0.size[3]))) + (T0.size[2] * (i1456 / i1433))) + (i1457 / T0.size[3]))];
= T2[(((((i1435 * T0.size[3]) * (i1438 / i1440)) + (i1435 * (i1473 % T0.size[3]))) + (T0.size[2] * (i1472 / i1449))) + (i1473 / T0.size[3]))];
__half T8[1];
T8[0] = 0;
T8[0]
= T0[i1422];
= T0[i1438];
float T3[1];
T3[0]
= __half2float(T9[0]);
Expand All @@ -9069,7 +9069,7 @@ __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2,
__half T10[1];
T10[0]
= __float2half(T6[0]);
T7[i1422]
T7[i1438]
= T10[0];
}
}
Expand Down
16 changes: 8 additions & 8 deletions test/test_gpu3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1749,21 +1749,21 @@ TEST_F(NVFuserTest, FusionIndexHoist3_CUDA) {

const std::string expected_kernel = R"(
__global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T2) {
int64_t i197;
i197 = ((nvfuser_index_t)threadIdx.x) + (256 * ((nvfuser_index_t)blockIdx.x));
int64_t i201;
i201 = ((nvfuser_index_t)threadIdx.x) + (256 * ((nvfuser_index_t)blockIdx.x));
int64_t i7;
i7 = T0.size[0] * T0.size[1];
bool b327;
b327 = i197 < i7;
bool b347;
b347 = i201 < i7;
float f8;
f8 = (float)(i7);
float T1[1];
if (b327) {
if (b347) {
T1[0]
= sinf(T0[i197]);
= sinf(T0[i201]);
}
if (b327) {
T2[i197]
if (b347) {
T2[i201]
= T1[0]
+ f8;
}
Expand Down
25 changes: 25 additions & 0 deletions test/test_gpu_tensorcore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,31 @@ TEST_F(NVFuserTest, FusionAmpereSwizzle_CUDA) {
TORCH_CHECK(gdimy == expected_gdimy);

runtime = fe.kernelTimeMs();

// Check that mma op is not predicated. This is a regression test for
// https://github.com/NVIDIA/Fuser/issues/95
class PredicateChecker : public kir::IrVisitor {
public:
using kir::IrVisitor::handle;
bool found_mma = false;

private:
void handle(MmaOp* uop) final {
found_mma = true;
for (auto expr : scope_exprs_) {
TORCH_CHECK(
!expr->isA<kir::IfThenElse>() ||
expr->as<kir::IfThenElse>()->predicate()->isTrivial(),
"MmaOp should't be predicated!",
" Get predicate ",
expr->as<kir::IfThenElse>()->predicate()->toInlineString());
}
}
} pred_checker;

GpuLower gpulw(&fusion);
pred_checker.handle(gpulw.kernel()->topLevelExprs());
ASSERT_TRUE(pred_checker.found_mma);
};

// Checking only a single layout to keep runtime short (compilation overhead)
Expand Down
30 changes: 15 additions & 15 deletions test/test_loop_rotation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ TEST_F(LoopRotationTest, NonDivisibleSplit_CUDA) {
const std::string expected_kernel = R"(
__global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T4) {
NVFUSER_DEFINE_MAGIC_ZERO
int64_t i1511;
i1511 = T0.size[0] * T0.size[1];
int64_t i1529;
i1529 = T0.size[0] * T0.size[1];
float T1[5];
float T2[5];
#pragma unroll
Expand All @@ -219,7 +219,7 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T4) {
for(nvfuser_index_t i36 = 0; i36 < 5; ++i36) {
int64_t i154;
i154 = i36 + nvfuser_zero;
if ((i154 < i1511)) {
if ((i154 < i1529)) {
T1[i36]
= T0[((T0.stride[0] * (i154 / T0.size[1])) + (T0.stride[1] * (i154 % T0.size[1])))];
}
Expand All @@ -233,10 +233,10 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T4) {
NVFUSER_UPDATE_MAGIC_ZERO
#pragma unroll 1
for(nvfuser_index_t i39 = 0; i39 < (ceilDiv((T0.size[0] * T0.size[1]), 5)); ++i39) {
int64_t i628;
i628 = 5 * i39;
int64_t i1218;
i1218 = 5 + i628;
int64_t i636;
i636 = 5 * i39;
int64_t i1230;
i1230 = 5 + i636;
// Alias Allocation - register
auto& T3 = T1;
#pragma unroll
Expand All @@ -247,10 +247,10 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T4) {
NVFUSER_UPDATE_MAGIC_ZERO
#pragma unroll
for(nvfuser_index_t i40 = 0; i40 < 5; ++i40) {
int64_t i629;
i629 = i628 + (i40 + nvfuser_zero);
if ((i629 < i1511)) {
T4[i629]
int64_t i637;
i637 = i636 + (i40 + nvfuser_zero);
if ((i637 < i1529)) {
T4[i637]
= T3[i40];
}
}
Expand All @@ -262,11 +262,11 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T4) {
NVFUSER_UPDATE_MAGIC_ZERO
#pragma unroll
for(nvfuser_index_t i36 = 0; i36 < 5; ++i36) {
int64_t i1219;
i1219 = i1218 + (i36 + nvfuser_zero);
if ((i1219 < i1511)) {
int64_t i1231;
i1231 = i1230 + (i36 + nvfuser_zero);
if ((i1231 < i1529)) {
T1[i36]
= T0[((T0.stride[0] * (i1219 / T0.size[1])) + (T0.stride[1] * (i1219 % T0.size[1])))];
= T0[((T0.stride[0] * (i1231 / T0.size[1])) + (T0.stride[1] * (i1231 % T0.size[1])))];
}
}
NVFUSER_UPDATE_MAGIC_ZERO
Expand Down