Skip to content

Commit 3de77f8

Browse files
committed
Merge remote-tracking branch 'main' into unity
2 parents 3f1347c + ffa0033 commit 3de77f8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+768
-272
lines changed

.github/workflows/main.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,9 @@ jobs:
7171
python -m pytest -v tests/python/all-platform-minimal-test
7272
- name: Minimal Metal Compile-Only
7373
shell: bash -l {0}
74-
run: >-
74+
run: |
7575
python -m pytest -v -s 'tests/python/unittest/test_allreduce.py::test_allreduce_sum_compile'
76+
python -m pytest -v -s 'tests/python/unittest/test_target_codegen_metal.py::test_func_with_trailing_pod_params'
7677
- name: Minimal Metal Compile-and-Run
7778
shell: bash -l {0}
7879
run: >-

cmake/utils/FindLLVM.cmake

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,14 @@ macro(find_llvm use_llvm)
111111
message(FATAL_ERROR "Fatal error executing: ${LLVM_CONFIG} --libdir")
112112
endif()
113113
message(STATUS "LLVM libdir: ${__llvm_libdir}")
114+
execute_process(COMMAND ${LLVM_CONFIG} --cmakedir
115+
RESULT_VARIABLE __llvm_exit_code
116+
OUTPUT_VARIABLE __llvm_cmakedir
117+
OUTPUT_STRIP_TRAILING_WHITESPACE)
118+
if(NOT "${__llvm_exit_code}" STREQUAL "0")
119+
message(FATAL_ERROR "Fatal error executing: ${LLVM_CONFIG} --cmakedir")
120+
endif()
121+
message(STATUS "LLVM cmakedir: ${__llvm_cmakedir}")
114122
# map prefix => $
115123
# to handle the case when the prefix contains space.
116124
string(REPLACE ${__llvm_prefix} "$" __llvm_cxxflags ${__llvm_cxxflags_space})
@@ -165,6 +173,7 @@ macro(find_llvm use_llvm)
165173
find_package(ZLIB REQUIRED)
166174
list(APPEND LLVM_LIBS "ZLIB::ZLIB")
167175
elseif("${__flag}" STREQUAL "-lzstd" OR ("${__flag}" STREQUAL "zstd.dll.lib"))
176+
list(APPEND CMAKE_MODULE_PATH "${__llvm_cmakedir}")
168177
find_package(zstd REQUIRED)
169178
if (TARGET "zstd::libzstd_static")
170179
message(STATUS "LLVM links against static zstd")

docs/how_to/deploy/tensorrt.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ regular TVM CUDA compilation and code generation.
9696
.. code:: python
9797
9898
from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt
99-
mod, config = partition_for_tensorrt(mod, params)
99+
mod = partition_for_tensorrt(mod, params)
100100
101101
102102
Build the Relay graph, using the new module and config returned by partition_for_tensorrt. The
@@ -107,7 +107,7 @@ PassContext so the values can be read during compilation.
107107
.. code:: python
108108
109109
target = "cuda"
110-
with tvm.transform.PassContext(opt_level=3, config={'relay.ext.tensorrt.options': config}):
110+
with tvm.transform.PassContext(opt_level=3):
111111
lib = relay.build(mod, target=target, params=params)
112112
113113

