From 6d691f646996f43754f32050a4088562593ef250 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Tue, 7 Apr 2020 11:59:05 +0100 Subject: [PATCH 1/4] [RELAY][BYOC] Register pattern tables from external codegens This adds utility functions to support registering and retrieving pattern tables used by MergeComposite for external codegens. Change-Id: I5be165a321440e48b15ff6aff4970e0c67496aaa --- python/tvm/relay/op/contrib/__init__.py | 2 + python/tvm/relay/op/contrib/register.py | 49 ++++++++++++++++++++++++ tests/python/relay/test_pattern_table.py | 45 ++++++++++++++++++++++ 3 files changed, 96 insertions(+) create mode 100644 python/tvm/relay/op/contrib/register.py create mode 100644 tests/python/relay/test_pattern_table.py diff --git a/python/tvm/relay/op/contrib/__init__.py b/python/tvm/relay/op/contrib/__init__.py index 4b6acceb3a83..3a3f6d5aa304 100644 --- a/python/tvm/relay/op/contrib/__init__.py +++ b/python/tvm/relay/op/contrib/__init__.py @@ -16,4 +16,6 @@ # under the License. # pylint: disable=wildcard-import """Contrib modules.""" +from .register import get_pattern_table, register_pattern_table + from .dnnl import * diff --git a/python/tvm/relay/op/contrib/register.py b/python/tvm/relay/op/contrib/register.py new file mode 100644 index 000000000000..b82abdb88804 --- /dev/null +++ b/python/tvm/relay/op/contrib/register.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Register utilities for external codegen.""" +_PATTERN_TABLES = {} + + +def register_pattern_table(compiler, table=None): + """Register a pattern table for an external compiler. + + Pattern tables are used to create composite functions. + See the MergeComposite pass. + + Parameters + ---------- + compiler : str + The name of compiler + + table : function, optional + A function that returns the pattern table + + Returns + ------- + fregister : function + Register function if value is not specified. + """ + def _register(t): + """internal register function""" + _PATTERN_TABLES[compiler] = t() + return t + return _register(table) if table is not None else _register + + +def get_pattern_table(compiler): + """Get the pattern table associated with a compiler (if it's registered).""" + return _PATTERN_TABLES[compiler] if compiler in _PATTERN_TABLES else None diff --git a/tests/python/relay/test_pattern_table.py b/tests/python/relay/test_pattern_table.py new file mode 100644 index 000000000000..f5a5e4bd2d4f --- /dev/null +++ b/tests/python/relay/test_pattern_table.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit test for pattern table registry (BYOC).""" +from tvm.relay.op.contrib import get_pattern_table, register_pattern_table +from tvm import relay + + +@register_pattern_table("test_pattern_table") +def pattern_table(): + def _make_add_relu_pattern(): + x = relay.var('x') + y = relay.var('y') + add_node = relay.add(x, y) + r = relay.nn.relu(add_node) + return r + + def _check_add_relu_pattern(): + return True + + return [ + ("test_pattern_table.add_relu", _make_add_relu_pattern(), _check_add_relu_pattern) + ] + + +def test_retrieve_pattern_table(): + table = get_pattern_table("test_pattern_table") + assert table[0][0] == "test_pattern_table.add_relu" + + +if __name__ == "__main__": + test_retrieve_pattern_table() From 4582dbc60f5edad0c05fcd81adb64009e902c4eb Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Tue, 14 Apr 2020 12:28:38 +0100 Subject: [PATCH 2/4] Updated DNNL tests to use pattern table mechanism --- python/tvm/relay/op/contrib/dnnl.py | 22 +++++++++++++++++++ .../python/relay/test_pass_partition_graph.py | 19 +++------------- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 45a8c8331f72..412393eb291a 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -32,7 +32,9 @@ - The other way is to implement the function by themselves to check the attributes of the op and decide if it should be offloaded to DNNL. """ +from ... import expr as _expr from ... import op as reg +from .register import register_pattern_table def _register_external_op_helper(op_name, supported=True): @@ -63,3 +65,23 @@ def _func_wrapper(attrs, args): _register_external_op_helper("add") _register_external_op_helper("subtract") _register_external_op_helper("multiply") + + +def make_pattern(with_bias=True): + data = _expr.var("data") + weight = _expr.var("weight") + bias = _expr.var("bias") + conv = reg.nn.conv2d(data, weight) + if with_bias: + conv_out = reg.add(conv, bias) + else: + conv_out = conv + return reg.nn.relu(conv_out) + + +@register_pattern_table("dnnl") +def pattern_table(): + conv2d_bias_relu_pat = ("dnnl.conv2d_bias_relu", make_pattern(with_bias=True)) + conv2d_relu_pat = ("dnnl.conv2d_relu", make_pattern(with_bias=False)) + dnnl_patterns = [conv2d_bias_relu_pat, conv2d_relu_pat] + return dnnl_patterns diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 2ee8538e30ed..274d8631845b 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -19,7 +19,6 @@ import sys import numpy as np -import pytest import tvm import tvm.relay.testing @@ -30,6 +29,7 @@ from tvm.relay.backend import compile_engine from tvm.relay.expr_functor import ExprMutator from tvm.relay.op.annotation import compiler_begin, compiler_end +from tvm.relay.op.contrib.register import get_pattern_table from tvm.relay.build_module import bind_params_by_name @@ -831,21 +831,8 @@ def expected(): def test_dnnl_fuse(): - def make_pattern(with_bias=True): - data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32")) - weight = relay.var("weight") - bias = relay.var("bias") - conv = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3), - channels=8, padding=(1, 1)) - if with_bias: - conv_out = relay.add(conv, bias) - else: - conv_out = conv - return relay.nn.relu(conv_out) - - conv2d_bias_relu_pat = ("dnnl.conv2d_bias_relu", make_pattern(with_bias=True)) - conv2d_relu_pat = ("dnnl.conv2d_relu", make_pattern(with_bias=False)) - dnnl_patterns = [conv2d_bias_relu_pat, conv2d_relu_pat] + dnnl_patterns = get_pattern_table("dnnl") + conv2d_bias_relu_pat, conv2d_relu_pat = dnnl_patterns def get_blocks(prefix, data, in_channel, out_channel, include_bn=True, include_sigmoid=False): From 690f6bdbc3647add9b4109e4f1038d1f3eb772d7 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Tue, 14 Apr 2020 12:31:25 +0100 Subject: [PATCH 3/4] Removed pattern table standalone test --- tests/python/relay/test_pattern_table.py | 45 ------------------------ 1 file changed, 45 deletions(-) delete mode 100644 tests/python/relay/test_pattern_table.py diff --git a/tests/python/relay/test_pattern_table.py b/tests/python/relay/test_pattern_table.py deleted file mode 100644 index f5a5e4bd2d4f..000000000000 --- a/tests/python/relay/test_pattern_table.py +++ /dev/null @@ -1,45 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Unit test for pattern table registry (BYOC).""" -from tvm.relay.op.contrib import get_pattern_table, register_pattern_table -from tvm import relay - - -@register_pattern_table("test_pattern_table") -def pattern_table(): - def _make_add_relu_pattern(): - x = relay.var('x') - y = relay.var('y') - add_node = relay.add(x, y) - r = relay.nn.relu(add_node) - return r - - def _check_add_relu_pattern(): - return True - - return [ - ("test_pattern_table.add_relu", _make_add_relu_pattern(), _check_add_relu_pattern) - ] - - -def test_retrieve_pattern_table(): - table = get_pattern_table("test_pattern_table") - assert table[0][0] == "test_pattern_table.add_relu" - - -if __name__ == "__main__": - test_retrieve_pattern_table() From 826b8186cde1020e4342cf937e0e8030deaf9c4b Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Wed, 15 Apr 2020 13:35:39 +0100 Subject: [PATCH 4/4] Change reg to _op --- python/tvm/relay/op/contrib/dnnl.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 412393eb291a..71ef430ec9c6 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -33,7 +33,7 @@ check the attributes of the op and decide if it should be offloaded to DNNL. """ from ... import expr as _expr -from ... import op as reg +from ... import op as _op from .register import register_pattern_table @@ -51,7 +51,7 @@ def _register_external_op_helper(op_name, supported=True): f : callable A function that returns if the operator is supported by DNNL. """ - @reg.register(op_name, "target.dnnl") + @_op.register(op_name, "target.dnnl") def _func_wrapper(attrs, args): return supported @@ -71,12 +71,12 @@ def make_pattern(with_bias=True): data = _expr.var("data") weight = _expr.var("weight") bias = _expr.var("bias") - conv = reg.nn.conv2d(data, weight) + conv = _op.nn.conv2d(data, weight) if with_bias: - conv_out = reg.add(conv, bias) + conv_out = _op.add(conv, bias) else: conv_out = conv - return reg.nn.relu(conv_out) + return _op.nn.relu(conv_out) @register_pattern_table("dnnl")