diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index 463f76995436..7a4a9bd9fae8 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -565,6 +565,36 @@ class ConcretizeBroadcastToLikeRewrite : public ConcretizeLikeRewrite { } }; +/*! + * \brief Converts cast_like operator to cast. Not inheriting from ConcretizeLikeRewrite + * because even if shape is not static, still can concretize. + */ +class ConcretizeCastLikeRewrite : public DFPatternRewrite { + public: + ConcretizeCastLikeRewrite() { + data_pat_ = IsWildcard(); + like_pat_ = IsWildcard(); + pattern_ = IsOp("cast_like")({data_pat_, like_pat_}); + } + + Expr Callback(const Expr& pre, const Expr& post, + const Map>& node_map) const override { + const CallNode* call_node = pre.as(); + ICHECK(call_node); + + if (!call_node->checked_type().as()) { + return post; + } + + const TensorTypeNode* like_ty = pre->checked_type().as(); + return MakeCast(node_map[data_pat_][0], like_ty->dtype); + } + + protected: + DFPattern data_pat_; + DFPattern like_pat_; +}; + /*! \brief Eliminates expressions that are equivalent to identity. */ class EliminateIdentityRewrite : public DFPatternRewrite { public: @@ -762,6 +792,7 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) { composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); + composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index dcd58602b0ac..16d5efe10c44 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -442,7 +442,7 @@ def test_simplify_consecutive_cast(): expr1 = relay.cast(x, "int32") expr2 = relay.cast_like(expr1, y) actual = run_opt_pass(expr2, relay.transform.SimplifyExpr()) - expected = run_infer_type(expr2) + expected = run_infer_type(relay.cast(expr1, "float32")) assert tvm.ir.structural_equal(actual, expected) @@ -517,6 +517,17 @@ def test_concretize_broadcast_to_like(): assert tvm.ir.structural_equal(actual, expected) +def test_concretize_cast_like(): + dim_any = tvm.tir.Any() + data = relay.var("data", shape=(3, dim_any, 5), dtype="float32") + dtype_like = relay.var("dtype_like", shape=(dim_any, 3, 3), dtype="int32") + expr = relay.cast_like(data, dtype_like) + + expected = run_infer_type(relay.cast(data, "int32")) + actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) + assert tvm.ir.structural_equal(actual, expected) + + def test_concretize_multiple(): x = relay.var("x", shape=(2, 3), dtype="float32") y = relay.var("y", shape=(3,), dtype="float32")