diff --git a/tests/python/unittest/test_te_schedule.py b/tests/python/unittest/test_te_schedule.py index 69478b451893..ed224883478e 100644 --- a/tests/python/unittest/test_te_schedule.py +++ b/tests/python/unittest/test_te_schedule.py @@ -252,20 +252,25 @@ def intrin_func(ins, outs): assert ins[0].shape[0].value == n return tvm.tir.call_packed("vadd", ins[0].data, outs[0].data, ins[0].shape[0]) - intrin = te.decl_tensor_intrin(z.op, intrin_func) + intrin = te.decl_tensor_intrin(z.op, intrin_func, default_buffer_params={"offset_factor": n}) assert intrin.op == z.op assert intrin.reduce_init is None assert tuple(intrin.inputs) == tuple(z.op.input_tensors) assert intrin.buffers[0].shape[0].value == n m = 32 - x = te.placeholder((m,), name="x") - y = te.placeholder((m,), name="y") - z = te.compute(x.shape, lambda i: x[i] + y[i], name="z") - s = te.create_schedule(z.op) - xo, xi = s[z].split(z.op.axis[0], factor=n) - s[z].tensorize(xi, intrin) - assert s[z].iter_var_attrs[xi].tensor_intrin == intrin - assert s[z].iter_var_attrs[xi].iter_type == tvm.te.schedule.IterVar.Tensorized + X = te.placeholder((m,), name="X") + Y = te.placeholder((m,), name="Y") + Z = te.compute(X.shape, lambda i: X[i] + Y[i], name="Z") + s = te.create_schedule(Z.op) + xo, xi = s[Z].split(Z.op.axis[0], factor=n) + s[Z].tensorize(xi, intrin) + stmt = tvm.lower(s, [X, Y, Z])["main"].body + assert isinstance(stmt.body, tvm.tir.Evaluate) + assert str(stmt.body.value.args[0]) == '"vadd"' + assert str(stmt.body.value.args[1]) == "X" + assert str(stmt.body.value.args[2]) == "Z" + assert s[Z].iter_var_attrs[xi].tensor_intrin == intrin + assert s[Z].iter_var_attrs[xi].iter_type == tvm.te.schedule.IterVar.Tensorized def test_tensor_intrin_scalar_params():