-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Description
When doing AlterLayout pass on a conv followed by a strided_slice from NCHW4c to NCHW, the compiler does nothing to strided_slice, while (I think) the only correct behavior should be wrapping it with two layout_transforms. This leads to incorrect numerical result/crash at InferType, depending on the concrete input shape,
as shown in the following code snippet:
import tvm
from tvm import relay
from tvm.relay import transform
from tvm.relay.testing.temp_op_attr import TempOpAttr
import numpy as np
def test1(x_shape, w_shape):
def before():
x = relay.var("x", shape=x_shape)
weight = relay.var("weight", shape=w_shape)
y = relay.nn.conv2d(
x,
weight,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW4c",
kernel_layout="OIHW4i4o",
)
y = relay.strided_slice(y, begin=[0, 0], end=[1, -1], strides=[1, 8])
y = relay.Function([x, weight], y)
return tvm.IRModule.from_expr(y)
def alter_conv2d(attrs, inputs, tinfos, out_type):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs["data_layout"] = "NCHW"
new_attrs["kernel_layout"] = "OIHW"
return relay.nn.conv2d(data, weight, **new_attrs)
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
be = transform.InferType()(before())
print('='*40, 'before', '='*40)
print(be)
af = transform.AlterOpLayout()(be)
print('='*40, 'after', '='*40)
print(af)
xnp = np.random.rand(*x_shape).astype(np.float32)
wnp = np.random.rand(*w_shape).astype(np.float32)
be_res = relay.create_executor("debug", be).evaluate()(xnp, wnp).numpy()
af_res = relay.create_executor("debug", af).evaluate()(xnp, wnp).numpy()
tvm.testing.assert_allclose(be_res, af_res, rtol=1e-3, atol=1e-3)
test1(x_shape=(1, 1, 1, 1, 4), w_shape=(9, 1, 3, 3, 4, 4)) # incorrect numerical result
# test1(x_shape=(1, 1, 1, 1, 4), w_shape=(11, 1, 3, 3, 4, 4)) # crash at InferTypeThe module before:
def @main(%x: Tensor[(1, 1, 1, 1, 4), float32], %weight: Tensor[(9, 1, 3, 3, 4, 4), float32]) -> Tensor[(1, 1, 1, 1, 4), float32] {
%0 = nn.conv2d(%x, %weight, padding=[1, 1, 1, 1], kernel_size=[3, 3], data_layout="NCHW4c", kernel_layout="OIHW4i4o") /* ty=Tensor[(1, 9, 1, 1, 4), float32] */;
strided_slice(%0, begin=[0, 0], end=[1, -1], strides=[1, 8], axes=None) /* ty=Tensor[(1, 1, 1, 1, 4), float32] */
}and after:
def @main(%x: Tensor[(1, 1, 1, 1, 4), float32], %weight: Tensor[(9, 1, 3, 3, 4, 4), float32]) -> Tensor[(1, 1, 1, 1, 4), float32] {
%0 = layout_transform(%x, src_layout="NCHW4c", dst_layout="NCHW") /* ty=Tensor[(1, 4, 1, 1), float32] */;
%1 = layout_transform(%weight, src_layout="OIHW4i4o", dst_layout="OIHW") /* ty=Tensor[(36, 4, 3, 3), float32] */;
%2 = nn.conv2d(%0, %1, padding=[1, 1, 1, 1], kernel_size=[3, 3]) /* ty=Tensor[(1, 36, 1, 1), float32] */;
%3 = strided_slice(%2, begin=[0, 0], end=[1, -1], strides=[1, 8], axes=None) /* ty=Tensor[(1, 5, 1, 1), float32] */;
layout_transform(%3, src_layout="NCHW", dst_layout="NCHW4c") /* ty=Tensor[(1, 1, 1, 1, 4), float32] */
}Specifically, I am doing conv_NCHW4c_out[;,::8,...] (a 8-stride slice at the primal C dimension of NCHW4c). After altering layout into NCHW, the compiler does not wrap strided_slice with any layout_transformations nor adjust its attributes, so the semantic gets changed to conv_NCHW_out[:,::8,...], which means picking 1 every 8 elements, while what we need is to pick 4 elements every 4*8=32 elements for conv_NCHW_out
It seems that StridedSliceInferCorrectLayout is responsible for this.
BTW, the layout_transform seems weird in the latter IR:
%3 = strided_slice(%2, begin=[0, 0], end=[1, -1], strides=[1, 8], axes=None) /* ty=Tensor[(1, 5, 1, 1), float32] */;
layout_transform(%3, src_layout="NCHW", dst_layout="NCHW4c") /* ty=Tensor[(1, 1, 1, 1, 4), float32] */The resultant tensor has smaller shape (1,1,1,1,4) than the before-transform one (1,5,1,1), and the reason I think is that (1,5,1,1) is not a valid input to be converted to the layout of NCHW4c, and I thought layout_transform should be able to detect and reject that?
Environment
- TVM: commit e334942
- CUDA version: 10.0
- System: Ubuntu 16.04
- GCC 5.4
- Build options: -DUSE_RELAY_DEBUG=ON -DUSE_CUBLAS=ON -DUSE_LLVM=ON -DUSE_CUDA=ON -DUSE_CUDNN=ON