Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions include/tvm/node/structural_equal.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,20 +195,40 @@ class SEqualReducer {
* \param rhs The right operand.
* \return the immediate check result.
*/
bool operator()(const double& lhs, const double& rhs) const;
bool operator()(const int64_t& lhs, const int64_t& rhs) const;
bool operator()(const uint64_t& lhs, const uint64_t& rhs) const;
bool operator()(const int& lhs, const int& rhs) const;
bool operator()(const bool& lhs, const bool& rhs) const;
bool operator()(const std::string& lhs, const std::string& rhs) const;
bool operator()(const DataType& lhs, const DataType& rhs) const;
bool operator()(const double& lhs, const double& rhs,
Optional<ObjectPathPair> paths = NullOpt) const;
bool operator()(const int64_t& lhs, const int64_t& rhs,
Optional<ObjectPathPair> paths = NullOpt) const;
bool operator()(const uint64_t& lhs, const uint64_t& rhs,
Optional<ObjectPathPair> paths = NullOpt) const;
bool operator()(const int& lhs, const int& rhs, Optional<ObjectPathPair> paths = NullOpt) const;
bool operator()(const bool& lhs, const bool& rhs, Optional<ObjectPathPair> paths = NullOpt) const;
bool operator()(const std::string& lhs, const std::string& rhs,
Optional<ObjectPathPair> paths = NullOpt) const;
bool operator()(const DataType& lhs, const DataType& rhs,
Optional<ObjectPathPair> paths = NullOpt) const;

template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
bool operator()(const ENum& lhs, const ENum& rhs) const {
bool operator()(const ENum& lhs, const ENum& rhs,
Optional<ObjectPathPair> paths = NullOpt) const {
using Underlying = typename std::underlying_type<ENum>::type;
static_assert(std::is_same<Underlying, int>::value,
"Enum must have `int` as the underlying type");
return EnumAttrsEqual(static_cast<int>(lhs), static_cast<int>(rhs), &lhs, &rhs);
return EnumAttrsEqual(static_cast<int>(lhs), static_cast<int>(rhs), &lhs, &rhs, paths);
}

template <typename T, typename Callable,
typename = std::enable_if_t<
std::is_same_v<std::invoke_result_t<Callable, const ObjectPath&>, ObjectPath>>>
bool operator()(const T& lhs, const T& rhs, const Callable& callable) {
if (IsPathTracingEnabled()) {
ObjectPathPair current_paths = GetCurrentObjectPaths();
ObjectPathPair new_paths = {callable(current_paths->lhs_path),
callable(current_paths->rhs_path)};
return (*this)(lhs, rhs, new_paths);
} else {
return (*this)(lhs, rhs);
}
}

/*!
Expand Down Expand Up @@ -310,7 +330,8 @@ class SEqualReducer {
void RecordMismatchPaths(const ObjectPathPair& paths) const;

private:
bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address) const;
bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address,
Optional<ObjectPathPair> paths = NullOpt) const;

bool ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
const ObjectPathPair* paths) const;
Expand All @@ -321,7 +342,8 @@ class SEqualReducer {

template <typename T>
static bool CompareAttributeValues(const T& lhs, const T& rhs,
const PathTracingData* tracing_data);
const PathTracingData* tracing_data,
Optional<ObjectPathPair> paths = NullOpt);

/*! \brief Internal class pointer. */
Handler* handler_ = nullptr;
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/script/ir_builder/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ class IRBuilder : public runtime::ObjectRef {
* \sa tvm::support::With
*/
static IRBuilder Current();
/*! \brief See if the current thread-local scope has an IRBuilder. */
static bool IsInScope();
/*!
* \brief Give a string name to the `obj`
* \tparam TObjectRef The type of the object to name.
Expand Down
14 changes: 11 additions & 3 deletions include/tvm/script/ir_builder/ir/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,21 @@ namespace ir {
*/
class IRModuleFrameNode : public IRBuilderFrameNode {
public:
Array<GlobalVar> global_vars;
Array<BaseFunc> functions;
/*! \brief A map from string names to global variables that ensures global uniqueness. */
Map<String, GlobalVar> global_var_map;
/*!
* \brief A map from GlobalVar to all global functions.
* \note Only defined functions are in the map, while declared functions are not included.
*/
Map<GlobalVar, BaseFunc> functions;
/*! \brief IRModule's attributes. */
Map<String, ObjectRef> attrs;

void VisitAttrs(tvm::AttrVisitor* v) {
IRBuilderFrameNode::VisitAttrs(v);
v->Visit("global_vars", &global_vars);
v->Visit("global_vars", &global_var_map);
v->Visit("functions", &functions);
v->Visit("attrs", &attrs);
}

static constexpr const char* _type_key = "script.ir_builder.IRModuleFrame";
Expand Down
17 changes: 17 additions & 0 deletions include/tvm/script/ir_builder/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,23 @@ namespace ir {
*/
TVM_DLL IRModuleFrame IRModule();

/*!
* \brief Declare a Function without given the specific function implementation.
* \note It is usually used in cross-function call. And we can specify the function by `DefFunction`
* \param func_name The function unique name.
* \param func_signature A Function w/o body, which used to specify the function signature
* (i.e. func params and func return type/shape).
* \return The corresponding GlobalVar.
*/
TVM_DLL GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature);

/*!
* \brief Define the function which is declared before.
* \param func_name The function unique name.
* \param func The given function implementation
*/
TVM_DLL void DefFunction(const String& func_name, const BaseFunc& func);

} // namespace ir
} // namespace ir_builder
} // namespace script
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -435,9 +435,9 @@ void Evaluate(PrimExpr value);
*/
inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(), //
String storage_scope = "global", //
bool is_size_var = false) {
bool is_size_var = false, bool is_unknown_type = false) {
Type type_annotation{nullptr};
if (dtype.is_void() && storage_scope == "global") {
if (is_unknown_type && storage_scope == "global") {
type_annotation = PrimType(runtime::DataType::Handle());
} else {
type_annotation = PointerType(PrimType(dtype), storage_scope);
Expand Down
46 changes: 33 additions & 13 deletions include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,34 +272,54 @@ PrimFunc Specialize(PrimFunc func, const Map<Var, ObjectRef>& param_map);
* \sa tvm::attr
*/
namespace attr {

/*!
* \brief List of thread IterVar that a DeviceLaunch function corresponds to.
*
* Type: Array<tir::IterVar>
* Type: Array<String>
*
* We call a device kernel launch function f using the following convention:
*
* Call(f,
* [arg1, arg2, ..., arg_n,
* work_size_1, work_size_2, ... work_size_m, dyn_shmem_size])
*
* Here n = len(arg), m = len(work_size) = len(device_thread_axis).
* Here n = len(arg), m = len(work_size) = len(launch_params)-1.
*
* When kDeviceUseDynSharedMemory is not set, dyn_shmem_size argument is omitted.
* The list of kernel launch params indicates which additional
* parameters will be provided to the PackedFunc by the calling
* scope.
*
* The list of device_thread_axis indicates how can be bind the
* work_size arguments to the corresponding threads.
* - "threadIdx.x", "threadIdx.y", "threadIdx.z"
*
* \sa tvm::CallingConv::kDeviceKernelLaunch
*/
constexpr const char* kDeviceThreadAxis = "tir.device_thread_axis";

/*!
* \brief Whether or not use dynamic shared memory.
* The extent of the thread count in x/y/z, to be used when
* launching the compute kernel on the device. For example, the
* gridDimX/Y/Z parameters passed to cuLaunchKernel when launching a
* CUDA kernel, or the groupCountX/Y/Z parameters passed to
* vkCmdDispatch when dispatching a compute pipeline to Vulkan.
*
* Type: Integer
* - "blockIdx.x", "blockIdx.y", "blockIdx.z"
*
* The extent of the block iterators, to be used when launching the
* compute kernel on the device. For example, the blockDimX/Y/Z
* parameters passed to cuLaunchKernel when launching a CUDA kernel.
* For runtimes that do not require the block to be provided
* externally, this parameter is ignored. For example, the
* spv::ExecutionModeLocalSize for SPIR-V shaders on Vulkan, where
* this parameter is defined in the shader.
*
* - tvm::runtime::launch_param::kUseDynamicSharedMemoryTag
*
* The size of the shared memory that may be allocated internally by
* the kernel. For example, exposed as the
* CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES attribute in
* cuda.
*
* Defined as "tir.use_dyn_shared_memory".
*
* \sa tvm::CallingConv::kDeviceKernelLaunch
*/
constexpr const char* kDeviceUseDynSharedMemory = "tir.device_use_dyn_shared_memory";
constexpr const char* kKernelLaunchParams = "tir.kernel_launch_params";

/*!
* \brief Whether to set noalias rule on the function arguments.
Expand Down
36 changes: 25 additions & 11 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -734,19 +734,33 @@ class SeqStmt : public Stmt {
public:
explicit Flattener(Array<Stmt>* seq) : seq_(seq) {}

void operator()(size_t i, const Stmt& stmt) const {
if (!stmt.defined()) return;
if (auto* op = stmt.as<SeqStmtNode>()) {
operator()(0, op->seq);
} else {
seq_->push_back(stmt);
template <typename T>
void operator()(size_t i, const T& stmt_or_seq) const {
if constexpr (std::is_base_of_v<ObjectRef, T>) {
// Early bail-out, applicable to any ObjectRef
if (!stmt_or_seq.defined()) return;
}
}

template <typename T>
void operator()(size_t i, const T& seq) const {
for (auto v : seq) {
this->operator()(0, v);
if constexpr (std::is_same_v<T, SeqStmt>) {
// No need for dynamic type-checking if the static type is a
// SeqStmt.
(*this)(0, stmt_or_seq->seq);
} else if constexpr (std::is_base_of_v<T, SeqStmt>) {
// Dynamic type-checking for a SeqStmt that could be
// flattened.
if (auto* op = stmt_or_seq.template as<SeqStmtNode>()) {
operator()(0, op->seq);
} else {
seq_->push_back(stmt_or_seq);
}
} else if constexpr (std::is_base_of_v<Stmt, T>) {
// Any other Stmt type just gets appended.
seq_->push_back(stmt_or_seq);
} else {
// Anything else is treated as an iterable of Stmt.
for (auto v : stmt_or_seq) {
this->operator()(0, v);
}
}
}

Expand Down
13 changes: 13 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,19 @@ TVM_DLL Pass RewriteUnsafeSelect();
*/
TVM_DLL Pass Simplify();

/*!
* \brief Convert an IRModule to be SSA form.
*
* This pass handles cases where the same tir::Var appears in
* multiple functions within the same module. For example, after
* extracting a fragment from one function into another, where the
* same `tir::Var` may be defined both as within the body of the
* original function, and as a parameter within the hoisted function.
*
* \return The pass.
*/
TVM_DLL Pass ConvertSSA();

/*!
* \brief Instruments bound checkers.
*
Expand Down
14 changes: 12 additions & 2 deletions python/tvm/ir/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class IRModule(Node, Scriptable):
Map of global var to BaseFunc
"""

def __init__(self, functions=None, type_definitions=None):
def __init__(self, functions=None, type_definitions=None, attrs=None):
if functions is None:
functions = {}
elif isinstance(functions, dict):
Expand All @@ -60,7 +60,17 @@ def __init__(self, functions=None, type_definitions=None):
raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]")
mapped_type_defs[k] = v
type_definitions = mapped_type_defs
self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions)

attrs = None if not attrs else attrs
if attrs is not None:
attrs = ast.literal_eval(str(attrs))
attrs = tvm.ir.make_node("DictAttrs", **attrs)
self.__init_handle_by_constructor__(
_ffi_api.IRModule,
functions,
type_definitions,
attrs,
)

def __setitem__(self, var, val):
"""Add a mapping to the module.
Expand Down
17 changes: 15 additions & 2 deletions python/tvm/script/ir_builder/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ def __enter__(self) -> "IRBuilderFrame":
_ffi_api.IRBuilderFrameEnter(self) # type: ignore[attr-defined] # pylint: disable=no-member
return self

def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument
_ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] # pylint: disable=no-member
def __exit__(self, exc_type, exc_value, trace) -> None: # pylint: disable=unused-argument
if exc_type is None and exc_value is None:
# Do not execute `FrameExit` if the with scope exits because of exceptions
_ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] # pylint: disable=no-member

def add_callback(self, callback: Callable[[], None]) -> None:
"""Add a callback method invoked when exiting the with-scope.
Expand Down Expand Up @@ -136,6 +138,17 @@ def current() -> "IRBuilder":
"""
return _ffi_api.IRBuilderCurrent() # type: ignore[attr-defined] # pylint: disable=no-member

@staticmethod
def is_in_scope() -> bool:
"""See if the current thread-local scope has an IRBuilder.

Returns
-------
bool
Whether the current thread-local scope has an IRBuilder
"""
return _ffi_api.IRBuilderIsInScope() # type: ignore[attr-defined] # pylint: disable=no-member

def get(self) -> _Object:
"""Get the constructed IR."""
return _ffi_api.IRBuilderGet(self) # type: ignore[attr-defined] # pylint: disable=no-member
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/script/ir_builder/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,9 @@
# under the License.
"""Package tvm.script.ir_builder.ir"""
from .frame import IRModuleFrame
from .ir import ir_module
from .ir import (
decl_function,
def_function,
ir_module,
module_attrs,
)
Loading