From ff3f95b1ae504e15b551e04b5b756e6dcaa464ca Mon Sep 17 00:00:00 2001 From: Chris Sidebottom Date: Mon, 12 Jul 2021 15:00:53 +0000 Subject: [PATCH 1/2] Pass resource_handle to operators with unpacked API This patch passes the resource_handle variable through the AOT executor down to the backend operators. In order to model this on the `PrimFunc` the resource_handle was added as a property, the resource_handle property is then added to the arguments when transformed via MakeUnpackedAPI. The flow of the resource_handle looks similar to this: ```c int32_t __tvm_main__(void* args, void* type_code, int num_args, void* out_value, void* out_type_code, void* resource_handle) { return tvmgen_run_model( ((DLTensor*)(((TVMValue*)args)[0].v_handle))[0].data, ((DLTensor*)(((TVMValue*)args)[1].v_handle))[0].data, ((DLTensor*)(((TVMValue*)args)[2].v_handle))[0].data, ((DLTensor*)(((TVMValue*)args)[3].v_handle))[0].data, resource_handle ); } TVM_DLL int32_t tvmgen_run_model(void* arg0, void* arg1, void* arg2, void* arg3, void* resource_handle) { void* input = arg0; void* input1 = arg1; void* input2 = arg2; void* output = arg3; (void)tvmgen_fused_concatenate_add(input, input1, input2, output, resource_handle); return 0; } ``` --- include/tvm/tir/function.h | 11 ++++++++-- python/tvm/tir/function.py | 22 +++++++++++++++++-- src/relay/backend/aot_executor_codegen.cc | 6 ++++- src/target/source/source_module.cc | 22 +++++-------------- src/tir/ir/function.cc | 9 +++++--- src/tir/transforms/make_packed_api.cc | 2 +- src/tir/transforms/make_unpacked_api.cc | 3 +++ .../test_tir_transform_make_unpacked_api.py | 19 +++++++++++----- 8 files changed, 63 insertions(+), 31 deletions(-) diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 55f4fc62649c..ee3fad7054ae 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -90,6 +90,8 @@ class PrimFuncNode : public BaseFuncNode { * will make program analysis much easier. */ Map buffer_map; + /*! \brief The resource handle to be used by the function when accessing platform resources */ + tir::Var resource_handle; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("params", ¶ms); @@ -97,6 +99,7 @@ class PrimFuncNode : public BaseFuncNode { v->Visit("ret_type", &ret_type); v->Visit("buffer_map", &buffer_map); v->Visit("attrs", &attrs); + v->Visit("resource_handle", &resource_handle); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -105,7 +108,7 @@ class PrimFuncNode : public BaseFuncNode { // visit params and buffer_map first as they contains defs. return equal.DefEqual(params, other->params) && equal(buffer_map, other->buffer_map) && equal(ret_type, other->ret_type) && equal(body, other->body) && - equal(attrs, other->attrs); + equal(attrs, other->attrs) && equal.DefEqual(resource_handle, other->resource_handle); } void SHashReduce(SHashReducer hash_reduce) const { @@ -114,6 +117,7 @@ class PrimFuncNode : public BaseFuncNode { hash_reduce(ret_type); hash_reduce(body); hash_reduce(attrs); + hash_reduce.DefHash(resource_handle); } /*! * \brief Return the derived function annotation of this function. @@ -141,11 +145,14 @@ class PrimFunc : public BaseFunc { * \param ret_type The return type of the function. * \param buffer_map The buffer map for parameter buffer unpacking. * \param attrs Additional function attributes. + * \param resource_handle Handle for passing resources to the function * \param span The location of this object in the source code. */ TVM_DLL PrimFunc(Array params, Stmt body, Type ret_type = VoidType(), Map buffer_map = Map(), - DictAttrs attrs = NullValue(), Span span = Span()); + DictAttrs attrs = NullValue(), + tir::Var resource_handle = tir::Var("resource_handle", DataType::Handle()), + Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode); diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 68d967aa497d..e9516084f1f5 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -48,13 +48,31 @@ class PrimFunc(BaseFunc): attrs: Optional[tvm.Attrs] Attributes of the function, can be None + resource_handle: Optional[tvm.tir.Var] + The resource handle to be used by the function when accessing platform resources, + if not passed a Var will be created for it + span : Optional[Span] The location of this itervar in the source code. """ - def __init__(self, params, body, ret_type=None, buffer_map=None, attrs=None, span=None): + def __init__( + self, + params, + body, + ret_type=None, + buffer_map=None, + attrs=None, + resource_handle=None, + span=None, + ): param_list = [] buffer_map = {} if buffer_map is None else buffer_map + + # This is bound later as it relies on the FFI API having defined "Var" + if resource_handle is None: + resource_handle = Var("resource_handle", dtype="handle") + for x in params: x = tvm.runtime.convert(x) if not isinstance(x, Object) else x if isinstance(x, Buffer): @@ -67,7 +85,7 @@ def __init__(self, params, body, ret_type=None, buffer_map=None, attrs=None, spa raise TypeError("params can only contain Var or Buffer") self.__init_handle_by_constructor__( - _ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs, span # type: ignore + _ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs, resource_handle, span # type: ignore ) def with_body(self, new_body, span=None): diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 221df958a8cb..70a990e43397 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -310,6 +310,7 @@ class AOTExecutorCodegen : public ExprVisitor { auto calling_pattern = tvm::tir::builtin::tvm_call_cpacked(); if (use_unpacked_api_) { calling_pattern = tvm::tir::builtin::call_extern(); + args.push_back(resource_handle_); } create_func_call_stmts.push_back( @@ -643,7 +644,7 @@ class AOTExecutorCodegen : public ExprVisitor { // Make the PrimFunc return tir::PrimFunc(main_signature_, body, VoidType(), Map(), - DictAttrs(dict_attrs)); + DictAttrs(dict_attrs), resource_handle_); } protected: @@ -651,6 +652,8 @@ class AOTExecutorCodegen : public ExprVisitor { runtime::Module* mod_; /*! \brief list of input expressions (i.e., variable passed by the user) */ std::vector input_vars_; + /*! \brief resource handle to be passed into operator functions */ + tir::Var resource_handle_; /*! \brief input and output variables belonging to the main function signature */ Array main_signature_; /*! \brief target device */ @@ -699,6 +702,7 @@ class AOTExecutorCodegen : public ExprVisitor { public: AOTExecutorCodegen(runtime::Module* mod, const TargetsMap& targets, Target target_host) : mod_(mod), + resource_handle_("resource_handle", DataType::Handle()), targets_(targets), target_host_(target_host), use_unpacked_api_(target_host->GetAttr("unpacked-api").value_or(Bool(false))), diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 7728773b13d7..dc6f49e25186 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -195,29 +195,19 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { void GenerateEntrypointForUnpackedAPI(const std::string& entrypoint_name, const std::string& run_func) { code_ << "TVM_DLL int32_t " << run_func << "("; - unsigned int total_args = (metadata_->inputs.size() + metadata_->num_outputs); - for (unsigned int i = 0; i < total_args; ++i) { - code_ << "void* arg" << i; - if (i + 1 != total_args) { - code_ << ","; - } + int total_args = (metadata_->inputs.size() + metadata_->num_outputs); + for (int i = 0; i < total_args; ++i) { + code_ << "void* arg" << i << ","; } - code_ << ");\n"; + code_ << "void* resource_handle);\n"; code_ << "int32_t " << entrypoint_name; code_ << "(void* args, void* type_code, int num_args, void* out_value, void* " "out_type_code, void* resource_handle) {\n"; code_ << "return " << run_func << "("; - for (unsigned int i = 0; i < metadata_->inputs.size(); ++i) { + for (int i = 0; i < total_args; ++i) { code_ << "((DLTensor*)(((TVMValue*)args)[" << i << "].v_handle))[0].data,"; } - for (int i = 0; i < metadata_->num_outputs; ++i) { - int j = metadata_->inputs.size() + i; - code_ << "((DLTensor*)(((TVMValue*)args)[" << j << "].v_handle))[0].data"; - if (i + 1 != metadata_->num_outputs) { - code_ << ","; - } - } - code_ << ");\n"; + code_ << "resource_handle);\n"; code_ << "}\n"; } diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index 101d80a52ea1..fc092530dfd1 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -37,7 +37,8 @@ LinkedParam::LinkedParam(int64_t id, ::tvm::runtime::NDArray param) { // Get the function type of a PrimFunc PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, - Map buffer_map, DictAttrs attrs, Span span) { + Map buffer_map, DictAttrs attrs, tir::Var resource_handle, + Span span) { // Assume void-return type for now // TODO(tvm-team) consider type deduction from body. if (!ret_type.defined()) { @@ -50,6 +51,7 @@ PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, n->buffer_map = std::move(buffer_map); n->attrs = std::move(attrs); n->checked_type_ = n->func_type_annotation(); + n->resource_handle = std::move(resource_handle); n->span = std::move(span); data_ = std::move(n); } @@ -81,8 +83,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_GLOBAL("tir.PrimFunc") .set_body_typed([](Array params, Stmt body, Type ret_type, - Map buffer_map, DictAttrs attrs, Span span) { - return PrimFunc(params, body, ret_type, buffer_map, attrs, span); + Map buffer_map, DictAttrs attrs, tir::Var resource_handle, + Span span) { + return PrimFunc(params, body, ret_type, buffer_map, attrs, resource_handle, span); }); } // namespace tir diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 393ce6c286b4..fd4b5f2ac233 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -161,7 +161,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { }; // --------------------------- // start of logics - // add signiture for packed arguments. + // add signature for packed arguments. if (pack_args) { args.push_back(v_packed_args); args.push_back(v_packed_arg_type_ids); diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index 6e8793fbd367..796eb21037eb 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -80,6 +80,9 @@ PrimFunc MakeUnpackedAPI(PrimFunc&& func) { args.push_back(v_arg); } + // Add resource handle to function parameters + args.push_back(func_ptr->resource_handle); + // Bind variables then bind buffers to them to ensure correct ordering for (const auto& kv : var_def) { binder.Bind(kv.second, kv.first, kv.first->name_hint, true); diff --git a/tests/python/unittest/test_tir_transform_make_unpacked_api.py b/tests/python/unittest/test_tir_transform_make_unpacked_api.py index 9d917466758b..0fbcefb1a11b 100644 --- a/tests/python/unittest/test_tir_transform_make_unpacked_api.py +++ b/tests/python/unittest/test_tir_transform_make_unpacked_api.py @@ -57,8 +57,9 @@ def test_fails_if_no_target(mod_without_attrs): def test_device_setup(mod, target, dev): mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target(target)))(mod) f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] - assert len(f.params) == 1 + assert len(f.params) == 2 assert f.params[0].name == "arg0" + assert f.params[1].name == "resource_handle" assert f.body.node == "default" assert f.body.attr_key == "device_id" assert f.body.value == 0 @@ -76,15 +77,18 @@ def test_no_buffers_no_device_setup(): mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod) f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] - assert len(f.params) == 1 + assert len(f.params) == 2 + assert f.params[0].name == "arg0" + assert f.params[1].name == "resource_handle" assert f.body.var.name == "A" assert f.body.value.name == "arg0" def test_argument_mapping(mod): f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] - assert len(f.params) == 1 + assert len(f.params) == 2 assert f.params[0].name == "arg0" + assert f.params[1].name == "resource_handle" assert f.body.body.body.var.name == "A" assert f.body.body.body.value.name == "arg0" @@ -100,9 +104,10 @@ def test_argument_mapping_multiple(): mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod) f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] - assert len(f.params) == 2 + assert len(f.params) == 3 assert f.params[0].name == "arg0" assert f.params[1].name == "arg1" + assert f.params[2].name == "resource_handle" assert f.body.body.body.var.name == "A" assert f.body.body.body.value.name == "arg0" assert f.body.body.body.body.var.name == "B" @@ -119,9 +124,10 @@ def test_argument_mapping_multiple_matching(): mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod) f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] - assert len(f.params) == 2 + assert len(f.params) == 3 assert f.params[0].name == "arg0" assert f.params[1].name == "arg1" + assert f.params[2].name == "resource_handle" assert f.body.body.body.var.name == "A" assert f.body.body.body.value.name == "arg0" assert f.body.body.body.body.condition.a.name == "A" @@ -139,10 +145,11 @@ def test_body(): mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(mod) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod) f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] - assert len(f.params) == 3 + assert len(f.params) == 4 assert f.params[0].name == "arg0" assert f.params[1].name == "arg1" assert f.params[2].name == "arg2" + assert f.params[3].name == "resource_handle" assert f.body.body.body.var.name == "A" assert f.body.body.body.value.name == "arg2" assert f.body.body.body.body.var.name == "B" From c69c99475a3f1b7f21573079c9bf6e62e6c5dce1 Mon Sep 17 00:00:00 2001 From: Chris Sidebottom Date: Tue, 13 Jul 2021 15:33:39 +0000 Subject: [PATCH 2/2] Only add type ignore on untyped variable in function.py Black autoformats this to a longer line than pylint allows, only marking the relevant variable causes the formatting to run correctly. I think this is fine based on the assumption that the ignores should be removed in favour of proper typing at some point. --- python/tvm/tir/function.py | 9 ++++++++- tests/scripts/task_mypy.sh | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index e9516084f1f5..e7fe3aa76815 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -85,7 +85,14 @@ def __init__( raise TypeError("params can only contain Var or Buffer") self.__init_handle_by_constructor__( - _ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs, resource_handle, span # type: ignore + _ffi_api.PrimFunc, # type: ignore + param_list, + body, + ret_type, + buffer_map, + attrs, + resource_handle, + span, ) def with_body(self, new_body, span=None): diff --git a/tests/scripts/task_mypy.sh b/tests/scripts/task_mypy.sh index b05acb090c2f..8e18e5f3f41c 100755 --- a/tests/scripts/task_mypy.sh +++ b/tests/scripts/task_mypy.sh @@ -23,5 +23,5 @@ mypy --check-untyped-defs python/tvm/tir/schedule echo "Checking MyPy Type defs in the analysis package." mypy --check-untyped-defs python/tvm/tir/analysis/ -echo "Checking MyPy Type defs in the transofrm package." +echo "Checking MyPy Type defs in the transform package." mypy --check-untyped-defs python/tvm/tir/transform/