From a178f628f1a2451b5941a1dd5559fdbdbffce43a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 5 May 2023 11:37:32 -0500 Subject: [PATCH] [TIR] Output DeclBuffer in MakePackedAPI For the `buf_shape` and `buf_strides` buffers, used while unpacking, generate a `DeclBuffer`. For each buffer that was in the `buffer_map`, generate a `DeclBuffer`. This is a subset of the changes made in https://github.com/apache/tvm/pull/14778, broken out for ease of testing and review. --- src/tir/transforms/arg_binder.cc | 2 ++ src/tir/transforms/make_packed_api.cc | 11 ++++++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index 571acf8e092f..f3d799365d2d 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -206,6 +206,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0)); init_nest_.emplace_back( LetStmt(buf_shape->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrShape), nop)); + init_nest_.emplace_back(DeclBuffer(buf_shape, nop)); for (size_t k = 0; k < buffer->shape.size(); ++k) { if (buffer->dtype == DataType::Int(4) || buffer->dtype == DataType::UInt(4) || buffer->dtype == DataType::Int(1)) { @@ -221,6 +222,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, def_handle_dtype_.Set(buf_strides->data, tir::TypeAnnotation(tvm_shape_type)); init_nest_.emplace_back(LetStmt( buf_strides->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), nop)); + init_nest_.emplace_back(DeclBuffer(buf_strides, nop)); PrimExpr v_strides_is_null = Call(DataType::Bool(1), builtin::isnullptr(), {buf_strides->data}); if (buffer->strides.size() == 0) { // Assert the buffer is compact diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index e387204045dd..94e245b636a8 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -251,9 +251,12 @@ PrimFunc MakePackedAPI(PrimFunc func) { Integer device_type(target_device_type); // seq_init gives sequence of initialization // seq_check gives sequence of later checks after init - std::vector seq_init, seq_check; + std::vector seq_init, seq_check, arg_buffer_declarations; std::unordered_map vmap; ArgBinder binder(&vmap); + + seq_init.emplace_back(DeclBuffer(buf_packed_arg_type_ids, nop)); + // --------------------------- // local function definitions // load i-th argument as type t @@ -331,6 +334,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { for (const auto& kv : buffer_def) { binder.BindDLTensor(kv.second, device_type, device_id, kv.first, name_hint + "." + kv.first->name_hint); + arg_buffer_declarations.push_back(DeclBuffer(kv.second, nop)); } func = WithAttrs(std::move(func), {{tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc)}, @@ -360,8 +364,9 @@ PrimFunc MakePackedAPI(PrimFunc func) { std::ostringstream num_args_error; num_args_error << name_hint << ": num_args should be " << num_args; std::vector arg_assert = {MakeAssertEQ(v_num_packed_args, num_args, num_args_error.str())}; - body = MergeNest({arg_assert, seq_init, binder.init_nest(), seq_check, binder.asserts()}, body); - + body = MergeNest({arg_assert, seq_init, binder.init_nest(), seq_check, binder.asserts(), + arg_buffer_declarations}, + body); func_ptr->body = body; func_ptr->params = args;