From e74794c08de4ad161a2973fa7879fddcfeae522e Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Tue, 23 Mar 2021 20:56:12 -0700 Subject: [PATCH 01/11] factor out some common code for DF rewriting, add ConcretizeLike --- python/tvm/relay/transform/transform.py | 13 ++ src/relay/transforms/concretize_like.cc | 182 ++++++++++++++++++ src/relay/transforms/simplify_expr.cc | 68 ++----- src/relay/transforms/simplify_expr.h | 80 ++++++++ .../python/relay/test_pass_concretize_like.py | 122 ++++++++++++ 5 files changed, 417 insertions(+), 48 deletions(-) create mode 100644 src/relay/transforms/concretize_like.cc create mode 100644 src/relay/transforms/simplify_expr.h create mode 100644 tests/python/relay/test_pass_concretize_like.py diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 5b0e480f5f28..f527cf65da90 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -830,6 +830,19 @@ def FirstOrderGradient(): return _ffi_api.FirstOrderGradient() +def ConcretizeLike(): + """ + Transforms `op_like` functions to their explicit-shape equivalent (e.g. `zeros_like(x, y)` + to `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary + dependencies and can enable more opportunities for operator fusion. + Returns + ------- + ret : tvm.transform.Pass + The registered ConcretizeLike pass. + """ + return _ffi_api.ConcretizeLike() + + def Defunctionalization(func, mod): """ Performs defunctionalization on func, diff --git a/src/relay/transforms/concretize_like.cc b/src/relay/transforms/concretize_like.cc new file mode 100644 index 000000000000..7bae6f1dcbbb --- /dev/null +++ b/src/relay/transforms/concretize_like.cc @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file concretize_like.cc + * \brief Converts `*_like` operators to their explicit shape equivalent (e.g. `zeros_like(x, y)` to + * `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary dependencies + * and can enable more opportunities for operator fusion. + */ + +#include + +#include "pattern_utils.h" +#include "simplify_expr.h" + +namespace tvm { +namespace relay { + +class ConcretizeLikeRewrite : public DFPatternRewrite { + public: + ConcretizeLikeRewrite(const Op& op) { + ICHECK(op->num_inputs == 1 || op->num_inputs == 2) + << "ConcretizeLike does not handle operators that aren't unary or binary, got: " << op; + like_pat_ = IsWildcard(); + data_pat_ = IsWildcard(); + if (op->num_inputs == 1) { + pattern_ = IsExpr(op)({like_pat_}); + } else { + pattern_ = IsExpr(op)({data_pat_, like_pat_}); + } + require_type_ = true; + } + + virtual bool Check(const Expr& pre, const Expr& post, + const Map>& node_map) const { + const CallNode* call_node = pre.as(); + ICHECK(call_node); + + if (!call_node->checked_type_.defined()) { + // TODO(@altanh): maybe because of the input being rewritten? + return false; + } + + const TensorTypeNode* like_ty = call_node->checked_type().as(); + ICHECK(like_ty) << "got non-Tensor *_like call type " << PrettyPrint(call_node->checked_type()); + + return true; + } + + virtual Expr Concretize(const Map>& node_map, Array shape, + DataType dtype) const = 0; + + Expr Callback(const Expr& pre, const Expr& post, + const Map>& node_map) const override { + if (!Check(pre, post, node_map)) { + return post; + } + + const TensorTypeNode* like_ty = pre->checked_type().as(); + Array cshape; + for (const auto& dim : like_ty->shape) { + if (const auto* imm = dim.as()) { + cshape.push_back(Integer(GetRef(imm))); + } else { + return post; + } + } + + return Concretize(node_map, cshape, like_ty->dtype); + } + + protected: + DFPattern data_pat_; + DFPattern like_pat_; +}; + +class ConcretizeZerosLikeRewrite : public ConcretizeLikeRewrite { + public: + ConcretizeZerosLikeRewrite() : ConcretizeLikeRewrite(Op::Get("zeros_like")) {} + + Expr Concretize(const Map>& node_map, Array shape, + DataType dtype) const override { + return MakeZeros(shape, dtype); + } + + TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeZerosLikeRewrite); +}; + +class ConcretizeOnesLikeRewrite : public ConcretizeLikeRewrite { + public: + ConcretizeOnesLikeRewrite() : ConcretizeLikeRewrite(Op::Get("ones_like")) {} + + Expr Concretize(const Map>& node_map, Array shape, + DataType dtype) const override { + return MakeOnes(shape, dtype); + } + + TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeOnesLikeRewrite); +}; + +class ConcretizeReshapeLikeRewrite : public ConcretizeLikeRewrite { + public: + ConcretizeReshapeLikeRewrite() : ConcretizeLikeRewrite(Op::Get("reshape_like")) {} + + Expr Concretize(const Map>& node_map, Array shape, + DataType dtype) const override { + return MakeReshape(node_map[data_pat_][0], shape); + } + + TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeReshapeLikeRewrite); +}; + +class ConcretizeCollapseSumLikeRewrite : public ConcretizeLikeRewrite { + public: + ConcretizeCollapseSumLikeRewrite() : ConcretizeLikeRewrite(Op::Get("collapse_sum_like")) {} + + Expr Concretize(const Map>& node_map, Array shape, + DataType dtype) const override { + ICHECK_LE(shape.size(), std::numeric_limits::max()); + static const Op& op = Op::Get("collapse_sum_to"); + auto attrs = make_object(); + attrs->shape = shape; + auto cshape = + MakeConstantTensor(DataType::Int(32), {static_cast(shape.size())}, shape); + return Call(op, {node_map[data_pat_][0], cshape}, Attrs(attrs)); + } + + TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeCollapseSumLikeRewrite); +}; + +class ConcretizeBroadcastToLikeRewrite : public ConcretizeLikeRewrite { + public: + ConcretizeBroadcastToLikeRewrite() : ConcretizeLikeRewrite(Op::Get("broadcast_to_like")) {} + + Expr Concretize(const Map>& node_map, Array shape, + DataType dtype) const override { + return MakeBroadCastTo(node_map[data_pat_][0], shape); + } + + TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeBroadcastToLikeRewrite); +}; + +Expr ConcretizeLike(const Expr& expr, const IRModule& mod) { + static Array callbacks = + MakeCallbacks({ConcretizeZerosLikeRewrite::Get(), ConcretizeOnesLikeRewrite::Get(), + ConcretizeReshapeLikeRewrite::Get(), ConcretizeCollapseSumLikeRewrite::Get(), + ConcretizeBroadcastToLikeRewrite::Get()}); + return RewritePatterns(callbacks, expr, mod); +} + +namespace transform { + +Pass ConcretizeLike() { + runtime::TypedPackedFunc pass_func = + [](Function f, IRModule m, PassContext pc) { + return Downcast(ConcretizeLike(f, m)); + }; + return CreateFunctionPass(pass_func, 0, "ConcretizeLike", {}); +} + +TVM_REGISTER_GLOBAL("relay._transform.ConcretizeLike").set_body_typed(ConcretizeLike); + +} // namespace transform + +} // namespace relay +} // namespace tvm \ No newline at end of file diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index b4f4cc16e9df..82f640071045 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -22,6 +22,8 @@ * \brief A pass for simplifying the Relay expression. */ +#include "simplify_expr.h" + #include #include #include @@ -34,32 +36,21 @@ namespace tvm { namespace relay { -class SimplifyPattern { - public: - virtual Expr callback(const Expr& pre, const Expr& post, - const Map>& node_map) const = 0; - - DFPattern pattern() const { return pattern_; } - - protected: - /*! \brief Pattern for rewriting */ - DFPattern pattern_; -}; - /*! * \brief SimplifyReshape matches the pattern of consecutive reshape or reverse_reshape ops, * and merges into one reshape op. */ -class SimplifyReshape : public SimplifyPattern { +class SimplifyReshape : public DFPatternRewrite { public: SimplifyReshape() { x_ = IsWildcard(); auto reshape1 = IsOp("reshape") || IsOp("contrib_reverse_reshape"); auto reshape2 = IsOp("reshape") || IsOp("contrib_reverse_reshape"); pattern_ = reshape1({reshape2({x_})}); + require_type_ = true; } - Expr callback(const Expr& pre, const Expr& post, + Expr Callback(const Expr& pre, const Expr& post, const Map>& node_map) const override { auto x = node_map[x_][0]; bool const_shape = true; @@ -77,6 +68,8 @@ class SimplifyReshape : public SimplifyPattern { return post; } + TVM_DF_PATTERN_REWRITE_GETTER(SimplifyReshape) + private: /*! \brief Pattern input */ DFPattern x_; @@ -86,16 +79,17 @@ class SimplifyReshape : public SimplifyPattern { * \brief SimplifyTranspose matches the pattern of consecutive transpose op, * and merges or cancels them. */ -class SimplifyTranspose : public SimplifyPattern { +class SimplifyTranspose : public DFPatternRewrite { public: SimplifyTranspose() { x_ = IsWildcard(); auto trans1 = IsOp("transpose") || IsOp("layout_transform"); auto trans2 = IsOp("transpose") || IsOp("layout_transform"); pattern_ = trans1({trans2({x_})}); + require_type_ = true; } - Expr callback(const Expr& pre, const Expr& post, + Expr Callback(const Expr& pre, const Expr& post, const Map>& node_map) const override { // Helper function to get the axes from call node attribute auto get_axes_from_call = [](const Call trans_call, int ndim) { @@ -170,6 +164,8 @@ class SimplifyTranspose : public SimplifyPattern { return x; } + TVM_DF_PATTERN_REWRITE_GETTER(SimplifyTranspose); + private: /*! \brief Pattern input */ DFPattern x_; @@ -178,7 +174,7 @@ class SimplifyTranspose : public SimplifyPattern { /*! * \brief FullArgwhere finds full followed by argwhere and turns it into an Arange op */ -class FullElementwise : public SimplifyPattern { +class FullElementwise : public DFPatternRewrite { public: FullElementwise() { x_ = IsWildcard(); @@ -194,9 +190,10 @@ class FullElementwise : public SimplifyPattern { DFPattern op = IsWildcard().HasAttr(attrs); DFPattern full = full_ || ones_ || zeros_; pattern_ = op({full, x_}) || op({x_, full}); + require_type_ = true; } - Expr callback(const Expr& pre, const Expr& post, + Expr Callback(const Expr& pre, const Expr& post, const Map>& node_map) const override { const CallNode* call = pre.as(); ICHECK(call); @@ -233,6 +230,8 @@ class FullElementwise : public SimplifyPattern { return post; } + TVM_DF_PATTERN_REWRITE_GETTER(FullElementwise); + private: /*! \brief binary argument */ DFPattern x_; @@ -248,37 +247,10 @@ class FullElementwise : public SimplifyPattern { DFPattern zeros_; }; -/*! - * \brief ExprSimplifier simplifies the Relay expression. - */ -class ExprSimplifier { - public: - explicit ExprSimplifier(IRModule mod) : mod_(mod) { - CreateCallback(SimplifyReshape()); - CreateCallback(SimplifyTranspose()); - CreateCallback(FullElementwise()); - } - template - void CreateCallback(const T& pattern) { - auto func = [pattern](TVMArgs args, TVMRetValue* rv) { - Expr pre = args[0]; - Expr post = args[1]; - Map> node_map = args[2]; - *rv = pattern.callback(pre, post, node_map); - }; - callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true)); - } - - Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); } - - private: - IRModule mod_; - /*! \brief Callbacks for expr simplification */ - Array callbacks_; -}; - Expr SimplifyExpr(const Expr& expr, const IRModule& mod) { - return ExprSimplifier(mod).Simplify(expr); + static Array callbacks = + MakeCallbacks({SimplifyReshape::Get(), SimplifyTranspose::Get(), FullElementwise::Get()}); + return RewritePatterns(callbacks, expr, mod); } namespace transform { diff --git a/src/relay/transforms/simplify_expr.h b/src/relay/transforms/simplify_expr.h new file mode 100644 index 000000000000..3d87513cef5a --- /dev/null +++ b/src/relay/transforms/simplify_expr.h @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/transforms/simplify_expr.h + * \brief Utility data structures for simplifying Relay expressions. + */ +#ifndef TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_ +#define TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_ + +#include +#include + +#include + +namespace tvm { +namespace relay { + +/*! \brief Defines a static function `RewriteType::Get()` that returns a statically initialized + * instance of RewriteType. */ +#define TVM_DF_PATTERN_REWRITE_GETTER(RewriteType) \ + static DFPatternRewrite* Get() { \ + static RewriteType rw; \ + return &rw; \ + } + +/*! \brief A wrapper class defining a rewrite matching a specific pattern. */ +class DFPatternRewrite { + public: + /*! \brief Returns the rewritten expression. */ + virtual Expr Callback(const Expr& pre, const Expr& post, + const Map>& node_map) const = 0; + + /*! \brief Returns the pattern to be used for matching and rewriting. */ + inline DFPattern Pattern() const { return pattern_; } + + inline bool RequireType() const { return require_type_; } + + protected: + /*! \brief The pattern for matching and rewriting. */ + DFPattern pattern_; + bool require_type_; +}; + +/*! \brief Returns an array of DFPatternCallbacks using the given rewrites. */ +inline Array MakeCallbacks(const std::vector& rewrites) { + Array callbacks; + for (const auto& rewrite : rewrites) { + auto func = [rewrite](TVMArgs args, TVMRetValue* rv) { + Expr pre = args[0]; + Expr post = args[1]; + Map> node_map = args[2]; + *rv = rewrite->Callback(pre, post, node_map); + }; + callbacks.push_back( + DFPatternCallback(rewrite->Pattern(), PackedFunc(func), rewrite->RequireType())); + } + return callbacks; +} + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_ diff --git a/tests/python/relay/test_pass_concretize_like.py b/tests/python/relay/test_pass_concretize_like.py new file mode 100644 index 000000000000..792475035b88 --- /dev/null +++ b/tests/python/relay/test_pass_concretize_like.py @@ -0,0 +1,122 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for the ConcretizeLike pass.""" +import tvm +import tvm.relay.testing +from tvm import relay +from tvm.relay.testing import run_infer_type + + +def test_reshape_like(): + data = relay.var("data", shape=(2, 3, 4), dtype="float32") + shape_like = relay.var("shape_like", shape=(6, 2, 2), dtype="float32") + f = relay.Function([data, shape_like], relay.reshape_like(data, shape_like)) + f_expected = relay.Function([data, shape_like], relay.reshape(data, (6, 2, 2))) + f_expected = run_infer_type(f_expected) + + mod = tvm.IRModule.from_expr(f) + mod_concrete = relay.transform.ConcretizeLike()(mod) + assert tvm.ir.structural_equal(mod_concrete["main"], f_expected) + + +def test_reshape_like_attrs(): + data = relay.var("data", shape=(2, 3, 4), dtype="float32") + shape_like = relay.var("shape_like", shape=(6, 2, 2), dtype="float32") + f = relay.Function( + [data, shape_like], relay.reshape_like(data, shape_like, lhs_begin=2, rhs_begin=1) + ) + f_expected = relay.Function([data, shape_like], relay.reshape(data, (2, 3, 2, 2))) + f_expected = run_infer_type(f_expected) + + mod = tvm.IRModule.from_expr(f) + mod_concrete = relay.transform.ConcretizeLike()(mod) + assert tvm.ir.structural_equal(mod_concrete["main"], f_expected) + + +def test_zeros_like(): + dtype = "int32" + shape_like = relay.var("shape_like", shape=(3, 4, 5), dtype=dtype) + f = relay.Function([shape_like], relay.zeros_like(shape_like)) + f_expected = relay.Function([shape_like], relay.zeros((3, 4, 5), dtype)) + f_expected = run_infer_type(f_expected) + + mod = tvm.IRModule.from_expr(f) + mod_concrete = relay.transform.ConcretizeLike()(mod) + assert tvm.ir.structural_equal(mod_concrete["main"], f_expected) + + +def test_ones_like(): + dtype = "int32" + shape_like = relay.var("shape_like", shape=(3, 4, 5), dtype=dtype) + f = relay.Function([shape_like], relay.ones_like(shape_like)) + f_expected = relay.Function([shape_like], relay.ones((3, 4, 5), dtype)) + f_expected = run_infer_type(f_expected) + + mod = tvm.IRModule.from_expr(f) + mod_concrete = relay.transform.ConcretizeLike()(mod) + assert tvm.ir.structural_equal(mod_concrete["main"], f_expected) + + +def test_collapse_sum_like(): + data = relay.var("data", shape=(3, 3, 3), dtype="float32") + shape_like = relay.var("shape_like", shape=(3,), dtype="float32") + f = relay.Function([data, shape_like], relay.collapse_sum_like(data, shape_like)) + f_expected = relay.Function([data, shape_like], relay.collapse_sum_to(data, (3,))) + f_expected = run_infer_type(f_expected) + + mod = tvm.IRModule.from_expr(f) + mod_concrete = relay.transform.ConcretizeLike()(mod) + assert tvm.ir.structural_equal(mod_concrete["main"], f_expected) + + +def test_broadcast_to_like(): + data = relay.var("data", shape=(3,), dtype="float32") + shape_like = relay.var("shape_like", shape=(3, 3, 3), dtype="float32") + f = relay.Function([data, shape_like], relay.broadcast_to_like(data, shape_like)) + f_expected = relay.Function([data, shape_like], relay.broadcast_to(data, (3, 3, 3))) + f_expected = run_infer_type(f_expected) + + mod = tvm.IRModule.from_expr(f) + mod_concrete = relay.transform.ConcretizeLike()(mod) + assert tvm.ir.structural_equal(mod_concrete["main"], f_expected) + + +def test_multiple(): + x = relay.var("x", shape=(2, 3), dtype="float32") + y = relay.var("x", shape=(3,), dtype="float32") + l = x + y + + dl = relay.ones_like(l) + dx = relay.zeros_like(x) + dy = relay.zeros_like(y) + dx = dx + relay.collapse_sum_like(dl, dx) + dy = dy + relay.collapse_sum_like(dl, dy) + ret = relay.Tuple([dx, dy]) + f = relay.Function([x, y], ret) + + dl_c = relay.ones((2, 3), "float32") + dx_c = relay.zeros((2, 3), "float32") + dy_c = relay.zeros((3,), "float32") + dx_c = dx_c + relay.collapse_sum_to(dl_c, (2, 3)) + dy_c = dy_c + relay.collapse_sum_to(dl_c, (3,)) + ret_c = relay.Tuple([dx_c, dy_c]) + f_expected = relay.Function([x, y], ret_c) + f_expected = run_infer_type(f_expected) + + mod = tvm.IRModule.from_expr(f) + mod_concrete = relay.transform.ConcretizeLike()(mod) + assert tvm.ir.structural_equal(mod_concrete["main"], f_expected) From 8bd1047880a0aa97f46ef32ad85288252882ec7d Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Tue, 23 Mar 2021 22:47:44 -0700 Subject: [PATCH 02/11] slight refactoring, add EliminateIdentity pass --- python/tvm/relay/transform/transform.py | 15 ++++ src/relay/transforms/concretize_like.cc | 14 ++-- src/relay/transforms/simplify_expr.cc | 70 ++++++++++++++++++- src/relay/transforms/simplify_expr.h | 38 +++++----- .../python/relay/test_pass_concretize_like.py | 5 ++ tests/python/relay/test_pass_simplify_expr.py | 38 ++++++++++ 6 files changed, 150 insertions(+), 30 deletions(-) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index f527cf65da90..13b66d96f949 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1136,6 +1136,21 @@ def SimplifyExpr(): return _ffi_api.SimplifyExpr() +def EliminateIdentity(): + """ + Eliminates any expressions that are equivalent to identity, such as x + 0 + and x * 1. Note that these expressions cannot be eliminated when they + broadcast x to a new shape (although they could be replaced with explicit + broadcasting operations). + + Returns + ------- + ret : tvm.transform.Pass + The registered EliminateIdentity pass. + """ + return _ffi_api.EliminateIdentity() + + def FoldExplicitPadding(): """ FoldExplicitPadding finds explict padding before an op that can support diff --git a/src/relay/transforms/concretize_like.cc b/src/relay/transforms/concretize_like.cc index 7bae6f1dcbbb..09d038bc35a0 100644 --- a/src/relay/transforms/concretize_like.cc +++ b/src/relay/transforms/concretize_like.cc @@ -157,10 +157,10 @@ class ConcretizeBroadcastToLikeRewrite : public ConcretizeLikeRewrite { }; Expr ConcretizeLike(const Expr& expr, const IRModule& mod) { - static Array callbacks = - MakeCallbacks({ConcretizeZerosLikeRewrite::Get(), ConcretizeOnesLikeRewrite::Get(), - ConcretizeReshapeLikeRewrite::Get(), ConcretizeCollapseSumLikeRewrite::Get(), - ConcretizeBroadcastToLikeRewrite::Get()}); + static Array callbacks = { + ConcretizeZerosLikeRewrite::GetCallback(), ConcretizeOnesLikeRewrite::GetCallback(), + ConcretizeReshapeLikeRewrite::GetCallback(), ConcretizeCollapseSumLikeRewrite::GetCallback(), + ConcretizeBroadcastToLikeRewrite::GetCallback()}; return RewritePatterns(callbacks, expr, mod); } @@ -168,9 +168,9 @@ namespace transform { Pass ConcretizeLike() { runtime::TypedPackedFunc pass_func = - [](Function f, IRModule m, PassContext pc) { - return Downcast(ConcretizeLike(f, m)); - }; + [](Function f, IRModule m, PassContext pc) { + return Downcast(ConcretizeLike(f, m)); + }; return CreateFunctionPass(pass_func, 0, "ConcretizeLike", {}); } diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index 82f640071045..fdf5b0f8c534 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -172,7 +172,8 @@ class SimplifyTranspose : public DFPatternRewrite { }; /*! - * \brief FullArgwhere finds full followed by argwhere and turns it into an Arange op + * \brief FullElementwise finds full like ops followed by broadcasting ops, and eliminates + * the full op by directly passing the fill value into the broadcasting op. */ class FullElementwise : public DFPatternRewrite { public: @@ -247,14 +248,77 @@ class FullElementwise : public DFPatternRewrite { DFPattern zeros_; }; +/*! \brief Eliminates expressions that are just identity. */ +class EliminateIdentity : public DFPatternRewrite { + public: + EliminateIdentity() { + x_ = IsWildcard(); + + DFPattern add_op = IsOp("add"); + DFPattern mul_op = IsOp("multiply"); + DFPattern zeros_call = IsOp("zeros")({}) || IsOp("zeros_like")({IsWildcard()}); + DFPattern ones_call = IsOp("ones")({}) || IsOp("ones_like")({IsWildcard()}); + + DFPattern add_id = add_op({x_, zeros_call}) || add_op({zeros_call, x_}); + DFPattern mul_id = mul_op({x_, ones_call}) || mul_op({ones_call, x_}); + DFPattern sub_id = IsOp("subtract")({x_, zeros_call}); + DFPattern div_id = IsOp("divide")({x_, ones_call}); + + pattern_ = add_id || mul_id || sub_id || div_id; + require_type_ = true; + } + + Expr Callback(const Expr& pre, const Expr& post, + const Map>& node_map) const override { + const CallNode* call = pre.as(); + ICHECK(call); + Type pre_type = pre->checked_type_; + ICHECK(pre_type.as()); + auto x = node_map[x_][0]; + bool is_left = post.as()->args[1] == x; + Type x_type; + if (is_left) { + x_type = call->args[1]->checked_type_; + } else { + x_type = call->args[0]->checked_type_; + } + + if (StructuralEqual()(x_type, pre_type)) { + return x; + } + + return post; + } + + TVM_DF_PATTERN_REWRITE_GETTER(EliminateIdentity); + + private: + DFPattern x_; +}; + +Expr EliminateIdentity(const Expr& expr, const IRModule& mod) { + return RewritePatterns({EliminateIdentity::GetCallback()}, expr, mod); +} + Expr SimplifyExpr(const Expr& expr, const IRModule& mod) { - static Array callbacks = - MakeCallbacks({SimplifyReshape::Get(), SimplifyTranspose::Get(), FullElementwise::Get()}); + static Array callbacks = {SimplifyReshape::GetCallback(), + SimplifyTranspose::GetCallback(), + FullElementwise::GetCallback()}; return RewritePatterns(callbacks, expr, mod); } namespace transform { +Pass EliminateIdentity() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(EliminateIdentity(f, m)); + }; + return CreateFunctionPass(pass_func, 0, "EliminateIdentity", {"InferType"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.EliminateIdentity").set_body_typed(EliminateIdentity); + Pass SimplifyExpr() { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { diff --git a/src/relay/transforms/simplify_expr.h b/src/relay/transforms/simplify_expr.h index 3d87513cef5a..913fbc1f7ba3 100644 --- a/src/relay/transforms/simplify_expr.h +++ b/src/relay/transforms/simplify_expr.h @@ -34,10 +34,14 @@ namespace relay { /*! \brief Defines a static function `RewriteType::Get()` that returns a statically initialized * instance of RewriteType. */ -#define TVM_DF_PATTERN_REWRITE_GETTER(RewriteType) \ - static DFPatternRewrite* Get() { \ - static RewriteType rw; \ - return &rw; \ +#define TVM_DF_PATTERN_REWRITE_GETTER(RewriteType) \ + static DFPatternRewrite* Get() { \ + static RewriteType rw; \ + return &rw; \ + } \ + static DFPatternCallback GetCallback() { \ + static DFPatternCallback cb = RewriteType::Get()->MakeCallback(); \ + return cb; \ } /*! \brief A wrapper class defining a rewrite matching a specific pattern. */ @@ -52,27 +56,21 @@ class DFPatternRewrite { inline bool RequireType() const { return require_type_; } - protected: - /*! \brief The pattern for matching and rewriting. */ - DFPattern pattern_; - bool require_type_; -}; - -/*! \brief Returns an array of DFPatternCallbacks using the given rewrites. */ -inline Array MakeCallbacks(const std::vector& rewrites) { - Array callbacks; - for (const auto& rewrite : rewrites) { - auto func = [rewrite](TVMArgs args, TVMRetValue* rv) { + inline DFPatternCallback MakeCallback() const { + auto func = [this](TVMArgs args, TVMRetValue* rv) { Expr pre = args[0]; Expr post = args[1]; Map> node_map = args[2]; - *rv = rewrite->Callback(pre, post, node_map); + *rv = this->Callback(pre, post, node_map); }; - callbacks.push_back( - DFPatternCallback(rewrite->Pattern(), PackedFunc(func), rewrite->RequireType())); + return DFPatternCallback(pattern_, PackedFunc(func), require_type_); } - return callbacks; -} + + protected: + /*! \brief The pattern for matching and rewriting. */ + DFPattern pattern_; + bool require_type_; +}; } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_pass_concretize_like.py b/tests/python/relay/test_pass_concretize_like.py index 792475035b88..4079c45352cf 100644 --- a/tests/python/relay/test_pass_concretize_like.py +++ b/tests/python/relay/test_pass_concretize_like.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Tests for the ConcretizeLike pass.""" +import pytest import tvm import tvm.relay.testing from tvm import relay @@ -120,3 +121,7 @@ def test_multiple(): mod = tvm.IRModule.from_expr(f) mod_concrete = relay.transform.ConcretizeLike()(mod) assert tvm.ir.structural_equal(mod_concrete["main"], f_expected) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index 897f90b9ee2a..12de0153d4ad 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -181,7 +181,45 @@ def after_right(x, elem_op, value): validate(shape, value, dtype) +def test_eliminate_identity(): + def check(x, y, do_nothing=False): + after = run_opt_pass(y, transform.EliminateIdentity()) + if do_nothing: + assert tvm.ir.structural_equal(after, y) + else: + assert tvm.ir.structural_equal(after, x) + + shape = [2, 3, 4] + dtype = "float32" + x = relay.var("x", shape=shape, dtype=dtype) + x = run_opt_pass(x, transform.InferType()) + + for (op, op_like, id_op) in [ + (relay.zeros, relay.zeros_like, relay.add), + (relay.ones, relay.ones_like, relay.multiply), + ]: + check(x, id_op(op_like(x), x)) + check(x, id_op(op(shape, dtype), x)) + check(x, id_op(x, op_like(x))) + check(x, id_op(x, op(shape, dtype))) + check(x, id_op(x, op(shape[1:], dtype))) + check(x, id_op(x, op([2] + shape, dtype)), do_nothing=True) + check(x, id_op(op([2] + shape, dtype), x), do_nothing=True) + + for (op, op_like, id_op) in [ + (relay.zeros, relay.zeros_like, relay.subtract), + (relay.ones, relay.ones_like, relay.divide), + ]: + check(x, id_op(x, op_like(x))) + check(x, id_op(x, op(shape, dtype))) + check(x, id_op(x, op(shape[1:], dtype))) + check(x, id_op(x, op([2] + shape, dtype)), do_nothing=True) + check(x, id_op(op(shape, dtype), x), do_nothing=True) + check(x, id_op(op_like(x), x), do_nothing=True) + + if __name__ == "__main__": test_simplify_reshape() test_simplify_transpose() test_simplify_full_elementwise() + test_eliminate_identity() From 8a4d3351af4a76c68fd74e5a68be8df5b982c7f8 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Tue, 23 Mar 2021 23:39:24 -0700 Subject: [PATCH 03/11] lint --- src/relay/transforms/concretize_like.cc | 4 ++-- src/relay/transforms/simplify_expr.cc | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/relay/transforms/concretize_like.cc b/src/relay/transforms/concretize_like.cc index 09d038bc35a0..8dbb4d91fa7e 100644 --- a/src/relay/transforms/concretize_like.cc +++ b/src/relay/transforms/concretize_like.cc @@ -34,7 +34,7 @@ namespace relay { class ConcretizeLikeRewrite : public DFPatternRewrite { public: - ConcretizeLikeRewrite(const Op& op) { + explicit ConcretizeLikeRewrite(const Op& op) { ICHECK(op->num_inputs == 1 || op->num_inputs == 2) << "ConcretizeLike does not handle operators that aren't unary or binary, got: " << op; like_pat_ = IsWildcard(); @@ -179,4 +179,4 @@ TVM_REGISTER_GLOBAL("relay._transform.ConcretizeLike").set_body_typed(Concretize } // namespace transform } // namespace relay -} // namespace tvm \ No newline at end of file +} // namespace tvm diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index fdf5b0f8c534..b10fcb94c3a6 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -30,6 +30,8 @@ #include #include +#include + #include "../op/tensor/transform.h" #include "pattern_utils.h" From 0dd1997a7f95b25be8ce564b6825737e5a78397f Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Wed, 24 Mar 2021 11:57:25 -0700 Subject: [PATCH 04/11] merge ConcretizeLike and EliminateIdentity into SimplifyExpr --- python/tvm/relay/transform/transform.py | 28 --- src/relay/transforms/concretize_like.cc | 182 ---------------- src/relay/transforms/pattern_utils.h | 4 +- src/relay/transforms/simplify_expr.cc | 197 +++++++++++++++--- src/relay/transforms/simplify_expr.h | 2 +- .../python/relay/test_pass_concretize_like.py | 127 ----------- tests/python/relay/test_pass_simplify_expr.py | 140 +++++++++++-- 7 files changed, 292 insertions(+), 388 deletions(-) delete mode 100644 src/relay/transforms/concretize_like.cc delete mode 100644 tests/python/relay/test_pass_concretize_like.py diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 13b66d96f949..5b0e480f5f28 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -830,19 +830,6 @@ def FirstOrderGradient(): return _ffi_api.FirstOrderGradient() -def ConcretizeLike(): - """ - Transforms `op_like` functions to their explicit-shape equivalent (e.g. `zeros_like(x, y)` - to `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary - dependencies and can enable more opportunities for operator fusion. - Returns - ------- - ret : tvm.transform.Pass - The registered ConcretizeLike pass. - """ - return _ffi_api.ConcretizeLike() - - def Defunctionalization(func, mod): """ Performs defunctionalization on func, @@ -1136,21 +1123,6 @@ def SimplifyExpr(): return _ffi_api.SimplifyExpr() -def EliminateIdentity(): - """ - Eliminates any expressions that are equivalent to identity, such as x + 0 - and x * 1. Note that these expressions cannot be eliminated when they - broadcast x to a new shape (although they could be replaced with explicit - broadcasting operations). - - Returns - ------- - ret : tvm.transform.Pass - The registered EliminateIdentity pass. - """ - return _ffi_api.EliminateIdentity() - - def FoldExplicitPadding(): """ FoldExplicitPadding finds explict padding before an op that can support diff --git a/src/relay/transforms/concretize_like.cc b/src/relay/transforms/concretize_like.cc deleted file mode 100644 index 8dbb4d91fa7e..000000000000 --- a/src/relay/transforms/concretize_like.cc +++ /dev/null @@ -1,182 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file concretize_like.cc - * \brief Converts `*_like` operators to their explicit shape equivalent (e.g. `zeros_like(x, y)` to - * `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary dependencies - * and can enable more opportunities for operator fusion. - */ - -#include - -#include "pattern_utils.h" -#include "simplify_expr.h" - -namespace tvm { -namespace relay { - -class ConcretizeLikeRewrite : public DFPatternRewrite { - public: - explicit ConcretizeLikeRewrite(const Op& op) { - ICHECK(op->num_inputs == 1 || op->num_inputs == 2) - << "ConcretizeLike does not handle operators that aren't unary or binary, got: " << op; - like_pat_ = IsWildcard(); - data_pat_ = IsWildcard(); - if (op->num_inputs == 1) { - pattern_ = IsExpr(op)({like_pat_}); - } else { - pattern_ = IsExpr(op)({data_pat_, like_pat_}); - } - require_type_ = true; - } - - virtual bool Check(const Expr& pre, const Expr& post, - const Map>& node_map) const { - const CallNode* call_node = pre.as(); - ICHECK(call_node); - - if (!call_node->checked_type_.defined()) { - // TODO(@altanh): maybe because of the input being rewritten? - return false; - } - - const TensorTypeNode* like_ty = call_node->checked_type().as(); - ICHECK(like_ty) << "got non-Tensor *_like call type " << PrettyPrint(call_node->checked_type()); - - return true; - } - - virtual Expr Concretize(const Map>& node_map, Array shape, - DataType dtype) const = 0; - - Expr Callback(const Expr& pre, const Expr& post, - const Map>& node_map) const override { - if (!Check(pre, post, node_map)) { - return post; - } - - const TensorTypeNode* like_ty = pre->checked_type().as(); - Array cshape; - for (const auto& dim : like_ty->shape) { - if (const auto* imm = dim.as()) { - cshape.push_back(Integer(GetRef(imm))); - } else { - return post; - } - } - - return Concretize(node_map, cshape, like_ty->dtype); - } - - protected: - DFPattern data_pat_; - DFPattern like_pat_; -}; - -class ConcretizeZerosLikeRewrite : public ConcretizeLikeRewrite { - public: - ConcretizeZerosLikeRewrite() : ConcretizeLikeRewrite(Op::Get("zeros_like")) {} - - Expr Concretize(const Map>& node_map, Array shape, - DataType dtype) const override { - return MakeZeros(shape, dtype); - } - - TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeZerosLikeRewrite); -}; - -class ConcretizeOnesLikeRewrite : public ConcretizeLikeRewrite { - public: - ConcretizeOnesLikeRewrite() : ConcretizeLikeRewrite(Op::Get("ones_like")) {} - - Expr Concretize(const Map>& node_map, Array shape, - DataType dtype) const override { - return MakeOnes(shape, dtype); - } - - TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeOnesLikeRewrite); -}; - -class ConcretizeReshapeLikeRewrite : public ConcretizeLikeRewrite { - public: - ConcretizeReshapeLikeRewrite() : ConcretizeLikeRewrite(Op::Get("reshape_like")) {} - - Expr Concretize(const Map>& node_map, Array shape, - DataType dtype) const override { - return MakeReshape(node_map[data_pat_][0], shape); - } - - TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeReshapeLikeRewrite); -}; - -class ConcretizeCollapseSumLikeRewrite : public ConcretizeLikeRewrite { - public: - ConcretizeCollapseSumLikeRewrite() : ConcretizeLikeRewrite(Op::Get("collapse_sum_like")) {} - - Expr Concretize(const Map>& node_map, Array shape, - DataType dtype) const override { - ICHECK_LE(shape.size(), std::numeric_limits::max()); - static const Op& op = Op::Get("collapse_sum_to"); - auto attrs = make_object(); - attrs->shape = shape; - auto cshape = - MakeConstantTensor(DataType::Int(32), {static_cast(shape.size())}, shape); - return Call(op, {node_map[data_pat_][0], cshape}, Attrs(attrs)); - } - - TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeCollapseSumLikeRewrite); -}; - -class ConcretizeBroadcastToLikeRewrite : public ConcretizeLikeRewrite { - public: - ConcretizeBroadcastToLikeRewrite() : ConcretizeLikeRewrite(Op::Get("broadcast_to_like")) {} - - Expr Concretize(const Map>& node_map, Array shape, - DataType dtype) const override { - return MakeBroadCastTo(node_map[data_pat_][0], shape); - } - - TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeBroadcastToLikeRewrite); -}; - -Expr ConcretizeLike(const Expr& expr, const IRModule& mod) { - static Array callbacks = { - ConcretizeZerosLikeRewrite::GetCallback(), ConcretizeOnesLikeRewrite::GetCallback(), - ConcretizeReshapeLikeRewrite::GetCallback(), ConcretizeCollapseSumLikeRewrite::GetCallback(), - ConcretizeBroadcastToLikeRewrite::GetCallback()}; - return RewritePatterns(callbacks, expr, mod); -} - -namespace transform { - -Pass ConcretizeLike() { - runtime::TypedPackedFunc pass_func = - [](Function f, IRModule m, PassContext pc) { - return Downcast(ConcretizeLike(f, m)); - }; - return CreateFunctionPass(pass_func, 0, "ConcretizeLike", {}); -} - -TVM_REGISTER_GLOBAL("relay._transform.ConcretizeLike").set_body_typed(ConcretizeLike); - -} // namespace transform - -} // namespace relay -} // namespace tvm diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index c1eebde15fba..cde91e217c09 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -394,7 +394,9 @@ static inline long double ToScalar(const runtime::NDArray& array, size_t i = 0) return reinterpret_cast(array->data)[i]; } } else if (array->dtype.code == kDLUInt) { - if (array->dtype.bits == 8) { + if (array->dtype.bits == 1) { // bool + return reinterpret_cast(array->data)[i]; + } else if (array->dtype.bits == 8) { return reinterpret_cast(array->data)[i]; } else if (array->dtype.bits == 16) { return reinterpret_cast(array->data)[i]; diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index b10fcb94c3a6..84470c9516fa 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -49,7 +49,6 @@ class SimplifyReshape : public DFPatternRewrite { auto reshape1 = IsOp("reshape") || IsOp("contrib_reverse_reshape"); auto reshape2 = IsOp("reshape") || IsOp("contrib_reverse_reshape"); pattern_ = reshape1({reshape2({x_})}); - require_type_ = true; } Expr Callback(const Expr& pre, const Expr& post, @@ -88,7 +87,6 @@ class SimplifyTranspose : public DFPatternRewrite { auto trans1 = IsOp("transpose") || IsOp("layout_transform"); auto trans2 = IsOp("transpose") || IsOp("layout_transform"); pattern_ = trans1({trans2({x_})}); - require_type_ = true; } Expr Callback(const Expr& pre, const Expr& post, @@ -193,7 +191,6 @@ class FullElementwise : public DFPatternRewrite { DFPattern op = IsWildcard().HasAttr(attrs); DFPattern full = full_ || ones_ || zeros_; pattern_ = op({full, x_}) || op({x_, full}); - require_type_ = true; } Expr Callback(const Expr& pre, const Expr& post, @@ -250,24 +247,162 @@ class FullElementwise : public DFPatternRewrite { DFPattern zeros_; }; -/*! \brief Eliminates expressions that are just identity. */ -class EliminateIdentity : public DFPatternRewrite { +/*! + * \brief Converts `*_like` operators to their explicit shape equivalent (e.g. `zeros_like(x, y)` to + * `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary dependencies + * and can enable more opportunities for operator fusion. + */ +class ConcretizeLikeRewrite : public DFPatternRewrite { public: - EliminateIdentity() { + explicit ConcretizeLikeRewrite(const Op& op) { + ICHECK(op->num_inputs == 1 || op->num_inputs == 2) + << "ConcretizeLike does not handle operators that aren't unary or binary, got: " << op; + like_pat_ = IsWildcard(); + data_pat_ = IsWildcard(); + if (op->num_inputs == 1) { + pattern_ = IsExpr(op)({like_pat_}); + } else { + pattern_ = IsExpr(op)({data_pat_, like_pat_}); + } + } + + virtual bool Check(const Expr& pre, const Expr& post, + const Map>& node_map) const { + const CallNode* call_node = pre.as(); + ICHECK(call_node); + + if (!call_node->checked_type().as()) { + return false; + } + + return true; + } + + virtual Expr Concretize(const Map>& node_map, Array shape, + DataType dtype) const = 0; + + Expr Callback(const Expr& pre, const Expr& post, + const Map>& node_map) const override { + if (!Check(pre, post, node_map)) { + return post; + } + + const TensorTypeNode* like_ty = pre->checked_type().as(); + Array cshape; + for (const auto& dim : like_ty->shape) { + if (const auto* imm = dim.as()) { + cshape.push_back(Integer(GetRef(imm))); + } else { + // shape is not static, don't concretize + return post; + } + } + + return Concretize(node_map, cshape, like_ty->dtype); + } + + protected: + DFPattern data_pat_; + DFPattern like_pat_; +}; + +class ConcretizeZerosLikeRewrite : public ConcretizeLikeRewrite { + public: + ConcretizeZerosLikeRewrite() : ConcretizeLikeRewrite(Op::Get("zeros_like")) {} + + Expr Concretize(const Map>& node_map, Array shape, + DataType dtype) const override { + return MakeZeros(shape, dtype); + } + + TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeZerosLikeRewrite); +}; + +class ConcretizeOnesLikeRewrite : public ConcretizeLikeRewrite { + public: + ConcretizeOnesLikeRewrite() : ConcretizeLikeRewrite(Op::Get("ones_like")) {} + + Expr Concretize(const Map>& node_map, Array shape, + DataType dtype) const override { + return MakeOnes(shape, dtype); + } + + TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeOnesLikeRewrite); +}; + +class ConcretizeReshapeLikeRewrite : public ConcretizeLikeRewrite { + public: + ConcretizeReshapeLikeRewrite() : ConcretizeLikeRewrite(Op::Get("reshape_like")) {} + + Expr Concretize(const Map>& node_map, Array shape, + DataType dtype) const override { + return MakeReshape(node_map[data_pat_][0], shape); + } + + TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeReshapeLikeRewrite); +}; + +class ConcretizeCollapseSumLikeRewrite : public ConcretizeLikeRewrite { + public: + ConcretizeCollapseSumLikeRewrite() : ConcretizeLikeRewrite(Op::Get("collapse_sum_like")) {} + + Expr Concretize(const Map>& node_map, Array shape, + DataType dtype) const override { + ICHECK_LE(shape.size(), std::numeric_limits::max()); + static const Op& op = Op::Get("collapse_sum_to"); + auto attrs = make_object(); + attrs->shape = shape; + auto cshape = + MakeConstantTensor(DataType::Int(32), {static_cast(shape.size())}, shape); + return Call(op, {node_map[data_pat_][0], cshape}, Attrs(attrs)); + } + + TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeCollapseSumLikeRewrite); +}; + +class ConcretizeBroadcastToLikeRewrite : public ConcretizeLikeRewrite { + public: + ConcretizeBroadcastToLikeRewrite() : ConcretizeLikeRewrite(Op::Get("broadcast_to_like")) {} + + Expr Concretize(const Map>& node_map, Array shape, + DataType dtype) const override { + return MakeBroadCastTo(node_map[data_pat_][0], shape); + } + + TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeBroadcastToLikeRewrite); +}; + +/*! \brief Eliminates expressions that are equivalent to identity. */ +class EliminateIdentityRewrite : public DFPatternRewrite { + public: + EliminateIdentityRewrite() { x_ = IsWildcard(); + const_ = IsConstant(); DFPattern add_op = IsOp("add"); DFPattern mul_op = IsOp("multiply"); - DFPattern zeros_call = IsOp("zeros")({}) || IsOp("zeros_like")({IsWildcard()}); - DFPattern ones_call = IsOp("ones")({}) || IsOp("ones_like")({IsWildcard()}); + DFPattern zeros_expr = IsOp("zeros")({}) || IsOp("zeros_like")({IsWildcard()}) || const_; + DFPattern ones_expr = IsOp("ones")({}) || IsOp("ones_like")({IsWildcard()}) || const_; - DFPattern add_id = add_op({x_, zeros_call}) || add_op({zeros_call, x_}); - DFPattern mul_id = mul_op({x_, ones_call}) || mul_op({ones_call, x_}); - DFPattern sub_id = IsOp("subtract")({x_, zeros_call}); - DFPattern div_id = IsOp("divide")({x_, ones_call}); + DFPattern add_id = add_op({x_, zeros_expr}) || add_op({zeros_expr, x_}); + DFPattern mul_id = mul_op({x_, ones_expr}) || mul_op({ones_expr, x_}); + DFPattern sub_id = IsOp("subtract")({x_, zeros_expr}); + DFPattern div_id = IsOp("divide")({x_, ones_expr}); pattern_ = add_id || mul_id || sub_id || div_id; - require_type_ = true; + } + + bool CheckConstant(const OpNode* op, const ConstantNode* constant) const { + if (!IsScalar(GetRef(constant))) { + return false; + } + long double value = ToScalar(constant->data); + if (op->name == "add" || op->name == "subtract") { + return value == 0.0; + } else if (op->name == "multiply" || op->name == "divide") { + return value == 1.0; + } + return false; } Expr Callback(const Expr& pre, const Expr& post, @@ -285,6 +420,17 @@ class EliminateIdentity : public DFPatternRewrite { x_type = call->args[0]->checked_type_; } + if (node_map.count(const_)) { + // the other argument is a Constant in this case + const ConstantNode* constant = node_map[const_][0].as(); + const OpNode* op = call->op.as(); + ICHECK(constant); + ICHECK(op); + if (!CheckConstant(op, constant)) { + return post; + } + } + if (StructuralEqual()(x_type, pre_type)) { return x; } @@ -292,18 +438,21 @@ class EliminateIdentity : public DFPatternRewrite { return post; } - TVM_DF_PATTERN_REWRITE_GETTER(EliminateIdentity); + TVM_DF_PATTERN_REWRITE_GETTER(EliminateIdentityRewrite); private: DFPattern x_; + DFPattern const_; }; -Expr EliminateIdentity(const Expr& expr, const IRModule& mod) { - return RewritePatterns({EliminateIdentity::GetCallback()}, expr, mod); -} - Expr SimplifyExpr(const Expr& expr, const IRModule& mod) { - static Array callbacks = {SimplifyReshape::GetCallback(), + static Array callbacks = {ConcretizeZerosLikeRewrite::GetCallback(), + ConcretizeOnesLikeRewrite::GetCallback(), + ConcretizeReshapeLikeRewrite::GetCallback(), + ConcretizeCollapseSumLikeRewrite::GetCallback(), + ConcretizeBroadcastToLikeRewrite::GetCallback(), + EliminateIdentityRewrite::GetCallback(), + SimplifyReshape::GetCallback(), SimplifyTranspose::GetCallback(), FullElementwise::GetCallback()}; return RewritePatterns(callbacks, expr, mod); @@ -311,16 +460,6 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) { namespace transform { -Pass EliminateIdentity() { - runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(EliminateIdentity(f, m)); - }; - return CreateFunctionPass(pass_func, 0, "EliminateIdentity", {"InferType"}); -} - -TVM_REGISTER_GLOBAL("relay._transform.EliminateIdentity").set_body_typed(EliminateIdentity); - Pass SimplifyExpr() { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { diff --git a/src/relay/transforms/simplify_expr.h b/src/relay/transforms/simplify_expr.h index 913fbc1f7ba3..3aa05133b536 100644 --- a/src/relay/transforms/simplify_expr.h +++ b/src/relay/transforms/simplify_expr.h @@ -69,7 +69,7 @@ class DFPatternRewrite { protected: /*! \brief The pattern for matching and rewriting. */ DFPattern pattern_; - bool require_type_; + bool require_type_ = true; }; } // namespace relay diff --git a/tests/python/relay/test_pass_concretize_like.py b/tests/python/relay/test_pass_concretize_like.py deleted file mode 100644 index 4079c45352cf..000000000000 --- a/tests/python/relay/test_pass_concretize_like.py +++ /dev/null @@ -1,127 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Tests for the ConcretizeLike pass.""" -import pytest -import tvm -import tvm.relay.testing -from tvm import relay -from tvm.relay.testing import run_infer_type - - -def test_reshape_like(): - data = relay.var("data", shape=(2, 3, 4), dtype="float32") - shape_like = relay.var("shape_like", shape=(6, 2, 2), dtype="float32") - f = relay.Function([data, shape_like], relay.reshape_like(data, shape_like)) - f_expected = relay.Function([data, shape_like], relay.reshape(data, (6, 2, 2))) - f_expected = run_infer_type(f_expected) - - mod = tvm.IRModule.from_expr(f) - mod_concrete = relay.transform.ConcretizeLike()(mod) - assert tvm.ir.structural_equal(mod_concrete["main"], f_expected) - - -def test_reshape_like_attrs(): - data = relay.var("data", shape=(2, 3, 4), dtype="float32") - shape_like = relay.var("shape_like", shape=(6, 2, 2), dtype="float32") - f = relay.Function( - [data, shape_like], relay.reshape_like(data, shape_like, lhs_begin=2, rhs_begin=1) - ) - f_expected = relay.Function([data, shape_like], relay.reshape(data, (2, 3, 2, 2))) - f_expected = run_infer_type(f_expected) - - mod = tvm.IRModule.from_expr(f) - mod_concrete = relay.transform.ConcretizeLike()(mod) - assert tvm.ir.structural_equal(mod_concrete["main"], f_expected) - - -def test_zeros_like(): - dtype = "int32" - shape_like = relay.var("shape_like", shape=(3, 4, 5), dtype=dtype) - f = relay.Function([shape_like], relay.zeros_like(shape_like)) - f_expected = relay.Function([shape_like], relay.zeros((3, 4, 5), dtype)) - f_expected = run_infer_type(f_expected) - - mod = tvm.IRModule.from_expr(f) - mod_concrete = relay.transform.ConcretizeLike()(mod) - assert tvm.ir.structural_equal(mod_concrete["main"], f_expected) - - -def test_ones_like(): - dtype = "int32" - shape_like = relay.var("shape_like", shape=(3, 4, 5), dtype=dtype) - f = relay.Function([shape_like], relay.ones_like(shape_like)) - f_expected = relay.Function([shape_like], relay.ones((3, 4, 5), dtype)) - f_expected = run_infer_type(f_expected) - - mod = tvm.IRModule.from_expr(f) - mod_concrete = relay.transform.ConcretizeLike()(mod) - assert tvm.ir.structural_equal(mod_concrete["main"], f_expected) - - -def test_collapse_sum_like(): - data = relay.var("data", shape=(3, 3, 3), dtype="float32") - shape_like = relay.var("shape_like", shape=(3,), dtype="float32") - f = relay.Function([data, shape_like], relay.collapse_sum_like(data, shape_like)) - f_expected = relay.Function([data, shape_like], relay.collapse_sum_to(data, (3,))) - f_expected = run_infer_type(f_expected) - - mod = tvm.IRModule.from_expr(f) - mod_concrete = relay.transform.ConcretizeLike()(mod) - assert tvm.ir.structural_equal(mod_concrete["main"], f_expected) - - -def test_broadcast_to_like(): - data = relay.var("data", shape=(3,), dtype="float32") - shape_like = relay.var("shape_like", shape=(3, 3, 3), dtype="float32") - f = relay.Function([data, shape_like], relay.broadcast_to_like(data, shape_like)) - f_expected = relay.Function([data, shape_like], relay.broadcast_to(data, (3, 3, 3))) - f_expected = run_infer_type(f_expected) - - mod = tvm.IRModule.from_expr(f) - mod_concrete = relay.transform.ConcretizeLike()(mod) - assert tvm.ir.structural_equal(mod_concrete["main"], f_expected) - - -def test_multiple(): - x = relay.var("x", shape=(2, 3), dtype="float32") - y = relay.var("x", shape=(3,), dtype="float32") - l = x + y - - dl = relay.ones_like(l) - dx = relay.zeros_like(x) - dy = relay.zeros_like(y) - dx = dx + relay.collapse_sum_like(dl, dx) - dy = dy + relay.collapse_sum_like(dl, dy) - ret = relay.Tuple([dx, dy]) - f = relay.Function([x, y], ret) - - dl_c = relay.ones((2, 3), "float32") - dx_c = relay.zeros((2, 3), "float32") - dy_c = relay.zeros((3,), "float32") - dx_c = dx_c + relay.collapse_sum_to(dl_c, (2, 3)) - dy_c = dy_c + relay.collapse_sum_to(dl_c, (3,)) - ret_c = relay.Tuple([dx_c, dy_c]) - f_expected = relay.Function([x, y], ret_c) - f_expected = run_infer_type(f_expected) - - mod = tvm.IRModule.from_expr(f) - mod_concrete = relay.transform.ConcretizeLike()(mod) - assert tvm.ir.structural_equal(mod_concrete["main"], f_expected) - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index 12de0153d4ad..d015cdd36c2d 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -14,10 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest import tvm from tvm import relay from tvm.relay import transform -from tvm.relay.testing import run_opt_pass +from tvm.relay.testing import run_opt_pass, run_infer_type import numpy as np @@ -123,12 +124,22 @@ def before_left(x, elem_op, full): return elem_op(full, x) def after_left(x, elem_op, value): + if elem_op == relay.add and value == 0: + return x + elif elem_op == relay.multiply and (value == 1 or (value > 1 and dtype == "bool")): + return x return elem_op(relay.const(value, dtype), x) def before_right(x, elem_op, full): return elem_op(x, full) def after_right(x, elem_op, value): + if elem_op in [relay.add, relay.subtract] and value == 0: + return x + elif elem_op in [relay.multiply, relay.divide] and ( + value == 1 or (value > 1 and dtype == "bool") + ): + return x return elem_op(x, relay.const(value, dtype)) x = relay.var("x", shape=shape, dtype=dtype) @@ -182,44 +193,133 @@ def after_right(x, elem_op, value): def test_eliminate_identity(): - def check(x, y, do_nothing=False): - after = run_opt_pass(y, transform.EliminateIdentity()) + def check(x, y=None, do_nothing=False): + expected = run_infer_type(x) if do_nothing: - assert tvm.ir.structural_equal(after, y) + actual = run_opt_pass(x, transform.SimplifyExpr()) + assert tvm.ir.structural_equal(actual, expected) else: - assert tvm.ir.structural_equal(after, x) + assert y is not None + actual = run_opt_pass(y, transform.SimplifyExpr()) + assert tvm.ir.structural_equal(actual, expected) shape = [2, 3, 4] dtype = "float32" x = relay.var("x", shape=shape, dtype=dtype) x = run_opt_pass(x, transform.InferType()) - for (op, op_like, id_op) in [ - (relay.zeros, relay.zeros_like, relay.add), - (relay.ones, relay.ones_like, relay.multiply), + for (op, op_like, id_op, const) in [ + (relay.zeros, relay.zeros_like, relay.add, relay.const(0, dtype)), + (relay.ones, relay.ones_like, relay.multiply, relay.const(1, dtype)), ]: check(x, id_op(op_like(x), x)) check(x, id_op(op(shape, dtype), x)) + check(x, id_op(const, x)) + check(x, id_op(op(shape[1:], dtype), x)) check(x, id_op(x, op_like(x))) check(x, id_op(x, op(shape, dtype))) + check(x, id_op(x, const)) check(x, id_op(x, op(shape[1:], dtype))) - check(x, id_op(x, op([2] + shape, dtype)), do_nothing=True) - check(x, id_op(op([2] + shape, dtype), x), do_nothing=True) + check(id_op(x, op([2] + shape, dtype)), do_nothing=True) + check(id_op(op([2] + shape, dtype), x), do_nothing=True) - for (op, op_like, id_op) in [ - (relay.zeros, relay.zeros_like, relay.subtract), - (relay.ones, relay.ones_like, relay.divide), + for (op, op_like, id_op, const) in [ + (relay.zeros, relay.zeros_like, relay.subtract, relay.const(0, dtype)), + (relay.ones, relay.ones_like, relay.divide, relay.const(1, dtype)), ]: check(x, id_op(x, op_like(x))) + check(x, id_op(x, const)) check(x, id_op(x, op(shape, dtype))) check(x, id_op(x, op(shape[1:], dtype))) - check(x, id_op(x, op([2] + shape, dtype)), do_nothing=True) - check(x, id_op(op(shape, dtype), x), do_nothing=True) - check(x, id_op(op_like(x), x), do_nothing=True) + check(id_op(x, op([2] + shape, dtype)), do_nothing=True) + check(id_op(const, x), id_op(op(shape, dtype), x)) + check(id_op(const, x), id_op(op_like(x), x)) + + +def test_concretize_reshape_like(): + data = relay.var("data", shape=(2, 3, 4), dtype="float32") + shape_like = relay.var("shape_like", shape=(6, 2, 2), dtype="float32") + expr = relay.reshape_like(data, shape_like) + + expected = run_infer_type(relay.reshape(data, (6, 2, 2))) + actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) + assert tvm.ir.structural_equal(actual, expected) + + +def test_concretize_reshape_like_attrs(): + data = relay.var("data", shape=(2, 3, 4), dtype="float32") + shape_like = relay.var("shape_like", shape=(6, 2, 2), dtype="float32") + expr = relay.reshape_like(data, shape_like, lhs_begin=2, rhs_begin=1) + + expected = run_infer_type(relay.reshape(data, (2, 3, 2, 2))) + actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) + assert tvm.ir.structural_equal(actual, expected) + + +def test_concretize_zeros_like(): + dtype = "int32" + shape_like = relay.var("shape_like", shape=(3, 4, 5), dtype=dtype) + expr = relay.zeros_like(shape_like) + + expected = run_infer_type(relay.zeros((3, 4, 5), dtype)) + actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) + assert tvm.ir.structural_equal(actual, expected) + + +def test_concretize_ones_like(): + dtype = "int32" + shape_like = relay.var("shape_like", shape=(3, 4, 5), dtype=dtype) + expr = relay.ones_like(shape_like) + + expected = run_infer_type(relay.ones((3, 4, 5), dtype)) + actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) + assert tvm.ir.structural_equal(actual, expected) + + +def test_concretize_collapse_sum_like(): + data = relay.var("data", shape=(3, 3, 3), dtype="float32") + shape_like = relay.var("shape_like", shape=(3,), dtype="float32") + expr = relay.collapse_sum_like(data, shape_like) + + expected = run_infer_type(relay.collapse_sum_to(data, (3,))) + actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) + assert tvm.ir.structural_equal(actual, expected) + + +def test_concretize_broadcast_to_like(): + data = relay.var("data", shape=(3,), dtype="float32") + shape_like = relay.var("shape_like", shape=(3, 3, 3), dtype="float32") + expr = relay.broadcast_to_like(data, shape_like) + + expected = run_infer_type(relay.broadcast_to(data, (3, 3, 3))) + actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) + assert tvm.ir.structural_equal(actual, expected) + + +def test_concretize_multiple(): + x = relay.var("x", shape=(2, 3), dtype="float32") + y = relay.var("y", shape=(3,), dtype="float32") + l = x + y + + dl = relay.ones_like(l) + dx = relay.zeros_like(x) + dy = relay.zeros_like(y) + dx = dx + relay.collapse_sum_like(dl, dx) + dy = dy + relay.collapse_sum_like(dl, dy) + ret = relay.Tuple([dx, dy]) + + dl_c = relay.ones((2, 3), "float32") + # NOTE: these are removed by EliminateIdentity + # dx_c = relay.zeros((2, 3), "float32") + # dy_c = relay.zeros((3,), "float32") + dx_c = relay.collapse_sum_to(dl_c, (2, 3)) + dy_c = relay.collapse_sum_to(dl_c, (3,)) + ret_c = relay.Tuple([dx_c, dy_c]) + + expected = run_infer_type(ret_c) + actual = run_opt_pass(ret, relay.transform.SimplifyExpr()) + assert tvm.ir.structural_equal(actual, expected) if __name__ == "__main__": - test_simplify_reshape() - test_simplify_transpose() - test_simplify_full_elementwise() - test_eliminate_identity() + pytest.main([__file__]) From ed1cba420d940c209b20ce142abf0c86ccbfc1f8 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Wed, 24 Mar 2021 12:49:09 -0700 Subject: [PATCH 05/11] nits and lint --- src/relay/transforms/simplify_expr.cc | 1 + src/relay/transforms/simplify_expr.h | 1 + 2 files changed, 2 insertions(+) diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index 84470c9516fa..57fdca4aa3a9 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -31,6 +31,7 @@ #include #include +#include #include "../op/tensor/transform.h" #include "pattern_utils.h" diff --git a/src/relay/transforms/simplify_expr.h b/src/relay/transforms/simplify_expr.h index 3aa05133b536..8ea82b45544b 100644 --- a/src/relay/transforms/simplify_expr.h +++ b/src/relay/transforms/simplify_expr.h @@ -69,6 +69,7 @@ class DFPatternRewrite { protected: /*! \brief The pattern for matching and rewriting. */ DFPattern pattern_; + /*! \brief Whether or not the rewrite requires types to be inferred. */ bool require_type_ = true; }; From 0090fe2a61c75cfc73d5b13383b2065c7eb9dc7a Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Wed, 24 Mar 2021 15:14:25 -0700 Subject: [PATCH 06/11] remove static stuff --- src/relay/transforms/simplify_expr.cc | 47 +++++++++------------------ src/relay/transforms/simplify_expr.h | 33 ++++++++++++------- 2 files changed, 37 insertions(+), 43 deletions(-) diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index 57fdca4aa3a9..19973e0f44a9 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -30,8 +30,8 @@ #include #include -#include #include +#include #include "../op/tensor/transform.h" #include "pattern_utils.h" @@ -70,8 +70,6 @@ class SimplifyReshape : public DFPatternRewrite { return post; } - TVM_DF_PATTERN_REWRITE_GETTER(SimplifyReshape) - private: /*! \brief Pattern input */ DFPattern x_; @@ -165,8 +163,6 @@ class SimplifyTranspose : public DFPatternRewrite { return x; } - TVM_DF_PATTERN_REWRITE_GETTER(SimplifyTranspose); - private: /*! \brief Pattern input */ DFPattern x_; @@ -231,8 +227,6 @@ class FullElementwise : public DFPatternRewrite { return post; } - TVM_DF_PATTERN_REWRITE_GETTER(FullElementwise); - private: /*! \brief binary argument */ DFPattern x_; @@ -315,8 +309,6 @@ class ConcretizeZerosLikeRewrite : public ConcretizeLikeRewrite { DataType dtype) const override { return MakeZeros(shape, dtype); } - - TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeZerosLikeRewrite); }; class ConcretizeOnesLikeRewrite : public ConcretizeLikeRewrite { @@ -327,8 +319,6 @@ class ConcretizeOnesLikeRewrite : public ConcretizeLikeRewrite { DataType dtype) const override { return MakeOnes(shape, dtype); } - - TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeOnesLikeRewrite); }; class ConcretizeReshapeLikeRewrite : public ConcretizeLikeRewrite { @@ -339,8 +329,6 @@ class ConcretizeReshapeLikeRewrite : public ConcretizeLikeRewrite { DataType dtype) const override { return MakeReshape(node_map[data_pat_][0], shape); } - - TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeReshapeLikeRewrite); }; class ConcretizeCollapseSumLikeRewrite : public ConcretizeLikeRewrite { @@ -357,8 +345,6 @@ class ConcretizeCollapseSumLikeRewrite : public ConcretizeLikeRewrite { MakeConstantTensor(DataType::Int(32), {static_cast(shape.size())}, shape); return Call(op, {node_map[data_pat_][0], cshape}, Attrs(attrs)); } - - TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeCollapseSumLikeRewrite); }; class ConcretizeBroadcastToLikeRewrite : public ConcretizeLikeRewrite { @@ -369,8 +355,6 @@ class ConcretizeBroadcastToLikeRewrite : public ConcretizeLikeRewrite { DataType dtype) const override { return MakeBroadCastTo(node_map[data_pat_][0], shape); } - - TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeBroadcastToLikeRewrite); }; /*! \brief Eliminates expressions that are equivalent to identity. */ @@ -385,8 +369,10 @@ class EliminateIdentityRewrite : public DFPatternRewrite { DFPattern zeros_expr = IsOp("zeros")({}) || IsOp("zeros_like")({IsWildcard()}) || const_; DFPattern ones_expr = IsOp("ones")({}) || IsOp("ones_like")({IsWildcard()}) || const_; - DFPattern add_id = add_op({x_, zeros_expr}) || add_op({zeros_expr, x_}); - DFPattern mul_id = mul_op({x_, ones_expr}) || mul_op({ones_expr, x_}); + // add and multiply are commutative so we don't need another pattern for reversed args + DFPattern add_id = add_op({x_, zeros_expr}); + DFPattern mul_id = mul_op({x_, ones_expr}); + DFPattern sub_id = IsOp("subtract")({x_, zeros_expr}); DFPattern div_id = IsOp("divide")({x_, ones_expr}); @@ -439,24 +425,23 @@ class EliminateIdentityRewrite : public DFPatternRewrite { return post; } - TVM_DF_PATTERN_REWRITE_GETTER(EliminateIdentityRewrite); - private: DFPattern x_; DFPattern const_; }; Expr SimplifyExpr(const Expr& expr, const IRModule& mod) { - static Array callbacks = {ConcretizeZerosLikeRewrite::GetCallback(), - ConcretizeOnesLikeRewrite::GetCallback(), - ConcretizeReshapeLikeRewrite::GetCallback(), - ConcretizeCollapseSumLikeRewrite::GetCallback(), - ConcretizeBroadcastToLikeRewrite::GetCallback(), - EliminateIdentityRewrite::GetCallback(), - SimplifyReshape::GetCallback(), - SimplifyTranspose::GetCallback(), - FullElementwise::GetCallback()}; - return RewritePatterns(callbacks, expr, mod); + DFPatternRewriteComposer composer; + composer.AddRewrite(); + composer.AddRewrite(); + composer.AddRewrite(); + composer.AddRewrite(); + composer.AddRewrite(); + composer.AddRewrite(); + composer.AddRewrite(); + composer.AddRewrite(); + composer.AddRewrite(); + return RewritePatterns(composer.MakeCallbacks(), expr, mod); } namespace transform { diff --git a/src/relay/transforms/simplify_expr.h b/src/relay/transforms/simplify_expr.h index 8ea82b45544b..c0efa73f91b0 100644 --- a/src/relay/transforms/simplify_expr.h +++ b/src/relay/transforms/simplify_expr.h @@ -28,22 +28,11 @@ #include #include +#include namespace tvm { namespace relay { -/*! \brief Defines a static function `RewriteType::Get()` that returns a statically initialized - * instance of RewriteType. */ -#define TVM_DF_PATTERN_REWRITE_GETTER(RewriteType) \ - static DFPatternRewrite* Get() { \ - static RewriteType rw; \ - return &rw; \ - } \ - static DFPatternCallback GetCallback() { \ - static DFPatternCallback cb = RewriteType::Get()->MakeCallback(); \ - return cb; \ - } - /*! \brief A wrapper class defining a rewrite matching a specific pattern. */ class DFPatternRewrite { public: @@ -73,6 +62,26 @@ class DFPatternRewrite { bool require_type_ = true; }; +/*! \brief Helper class for composing rewrites and getting callbacks. */ +class DFPatternRewriteComposer { + public: + template + inline void AddRewrite(Args... args) { + rewrites_.push_back(std::make_shared(&args...)); + } + + inline Array MakeCallbacks() const { + Array callbacks; + for (const auto rewrite : rewrites_) { + callbacks.push_back(rewrite->MakeCallback()); + } + return callbacks; + } + + private: + std::vector> rewrites_; +}; + } // namespace relay } // namespace tvm From 5700fe3cfbb723596e0255887396d32e65a50f7a Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Wed, 24 Mar 2021 15:19:45 -0700 Subject: [PATCH 07/11] document --- src/relay/transforms/simplify_expr.cc | 1 + src/relay/transforms/simplify_expr.h | 1 + 2 files changed, 2 insertions(+) diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index 19973e0f44a9..f649e027b16c 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -431,6 +431,7 @@ class EliminateIdentityRewrite : public DFPatternRewrite { }; Expr SimplifyExpr(const Expr& expr, const IRModule& mod) { + // the rewrites will be applied in the given order, and repeated until fixed point DFPatternRewriteComposer composer; composer.AddRewrite(); composer.AddRewrite(); diff --git a/src/relay/transforms/simplify_expr.h b/src/relay/transforms/simplify_expr.h index c0efa73f91b0..952dcc87f8a0 100644 --- a/src/relay/transforms/simplify_expr.h +++ b/src/relay/transforms/simplify_expr.h @@ -79,6 +79,7 @@ class DFPatternRewriteComposer { } private: + /*! \brief the rewrites to be composed. */ std::vector> rewrites_; }; From a14142ef339d80906379a9fb3fcc6e3e5b04bd95 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Wed, 24 Mar 2021 16:25:13 -0700 Subject: [PATCH 08/11] definitely ran clang-format but ok --- src/relay/transforms/simplify_expr.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/transforms/simplify_expr.h b/src/relay/transforms/simplify_expr.h index 952dcc87f8a0..b24fe1d2d6d2 100644 --- a/src/relay/transforms/simplify_expr.h +++ b/src/relay/transforms/simplify_expr.h @@ -27,8 +27,8 @@ #include #include -#include #include +#include namespace tvm { namespace relay { @@ -65,7 +65,7 @@ class DFPatternRewrite { /*! \brief Helper class for composing rewrites and getting callbacks. */ class DFPatternRewriteComposer { public: - template + template inline void AddRewrite(Args... args) { rewrites_.push_back(std::make_shared(&args...)); } From 4887dcabb4ab9c32885c3d0a4c07d369d9c7fb78 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Wed, 24 Mar 2021 21:06:08 -0700 Subject: [PATCH 09/11] make ToScalar return optional, fix missing virtual destructor --- src/relay/op/tensor/transform.cc | 9 +++-- src/relay/transforms/dynamic_to_static.cc | 28 ++++++++------ src/relay/transforms/pattern_utils.h | 45 ++++++++++++----------- src/relay/transforms/simplify_expr.cc | 10 +++-- src/relay/transforms/simplify_expr.h | 2 + 5 files changed, 54 insertions(+), 40 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index b65068bd0506..1ed92f53b752 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1438,10 +1438,11 @@ bool ArangeRel(const Array& types, int num_inputs, const Attrs& raw_attrs, if ((cstart = attrs->start.as()) && (cstop = attrs->stop.as()) && (cstep = attrs->step.as())) { - double start = ToScalar(cstart->data); - double stop = ToScalar(cstop->data); - double step = ToScalar(cstep->data); - int32_t num_elem = static_cast(std::ceil((stop - start) / step)); + auto start = ToScalar(cstart->data); + auto stop = ToScalar(cstop->data); + auto step = ToScalar(cstep->data); + int32_t num_elem = + static_cast(std::ceil((stop.value() - start.value()) / step.value())); ICHECK_GT(num_elem, 0) << "Invalid arange attributes (start, stop, step): " << attrs->start << ", " << attrs->stop << ", " << attrs->step; reporter->Assign(types[3], TensorType({num_elem}, attrs->dtype)); diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index 815e4d224cc5..cd88df99fa4c 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -60,8 +60,9 @@ class DynamicToStaticMutator : public MixedModeMutator { if (const ConstantNode* k = args[1].as()) { const TopKAttrs* param = call_node->attrs.as(); ICHECK(param); - return MakeTopK(call_node->args[0], static_cast(ToScalar(k->data, 0)), - param->axis, param->ret_type, param->is_ascend, param->dtype); + auto k_val = ToScalar(k->data, 0); + return MakeTopK(call_node->args[0], static_cast(k_val.value()), param->axis, + param->ret_type, param->is_ascend, param->dtype); } return Expr(nullptr); }}, @@ -100,9 +101,9 @@ class DynamicToStaticMutator : public MixedModeMutator { if (const ConstantNode* depth = args[3].as()) { const OneHotAttrs* param = call_node->attrs.as(); ICHECK(param); + auto depth_val = ToScalar(depth->data, 0); return MakeOneHot(call_node->args[0], call_node->args[1], call_node->args[2], - static_cast(ToScalar(depth->data, 0)), param->axis, - param->dtype); + static_cast(depth_val.value()), param->axis, param->dtype); } return Expr(nullptr); }}, @@ -143,9 +144,10 @@ class DynamicToStaticMutator : public MixedModeMutator { ICHECK_EQ(scale_w->data->ndim, 0); const UpSamplingAttrs* param = call_node->attrs.as(); ICHECK(param); - return MakeUpSampling(call_node->args[0], ToScalar(scale_h->data), - ToScalar(scale_w->data), param->layout, param->method, - param->align_corners); + auto scale_h_val = ToScalar(scale_h->data); + auto scale_w_val = ToScalar(scale_w->data); + return MakeUpSampling(call_node->args[0], scale_h_val.value(), scale_w_val.value(), + param->layout, param->method, param->align_corners); } return Expr(nullptr); }}, @@ -161,10 +163,11 @@ class DynamicToStaticMutator : public MixedModeMutator { ICHECK_EQ(scale_w->data->ndim, 0); const UpSampling3DAttrs* param = call_node->attrs.as(); ICHECK(param); - - return MakeUpSampling3D(call_node->args[0], ToScalar(scale_d->data), - ToScalar(scale_h->data), ToScalar(scale_w->data), - param->layout, param->method, + auto scale_d_val = ToScalar(scale_d->data); + auto scale_h_val = ToScalar(scale_h->data); + auto scale_w_val = ToScalar(scale_w->data); + return MakeUpSampling3D(call_node->args[0], scale_d_val.value(), scale_h_val.value(), + scale_w_val.value(), param->layout, param->method, param->coordinate_transformation_mode); } return Expr(nullptr); @@ -180,7 +183,8 @@ class DynamicToStaticMutator : public MixedModeMutator { const PadAttrs* param = call_node->attrs.as(); ICHECK(param); - return MakePad(call_node->args[0], ToMatrix(pad_width->data), ToScalar(pad_fill->data), + auto pad_fill_val = ToScalar(pad_fill->data); + return MakePad(call_node->args[0], ToMatrix(pad_width->data), pad_fill_val.value(), param->pad_mode); } return Expr(nullptr); diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index cde91e217c09..9c4051f728f8 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -27,6 +27,7 @@ #define TVM_RELAY_TRANSFORMS_PATTERN_UTILS_H_ #include +#include #include #include #include @@ -382,43 +383,45 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) { * \param i element index * \return Converted scalar value. */ -static inline long double ToScalar(const runtime::NDArray& array, size_t i = 0) { +static inline dmlc::optional ToScalar(const runtime::NDArray& array, size_t i = 0, bool allow_fail = false) { if (array->dtype.code == kDLInt) { if (array->dtype.bits == 8) { - return reinterpret_cast(array->data)[i]; + return dmlc::optional(reinterpret_cast(array->data)[i]); } else if (array->dtype.bits == 16) { - return reinterpret_cast(array->data)[i]; + return dmlc::optional(reinterpret_cast(array->data)[i]); } else if (array->dtype.bits == 32) { - return reinterpret_cast(array->data)[i]; + return dmlc::optional(reinterpret_cast(array->data)[i]); } else if (array->dtype.bits == 64) { - return reinterpret_cast(array->data)[i]; + return dmlc::optional(reinterpret_cast(array->data)[i]); } } else if (array->dtype.code == kDLUInt) { if (array->dtype.bits == 1) { // bool - return reinterpret_cast(array->data)[i]; + return dmlc::optional(reinterpret_cast(array->data)[i]); } else if (array->dtype.bits == 8) { - return reinterpret_cast(array->data)[i]; + return dmlc::optional(reinterpret_cast(array->data)[i]); } else if (array->dtype.bits == 16) { - return reinterpret_cast(array->data)[i]; + return dmlc::optional(reinterpret_cast(array->data)[i]); } else if (array->dtype.bits == 32) { - return reinterpret_cast(array->data)[i]; + return dmlc::optional(reinterpret_cast(array->data)[i]); } else if (array->dtype.bits == 64) { - return reinterpret_cast(array->data)[i]; + return dmlc::optional(reinterpret_cast(array->data)[i]); } } else if (array->dtype.code == kDLFloat) { if (array->dtype.bits == 16) { - return __extendXfYf2__( - reinterpret_cast(array->data)[i]); + return dmlc::optional( + __extendXfYf2__( + reinterpret_cast(array->data)[i])); } if (array->dtype.bits == 32) { - return reinterpret_cast(array->data)[i]; + return dmlc::optional(reinterpret_cast(array->data)[i]); } else if (array->dtype.bits == 64) { - return reinterpret_cast(array->data)[i]; + return dmlc::optional(reinterpret_cast(array->data)[i]); } } - LOG(FATAL) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype); - // make compiler happy - return -std::numeric_limits::infinity(); + if (!allow_fail) { + LOG(FATAL) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype); + } + return dmlc::optional(); } /*! @@ -432,8 +435,8 @@ static inline Array ToVector(const runtime::NDArray& array) { size_t len = array.Shape().front(); Array out; for (size_t i = 0; i < len; ++i) { - long double elem_val = ToScalar(array, i); - out.push_back(Integer(IntImm(DataType::Int(32), static_cast(elem_val)))); + auto elem_val = ToScalar(array, i); + out.push_back(Integer(IntImm(DataType::Int(32), static_cast(elem_val.value())))); } return out; } @@ -454,8 +457,8 @@ static inline Array> ToMatrix(const runtime::NDArray& array) { for (size_t i = 0; i < dim1; ++i) { Array inner_out; for (size_t j = 0; j < dim2; ++j) { - double elem_val = ToScalar(array, i * dim2 + j); - inner_out.push_back(Integer(static_cast(elem_val))); + auto elem_val = ToScalar(array, i * dim2 + j); + inner_out.push_back(Integer(static_cast(elem_val.value()))); } out.push_back(inner_out); } diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index f649e027b16c..f1fa5a5d4428 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -383,11 +383,15 @@ class EliminateIdentityRewrite : public DFPatternRewrite { if (!IsScalar(GetRef(constant))) { return false; } - long double value = ToScalar(constant->data); + auto value = ToScalar(constant->data, 0, true); + if (!value) { + // unsupported dtype + return false; + } if (op->name == "add" || op->name == "subtract") { - return value == 0.0; + return value.value() == 0.0; } else if (op->name == "multiply" || op->name == "divide") { - return value == 1.0; + return value.value() == 1.0; } return false; } diff --git a/src/relay/transforms/simplify_expr.h b/src/relay/transforms/simplify_expr.h index b24fe1d2d6d2..6b3925e6b007 100644 --- a/src/relay/transforms/simplify_expr.h +++ b/src/relay/transforms/simplify_expr.h @@ -40,6 +40,8 @@ class DFPatternRewrite { virtual Expr Callback(const Expr& pre, const Expr& post, const Map>& node_map) const = 0; + virtual ~DFPatternRewrite() = default; + /*! \brief Returns the pattern to be used for matching and rewriting. */ inline DFPattern Pattern() const { return pattern_; } From a2b8bd7ce991277c0a6442ad61dfbf72784ba270 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Thu, 25 Mar 2021 00:17:24 -0700 Subject: [PATCH 10/11] lint --- src/relay/transforms/pattern_utils.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index 9c4051f728f8..da46ceeb247f 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -383,7 +383,8 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) { * \param i element index * \return Converted scalar value. */ -static inline dmlc::optional ToScalar(const runtime::NDArray& array, size_t i = 0, bool allow_fail = false) { +static inline dmlc::optional ToScalar(const runtime::NDArray& array, size_t i = 0, + bool allow_fail = false) { if (array->dtype.code == kDLInt) { if (array->dtype.bits == 8) { return dmlc::optional(reinterpret_cast(array->data)[i]); From 59a9900c27350eca32211fac2bb18b4bbb939177 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Tue, 30 Mar 2021 15:38:29 -0700 Subject: [PATCH 11/11] tweak scalar conversion API to maintain compatibility --- src/relay/op/tensor/transform.cc | 9 ++++---- src/relay/transforms/dynamic_to_static.cc | 27 +++++++++------------- src/relay/transforms/pattern_utils.h | 28 +++++++++++++++-------- src/relay/transforms/simplify_expr.cc | 2 +- 4 files changed, 34 insertions(+), 32 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 1ed92f53b752..b65068bd0506 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1438,11 +1438,10 @@ bool ArangeRel(const Array& types, int num_inputs, const Attrs& raw_attrs, if ((cstart = attrs->start.as()) && (cstop = attrs->stop.as()) && (cstep = attrs->step.as())) { - auto start = ToScalar(cstart->data); - auto stop = ToScalar(cstop->data); - auto step = ToScalar(cstep->data); - int32_t num_elem = - static_cast(std::ceil((stop.value() - start.value()) / step.value())); + double start = ToScalar(cstart->data); + double stop = ToScalar(cstop->data); + double step = ToScalar(cstep->data); + int32_t num_elem = static_cast(std::ceil((stop - start) / step)); ICHECK_GT(num_elem, 0) << "Invalid arange attributes (start, stop, step): " << attrs->start << ", " << attrs->stop << ", " << attrs->step; reporter->Assign(types[3], TensorType({num_elem}, attrs->dtype)); diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index cd88df99fa4c..0590b41550ce 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -60,9 +60,8 @@ class DynamicToStaticMutator : public MixedModeMutator { if (const ConstantNode* k = args[1].as()) { const TopKAttrs* param = call_node->attrs.as(); ICHECK(param); - auto k_val = ToScalar(k->data, 0); - return MakeTopK(call_node->args[0], static_cast(k_val.value()), param->axis, - param->ret_type, param->is_ascend, param->dtype); + return MakeTopK(call_node->args[0], static_cast(ToScalar(k->data, 0)), + param->axis, param->ret_type, param->is_ascend, param->dtype); } return Expr(nullptr); }}, @@ -101,9 +100,9 @@ class DynamicToStaticMutator : public MixedModeMutator { if (const ConstantNode* depth = args[3].as()) { const OneHotAttrs* param = call_node->attrs.as(); ICHECK(param); - auto depth_val = ToScalar(depth->data, 0); return MakeOneHot(call_node->args[0], call_node->args[1], call_node->args[2], - static_cast(depth_val.value()), param->axis, param->dtype); + static_cast(ToScalar(depth->data, 0)), param->axis, + param->dtype); } return Expr(nullptr); }}, @@ -144,10 +143,9 @@ class DynamicToStaticMutator : public MixedModeMutator { ICHECK_EQ(scale_w->data->ndim, 0); const UpSamplingAttrs* param = call_node->attrs.as(); ICHECK(param); - auto scale_h_val = ToScalar(scale_h->data); - auto scale_w_val = ToScalar(scale_w->data); - return MakeUpSampling(call_node->args[0], scale_h_val.value(), scale_w_val.value(), - param->layout, param->method, param->align_corners); + return MakeUpSampling(call_node->args[0], ToScalar(scale_h->data), + ToScalar(scale_w->data), param->layout, param->method, + param->align_corners); } return Expr(nullptr); }}, @@ -163,11 +161,9 @@ class DynamicToStaticMutator : public MixedModeMutator { ICHECK_EQ(scale_w->data->ndim, 0); const UpSampling3DAttrs* param = call_node->attrs.as(); ICHECK(param); - auto scale_d_val = ToScalar(scale_d->data); - auto scale_h_val = ToScalar(scale_h->data); - auto scale_w_val = ToScalar(scale_w->data); - return MakeUpSampling3D(call_node->args[0], scale_d_val.value(), scale_h_val.value(), - scale_w_val.value(), param->layout, param->method, + return MakeUpSampling3D(call_node->args[0], ToScalar(scale_d->data), + ToScalar(scale_h->data), ToScalar(scale_w->data), + param->layout, param->method, param->coordinate_transformation_mode); } return Expr(nullptr); @@ -183,8 +179,7 @@ class DynamicToStaticMutator : public MixedModeMutator { const PadAttrs* param = call_node->attrs.as(); ICHECK(param); - auto pad_fill_val = ToScalar(pad_fill->data); - return MakePad(call_node->args[0], ToMatrix(pad_width->data), pad_fill_val.value(), + return MakePad(call_node->args[0], ToMatrix(pad_width->data), ToScalar(pad_fill->data), param->pad_mode); } return Expr(nullptr); diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index da46ceeb247f..8d9f723dffea 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -381,10 +381,9 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) { * \brief Convert an element of a NDArray with type int or float to scalar. * \param array Input NDArray * \param i element index - * \return Converted scalar value. + * \return Converted scalar value, or None if conversion failed */ -static inline dmlc::optional ToScalar(const runtime::NDArray& array, size_t i = 0, - bool allow_fail = false) { +static inline dmlc::optional TryToScalar(const runtime::NDArray& array, size_t i = 0) { if (array->dtype.code == kDLInt) { if (array->dtype.bits == 8) { return dmlc::optional(reinterpret_cast(array->data)[i]); @@ -419,12 +418,21 @@ static inline dmlc::optional ToScalar(const runtime::NDArray& array return dmlc::optional(reinterpret_cast(array->data)[i]); } } - if (!allow_fail) { - LOG(FATAL) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype); - } return dmlc::optional(); } +/*! + * \brief Convert an element of a NDArray with type int or float to scalar. + * \param array Input NDArray + * \param i element index + * \return Converted scalar value + */ +static inline long double ToScalar(const runtime::NDArray& array, size_t i = 0) { + auto try_value = TryToScalar(array, i); + ICHECK(try_value) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype); + return try_value.value(); +} + /*! * \brief Convert a NDArray with type int or float to Array. * \param array Input NDArray @@ -436,8 +444,8 @@ static inline Array ToVector(const runtime::NDArray& array) { size_t len = array.Shape().front(); Array out; for (size_t i = 0; i < len; ++i) { - auto elem_val = ToScalar(array, i); - out.push_back(Integer(IntImm(DataType::Int(32), static_cast(elem_val.value())))); + long double elem_val = ToScalar(array, i); + out.push_back(Integer(IntImm(DataType::Int(32), static_cast(elem_val)))); } return out; } @@ -458,8 +466,8 @@ static inline Array> ToMatrix(const runtime::NDArray& array) { for (size_t i = 0; i < dim1; ++i) { Array inner_out; for (size_t j = 0; j < dim2; ++j) { - auto elem_val = ToScalar(array, i * dim2 + j); - inner_out.push_back(Integer(static_cast(elem_val.value()))); + double elem_val = ToScalar(array, i * dim2 + j); + inner_out.push_back(Integer(static_cast(elem_val))); } out.push_back(inner_out); } diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index f1fa5a5d4428..762aa58f7298 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -383,7 +383,7 @@ class EliminateIdentityRewrite : public DFPatternRewrite { if (!IsScalar(GetRef(constant))) { return false; } - auto value = ToScalar(constant->data, 0, true); + auto value = TryToScalar(constant->data, 0); if (!value) { // unsupported dtype return false;