From 6025f61e5c920a270f69639c844ac3203c3bbae4 Mon Sep 17 00:00:00 2001 From: kvegiraj Date: Tue, 20 Dec 2022 22:51:22 -0800 Subject: [PATCH 1/6] [CLML][RELAY] Enable Pad and Conv2d layer fusion Enabled clml supported nn.pad+nn.conv2d fusion pattern in clml pattern table --- python/tvm/relay/op/contrib/clml.py | 29 ++++++++++++++++++++++ src/relay/backend/contrib/clml/codegen.cc | 2 +- tests/python/contrib/test_clml/test_ops.py | 2 +- 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/contrib/clml.py b/python/tvm/relay/op/contrib/clml.py index c3d4eb84700d..6b37686d1457 100644 --- a/python/tvm/relay/op/contrib/clml.py +++ b/python/tvm/relay/op/contrib/clml.py @@ -22,6 +22,7 @@ from tvm._ffi import register_func from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name +from tvm.relay.expr import Call, Var, Constant from ...dataflow_pattern import wildcard, is_op, is_constant, is_tuple_get_item, is_tuple from .register import register_pattern_table @@ -147,6 +148,23 @@ def conv_pattern(): pattern = pattern.optional(is_op("clip")) return pattern + def pad_conv_pattern(): + """Create a pad with convolution pattern.""" + pattern = is_op("nn.pad")(wildcard(), is_constant()) + pattern = is_op("nn.conv2d")(pattern, is_constant()) + pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant())) + pattern = pattern.optional(lambda x: is_op("add")(x, is_constant())) + pattern = pattern.optional( + lambda x: is_tuple_get_item( + is_op("nn.batch_norm")( + x, is_constant(), is_constant(), is_constant(), is_constant() + ) + ) + ) + pattern = pattern.optional(is_op("nn.relu")) + pattern = pattern.optional(is_op("clip")) + return pattern + def batch_norm_pattern(): """Create a batch norm pattern.""" pattern = is_op("nn.batch_norm")( @@ -200,9 +218,18 @@ def check_conv(extract): while call.op.name != "nn.conv2d": call = call.args[0] + attrs, args = call.attrs, call.args if attrs.data_layout != "NCHW": return False + + if( + (isinstance(args[0], (Var, Constant)) == False) + and (args[0].op.name == "nn.pad") + and (len(args[0].attrs["pad_width"]) != 4) + ): + return False + if ( (not clip_found) and (attrs.kernel_size[0] == 3) @@ -211,6 +238,7 @@ def check_conv(extract): and (attrs.channels == attrs.groups) ): return False + data_typ = args[0].checked_type kernel_typ = args[1].checked_type is_depthwise = is_depthwise_conv2d( @@ -246,6 +274,7 @@ def check_default_op(extract): return True return [ + ("clml.pad_conv2d", pad_conv_pattern(), check_conv), ("clml.conv2d", conv_pattern(), check_conv), ("clml.dense", dense_pattern(), check_default_op), ("clml.pad", pad_pattern(), check_pad_op), diff --git a/src/relay/backend/contrib/clml/codegen.cc b/src/relay/backend/contrib/clml/codegen.cc index 9ecec0c4531f..167c48e1baf5 100644 --- a/src/relay/backend/contrib/clml/codegen.cc +++ b/src/relay/backend/contrib/clml/codegen.cc @@ -83,7 +83,7 @@ class CLMLJSONSerializer : public backend::contrib::JSONSerializer { ICHECK(comp.defined()) << "CLML JSON runtime only supports composite functions."; const std::string name = comp.value(); std::shared_ptr json_node; - if (name == "clml.conv2d") { + if (name == "clml.conv2d" || name == "clml.pad_conv2d") { json_node = CreateCompositeConvJSONNode(cn); } else if (name == "clml.batch_norm") { json_node = CreateBatchNormJSONNode(cn); diff --git a/tests/python/contrib/test_clml/test_ops.py b/tests/python/contrib/test_clml/test_ops.py index d2431d2dfd3b..75a11656748b 100644 --- a/tests/python/contrib/test_clml/test_ops.py +++ b/tests/python/contrib/test_clml/test_ops.py @@ -97,7 +97,7 @@ def test_conv2d(device, dtype): trials = [ # Normal convolution [3, 3, (1, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, False, False)], - [2, 1, (2, 2), (1, 1), (1, 1), 7, (15, 16, 12), (False, False, True)], + [2, 1, (2, 2), (1, 1), (1, 1), 7, (15, 16, 12), (True, False, True)], [3, 3, (2, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, True, False)], [3, 3, (2, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, True, True)], # Normal convolution From e87d68547a06e56a656c5ad0e86092a64c85aa86 Mon Sep 17 00:00:00 2001 From: kvegiraj Date: Tue, 20 Dec 2022 22:59:23 -0800 Subject: [PATCH 2/6] Fix pad testcase attributes --- tests/python/contrib/test_clml/test_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/contrib/test_clml/test_ops.py b/tests/python/contrib/test_clml/test_ops.py index 75a11656748b..da09715fbe4c 100644 --- a/tests/python/contrib/test_clml/test_ops.py +++ b/tests/python/contrib/test_clml/test_ops.py @@ -45,7 +45,7 @@ def _get_conv_model( a = relay.var(next(iter(var)), shape=shape, dtype=dtype) input_arr = var[next(iter(var))] if has_pad: - p = ((0, 0), (padding[0], padding[0]), (padding[1], padding[1]), (0, 0)) + p = ((0, 0), (0, 0), (padding[0], padding[0]), (padding[1], padding[1])) a = relay.nn.pad(a, pad_width=p) padding = (0, 0, 0, 0) else: From ae84e10b770beea4f5cfe25fbc368fe205656203 Mon Sep 17 00:00:00 2001 From: kvegiraj Date: Tue, 20 Dec 2022 23:59:45 -0800 Subject: [PATCH 3/6] Fix the lint error --- python/tvm/relay/op/contrib/clml.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/op/contrib/clml.py b/python/tvm/relay/op/contrib/clml.py index 6b37686d1457..106cd5d4245f 100644 --- a/python/tvm/relay/op/contrib/clml.py +++ b/python/tvm/relay/op/contrib/clml.py @@ -224,9 +224,9 @@ def check_conv(extract): return False if( - (isinstance(args[0], (Var, Constant)) == False) - and (args[0].op.name == "nn.pad") - and (len(args[0].attrs["pad_width"]) != 4) + (not isinstance(args[0], (Var, Constant))) + and (args[0].op.name == "nn.pad") + and (len(args[0].attrs["pad_width"]) != 4) ): return False From b7c0b3dfd7ef775bea430167d57589c33dff7b1e Mon Sep 17 00:00:00 2001 From: kvegiraj Date: Wed, 21 Dec 2022 00:09:17 -0800 Subject: [PATCH 4/6] Fix the lint error --- python/tvm/relay/op/contrib/clml.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/contrib/clml.py b/python/tvm/relay/op/contrib/clml.py index 106cd5d4245f..2acab7eff41b 100644 --- a/python/tvm/relay/op/contrib/clml.py +++ b/python/tvm/relay/op/contrib/clml.py @@ -22,7 +22,7 @@ from tvm._ffi import register_func from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name -from tvm.relay.expr import Call, Var, Constant +from tvm.relay.expr import Var, Constant from ...dataflow_pattern import wildcard, is_op, is_constant, is_tuple_get_item, is_tuple from .register import register_pattern_table @@ -223,7 +223,7 @@ def check_conv(extract): if attrs.data_layout != "NCHW": return False - if( + if ( (not isinstance(args[0], (Var, Constant))) and (args[0].op.name == "nn.pad") and (len(args[0].attrs["pad_width"]) != 4) From b97732e8748aa336fb25ede969c340b84fa3ecf3 Mon Sep 17 00:00:00 2001 From: kvegiraj Date: Sun, 25 Dec 2022 21:34:06 -0800 Subject: [PATCH 5/6] Removed redundent check in clml pattern --- python/tvm/relay/op/contrib/clml.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/python/tvm/relay/op/contrib/clml.py b/python/tvm/relay/op/contrib/clml.py index 2acab7eff41b..9da9eedcef2f 100644 --- a/python/tvm/relay/op/contrib/clml.py +++ b/python/tvm/relay/op/contrib/clml.py @@ -223,13 +223,6 @@ def check_conv(extract): if attrs.data_layout != "NCHW": return False - if ( - (not isinstance(args[0], (Var, Constant))) - and (args[0].op.name == "nn.pad") - and (len(args[0].attrs["pad_width"]) != 4) - ): - return False - if ( (not clip_found) and (attrs.kernel_size[0] == 3) From 843898fd1144d3deb294424968749366ef37c420 Mon Sep 17 00:00:00 2001 From: kvegiraj Date: Sun, 25 Dec 2022 21:51:03 -0800 Subject: [PATCH 6/6] Fix the lint error --- python/tvm/relay/op/contrib/clml.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relay/op/contrib/clml.py b/python/tvm/relay/op/contrib/clml.py index 9da9eedcef2f..6453b8a06c9f 100644 --- a/python/tvm/relay/op/contrib/clml.py +++ b/python/tvm/relay/op/contrib/clml.py @@ -22,7 +22,6 @@ from tvm._ffi import register_func from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name -from tvm.relay.expr import Var, Constant from ...dataflow_pattern import wildcard, is_op, is_constant, is_tuple_get_item, is_tuple from .register import register_pattern_table