From f95c51d8e6a3bb428c90301cd478fab81fb4ed83 Mon Sep 17 00:00:00 2001 From: xqdan Date: Tue, 23 Oct 2018 17:01:25 +0800 Subject: [PATCH 1/4] [intrin]support fmod for cuda --- python/tvm/intrin.py | 16 +++++++++++ src/codegen/intrin_rule_cuda.cc | 2 ++ src/lang/ir_operator.cc | 6 ++++ tests/python/integration/test_ewise.py | 38 ++++++++++++++++++++++++++ 4 files changed, 62 insertions(+) diff --git a/python/tvm/intrin.py b/python/tvm/intrin.py index 30da873b5dcf..5c194b2e3d16 100644 --- a/python/tvm/intrin.py +++ b/python/tvm/intrin.py @@ -376,6 +376,22 @@ def popcount(x): """ return call_pure_intrin(x.dtype, "popcount", x) +def mod(x, y): + """Take mod cast of input x and y + + Parameters + ---------- + x : Expr + Input argument. + y : Expr + Input argument. + + Returns + ------- + z : Expr + The result. + """ + return call_pure_intrin(x.dtype, "fmod", x, y) # Intrinsic rule related code def register_intrin_rule(target, intrin, f=None, override=False): diff --git a/src/codegen/intrin_rule_cuda.cc b/src/codegen/intrin_rule_cuda.cc index ee98a54329ab..a6867c7f201c 100644 --- a/src/codegen/intrin_rule_cuda.cc +++ b/src/codegen/intrin_rule_cuda.cc @@ -91,6 +91,8 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod") +.set_body(DispatchExtern); } // namespace intrin } // namespace codegen diff --git a/src/lang/ir_operator.cc b/src/lang/ir_operator.cc index 275752644be9..c2ad8539dcbf 100644 --- a/src/lang/ir_operator.cc +++ b/src/lang/ir_operator.cc @@ -450,4 +450,10 @@ Expr prod(Expr source, Array rdom) { return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0); } +Expr mod(Expr x, Expr y) { + BinaryOpMatchTypes(x, y); + CHECK(x.type().is_float()) << "mod only applies to float"; + return ir::Call::make(x.type(), "fmod", { x, y }, ir::Call::PureIntrinsic); +} + } // namespace tvm diff --git a/tests/python/integration/test_ewise.py b/tests/python/integration/test_ewise.py index 0f58c2367576..8da94bca2804 100644 --- a/tests/python/integration/test_ewise.py +++ b/tests/python/integration/test_ewise.py @@ -38,6 +38,43 @@ def check_device(device, host="stackvm"): check_device("cuda", "llvm") check_device("vulkan") +def test_mod(): + # graph + def run(dtype): + n = tvm.var('n') + A = tvm.placeholder((n,), name='A', dtype=dtype) + B = tvm.placeholder((n,), name='B', dtype=dtype) + C = tvm.compute(A.shape, lambda *i: tvm.mod(A(*i), B(*i)), name='C') + s = tvm.create_schedule(C.op) + # create iter var and assign them tags. + num_thread = 8 + bx, tx = s[C].split(C.op.axis[0], factor=num_thread) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("skip because %s is not enabled.." % device) + return + target = tvm.target.create(device) + if "cpu" not in target.keys: + s[C].bind(bx, tvm.thread_axis("blockIdx.x")) + s[C].bind(tx, tvm.thread_axis("threadIdx.x")) + fmod = tvm.build(s, [A, B, C], device, name="mymod") + + # launch the kernel. + n = 1024 + a = tvm.nd.array((np.random.uniform(size=n) * 256).astype(A.dtype), ctx) + b = tvm.nd.array((np.random.uniform(size=n) * 256).astype(B.dtype), ctx) + c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) + ftimer = fmod.time_evaluator(fmod.entry_name, ctx, number=1) + tcost = ftimer(a, b, c).mean + #fmod(a, b, c) + np.testing.assert_allclose( + c.asnumpy(), np.mod(a.asnumpy(), b.asnumpy()), rtol=1e-5) + + check_device("cuda") + + run("float32") def test_multiple_cache_write(): # graph @@ -245,3 +282,4 @@ def check_device(device): test_add() test_log_pow_llvm() test_popcount() + test_mod() From 219bfdd13abdee1a5114c027199d445881cff19f Mon Sep 17 00:00:00 2001 From: xqdan Date: Wed, 24 Oct 2018 09:47:03 +0800 Subject: [PATCH 2/4] rename mod as fmod --- python/tvm/intrin.py | 2 +- src/lang/ir_operator.cc | 2 +- tests/python/integration/test_ewise.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tvm/intrin.py b/python/tvm/intrin.py index 5c194b2e3d16..9e2f18f29067 100644 --- a/python/tvm/intrin.py +++ b/python/tvm/intrin.py @@ -376,7 +376,7 @@ def popcount(x): """ return call_pure_intrin(x.dtype, "popcount", x) -def mod(x, y): +def fmod(x, y): """Take mod cast of input x and y Parameters diff --git a/src/lang/ir_operator.cc b/src/lang/ir_operator.cc index c2ad8539dcbf..d33fb30de938 100644 --- a/src/lang/ir_operator.cc +++ b/src/lang/ir_operator.cc @@ -450,7 +450,7 @@ Expr prod(Expr source, Array rdom) { return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0); } -Expr mod(Expr x, Expr y) { +Expr fmod(Expr x, Expr y) { BinaryOpMatchTypes(x, y); CHECK(x.type().is_float()) << "mod only applies to float"; return ir::Call::make(x.type(), "fmod", { x, y }, ir::Call::PureIntrinsic); diff --git a/tests/python/integration/test_ewise.py b/tests/python/integration/test_ewise.py index 8da94bca2804..571978c20b19 100644 --- a/tests/python/integration/test_ewise.py +++ b/tests/python/integration/test_ewise.py @@ -38,13 +38,13 @@ def check_device(device, host="stackvm"): check_device("cuda", "llvm") check_device("vulkan") -def test_mod(): +def test_fmod(): # graph def run(dtype): n = tvm.var('n') A = tvm.placeholder((n,), name='A', dtype=dtype) B = tvm.placeholder((n,), name='B', dtype=dtype) - C = tvm.compute(A.shape, lambda *i: tvm.mod(A(*i), B(*i)), name='C') + C = tvm.compute(A.shape, lambda *i: tvm.fmod(A(*i), B(*i)), name='C') s = tvm.create_schedule(C.op) # create iter var and assign them tags. num_thread = 8 @@ -59,7 +59,7 @@ def check_device(device): if "cpu" not in target.keys: s[C].bind(bx, tvm.thread_axis("blockIdx.x")) s[C].bind(tx, tvm.thread_axis("threadIdx.x")) - fmod = tvm.build(s, [A, B, C], device, name="mymod") + fmod = tvm.build(s, [A, B, C], device, name="myfmod") # launch the kernel. n = 1024 From 2fd55a2dcf5a674a4962bbc2dff87cf956a0aa70 Mon Sep 17 00:00:00 2001 From: xqdan Date: Wed, 24 Oct 2018 16:47:28 +0800 Subject: [PATCH 3/4] fix --- python/tvm/intrin.py | 2 +- src/lang/ir_operator.cc | 2 +- tests/python/integration/test_ewise.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/intrin.py b/python/tvm/intrin.py index 9e2f18f29067..3207b6112b1d 100644 --- a/python/tvm/intrin.py +++ b/python/tvm/intrin.py @@ -377,7 +377,7 @@ def popcount(x): return call_pure_intrin(x.dtype, "popcount", x) def fmod(x, y): - """Take mod cast of input x and y + """Return the remainder of x divided by y with the same sign as x. Parameters ---------- diff --git a/src/lang/ir_operator.cc b/src/lang/ir_operator.cc index d33fb30de938..9ae2912901be 100644 --- a/src/lang/ir_operator.cc +++ b/src/lang/ir_operator.cc @@ -452,7 +452,7 @@ Expr prod(Expr source, Array rdom) { Expr fmod(Expr x, Expr y) { BinaryOpMatchTypes(x, y); - CHECK(x.type().is_float()) << "mod only applies to float"; + CHECK(x.type().is_float()) << "fmod only applies to float"; return ir::Call::make(x.type(), "fmod", { x, y }, ir::Call::PureIntrinsic); } diff --git a/tests/python/integration/test_ewise.py b/tests/python/integration/test_ewise.py index 571978c20b19..321b66bbec1d 100644 --- a/tests/python/integration/test_ewise.py +++ b/tests/python/integration/test_ewise.py @@ -282,4 +282,4 @@ def check_device(device): test_add() test_log_pow_llvm() test_popcount() - test_mod() + test_fmod() From f298ed7a7e3c24fdd65027871564ff71e473c71b Mon Sep 17 00:00:00 2001 From: xiaoqiang dan Date: Sat, 27 Oct 2018 17:16:38 +0800 Subject: [PATCH 4/4] enable metal and opencl --- src/codegen/intrin_rule_metal.cc | 3 +++ src/codegen/intrin_rule_opencl.cc | 3 +++ tests/python/integration/test_ewise.py | 2 ++ 3 files changed, 8 insertions(+) diff --git a/src/codegen/intrin_rule_metal.cc b/src/codegen/intrin_rule_metal.cc index 8b499fb9ea9b..2e65d5537dd2 100644 --- a/src/codegen/intrin_rule_metal.cc +++ b/src/codegen/intrin_rule_metal.cc @@ -42,6 +42,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.pow") TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.popcount") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fmod") +.set_body(DispatchExtern); + } // namespace intrin } // namespace codegen } // namespace tvm diff --git a/src/codegen/intrin_rule_opencl.cc b/src/codegen/intrin_rule_opencl.cc index 1cb1aed01102..e4cf11bf6e64 100644 --- a/src/codegen/intrin_rule_opencl.cc +++ b/src/codegen/intrin_rule_opencl.cc @@ -42,6 +42,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fmod") +.set_body(DispatchExtern); + // There is no warp shuffle instruction in standard OpenCL // When shuffle is used, we assume it is intel's shuffle extension struct IntelShuffle { diff --git a/tests/python/integration/test_ewise.py b/tests/python/integration/test_ewise.py index 321b66bbec1d..b3f17b7c1bb1 100644 --- a/tests/python/integration/test_ewise.py +++ b/tests/python/integration/test_ewise.py @@ -73,6 +73,8 @@ def check_device(device): c.asnumpy(), np.mod(a.asnumpy(), b.asnumpy()), rtol=1e-5) check_device("cuda") + check_device("opencl -device=intel_graphics") + check_device("metal") run("float32")