diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index f569e31cd8e0..209639dd8f83 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -78,11 +78,11 @@ class SimplifyReshape : public DFPatternRewrite { }; /*! - * \brief SimplifyCast matches the pattern of cast data to the same dtype. + * \brief SimplifySameCast matches the pattern of cast data to the same dtype. */ -class SimplifyCast : public DFPatternRewrite { +class SimplifySameCast : public DFPatternRewrite { public: - SimplifyCast() { + SimplifySameCast() { data_pat_ = IsWildcard(); like_pat_ = IsWildcard(); pattern_ = IsOp("cast_like")({data_pat_, like_pat_}) || IsOp("cast")({data_pat_}); @@ -104,6 +104,60 @@ class SimplifyCast : public DFPatternRewrite { DFPattern like_pat_; }; +/*! + * \brief SimplifyConsecutiveCast matches the pattern of consecutive cast/cast_like ops + */ +class SimplifyConsecutiveCast : public DFPatternRewrite { + public: + SimplifyConsecutiveCast() { + data_ = IsWildcard(); + cast1_ = IsOp("cast_like")({data_, IsWildcard()}) || IsOp("cast")({data_}); + pattern_ = IsOp("cast_like")({cast1_, IsWildcard()}) || IsOp("cast")({cast1_}); + } + + Expr Callback(const Expr& pre, const Expr& post, + const Map>& node_map) const override { + auto data = node_map[data_][0]; + auto cast1 = Downcast(node_map[cast1_][0]); + auto data_type = Downcast(data->checked_type()); + DataType cast1_dtype = Downcast(cast1->checked_type())->dtype; + + if (!IsWidenCast(data_type->dtype, cast1_dtype)) { + // Cannot remove the narrow cast + return post; + } + + const CallNode* cast2 = post.as(); + DataType cast2_dtype = Downcast(cast2->checked_type())->dtype; + auto expr = MakeCast(data, cast2_dtype); + + // We need to set the checked type as it may be needed in the next callback + expr->checked_type_ = TensorType(data_type->shape, cast2_dtype); + return expr; + } + + bool IsWidenCast(DataType origin, DataType cast) const { + /* Return whether casting from origin to cast results in more or the same precision.*/ + if (origin.code() == cast.code() && origin.bits() <= cast.bits()) { + return true; + } + if (origin.code() == DataType::kBFloat || cast.code() == DataType::kBFloat) { + // BFloat cast cannot be omitted + return false; + } + if (origin.code() < cast.code()) { + // Loosely have a hiearchy to datatypes + // e.g. int --> uint --> float has increasing range of numbers they can represent + return true; + } + return false; + } + + protected: + DFPattern data_; + DFPattern cast1_; +}; + /*! * \brief SimplifyTranspose matches the pattern of consecutive transpose op, * and merges or cancels them. @@ -640,7 +694,8 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) { composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); - composer.AddRewrite(); + composer.AddRewrite(); + composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); return RewritePatterns(composer.MakeCallbacks(), expr, mod); diff --git a/tests/python/relay/aot/test_crt_aot_usmp.py b/tests/python/relay/aot/test_crt_aot_usmp.py index b88e0905dba5..47495aaa16c8 100644 --- a/tests/python/relay/aot/test_crt_aot_usmp.py +++ b/tests/python/relay/aot/test_crt_aot_usmp.py @@ -211,7 +211,7 @@ def test_byoc_microtvm(merge_compiler_regions): "model_url, usmp_algo, workspace_size,", [ (MOBILENET_V1_URL, "greedy_by_size", 4845696), - (MOBILENET_V1_URL, "greedy_by_conflicts", 4845696), + (MOBILENET_V1_URL, "greedy_by_conflicts", 4444288), ], ) def test_tflite_model(model_url, usmp_algo, workspace_size): diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index f9367d5fd567..162ac6e73ddb 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -402,7 +402,7 @@ def check(x, y=None, do_nothing=False): check(id_op(const, x), id_op(op_like(x), x)) -def test_simplify_cast(): +def test_simplify_same_cast(): dtype = "int32" data = relay.var("data", shape=(3, 4, 5), dtype=dtype) expr1 = relay.cast(data, dtype) @@ -416,6 +416,36 @@ def test_simplify_cast(): assert tvm.ir.structural_equal(actual2, expected) +def test_simplify_consecutive_cast(): + x = relay.var("x", shape=(3, 4, 5), dtype="int8") + y = relay.var("y", shape=(3, 4), dtype="int64") + z = relay.var("z", shape=(3,), dtype="float32") + + expr1 = relay.cast(x, "int16") + expr2 = relay.cast(expr1, "int32") + expr3 = relay.cast_like(expr2, y) + expr4 = relay.cast_like(expr3, z) + + actual1 = run_opt_pass(expr2, relay.transform.SimplifyExpr()) + expected = run_infer_type(relay.cast(x, "int32")) + assert tvm.ir.structural_equal(actual1, expected) + actual2 = run_opt_pass(expr3, relay.transform.SimplifyExpr()) + expected = run_infer_type(relay.cast(x, "int64")) + assert tvm.ir.structural_equal(actual2, expected) + actual3 = run_opt_pass(expr4, relay.transform.SimplifyExpr()) + expected = run_infer_type(relay.cast(x, "float32")) + assert tvm.ir.structural_equal(actual3, expected) + + # cannot simplify the narrow cast + x = relay.var("x", shape=(3, 4, 5), dtype="float32") + y = relay.var("y", shape=(3, 4), dtype="float32") + expr1 = relay.cast(x, "int32") + expr2 = relay.cast_like(expr1, y) + actual = run_opt_pass(expr2, relay.transform.SimplifyExpr()) + expected = run_infer_type(expr2) + assert tvm.ir.structural_equal(actual, expected) + + def test_concretize_reshape_like(): data = relay.var("data", shape=(2, 3, 4), dtype="float32") shape_like = relay.var("shape_like", shape=(6, 2, 2), dtype="float32")