diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index bea53136fd54..a150595ab551 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -549,6 +549,11 @@ TVM_DLL const Op& vectorlow(); */ TVM_DLL const Op& vectorcombine(); +/*! + * \brief atomic add instruction, corresponding e.g. to atomicAdd in CUDA + */ +TVM_DLL const Op& atomic_add(); + /*! \brief The kind of structure field info used in intrinsic */ enum TVMStructFieldKind : int { // array head address diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 46d7f9800c43..d51eb5ce1d11 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -47,8 +47,6 @@ def opencl_atomic_add_rule(op): "opencl", "atomic_add", opencl_atomic_add_rule, override=True ) -tvm.ir.register_op_attr("tir.atomic_add", "TCallEffectKind", tvm.tir.CallEffectKind.Opaque) - def atomic_add(x, y): return tvm.tir.call_intrin(y.dtype, "tir.atomic_add", x, y) diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 2890c1ce3e56..605870f48c52 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -183,6 +183,25 @@ class CodeGenAMDGPU : public CodeGenLLVM { unsigned GetGlobalAddressSpace() const final { return 1; } + llvm::Value* CreateIntrinsic(const CallNode* op) final { + if (op->op.same_as(builtin::atomic_add())) { + ICHECK(op->args[1]->dtype.bits() == 32) << "Only supports 32 bit atomic for now"; + llvm::Value* v0 = MakeValue(op->args[0]); + llvm::Value* v1 = MakeValue(op->args[1]); + if (op->args[1]->dtype.is_float()) { +#if TVM_LLVM_VERSION >= 90 + return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::FAdd, v0, v1, + llvm::AtomicOrdering::Monotonic); +#else + LOG(FATAL) << "Floating point atomic requires LLVM 9 or newer"; +#endif + } + return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::Add, v0, v1, + llvm::AtomicOrdering::Monotonic); + } + return CodeGenLLVM::CreateIntrinsic(op); + } + protected: void InitTarget(llvm::TargetMachine* tm) final { // Maximum vector lane = float4 diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index d10ed311949c..70f094a186e7 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -955,6 +955,10 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { indices.push_back(i); } return builder_->CreateShuffleVector(v0, v1, indices); + } else if (op->op.same_as(builtin::atomic_add())) { + // TODO(masahi): Support atomic for CPU backend + LOG(FATAL) << "CPU backend does not support atomic add yet."; + return nullptr; } else { LOG(FATAL) << "unknown intrinsic " << op->op; return nullptr; diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index 22e612b11090..d8002a2b58a6 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -232,6 +232,20 @@ llvm::Value* CodeGenNVPTX::CreateIntrinsic(const CallNode* op) { auto fty = llvm::FunctionType::get(t_int32_, false); auto val = llvm::InlineAsm::get(fty, "activemask.b32 %0", "=r", true); return builder_->CreateCall(val); + } else if (op->op.same_as(builtin::atomic_add())) { + ICHECK(op->args[1]->dtype.bits() == 32) << "Only supports 32 bit atomic for now"; + llvm::Value* v0 = MakeValue(op->args[0]); + llvm::Value* v1 = MakeValue(op->args[1]); + if (op->args[1]->dtype.is_float()) { +#if TVM_LLVM_VERSION >= 90 + return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::FAdd, v0, v1, + llvm::AtomicOrdering::Monotonic); +#else + LOG(FATAL) << "Floating point atomic requires LLVM 9 or newer"; +#endif + } + return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::Add, v0, v1, + llvm::AtomicOrdering::Monotonic); } return CodeGenLLVM::CreateIntrinsic(op); } diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 3afb8810e774..796b113a4054 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -229,6 +229,9 @@ TIR_DEFINE_BUILTIN_FUNC(vectorlow).set_attr("TCallEffectKind", TIR_DEFINE_BUILTIN_FUNC(vectorcombine) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); +TIR_DEFINE_BUILTIN_FUNC(atomic_add) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + } // namespace builtin } // namespace tir } // namespace tvm diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index fc1929e9dc18..668285dfb882 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1002,30 +1002,33 @@ def ref_scatter_add(data, indices, updates, axis=0): output[tuple(new_index)] += updates[index] return output - def verify_scatter_add(dshape, ishape, axis=0): - d = relay.var("d", relay.TensorType(dshape, "float32")) + def verify_scatter_add(dshape, ishape, axis=0, dtype="float32"): + d = relay.var("d", relay.TensorType(dshape, dtype)) i = relay.var("i", relay.TensorType(ishape, "int64")) - u = relay.var("u", relay.TensorType(ishape, "float32")) + u = relay.var("u", relay.TensorType(ishape, dtype)) z = relay.op.scatter_add(d, i, u, axis) func = relay.Function([d, i, u], z) - data_np = np.random.uniform(size=dshape).astype("float32") - updates_np = np.random.uniform(size=ishape).astype("float32") + data_np = np.random.uniform(size=dshape).astype(dtype) + updates_np = np.random.uniform(size=ishape).astype(dtype) indices_np = np.random.randint(-dshape[axis], dshape[axis] - 1, ishape).astype("int64") ref_res = ref_scatter_add(data_np, indices_np, updates_np, axis) for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - if target == "nvptx": - # TODO(masahi): support atomic in LLVM codegen + if target == "nvptx" and dtype == "float32" and len(dshape) == 1: + # scatter_add 1D on GPU is implemented via atomic. + # Floating point atomic requires LLVM 9 or newer for nvptx backend. + # But LLVM on CI is LLVM 8. continue intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(data_np, indices_np, updates_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) - verify_scatter_add((10,), (10,), 0) - verify_scatter_add((1000,), (1000,), 0) + verify_scatter_add((10,), (10,), 0, dtype="int32") + verify_scatter_add((1000,), (1000,)) + verify_scatter_add((1000,), (1000,), 0, dtype="int32") verify_scatter_add((10, 5), (10, 5), -2) verify_scatter_add((10, 5), (10, 5), -1) verify_scatter_add((10, 5), (3, 5), 0) diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 9e9aaf842669..f114957f3cab 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -393,8 +393,6 @@ def verify_nms( intrp2 = relay.create_executor("debug", ctx=ctx, target=target) op_res2 = intrp2.evaluate(func)(x0_data, x1_data, x2_data, x3_data) tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5) - if target == "nvptx": - continue op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data, x2_data, x3_data) tvm.testing.assert_allclose(op_indices_res1[0].asnumpy(), ref_indices_res, rtol=1e-5) op_indices_res2 = intrp2.evaluate(func_indices)(x0_data, x1_data, x2_data, x3_data) diff --git a/tests/python/topi/python/test_topi_vision.py b/tests/python/topi/python/test_topi_vision.py index 6d6353eebce6..778843be37de 100644 --- a/tests/python/topi/python/test_topi_vision.py +++ b/tests/python/topi/python/test_topi_vision.py @@ -202,15 +202,11 @@ def check_device(device): tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-4) tvm_indices_out = tvm.nd.array(np.zeros(indices_dshape, dtype="int32"), ctx) - if device in ["llvm", "cuda"]: - f = tvm.build(indices_s, [data, valid_count, indices, indices_out[0]], device) - f(tvm_data, tvm_valid_count, tvm_indices, tvm_indices_out) - else: - f = tvm.build(indices_s, [data, valid_count, indices, indices_out], device) - f(tvm_data, tvm_valid_count, tvm_indices, tvm_indices_out) + f = tvm.build(indices_s, [data, valid_count, indices, indices_out[0]], device) + f(tvm_data, tvm_valid_count, tvm_indices, tvm_indices_out) tvm.testing.assert_allclose(tvm_indices_out.asnumpy(), np_indices_result, rtol=1e-4) - for device in ["llvm", "cuda", "opencl"]: + for device in ["llvm", "cuda", "opencl", "nvptx"]: check_device(device) diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 162481bfdb6e..4b67752367db 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -74,7 +74,6 @@ def use_llvm_intrinsic(A, C): C = tvm.te.extern( (1, 1), [A], lambda ins, outs: use_llvm_intrinsic(ins[0], outs[0]), name="C", dtype="int32" ) - s = tvm.te.create_schedule(C.op) f = tvm.build(s, [A, C], target="llvm") @@ -750,6 +749,72 @@ def test_llvm_crt_static_lib(): module.save("test.o") +def atomic_add(x, y): + return tvm.tir.call_intrin(y.dtype, "tir.atomic_add", x, y) + + +@tvm.testing.requires_llvm +def test_llvm_lower_atomic(): + def do_atomic_add(A): + ib = tvm.tir.ir_builder.create() + n = A.shape[0] + atomic_add_return = ib.allocate(A.dtype, (1,), name="atomic_add_return", scope="local") + one = tvm.tir.const(1, A.dtype) + A_ptr = ib.buffer_ptr(A) + with ib.for_range(0, n, name="i", for_type="parallel") as i: + atomic_add_return[0] = atomic_add( + tvm.tir.call_intrin("handle", "tir.address_of", A_ptr[0]), one + ) + return ib.get() + + A = tvm.te.placeholder((100,), dtype="int32", name="A") + C = tvm.te.extern((100,), [A], lambda ins, _: do_atomic_add(ins[0]), name="C", dtype="int32") + s = tvm.te.create_schedule(C.op) + # This does not work because of pointer type mismatch + # TVMError: LLVM module verification failed with the following errors: + # Argument value type does not match pointer operand type! + # %21 = atomicrmw add i8* %7, i32 1 monotonic + # i8 + # f = tvm.build(s, [A], target="llvm") + + +@tvm.testing.requires_llvm +@tvm.testing.requires_gpu +def test_llvm_gpu_lower_atomic(): + def do_atomic_add(A): + ib = tvm.tir.ir_builder.create() + n = A.shape[0] + atomic_add_return = ib.allocate(A.dtype, (1,), name="atomic_add_return", scope="local") + one = tvm.tir.const(1, A.dtype) + A_ptr = ib.buffer_ptr(A) + nthread_tx = 64 + with ib.new_scope(): + nthread_bx = (n + nthread_tx - 1) // nthread_tx + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + atomic_add_return[0] = atomic_add( + tvm.tir.call_intrin("handle", "tir.address_of", A_ptr[0]), one + ) + return ib.get() + + size = 1024 + # CI uses LLVM 8, which does not support float atomic + for dtype in ["int32"]: + A = tvm.te.placeholder((size,), dtype=dtype, name="A") + C = tvm.te.extern((size,), [A], lambda ins, _: do_atomic_add(ins[0]), dtype=dtype) + s = tvm.te.create_schedule(C.op) + f = tvm.build(s, [A], target="nvptx") + + ctx = tvm.gpu() + a = tvm.nd.array(np.zeros((size,)).astype(A.dtype), ctx) + f(a) + ref = np.zeros((size,)).astype(A.dtype) + ref[0] = size + tvm.testing.assert_allclose(a.asnumpy(), ref, rtol=1e-5) + + if __name__ == "__main__": test_multiple_func() test_llvm_large_uintimm() @@ -774,3 +839,4 @@ def test_llvm_crt_static_lib(): test_llvm_shuffle() test_llvm_bf16() test_llvm_crt_static_lib() + test_llvm_gpu_lower_atomic()