Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,11 @@ TVM_DLL const Op& vectorlow();
*/
TVM_DLL const Op& vectorcombine();

/*!
* \brief atomic add instruction, corresponding e.g. to atomicAdd in CUDA
*/
TVM_DLL const Op& atomic_add();

/*! \brief The kind of structure field info used in intrinsic */
enum TVMStructFieldKind : int {
// array head address
Expand Down
2 changes: 0 additions & 2 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ def opencl_atomic_add_rule(op):
"opencl", "atomic_add", opencl_atomic_add_rule, override=True
)

tvm.ir.register_op_attr("tir.atomic_add", "TCallEffectKind", tvm.tir.CallEffectKind.Opaque)


def atomic_add(x, y):
return tvm.tir.call_intrin(y.dtype, "tir.atomic_add", x, y)
Expand Down
19 changes: 19 additions & 0 deletions src/target/llvm/codegen_amdgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,25 @@ class CodeGenAMDGPU : public CodeGenLLVM {

unsigned GetGlobalAddressSpace() const final { return 1; }

llvm::Value* CreateIntrinsic(const CallNode* op) final {
if (op->op.same_as(builtin::atomic_add())) {
ICHECK(op->args[1]->dtype.bits() == 32) << "Only supports 32 bit atomic for now";
llvm::Value* v0 = MakeValue(op->args[0]);
llvm::Value* v1 = MakeValue(op->args[1]);
if (op->args[1]->dtype.is_float()) {
#if TVM_LLVM_VERSION >= 90
return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::FAdd, v0, v1,
llvm::AtomicOrdering::Monotonic);
#else
LOG(FATAL) << "Floating point atomic requires LLVM 9 or newer";
#endif
}
return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::Add, v0, v1,
llvm::AtomicOrdering::Monotonic);
}
return CodeGenLLVM::CreateIntrinsic(op);
}

protected:
void InitTarget(llvm::TargetMachine* tm) final {
// Maximum vector lane = float4
Expand Down
4 changes: 4 additions & 0 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,10 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
indices.push_back(i);
}
return builder_->CreateShuffleVector(v0, v1, indices);
} else if (op->op.same_as(builtin::atomic_add())) {
// TODO(masahi): Support atomic for CPU backend
LOG(FATAL) << "CPU backend does not support atomic add yet.";
return nullptr;
} else {
LOG(FATAL) << "unknown intrinsic " << op->op;
return nullptr;
Expand Down
14 changes: 14 additions & 0 deletions src/target/llvm/codegen_nvptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,20 @@ llvm::Value* CodeGenNVPTX::CreateIntrinsic(const CallNode* op) {
auto fty = llvm::FunctionType::get(t_int32_, false);
auto val = llvm::InlineAsm::get(fty, "activemask.b32 %0", "=r", true);
return builder_->CreateCall(val);
} else if (op->op.same_as(builtin::atomic_add())) {
ICHECK(op->args[1]->dtype.bits() == 32) << "Only supports 32 bit atomic for now";
llvm::Value* v0 = MakeValue(op->args[0]);
llvm::Value* v1 = MakeValue(op->args[1]);
if (op->args[1]->dtype.is_float()) {
#if TVM_LLVM_VERSION >= 90
return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::FAdd, v0, v1,
llvm::AtomicOrdering::Monotonic);
#else
LOG(FATAL) << "Floating point atomic requires LLVM 9 or newer";
#endif
}
return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::Add, v0, v1,
llvm::AtomicOrdering::Monotonic);
}
return CodeGenLLVM::CreateIntrinsic(op);
}
Expand Down
3 changes: 3 additions & 0 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ TIR_DEFINE_BUILTIN_FUNC(vectorlow).set_attr<TCallEffectKind>("TCallEffectKind",
TIR_DEFINE_BUILTIN_FUNC(vectorcombine)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));

TIR_DEFINE_BUILTIN_FUNC(atomic_add)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

} // namespace builtin
} // namespace tir
} // namespace tvm
21 changes: 12 additions & 9 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,30 +1002,33 @@ def ref_scatter_add(data, indices, updates, axis=0):
output[tuple(new_index)] += updates[index]
return output

def verify_scatter_add(dshape, ishape, axis=0):
d = relay.var("d", relay.TensorType(dshape, "float32"))
def verify_scatter_add(dshape, ishape, axis=0, dtype="float32"):
d = relay.var("d", relay.TensorType(dshape, dtype))
i = relay.var("i", relay.TensorType(ishape, "int64"))
u = relay.var("u", relay.TensorType(ishape, "float32"))
u = relay.var("u", relay.TensorType(ishape, dtype))
z = relay.op.scatter_add(d, i, u, axis)

func = relay.Function([d, i, u], z)

data_np = np.random.uniform(size=dshape).astype("float32")
updates_np = np.random.uniform(size=ishape).astype("float32")
data_np = np.random.uniform(size=dshape).astype(dtype)
updates_np = np.random.uniform(size=ishape).astype(dtype)
indices_np = np.random.randint(-dshape[axis], dshape[axis] - 1, ishape).astype("int64")

ref_res = ref_scatter_add(data_np, indices_np, updates_np, axis)
for target, ctx in tvm.testing.enabled_targets():
for kind in ["graph", "debug"]:
if target == "nvptx":
# TODO(masahi): support atomic in LLVM codegen
if target == "nvptx" and dtype == "float32" and len(dshape) == 1:
# scatter_add 1D on GPU is implemented via atomic.
# Floating point atomic requires LLVM 9 or newer for nvptx backend.
# But LLVM on CI is LLVM 8.
continue
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(data_np, indices_np, updates_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)

