From d482a6201dbcf6d2d97da2af93d2ad6d12da4c17 Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Mon, 31 Oct 2022 13:14:23 +0300 Subject: [PATCH] [Relay] Add ClipAndConsecutiveCast and CastClip to SimplifyExpr This commit adds SimplifyClipAndConsecutiveCast and SimplifyCastClip to SimplifyExpr Relay pass. These simplify sequence clip->cast->cast and cast->clip based on Clip min/max attributes and Cast target data type. 1) SimplifyClipAndConsecutiveCast example: %0 == [type=int32] %1 = clip(%0, a_min=0f, a_max=255f) [type=int32] %2 = cast(%1, dtype="uint8") [type=uint8] %3 = cast(%2, dtype="int32") [type=int32] --> Here Clip dtype == Cast2 dtype and max_value("uint8") == 255 min_value("uint8") == 0 Optimized sequence (both casts can be removed): %1 = clip(%0, a_min=0f, a_max=255f) [type=int32] 2) SimplifyCastClip example: %1 = cast(%0, dtype="uint8") [type=uint8] %2 = clip(%1, a_min=0f, a_max=255f) [type=int8] Optimized sequence (remove Clip): %1 = cast(%0, dtype="uint8") [type=uint8] --- src/relay/transforms/simplify_expr.cc | 102 ++++++++++++++++++ tests/python/relay/test_pass_simplify_expr.py | 31 ++++++ 2 files changed, 133 insertions(+) diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index cf594a09a266..89d9e07d3f1d 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -159,6 +159,106 @@ class SimplifyConsecutiveCast : public DFPatternRewrite { DFPattern cast1_; }; +bool CheckDataTypeMaxMinValue(DataType dtype, double min_value, double max_value) { + if (dtype.is_int() || dtype.is_uint()) { + double ubound = static_cast(Downcast(tvm::max_value(dtype))->value); + double lbound = static_cast(Downcast(tvm::min_value(dtype))->value); + return ubound == max_value && lbound == min_value; + } else if (dtype.is_float()) { + double ubound = Downcast(tvm::max_value(dtype))->value; + double lbound = Downcast(tvm::min_value(dtype))->value; + return ubound == max_value && lbound == min_value; + } + + return false; +} + +/*! + * \brief SimplifyClipAndConsecutiveCast matches the pattern clip->cast->cast and remove redundant + * casts. + * Analysis of "redundancy" is done based on clip min/max values and min/max values of casted data + * type. + */ +class SimplifyClipAndConsecutiveCast : public DFPatternRewrite { + public: + SimplifyClipAndConsecutiveCast() { + clip_ = IsOp("clip")({IsWildcard()}); + cast1_ = IsOp("cast")({clip_}); + pattern_ = IsOp("cast")({cast1_}); + } + + Expr Callback(const Expr& pre, const Expr& post, + const Map>& node_map) const override { + auto clip = Downcast(node_map[clip_][0]); + const CallNode* clip_node = clip.as(); + const ClipAttrs* clip_attrs = clip_node->attrs.as(); + DataType clip_dtype = Downcast(clip->checked_type())->dtype; + + auto cast1 = Downcast(node_map[cast1_][0]); + DataType cast1_dtype = Downcast(cast1->checked_type())->dtype; + + auto cast2 = Downcast(post); + DataType cast2_dtype = Downcast(cast2->checked_type())->dtype; + + if (clip_dtype == cast2_dtype && + CheckDataTypeMaxMinValue(cast1_dtype, clip_attrs->a_min, clip_attrs->a_max)) { + // Case 1: + // Data type of Clip == target data type of second Cast and min/max value of Clip == min/max + // value of first Clip target data type. In this case both Clip ops can be removed. + // Example: + // %0 == [type=int32] + // %1 = clip(%0, a_min=0f, a_max=255f) [type=int32] + // %2 = cast(%1, dtype="uint8") [type=uint8] + // %3 = cast(%2, dtype="int32") [type=int32] + // + // Optimized to (both casts can be removed): + // %1 = clip(%0, a_min=0f, a_max=255f) [type=int32] + return node_map[clip_][0]; + } + return post; + } + + protected: + DFPattern clip_, cast1_; +}; + +/*! + * \brief SimplifyCastClip matches the pattern cast->clip and remove redundant Cast based on Clip + * min/max values and min/max values of Cast target data type. + * + * Example: + * %1 = cast(%0, dtype="uint8") [type=uint8] + * %2 = clip(%1, a_min=0f, a_max=255f) [type=int8] + * + * Optimized to (remove Clip): + * %1 = cast(%0, dtype="uint8") [type=uint8] + */ +class SimplifyCastClip : public DFPatternRewrite { + public: + SimplifyCastClip() { + cast_ = IsOp("cast")({IsWildcard()}); + pattern_ = IsOp("clip")({cast_}); + } + + Expr Callback(const Expr& pre, const Expr& post, + const Map>& node_map) const override { + auto cast = Downcast(node_map[cast_][0]); + DataType cast_dtype = Downcast(cast->checked_type())->dtype; + + auto clip = Downcast(post); + const CallNode* clip_node = clip.as(); + const ClipAttrs* clip_attrs = clip_node->attrs.as(); + + if (CheckDataTypeMaxMinValue(cast_dtype, clip_attrs->a_min, clip_attrs->a_max)) { + return node_map[cast_][0]; + } + return post; + } + + protected: + DFPattern clip_, cast_; +}; + /*! * \brief SimplifyTranspose matches the pattern of consecutive transpose op, * and merges or cancels them. @@ -804,6 +904,8 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) { composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); + composer.AddRewrite(); + composer.AddRewrite(); return RewritePatterns(composer.MakeCallbacks(), expr, mod); } diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index e84d238aaa75..b8275d467338 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -669,5 +669,36 @@ def expected(): assert tvm.ir.structural_equal(opt, after) +def test_simplify_clip_cast(): + x = relay.var("x", shape=(4, 8), dtype="int32") + + def before(): + clip = relay.clip(x, a_min=0.0, a_max=255.0) + cast = relay.cast(clip, "uint8") + return relay.cast(cast, "int32") + + def expected(): + return relay.clip(x, a_min=0.0, a_max=255.0) + + opt = run_opt_pass(before(), transform.SimplifyExpr()) + ref = run_infer_type(expected()) + assert tvm.ir.structural_equal(opt, ref) + + +def test_simplify_cast_clip(): + x = relay.var("x", shape=(4, 8), dtype="int32") + + def before(): + cast = relay.cast(x, "uint8") + return relay.clip(cast, a_min=0.0, a_max=255.0) + + def expected(): + return relay.cast(x, "uint8") + + opt = run_opt_pass(before(), transform.SimplifyExpr()) + ref = run_infer_type(expected()) + assert tvm.ir.structural_equal(opt, ref) + + if __name__ == "__main__": pytest.main([__file__])