Skip to content

[Bug] Long lowering time after #13217 #13508

@comaniac

Description

@comaniac

Expected behavior

The lowering time of the given case should be around 10 seconds.

Actual behavior

The lowering time is more than 550 seconds.

Environment

Any environment with commit commit 101e3a4 (#13217) or later.

Steps to reproduce

The script:

import time

import tvm
from tvm import topi

class Timer:
    def __init__(self, msg):
        self.msg = msg
        print(f"{msg}...", flush=True)

    def __enter__(self):
        self.start = time.time()

    def __exit__(self, *args):
        print(f"{self.msg}...{time.time() - self.start:.2f}s", flush=True)

def resize2d_dx_compute(inp, dy):
    """compute definition for resize2d_dx op"""
    size = (64, 32)
    layout = "NCHW"
    method = "cubic"
    coord_trans = "half_pixel"
    rounding_method = ""
    cubic_alpha = -0.75
    cubic_exclude = 0
    out_dtype = "float32"

    out = topi.image.resize2d(
        inp,
        (None, None, None, None),
        size,
        layout,
        method,
        coord_trans,
        rounding_method,
        bicubic_alpha=cubic_alpha,
        bicubic_exclude=cubic_exclude,
        out_dtype=out_dtype,
    )
    grads = tvm.te.gradient(out, [inp], head=dy)
    return grads

inp = tvm.te.placeholder((32, 3, 32, 32), name="inp")
dy = tvm.te.placeholder((32, 3, 64, 32), name="dy")
with Timer("te.gradient"):
    grads = resize2d_dx_compute(inp, dy)

# This problem is platform-independent.
with Timer("schedule"):
    sch = topi.x86.injective.schedule_injective(grads)

with Timer("lower"):
    print(tvm.lower(sch, [inp, dy, grads[0]], simple_mode=True))
  1. Switch to a commit before 101e3a4 ([TIR][Transform] Optional data-flow analysis in RemoveNoOp #13217) and run the script.
  2. Checkout the commit 101e3a4 ([TIR][Transform] Optional data-flow analysis in RemoveNoOp #13217) and run again.

Here are also the lowered IR without and with this commit:

Without this commit:

@main = primfn(inp_1: handle, dy_1: handle, resize.inp.grad_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {inp: Buffer(inp_2: Pointer(float32), float32, [32, 3, 32, 32], []),
             dy: Buffer(dy_2: Pointer(float32), float32, [32, 3, 64, 32], []),
             resize.inp.grad: Buffer(resize.inp.grad_2: Pointer(float32), float32, [32, 3, 32, 32], [])}
  buffer_map = {inp_1: inp, dy_1: dy, resize.inp.grad_1: resize.inp.grad} {
  for (ax0.ax1.fused: int32, 0, 96) "parallel" {
    for (ax2: int32, 0, 32) {
      for (ax3.outer: int32, 0, 2) {
        resize.inp.grad_3: Buffer(resize.inp.grad_2, float32, [98304], [])[ramp((((ax0.ax1.fused*1024) + (ax2*32)) + (ax3.outer*16)), 1, 16)] = broadcast(0f32, 16)
        for (n0_n0_k2.shifted.shifted: int32, 0, 64) {
          for (n1_n1_k3.shifted.shifted: int32, 0, 32) {
            for (ax3.inner.s: int32, 0, 16) {
              let cse_var_3: float32 = cast(float32, n1_n1_k3.shifted.shifted)
              let cse_var_2: int32 = ((ax3.outer*16) + ax3.inner.s)
              let cse_var_1: float32 = (((cast(float32, n0_n0_k2.shifted.shifted) + 0.5f32)*0.5f32) - 0.5f32)
              if ((((((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0))) || (((ax2 - max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) || (((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) || (((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) {
                let cse_var_4: int32 = ((((ax0.ax1.fused*1024) + (ax2*32)) + (ax3.outer*16)) + ax3.inner.s)
                resize.inp.grad_3[cse_var_4] = (resize.inp.grad_3[cse_var_4] + (dy_3: Buffer(dy_2, float32, [196608], [])[(((ax0.ax1.fused*2048) + (n0_n0_k2.shifted.shifted*32)) + n1_n1_k3.shifted.shifted)]*(select((((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))) || (((((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))) || (((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))), (select(((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))) || (((((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))), (select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(-0.75f32*(((((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))) - (2f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) + (cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))), 0f32) + select((((((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(((1.25f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) - (2.25f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) + 1f32)), 0f32)), 0f32) + select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(((-1.25f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) + (1.5f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) - (-0.75f32*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))), 0f32)), 0f32) + select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*((0.75f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) + (-0.75f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))))), 0f32))))
              }
            }
          }
        }
      }
    }
  }
}

With this commit:

@main = primfn(inp_1: handle, dy_1: handle, resize.inp.grad_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {inp: Buffer(inp_2: Pointer(float32), float32, [32, 3, 32, 32], []),
             dy: Buffer(dy_2: Pointer(float32), float32, [32, 3, 64, 32], []),
             resize.inp.grad: Buffer(resize.inp.grad_2: Pointer(float32), float32, [32, 3, 32, 32], [])}
  buffer_map = {inp_1: inp, dy_1: dy, resize.inp.grad_1: resize.inp.grad} {
  for (ax0.ax1.fused: int32, 0, 96) "parallel" {
    for (ax2: int32, 0, 32) {
      for (ax3.outer: int32, 0, 2) {
        resize.inp.grad_3: Buffer(resize.inp.grad_2, float32, [98304], [])[ramp((((ax0.ax1.fused*1024) + (ax2*32)) + (ax3.outer*16)), 1, 16)] = broadcast(0f32, 16)
        for (n0_n0_k2.shifted.shifted: int32, 0, 64) {
          for (n1_n1_k3.shifted.shifted: int32, 0, 32) {
            for (ax3.inner.s: int32, 0, 16) {
              let cse_var_3: float32 = cast(float32, n1_n1_k3.shifted.shifted)
              let cse_var_2: int32 = ((ax3.outer*16) + ax3.inner.s)
              let cse_var_1: float32 = (((cast(float32, n0_n0_k2.shifted.shifted) + 0.5f32)*0.5f32) - 0.5f32)
              if ((((((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0))) || (((ax2 - max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) || (((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) || (((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) {
                let cse_var_4: int32 = ((((ax0.ax1.fused*1024) + (ax2*32)) + (ax3.outer*16)) + ax3.inner.s)
                resize.inp.grad_3[cse_var_4] = (resize.inp.grad_3[cse_var_4] + (dy_3: Buffer(dy_2, float32, [196608], [])[(((ax0.ax1.fused*2048) + (n0_n0_k2.shifted.shifted*32)) + n1_n1_k3.shifted.shifted)]*(select((((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))) || (((((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))) || (((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))), (select(((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))) || (((((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))), (select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(-0.75f32*(((((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))) - (2f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) + (cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))), 0f32) + select((((((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(((1.25f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) - (2.25f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) + 1f32)), 0f32)), 0f32) + select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(((-1.25f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) + (1.5f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) - (-0.75f32*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))), 0f32)), 0f32) + select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*((0.75f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) + (-0.75f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))))), 0f32))))
              }
            }
          }
        }
      }
    }
  }
}

The IRs are pretty much identical, so it may be due to the change of lowering passes.

cc @Lunderberg @masahi

Triage

  • needs-triage

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions