From 8b0a4913ff5dc3323eedcfb24a15d2f0d82818e4 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Wed, 19 May 2021 21:32:49 +0000 Subject: [PATCH 1/3] [Pass] Simplify consecutive casts in Relay --- src/relay/transforms/simplify_expr.cc | 39 +++++++++++++++++-- tests/python/relay/test_pass_simplify_expr.py | 23 ++++++++++- 2 files changed, 57 insertions(+), 5 deletions(-) diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index fb7a76f1ea7a..d79954818430 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -78,11 +78,11 @@ class SimplifyReshape : public DFPatternRewrite { }; /*! - * \brief SimplifyCast matches the pattern of cast data to the same dtype. + * \brief SimplifySameCast matches the pattern of cast data to the same dtype. */ -class SimplifyCast : public DFPatternRewrite { +class SimplifySameCast : public DFPatternRewrite { public: - SimplifyCast() { + SimplifySameCast() { data_pat_ = IsWildcard(); like_pat_ = IsWildcard(); pattern_ = IsOp("cast_like")({data_pat_, like_pat_}) || IsOp("cast")({data_pat_}); @@ -104,6 +104,36 @@ class SimplifyCast : public DFPatternRewrite { DFPattern like_pat_; }; +/*! + * \brief SimplifyConsecutiveCast matches the pattern of consecutive cast/cast_like ops + */ +class SimplifyConsecutiveCast : public DFPatternRewrite { + public: + SimplifyConsecutiveCast() { + data_ = IsWildcard(); + auto cast1 = IsOp("cast_like")({data_, IsWildcard()}) || IsOp("cast")({data_}); + pattern_ = IsOp("cast_like")({cast1, IsWildcard()}) || IsOp("cast")({cast1}); + } + + Expr Callback(const Expr& pre, const Expr& post, + const Map>& node_map) const override { + static const Op& cast_op = Op::Get("cast"); + static const Op& cast_like_op = Op::Get("cast_like"); + auto data = node_map[data_][0]; + const CallNode* call = post.as(); + if (call->op == cast_op) { + auto attr = call->attrs.as(); + CHECK(attr); + return MakeCast(data, attr->dtype); + } + // cast_like op + return Call(cast_like_op, {data, call->args[1]}, Attrs(), {}); + } + + protected: + DFPattern data_; +}; + /*! * \brief SimplifyTranspose matches the pattern of consecutive transpose op, * and merges or cancels them. @@ -597,7 +627,8 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) { composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); - composer.AddRewrite(); + composer.AddRewrite(); + composer.AddRewrite(); composer.AddRewrite(); return RewritePatterns(composer.MakeCallbacks(), expr, mod); } diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index 9f11d3827064..0d5c3d0d269f 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -402,7 +402,7 @@ def check(x, y=None, do_nothing=False): check(id_op(const, x), id_op(op_like(x), x)) -def test_simplify_cast(): +def test_simplify_same_cast(): dtype = "int32" data = relay.var("data", shape=(3, 4, 5), dtype=dtype) expr1 = relay.cast(data, dtype) @@ -416,6 +416,27 @@ def test_simplify_cast(): assert tvm.ir.structural_equal(actual2, expected) +def test_simplify_consecutive_cast(): + dtype = "int32" + x = relay.var("x", shape=(3, 4, 5), dtype="int32") + y = relay.var("y", shape=(3, 4), dtype="int8") + z = relay.var("z", shape=(3,), dtype="float32") + expr1 = relay.cast(x, "int64") + expr2 = relay.cast(expr1, "int16") + expr3 = relay.cast_like(expr2, y) + expr4 = relay.cast_like(expr3, z) + + actual1 = run_opt_pass(expr2, relay.transform.SimplifyExpr()) + expected = run_infer_type(relay.cast(x, "int16")) + assert tvm.ir.structural_equal(actual1, expected) + actual2 = run_opt_pass(expr3, relay.transform.SimplifyExpr()) + expected = run_infer_type(relay.cast_like(x, y)) + assert tvm.ir.structural_equal(actual2, expected) + actual3 = run_opt_pass(expr4, relay.transform.SimplifyExpr()) + expected = run_infer_type(relay.cast_like(x, z)) + assert tvm.ir.structural_equal(actual3, expected) + + def test_concretize_reshape_like(): data = relay.var("data", shape=(2, 3, 4), dtype="float32") shape_like = relay.var("shape_like", shape=(6, 2, 2), dtype="float32") From 960c49c8920caa15872bb33c94fcde7e8e6414c8 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Fri, 21 May 2021 20:53:51 +0000 Subject: [PATCH 2/3] fix bug --- src/relay/transforms/simplify_expr.cc | 51 +++++++++++++++---- tests/python/frontend/pytorch/test_forward.py | 5 +- tests/python/relay/test_pass_simplify_expr.py | 24 ++++++--- 3 files changed, 62 insertions(+), 18 deletions(-) diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index d79954818430..780b86f2ff01 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -111,27 +111,60 @@ class SimplifyConsecutiveCast : public DFPatternRewrite { public: SimplifyConsecutiveCast() { data_ = IsWildcard(); - auto cast1 = IsOp("cast_like")({data_, IsWildcard()}) || IsOp("cast")({data_}); - pattern_ = IsOp("cast_like")({cast1, IsWildcard()}) || IsOp("cast")({cast1}); + cast1_ = IsOp("cast_like")({data_, IsWildcard()}) || IsOp("cast")({data_}); + pattern_ = IsOp("cast_like")({cast1_, IsWildcard()}) || IsOp("cast")({cast1_}); } Expr Callback(const Expr& pre, const Expr& post, const Map>& node_map) const override { static const Op& cast_op = Op::Get("cast"); - static const Op& cast_like_op = Op::Get("cast_like"); auto data = node_map[data_][0]; - const CallNode* call = post.as(); - if (call->op == cast_op) { - auto attr = call->attrs.as(); + auto cast1 = Downcast(node_map[cast1_][0]); + auto data_type = Downcast(data->checked_type()); + DataType cast1_dtype; + if (cast1->op == cast_op) { + auto attr = cast1->attrs.as(); CHECK(attr); - return MakeCast(data, attr->dtype); + cast1_dtype = attr->dtype; + } else { // cast_like + cast1_dtype = Downcast(cast1->args[1]->checked_type())->dtype; } - // cast_like op - return Call(cast_like_op, {data, call->args[1]}, Attrs(), {}); + if (!IsWidenCast(data_type->dtype, cast1_dtype)) { + // Cannot remove the narrow cast + return post; + } + const CallNode* cast2 = post.as(); + DataType cast2_dtype; + if (cast2->op == cast_op) { + auto attr = cast2->attrs.as(); + CHECK(attr); + cast2_dtype = attr->dtype; + } else { // cast_like + cast2_dtype = Downcast(cast2->args[1]->checked_type())->dtype; + } + auto expr = MakeCast(data, cast2_dtype); + // We need to set the checked type as it may be needed in the next callback + expr->checked_type_ = TensorType(data_type->shape, cast2_dtype); + return expr; + } + + bool IsWidenCast(DataType origin, DataType cast) const { + if (origin.code() == cast.code() && origin.bits() <= cast.bits()) { + return true; + } + if (origin.code() == DataType::kBFloat || cast.code() == DataType::kBFloat) { + // BFloat cast cannot be omitted + return false; + } + if (origin.code() < cast.code()) { + return true; + } + return false; } protected: DFPattern data_; + DFPattern cast1_; }; /*! diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 07f0d8e75c4d..fbf85fe641da 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -205,6 +205,7 @@ def verify_model( input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)] input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input])) mod, params = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map) + print(mod["main"]) for arg in mod["main"].params[: len(input_names)]: assert arg.name_hint in input_names compiled_input = dict(zip(input_names, [inp.clone().cpu().numpy() for inp in baseline_input])) @@ -3710,6 +3711,7 @@ def model_fn(x, y): input_y = torch.randint(low=0, high=100, size=ishape, dtype=torch.int32) inputs = [input_x, input_y] script_module = torch.jit.trace(model_fn, inputs) + print(script_module) fname = "tmp.pt" torch.jit.save(script_module, fname) @@ -3867,6 +3869,7 @@ def test_fn(is_sorted, return_inverse, return_counts): if __name__ == "__main__": + """ # some structural tests test_forward_traced_function() test_forward_dtypes() @@ -4042,6 +4045,6 @@ def test_fn(is_sorted, return_inverse, return_counts): # Test bert model test_forward_pretrained_bert_base_uncased() - + """ # Test convert torch script(jit) with specific inputs' types test_convert_torch_script_with_input_types() diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index 0d5c3d0d269f..1734af1b5518 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -417,25 +417,33 @@ def test_simplify_same_cast(): def test_simplify_consecutive_cast(): - dtype = "int32" - x = relay.var("x", shape=(3, 4, 5), dtype="int32") - y = relay.var("y", shape=(3, 4), dtype="int8") + x = relay.var("x", shape=(3, 4, 5), dtype="int8") + y = relay.var("y", shape=(3, 4), dtype="int64") z = relay.var("z", shape=(3,), dtype="float32") - expr1 = relay.cast(x, "int64") - expr2 = relay.cast(expr1, "int16") + expr1 = relay.cast(x, "int16") + expr2 = relay.cast(expr1, "int32") expr3 = relay.cast_like(expr2, y) expr4 = relay.cast_like(expr3, z) actual1 = run_opt_pass(expr2, relay.transform.SimplifyExpr()) - expected = run_infer_type(relay.cast(x, "int16")) + expected = run_infer_type(relay.cast(x, "int32")) assert tvm.ir.structural_equal(actual1, expected) actual2 = run_opt_pass(expr3, relay.transform.SimplifyExpr()) - expected = run_infer_type(relay.cast_like(x, y)) + expected = run_infer_type(relay.cast(x, "int64")) assert tvm.ir.structural_equal(actual2, expected) actual3 = run_opt_pass(expr4, relay.transform.SimplifyExpr()) - expected = run_infer_type(relay.cast_like(x, z)) + expected = run_infer_type(relay.cast(x, "float32")) assert tvm.ir.structural_equal(actual3, expected) + # cannot simplify the narrow cast + x = relay.var("x", shape=(3, 4, 5), dtype="float32") + y = relay.var("y", shape=(3, 4), dtype="float32") + expr1 = relay.cast(x, "int32") + expr2 = relay.cast_like(expr1, y) + actual = run_opt_pass(expr2, relay.transform.SimplifyExpr()) + expected = run_infer_type(expr2) + assert tvm.ir.structural_equal(actual, expected) + def test_concretize_reshape_like(): data = relay.var("data", shape=(2, 3, 4), dtype="float32") From f408749634d2e59314394de79f59b740256d6a0a Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Fri, 21 May 2021 20:55:48 +0000 Subject: [PATCH 3/3] clean up --- tests/python/frontend/pytorch/test_forward.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index fbf85fe641da..07f0d8e75c4d 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -205,7 +205,6 @@ def verify_model( input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)] input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input])) mod, params = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map) - print(mod["main"]) for arg in mod["main"].params[: len(input_names)]: assert arg.name_hint in input_names compiled_input = dict(zip(input_names, [inp.clone().cpu().numpy() for inp in baseline_input])) @@ -3711,7 +3710,6 @@ def model_fn(x, y): input_y = torch.randint(low=0, high=100, size=ishape, dtype=torch.int32) inputs = [input_x, input_y] script_module = torch.jit.trace(model_fn, inputs) - print(script_module) fname = "tmp.pt" torch.jit.save(script_module, fname) @@ -3869,7 +3867,6 @@ def test_fn(is_sorted, return_inverse, return_counts): if __name__ == "__main__": - """ # some structural tests test_forward_traced_function() test_forward_dtypes() @@ -4045,6 +4042,6 @@ def test_fn(is_sorted, return_inverse, return_counts): # Test bert model test_forward_pretrained_bert_base_uncased() - """ + # Test convert torch script(jit) with specific inputs' types test_convert_torch_script_with_input_types()