Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,11 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos):
from ... import op
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op)

# A placeholder to have at least one invocation of register legalize to register FTVMLegalize.
@reg.register_legalize("nn.conv2d")
def legalize_conv2d(attrs, inputs, arg_dtypes):
return None
"""Legalize conv2d"""
from ... import op
return topi.nn.conv2d_legalize(attrs, inputs, arg_dtypes, op)

reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)

Expand Down
46 changes: 46 additions & 0 deletions tests/python/relay/test_pass_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
# specific language governing permissions and limitations
# under the License.
"""Test legalize pass"""
import numpy as np
import tvm

from tvm import relay
from tvm.contrib import graph_runtime
from tvm.relay.op import register_legalize
from tvm.relay import transform, analysis

Expand Down Expand Up @@ -123,8 +125,52 @@ def expected():

assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)

def test_legalize_arm_layout_functional():
"""Test if the legalized conversion yields same result as original"""
def get_output(func, data_val, parameters):
with relay.build_config(opt_level=0):
graph, lib, params = relay.build(func, target='llvm', params=parameters)
m = graph_runtime.create(graph, lib, tvm.cpu())
m.set_input("data", data_val)
m.set_input(**params)
m.run()
out = m.get_output(0, tvm.nd.empty((1, 224, 224, 32), 'float32')).asnumpy()
return out

def before():
n, ic, ih, iw, oc, kh, kw = 1, 16, 224, 224, 32, 3, 3
data = relay.var("data", relay.TensorType((n, ih, iw, ic), 'float32'))
kernel = relay.var("kernel", relay.TensorType((kh, kw, ic, oc), 'float32'))
y = relay.nn.conv2d(data, kernel,
kernel_size=(kh, kw),
channels=oc,
padding=(1, 1),
dilation=(1, 1),
data_layout='NHWC',
kernel_layout='HWIO',
out_dtype='float32')
func = relay.Function([data, kernel], y)
return func

@register_legalize("nn.conv2d", level=101)
def legalize_conv2d(attrs, inputs, arg_types):
from topi.arm_cpu.conv2d import _conv2d_legalize
return _conv2d_legalize(attrs, inputs, arg_types, tvm.relay.op)

a = before()
b = run_opt_pass(a, transform.Legalize())
assert b.astext().count('transpose') == 3

wdata = np.random.rand(3, 3, 16, 32) * 10
parameters = {"kernel": tvm.nd.array(wdata.astype('float32'))}
data_val = np.random.rand(1, 224, 224, 16).astype('float32')
ref_out = get_output(a, data_val, parameters)
legalized_out = get_output(b, data_val, parameters)
np.testing.assert_allclose(ref_out, legalized_out, rtol=0.01)


if __name__ == "__main__":
test_legalize()
test_legalize_none()
test_legalize_multi_input()
test_legalize_arm_layout_functional()
31 changes: 31 additions & 0 deletions topi/python/topi/arm_cpu/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
conv2d_winograd_without_weight_transform, \
conv2d_winograd_nnpack_without_weight_transform, \
depthwise_conv2d_nchw
from ..nn import conv2d_legalize
from ..nn.util import get_const_int, get_pad_tuple
from ..nn.winograd_util import winograd_transform_matrices

Expand Down Expand Up @@ -783,3 +784,33 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
# currently we only have contrib_spatial_pack and direct template
# add more schedule templates.
return None

@conv2d_legalize.register("arm_cpu")
def _conv2d_legalize(attrs, inputs, arg_types, F):
if F.__name__ != 'tvm.relay.op':
return None
if attrs['data_layout'] == 'NHWC':
data, kernel = inputs
if attrs['kernel_layout'] == 'HWIO':
# Handle HWIO layout. This is common in TF graph.
kernel = F.transpose(kernel, axes=(3, 2, 0, 1))
elif attrs['kernel_layout'] == 'HWOI':
# Handle HWOI layout. This is common in TF depthwise conv2d graph.
kernel = F.transpose(kernel, axes=(2, 3, 0, 1))
elif attrs['kernel_layout'] != 'OIHW':
return None

warnings.warn("Legalize arm_cpu - NHWC schedule absent. Inserting layout transforms to "
+ "fallback to NCHW. This can result in performance degradation.")
# Set new attrs for the tranposed conv.
new_attrs = {k: attrs[k] for k in attrs.keys()}
new_attrs['data_layout'] = 'NCHW'
new_attrs['kernel_layout'] = 'OIHW'

# Convert from NHWC to NCHW.
data = F.transpose(data, axes=(0, 3, 1, 2))
conv = F.nn.conv2d(data, kernel, **new_attrs)
# Convert back to original NHWC layout.
out = F.transpose(conv, axes=(0, 2, 3, 1))
return out
return None
22 changes: 22 additions & 0 deletions topi/python/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,28 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N
raise ValueError("not support this layout {} yet".format(layout))


@tvm.target.generic_func
def conv2d_legalize(attrs, inputs, arg_dtypes, F):
"""Legalizes Conv2D op.
Parameters
----------
attrs : nnvm.top.AttrDict or tvm.attrs.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized.
arg_dtypes : list of types
List of types of input arguments
F: symbol
The context, can be either nnvm.sym or relay.op
Note
----
Unlike other TOPI functions, this function operates on both graph level and operator level,
so we have to pass 'F' to make it support our two versions of graph IR, NNVM and Relay.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anijain2305 why are NNVM and F argument still mentioned here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, thanks for pointing out. I copied the description from other functions which have this comment. Will send a separate PR to clean up the comments.

"""
# not to change by default
return None


@tvm.target.generic_func
def conv2d_alter_layout(attrs, inputs, tinfos, F):
"""Change Conv2D layout.
Expand Down