diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index fc0bc1f1abd2..714733a196c6 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -63,6 +63,14 @@ bool HasSideEffect(const Expr& e); */ bool ExprUseVar(const Expr& e, const Var& v); +/*! + * \brief Whether e expression used any var in variable set.. + * \param e The expression to be checked. + * \param vset The variable set. + * \return Whether e uses vset. + */ +bool ExprUseVar(const Expr& e, const std::unordered_set& vset); + /*! * \brief Convert a IR node to be SSA form. * \param stmt The source statement to be converted. @@ -77,6 +85,24 @@ Stmt ConvertSSA(Stmt stmt); */ Stmt CanonicalSimplify(Stmt stmt); +/*! + * \brief Substitute the var specified in key->var to be value. + * \param stmt The source statement to be substituted + * \param value_map The map of new values. + * \return The converted form. + */ +Stmt Substitute(Stmt stmt, + const std::unordered_map& value_map); + +/*! + * \brief Substitute the var specified in key->var to be value. + * \param expr The source expression to be substituted + * \param value_map The map of new values. + * \return The converted expression. + */ +Expr Substitute(Expr expr, + const std::unordered_map& value_map); + /*! * \brief Substitute the var specified in key->var to be value. * \param stmt The source statement to be substituted diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index e9c8a179c95f..92bbc08aec7f 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -7,7 +7,6 @@ from . import collections as _collections from ._ffi.function import _init_api - @register_node class Buffer(NodeBase): """Symbolic data buffer in TVM. @@ -24,16 +23,19 @@ class Buffer(NodeBase): """ pass + @register_node class Split(NodeBase): """Split operation on axis.""" pass + @register_node class Fuse(NodeBase): """Fuse operation on axis.""" pass + @register_node class IterVar(NodeBase, _expr.ExprOp): """Represent iteration variable. diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 1ba8d3077482..d8e27288e274 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -30,6 +30,11 @@ TVM_REGISTER_API("ir_pass.Equal") } }); +TVM_REGISTER_API("ir_pass.ExprUseVar") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = ExprUseVar(args[0].operator Expr(), args[1].operator Var()); + }); + TVM_REGISTER_API("ir_pass.PostOrderVisit") .set_body([](TVMArgs args, TVMRetValue *ret) { PackedFunc f = args[1]; @@ -69,7 +74,6 @@ REGISTER_PASS1(CanonicalSimplify); REGISTER_PASS4(Inline); REGISTER_PASS2(StorageFlatten); REGISTER_PASS1(VectorizeLoop); -REGISTER_PASS2(ExprUseVar); REGISTER_PASS4(UnrollLoop); REGISTER_PASS2(StorageSync); REGISTER_PASS4(MakeAPI); diff --git a/src/codegen/stack_vm/codegen_stack_vm.cc b/src/codegen/stack_vm/codegen_stack_vm.cc index ca96b1fd7991..34a8601276a3 100644 --- a/src/codegen/stack_vm/codegen_stack_vm.cc +++ b/src/codegen/stack_vm/codegen_stack_vm.cc @@ -215,11 +215,7 @@ void CodeGenStackVM::PushBinary(StackVM::OpCode op_int64, if (t.is_int()) { this->PushOp(op_int64); } else if (t.is_uint()) { - if (t.bits() <= 32) { - this->PushOp(op_int64); - } else { - LOG(FATAL) << "Cannot handle uint64_t in StackVM"; - } + this->PushOp(op_int64); } else { this->PushOp(StackVM::CodeI64ToF64(op_int64)); } diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc new file mode 100644 index 000000000000..4ac7998d6c34 --- /dev/null +++ b/src/pass/arg_binder.cc @@ -0,0 +1,196 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file arg_binder.cc + * \brief Helper utility to match and bind arguments. + */ +#include +#include +#include +#include "./ir_util.h" +#include "./arg_binder.h" +#include "../arithmetic/compute_expr.h" + +namespace tvm { +namespace ir { + +void BinderAddAssert(Expr cond, + const std::string& arg_name, + std::vector* asserts) { + cond = Simplify(cond); + if (is_zero(cond)) { + LOG(FATAL) << "Bind have an unmet assertion: " + << cond << ", " << " on argument " << arg_name; + } + if (!is_one(cond)) { + std::ostringstream os; + os << "Argument " << arg_name << " has an unsatisfied constraint"; + asserts->emplace_back(AssertStmt::make(cond, os.str())); + } +} + +bool ArgBinder::Bind_(const Expr& arg, + const Expr& value, + const std::string& arg_name, + bool with_lets) { + CHECK_EQ(arg.type(), value.type()); + if (const Variable* v = arg.as()) { + auto it = def_map_->find(v); + if (it == def_map_->end()) { + Var v_arg(arg.node_); + defs_.emplace_back(v_arg); + if (with_lets) { + (*def_map_)[v] = arg; + init_nest_.emplace_back(LetStmt::make(v_arg, value, Evaluate::make(0))); + } else { + (*def_map_)[v] = value; + } + return true; + } else { + BinderAddAssert(it->second == value, arg_name, &asserts_); + } + } else { + BinderAddAssert(arg == value, arg_name, &asserts_); + } + return false; +} + +void ArgBinder::Bind(const Expr& arg, + const Expr& value, + const std::string& arg_name, + bool with_let) { + Bind_(arg, value, arg_name, with_let); +} + +void ArgBinder::BindArray(const Array& arg, + const Array& value, + const std::string& arg_name) { + CHECK_EQ(arg.size(), value.size()) + << "Argument " << arg_name << " array size mismatch"; + for (size_t i = 0; i < arg.size(); ++i) { + std::ostringstream os; + os << arg_name << "[" << i << "]"; + this->Bind(arg[i], value[i], os.str()); + } +} + +void ArgBinder::BindBuffer(const Buffer& arg, + const Buffer& value, + const std::string& arg_name) { + CHECK_EQ(arg->scope, value->scope) + << "Argument " << arg_name + << " Buffer bind scope mismatch"; + this->Bind(arg->data, value->data, arg_name + ".data"); + this->BindArray(arg->shape, value->shape, arg_name + ".shape"); + this->BindArray(arg->strides, value->strides, arg_name + ".strides"); + this->Bind(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset"); +} + +inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMStructFieldKind kind) { + return TVMStructGet(t, arr, 0, kind); +} + +inline Stmt AssertNull(Var handle, std::string msg) { + return AssertStmt::make(Call::make( + Bool(1), intrinsic::tvm_handle_is_null, + {handle}, Call::PureIntrinsic), msg); +} + +void ArgBinder::BindDLTensor(const Buffer& buffer, + const Expr& device_type, + const Expr& device_id, + const Var& handle, + const std::string& arg_name) { + const Type tvm_shape_type = TVMShapeIndexType(); + const Type tvm_ndim_type = Int(32); + const Stmt nop = Evaluate::make(0); + // dimension checks + Expr v_ndim = TVMArrayGet(tvm_ndim_type, handle, intrinsic::kArrNDim); + Expr a_ndim = make_const(tvm_ndim_type, + static_cast(buffer->shape.size())); + std::ostringstream ndim_err_msg; + ndim_err_msg << arg_name + << ".ndim is expected to equal " + << buffer->shape.size(); + asserts_.emplace_back(AssertStmt::make(a_ndim == v_ndim, ndim_err_msg.str())); + // type checks + Type dtype = buffer->dtype; + std::ostringstream type_err_msg; + type_err_msg << arg_name << ".dtype is expected to be " << dtype; + Expr cond = (TVMArrayGet(UInt(8), handle, intrinsic::kArrTypeCode) == + UIntImm::make(UInt(8), dtype.code()) && + TVMArrayGet(UInt(8), handle, intrinsic::kArrTypeBits) == + UIntImm::make(UInt(8), dtype.bits()) && + TVMArrayGet(UInt(16), handle, intrinsic::kArrTypeLanes) == + UIntImm::make(UInt(16), dtype.lanes())); + asserts_.emplace_back(AssertStmt::make(cond, type_err_msg.str())); + // data field + if (Bind_(buffer->data, TVMArrayGet(Handle(), handle, intrinsic::kArrData), + arg_name + ".data", true)) { + Var vptr(buffer->data); + def_handle_dtype_.Set(vptr, make_const(buffer->dtype, 0)); + // mark alignment of external bufs + init_nest_.emplace_back(AttrStmt::make( + vptr, ir::attr::storage_alignment, + IntImm::make(Int(32), runtime::kAllocAlignment), nop)); + } + + Var v_shape(arg_name + ".shape", Handle()); + def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0)); + init_nest_.emplace_back(LetStmt::make( + v_shape, TVMArrayGet(Handle(), handle, intrinsic::kArrShape), nop)); + for (size_t k = 0; k < buffer->shape.size(); ++k) { + std::ostringstream field_name; + field_name << v_shape->name_hint << '[' << k << ']'; + Bind_(buffer->shape[k], + cast(buffer->shape[k].type(), + Load::make(tvm_shape_type, v_shape, + IntImm::make(Int(32), k), const_true(1))), + field_name.str(), true); + } + // strides field + Var v_strides(arg_name + ".strides", Handle()); + def_handle_dtype_.Set(v_strides, make_const(tvm_shape_type, 0)); + init_nest_.emplace_back(LetStmt::make( + v_strides, TVMArrayGet(Handle(), handle, intrinsic::kArrStrides), + nop)); + if (buffer->strides.size() == 0) { + std::ostringstream stride_err_msg; + stride_err_msg << arg_name << ".strides:" + << " expected to be nullptr for contiguous array"; + init_nest_.emplace_back(AssertNull(v_strides, stride_err_msg.str())); + } else { + for (size_t k = 0; k < buffer->strides.size(); ++k) { + std::ostringstream field_name; + field_name << v_strides->name_hint << '[' << k << ']'; + Bind_(buffer->strides[k], + cast(buffer->shape[k].type(), + Load::make(tvm_shape_type, v_strides, + IntImm::make(Int(32), k), const_true(1))), + field_name.str(), true); + } + } + // Byte_offset field. + int data_bytes = GetVectorBytes(buffer->dtype); + int64_t const_offset; + if (arith::GetConst(buffer->elem_offset, &const_offset)) { + Bind_(make_const(UInt(64), const_offset * data_bytes), + TVMArrayGet(UInt(64), handle, intrinsic::kArrByteOffset), + arg_name + ".byte_offset", true); + } else { + Bind_(buffer->elem_offset, + cast(buffer->elem_offset.type(), + (TVMArrayGet(UInt(64), handle, intrinsic::kArrByteOffset) / + make_const(UInt(64), data_bytes))), + arg_name + ".elem_offset", true); + } + // device info. + Bind_(device_type, + TVMArrayGet(Int(32), handle, intrinsic::kArrDeviceType), + arg_name + ".device_type", true); + Bind_(device_id, + TVMArrayGet(Int(32), handle, intrinsic::kArrDeviceId), + arg_name + ".device_id", true); +} + +} // namespace ir +} // namespace tvm diff --git a/src/pass/arg_binder.h b/src/pass/arg_binder.h new file mode 100644 index 000000000000..59e4eab55a1f --- /dev/null +++ b/src/pass/arg_binder.h @@ -0,0 +1,138 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file arg_binder.h + * \brief Helper utility to match and bind arguments. + */ +#ifndef TVM_PASS_ARG_BINDER_H_ +#define TVM_PASS_ARG_BINDER_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace ir { + +/*! + * \brief Helper utility to generate match and bind of arguments. + * + * \note There is many places in TVM IR where we need argument bindings. + * + * Consider a function f(tA(shape=var(n)), tB(shape=3), tC(shape=(n+2)). + * Here n is a undefined variable that is decided by the outside, tB imposes + * a constraint such that it can only take tensor with shape 3, tC imposes + * another constraint that it's shape must equals n + 2. + * So if we call it with f(bufferA, bufferB, bufferC), we need to generate + * the following binding sequence: + * - define n = bufferA.shape[0] + * - assert bufferB.shape[0] == 3 + * - assert bufferB.shape[1] == n + 3 + * + * In general, this is a constraint solving problem. We have simplified assumption + * over the binding declaration, such that we require the variable occured in + * constraint must be declared in argument list. So it is illegal to have signature + * f(tA(shape=(n+3))) without any argument variable corresponds to n, even though + * it is already enough to derive n from the input argument. + */ +class ArgBinder { + public: + /*! + * \brief Constructor + * \param def_map A definition map that contains definition of known variables. + * ArgBinder will update this def_map when adding new definitions. + */ + explicit ArgBinder( + std::unordered_map* def_map) + : def_map_(def_map) { + } + /*! + * \brief Try to bind arg to value, generate constraint if necessary. + * \param arg The argument to be binded. + * \param value The target expression value + * \param arg_name argument name. + * \param with_let Whether add lets during bind + */ + void Bind(const Expr& arg, + const Expr& value, + const std::string& arg_name, + bool with_let = false); + /*! + * \brief Bind array to array + * \param arg The argument to be binded. + * \param value The target expression value + * \param arg_name argument name. + */ + void BindArray(const Array& arg, + const Array& value, + const std::string& arg_name); + /*! + * \brief Bind symbolic buffer to another symbolic buffer + * \param arg The argument to be binded. + * \param value The target expression value + * \param arg_name argument name. + */ + void BindBuffer(const Buffer& arg, + const Buffer& value, + const std::string& arg_name); + /*! + * \brief Bind symbolic buffer to a DLTensor handle. + * \param buffer The argument buffer to be binded. + * \param device_type The device id to be binded. + * \param device_id The device id to be binded. + * \param handle The DLTensor handle. + * \param arg_name argument name. + */ + void BindDLTensor(const Buffer& buffer, + const Expr& device_type, + const Expr& device_id, + const Var& handle, + const std::string& arg_name); + + /*! \return The defs generated in binding. */ + const std::vector& defs() const { + return defs_; + } + /*! \return The asserts generated in binding */ + const std::vector& asserts() const { + return asserts_; + } + /*! + * \brief Initialization nest generated + * This is only non-empty when BindDLTensor is called. + * + * \note The binder may choose to generate a let statement + * and simply put def_map to map Variable to itself, + * or update def_map to directly map to new value and not generate let statement. + * + * Let statement is usually generated when bind to DLTensor and memory load is involved. + * \return The initialization nest generated during binding. + */ + const std::vector& init_nest() const { + return init_nest_; + } + /*! \return Handle data type of the data */ + const Map& def_handle_dtype() const { + return def_handle_dtype_; + } + + private: + // Internal bind function + bool Bind_(const Expr& arg, + const Expr& value, + const std::string& arg_name, + bool with_lets); + /*! \brief The definition map, can be uses to substitute */ + std::unordered_map* def_map_; + /*! \brief defs generated in the current binder */ + std::vector defs_; + /*! \brief Initialize nest */ + std::vector init_nest_; + /*! \brief handle data type in the defintiions */ + Map def_handle_dtype_; + /*! \brief asserts generated */ + std::vector asserts_; +}; +} // namespace ir +} // namespace tvm +#endif // TVM_PASS_ARG_BINDER_H_ diff --git a/src/pass/ir_util.cc b/src/pass/ir_util.cc new file mode 100644 index 000000000000..8c5136a4d108 --- /dev/null +++ b/src/pass/ir_util.cc @@ -0,0 +1,67 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file ir_util.cc + * \brief Helper functions to construct and compose IR nodes. + */ +#include "./ir_util.h" + +namespace tvm { +namespace ir { + +Stmt MergeNest(const std::vector& nest, Stmt body) { + // use reverse iteration + for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) { + Stmt s = *ri; + if (s.as()) { + auto n = std::make_shared(*s.as()); + CHECK(is_no_op(n->body)); + n->body = body; + body = Stmt(n); + } else if (s.as()) { + auto n = std::make_shared(*s.as()); + CHECK(is_no_op(n->body)); + n->body = body; + body = Stmt(n); + } else if (s.as()) { + auto n = std::make_shared(*s.as()); + CHECK(is_no_op(n->body)); + n->body = body; + body = Stmt(n); + } else if (s.as()) { + auto n = std::make_shared(*s.as()); + CHECK(is_no_op(n->then_case)); + CHECK(!n->else_case.defined()); + n->then_case = body; + body = Stmt(n); + } else if (s.as()) { + body = Block::make(s, body); + } else if (s.as()) { + auto n = std::make_shared(*s.as()); + CHECK(is_no_op(n->body)); + n->body = body; + body = Stmt(n); + } else { + LOG(FATAL) << "not supported nest type"; + } + } + return body; +} + +Stmt MergeNest(const std::vector >& nest, Stmt body) { + for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) { + body = MergeNest(*ri, body); + } + return body; +} + +Stmt MergeSeq(const std::vector& seq) { + if (seq.size() == 0) return Evaluate::make(0); + Stmt body = seq[0]; + for (size_t i = 1; i < seq.size(); ++i) { + body = Block::make(body, seq[i]); + } + return body; +} + +} // namespace ir +} // namespace tvm diff --git a/src/pass/ir_util.h b/src/pass/ir_util.h index 472b408e32d5..bb09aa88b29c 100644 --- a/src/pass/ir_util.h +++ b/src/pass/ir_util.h @@ -11,6 +11,28 @@ namespace tvm { namespace ir { +/*! + * \brief combine the nest stmt, whose body is not defined. + * \param nest A list of For and LetStmt, whose body is not defined. + * \param body body + * \return The combined Stmt + */ +Stmt MergeNest(const std::vector& nest, Stmt body); + +/*! + * \brief combine the nest stmt, whose body is not defined. + * \param nest A list of For and LetStmt, whose body is not defined. + * \param body body + * \return The combined Stmt + */ +Stmt MergeNest(const std::vector >& nest, Stmt body); + +/*! + * \brief combine sequence of operations. + * \param seq The sequence. + * \return The combined Stmt + */ +Stmt MergeSeq(const std::vector& seq); /*! * \brief update array with an unary function @@ -38,79 +60,6 @@ inline Array UpdateArray(Array arr, F fupdate) { } } -/*! - * \brief combine the nest stmt, whose body is not defined. - * \param nest A list of For and LetStmt, whose body is not defined. - * \param body body - * \return The combined Stmt - */ -inline Stmt MergeNest(std::vector nest, Stmt body) { - // use reverse iteration - for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) { - Stmt s = *ri; - if (s.as()) { - auto n = std::make_shared(*s.as()); - CHECK(is_no_op(n->body)); - n->body = body; - body = Stmt(n); - } else if (s.as()) { - auto n = std::make_shared(*s.as()); - CHECK(is_no_op(n->body)); - n->body = body; - body = Stmt(n); - } else if (s.as()) { - auto n = std::make_shared(*s.as()); - CHECK(is_no_op(n->body)); - n->body = body; - body = Stmt(n); - } else if (s.as()) { - auto n = std::make_shared(*s.as()); - CHECK(is_no_op(n->then_case)); - CHECK(!n->else_case.defined()); - n->then_case = body; - body = Stmt(n); - } else if (s.as()) { - body = Block::make(s, body); - } else if (s.as()) { - auto n = std::make_shared(*s.as()); - CHECK(is_no_op(n->body)); - n->body = body; - body = Stmt(n); - } else { - LOG(FATAL) << "not supported nest type"; - } - } - return body; -} - -/*! - * \brief combine the nest stmt, whose body is not defined. - * \param nest A list of For and LetStmt, whose body is not defined. - * \param body body - * \return The combined Stmt - */ -inline Stmt MergeNest(std::vector > nest, Stmt body) { - for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) { - body = MergeNest(*ri, body); - } - return body; -} - - -/*! - * \brief combine sequence of operations. - * \param seq The sequence. - * \return The combined Stmt - */ -inline Stmt MergeSeq(const std::vector& seq) { - if (seq.size() == 0) return Evaluate::make(0); - Stmt body = seq[0]; - for (size_t i = 1; i < seq.size(); ++i) { - body = Block::make(body, seq[i]); - } - return body; -} - /*! * \brief Get construct from struct * \param dtype The data type. @@ -176,7 +125,6 @@ inline Type APIType(Type t) { CHECK(t.is_float()); return Float(64); } - } // namespace ir } // namespace tvm #endif // TVM_PASS_IR_UTIL_H_ diff --git a/src/pass/make_api.cc b/src/pass/make_api.cc index b845b24b014e..8ca7b590e9b3 100644 --- a/src/pass/make_api.cc +++ b/src/pass/make_api.cc @@ -12,21 +12,12 @@ #include #include "./ir_util.h" +#include "./arg_binder.h" #include "../arithmetic/compute_expr.h" namespace tvm { namespace ir { -inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMStructFieldKind kind) { - return TVMStructGet(t, arr, 0, kind); -} - -inline Stmt AssertNull(Var handle, std::string msg) { - return AssertStmt::make(Call::make( - Bool(1), intrinsic::tvm_handle_is_null, - {handle}, Call::PureIntrinsic), msg); -} - inline Stmt MakeAssertEQ(Expr lhs, Expr rhs, std::string msg) { return AssertStmt::make(lhs == rhs, msg); } @@ -35,8 +26,6 @@ LoweredFunc MakeAPI(Stmt body, std::string name, Array api_args, int num_unpacked_args) { - const Type tvm_shape_type = TVMShapeIndexType(); - const Type tvm_ndim_type = Int(32); const Stmt nop = Evaluate::make(0); int num_args = static_cast(api_args.size()); CHECK_LE(num_unpacked_args, num_args); @@ -48,14 +37,13 @@ LoweredFunc MakeAPI(Stmt body, Var v_num_packed_args("num_args", Int(32)); // The arguments of the function. Array args; + // The device context + Var device_type("dev_type"), device_id("dev_id"); // seq_init gives sequence of initialization // seq_check gives sequence of later checks after iniit std::vector seq_init, seq_check; - std::unordered_set visited; - // the handle data types - Map handle_data_type; - // The device context - Var device_id, device_type; + std::unordered_map vmap; + ArgBinder binder(&vmap); // --------------------------- // local function defintiions // load i-th argument as type t @@ -81,25 +69,6 @@ LoweredFunc MakeAPI(Stmt body, const Variable* v = api_args[i].as(); return Var(os.str(), v ? v->type: Handle()); }; - // Push related into assertions or variable defintion - // given the symbolic declaration and concrete value - auto f_push = [&](Expr sym, Expr value, std::string field) { - if (sym.as()) { - // If sym is a Variable and this Variable is not yet defined - // add this to defintion. - Var v(sym.node_); - if (!visited.count(v.get())) { - seq_init.emplace_back(LetStmt::make(v, value, nop)); - visited.insert(v.get()); - return true; - } - } - // otherwise, assume sym is already defined, insert assertion. - std::ostringstream os; - os << "Field " << field << " has a unsatisfied constraint"; - seq_check.emplace_back(MakeAssertEQ(sym, value, os.str())); - return false; - }; // --------------------------- // start of logics // add signiture for packed arguments. @@ -112,7 +81,6 @@ LoweredFunc MakeAPI(Stmt body, seq_init.emplace_back( MakeAssertEQ(v_num_packed_args, num_packed_args, os.str())); } - for (int i = 0; i < static_cast(api_args.size()); ++i) { Var v_arg = f_arg_decl(i); if (i < num_packed_args) { @@ -148,117 +116,30 @@ LoweredFunc MakeAPI(Stmt body, } // add checks for functions. if (api_args[i].as()) { - f_push(Var(api_args[i].node_), v_arg, v_arg->name_hint); + binder.Bind(Var(api_args[i].node_), v_arg, v_arg->name_hint, true); } else { // Buffer checks CHECK(api_args[i].as()) << "api_args can only be Buffer or Var"; Buffer buf(api_args[i].node_); - // dimension checks - Expr v_ndim = TVMArrayGet(tvm_ndim_type, v_arg, intrinsic::kArrNDim); - std::ostringstream ndim_err_msg; - ndim_err_msg << "arg_" << i - << ".ndim is expected to equal " - << buf->shape.size(); - seq_init.emplace_back( - MakeAssertEQ(v_ndim, - make_const(tvm_ndim_type, - static_cast(buf->shape.size())), - ndim_err_msg.str())); - // type checks - Type dtype = buf->dtype; - std::ostringstream type_err_msg; - type_err_msg << "arg" << i << ".dtype is expected to be " << dtype; - Expr cond = (TVMArrayGet(UInt(8), v_arg, intrinsic::kArrTypeCode) == - UIntImm::make(UInt(8), dtype.code()) && - TVMArrayGet(UInt(8), v_arg, intrinsic::kArrTypeBits) == - UIntImm::make(UInt(8), dtype.bits()) && - TVMArrayGet(UInt(16), v_arg, intrinsic::kArrTypeLanes) == - UIntImm::make(UInt(16), dtype.lanes())); - seq_init.emplace_back(AssertStmt::make(cond, type_err_msg.str())); - // Data Field - if (f_push(buf->data, TVMArrayGet(Handle(), v_arg, intrinsic::kArrData), - v_arg->name_hint + ".data")) { - Var vptr(buf->data); - handle_data_type.Set(vptr, make_const(buf->dtype, 0)); - // mark storage alignment of external buffer arguments. - seq_init.emplace_back(AttrStmt::make( - vptr, ir::attr::storage_alignment, - IntImm::make(Int(32), runtime::kAllocAlignment), nop)); - } - // shape field - Var v_shape(v_arg->name_hint + ".shape", Handle()); - handle_data_type.Set(v_shape, make_const(tvm_shape_type, 0)); - seq_init.emplace_back(LetStmt::make( - v_shape, TVMArrayGet(Handle(), v_arg, intrinsic::kArrShape), nop)); - for (size_t k = 0; k < buf->shape.size(); ++k) { - std::ostringstream field_name; - field_name << v_shape->name_hint << '[' << k << ']'; - f_push(buf->shape[k], - cast(buf->shape[k].type(), - Load::make(tvm_shape_type, v_shape, - IntImm::make(Int(32), k), const_true(1))), - field_name.str()); - } - // strides field - Var v_strides(v_arg->name_hint + ".strides", Handle()); - handle_data_type.Set(v_strides, make_const(tvm_shape_type, 0)); - seq_init.emplace_back(LetStmt::make( - v_strides, TVMArrayGet(Handle(), v_arg, intrinsic::kArrStrides), - nop)); - if (buf->strides.size() == 0) { - std::ostringstream stride_err_msg; - stride_err_msg << "arg_" << i << ".strides:" - << " expected to be nullptr for contiguous array"; - seq_init.emplace_back(AssertNull(v_strides, stride_err_msg.str())); - } else { - for (size_t k = 0; k < buf->strides.size(); ++k) { - std::ostringstream field_name; - field_name << v_strides->name_hint << '[' << k << ']'; - f_push(buf->strides[k], - cast(buf->shape[k].type(), - Load::make(tvm_shape_type, v_strides, - IntImm::make(Int(32), k), const_true(1))), - field_name.str()); - } - } - // Byte_offset field. - int data_bytes = GetVectorBytes(buf->dtype); - int64_t const_offset; - if (arith::GetConst(buf->elem_offset, &const_offset)) { - f_push(make_const(buf->elem_offset.type(), const_offset * data_bytes), - TVMArrayGet(UInt(64), v_arg, intrinsic::kArrByteOffset), - v_arg->name_hint + ".byte_offset"); - } else { - f_push(buf->elem_offset, - cast(buf->elem_offset.type(), - (TVMArrayGet(UInt(64), v_arg, intrinsic::kArrByteOffset) / - make_const(UInt(64), data_bytes))), - v_arg->name_hint + ".elem_offset"); - } - // device info. - f_push(device_id, - TVMArrayGet(Int(32), v_arg, intrinsic::kArrDeviceId), - v_arg->name_hint + ".device_id"); - f_push(device_type, - TVMArrayGet(Int(32), v_arg, intrinsic::kArrDeviceType), - v_arg->name_hint + ".device_type"); + binder.BindDLTensor( + buf, device_type, device_id, v_arg, v_arg->name_hint); } } std::shared_ptr n = std::make_shared(); n->name = name; n->args = args; - n->handle_data_type = handle_data_type; + n->handle_data_type = binder.def_handle_dtype(); n->is_packed_func = num_unpacked_args == 0; // Set device context - if (visited.count(device_id.get())) { + if (vmap.count(device_id.get())) { Expr node = StringImm::make("default"); - CHECK(visited.count(device_type.get())); - seq_init.push_back(AttrStmt::make( + CHECK(vmap.count(device_type.get())); + seq_check.push_back(AttrStmt::make( node, attr::device_context_id, device_id, nop)); - seq_init.push_back(AttrStmt::make( + seq_check.push_back(AttrStmt::make( node, attr::device_context_type, device_type, nop)); Stmt set_device = IfThenElse::make( device_type != kCPU, Evaluate::make(Call::make( @@ -267,7 +148,8 @@ LoweredFunc MakeAPI(Stmt body, device_type, device_id}, Call::Intrinsic))); body = Block::make(set_device, body); } - n->body = MergeNest({seq_init, seq_check}, body); + n->body = MergeNest( + {seq_init, binder.init_nest(), seq_check, binder.asserts()}, body); LoweredFunc f(n); Array undefined = UndefinedVars(f->body, f->args); if (undefined.size() != 0) { diff --git a/src/pass/simple_passes.cc b/src/pass/simple_passes.cc index bf91e6ebfb67..a3ca2904b842 100644 --- a/src/pass/simple_passes.cc +++ b/src/pass/simple_passes.cc @@ -37,67 +37,110 @@ bool HasSideEffect(const Expr& e) { class IRSubstitue : public IRMutator { public: + explicit IRSubstitue( + const std::unordered_map& smap) + : smap_(smap) { + } + Expr Mutate_(const Variable* op, const Expr& e) final { - auto it = smap.find(op); - if (it != smap.end()) { + auto it = smap_.find(op); + if (it != smap_.end()) { return it->second; } else { return e; } } - std::unordered_map smap; + + private: + const std::unordered_map& smap_; }; -Stmt Substitute(Stmt stmt, const Map& value_map) { +Stmt Substitute(Stmt stmt, + const std::unordered_map& value_map) { if (value_map.size() == 0) return stmt; - IRSubstitue m; - for (auto kv : value_map) { - m.smap[kv.first.get()] = kv.second; + return IRSubstitue(value_map).Mutate(stmt); +} + +Expr Substitute(Expr expr, + const std::unordered_map& value_map) { + if (value_map.size() == 0) return expr; + return IRSubstitue(value_map).Mutate(expr); +} + +Stmt Substitute(Stmt stmt, const Map& value_map) { + std::unordered_map vmap; + for (const auto& kv : value_map) { + vmap[kv.first.get()] = kv.second; } - return m.Mutate(stmt); + return Substitute(stmt, vmap); } Expr Substitute(Expr expr, const Map& value_map) { - if (value_map.size() == 0) return expr; - IRSubstitue m; - for (auto kv : value_map) { - m.smap[kv.first.get()] = kv.second; + std::unordered_map vmap; + for (const auto& kv : value_map) { + vmap[kv.first.get()] = kv.second; } - return m.Mutate(expr); + return Substitute(expr, vmap); } -class ExprUseVarVisitor : public IRVisitor { +class VarTouchVisitor : public IRVisitor { public: - explicit ExprUseVarVisitor(const Variable* var) - : var_(var) {} - void Visit(const NodeRef& e) final { if (use_var_) return; IRVisitor::Visit(e); } void Visit_(const Variable* op) final { - if (op == var_) { - use_var_ = true; - } + Handle(op); } void Visit_(const Load* op) final { - if (op->buffer_var.get() == var_) { - use_var_ = true; - } + Handle(op->buffer_var.get()); IRVisitor::Visit_(op); } - const Variable* var_; + virtual void Handle(const Variable* var) = 0; + bool use_var_{false}; }; +class ExprUseVarVisitor : public VarTouchVisitor { + public: + explicit ExprUseVarVisitor(const Variable* var) + : var_(var) {} + + void Handle(const Variable* var) final { + if (var == var_) use_var_ = true; + } + private: + const Variable* var_; +}; + +class ExprUseVSetVisitor : public VarTouchVisitor { + public: + explicit ExprUseVSetVisitor( + const std::unordered_set& vset) + : vset_(vset) {} + + void Handle(const Variable* var) final { + if (vset_.count(var)) use_var_ = true; + } + private: + const std::unordered_set& vset_; +}; + bool ExprUseVar(const Expr& e, const Var& v) { ExprUseVarVisitor visitor(v.get()); visitor.Visit(e); return visitor.use_var_; } +bool ExprUseVar(const Expr& e, + const std::unordered_set& vset) { + ExprUseVSetVisitor visitor(vset); + visitor.Visit(e); + return visitor.use_var_; +} + } // namespace ir } // namespace tvm diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc index 75ddf75beaf7..a33e0349621f 100644 --- a/src/pass/storage_flatten.cc +++ b/src/pass/storage_flatten.cc @@ -8,6 +8,8 @@ #include #include #include +#include "./ir_util.h" +#include "./arg_binder.h" #include "../arithmetic/compute_expr.h" #include "../runtime/thread_storage_scope.h" @@ -156,30 +158,6 @@ class StorageFlattener : public IRMutator { } private: - // Bind the symbol sym to value if it is a Variable - // send a sequence of asserts if it is a constant constrant. - // hint_name: used for error message - // add_keys: a list of newly binded keys - // add_asserts: a list of asserts during the bind - void BindSymbol(Expr sym, - Expr value, - std::string hint_name, - std::vector* add_keys, - std::vector* add_asserts) { - if (const Variable* v = sym.as()) { - auto it = var_remap_.find(v); - if (it == var_remap_.end()) { - add_keys->push_back(v); - var_remap_[v] = value; - return; - } - } - // add assertions - std::ostringstream os; - os << "BufferBind constaint fail " << hint_name; - add_asserts->emplace_back( - AssertStmt::make(sym == value, os.str())); - } // Start bind Stmt HandleBufferBindScope(const AttrStmt* op) { Array arr(op->node.node_); @@ -215,47 +193,16 @@ class StorageFlattener : public IRMutator { } else { slice = slice.MakeStrideView(); } - CHECK_EQ(slice->strides.size(), buffer->strides.size()); // start binding - std::vector keys; - std::vector asserts; - BindSymbol(buffer->data, slice->data, - buffer->name + ".data", - &keys, &asserts); - for (size_t i = 0; i < buffer->shape.size(); ++i) { - std::ostringstream field_name; - field_name << buffer->name << ".shape[" << i << ']'; - BindSymbol(buffer->shape[i], slice->shape[i], - field_name.str(), - &keys, &asserts); - } - for (size_t i = 0; i < buffer->strides.size(); ++i) { - std::ostringstream field_name; - field_name << buffer->name << ".strides[" << i << ']'; - BindSymbol(buffer->strides[i], slice->strides[i], - field_name.str(), - &keys, &asserts); - } - BindSymbol(buffer->elem_offset, slice->elem_offset, - buffer->name + ".elem_offset", - &keys, &asserts); - CHECK_EQ(buffer->scope, slice->scope) - << "Buffer bind scope mismatch"; + ArgBinder binder(&var_remap_); + binder.BindBuffer(Buffer(arr[0].node_), slice, buffer->name); // Apply the remaps - Stmt body = this->Mutate(op->body); - for (size_t i = 0; i < asserts.size(); ++i) { - Stmt ret = Simplify(this->Mutate(asserts[i])); - if (const AssertStmt* assert_op = ret.as()) { - if (!is_zero(assert_op->condition)) { - body = Block::make(ret, body); - } else { - LOG(FATAL) << "BindBuffer have unmet assertion: " << ret; - } - } - } + Stmt body = MergeNest(binder.asserts(), op->body); + body = MergeNest(binder.init_nest(), body); + body = this->Mutate(body); // remove the binds - for (const Variable* op : keys) { - var_remap_.erase(op); + for (const Var& v : binder.defs()) { + var_remap_.erase(v.get()); } return body; }