include/tvm/runtime/memory/memory_manager.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ class Allocator {
8989
* \param buffer The buffer to free.
9090
*/
9191
virtual void Free(const Buffer& buffer) = 0;
92+
/*! \brief Clear the allocated memory. */
93+
virtual void Clear();
9294
/*! \brief The amount of memory currently allocated.
9395
* \return The amount of memory currently allocated.
9496
*/
@@ -119,6 +121,8 @@ class MemoryManager {
119121
* \return The memory allocator.
120122
*/
121123
static Allocator* GetAllocator(Device dev, AllocatorType type);
124+
/*! \brief Clear the allocators. */
125+
static void Clear();
122126

123127
private:
124128
MemoryManager() {}

include/tvm/runtime/packed_func.h

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,6 +1145,31 @@ struct PackedFuncValueConverter {
11451145
} \
11461146
}
11471147

1148+
#define TVM_MODULE_VTABLE_BEGIN(TypeKey) \
1149+
const char* type_key() const final { return TypeKey; } \
1150+
PackedFunc GetFunction(const String& _name, const ObjectPtr<Object>& _self) final { \
1151+
using SelfPtr = std::remove_cv_t<decltype(this)>;
1152+
#define TVM_MODULE_VTABLE_END() \
1153+
return PackedFunc(nullptr); \
1154+
}
1155+
#define TVM_MODULE_VTABLE_ENTRY(Name, MemFunc) \
1156+
if (_name == Name) { \
1157+
return PackedFunc([_self](TVMArgs args, TVMRetValue* rv) -> void { \
1158+
using Helper = ::tvm::runtime::detail::ModuleVTableEntryHelper<decltype(MemFunc)>; \
1159+
SelfPtr self = static_cast<SelfPtr>(_self.get()); \
1160+
CHECK_EQ(args.size(), Helper::LenArgs) \
1161+
<< "Function `" << self->type_key() << "::" << Name << "` requires " << Helper::LenArgs \
1162+
<< " arguments, but got " << args.size(); \
1163+
Helper::Call(rv, self, MemFunc, args, Helper::IndexSeq{}); \
1164+
}); \
1165+
}
1166+
#define TVM_MODULE_VTABLE_ENTRY_PACKED(Name, Func) \
1167+
if (_name == Name) { \
1168+
auto f = (Func); \
1169+
using FType = ::tvm::runtime::detail::function_signature<decltype(f)>::FType; \
1170+
return TypedPackedFunc<FType>(std::move(f)).packed(); \
1171+
}
1172+
11481173
/*!
11491174
* \brief Export typed function as a PackedFunc
11501175
* that can be loaded by LibraryModule.
@@ -1330,6 +1355,61 @@ inline void for_each(const F& f, Args&&... args) { // NOLINT(*)
13301355
for_each_dispatcher<sizeof...(Args) == 0, 0, F>::run(f, std::forward<Args>(args)...);
13311356
}
13321357

1358+
template <typename T>
1359+
struct ModuleVTableEntryHelper {};
1360+
1361+
template <typename T, typename R, typename... Args>
1362+
struct ModuleVTableEntryHelper<R (T::*)(Args...) const> {
1363+
using MemFnType = R (T::*)(Args...) const;
1364+
using IndexSeq = std::index_sequence_for<Args...>;
1365+
static constexpr const std::size_t LenArgs = sizeof...(Args);
1366+
1367+
template <std::size_t... Is>
1368+
static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args,
1369+
std::index_sequence<Is...>) {
1370+
*rv = (self->*f)(args[Is]...);
1371+
}
1372+
};
1373+
1374+
template <typename T, typename R, typename... Args>
1375+
struct ModuleVTableEntryHelper<R (T::*)(Args...)> {
1376+
using MemFnType = R (T::*)(Args...);
1377+
using IndexSeq = std::index_sequence_for<Args...>;
1378+
static constexpr const std::size_t LenArgs = sizeof...(Args);
1379+
1380+
template <std::size_t... Is>
1381+
static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args,
1382+
std::index_sequence<Is...>) {
1383+
*rv = (self->*f)(args[Is]...);
1384+
}
1385+
};
1386+
1387+
template <typename T, typename... Args>
1388+
struct ModuleVTableEntryHelper<void (T::*)(Args...) const> {
1389+
using MemFnType = void (T::*)(Args...) const;
1390+
using IndexSeq = std::index_sequence_for<Args...>;
1391+
static constexpr const std::size_t LenArgs = sizeof...(Args);
1392+
1393+
template <std::size_t... Is>
1394+
static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args,
1395+
std::index_sequence<Is...>) {
1396+
(self->*f)(args[Is]...);
1397+
}
1398+
};
1399+
1400+
template <typename T, typename... Args>
1401+
struct ModuleVTableEntryHelper<void (T::*)(Args...)> {
1402+
using MemFnType = void (T::*)(Args...);
1403+
using IndexSeq = std::index_sequence_for<Args...>;
1404+
static constexpr const std::size_t LenArgs = sizeof...(Args);
1405+
1406+
template <std::size_t... Is>
1407+
static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args,
1408+
std::index_sequence<Is...>) {
1409+
(self->*f)(args[Is]...);
1410+
}
1411+
};
1412+
13331413
namespace parameter_pack {
13341414

13351415
template <typename... EnumArgs>

include/tvm/runtime/vm/executable.h

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -57,19 +57,28 @@ struct VMFunction;
5757
*/
5858
class TVM_DLL Executable : public ModuleNode {
5959
public:
60-
/*!
61-
* \brief Get a PackedFunc from an executable module.
62-
*
63-
* \param name the name of the function.
64-
* \param sptr_to_self The shared_ptr that points to this module node.
65-
*
66-
* \return PackedFunc or nullptr when it is not available.
67-
*/
68-
PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) final;
60+
TVM_MODULE_VTABLE_BEGIN("VMExecutable");
61+
TVM_MODULE_VTABLE_ENTRY("get_lib", &Executable::GetLib);
62+
TVM_MODULE_VTABLE_ENTRY("get_bytecode", &Executable::GetBytecode);
63+
TVM_MODULE_VTABLE_ENTRY("get_constants", &Executable::GetConstants);
64+
TVM_MODULE_VTABLE_ENTRY("get_virtual_devices", &Executable::GetVirtualDevices);
65+
TVM_MODULE_VTABLE_ENTRY("get_primitives", &Executable::GetPrimitives);
66+
TVM_MODULE_VTABLE_ENTRY("get_stats", &Executable::Stats);
67+
TVM_MODULE_VTABLE_ENTRY("save", &Executable::Save);
68+
TVM_MODULE_VTABLE_ENTRY("get_function_arity", &Executable::GetFunctionArity);
69+
TVM_MODULE_VTABLE_ENTRY("get_function_param_name", &Executable::GetFunctionParameterName);
70+
TVM_MODULE_VTABLE_ENTRY("vm_load_executable", &Executable::VMLoadExecutable);
71+
TVM_MODULE_VTABLE_ENTRY("move_late_bound_consts", &Executable::MoveLateBoundConstantsToFile);
72+
TVM_MODULE_VTABLE_ENTRY("get_late_bound_consts", &Executable::GetLateBoundConstants);
73+
TVM_MODULE_VTABLE_ENTRY("load_late_bound_consts", &Executable::LoadLateBoundConstantsFromFile);
74+
TVM_MODULE_VTABLE_ENTRY("load_late_bound_consts_from_map",
75+
&Executable::LoadLateBoundConstantsFromMap);
76+
TVM_MODULE_VTABLE_END();
6977

7078
/*! \brief Get the property of the runtime module .*/
7179
int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable; };
72-
80+
/*! \brief Creates a VM that loads `this` as the executable. */
81+
Module VMLoadExecutable();
7382
/*!
7483
* \brief Write the Executable to the binary stream in serialized form.
7584
*
@@ -123,17 +132,17 @@ class TVM_DLL Executable : public ModuleNode {
123132
* Must be called before \p SaveToBinary and friends if late-bound constants are
124133
* desired. Otherwise can be ignore.
125134
*/
126-
void MoveLateBoundConstantsToStream(dmlc::Stream* stream, size_t byte_limit);
135+
void MoveLateBoundConstantsToStream(dmlc::Stream* stream, int64_t byte_limit);
127136

