-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Closed
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
Expected behavior
The lowering time of the given case should be around 10 seconds.
Actual behavior
The lowering time is more than 550 seconds.
Environment
Any environment with commit commit 101e3a4 (#13217) or later.
Steps to reproduce
The script:
import time
import tvm
from tvm import topi
class Timer:
def __init__(self, msg):
self.msg = msg
print(f"{msg}...", flush=True)
def __enter__(self):
self.start = time.time()
def __exit__(self, *args):
print(f"{self.msg}...{time.time() - self.start:.2f}s", flush=True)
def resize2d_dx_compute(inp, dy):
"""compute definition for resize2d_dx op"""
size = (64, 32)
layout = "NCHW"
method = "cubic"
coord_trans = "half_pixel"
rounding_method = ""
cubic_alpha = -0.75
cubic_exclude = 0
out_dtype = "float32"
out = topi.image.resize2d(
inp,
(None, None, None, None),
size,
layout,
method,
coord_trans,
rounding_method,
bicubic_alpha=cubic_alpha,
bicubic_exclude=cubic_exclude,
out_dtype=out_dtype,
)
grads = tvm.te.gradient(out, [inp], head=dy)
return grads
inp = tvm.te.placeholder((32, 3, 32, 32), name="inp")
dy = tvm.te.placeholder((32, 3, 64, 32), name="dy")
with Timer("te.gradient"):
grads = resize2d_dx_compute(inp, dy)
# This problem is platform-independent.
with Timer("schedule"):
sch = topi.x86.injective.schedule_injective(grads)
with Timer("lower"):
print(tvm.lower(sch, [inp, dy, grads[0]], simple_mode=True))- Switch to a commit before 101e3a4 ([TIR][Transform] Optional data-flow analysis in RemoveNoOp #13217) and run the script.
- Checkout the commit 101e3a4 ([TIR][Transform] Optional data-flow analysis in RemoveNoOp #13217) and run again.
Here are also the lowered IR without and with this commit:
Without this commit:
@main = primfn(inp_1: handle, dy_1: handle, resize.inp.grad_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
buffers = {inp: Buffer(inp_2: Pointer(float32), float32, [32, 3, 32, 32], []),
dy: Buffer(dy_2: Pointer(float32), float32, [32, 3, 64, 32], []),
resize.inp.grad: Buffer(resize.inp.grad_2: Pointer(float32), float32, [32, 3, 32, 32], [])}
buffer_map = {inp_1: inp, dy_1: dy, resize.inp.grad_1: resize.inp.grad} {
for (ax0.ax1.fused: int32, 0, 96) "parallel" {
for (ax2: int32, 0, 32) {
for (ax3.outer: int32, 0, 2) {
resize.inp.grad_3: Buffer(resize.inp.grad_2, float32, [98304], [])[ramp((((ax0.ax1.fused*1024) + (ax2*32)) + (ax3.outer*16)), 1, 16)] = broadcast(0f32, 16)
for (n0_n0_k2.shifted.shifted: int32, 0, 64) {
for (n1_n1_k3.shifted.shifted: int32, 0, 32) {
for (ax3.inner.s: int32, 0, 16) {
let cse_var_3: float32 = cast(float32, n1_n1_k3.shifted.shifted)
let cse_var_2: int32 = ((ax3.outer*16) + ax3.inner.s)
let cse_var_1: float32 = (((cast(float32, n0_n0_k2.shifted.shifted) + 0.5f32)*0.5f32) - 0.5f32)
if ((((((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0))) || (((ax2 - max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) || (((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) || (((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) {
let cse_var_4: int32 = ((((ax0.ax1.fused*1024) + (ax2*32)) + (ax3.outer*16)) + ax3.inner.s)
resize.inp.grad_3[cse_var_4] = (resize.inp.grad_3[cse_var_4] + (dy_3: Buffer(dy_2, float32, [196608], [])[(((ax0.ax1.fused*2048) + (n0_n0_k2.shifted.shifted*32)) + n1_n1_k3.shifted.shifted)]*(select((((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))) || (((((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))) || (((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))), (select(((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))) || (((((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))), (select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(-0.75f32*(((((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))) - (2f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) + (cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))), 0f32) + select((((((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(((1.25f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) - (2.25f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) + 1f32)), 0f32)), 0f32) + select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(((-1.25f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) + (1.5f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) - (-0.75f32*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))), 0f32)), 0f32) + select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*((0.75f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) + (-0.75f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))))), 0f32))))
}
}
}
}
}
}
}
}
With this commit:
@main = primfn(inp_1: handle, dy_1: handle, resize.inp.grad_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
buffers = {inp: Buffer(inp_2: Pointer(float32), float32, [32, 3, 32, 32], []),
dy: Buffer(dy_2: Pointer(float32), float32, [32, 3, 64, 32], []),
resize.inp.grad: Buffer(resize.inp.grad_2: Pointer(float32), float32, [32, 3, 32, 32], [])}
buffer_map = {inp_1: inp, dy_1: dy, resize.inp.grad_1: resize.inp.grad} {
for (ax0.ax1.fused: int32, 0, 96) "parallel" {
for (ax2: int32, 0, 32) {
for (ax3.outer: int32, 0, 2) {
resize.inp.grad_3: Buffer(resize.inp.grad_2, float32, [98304], [])[ramp((((ax0.ax1.fused*1024) + (ax2*32)) + (ax3.outer*16)), 1, 16)] = broadcast(0f32, 16)
for (n0_n0_k2.shifted.shifted: int32, 0, 64) {
for (n1_n1_k3.shifted.shifted: int32, 0, 32) {
for (ax3.inner.s: int32, 0, 16) {
let cse_var_3: float32 = cast(float32, n1_n1_k3.shifted.shifted)
let cse_var_2: int32 = ((ax3.outer*16) + ax3.inner.s)
let cse_var_1: float32 = (((cast(float32, n0_n0_k2.shifted.shifted) + 0.5f32)*0.5f32) - 0.5f32)
if ((((((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0))) || (((ax2 - max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) || (((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) || (((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) {
let cse_var_4: int32 = ((((ax0.ax1.fused*1024) + (ax2*32)) + (ax3.outer*16)) + ax3.inner.s)
resize.inp.grad_3[cse_var_4] = (resize.inp.grad_3[cse_var_4] + (dy_3: Buffer(dy_2, float32, [196608], [])[(((ax0.ax1.fused*2048) + (n0_n0_k2.shifted.shifted*32)) + n1_n1_k3.shifted.shifted)]*(select((((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))) || (((((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))) || (((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))), (select(((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))) || (((((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))), (select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(-0.75f32*(((((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))) - (2f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) + (cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))), 0f32) + select((((((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(((1.25f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) - (2.25f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) + 1f32)), 0f32)), 0f32) + select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(((-1.25f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) + (1.5f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) - (-0.75f32*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))), 0f32)), 0f32) + select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*((0.75f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) + (-0.75f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))))), 0f32))))
}
}
}
}
}
}
}
}
The IRs are pretty much identical, so it may be due to the change of lowering passes.
Triage
- needs-triage
masahi, Lunderberg and yzh119
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug