Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ class ConstIntBoundAnalyzer::Impl
return Union(a, b);
}

Entry VisitExpr_(const BroadcastNode* op) final { return VisitExpr(op->value); }

Entry VisitExpr_(const CastNode* op) final {
Entry a;

Expand Down
5 changes: 5 additions & 0 deletions src/target/source/codegen_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,11 @@ void CodeGenMetal::PrintStorageScope(const std::string& scope, std::ostream& os)
}
}

void CodeGenMetal::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLINT(*)
os << "select(" << PrintExpr(op->false_value) << ", " << PrintExpr(op->true_value) << ", "
<< PrintExpr(op->condition) << ")";
}

void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
std::string v = PrintExpr(op->value);
PrintType(op->dtype, os);
Expand Down
1 change: 1 addition & 0 deletions src/target/source/codegen_metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class CodeGenMetal final : public CodeGenC {
// print store of single element.
void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final;
// overload visitor
void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*)
Expand Down
2 changes: 2 additions & 0 deletions src/tir/transforms/flatten_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
}

private:
using IRMutatorWithAnalyzer::VisitExpr;
using IRMutatorWithAnalyzer::VisitExpr_;
using IRMutatorWithAnalyzer::VisitStmt;
using IRMutatorWithAnalyzer::VisitStmt_;

Expand Down
10 changes: 6 additions & 4 deletions src/tir/transforms/lower_intrin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
// in terms of truncdiv using only positive operands.
arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a);
if (const_int_bound->min_value < 0 &&
const_int_bound->min_value > -(Downcast<IntImm>(tvm::max_value(op->a->dtype))->value)) {
const_int_bound->min_value >
-(Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))->value)) {
// The goal is to write floordiv(a,b) in terms of truncdiv, without using
// negative operands.
//
Expand Down Expand Up @@ -150,7 +151,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
// floordiv(a,b)
// == floordiv(a + b*c, b) - c
// == truncdiv(a + b*c, b) - c
IntImm min(op->a->dtype, const_int_bound->min_value);
IntImm min(op->a->dtype.element_of(), const_int_bound->min_value);
PrimExpr ceildiv = truncdiv((op->b - 1) - min, op->b);
PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv);
return truncdiv(offset_numerator, op->b) - ceildiv;
Expand Down Expand Up @@ -214,7 +215,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
// in terms of truncmod using only positive operands.
arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a);
if (const_int_bound->min_value < 0 &&
const_int_bound->min_value > -(Downcast<IntImm>(tvm::max_value(op->a->dtype))->value)) {
const_int_bound->min_value >
-(Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))->value)) {
// The goal is to write floormod(a,b) in terms of truncdiv and truncmod,
// without using negative operands.
//
Expand Down Expand Up @@ -244,7 +246,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
// floormod(a,b)
// == floormod(a + b*c, b)
// == truncmod(a + b*c, b)
IntImm min(op->a->dtype, const_int_bound->min_value);
IntImm min(op->a->dtype.element_of(), const_int_bound->min_value);
PrimExpr ceildiv = truncdiv(-min + (op->b - 1), op->b);
PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv);
return truncmod(offset_numerator, op->b);
Expand Down
18 changes: 18 additions & 0 deletions tests/python/unittest/test_arith_const_int_bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,5 +349,23 @@ def test_multiple_condition():
assert bound.min_value == 0


def test_broadcast_bound():
analyzer = tvm.arith.Analyzer()
a = te.var("a")
analyzer.update(a, tvm.arith.ConstIntBound(0, 128))
bound = analyzer.const_int_bound(tvm.tir.Broadcast(a, 4))
assert bound.min_value == 0
assert bound.max_value == 128


def test_ramp_bound():
analyzer = tvm.arith.Analyzer()
a = te.var("a")
analyzer.update(a, tvm.arith.ConstIntBound(0, 128))
bound = analyzer.const_int_bound(tvm.tir.Ramp(a, 2, 4) + 2)
assert bound.min_value == 2
assert bound.max_value == 128 + 2 * 3 + 2


if __name__ == "__main__":
tvm.testing.main()
46 changes: 35 additions & 11 deletions tests/python/unittest/test_target_codegen_metal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,10 @@
from tvm import te
import numpy as np

from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16
from tvm.contrib import nvcc
import tvm.testing
import tvm.script
from tvm.script import tir as T

tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")


@tvm.testing.requires_gpu
@tvm.testing.requires_metal
Expand All @@ -37,9 +32,11 @@ def check_inf_nan(dev, n, value, dtype):
A = te.placeholder((n,), name="A", dtype=dtype)
inf_value = tvm.tir.const(value, dtype=dtype)
C = te.compute((n,), lambda i: inf_value, name="C")
s = te.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], tx)
fun = tvm.build(s, [A, C], target)
prim_func = te.create_prim_func([A, C])
sch = tvm.tir.Schedule(prim_func)
(x,) = sch.get_loops(sch.get_block("C"))
sch.bind(x, "threadIdx.x")
fun = tvm.build(sch.mod, target=target)
a = tvm.nd.empty((n,), A.dtype, dev)
c = tvm.nd.empty((n,), A.dtype, dev)
# Only need to test compiling here
Expand Down Expand Up @@ -88,9 +85,11 @@ def test_metal_erf():
def check_erf(dev, n, dtype):
A = te.placeholder((n,), name="A", dtype=dtype)
C = te.compute(A.shape, lambda *i: te.erf(A(*i)), name="C")
s = te.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], tx)
fun = tvm.build(s, [A, C], target)
func = te.create_prim_func([A, C])
sch = tvm.tir.Schedule(func)
(x,) = sch.get_loops(sch.get_block("C"))
sch.bind(x, "threadIdx.x")
fun = tvm.build(sch.mod, target=target)
a = tvm.nd.empty((n,), A.dtype, dev)
c = tvm.nd.empty((n,), A.dtype, dev)
# Only need to test compiling here
Expand Down Expand Up @@ -125,6 +124,31 @@ def main(A: T.Buffer((1, 2), "int32")):
assert tuple(a_nd.numpy()[0, :]) == (0, 3)


@tvm.testing.requires_gpu
@tvm.testing.requires_metal
def test_select_vectorize():
@tvm.script.ir_module
class IRModule:
@T.prim_func
def main(A: T.Buffer((6), "float32"), B: T.Buffer((6,), "float32")):
T.func_attr({"global_symbol": "main"})
for i0_1 in T.thread_binding(3, thread="threadIdx.x"):
for i0_0 in T.vectorized(2):
with T.block("block"):
vi0 = T.axis.spatial(6, i0_0 * 3 + i0_1)
B[vi0] = T.Select((vi0 % 2) == 0, A[vi0], T.float32(0))

target = "metal"
dev = tvm.metal()
a = np.arange(6).astype("float32")
a_nd = tvm.nd.array(a, dev)
b_nd = tvm.nd.empty((6,), "float32", dev)
f = tvm.build(IRModule, target=target)
f(a_nd, b_nd)
a.reshape(3, 2)[:, 1] = 0
np.testing.assert_allclose(b_nd.numpy(), a, atol=1e-5, rtol=1e-5)


if __name__ == "__main__":
test_ramp()
test_metal_inf_nan()
Expand Down