diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index 815e4d224cc5..0590b41550ce 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -161,7 +161,6 @@ class DynamicToStaticMutator : public MixedModeMutator { ICHECK_EQ(scale_w->data->ndim, 0); const UpSampling3DAttrs* param = call_node->attrs.as(); ICHECK(param); - return MakeUpSampling3D(call_node->args[0], ToScalar(scale_d->data), ToScalar(scale_h->data), ToScalar(scale_w->data), param->layout, param->method, diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index c1eebde15fba..8d9f723dffea 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -27,6 +27,7 @@ #define TVM_RELAY_TRANSFORMS_PATTERN_UTILS_H_ #include +#include #include #include #include @@ -380,43 +381,56 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) { * \brief Convert an element of a NDArray with type int or float to scalar. * \param array Input NDArray * \param i element index - * \return Converted scalar value. + * \return Converted scalar value, or None if conversion failed */ -static inline long double ToScalar(const runtime::NDArray& array, size_t i = 0) { +static inline dmlc::optional TryToScalar(const runtime::NDArray& array, size_t i = 0) { if (array->dtype.code == kDLInt) { if (array->dtype.bits == 8) { - return reinterpret_cast(array->data)[i]; + return dmlc::optional(reinterpret_cast(array->data)[i]); } else if (array->dtype.bits == 16) { - return reinterpret_cast(array->data)[i]; + return dmlc::optional(reinterpret_cast(array->data)[i]); } else if (array->dtype.bits == 32) { - return reinterpret_cast(array->data)[i]; + return dmlc::optional(reinterpret_cast(array->data)[i]); } else if (array->dtype.bits == 64) { - return reinterpret_cast(array->data)[i]; + return dmlc::optional(reinterpret_cast(array->data)[i]); } } else if (array->dtype.code == kDLUInt) { - if (array->dtype.bits == 8) { - return reinterpret_cast(array->data)[i]; + if (array->dtype.bits == 1) { // bool + return dmlc::optional(reinterpret_cast(array->data)[i]); + } else if (array->dtype.bits == 8) { + return dmlc::optional(reinterpret_cast(array->data)[i]); } else if (array->dtype.bits == 16) { - return reinterpret_cast(array->data)[i]; + return dmlc::optional(reinterpret_cast(array->data)[i]); } else if (array->dtype.bits == 32) { - return reinterpret_cast(array->data)[i]; + return dmlc::optional(reinterpret_cast(array->data)[i]); } else if (array->dtype.bits == 64) { - return reinterpret_cast(array->data)[i]; + return dmlc::optional(reinterpret_cast(array->data)[i]); } } else if (array->dtype.code == kDLFloat) { if (array->dtype.bits == 16) { - return __extendXfYf2__( - reinterpret_cast(array->data)[i]); + return dmlc::optional( + __extendXfYf2__( + reinterpret_cast(array->data)[i])); } if (array->dtype.bits == 32) { - return reinterpret_cast(array->data)[i]; + return dmlc::optional(reinterpret_cast(array->data)[i]); } else if (array->dtype.bits == 64) { - return reinterpret_cast(array->data)[i]; + return dmlc::optional(reinterpret_cast(array->data)[i]); } } - LOG(FATAL) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype); - // make compiler happy - return -std::numeric_limits::infinity(); + return dmlc::optional(); +} + +/*! + * \brief Convert an element of a NDArray with type int or float to scalar. + * \param array Input NDArray + * \param i element index + * \return Converted scalar value + */ +static inline long double ToScalar(const runtime::NDArray& array, size_t i = 0) { + auto try_value = TryToScalar(array, i); + ICHECK(try_value) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype); + return try_value.value(); } /*! diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index b4f4cc16e9df..762aa58f7298 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -22,35 +22,28 @@ * \brief A pass for simplifying the Relay expression. */ +#include "simplify_expr.h" + #include #include #include #include #include +#include +#include + #include "../op/tensor/transform.h" #include "pattern_utils.h" namespace tvm { namespace relay { -class SimplifyPattern { - public: - virtual Expr callback(const Expr& pre, const Expr& post, - const Map>& node_map) const = 0; - - DFPattern pattern() const { return pattern_; } - - protected: - /*! \brief Pattern for rewriting */ - DFPattern pattern_; -}; - /*! * \brief SimplifyReshape matches the pattern of consecutive reshape or reverse_reshape ops, * and merges into one reshape op. */ -class SimplifyReshape : public SimplifyPattern { +class SimplifyReshape : public DFPatternRewrite { public: SimplifyReshape() { x_ = IsWildcard(); @@ -59,7 +52,7 @@ class SimplifyReshape : public SimplifyPattern { pattern_ = reshape1({reshape2({x_})}); } - Expr callback(const Expr& pre, const Expr& post, + Expr Callback(const Expr& pre, const Expr& post, const Map>& node_map) const override { auto x = node_map[x_][0]; bool const_shape = true; @@ -86,7 +79,7 @@ class SimplifyReshape : public SimplifyPattern { * \brief SimplifyTranspose matches the pattern of consecutive transpose op, * and merges or cancels them. */ -class SimplifyTranspose : public SimplifyPattern { +class SimplifyTranspose : public DFPatternRewrite { public: SimplifyTranspose() { x_ = IsWildcard(); @@ -95,7 +88,7 @@ class SimplifyTranspose : public SimplifyPattern { pattern_ = trans1({trans2({x_})}); } - Expr callback(const Expr& pre, const Expr& post, + Expr Callback(const Expr& pre, const Expr& post, const Map>& node_map) const override { // Helper function to get the axes from call node attribute auto get_axes_from_call = [](const Call trans_call, int ndim) { @@ -176,9 +169,10 @@ class SimplifyTranspose : public SimplifyPattern { }; /*! - * \brief FullArgwhere finds full followed by argwhere and turns it into an Arange op + * \brief FullElementwise finds full like ops followed by broadcasting ops, and eliminates + * the full op by directly passing the fill value into the broadcasting op. */ -class FullElementwise : public SimplifyPattern { +class FullElementwise : public DFPatternRewrite { public: FullElementwise() { x_ = IsWildcard(); @@ -196,7 +190,7 @@ class FullElementwise : public SimplifyPattern { pattern_ = op({full, x_}) || op({x_, full}); } - Expr callback(const Expr& pre, const Expr& post, + Expr Callback(const Expr& pre, const Expr& post, const Map>& node_map) const override { const CallNode* call = pre.as(); ICHECK(call); @@ -249,36 +243,210 @@ class FullElementwise : public SimplifyPattern { }; /*! - * \brief ExprSimplifier simplifies the Relay expression. + * \brief Converts `*_like` operators to their explicit shape equivalent (e.g. `zeros_like(x, y)` to + * `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary dependencies + * and can enable more opportunities for operator fusion. */ -class ExprSimplifier { +class ConcretizeLikeRewrite : public DFPatternRewrite { public: - explicit ExprSimplifier(IRModule mod) : mod_(mod) { - CreateCallback(SimplifyReshape()); - CreateCallback(SimplifyTranspose()); - CreateCallback(FullElementwise()); + explicit ConcretizeLikeRewrite(const Op& op) { + ICHECK(op->num_inputs == 1 || op->num_inputs == 2) + << "ConcretizeLike does not handle operators that aren't unary or binary, got: " << op; + like_pat_ = IsWildcard(); + data_pat_ = IsWildcard(); + if (op->num_inputs == 1) { + pattern_ = IsExpr(op)({like_pat_}); + } else { + pattern_ = IsExpr(op)({data_pat_, like_pat_}); + } } - template - void CreateCallback(const T& pattern) { - auto func = [pattern](TVMArgs args, TVMRetValue* rv) { - Expr pre = args[0]; - Expr post = args[1]; - Map> node_map = args[2]; - *rv = pattern.callback(pre, post, node_map); - }; - callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true)); + + virtual bool Check(const Expr& pre, const Expr& post, + const Map>& node_map) const { + const CallNode* call_node = pre.as(); + ICHECK(call_node); + + if (!call_node->checked_type().as()) { + return false; + } + + return true; + } + + virtual Expr Concretize(const Map>& node_map, Array shape, + DataType dtype) const = 0; + + Expr Callback(const Expr& pre, const Expr& post, + const Map>& node_map) const override { + if (!Check(pre, post, node_map)) { + return post; + } + + const TensorTypeNode* like_ty = pre->checked_type().as(); + Array cshape; + for (const auto& dim : like_ty->shape) { + if (const auto* imm = dim.as()) { + cshape.push_back(Integer(GetRef(imm))); + } else { + // shape is not static, don't concretize + return post; + } + } + + return Concretize(node_map, cshape, like_ty->dtype); + } + + protected: + DFPattern data_pat_; + DFPattern like_pat_; +}; + +class ConcretizeZerosLikeRewrite : public ConcretizeLikeRewrite { + public: + ConcretizeZerosLikeRewrite() : ConcretizeLikeRewrite(Op::Get("zeros_like")) {} + + Expr Concretize(const Map>& node_map, Array shape, + DataType dtype) const override { + return MakeZeros(shape, dtype); + } +}; + +class ConcretizeOnesLikeRewrite : public ConcretizeLikeRewrite { + public: + ConcretizeOnesLikeRewrite() : ConcretizeLikeRewrite(Op::Get("ones_like")) {} + + Expr Concretize(const Map>& node_map, Array shape, + DataType dtype) const override { + return MakeOnes(shape, dtype); + } +}; + +class ConcretizeReshapeLikeRewrite : public ConcretizeLikeRewrite { + public: + ConcretizeReshapeLikeRewrite() : ConcretizeLikeRewrite(Op::Get("reshape_like")) {} + + Expr Concretize(const Map>& node_map, Array shape, + DataType dtype) const override { + return MakeReshape(node_map[data_pat_][0], shape); + } +}; + +class ConcretizeCollapseSumLikeRewrite : public ConcretizeLikeRewrite { + public: + ConcretizeCollapseSumLikeRewrite() : ConcretizeLikeRewrite(Op::Get("collapse_sum_like")) {} + + Expr Concretize(const Map>& node_map, Array shape, + DataType dtype) const override { + ICHECK_LE(shape.size(), std::numeric_limits::max()); + static const Op& op = Op::Get("collapse_sum_to"); + auto attrs = make_object(); + attrs->shape = shape; + auto cshape = + MakeConstantTensor(DataType::Int(32), {static_cast(shape.size())}, shape); + return Call(op, {node_map[data_pat_][0], cshape}, Attrs(attrs)); + } +}; + +class ConcretizeBroadcastToLikeRewrite : public ConcretizeLikeRewrite { + public: + ConcretizeBroadcastToLikeRewrite() : ConcretizeLikeRewrite(Op::Get("broadcast_to_like")) {} + + Expr Concretize(const Map>& node_map, Array shape, + DataType dtype) const override { + return MakeBroadCastTo(node_map[data_pat_][0], shape); + } +}; + +/*! \brief Eliminates expressions that are equivalent to identity. */ +class EliminateIdentityRewrite : public DFPatternRewrite { + public: + EliminateIdentityRewrite() { + x_ = IsWildcard(); + const_ = IsConstant(); + + DFPattern add_op = IsOp("add"); + DFPattern mul_op = IsOp("multiply"); + DFPattern zeros_expr = IsOp("zeros")({}) || IsOp("zeros_like")({IsWildcard()}) || const_; + DFPattern ones_expr = IsOp("ones")({}) || IsOp("ones_like")({IsWildcard()}) || const_; + + // add and multiply are commutative so we don't need another pattern for reversed args + DFPattern add_id = add_op({x_, zeros_expr}); + DFPattern mul_id = mul_op({x_, ones_expr}); + + DFPattern sub_id = IsOp("subtract")({x_, zeros_expr}); + DFPattern div_id = IsOp("divide")({x_, ones_expr}); + + pattern_ = add_id || mul_id || sub_id || div_id; } - Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); } + bool CheckConstant(const OpNode* op, const ConstantNode* constant) const { + if (!IsScalar(GetRef(constant))) { + return false; + } + auto value = TryToScalar(constant->data, 0); + if (!value) { + // unsupported dtype + return false; + } + if (op->name == "add" || op->name == "subtract") { + return value.value() == 0.0; + } else if (op->name == "multiply" || op->name == "divide") { + return value.value() == 1.0; + } + return false; + } + + Expr Callback(const Expr& pre, const Expr& post, + const Map>& node_map) const override { + const CallNode* call = pre.as(); + ICHECK(call); + Type pre_type = pre->checked_type_; + ICHECK(pre_type.as()); + auto x = node_map[x_][0]; + bool is_left = post.as()->args[1] == x; + Type x_type; + if (is_left) { + x_type = call->args[1]->checked_type_; + } else { + x_type = call->args[0]->checked_type_; + } + + if (node_map.count(const_)) { + // the other argument is a Constant in this case + const ConstantNode* constant = node_map[const_][0].as(); + const OpNode* op = call->op.as(); + ICHECK(constant); + ICHECK(op); + if (!CheckConstant(op, constant)) { + return post; + } + } + + if (StructuralEqual()(x_type, pre_type)) { + return x; + } + + return post; + } private: - IRModule mod_; - /*! \brief Callbacks for expr simplification */ - Array callbacks_; + DFPattern x_; + DFPattern const_; }; Expr SimplifyExpr(const Expr& expr, const IRModule& mod) { - return ExprSimplifier(mod).Simplify(expr); + // the rewrites will be applied in the given order, and repeated until fixed point + DFPatternRewriteComposer composer; + composer.AddRewrite(); + composer.AddRewrite(); + composer.AddRewrite(); + composer.AddRewrite(); + composer.AddRewrite(); + composer.AddRewrite(); + composer.AddRewrite(); + composer.AddRewrite(); + composer.AddRewrite(); + return RewritePatterns(composer.MakeCallbacks(), expr, mod); } namespace transform { diff --git a/src/relay/transforms/simplify_expr.h b/src/relay/transforms/simplify_expr.h new file mode 100644 index 000000000000..6b3925e6b007 --- /dev/null +++ b/src/relay/transforms/simplify_expr.h @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/transforms/simplify_expr.h + * \brief Utility data structures for simplifying Relay expressions. + */ +#ifndef TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_ +#define TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_ + +#include +#include + +#include +#include + +namespace tvm { +namespace relay { + +/*! \brief A wrapper class defining a rewrite matching a specific pattern. */ +class DFPatternRewrite { + public: + /*! \brief Returns the rewritten expression. */ + virtual Expr Callback(const Expr& pre, const Expr& post, + const Map>& node_map) const = 0; + + virtual ~DFPatternRewrite() = default; + + /*! \brief Returns the pattern to be used for matching and rewriting. */ + inline DFPattern Pattern() const { return pattern_; } + + inline bool RequireType() const { return require_type_; } + + inline DFPatternCallback MakeCallback() const { + auto func = [this](TVMArgs args, TVMRetValue* rv) { + Expr pre = args[0]; + Expr post = args[1]; + Map> node_map = args[2]; + *rv = this->Callback(pre, post, node_map); + }; + return DFPatternCallback(pattern_, PackedFunc(func), require_type_); + } + + protected: + /*! \brief The pattern for matching and rewriting. */ + DFPattern pattern_; + /*! \brief Whether or not the rewrite requires types to be inferred. */ + bool require_type_ = true; +}; + +/*! \brief Helper class for composing rewrites and getting callbacks. */ +class DFPatternRewriteComposer { + public: + template + inline void AddRewrite(Args... args) { + rewrites_.push_back(std::make_shared(&args...)); + } + + inline Array MakeCallbacks() const { + Array callbacks; + for (const auto rewrite : rewrites_) { + callbacks.push_back(rewrite->MakeCallback()); + } + return callbacks; + } + + private: + /*! \brief the rewrites to be composed. */ + std::vector> rewrites_; +}; + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_ diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index 897f90b9ee2a..d015cdd36c2d 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -14,10 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest import tvm from tvm import relay from tvm.relay import transform -from tvm.relay.testing import run_opt_pass +from tvm.relay.testing import run_opt_pass, run_infer_type import numpy as np @@ -123,12 +124,22 @@ def before_left(x, elem_op, full): return elem_op(full, x) def after_left(x, elem_op, value): + if elem_op == relay.add and value == 0: + return x + elif elem_op == relay.multiply and (value == 1 or (value > 1 and dtype == "bool")): + return x return elem_op(relay.const(value, dtype), x) def before_right(x, elem_op, full): return elem_op(x, full) def after_right(x, elem_op, value): + if elem_op in [relay.add, relay.subtract] and value == 0: + return x + elif elem_op in [relay.multiply, relay.divide] and ( + value == 1 or (value > 1 and dtype == "bool") + ): + return x return elem_op(x, relay.const(value, dtype)) x = relay.var("x", shape=shape, dtype=dtype) @@ -181,7 +192,134 @@ def after_right(x, elem_op, value): validate(shape, value, dtype) +def test_eliminate_identity(): + def check(x, y=None, do_nothing=False): + expected = run_infer_type(x) + if do_nothing: + actual = run_opt_pass(x, transform.SimplifyExpr()) + assert tvm.ir.structural_equal(actual, expected) + else: + assert y is not None + actual = run_opt_pass(y, transform.SimplifyExpr()) + assert tvm.ir.structural_equal(actual, expected) + + shape = [2, 3, 4] + dtype = "float32" + x = relay.var("x", shape=shape, dtype=dtype) + x = run_opt_pass(x, transform.InferType()) + + for (op, op_like, id_op, const) in [ + (relay.zeros, relay.zeros_like, relay.add, relay.const(0, dtype)), + (relay.ones, relay.ones_like, relay.multiply, relay.const(1, dtype)), + ]: + check(x, id_op(op_like(x), x)) + check(x, id_op(op(shape, dtype), x)) + check(x, id_op(const, x)) + check(x, id_op(op(shape[1:], dtype), x)) + check(x, id_op(x, op_like(x))) + check(x, id_op(x, op(shape, dtype))) + check(x, id_op(x, const)) + check(x, id_op(x, op(shape[1:], dtype))) + check(id_op(x, op([2] + shape, dtype)), do_nothing=True) + check(id_op(op([2] + shape, dtype), x), do_nothing=True) + + for (op, op_like, id_op, const) in [ + (relay.zeros, relay.zeros_like, relay.subtract, relay.const(0, dtype)), + (relay.ones, relay.ones_like, relay.divide, relay.const(1, dtype)), + ]: + check(x, id_op(x, op_like(x))) + check(x, id_op(x, const)) + check(x, id_op(x, op(shape, dtype))) + check(x, id_op(x, op(shape[1:], dtype))) + check(id_op(x, op([2] + shape, dtype)), do_nothing=True) + check(id_op(const, x), id_op(op(shape, dtype), x)) + check(id_op(const, x), id_op(op_like(x), x)) + + +def test_concretize_reshape_like(): + data = relay.var("data", shape=(2, 3, 4), dtype="float32") + shape_like = relay.var("shape_like", shape=(6, 2, 2), dtype="float32") + expr = relay.reshape_like(data, shape_like) + + expected = run_infer_type(relay.reshape(data, (6, 2, 2))) + actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) + assert tvm.ir.structural_equal(actual, expected) + + +def test_concretize_reshape_like_attrs(): + data = relay.var("data", shape=(2, 3, 4), dtype="float32") + shape_like = relay.var("shape_like", shape=(6, 2, 2), dtype="float32") + expr = relay.reshape_like(data, shape_like, lhs_begin=2, rhs_begin=1) + + expected = run_infer_type(relay.reshape(data, (2, 3, 2, 2))) + actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) + assert tvm.ir.structural_equal(actual, expected) + + +def test_concretize_zeros_like(): + dtype = "int32" + shape_like = relay.var("shape_like", shape=(3, 4, 5), dtype=dtype) + expr = relay.zeros_like(shape_like) + + expected = run_infer_type(relay.zeros((3, 4, 5), dtype)) + actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) + assert tvm.ir.structural_equal(actual, expected) + + +def test_concretize_ones_like(): + dtype = "int32" + shape_like = relay.var("shape_like", shape=(3, 4, 5), dtype=dtype) + expr = relay.ones_like(shape_like) + + expected = run_infer_type(relay.ones((3, 4, 5), dtype)) + actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) + assert tvm.ir.structural_equal(actual, expected) + + +def test_concretize_collapse_sum_like(): + data = relay.var("data", shape=(3, 3, 3), dtype="float32") + shape_like = relay.var("shape_like", shape=(3,), dtype="float32") + expr = relay.collapse_sum_like(data, shape_like) + + expected = run_infer_type(relay.collapse_sum_to(data, (3,))) + actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) + assert tvm.ir.structural_equal(actual, expected) + + +def test_concretize_broadcast_to_like(): + data = relay.var("data", shape=(3,), dtype="float32") + shape_like = relay.var("shape_like", shape=(3, 3, 3), dtype="float32") + expr = relay.broadcast_to_like(data, shape_like) + + expected = run_infer_type(relay.broadcast_to(data, (3, 3, 3))) + actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) + assert tvm.ir.structural_equal(actual, expected) + + +def test_concretize_multiple(): + x = relay.var("x", shape=(2, 3), dtype="float32") + y = relay.var("y", shape=(3,), dtype="float32") + l = x + y + + dl = relay.ones_like(l) + dx = relay.zeros_like(x) + dy = relay.zeros_like(y) + dx = dx + relay.collapse_sum_like(dl, dx) + dy = dy + relay.collapse_sum_like(dl, dy) + ret = relay.Tuple([dx, dy]) + + dl_c = relay.ones((2, 3), "float32") + # NOTE: these are removed by EliminateIdentity + # dx_c = relay.zeros((2, 3), "float32") + # dy_c = relay.zeros((3,), "float32") + dx_c = relay.collapse_sum_to(dl_c, (2, 3)) + dy_c = relay.collapse_sum_to(dl_c, (3,)) + ret_c = relay.Tuple([dx_c, dy_c]) + + expected = run_infer_type(ret_c) + actual = run_opt_pass(ret, relay.transform.SimplifyExpr()) + assert tvm.ir.structural_equal(actual, expected) + + if __name__ == "__main__": - test_simplify_reshape() - test_simplify_transpose() - test_simplify_full_elementwise() + pytest.main([__file__])