verify_scatter_add((10,), (10,), 0)
verify_scatter_add((1000,), (1000,), 0)
verify_scatter_add((10,), (10,), 0, dtype="int32")
verify_scatter_add((1000,), (1000,))
verify_scatter_add((1000,), (1000,), 0, dtype="int32")
verify_scatter_add((10, 5), (10, 5), -2)
verify_scatter_add((10, 5), (10, 5), -1)
verify_scatter_add((10, 5), (3, 5), 0)
Expand Down
2 changes: 0 additions & 2 deletions tests/python/relay/test_op_level5.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,6 @@ def verify_nms(
intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res2 = intrp2.evaluate(func)(x0_data, x1_data, x2_data, x3_data)
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
if target == "nvptx":
continue
op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data, x2_data, x3_data)
tvm.testing.assert_allclose(op_indices_res1[0].asnumpy(), ref_indices_res, rtol=1e-5)
op_indices_res2 = intrp2.evaluate(func_indices)(x0_data, x1_data, x2_data, x3_data)
Expand Down
10 changes: 3 additions & 7 deletions tests/python/topi/python/test_topi_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,15 +202,11 @@ def check_device(device):
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-4)

tvm_indices_out = tvm.nd.array(np.zeros(indices_dshape, dtype="int32"), ctx)
if device in ["llvm", "cuda"]:
f = tvm.build(indices_s, [data, valid_count, indices, indices_out[0]], device)
f(tvm_data, tvm_valid_count, tvm_indices, tvm_indices_out)
else:
f = tvm.build(indices_s, [data, valid_count, indices, indices_out], device)
f(tvm_data, tvm_valid_count, tvm_indices, tvm_indices_out)
f = tvm.build(indices_s, [data, valid_count, indices, indices_out[0]], device)
f(tvm_data, tvm_valid_count, tvm_indices, tvm_indices_out)
tvm.testing.assert_allclose(tvm_indices_out.asnumpy(), np_indices_result, rtol=1e-4)

for device in ["llvm", "cuda", "opencl"]:
for device in ["llvm", "cuda", "opencl", "nvptx"]:
check_device(device)


Expand Down
68 changes: 67 additions & 1 deletion tests/python/unittest/test_target_codegen_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def use_llvm_intrinsic(A, C):
C = tvm.te.extern(
(1, 1), [A], lambda ins, outs: use_llvm_intrinsic(ins[0], outs[0]), name="C", dtype="int32"
)

s = tvm.te.create_schedule(C.op)
f = tvm.build(s, [A, C], target="llvm")

Expand Down Expand Up @@ -750,6 +749,72 @@ def test_llvm_crt_static_lib():
module.save("test.o")


def atomic_add(x, y):
return tvm.tir.call_intrin(y.dtype, "tir.atomic_add", x, y)


@tvm.testing.requires_llvm
def test_llvm_lower_atomic():
def do_atomic_add(A):
ib = tvm.tir.ir_builder.create()
n = A.shape[0]
atomic_add_return = ib.allocate(A.dtype, (1,), name="atomic_add_return", scope="local")
one = tvm.tir.const(1, A.dtype)
A_ptr = ib.buffer_ptr(A)
with ib.for_range(0, n, name="i", for_type="parallel") as i:
atomic_add_return[0] = atomic_add(
tvm.tir.call_intrin("handle", "tir.address_of", A_ptr[0]), one
)
return ib.get()

A = tvm.te.placeholder((100,), dtype="int32", name="A")
C = tvm.te.extern((100,), [A], lambda ins, _: do_atomic_add(ins[0]), name="C", dtype="int32")
s = tvm.te.create_schedule(C.op)
# This does not work because of pointer type mismatch
# TVMError: LLVM module verification failed with the following errors:
# Argument value type does not match pointer operand type!
# %21 = atomicrmw add i8* %7, i32 1 monotonic
# i8
# f = tvm.build(s, [A], target="llvm")


@tvm.testing.requires_llvm
@tvm.testing.requires_gpu
def test_llvm_gpu_lower_atomic():
def do_atomic_add(A):
ib = tvm.tir.ir_builder.create()
n = A.shape[0]
atomic_add_return = ib.allocate(A.dtype, (1,), name="atomic_add_return", scope="local")
one = tvm.tir.const(1, A.dtype)
A_ptr = ib.buffer_ptr(A)
nthread_tx = 64
with ib.new_scope():
nthread_bx = (n + nthread_tx - 1) // nthread_tx
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
atomic_add_return[0] = atomic_add(
tvm.tir.call_intrin("handle", "tir.address_of", A_ptr[0]), one
)
return ib.get()

size = 1024
# CI uses LLVM 8, which does not support float atomic
for dtype in ["int32"]:
A = tvm.te.placeholder((size,), dtype=dtype, name="A")
C = tvm.te.extern((size,), [A], lambda ins, _: do_atomic_add(ins[0]), dtype=dtype)
s = tvm.te.create_schedule(C.op)
f = tvm.build(s, [A], target="nvptx")

ctx = tvm.gpu()
a = tvm.nd.array(np.zeros((size,)).astype(A.dtype), ctx)
f(a)
ref = np.zeros((size,)).astype(A.dtype)
ref[0] = size
tvm.testing.assert_allclose(a.asnumpy(), ref, rtol=1e-5)


if __name__ == "__main__":
test_multiple_func()
test_llvm_large_uintimm()
Expand All @@ -774,3 +839,4 @@ def test_llvm_crt_static_lib():
test_llvm_shuffle()
test_llvm_bf16()
test_llvm_crt_static_lib()
test_llvm_gpu_lower_atomic()