Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ tvm_option(HIDE_PRIVATE_SYMBOLS "Compile with -fvisibility=hidden." OFF)
tvm_option(USE_TF_TVMDSOOP "Build with TensorFlow TVMDSOOp" OFF)
tvm_option(USE_FALLBACK_STL_MAP "Use TVM's POD compatible Map" OFF)
tvm_option(USE_ETHOSN "Build with Arm Ethos-N" OFF)
tvm_option(INDEX_DEFAULT_I64 "Defaults the index datatype to int64" ON)

# 3rdparty libraries
tvm_option(DLPACK_PATH "Path to DLPACK" "3rdparty/dlpack/include")
Expand Down Expand Up @@ -259,6 +260,10 @@ if(NOT USE_RTTI)
add_definitions(-DDMLC_ENABLE_RTTI=0)
endif()

if (INDEX_DEFAULT_I64)
add_definitions(-DTVM_INDEX_DEFAULT_I64=1)
endif()

list(APPEND RUNTIME_SRCS 3rdparty/bfloat16/bfloat16.cc)

if(USE_RPC)
Expand Down
1 change: 1 addition & 0 deletions cmake/modules/LibInfo.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ function(add_lib_info src_file)
TVM_INFO_USE_TARGET_ONNX="${USE_TARGET_ONNX}"
TVM_INFO_USE_ARM_COMPUTE_LIB="${USE_ARM_COMPUTE_LIB}"
TVM_INFO_USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME="${USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME}"
TVM_INFO_INDEX_DEFAULT_I64="${INDEX_DEFAULT_I64}"
)

endfunction()
6 changes: 4 additions & 2 deletions include/tvm/topi/detail/constant_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,10 @@ inline bool EqualCheck(PrimExpr lhs, PrimExpr rhs) {
tvm::tir::ExprDeepEqual expr_equal;
bool result = expr_equal(lhs, rhs);
if (!result) {
PrimExpr zero(0);
result = expr_equal(tvm::arith::Analyzer().Simplify(lhs - rhs), zero);
PrimExpr t = tvm::arith::Analyzer().Simplify(lhs - rhs);
if (const IntImmNode* i = t.as<IntImmNode>()) {
result = i->value == 0;
}
}
return result;
}
Expand Down
10 changes: 7 additions & 3 deletions python/tvm/relay/backend/compile_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import tvm
from tvm import te
from tvm.runtime import Object
from tvm.support import libinfo
from ... import target as _target
from ... import autotvm
from .. import function as _function
Expand Down Expand Up @@ -80,9 +81,12 @@ def get_shape(shape):
ret = []
for dim in shape:
if isinstance(dim, tvm.tir.IntImm):
val = int(dim)
assert val <= np.iinfo(np.int32).max
ret.append(tvm.tir.IntImm("int32", val))
if libinfo()["INDEX_DEFAULT_I64"] == "ON":
ret.append(dim)
else:
val = int(dim)
assert val <= np.iinfo(np.int32).max
ret.append(tvm.tir.IntImm("int32", val))
elif isinstance(dim, tvm.tir.Any):
ret.append(te.var("any_dim", "int32"))
else:
Expand Down
16 changes: 4 additions & 12 deletions src/arith/bound_deducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,21 +92,13 @@ class BoundDeducer : public ExprVisitor {
}
}

void VisitExpr_(const LTNode* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}
void VisitExpr_(const LTNode* op) final { success_ = false; }

void VisitExpr_(const LENode* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}
void VisitExpr_(const LENode* op) final { success_ = false; }

void VisitExpr_(const GTNode* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}
void VisitExpr_(const GTNode* op) final { success_ = false; }

void VisitExpr_(const GENode* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}
void VisitExpr_(const GENode* op) final { success_ = false; }

