From 4f15316344c88d674eac1f5737cac2767c1d32b2 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 7 Oct 2019 22:14:27 +0000 Subject: [PATCH] [AlterOpLayout][x86] NHWC to NCHWc conv support. --- topi/python/topi/x86/conv2d_alter_op.py | 28 ++++++++++++------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/topi/python/topi/x86/conv2d_alter_op.py b/topi/python/topi/x86/conv2d_alter_op.py index a6333d8940f1..aec3efc7a86c 100644 --- a/topi/python/topi/x86/conv2d_alter_op.py +++ b/topi/python/topi/x86/conv2d_alter_op.py @@ -158,21 +158,19 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F): return None return F.nn.contrib_conv2d_nchwc_int8(*copy_inputs, **new_attrs) - if data_layout == 'NCHW' and attrs['kernel_layout'] == 'OIHW': - # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc) - new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn) - # Store altered operator's config - new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, - kh, kw, ic_bn, oc_bn), dtype=kernel_tensor.dtype) - new_workload = autotvm.task.args_to_workload( - [new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name], - new_attrs['out_layout'], out_dtype], conv2d_NCHWc) - dispatch_ctx.update(target, new_workload, cfg) - - if F.__name__ == 'nnvm.symbol': - return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) - return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs) - return None + # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc) + new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn) + # Store altered operator's config + new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, + kh, kw, ic_bn, oc_bn), dtype=kernel_tensor.dtype) + new_workload = autotvm.task.args_to_workload( + [new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name], + new_attrs['out_layout'], out_dtype], conv2d_NCHWc) + dispatch_ctx.update(target, new_workload, cfg) + + if F.__name__ == 'nnvm.symbol': + return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) + return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs) @conv2d_legalize.register("cpu")