-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Description
This is identified in the process of performing workspace calculation work, and can be seen in Depthwise Conv2D of quantized mobilenet_v1.
Conv2D of quantized mobilenet_v1 is producing the following Relay primitive function with fused elementwise operations that does not get really fused into a single loop, thus end up creating large feature maps. This is seen in the following generated relay and TIR primfuncs, which have these allocates:
allocate(PaddedInput, int16, [430592]);
allocate(DepthwiseConv2d, int32, [401408])
These two allocates create roughly 2.4 MB of data, moreover, the unusual cast operator at the end is making the inter-fused-operator tensors 16-bit wide where the model description states them to be 8-bits (this is taken from a single operator within mobilenet v1):
fn (%p0: Tensor[(1, 56, 56, 128), int16], %p1: Tensor[(3, 3, 128, 1), int16], %p2: Tensor[(1, 1, 1, 128), int32], Primitive=1) -> Tensor[(1, 56, 56, 128), int16] {
%0 = nn.conv2d(%p0, %p1, padding=[1, 1, 1, 1], groups=128, channels=128, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="HWOI", out_dtype="int32") /* ty=Tensor[(1, 56, 56, 128), int32] */;
%1 = add(%0, %p2) /* ty=Tensor[(1, 56, 56, 128), int32] */;
%2 = fixed_point_multiply(%1, multiplier=2080045879, shift=-4) /* ty=Tensor[(1, 56, 56, 128), int32] */;
%3 = clip(%2, a_min=0f, a_max=255f) /* ty=Tensor[(1, 56, 56, 128), int32] */;
%4 = cast(%3, dtype="uint8") /* ty=Tensor[(1, 56, 56, 128), uint8] */;
cast(%4, dtype="int16") /* ty=Tensor[(1, 56, 56, 128), int16] */
}
This gets translated to the following TIR primfunc :
primfn(placeholder_3: handle, placeholder_4: handle, placeholder_5: handle, T_cast_1: handle) -> ()
attr = {"global_symbol": "fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_21", "tir.noalias": True}
buffers = {T_cast: Buffer(T_cast_2: Pointer(int16), int16, [1, 56, 56, 128], []),
placeholder_2: Buffer(placeholder_6: Pointer(int32), int32, [1, 1, 1, 128], []),
placeholder: Buffer(placeholder_7: Pointer(int16), int16, [1, 56, 56, 128], []),
placeholder_1: Buffer(placeholder_8: Pointer(int16), int16, [3, 3, 128, 1], [])}
buffer_map = {placeholder_3: placeholder, placeholder_4: placeholder_1, placeholder_5: placeholder_2, T_cast_1: T_cast} {
attr [PaddedInput: Pointer(int16)] "storage_scope" = "global";
allocate(PaddedInput, int16, [430592]);
attr [DepthwiseConv2d: Pointer(int32)] "storage_scope" = "global";
allocate(DepthwiseConv2d, int32, [401408]) {
for (i1: int32, 0, 58) {
for (i2: int32, 0, 58) {
for (i3: int32, 0, 128) {
PaddedInput[(((i1*7424) + (i2*128)) + i3)] = @tir.if_then_else(((((1 <= i1) && (i1 < 57)) && (1 <= i2)) && (i2 < 57)), (int16*)placeholder_7[((((i1*7168) + (i2*128)) + i3) - 7296)], 0i16, dtype=int16)
}
}
}
for (i: int32, 0, 56) {
for (j: int32, 0, 56) {
for (c: int32, 0, 128) {
DepthwiseConv2d[(((i*7168) + (j*128)) + c)] = 0
for (di: int32, 0, 3) {
for (dj: int32, 0, 3) {
DepthwiseConv2d[(((i*7168) + (j*128)) + c)] = ((int32*)DepthwiseConv2d[(((i*7168) + (j*128)) + c)] + (cast(int32, (int16*)PaddedInput[(((((i*7424) + (di*7424)) + (j*128)) + (dj*128)) + c)])*cast(int32, (int16*)placeholder_8[(((di*384) + (dj*128)) + c)])))
}
}
}
}
}
for (ax1: int32, 0, 56) {
for (ax2: int32, 0, 56) {
for (ax3: int32, 0, 128) {
DepthwiseConv2d[(((ax1*7168) + (ax2*128)) + ax3)] = ((int32*)DepthwiseConv2d[(((ax1*7168) + (ax2*128)) + ax3)] + (int32*)placeholder_6[ax3])
}
}
}
for (i1_1: int32, 0, 56) {
for (i2_1: int32, 0, 56) {
for (i3_1: int32, 0, 128) {
DepthwiseConv2d[(((i1_1*7168) + (i2_1*128)) + i3_1)] = @tir.q_multiply_shift((int32*)DepthwiseConv2d[(((i1_1*7168) + (i2_1*128)) + i3_1)], 2080045879, 31, -4, dtype=int32)
}
}
}
for (i1_2: int32, 0, 56) {
for (i2_2: int32, 0, 56) {
for (i3_2: int32, 0, 128) {
DepthwiseConv2d[(((i1_2*7168) + (i2_2*128)) + i3_2)] = max(min((int32*)DepthwiseConv2d[(((i1_2*7168) + (i2_2*128)) + i3_2)], 255), 0)
}
}
}
for (ax1_1: int32, 0, 56) {
for (ax2_1: int32, 0, 56) {
for (ax3_1: int32, 0, 128) {
PaddedInput[(((ax1_1*7168) + (ax2_1*128)) + ax3_1)] = cast(uint8, (int32*)DepthwiseConv2d[(((ax1_1*7168) + (ax2_1*128)) + ax3_1)])
}
}
}
for (ax1_2: int32, 0, 56) {
for (ax2_2: int32, 0, 56) {
for (ax3_2: int32, 0, 128) {
T_cast_2[(((ax1_2*7168) + (ax2_2*128)) + ax3_2)] = cast(int16, (uint8*)PaddedInput[(((ax1_2*7168) + (ax2_2*128)) + ax3_2)])
}
}
}
}
}