Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions csrc/device_lower/pass/double_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,9 @@ class DoubleBufferFusionInspector : private IterVisitor {

// The epilogue loop is only created when the producer of a double
// buffer tensor is on smem, in which case it would otherwise require
// an additional predicate to guard buffer overruns. When it's on
// gmem, that isn't the case, so it does not need to create an
// epilogue loop.
// an additional predicate to guard buffer overruns.
bool requireEpilogue(const std::vector<Expr*>& exprs) {
return true;
return std::any_of(exprs.begin(), exprs.end(), [](const Expr* expr) {
return expr->input(0)->as<TensorView>()->getMemoryType() ==
MemoryType::Shared;
Expand Down Expand Up @@ -214,13 +213,13 @@ class DoubleBufferLoopCloner : public kir::IrVisitor {
loop_type_ == DoubleBufferLoopStage::Main &&
requireEpilogue(double_buffer_load_exprs_)) {
stop = IrBuilder::subExpr(
double_buffer_loop_->stop(), gpu_lower->kernel()->oneVal());
double_buffer_loop_->stop(),
IrBuilder::create<Val>(int64_t(stage_depth - 1), DataType::Index));
} else if (loop_type_ == DoubleBufferLoopStage::Epilog) {
NVF_ERROR(requireEpilogue(double_buffer_load_exprs_));
start = IrBuilder::subExpr(
double_buffer_loop_->stop(),
SimplifyingIrBuilder::create<Val>(
int64_t(stage_depth - 1), DataType::Index));
IrBuilder::create<Val>(int64_t(stage_depth - 1), DataType::Index));
}

cloned_top_level_loop_ = IrBuilder::create<kir::ForLoop>(
Expand Down Expand Up @@ -596,6 +595,34 @@ class DoubleBufferInserter : private kir::ExprMutator {
loads,
DoubleBufferLoopStage::Epilog,
alloc_in_main);
if (has_cpasync) {
// Insert a cp.async.wait_group at the end of each epilogue loop
// iteration. This should require all but the remaining future iteration
// loads to be completed.
// for i in (N-(D-1))..N: // epilog
// for j in ...
// .. = x[(i%2)*S+j]
// cp.async.wait N-1-i; // Ensure all but the future loads are
// complete
// For D=4 this unrolls as
// for j in ...
// .. = x[(N-3)%2)*S+j]
// cp.async.wait 2;
// for j in ...
// .. = x[(N-2)%2)*S+j]
// cp.async.wait 1;
// for j in ...
// .. = x[(N-1)%2)*S+j]
// cp.async.wait 0;
auto cp_async_wait = IrBuilder::create<kir::AsyncWait>(
AsyncOpType::CpAsync,
SimplifyingIrBuilder::subExpr(
epilogue_loop->stop(),
SimplifyingIrBuilder::addExpr(
epilogue_loop->index(),
FusionGuard::getCurFusion()->oneVal())));
epilogue_loop->body().push_back(cp_async_wait);
}
registerInsertAfter(double_buffer_loop, epilogue_loop);
}
}
Expand Down
18 changes: 11 additions & 7 deletions csrc/device_lower/pass/double_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@
// iteration. However, the value loaded by the invalid load would not
// be used, so instead of adding the additional predicate, the Epilogue
// loop is replicated from the original loop, except for the load
// expression since it's not used. Note that this overrun does not
// happen when the producer is on gmem, so in that case, this
// additional replication is not done.
// expression since it's not used.
//
// When creating those three types of loops, additional care must be
// taken when multiple tensors are double buffered. When multiple
Expand Down Expand Up @@ -109,16 +107,16 @@
// if pred:
// x[i*S+j] = y[i, j];
//
// for i in 0..N: // main loop
// for i in 0..(N-(D-1)): // main loop
// for j in ...
// if pred:
// x[((i+D-1)%D)*S+j] = y[i+D-1, j];
// for j in ...
// .. = x[(i%D)*S+j]
//
// (Epilog omitted since this only makes sense in using
// cp.async, where producer will be in global mem and consumer will
// be in shared mem).
// for i in (N-(D-1))..N: // epilog
// for j in ...
// .. = x[(i%2)*S+j]
//
// The profitability of this optimization comes from extra tolerance
// of global memory pipeline latency, as on the expression `.. = x[(i%D)*S+j]`
Expand Down Expand Up @@ -162,6 +160,12 @@
// ensure completion of its own async copies so
// would need to sync to this point to ensure
// completion of the whole tile.
//
// for i in (N-(D-1))..N: // epilog
// for j in ...
// // TODO: is a sync required here?
// .. = x[(i%2)*S+j]
// cp.async.wait N-1-i; // Ensure all but the future loads are incomplete

namespace nvfuser {

Expand Down
5 changes: 3 additions & 2 deletions csrc/device_lower/pass/inline_ptx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ class LowerToInlinePtx : public kir::ExprMutator {

void handle(kir::AsyncWait* wait) override {
if (wait->asyncOpType() == AsyncOpType::CpAsync &&
wait->keepStages() == 0) {
wait->keepStages()->isConstInt() &&
wait->keepStages()->evaluate().as<int64_t>() == 0) {
// cp.async uses wait_all for zero keep stages, other instructions uses a
// unified interface for all keep stages.
registerReplace(
Expand All @@ -48,7 +49,7 @@ class LowerToInlinePtx : public kir::ExprMutator {
IrBuilder::create<kir::Asm>(
wait->ptx(),
std::vector<Val*>{},
std::vector<Val*>{IrBuilder::create<Val>(wait->keepStages())},
std::vector<Val*>{wait->keepStages()},
kir::Asm::Options{/*volatile=*/true, /*memory=*/wait->memory()}));
}
}
Expand Down
16 changes: 13 additions & 3 deletions csrc/kernel_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -612,19 +612,29 @@ NVFUSER_DEFINE_CLONE_AND_CREATE(BlockSerializeRelease)
AsyncWait::AsyncWait(
IrBuilderPasskey passkey,
AsyncOpType async_op_type,
int64_t keep_stages)
Val* keep_stages)
: Expr(passkey) {
NVF_ERROR(passkey.ir_container_ != nullptr);
NVF_ERROR(
passkey.ir_container_->isA<kir::Kernel>(),
"IR type only valid for Kernel container.");
addDataAttribute(async_op_type);
addDataAttribute(keep_stages);
addAttribute(keep_stages);
}

AsyncWait::AsyncWait(
IrBuilderPasskey passkey,
AsyncOpType async_op_type,
int64_t keep_stages)
: AsyncWait(
passkey,
async_op_type,
IrBuilder::create<Val>(keep_stages, DataType::Int)) {}

std::string AsyncWait::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << ptx() << " " << keepStages() << "\n";
indent(ss, indent_size) << ptx() << " " << keepStages()->toInlineString()
<< "\n";
return ss.str();
}

Expand Down
9 changes: 7 additions & 2 deletions csrc/kernel_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,11 @@ class AsyncWait final : public Expr {
public:
using Expr::Expr;

explicit AsyncWait(
IrBuilderPasskey passkey,
AsyncOpType async_op_type,
Val* keep_stages);

explicit AsyncWait(
IrBuilderPasskey passkey,
AsyncOpType async_op_type,
Expand All @@ -651,8 +656,8 @@ class AsyncWait final : public Expr {

//! Returns the remaining number of stages that are not synchronized
//! after this op.
int64_t keepStages() const {
return attribute<int64_t>(1);
Val* keepStages() const {
return attributeVal(1);
Comment on lines -654 to +660
Copy link
Collaborator Author

@jacobhinkle jacobhinkle Mar 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I attempted to make keepStages() a Val*. The issue is that this must be a compile time constant argument for the inline asm instruction. I think it still might work if we are able to unroll the epilogue loop, but that might not always be preferable/acceptable. So instead, we could also have a runtime function/kir node that calls a runtime helper function that wraps cp.async.wait_group N for variable N and handles values up to say 5 or 6 inside a switch statement.

Copy link
Collaborator Author

@jacobhinkle jacobhinkle Mar 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UPDATE: the switch statement helper approach seems to work and doesn't require unrolling the epilogue (we can't get away from this switch statement even with unrolling the epilogue), but it means we need to set an upper limit on the number of unsynched stages. We could set that to something high like 10. We only need it to be num_stages, i.e. we don't need a switch statement with 10 cases if we have only 3 circular buffering stages. However, the requirement to have a constant N in cp.async.wait_group, along with the requirement for inline asm to have string literal inputs has stumped me. I tried all kinds of combinations of templates and macros but got nowhere.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another alternative is to, instead of making the epilogue

for i in (N-(D-1))..N // epilogue
  cp.async.wait N-1-i
  for j in ...
    ... = x[((i%D)*S+j)]

we change it to

for i in 0..(D-1) // epilogue
  cp.async.wait D-2-i
  for j in ...
    ... = x[(((N-(D-1) + i)%D)*S+j)]

so that the epilogue is naturally unrolled.

And at the same time

Fuser/csrc/kernel_ir.cpp

Lines 321 to 323 in b108bca

if (in->isConst()) {
constraint = "n";
} else {

needs change to generate an n for PTX constraints

also, you will need update

Fuser/csrc/index_compute.cpp

Lines 2354 to 2400 in b108bca

if (consumer_tv->isDoubleBuffered() || consumer_tv->isCircularBuffered()) {
auto db_loop =
gpu_lower->doubleBufferInfo().getDoubleBufferLoop(consumer_tv, loops);
auto stage_depth = (int64_t)gpu_lower->doubleBufferInfo().getStageDepthFor(
db_loop->iter_domain());
bool is_circular_buffer_loop = stage_depth > 2;
bool is_prolog =
db_loop->doubleBufferLoopStage() == DoubleBufferLoopStage::Prolog;
Val* db_switch_index = nullptr;
// In double buffered we don't materialize the prolog loop as there will
// be only one iteration. In circular buffer case we materialize the
// prolog loop as well covering the first N-1 iterations, N being the
// stage depth.
if (!is_prolog || is_circular_buffer_loop) {
if (is_prolog && is_circular_buffer_loop) {
// The buffer switching logic is the same as original index
// in the case of circular buffer prolog.
db_switch_index = db_loop->indexOrStartIfTrivial();
if (rotated_loops.count(db_loop)) {
db_switch_index =
SimplifyingIrBuilder::addExpr(db_switch_index, db_loop->step());
}
} else {
auto loop_index = db_loop->indexOrStartIfTrivial();
if (rotated_loops.count(db_loop)) {
loop_index =
SimplifyingIrBuilder::addExpr(loop_index, db_loop->step());
}
// Switching index generated for main loop or epilog component.
db_switch_index = SimplifyingIrBuilder::modExpr(
SimplifyingIrBuilder::addExpr(
loop_index,
SimplifyingIrBuilder::create<Val>(
stage_depth - 1, DataType::Index)),
SimplifyingIrBuilder::create<Val>(stage_depth, DataType::Index));
}
// Use the generated switching buffer index to access the buffer space.
auto original_alloc_size =
gpu_lower->doubleBufferInfo().getOriginalAllocSize(consumer_tv);
auto db_strided_index =
SimplifyingIrBuilder::mulExpr(db_switch_index, original_alloc_size);
strided_inds.push_back(db_strided_index);
}
}

to change loop index

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm understanding, you mean to use the unrolled loop index as the argument to wait_group. That fails since even if the loop variable is the actual argument we get

CUDA NVRTC compile error: __tmp_kernel_none_f0_c0_r0_g0.cu(10249): error: an asm operand must be an integral constant expression.
      asm volatile("cp.async.wait_group %0;"::"n"(ii)); 

That is, in PTX it is seeing this as a non-constant argument. I have tried interpolating it into that string, but for inline assembly the command must be a string literal...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mmigdal-nv helped come up with this solution:

template <int num_stages>
__inline__ __device__ void cpAsyncPartialBarrier(int keep_stages) {
  if constexpr (num_stages < 0) {
    return;
  }
  if (keep_stages == num_stages) {
    asm volatile("cp.async.wait_group %0;"::"n"(num_stages));
  } else {
    cpAsyncPartialBarrier<num_stages - 1>(keep_stages);
  }
}

template <>
__inline__ __device__ void cpAsyncPartialBarrier<-1>(int keep_stages) {
}

...
  #pragma unroll
  for(nvfuser_index_t i13 = 12793; i13 < 12800LL; ++i13) {
    ...
    __syncthreads();
    cpAsyncPartialBarrier<8>((12800 - 2) - i13);
  }
  NVFUSER_UPDATE_MAGIC_ZERO;

We can replace ((12800 - 2) - i13) with the Val* we have currently. The compiler will evaluate the recursive template and prune the dead branches. I think the only downside to this is that we need to unroll the new loop, will probably hurt compilation time.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fun C++...

template<int i>
void f() {}

int main() {
    #pragma unroll
    for (int i = 0; i < 10; i++) {
        f<i>();
    }
    return 0;
}
<source>: In function 'int main()':
<source>:7:13: error: no matching function for call to 'f<i>()'
    7 |         f<i>();
      |         ~~~~^~
<source>:2:6: note: candidate: 'template<int i> void f()'
    2 | void f() {}
      |      ^
<source>:2:6: note:   template argument deduction/substitution failed:
<source>:7:13: error: the value of 'i' is not usable in a constant expression
    7 |         f<i>();
      |         ~~~~^~
<source>:6:14: note: 'int i' is not const
    6 |     for (int i = 0; i < 10; i++) {
      |              ^
<source>:7:13: note: in template argument for type 'int'
    7 |         f<i>();
      |         ~~~~^~

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW I looked to see what CUTLASS does. It seems they do not peel off an epilogue loop. Instead they just wait_group 0 after the main loop. https://github.com/NVIDIA/cutlass/blob/c4e3e122e266644c61b4af33d0cc09f4c391a64b/include/cutlass/gemm/threadblock/mma_multistage.h

}
};

Expand Down