Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions python/tvm/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,22 @@ def popcount(x):
"""
return call_pure_intrin(x.dtype, "popcount", x)

def fmod(x, y):
"""Return the remainder of x divided by y with the same sign as x.

Parameters
----------
x : Expr
Input argument.
y : Expr
Input argument.

Returns
-------
z : Expr
The result.
"""
return call_pure_intrin(x.dtype, "fmod", x, y)

# Intrinsic rule related code
def register_intrin_rule(target, intrin, f=None, override=False):
Expand Down
2 changes: 2 additions & 0 deletions src/codegen/intrin_rule_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle")
.set_body(DispatchExtern<CUDAShuffle>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod")
.set_body(DispatchExtern<CUDAMath>);

} // namespace intrin
} // namespace codegen
Expand Down
3 changes: 3 additions & 0 deletions src/codegen/intrin_rule_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.pow")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.popcount")
.set_body(DispatchExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fmod")
.set_body(DispatchExtern<Direct>);

} // namespace intrin
} // namespace codegen
} // namespace tvm
3 changes: 3 additions & 0 deletions src/codegen/intrin_rule_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount")
.set_body(DispatchExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fmod")
.set_body(DispatchExtern<Direct>);

// There is no warp shuffle instruction in standard OpenCL
// When shuffle is used, we assume it is intel's shuffle extension
struct IntelShuffle {
Expand Down
6 changes: 6 additions & 0 deletions src/lang/ir_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -450,4 +450,10 @@ Expr prod(Expr source, Array<IterVar> rdom) {
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
}

Expr fmod(Expr x, Expr y) {
BinaryOpMatchTypes(x, y);
CHECK(x.type().is_float()) << "fmod only applies to float";
return ir::Call::make(x.type(), "fmod", { x, y }, ir::Call::PureIntrinsic);
}

} // namespace tvm
40 changes: 40 additions & 0 deletions tests/python/integration/test_ewise.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,45 @@ def check_device(device, host="stackvm"):
check_device("cuda", "llvm")
check_device("vulkan")

def test_fmod():
# graph
def run(dtype):
n = tvm.var('n')
A = tvm.placeholder((n,), name='A', dtype=dtype)
B = tvm.placeholder((n,), name='B', dtype=dtype)
C = tvm.compute(A.shape, lambda *i: tvm.fmod(A(*i), B(*i)), name='C')
s = tvm.create_schedule(C.op)
# create iter var and assign them tags.
num_thread = 8
bx, tx = s[C].split(C.op.axis[0], factor=num_thread)

def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("skip because %s is not enabled.." % device)
return
target = tvm.target.create(device)
if "cpu" not in target.keys:
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
fmod = tvm.build(s, [A, B, C], device, name="myfmod")

# launch the kernel.
n = 1024
a = tvm.nd.array((np.random.uniform(size=n) * 256).astype(A.dtype), ctx)
b = tvm.nd.array((np.random.uniform(size=n) * 256).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
ftimer = fmod.time_evaluator(fmod.entry_name, ctx, number=1)
tcost = ftimer(a, b, c).mean
#fmod(a, b, c)
np.testing.assert_allclose(
c.asnumpy(), np.mod(a.asnumpy(), b.asnumpy()), rtol=1e-5)

check_device("cuda")
check_device("opencl -device=intel_graphics")
check_device("metal")

run("float32")

def test_multiple_cache_write():
# graph
Expand Down Expand Up @@ -245,3 +284,4 @@ def check_device(device):
test_add()
test_log_pow_llvm()
test_popcount()
test_fmod()