From f49ee302a55b96ca135f2d4d3691cdc82dff76d8 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 17 May 2023 15:06:11 -0500 Subject: [PATCH 1/9] [AOT] Const-correctness in argument for TVMAotExecutor_GetInputName The `TVMAotExecutor_GetInputName` function requires a `const char**`, but is provided with a `char**`. While C's automatic pointer conversion does allow this as an automatic conversion, following const-correctness avoids a compile-time warning. This warning was first caused following https://github.com/apache/tvm/pull/14529, which itself was to follow const-correctness within the implementation of `TVMAotExecutor_GetInputName`. --- src/runtime/crt/aot_executor_module/aot_executor_module.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/runtime/crt/aot_executor_module/aot_executor_module.c b/src/runtime/crt/aot_executor_module/aot_executor_module.c index a5c8105144f7..155f45bff5ce 100644 --- a/src/runtime/crt/aot_executor_module/aot_executor_module.c +++ b/src/runtime/crt/aot_executor_module/aot_executor_module.c @@ -154,7 +154,7 @@ int32_t TVMAotExecutorModule_GetInputName(TVMValue* args, int* tcodes, int nargs return kTvmErrorFunctionCallNumArguments; } - char* name; + const char* name; int ret = TVMAotExecutor_GetInputName(aot_executor.executor, args[0].v_int64, &name); if (ret < 0) { return kTvmErrorExecutorModuleNoSuchInput; From 26d504ef244c971b135198ffd9b7906bd3e4375e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 5 May 2023 09:22:28 -0500 Subject: [PATCH 2/9] [Rust] Remove alignment requirement in rust test PrimFunc By default, the LLVM codegen is allowed to assume that externally-allocated buffers are 64-byte aligned (from the value of `kAllocAlignment`). The buffers allocated in the Rust unit tests do not provide this alignment, and need to specify that no additional alignment is provided. This is a subset of changes made in https://github.com/apache/tvm/pull/14771, broken out into an independent PR for ease of testing/review. --- .../test_tvm_basic/src/build_test_lib.py | 24 ++++++++------ .../tests/test_tvm_dso/src/build_test_lib.py | 24 ++++++++------ .../tests/test_wasm32/src/build_test_lib.py | 24 ++++++++------ rust/tvm/tests/basics/src/tvm_add.py | 31 +++++++++++++------ 4 files changed, 66 insertions(+), 37 deletions(-) diff --git a/rust/tvm-graph-rt/tests/test_tvm_basic/src/build_test_lib.py b/rust/tvm-graph-rt/tests/test_tvm_basic/src/build_test_lib.py index d6e1922efa85..52b1dec2c9fe 100755 --- a/rust/tvm-graph-rt/tests/test_tvm_basic/src/build_test_lib.py +++ b/rust/tvm-graph-rt/tests/test_tvm_basic/src/build_test_lib.py @@ -23,19 +23,25 @@ import tvm from tvm.relay.backend import Runtime -from tvm import te +from tvm.script import tir as T def main(): - n = te.var("n") - A = te.placeholder((n,), name="A") - B = te.placeholder((n,), name="B") - C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C") - s = tvm.te.create_schedule(C.op) - s[C].parallel(s[C].op.axis[0]) + @T.prim_func + def func(var_A: T.handle, var_B: T.handle, var_C: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + n = T.int32() + A = T.match_buffer(var_A, (n,), align=1) + B = T.match_buffer(var_B, (n,), align=1) + C = T.match_buffer(var_C, (n,), align=1) + for i in T.parallel(n): + with T.block("C"): + vi = T.axis.spatial(n, i) + C[vi] = A[vi] + B[vi] + runtime = Runtime("cpp", {"system-lib": True}) - print(tvm.lower(s, [A, B, C], simple_mode=True)) - tvm.build(s, [A, B, C], "llvm", runtime=runtime).save(osp.join(sys.argv[1], "test.o")) + print(tvm.lower(func, simple_mode=True)) + tvm.build(func, target="llvm", runtime=runtime).save(osp.join(sys.argv[1], "test.o")) if __name__ == "__main__": diff --git a/rust/tvm-graph-rt/tests/test_tvm_dso/src/build_test_lib.py b/rust/tvm-graph-rt/tests/test_tvm_dso/src/build_test_lib.py index 4b270fa17cbc..415832a614f0 100755 --- a/rust/tvm-graph-rt/tests/test_tvm_dso/src/build_test_lib.py +++ b/rust/tvm-graph-rt/tests/test_tvm_dso/src/build_test_lib.py @@ -22,20 +22,26 @@ import sys import tvm -from tvm import te from tvm.contrib import cc +from tvm.script import tir as T def main(): - n = te.var("n") - A = te.placeholder((n,), name="A") - B = te.placeholder((n,), name="B") - C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C") - s = tvm.te.create_schedule(C.op) - s[C].parallel(s[C].op.axis[0]) - print(tvm.lower(s, [A, B, C], simple_mode=True)) + @T.prim_func + def func(var_A: T.handle, var_B: T.handle, var_C: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + n = T.int32() + A = T.match_buffer(var_A, (n,), align=1) + B = T.match_buffer(var_B, (n,), align=1) + C = T.match_buffer(var_C, (n,), align=1) + for i in T.parallel(n): + with T.block("C"): + vi = T.axis.spatial(n, i) + C[vi] = A[vi] + B[vi] + + print(tvm.lower(func, simple_mode=True)) obj_file = osp.join(sys.argv[1], "test.o") - tvm.build(s, [A, B, C], "llvm").save(obj_file) + tvm.build(func, "llvm").save(obj_file) cc.create_shared(osp.join(sys.argv[1], "test.so"), [obj_file]) diff --git a/rust/tvm-graph-rt/tests/test_wasm32/src/build_test_lib.py b/rust/tvm-graph-rt/tests/test_wasm32/src/build_test_lib.py index 2bf327a31b1b..6d744b28c67e 100755 --- a/rust/tvm-graph-rt/tests/test_wasm32/src/build_test_lib.py +++ b/rust/tvm-graph-rt/tests/test_wasm32/src/build_test_lib.py @@ -22,20 +22,26 @@ import sys import tvm -from tvm import te from tvm.relay.backend import Runtime +from tvm.script import tir as T def main(): - n = te.var("n") - A = te.placeholder((n,), name="A") - B = te.placeholder((n,), name="B") - C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C") - s = tvm.te.create_schedule(C.op) - s[C].parallel(s[C].op.axis[0]) - print(tvm.lower(s, [A, B, C], simple_mode=True)) + @T.prim_func + def func(var_A: T.handle, var_B: T.handle, var_C: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + n = T.int32() + A = T.match_buffer(var_A, (n,), align=1) + B = T.match_buffer(var_B, (n,), align=1) + C = T.match_buffer(var_C, (n,), align=1) + for i in T.parallel(n): + with T.block("C"): + vi = T.axis.spatial(n, i) + C[vi] = A[vi] + B[vi] + + print(tvm.lower(s, simple_mode=True)) runtime = Runtime("cpp", {"system-lib": True}) - tvm.build(s, [A, B, C], "llvm -mtriple=wasm32-unknown-unknown", runtime=runtime).save( + tvm.build(func, target="llvm -mtriple=wasm32-unknown-unknown", runtime=runtime).save( osp.join(sys.argv[1], "test.o") ) diff --git a/rust/tvm/tests/basics/src/tvm_add.py b/rust/tvm/tests/basics/src/tvm_add.py index fc5c4213bd08..807e27db7794 100755 --- a/rust/tvm/tests/basics/src/tvm_add.py +++ b/rust/tvm/tests/basics/src/tvm_add.py @@ -20,23 +20,34 @@ import sys import tvm -from tvm import te from tvm.contrib import cc +from tvm.script import tir as T def main(target, out_dir): - n = te.var("n") - A = te.placeholder((n,), name="A") - B = te.placeholder((n,), name="B") - C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") - s = te.create_schedule(C.op) + @T.prim_func + def func(var_A: T.handle, var_B: T.handle, var_C: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + n = T.int32() + A = T.match_buffer(var_A, (n,), align=1) + B = T.match_buffer(var_B, (n,), align=1) + C = T.match_buffer(var_C, (n,), align=1) + # with T.block("root"): + for i in range(n): + with T.block("C"): + v_i = T.axis.spatial(n, i) + T.reads(A[v_i], B[v_i]) + T.writes(C[v_i]) + C[v_i] = A[v_i] + B[v_i] if target == "cuda": - bx, tx = s[C].split(C.op.axis[0], factor=64) - s[C].bind(bx, te.thread_axis("blockIdx.x")) - s[C].bind(tx, te.thread_axis("threadIdx.x")) + sch = tvm.tir.Schedule(func) + i, j = sch.split(sch.get_loops("C")[0], [None, 64]) + sch.bind(i, "blockIdx.x") + sch.bind(j, "threadIdx.x") + func = sch.mod["main"] - fadd = tvm.build(s, [A, B, C], tvm.target.Target(target, host="llvm"), name="myadd") + fadd = tvm.build(func, target=tvm.target.Target(target, host="llvm"), name="myadd") fadd.save(osp.join(out_dir, "test_add.o")) if target == "cuda": fadd.imported_modules[0].save(osp.join(out_dir, "test_add.ptx")) From 2198ae07c90b6a51d79548984706ed1eca3b8d5f Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 8 May 2023 12:21:46 -0500 Subject: [PATCH 3/9] [AOT] Align constants to buffer's alignment, fallback kAllocAlignment Prior to this commit, constants were aligned to 16-byte regions within the allocation pool. As a result, the pointer passed to packed functions would usually not meet the default alignment requirement of `kAllocAlignment = 64` bytes. Because LLVM is allowed to assume that this alignment is met, this could result in runtime errors. This commit updates the default alignment to `kAllocAlignment`, which can be overridden on a per-buffer basis. This is a subset of changes made in https://github.com/apache/tvm/pull/14771, broken out into an independent PR for ease of testing/review. --- src/tir/usmp/analysis/extract_buffer_info.cc | 26 ++++++++++++++------ 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc index 5abfe24f434d..f447ddaa200f 100644 --- a/src/tir/usmp/analysis/extract_buffer_info.cc +++ b/src/tir/usmp/analysis/extract_buffer_info.cc @@ -244,13 +244,25 @@ void BufferInfoExtractor::RecordAllocateNodeInfo(const AllocateNode* op) { "user-given arguments for memory pools, the default behaviour is a single size " "un-restricted pool is assigned"; PrimFunc func = scope_stack_.top().func; - Optional executor_config = - module_->GetAttr(tvm::attr::kExecutor); - Integer workspace_alignment = 16; - if (executor_config) { - workspace_alignment = - executor_config.value()->GetAttr("workspace-byte-alignment").value_or(16); - } + auto workspace_alignment = [&]() -> Integer { + if (const auto* decl_buffer = op->body.as()) { + ICHECK(decl_buffer->buffer->data.same_as(op->buffer_var)) + << "DeclBuffer of Buffer " << decl_buffer->buffer << " has data ptr " + << decl_buffer->buffer->data + << ", which is mismatched from the parent Allocate's buffer_var of " + << op->buffer_var; + return decl_buffer->buffer->data_alignment; + } + + if (auto executor_config = module_->GetAttr(tvm::attr::kExecutor)) { + if (auto config_alignment = + executor_config.value()->GetAttr("workspace-byte-alignment")) { + return config_alignment.value(); + } + } + + return tvm::runtime::kAllocAlignment; + }(); BufferInfoKind bi_kind = BufferInfoKind::kIntermediate; String buffer_info_name = op->buffer_var->name_hint; From c03b16d34ade2e65d021ca74563c5d5aa91465cd Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 15 May 2023 11:26:15 -0500 Subject: [PATCH 4/9] [MicroTVM] Updated unit tests to use aligned memory The default TVM allocators provide allocations with 64-byte alignment. This alignment is provided as a guarantee to LLVM optimizations, and failure to provide aligned allocations may result in incorrect results. This commit updates the MicroTVM examples to provide 64-byte aligned memory allocations. This is a subset of changes made in https://github.com/apache/tvm/pull/14771, broken out into an independent PR for ease of testing/review. --- apps/bundle_deploy/bundle.c | 4 +++- apps/bundle_deploy/bundle_static.c | 4 +++- apps/bundle_deploy/test.cc | 8 ++++++-- apps/bundle_deploy/test_static.c | 9 ++++++--- 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/apps/bundle_deploy/bundle.c b/apps/bundle_deploy/bundle.c index 6018d40dd300..e411c929f936 100644 --- a/apps/bundle_deploy/bundle.c +++ b/apps/bundle_deploy/bundle.c @@ -32,8 +32,10 @@ #define CRT_MEMORY_NUM_PAGES 16384 #define CRT_MEMORY_PAGE_SIZE_LOG2 10 +#define CRT_MEMORY_PAGE_SIZE_BYTES (1 << CRT_MEMORY_PAGE_SIZE_LOG2) -static uint8_t g_crt_memory[CRT_MEMORY_NUM_PAGES * (1 << CRT_MEMORY_PAGE_SIZE_LOG2)]; +static uint8_t g_crt_memory[CRT_MEMORY_NUM_PAGES * CRT_MEMORY_PAGE_SIZE_BYTES] + __attribute__((aligned(CRT_MEMORY_PAGE_SIZE_BYTES))); static MemoryManagerInterface* g_memory_manager; /*! \brief macro to do C API call */ diff --git a/apps/bundle_deploy/bundle_static.c b/apps/bundle_deploy/bundle_static.c index 18a7b2bbb0ff..846e21b36883 100644 --- a/apps/bundle_deploy/bundle_static.c +++ b/apps/bundle_deploy/bundle_static.c @@ -33,8 +33,10 @@ #define CRT_MEMORY_NUM_PAGES 16384 #define CRT_MEMORY_PAGE_SIZE_LOG2 10 +#define CRT_MEMORY_PAGE_SIZE_BYTES (1 << CRT_MEMORY_PAGE_SIZE_LOG2) -static uint8_t g_crt_memory[CRT_MEMORY_NUM_PAGES * (1 << CRT_MEMORY_PAGE_SIZE_LOG2)]; +static uint8_t g_crt_memory[CRT_MEMORY_NUM_PAGES * CRT_MEMORY_PAGE_SIZE_BYTES] + __attribute__((aligned(CRT_MEMORY_PAGE_SIZE_BYTES))); static MemoryManagerInterface* g_memory_manager; /*! \brief macro to do C API call */ diff --git a/apps/bundle_deploy/test.cc b/apps/bundle_deploy/test.cc index 25056f4d17a6..7194427426d2 100644 --- a/apps/bundle_deploy/test.cc +++ b/apps/bundle_deploy/test.cc @@ -23,6 +23,7 @@ #include #include +#include #include #include #include @@ -52,7 +53,7 @@ char* read_all_or_die(const char* name, const char* file_path, size_t* out_size) *out_size = st.st_size; } - char* data = (char*)malloc(st.st_size); + char* data = (char*)std::aligned_alloc(64, st.st_size); FILE* fp = fopen(file_path, "rb"); size_t bytes_to_read = st.st_size; size_t bytes_read = 0; @@ -129,7 +130,7 @@ int main(int argc, char** argv) { ftvm_runtime_run(handle); gettimeofday(&t3, 0); - float output_storage[10 * 5]; + float* output_storage = static_cast(std::aligned_alloc(64, 10 * 5 * sizeof(float))); std::vector output_shape = {10, 5}; DLTensor output; output.data = output_storage; @@ -162,6 +163,9 @@ int main(int argc, char** argv) { (t4.tv_sec - t3.tv_sec) * 1000.0f + (t4.tv_usec - t3.tv_usec) / 1000.f, (t5.tv_sec - t4.tv_sec) * 1000.0f + (t5.tv_usec - t4.tv_usec) / 1000.f); + free(output_storage); + free(result_storage); + free(input_storage); free(json_data); free(params_data); dlclose(bundle); diff --git a/apps/bundle_deploy/test_static.c b/apps/bundle_deploy/test_static.c index b9c980843ea1..4258b3a4e0f9 100644 --- a/apps/bundle_deploy/test_static.c +++ b/apps/bundle_deploy/test_static.c @@ -54,12 +54,12 @@ int main(int argc, char** argv) { void* handle = tvm_runtime_create(json_data, params_data, params_size, argv[0]); gettimeofday(&t1, 0); - float input_storage[10 * 5]; + float* input_storage = aligned_alloc(64, 10 * 5 * sizeof(float)); fp = fopen(argv[1], "rb"); fread(input_storage, 10 * 5, 4, fp); fclose(fp); - float result_storage[10 * 5]; + float* result_storage = aligned_alloc(64, 10 * 5 * sizeof(float)); fp = fopen(argv[2], "rb"); fread(result_storage, 10 * 5, 4, fp); fclose(fp); @@ -82,7 +82,7 @@ int main(int argc, char** argv) { tvm_runtime_run(handle); gettimeofday(&t3, 0); - float output_storage[10 * 5]; + float* output_storage = aligned_alloc(64, 10 * 5 * sizeof(float)); DLTensor output; output.data = output_storage; DLDevice out_dev = {kDLCPU, 0}; @@ -117,6 +117,9 @@ int main(int argc, char** argv) { (t4.tv_sec - t3.tv_sec) * 1000 + (t4.tv_usec - t3.tv_usec) / 1000.f, (t5.tv_sec - t4.tv_sec) * 1000 + (t5.tv_usec - t4.tv_usec) / 1000.f); + free(output_storage); + free(result_storage); + free(input_storage); free(json_data); free(params_data); From 7df1810aa34408d669bebc2ab5db57cf300a2e6d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 16 May 2023 15:08:29 -0500 Subject: [PATCH 5/9] [CRT] Updated CRT to allocate aligned memory The default TVM allocators provide allocations with 64-byte alignment. This alignment is provided as a guarantee to LLVM optimizations, and failure to provide aligned allocations may result in incorrect results. This commit updates the CRT examples to provide 64-byte aligned memory allocations. This is a subset of changes made in https://github.com/apache/tvm/pull/14771, broken out into an independent PR for ease of testing/review. --- cmake/utils/CRTConfig.cmake | 1 + src/runtime/crt/common/crt_runtime_api.c | 30 +++++++++++++++---- src/runtime/crt/common/ndarray.c | 6 ++-- src/runtime/crt/crt_config.h.template | 3 ++ .../crt/graph_executor/graph_executor.c | 11 ++----- src/runtime/crt/host/CMakeLists.txt.template | 12 ++++++-- 6 files changed, 45 insertions(+), 18 deletions(-) diff --git a/cmake/utils/CRTConfig.cmake b/cmake/utils/CRTConfig.cmake index 42c523b08786..7f6b058bfd0a 100644 --- a/cmake/utils/CRTConfig.cmake +++ b/cmake/utils/CRTConfig.cmake @@ -25,6 +25,7 @@ function(generate_crt_config platform output_path) set(TVM_CRT_MAX_STRLEN_DLTYPE 10) set(TVM_CRT_MAX_STRLEN_FUNCTION_NAME 120) set(TVM_CRT_MAX_STRLEN_PARAM_NAME 80) + set(TVM_CRT_ALLOC_ALIGNMENT 64) if("${platform}" STREQUAL "zephyr") set(TVM_CRT_MAX_PACKET_SIZE_BYTES 512) diff --git a/src/runtime/crt/common/crt_runtime_api.c b/src/runtime/crt/common/crt_runtime_api.c index 99b3201b95b0..4e0e0d4fc240 100644 --- a/src/runtime/crt/common/crt_runtime_api.c +++ b/src/runtime/crt/common/crt_runtime_api.c @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -90,10 +91,26 @@ int TVMArrayFree(TVMArrayHandle handle) { int TVMDeviceAllocDataSpace(DLDevice dev, size_t nbytes, size_t alignment, DLDataType type_hint, void** out_data) { - if (alignment != 1) { - nbytes = (nbytes + alignment - 1) / alignment * alignment; - } - return TVMPlatformMemoryAllocate(nbytes, dev, out_data); + // The TVMPlatformMemoryAllocate function does not guarantee the + // alignment of the allocation. Therefore, deliberately + // overallocate by (alignment-1) and return an aligned region from + // it. + size_t total_bytes = nbytes + sizeof(void*) + (alignment - 1); + void* allocated_buf; + int err = TVMPlatformMemoryAllocate(total_bytes, dev, &allocated_buf); + if (err) return err; + + void* first_allowed_data_ptr = ((uint8_t*)allocated_buf) + sizeof(void*); + uintptr_t offset = (alignment - ((uintptr_t)first_allowed_data_ptr) % alignment); + void* data_ptr = first_allowed_data_ptr + offset; + + // Must keep a pointer to the original allocation, so that it can be + // passed to TVMPlatformMemoryFree. + ((void**)data_ptr)[-1] = allocated_buf; + + *out_data = data_ptr; + + return err; } int TVMDeviceAllocDataSpaceWithScope(DLDevice dev, int ndim, const int64_t* shape, DLDataType dtype, @@ -110,7 +127,10 @@ int TVMDeviceAllocDataSpaceWithScope(DLDevice dev, int ndim, const int64_t* shap return TVMDeviceAllocDataSpace(dev, nbytes, align, dtype, out_data); } -int TVMDeviceFreeDataSpace(DLDevice dev, void* ptr) { return TVMPlatformMemoryFree(ptr, dev); } +int TVMDeviceFreeDataSpace(DLDevice dev, void* ptr) { + void* allocated_buf = ((void**)ptr)[-1]; + return TVMPlatformMemoryFree(allocated_buf, dev); +} TVM_ATTRIBUTE_UNUSED static bool IsContiguous(const DLTensor* arr) { if (arr->strides == NULL) return true; diff --git a/src/runtime/crt/common/ndarray.c b/src/runtime/crt/common/ndarray.c index b0e869766bde..1c1e3c58f47a 100644 --- a/src/runtime/crt/common/ndarray.c +++ b/src/runtime/crt/common/ndarray.c @@ -63,8 +63,8 @@ int TVMNDArray_Empty(int32_t ndim, const tvm_index_t* shape, DLDataType dtype, D return status; } int total_elem_bytes = TVMNDArray_DataSizeBytes(array); - array->dl_tensor.data = - TVMBackendAllocWorkspace(kDLCPU, 0, total_elem_bytes, dtype.code, dtype.bits); + TVMDeviceAllocDataSpace(dev, total_elem_bytes, TVM_CRT_ALLOC_ALIGNMENT, dtype, + &array->dl_tensor.data); memset(array->dl_tensor.data, 0, total_elem_bytes); return 0; } @@ -167,7 +167,7 @@ int TVMNDArray_Release(TVMNDArray* arr) { return 0; } - err = TVMPlatformMemoryFree(arr->dl_tensor.data, dev); + err = TVMDeviceFreeDataSpace(dev, arr->dl_tensor.data); if (err != kTvmErrorNoError) { return err; } diff --git a/src/runtime/crt/crt_config.h.template b/src/runtime/crt/crt_config.h.template index 1d32253282e8..70ac85b91fec 100644 --- a/src/runtime/crt/crt_config.h.template +++ b/src/runtime/crt/crt_config.h.template @@ -54,6 +54,9 @@ /*! Maximum supported string length in parameter names */ #define TVM_CRT_MAX_STRLEN_PARAM_NAME ${TVM_CRT_MAX_STRLEN_PARAM_NAME} +/*! Alignment (in bytes) for data buffer allocation */ +#define TVM_CRT_ALLOC_ALIGNMENT ${TVM_CRT_ALLOC_ALIGNMENT} + /*! Enable checks to enforce the stack allocator with a FIFO ordering. Off by default */ // #define TVM_CRT_STACK_ALLOCATOR_ENABLE_FIFO_CHECK diff --git a/src/runtime/crt/graph_executor/graph_executor.c b/src/runtime/crt/graph_executor/graph_executor.c index 395a343ccb41..46704a332a31 100644 --- a/src/runtime/crt/graph_executor/graph_executor.c +++ b/src/runtime/crt/graph_executor/graph_executor.c @@ -841,15 +841,10 @@ int TVMGraphExecutor_LoadParams(TVMGraphExecutor* executor, const char* param_bl status = -1; } - if (executor->data_entry[eid].dl_tensor.shape) { - err = TVMPlatformMemoryFree(executor->data_entry[eid].dl_tensor.shape, dev); - if (err != kTvmErrorNoError) { - status = -1; - } - executor->data_entry[eid].dl_tensor.shape = 0; - } + // The memory in the executor->data_entry[eid].dl_tensor.shape is + // owned by attrs->shape, and should not be freed here. if (executor->data_entry[eid].dl_tensor.data) { - err = TVMPlatformMemoryFree(executor->data_entry[eid].dl_tensor.data, dev); + err = TVMDeviceFreeDataSpace(dev, executor->data_entry[eid].dl_tensor.data); if (err != kTvmErrorNoError) { status = -1; } diff --git a/src/runtime/crt/host/CMakeLists.txt.template b/src/runtime/crt/host/CMakeLists.txt.template index be0bce85513b..8059ea316524 100644 --- a/src/runtime/crt/host/CMakeLists.txt.template +++ b/src/runtime/crt/host/CMakeLists.txt.template @@ -34,6 +34,9 @@ set(CRT_LIBS microtvm_rpc_server memory ) + +add_library(tvm_model) + # Build CRT libraries foreach(crt_lib_name ${CRT_LIBS}) add_library(${crt_lib_name}) @@ -41,11 +44,16 @@ foreach(crt_lib_name ${CRT_LIBS}) target_sources(${crt_lib_name} PRIVATE ${crt_lib_srcs}) target_include_directories(${crt_lib_name} PRIVATE crt_config crt/include) target_compile_definitions(${crt_lib_name} PRIVATE -DTVM_HOST_USE_GRAPH_EXECUTOR_MODULE) - target_link_libraries(main PRIVATE ${crt_lib_name}) + # Circular dependencies result in the static libraries being listed + # twice in the link command, resolving circular dependencies between + # the libraries and the model. + # + # See https://cmake.org/cmake/help/latest/command/target_link_libraries.html#cyclic-dependencies-of-static-libraries + target_link_libraries(tvm_model PRIVATE ${crt_lib_name}) + target_link_libraries(${crt_lib_name} PRIVATE tvm_model) endforeach(crt_lib_name ${CRT_LIBS}) # Build model files -add_library(tvm_model) file(GLOB_RECURSE tvm_model_srcs model/codegen/host/src/*.c model/codegen/host/lib/*.o) target_sources(tvm_model PRIVATE ${tvm_model_srcs}) target_include_directories(tvm_model PRIVATE ${CMAKE_SOURCE_DIR}/include crt_config crt/include) From b3dd5d1976166898261c4fe4c64ff494d53eced0 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 4 May 2023 15:57:22 -0500 Subject: [PATCH 6/9] [TIR][Runtime] Allow use of external non-compact/non-aligned buffers Prior to this commit, any use of `tvm.nd.from_dlpack` to create a strided `NDArray`, or a `NDArray` whose alignment was less than `tvm::runtime::kAllocAlignment` would raise an error. As a result, views into larger arrays, which are unlikely to be aligned and compact, could only be shared when copied into an aligned and compact buffer. This commit moves the compact/aligned check from the `NDArray` class into the generated TIR code as part of DLTensor unpacking. These checks were initially introduced in https://github.com/apache/tvm/pull/11391, to avoid segfaults caused by use of non-aligned buffers in code intended for aligned buffers. The new checks will provide the same safeguard as the alignment is checked prior to use, but allows the alignment requirement to be relaxed on a per-buffer basis. This approach also removes a potential bug resulting from compile-time configuration of `tvm::runtime::kAllocAlignment`, first introduced in https://github.com/apache/tvm/pull/13307. Since TVM supports cross-compiling, the installation of TVM used to compile a kernel may assume a larger value of `kAllocAlignment` than is provided by the runtime installation of TVM. By validating the alignment within the generated kernel, rather than as part of the runtime, this potential inconsistency would be caught. This check is also restricted to targets whose `void*` opaque pointer can be interpreted as a pointer to the data array. (e.g. No such check applies on Vulkan, as the `void*` is a pointer to a struct that contains additional bookkeeping.) --- src/runtime/ndarray.cc | 26 --- src/target/llvm/codegen_llvm.cc | 7 +- src/tir/transforms/arg_binder.cc | 48 ++++-- .../test_tir_argument_alignment.py | 162 ++++++++++++++++++ 4 files changed, 203 insertions(+), 40 deletions(-) create mode 100644 tests/python/tir-transform/test_tir_argument_alignment.py diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index c2efa79c0c83..e055187bb467 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -182,30 +182,7 @@ struct NDArray::Internal { NDArray NDArray::CreateView(ShapeTuple shape, DLDataType dtype, uint64_t relative_byte_offset) { ICHECK(data_ != nullptr); - const DLTensor& orig = get_mutable()->dl_tensor; - CHECK(IsContiguous()) << [&orig]() { - std::stringstream ss; - ss << "Can only create view for compact tensor, but found strides "; - - ss << "["; - for (int i = 0; i < orig.ndim; i++) { - if (i) ss << ", "; - ss << orig.strides[i]; - } - ss << "]"; - - ss << ", for shape "; - ss << "["; - for (int i = 0; i < orig.ndim; i++) { - if (i) ss << ", "; - ss << orig.shape[i]; - } - ss << "]"; - return ss.str(); - }(); - const auto& curr_dl_tensor = get_mutable()->dl_tensor; - NDArray ret = Internal::Create(shape, dtype, curr_dl_tensor.device); size_t curr_size = GetDataSize(this->get_mutable()->dl_tensor); @@ -273,9 +250,6 @@ NDArray NDArray::FromDLPack(DLManagedTensor* tensor) { data->SetDeleter(Internal::DLPackDeleter); // fill up content. data->manager_ctx = tensor; - ICHECK(::tvm::runtime::IsContiguous(tensor->dl_tensor)) << "DLManagedTensor must be contiguous."; - ICHECK(IsAligned(tensor->dl_tensor)) - << "Data in DLManagedTensor is not aligned as required by NDArray"; data->dl_tensor = tensor->dl_tensor; // update shape_ std::vector shape; diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 6fc083d17ccf..e378de6a9228 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1439,7 +1439,12 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { return ret_dummy; } else if (op->op.same_as(builtin::reinterpret())) { llvm::Type* target = DTypeToLLVMType(op->dtype); - return builder_->CreateBitCast(MakeValue(op->args[0]), target); + llvm::Value* value = MakeValue(op->args[0]); + if (op->args[0].dtype().is_handle() && (op->dtype.is_int() || op->dtype.is_uint())) { + return builder_->CreatePtrToInt(value, target); + } else { + return builder_->CreateBitCast(value, target); + } } else if (op->op.same_as(builtin::isnan())) { // TODO(hgt312): set fast math flag llvm::Value* a = MakeValue(op->args[0]); diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index 5b9e005b7ea3..a4e39984ffd5 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -302,25 +302,47 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrData), arg_name + ".data", true)) { Var vptr(buffer->data); + def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype)); + + // Mark alignment of external bufs. + AssertStmt satisfies_tvm_alignment = [&]() { + PrimExpr interpretable_as_pointer = + (device_type == kDLCPU) || (device_type == kDLCUDA) || (device_type == kDLCUDAHost) || + (device_type == kDLCUDAManaged) || (device_type == kDLHexagon); + PrimExpr is_aligned = + tvm::floormod(tvm::reinterpret(DataType::UInt(vptr->dtype.bits()), vptr), + buffer->data_alignment) == 0; + + std::ostringstream alignment_err_msg; + alignment_err_msg << arg_name << ".data is expected to be aligned to " + << buffer->data_alignment << " bytes"; + + return AssertStmt(!interpretable_as_pointer || is_aligned, StringImm(alignment_err_msg.str()), + nop); + }(); + asserts_.push_back(satisfies_tvm_alignment); // Check if the data pointer is NULL. This check is skipped for // size-0 arrays, since CUDA provides a NULL pointer for size-zero // allocations. - auto alloc_size = [&]() -> PrimExpr { - PrimExpr product = IntImm(buffer->DefaultIndexType(), 1); - for (const auto& dim : buffer->shape) { - product *= dim; - } - return product; + AssertStmt valid_data_ptr_for_non_empty_array = [&]() { + auto alloc_size = [&]() -> PrimExpr { + PrimExpr product = IntImm(buffer->DefaultIndexType(), 1); + for (const auto& dim : buffer->shape) { + product *= dim; + } + return product; + }(); + return AssertStmt( + alloc_size == 0 || !Call(DataType::Bool(), builtin::isnullptr(), {vptr}), + tvm::tir::StringImm(arg_name + " is expected to have non-NULL data pointer"), nop); }(); - asserts_.emplace_back(AssertStmt( - alloc_size == 0 || !Call(DataType::Bool(), builtin::isnullptr(), {vptr}), - tvm::tir::StringImm(arg_name + " is expected to have non-NULL data pointer"), nop)); + asserts_.push_back(valid_data_ptr_for_non_empty_array); - def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype)); - // mark alignment of external bufs - init_nest_.emplace_back(AttrStmt(vptr, tir::attr::storage_alignment, - IntImm(DataType::Int(32), buffer->data_alignment), nop)); + // AttrStmt provides a @llvm.assume that can be placed in the + // generated LLVM kernel + init_nest_.push_back(AttrStmt(vptr, tir::attr::storage_alignment, + IntImm(DataType::Int(32), buffer->data_alignment), nop)); } } diff --git a/tests/python/tir-transform/test_tir_argument_alignment.py b/tests/python/tir-transform/test_tir_argument_alignment.py new file mode 100644 index 000000000000..081dae64254f --- /dev/null +++ b/tests/python/tir-transform/test_tir_argument_alignment.py @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring + +import pytest + +import tvm +import tvm.testing +from tvm.script import tir as T + + +alignment_required = tvm.testing.parameter(1, 32, 64) +offset = tvm.testing.parameter(0, 3, 32, 64) + + +def test_aligned_dltensor(alignment_required, offset): + """Alignment of buffer arguments checked during DLTensor unpacking + + TVM allocates buffers that are aligned according to the value of + `tvm::runtime::kAllocAlignment`. However, buffers may be + non-aligned, either when provided by an external source, or when + the TVM runtime used for compilation and for execution have a + different value of `tvm::runtime::kAllocAlignment` through the + `TVM_KALLOC_ALIGNMENT` macro definition. In addition, while + `tvm::runtime::kAllocAlignment` is the default alignment for TIR + buffers, it can be overridden on a per-buffer basis. + + This test varies the alignment required by a buffer argument and + the alignment provided by an externally-owned array, validating + that non-aligned buffers may always be converted to TVM, and must + have their alignment validated when calling a PrimFunc. + """ + torch = pytest.importorskip("torch") + + @T.prim_func + def func(a: T.handle): + A = T.match_buffer(a, 16, dtype="int8", align=alignment_required) + T.evaluate(0) + + built = tvm.build(func) + + torch_tensor = torch.arange(128, dtype=torch.int8) + torch_view = torch_tensor[offset : offset + 16] + tvm_array = tvm.nd.from_dlpack(torch_view) + + satisfies_alignment = offset % alignment_required == 0 + if satisfies_alignment: + built(tvm_array) + else: + with pytest.raises(tvm.TVMError): + built(tvm_array) + + +contiguity_test_case = tvm.testing.parameter( + by_dict={ + "entire_first_row": ([4, 16], [0, 0]), + "entire_second_row": ([4, 16], [1, 0]), + "left_half_of_first_row": ([4, 32], [0, 0]), + "right_half_of_first_row": ([4, 32], [0, 16]), + } +) + + +def test_contiguous_dltensor(contiguity_test_case): + """Validate argument buffer is compact when strides are unspecified.""" + torch = pytest.importorskip("torch") + + @T.prim_func + def func(a: T.handle): + A = T.match_buffer(a, [1, 16], dtype="int8", align=1) + T.evaluate(0) + + built = tvm.build(func) + + view_backing_shape, view_offset = contiguity_test_case + torch_tensor = torch.zeros(*view_backing_shape, dtype=torch.int8) + torch_view = torch_tensor[ + view_offset[0] : view_offset[0] + 1, + view_offset[1] : view_offset[1] + 16, + ] + tvm_array = tvm.nd.from_dlpack(torch_view) + + built(tvm_array) + + +strided_test_case = tvm.testing.parameter( + by_dict={ + "entire_buffer": (8, 16), + "split_in_slowest_changing_dim": (32, 16), + "split_in_fastest_changing_dim": (8, 64), + } +) + + +def test_dynamic_striding_on_external_dltensor(strided_test_case): + """External buffers may be strided. + + Validity is checked by the TIR unpacking of the DLTensor, based on + the requirements of the TIR buffer. + """ + torch = pytest.importorskip("torch") + + @T.prim_func + def func(a: T.handle): + stride_i = T.var("int32") + stride_j = T.var("int32") + A = T.match_buffer(a, [8, 16], strides=[stride_i, stride_j], dtype="int8", align=1) + T.evaluate(0) + + built = tvm.build(func) + + torch_tensor = torch.zeros(*strided_test_case, dtype=torch.int8) + torch_view = torch_tensor[:8, :16] + tvm_array = tvm.nd.from_dlpack(torch_view) + + built(tvm_array) + + +def test_static_striding_on_external_dltensor(strided_test_case): + """External buffers may be strided. + + Import of strided arrays from external sources is legal. The + validity for any given PrimFunc is checked by the TIR unpacking of + the DLTensor, based on the requirements of the TIR buffer. + """ + torch = pytest.importorskip("torch") + + @T.prim_func + def func(a: T.handle): + A = T.match_buffer(a, [8, 16], dtype="int8", align=1) + T.evaluate(0) + + built = tvm.build(func) + + torch_tensor = torch.zeros(*strided_test_case, dtype=torch.int8) + torch_view = torch_tensor[:8, :16] + tvm_array = tvm.nd.from_dlpack(torch_view) + + has_correct_striding = strided_test_case[1] == 16 + if has_correct_striding: + built(tvm_array) + else: + with pytest.raises(tvm.TVMError): + built(tvm_array) + + +if __name__ == "__main__": + tvm.testing.main() From 2e176b4b54bef4316d4803019be30a730193f9e0 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 9 May 2023 15:33:55 -0500 Subject: [PATCH 7/9] Update expected constant size, given the alignment --- tests/python/tir-usmp/test_tir_usmp_algo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/tir-usmp/test_tir_usmp_algo.py b/tests/python/tir-usmp/test_tir_usmp_algo.py index b9cfde485633..886b10c3b2fc 100644 --- a/tests/python/tir-usmp/test_tir_usmp_algo.py +++ b/tests/python/tir-usmp/test_tir_usmp_algo.py @@ -365,8 +365,8 @@ def run_model(input: T.handle, output: T.handle) -> None: @pytest.mark.parametrize( ["algorithm", "fast_memory_size", "slow_memory_size"], [ - ("greedy_by_size", 200704, 1418528), - ("greedy_by_conflicts", 200704, 1418528), + ("greedy_by_size", 200704, 1418560), + ("greedy_by_conflicts", 200704, 1418560), ("hill_climb", 200704, 1117462), ], ) From f066eea59055734ac81fa998f28150500e4357bb Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 17 May 2023 15:06:24 -0500 Subject: [PATCH 8/9] LLVM codegen test, disable optimization to preserve provable assume() --- tests/python/codegen/test_target_codegen_llvm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index f1316ae3cee0..f81bb8a3b744 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -382,7 +382,7 @@ def test_alignment(): s = te.create_schedule(B.op) bx, tx = s[B].split(B.op.axis[0], factor=8) s[B].vectorize(tx) - f = tvm.build(s, [A, B], "llvm", name="test_alignment") + f = tvm.build(s, [A, B], "llvm -opt-level=0", name="test_alignment") lines = f.get_source().split("\n") From b265574be1232379daba02fb4ec0d8e6ede4f058 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 17 May 2023 15:09:14 -0500 Subject: [PATCH 9/9] Updated USMP test with aligned size/offsets --- ...form_convert_pool_allocations_to_offsets.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/python/tir-usmp/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/tir-usmp/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index 9e9fea7c8152..5de8c83f288e 100644 --- a/tests/python/tir-usmp/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/tir-usmp/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -106,7 +106,7 @@ class LinearStructure: def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) - placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8", elem_offset=0, align=64, offset_factor=1) + placeholder_4 = T.match_buffer(placeholder_2, [150560], dtype="uint8", elem_offset=0, align=64, offset_factor=1) placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=64, offset_factor=1) T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16", elem_offset=0, align=64, offset_factor=1) # body @@ -118,7 +118,7 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) - placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16", elem_offset=0, align=64, offset_factor=1) + placeholder_65 = T.match_buffer(placeholder_62, [150560], dtype="int16", elem_offset=0, align=64, offset_factor=1) placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16", elem_offset=0, align=64, offset_factor=1) placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32", elem_offset=0, align=64, offset_factor=1) T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8", elem_offset=0, align=64, offset_factor=1) @@ -178,11 +178,11 @@ class LinearStructurePlanned: @T.prim_func def __tvm_main__(input: T.handle, fast_memory_0_var: T.handle("uint8"), slow_memory_1_var: T.handle("uint8"), output: T.handle) -> None: fast_memory_0_buffer_var = T.match_buffer(fast_memory_0_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) - slow_memory_1_buffer_var = T.match_buffer(slow_memory_1_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) + slow_memory_1_buffer_var = T.match_buffer(slow_memory_1_var, [1418560], dtype="uint8", strides=[1], elem_offset=0, align=16) # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) - sid_9_let: T.handle("int8") = T.address_of(slow_memory_1_buffer_var[1117472], dtype="handle") + sid_9_let: T.handle("int8") = T.address_of(slow_memory_1_buffer_var[1117504], dtype="handle") sid_8_let: T.handle("int8") = T.address_of(slow_memory_1_buffer_var[0], dtype="handle") T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9_let, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9_let, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8_let, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32")) @@ -193,7 +193,7 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8") T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16") fast_memory_6_buffer_var = T.match_buffer(fast_memory_6_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) - slow_memory_7_buffer_var = T.match_buffer(slow_memory_7_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) + slow_memory_7_buffer_var = T.match_buffer(slow_memory_7_var, [1418560], dtype="uint8", strides=[1], elem_offset=0, align=16) # body tensor_2_let = T.Buffer([200704], dtype="uint8") with T.LetStmt(T.address_of(fast_memory_6_buffer_var[0], dtype="handle"), var=tensor_2_let.data): @@ -207,23 +207,23 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: @T.prim_func def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle, fast_memory_2_var: T.handle("uint8"), slow_memory_3_var: T.handle("uint8")) -> None: - placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8") + placeholder_4 = T.match_buffer(placeholder_2, [150560], dtype="uint8") placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16") T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16") fast_memory_2_buffer_var = T.match_buffer(fast_memory_2_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) - slow_memory_3_buffer_var = T.match_buffer(slow_memory_3_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) + slow_memory_3_buffer_var = T.match_buffer(slow_memory_3_var, [1418560], dtype="uint8", strides=[1], elem_offset=0, align=16) # body for ax0_ax1_fused_1, ax2_1, ax3_inner_1 in T.grid(224, 224, 3): T_subtract_1[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1] = T.cast(placeholder_4[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1], "int16") - placeholder_5[0] @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle, fast_memory_4_var: T.handle("uint8"), slow_memory_5_var: T.handle("uint8")) -> None: - placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16") + placeholder_65 = T.match_buffer(placeholder_62, [150560], dtype="int16") placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16") placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32") T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8") fast_memory_4_buffer_var = T.match_buffer(fast_memory_4_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) - slow_memory_5_buffer_var = T.match_buffer(slow_memory_5_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) + slow_memory_5_buffer_var = T.match_buffer(slow_memory_5_var, [1418560], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_7_let = T.Buffer([157323], "int16") with T.LetStmt(T.address_of(slow_memory_5_buffer_var[802816], dtype="handle"), var=PaddedInput_7_let.data):