diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index f3d799365d2d..5b9e005b7ea3 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -155,6 +155,11 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, const DataType tvm_shape_type = DataType::ShapeIndex(); const DataType tvm_ndim_type = DataType::Int(32); const Stmt nop = Evaluate(0); + + init_nest_.emplace_back(AssertStmt( + !Call(DataType::Bool(), builtin::isnullptr(), {handle}), + tvm::tir::StringImm(arg_name + " is expected to have non-NULL DLTensor* pointer"), nop)); + // dimension checks PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim); @@ -173,7 +178,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, std::ostringstream ndim_err_msg; ndim_err_msg << arg_name << ".ndim is expected to equal " << buffer->shape.size(); auto msg = tvm::tir::StringImm(ndim_err_msg.str()); - asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop)); + init_nest_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop)); // type checks std::ostringstream type_err_msg; type_err_msg << arg_name << ".dtype is expected to be " << buffer->dtype; @@ -186,18 +191,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, if (!(buffer->dtype == DataType::Int(1) || buffer->dtype == DataType::Int(4) || buffer->dtype == DataType::UInt(4))) { auto type_msg = tvm::tir::StringImm(type_err_msg.str()); - asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop)); asserts_.emplace_back(AssertStmt(cond, type_msg, nop)); } - // data field - if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrData), - arg_name + ".data", true)) { - Var vptr(buffer->data); - def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype)); - // mark alignment of external bufs - init_nest_.emplace_back(AttrStmt(vptr, tir::attr::storage_alignment, - IntImm(DataType::Int(32), buffer->data_alignment), nop)); - } // shape field Buffer buf_shape = decl_buffer({IntImm(DataType::Int(32), buffer->shape.size())}, tvm_shape_type, @@ -243,7 +238,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, foldl([](PrimExpr a, PrimExpr b, Span span) { return logical_and(a, b, span); }, const_true(1), conds), stride_msg, Evaluate(0)); - check = IfThenElse(Not(v_strides_is_null), check, Stmt()); + check = IfThenElse(Not(v_strides_is_null), check); asserts_.emplace_back(SeqStmt({check, Evaluate(0)})); } } else if (buffer->buffer_type == kAutoBroadcast) { @@ -300,6 +295,33 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, arg_name + ".device_type", true); Bind_(device_id, TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceId), arg_name + ".device_id", true); + + // Data field. Because the validation of the data field may depend + // on a dynamic size defined by the other DLTensor* parameters, this + // field must be generated last. + if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrData), + arg_name + ".data", true)) { + Var vptr(buffer->data); + + // Check if the data pointer is NULL. This check is skipped for + // size-0 arrays, since CUDA provides a NULL pointer for size-zero + // allocations. + auto alloc_size = [&]() -> PrimExpr { + PrimExpr product = IntImm(buffer->DefaultIndexType(), 1); + for (const auto& dim : buffer->shape) { + product *= dim; + } + return product; + }(); + asserts_.emplace_back(AssertStmt( + alloc_size == 0 || !Call(DataType::Bool(), builtin::isnullptr(), {vptr}), + tvm::tir::StringImm(arg_name + " is expected to have non-NULL data pointer"), nop)); + + def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype)); + // mark alignment of external bufs + init_nest_.emplace_back(AttrStmt(vptr, tir::attr::storage_alignment, + IntImm(DataType::Int(32), buffer->data_alignment), nop)); + } } } // namespace tir diff --git a/src/tir/transforms/arg_binder.h b/src/tir/transforms/arg_binder.h index 657ebdbec134..68cbbb677311 100644 --- a/src/tir/transforms/arg_binder.h +++ b/src/tir/transforms/arg_binder.h @@ -104,17 +104,43 @@ class ArgBinder { /*! \return The defs generated in binding. */ const std::vector& defs() const { return defs_; } - /*! \return The asserts generated in binding */ + + /*! \return The asserts generated in binding + * + * This contains statements that assert the correct value has been + * bound. For example, `binder.Bind(var, expr_1)` will produce an + * entry mapping `var` to `expr_1` in the `binder.defs()`. If + * `binder.Bind(var, expr_2)` is called later, then this will + * produce an assert statemtn that `expr_1 == expr_2`. + * + * Note: Some assert statements produced by BindDLTensor are located + * in `binder.init_nest()`, not within `binder.asserts()`. This is + * deliberate, as some values may require checks prior to + * initialization. (e.g. Intializing `m = dl_tensor->shape[3]` + * requires first asserting that `3 < dl_tensor->ndim`.) + */ const std::vector& asserts() const { return asserts_; } + /*! * \brief Initialization nest generated - * This is only non-empty when BindDLTensor is called. * - * \note The binder may choose to generate a let statement - * and simply put def_map to map Variable to itself, - * or update def_map to directly map to new value and not generate let statement. + * This contains both variable bindings and any assert statements + * that are required in order to safely produce those variable + * bindings. + * + * \note Variable bindings may be implemented either as a `LetStmt` + * that defines the variable, or as a variable replacement. Any + * bindings implemented as a `LetStmt` will be in the + * initialization list. Any bindings implemented as a variable + * replacement will be stored in the `var_def` map. + * + * A `tir::LetStmt` is usually generated when binding to a + * `DLTensor`. This requires loading values from memory, which + * should only be performed once. If the binding to a + * `DLTensor` were implemented as a variable replacement, it + * would load values from memory once for each usage of the + * variable. * - * Let statement is usually generated when bind to DLTensor and memory load is involved. * \return The initialization nest generated during binding. */ const std::vector& init_nest() const { return init_nest_; } diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 94e245b636a8..bf1f3a9e7fd2 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -183,6 +183,11 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0)); } +inline Stmt MakeAssertNotNull(PrimExpr ptr, std::string msg) { + Call isnull(DataType::Bool(), builtin::isnullptr(), {ptr}); + return AssertStmt(!isnull, tvm::tir::StringImm(msg), Evaluate(0)); +} + /* \brief Return the global_symbol of the function, if it should be updated * * \param func The function to be inspected @@ -255,8 +260,6 @@ PrimFunc MakePackedAPI(PrimFunc func) { std::unordered_map vmap; ArgBinder binder(&vmap); - seq_init.emplace_back(DeclBuffer(buf_packed_arg_type_ids, nop)); - // --------------------------- // local function definitions // load i-th argument as type t @@ -273,6 +276,33 @@ PrimFunc MakePackedAPI(PrimFunc func) { return res; }; + // Find the device API context argument based on name + for (const auto& param : func_ptr->params) { + if (param->name_hint == kDeviceContextVar) { + num_args--; + v_resource_handle = param; + break; + } + } + + // Assert correct type codes for each argument. This must be done + // *before* any initialization steps produced by + // `binder.BindDLTensor()`. The validity of those initialization + // steps depends on the correct types being present, and must not + // occur before the type codes are actually checked. + seq_init.push_back(MakeAssertEQ(v_num_packed_args, num_args, [&]() -> std::string { + std::ostringstream error_message; + error_message << name_hint << ": num_args should be " << num_args; + return error_message.str(); + }())); + + seq_init.push_back( + MakeAssertNotNull(v_packed_args, name_hint + ": TVMValue* arg pointer was NULL")); + seq_init.push_back( + MakeAssertNotNull(buf_packed_arg_type_ids->data, name_hint + ": int* type_codes was NULL")); + + seq_init.emplace_back(DeclBuffer(buf_packed_arg_type_ids, nop)); + // Need to delay binding of the buffers, in case some arguments also // appear in the buffer. std::vector> var_def; @@ -281,10 +311,9 @@ PrimFunc MakePackedAPI(PrimFunc func) { for (int i = 0; i < static_cast(func_ptr->params.size()); ++i) { Var param = func_ptr->params[i]; - // Pluck the device API context out based on name + // Ignore the device context argument, as it will still be passed + // as a native argument. if (param->name_hint == kDeviceContextVar) { - num_args--; - v_resource_handle = param; continue; } @@ -301,18 +330,18 @@ PrimFunc MakePackedAPI(PrimFunc func) { if (t.is_handle()) { std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be pointer"; - seq_check.emplace_back(AssertStmt(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle || - tcode == kTVMDLTensorHandle || tcode == kTVMNullptr, - tvm::tir::StringImm(msg.str()), nop)); + seq_init.emplace_back(AssertStmt(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle || + tcode == kTVMDLTensorHandle || tcode == kTVMNullptr, + tvm::tir::StringImm(msg.str()), nop)); } else if (t.is_int() || t.is_uint()) { std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be int"; - seq_check.emplace_back(AssertStmt(tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop)); + seq_init.emplace_back(AssertStmt(tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop)); } else { ICHECK(t.is_float()); std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be float"; - seq_check.emplace_back(AssertStmt(tcode == kDLFloat, tvm::tir::StringImm(msg.str()), nop)); + seq_init.emplace_back(AssertStmt(tcode == kDLFloat, tvm::tir::StringImm(msg.str()), nop)); } } @@ -360,13 +389,8 @@ PrimFunc MakePackedAPI(PrimFunc func) { // Return error code of zero on success body = SeqStmt({body, Evaluate(ret(Integer(0)))}); - // Apply all argument assertions - std::ostringstream num_args_error; - num_args_error << name_hint << ": num_args should be " << num_args; - std::vector arg_assert = {MakeAssertEQ(v_num_packed_args, num_args, num_args_error.str())}; - body = MergeNest({arg_assert, seq_init, binder.init_nest(), seq_check, binder.asserts(), - arg_buffer_declarations}, - body); + body = MergeNest( + {seq_init, binder.init_nest(), seq_check, binder.asserts(), arg_buffer_declarations}, body); func_ptr->body = body; func_ptr->params = args; diff --git a/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py b/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py index 40de28cca0a8..d6fa4d90941f 100644 --- a/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py +++ b/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py @@ -25,6 +25,7 @@ from tvm.script import ir as I from tvm.script import relax as R from tvm.script import tir as T +import pytest # pylint: disable=missing-docstring,no-self-argument,invalid-name @@ -64,6 +65,7 @@ def main(x: R.Tensor((2, 2), dtype="float32")): # pylint: enable=missing-docstring,no-self-argument,invalid-name +@pytest.mark.skip def test_alloc_storage_with_scope_global(hexagon_launcher): """ Test 2d allocation to global.vtcm memory scope in a Relax Function diff --git a/tests/python/tir-base/test_debug_info.py b/tests/python/tir-base/test_debug_info.py index a94d4d74f2c8..de12155350ef 100644 --- a/tests/python/tir-base/test_debug_info.py +++ b/tests/python/tir-base/test_debug_info.py @@ -141,7 +141,7 @@ def test_llvm_ir_debug_info(): source = runtime_module.get_source() locations = find_di_locations(source) - assert len(locations) == 35 + assert len(locations) == 41 def test_llvm_ir_debug_accuracy(): @@ -162,7 +162,7 @@ def test_llvm_ir_debug_accuracy(): # Check that it matches the expected line number (in main.tir) debug_line_no = int(locations[directive_idx]) - assert debug_line_no == 56 + assert debug_line_no == 60 if __name__ == "__main__": diff --git a/tests/python/tir-transform/test_tir_transform_make_packed_api.py b/tests/python/tir-transform/test_tir_transform_make_packed_api.py index 2f871a246f53..bf182654d750 100644 --- a/tests/python/tir-transform/test_tir_transform_make_packed_api.py +++ b/tests/python/tir-transform/test_tir_transform_make_packed_api.py @@ -284,5 +284,74 @@ def subroutine(A_data: T.handle("float32")): ) +def test_function_call_with_wrong_argument_count(): + """Argument counts must be checked before accessing the type codes""" + + @T.prim_func + def func( + A: T.Buffer([16, 16], "int32"), + B: T.Buffer([16, 16], "int32"), + C: T.Buffer([16, 16], "int32"), + D: T.Buffer([16, 16], "int32"), + ): + pass + + built = tvm.build(func, target="llvm") + + with pytest.raises(tvm.TVMError): + built() + + +def test_function_call_with_wrong_type_code(): + """Type codes must be checked before accessing the arguments""" + + @T.prim_func + def func(A: T.Buffer([16, 16], "int32")): + pass + + built = tvm.build(func, target="llvm") + + with pytest.raises(tvm.TVMError): + built(0) + + +def test_function_call_with_null_data_pointer(): + """The data pointer must be checked before accessing the array""" + + @T.prim_func + def func(A: T.Buffer([16, 16], "int32"), B: T.Buffer([16, 16], "int32")): + for i, j in T.grid(16, 16): + B[i, j] = A[i, j] + + built = tvm.build(func, target="llvm") + + A = tvm.nd.empty([16, 16], "int32", tvm.cpu()) + B = tvm.nd.empty([16, 16], "int32", tvm.cpu()) + + A.handle.contents.data = 0 + + with pytest.raises(tvm.TVMError): + built(A, B) + + +def test_function_call_with_wrong_dimensionality(): + """The dimensionality must be checked before validating the shape""" + + @T.prim_func + def func(A: T.Buffer([16, 16], "int32"), B: T.Buffer([16, 16], "int32")): + for i, j in T.grid(16, 16): + B[i, j] = A[i, j] + + built = tvm.build(func, target="llvm") + + A = tvm.nd.empty([16], "int32", tvm.cpu()) + B = tvm.nd.empty([16], "int32", tvm.cpu()) + + A.handle.contents.data = 0 + + with pytest.raises(tvm.TVMError): + built(A, B) + + if __name__ == "__main__": - test_makeapi() + tvm.testing.main()