Skip to content

[Bug] Missing predicate guarding reduction init for a tensor scheduled with compute_at #9598

@lazycal

Description

@lazycal
import tvm
from tvm import te
import numpy as np
import tvm.testing

F = 100
N = F + 1
A = te.placeholder((N, N), name="A")
k = te.reduce_axis((0, N), name="k")
B = te.compute((N,), lambda i: te.sum(A[i, k], k), name="B")
C = te.compute((N,), lambda i: B[i], name="C")

s = te.create_schedule(C.op)
xo, xi = s[C].split(C.op.axis[0], factor=F)
s[B].compute_at(s[C], xi)

foo = tvm.build(s, [A, B, C], "llvm")
print(tvm.lower(s, [A, B, C], simple_mode=True))

anp = tvm.nd.array(np.random.uniform(
    size=(N, N)).astype(A.dtype), tvm.cpu())
bnp = tvm.nd.array(np.random.uniform(
    size=(N,)).astype(A.dtype), tvm.cpu())
cnp = tvm.nd.array(np.random.uniform(
    size=(N,)).astype(A.dtype), tvm.cpu())
foo(anp, bnp, cnp)
tvm.testing.assert_allclose(bnp.asnumpy(), cnp.asnumpy())

This triggers segmentation fault. The produced IR is

@main = primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {C: Buffer(C_2: Pointer(float32), float32, [101], []),
             A: Buffer(A_2: Pointer(float32), float32, [101, 101], []),
             B: Buffer(B_2: Pointer(float32), float32, [101], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C} {
  for (i.outer: int32, 0, 2) {
    for (i.inner: int32, 0, 100) {
      B_2[((i.outer*100) + i.inner)] = 0f32
      if @tir.likely((((i.outer*100) + i.inner) < 101), dtype=bool) {
        for (k: int32, 0, 101) {
          B_2[((i.outer*100) + i.inner)] = ((float32*)B_2[((i.outer*100) + i.inner)] + (float32*)A_2[(((i.outer*10100) + (i.inner*101)) + k)])
        }
      }
      if @tir.likely((((i.outer*100) + i.inner) < 101), dtype=bool) {
        C_2[((i.outer*100) + i.inner)] = (float32*)B_2[((i.outer*100) + i.inner)]
      }
    }
  }
}

where B_2[((i.outer*100) + i.inner)] = 0f32 isn't wrapped with the predicate as in the reduction body.

Investigation

The problem can be solved if we do not skip the bound check by replacing !stage->rolling_buffer with false in

ret.init_predicates =
MakeBoundCheck(stage, dom_map, ret.init_vmap, !stage->rolling_buffer, skip_iter);
. However, I'm not sure if this is the right fix as I am having trouble understanding the logic of bound checking. The part that confuses me is why the reduction body does not skip the bound checks (shown in
ret.main_predicates =
MakeBoundCheck(stage, dom_map, ret.main_vmap, false, std::unordered_set<IterVar>());
) but the init skips it.

I see that there are two types (L550-L560 and L561-577) of bound checks in the MakeBoundCheck function

std::vector<PrimExpr> MakeBoundCheck(const Stage& stage, const Map<IterVar, Range>& dom_map,
const std::unordered_map<IterVar, PrimExpr>& value_map,
bool skip_ivar_domain,
const std::unordered_set<IterVar>& skip_iter) {
arith::Analyzer analyzer;
std::unordered_map<IterVar, bool> bound_state;
for (IterVar iv : stage->leaf_iter_vars) {
bound_state[iv] = false;
}
PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer);
std::vector<PrimExpr> preds;
Map<Var, IntSet> iset_dmap;
// setup domain map for set analysis
for (const auto& kv : dom_map) {
iset_dmap.Set(kv.first->var, IntSet::FromRange(kv.second));
}
for (auto entry : dom_map) {
analyzer.Bind(entry.first->var, entry.second);
}
for (const IterVar& iv : stage->all_iter_vars) {
if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
if (bound_state.at(iv)) {
Range dom = dom_map.at(iv);
PrimExpr value = value_map.at(iv) - dom->min;
PrimExpr vmax = analyzer.int_set(value, iset_dmap).max();
if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < dom->extent)) {
preds.emplace_back(value < dom->extent);
}
}
}
for (const IterVar& iv : stage->op->root_iter_vars()) {
if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
Range dom = dom_map.at(iv);
ICHECK(iv->dom.defined());
if (!skip_ivar_domain && !IsRangeSame(iv->dom, dom)) {
PrimExpr value = value_map.at(iv) - iv->dom->min;
IntSet s = analyzer.int_set(value, iset_dmap);
PrimExpr vmin = s.min();
PrimExpr vmax = s.max();
// The range of `value` resides in [vmin, vmax]
if (vmin.dtype() != value.dtype() || !analyzer.CanProve(vmin >= 0)) {
preds.emplace_back(value >= 0);
}
if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < iv->dom->extent)) {
preds.emplace_back(value < iv->dom->extent);
}
}
}
return preds;
}
and passing false to skip_ivar_domain only disables the second one. But the first check seems not comprehensive: in the above code, due to the compute_at B's axis is "implicitly" binded to a split axis of C, but the first check cannot see the split relation. As a result PassUpBoundCheck doesn't mark it as needing checks. So I'm also curious whehter this is expected or not.

Environment

OS: Ubuntu 18.04
TVM Version: ecd8a9c

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions