From 805de21830ccda29ba8619608985efd9110ab2c3 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 26 May 2021 16:50:24 +0800 Subject: [PATCH 1/3] Add fast_softmax support in fast_math pass --- src/relay/transforms/fast_math.cc | 9 ++++++++- src/relay/transforms/pattern_utils.h | 5 +++++ tests/python/relay/test_op_fast_math.py | 11 +++++++++-- tests/python/relay/test_pass_fast_math.py | 12 ++++++++++++ 4 files changed, 34 insertions(+), 3 deletions(-) diff --git a/src/relay/transforms/fast_math.cc b/src/relay/transforms/fast_math.cc index 91fb4cfa8973..f6da52ebe30c 100644 --- a/src/relay/transforms/fast_math.cc +++ b/src/relay/transforms/fast_math.cc @@ -34,7 +34,11 @@ namespace relay { class FastMathMutator : public ExprRewriter { public: - FastMathMutator() : exp_op_(Op::Get("exp")), erf_op_(Op::Get("erf")), tanh_op_(Op::Get("tanh")) {} + FastMathMutator() + : exp_op_(Op::Get("exp")), + erf_op_(Op::Get("erf")), + tanh_op_(Op::Get("tanh")), + softmax_op_(Op::Get("nn.softmax")) {} Expr Rewrite_(const CallNode* pre, const Expr& post) override { if (pre->op == exp_op_) { @@ -43,6 +47,8 @@ class FastMathMutator : public ExprRewriter { return FastErf(post.as()->args[0]); } else if (pre->op == tanh_op_) { return FastTanh(post.as()->args[0]); + } else if (pre->op == softmax_op_) { + return FastSoftmax(post.as()->args[0], post.as()->attrs); } return post; } @@ -54,6 +60,7 @@ class FastMathMutator : public ExprRewriter { const Op& exp_op_; const Op& erf_op_; const Op& tanh_op_; + const Op& softmax_op_; }; Expr FastMath(const Expr& e) { diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index 50a695bf1d84..920ac153b63d 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -498,6 +498,11 @@ inline Expr FastTanh(Expr e) { return Call(op, {e}); } +inline Expr FastSoftmax(Expr e, tvm::Attrs attr) { + static const Op& op = Op::Get("nn.fast_softmax"); + return Call(op, {e}, attr); +} + inline Expr Log(Expr e) { static const Op& op = Op::Get("log"); return Call(op, {e}); diff --git a/tests/python/relay/test_op_fast_math.py b/tests/python/relay/test_op_fast_math.py index 8e401bc5670a..28c2a5077555 100644 --- a/tests/python/relay/test_op_fast_math.py +++ b/tests/python/relay/test_op_fast_math.py @@ -27,7 +27,7 @@ def test_fastmath(): def test_apply(relay_op, name, f_numpy, low, high, step, dtype="float32"): - a_np = np.arange(low, high, step).astype(dtype) + a_np = np.arange(low, high, step).astype(dtype).reshape((1, -1)) b_np = f_numpy(a_np) x = relay.var("x", shape=a_np.shape, dtype="float32") @@ -56,7 +56,14 @@ def test_apply(relay_op, name, f_numpy, low, high, step, dtype="float32"): test_apply(relay.exp, "fast_exp", np.exp, low=-88, high=88, step=0.01) test_apply(relay.erf, "fast_erf", scipy.special.erf, low=-10, high=10, step=0.01) test_apply(relay.tanh, "fast_tanh", np.tanh, low=-10, high=10, step=0.01) - + test_apply( + relay.nn.fast_softmax, + "nn_fast_softmax", + tvm.topi.testing.softmax_python, + low=-10, + high=10, + step=0.01, + ) if __name__ == "__main__": test_fastmath() diff --git a/tests/python/relay/test_pass_fast_math.py b/tests/python/relay/test_pass_fast_math.py index bb3fb84fc61f..f63b6ce0f23e 100644 --- a/tests/python/relay/test_pass_fast_math.py +++ b/tests/python/relay/test_pass_fast_math.py @@ -65,7 +65,19 @@ def test_erf(): assert "fast_erf" in fast_mod[0].astext() +def test_softmax(): + x = relay.var("x", shape=(1, 16), dtype="float32") + y = relay.nn.softmax(x) + func = relay.Function([x], y) + mod = tvm.IRModule.from_expr(func) + + with tvm.transform.PassContext(opt_level=3, required_pass=["FastMath"]): + fast_mod = relay.optimize(mod, target="llvm") + assert "nn.fast_softmax" in fast_mod[0].astext() + + if __name__ == "__main__": test_exp() test_tanh() test_erf() + test_softmax() From ad94a227d448aef6a6282d98d98dd4b33f8cba09 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 26 May 2021 16:55:46 +0800 Subject: [PATCH 2/3] Lintfix --- tests/python/relay/test_op_fast_math.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/relay/test_op_fast_math.py b/tests/python/relay/test_op_fast_math.py index 28c2a5077555..c9314fae37ac 100644 --- a/tests/python/relay/test_op_fast_math.py +++ b/tests/python/relay/test_op_fast_math.py @@ -65,5 +65,6 @@ def test_apply(relay_op, name, f_numpy, low, high, step, dtype="float32"): step=0.01, ) + if __name__ == "__main__": test_fastmath() From 2151db3f84af241040cdec9e08545b90235f02ec Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 26 May 2021 17:09:43 +0800 Subject: [PATCH 3/3] Update --- python/tvm/relay/op/strategy/generic.py | 2 +- python/tvm/topi/generic/nn.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index a6ad06e544a6..0d6c3ef58cdf 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -175,7 +175,7 @@ def fast_softmax_strategy(attrs, inputs, out_type, target): strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_softmax(topi.nn.fast_softmax), - naive_schedule, + wrap_topi_schedule(topi.generic.schedule_fast_softmax), name="fast_softmax.generic", ) return strategy diff --git a/python/tvm/topi/generic/nn.py b/python/tvm/topi/generic/nn.py index 866887706862..04d649037fef 100644 --- a/python/tvm/topi/generic/nn.py +++ b/python/tvm/topi/generic/nn.py @@ -563,6 +563,23 @@ def schedule_softmax(outs): return _default_schedule(outs, False) +def schedule_fast_softmax(outs): + """Schedule for fast_softmax + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of fast_softmax + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + def schedule_dense(outs): """Schedule for dense