From 3c44f7bb9f8e6910d926affdae5052c9c86cb9cb Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 28 Feb 2022 17:05:41 +0900 Subject: [PATCH 01/32] add get_c_struct_name() method to Metadata to distinguish struct type name in llvm --- include/tvm/runtime/metadata.h | 2 + include/tvm/runtime/metadata_base.h | 31 ++++- src/runtime/metadata.cc | 15 ++- src/target/metadata.h | 11 +- src/target/source/source_module.cc | 168 +++++++++++++++++----------- tests/cpp/aot_metadata_test.cc | 12 +- 6 files changed, 153 insertions(+), 86 deletions(-) diff --git a/include/tvm/runtime/metadata.h b/include/tvm/runtime/metadata.h index cd65f6fb7486..b7f7c6c0a458 100644 --- a/include/tvm/runtime/metadata.h +++ b/include/tvm/runtime/metadata.h @@ -116,6 +116,7 @@ class MetadataNode : public MetadataBaseNode { public: explicit MetadataNode(const struct ::TVMMetadata* data) : data_{data} {} static constexpr const char* _type_key = "metadata.MetadataNode"; + const char* get_c_struct_name() const override; inline int64_t version() const { return int64_t(data_->version); } inline int64_t num_inputs() const { return data_->num_inputs; } ArrayAccessor inputs(); @@ -141,6 +142,7 @@ class TensorInfoNode : public MetadataBaseNode { public: explicit TensorInfoNode(const struct ::TVMTensorInfo* data) : data_{data} {} static constexpr const char* _type_key = "metadata.TensorInfoNode"; + const char* get_c_struct_name() const override; inline ::tvm::runtime::String name() const { return ::tvm::runtime::String(data_->name); } inline int64_t num_shape() const { return data_->num_shape; } inline ::tvm::support::Span shape() const { diff --git a/include/tvm/runtime/metadata_base.h b/include/tvm/runtime/metadata_base.h index 96743199fe28..698f56d46d28 100644 --- a/include/tvm/runtime/metadata_base.h +++ b/include/tvm/runtime/metadata_base.h @@ -44,6 +44,8 @@ namespace metadata { */ class MetadataBaseNode : public ::tvm::runtime::Object { public: + virtual const char* get_c_struct_name() const = 0; + static constexpr const char* _type_key = "metadata.MetadataBaseNode"; TVM_DECLARE_BASE_OBJECT_INFO(MetadataBaseNode, ::tvm::runtime::Object); }; @@ -157,7 +159,7 @@ class ArrayAccessor { * * These are separate from TIR DataType because TIR does not model structs. */ -enum MetadataTypeIndex : uint8_t { +enum MetadataKind : uint8_t { kUint64 = 0, kInt64 = 1, kBool = 2, @@ -173,12 +175,29 @@ enum MetadataTypeIndex : uint8_t { */ class MetadataArrayNode : public MetadataBaseNode { public: - MetadataArrayNode(Array array, MetadataTypeIndex type_index, const char* struct_name) - : array(::std::move(array)), type_index{type_index}, struct_name{struct_name} {} + MetadataArrayNode(Array array, MetadataKind kind, const char* type_key) + : array(::std::move(array)), kind{kind}, type_key{type_key} {} + + const char* get_c_struct_name() const final; + + std::string get_element_c_struct_name() const { + CHECK(kind == MetadataKind::kMetadata) + << "cannot get struct name for MetadataArray with kind=" << kind; + constexpr int prefix_size = sizeof("metadata.") - 1; + constexpr int suffix_size = sizeof("Node") - 1; + std::string type_key_str(type_key); + return std::string("TVM") + + type_key_str.substr(prefix_size, type_key_str.size() - prefix_size - suffix_size); + } Array array; - MetadataTypeIndex type_index; - const char* struct_name; + + /*! \brief Describes the storage class of the emitted struct member. */ + MetadataKind kind; + + /*! \brief When `kind` is Metadata, type_key of the MetadataBaseNode used with this array. */ + const char* type_key; + static constexpr const char* _type_key = "metadata.MetadataArrayNode"; TVM_DECLARE_BASE_OBJECT_INFO(MetadataArrayNode, MetadataBaseNode); }; @@ -186,7 +205,7 @@ class MetadataArrayNode : public MetadataBaseNode { /*! \brief Reference class for MetadataArray. */ class MetadataArray : public MetadataBase { public: - MetadataArray(Array array, MetadataTypeIndex type_index, const char* struct_name); + MetadataArray(Array array, MetadataKind kind, const char* struct_name); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataArray, MetadataBase, MetadataArrayNode); }; diff --git a/src/runtime/metadata.cc b/src/runtime/metadata.cc index 90469fabad2c..c08f2872fe8a 100644 --- a/src/runtime/metadata.cc +++ b/src/runtime/metadata.cc @@ -18,7 +18,7 @@ */ /*! - * \file tvm/runtime/metadata.h + * \file src/runtime/metadata.cc * \brief Defines implementations of TVM metadata which can exist in the runtime. */ @@ -47,20 +47,27 @@ ArrayAccessor MetadataNode::pools() { TVM_REGISTER_OBJECT_TYPE(MetadataBaseNode); -MetadataArray::MetadataArray(Array array, MetadataTypeIndex type_index, - const char* struct_name) - : MetadataBase{make_object(array, type_index, struct_name)} {} +MetadataArray::MetadataArray(Array array, MetadataKind kind, const char* struct_name) + : MetadataBase{make_object(array, kind, struct_name)} {} +const char* MetadataArrayNode::get_c_struct_name() const { + ICHECK(false) << "MetadataArrayNode get_c_struct_name is unimplemented"; + return nullptr; +} TVM_REGISTER_OBJECT_TYPE(MetadataArrayNode); Metadata::Metadata(const struct ::TVMMetadata* data) : MetadataBase{make_object(data)} {} TVM_REGISTER_OBJECT_TYPE(MetadataNode); +const char* MetadataNode::get_c_struct_name() const { return "TVMMetadata"; } + TensorInfo::TensorInfo(const struct ::TVMTensorInfo* data) : MetadataBase{make_object(data)} {} TVM_REGISTER_OBJECT_TYPE(TensorInfoNode); +const char* TensorInfoNode::get_c_struct_name() const { return "TVMTensorInfo"; } + } // namespace metadata class MetadataModuleNode : public ::tvm::runtime::ModuleNode { diff --git a/src/target/metadata.h b/src/target/metadata.h index b8ca24580f15..5dc1c9d0eec5 100644 --- a/src/target/metadata.h +++ b/src/target/metadata.h @@ -56,7 +56,8 @@ class VisitableMetadataNode : public ::tvm::runtime::metadata::MetadataNode { inputs_array.push_back(::tvm::runtime::metadata::TensorInfo{inputs_accessor[i]}); } ::tvm::runtime::metadata::MetadataArray inputs_metadata_array{ - inputs_array, ::tvm::runtime::metadata::MetadataTypeIndex::kMetadata, "TVMTensorInfo"}; + inputs_array, ::tvm::runtime::metadata::MetadataKind::kMetadata, + ::tvm::runtime::metadata::TensorInfoNode::_type_key}; v->Visit("inputs", &inputs_metadata_array); int64_t num_inputs_cpp = num_inputs(); v->Visit("num_inputs", &num_inputs_cpp); @@ -67,7 +68,8 @@ class VisitableMetadataNode : public ::tvm::runtime::metadata::MetadataNode { outputs_array.push_back(::tvm::runtime::metadata::TensorInfo{outputs_accessor[i]}); } ::tvm::runtime::metadata::MetadataArray outputs_metadata_array{ - outputs_array, ::tvm::runtime::metadata::MetadataTypeIndex::kMetadata, "TVMTensorInfo"}; + outputs_array, ::tvm::runtime::metadata::MetadataKind::kMetadata, + ::tvm::runtime::metadata::TensorInfoNode::_type_key}; v->Visit("outputs", &outputs_metadata_array); int64_t num_outputs_cpp = num_outputs(); v->Visit("num_outputs", &num_outputs_cpp); @@ -78,7 +80,8 @@ class VisitableMetadataNode : public ::tvm::runtime::metadata::MetadataNode { pools_array.push_back(::tvm::runtime::metadata::TensorInfo{pools_accessor[i]}); } ::tvm::runtime::metadata::MetadataArray pools_metadata_array{ - pools_array, ::tvm::runtime::metadata::MetadataTypeIndex::kMetadata, "TVMTensorInfo"}; + pools_array, ::tvm::runtime::metadata::MetadataKind::kMetadata, + ::tvm::runtime::metadata::TensorInfoNode::_type_key}; v->Visit("pools", &pools_metadata_array); int64_t num_pools_cpp = num_pools(); v->Visit("num_pools", &num_pools_cpp); @@ -156,7 +159,7 @@ class VisitableTensorInfoNode : public ::tvm::runtime::metadata::TensorInfoNode shape_array.push_back(::tvm::Integer{static_cast(shape_accessor[i])}); } ::tvm::runtime::metadata::MetadataArray shape_metadata_array{ - shape_array, ::tvm::runtime::metadata::MetadataTypeIndex::kInt64, nullptr}; + shape_array, ::tvm::runtime::metadata::MetadataKind::kInt64, nullptr}; v->Visit("shape", &shape_metadata_array); int64_t num_shape_cpp = num_shape(); v->Visit("num_shape", &num_shape_cpp); diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 80b4f1b970f3..b8a6f789ddaf 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -23,6 +23,7 @@ */ #include "source_module.h" +#include #include #include #include @@ -585,7 +586,7 @@ class MetadataQueuer : public AttrVisitor { class MetadataSerializer : public AttrVisitor { public: static constexpr const char* kGlobalSymbol = "kTvmgenMetadata"; - using MetadataTypeIndex = ::tvm::runtime::metadata::MetadataTypeIndex; + using MetadataKind = ::tvm::runtime::metadata::MetadataKind; MetadataSerializer() : is_first_item_{true} {} @@ -653,29 +654,54 @@ class MetadataSerializer : public AttrVisitor { ICHECK(false) << "do not support serializing NDArray as metadata"; } - void VisitArray(const runtime::metadata::MetadataArrayNode* array) { + void VisitArray(runtime::metadata::MetadataArray array) { auto old_is_first_item = is_first_item_; is_first_item_ = true; for (unsigned int i = 0; i < array->array.size(); ++i) { ObjectRef o = array->array[i]; - if (o->IsInstance()) { - int64_t i = Downcast(o); - Visit(nullptr, &i); - continue; - } - if (o->IsInstance()) { - std::string s = Downcast(o); - Visit(nullptr, &s); - continue; + switch (array->kind) { + case MetadataKind::kUint64: { + int64_t i = Downcast(o); + CHECK_GT(i, 0) + << "Metadata is of type uint64_t, but array type contains a negative number"; + uint64_t ui = static_cast(i); + Visit(nullptr, &ui); + continue; + } + case MetadataKind::kInt64: { + int64_t i = Downcast(o); + Visit(nullptr, &i); + continue; + } + case MetadataKind::kBool: { + bool b = Downcast(o); + Visit(nullptr, &b); + break; + } + case MetadataKind::kString: { + std::string s = Downcast(o); + Visit(nullptr, &s); + break; + } + case MetadataKind::kHandle: + CHECK(false) << "Don't know how to serialize handle"; + break; + + case MetadataKind::kMetadata: { + runtime::metadata::MetadataBase metadata = Downcast(o); + std::stringstream i_str; + i_str << i; + address_.push_back(i_str.str()); + Visit(nullptr, &metadata); + address_.pop_back(); + break; + } + default: + CHECK(false) << "Unknown MetadataKind for array: " << array->kind; + break; } - - runtime::metadata::MetadataBase metadata = Downcast(o); - std::stringstream i_str; - i_str << i; - address_.push_back(i_str.str()); - Visit(nullptr, &metadata); - address_.pop_back(); + is_first_item_ = false; } is_first_item_ = old_is_first_item; } @@ -688,7 +714,7 @@ class MetadataSerializer : public AttrVisitor { if (key != nullptr) { address_.push_back(key); } - code_ << address_from_parts(address_); + code_ << metadata::address_from_parts(address_); if (key != nullptr) { address_.pop_back(); } @@ -705,59 +731,69 @@ class MetadataSerializer : public AttrVisitor { } } + private: + void EmitCType(const runtime::metadata::MetadataArrayNode* arr, std::ostream& os) { + switch (arr->kind) { + case MetadataKind::kUint64: + os << "uint64_t"; + break; + case MetadataKind::kInt64: + os << "int64_t"; + break; + case MetadataKind::kBool: + os << "bool"; + break; + case MetadataKind::kString: + os << "const char*"; + break; + case MetadataKind::kHandle: + os << "void*"; + break; + case MetadataKind::kMetadata: + os << "struct " << arr->get_element_c_struct_name(); + break; + default: + CHECK(false) << "Unknown kind in MetadataArray: " << arr->kind + << " (struct_name=" << arr->get_c_struct_name() << ")"; + break; + } + } + + public: void CodegenMetadata(::tvm::runtime::metadata::Metadata metadata) { decl_ << "#include " << std::endl << "#include " << std::endl << "#include " << std::endl; - std::vector queue; - MetadataQueuer queuer{&queue}; - queuer.Visit(kGlobalSymbol, &metadata); - - for (MetadataQueuer::QueueItem item : queue) { - auto struct_name = std::get<0>(item); - auto obj = std::get<1>(item); - auto arr = obj.as(); - is_first_item_ = true; - address_.push_back(struct_name); - if (arr != nullptr) { - const char* const_part = "const "; - if (arr->type_index == MetadataTypeIndex::kString) { - const_part = ""; - } - code_ << const_part; - switch (arr->type_index) { - case MetadataTypeIndex::kUint64: - code_ << "uint64_t"; - break; - case MetadataTypeIndex::kInt64: - code_ << "int64_t"; - break; - case MetadataTypeIndex::kBool: - code_ << "bool"; - break; - case MetadataTypeIndex::kString: - code_ << "const char*"; - break; - case MetadataTypeIndex::kHandle: - code_ << "void*"; - break; - case MetadataTypeIndex::kMetadata: - code_ << "struct " << arr->struct_name; - break; - default: - CHECK(false) << "Unknown type_index in array: " << arr->type_index - << " (struct_name=" << arr->struct_name << ")"; - break; - } - code_ << " " << struct_name << "[" << arr->array.size() << "] = {" << std::endl; - VisitArray(arr); - } else { - code_ << "const struct TVMMetadata " << struct_name << " = {" << std::endl; - Visit(nullptr, &obj); + std::vector queue; + metadata::DiscoverArraysVisitor array_discover{&queue}; + array_discover.Visit(metadata::kMetadataGlobalSymbol, &metadata); + + for (auto item : queue) { + auto struct_address = std::get<0>(item); + address_.push_back(struct_address); + + auto arr = std::get<1>(item); + + // Prepend const with everything except C-string, which needs appending. + if (arr->kind != MetadataKind::kString) { + code_ << "const "; + } + EmitCType(arr.operator->(), code_); + if (arr->kind == MetadataKind::kString) { + code_ << " const"; } + code_ << " " << struct_address << "[" << arr->array.size() << "] = {" << std::endl; + is_first_item_ = true; + + VisitArray(arr); address_.pop_back(); code_ << "};" << std::endl; } + + // Finally, emit overall struct. + code_ << "const struct TVMMetadata " << metadata::kMetadataGlobalSymbol << " = {" << std::endl; + Visit(nullptr, &metadata); + code_ << "};" << std::endl; } std::string GetOutput() { return decl_.str() + code_.str(); } @@ -804,8 +840,8 @@ runtime::Module CreateCSourceCppMetadataModule(runtime::metadata::Metadata metad << "(TVMValue* arg_values, int* arg_tcodes, int " "num_args, TVMValue* ret_values, int* ret_tcodes, void* resource_handle) {" << std::endl; - lookup_func << " ret_values[0].v_handle = (void*) &" << MetadataSerializer::kGlobalSymbol - << ";" << std::endl; + lookup_func << " ret_values[0].v_handle = (void*) &" << metadata::kMetadataGlobalSymbol << ";" + << std::endl; lookup_func << " ret_tcodes[0] = kTVMOpaqueHandle;" << std::endl; lookup_func << " return 0;" << std::endl; lookup_func << "};" << std::endl; diff --git a/tests/cpp/aot_metadata_test.cc b/tests/cpp/aot_metadata_test.cc index abf37ce4569a..0fa03af3b738 100644 --- a/tests/cpp/aot_metadata_test.cc +++ b/tests/cpp/aot_metadata_test.cc @@ -150,8 +150,8 @@ TEST(Metadata, Visitor) { // Just identify the tensor. auto input_array = Downcast(v.values[1]); - EXPECT_THAT(input_array->type_index, Eq(tvm::runtime::metadata::MetadataTypeIndex::kMetadata)); - EXPECT_THAT(input_array->struct_name, StrEq("TVMTensorInfo")); + EXPECT_THAT(input_array->kind, Eq(tvm::runtime::metadata::MetadataKind::kMetadata)); + EXPECT_THAT(input_array->type_key, StrEq("metadata.TensorInfoNode")); EXPECT_THAT(input_array->array.size(), Eq(2)); auto input1 = Downcast(input_array->array[0]); @@ -168,8 +168,8 @@ TEST(Metadata, Visitor) { EXPECT_THAT(num_inputs->value, Eq(2)); auto output_array = Downcast(v.values[3]); - EXPECT_THAT(output_array->type_index, Eq(tvm::runtime::metadata::MetadataTypeIndex::kMetadata)); - EXPECT_THAT(output_array->struct_name, StrEq("TVMTensorInfo")); + EXPECT_THAT(output_array->kind, Eq(tvm::runtime::metadata::MetadataKind::kMetadata)); + EXPECT_THAT(output_array->type_key, StrEq("metadata.TensorInfoNode")); auto output1 = Downcast(output_array->array[0]); EXPECT_THAT(output1->name(), Eq("output1")); @@ -178,8 +178,8 @@ TEST(Metadata, Visitor) { EXPECT_THAT(num_outputs->value, Eq(1)); auto pool_array = Downcast(v.values[5]); - EXPECT_THAT(pool_array->type_index, Eq(tvm::runtime::metadata::MetadataTypeIndex::kMetadata)); - EXPECT_THAT(pool_array->struct_name, StrEq("TVMTensorInfo")); + EXPECT_THAT(pool_array->kind, Eq(tvm::runtime::metadata::MetadataKind::kMetadata)); + EXPECT_THAT(pool_array->type_key, StrEq("metadata.TensorInfoNode")); auto pool1 = Downcast(pool_array->array[0]); EXPECT_THAT(pool1->name(), Eq("pool1")); From 783aa439b215d6a3c22817861a6c692ddcd114e3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 28 Feb 2022 17:07:14 +0900 Subject: [PATCH 02/32] add metadata serialization support to llvm codegen --- src/target/llvm/codegen_cpu.cc | 385 ++++++++++++++++++++++++- src/target/llvm/codegen_cpu.h | 10 +- src/target/llvm/codegen_llvm.cc | 87 +++--- src/target/llvm/codegen_llvm.h | 37 ++- src/target/llvm/llvm_common.cc | 7 + src/target/llvm/llvm_common.h | 2 + src/target/llvm/llvm_module.cc | 46 ++- src/target/llvm/llvm_module.h | 3 + src/target/metadata_module.cc | 6 + tests/python/relay/aot/test_cpp_aot.py | 21 +- 10 files changed, 530 insertions(+), 74 deletions(-) diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 53c8f7754602..4188c26f7d6f 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -30,6 +30,7 @@ #include #include #include +#include #include "../func_registry_generator.h" @@ -73,10 +74,7 @@ void CodeGenCPU::Init(const std::string& module_name, llvm::TargetMachine* tm, // TVMValue* out_ret_value, int* out_ret_tcode, // void* resource_handle); ftype_tvm_backend_packed_c_func_ = llvm::FunctionType::get( - t_int_, - {t_tvm_func_handle_, t_tvm_value_->getPointerTo(), t_int_->getPointerTo(), t_int_, - t_tvm_value_->getPointerTo(), t_int_->getPointerTo(), t_void_p_}, - false); + t_int_, {t_void_p_, t_void_p_, t_int_, t_void_p_, t_void_p_, t_void_p_}, false); t_tvm_crt_func_registry_ = llvm::StructType::create( {t_char_->getPointerTo(), ftype_tvm_backend_packed_c_func_->getPointerTo()}); t_tvm_crt_module_ = llvm::StructType::create({t_tvm_crt_func_registry_->getPointerTo()}); @@ -802,10 +800,14 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& args, const DataType& r_type, - const int64_t begin, const int64_t end) { + const int64_t begin, const int64_t end, + bool use_string_lookup) { PackedCall pc; std::string func_name = args[0].as()->value; - llvm::Value* handle = GetPackedFuncHandle(func_name); + llvm::Value* handle = nullptr; + if (use_string_lookup) { + handle = GetPackedFuncHandle(func_name); + } // call the function int64_t nargs = end - begin; ICHECK_GE(nargs, 0); @@ -822,14 +824,45 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& TypedPointer ret_tcode = CreateBufferPtr(stack_tcode, DataType::Int(32), {ConstInt32(end)}, DataType::Int(32)); + llvm::FunctionType* callee_ftype = nullptr; + llvm::Value* callee_value = nullptr; + std::vector call_args; + + if (use_string_lookup) { + callee_ftype = ftype_tvm_func_call_; + callee_value = RuntimeTVMFuncCall(); + call_args.push_back(handle); + } else { + callee_ftype = ftype_tvm_backend_packed_c_func_; + callee_value = module_->getFunction(func_name); + if (callee_value == nullptr) { + callee_value = + llvm::Function::Create(ftype_tvm_backend_packed_c_func_, llvm::Function::ExternalLinkage, + func_name, module_.get()); + } + } + + if (use_string_lookup) { + call_args.insert(call_args.end(), + {arg_value, arg_tcode.addr, ConstInt32(nargs), ret_value, ret_tcode.addr}); + } else { + nargs -= 1; + call_args.insert(call_args.end(), { + builder_->CreateBitCast(arg_value, t_void_p_), + builder_->CreateBitCast(arg_tcode.addr, t_void_p_), + ConstInt32(nargs), + builder_->CreateBitCast(ret_value, t_void_p_), + builder_->CreateBitCast(ret_tcode.addr, t_void_p_), + }); + call_args.push_back(llvm::ConstantPointerNull::get(t_void_p_)); + } #if TVM_LLVM_VERSION >= 90 - auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall()); + auto call_callee = llvm::FunctionCallee(callee_ftype, callee_value); #else - auto call_callee = RuntimeTVMFuncCall(); + auto call_callee = callee_value; #endif - llvm::Value* call = builder_->CreateCall( - call_callee, - {handle, arg_value, arg_tcode.addr, ConstInt32(nargs), ret_value, ret_tcode.addr}); + llvm::Value* call = builder_->CreateCall(call_callee, call_args); + llvm::BasicBlock* end_block = CheckCallSuccess(call); // Load the return value and cast it to the designated type (r_type). @@ -858,17 +891,19 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& return pc; } -llvm::Value* CodeGenCPU::CreateCallPacked(const CallNode* op) { - ICHECK_EQ(op->args.size(), 5U); +llvm::Value* CodeGenCPU::CreateCallPacked(const CallNode* op, bool use_string_lookup) { + LOG(INFO) << "CreateCallPacked: " << GetRef(op); + auto expected_num_args = use_string_lookup ? 5U : 6U; + ICHECK_EQ(op->args.size(), expected_num_args); PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as()->value, - op->args[4].as()->value); + op->args[4].as()->value, use_string_lookup); return pc.ret_value; } llvm::Value* CodeGenCPU::CreateCallTracePacked(const CallNode* op) { ICHECK_EQ(op->args.size(), 6U); PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as()->value, - op->args[4].as()->value); + op->args[4].as()->value, true); // Get traced value. llvm::Value* traced_value = MakeValue(op->args[5]); // The update_block handles case when we need to update the return value. @@ -914,6 +949,321 @@ llvm::Value* CodeGenCPU::RuntimeTVMParallelBarrier() { return GetContextPtr(gv_tvm_parallel_barrier_); } +/*! \brief Defines LLVM Types for each Metadata member type. */ +struct MetadataLlvmTypes { + llvm::Type* t_float64; + llvm::Type* t_uint8; + llvm::Type* t_int64; + llvm::Type* t_bool; + llvm::Type* t_cstring; + llvm::Type* t_void_p; + llvm::StructType* t_data_type; + + /*! \brief Maps a MetadataBase subclass' type_key to its corresponding LLVM StructType. */ + ::std::unordered_map structs_by_type_key; +}; + +class MetadataTypeDefiner : public AttrVisitor { + public: + MetadataTypeDefiner(llvm::LLVMContext* ctx, struct MetadataLlvmTypes* llvm_types) + : ctx_{ctx}, llvm_types_{llvm_types} {} + + void Visit(const char* key, double* value) final { + elements_.emplace_back(llvm_types_->t_float64); + } + void Visit(const char* key, int64_t* value) final { + elements_.emplace_back(llvm_types_->t_int64); + } + void Visit(const char* key, uint64_t* value) final { + elements_.emplace_back(llvm_types_->t_int64); + } + void Visit(const char* key, int* value) final { elements_.emplace_back(llvm_types_->t_int64); } + void Visit(const char* key, bool* value) final { elements_.emplace_back(llvm_types_->t_bool); } + void Visit(const char* key, std::string* value) final { + elements_.emplace_back(llvm_types_->t_cstring); + } + void Visit(const char* key, void** value) final { elements_.emplace_back(llvm_types_->t_void_p); } + void Visit(const char* key, DataType* value) final { + elements_.emplace_back(llvm_types_->t_data_type); + } + void Visit(const char* key, runtime::NDArray* value) final { + CHECK(false) << "Do not support serializing NDArray"; + } + + private: + void VisitMetadataBase(runtime::metadata::MetadataBase metadata) { + elements_.emplace_back(llvm::PointerType::getUnqual( + llvm::StructType::create(*ctx_, metadata->get_c_struct_name()))); + if (visited_.find(metadata->get_c_struct_name()) != visited_.end()) { + return; + } + + if (to_visit_.find(metadata->get_c_struct_name()) != to_visit_.end()) { + return; + } + to_visit_[metadata->get_c_struct_name()] = metadata; + } + + public: + using MetadataKind = runtime::metadata::MetadataKind; + + void VisitArray(const runtime::metadata::MetadataArrayNode* arr) { + switch (arr->kind) { + case MetadataKind::kUint64: // LLVM encodes signed and unsigned with same types. + case MetadataKind::kInt64: + elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_int64)); + break; + case MetadataKind::kBool: + elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_bool)); + break; + case MetadataKind::kString: + elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_cstring)); + break; + case MetadataKind::kHandle: + CHECK(false) << "Do not support handle"; + break; + case MetadataKind::kMetadata: + elements_.emplace_back( + llvm::PointerType::getUnqual(llvm_types_->structs_by_type_key[arr->type_key])); + break; + default: + CHECK(false) << "Unsupported metadata kind " << arr->kind; + break; + } + } + + void Visit(const char* key, ObjectRef* value) final { + const runtime::metadata::MetadataArrayNode* arr = + value->as(); + if (arr != nullptr) { + VisitArray(arr); + } else { + elements_.emplace_back( + llvm::PointerType::getUnqual(llvm_types_->structs_by_type_key[(*value)->GetTypeKey()])); + } + } + + void DefineType(runtime::metadata::MetadataBase metadata) { + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + LOG(INFO) << "Created type for " << metadata->GetTypeKey() << ":"; + for (auto e : elements_) { + std::string value; + llvm::raw_string_ostream os(value); + e->print(os, true); + // LOG(INFO) << " - " << e << ", tyid=" << e->getTypeID() << " == " << value; + // e->dump(); + } + llvm_types_->structs_by_type_key[metadata->GetTypeKey()] = + llvm::StructType::create(*ctx_, elements_, metadata->get_c_struct_name()); + elements_.clear(); + } + + llvm::LLVMContext* ctx_; + struct MetadataLlvmTypes* llvm_types_; + ::std::unordered_set<::std::string> visited_; + ::std::unordered_map<::std::string, runtime::metadata::MetadataBase> to_visit_; + ::std::vector elements_; +}; + +class MetadataSerializerLLVM : public AttrVisitor { + using MetadataKind = runtime::metadata::MetadataKind; + + public: + MetadataSerializerLLVM(CodeGenLLVM* codegen, struct MetadataLlvmTypes* llvm_types) + : codegen_{codegen}, llvm_types_{llvm_types} {} + + void Visit(const char* key, double* value) final { + elements_.back().emplace_back(llvm::ConstantFP::get(llvm_types_->t_float64, *value)); + } + void Visit(const char* key, int64_t* value) final { + elements_.back().emplace_back(llvm::ConstantInt::get( + llvm_types_->t_int64, static_cast(*value), true /* isSigned */)); + } + void Visit(const char* key, uint64_t* value) final { + elements_.back().emplace_back( + llvm::ConstantInt::get(llvm_types_->t_int64, *value, false /* isSigned */)); + } + void Visit(const char* key, int* value) final { + elements_.back().emplace_back( + llvm::ConstantInt::get(llvm_types_->t_int64, *value, true /* isSigned */)); + } + void Visit(const char* key, bool* value) final { + elements_.back().emplace_back(llvm::ConstantInt::get( + llvm_types_->t_uint8, static_cast(*value), false /* isSigned */)); + } + void Visit(const char* key, std::string* value) final { + elements_.back().emplace_back(codegen_->GetConstString(*value)); + } + void Visit(const char* key, void** value) final { + CHECK(false) << "Do not support serializing void*"; + } + void Visit(const char* key, DataType* value) final { + elements_.back().emplace_back(llvm::ConstantStruct::get( + llvm_types_->t_data_type, + {llvm::ConstantInt::get(llvm_types_->t_uint8, value->code(), false /* isSigned */), + llvm::ConstantInt::get(llvm_types_->t_uint8, value->bits(), false /* isSigned */), + llvm::ConstantInt::get(llvm_types_->t_uint8, value->lanes(), false /* isSigned */)})); + } + + void Visit(const char* key, runtime::NDArray* value) final { + CHECK(false) << "Do not support serializing NDArray"; + } + + void VisitMetadata(runtime::metadata::MetadataBase metadata) { + elements_.emplace_back(std::vector()); + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + auto struct_elements = elements_.back(); + elements_.pop_back(); + auto struct_ty = llvm_types_->structs_by_type_key[metadata->GetTypeKey()]; + ICHECK(struct_ty != nullptr) << "Did not find LLVM StructType* for type_key=" + << metadata->GetTypeKey(); + std::string ty_value; + llvm::raw_string_ostream ty_os(ty_value); + struct_ty->print(ty_os, true); + LOG(INFO) << "Get LLVM ConstantStruct (" << struct_elements.size() << " elements)"; + LOG(INFO) << " Type (" << metadata->GetTypeKey() << "==" << struct_ty->getName().data() + << "): " << ty_value; + for (auto e : struct_elements) { + std::string value; + llvm::raw_string_ostream os(value); + e->print(os); + LOG(INFO) << " - " << value; + } + CHECK_EQ(struct_elements.size(), struct_ty->getNumElements()); + auto out = llvm::ConstantStruct::get(struct_ty, struct_elements); + if (elements_.size() > 0) { + elements_.back().push_back(out); + } else { + last_production_ = out; + } + } + + void VisitArray(const runtime::metadata::MetadataArrayNode* arr) { + llvm::Type* element_type; + switch (arr->kind) { + case MetadataKind::kInt64: + element_type = llvm_types_->t_int64; + break; + case MetadataKind::kUint64: + element_type = llvm_types_->t_int64; + break; + case MetadataKind::kBool: + element_type = llvm_types_->t_uint8; + break; + case MetadataKind::kString: + element_type = llvm_types_->t_cstring; + break; + case MetadataKind::kMetadata: { + element_type = llvm_types_->structs_by_type_key[arr->type_key]; + ICHECK(element_type != nullptr) + << "Did not find LLVM StructType* for type_key=" << arr->type_key; + break; + } + default: + LOG(FATAL) << "unknown metadata kind " << arr->kind; + break; + } + + elements_.emplace_back(std::vector()); + for (auto o : arr->array) { + if (o->IsInstance()) { + double value = Downcast(o)->value; + Visit(nullptr, &value); + } + if (o->IsInstance()) { + auto value = Downcast(o)->value; + Visit(nullptr, &value); + } else if (o->IsInstance()) { + ::std::string value = Downcast(o); + Visit(nullptr, &value); + } else { + // nested array not possible. + VisitMetadata(Downcast(o)); + } + } + auto array = elements_.back(); + elements_.pop_back(); + CHECK(element_type != nullptr); + auto arr_ty = llvm::ArrayType::get(element_type, array.size()); + auto llvm_arr = llvm::ConstantArray::get(arr_ty, array); + + if (elements_.size() > 0) { + elements_.back().emplace_back( + codegen_->GetGlobalConstant(llvm_arr, "", llvm::GlobalValue::PrivateLinkage)); + } else { + last_production_ = llvm_arr; + } + } + + void Visit(const char* key, ObjectRef* value) final { + const runtime::metadata::MetadataArrayNode* arr = + value->as(); + if (arr != nullptr) { + VisitArray(arr); + return; + } + + runtime::metadata::MetadataBase metadata = Downcast(*value); + VisitMetadata(metadata); + } + + llvm::Constant* Serialize(runtime::metadata::MetadataBase metadata) { + Visit(nullptr, &metadata); + ICHECK(last_production_); + return codegen_->GetGlobalConstant(last_production_); + } + + CodeGenLLVM* codegen_; + MetadataLlvmTypes* llvm_types_; + llvm::LLVMContext* ctx_; + llvm::Module* module_; + std::vector> elements_; + llvm::Constant* last_production_; +}; + +void CodeGenCPU::DefineMetadata(runtime::metadata::Metadata metadata) { + MetadataLlvmTypes llvm_types{ + t_float64_ /* t_float64 */, + llvm::Type::getInt8Ty(*ctx_) /* t_uint8 */, + t_int64_ /* t_int64 */, + llvm::Type::getInt8Ty(*ctx_) /* t_bool */, + t_char_->getPointerTo() /* t_cstring */, + t_void_p_ /* t_void_p */, + llvm::StructType::create(*ctx_, {t_int8_, t_int8_, t_int8_}, "DLDataType") /* t_data_type */, + }; + + std::vector queue; + metadata::DiscoverComplexTypesVisitor discover_complex{&queue}; + discover_complex.Discover(metadata); + + MetadataTypeDefiner definer{ctx_, &llvm_types}; + for (auto md : queue) { + if (md.defined()) { + definer.DefineType(md); + } + } + + MetadataSerializerLLVM serializer{this, &llvm_types}; + auto metadata_constant_gv = serializer.Serialize(metadata); + + function_ = + llvm::Function::Create(ftype_tvm_backend_packed_c_func_, llvm::Function::ExternalLinkage, + "get_c_metadata", module_.get()); + function_->setCallingConv(llvm::CallingConv::C); + function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass); + + llvm::BasicBlock* entry_point_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_); + builder_->SetInsertPoint(entry_point_entry); + + auto ret_values_p = builder_->CreateBitCast(GetArg(function_, 3), t_void_p_->getPointerTo()); + builder_->CreateStore(builder_->CreateBitCast(metadata_constant_gv, t_void_p_), ret_values_p); + + auto ret_tcode = builder_->CreateBitCast(GetArg(function_, 4), t_int_->getPointerTo()); + builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMOpaqueHandle), ret_tcode); + + builder_->CreateRet(ConstInt32(0)); +} + void CodeGenCPU::DefineFunctionRegistry(Array func_names) { ICHECK(is_system_lib_) << "Loading of --system-lib modules is yet to be defined for C runtime"; Array symbols; @@ -980,9 +1330,11 @@ void CodeGenCPU::AddStartupFunction() { llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { if (op->op.same_as(builtin::tvm_call_packed_lowered())) { - return CreateCallPacked(op); + return CreateCallPacked(op, true /* use_string_lookup */); } else if (op->op.same_as(builtin::tvm_call_trace_packed_lowered())) { return CreateCallTracePacked(op); + } else if (op->op.same_as(builtin::tvm_call_cpacked_lowered())) { + return CreateCallPacked(op, false /* use_string_lookup */); } else if (op->op.same_as(builtin::tvm_static_handle())) { return CreateStaticHandle(); } else if (op->op.same_as(builtin::tvm_throw_last_error())) { @@ -1052,6 +1404,7 @@ void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) { builder_->CreateCondBr(cond, end_block, fail_block, md_very_likely_branch_); // fail condition. builder_->SetInsertPoint(fail_block); + #if TVM_LLVM_VERSION >= 90 auto err_callee = llvm::FunctionCallee(ftype_tvm_api_set_last_error_, RuntimeTVMAPISetLastError()); diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index 26f251f1a9c8..a491d539a6ea 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -56,6 +56,12 @@ class CodeGenCPU : public CodeGenLLVM { */ void DefineFunctionRegistry(Array func_names); + /*! + * \brief Serialize the metadata object as data, and implement get_c_metadata function. + * \param metadata The metadata which should be serialized. + */ + void DefineMetadata(runtime::metadata::Metadata metadata); + protected: void AddStartupFunction() final; // meta data @@ -117,9 +123,9 @@ class CodeGenCPU : public CodeGenLLVM { llvm::BasicBlock* end_block; }; PackedCall MakeCallPackedLowered(const Array& args, const DataType& r_type, - const int64_t begin, const int64_t end); + const int64_t begin, const int64_t end, bool use_string_lookup); // create call into tvm packed function. - llvm::Value* CreateCallPacked(const CallNode* op); + llvm::Value* CreateCallPacked(const CallNode* op, bool use_string_lookup); // Create trace call into tvm packed function. llvm::Value* CreateCallTracePacked(const CallNode* op); // Create static initialization diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index bacfbc9947a5..47cb3dd711d7 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -37,6 +37,7 @@ #include "codegen_cpu.h" #include "codegen_params.h" #include "llvm/Support/raw_os_ostream.h" +#include "llvm_common.h" namespace tvm { namespace codegen { @@ -134,11 +135,11 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; - ICHECK(module_->getFunction(static_cast(global_symbol.value())) == nullptr) - << "Function " << global_symbol << " already exist in module"; - - function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, - global_symbol.value().operator std::string(), module_.get()); + function_ = module_->getFunction(static_cast(global_symbol.value())); + if (function_ == nullptr) { + function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, + global_symbol.value().operator std::string(), module_.get()); + } function_->setCallingConv(llvm::CallingConv::C); function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass); @@ -191,6 +192,19 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { } } +llvm::GlobalVariable* CodeGenLLVM::GetLinkedParamSymbol(const std::string& param_name, + llvm::ConstantArray* array) { + std::string symbol_name = std::string(::tvm::runtime::symbol::tvm_param_prefix) + param_name; + llvm::GlobalVariable* var = module_->getGlobalVariable(symbol_name, true /* AllowInternal */); + if (var == nullptr) { + CHECK(array != nullptr) << "Expect param symbol " << symbol_name + << " to either be defined or for the array to be supplied"; + var = new llvm::GlobalVariable(*module_, static_cast(array->getType()), true, + llvm::GlobalValue::InternalLinkage, array, symbol_name); + } + return var; +} + void CodeGenLLVM::LinkParameters(const Map params) { // It would be nice to de-dupe these declarations frm src/tir/transforms/make_packed_api.cc, // but they are at a different layer in the compiler... @@ -209,22 +223,13 @@ void CodeGenLLVM::LinkParameters(const Map params) { llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function); builder_->SetInsertPoint(entry); - auto getArg = [function](int i) -> llvm::Argument* { -#if TVM_LLVM_VERSION >= 100 - return function->getArg(i); -#elif TVM_LLVM_VERSION >= 50 - return &function->arg_begin()[i]; -#else - return &*std::next(function->arg_begin(), i); -#endif - }; - llvm::Type* t_int64_p = t_int64_->getPointerTo(GetGlobalAddressSpace()); - llvm::Value* sid = builder_->CreateLoad(t_int64_, builder_->CreateBitCast(getArg(0), t_int64_p)); + llvm::Value* sid = + builder_->CreateLoad(t_int64_, builder_->CreateBitCast(GetArg(function, 0), t_int64_p)); - auto ret_tcode = builder_->CreateBitCast(getArg(4), t_int_p); - auto ret_value = - builder_->CreateBitCast(getArg(3), t_void_p_->getPointerTo(GetGlobalAddressSpace())); + auto ret_tcode = builder_->CreateBitCast(GetArg(function, 4), t_int_p); + auto ret_value = builder_->CreateBitCast(GetArg(function, 3), + t_void_p_->getPointerTo(GetGlobalAddressSpace())); llvm::BasicBlock* default_block = llvm::BasicBlock::Create(*ctx_, "default_block", function); llvm::SwitchInst* switch_inst = builder_->CreateSwitch(sid, default_block, params.size() + 1); @@ -236,9 +241,7 @@ void CodeGenLLVM::LinkParameters(const Map params) { // Add data to the global section. for (auto kv : params) { auto array = NDArrayToLLVMArray(ctx_, kv.second->param); - std::string symbol_name = std::string(::tvm::runtime::symbol::tvm_param_prefix) + kv.first; - llvm::GlobalVariable* param_symbol = new llvm::GlobalVariable( - *module_, array->getType(), true, llvm::GlobalValue::InternalLinkage, array, symbol_name); + llvm::GlobalVariable* param_symbol = GetLinkedParamSymbol(kv.first, array); auto dtype = tvm::runtime::DataType(kv.second->param->dtype); size_t align = std::max(tvm::runtime::GetVectorBytes(dtype), tvm::runtime::kAllocAlignment); #if TVM_LLVM_VERSION >= 100 @@ -246,8 +249,10 @@ void CodeGenLLVM::LinkParameters(const Map params) { #else param_symbol->setAlignment(align); #endif + param_symbol->setInitializer(array); - llvm::BasicBlock* case_block = llvm::BasicBlock::Create(*ctx_, "case_" + symbol_name, function); + llvm::BasicBlock* case_block = + llvm::BasicBlock::Create(*ctx_, "case_" + param_symbol->getName(), function); switch_inst->addCase( llvm::cast(llvm::ConstantInt::get(t_int64_, kv.second->id)), case_block); builder_->SetInsertPoint(case_block); @@ -388,6 +393,7 @@ void CodeGenLLVM::Optimize() { fpass.run(*it); } fpass.doFinalization(); + // PrintModule(module_.get()); mpass.run(*module_); } @@ -770,21 +776,27 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* va } } -llvm::Constant* CodeGenLLVM::GetConstString(const std::string& str) { - auto it = str_map_.find(str); - if (it != str_map_.end()) return it->second; - llvm::Type* type = llvm::ArrayType::get(t_char_, str.length() + 1); - llvm::GlobalVariable* global = new llvm::GlobalVariable( - *module_, type, true, llvm::GlobalValue::PrivateLinkage, nullptr, ".str"); +llvm::Constant* CodeGenLLVM::GetGlobalConstant(llvm::Constant* const_data, const std::string& name, + llvm::GlobalValue::LinkageTypes linkage_type) { + llvm::Type* ty = const_data->getType(); + llvm::GlobalVariable* global = + new llvm::GlobalVariable(*module_, ty, true, linkage_type, const_data, name); #if TVM_LLVM_VERSION >= 100 global->setAlignment(llvm::Align(1)); #else global->setAlignment(1); #endif - global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, str)); llvm::Constant* zero = ConstInt32(0); llvm::Constant* indices[] = {zero, zero}; - llvm::Constant* ptr = llvm::ConstantExpr::getGetElementPtr(type, global, indices); + llvm::Constant* ptr = llvm::ConstantExpr::getGetElementPtr(ty, global, indices); + return ptr; +} + +llvm::Constant* CodeGenLLVM::GetConstString(const std::string& str) { + auto it = str_map_.find(str); + if (it != str_map_.end()) return it->second; + auto llvm_str = llvm::ConstantDataArray::getString(*ctx_, str); + auto ptr = GetGlobalConstant(llvm_str, ".str", llvm::GlobalValue::PrivateLinkage); str_map_[str] = ptr; return ptr; } @@ -1399,9 +1411,13 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { + // LOG(INFO) << "Visit Call:" << GetRef(op); if (auto* ptr_op = op->op.as()) { auto call_op = GetRef(ptr_op); - if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { + if (op->op.same_as(builtin_lookup_param_)) { + // return llvm::ConstantInt::get(t_void_p_, 0); + return GetLinkedParamSymbol(Downcast(op->args[0])->value, nullptr); + } else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { // call extern intrinsic ICHECK_GE(op->args.size(), 1U); auto global_symbol = Downcast(op->args[0]); @@ -1412,7 +1428,10 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { return this->CreateCallExtern(GetType(GetRef(op)), op_attr_global_symbol_[call_op], op->args, false); } else { - return CreateIntrinsic(op); + VLOG(2) << "CreateIntrinsic: " << GetRef(op); + auto x = CreateIntrinsic(op); + VLOG(2) << "CreateIntrinsic done"; + return x; } } else { ICHECK(op->op.as()); @@ -1557,7 +1576,7 @@ void CodeGenLLVM::VisitStmt_(const AllocateNode* op) { ICHECK(!is_zero(op->condition)); llvm::Value* buf = nullptr; - size_t constant_size = op->ConstantAllocationSize(); + int32_t constant_size = op->ConstantAllocationSize(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation"; StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; if (constant_size % 4 == 0 && info.alignment == 0) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 7a7ca6578f28..172eb5ef1019 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -23,6 +23,7 @@ */ #ifndef TVM_TARGET_LLVM_CODEGEN_LLVM_H_ #define TVM_TARGET_LLVM_CODEGEN_LLVM_H_ +#include #ifdef TVM_LLVM_VERSION #include @@ -141,7 +142,10 @@ class CodeGenLLVM : public ExprFunctor, * \param e The expression to be created value for. * \return created value. */ - llvm::Value* MakeValue(const PrimExpr& e) { return VisitExpr(e); } + llvm::Value* MakeValue(const PrimExpr& e) { + auto a = VisitExpr(e); /* LOG(INFO) << "MakeValue (" << e << "): " << a; */ + return a; + } // Short hande code to get a constant int 32 llvm::Constant* ConstInt32(int64_t value) const { return llvm::ConstantInt::getSigned(t_int32_, value); @@ -190,6 +194,13 @@ class CodeGenLLVM : public ExprFunctor, void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const EvaluateNode* op) override; + // Get constant string + llvm::Constant* GetConstString(const std::string& str); + + llvm::Constant* GetGlobalConstant( + llvm::Constant* const_data, const std::string& name = "", + llvm::GlobalValue::LinkageTypes linkage_type = llvm::GlobalValue::InternalLinkage); + protected: /*! * \brief Address and type pair to assist in handling opaque pointers. @@ -341,6 +352,14 @@ class CodeGenLLVM : public ExprFunctor, */ llvm::Function* GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type* ret_type, llvm::ArrayRef arg_types); + /*! + * \brief Lookup or create a GlobalVariable whose content is the data field of a DLTensor for a + * given linked_param() CallNode. + * \param param_name Parameter name (e.g. unmangled, from lookup_param node). + * \return the GlobalVariable indicated in the brief. + */ + llvm::GlobalVariable* GetLinkedParamSymbol(const ::std::string& param_name, + llvm::ConstantArray* array); /*! * \brief Get the number of elements in the given vector value. * \param vec The value, must be of a vector type. @@ -353,8 +372,6 @@ class CodeGenLLVM : public ExprFunctor, int* p_native_bits); // Returns whether the LLVM type has padding for alignment bool HasAlignmentPadding(DataType dtype); - // Get constant string - llvm::Constant* GetConstString(const std::string& str); // do a scalarize call with f llvm::Value* CreateScalarizedCall(const CallNode* op, llvm::Function* f, const std::vector& args); @@ -389,6 +406,16 @@ class CodeGenLLVM : public ExprFunctor, unsigned int shared_address_space, int alignment, llvm::GlobalValue::LinkageTypes linkage); + llvm::Argument* GetArg(const llvm::Function* function, int i) const { +#if TVM_LLVM_VERSION >= 100 + return function->getArg(i); +#elif TVM_LLVM_VERSION >= 50 + return const_cast(&function->arg_begin()[i]); +#else + return const_cast(&*std::next(function->arg_begin(), i)); +#endif + } + // The IRBuilder. using IRBuilder = llvm::IRBuilder; // The current function @@ -447,6 +474,8 @@ class CodeGenLLVM : public ExprFunctor, const Op& builtin_call_pure_extern_ = builtin::call_pure_extern(); const Op& builtin_call_llvm_intrin_ = builtin::call_llvm_intrin(); const Op& builtin_call_llvm_pure_intrin_ = builtin::call_llvm_pure_intrin(); + const Op& builtin_lookup_param_ = builtin::lookup_param(); + const Op& builtin_tvm_call_cpacked_lowered_ = builtin::tvm_call_cpacked_lowered(); /*! \brief Helper struct for debug infos. */ struct DebugInfo { @@ -481,6 +510,8 @@ void CodeGenLLVM::AddFunctionsOrdered(IterType begin, IterType end, ConvType pfu return name_a < name_b; }); for (auto& f : funcs) { + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + LOG(INFO) << "Adding " << static_cast(global_symbol.value()); AddFunction(f); } } diff --git a/src/target/llvm/llvm_common.cc b/src/target/llvm/llvm_common.cc index 06b2be2d9fb6..f13e8563e053 100644 --- a/src/target/llvm/llvm_common.cc +++ b/src/target/llvm/llvm_common.cc @@ -189,6 +189,13 @@ std::string LLVMTargetToString(const Target& target) { return os.str(); } +void PrintModule(const llvm::Module* mod) { + std::string modpe_str; + llvm::raw_string_ostream rso(modpe_str); + mod->print(rso, nullptr); + LOG(INFO) << rso.str(); +} + } // namespace codegen } // namespace tvm #endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/llvm_common.h b/src/target/llvm/llvm_common.h index 556f05d2e33a..e2e3384c1a19 100644 --- a/src/target/llvm/llvm_common.h +++ b/src/target/llvm/llvm_common.h @@ -126,6 +126,8 @@ std::unique_ptr GetLLVMTargetMachine(const Target& target, */ std::string LLVMTargetToString(const Target& target); +void PrintModule(const llvm::Module* mod); + } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index cf8b59357b47..066e091e637c 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -308,14 +308,14 @@ class LLVMModuleNode final : public runtime::ModuleNode { cg->SetFastMathFlag(fmf); + if (found_linked_params) { + cg->LinkParameters(linked_params); + } cg->AddFunctionsOrdered(funcs.begin(), funcs.end()); if (entry_func.length() != 0) { cg->AddMainFunction(entry_func); } - if (found_linked_params) { - cg->LinkParameters(linked_params); - } module_ = cg->Finish(); module_->addModuleFlag(llvm::Module::Warning, "tvm_target", llvm::MDString::get(*ctx_, LLVMTargetToString(target))); @@ -527,6 +527,46 @@ TVM_REGISTER_GLOBAL("codegen.codegen_blob") return runtime::Module(n); }); +runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata, Target target, + tvm::relay::Runtime runtime) { + InitializeLLVM(); + auto tm = GetLLVMTargetMachine(target); + bool system_lib = runtime->GetAttr("system-lib").value_or(Bool(false)); + auto ctx = std::make_shared(); + std::unique_ptr cg{new CodeGenCPU()}; + + cg->Init("TVMMetadataMod", tm.get(), ctx.get(), system_lib, system_lib, + false /* target_c_runtime */); + + cg->DefineMetadata(metadata); + auto mod = cg->Finish(); + mod->addModuleFlag(llvm::Module::Warning, "tvm_target", + llvm::MDString::get(*ctx, LLVMTargetToString(target))); + mod->addModuleFlag(llvm::Module::Override, "Debug Info Version", llvm::DEBUG_METADATA_VERSION); + + if (tm->getTargetTriple().isOSDarwin()) { + mod->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2); + } + + std::string verify_errors_storage; + llvm::raw_string_ostream verify_errors(verify_errors_storage); + LOG_IF(FATAL, llvm::verifyModule(*mod, &verify_errors)) + << "LLVM module verification failed with the following errors: \n" + << verify_errors.str(); + + // std::string tmp; + // llvm::raw_string_ostream stream(tmp); + // mod->print(stream, nullptr); + // LOG(INFO) << "LLVM metadata IR: " << stream.str(); + + auto n = make_object(); + n->Init(std::move(mod), ctx); + + auto meta_mod = MetadataModuleCreate(metadata); + meta_mod->Import(runtime::Module(n)); + return meta_mod; +} + runtime::Module CreateLLVMCrtMetadataModule(const Array& modules, Target target, tvm::relay::Runtime runtime) { Array func_names; diff --git a/src/target/llvm/llvm_module.h b/src/target/llvm/llvm_module.h index 933030e213d2..660d81400b0d 100644 --- a/src/target/llvm/llvm_module.h +++ b/src/target/llvm/llvm_module.h @@ -33,6 +33,9 @@ namespace tvm { namespace codegen { +runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata, Target target, + tvm::relay::Runtime runtime); + runtime::Module CreateLLVMCrtMetadataModule(const Array& modules, Target target, tvm::relay::Runtime runtime); diff --git a/src/target/metadata_module.cc b/src/target/metadata_module.cc index 8abd18c1d8f3..70b1896fec0e 100644 --- a/src/target/metadata_module.cc +++ b/src/target/metadata_module.cc @@ -144,6 +144,12 @@ static runtime::Module CreateCppMetadataModule( auto metadata_module = CreateCSourceCppMetadataModule(runtime_metadata); metadata_module->Import(target_module); target_module = metadata_module; +#ifdef TVM_LLVM_VERSION + } else if (target->kind->name == "llvm") { + auto metadata_module = CreateLLVMCppMetadataModule(runtime_metadata, target, runtime); + metadata_module->Import(target_module); + target_module = metadata_module; +#endif // TVM_LLVM_VERSION } else { CHECK(false) << "Don't know how to create MetadataModule for target type " << target->str(); } diff --git a/tests/python/relay/aot/test_cpp_aot.py b/tests/python/relay/aot/test_cpp_aot.py index 48057404dd4c..5820efe6237a 100644 --- a/tests/python/relay/aot/test_cpp_aot.py +++ b/tests/python/relay/aot/test_cpp_aot.py @@ -24,20 +24,8 @@ import pytest import tvm -from tvm import relay, TVMError -from tvm.ir.module import IRModule -from tvm.relay import backend, testing, transform -from tvm.relay.testing import byoc -from tvm.relay.op.annotation import compiler_begin, compiler_end -from aot_test_utils import ( - AOTTestModel, - AOT_DEFAULT_RUNNER, - generate_ref_data, - convert_to_relay, - compile_and_run, - compile_models, - parametrize_aot_options, -) +from tvm.relay import backend, testing +from aot_test_utils import generate_ref_data def test_error_c_interface(): @@ -67,9 +55,10 @@ def test_error_c_interface(): enable_usmp = tvm.testing.parameter(True, False) +target_kind = tvm.testing.parameter("c", "llvm") -def test_conv2d(enable_usmp): +def test_conv2d(enable_usmp, target_kind): RELAY_MODEL = textwrap.dedent( """\ #[version = "0.0.5"] @@ -117,7 +106,7 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(3, 3, 5, 5), mod = tvm.relay.build( ir_mod, params=params, - target="c", + target=target_kind, executor=backend.Executor("aot", {"interface-api": "packed"}), ) From f991e199abbed145c8165a3504aeaf70bd81e070 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Fri, 4 Mar 2022 10:21:29 -0800 Subject: [PATCH 03/32] Organize MetadataQueuer into a separate file. --- src/target/llvm/codegen_cpu.cc | 2 +- src/target/metadata_utils.cc | 84 ++++++++++++++++++++++++++++++ src/target/metadata_utils.h | 68 ++++++++++++++++++++++++ src/target/source/source_module.cc | 61 +--------------------- 4 files changed, 154 insertions(+), 61 deletions(-) create mode 100644 src/target/metadata_utils.cc create mode 100644 src/target/metadata_utils.h diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 4188c26f7d6f..fe9c2f27594c 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -33,6 +33,7 @@ #include #include "../func_registry_generator.h" +#include "../metadata_utils.h" namespace tvm { namespace codegen { @@ -892,7 +893,6 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& } llvm::Value* CodeGenCPU::CreateCallPacked(const CallNode* op, bool use_string_lookup) { - LOG(INFO) << "CreateCallPacked: " << GetRef(op); auto expected_num_args = use_string_lookup ? 5U : 6U; ICHECK_EQ(op->args.size(), expected_num_args); PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as()->value, diff --git a/src/target/metadata_utils.cc b/src/target/metadata_utils.cc new file mode 100644 index 000000000000..92724e1636c8 --- /dev/null +++ b/src/target/metadata_utils.cc @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/target/metadata_utils.cc + * \brief Defines utility functions and classes for emitting metadata. + */ +#include "metadata_utils.h" + +namespace tvm { +namespace codegen { + +MetadataQueuer::MetadataQueuer(std::vector* queue) : queue_{queue} {} + +std::string address_from_parts(const std::vector& parts) { + std::stringstream ss; + for (unsigned int i = 0; i < parts.size(); ++i) { + if (i > 0) { + ss << "_"; + } + ss << parts[i]; + } + return ss.str(); +} + +void MetadataQueuer::Visit(const char* key, double* value) {} +void MetadataQueuer::Visit(const char* key, int64_t* value) {} +void MetadataQueuer::Visit(const char* key, uint64_t* value) {} +void MetadataQueuer::Visit(const char* key, int* value) {} +void MetadataQueuer::Visit(const char* key, bool* value) {} +void MetadataQueuer::Visit(const char* key, std::string* value) {} +void MetadataQueuer::Visit(const char* key, DataType* value) {} +void MetadataQueuer::Visit(const char* key, runtime::NDArray* value) {} +void MetadataQueuer::Visit(const char* key, void** value) {} + +void MetadataQueuer::Visit(const char* key, ObjectRef* value) { + address_parts_.push_back(key); + if (value->as() != nullptr) { + auto metadata = Downcast(*value); + const runtime::metadata::MetadataArrayNode* arr = + value->as(); + std::cout << "Is array? " << arr << std::endl; + if (arr != nullptr) { + for (unsigned int i = 0; i < arr->array.size(); i++) { + ObjectRef o = arr->array[i]; + std::cout << "queue-visiting array element " << i << ": " << o->type_index() << " (" + << o.operator->() << ")" << std::endl; + if (o.as() != nullptr) { + std::stringstream ss; + ss << i; + address_parts_.push_back(ss.str()); + runtime::metadata::MetadataBase metadata = Downcast(o); + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + address_parts_.pop_back(); + } + } + } else { + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + } + + queue_->push_back(std::make_tuple(address_from_parts(address_parts_), + Downcast(*value))); + } + address_parts_.pop_back(); +} + +} // namespace codegen +} // namespace tvm diff --git a/src/target/metadata_utils.h b/src/target/metadata_utils.h new file mode 100644 index 000000000000..c305a0671c07 --- /dev/null +++ b/src/target/metadata_utils.h @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/target/metadata_utils.h + * \brief Declares utilty functions and classes for emitting metadata. + */ +#ifndef TVM_TARGET_METADATA_UTILS_H_ +#define TVM_TARGET_METADATA_UTILS_H_ + +#include +#include +#include + +#include +#include +#include + +#include "metadata.h" + +namespace tvm { +namespace codegen { + +std::string address_from_parts(const std::vector& parts); +static constexpr const char* kMetadataGlobalSymbol = "kTvmgenMetadata"; + +class MetadataQueuer : public AttrVisitor { + public: + using QueueItem = std::tuple; + explicit MetadataQueuer(std::vector* queue); + + void Visit(const char* key, double* value) final; + void Visit(const char* key, int64_t* value) final; + void Visit(const char* key, uint64_t* value) final; + void Visit(const char* key, int* value) final; + void Visit(const char* key, bool* value) final; + void Visit(const char* key, std::string* value) final; + void Visit(const char* key, DataType* value) final; + void Visit(const char* key, runtime::NDArray* value) final; + void Visit(const char* key, void** value) final; + + void Visit(const char* key, ObjectRef* value) final; + + private: + std::vector* queue_; + std::vector address_parts_; +}; + +} // namespace codegen +} // namespace tvm + +#endif // TVM_TARGET_METADATA_UTILS_H_ diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index b8a6f789ddaf..018df5d70af9 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -30,7 +30,6 @@ #include #include -#include #include #include #include @@ -41,6 +40,7 @@ #include "../../support/str_escape.h" #include "../func_registry_generator.h" #include "../metadata.h" +#include "../metadata_utils.h" #include "codegen_source_base.h" namespace tvm { @@ -524,65 +524,6 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { } }; -static std::string address_from_parts(const std::vector& parts) { - std::stringstream ss; - for (unsigned int i = 0; i < parts.size(); ++i) { - if (i > 0) { - ss << "_"; - } - ss << parts[i]; - } - return ss.str(); -} - -class MetadataQueuer : public AttrVisitor { - public: - using QueueItem = std::tuple; - explicit MetadataQueuer(std::vector* queue) : queue_{queue} {} - - void Visit(const char* key, double* value) final {} - void Visit(const char* key, int64_t* value) final {} - void Visit(const char* key, uint64_t* value) final {} - void Visit(const char* key, int* value) final {} - void Visit(const char* key, bool* value) final {} - void Visit(const char* key, std::string* value) final {} - void Visit(const char* key, DataType* value) final {} - void Visit(const char* key, runtime::NDArray* value) final {} - void Visit(const char* key, void** value) final {} - - void Visit(const char* key, ObjectRef* value) final { - address_parts_.push_back(key); - if (value->as() != nullptr) { - auto metadata = Downcast(*value); - const runtime::metadata::MetadataArrayNode* arr = - value->as(); - if (arr != nullptr) { - for (unsigned int i = 0; i < arr->array.size(); i++) { - ObjectRef o = arr->array[i]; - if (o.as() != nullptr) { - std::stringstream ss; - ss << i; - address_parts_.push_back(ss.str()); - runtime::metadata::MetadataBase metadata = Downcast(o); - ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); - address_parts_.pop_back(); - } - } - } else { - ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); - } - - queue_->push_back(std::make_tuple(address_from_parts(address_parts_), - Downcast(*value))); - } - address_parts_.pop_back(); - } - - private: - std::vector* queue_; - std::vector address_parts_; -}; - class MetadataSerializer : public AttrVisitor { public: static constexpr const char* kGlobalSymbol = "kTvmgenMetadata"; From f6b1314607a68040fdc483392189c7afbe1a9cd1 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 22 Mar 2022 08:52:01 -0700 Subject: [PATCH 04/32] Add DiscoverArraysVisitor to metadata_utils --- src/target/metadata_utils.cc | 107 +++++++++++++++++++++++++++++------ src/target/metadata_utils.h | 68 ++++++++++++++++++++-- 2 files changed, 153 insertions(+), 22 deletions(-) diff --git a/src/target/metadata_utils.cc b/src/target/metadata_utils.cc index 92724e1636c8..1034fd570007 100644 --- a/src/target/metadata_utils.cc +++ b/src/target/metadata_utils.cc @@ -25,8 +25,9 @@ namespace tvm { namespace codegen { +namespace metadata { -MetadataQueuer::MetadataQueuer(std::vector* queue) : queue_{queue} {} +DiscoverArraysVisitor::DiscoverArraysVisitor(std::vector* queue) : queue_{queue} {} std::string address_from_parts(const std::vector& parts) { std::stringstream ss; @@ -39,28 +40,25 @@ std::string address_from_parts(const std::vector& parts) { return ss.str(); } -void MetadataQueuer::Visit(const char* key, double* value) {} -void MetadataQueuer::Visit(const char* key, int64_t* value) {} -void MetadataQueuer::Visit(const char* key, uint64_t* value) {} -void MetadataQueuer::Visit(const char* key, int* value) {} -void MetadataQueuer::Visit(const char* key, bool* value) {} -void MetadataQueuer::Visit(const char* key, std::string* value) {} -void MetadataQueuer::Visit(const char* key, DataType* value) {} -void MetadataQueuer::Visit(const char* key, runtime::NDArray* value) {} -void MetadataQueuer::Visit(const char* key, void** value) {} - -void MetadataQueuer::Visit(const char* key, ObjectRef* value) { +void DiscoverArraysVisitor::Visit(const char* key, double* value) {} +void DiscoverArraysVisitor::Visit(const char* key, int64_t* value) {} +void DiscoverArraysVisitor::Visit(const char* key, uint64_t* value) {} +void DiscoverArraysVisitor::Visit(const char* key, int* value) {} +void DiscoverArraysVisitor::Visit(const char* key, bool* value) {} +void DiscoverArraysVisitor::Visit(const char* key, std::string* value) {} +void DiscoverArraysVisitor::Visit(const char* key, DataType* value) {} +void DiscoverArraysVisitor::Visit(const char* key, runtime::NDArray* value) {} +void DiscoverArraysVisitor::Visit(const char* key, void** value) {} + +void DiscoverArraysVisitor::Visit(const char* key, ObjectRef* value) { address_parts_.push_back(key); if (value->as() != nullptr) { auto metadata = Downcast(*value); const runtime::metadata::MetadataArrayNode* arr = value->as(); - std::cout << "Is array? " << arr << std::endl; if (arr != nullptr) { for (unsigned int i = 0; i < arr->array.size(); i++) { ObjectRef o = arr->array[i]; - std::cout << "queue-visiting array element " << i << ": " << o->type_index() << " (" - << o.operator->() << ")" << std::endl; if (o.as() != nullptr) { std::stringstream ss; ss << i; @@ -70,15 +68,88 @@ void MetadataQueuer::Visit(const char* key, ObjectRef* value) { address_parts_.pop_back(); } } + + queue_->push_back(std::make_tuple(address_from_parts(address_parts_), + Downcast(metadata))); } else { ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); } - - queue_->push_back(std::make_tuple(address_from_parts(address_parts_), - Downcast(*value))); } address_parts_.pop_back(); } +void DiscoverComplexTypesVisitor::Visit(const char* key, double* value) {} +void DiscoverComplexTypesVisitor::Visit(const char* key, int64_t* value) {} +void DiscoverComplexTypesVisitor::Visit(const char* key, uint64_t* value) {} +void DiscoverComplexTypesVisitor::Visit(const char* key, int* value) {} +void DiscoverComplexTypesVisitor::Visit(const char* key, bool* value) {} +void DiscoverComplexTypesVisitor::Visit(const char* key, std::string* value) {} +void DiscoverComplexTypesVisitor::Visit(const char* key, DataType* value) {} +void DiscoverComplexTypesVisitor::Visit(const char* key, runtime::NDArray* value) {} +void DiscoverComplexTypesVisitor::Visit(const char* key, void** value) {} + +bool DiscoverComplexTypesVisitor::DiscoverType(std::string type_key) { + VLOG(2) << "DiscoverType " << type_key; + auto position_it = type_key_to_position_.find(type_key); + if (position_it != type_key_to_position_.end()) { + return false; + } + + queue_->emplace_back(tvm::runtime::metadata::MetadataBase()); + type_key_to_position_[type_key] = queue_->size() - 1; + return true; +} + +void DiscoverComplexTypesVisitor::DiscoverInstance(runtime::metadata::MetadataBase md) { + auto position_it = type_key_to_position_.find(md->GetTypeKey()); + ICHECK(position_it != type_key_to_position_.end()) + << "DiscoverInstance requires that DiscoverType has already been called: type_key=" + << md->GetTypeKey(); + + int queue_position = (*position_it).second; + if (!(*queue_)[queue_position].defined() && md.defined()) { + VLOG(2) << "DiscoverInstance " << md->GetTypeKey() << ":" << md; + (*queue_)[queue_position] = md; + } +} + +void DiscoverComplexTypesVisitor::Visit(const char* key, ObjectRef* value) { + ICHECK_NOTNULL(value->as()); + + auto metadata = Downcast(*value); + const runtime::metadata::MetadataArrayNode* arr = + value->as(); + + if (arr == nullptr) { + VLOG(2) << "No array, object-traversing " << metadata->GetTypeKey(); + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + DiscoverType(metadata->GetTypeKey()); + DiscoverInstance(metadata); + return; + } + + if (arr->kind != tvm::runtime::metadata::MetadataKind::kMetadata) { + return; + } + + bool needs_instance = DiscoverType(arr->type_key); + for (unsigned int i = 0; i < arr->array.size(); i++) { + tvm::runtime::metadata::MetadataBase o = + Downcast(arr->array[i]); + if (needs_instance) { + DiscoverInstance(o); + needs_instance = false; + } + ReflectionVTable::Global()->VisitAttrs(o.operator->(), this); + } +} + +void DiscoverComplexTypesVisitor::Discover(runtime::metadata::MetadataBase metadata) { + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + DiscoverType(metadata->GetTypeKey()); + DiscoverInstance(metadata); +} + +} // namespace metadata } // namespace codegen } // namespace tvm diff --git a/src/target/metadata_utils.h b/src/target/metadata_utils.h index c305a0671c07..4902f330a239 100644 --- a/src/target/metadata_utils.h +++ b/src/target/metadata_utils.h @@ -31,19 +31,30 @@ #include #include #include +#include #include "metadata.h" namespace tvm { namespace codegen { +namespace metadata { + std::string address_from_parts(const std::vector& parts); static constexpr const char* kMetadataGlobalSymbol = "kTvmgenMetadata"; -class MetadataQueuer : public AttrVisitor { +/*! + * \brief Post-order traverse metadata to discover arrays which need to be forward-defined. + */ +class DiscoverArraysVisitor : public AttrVisitor { public: - using QueueItem = std::tuple; - explicit MetadataQueuer(std::vector* queue); + /*! \brief Models a single array discovered in this visitor. + * Conatains two fields: + * 0. An address which uniquely identifies the array in this Metadata instance. + * 1. The discovered MetadataArray. + */ + using DiscoveredArray = std::tuple; + explicit DiscoverArraysVisitor(std::vector* queue); void Visit(const char* key, double* value) final; void Visit(const char* key, int64_t* value) final; @@ -58,10 +69,59 @@ class MetadataQueuer : public AttrVisitor { void Visit(const char* key, ObjectRef* value) final; private: - std::vector* queue_; + /*! \brief The queue to be filled with discovered arrays. */ + std::vector* queue_; + + /*! \brief Tracks the preceding address pieces. */ std::vector address_parts_; }; +/*! + * \brief Post-order traverse Metadata to discover all complex types which need to be + * forward-defined. This visitor finds one defined() MetadataBase instance for each unique subclass + * present inside Metadata in the order in which the subclass was first discovered. + */ +class DiscoverComplexTypesVisitor : public AttrVisitor { + public: + /*! \brief Models a single complex type discovered in this visitor. + * Contains two fields: + * 0. The struct_name for this Metadata instance. + * 1. The discovered MetadataArray. + */ + using DiscoveredComplexType = std::tuple; + + /*! \brief Construct a new instance. + * \param queue An ordered map which holds the + */ + explicit DiscoverComplexTypesVisitor(std::vector* queue) + : queue_{queue} {} + + void Visit(const char* key, double* value) final; + void Visit(const char* key, int64_t* value) final; + void Visit(const char* key, uint64_t* value) final; + void Visit(const char* key, int* value) final; + void Visit(const char* key, bool* value) final; + void Visit(const char* key, std::string* value) final; + void Visit(const char* key, DataType* value) final; + void Visit(const char* key, runtime::NDArray* value) final; + void Visit(const char* key, void** value) final; + + void Visit(const char* key, ObjectRef* value) final; + + void Discover(runtime::metadata::MetadataBase metadata); + + private: + bool DiscoverType(std::string type_key); + + void DiscoverInstance(runtime::metadata::MetadataBase md); + + std::vector* queue_; + + /*! \brief map type_index to index in queue_. */ + std::unordered_map type_key_to_position_; +}; + +} // namespace metadata } // namespace codegen } // namespace tvm From e8e93fc12e168923859f533eaa8b465f6c36b910 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 22 Mar 2022 08:53:30 -0700 Subject: [PATCH 05/32] Fill DLTensor metadata in LegalizePackedCalls. --- src/relay/backend/aot_executor_codegen.cc | 440 +++++++++++++++----- src/target/llvm/codegen_cpu.cc | 8 +- src/target/metadata_utils.h | 2 +- src/tir/transforms/legalize_packed_calls.cc | 62 ++- tests/python/relay/aot/test_crt_aot.py | 2 +- 5 files changed, 378 insertions(+), 136 deletions(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 542bcd163995..515bf4c28669 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -260,9 +260,168 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { std::vector return_ttypes_; }; -/*! \brief Code generator for AOT executor */ -class AOTExecutorCodegen : public MixedModeVisitor { - protected: +namespace { + +/*! + * \brief Utility function to convert a concrete integer to a PrimExpr. + * \param num the number to convert + * \return PrimExpr representing num + */ +inline PrimExpr ConstInt32(int32_t num) { + ICHECK_LE(num, std::numeric_limits::max()); + return tir::make_const(DataType::Int(32), static_cast(num)); +} + +/*! + * \brief Emit a call to the C Device API. + * \param device_name Name of the device, used to prefix the function name. + * \param hook Name of the Device API function. + * \param context void* context arg passed to this API function. + */ +tir::Stmt MakeDeviceHookCall(const std::string& device_name, const std::string& hook, + PrimExpr context) { + Array sections = {"Device", device_name, hook}; + String device_hook = ToCFunctionStyle(PrefixName(sections)); + + return tir::Evaluate(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), + {tvm::tir::StringImm(device_hook), context})); +} +} // namespace + +class AOTCallGenerator { + public: + explicit AOTCallGenerator(std::string func_name) + : func_name_{func_name}, args_{tvm::tir::StringImm(func_name)} {} + + tir::Var PushArg(PrimExpr arg) { + if (!arg->IsInstance()) { + arg = MakeLetBind(arg); + } + args_.push_back(arg); + return Downcast(arg); + } + + void PushStackDLTensor(const TensorType& ttype, PrimExpr data) { + auto dltensor_var = MakeLetBind(StackAlloca("array", 1)); + auto shape_var = MakeLetBind(StackAlloca("shape", ttype->shape.size())); + + // Populate DLTensor.data + prep_stmts_.push_back( + tir::Evaluate(tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), + {dltensor_var, 0, tir::builtin::kArrData, data}))); + + // Populate DLTensor.device + prep_stmts_.push_back( + tir::Evaluate(tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), + {dltensor_var, 0, tir::builtin::kArrDeviceType, kDLCPU}))); + prep_stmts_.push_back( + tir::Evaluate(tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), + {dltensor_var, 0, tir::builtin::kArrDeviceId, 0}))); + + // Populate DLTensor.ndim + prep_stmts_.push_back(tir::Evaluate(tvm::tir::Call( + DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), + {dltensor_var, 0, tir::builtin::kArrNDim, static_cast(ttype->shape.size())}))); + + // Populate DLTensor.dtype + prep_stmts_.push_back( + tir::Evaluate(tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), + {dltensor_var, 0, tir::builtin::kArrTypeCode, + IntImm(DataType(kDLUInt, 8, 1), ttype->dtype.code())}))); + prep_stmts_.push_back( + tir::Evaluate(tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), + {dltensor_var, 0, tir::builtin::kArrTypeBits, + IntImm(DataType(kDLUInt, 8, 1), ttype->dtype.bits())}))); + prep_stmts_.push_back( + tir::Evaluate(tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), + {dltensor_var, 0, tir::builtin::kArrTypeLanes, + IntImm(DataType(kDLUInt, 16, 1), ttype->dtype.lanes())}))); + + // Populate DLTensor.shape + for (size_t i = 0; i < ttype->shape.size(); ++i) { + prep_stmts_.push_back(tvm::tir::Store( + shape_var, IntImm(DataType(kDLInt, 64, 1), Downcast(ttype->shape[i])->value), + IntImm(DataType(kDLUInt, 64, 1), i), tir::const_true())); + } + + prep_stmts_.push_back( + tir::Evaluate(tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), + {dltensor_var, 0, tir::builtin::kArrShape, shape_var}))); + + // Populate DLTensor.strides. DNS -- TODO actually pull correct byte_offset + prep_stmts_.push_back(tir::Evaluate(tvm::tir::Call( + DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), + {dltensor_var, 0, tir::builtin::kArrStrides, IntImm(DataType(kDLUInt, 64, 1), 0)}))); + + // Populate DLTensor.byte_offset. DNS -- TODO actually pull correct byte_offset + prep_stmts_.push_back(tir::Evaluate(tvm::tir::Call( + DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), + {dltensor_var, 0, tir::builtin::kArrByteOffset, IntImm(DataType(kDLUInt, 64, 1), 0)}))); + + args_.push_back(dltensor_var); + } + + void PushStackDLTensors(const Expr& expr, std::vector sids) { + const TupleNode* t = expr.as(); + if (t != nullptr) { + CHECK_EQ(sids.size(), t->fields.size()) << "Relay tuple does not map 1:1 into TIR; AOT can't " + "handle this type of Relay Expr in a CallNode."; + for (size_t i = 0; i < sids.size(); i++) { + PushStackDLTensor(Downcast(t->fields[i]->checked_type()), sids[i]); + } + } else { + PushStackDLTensor(Downcast(expr->checked_type()), sids[0]); + } + } + + tir::Stmt GenerateUnpacked(std::string device_name, PrimExpr device_context) { + auto make_call = [this] { + return tir::Evaluate( + tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), args_)); + }; + if (device_context.defined()) { + tir::Var context_var = PushArg(device_context); + return Generate(tir::SeqStmt({ + MakeDeviceHookCall(device_name, "Open", context_var), + make_call(), + MakeDeviceHookCall(device_name, "Close", context_var), + })); + } else { + return Generate(make_call()); + } + } + + tir::Stmt GeneratePacked() { + return Generate( + tir::Evaluate(tvm::tir::Call(DataType::Int(32), tir::builtin::tvm_call_packed(), args_))); + } + + tir::Stmt GenerateCPacked() { + // call_cpacked calling convention does not use a context + PushArg(tir::make_zero(DataType::Handle())); + return Generate( + tir::Evaluate(tvm::tir::Call(DataType::Int(32), tir::builtin::tvm_call_cpacked(), args_))); + } + + private: + tir::Stmt Generate(tir::Stmt call_stmts) { + tir::Stmt body = tir::SeqStmt::Flatten(prep_stmts_, call_stmts); + + for (auto bind : let_binds_) { + body = tir::LetStmt(bind.first, bind.second, body); + } + + return body; + } + + tir::Var MakeLetBind(PrimExpr expr) { + std::stringstream ss; + ss << func_name_ << "_let" << let_binds_.size(); + tir::Var v{ss.str(), DataType::Handle()}; + let_binds_.emplace_back(std::make_pair(v, expr)); + return v; + } + /*! * \brief Utility function to allocate a DLTensor or TVMValue * \param type the type of allocation @@ -274,15 +433,81 @@ class AOTExecutorCodegen : public MixedModeVisitor { return tir::Call(DataType::Handle(), tir::builtin::tvm_stack_alloca(), args); } - /*! - * \brief Utility function to convert a concrete integer to a PrimExpr. - * \param num the number to convert - * \return PrimExpr representing num - */ - inline PrimExpr ConstInt32(int32_t num) { - ICHECK_LE(num, std::numeric_limits::max()); - return tir::make_const(DataType::Int(32), static_cast(num)); - } + std::string func_name_; + tvm::Array args_; + std::vector> let_binds_; + Array prep_stmts_; +}; + +/*! \brief Code generator for AOT executor */ +class AOTExecutorCodegen : public MixedModeVisitor { + protected: + /*! \brief Describes the type of kernel call emitted. */ + enum CallType { + /*! + * \brief Emit PackedFunc calls bound just-in-time using TVMBackend* functions. + * + * When this type is selected, assumes all operators must be called via TVMFuncCall. Given the + * implementation of TVMFuncCall in the C++ runtime, this in practice implies that those + * functions are of type TVMBackendPackedCFunc. + * + * The following code is emitted at call sites to call a function named `func`: + * void* func_ptr = TVMBackendGetFuncFromEnv("func"); + * TVMFuncCall(func_ptr, values, tcodes, num_args, ret_values, ret_tcodes) + * + * The arguments given to the tir::Call node are encoded into `values`, `tcodes`, and `num_args` + * by LowerTVMBuiltin TIR transform. + * + * If `resource_handle` is passed to `func`, it is determined by TVMFuncCall (often, + * `resource_handle` is registered with the C++ runtime to provide a `this` equivalent when + * `func` is implemented in C). + * + * Compatible with both C++ and C runtimes, implemented with the C runtime only. + */ + kPacked, // Emit tir.call_packed and wrap all arguments in DLTensor. + + /*! + * \brief Directly call a TVMBackendPackedCFunc named according to the tir::Call. + * + * When this type is selected, assumes all operators are implemented in functions of type + * `TVMBackendPackedCFunc` and should be called directly. That is, presumes at the time of + * downstream compilation that there is a symbol named after the 0th arg to tir::Call of + * type `TVMBackendPackedCFunc`. This situation should occur when target_host == target. + * + * The following code is emitted at call sites to call a function named `func`: + * func(values, tcodes, num_args, ret_values, ret_tcodes, resource_handle) + * + * The arguments given to the tir::Call node are encoded into `values`, `tcodes`, and `num_args` + * by LowerTVMBuiltin TIR transform. + * + * `resource_handle` is encoded as the final argument to the tir::Call node. In practice, it is + * always the device context parameter when not null. At present, the implementation does not + * support forwarding device context parameters to CPacked. + * + * Compatible with the C runtime and C++ runtime (so long as target_host == target). Implemented + * in the same scenarios. + */ + kCPacked, // Emit tir.call_cpacked and wrap all arguments in DLTensor. + + /*! \brief Directly call a function accepting the `data` arrays as args. + * + * When this type is selected, assumes all operaotrs are implemented in C functions whose + * arguments are 1-to-1 with those in the tir::Call. DLTensor arguments are encoded as just the + * `data` parameters (i.e. no DLTensor object is passed along). + * + * The following code is emitted at call sites to a function named `func`: + * func(void* arg0, void* arg1, ..., void* argN) // no resource_handle + * -or- + * func(void* arg0, void* arg1, ..., void* argN, void* resource_handle) // with resource_handle + * + * `resource_handle` is encoded as the final argument to the tir::Call node. In practice, it is + * always the device context parameter when not null. + * + * Compatible with the C runtime and C++ runtime (so long as target_host == target). Implemented + * with the C runtime only. + */ + kUnpacked, // Emit tir.call_extern passing only the `data` part of DLTensors. + }; /*! * \brief Return a vector of variables that represents the sids for the given Relay Expr @@ -323,6 +548,21 @@ class AOTExecutorCodegen : public MixedModeVisitor { } } + /*! + * \brief Reverse lookup the device name in devices_ map. + * \param device_context Value in devices_ to find. + * \return Key matching device_context in devices_. + */ + std::string FindDeviceName(tir::Var device_context) { + for (std::pair kv : devices_) { + if (kv.second->name_hint == device_context->name_hint) { + return kv.first; + } + } + ICHECK(false) << "Did not find a device name associated with " << device_context; + return ""; + } + void PushArgs(const Expr& expr, const std::vector& sids, Array* args) { const TupleNode* t = expr.as(); if (t != nullptr) { @@ -338,12 +578,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { * returns the passed Call */ tir::Call AddCheckReturn(tir::Call existing_call) { - if (use_unpacked_api_) { - Array args = {ConstInt32(0), ConstInt32(-1), existing_call}; - return tir::Call(DataType::Int(32), tir::builtin::tvm_check_return(), args); - } - - return existing_call; + Array args = {ConstInt32(0), ConstInt32(-1), existing_call}; + return tir::Call(DataType::Int(32), tir::builtin::tvm_check_return(), args); } /*! @@ -378,56 +614,60 @@ class AOTExecutorCodegen : public MixedModeVisitor { auto result_expr_sid = PackSid(result_expr); PushArgs(result_expr, result_expr_sid, &args); - // Choose call style based on Runtime/Executor config. - Op calling_pattern; - if (use_unpacked_api_) { - calling_pattern = tvm::tir::builtin::call_extern(); - } else if (use_call_cpacked_) { - calling_pattern = tvm::tir::builtin::tvm_call_cpacked(); - } else { - calling_pattern = tvm::tir::builtin::tvm_call_packed(); - } - GlobalVar global_var = call_lowered_props.lowered_func; - tir::Var empty_var("no_device_context", DataType::Handle()); bool has_c_device_api_context = device_contexts_.count(global_var) != 0; + bool call_device_hooks = false; + tir::Var device_context; + tir::Stmt func_call; + + switch (call_type_) { + case CallType::kUnpacked: { + // call_extern calling convention with optional context + if (has_c_device_api_context) { + device_context = device_contexts_.Get(global_var).value(); + args.push_back(device_context); + } + func_call = tir::Evaluate(AddCheckReturn( + tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), args))); + break; + } + case CallType::kCPacked: { + if (has_c_device_api_context) { + device_context = device_contexts_.Get(global_var).value(); + args.push_back(device_context); + } else { + // NOTE: LowerTVMBuiltin expects some device_context placeholder. + args.push_back(tir::make_zero(DataType::Handle())); + } + func_call = tir::Evaluate( + tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_cpacked(), args)); + create_func_call_stmts.push_back(func_call); + break; + } + case CallType::kPacked: { + // call_packed does not accept a device context. + CHECK(!has_c_device_api_context) << "CallType::kPacked does not accept a device context"; + func_call = tir::Evaluate(AddCheckReturn( + tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_packed(), args))); + create_func_call_stmts.push_back(func_call); + break; + } + default: + ICHECK(false) << "Unknown CallType: " << call_type_; + } - // The device context is passed to the operator in one of the following calling patterns: - // * Unpacked / direct function call with context: - // operator(arg0, arg1, device_context); - // * Unpacked / direct function call without context: - // operator(arg0, arg1); - // * Type-erased packed function call with context: - // operator(args, type_codes, int num_args, out_ret_value, out_ret_tcode, - // device_context_my_device) - // * Type-erased packed function call without context (we create an empty var for codegen): - // operator(args, type_codes, int num_args, out_ret_value, out_ret_tcode, - // no_device_context) - if (has_c_device_api_context) { - // call_extern calling convention with context - tir::Var context = device_contexts_.Get(global_var).value(); - args.push_back(context); - - tir::Evaluate func_call( - AddCheckReturn(tvm::tir::Call(DataType::Int(32), calling_pattern, args))); - create_func_call_stmts.push_back(tir::SeqStmt({ - GenerateDeviceHook(context, "Open"), + ICHECK(func_call.defined()) << "Must define func_call"; + + if (call_device_hooks) { + func_call = tir::SeqStmt(Array({ + GenerateDeviceHook(device_context, "Open"), func_call, - GenerateDeviceHook(context, "Close"), + GenerateDeviceHook(device_context, "Close"), })); - } else if (use_call_cpacked_) { - // call_cpacked calling convention needs a blank context - args.push_back(tir::make_zero(DataType::Handle())); - tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args)); - create_func_call_stmts.push_back(func_call); - } else { - // call_extern calling convention without context - tir::Evaluate func_call( - AddCheckReturn(tvm::tir::Call(DataType::Int(32), calling_pattern, args))); - create_func_call_stmts.push_back(func_call); } - tir::Stmt body = tir::SeqStmt(create_func_call_stmts); + tir::Stmt body = tir::SeqStmt({func_call}); + LOG(INFO) << "CreateFuncCall: " << call_lowered_props.lowered_func->name_hint << " -> " << body; stmts_.push_back(body); } @@ -517,12 +757,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { return it.second->name_hint == context->name_hint; }); const String& device_name = (*it).first; - Array sections = {"Device", device_name, hook}; - String device_hook = ToCFunctionStyle(PrefixName(sections)); - - return tir::Evaluate( - AddCheckReturn(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), - {tvm::tir::StringImm(device_hook), context}))); + return MakeDeviceHookCall(device_name, hook, context); } /*! @@ -855,30 +1090,10 @@ class AOTExecutorCodegen : public MixedModeVisitor { /*! \brief target host */ Target target_host_; /*! - * \brief unpacked api toggle - * When set to true, the generated code will use unpacked calls to functions: - * func(void* arg0, void* arg1) - * Rather than packed calls (in which arg0 and arg1 are in `arg_values`). - * func(TVMValue* arg_values, int* arg_type_codes, int num_args, ...) - * Defaults to using the packed calling convention - * - * Unpacked API is supported when runtime == "c" and interface_api is "c". + * \brief The type of kernel call to be emitted. + * See CallType for more documentation. */ - Bool use_unpacked_api_; - /*! - * \brief cpacked api toggle - * When set to true, the generated code will use call_cpacked to call functions directly, assuming - * they exist in a DSO-exportable module: - * func(...) - * Rather than through the traditional call_packed calls, which should use function pointers - * looked-up through TVMBackendGetFuncFromEnv: - * TVMBackendPackedCFunc* func_ptr = TVMBackendGetFuncFromEnv("func"); - * func_ptr(...) - * Defaults to using the packed calling convention - * - * call_cpacked is required when runtime is "c++" and supported when runtime is "c" - */ - Bool use_call_cpacked_; + CallType call_type_; /*! * \brief parameters (i.e. ConstantNodes found in the graph). @@ -907,11 +1122,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { public: AOTExecutorCodegen(runtime::Module* mod, const tec::TargetMap& targets, Target target_host) - : mod_(mod), - targets_(targets), - target_host_(target_host), - use_unpacked_api_(Bool(false)), - use_call_cpacked_(Bool(false)) {} + : mod_(mod), targets_(targets), target_host_(target_host) {} LoweredOutput Codegen(IRModule mod, relay::Function func, String mod_name) { VLOG_CONTEXT << "AOT"; @@ -923,23 +1134,36 @@ class AOTExecutorCodegen : public MixedModeVisitor { Runtime runtime_config = mod->GetAttr(tvm::attr::kRuntime).value(); Executor executor_config = mod->GetAttr(tvm::attr::kExecutor).value(); - String interface_api = executor_config->GetAttr("interface-api").value_or("packed"); + std::string interface_api = + executor_config->GetAttr("interface-api").value_or("packed"); Integer workspace_byte_alignment = executor_config->GetAttr("workspace-byte-alignment").value_or(16); - use_unpacked_api_ = executor_config->GetAttr("unpacked-api").value_or(Bool(false)); - use_call_cpacked_ = !use_unpacked_api_; + bool unpacked_api = executor_config->GetAttr("unpacked-api").value_or(Bool(false)); // Validate choice of use_unpacked_api_ and use_call_cpacked_ if (runtime_config->name == kTvmRuntimeCrt) { - ICHECK(interface_api == "packed" || static_cast(use_unpacked_api_) == true) - << "Either need interface_api == \"packed\" (got: " << interface_api - << ") or unpacked-api == true (got: " << use_unpacked_api_ - << ") when targeting c runtime"; + if (unpacked_api == true) { + call_type_ = CallType::kUnpacked; + } else if (unpacked_api == false && interface_api == "packed") { + call_type_ = CallType::kCPacked; + } else { + CHECK(interface_api == "packed" || unpacked_api == true) + << "Either need interface_api == \"packed\" (got: " << interface_api + << ") or unpacked-api == true (got: " << unpacked_api << ") when targeting c runtime"; + ICHECK(false) << "Unhandled executor option config: interface-api=" << interface_api + << ", unpacked-api=" << unpacked_api; + } } else if (runtime_config->name == kTvmRuntimeCpp) { - ICHECK(static_cast(use_unpacked_api_) == false) - << "Need unpacked-api == false (got: " << use_unpacked_api_ - << ") and interface-api == \"packed\" (got: " << interface_api - << ") when targeting c++ runtime"; + if (unpacked_api == false && interface_api == "packed") { + call_type_ = CallType::kCPacked; + } else { + CHECK(static_cast(unpacked_api) == false && interface_api == "packed") + << "Need unpacked-api == false (got: " << unpacked_api + << ") and interface-api == \"packed\" (got: " << interface_api + << ") when targeting c++ runtime"; + ICHECK(false) << "Unhandled executor option config: interface-api=" << interface_api + << ", unpacked-api=" << unpacked_api; + } } else { ICHECK(false) << "runtime_config (" << runtime_config->name << ") is not one of the expected values"; @@ -1037,7 +1261,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { // Legalize AOT if needed. This means that all the packed calls // need to be wrapped in TVMValues (unless use_unpacked_api is set) - if (!use_unpacked_api_) { + if (call_type_ == CallType::kCPacked || call_type_ == CallType::kPacked) { auto pack_calls = tir::transform::LegalizePackedCalls(); lowered_mod = pack_calls(lowered_mod); } @@ -1106,7 +1330,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { ret.metadata = ExecutorCodegenMetadata( inputs, input_tensor_types, output_var_names, output_tensor_types, pool_vars, devices, - runtime::kTvmExecutorAot, mod_name, interface_api, use_unpacked_api_, pool_var_info); + runtime::kTvmExecutorAot, mod_name, interface_api, unpacked_api, pool_var_info); return ret; } diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index fe9c2f27594c..eda96d0db50b 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -75,7 +75,9 @@ void CodeGenCPU::Init(const std::string& module_name, llvm::TargetMachine* tm, // TVMValue* out_ret_value, int* out_ret_tcode, // void* resource_handle); ftype_tvm_backend_packed_c_func_ = llvm::FunctionType::get( - t_int_, {t_void_p_, t_void_p_, t_int_, t_void_p_, t_void_p_, t_void_p_}, false); + t_int_, + {t_void_p_, t_int_->getPointerTo(), t_int_, t_void_p_, t_int_->getPointerTo(), t_void_p_}, + false); t_tvm_crt_func_registry_ = llvm::StructType::create( {t_char_->getPointerTo(), ftype_tvm_backend_packed_c_func_->getPointerTo()}); t_tvm_crt_module_ = llvm::StructType::create({t_tvm_crt_func_registry_->getPointerTo()}); @@ -850,10 +852,10 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& nargs -= 1; call_args.insert(call_args.end(), { builder_->CreateBitCast(arg_value, t_void_p_), - builder_->CreateBitCast(arg_tcode.addr, t_void_p_), + arg_tcode.addr, ConstInt32(nargs), builder_->CreateBitCast(ret_value, t_void_p_), - builder_->CreateBitCast(ret_tcode.addr, t_void_p_), + ret_tcode.addr, }); call_args.push_back(llvm::ConstantPointerNull::get(t_void_p_)); } diff --git a/src/target/metadata_utils.h b/src/target/metadata_utils.h index 4902f330a239..1486a885e502 100644 --- a/src/target/metadata_utils.h +++ b/src/target/metadata_utils.h @@ -30,8 +30,8 @@ #include #include -#include #include +#include #include "metadata.h" diff --git a/src/tir/transforms/legalize_packed_calls.cc b/src/tir/transforms/legalize_packed_calls.cc index 2d8b6681fa84..3163b11369b4 100644 --- a/src/tir/transforms/legalize_packed_calls.cc +++ b/src/tir/transforms/legalize_packed_calls.cc @@ -43,10 +43,9 @@ using InputMap = */ class PackedCallLegalizer : public StmtExprMutator { public: - Stmt Legalize(const InputMap& params, tir::Stmt body) { - inputs_ = params; - return StmtExprMutator::VisitStmt(body); - } + PackedCallLegalizer(IRModule m, const InputMap& inputs) : mod_{m}, inputs_{inputs} {} + + Stmt Legalize(tir::Stmt body) { return StmtExprMutator::VisitStmt(body); } Stmt VisitStmt_(const EvaluateNode* op) final { if (tir::is_const_int(op->value)) return StmtExprMutator::VisitStmt_(op); @@ -61,29 +60,45 @@ class PackedCallLegalizer : public StmtExprMutator { if (call->op.same_as(builtin::tvm_call_cpacked())) { Array packed_args{call->args[0]}; std::vector tvm_values; - for (unsigned i = 1; i < call->args.size(); i++) { + VLOG(2) << "Legalize call:" << call; + BaseFunc base_func = mod_->Lookup(Downcast(call->args[0])->value); + const PrimFuncNode* prim_func = base_func.as(); + VLOG(2) << " to func " << base_func; + for (unsigned i = 1; i < call->args.size() - 1; i++) { // No need to pack inputs of the prim_func if (inputs_[call->args[i]] == true) { packed_args.push_back(call->args[i]); } else { - // Pack the argument inside a TVMValue - std::stringstream ss; - ss << "tvm_value_" << tvm_value_index_++; - auto sid_array = tir::Var(ss.str(), DataType::Handle()); - tvm_values.push_back(sid_array); - - new_stmts.push_back(tir::Evaluate( - tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), - {sid_array, 0, tir::builtin::kArrData, call->args[i]}))); - new_stmts.push_back(tir::Evaluate( - tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), - {sid_array, 0, tir::builtin::kArrDeviceType, kDLCPU}))); - new_stmts.push_back(tir::Evaluate( - tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), - {sid_array, 0, tir::builtin::kArrDeviceId, 0}))); - packed_args.push_back(sid_array); + // Stack-allocate a DLTensor for this parameter. Note that LowerTVMBuiltin will collect + // all such stack-allocated tensors and minimize the storage needed by reusing + // DLTensors. + Array call_args{call->args[i]}; + if (prim_func != nullptr) { + Buffer param = prim_func->preflattened_buffer_map[prim_func->params[i - 1]]; + PrimExpr shape = tvm::tir::Call( + DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), param->shape); + Cast var_type(param->dtype, IntImm(DataType::Int(32), 0)); + call_args.push_back(shape /* shape */); + call_args.push_back(make_zero(DataType::Handle()) /* strides */); + call_args.push_back(tvm::IntImm(DataType::UInt(32), param->shape.size()) /* ndim */); + call_args.push_back(var_type /* carries dtype */); + call_args.push_back(param->elem_offset /* elem_offset */); + } else { + // When the PrimFunc cannot be found, most DLTensor information cannot be populated. + PrimExpr shape = tvm::tir::Call( + DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), Array()); + Cast var_type(DataType::Handle(), IntImm(DataType::Int(32), 0)); + call_args.push_back(shape /* shape */); + call_args.push_back(make_zero(DataType::Handle()) /* strides */); + call_args.push_back(tvm::IntImm(DataType::UInt(32), 0) /* ndim */); + call_args.push_back(var_type /* carries dtype */); + call_args.push_back(tvm::IntImm(DataType::UInt(64), 0) /* elem_offset */); + } + packed_args.push_back(tvm::tir::Call( + DataType::Handle(), tvm::tir::builtin::tvm_stack_make_array(), call_args)); } } + packed_args.push_back(call->args[call->args.size() - 1]); // push device_context // Evaluate the packed call new_stmts.push_back(tir::Evaluate(tir::Call(call->dtype, call->op, packed_args))); tir::Stmt call_stmt = tir::SeqStmt(new_stmts); @@ -99,6 +114,7 @@ class PackedCallLegalizer : public StmtExprMutator { } private: + IRModule mod_; InputMap inputs_; // Store the inputs to the primfunc that don't need to be packed. int tvm_value_index_; // Index of the actual tvm_value variable }; @@ -109,12 +125,12 @@ Pass LegalizePackedCalls() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); - // Create the + // Note which Var are inputs and exclude them from packing. InputMap inputs; for (auto i : f->params) { inputs[i] = true; } - n->body = PackedCallLegalizer().Legalize(inputs, std::move(n->body)); + n->body = PackedCallLegalizer(m, inputs).Legalize(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LegalizePackedCalls", {}); diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 51a503ecfe38..3c44d2bf1bc8 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -60,7 +60,7 @@ def test_error_c_interface_with_packed_api(): tvm.TVMError, match=re.escape( 'Either need interface_api == "packed" (got: c) or ' - "unpacked-api == true (got: (bool)0) when targeting " + "unpacked-api == true (got: 0) when targeting " "c runtime" ), ): From 14773837bcc9ed0f188656681344cb1f500f4dea Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 22 Mar 2022 08:57:24 -0700 Subject: [PATCH 06/32] Improve error message from Call asserts --- src/tir/ir/expr.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 07b341dfd2c7..f4dbc238c120 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -810,7 +810,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // Call Call::Call(DataType dtype, RelayExpr op, Array args, Span span) { for (size_t i = 0; i < args.size(); ++i) { - ICHECK(args[i].defined()); + ICHECK(args[i].defined()) << "arg " << i << " is not defined()"; } ObjectPtr node = make_object(); From 1641cfd4028be21c6f20dc05d6f78e9e5e86e918 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 22 Mar 2022 08:57:48 -0700 Subject: [PATCH 07/32] Pass non-String device_context down to codegen. * this is necessary to allow CodeGenCPU to emit calls that include resource_handle. --- src/tir/transforms/lower_tvm_builtin.cc | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index e474683b39fc..58c1b8b2d763 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -436,8 +436,14 @@ class BuiltinLower : public StmtExprMutator { // cpacked call resource_handle if (!use_string_lookup) { - tir::Var resource_handle = Downcast(op->args[arg_count]); - packed_args.push_back(StringImm(resource_handle->name_hint)); + PrimExpr last_arg = op->args[arg_count]; + const VarNode* var_node = last_arg.as(); + if (var_node != nullptr) { + tir::Var resource_handle = GetRef(var_node); + packed_args.push_back(StringImm(resource_handle->name_hint)); + } else { + packed_args.push_back(last_arg); + } } auto builtin_call = use_string_lookup ? builtin::tvm_call_packed_lowered() From 58da1082b69349e6cbfcbe4b6e7d26144f955fcf Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 22 Mar 2022 09:02:32 -0700 Subject: [PATCH 08/32] Scope usage of lvalue refs in LowerTVMBuiltin to avoid corrupt memory. --- src/tir/transforms/lower_tvm_builtin.cc | 83 +++++++++++++++---------- 1 file changed, 49 insertions(+), 34 deletions(-) diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 58c1b8b2d763..7ebc685090fc 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -109,11 +109,14 @@ class BuiltinLower : public StmtExprMutator { precheck.device_type_ = this->device_type_; precheck.alloca_scope_.emplace_back(); - auto& scope = precheck.alloca_scope_.back(); - scope.stack_shape = - decl_buffer({IntImm(DataType::Int(64), 0)}, DataType::Int(64), "stack_shape"); - scope.stack_tcode = - decl_buffer({IntImm(DataType::UInt(64), 0)}, DataType::Int(32), "stack_tcode"); + { + // NOTE: this scope reference is invalid after any mutation is applied to alloca_scope_. + auto& scope = precheck.alloca_scope_.back(); + scope.stack_shape = + decl_buffer({IntImm(DataType::Int(64), 0)}, DataType::Int(64), "stack_shape"); + scope.stack_tcode = + decl_buffer({IntImm(DataType::UInt(64), 0)}, DataType::Int(32), "stack_tcode"); + } precheck.VisitStmt(stmt); @@ -130,31 +133,35 @@ class BuiltinLower : public StmtExprMutator { } alloca_scope_.emplace_back(); - auto& scope = alloca_scope_.back(); - - // Initial check to identify maximum stack sizes. These are used - // to construct Buffer objects to hold the stack, which are then - // used when mutating. - scope.max_sizes = GetMaxStack(stmt); - - if (scope.max_sizes.shape_stack != -1) { - scope.stack_shape = decl_buffer({IntImm(DataType::Int(64), scope.max_sizes.shape_stack)}, - DataType::Int(64), "stack_shape"); - stmt = - LetStmt(scope.stack_shape->data, StackAlloca("shape", scope.max_sizes.shape_stack), stmt); - } + { + // NOTE: this scope reference is invalid after any mutation is applied to alloca_scope_. + auto& scope = alloca_scope_.back(); + + // Initial check to identify maximum stack sizes. These are used + // to construct Buffer objects to hold the stack, which are then + // used when mutating. + scope.max_sizes = GetMaxStack(stmt); + + if (scope.max_sizes.shape_stack != -1) { + scope.stack_shape = decl_buffer({IntImm(DataType::Int(64), scope.max_sizes.shape_stack)}, + DataType::Int(64), "stack_shape"); + stmt = LetStmt(scope.stack_shape->data, StackAlloca("shape", scope.max_sizes.shape_stack), + stmt); + } - if (scope.max_sizes.array_stack != 0) { - stmt = LetStmt(scope.stack_array, StackAlloca("array", scope.max_sizes.array_stack), stmt); - } + if (scope.max_sizes.array_stack != 0) { + stmt = LetStmt(scope.stack_array, StackAlloca("array", scope.max_sizes.array_stack), stmt); + } - if (scope.max_sizes.arg_stack != 0) { - scope.stack_tcode = decl_buffer({IntImm(DataType::UInt(64), scope.max_sizes.arg_stack)}, - DataType::Int(32), "stack_tcode"); - stmt = LetStmt(scope.stack_value, StackAlloca("arg_value", scope.max_sizes.arg_stack), stmt); + if (scope.max_sizes.arg_stack != 0) { + scope.stack_tcode = decl_buffer({IntImm(DataType::UInt(64), scope.max_sizes.arg_stack)}, + DataType::Int(32), "stack_tcode"); + stmt = + LetStmt(scope.stack_value, StackAlloca("arg_value", scope.max_sizes.arg_stack), stmt); - stmt = LetStmt(scope.stack_tcode->data, StackAlloca("arg_tcode", scope.max_sizes.arg_stack), - stmt); + stmt = LetStmt(scope.stack_tcode->data, StackAlloca("arg_tcode", scope.max_sizes.arg_stack), + stmt); + } } stmt = this->VisitStmt(stmt); @@ -169,14 +176,22 @@ class BuiltinLower : public StmtExprMutator { // allocate space to hold prepare stmts before s prep_seq_stack_.emplace_back(std::vector()); + auto scope_size = alloca_scope_.size(); auto stmt = StmtExprMutator::VisitStmt(s); - auto& scope = alloca_scope_.back(); - // This invariant asserts the assumption that - // make_stack_shape only happens within a call_packed. - // We could relax this in the future if we want to - // introduce root scope as a separate scope - ICHECK_EQ(scope.run_sizes.shape_stack, -1); - ICHECK_EQ(scope.run_sizes.array_stack, 0); + { + // NOTE: this scope reference is invalid after any mutation is applied to alloca_scope_. + auto& scope = alloca_scope_.back(); + // This invariant asserts the assumption that + // make_stack_shape only happens within a call_packed. + // We could relax this in the future if we want to + // introduce root scope as a separate scope + ICHECK_EQ(alloca_scope_.size(), scope_size) + << "alloca_scope_ length is different before and after recursion"; + ICHECK_EQ(scope.run_sizes.shape_stack, -1) + << "Expect no tvm_stack_make_shape outside of CallNodes"; + ICHECK_EQ(scope.run_sizes.array_stack, 0) + << "Expect no tvm_stack_make_array outside of CallNodes"; + } auto prep_seq = std::move(prep_seq_stack_.back()); prep_seq_stack_.pop_back(); From cab2df8854cbe93ac8c05bc3cc94c6c7c45a2523 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 22 Mar 2022 09:07:18 -0700 Subject: [PATCH 09/32] test fixes --- tests/python/relay/aot/test_cpp_aot.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/relay/aot/test_cpp_aot.py b/tests/python/relay/aot/test_cpp_aot.py index 5820efe6237a..0968870e67cc 100644 --- a/tests/python/relay/aot/test_cpp_aot.py +++ b/tests/python/relay/aot/test_cpp_aot.py @@ -120,7 +120,7 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(3, 3, 5, 5), assert (runner.get_output(0).asnumpy() == list(ref_outputs.values())[0]).all() -def test_mobilenet(): +def test_mobilenet(target_kind): ir_mod, params = testing.mobilenet.get_workload(batch_size=1) data_shape = [int(x) for x in ir_mod["main"].checked_type.arg_types[0].shape] data = np.random.uniform(size=data_shape).astype("float32") @@ -131,7 +131,7 @@ def test_mobilenet(): mod = tvm.relay.build( ir_mod, params=params, - target="c", + target=target_kind, executor=backend.Executor("aot", {"interface-api": "packed"}), ) From 43ad6d4e6844e8ac6208e512465bd31de527b82f Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 22 Mar 2022 09:08:03 -0700 Subject: [PATCH 10/32] Also fill preflattened_buffer_map (TODO, maybe don't do this) --- src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc index b73534090ab5..48ac25a1b6b9 100644 --- a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc +++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc @@ -223,8 +223,8 @@ PrimFunc PoolAllocationToOffsetConverter::CreatePrimFuncWithPoolParams( if (emit_tvmscript_printable_) { original_attrs = DictAttrs(); } - PrimFunc ret = PrimFunc(si.params, new_body, original_primfunc->ret_type, si.buffer_map, {}, - original_attrs); + PrimFunc ret = PrimFunc(si.params, new_body, original_primfunc->ret_type, si.buffer_map, + si.buffer_map, original_attrs); if (!emit_tvmscript_printable_) { ret = WithAttr(ret, tvm::attr::kPoolArgs, si.allocated_pool_params); } From acba246f76a686a3dce5e4131a896a59232b6557 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 22 Mar 2022 08:32:29 -0700 Subject: [PATCH 11/32] Fix C codegen. --- src/target/source/codegen_c_host.cc | 28 ++++++++++++++++++++++++---- src/target/source/source_module.cc | 5 ++++- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 0b74a1a1c4d9..d7a121c631f5 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -273,7 +273,7 @@ std::string CodeGenCHost::GetPackedName(const CallNode* op) { CodeGenCHost::FunctionInfo CodeGenCHost::GetFunctionInfo(const CallNode* op, bool has_resource_handle) { const StringImmNode* s = op->args[0].as(); - ICHECK(s != nullptr) << "tvm_call_{c}packed_lowered expects first argument as function name"; + ICHECK(s != nullptr) << "tvm_call_[c]packed_lowered expects first argument as function name"; int64_t begin = op->args[3].as()->value; int64_t end = op->args[4].as()->value; int64_t num_args = end - begin; @@ -281,10 +281,30 @@ CodeGenCHost::FunctionInfo CodeGenCHost::GetFunctionInfo(const CallNode* op, std::string func_name = s->value; if (has_resource_handle) { - std::string resource_handle_name = op->args[5].as()->value; - return {func_name, num_args - 1, resource_handle_name}; + const StringImmNode* resource_handle_var = op->args[5].as(); + if (resource_handle_var != nullptr) { + std::string resource_handle_name = resource_handle_var->value; + return {func_name, num_args - 1, resource_handle_name}; + } else { + // The final arg should be "(void*) NULL" to indicate the empty resource_handle. + num_args--; + + const CallNode* reinterpret_call = op->args[5].as(); + ICHECK_NE(reinterpret_call, (void*)nullptr) + << "At CallNode to " << s + << "arg 5: Expect either StringImm naming the resource_handle var from interface API or " + << "reinterpret(0); got: " << op->args[5]; + ICHECK_EQ(reinterpret_call->op, builtin::reinterpret()) + << "At CallNode to " << s + << "arg 5: Expect either StringImm naming the resource_handle var from interface API or " + << "reinterpret(0); got: " << op->args[5]; + ICHECK(is_zero(reinterpret_call->args[0])) << "At CallNode to " << s + << " arg 5: Expect either StringImm naming the " + "resource_handle var from interface API, or " + << "zero; got " << op->args[5]; + } } - return {func_name, num_args}; + return {func_name, num_args, "NULL"}; } void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 018df5d70af9..635775a7777d 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -732,9 +732,12 @@ class MetadataSerializer : public AttrVisitor { } // Finally, emit overall struct. - code_ << "const struct TVMMetadata " << metadata::kMetadataGlobalSymbol << " = {" << std::endl; + address_.push_back(metadata::kMetadataGlobalSymbol); + code_ << "const struct TVMMetadata " << metadata::address_from_parts(address_) << " = {" + << std::endl; Visit(nullptr, &metadata); code_ << "};" << std::endl; + address_.pop_back(); } std::string GetOutput() { return decl_.str() + code_.str(); } From fe910e9d72d00d96b80b74121a34f4a9edca490f Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 22 Mar 2022 08:33:15 -0700 Subject: [PATCH 12/32] Set USMP elem_offset to 0. --- .../usmp/transform/convert_pool_allocations_to_offsets.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc index 48ac25a1b6b9..ba5ab891baa4 100644 --- a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc +++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc @@ -200,8 +200,11 @@ PoolAllocationToOffsetConverter::ScopeInfo PoolAllocationToOffsetConverter::Upda int pool_size = all_pools_sizes_[pool_info]; String buffer_var_name = pool_ref_name + "_buffer_var"; - si.buffer_map.Set(pool_var, Buffer(buffer_var, elem_dtype, {pool_size}, {1}, 1, buffer_var_name, - 16, 1, BufferType::kDefault)); + si.buffer_map.Set(pool_var, + Buffer(buffer_var /* data */, elem_dtype /* dtype */, {pool_size} /* shape */, + {1} /* strides */, 0 /* elem_offset */, buffer_var_name /* name */, + 16 /* data_alignment */, 1 /* offset_factor */, + BufferType::kDefault /* buffer-type */)); } if (resource_handle) { si.params.push_back(resource_handle.value()); From 1558cf7c256ee09e7221bbc99b95311c8563b437 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 22 Mar 2022 08:33:27 -0700 Subject: [PATCH 13/32] Clarify calculation of byte_offset from elem_offset. --- src/tir/transforms/lower_tvm_builtin.cc | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 7ebc685090fc..9d0087cc7a0b 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -384,9 +384,12 @@ class BuiltinLower : public StmtExprMutator { make_const(DataType::UInt(16), dtype.lanes()))); // set byte offset int data_bytes = GetVectorBytes(dtype); - PrimExpr byte_offset = op->args[5]; - if (!is_zero(byte_offset)) { - byte_offset = byte_offset * make_const(byte_offset.dtype(), data_bytes); + PrimExpr elem_offset = op->args[5]; + PrimExpr byte_offset; + if (!is_zero(elem_offset)) { + byte_offset = elem_offset * make_const(elem_offset.dtype(), data_bytes); + } else { + byte_offset = elem_offset; } prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrByteOffset, cast(DataType::UInt(64), byte_offset))); @@ -582,6 +585,7 @@ Pass LowerTVMBuiltin() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); n->body = BuiltinLower().Build(n->body); + VLOG(2) << "LowerTVMBuiltin: " << f; return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerTVMBuiltin", {}); From 4290e28b33622d92572272f47655c48143f8e0ef Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 22 Mar 2022 08:35:32 -0700 Subject: [PATCH 14/32] fix tests --- tests/python/relay/aot/test_cpp_aot.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/python/relay/aot/test_cpp_aot.py b/tests/python/relay/aot/test_cpp_aot.py index 0968870e67cc..2a11e7e28748 100644 --- a/tests/python/relay/aot/test_cpp_aot.py +++ b/tests/python/relay/aot/test_cpp_aot.py @@ -24,8 +24,10 @@ import pytest import tvm +from tvm import IRModule +from tvm import relay from tvm.relay import backend, testing -from aot_test_utils import generate_ref_data +from aot_test_utils import AOT_DEFAULT_RUNNER, AOTTestModel, generate_ref_data, compile_and_run def test_error_c_interface(): @@ -39,18 +41,14 @@ def test_error_c_interface(): with pytest.raises( tvm.TVMError, match=re.escape( - 'Either need interface_api == "packed" (got: c) or ' - "unpacked-api == true (got: (bool)0) when targeting " - "c runtime" + 'Need unpacked-api == false (got: 0) and interface-api == "packed" (got: c) when ' + "targeting c++ runtime" ), ): - compile_and_run( - AOTTestModel( - module=IRModule.from_expr(func), inputs={}, outputs=generate_ref_data(func, {}) - ), - test_runner, - interface_api, - use_unpacked_api, + tvm.relay.build( + IRModule.from_expr(func), + target="llvm", + executor=backend.Executor("aot", {"interface-api": "c"}), ) @@ -120,14 +118,16 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(3, 3, 5, 5), assert (runner.get_output(0).asnumpy() == list(ref_outputs.values())[0]).all() -def test_mobilenet(target_kind): +def test_mobilenet(enable_usmp, target_kind): ir_mod, params = testing.mobilenet.get_workload(batch_size=1) data_shape = [int(x) for x in ir_mod["main"].checked_type.arg_types[0].shape] data = np.random.uniform(size=data_shape).astype("float32") inputs = {"data": data} ref_outputs = generate_ref_data(ir_mod, inputs, params) - with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + with tvm.transform.PassContext( + opt_level=3, config={"tir.disable_vectorize": True, "tir.usmp.enable": enable_usmp} + ): mod = tvm.relay.build( ir_mod, params=params, From 74283a7011fd3620f10bef5163ae458d88ceb09a Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Wed, 23 Mar 2022 17:20:49 -0700 Subject: [PATCH 15/32] Fix arm compile warning --- src/target/llvm/codegen_cpu.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index eda96d0db50b..5b54f5716d3a 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -862,6 +862,7 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& #if TVM_LLVM_VERSION >= 90 auto call_callee = llvm::FunctionCallee(callee_ftype, callee_value); #else + (void)callee_ftype; // use callee_ftype to avoid unused variable warning when using older LLVM. auto call_callee = callee_value; #endif llvm::Value* call = builder_->CreateCall(call_callee, call_args); From 37643855cf8b43fcd75bd00146dead965d2e5ed4 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Thu, 24 Mar 2022 11:57:46 -0700 Subject: [PATCH 16/32] Fix hexagon test. * previously I believe we required interface_api == "c", but this really means to generate C API bindings, and we are generating "packed" bindings. * I think "c" was chosen here because the distinction between interface-api and use-unpacked-api is confusing. "c" interface-api means to generate an entrypoint API for microcontrollers that accepts bare data buffers. "packed" interface-api means to generate a TVMBackendPackedCFunc entrypoint. use-unpacked-api forms the same determination for the operator functions. * A further confusion here is that there are two ways to call "packed" operator functions: tir.tvm_builtin_call_packed and tir.tvm_builtin_call_cpacked. This distinction describes whether or not to late-bind calls via TVMBackendGetFuncFromEnv. Right now, AOT only ever requires call_cpacked because target_host == target, and for all suitable target_host, we expect a single DSO-exportable runtime.Module. When we move away from this by introducing heterogeneous target support to AOT, we can use this as a condition to help us choose between call_cpacked and call_packed (and possibly add a compile-time option to assert it is call_cpacked, for situations where we really don't want call_packed). --- tests/python/contrib/test_hexagon/test_launcher.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/contrib/test_hexagon/test_launcher.py b/tests/python/contrib/test_hexagon/test_launcher.py index 3e72c38f1909..86574bfa3f9b 100644 --- a/tests/python/contrib/test_hexagon/test_launcher.py +++ b/tests/python/contrib/test_hexagon/test_launcher.py @@ -350,7 +350,7 @@ def test_aot_executor(hexagon_launcher, hexagon_session): params=params, target=tvm.target.Target(target_hexagon, host="c"), runtime=Runtime("cpp"), - executor=Executor("aot", {"unpacked-api": False, "interface-api": "c"}), + executor=Executor("aot", {"unpacked-api": False, "interface-api": "packed"}), ) # Uncomment this once the workaround is not needed. # lowered.export_library( @@ -442,7 +442,7 @@ def test_aot_executor_multiple_conv2d(hexagon_launcher, hexagon_session): params=params, target=tvm.target.Target(target_hexagon, host="c"), runtime=Runtime("cpp"), - executor=Executor("aot", {"unpacked-api": False, "interface-api": "c"}), + executor=Executor("aot", {"unpacked-api": False, "interface-api": "packed"}), ) # Uncomment this once the workaround is not needed. # lowered.export_library( From 48478c7421780d9f544b3820477ef762fb4b4c74 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Thu, 31 Mar 2022 14:37:00 -0700 Subject: [PATCH 17/32] Document T.preflattened_buffer --- python/tvm/script/tir/special_stmt.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index 0148bd0b4243..ea523f1810ec 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -871,7 +871,8 @@ class PreflattenedBufferMap(SpecialStmt): Example ------- .. code-block:: python - T.preflattened_buffer_map({}) + A0 = T.match_buffer(A, (48,), dtype="float32") + T.preflattened_buffer_map(A, (1, 4, 4, 3), elem_offset=1, align=4, dtype="float32") """ def __init__(self): From 4bf22e9ee7367751ce44b4b7a03a1b0dcd8c4bd4 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Thu, 31 Mar 2022 14:37:13 -0700 Subject: [PATCH 18/32] Fix test_aot_legalize_packed_calls --- src/tir/transforms/legalize_packed_calls.cc | 11 +--- .../unittest/test_aot_legalize_packed_call.py | 56 ++++++++++--------- 2 files changed, 30 insertions(+), 37 deletions(-) diff --git a/src/tir/transforms/legalize_packed_calls.cc b/src/tir/transforms/legalize_packed_calls.cc index 3163b11369b4..c8b1377162ef 100644 --- a/src/tir/transforms/legalize_packed_calls.cc +++ b/src/tir/transforms/legalize_packed_calls.cc @@ -55,11 +55,9 @@ class PackedCallLegalizer : public StmtExprMutator { // let B_packed = set_struct(tvm_value2, B) // let C_packed = set_struct(tvm_value3, C) // call_packed(f, A_packed, B_packed, C_packed) - std::vector new_stmts; if (call) { if (call->op.same_as(builtin::tvm_call_cpacked())) { Array packed_args{call->args[0]}; - std::vector tvm_values; VLOG(2) << "Legalize call:" << call; BaseFunc base_func = mod_->Lookup(Downcast(call->args[0])->value); const PrimFuncNode* prim_func = base_func.as(); @@ -100,14 +98,7 @@ class PackedCallLegalizer : public StmtExprMutator { } packed_args.push_back(call->args[call->args.size() - 1]); // push device_context // Evaluate the packed call - new_stmts.push_back(tir::Evaluate(tir::Call(call->dtype, call->op, packed_args))); - tir::Stmt call_stmt = tir::SeqStmt(new_stmts); - - // Allocate the TVMValues on the stack and define the variables - for (auto v : tvm_values) { - call_stmt = LetStmt(v, StackAlloca("array", 1), call_stmt); - } - return call_stmt; + return tir::Evaluate(tir::Call(call->dtype, call->op, packed_args)); } } return StmtExprMutator::VisitStmt_(op); diff --git a/tests/python/unittest/test_aot_legalize_packed_call.py b/tests/python/unittest/test_aot_legalize_packed_call.py index 54561ade23e4..a3d18c95475c 100644 --- a/tests/python/unittest/test_aot_legalize_packed_call.py +++ b/tests/python/unittest/test_aot_legalize_packed_call.py @@ -24,11 +24,22 @@ @tvm.script.ir_module class Module: + @T.prim_func + def tvm_test_cpacked(A : T.handle, B: T.handle, C: T.handle, device_context: T.handle) -> T.handle: + A_0 = T.match_buffer(A, (1,), dtype="float32") + A_0pre = T.preflattened_buffer(A_0, (1,), dtype="float32") + B_0 = T.match_buffer(B, (1,), dtype="float32") + B_0pre = T.preflattened_buffer(B_0, (1,), dtype="float32") + C_0 = T.match_buffer(C, (1,), dtype="float32") + C_0pre = T.preflattened_buffer(C_0, (1,), dtype="float32") + T.evaluate(C) + @T.prim_func def tir_packed_call() -> None: A = T.var("handle") B = T.var("handle") C = T.var("handle") + device_context = T.var("handle") # body T.evaluate( T.tvm_call_cpacked( @@ -36,6 +47,7 @@ def tir_packed_call() -> None: A, B, C, + device_context, dtype="int32", ) ) @@ -43,40 +55,30 @@ def tir_packed_call() -> None: @tvm.script.ir_module class Expected: + @T.prim_func + def tvm_test_cpacked(A : T.handle, B: T.handle, C: T.handle, device_context: T.handle) -> T.handle: + A_0 = T.match_buffer(A, (1,), dtype="float32") + A_0pre = T.preflattened_buffer(A_0, (1,), dtype="float32") + B_0 = T.match_buffer(B, (1,), dtype="float32") + B_0pre = T.preflattened_buffer(B_0, (1,), dtype="float32") + C_0 = T.match_buffer(C, (1,), dtype="float32") + C_0pre = T.preflattened_buffer(C_0, (1,), dtype="float32") + T.evaluate(C) + @T.prim_func def tir_packed_call() -> None: A = T.var("handle") B = T.var("handle") C = T.var("handle") + device_context = T.var("handle") # body - tvm_value_2 = T.var("handle") - tvm_value_1 = T.var("handle") - tvm_value_0 = T.var("handle") - with T.let(tvm_value_2, T.tvm_stack_alloca("array", 1, dtype="handle")): - with T.let(tvm_value_1, T.tvm_stack_alloca("array", 1, dtype="handle")): - with T.let(tvm_value_0, T.tvm_stack_alloca("array", 1, dtype="handle")): - T.evaluate(T.tvm_struct_set(tvm_value_0, 0, 1, A, dtype="handle")) - T.evaluate(T.tvm_struct_set(tvm_value_0, 0, 10, 1, dtype="handle")) - T.evaluate(T.tvm_struct_set(tvm_value_0, 0, 9, 0, dtype="handle")) - - T.evaluate(T.tvm_struct_set(tvm_value_1, 0, 1, B, dtype="handle")) - T.evaluate(T.tvm_struct_set(tvm_value_1, 0, 10, 1, dtype="handle")) - T.evaluate(T.tvm_struct_set(tvm_value_1, 0, 9, 0, dtype="handle")) - - T.evaluate(T.tvm_struct_set(tvm_value_2, 0, 1, C, dtype="handle")) - T.evaluate(T.tvm_struct_set(tvm_value_2, 0, 10, 1, dtype="handle")) - T.evaluate(T.tvm_struct_set(tvm_value_2, 0, 9, 0, dtype="handle")) - - T.evaluate( - T.tvm_call_cpacked( - "tvm_test_cpacked", - tvm_value_0, - tvm_value_1, - tvm_value_2, - dtype="int32", - ) - ) + T.evaluate(T.tvm_call_cpacked("tvm_test_cpacked", + T.tvm_stack_make_array(A, T.tvm_stack_make_shape(1, dtype="handle"), T.reinterpret(T.uint64(0), dtype="handle"), T.uint32(1), T.cast(0, dtype="float32"), 0, dtype="handle"), + T.tvm_stack_make_array(B, T.tvm_stack_make_shape(1, dtype="handle"), T.reinterpret(T.uint64(0), dtype="handle"), T.uint32(1), T.cast(0, dtype="float32"), 0, dtype="handle"), + T.tvm_stack_make_array(C, T.tvm_stack_make_shape(1, dtype="handle"), T.reinterpret(T.uint64(0), dtype="handle"), T.uint32(1), T.cast(0, dtype="float32"), 0, dtype="handle"), + device_context, + dtype="int32")) def test_aot_packed_call(): From 5de35ef8e5f87714cf546e4a68e443653743cfe8 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 5 Apr 2022 16:40:18 -0700 Subject: [PATCH 19/32] Address manupa comments --- src/relay/backend/aot_executor_codegen.cc | 191 +--------------------- src/target/llvm/codegen_cpu.cc | 28 +--- src/target/llvm/codegen_llvm.h | 11 ++ src/target/llvm/llvm_module.cc | 5 - src/target/metadata_module.cc | 2 +- src/target/metadata_utils.h | 22 ++- 6 files changed, 44 insertions(+), 215 deletions(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 515bf4c28669..eba4260e7f41 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -260,185 +260,6 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { std::vector return_ttypes_; }; -namespace { - -/*! - * \brief Utility function to convert a concrete integer to a PrimExpr. - * \param num the number to convert - * \return PrimExpr representing num - */ -inline PrimExpr ConstInt32(int32_t num) { - ICHECK_LE(num, std::numeric_limits::max()); - return tir::make_const(DataType::Int(32), static_cast(num)); -} - -/*! - * \brief Emit a call to the C Device API. - * \param device_name Name of the device, used to prefix the function name. - * \param hook Name of the Device API function. - * \param context void* context arg passed to this API function. - */ -tir::Stmt MakeDeviceHookCall(const std::string& device_name, const std::string& hook, - PrimExpr context) { - Array sections = {"Device", device_name, hook}; - String device_hook = ToCFunctionStyle(PrefixName(sections)); - - return tir::Evaluate(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), - {tvm::tir::StringImm(device_hook), context})); -} -} // namespace - -class AOTCallGenerator { - public: - explicit AOTCallGenerator(std::string func_name) - : func_name_{func_name}, args_{tvm::tir::StringImm(func_name)} {} - - tir::Var PushArg(PrimExpr arg) { - if (!arg->IsInstance()) { - arg = MakeLetBind(arg); - } - args_.push_back(arg); - return Downcast(arg); - } - - void PushStackDLTensor(const TensorType& ttype, PrimExpr data) { - auto dltensor_var = MakeLetBind(StackAlloca("array", 1)); - auto shape_var = MakeLetBind(StackAlloca("shape", ttype->shape.size())); - - // Populate DLTensor.data - prep_stmts_.push_back( - tir::Evaluate(tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), - {dltensor_var, 0, tir::builtin::kArrData, data}))); - - // Populate DLTensor.device - prep_stmts_.push_back( - tir::Evaluate(tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), - {dltensor_var, 0, tir::builtin::kArrDeviceType, kDLCPU}))); - prep_stmts_.push_back( - tir::Evaluate(tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), - {dltensor_var, 0, tir::builtin::kArrDeviceId, 0}))); - - // Populate DLTensor.ndim - prep_stmts_.push_back(tir::Evaluate(tvm::tir::Call( - DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), - {dltensor_var, 0, tir::builtin::kArrNDim, static_cast(ttype->shape.size())}))); - - // Populate DLTensor.dtype - prep_stmts_.push_back( - tir::Evaluate(tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), - {dltensor_var, 0, tir::builtin::kArrTypeCode, - IntImm(DataType(kDLUInt, 8, 1), ttype->dtype.code())}))); - prep_stmts_.push_back( - tir::Evaluate(tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), - {dltensor_var, 0, tir::builtin::kArrTypeBits, - IntImm(DataType(kDLUInt, 8, 1), ttype->dtype.bits())}))); - prep_stmts_.push_back( - tir::Evaluate(tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), - {dltensor_var, 0, tir::builtin::kArrTypeLanes, - IntImm(DataType(kDLUInt, 16, 1), ttype->dtype.lanes())}))); - - // Populate DLTensor.shape - for (size_t i = 0; i < ttype->shape.size(); ++i) { - prep_stmts_.push_back(tvm::tir::Store( - shape_var, IntImm(DataType(kDLInt, 64, 1), Downcast(ttype->shape[i])->value), - IntImm(DataType(kDLUInt, 64, 1), i), tir::const_true())); - } - - prep_stmts_.push_back( - tir::Evaluate(tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), - {dltensor_var, 0, tir::builtin::kArrShape, shape_var}))); - - // Populate DLTensor.strides. DNS -- TODO actually pull correct byte_offset - prep_stmts_.push_back(tir::Evaluate(tvm::tir::Call( - DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), - {dltensor_var, 0, tir::builtin::kArrStrides, IntImm(DataType(kDLUInt, 64, 1), 0)}))); - - // Populate DLTensor.byte_offset. DNS -- TODO actually pull correct byte_offset - prep_stmts_.push_back(tir::Evaluate(tvm::tir::Call( - DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), - {dltensor_var, 0, tir::builtin::kArrByteOffset, IntImm(DataType(kDLUInt, 64, 1), 0)}))); - - args_.push_back(dltensor_var); - } - - void PushStackDLTensors(const Expr& expr, std::vector sids) { - const TupleNode* t = expr.as(); - if (t != nullptr) { - CHECK_EQ(sids.size(), t->fields.size()) << "Relay tuple does not map 1:1 into TIR; AOT can't " - "handle this type of Relay Expr in a CallNode."; - for (size_t i = 0; i < sids.size(); i++) { - PushStackDLTensor(Downcast(t->fields[i]->checked_type()), sids[i]); - } - } else { - PushStackDLTensor(Downcast(expr->checked_type()), sids[0]); - } - } - - tir::Stmt GenerateUnpacked(std::string device_name, PrimExpr device_context) { - auto make_call = [this] { - return tir::Evaluate( - tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), args_)); - }; - if (device_context.defined()) { - tir::Var context_var = PushArg(device_context); - return Generate(tir::SeqStmt({ - MakeDeviceHookCall(device_name, "Open", context_var), - make_call(), - MakeDeviceHookCall(device_name, "Close", context_var), - })); - } else { - return Generate(make_call()); - } - } - - tir::Stmt GeneratePacked() { - return Generate( - tir::Evaluate(tvm::tir::Call(DataType::Int(32), tir::builtin::tvm_call_packed(), args_))); - } - - tir::Stmt GenerateCPacked() { - // call_cpacked calling convention does not use a context - PushArg(tir::make_zero(DataType::Handle())); - return Generate( - tir::Evaluate(tvm::tir::Call(DataType::Int(32), tir::builtin::tvm_call_cpacked(), args_))); - } - - private: - tir::Stmt Generate(tir::Stmt call_stmts) { - tir::Stmt body = tir::SeqStmt::Flatten(prep_stmts_, call_stmts); - - for (auto bind : let_binds_) { - body = tir::LetStmt(bind.first, bind.second, body); - } - - return body; - } - - tir::Var MakeLetBind(PrimExpr expr) { - std::stringstream ss; - ss << func_name_ << "_let" << let_binds_.size(); - tir::Var v{ss.str(), DataType::Handle()}; - let_binds_.emplace_back(std::make_pair(v, expr)); - return v; - } - - /*! - * \brief Utility function to allocate a DLTensor or TVMValue - * \param type the type of allocation - * \param num the number of variable to allocate on the stack - * \return PrimExpr representing the allocated object - */ - PrimExpr StackAlloca(std::string type, size_t num) { - Array args = {tir::StringImm(type), ConstInt32(num)}; - return tir::Call(DataType::Handle(), tir::builtin::tvm_stack_alloca(), args); - } - - std::string func_name_; - tvm::Array args_; - std::vector> let_binds_; - Array prep_stmts_; -}; - /*! \brief Code generator for AOT executor */ class AOTExecutorCodegen : public MixedModeVisitor { protected: @@ -578,7 +399,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { * returns the passed Call */ tir::Call AddCheckReturn(tir::Call existing_call) { - Array args = {ConstInt32(0), ConstInt32(-1), existing_call}; + Array args = {tir::make_const(DataType::Int(32, 1), 0, Span()), tir::make_const(DataType::Int(32, 1), -1, Span()), existing_call}; return tir::Call(DataType::Int(32), tir::builtin::tvm_check_return(), args); } @@ -687,7 +508,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { auto retval_i = tir::BufferLoad(tmp_read, {loop_idx}); // Copy the variable from the input to the output tir::Stmt copy = - tir::For(loop_idx, 0, ConstInt32(size), tir::ForKind::kSerial, + tir::For(loop_idx, 0, tir::make_const(DataType::Int(32, 1), size, Span()), tir::ForKind::kSerial, tir::BufferStore(tmp_write, tir::Let(tmp_read->data, in, retval_i), {loop_idx})); stmts_.push_back(tir::LetStmt(tmp_write->data, out, copy)); } @@ -757,7 +578,11 @@ class AOTExecutorCodegen : public MixedModeVisitor { return it.second->name_hint == context->name_hint; }); const String& device_name = (*it).first; - return MakeDeviceHookCall(device_name, hook, context); + Array sections = {"Device", device_name, hook}; + String device_hook = ToCFunctionStyle(PrefixName(sections)); + + return tir::Evaluate(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), + {tvm::tir::StringImm(device_hook), context})); } /*! @@ -927,7 +752,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { for (int i = 0; i < ndim; i++) { int shape = kv.second->data->shape[i]; - extents.push_back(tir::make_const(DataType::Int(32), shape)); + extents.push_back(tir::make_const(DataType::Int(32), shape, Span())); } body = tir::AllocateConst(buffer_var, dtype, extents, kv.second->data, body); } diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 5b54f5716d3a..033275ae5286 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -807,10 +807,6 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& bool use_string_lookup) { PackedCall pc; std::string func_name = args[0].as()->value; - llvm::Value* handle = nullptr; - if (use_string_lookup) { - handle = GetPackedFuncHandle(func_name); - } // call the function int64_t nargs = end - begin; ICHECK_GE(nargs, 0); @@ -834,7 +830,9 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& if (use_string_lookup) { callee_ftype = ftype_tvm_func_call_; callee_value = RuntimeTVMFuncCall(); - call_args.push_back(handle); + call_args.push_back(GetPackedFuncHandle(func_name)); + call_args.insert(call_args.end(), + {arg_value, arg_tcode.addr, ConstInt32(nargs), ret_value, ret_tcode.addr}); } else { callee_ftype = ftype_tvm_backend_packed_c_func_; callee_value = module_->getFunction(func_name); @@ -843,12 +841,7 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& llvm::Function::Create(ftype_tvm_backend_packed_c_func_, llvm::Function::ExternalLinkage, func_name, module_.get()); } - } - if (use_string_lookup) { - call_args.insert(call_args.end(), - {arg_value, arg_tcode.addr, ConstInt32(nargs), ret_value, ret_tcode.addr}); - } else { nargs -= 1; call_args.insert(call_args.end(), { builder_->CreateBitCast(arg_value, t_void_p_), @@ -1048,13 +1041,10 @@ class MetadataTypeDefiner : public AttrVisitor { void DefineType(runtime::metadata::MetadataBase metadata) { ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); - LOG(INFO) << "Created type for " << metadata->GetTypeKey() << ":"; for (auto e : elements_) { std::string value; llvm::raw_string_ostream os(value); e->print(os, true); - // LOG(INFO) << " - " << e << ", tyid=" << e->getTypeID() << " == " << value; - // e->dump(); } llvm_types_->structs_by_type_key[metadata->GetTypeKey()] = llvm::StructType::create(*ctx_, elements_, metadata->get_c_struct_name()); @@ -1120,18 +1110,6 @@ class MetadataSerializerLLVM : public AttrVisitor { auto struct_ty = llvm_types_->structs_by_type_key[metadata->GetTypeKey()]; ICHECK(struct_ty != nullptr) << "Did not find LLVM StructType* for type_key=" << metadata->GetTypeKey(); - std::string ty_value; - llvm::raw_string_ostream ty_os(ty_value); - struct_ty->print(ty_os, true); - LOG(INFO) << "Get LLVM ConstantStruct (" << struct_elements.size() << " elements)"; - LOG(INFO) << " Type (" << metadata->GetTypeKey() << "==" << struct_ty->getName().data() - << "): " << ty_value; - for (auto e : struct_elements) { - std::string value; - llvm::raw_string_ostream os(value); - e->print(os); - LOG(INFO) << " - " << value; - } CHECK_EQ(struct_elements.size(), struct_ty->getNumElements()); auto out = llvm::ConstantStruct::get(struct_ty, struct_elements); if (elements_.size() > 0) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 172eb5ef1019..8675b824a914 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -406,6 +406,17 @@ class CodeGenLLVM : public ExprFunctor, unsigned int shared_address_space, int alignment, llvm::GlobalValue::LinkageTypes linkage); + /*! + * \brief Get the `i`th argument to the given function, respecting LLVM API changes. + * + * NOTE: in LLVM < 10.0, the underlying API returns a const llvm::Argument*. To provide a uniform + * API, const is removed here. Proper usage of LLVM APIs depends on having a non-const Argument*, + * so we take this appraoch here rather than adding const. + * + * \param function The function containing the arguments. + * \param i The index of the argument to retrieve. + * \return The retrieved argument. + */ llvm::Argument* GetArg(const llvm::Function* function, int i) const { #if TVM_LLVM_VERSION >= 100 return function->getArg(i); diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 066e091e637c..ab679bdedd1f 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -554,11 +554,6 @@ runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata << "LLVM module verification failed with the following errors: \n" << verify_errors.str(); - // std::string tmp; - // llvm::raw_string_ostream stream(tmp); - // mod->print(stream, nullptr); - // LOG(INFO) << "LLVM metadata IR: " << stream.str(); - auto n = make_object(); n->Init(std::move(mod), ctx); diff --git a/src/target/metadata_module.cc b/src/target/metadata_module.cc index 70b1896fec0e..5457946322c3 100644 --- a/src/target/metadata_module.cc +++ b/src/target/metadata_module.cc @@ -144,7 +144,7 @@ static runtime::Module CreateCppMetadataModule( auto metadata_module = CreateCSourceCppMetadataModule(runtime_metadata); metadata_module->Import(target_module); target_module = metadata_module; -#ifdef TVM_LLVM_VERSION +#ifdef TVM_LLVM_VERSION // defining TVM_LLVM_VERSION indicates TVM was compiled with USE_LLVM ON. } else if (target->kind->name == "llvm") { auto metadata_module = CreateLLVMCppMetadataModule(runtime_metadata, target, runtime); metadata_module->Import(target_module); diff --git a/src/target/metadata_utils.h b/src/target/metadata_utils.h index 1486a885e502..d3f24e0888ab 100644 --- a/src/target/metadata_utils.h +++ b/src/target/metadata_utils.h @@ -37,10 +37,30 @@ namespace tvm { namespace codegen { - namespace metadata { +/*! + * \brief Construct a unique string "address" for a struct member from a vector of pieces. + * + * In codegen, it is frequently necessary to assemble a C-style identifier for an + * otherwise-anonymous member of Metadata. For instance, suppose Metadata declares an array: + * struct TVMMetadata { + * int64_t* shape; + * }; + * + * In order to properly initialize this struct, the array must be declared separately with a global name. + * This function produces such a name, here termed "address." + * + * \param parts A vector of pieces, typically the struct member names which identify the path to + * this member. + * \return The joined pieces. + */ std::string address_from_parts(const std::vector& parts); + +/*! + * \brief A prefix in metadata symbol names. + * This prefix is typically given to address_from_parts as the 0th item in parts. + */ static constexpr const char* kMetadataGlobalSymbol = "kTvmgenMetadata"; /*! From d756d79d71c78c0c08b55be029a753ce7253ed34 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 5 Apr 2022 16:40:24 -0700 Subject: [PATCH 20/32] Fix convert_pool_allocations_to_offsets test. --- ...orm_convert_pool_allocations_to_offsets.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index 4ed02615cd44..99fff94d9c73 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -141,8 +141,8 @@ def __tvm_main__(input: T.handle, output: T.handle) -> None: class LinearStructurePlanned: @T.prim_func def __tvm_main__(input: T.handle, fast_memory_0_var: T.Ptr[T.uint8], slow_memory_1_var: T.Ptr[T.uint8], output: T.handle) -> None: - fast_memory_0_buffer_var = T.match_buffer(fast_memory_0_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) - slow_memory_1_buffer_var = T.match_buffer(slow_memory_1_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) + fast_memory_0_buffer_var = T.match_buffer(fast_memory_0_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) + slow_memory_1_buffer_var = T.match_buffer(slow_memory_1_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) @@ -156,8 +156,8 @@ def __tvm_main__(input: T.handle, fast_memory_0_var: T.Ptr[T.uint8], slow_memory def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle, fast_memory_6_var: T.Ptr[T.uint8], slow_memory_7_var: T.Ptr[T.uint8]) -> None: placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8") T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16") - fast_memory_6_buffer_var = T.match_buffer(fast_memory_6_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) - slow_memory_7_buffer_var = T.match_buffer(slow_memory_7_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) + fast_memory_6_buffer_var = T.match_buffer(fast_memory_6_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) + slow_memory_7_buffer_var = T.match_buffer(slow_memory_7_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body tensor_2_let = T.buffer_decl([200704], dtype="uint8") with T.let(tensor_2_let.data, T.address_of(fast_memory_6_buffer_var[0], dtype="handle")): @@ -174,8 +174,8 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8") placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16") T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16") - fast_memory_2_buffer_var = T.match_buffer(fast_memory_2_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) - slow_memory_3_buffer_var = T.match_buffer(slow_memory_3_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) + fast_memory_2_buffer_var = T.match_buffer(fast_memory_2_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) + slow_memory_3_buffer_var = T.match_buffer(slow_memory_3_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body for ax0_ax1_fused_1, ax2_1, ax3_inner_1 in T.grid(224, 224, 3): T_subtract_1[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1] = T.cast(placeholder_4[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1], "int16") - placeholder_5[0] @@ -186,8 +186,8 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16") placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32") T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8") - fast_memory_4_buffer_var = T.match_buffer(fast_memory_4_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) - slow_memory_5_buffer_var = T.match_buffer(slow_memory_5_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) + fast_memory_4_buffer_var = T.match_buffer(fast_memory_4_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) + slow_memory_5_buffer_var = T.match_buffer(slow_memory_5_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_7_let = T.buffer_decl([157323], "int16") with T.let(PaddedInput_7_let.data, T.address_of(slow_memory_5_buffer_var[802816], dtype="handle")): @@ -371,7 +371,7 @@ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(p placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8") placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") T_cast_1 = T.match_buffer(T_cast, [215], dtype="int16") - global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): T_cast_1[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16") @@ -383,7 +383,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32") placeholder_28 = T.match_buffer(placeholder_25, [1440000], dtype="int32") T_cast_7 = T.match_buffer(T_cast_6, [407], dtype="uint8") - global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_3_let = T.buffer_decl([360000], 'int16') with T.let(PaddedInput_3_let.data, T.address_of(global_workspace_5_buffer_var[6480000], dtype="handle")): @@ -406,7 +406,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16") placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32") T_add_1 = T.match_buffer(T_add, [407], dtype="int32") - global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_2_let = T.buffer_decl([360000], "int16") with T.let(PaddedInput_2_let.data, T.address_of(global_workspace_4_buffer_var[7200000], dtype="handle")): @@ -429,7 +429,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16") placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32") T_cast_3 = T.match_buffer(T_cast_2, [215], dtype="int16") - global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_let = T.buffer_decl([360000], "int16") with T.let(PaddedInput_let.data, T.address_of(global_workspace_2_buffer_var[7200000], dtype="handle")): @@ -451,7 +451,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16") placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32") T_cast_5 = T.match_buffer(T_cast_4, [215], dtype="int16") - global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_1_let = T.buffer_decl([379456], "int16") with T.let(PaddedInput_1_let.data, T.address_of(global_workspace_3_buffer_var[0], dtype="handle")): @@ -469,7 +469,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla @T.prim_func def __tvm_main__(input: T.handle, global_workspace_0_var: T.Ptr[T.uint8], output: T.handle) -> None: - global_workspace_0_buffer_var = T.match_buffer(global_workspace_0_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + global_workspace_0_buffer_var = T.match_buffer(global_workspace_0_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) From 65635346a6025dc5536bd922602fa07073a2864c Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Thu, 7 Apr 2022 11:29:58 -0700 Subject: [PATCH 21/32] lint --- src/relay/backend/aot_executor_codegen.cc | 9 ++-- src/target/metadata_utils.h | 4 +- .../unittest/test_aot_legalize_packed_call.py | 48 +++++++++++++++---- 3 files changed, 47 insertions(+), 14 deletions(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index eba4260e7f41..4f3b7644ef62 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -399,7 +399,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { * returns the passed Call */ tir::Call AddCheckReturn(tir::Call existing_call) { - Array args = {tir::make_const(DataType::Int(32, 1), 0, Span()), tir::make_const(DataType::Int(32, 1), -1, Span()), existing_call}; + Array args = {tir::make_const(DataType::Int(32, 1), 0, Span()), + tir::make_const(DataType::Int(32, 1), -1, Span()), existing_call}; return tir::Call(DataType::Int(32), tir::builtin::tvm_check_return(), args); } @@ -507,9 +508,9 @@ class AOTExecutorCodegen : public MixedModeVisitor { te::Var loop_idx("i", DataType::Int(32)); auto retval_i = tir::BufferLoad(tmp_read, {loop_idx}); // Copy the variable from the input to the output - tir::Stmt copy = - tir::For(loop_idx, 0, tir::make_const(DataType::Int(32, 1), size, Span()), tir::ForKind::kSerial, - tir::BufferStore(tmp_write, tir::Let(tmp_read->data, in, retval_i), {loop_idx})); + tir::Stmt copy = tir::For( + loop_idx, 0, tir::make_const(DataType::Int(32, 1), size, Span()), tir::ForKind::kSerial, + tir::BufferStore(tmp_write, tir::Let(tmp_read->data, in, retval_i), {loop_idx})); stmts_.push_back(tir::LetStmt(tmp_write->data, out, copy)); } diff --git a/src/target/metadata_utils.h b/src/target/metadata_utils.h index d3f24e0888ab..3ad05b6dcb18 100644 --- a/src/target/metadata_utils.h +++ b/src/target/metadata_utils.h @@ -48,8 +48,8 @@ namespace metadata { * int64_t* shape; * }; * - * In order to properly initialize this struct, the array must be declared separately with a global name. - * This function produces such a name, here termed "address." + * In order to properly initialize this struct, the array must be declared separately with a global + * name. This function produces such a name, here termed "address." * * \param parts A vector of pieces, typically the struct member names which identify the path to * this member. diff --git a/tests/python/unittest/test_aot_legalize_packed_call.py b/tests/python/unittest/test_aot_legalize_packed_call.py index a3d18c95475c..c7c0daa30e2f 100644 --- a/tests/python/unittest/test_aot_legalize_packed_call.py +++ b/tests/python/unittest/test_aot_legalize_packed_call.py @@ -25,7 +25,9 @@ @tvm.script.ir_module class Module: @T.prim_func - def tvm_test_cpacked(A : T.handle, B: T.handle, C: T.handle, device_context: T.handle) -> T.handle: + def tvm_test_cpacked( + A: T.handle, B: T.handle, C: T.handle, device_context: T.handle + ) -> T.handle: A_0 = T.match_buffer(A, (1,), dtype="float32") A_0pre = T.preflattened_buffer(A_0, (1,), dtype="float32") B_0 = T.match_buffer(B, (1,), dtype="float32") @@ -56,7 +58,9 @@ def tir_packed_call() -> None: @tvm.script.ir_module class Expected: @T.prim_func - def tvm_test_cpacked(A : T.handle, B: T.handle, C: T.handle, device_context: T.handle) -> T.handle: + def tvm_test_cpacked( + A: T.handle, B: T.handle, C: T.handle, device_context: T.handle + ) -> T.handle: A_0 = T.match_buffer(A, (1,), dtype="float32") A_0pre = T.preflattened_buffer(A_0, (1,), dtype="float32") B_0 = T.match_buffer(B, (1,), dtype="float32") @@ -73,12 +77,40 @@ def tir_packed_call() -> None: device_context = T.var("handle") # body - T.evaluate(T.tvm_call_cpacked("tvm_test_cpacked", - T.tvm_stack_make_array(A, T.tvm_stack_make_shape(1, dtype="handle"), T.reinterpret(T.uint64(0), dtype="handle"), T.uint32(1), T.cast(0, dtype="float32"), 0, dtype="handle"), - T.tvm_stack_make_array(B, T.tvm_stack_make_shape(1, dtype="handle"), T.reinterpret(T.uint64(0), dtype="handle"), T.uint32(1), T.cast(0, dtype="float32"), 0, dtype="handle"), - T.tvm_stack_make_array(C, T.tvm_stack_make_shape(1, dtype="handle"), T.reinterpret(T.uint64(0), dtype="handle"), T.uint32(1), T.cast(0, dtype="float32"), 0, dtype="handle"), - device_context, - dtype="int32")) + T.evaluate( + T.tvm_call_cpacked( + "tvm_test_cpacked", + T.tvm_stack_make_array( + A, + T.tvm_stack_make_shape(1, dtype="handle"), + T.reinterpret(T.uint64(0), dtype="handle"), + T.uint32(1), + T.cast(0, dtype="float32"), + 0, + dtype="handle", + ), + T.tvm_stack_make_array( + B, + T.tvm_stack_make_shape(1, dtype="handle"), + T.reinterpret(T.uint64(0), dtype="handle"), + T.uint32(1), + T.cast(0, dtype="float32"), + 0, + dtype="handle", + ), + T.tvm_stack_make_array( + C, + T.tvm_stack_make_shape(1, dtype="handle"), + T.reinterpret(T.uint64(0), dtype="handle"), + T.uint32(1), + T.cast(0, dtype="float32"), + 0, + dtype="handle", + ), + device_context, + dtype="int32", + ) + ) def test_aot_packed_call(): From e39deedbeb5c1f7e110d178c752a06aea46a462e Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Thu, 7 Apr 2022 15:16:43 -0700 Subject: [PATCH 22/32] Fix T.preflattened_buffer --- python/tvm/script/tir/special_stmt.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index ea523f1810ec..9cb83817da84 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -894,12 +894,30 @@ def preflattened_buffer( for key, value in self.context.func_buffer_map.items(): if value.same_as(postflattened): param = key + break assert ( param is not None ), f"Post-flatten buffer {postflattened.name} does not appear in the buffer map." + if data is None: + data = self.context.func_buffer_map[param].data + buffer_name: str = f"{postflattened.name}_preflatten" + if align != -1: + if isinstance(align, IntImm): + align = align.value + else: + assert isinstance(align, int), f"align: want int or IntImm, got {align!r}" + + if offset_factor != 0: + if isinstance(offset_factor, IntImm): + offset_factor = offset_factor.value + else: + assert isinstance( + offset_factor, int + ), f"offset_factor: want int or IntImm, got {offset_factor!r}" + preflattened = tvm.tir.decl_buffer( shape, dtype, From f2138d5cb8a2a502ec30ed95780ccba250ec241c Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Thu, 7 Apr 2022 15:17:55 -0700 Subject: [PATCH 23/32] Add preflattened_buffer_map to TIRTextPrinter --- src/printer/tir_text_printer.cc | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 1ef62c257648..fe829016b6b5 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -151,6 +151,17 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) { doc << Doc::Indent( 2, Doc::NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}"); } + + if (op->preflattened_buffer_map.size() != 0) { + // print preflattened_buffer_map + std::vector preflattened_buffer_map_doc; + for (auto& v : op->preflattened_buffer_map) { + preflattened_buffer_map_doc.push_back(Print(v.first) << ": " << Print(v.second)); + } + doc << Doc::Indent(2, Doc::NewLine() + << "preflattened_buffer_map = {" + << PrintSep(preflattened_buffer_map_doc, Doc::Text(", ")) << "}"); + } doc << PrintBody(op->body); return doc; } From c257f7fc0f742663b53038855e337b352799e684 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Thu, 7 Apr 2022 15:18:08 -0700 Subject: [PATCH 24/32] Fix tests --- ...orm_convert_pool_allocations_to_offsets.py | 71 ++++++++++++++++++- 1 file changed, 70 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index 99fff94d9c73..ce8675f575ee 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -74,8 +74,11 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T.preflattened_buffer(placeholder_4, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T.preflattened_buffer(placeholder_5, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1) T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T.preflattened_buffer(T_subtract_1, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body for ax0_ax1_fused_1 in T.serial(0, 224): for ax2_1, ax3_inner_1 in T.grid(224, 3): @@ -86,9 +89,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T.preflattened_buffer(placeholder_65, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T.preflattened_buffer(placeholder_66, [9408], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T.preflattened_buffer(placeholder_67, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T.preflattened_buffer(T_cast_21, [289], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): @@ -108,7 +115,9 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T.preflattened_buffer(placeholder_29, [802816], dtype="uint8", elem_offset=0, align=128, offset_factor=1) T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T.preflattened_buffer(T_cast_7, [177], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body tensor_2 = T.allocate([200704], "uint8", "global") for ax0_ax1_fused_4 in T.serial(0, 56): @@ -140,7 +149,7 @@ def __tvm_main__(input: T.handle, output: T.handle) -> None: @tvm.script.ir_module class LinearStructurePlanned: @T.prim_func - def __tvm_main__(input: T.handle, fast_memory_0_var: T.Ptr[T.uint8], slow_memory_1_var: T.Ptr[T.uint8], output: T.handle) -> None: + def __tvm_main__(input: T.handle, fast_memory_0_var: T.Ptr[T.uint8], slow_memory_1_var: T.Ptr[T.uint8], output: T.handle) -> None: fast_memory_0_buffer_var = T.match_buffer(fast_memory_0_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) slow_memory_1_buffer_var = T.match_buffer(slow_memory_1_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body @@ -155,9 +164,13 @@ def __tvm_main__(input: T.handle, fast_memory_0_var: T.Ptr[T.uint8], slow_memory @T.prim_func def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle, fast_memory_6_var: T.Ptr[T.uint8], slow_memory_7_var: T.Ptr[T.uint8]) -> None: placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8") + T.preflattened_buffer(placeholder_29, [802816], dtype="uint8") T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16") + T.preflattened_buffer(T_cast_7, [177], dtype="int16") fast_memory_6_buffer_var = T.match_buffer(fast_memory_6_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(fast_memory_6_buffer_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) slow_memory_7_buffer_var = T.match_buffer(slow_memory_7_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(slow_memory_7_buffer_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body tensor_2_let = T.buffer_decl([200704], dtype="uint8") with T.let(tensor_2_let.data, T.address_of(fast_memory_6_buffer_var[0], dtype="handle")): @@ -172,10 +185,15 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: @T.prim_func def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle, fast_memory_2_var: T.Ptr[T.uint8], slow_memory_3_var: T.Ptr[T.uint8]) -> None: placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8") + T.preflattened_buffer(placeholder_4, [150528], dtype="uint8") placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16") + T.preflattened_buffer(placeholder_5, [1], dtype="int16") T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16") + T.preflattened_buffer(T_subtract_1, [452], dtype="int16") fast_memory_2_buffer_var = T.match_buffer(fast_memory_2_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(fast_memory_2_buffer_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) slow_memory_3_buffer_var = T.match_buffer(slow_memory_3_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(slow_memory_3_buffer_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body for ax0_ax1_fused_1, ax2_1, ax3_inner_1 in T.grid(224, 224, 3): T_subtract_1[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1] = T.cast(placeholder_4[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1], "int16") - placeholder_5[0] @@ -183,11 +201,17 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle, fast_memory_4_var: T.Ptr[T.uint8], slow_memory_5_var: T.Ptr[T.uint8]) -> None: placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16") + T.preflattened_buffer(placeholder_65, [150528], dtype="int16") placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16") + T.preflattened_buffer(placeholder_66, [9408], dtype="int16") placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32") + T.preflattened_buffer(placeholder_67, [64], dtype="int32") T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8") + T.preflattened_buffer(T_cast_21, [289], dtype="uint8") fast_memory_4_buffer_var = T.match_buffer(fast_memory_4_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(fast_memory_4_buffer_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) slow_memory_5_buffer_var = T.match_buffer(slow_memory_5_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(slow_memory_5_buffer_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_7_let = T.buffer_decl([157323], "int16") with T.let(PaddedInput_7_let.data, T.address_of(slow_memory_5_buffer_var[802816], dtype="handle")): @@ -251,8 +275,11 @@ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(p # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8") + T.preflattened_buffer(placeholder_2, [360000], dtype="uint8") placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") + T.preflattened_buffer(placeholder_3, [64], dtype="int32") T_cast_1 = T.match_buffer(T_cast, [215], dtype="int16") + T.preflattened_buffer(T_cast_1, [215], dtype="int16") # body for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): T_cast_1[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16") @@ -262,9 +289,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True}) placeholder_13 = T.match_buffer(placeholder_10, [360000], dtype="int16") + T.preflattened_buffer(placeholder_13, [360000], dtype="int16") placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16") + T.preflattened_buffer(placeholder_14, [36864], dtype="int16") placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32") + T.preflattened_buffer(placeholder_15, [64], dtype="int32") T_cast_5 = T.match_buffer(T_cast_4, [215], dtype="int16") + T.preflattened_buffer(T_cast_5, [215], dtype="int16") # body PaddedInput_1 = T.allocate([379456], "int16", "global") for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): @@ -283,9 +314,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", "tir.noalias": True}) placeholder_19 = T.match_buffer(placeholder_16, [360000], dtype="int16") + T.preflattened_buffer(placeholder_19, [360000], dtype="int16") placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16") + T.preflattened_buffer(placeholder_20, [16384], dtype="int16") placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32") + T.preflattened_buffer(placeholder_21, [256], dtype="int32") T_add_1 = T.match_buffer(T_add, [407], dtype="int32") + T.preflattened_buffer(T_add_1, [407], dtype="int32") # body PaddedInput_2 = T.allocate([360000], "int16", "global") for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): @@ -305,10 +340,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", "tir.noalias": True}) placeholder_29 = T.match_buffer(placeholder_22, [360000], dtype="int16") + T.preflattened_buffer(placeholder_29, [360000], dtype="int16") placeholder_27 = T.match_buffer(placeholder_23, [16384], dtype="int16") + T.preflattened_buffer(placeholder_27, [16384], dtype="int16") placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32") + T.preflattened_buffer(placeholder_26, [256], dtype="int32") placeholder_28 = T.match_buffer(placeholder_25, [1440000], dtype="int32") + T.preflattened_buffer(placeholder_28, [1440000], dtype="int32") T_cast_7 = T.match_buffer(T_cast_6, [407], dtype="uint8") + T.preflattened_buffer(T_cast_7, [407], dtype="uint8") # body PaddedInput_3 = T.allocate([360000], "int16", "global") for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): @@ -345,9 +385,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True}) placeholder_7 = T.match_buffer(placeholder_4, [360000], dtype="int16") + T.preflattened_buffer(placeholder_7, [360000], dtype="int16") placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16") + T.preflattened_buffer(placeholder_8, [4096], dtype="int16") placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32") + T.preflattened_buffer(placeholder_9, [64], dtype="int32") T_cast_3 = T.match_buffer(T_cast_2, [215], dtype="int16") + T.preflattened_buffer(T_cast_3, [215], dtype="int16") # body PaddedInput = T.allocate([360000], "int16", "global") for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): @@ -369,9 +413,13 @@ class ResnetStructurePlanned: @T.prim_func def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: T.Ptr[T.uint8]) -> None: placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8") + T.preflattened_buffer(placeholder_2, [360000], dtype="uint8") placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") + T.preflattened_buffer(placeholder_3, [64], dtype="int32") T_cast_1 = T.match_buffer(T_cast, [215], dtype="int16") + T.preflattened_buffer(T_cast_1, [215], dtype="int16") global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(global_workspace_1_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): T_cast_1[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16") @@ -379,11 +427,17 @@ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(p @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle, global_workspace_5_var: T.Ptr[T.uint8]) -> None: placeholder_29 = T.match_buffer(placeholder_22, [360000], dtype="int16") + T.preflattened_buffer(placeholder_29, [360000], dtype="int16") placeholder_27 = T.match_buffer(placeholder_23, [16384], dtype="int16") + T.preflattened_buffer(placeholder_27, [16384], dtype="int16") placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32") + T.preflattened_buffer(placeholder_26, [256], dtype="int32") placeholder_28 = T.match_buffer(placeholder_25, [1440000], dtype="int32") + T.preflattened_buffer(placeholder_28, [1440000], dtype="int32") T_cast_7 = T.match_buffer(T_cast_6, [407], dtype="uint8") + T.preflattened_buffer(T_cast_7, [407], dtype="uint8") global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(global_workspace_5_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_3_let = T.buffer_decl([360000], 'int16') with T.let(PaddedInput_3_let.data, T.address_of(global_workspace_5_buffer_var[6480000], dtype="handle")): @@ -403,10 +457,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle, global_workspace_4_var: T.Ptr[T.uint8]) -> None: placeholder_19 = T.match_buffer(placeholder_16, [360000], dtype="int16") + T.preflattened_buffer(placeholder_19, [360000], dtype="int16") placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16") + T.preflattened_buffer(placeholder_20, [16384], dtype="int16") placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32") + T.preflattened_buffer(placeholder_21, [256], dtype="int32") T_add_1 = T.match_buffer(T_add, [407], dtype="int32") + T.preflattened_buffer(T_add_1, [407], dtype="int32") global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(global_workspace_4_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_2_let = T.buffer_decl([360000], "int16") with T.let(PaddedInput_2_let.data, T.address_of(global_workspace_4_buffer_var[7200000], dtype="handle")): @@ -426,10 +485,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle, global_workspace_2_var: T.Ptr[T.uint8]) -> None: placeholder_7 = T.match_buffer(placeholder_4, [360000], dtype="int16") + T.preflattened_buffer(placeholder_7, [360000], dtype="int16") placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16") + T.preflattened_buffer(placeholder_8, [4096], dtype="int16") placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32") + T.preflattened_buffer(placeholder_9, [64], dtype="int32") T_cast_3 = T.match_buffer(T_cast_2, [215], dtype="int16") + T.preflattened_buffer(T_cast_3, [215], dtype="int16") global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(global_workspace_2_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_let = T.buffer_decl([360000], "int16") with T.let(PaddedInput_let.data, T.address_of(global_workspace_2_buffer_var[7200000], dtype="handle")): @@ -448,10 +512,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle, global_workspace_3_var: T.Ptr[T.uint8]) -> None: placeholder_13 = T.match_buffer(placeholder_10, [360000], dtype="int16") + T.preflattened_buffer(placeholder_13, [360000], dtype="int16") placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16") + T.preflattened_buffer(placeholder_14, [36864], dtype="int16") placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32") + T.preflattened_buffer(placeholder_15, [64], dtype="int32") T_cast_5 = T.match_buffer(T_cast_4, [215], dtype="int16") + T.preflattened_buffer(T_cast_5, [215], dtype="int16") global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) + T.preflattened_buffer(global_workspace_3_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_1_let = T.buffer_decl([379456], "int16") with T.let(PaddedInput_1_let.data, T.address_of(global_workspace_3_buffer_var[0], dtype="handle")): From 4705a18f4aec97bdd8b3734698e9b781a66fe1b3 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Thu, 7 Apr 2022 16:28:22 -0700 Subject: [PATCH 25/32] Fix BYOC --- src/tir/transforms/legalize_packed_calls.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/legalize_packed_calls.cc b/src/tir/transforms/legalize_packed_calls.cc index c8b1377162ef..43cb1fb03fa2 100644 --- a/src/tir/transforms/legalize_packed_calls.cc +++ b/src/tir/transforms/legalize_packed_calls.cc @@ -71,8 +71,13 @@ class PackedCallLegalizer : public StmtExprMutator { // all such stack-allocated tensors and minimize the storage needed by reusing // DLTensors. Array call_args{call->args[i]}; + tvm::runtime::Map::iterator param_buf_it; if (prim_func != nullptr) { - Buffer param = prim_func->preflattened_buffer_map[prim_func->params[i - 1]]; + auto param_var = prim_func->params[i - 1]; + param_buf_it = prim_func->preflattened_buffer_map.find(param_var); + } + if (prim_func != nullptr && param_buf_it != prim_func->preflattened_buffer_map.end()) { + Buffer param = (*param_buf_it).second; PrimExpr shape = tvm::tir::Call( DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), param->shape); Cast var_type(param->dtype, IntImm(DataType::Int(32), 0)); From 9642548e1f11c565dbca337ca0b1c16ca75fb1d7 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Fri, 8 Apr 2022 12:08:23 -0700 Subject: [PATCH 26/32] Fix invoking C device API. --- src/relay/backend/aot_executor_codegen.cc | 8 +++---- tests/python/relay/aot/test_c_device_api.py | 23 ++++++--------------- 2 files changed, 10 insertions(+), 21 deletions(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 4f3b7644ef62..c2b2ac0fc5e2 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -438,7 +438,6 @@ class AOTExecutorCodegen : public MixedModeVisitor { GlobalVar global_var = call_lowered_props.lowered_func; bool has_c_device_api_context = device_contexts_.count(global_var) != 0; - bool call_device_hooks = false; tir::Var device_context; tir::Stmt func_call; @@ -480,7 +479,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { ICHECK(func_call.defined()) << "Must define func_call"; - if (call_device_hooks) { + if (has_c_device_api_context) { func_call = tir::SeqStmt(Array({ GenerateDeviceHook(device_context, "Open"), func_call, @@ -582,8 +581,9 @@ class AOTExecutorCodegen : public MixedModeVisitor { Array sections = {"Device", device_name, hook}; String device_hook = ToCFunctionStyle(PrefixName(sections)); - return tir::Evaluate(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), - {tvm::tir::StringImm(device_hook), context})); + return tir::Evaluate( + AddCheckReturn(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), + {tvm::tir::StringImm(device_hook), context}))); } /*! diff --git a/tests/python/relay/aot/test_c_device_api.py b/tests/python/relay/aot/test_c_device_api.py index 6a12a38d35c2..d547b52e85c3 100644 --- a/tests/python/relay/aot/test_c_device_api.py +++ b/tests/python/relay/aot/test_c_device_api.py @@ -143,6 +143,7 @@ def test_device_api_hooks_unpacked_api(device_api_main_func): + " device_context_ethos_u))\n" ) # Open Device + print("main func", repr(main_func.body)) assert ( str(main_func.body[1][0][0][0]) == "tir.tvm_check_return(0, -1, tir.call_extern(" @@ -239,23 +240,11 @@ def test_without_device_api_packed_api(non_device_api_main_func): main_func = non_device_api_main_func(interface_api="packed", use_unpacked_api=False) assert str(main_func.body) == ( - 'let tvm_value_3 = tir.tvm_stack_alloca("array", 1)\n' - 'let tvm_value_2 = tir.tvm_stack_alloca("array", 1)\n' - 'let tvm_value_1 = tir.tvm_stack_alloca("array", 1)\n' - 'let tvm_value_0 = tir.tvm_stack_alloca("array", 1)\n' - "tir.tvm_struct_set(tvm_value_0, 0, 1, x_buffer_var)\n" - "tir.tvm_struct_set(tvm_value_0, 0, 10, 1)\n" - "tir.tvm_struct_set(tvm_value_0, 0, 9, 0)\n" - "tir.tvm_struct_set(tvm_value_1, 0, 1, y_buffer_var)\n" - "tir.tvm_struct_set(tvm_value_1, 0, 10, 1)\n" - "tir.tvm_struct_set(tvm_value_1, 0, 9, 0)\n" - "tir.tvm_struct_set(tvm_value_2, 0, 1, output_buffer_var)\n" - "tir.tvm_struct_set(tvm_value_2, 0, 10, 1)\n" - "tir.tvm_struct_set(tvm_value_2, 0, 9, 0)\n" - "tir.tvm_struct_set(tvm_value_3, 0, 1, tir.reinterpret((uint64)0))\n" - "tir.tvm_struct_set(tvm_value_3, 0, 10, 1)\n" - "tir.tvm_struct_set(tvm_value_3, 0, 9, 0)\n" - 'tir.tvm_call_cpacked("tvmgen_default_fused_multiply", tvm_value_0, tvm_value_1, tvm_value_2, tvm_value_3)\n' + 'tir.tvm_call_cpacked("tvmgen_default_fused_multiply", ' + "tir.tvm_stack_make_array(x_buffer_var, tir.tvm_stack_make_shape(10, 10), tir.reinterpret((uint64)0), (uint32)2, float32(0), 0), " + "tir.tvm_stack_make_array(y_buffer_var, tir.tvm_stack_make_shape(1, 10), tir.reinterpret((uint64)0), (uint32)2, float32(0), 0), " + "tir.tvm_stack_make_array(output_buffer_var, tir.tvm_stack_make_shape(10, 10), tir.reinterpret((uint64)0), (uint32)2, float32(0), 0), " + "tir.reinterpret((uint64)0))\n" ) From 66f08987331a93c8cf2fb29ee3155bf2170e9d1c Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Sun, 10 Apr 2022 22:11:10 -0700 Subject: [PATCH 27/32] remove comments --- src/target/llvm/codegen_llvm.cc | 2 -- src/target/llvm/codegen_llvm.h | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 47cb3dd711d7..342204032403 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1411,11 +1411,9 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { - // LOG(INFO) << "Visit Call:" << GetRef(op); if (auto* ptr_op = op->op.as()) { auto call_op = GetRef(ptr_op); if (op->op.same_as(builtin_lookup_param_)) { - // return llvm::ConstantInt::get(t_void_p_, 0); return GetLinkedParamSymbol(Downcast(op->args[0])->value, nullptr); } else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { // call extern intrinsic diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 8675b824a914..9d79ec572e8f 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -143,7 +143,7 @@ class CodeGenLLVM : public ExprFunctor, * \return created value. */ llvm::Value* MakeValue(const PrimExpr& e) { - auto a = VisitExpr(e); /* LOG(INFO) << "MakeValue (" << e << "): " << a; */ + auto a = VisitExpr(e); return a; } // Short hande code to get a constant int 32 From 1b36e6ecd325120ab5406f5fca647a45695f3b50 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Wed, 13 Apr 2022 12:15:04 -0700 Subject: [PATCH 28/32] Address Mousius comments --- CMakeLists.txt | 11 +- python/tvm/testing/tir.py | 11 ++ src/target/llvm/codegen_llvm.h | 5 +- src/target/metadata_utils.cc | 8 +- src/target/metadata_utils.h | 11 +- src/target/source/source_module.cc | 4 +- tests/cpp/aot_metadata_test.cc | 145 +++++++++++++++--- .../unittest/test_tvmscript_error_report.py | 14 ++ .../unittest/test_tvmscript_syntax_sugar.py | 13 ++ 9 files changed, 177 insertions(+), 45 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e59a112fab04..2bd33d249fb3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -431,6 +431,15 @@ if(USE_GTEST) find_package(GTest REQUIRED) endif() if(GTEST_FOUND) + # GMock is formally supported in CMake 3.20; for now, expect libgmock.a in the same directory, + # and require that folks compiling against GMock::GMock also link against GTest::GTest + # (for the includes dir). + add_library(GMock::GMock STATIC IMPORTED GLOBAL) + get_target_property(GTEST_LIB_PATH GTest::GTest IMPORTED_LOCATION) + get_filename_component(GTEST_LIB_DIR "${GTEST_LIB_PATH}" DIRECTORY) + set_target_properties(GMock::GMock PROPERTIES + IMPORTED_LOCATION "${GTEST_LIB_DIR}/libgmock.a") + enable_testing() include(CTest) endif() @@ -626,7 +635,7 @@ if(GTEST_FOUND) add_executable(cpptest ${TEST_SRCS}) # include runtime files for unit testing target_include_directories(cpptest PUBLIC "src/runtime") - target_link_libraries(cpptest PRIVATE ${TVM_TEST_LIBRARY_NAME} GTest::GTest GTest::Main pthread dl) + target_link_libraries(cpptest PRIVATE ${TVM_TEST_LIBRARY_NAME} GTest::GTest GTest::Main GMock::GMock pthread dl) set_target_properties(cpptest PROPERTIES EXCLUDE_FROM_ALL 1) set_target_properties(cpptest PROPERTIES EXCLUDE_FROM_DEFAULT_BUILD 1) # For some reason, compile definitions are not propagated correctly, so we manually add them here diff --git a/python/tvm/testing/tir.py b/python/tvm/testing/tir.py index f9115fc61bfa..d208a42b8360 100644 --- a/python/tvm/testing/tir.py +++ b/python/tvm/testing/tir.py @@ -17,10 +17,14 @@ # pylint: disable=invalid-name, import-outside-toplevel, unused-variable """Common utility functions in TVM tir""" import inspect +import re import tvm from tvm.ir.diagnostics import override_renderer +CHECK_ERROR_RE = re.compile(r'^.*# check_error: (.+)$') + + def check_error(func, rel_lineno): """check if TIR script throws error""" # Override the default renderer to accumulate errors @@ -46,3 +50,10 @@ def render(e): assert ( d.span.line - 1 == rel_lineno ), f"Expected error to be on line {rel_lineno}, but it was on {d.span.line - 1}" + + error_line = source_code.split('\n')[rel_lineno] + m = CHECK_ERROR_RE.match(error_line) + if m: + expected_error_text = m.group(1) + errors = [e.message for e in errors] + assert expected_error_text in errors, f'check_error expects "{expected_error_text} in str(errors): {errors}' diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 9d79ec572e8f..3c80952b1fd7 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -142,10 +142,7 @@ class CodeGenLLVM : public ExprFunctor, * \param e The expression to be created value for. * \return created value. */ - llvm::Value* MakeValue(const PrimExpr& e) { - auto a = VisitExpr(e); - return a; - } + llvm::Value* MakeValue(const PrimExpr& e) { return VisitExpr(e); } // Short hande code to get a constant int 32 llvm::Constant* ConstInt32(int64_t value) const { return llvm::ConstantInt::getSigned(t_int32_, value); diff --git a/src/target/metadata_utils.cc b/src/target/metadata_utils.cc index 1034fd570007..db17d1862846 100644 --- a/src/target/metadata_utils.cc +++ b/src/target/metadata_utils.cc @@ -27,9 +27,7 @@ namespace tvm { namespace codegen { namespace metadata { -DiscoverArraysVisitor::DiscoverArraysVisitor(std::vector* queue) : queue_{queue} {} - -std::string address_from_parts(const std::vector& parts) { +std::string AddressFromParts(const std::vector& parts) { std::stringstream ss; for (unsigned int i = 0; i < parts.size(); ++i) { if (i > 0) { @@ -40,6 +38,8 @@ std::string address_from_parts(const std::vector& parts) { return ss.str(); } +DiscoverArraysVisitor::DiscoverArraysVisitor(std::vector* queue) : queue_{queue} {} + void DiscoverArraysVisitor::Visit(const char* key, double* value) {} void DiscoverArraysVisitor::Visit(const char* key, int64_t* value) {} void DiscoverArraysVisitor::Visit(const char* key, uint64_t* value) {} @@ -69,7 +69,7 @@ void DiscoverArraysVisitor::Visit(const char* key, ObjectRef* value) { } } - queue_->push_back(std::make_tuple(address_from_parts(address_parts_), + queue_->push_back(std::make_tuple(AddressFromParts(address_parts_), Downcast(metadata))); } else { ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); diff --git a/src/target/metadata_utils.h b/src/target/metadata_utils.h index 3ad05b6dcb18..977a0f412bb5 100644 --- a/src/target/metadata_utils.h +++ b/src/target/metadata_utils.h @@ -55,11 +55,11 @@ namespace metadata { * this member. * \return The joined pieces. */ -std::string address_from_parts(const std::vector& parts); +std::string AddressFromParts(const std::vector& parts); /*! * \brief A prefix in metadata symbol names. - * This prefix is typically given to address_from_parts as the 0th item in parts. + * This prefix is typically given to AddressFromParts as the 0th item in parts. */ static constexpr const char* kMetadataGlobalSymbol = "kTvmgenMetadata"; @@ -103,13 +103,6 @@ class DiscoverArraysVisitor : public AttrVisitor { */ class DiscoverComplexTypesVisitor : public AttrVisitor { public: - /*! \brief Models a single complex type discovered in this visitor. - * Contains two fields: - * 0. The struct_name for this Metadata instance. - * 1. The discovered MetadataArray. - */ - using DiscoveredComplexType = std::tuple; - /*! \brief Construct a new instance. * \param queue An ordered map which holds the */ diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 635775a7777d..ef5755f3e84b 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -655,7 +655,7 @@ class MetadataSerializer : public AttrVisitor { if (key != nullptr) { address_.push_back(key); } - code_ << metadata::address_from_parts(address_); + code_ << metadata::AddressFromParts(address_); if (key != nullptr) { address_.pop_back(); } @@ -733,7 +733,7 @@ class MetadataSerializer : public AttrVisitor { // Finally, emit overall struct. address_.push_back(metadata::kMetadataGlobalSymbol); - code_ << "const struct TVMMetadata " << metadata::address_from_parts(address_) << " = {" + code_ << "const struct TVMMetadata " << metadata::AddressFromParts(address_) << " = {" << std::endl; Visit(nullptr, &metadata); code_ << "};" << std::endl; diff --git a/tests/cpp/aot_metadata_test.cc b/tests/cpp/aot_metadata_test.cc index 0fa03af3b738..28dab3b33582 100644 --- a/tests/cpp/aot_metadata_test.cc +++ b/tests/cpp/aot_metadata_test.cc @@ -1,4 +1,3 @@ - /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file @@ -25,6 +24,7 @@ #include #include "../src/target/metadata.h" +#include "../src/target/metadata_utils.h" namespace { @@ -46,12 +46,28 @@ const struct TVMMetadata kNormal = { } // namespace using ::testing::ElementsAre; +using ::testing::ElementsAreArray; using ::testing::Eq; +using ::testing::Matcher; +using ::testing::MatcherInterface; +using ::testing::MatchResultListener; using ::testing::StrEq; + +using ::tvm::codegen::metadata::DiscoverArraysVisitor; +using ::tvm::codegen::metadata::DiscoverComplexTypesVisitor; +using ::tvm::codegen::metadata::kMetadataGlobalSymbol; + +using ::tvm::runtime::Array; using ::tvm::runtime::Downcast; +using ::tvm::runtime::ObjectRef; + +using ::tvm::runtime::metadata::TensorInfo; +using ::tvm::runtime::metadata::Metadata; +using ::tvm::runtime::metadata::MetadataArray; +using ::tvm::runtime::metadata::MetadataKind; TEST(Metadata, ParseStruct) { - tvm::runtime::metadata::Metadata md = tvm::runtime::metadata::Metadata(&kNormal); + Metadata md = Metadata(&kNormal); EXPECT_THAT(md->version(), Eq(TVM_METADATA_VERSION)); EXPECT_THAT(md->num_inputs(), Eq(2)); @@ -137,7 +153,7 @@ class TestVisitor : public tvm::AttrVisitor { }; TEST(Metadata, Visitor) { - tvm::runtime::metadata::Metadata md = tvm::runtime::metadata::Metadata(&kNormal); + Metadata md = Metadata(&kNormal); TestVisitor v; ::tvm::ReflectionVTable::Global()->VisitAttrs(md.operator->(), &v); @@ -149,17 +165,17 @@ TEST(Metadata, Visitor) { EXPECT_THAT(Downcast(v.values[0])->value, Eq(TVM_METADATA_VERSION)); // Just identify the tensor. - auto input_array = Downcast(v.values[1]); - EXPECT_THAT(input_array->kind, Eq(tvm::runtime::metadata::MetadataKind::kMetadata)); + auto input_array = Downcast(v.values[1]); + EXPECT_THAT(input_array->kind, Eq(MetadataKind::kMetadata)); EXPECT_THAT(input_array->type_key, StrEq("metadata.TensorInfoNode")); EXPECT_THAT(input_array->array.size(), Eq(2)); - auto input1 = Downcast(input_array->array[0]); + auto input1 = Downcast(input_array->array[0]); EXPECT_THAT(input1->name(), StrEq("input1")); EXPECT_THAT(input1->shape(), ElementsAre(1, 5, 5, 3)); EXPECT_THAT(input1->dtype(), tvm::runtime::DataType(DLDataType{1, 2, 3})); - auto input2 = Downcast(input_array->array[1]); + auto input2 = Downcast(input_array->array[1]); EXPECT_THAT(input1->name(), StrEq("input1")); EXPECT_THAT(input1->shape(), ElementsAre(1, 5, 5, 3)); EXPECT_THAT(input1->dtype(), tvm::runtime::DataType(DLDataType{1, 2, 3})); @@ -167,20 +183,20 @@ TEST(Metadata, Visitor) { auto num_inputs = Downcast(v.values[2]); EXPECT_THAT(num_inputs->value, Eq(2)); - auto output_array = Downcast(v.values[3]); - EXPECT_THAT(output_array->kind, Eq(tvm::runtime::metadata::MetadataKind::kMetadata)); + auto output_array = Downcast(v.values[3]); + EXPECT_THAT(output_array->kind, Eq(MetadataKind::kMetadata)); EXPECT_THAT(output_array->type_key, StrEq("metadata.TensorInfoNode")); - auto output1 = Downcast(output_array->array[0]); + auto output1 = Downcast(output_array->array[0]); EXPECT_THAT(output1->name(), Eq("output1")); auto num_outputs = Downcast(v.values[4]); EXPECT_THAT(num_outputs->value, Eq(1)); - auto pool_array = Downcast(v.values[5]); - EXPECT_THAT(pool_array->kind, Eq(tvm::runtime::metadata::MetadataKind::kMetadata)); + auto pool_array = Downcast(v.values[5]); + EXPECT_THAT(pool_array->kind, Eq(MetadataKind::kMetadata)); EXPECT_THAT(pool_array->type_key, StrEq("metadata.TensorInfoNode")); - auto pool1 = Downcast(pool_array->array[0]); + auto pool1 = Downcast(pool_array->array[0]); EXPECT_THAT(pool1->name(), Eq("pool1")); @@ -193,23 +209,22 @@ TEST(Metadata, Visitor) { using ::tvm::runtime::make_object; TEST(Metadata, InMemory) { - tvm::runtime::metadata::Metadata md = - tvm::runtime::metadata::Metadata(make_object( + Metadata md = Metadata(make_object( TVM_METADATA_VERSION, - std::vector( - {tvm::runtime::metadata::TensorInfo( + std::vector( + {TensorInfo( make_object( tvm::String("Input1"), std::vector{1, 5, 5, 3}, tvm::runtime::DataType(DLDataType{1, 2, 3}))), - tvm::runtime::metadata::TensorInfo( + TensorInfo( make_object( tvm::String("Input2"), std::vector{1, 5, 5, 3}, tvm::runtime::DataType(DLDataType{2, 3, 4})))}), - std::vector({tvm::runtime::metadata::TensorInfo( + std::vector({TensorInfo( make_object( tvm::String("Output1"), std::vector{3, 8, 8}, tvm::runtime::DataType(DLDataType{3, 4, 5})))}), - std::vector({tvm::runtime::metadata::TensorInfo( + std::vector({TensorInfo( make_object( tvm::String("Pool1"), std::vector{5, 10, 10}, tvm::runtime::DataType(DLDataType{3, 4, 7})))}), @@ -251,14 +266,14 @@ TEST(Metadata, InMemory) { } TEST(Metadata, ZeroElementLists) { - tvm::runtime::metadata::Metadata md = - tvm::runtime::metadata::Metadata(make_object( - TVM_METADATA_VERSION, std::vector({}), - std::vector({tvm::runtime::metadata::TensorInfo( + Metadata md = + Metadata(make_object( + TVM_METADATA_VERSION, std::vector({}), + std::vector({TensorInfo( make_object( tvm::String("Output1"), std::vector{}, tvm::runtime::DataType(DLDataType{3, 4, 5})))}), - std::vector({}), "default")); + std::vector({}), "default")); EXPECT_THAT(md->data()->num_inputs, Eq(0)); EXPECT_THAT(md->inputs().size(), Eq(0)); @@ -274,3 +289,83 @@ TEST(Metadata, ZeroElementLists) { EXPECT_THAT(md->num_pools(), Eq(0)); EXPECT_THAT(md->pools(), ElementsAre()); } + +TEST(MetadataArray, GetElementCStructName) { + MetadataArray arr_struct{make_object( + Array(), MetadataKind::kMetadata, "metadata.FooMetadataNode")}; + EXPECT_THAT(arr_struct->kind, Eq(MetadataKind::kMetadata)); + EXPECT_THAT(arr_struct->get_element_c_struct_name(), StrEq("TVMFooMetadata")); + + MetadataArray arr_int{make_object( + Array(), MetadataKind::kInt64, nullptr)}; + EXPECT_THROW(arr_int->get_element_c_struct_name(), std::runtime_error); +} + +namespace { +std::string ExplainDiscoveredNameEq(bool negation, std::string expected_name) { + std::stringstream ss; + ss << "std::get<0>(discovered_array) " << (negation ? "isn't" : "is") << " equal to " << expected_name; + return ss.str(); +} +} + +MATCHER_P(DiscoveredNameEq, expected_name, ExplainDiscoveredNameEq(negation, expected_name)) { + return std::string{std::get<0>(arg)} == expected_name; +} + +TEST(DiscoverArraysVisitor, DiscoverArrays) { + std::vector q; + DiscoverArraysVisitor visitor(&q); + + Metadata md = Metadata(&kNormal); + visitor.Visit(kMetadataGlobalSymbol, &md); + + EXPECT_THAT(q, ElementsAreArray( + {DiscoveredNameEq("kTvmgenMetadata_inputs_0_shape"), + DiscoveredNameEq("kTvmgenMetadata_inputs_1_shape"), + DiscoveredNameEq("kTvmgenMetadata_inputs"), + DiscoveredNameEq("kTvmgenMetadata_outputs_0_shape"), + DiscoveredNameEq("kTvmgenMetadata_outputs"), + DiscoveredNameEq("kTvmgenMetadata_pools_0_shape"), + DiscoveredNameEq("kTvmgenMetadata_pools")})); +} + +template ::value, bool> = true> +class TVMObjectIsInstanceMatcher : public MatcherInterface { + public: + using is_gtest_matcher = void; + + bool MatchAndExplain(tvm::runtime::metadata::MetadataBase arg, MatchResultListener* os) const override { + bool result = arg->IsInstance(); + if (!result) { + (*os) << "is an instance of type " << T::ContainerType::_type_key; + } + + return result; + } + + void DescribeTo(std::ostream* os) const override { + (*os) << "is an instance of type " << T::ContainerType::_type_key; + } + + void DescribeNegationTo(std::ostream* os) const override { + (*os) << "is not an instance of type " << T::ContainerType::_type_key; + } +}; + +template +Matcher TVMObjectIsInstance() { + return Matcher(new TVMObjectIsInstanceMatcher()); +} + +TEST(DiscoverComplexTypesVisitor, DiscoverComplexTypes) { + std::vector q; + DiscoverComplexTypesVisitor visitor(&q); + + Metadata md = Metadata(&kNormal); + visitor.Discover(md); + + EXPECT_THAT(q, ElementsAre( + TVMObjectIsInstance(), + TVMObjectIsInstance())); +} diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 73be9d8cdc58..ea0fd01d2e66 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -635,6 +635,20 @@ def non_integer_typed_block_iter(): def test_non_integer_typed_block_iter(): check_error(non_integer_typed_block_iter, 3) +def preflattened_buffer_map_align_nonint(foo: T.handle): + foo_1 = T.match_buffer(foo, [1]) + T.preflattened_buffer(foo_1, [1], align="bar") # check_error: align: want int or IntImm, got 'bar' + +def test_preflattened_buffer_map_align(): + check_error(preflattened_buffer_map_align_nonint, 3) + +def preflattened_buffer_map_offset_factor_nonint(foo: T.handle): + foo_1 = T.match_buffer(foo, [1]) + T.preflattened_buffer(foo_1, [1], offset_factor="bar") # check_error: offset_factor: want int or IntImm, got 'bar' + +def test_preflattened_buffer_map_offset_factor(): + check_error(preflattened_buffer_map_offset_factor_nonint, 3) + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py index 0e77b2a49454..c17c7b4348e4 100644 --- a/tests/python/unittest/test_tvmscript_syntax_sugar.py +++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py @@ -180,6 +180,19 @@ def test_dynamic_shape_gemm(): gemm_dyn_shape_roundtrip = from_source(gemm_dyn_shape.script()) assert_structural_equal(gemm_dyn_shape, gemm_dyn_shape_roundtrip) +@T.prim_func +def preflattened_buffer_map(A: T.handle, B: T.handle): + A_1 = T.match_buffer(A, [1]) + T.preflattened_buffer(A_1, [1], align=T.int32(1), offset_factor=T.int64(2)) + B_1 = T.match_buffer(B, [1]) + T.preflattened_buffer(B_1, [1]) + B_1[0] = A_1[0] + +def test_preflattened_buffer_map(): + A_var = [k for k, _ in preflattened_buffer_map.preflattened_buffer_map.items() if k.name == "A"][0] + assert preflattened_buffer_map.preflattened_buffer_map[A_var].data_alignment == 1 + assert preflattened_buffer_map.preflattened_buffer_map[A_var].offset_factor == 2 + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) From 792245a894d719d046f654d9fbc765e23df0fa91 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Wed, 13 Apr 2022 12:20:18 -0700 Subject: [PATCH 29/32] lint --- python/tvm/testing/tir.py | 8 +- tests/cpp/aot_metadata_test.cc | 88 +++++++++---------- .../unittest/test_tvmscript_error_report.py | 12 ++- .../unittest/test_tvmscript_syntax_sugar.py | 6 +- 4 files changed, 63 insertions(+), 51 deletions(-) diff --git a/python/tvm/testing/tir.py b/python/tvm/testing/tir.py index d208a42b8360..cedaafe80a52 100644 --- a/python/tvm/testing/tir.py +++ b/python/tvm/testing/tir.py @@ -22,7 +22,7 @@ from tvm.ir.diagnostics import override_renderer -CHECK_ERROR_RE = re.compile(r'^.*# check_error: (.+)$') +CHECK_ERROR_RE = re.compile(r"^.*# check_error: (.+)$") def check_error(func, rel_lineno): @@ -51,9 +51,11 @@ def render(e): d.span.line - 1 == rel_lineno ), f"Expected error to be on line {rel_lineno}, but it was on {d.span.line - 1}" - error_line = source_code.split('\n')[rel_lineno] + error_line = source_code.split("\n")[rel_lineno] m = CHECK_ERROR_RE.match(error_line) if m: expected_error_text = m.group(1) errors = [e.message for e in errors] - assert expected_error_text in errors, f'check_error expects "{expected_error_text} in str(errors): {errors}' + assert ( + expected_error_text in errors + ), f'check_error expects "{expected_error_text} in str(errors): {errors}' diff --git a/tests/cpp/aot_metadata_test.cc b/tests/cpp/aot_metadata_test.cc index 28dab3b33582..a5b3cd0fac14 100644 --- a/tests/cpp/aot_metadata_test.cc +++ b/tests/cpp/aot_metadata_test.cc @@ -61,10 +61,10 @@ using ::tvm::runtime::Array; using ::tvm::runtime::Downcast; using ::tvm::runtime::ObjectRef; -using ::tvm::runtime::metadata::TensorInfo; using ::tvm::runtime::metadata::Metadata; using ::tvm::runtime::metadata::MetadataArray; using ::tvm::runtime::metadata::MetadataKind; +using ::tvm::runtime::metadata::TensorInfo; TEST(Metadata, ParseStruct) { Metadata md = Metadata(&kNormal); @@ -210,25 +210,23 @@ TEST(Metadata, Visitor) { using ::tvm::runtime::make_object; TEST(Metadata, InMemory) { Metadata md = Metadata(make_object( - TVM_METADATA_VERSION, - std::vector( - {TensorInfo( - make_object( - tvm::String("Input1"), std::vector{1, 5, 5, 3}, - tvm::runtime::DataType(DLDataType{1, 2, 3}))), - TensorInfo( - make_object( - tvm::String("Input2"), std::vector{1, 5, 5, 3}, - tvm::runtime::DataType(DLDataType{2, 3, 4})))}), - std::vector({TensorInfo( - make_object( - tvm::String("Output1"), std::vector{3, 8, 8}, - tvm::runtime::DataType(DLDataType{3, 4, 5})))}), - std::vector({TensorInfo( - make_object( - tvm::String("Pool1"), std::vector{5, 10, 10}, - tvm::runtime::DataType(DLDataType{3, 4, 7})))}), - "default")); + TVM_METADATA_VERSION, + std::vector( + {TensorInfo(make_object( + tvm::String("Input1"), std::vector{1, 5, 5, 3}, + tvm::runtime::DataType(DLDataType{1, 2, 3}))), + TensorInfo(make_object( + tvm::String("Input2"), std::vector{1, 5, 5, 3}, + tvm::runtime::DataType(DLDataType{2, 3, 4})))}), + std::vector( + {TensorInfo(make_object( + tvm::String("Output1"), std::vector{3, 8, 8}, + tvm::runtime::DataType(DLDataType{3, 4, 5})))}), + std::vector( + {TensorInfo(make_object( + tvm::String("Pool1"), std::vector{5, 10, 10}, + tvm::runtime::DataType(DLDataType{3, 4, 7})))}), + "default")); auto md_data = md->data(); EXPECT_THAT(md_data->version, Eq(TVM_METADATA_VERSION)); @@ -266,14 +264,13 @@ TEST(Metadata, InMemory) { } TEST(Metadata, ZeroElementLists) { - Metadata md = - Metadata(make_object( - TVM_METADATA_VERSION, std::vector({}), - std::vector({TensorInfo( - make_object( - tvm::String("Output1"), std::vector{}, - tvm::runtime::DataType(DLDataType{3, 4, 5})))}), - std::vector({}), "default")); + Metadata md = Metadata(make_object( + TVM_METADATA_VERSION, std::vector({}), + std::vector( + {TensorInfo(make_object( + tvm::String("Output1"), std::vector{}, + tvm::runtime::DataType(DLDataType{3, 4, 5})))}), + std::vector({}), "default")); EXPECT_THAT(md->data()->num_inputs, Eq(0)); EXPECT_THAT(md->inputs().size(), Eq(0)); @@ -304,13 +301,14 @@ TEST(MetadataArray, GetElementCStructName) { namespace { std::string ExplainDiscoveredNameEq(bool negation, std::string expected_name) { std::stringstream ss; - ss << "std::get<0>(discovered_array) " << (negation ? "isn't" : "is") << " equal to " << expected_name; + ss << "std::get<0>(discovered_array) " << (negation ? "isn't" : "is") << " equal to " + << expected_name; return ss.str(); } -} +} // namespace MATCHER_P(DiscoveredNameEq, expected_name, ExplainDiscoveredNameEq(negation, expected_name)) { - return std::string{std::get<0>(arg)} == expected_name; + return std::string {std::get<0>(arg)} == expected_name; } TEST(DiscoverArraysVisitor, DiscoverArrays) { @@ -320,22 +318,24 @@ TEST(DiscoverArraysVisitor, DiscoverArrays) { Metadata md = Metadata(&kNormal); visitor.Visit(kMetadataGlobalSymbol, &md); - EXPECT_THAT(q, ElementsAreArray( - {DiscoveredNameEq("kTvmgenMetadata_inputs_0_shape"), - DiscoveredNameEq("kTvmgenMetadata_inputs_1_shape"), - DiscoveredNameEq("kTvmgenMetadata_inputs"), - DiscoveredNameEq("kTvmgenMetadata_outputs_0_shape"), - DiscoveredNameEq("kTvmgenMetadata_outputs"), - DiscoveredNameEq("kTvmgenMetadata_pools_0_shape"), - DiscoveredNameEq("kTvmgenMetadata_pools")})); + EXPECT_THAT(q, ElementsAreArray({DiscoveredNameEq("kTvmgenMetadata_inputs_0_shape"), + DiscoveredNameEq("kTvmgenMetadata_inputs_1_shape"), + DiscoveredNameEq("kTvmgenMetadata_inputs"), + DiscoveredNameEq("kTvmgenMetadata_outputs_0_shape"), + DiscoveredNameEq("kTvmgenMetadata_outputs"), + DiscoveredNameEq("kTvmgenMetadata_pools_0_shape"), + DiscoveredNameEq("kTvmgenMetadata_pools")})); } -template ::value, bool> = true> +template ::value, bool> = + true> class TVMObjectIsInstanceMatcher : public MatcherInterface { public: using is_gtest_matcher = void; - bool MatchAndExplain(tvm::runtime::metadata::MetadataBase arg, MatchResultListener* os) const override { + bool MatchAndExplain(tvm::runtime::metadata::MetadataBase arg, + MatchResultListener* os) const override { bool result = arg->IsInstance(); if (!result) { (*os) << "is an instance of type " << T::ContainerType::_type_key; @@ -344,7 +344,7 @@ class TVMObjectIsInstanceMatcher : public MatcherInterface(), - TVMObjectIsInstance())); + EXPECT_THAT(q, ElementsAre(TVMObjectIsInstance(), TVMObjectIsInstance())); } diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index ea0fd01d2e66..0610559a05d8 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -635,16 +635,24 @@ def non_integer_typed_block_iter(): def test_non_integer_typed_block_iter(): check_error(non_integer_typed_block_iter, 3) + def preflattened_buffer_map_align_nonint(foo: T.handle): foo_1 = T.match_buffer(foo, [1]) - T.preflattened_buffer(foo_1, [1], align="bar") # check_error: align: want int or IntImm, got 'bar' + T.preflattened_buffer( + foo_1, [1], align="bar" + ) # check_error: align: want int or IntImm, got 'bar' + def test_preflattened_buffer_map_align(): check_error(preflattened_buffer_map_align_nonint, 3) + def preflattened_buffer_map_offset_factor_nonint(foo: T.handle): foo_1 = T.match_buffer(foo, [1]) - T.preflattened_buffer(foo_1, [1], offset_factor="bar") # check_error: offset_factor: want int or IntImm, got 'bar' + T.preflattened_buffer( + foo_1, [1], offset_factor="bar" + ) # check_error: offset_factor: want int or IntImm, got 'bar' + def test_preflattened_buffer_map_offset_factor(): check_error(preflattened_buffer_map_offset_factor_nonint, 3) diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py index c17c7b4348e4..d8fe8b75a108 100644 --- a/tests/python/unittest/test_tvmscript_syntax_sugar.py +++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py @@ -180,6 +180,7 @@ def test_dynamic_shape_gemm(): gemm_dyn_shape_roundtrip = from_source(gemm_dyn_shape.script()) assert_structural_equal(gemm_dyn_shape, gemm_dyn_shape_roundtrip) + @T.prim_func def preflattened_buffer_map(A: T.handle, B: T.handle): A_1 = T.match_buffer(A, [1]) @@ -188,8 +189,11 @@ def preflattened_buffer_map(A: T.handle, B: T.handle): T.preflattened_buffer(B_1, [1]) B_1[0] = A_1[0] + def test_preflattened_buffer_map(): - A_var = [k for k, _ in preflattened_buffer_map.preflattened_buffer_map.items() if k.name == "A"][0] + A_var = [ + k for k, _ in preflattened_buffer_map.preflattened_buffer_map.items() if k.name == "A" + ][0] assert preflattened_buffer_map.preflattened_buffer_map[A_var].data_alignment == 1 assert preflattened_buffer_map.preflattened_buffer_map[A_var].offset_factor == 2 From 131722e18fcde963b65d4d624d63fea54a971f72 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Wed, 13 Apr 2022 14:14:22 -0700 Subject: [PATCH 30/32] lint --- tests/cpp/aot_metadata_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/aot_metadata_test.cc b/tests/cpp/aot_metadata_test.cc index a5b3cd0fac14..b1dea64aaa9c 100644 --- a/tests/cpp/aot_metadata_test.cc +++ b/tests/cpp/aot_metadata_test.cc @@ -308,7 +308,7 @@ std::string ExplainDiscoveredNameEq(bool negation, std::string expected_name) { } // namespace MATCHER_P(DiscoveredNameEq, expected_name, ExplainDiscoveredNameEq(negation, expected_name)) { - return std::string {std::get<0>(arg)} == expected_name; + return std::string(std::get<0>(arg)) == expected_name; } TEST(DiscoverArraysVisitor, DiscoverArrays) { From ee1877cc9d0c888349fc0110617b9245a5701bd3 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Mon, 18 Apr 2022 14:08:31 -0700 Subject: [PATCH 31/32] Fix GMock linking on new CMake --- CMakeLists.txt | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2bd33d249fb3..09b9aeb4db3b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -431,14 +431,24 @@ if(USE_GTEST) find_package(GTest REQUIRED) endif() if(GTEST_FOUND) - # GMock is formally supported in CMake 3.20; for now, expect libgmock.a in the same directory, - # and require that folks compiling against GMock::GMock also link against GTest::GTest - # (for the includes dir). - add_library(GMock::GMock STATIC IMPORTED GLOBAL) - get_target_property(GTEST_LIB_PATH GTest::GTest IMPORTED_LOCATION) - get_filename_component(GTEST_LIB_DIR "${GTEST_LIB_PATH}" DIRECTORY) - set_target_properties(GMock::GMock PROPERTIES - IMPORTED_LOCATION "${GTEST_LIB_DIR}/libgmock.a") + if(NOT TARGET GTest::gmock) + # GMock is formally supported in CMake 3.20; for now, expect libgmock.a in the same directory, + # and require that folks compiling against GTest::gmock also link against GTest::GTest + # (for the includes dir). + add_library(GTest::gmock STATIC IMPORTED GLOBAL) + get_target_property(GTEST_LIB_PATH GTest::GTest IMPORTED_LOCATION) + if("${GTEST_LIB_PATH}" STREQUAL "GTEST_LIB_PATH-NOTFOUND") + # CMake >= 3.20 makes GTest::GTest into a compatibility target. The real import location is in + # GTest::gtest. + get_target_property(GTEST_LIB_PATH GTest::gtest IMPORTED_LOCATION) + if("${GTEST_LIB_PATH}" STREQUAL "GTEST_LIB_PATH-NOTFOUND") + message(FATAL_ERROR "Neither GTest::GTest nor GTets::gtest targets defined IMPORTED_LOCATION") + endif() + endif() + get_filename_component(GTEST_LIB_DIR "${GTEST_LIB_PATH}" DIRECTORY) + set_target_properties(GTest::gmock PROPERTIES + IMPORTED_LOCATION "${GTEST_LIB_DIR}/libgmock.a") + endif() enable_testing() include(CTest) @@ -635,7 +645,7 @@ if(GTEST_FOUND) add_executable(cpptest ${TEST_SRCS}) # include runtime files for unit testing target_include_directories(cpptest PUBLIC "src/runtime") - target_link_libraries(cpptest PRIVATE ${TVM_TEST_LIBRARY_NAME} GTest::GTest GTest::Main GMock::GMock pthread dl) + target_link_libraries(cpptest PRIVATE ${TVM_TEST_LIBRARY_NAME} GTest::GTest GTest::Main GTest::gmock pthread dl) set_target_properties(cpptest PROPERTIES EXCLUDE_FROM_ALL 1) set_target_properties(cpptest PROPERTIES EXCLUDE_FROM_DEFAULT_BUILD 1) # For some reason, compile definitions are not propagated correctly, so we manually add them here From 374da005e1da414def82998bd639711c33f23a00 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Mon, 18 Apr 2022 14:19:18 -0700 Subject: [PATCH 32/32] address masahi comment --- src/target/llvm/codegen_llvm.h | 1 - 1 file changed, 1 deletion(-) diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 3c80952b1fd7..7f84119345db 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -519,7 +519,6 @@ void CodeGenLLVM::AddFunctionsOrdered(IterType begin, IterType end, ConvType pfu }); for (auto& f : funcs) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - LOG(INFO) << "Adding " << static_cast(global_symbol.value()); AddFunction(f); } }