diff --git a/src/pass/lower_thread_allreduce.cc b/src/pass/lower_thread_allreduce.cc index 91318ff7bfc9..8c0eb037d953 100644 --- a/src/pass/lower_thread_allreduce.cc +++ b/src/pass/lower_thread_allreduce.cc @@ -261,6 +261,7 @@ class ThreadAllreduceBuilder final : public IRMutator { if (in_warp_seq.size() != 0) { Stmt warp_body = MergeSeq(in_warp_seq); seq.emplace_back(IfThenElse::make(in_warp_cond, warp_body)); + seq.emplace_back(SyncThread("shared")); } return MergeSeq(seq); } diff --git a/topi/python/topi/cuda/softmax.py b/topi/python/topi/cuda/softmax.py index de990e086ac8..d57c9e10fe0e 100644 --- a/topi/python/topi/cuda/softmax.py +++ b/topi/python/topi/cuda/softmax.py @@ -1,4 +1,4 @@ -# pylint: disable=invalid-name, unused-variable, trailing-whitespace +# pylint: disable=invalid-name, unused-variable, trailing-whitespace """Schedule for softmax operator""" import tvm @@ -34,7 +34,7 @@ def schedule_softmax(outs): s[expsum].bind(s[expsum].op.axis[0], block_x) s[expsum].bind(s[expsum].op.reduce_axis[0], thread_x) s[EF].compute_at(s[expsum], s[expsum].op.reduce_axis[0]) - + s[expsum].set_store_predicate(thread_x.var.equal(0)) tx, xi = s[softmax].split(softmax.op.axis[1], nparts=num_thread) s[softmax].bind(softmax.op.axis[0], block_x) s[softmax].bind(tx, thread_x) diff --git a/tutorials/language/reduction.py b/tutorials/language/reduction.py index f27139c3a158..531283e15213 100644 --- a/tutorials/language/reduction.py +++ b/tutorials/language/reduction.py @@ -108,8 +108,10 @@ xo, xi = s[B].split(s[B].op.axis[0], factor=32) s[B].bind(xo, tvm.thread_axis("blockIdx.x")) s[B].bind(xi, tvm.thread_axis("threadIdx.y")) -s[B].bind(s[B].op.reduce_axis[0], tvm.thread_axis("threadIdx.x")) +tx = tvm.thread_axis("threadIdx.x") +s[B].bind(s[B].op.reduce_axis[0], tx) s[BF].compute_at(s[B], s[B].op.reduce_axis[0]) +s[B].set_store_predicate(tx.var.equal(0)) fcuda = tvm.build(s, [A, B], "cuda") print(fcuda.imported_modules[0].get_source())