void VisitExpr_(const AddNode* op) final {
bool left = op->a.get() == path_[iter_];
Expand Down
14 changes: 12 additions & 2 deletions src/arith/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,8 @@ inline IntervalSet Combine<tir::FloorMod>(Analyzer* analyzer, IntervalSet a, Int
if (analyzer->CanProveGreaterEqual(divisor, 0)) {
if (divisor.as<tir::IntImmNode>()) {
// a mod b = a - (a / b) * b if a_max / b == a_min / b
auto qmax = floordiv(a->max_value, divisor);
auto qmin = floordiv(a->min_value, divisor);
auto qmax = a->HasUpperBound() ? floordiv(a->max_value, divisor) : pos_inf();
auto qmin = a->HasLowerBound() ? floordiv(a->min_value, divisor) : neg_inf();
if (analyzer->CanProve(qmax == qmin)) {
auto tmax = a->max_value - divisor * qmin;
auto tmin = a->min_value - divisor * qmin;
Expand Down Expand Up @@ -441,6 +441,15 @@ class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {
return Union(analyzer_, false_set, true_set);
}

IntervalSet VisitExpr_(const CastNode* op) final {
IntervalSet value_set = this->Eval(op->value);
PrimExpr min_value =
value_set->HasLowerBound() ? cast(op->dtype, value_set->min_value) : neg_inf();
PrimExpr max_value =
value_set->HasUpperBound() ? cast(op->dtype, value_set->max_value) : pos_inf();
return IntervalSet(min_value, max_value);
}

IntervalSet VisitExprDefault_(const Object* op) final {
DLOG(WARNING) << "cannot evaluate set type " << op->GetTypeKey();
return IntervalSet::Everything();
Expand Down Expand Up @@ -609,6 +618,7 @@ bool IntSet::MatchRange(const Range& b) const {
const IntSet& a = *this;
const IntervalSetNode* a_int = a.as<IntervalSetNode>();
if (!a_int) return false;
if (!a_int->HasUpperBound() || !a_int->HasLowerBound()) return false;
Analyzer ana;
return ProveEqual(&ana, a_int->min_value, b->min) &&
ProveEqual(&ana, a_int->max_value, b->extent + b->min - 1);
Expand Down
5 changes: 4 additions & 1 deletion src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,13 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
for (IndexExpr val : shape) {
const int64_t* pval = tir::as_const_int(val);
if (pval != nullptr) {
#ifndef TVM_INDEX_DEFAULT_I64
CHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
CHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
res.push_back(IntImm(DataType::Int(32), *pval));
#else
res.push_back(val);
#endif // TVM_INDEX_DEFAULT_I64
} else if (val->IsInstance<tir::AnyNode>()) {
res.push_back(val.as<tir::AnyNode>()->ToVar());
} else {
Expand Down Expand Up @@ -131,7 +135,6 @@ class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>>
candidate_name = truncated_name.str();
}
cache_node->func_name = candidate_name;

CHECK(master_op_.defined());
// Fusion over tupled results may leave identity relationships
// between inputs and outputs, and those should not be scheduled.
Expand Down
3 changes: 3 additions & 0 deletions src/relay/op/tensor/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,9 @@ Array<te::Tensor> MeanCompute(const Attrs& attrs, const Array<te::Tensor>& input
for (int64_t i : GetReduceAxes(inputs[0]->shape.size(), param->axis, param->exclude)) {
count *= inputs[0]->shape[i];
}
// Although count is created as inputs[0]->dtype,
// its type may be changed (promoted) during multiplication
count = cast(inputs[0]->dtype, count);
auto res = ReduceCompute(attrs, inputs, out_type, topi::sum);
return {topi::divide(res[0], count)};
}
Expand Down
6 changes: 5 additions & 1 deletion src/support/libinfo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@
#define TVM_INFO_USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME "NOT-FOUND"
#endif

#ifndef TVM_INFO_INDEX_DEFAULT_I64
#define TVM_INFO_INDEX_DEFAULT_I64 "NOT-FOUND"
#endif

namespace tvm {

/*!
Expand Down Expand Up @@ -253,7 +257,7 @@ TVM_DLL Map<String, String> GetLibInfo() {
{"USE_TARGET_ONNX", TVM_INFO_USE_TARGET_ONNX},
{"USE_ARM_COMPUTE_LIB", TVM_INFO_USE_ARM_COMPUTE_LIB},
{"USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME", TVM_INFO_USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME},
};
{"INDEX_DEFAULT_I64", TVM_INFO_INDEX_DEFAULT_I64}};
return result;
}

Expand Down
4 changes: 2 additions & 2 deletions src/te/operation/op_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ std::vector<std::vector<Stmt> > MakeLoopNest(const Stage& stage,
}
}
if (!debug_keep_trivial_loop && is_one(dom->extent)) {
nest[i + 1].emplace_back(LetStmt(var, dom->min, no_op));
value_map[iv] = dom->min;
nest[i + 1].emplace_back(LetStmt(var, cast(var.dtype(), dom->min), no_op));
value_map[iv] = cast(var.dtype(), dom->min);
} else if (is_zero(dom->min)) {
nest[i + 1].emplace_back(For(var, 0, dom->extent, for_type, DeviceAPI::None, no_op));
value_map[iv] = var;
Expand Down
4 changes: 4 additions & 0 deletions src/te/schedule/message_passing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,10 @@ void PassUpIndex(const Stage& stage, const Map<IterVar, Range>& dom_map,
if (!is_zero(inner_min)) {
state[s->inner] = state[s->inner] + inner_min;
}
// s->fused, s->outer and s->inner may be of different dtype,
// so we cast the `state` back to its original dtype
state[s->outer] = cast(s->outer->var.dtype(), state[s->outer]);
state[s->inner] = cast(s->inner->var.dtype(), state[s->inner]);
} else if (const RebaseNode* s = rel.as<RebaseNode>()) {
if (!state.count(s->rebased)) {
CHECK(allow_missing);
Expand Down
13 changes: 12 additions & 1 deletion src/te/schedule/schedule_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v)
return 0;
}

DataType MatchDataType(std::vector<DataType> dtypes) {
int max_bits = -1;
for (const auto& dtype : dtypes) {
CHECK(dtype.is_int());
CHECK(dtype.is_scalar());
max_bits = std::max(max_bits, dtype.bits());
}
return DataType::Int(max_bits);
}

void SplitHelper(StageNode* self, IterVar parent, PrimExpr factor, PrimExpr nparts,
IterVar* p_outer, IterVar* p_inner) {
// Check if split is valid.
Expand Down Expand Up @@ -228,8 +238,9 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT
IterVarType iter_type = outer->iter_type;
if (inner->iter_type > iter_type) iter_type = inner->iter_type;
std::string fused_name = outer->var->name_hint + "." + inner->var->name_hint + ".fused";
DataType iter_dtype = MatchDataType({inner->var.dtype(), outer->var.dtype()});

IterVar fused = IterVar(Range(), Var(fused_name, outer->var.dtype()), iter_type);
IterVar fused = IterVar(Range(), Var(fused_name, iter_dtype), iter_type);

Array<IterVar>& all_vars = self->all_iter_vars;
Array<IterVar>& leaf_vars = self->leaf_iter_vars;
Expand Down
49 changes: 49 additions & 0 deletions tests/python/unittest/test_tir_transform_narrow_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
import tvm
from tvm import te
from tvm import relay
from tvm.tir import const


Expand All @@ -38,6 +39,7 @@ def lower_sch(sch, args, target_bits):
arg_list.append(buf)
else:
raise ValueError("args must be Tensor, Buffer or Var")
sch = sch.normalize()
bounds = te.schedule.InferBound(sch)
stmt = te.schedule.ScheduleOps(sch, bounds)

Expand Down Expand Up @@ -189,9 +191,56 @@ def check(m, n, target_bits, target_dtype):
target_bits=32, target_dtype='int64')


def test_relay_basic():
engine = relay.backend.compile_engine.get()
def check(shapex, shapey, target_bits, target_dtype):
x = relay.var('x', shape=shapex)
y = relay.var('y', shape=shapey)
z = relay.add(x, y)
func = relay.Function([x, y], z)
mod = tvm.IRModule.from_expr(func)
func = mod["main"]
z = engine.lower(func, "llvm")
stmt = lower_sch(z.schedule, tuple(z.inputs) + tuple(z.outputs), 32)
# outer loop
assert stmt.loop_var.dtype == target_dtype
# inner loop
if len(shapex) > 1 or len(shapey) > 1:
assert stmt.body.loop_var.dtype == target_dtype

check((const(2**16, 'int64'), const(2**15 + 1, 'int64')), (1, const(2**15 + 1, 'int64')),
target_bits=32, target_dtype="int64")
check((const(2**16, 'int64'), const(2**15, 'int64')), (1, const(2**15, 'int64')),
target_bits=32, target_dtype="int32")
check((const(2**31, 'int64'),), (const(2**31, 'int64'),),
target_bits=32, target_dtype="int32")
check((const(2**31 + 1, 'int64'),), (const(2**31 + 1, 'int64'),),
target_bits=32, target_dtype="int64")


def test_relay_take():
engine = relay.backend.compile_engine.get()
def check(shape, index, target_bits, target_dtype):
x = relay.var("x", shape=shape)
y = relay.op.take(x, indices=index)
func = relay.Function([x], y)
mod = tvm.IRModule.from_expr(func)
func = mod["main"]
z = engine.lower(func, "llvm")
stmt = lower_sch(z.schedule, tuple(z.inputs) + tuple(z.outputs), 32)
assert stmt.value.index.dtype == target_dtype

check((const(2**16, 'int64'), const(2**15 + 1, 'int64')), relay.const(0, dtype="int64"),
target_bits=32, target_dtype="int32")
check((const(2**16, 'int64'), const(2**15 + 1, 'int64')), relay.const(2**31, dtype="int64"),
target_bits=32, target_dtype="int64")


if __name__ == "__main__":
test_basic()
test_thread_axis()
test_multilanes()
test_reduce()
test_slice()
test_relay_basic()
test_relay_take()