128137
/*!
129138
* \brief As for \p MoveLateBoundConstantsToStream, but save to file at \p path.
130139
*/
131-
void MoveLateBoundConstantsToFile(const std::string& path, size_t byte_limit);
140+
void MoveLateBoundConstantsToFile(const std::string& path, int64_t byte_limit);
132141

133142
/*!
134143
* \brief Get a map of all constants with larger that byte_limit in size.
135144
*/
136-
Map<String, NDArray> GetLateBoundConstants(size_t byte_limit);
145+
Map<String, NDArray> GetLateBoundConstants(int64_t byte_limit);
137146

138147
/*!
139148
* \brief Restores the late-bound constants for the executable (if any) from given byte-stream.
@@ -255,12 +264,10 @@ class TVM_DLL Executable : public ModuleNode {
255264
* \param index Parameter index.
256265
* \return The parameter name.
257266
*/
258-
std::string GetFunctionParameterName(std::string func, uint32_t index) const;
267+
std::string GetFunctionParameterName(std::string func, int index) const;
259268

260269
virtual ~Executable() {}
261270

262-
const char* type_key() const final { return "VMExecutable"; }
263-
264271
/*!
265272
* \brief The (compile-time, virtual) devices corresponding to each device index.
266273
* This vector contains a pair Device and its memory_scope.

python/tvm/_ffi/_ctypes/packed_func.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,8 @@ def _init_pythonapi_inc_def_ref():
340340
register_func = _LIB.TVMBackendRegisterEnvCAPI
341341
register_func(c_str("Py_IncRef"), ctypes.pythonapi.Py_IncRef)
342342
register_func(c_str("Py_DecRef"), ctypes.pythonapi.Py_DecRef)
343+
register_func(c_str("PyGILState_Ensure"), ctypes.pythonapi.PyGILState_Ensure)
344+
register_func(c_str("PyGILState_Release"), ctypes.pythonapi.PyGILState_Release)
343345

344346

345347
_init_pythonapi_inc_def_ref()

python/tvm/_ffi/_cython/packed_func.pxi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import ctypes
1919
import traceback
20-
from cpython cimport Py_INCREF, Py_DECREF
20+
from cpython cimport Py_INCREF, Py_DECREF, PyGILState_Ensure, PyGILState_Release
2121
from numbers import Number, Integral
2222
from ..base import string_types, py2cerror
2323
from ..runtime_ctypes import DataType, Device, TVMByteArray, ObjectRValueRef
@@ -381,5 +381,7 @@ def _init_pythonapi_inc_def_ref():
381381
register_func = TVMBackendRegisterEnvCAPI
382382
register_func(c_str("Py_IncRef"), <void*>_py_incref_wrapper)
383383
register_func(c_str("Py_DecRef"), <void*>_py_decref_wrapper)
384+
register_func(c_str("PyGILState_Ensure"), <void*>PyGILState_Ensure)
385+
register_func(c_str("PyGILState_Release"), <void*>PyGILState_Release)
384386

385387
_init_pythonapi_inc_def_ref()

python/tvm/_ffi/base.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import sys
2525
import types
2626

27-
from typing import Callable, Sequence
27+
from typing import Callable, Sequence, Optional
2828

2929
import numpy as np
3030

@@ -340,15 +340,16 @@ def get_last_ffi_error():
340340
return ERROR_TYPE.get(err_type, TVMError)(py_err_msg)
341341

342342

343-
def _append_traceback_frame(tb, func_name, filepath, lineno):
343+
def _append_traceback_frame(tb, func_name, filepath, lineno: Optional[int]):
344344
"""Append a dummy frame to appear in the Python traceback"""
345345

346346
# Compile a dummy function to Python bytecode, so that with the
347347
# filepath that we want to appear in the traceback. Any external
348348
# debugger (e.g. pdb) that catches the exception will use the
349349
# filepath to show code snippets from that FFI file.
350+
header = "" if lineno is None else "\n" * (lineno - 1)
350351
code = compile(
351-
"{}def dummy_func(): raise NotImplementedError()".format("\n" * (lineno - 1)),
352+
f"{header}def dummy_func(): raise NotImplementedError()",
352353
filepath,
353354
"exec",
354355
)
@@ -446,10 +447,14 @@ def raise_last_ffi_error():
446447
for frame in frames:
447448
if " at " in frame:
448449
func_name, frame = frame.split(" at ", 1)
449-
filename, lineno = frame.rsplit(":", 1)
450+
if ":" in frame:
451+
filename, lineno = frame.rsplit(":", 1)
452+
lineno = int(lineno.strip())
453+
else:
454+
filename = frame
455+
lineno = None
450456
func_name = func_name.strip()
451457
filename = filename.strip()
452-
lineno = int(lineno.strip())
453458

454459
tb = _append_traceback_frame(tb, func_name, filename, lineno)
455460

python/tvm/contrib/nvcc.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,32 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target
6767
arch = ["-gencode", f"arch=compute_{compute_version},code=sm_{compute_version}"]
6868

6969
temp = utils.tempdir()
70+
file_name = "tvm_kernels"
7071
if target_format not in ["cubin", "ptx", "fatbin"]:
7172
raise ValueError("target_format must be in cubin, ptx, fatbin")
72-
temp_code = temp.relpath("my_kernel.cu")
73-
temp_target = temp.relpath(f"my_kernel.{target_format}")
73+
temp_code = temp.relpath(f"{file_name}.cu")
74+
temp_target = temp.relpath(f"{file_name}.{target_format}")
75+
76+
pass_context = tvm.get_global_func("transform.GetCurrentPassContext")()
77+
kernels_output_dir = (
78+
pass_context.config["cuda.kernels_output_dir"]
79+
if "cuda.kernels_output_dir" in pass_context.config
80+
else None
81+
)
82+
if kernels_output_dir is not None:
83+
if not os.path.isdir(kernels_output_dir):
84+
os.makedirs(kernels_output_dir)
85+
temp_code = os.path.join(kernels_output_dir, f"{file_name}.cu")
86+
temp_target = os.path.join(kernels_output_dir, f"{file_name}.{target_format}")
7487

7588
with open(temp_code, "w") as out_file:
7689
out_file.write(code)
7790

7891
file_target = path_target if path_target else temp_target
7992
cmd = ["nvcc"]
8093
cmd += [f"--{target_format}", "-O3"]
94+
if kernels_output_dir is not None:
95+
cmd += ["-lineinfo"]
8196
if isinstance(arch, list):
8297
cmd += arch
8398
elif isinstance(arch, str):

0 commit comments

Comments
 (0)