Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,31 @@
import pytest

import tvm
from tvm import tir, script
from tvm.script import ty

from tvm.script import tir as T
from tvm.tir import stmt_functor

# fmt: off
@tvm.script.tir
def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_30: ty.handle, placeholder_31: ty.handle, placeholder_32: ty.handle, T_cast_8: ty.handle) -> None:
@T.prim_func
def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_30: T.handle, placeholder_31: T.handle, placeholder_32: T.handle, T_cast_8: T.handle) -> None:
# function attr dict
tir.func_attr({"global_symbol": "fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", "tir.noalias": True})
placeholder_33 = tir.match_buffer(placeholder_30, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1)
placeholder_34 = tir.match_buffer(placeholder_31, [1, 1, 192, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1)
placeholder_35 = tir.match_buffer(placeholder_32, [1, 1, 1, 16], dtype="int32", elem_offset=0, align=128, offset_factor=1)
T_cast_9 = tir.match_buffer(T_cast_8, [1, 28, 28, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1)
T.func_attr({"global_symbol": "fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", "tir.noalias": True})
placeholder_33 = T.match_buffer(placeholder_30, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1)
placeholder_34 = T.match_buffer(placeholder_31, [1, 1, 192, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1)
placeholder_35 = T.match_buffer(placeholder_32, [1, 1, 1, 16], dtype="int32", elem_offset=0, align=128, offset_factor=1)
T_cast_9 = T.match_buffer(T_cast_8, [1, 28, 28, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1)
# body
PaddedInput_3 = tir.allocate([1, 28, 28, 192], "int16", "global")
for i0_i1_fused_3 in tir.parallel(0, 28):
for i2_3, i3_3 in tir.grid(28, 192):
tir.store(PaddedInput_3, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3), tir.load("int16", placeholder_33.data, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)), True)
for ax0_ax1_fused_ax2_fused_3 in tir.parallel(0, 784):
for ax3_2 in tir.serial(0, 16):
Conv2dOutput_3 = tir.allocate([1, 1, 1, 1], "int32", "global")
tir.store(Conv2dOutput_3, 0, 0, True)
for rc_3 in tir.serial(0, 192):
tir.store(Conv2dOutput_3, 0, (tir.load("int32", Conv2dOutput_3, 0) + (tir.cast(tir.load("int16", PaddedInput_3, ((ax0_ax1_fused_ax2_fused_3*192) + rc_3)), "int32")*tir.cast(tir.load("int16", placeholder_34.data, ((rc_3*16) + ax3_2)), "int32"))), True)
tir.store(T_cast_9.data, ((ax0_ax1_fused_ax2_fused_3*16) + ax3_2), tir.cast(tir.cast(tir.max(tir.min(tir.q_multiply_shift((tir.load("int32", Conv2dOutput_3, 0) + tir.load("int32", placeholder_35.data, ax3_2)), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True)
PaddedInput_3 = T.allocate([1, 28, 28, 192], "int16", "global")
for i0_i1_fused_3 in T.parallel(0, 28):
for i2_3, i3_3 in T.grid(28, 192):
T.store(PaddedInput_3, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3), T.load("int16", placeholder_33.data, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)), True)
for ax0_ax1_fused_ax2_fused_3 in T.parallel(0, 784):
for ax3_2 in T.serial(0, 16):
Conv2dOutput_3 = T.allocate([1, 1, 1, 1], "int32", "global")
T.store(Conv2dOutput_3, 0, 0, True)
for rc_3 in T.serial(0, 192):
T.store(Conv2dOutput_3, 0, (T.load("int32", Conv2dOutput_3, 0) + (T.cast(T.load("int16", PaddedInput_3, ((ax0_ax1_fused_ax2_fused_3*192) + rc_3)), "int32")*T.cast(T.load("int16", placeholder_34.data, ((rc_3*16) + ax3_2)), "int32"))), True)
T.store(T_cast_9.data, ((ax0_ax1_fused_ax2_fused_3*16) + ax3_2), T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_3, 0) + T.load("int32", placeholder_35.data, ax3_2)), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True)
# fmt: on


Expand Down