Skip to content

[Bug] AlterLayout doesn't correctly wrap strided_slice with layout_transforms #8759

@lazycal

Description

@lazycal

When doing AlterLayout pass on a conv followed by a strided_slice from NCHW4c to NCHW, the compiler does nothing to strided_slice, while (I think) the only correct behavior should be wrapping it with two layout_transforms. This leads to incorrect numerical result/crash at InferType, depending on the concrete input shape,
as shown in the following code snippet:

import tvm
from tvm import relay
from tvm.relay import transform
from tvm.relay.testing.temp_op_attr import TempOpAttr
import numpy as np

def test1(x_shape, w_shape):
    def before():
        x = relay.var("x", shape=x_shape)
        weight = relay.var("weight", shape=w_shape)
        y = relay.nn.conv2d(
            x,
            weight,
            kernel_size=(3, 3),
            padding=(1, 1),
            data_layout="NCHW4c",
            kernel_layout="OIHW4i4o",
        )
        y = relay.strided_slice(y, begin=[0, 0], end=[1, -1], strides=[1, 8])
        y = relay.Function([x, weight], y)
        return tvm.IRModule.from_expr(y)

    def alter_conv2d(attrs, inputs, tinfos, out_type):
        data, weight = inputs
        new_attrs = dict(attrs)
        new_attrs["data_layout"] = "NCHW"
        new_attrs["kernel_layout"] = "OIHW"
        return relay.nn.conv2d(data, weight, **new_attrs)

    with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
        be = transform.InferType()(before())
        print('='*40, 'before', '='*40)
        print(be)
        af = transform.AlterOpLayout()(be)
        print('='*40, 'after', '='*40)
        print(af)
        xnp = np.random.rand(*x_shape).astype(np.float32)
        wnp = np.random.rand(*w_shape).astype(np.float32)
        be_res = relay.create_executor("debug", be).evaluate()(xnp, wnp).numpy()
        af_res = relay.create_executor("debug", af).evaluate()(xnp, wnp).numpy()
        tvm.testing.assert_allclose(be_res, af_res, rtol=1e-3, atol=1e-3)

test1(x_shape=(1, 1, 1, 1, 4), w_shape=(9, 1, 3, 3, 4, 4)) # incorrect numerical result
# test1(x_shape=(1, 1, 1, 1, 4), w_shape=(11, 1, 3, 3, 4, 4)) # crash at InferType

The module before:

def @main(%x: Tensor[(1, 1, 1, 1, 4), float32], %weight: Tensor[(9, 1, 3, 3, 4, 4), float32]) -> Tensor[(1, 1, 1, 1, 4), float32] {
  %0 = nn.conv2d(%x, %weight, padding=[1, 1, 1, 1], kernel_size=[3, 3], data_layout="NCHW4c", kernel_layout="OIHW4i4o") /* ty=Tensor[(1, 9, 1, 1, 4), float32] */;
  strided_slice(%0, begin=[0, 0], end=[1, -1], strides=[1, 8], axes=None) /* ty=Tensor[(1, 1, 1, 1, 4), float32] */
}

and after:

def @main(%x: Tensor[(1, 1, 1, 1, 4), float32], %weight: Tensor[(9, 1, 3, 3, 4, 4), float32]) -> Tensor[(1, 1, 1, 1, 4), float32] {
  %0 = layout_transform(%x, src_layout="NCHW4c", dst_layout="NCHW") /* ty=Tensor[(1, 4, 1, 1), float32] */;
  %1 = layout_transform(%weight, src_layout="OIHW4i4o", dst_layout="OIHW") /* ty=Tensor[(36, 4, 3, 3), float32] */;
  %2 = nn.conv2d(%0, %1, padding=[1, 1, 1, 1], kernel_size=[3, 3]) /* ty=Tensor[(1, 36, 1, 1), float32] */;
  %3 = strided_slice(%2, begin=[0, 0], end=[1, -1], strides=[1, 8], axes=None) /* ty=Tensor[(1, 5, 1, 1), float32] */;
  layout_transform(%3, src_layout="NCHW", dst_layout="NCHW4c") /* ty=Tensor[(1, 1, 1, 1, 4), float32] */
}

Specifically, I am doing conv_NCHW4c_out[;,::8,...] (a 8-stride slice at the primal C dimension of NCHW4c). After altering layout into NCHW, the compiler does not wrap strided_slice with any layout_transformations nor adjust its attributes, so the semantic gets changed to conv_NCHW_out[:,::8,...], which means picking 1 every 8 elements, while what we need is to pick 4 elements every 4*8=32 elements for conv_NCHW_out

It seems that StridedSliceInferCorrectLayout is responsible for this.

BTW, the layout_transform seems weird in the latter IR:

%3 = strided_slice(%2, begin=[0, 0], end=[1, -1], strides=[1, 8], axes=None) /* ty=Tensor[(1, 5, 1, 1), float32] */;
layout_transform(%3, src_layout="NCHW", dst_layout="NCHW4c") /* ty=Tensor[(1, 1, 1, 1, 4), float32] */

The resultant tensor has smaller shape (1,1,1,1,4) than the before-transform one (1,5,1,1), and the reason I think is that (1,5,1,1) is not a valid input to be converted to the layout of NCHW4c, and I thought layout_transform should be able to detect and reject that?

Environment

  • TVM: commit e334942
  • CUDA version: 10.0
  • System: Ubuntu 16.04
  • GCC 5.4
  • Build options: -DUSE_RELAY_DEBUG=ON -DUSE_CUBLAS=ON -DUSE_LLVM=ON -DUSE_CUDA=ON -DUSE_CUDNN=ON

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions