diff --git a/CMakeLists.txt b/CMakeLists.txt index c667750ed5d6..f56e8946d398 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,6 +40,7 @@ tvm_option(USE_LLVM "Build with LLVM, can be set to specific llvm-config path" O tvm_option(USE_STACKVM_RUNTIME "Include stackvm into the runtime" OFF) tvm_option(USE_GRAPH_EXECUTOR "Build with tiny graph executor" ON) tvm_option(USE_GRAPH_EXECUTOR_CUDA_GRAPH "Build with tiny graph executor with CUDA Graph for GPUs" OFF) +tvm_option(USE_AOT_EXECUTOR "Build with AOT executor" ON) tvm_option(USE_PROFILER "Build profiler for the VM and graph executor" ON) tvm_option(USE_OPENMP "Build with OpenMP thread pool implementation" OFF) tvm_option(USE_RELAY_DEBUG "Building Relay in debug mode..." OFF) @@ -392,6 +393,13 @@ if(USE_PROFILER) list(APPEND RUNTIME_SRCS ${RUNTIME_VM_PROFILER_SRCS}) endif(USE_PROFILER) +if(USE_AOT_EXECUTOR) + message(STATUS "Build with AOT Executor support...") + file(GLOB RUNTIME_AOT_EXECUTOR_SRCS src/runtime/aot_executor/*.cc) + list(APPEND RUNTIME_SRCS ${RUNTIME_AOT_EXECUTOR_SRCS}) + +endif(USE_AOT_EXECUTOR) + # Enable ctest if gtest is available if(USE_GTEST) # Check env var for backward compatibility. A better way to specify package diff --git a/include/tvm/relay/runtime.h b/include/tvm/relay/runtime.h index c4cabf5a5548..50f0b07e5005 100644 --- a/include/tvm/relay/runtime.h +++ b/include/tvm/relay/runtime.h @@ -44,6 +44,12 @@ class AttrRegistry; namespace relay { +/*! \brief Value used with Runtime::name to indicate the C++ runtime. */ +static constexpr const char* kTvmRuntimeCpp = "cpp"; + +/*! \brief Value used with Runtime::name to indicate the C runtime. */ +static constexpr const char* kTvmRuntimeCrt = "crt"; + /*! * \brief Runtime information. * diff --git a/include/tvm/runtime/metadata.h b/include/tvm/runtime/metadata.h new file mode 100644 index 000000000000..3fab610ef055 --- /dev/null +++ b/include/tvm/runtime/metadata.h @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/metadata.h + * \brief Defines types which can be used in Metadata. + */ +#ifndef TVM_RUNTIME_METADATA_H_ +#define TVM_RUNTIME_METADATA_H_ + +#include + +#include +#include +#include + +// TODO(areusch): idk what's up here. +#include // NOLINT(build/include_order) +#include // NOLINT(build/include_order) +#include // NOLINT(build/include_order) + +#define TVM_METADATA_VERSION 1 +static const constexpr int64_t kMetadataVersion = TVM_METADATA_VERSION; +#ifdef __cplusplus +extern "C" { +#endif + +struct TVMMetadata { + int64_t version; + const struct TVMTensorInfo* inputs; + int64_t num_inputs; + const struct TVMTensorInfo* outputs; + int64_t num_outputs; + const char** devices; + int64_t num_devices; + const char* executor; + const char* mod_name; + const char* interface_api; + bool use_unpacked_api; +}; + +struct TVMTensorInfo { + const char* name; + const int64_t* shape; + int64_t num_shape; + DLDataType dtype; +}; +#ifdef __cplusplus +} // extern "C" +#include +namespace tvm { +namespace runtime { +namespace metadata { + +class Metadata; +class TensorInfo; + +class MetadataNode : public MetadataBaseNode { + public: + explicit MetadataNode(const struct ::TVMMetadata* data) : data_{data} {} + static constexpr const char* _type_key = "metadata.MetadataNode"; + std::string get_name() override; + inline int64_t version() const { return int64_t(data_->version); } + inline int64_t num_inputs() const { return data_->num_inputs; } + ArrayAccessor inputs(); + inline int64_t num_outputs() const { return data_->num_outputs; } + ArrayAccessor outputs(); + inline int64_t num_devices() const { return data_->num_devices; } + ArrayAccessor devices(); + inline ::tvm::runtime::String executor() const { return ::tvm::runtime::String(data_->executor); } + inline ::tvm::runtime::String mod_name() const { return ::tvm::runtime::String(data_->mod_name); } + inline ::tvm::runtime::String interface_api() const { + return ::tvm::runtime::String(data_->interface_api); + } + inline bool use_unpacked_api() const { return static_cast(data_->use_unpacked_api); } + const struct ::TVMMetadata* data() const { return data_; } + TVM_DECLARE_FINAL_OBJECT_INFO(MetadataNode, MetadataBaseNode); + + private: + const struct ::TVMMetadata* data_; +}; + +class Metadata : public MetadataBase { + public: + explicit Metadata(const struct ::TVMMetadata* data); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Metadata, MetadataBase, MetadataNode); +}; + +class TensorInfoNode : public MetadataBaseNode { + public: + explicit TensorInfoNode(const struct ::TVMTensorInfo* data) : data_{data} {} + static constexpr const char* _type_key = "metadata.TensorInfoNode"; + std::string get_name() override; + inline ::tvm::runtime::String name() const { return ::tvm::runtime::String(data_->name); } + inline int64_t num_shape() const { return data_->num_shape; } + inline ::tvm::support::Span shape() const { + return ::tvm::support::Span(data_->shape, + data_->shape + data_->num_shape); + } + inline ::tvm::runtime::DataType dtype() const { return ::tvm::runtime::DataType(data_->dtype); } + const struct ::TVMTensorInfo* data() const { return data_; } + TVM_DECLARE_FINAL_OBJECT_INFO(TensorInfoNode, MetadataBaseNode); + + private: + const struct ::TVMTensorInfo* data_; +}; + +class TensorInfo : public MetadataBase { + public: + explicit TensorInfo(const struct ::TVMTensorInfo* data); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorInfo, MetadataBase, TensorInfoNode); +}; + +} // namespace metadata +} // namespace runtime +} // namespace tvm +#endif // defined(__cplusplus) + +#endif // TVM_RUNTIME_METADATA_H_ diff --git a/include/tvm/runtime/metadata_base.h b/include/tvm/runtime/metadata_base.h new file mode 100644 index 000000000000..b707bdf68a96 --- /dev/null +++ b/include/tvm/runtime/metadata_base.h @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/metadata_base.h + * \brief Defines types which can be used in Metadata. + */ +#ifndef TVM_RUNTIME_METADATA_BASE_H_ +#define TVM_RUNTIME_METADATA_BASE_H_ + +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace metadata { + +class MetadataBaseNode : public ::tvm::runtime::Object { + public: + virtual std::string get_name() = 0; + + static constexpr const char* _type_key = "metadata.MetadataBaseNode"; + TVM_DECLARE_BASE_OBJECT_INFO(MetadataBaseNode, ::tvm::runtime::Object); +}; + +class MetadataBase : public ::tvm::runtime::ObjectRef { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataBase, ::tvm::runtime::ObjectRef, MetadataBaseNode); +}; + +template +class ArrayAccessor; + +template +class ArrayIterator { + public: + using value_type = Ref; + + ArrayIterator(size_t index, const ArrayAccessor* parent) : index_{index}, parent_{parent} {} + + inline Ref operator*() { return (*parent_)[index_]; } + + inline ArrayIterator& operator++() { + if (index_ < parent_->size()) { + index_++; + } + + return *this; + } + + inline bool operator==(const ArrayIterator& other) { + return parent_ == other.parent_ && index_ == other.index_; + } + + inline bool operator!=(const ArrayIterator& other) { return !operator==(other); } + + private: + size_t index_; + const ArrayAccessor* parent_; +}; + +template +class ArrayAccessor { + public: + using value_type = Ref; + using iterator = ArrayIterator; + using const_iterator = ArrayIterator; + + template ::value>::type> + ArrayAccessor(const C* data, size_t num_data) : data_{data}, num_data_{num_data} {} + + inline size_t size() const { return num_data_; } + + inline Ref operator[](size_t index) const { + if (index >= num_data_) { + throw std::runtime_error("Index out of range"); + } + + return Ref(&data_[index]); + } + + inline ArrayIterator begin() const { return ArrayIterator{0, this}; } + + inline ArrayIterator end() const { return ArrayIterator{num_data_, this}; } + + private: + const C* data_; + size_t num_data_; +}; + +template <> +class ArrayAccessor { + public: + using value_type = ::tvm::runtime::String; + using iterator = ArrayIterator; + using const_iterator = ArrayIterator; + + ArrayAccessor(const char** data, size_t num_data) : data_{data}, num_data_{num_data} {} + + inline size_t size() const { return num_data_; } + + inline ::tvm::runtime::String operator[](size_t index) const { + if (index >= num_data_) { + throw std::runtime_error("Index out of range"); + } + + return ::tvm::runtime::String(data_[index]); + } + + inline ArrayIterator begin() const { + return ArrayIterator{0, this}; + } + + inline ArrayIterator end() const { + return ArrayIterator{num_data_, this}; + } + + private: + const char** data_; + size_t num_data_; +}; + +enum MetadataTypeIndex : uint8_t { + kUint64 = 0, + kInt64 = 1, + kBool = 2, + kString = 3, + kMetadata = 4, +}; + +class MetadataArrayNode : public MetadataBaseNode { + public: + MetadataArrayNode(Array array, MetadataTypeIndex type_index, const char* struct_name) : + array{array}, type_index{type_index}, struct_name{struct_name} {} +// MetadataArrayNode(Array array, const char* c_type) +// : array(std::move(array)), c_type{c_type} {} + + std::string get_name() override; + + Array array; + MetadataTypeIndex type_index; + const char* struct_name; + static constexpr const char* _type_key = "metadata.MetadataArrayNode"; + TVM_DECLARE_BASE_OBJECT_INFO(MetadataArrayNode, MetadataBaseNode); +}; + +class MetadataArray : public MetadataBase { + public: + // MetadataArray(Array array, MetadataTypeIndex type_index); + MetadataArray(Array array, MetadataTypeIndex type_index, const char* struct_name); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataArray, MetadataBase, MetadataArrayNode); +}; + +} // namespace metadata +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_METADATA_BASE_H_ diff --git a/include/tvm/support/span.h b/include/tvm/support/span.h new file mode 100644 index 000000000000..2bd94b8bf338 --- /dev/null +++ b/include/tvm/support/span.h @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file tvm/support/span.h + * \brief Reimplementation of part of C++-20 style span. + */ +#ifndef TVM_SUPPORT_SPAN_H_ +#define TVM_SUPPORT_SPAN_H_ + +#include +#include +#include + +namespace tvm { +namespace support { + +template +class Span { + public: + using value_type = W; + using reference = W&; + using const_reference = const W&; + using pointer = W*; + using const_pointer = const W*; + + class iterator : public std::iterator { + public: + inline iterator(T* ptr, T* end) : ptr_{ptr}, end_{end} { CHECK_GE(end, ptr); } + + inline W operator*() { return W(*ptr_); } + + inline iterator& operator++() { + if (ptr_ != end_) ptr_++; + return *this; + } + + inline bool operator==(iterator other) const { return ptr_ == other.ptr_ && end_ == other.end_; } + + inline bool operator!=(iterator other) const { return !(*this == other); } + + protected: + T* ptr_; + T* end_; + }; + + using const_iterator = iterator; + + inline Span(T* begin, int num_elements) : begin_{begin}, end_{begin + num_elements} {} + inline Span(T* begin, T* end) : begin_{begin}, end_{end} {} + + inline iterator begin() const { return iterator(begin_, end_); } + + inline iterator end() const { return iterator(end_, end_); } + + inline W operator[](int i) { + T* to_return = begin_ + i; + ICHECK_LT(to_return, end_) << "Span access out of bounds: " << i; + return W(*to_return); + } + + inline operator std::vector() { return std::vector(begin(), end()); } + + protected: + T* begin_; + T* end_; +}; + +} // namespace support +} // namespace tvm + +#endif // TVM_SUPPORT_SPAN_H_ diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index e802a3088d2d..6e9c8445695c 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -162,12 +162,6 @@ class TargetKindAttrMap : public AttrRegistryMap { explicit TargetKindAttrMap(const AttrRegistryMapContainerMap& map) : TParent(map) {} }; -/*! \brief Value used with --runtime in target specs to indicate the C++ runtime. */ -static constexpr const char* kTvmRuntimeCpp = "c++"; - -/*! \brief Value used with --runtime in target specs to indicate the C runtime. */ -static constexpr const char* kTvmRuntimeCrt = "c"; - /*! * \brief Helper structure to register TargetKind * \sa TVM_REGISTER_TARGET_KIND diff --git a/python/tvm/contrib/graph_executor.py b/python/tvm/contrib/graph_executor.py index 6337c6e6fec5..5c97e4c33b50 100644 --- a/python/tvm/contrib/graph_executor.py +++ b/python/tvm/contrib/graph_executor.py @@ -188,7 +188,7 @@ def set_input(self, key=None, value=None, **params): keys.sort(key=lambda x: -np.prod(params[x].shape)) for k in keys: # TODO(zhiics) Skip the weights for submodule in a better way. - # We should use MetadataModule for initialization and remove + # We should use ConstLoaderModule for initialization and remove # params from set_input val = self._get_input(k) if val: diff --git a/python/tvm/relay/backend/executor_factory.py b/python/tvm/relay/backend/executor_factory.py index 5f4a134270ac..b836ce914696 100644 --- a/python/tvm/relay/backend/executor_factory.py +++ b/python/tvm/relay/backend/executor_factory.py @@ -105,6 +105,13 @@ def __init__( function_metadata, devices, ): + fcreate = get_global_func("tvm.aot_executor_factory.create") + args = [] + for k, v in params.items(): + args.append(k) + args.append(ndarray.array(v)) + + self.module = fcreate(libmod, libmod_name, *args) self.ir_mod = ir_mod self.lowered_ir_mods = lowered_ir_mods self.target = target @@ -128,6 +135,9 @@ def get_executor_config(self): def get_lib(self): return self.lib + def export_library(self, file_name, fcompile=None, addons=None, **kwargs): + return self.module.export_library(file_name, fcompile, addons, **kwargs) + class GraphExecutorFactoryModule(ExecutorFactoryModule): """Graph executor factory module. diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index b3504dbac506..ab0fc1709fa9 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -31,3 +31,5 @@ from .module import load_module, enabled, system_lib from .container import String, ShapeTuple from .params import save_param_dict, load_param_dict + +from . import executor diff --git a/python/tvm/runtime/executor/__init__.py b/python/tvm/runtime/executor/__init__.py new file mode 100644 index 000000000000..0748bbd00aec --- /dev/null +++ b/python/tvm/runtime/executor/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Top-level file for the executor module.""" + +from .aot_executor import AotModule diff --git a/python/tvm/runtime/executor/aot_executor.py b/python/tvm/runtime/executor/aot_executor.py new file mode 100644 index 000000000000..91f056ff25fa --- /dev/null +++ b/python/tvm/runtime/executor/aot_executor.py @@ -0,0 +1,182 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A Python wrapper for the Module-based Model Runtime Interface for Ahead-of-Time compilation.""" + +import numpy as np + + +class AotModule(object): + """Wraps the AOT executor runtime.Module. + + This is a thin wrapper of the underlying TVM module. + you can also directly call set_input, run, and get_output + of underlying module functions + + Parameters + ---------- + module : tvm.runtime.Module + The internal tvm module that holds the actual graph functions. + + Attributes + ---------- + module : tvm.runtime.Module + The internal tvm module that holds the actual graph functions. + + Examples + -------- + + .. code-block:: python + + import tvm + from tvm import relay + from tvm.contrib import graph_executor + + # build the library using graph executor + lib = relay.build(...) + lib.export_library("compiled_lib.so") + # load it back as a runtime + lib: tvm.runtime.Module = tvm.runtime.load_module("compiled_lib.so") + # Call the library factory function for default and create + # a new runtime.Module, wrap with graph module. + gmod = graph_executor.GraphModule(lib["default"](dev)) + # use the graph module. + gmod.set_input("x", data) + gmod.run() + """ + + def __init__(self, module): + self.module = module + self._set_input = module["set_input"] + self._run = module["run"] + self._get_output = module["get_output"] + self._get_input = module["get_input"] + self._get_num_outputs = module["get_num_outputs"] + self._get_input_index = module["get_input_index"] + self._get_num_inputs = module["get_num_inputs"] + + def set_input(self, key=None, value=None, **params): + """Set inputs to the module via kwargs + + Parameters + ---------- + key : int or str + The input key + + value : the input value. + The input key + + params : dict of str to NDArray + Additional arguments + """ + if key is not None: + v = self._get_input(key) + if v is None: + raise RuntimeError("Could not find '%s' in graph's inputs" % key) + v.copyfrom(value) + + if params: + # upload big arrays first to avoid memory issue in rpc mode + keys = list(params.keys()) + keys.sort(key=lambda x: -np.prod(params[x].shape)) + for k in keys: + # TODO(zhiics) Skip the weights for submodule in a better way. + # We should use MetadataModule for initialization and remove + # params from set_input + val = self._get_input(k) + if val: + self._get_input(k).copyfrom(params[k]) + + def run(self, **input_dict): + """Run forward execution of the graph + + Parameters + ---------- + input_dict: dict of str to NDArray + List of input values to be feed to + """ + if input_dict: + self.set_input(**input_dict) + self._run() + + def get_num_outputs(self): + """Get the number of outputs from the graph + + Returns + ------- + count : int + The number of outputs. + """ + return self._get_num_outputs() + + def get_num_inputs(self): + """Get the number of inputs to the graph + + Returns + ------- + count : int + The number of inputs. + """ + return self._get_num_inputs() + + def get_input(self, index, out=None): + """Get index-th input to out + + Parameters + ---------- + index : int + The input index + + out : NDArray + The output array container + """ + if out: + self._get_input(index).copyto(out) + return out + + return self._get_input(index) + + def get_input_index(self, name): + """Get inputs index via input name. + + Parameters + ---------- + name : str + The input key name + + Returns + ------- + index: int + The input index. -1 will be returned if the given input name is not found. + """ + return self._get_input_index(name) + + def get_output(self, index, out=None): + """Get index-th output to out + + Parameters + ---------- + index : int + The output index + + out : NDArray + The output array container + """ + if out: + self._get_output(index, out) + return out + + return self._get_output(index) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index d901f8a26c4f..a1b86f32c2b5 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -41,6 +42,7 @@ #include #include +#include "../../target/metadata.h" #include "../op/annotation/annotation.h" #include "../op/call/call.h" #include "../op/memory/device_copy.h" @@ -68,6 +70,7 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { void Run(const Function& func) { VisitExpr(func); } std::vector GetReturnIds() const { return return_ids_; } + std::vector GetReturnTtypes() const { return return_ttypes_; } StorageMap GetStorageMap() const { return storage_device_map_; } @@ -147,6 +150,7 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { sid->storage_sizes_in_bytes.begin(), sid->storage_sizes_in_bytes.end()); } + LOG(INFO) << "Visit tuple: " << GetRef(op); storage_device_map_[expr] = StorageInfo(storage_ids, virtual_devices, storage_sizes_in_bytes); AssignReturnSid(expr); } @@ -155,6 +159,7 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { Expr expr = GetRef(op); auto sids = GetStorage(op->tuple); ICHECK_LT(static_cast(op->index), sids->storage_ids.size()); + LOG(INFO) << "Visit TupleGetItem: " << expr; storage_device_map_[expr] = StorageInfo({sids->storage_ids[op->index]}, {sids->virtual_devices[op->index]}, {sids->storage_sizes_in_bytes[op->index]}); @@ -170,11 +175,19 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { private: void AssignReturnSid(Expr e) { if (storage_device_map_.find(e) != storage_device_map_.end()) { +// LOG(INFO) << "AssignReturnSid: is now " << e; StorageInfo& sinfo = storage_device_map_[e]; +// LOG(INFO) << "AssignReturnSid: storage_device_map_ " << sinfo; return_ids_.clear(); for (auto sid : sinfo->storage_ids) { return_ids_.push_back(sid); } + return_ttypes_.clear(); + auto ttypes = FlattenTupleType(e->checked_type()); + return_ttypes_.reserve(ttypes.size()); + for (auto ttype : ttypes) { + return_ttypes_.push_back(ttype); + } } } /*! @@ -240,6 +253,7 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { virtual_devices.push_back(virtual_device); storage_sizes_in_bytes.push_back(GetMemorySizeBytes(ttype)); } +// LOG(INFO) << "CreateStorage: " << expr; storage_device_map_[expr] = StorageInfo(std::move(storage_ids), std::move(virtual_devices), std::move(storage_sizes_in_bytes)); } @@ -250,6 +264,8 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { int next_available_sid_{0}; /*! \brief the set of intermediate tensors that are return variables */ std::vector return_ids_; + /*! \brief the data types of the return values */ + std::vector return_ttypes_; }; /*! \brief Code generator for AOT executor */ @@ -315,6 +331,45 @@ class AOTExecutorCodegen : public MixedModeVisitor { } } + /*! \brief Return a PrimExpr which contains the arg to be passed down to a PrimFunc. + * + * TODO(areusch): Document the various cases which could necessitate us synthesizing + * a DLTensor on stack. + */ + PrimExpr MakeDLTensor(Expr relay_arg, TensorType ttype, PrimExpr data) { +// LOG(INFO) << "MakeDLTensor: " << relay_arg << " (ttype " << ttype << "): " << data; + return data; + } + // for (Var v : input_vars_) { + // if (v == relay_arg) { + // return data; + // } + // } + // for (int return_sid : return_sid_) { + // auto return_expr = sids_table_[return_sid]; + // if (return_expr == relay_arg) { + // return data; + // } + // } + // return data; + // } + + void PushTuple(Expr tuple, std::vector sids, Array* args) { +// CHECK_EQ(sids.size(), tuple->fields.size()) +// << "Relay tuple does not map 1:1 into TIR; AOT can't handle this type of Relay Expr in a " +// "CallNode."; + StorageInfo& sinfo = storage_device_map_[tuple]; + for (unsigned int i = 0; i < sids.size(); ++i) { + if (std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[i]) != + return_sid_.end()) { + args->push_back(sids[i]); + } else { + args->push_back(sids[i]); //MakeDLTensor( +// tuple->fields[i], Downcast(tuple->fields[i]->checked_type()), sids[i])); + } + } + } + /*! * brief Create a function call * \param call_lowered_props The lowered function and the arguments to call it with @@ -329,32 +384,68 @@ class AOTExecutorCodegen : public MixedModeVisitor { // Pack the inputs for (const Expr& arg : call_lowered_props.arguments) { if (params_by_expr_.find(arg) != params_by_expr_.end()) { - auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), - {tir::StringImm(params_by_expr_[arg])}); - args.push_back(param_handle); + args.push_back(MakeDLTensor( + arg, Downcast(arg->checked_type()), + tir::Cast(runtime::DataType(DataType::TypeCode::kHandle, 32, 1), + tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), + {tir::StringImm(params_by_expr_[arg])})))); } else { - auto var_arg = FindExpr(arg); - for (const auto& var : var_arg) { - args.push_back(var); + auto sids = FindExpr(arg); + if (sids.size() > 1) { +// auto tuple = Downcast(arg); + PushTuple(arg, sids, &args); + } else { + StorageInfo& sinfo = storage_device_map_[arg]; + if (std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]) != + return_sid_.end()) { + args.push_back(sids[0]); + } else { + args.push_back(MakeDLTensor(arg, Downcast(arg->checked_type()), sids[0])); + } } } } // Pack the return(s) value. A call node can produce multiple outputs - for (const auto& var : PackSid(result_expr)) { - args.push_back(var); + auto result_expr_sid = PackSid(result_expr); + if (result_expr_sid.size() > 1) { + LOG(INFO) << "RESULT EXPR " << result_expr; + LOG(INFO) << "RESULT TYPE " << result_expr->checked_type(); + auto result_storage_device_map = storage_device_map_[result_expr]; + LOG(INFO) << "RESULT STORAGE DEVICE MAP " << result_storage_device_map; + std::stringstream rsid; + for (auto s : result_expr_sid) { + rsid << s << ","; + } + LOG(INFO) << "RESULT_EXPR SID " << rsid.str() << "(end)"; +// auto tuple = Downcast(result_expr); + + PushTuple(result_expr, result_expr_sid, &args); + + } else { + StorageInfo& sinfo = storage_device_map_[result_expr]; + if (std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]) != + return_sid_.end()) { + args.push_back(result_expr_sid[0]); + } else { + args.push_back(MakeDLTensor(result_expr, Downcast(result_expr->checked_type()), + result_expr_sid[0])); + } } - // Use tvm_call_packed to execute the function unless we're calling directly - auto calling_pattern = tvm::tir::builtin::tvm_call_cpacked(); + // Choose call style based on Runtime/Executor config. + Op calling_pattern; if (use_unpacked_api_) { calling_pattern = tvm::tir::builtin::call_extern(); + } else if (use_call_cpacked_) { + calling_pattern = tvm::tir::builtin::tvm_call_cpacked(); + } else { + calling_pattern = tvm::tir::builtin::tvm_call_packed(); } GlobalVar global_var = call_lowered_props.lowered_func; tir::Var empty_var("no_device_context", DataType::Handle()); bool has_c_device_api_context = device_contexts_.count(global_var) != 0; - bool use_cpacked_api = !use_unpacked_api_; // The device context is passed to the operator in one of the following calling patterns: // * Unpacked / direct function call with context: @@ -378,9 +469,10 @@ class AOTExecutorCodegen : public MixedModeVisitor { func_call, GenerateDeviceHook(context, "Close"), })); - } else if (use_cpacked_api) { + } else if (use_call_cpacked_ && !use_unpacked_api_) { // call_cpacked calling convention needs a blank context - args.push_back(tir::make_zero(DataType::Handle())); + // TOOD only c runtime +// args.push_back(tir::make_zero(DataType::Handle())); tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args)); create_func_call_stmts.push_back(func_call); } else { @@ -391,6 +483,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { tir::Stmt body = tir::SeqStmt(create_func_call_stmts); stmts_.push_back(body); + LOG(INFO) << "Create func call " << body; } /*! @@ -688,13 +781,25 @@ class AOTExecutorCodegen : public MixedModeVisitor { Target target_host_; /*! * \brief unpacked api toggle - * When set to true the code generated will use unpacked calls to functions: + * When set to true, the generated code will use unpacked calls to functions: * func(void* arg0, void* arg1) * Rather than packed calls: * func(void* args) * Defaults to using the packed calling convention */ Bool use_unpacked_api_; + /*! + * \brief cpacked api toggle + * When set to true, the generated code will use call_cpacked to call functions directly, assuming + * they exist in a DSO-exportable module. + * func(...) + * Rather than through the traditional call_packed calls, which should use function pointers + * looked-up through TVMBackendGetFuncFromEnv: + * TVMBackendPackedCFunc* func_ptr = TVMBackendGetFuncFromEnv("func"); + * func_ptr(...) + * Defaults to using the packed calling convention + */ + Bool use_call_cpacked_; /*! * \brief parameters (i.e. ConstantNodes found in the graph). @@ -721,7 +826,11 @@ class AOTExecutorCodegen : public MixedModeVisitor { public: AOTExecutorCodegen(runtime::Module* mod, const tec::TargetMap& targets, Target target_host) - : mod_(mod), targets_(targets), target_host_(target_host), use_unpacked_api_(Bool(false)) {} + : mod_(mod), + targets_(targets), + target_host_(target_host), + use_unpacked_api_(Bool(false)), + use_call_cpacked_(Bool(false)) {} LoweredOutput Codegen(IRModule mod, relay::Function func, String mod_name) { VLOG_CONTEXT << "AOT"; @@ -731,14 +840,40 @@ class AOTExecutorCodegen : public MixedModeVisitor { ICHECK(target_host_.defined()) << "require a target_host to be given for AOT codegen"; VLOG(1) << "target host: " << target_host_->ToDebugString(); + Runtime runtime_config = mod->GetAttr(tvm::attr::kRuntime).value(); Executor executor_config = mod->GetAttr(tvm::attr::kExecutor).value(); String interface_api = executor_config->GetAttr("interface-api").value_or("packed"); Integer workspace_byte_alignment = executor_config->GetAttr("workspace-byte-alignment").value_or(16); use_unpacked_api_ = executor_config->GetAttr("unpacked-api").value_or(Bool(false)); + use_call_cpacked_ = + (!use_unpacked_api_ || + // for now, C runtime does not support calling functions on other devices. therefore, + // opt to call PackedFunc directly by name rather than TVMBackendGetFuncFromEnv. + runtime_config->name == kTvmRuntimeCrt); + LOG(INFO) << "Use call cpacked? " << bool(use_call_cpacked_) << "; " << interface_api << ", unpacked=" << use_unpacked_api_; + + // Validate choice of use_unpacked_api_ and use_call_cpacked_ + if (runtime_config->name == kTvmRuntimeCrt) { + if (interface_api == "c") { + CHECK(static_cast(use_unpacked_api_) == true) + << "When interface_api == \"c\", need unpacked-api == true (got: " + << use_unpacked_api_ << ") when targeting c runtime"; + } + } else if (runtime_config->name == kTvmRuntimeCpp) { + CHECK(static_cast(use_unpacked_api_) == true || + static_cast(use_call_cpacked_) == true) + << "Need unpacked-api == false (got: " << use_unpacked_api_ + << ") and interface-api == \"c\" (got: " << interface_api + << ") when targeting c++ runtime"; + } else { + ICHECK(false) << "runtime_config (" << runtime_config->name + << ") is not one of the expected values"; + } // TODO(mbs): Plumb from compiler config VirtualDevice host_virtual_device = VirtualDevice::ForTarget(target_host_); + VLOG(1) << "relay mod:" << std::endl << PrettyPrint(mod); IRModule lowered_mod = tec::LowerTEPass( mod_name, @@ -797,6 +932,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { } CollectDeviceVariables(lowered_mod->GetAttr>("device_contexts").value()); + VLOG(1) << "lowered_main_func:" << std::endl << PrettyPrint(lowered_main_func); VisitExpr(lowered_main_func->body); // Create the runner function. Please note that the function is not legal yet @@ -836,8 +972,10 @@ class AOTExecutorCodegen : public MixedModeVisitor { // Legalize AOT if needed. This means that all the packed calls // need to be wrapped in TVMValues (unless use_unpacked_api is set) if (!use_unpacked_api_) { + LOG(INFO) << "Legalize Packed " << mod_run; auto pack_calls = tir::transform::LegalizePackedCalls(); mod_run = pack_calls(mod_run); + LOG(INFO) << "Legalize Packed done " << mod_run; } ret.function_metadata = std::move(function_metadata_); @@ -867,12 +1005,35 @@ class AOTExecutorCodegen : public MixedModeVisitor { ret.lowered_funcs.Set(target_host_, mod_run); } - std::vector input_var_names(input_vars_.size()); - std::transform(input_vars_.begin(), input_vars_.end(), input_var_names.begin(), - [](Var input_var) -> String { return input_var->name_hint(); }); - ret.metadata = - runtime::Metadata(input_var_names, ListDevices(), return_sid_.size(), - runtime::kTvmExecutorAot, mod_name, interface_api, use_unpacked_api_); + std::vector inputs; + for (auto v : input_vars_) { + auto ttype = Downcast(v->type_annotation); + inputs.push_back( + runtime::metadata::TensorInfo(make_object( + v->name_hint(), ShapeToJSON(ttype->shape), ttype->dtype))); + } + + LOG(INFO) << "MAKE METADATA? "; + std::vector outputs; + auto output_ttypes = final_aot_allocator.GetReturnTtypes(); + for (unsigned int i = 0; i < output_ttypes.size(); i++) { + auto ttype = Downcast(output_ttypes[i]); + std::stringstream name; + name << "output" << i; + outputs.push_back( + runtime::metadata::TensorInfo(make_object( + name.str(), ShapeToJSON(ttype->shape), ttype->dtype))); + } + auto devices = ListDevices(); + std::vector devices_vector; + for (auto d : devices) { + devices_vector.push_back(d.operator std::string()); + } + auto n = make_object( + kMetadataVersion, inputs, outputs, devices_vector, runtime::kTvmExecutorAot, mod_name, + interface_api, use_unpacked_api_); + ret.metadata = runtime::metadata::Metadata(std::move(n)); + LOG(INFO) << "MAKE METADATA: " << ret.metadata; return ret; } diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index ccfd30476f67..39c56d66947d 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -103,7 +103,9 @@ struct ExecutorCodegen { Array ListDevices() { return CallFunc>("get_devices"); } - runtime::Metadata GetMetadata() { return CallFunc("get_metadata"); } + runtime::metadata::Metadata GetMetadata() { + return CallFunc("get_metadata"); + } virtual ~ExecutorCodegen() {} protected: @@ -408,6 +410,7 @@ class RelayBuildModule : public runtime::ModuleNode { Function func = Downcast(relay_module->Lookup("main")); IRModule func_module = WithAttrs(IRModule::FromExpr(func), {{tvm::attr::kExecutor, executor_}, {tvm::attr::kRuntime, runtime_}}); + LOG(INFO) << "Executor " << executor_; // Generate code for the updated function. executor_codegen_ = MakeExecutorCodegen(executor_->name); diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index f61fe9b402b3..f9e5f4eb4a5b 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -290,25 +290,18 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator - * - * \param shape - * \return std::vector - */ - std::vector _ShapeToJSON(tvm::Array shape) { - std::vector ret; - for (IndexExpr dim : shape) { - const int64_t* pval = tir::as_const_int(dim); - ret.push_back(*pval); - } + std::vector inputs; + std::vector outputs; + std::vector devices_vector; + auto n = make_object( + kMetadataVersion, inputs, outputs, devices_vector, runtime::kTvmExecutorGraph, mod_name_, + "packed", Bool(false)); + ret.metadata = runtime::metadata::Metadata(std::move(n)); return ret; } + protected: /*! * \brief Add node to graph * @@ -352,7 +345,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorfields.size(); ++i) { if (const auto* typ = tuple_type->fields[i].as()) { ret.push_back(GraphNodeRef(node_id, i)); - shape.emplace_back(_ShapeToJSON(typ->shape)); + shape.emplace_back(ShapeToJSON(typ->shape)); dtype.emplace_back(DType2String(typ->dtype)); } else { LOG(FATAL) << "type " << checked_type->GetTypeKey() << " not supported"; @@ -369,7 +362,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator()) { ShapeVector shape; std::vector dtype; - shape.emplace_back(_ShapeToJSON(tensor_type->shape)); + shape.emplace_back(ShapeToJSON(tensor_type->shape)); dtype.emplace_back(DType2String(tensor_type->dtype)); node->attrs_["shape"] = shape; node->attrs_["dtype"] = dtype; diff --git a/src/relay/backend/runtime.cc b/src/relay/backend/runtime.cc index 786d6f937f14..923c9b2d5f65 100644 --- a/src/relay/backend/runtime.cc +++ b/src/relay/backend/runtime.cc @@ -88,9 +88,9 @@ RuntimeRegEntry& RuntimeRegEntry::RegisterOrGet(const String& name) { /********** Register Runtimes and options **********/ -TVM_REGISTER_RUNTIME("crt").add_attr_option("system-lib"); +TVM_REGISTER_RUNTIME(kTvmRuntimeCrt).add_attr_option("system-lib"); -TVM_REGISTER_RUNTIME("cpp").add_attr_option("system-lib"); +TVM_REGISTER_RUNTIME(kTvmRuntimeCpp).add_attr_option("system-lib"); /********** Registry **********/ diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 608d4cdb9f85..d8f3c6d2bc93 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -275,6 +275,15 @@ void UpdateAutoSchedulerOpWeights(const IRModule& module) { (*te_compiler_update_weights)(weight_map); } +std::vector ShapeToJSON(tvm::Array shape) { + std::vector ret; + for (IndexExpr dim : shape) { + const int64_t* pval = tir::as_const_int(dim); + ret.push_back(*pval); + } + return ret; +} + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 658283b5dc36..4a0221809a22 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -42,6 +42,7 @@ #include #include "../../runtime/meta_data.h" +#include "../../target/metadata.h" namespace tvm { namespace relay { @@ -147,7 +148,7 @@ struct LoweredOutput { Array external_mods; Map function_metadata; std::unordered_map> params; - runtime::Metadata metadata; + runtime::metadata::Metadata metadata; // points to InMemoryMetadataNode }; /*! @@ -527,6 +528,14 @@ Map TargetStrModuleMapToTargetModuleMap( */ void UpdateAutoSchedulerOpWeights(const IRModule& module); +/*! + * \brief Extract shape from expr to vector + * + * \param shape + * \return std::vector + */ +std::vector ShapeToJSON(tvm::Array shape); + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 73f4b672a81c..1253e8527739 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -46,6 +46,7 @@ #include #include +#include "../../../target/metadata.h" #include "../../../target/metadata_module.h" #include "../../../target/source/codegen_source_base.h" #include "../../op/annotation/annotation.h" @@ -1161,8 +1162,9 @@ void VMCompiler::Codegen() { lib = tvm::build(per_tvm_target_modules, config_->host_target); } - lib = codegen::CreateMetadataModule(params_, lib, ext_mods, config_->host_target, - Runtime::Create("cpp"), runtime::Metadata()); + lib = codegen::CreateMetadataModule( + params_, lib, ext_mods, config_->host_target, Runtime::Create("cpp"), + runtime::metadata::Metadata(make_object())); exec_->SetLib(lib); } diff --git a/src/runtime/aot_executor/aot_executor.cc b/src/runtime/aot_executor/aot_executor.cc new file mode 100644 index 000000000000..24a6a7328890 --- /dev/null +++ b/src/runtime/aot_executor/aot_executor.cc @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief Defines an implementation of Module-based Model Runtime Interface that works with + * Ahead-of-Time compilation. + * \file aot_executor.cc + */ + +#include "aot_executor.h" + +#include + +// TODO(areusch): idk what's up here... +#include // NOLINT(build/include_order) + +namespace tvm { +namespace runtime { + +AotExecutor::AotExecutor(tvm::runtime::Module module, const std::vector& devs) + : module_{module}, devices_{devs} { + auto fmetadata = module->GetFunction("get_metadata"); + CHECK(fmetadata != nullptr) << "Expected a module with PackedFunc get_metadata"; + auto ret_value = fmetadata(); + metadata_ = ret_value.AsObjectRef(); + + for (auto input : metadata_->inputs()) { + // TODO(areusch): Encode device information in Metadata. + args_.emplace_back(NDArray::Empty(ShapeTuple(input->shape().begin(), input->shape().end()), + input->dtype(), devices_[0])); + } + + for (auto output : metadata_->outputs()) { + args_.emplace_back(NDArray::Empty(ShapeTuple(output->shape().begin(), output->shape().end()), + output->dtype(), devices_[0])); + } +} + +PackedFunc AotExecutor::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { + // Return member functions during query. + if (name == "set_input") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (String::CanConvertFrom(args[0])) { + int in_idx = this->GetInputIndex(args[0].operator String()); + if (in_idx >= 0) this->SetInput(in_idx, args[1]); + } else { + this->SetInput(args[0], args[1]); + } + }); + } else if (name == "set_input_zero_copy") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (String::CanConvertFrom(args[0])) { + int in_idx = this->GetInputIndex(args[0].operator String()); + if (in_idx >= 0) this->SetInputZeroCopy(in_idx, args[1]); + } else { + this->SetInputZeroCopy(args[0], args[1]); + } + }); + } else if (name == "set_output_zero_copy") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (String::CanConvertFrom(args[0])) { + int out_idx = this->GetOutputIndex(args[0].operator String()); + if (out_idx >= 0) this->SetOutputZeroCopy(out_idx, args[1]); + } else { + this->SetOutputZeroCopy(args[0], args[1]); + } + }); + } else if (name == "get_output") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (args.num_args == 2) { + this->CopyOutputTo(args[0], args[1]); + } else { + *rv = this->GetOutput(args[0]); + } + }); + } else if (name == "get_input") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + int in_idx = 0; + if (String::CanConvertFrom(args[0])) { + in_idx = this->GetInputIndex(args[0].operator String()); + } else { + in_idx = args[0]; + } + if (in_idx >= 0) { + *rv = this->GetInput(in_idx); + } + }); + } else if (name == "get_num_outputs") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumOutputs(); }); + } else if (name == "get_num_inputs") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumInputs(); }); + } else if (name == "run") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Run(); }); + } else if (name == "get_input_index") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK(String::CanConvertFrom(args[0])) << "Input key is not a string"; + *rv = this->GetInputIndex(args[0].operator String()); + }); + } else { + return PackedFunc(); + } +} + +void AotExecutor::Run() { + LOG(INFO) << "Get entrypoint " << metadata_->mod_name() << "_run_model"; + auto pf = module_.GetFunction(metadata_->mod_name() + "_run_model", true /* query_imports */); + ICHECK(pf != nullptr) << "Module entrypoint is not defined"; + + const int num_args = args_.size(); + ::std::unique_ptr call_values{new TVMValue[num_args]}; + ::std::unique_ptr call_type_codes{new int[num_args]}; + for (int i = 0; i < num_args; ++i) { + auto managed = args_[i].ToDLPack(); + call_values.get()[i].v_handle = &managed->dl_tensor; + call_type_codes.get()[i] = kTVMDLTensorHandle; + } + + TVMArgs args{call_values.get(), call_type_codes.get(), num_args}; + TVMRetValue rv; + pf.CallPacked(args, &rv); +} + +int AotExecutor::GetInputIndex(const std::string& name) { + auto inputs = metadata_->inputs(); + for (unsigned int i = 0; i < inputs.size(); i++) { + if (inputs[i]->name() == name) { + return i; + } + } + return -1; +} + +int AotExecutor::GetOutputIndex(const std::string& name) { + auto outputs = metadata_->outputs(); + for (unsigned int i = 0; i < outputs.size(); i++) { + if (outputs[i]->name() == name) { + return i; + } + } + return -1; +} + +void AotExecutor::SetInput(int index, DLTensor* data_ref) { args_[index].CopyFrom(data_ref); } + +void AotExecutor::SetInputZeroCopy(int index, DLTensor* data_ref) { + ICHECK(false) << "not implemented"; +} + +void AotExecutor::SetOutputZeroCopy(int index, DLTensor* data_ref) { + ICHECK(false) << "not implemented"; +} + +int AotExecutor::NumOutputs() const { return metadata_->num_outputs(); } + +int AotExecutor::NumInputs() const { return metadata_->num_inputs(); } + +NDArray AotExecutor::GetInput(int index) const { return args_[index]; } + +NDArray AotExecutor::GetOutput(int index) const { return args_[metadata_->num_inputs() + index]; } + +void AotExecutor::CopyOutputTo(int index, DLTensor* data_out) { GetOutput(index).CopyTo(data_out); } + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/aot_executor/aot_executor.h b/src/runtime/aot_executor/aot_executor.h new file mode 100644 index 000000000000..a213ef83bdc2 --- /dev/null +++ b/src/runtime/aot_executor/aot_executor.h @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief Defines an implementation of Module-based Model Runtime Interface that works with + * Ahead-of-Time compilation. + * \file aot_executor.h + */ +#ifndef TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_H_ +#define TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_H_ + +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace runtime { + +class TVM_DLL AotExecutor : public ModuleNode { + public: + /*! + * \brief Implements member function lookup for this Module for the frontend. + * \param name The name of the function. + * \param sptr_to_self The pointer to the module node. + * \return The corresponding member function. + */ + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) override; + + /*! + * \return The type key of the executor. + */ + const char* type_key() const final { return "AotExecutor"; } + + void Run(); + + /*! + * \brief Initialize the AOT executor with metadata, runtime::Module, and device. + * \param module The module containing the compiled functions for the host + * processor. + * \param devs The device of the host and devices where graph nodes will be + * executed on. + * \param lookup_linked_param_func If given, a PackedFunc invoked to lookup linked parameters + * by storage_id. If not given, linked parameters are looked-up using an internal implementation, + * which is not compatible with RPCModules. Default is nullptr. + */ + AotExecutor(tvm::runtime::Module module, const std::vector& devs); + + /*! + * \brief Get the input index given the name of input. + * \param name The name of the input. + * \return The index of input. + */ + int GetInputIndex(const std::string& name); + + /*! + * \brief Get the output index given the name of output. + * \param name The name of the output. + * \return The index of output. + */ + int GetOutputIndex(const std::string& name); + + /*! + * \brief set index-th input to the graph. + * \param index The input index. + * \param data_in The input data. + */ + void SetInput(int index, DLTensor* data_in); + /*! + * \brief set index-th input to the graph without copying the data + * \param index The input index. + * \param data_ref The input data that is referred. + */ + void SetInputZeroCopy(int index, DLTensor* data_ref); + /*! + * \brief set index-th output to the graph without copying the data. + * \param index The output index. + * \param data_ref The output data that is referred. + */ + void SetOutputZeroCopy(int index, DLTensor* data_ref); + /*! + * \brief Get the number of outputs + * + * \return The number of outputs from graph. + */ + int NumOutputs() const; + /*! + * \brief Get the number of inputs + * + * \return The number of inputs to the graph. + */ + int NumInputs() const; + /*! + * \brief Return NDArray for given input index. + * \param index The input index. + * + * \return NDArray corresponding to given input node index. + */ + NDArray GetInput(int index) const; + /*! + * \brief Return NDArray for given output index. + * \param index The output index. + * + * \return NDArray corresponding to given output node index. + */ + NDArray GetOutput(int index) const; + /*! + * \brief Copy index-th output to data_out. + * \param index The output index. + * \param data_out the output data. + */ + void CopyOutputTo(int index, DLTensor* data_out); + + private: + /*! \brief Metadata provided to the runtime from the compiler. */ + metadata::Metadata metadata_; + + /*! \brief Runtime module which contains the AOT top-level function. */ + Module module_; + + /*! \brief The devices which should be used to execute the computations. */ + std::vector devices_; + + /*! \brief Holds one NDArray per function argument in the same order. */ + std::vector args_; +}; + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_H_ diff --git a/src/runtime/aot_executor/aot_executor_factory.cc b/src/runtime/aot_executor/aot_executor_factory.cc new file mode 100644 index 000000000000..4cb3026991fe --- /dev/null +++ b/src/runtime/aot_executor/aot_executor_factory.cc @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file aot_executor_factory.cc + * \brief Graph executor factory implementations + */ + +#include "./aot_executor_factory.h" + +#include +#include +#include + +#include +#include + +namespace tvm { +namespace runtime { + +AotExecutorFactory::AotExecutorFactory( + const std::unordered_map& params, + const std::string& module_name) { + params_ = params; + module_name_ = module_name; +} + +PackedFunc AotExecutorFactory::GetFunction( + const std::string& name, const tvm::runtime::ObjectPtr& sptr_to_self) { + if (name == module_name_) { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + ICHECK_GT(args.num_args, 0) << "Must supply at least one device argument"; + std::vector devices; + for (int i = 0; i < args.num_args; ++i) { + devices.emplace_back(args[i].operator Device()); + } + *rv = this->ExecutorCreate(devices); + }); + } else if (name == "remove_params") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::unordered_map empty_params{}; + auto exec = make_object(empty_params, this->module_name_); + exec->Import(this->imports_[0]); + *rv = Module(exec); + }); + } else { + return PackedFunc(); + } +} + +void AotExecutorFactory::SaveToBinary(dmlc::Stream* stream) { + std::vector names; + std::vector arrays; + for (const auto& v : params_) { + names.emplace_back(v.first); + arrays.emplace_back(const_cast(v.second.operator->())); + } + uint64_t sz = arrays.size(); + ICHECK(sz == names.size()); + stream->Write(sz); + stream->Write(names); + for (size_t i = 0; i < sz; ++i) { + tvm::runtime::SaveDLTensor(stream, arrays[i]); + } + stream->Write(module_name_); +} + +Module AotExecutorFactory::ExecutorCreate(const std::vector& devs) { + auto exec = make_object(this->imports_[0], devs); + // set params + SetParams(exec.get(), this->params_); + return Module(exec); +} + +Module AotExecutorFactoryModuleLoadBinary(void* strm) { + dmlc::Stream* stream = static_cast(strm); + std::unordered_map params; + std::string module_name; + uint64_t sz; + ICHECK(stream->Read(&sz)); + std::vector names; + ICHECK(stream->Read(&names)); + ICHECK(sz == names.size()); + for (size_t i = 0; i < sz; ++i) { + tvm::runtime::NDArray temp; + temp.Load(stream); + params[names[i]] = temp; + } + ICHECK(stream->Read(&module_name)); + auto exec = make_object(params, module_name); + return Module(exec); +} + +TVM_REGISTER_GLOBAL("tvm.aot_executor_factory.create").set_body([](TVMArgs args, TVMRetValue* rv) { + ICHECK_GE(args.num_args, 2) << "The expected number of arguments for " + "aot_executor_factory.create needs at least 2, " + "but it has " + << args.num_args; + // The argument order is module, module_name, param0_name, param0_tensor, + // [param1_name, param1_tensor], ... + ICHECK_EQ((args.size() - 2) % 2, 0); + std::unordered_map params; + for (size_t i = 2; i < static_cast(args.size()); i += 2) { + std::string name = args[i].operator String(); + params[name] = args[i + 1].operator tvm::runtime::NDArray(); + } + auto exec = make_object(params, args[1]); + exec->Import(args[0]); + *rv = Module(exec); +}); + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_AotExecutorFactory") + .set_body_typed(AotExecutorFactoryModuleLoadBinary); + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/aot_executor/aot_executor_factory.h b/src/runtime/aot_executor/aot_executor_factory.h new file mode 100644 index 000000000000..1d6a0a62776e --- /dev/null +++ b/src/runtime/aot_executor/aot_executor_factory.h @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/aot_executor/aot_executor_factory.h + * \brief Aot executor factory creating aot executor. + */ + +#ifndef TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_FACTORY_H_ +#define TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_FACTORY_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "./aot_executor.h" + +namespace tvm { +namespace runtime { + +class TVM_DLL AotExecutorFactory : public runtime::ModuleNode { + public: + /*! + * \brief Construct the AotExecutorFactory. + * \param params The params of aot. + * \param module_name The module name of aot. + */ + AotExecutorFactory(const std::unordered_map& params, + const std::string& module_name); + + /*! + * \brief Get member function to front-end + * \param name The name of the function. + * \param sptr_to_self The pointer to the module node. + * \return The corresponding member function. + */ + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; + + /*! + * \return The type key of the executor. + */ + const char* type_key() const override { return "AotExecutorFactory"; } + + /*! + * \brief Save the module to binary stream. + * \param stream The binary stream to save to. + */ + void SaveToBinary(dmlc::Stream* stream) override; + + /*! + * \brief Create a specific executor module + * \param devs The device of the host and devices where the model will be + * executed. + * \return created executor module + */ + Module ExecutorCreate(const std::vector& devs); + + /*! + * \brief Set params. + * \param aot_executor The aot executor we want to set the params into. + * \param params The aot params value we want to set. + */ + void SetParams(AotExecutor* aot_executor, + const std::unordered_map& params) const { + std::unordered_map value = params; + // upload big arrays first to avoid memory issue in rpc mode + std::vector keys; + for (const auto& p : value) { + keys.emplace_back(p.first); + } + std::sort(std::begin(keys), std::end(keys), + [&](const std::string& lhs, const std::string& rhs) -> bool { + auto lhs_size = GetDataSize(*value[lhs].operator->()); + auto rhs_size = GetDataSize(*value[rhs].operator->()); + return lhs_size > rhs_size; + }); + for (const auto& key : keys) { + int in_idx = aot_executor->GetInputIndex(key); + if (in_idx >= 0) { + aot_executor->SetInput(in_idx, const_cast(value[key].operator->())); + } + } + } + + protected: + /*! \brief The params. */ + std::unordered_map params_; + /*! \brief module name */ + std::string module_name_; +}; + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_FACTORY_H_ diff --git a/src/runtime/metadata_module.cc b/src/runtime/const_loader_module.cc similarity index 73% rename from src/runtime/metadata_module.cc rename to src/runtime/const_loader_module.cc index 7cb986bba62c..818525def22c 100644 --- a/src/runtime/metadata_module.cc +++ b/src/runtime/const_loader_module.cc @@ -43,17 +43,18 @@ namespace runtime { /*! * \brief The metadata module is designed to manage initialization of the - * imported submodules. + * imported submodules for the C++ runtime. */ -class MetadataModuleNode : public ModuleNode { +class ConstLoaderModuleNode : public ModuleNode { public: - MetadataModuleNode(const std::unordered_map& metadata, - const std::unordered_map>& sym_vars) - : metadata_(metadata), sym_vars_(sym_vars) { + ConstLoaderModuleNode( + const std::unordered_map& const_var_ndarray, + const std::unordered_map>& const_vars_by_symbol) + : const_var_ndarray_(const_var_ndarray), const_vars_by_symbol_(const_vars_by_symbol) { // Only the related submodules are cached to reduce the number of runtime // symbol lookup for initialization. Otherwise, symbols/primitives in the // DSO module will also be cached but they never need to be initialized. - for (const auto& it : sym_vars_) { + for (const auto& it : const_vars_by_symbol_) { initialized_[it.first] = false; } } @@ -78,7 +79,7 @@ class MetadataModuleNode : public ModuleNode { return PackedFunc(nullptr); } - const char* type_key() const { return "metadata"; } + const char* type_key() const { return "const_loader"; } /*! * \brief Get the list of metadata that is required by the given module. @@ -87,11 +88,11 @@ class MetadataModuleNode : public ModuleNode { */ Array GetRequiredMetadata(const std::string& symbol) { Array ret; - ICHECK_GT(sym_vars_.count(symbol), 0U) << "No symbol is recorded for " << symbol; - std::vector vars = sym_vars_[symbol]; + ICHECK_GT(const_vars_by_symbol_.count(symbol), 0U) << "No symbol is recorded for " << symbol; + std::vector vars = const_vars_by_symbol_[symbol]; for (const auto& it : vars) { - ICHECK_GT(metadata_.count(it), 0U) << "Found not recorded constant variable: " << it; - ret.push_back(metadata_[it]); + ICHECK_GT(const_var_ndarray_.count(it), 0U) << "Found not recorded constant variable: " << it; + ret.push_back(const_var_ndarray_[it]); } return ret; } @@ -102,7 +103,7 @@ class MetadataModuleNode : public ModuleNode { * for runtime lookup. * * \note A module could be like the following: - * MetadataModuleNode (contains all the metadata) + * ConstLoaderModuleNode (contains all the metadata) * - CSourceModule * - JSON runtime module * @@ -128,32 +129,32 @@ class MetadataModuleNode : public ModuleNode { void SaveToBinary(dmlc::Stream* stream) final { std::vector variables; - std::vector metadata; - for (const auto& it : metadata_) { + std::vector const_var_ndarray; + for (const auto& it : const_var_ndarray_) { String var_name = it.first; variables.push_back(var_name); - metadata.push_back(it.second); + const_var_ndarray.push_back(it.second); } // Save all variables in the function. stream->Write(variables); // Save all constant data. - uint64_t sz = static_cast(metadata.size()); + uint64_t sz = static_cast(const_var_ndarray.size()); stream->Write(sz); for (uint64_t i = 0; i < sz; i++) { - metadata[i].Save(stream); + const_var_ndarray[i].Save(stream); } // Save the symbol to list of required constant variables mapping std::vector symbols; std::vector> const_vars; - for (const auto& it : sym_vars_) { + for (const auto& it : const_vars_by_symbol_) { symbols.push_back(it.first); const_vars.push_back(it.second); } stream->Write(symbols); - sz = static_cast(sym_vars_.size()); + sz = static_cast(const_vars_by_symbol_.size()); stream->Write(sz); for (uint64_t i = 0; i < sz; i++) { stream->Write(const_vars[i]); @@ -165,9 +166,9 @@ class MetadataModuleNode : public ModuleNode { // Load the variables. std::vector variables; - ICHECK(stream->Read(&variables)) << "Loading variables failed"; + ICHECK(stream->Read(&variables)) << "Loading variable names failed"; uint64_t sz; - ICHECK(stream->Read(&sz, sizeof(sz))) << "Loading metadata size failed"; + ICHECK(stream->Read(&sz, sizeof(sz))) << "Loading number of vars failed"; ICHECK_EQ(static_cast(sz), variables.size()) << "The number of variables and ndarray counts must match"; // Load the list of ndarray. @@ -178,10 +179,10 @@ class MetadataModuleNode : public ModuleNode { arrays.push_back(temp); } - std::unordered_map metadata; + std::unordered_map const_var_ndarray; for (uint64_t i = 0; i < sz; i++) { - ICHECK_EQ(metadata.count(variables[i]), 0U); - metadata[variables[i]] = arrays[i]; + ICHECK_EQ(const_var_ndarray.count(variables[i]), 0U); + const_var_ndarray[variables[i]] = arrays[i]; } // Load the symbol to list of required constant variables mapping @@ -196,12 +197,12 @@ class MetadataModuleNode : public ModuleNode { const_vars.push_back(vars); } - std::unordered_map> sym_vars; + std::unordered_map> const_vars_by_symbol; for (uint64_t i = 0; i < sz; i++) { - sym_vars[symbols[i]] = const_vars[i]; + const_vars_by_symbol[symbols[i]] = const_vars[i]; } - auto n = make_object(metadata, sym_vars); + auto n = make_object(const_var_ndarray, const_vars_by_symbol); return Module(n); } @@ -212,19 +213,21 @@ class MetadataModuleNode : public ModuleNode { */ std::unordered_map initialized_; /*! \brief Variable name to NDArray mapping. */ - std::unordered_map metadata_; + std::unordered_map const_var_ndarray_; /*! \brief Symbol name to required constant variables mapping. */ - std::unordered_map> sym_vars_; + std::unordered_map> const_vars_by_symbol_; }; -Module MetadataModuleCreate( - const std::unordered_map& metadata, - const std::unordered_map>& sym_vars) { - auto n = make_object(metadata, sym_vars); +Module ConstLoaderModuleCreate( + const std::unordered_map& const_var_ndarray, + const std::unordered_map>& const_vars_by_symbol) { + auto n = make_object(const_var_ndarray, const_vars_by_symbol); return Module(n); } TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metadata") - .set_body_typed(MetadataModuleNode::LoadFromBinary); + .set_body_typed(ConstLoaderModuleNode::LoadFromBinary); +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_const_loader") + .set_body_typed(ConstLoaderModuleNode::LoadFromBinary); } // namespace runtime } // namespace tvm diff --git a/src/runtime/const_loader_module.h b/src/runtime/const_loader_module.h new file mode 100644 index 000000000000..eb548dfcf370 --- /dev/null +++ b/src/runtime/const_loader_module.h @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file const_loader_module.h + * \brief Defines an interface to use the ConstLoaderModule. + */ + +#ifndef TVM_RUNTIME_CONST_LOADER_MODULE_H_ +#define TVM_RUNTIME_CONST_LOADER_MODULE_H_ + +#include + +#include +#include +#include + +namespace tvm { +namespace runtime { + +/*! + * \brief Create a ConstLoader module object. + * + * \param const_var_ndarray Maps consts var name to NDArray containing data for the var. + * \param const_vars_by_symbol Maps the name of a module init function to a list of names of + * const vars whose data will be passed to that init function. + * + * \return The created ConstLoaderModule. + */ +Module ConstLoaderModuleCreate( + const std::unordered_map& const_var_ndarray, + const std::unordered_map>& const_vars_by_symbol); + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONST_LOADER_MODULE_H_ diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index 8996d1b76e1f..766b93261ac0 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -49,69 +50,22 @@ inline String get_name_mangled(const String& module_name, const String& name) { return ss.str(); } -/*! - * \brief Structure that can be optionally used by the executor codegen - */ -class MetadataNode : public Object { - public: - /*! \brief input information for the main function */ - Array inputs; - /*! \brief number of outputs of the main function */ - int num_outputs = 1; - /*! \brief device contexts information for the main function */ - Array devices; - /*! \brief the executor to be used to run the model */ - String executor = kTvmExecutorGraph; - /*! \brief The external API (packed or c) in use */ - String interface_api; - /*! \brief The internal API (packed or unpacked) in use */ - bool unpacked_api; - - String mod_name = ""; - - static constexpr const uint32_t _type_index = TypeIndex::kDynamic; - static constexpr const char* _type_key = "MetadataObj"; - TVM_DECLARE_FINAL_OBJECT_INFO(MetadataNode, Object); -}; - -/*! - * \brief Managed reference to MetadataNode. - */ -class Metadata : public ObjectRef { - public: - TVM_DLL Metadata(Array inputs, Array devices, int num_outputs, String executor, - String mod_name, String interface_api = "packed", bool unpacked_api = false) { - auto n = make_object(); - n->inputs = inputs; - n->devices = devices; - n->num_outputs = num_outputs; - n->executor = executor; - n->interface_api = interface_api; - n->unpacked_api = unpacked_api; - n->mod_name = mod_name; - data_ = std::move(n); - } - - TVM_DEFINE_OBJECT_REF_METHODS(Metadata, ObjectRef, MetadataNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(MetadataNode); -}; - /*! * \brief Create a metadata module object. * - * \param metadata The variable name to ndarray mapping. - * \param sym_vars The symbol to the list of required constant variables - * mapping. + * \param metadata Exported metadata structure. * * \return The created metadata module. */ -Module MetadataModuleCreate( - const std::unordered_map& metadata, - const std::unordered_map>& sym_vars); +Module MetadataModuleCreate(metadata::Metadata metadata); + +namespace launch_param { /*! \brief A tag to specify whether or not dynamic shared memory is used */ constexpr const char* kUseDynamicSharedMemoryTag = "tir.use_dyn_shared_memory"; +} // namespace launch_param + /*! \brief function information needed by device */ struct FunctionInfo { std::string name; diff --git a/src/runtime/metadata.cc b/src/runtime/metadata.cc new file mode 100644 index 000000000000..021e8244d5bf --- /dev/null +++ b/src/runtime/metadata.cc @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/metadata.h + * \brief Defines implementations of TVM metadata which can exist in the runtime. + */ + +#include +#include +#include +#include + +#include + +namespace tvm { +namespace runtime { +namespace metadata { + +MetadataArray::MetadataArray(Array array, MetadataTypeIndex type_index, const char* struct_name) + : MetadataBase{make_object(array, type_index, struct_name)} {} + +std::string MetadataArrayNode::get_name() { return "MetadataArray"; } + +TVM_REGISTER_OBJECT_TYPE(MetadataBaseNode); +TVM_REGISTER_OBJECT_TYPE(MetadataArrayNode); + +ArrayAccessor MetadataNode::inputs() { + return ArrayAccessor(data_->inputs, data_->num_inputs); +} +ArrayAccessor MetadataNode::outputs() { + return ArrayAccessor(data_->outputs, data_->num_outputs); +} +ArrayAccessor MetadataNode::devices() { + return ArrayAccessor(data_->devices, data_->num_devices); +} +Metadata::Metadata(const struct ::TVMMetadata* data) + : MetadataBase{make_object(data)} {} +std::string MetadataNode::get_name() { return std::string{"Metadata"}; } +TVM_REGISTER_OBJECT_TYPE(MetadataNode); +TensorInfo::TensorInfo(const struct ::TVMTensorInfo* data) + : MetadataBase{make_object(data)} {} +std::string TensorInfoNode::get_name() { return std::string{"TensorInfo"}; } + +} // namespace metadata + +class MetadataModuleNode : public ::tvm::runtime::ModuleNode { + public: + explicit MetadataModuleNode(runtime::metadata::Metadata metadata) { + // CHECK((metadata.defined() && code.size() > 0) || (!metadata.defined() && code.size() == 0)) + // << "metadata and code must both be either defined (when passed from compiler) or undefined + // " + // << "(when passed from runtime)"; + metadata_ = metadata; + // code_ = code; + } + + const char* type_key() const { return "metadata_module"; } + + static Module LoadFromBinary() { + return Module(make_object(runtime::metadata::Metadata())); + } + + void SaveToBinary(dmlc::Stream* stream) final {} + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { + if (name == "get_metadata") { + return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { + if (!metadata_.defined()) { + TVMFunctionHandle f_handle; + int32_t ret_code = TVMBackendGetFuncFromEnv(this, "get_c_metadata", &f_handle); + CHECK_EQ(ret_code, 0) << "Unable to locate get_c_metadata PackedFunc"; + + TVMValue ret_value; + int ret_type_code; + ret_code = TVMFuncCall(f_handle, nullptr, nullptr, 0, &ret_value, &ret_type_code); + CHECK_EQ(ret_code, 0) << "Invoking get_c_metadata: TVMFuncCall returned " << ret_code; + + CHECK_EQ(ret_type_code, kTVMOpaqueHandle) + << "Expected kOpaqueHandle returned; got " << ret_type_code; + CHECK(ret_value.v_handle != nullptr) << "get_c_metadata returned nullptr"; + + metadata_ = runtime::metadata::Metadata( + static_cast(ret_value.v_handle)); + } + + *rv = metadata_; + return; + }); + } + + return PackedFunc(); + } + + private: + runtime::metadata::Metadata metadata_; +}; + +Module MetadataModuleCreate(metadata::Metadata metadata) { + return Module(make_object(metadata)); +} + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metadata_module") + .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = MetadataModuleNode::LoadFromBinary(); }); + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index d577770db1a9..4122f9d0798e 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -24,6 +24,7 @@ #ifndef TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ #define TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ +#include #include #include @@ -205,7 +206,7 @@ class LaunchParamConfig { std::vector filled(6, false); for (size_t i = 0; i < launch_param_tags.size(); ++i) { const std::string& tag = launch_param_tags[i]; - if (tag == kUseDynamicSharedMemoryTag) { + if (tag == launch_param::kUseDynamicSharedMemoryTag) { ICHECK_EQ(i, launch_param_tags.size() - 1) << "kUseDynamicSharedMemoryTag should be the last tag in launch_param_tags."; use_dyn_shared_memory_ = true; diff --git a/src/target/build_common.h b/src/target/build_common.h index c66c2b52822e..6c94ec8703b7 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -58,7 +58,7 @@ inline std::unordered_map ExtractFuncInfo(co } if (auto opt = f->GetAttr(tir::attr::kDeviceUseDynSharedMemory)) { if (opt.value()) { - info.launch_param_tags.push_back(runtime::kUseDynamicSharedMemoryTag); + info.launch_param_tags.push_back(runtime::launch_param::kUseDynamicSharedMemoryTag); } } auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 8f1d76a937bf..702aa3e38495 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -74,7 +74,7 @@ void CodeGenCPU::Init(const std::string& module_name, llvm::TargetMachine* tm, // void* resource_handle); ftype_tvm_backend_packed_c_func_ = llvm::FunctionType::get( t_int_, - {t_tvm_func_handle_, t_tvm_value_->getPointerTo(), t_int_->getPointerTo(), t_int_, + {t_tvm_value_->getPointerTo(), t_int_->getPointerTo(), t_int_, t_tvm_value_->getPointerTo(), t_int_->getPointerTo(), t_void_p_}, false); t_tvm_crt_func_registry_ = llvm::StructType::create( @@ -795,10 +795,13 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& args, const DataType& r_type, - const int64_t begin, const int64_t end) { + const int64_t begin, const int64_t end, bool use_string_lookup) { PackedCall pc; std::string func_name = args[0].as()->value; - llvm::Value* handle = GetPackedFuncHandle(func_name); + llvm::Value* handle = nullptr; + if (use_string_lookup) { + handle = GetPackedFuncHandle(func_name); + } // call the function int64_t nargs = end - begin; ICHECK_GE(nargs, 0); @@ -813,14 +816,31 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& ConstInt32(end)); TypedPointer ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end)); + llvm::FunctionType* callee_ftype = nullptr; + llvm::Value* callee_value = nullptr; + std::vector call_args; + if (use_string_lookup) { + callee_ftype = ftype_tvm_func_call_; + callee_value = RuntimeTVMFuncCall(); + call_args.push_back(handle); + } else { + callee_ftype = ftype_tvm_backend_packed_c_func_; + callee_value = module_->getFunction(func_name); + if (callee_value == nullptr) { + callee_value = llvm::Function::Create(ftype_tvm_backend_packed_c_func_, llvm::Function::ExternalLinkage, func_name, module_.get()); + } + } + call_args.insert(call_args.end(), {arg_value, arg_tcode.addr, ConstInt32(nargs), ret_value, ret_tcode.addr}); + if (!use_string_lookup) { + call_args.push_back(llvm::ConstantPointerNull::get(t_void_p_)); + } #if TVM_LLVM_VERSION >= 90 - auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall()); + auto call_callee = llvm::FunctionCallee(callee_ftype, callee_value); #else - auto call_callee = RuntimeTVMFuncCall(); + auto call_callee = callee_value; #endif - llvm::Value* call = builder_->CreateCall( - call_callee, - {handle, arg_value, arg_tcode.addr, ConstInt32(nargs), ret_value, ret_tcode.addr}); + llvm::Value* call = builder_->CreateCall(call_callee, call_args); + llvm::BasicBlock* end_block = CheckCallSuccess(call); // Load the return value and cast it to the designated type (r_type). @@ -849,17 +869,18 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& return pc; } -llvm::Value* CodeGenCPU::CreateCallPacked(const CallNode* op) { +llvm::Value* CodeGenCPU::CreateCallPacked(const CallNode* op, bool use_string_lookup) { + LOG(INFO) << "CreateCallPacked: " << GetRef(op); ICHECK_EQ(op->args.size(), 5U); PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as()->value, - op->args[4].as()->value); + op->args[4].as()->value, use_string_lookup); return pc.ret_value; } llvm::Value* CodeGenCPU::CreateCallTracePacked(const CallNode* op) { ICHECK_EQ(op->args.size(), 6U); PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as()->value, - op->args[4].as()->value); + op->args[4].as()->value, true); // Get traced value. llvm::Value* traced_value = MakeValue(op->args[5]); // The update_block handles case when we need to update the return value. @@ -905,6 +926,251 @@ llvm::Value* CodeGenCPU::RuntimeTVMParallelBarrier() { return GetContextPtr(gv_tvm_parallel_barrier_); } +struct MetadataLlvmTypes { + llvm::Type* t_float64; + llvm::Type* t_uint8; + llvm::Type* t_int64; + llvm::Type* t_bool; + llvm::Type* t_cstring; + llvm::Type* t_void_p; + llvm::StructType* t_data_type; + ::std::unordered_map<::std::string, llvm::StructType*> structs; +}; + +class MetadataTypeDefiner : public AttrVisitor { +public: + MetadataTypeDefiner(llvm::LLVMContext* ctx, struct MetadataLlvmTypes* llvm_types) : ctx_{ctx}, llvm_types_{llvm_types} {} + + void Visit(const char* key, double* value) final { + elements_.emplace_back(llvm_types_->t_float64); + } + void Visit(const char* key, int64_t* value) final { + elements_.emplace_back(llvm_types_->t_int64); + } + void Visit(const char* key, uint64_t* value) final { + elements_.emplace_back(llvm_types_->t_int64); + } + void Visit(const char* key, int* value) final { + elements_.emplace_back(llvm_types_->t_int64); + } + void Visit(const char* key, bool* value) final { + elements_.emplace_back(llvm_types_->t_bool); + } + void Visit(const char* key, std::string* value) final { + elements_.emplace_back(llvm_types_->t_cstring); + } + void Visit(const char* key, void** value) final { + elements_.emplace_back(llvm_types_->t_void_p); + } + void Visit(const char* key, DataType* value) final { + elements_.emplace_back(llvm_types_->t_data_type); + } + void Visit(const char* key, runtime::NDArray* value) final { + CHECK(false) << "Do not support serializing NDArray"; + } + +private: + void VisitMetadataBase(runtime::metadata::MetadataBase metadata) { + elements_.emplace_back(llvm::PointerType::getUnqual(llvm::StructType::create(*ctx_, metadata->get_name()))); + if (visited_.find(metadata->get_name()) != visited_.end()) { + return; + } + + if (to_visit_.find(metadata->get_name()) != to_visit_.end()) { + return; + } + to_visit_[metadata->get_name()] = metadata; + } + +public: + void VisitArray(const runtime::metadata::MetadataArrayNode* arr) { + for (auto o : arr->array) { + if (o->IsInstance()) { + elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_float64)); + } if (o->IsInstance()) { + elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_int64)); + } else if (o->IsInstance()) { + elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_cstring)); + } else { + runtime::metadata::MetadataBase metadata = Downcast(o); + VisitMetadataBase(metadata); + } + } + } + + void Visit(const char* key, ObjectRef* value) final { + const runtime::metadata::MetadataArrayNode* arr = + value->as(); + if (arr != nullptr) { + VisitArray(arr); + return; + } + + runtime::metadata::MetadataBase metadata = Downcast(*value); + VisitMetadataBase(metadata); + } + + void DefineTypes(runtime::metadata::Metadata metadata) { + to_visit_[metadata->get_name()] = metadata; + + while (to_visit_.size() > 0) { + auto it = to_visit_.begin(); + runtime::metadata::MetadataBase node = (*it).second; + visited_.insert((*it).first); + to_visit_.erase(it); + ReflectionVTable::Global()->VisitAttrs(node.operator->(), this); + llvm_types_->structs[metadata->get_name()] = llvm::StructType::create(*ctx_, elements_, metadata->get_name()); + elements_.clear(); + } + } + + llvm::LLVMContext* ctx_; + struct MetadataLlvmTypes* llvm_types_; + ::std::unordered_set<::std::string> visited_; + ::std::unordered_map<::std::string, runtime::metadata::MetadataBase> to_visit_; + ::std::vector elements_; +}; + +class MetadataSerializer : public AttrVisitor { + using MetadataTypeIndex = runtime::metadata::MetadataTypeIndex; +public: + MetadataSerializer(CodeGenLLVM* codegen, struct MetadataLlvmTypes* llvm_types) : codegen_{codegen}, llvm_types_{llvm_types} {} + + void Visit(const char* key, double* value) final { + elements_.back().emplace_back(llvm::ConstantFP::get(llvm_types_->t_float64, *value)); + } + void Visit(const char* key, int64_t* value) final { + elements_.back().emplace_back(llvm::ConstantInt::get(llvm_types_->t_int64, static_cast(*value), true /* isSigned */)); + } + void Visit(const char* key, uint64_t* value) final { + elements_.back().emplace_back(llvm::ConstantInt::get(llvm_types_->t_int64, *value, false /* isSigned */)); + } + void Visit(const char* key, int* value) final { + elements_.back().emplace_back(llvm::ConstantInt::get(llvm_types_->t_int64, *value, true /* isSigned */)); + } + void Visit(const char* key, bool* value) final { + elements_.back().emplace_back(llvm::ConstantInt::get(llvm_types_->t_uint8, static_cast(*value), false /* isSigned */)); + } + void Visit(const char* key, std::string* value) final { + elements_.back().emplace_back(codegen_->GetConstString(*value)); + } + void Visit(const char* key, void** value) final { + CHECK(false) << "Do not support serializing void*"; + } + void Visit(const char* key, DataType* value) final { + elements_.back().emplace_back(llvm::ConstantStruct::get( + llvm_types_->t_data_type, + {llvm::ConstantInt::get(llvm_types_->t_uint8, value->code(), false /* isSigned */), + llvm::ConstantInt::get(llvm_types_->t_uint8, value->bits(), false /* isSigned */), + llvm::ConstantInt::get(llvm_types_->t_uint8, value->lanes(), false /* isSigned */)})); + } + + void Visit(const char* key, runtime::NDArray* value) final { + CHECK(false) << "Do not support serializing NDArray"; + } + + llvm::Constant* VisitMetadata(runtime::metadata::MetadataBase metadata) { + elements_.emplace_back(std::vector()); + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + auto struct_elements = elements_.back(); + elements_.pop_back(); + return llvm::ConstantStruct::get(llvm_types_->structs[metadata->get_name()], struct_elements); + } + + llvm::Constant* VisitArray(const runtime::metadata::MetadataArrayNode* arr) { + llvm::Type* element_type; + switch (arr->type_index) { + case MetadataTypeIndex::kInt64: + element_type = llvm_types_->t_int64; + break; + case MetadataTypeIndex::kUint64: + element_type = llvm_types_->t_int64; + break; + case MetadataTypeIndex::kBool: + element_type = llvm_types_->t_uint8; + break; + case MetadataTypeIndex::kString: + element_type = llvm_types_->t_cstring; + break; + case MetadataTypeIndex::kMetadata: + element_type = llvm_types_->structs[arr->struct_name]; + break; + default: + LOG(FATAL) << "unknown metadata type_index " << arr->type_index; + } + + elements_.emplace_back(std::vector()); + for (auto o : arr->array) { + if (o->IsInstance()) { + double value = Downcast(o)->value;; + Visit(nullptr, &value); + } if (o->IsInstance()) { + auto value = Downcast(o)->value; + Visit(nullptr, &value); + } else if (o->IsInstance()) { + ::std::string value = Downcast(o); + Visit(nullptr, &value); + } else { + // nested array not possible. + runtime::metadata::MetadataBase metadata = Downcast(o); + VisitMetadata(metadata); + } + } + auto array = elements_.back(); + elements_.pop_back(); + return llvm::ConstantArray::get(llvm::ArrayType::get(element_type, array.size()), array); + } + + void Visit(const char* key, ObjectRef* value) final { + const runtime::metadata::MetadataArrayNode* arr = + value->as(); + if (arr != nullptr) { + elements_.back().emplace_back(VisitArray(arr)); + return; + } + + runtime::metadata::MetadataBase metadata = Downcast(*value); + elements_.back().emplace_back(VisitMetadata(metadata)); + } + + llvm::Constant* Serialize(runtime::metadata::MetadataBase metadata) { + Visit(nullptr, &metadata); + return last_production_; + } + + CodeGenLLVM* codegen_; + MetadataLlvmTypes* llvm_types_; + llvm::LLVMContext* ctx_; + llvm::Module* module_; + std::vector> elements_; + llvm::Constant* last_production_; +}; + +void CodeGenCPU::DefineMetadata(runtime::metadata::Metadata metadata) { + MetadataLlvmTypes llvm_types{ + t_float64_ /* t_float64 */, + llvm::Type::getInt8Ty(*ctx_) /* t_uint8 */, + t_int64_ /* t_int64 */, + llvm::Type::getInt8Ty(*ctx_) /* t_bool */, + t_char_->getPointerTo() /* t_cstring */, + t_void_p_ /* t_void_p */, + llvm::StructType::create(*ctx_, {t_int8_, t_int8_, t_int8_}, "DLDataType") /* t_data_type */, + }; + + MetadataTypeDefiner definer{ctx_, &llvm_types}; + definer.DefineTypes(metadata); + + MetadataSerializer serializer{this, &llvm_types}; + llvm::Constant* metadata_constant = serializer.Serialize(metadata); + + llvm::FunctionType* ftype = llvm::FunctionType::get(t_void_p_, {}, false); + function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, + "get_c_metadata", module_.get()); + llvm::BasicBlock* entry_point_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_); + builder_->SetInsertPoint(entry_point_entry); + builder_->CreateRet(builder_->CreateBitCast(metadata_constant, t_void_p_)); +} + void CodeGenCPU::DefineFunctionRegistry(Array func_names) { ICHECK(is_system_lib_) << "Loading of --system-lib modules is yet to be defined for C runtime"; Array symbols; @@ -971,9 +1237,11 @@ void CodeGenCPU::AddStartupFunction() { llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { if (op->op.same_as(builtin::tvm_call_packed_lowered())) { - return CreateCallPacked(op); + return CreateCallPacked(op, true /* use_string_lookup */); } else if (op->op.same_as(builtin::tvm_call_trace_packed_lowered())) { return CreateCallTracePacked(op); + } else if (op->op.same_as(builtin::tvm_call_cpacked_lowered())) { + return CreateCallPacked(op, false /* use_string_lookup */); } else if (op->op.same_as(builtin::tvm_static_handle())) { return CreateStaticHandle(); } else if (op->op.same_as(builtin::tvm_throw_last_error())) { diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index 58e314ec0c6e..af56088125d5 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -56,6 +56,11 @@ class CodeGenCPU : public CodeGenLLVM { */ void DefineFunctionRegistry(Array func_names); + /*! + * \brief Serialize the metadata object as data, and implement get_c_metadata function. + * \param metadata The metadata which should be serialized. + */ + void DefineMetadata(runtime::metadata::Metadata metadata); protected: void AddStartupFunction() final; // meta data @@ -116,9 +121,9 @@ class CodeGenCPU : public CodeGenLLVM { llvm::BasicBlock* end_block; }; PackedCall MakeCallPackedLowered(const Array& args, const DataType& r_type, - const int64_t begin, const int64_t end); + const int64_t begin, const int64_t end, bool use_string_lookup); // create call into tvm packed function. - llvm::Value* CreateCallPacked(const CallNode* op); + llvm::Value* CreateCallPacked(const CallNode* op, bool use_string_lookup); // Create trace call into tvm packed function. llvm::Value* CreateCallTracePacked(const CallNode* op); // Create static initialization diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 6c64f6798e47..510e89eee1ca 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -190,6 +190,19 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { } } +llvm::GlobalVariable* CodeGenLLVM::GetLinkedParamSymbol(const std::string& param_name, llvm::ConstantArray* array) { + std::string symbol_name = std::string(::tvm::runtime::symbol::tvm_param_prefix) + param_name; + llvm::GlobalVariable* var = module_->getGlobalVariable(symbol_name); + if (var == nullptr) { + CHECK(array != nullptr) << "Expect param symbol " << symbol_name << " to either be defined or for the array to be supplied"; + var = new llvm::GlobalVariable( + *module_, static_cast(array->getType()), true, llvm::GlobalValue::CommonLinkage, array, symbol_name); + } + //(array != nullptr ? : static_cast(t_void_p_)), + return var; +} + + void CodeGenLLVM::LinkParameters(const Map params) { // It would be nice to de-dupe these declarations frm src/tir/transforms/make_packed_api.cc, // but they are at a different layer in the compiler... @@ -235,9 +248,7 @@ void CodeGenLLVM::LinkParameters(const Map params) { // Add data to the global section. for (auto kv : params) { auto array = NDArrayToLLVMArray(ctx_, kv.second->param); - std::string symbol_name = std::string(::tvm::runtime::symbol::tvm_param_prefix) + kv.first; - llvm::GlobalVariable* param_symbol = new llvm::GlobalVariable( - *module_, array->getType(), true, llvm::GlobalValue::InternalLinkage, array, symbol_name); + llvm::GlobalVariable* param_symbol = GetLinkedParamSymbol(kv.first, array); auto dtype = tvm::runtime::DataType(kv.second->param->dtype); size_t align = std::max(tvm::runtime::GetVectorBytes(dtype), tvm::runtime::kAllocAlignment); #if TVM_LLVM_VERSION >= 100 @@ -245,8 +256,9 @@ void CodeGenLLVM::LinkParameters(const Map params) { #else param_symbol->setAlignment(align); #endif + param_symbol->setInitializer(array); - llvm::BasicBlock* case_block = llvm::BasicBlock::Create(*ctx_, "case_" + symbol_name, function); + llvm::BasicBlock* case_block = llvm::BasicBlock::Create(*ctx_, "case_" + param_symbol->getName(), function); switch_inst->addCase( llvm::cast(llvm::ConstantInt::get(t_int64_, kv.second->id)), case_block); builder_->SetInsertPoint(case_block); @@ -387,6 +399,10 @@ void CodeGenLLVM::Optimize() { fpass.run(*it); } fpass.doFinalization(); + std::string tmp; + llvm::raw_string_ostream stream(tmp); + module_->print(stream, nullptr); +// LOG(INFO) << "LLVM IR: " << stream.str(); mpass.run(*module_); } @@ -1259,9 +1275,13 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { +// LOG(INFO) << "Visit Call:" << GetRef(op); if (auto* ptr_op = op->op.as()) { auto call_op = GetRef(ptr_op); - if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { + if (op->op.same_as(builtin_lookup_param_)) { +// return llvm::ConstantInt::get(t_void_p_, 0); + return GetLinkedParamSymbol(Downcast(op->args[0])->value, nullptr); + } else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { // call extern intrinsic ICHECK_GE(op->args.size(), 1U); auto global_symbol = Downcast(op->args[0]); @@ -1272,7 +1292,10 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { return this->CreateCallExtern(GetType(GetRef(op)), op_attr_global_symbol_[call_op], op->args, false); } else { - return CreateIntrinsic(op); + VLOG(2) << "CreateIntrinsic: " << GetRef(op); + auto x = CreateIntrinsic(op); + VLOG(2) << "CreateIntrinsic done"; + return x; } } else { ICHECK(op->op.as()); diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 4a9df65951c0..81828735d10b 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -141,7 +141,7 @@ class CodeGenLLVM : public ExprFunctor, * \param e The expression to be created value for. * \return created value. */ - llvm::Value* MakeValue(const PrimExpr& e) { return VisitExpr(e); } + llvm::Value* MakeValue(const PrimExpr& e) { auto a = VisitExpr(e); /* LOG(INFO) << "MakeValue (" << e << "): " << a; */ return a; } // Short hande code to get a constant int 32 llvm::Constant* ConstInt32(int64_t value) const { return llvm::ConstantInt::getSigned(t_int32_, value); @@ -187,6 +187,9 @@ class CodeGenLLVM : public ExprFunctor, void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const EvaluateNode* op) override; + // Get constant string + llvm::Constant* GetConstString(const std::string& str); + protected: /*! * \brief Address and type pair to assist in handling opaque pointers. @@ -288,6 +291,13 @@ class CodeGenLLVM : public ExprFunctor, */ llvm::Function* GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type* ret_type, llvm::ArrayRef arg_types); + /*! + * \brief Lookup or create a GlobalVariable whose content is the data field of a DLTensor for a + * given linked_param() CallNode. + * \param param_name Parameter name (e.g. unmangled, from lookup_param node). + * \return the GlobalVariable indicated in the brief. + */ + llvm::GlobalVariable* GetLinkedParamSymbol(const ::std::string& param_name, llvm::ConstantArray* array); /*! * \brief Get the number of elements in the given vector value. * \param vec The value, must be of a vector type. @@ -298,8 +308,6 @@ class CodeGenLLVM : public ExprFunctor, // Get alignment given index. void GetAlignment(DataType t, const VarNode* buf_var, const PrimExpr& index, int* p_alignment, int* p_native_bits); - // Get constant string - llvm::Constant* GetConstString(const std::string& str); // do a scalarize call with f llvm::Value* CreateScalarizedCall(const CallNode* op, llvm::Function* f, const std::vector& args); @@ -391,6 +399,8 @@ class CodeGenLLVM : public ExprFunctor, const Op& builtin_call_pure_extern_ = builtin::call_pure_extern(); const Op& builtin_call_llvm_intrin_ = builtin::call_llvm_intrin(); const Op& builtin_call_llvm_pure_intrin_ = builtin::call_llvm_pure_intrin(); + const Op& builtin_lookup_param_ = builtin::lookup_param(); + const Op& builtin_tvm_call_cpacked_lowered_ = builtin::tvm_call_cpacked_lowered(); /*! \brief Helper struct for debug infos. */ struct DebugInfo { diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index dc10d7885c25..5ddeb18c9ee0 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -308,14 +308,14 @@ class LLVMModuleNode final : public runtime::ModuleNode { cg->SetFastMathFlag(fmf); + if (found_linked_params) { + cg->LinkParameters(linked_params); + } cg->AddFunctionsOrdered(funcs.begin(), funcs.end()); if (entry_func.length() != 0) { cg->AddMainFunction(entry_func); } - if (found_linked_params) { - cg->LinkParameters(linked_params); - } module_ = cg->Finish(); module_->addModuleFlag(llvm::Module::Warning, "tvm_target", llvm::MDString::get(*ctx_, LLVMTargetToString(target))); @@ -512,6 +512,37 @@ TVM_REGISTER_GLOBAL("codegen.codegen_blob") return runtime::Module(n); }); +runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata, Target target, + tvm::relay::Runtime runtime) { + InitializeLLVM(); + auto tm = GetLLVMTargetMachine(target); + bool system_lib = runtime->GetAttr("system-lib").value_or(Bool(false)); + auto ctx = std::make_shared(); + std::unique_ptr cg{new CodeGenCPU()}; + + cg->Init("TVMMetadataMod", tm.get(), ctx.get(), system_lib, system_lib, false /* target_c_runtime */); + + cg->DefineMetadata(metadata); + auto mod = cg->Finish(); + mod->addModuleFlag(llvm::Module::Warning, "tvm_target", + llvm::MDString::get(*ctx, LLVMTargetToString(target))); + mod->addModuleFlag(llvm::Module::Override, "Debug Info Version", llvm::DEBUG_METADATA_VERSION); + + if (tm->getTargetTriple().isOSDarwin()) { + mod->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2); + } + + std::string verify_errors_storage; + llvm::raw_string_ostream verify_errors(verify_errors_storage); + LOG_IF(FATAL, llvm::verifyModule(*mod, &verify_errors)) + << "LLVM module verification failed with the following errors: \n" + << verify_errors.str(); + + auto n = make_object(); + n->Init(std::move(mod), ctx); + return runtime::Module(n); +} + runtime::Module CreateLLVMCrtMetadataModule(const Array& modules, Target target, tvm::relay::Runtime runtime) { Array func_names; diff --git a/src/target/llvm/llvm_module.h b/src/target/llvm/llvm_module.h index 933030e213d2..660d81400b0d 100644 --- a/src/target/llvm/llvm_module.h +++ b/src/target/llvm/llvm_module.h @@ -33,6 +33,9 @@ namespace tvm { namespace codegen { +runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata, Target target, + tvm::relay::Runtime runtime); + runtime::Module CreateLLVMCrtMetadataModule(const Array& modules, Target target, tvm::relay::Runtime runtime); diff --git a/src/target/metadata.cc b/src/target/metadata.cc new file mode 100644 index 000000000000..adf4cba3e610 --- /dev/null +++ b/src/target/metadata.cc @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file metadata.cc + * \brief Implementations of the compiler extensions for Metadata. + */ + +#include "metadata.h" + +#include + +namespace tvm { +namespace target { +namespace metadata { + +TVM_REGISTER_REFLECTION_VTABLE(VisitableMetadataNode, + ::tvm::detail::ReflectionTrait) + .set_creator([](const std::string&) -> ObjectPtr { + return ::tvm::runtime::make_object(); + }); + +TVM_REGISTER_REFLECTION_VTABLE(VisitableTensorInfoNode, + ::tvm::detail::ReflectionTrait) + .set_creator([](const std::string&) -> ObjectPtr { + return ::tvm::runtime::make_object(); + }); + +} // namespace metadata +} // namespace target +} // namespace tvm diff --git a/src/target/metadata.h b/src/target/metadata.h new file mode 100644 index 000000000000..adab4810f3d5 --- /dev/null +++ b/src/target/metadata.h @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/target/metadata.h + * \brief Extends Metadata for use in the compiler. + */ +#ifndef TVM_TARGET_METADATA_H_ +#define TVM_TARGET_METADATA_H_ + +#include + +#include +#include +#include + +namespace tvm { +namespace target { +namespace metadata { + +class VisitableMetadataNode : public ::tvm::runtime::metadata::MetadataNode { + public: + explicit VisitableMetadataNode(const struct ::TVMMetadata* data) : MetadataNode{data} {} + VisitableMetadataNode() : MetadataNode{nullptr} {} + + void VisitAttrs(AttrVisitor* v) { + int64_t version_cpp{version()}; + v->Visit("version", &version_cpp); + auto inputs_array = Array(); + auto inputs_accessor = inputs(); + inputs_array.reserve(num_inputs()); + for (int64_t i = 0; i < num_inputs(); ++i) { + auto ti = ::tvm::runtime::metadata::TensorInfo{inputs_accessor[i]}; + inputs_array.push_back(ti); + } + ::tvm::runtime::metadata::MetadataArray inputs_metadata_array{::std::move(inputs_array), + runtime::metadata::MetadataTypeIndex::kMetadata, + "TVMTensorInfo"}; + Downcast<::tvm::runtime::metadata::TensorInfo>(inputs_metadata_array->array[0]); + v->Visit("inputs", &inputs_metadata_array); + auto outputs_array = Array(); + auto outputs_accessor = outputs(); + outputs_array.reserve(num_outputs()); + for (int64_t i = 0; i < num_outputs(); ++i) { + outputs_array.push_back(::tvm::runtime::metadata::TensorInfo{outputs_accessor[i]}); + } + ::tvm::runtime::metadata::MetadataArray outputs_metadata_array{outputs_array, + runtime::metadata::MetadataTypeIndex::kMetadata, + "TVMTensorInfo"}; + v->Visit("outputs", &outputs_metadata_array); + auto devices_array = Array(); + auto devices_accessor = devices(); + devices_array.reserve(num_devices()); + for (int64_t i = 0; i < num_devices(); ++i) { + devices_array.push_back(::tvm::runtime::String{devices_accessor[i]}); + } + ::tvm::runtime::metadata::MetadataArray devices_metadata_array{devices_array, runtime::metadata::MetadataTypeIndex::kString, "const char*"}; + v->Visit("devices", &devices_metadata_array); + ::std::string executor_cpp{data()->executor}; + v->Visit("executor", &executor_cpp); + ::std::string mod_name_cpp{data()->mod_name}; + v->Visit("mod_name", &mod_name_cpp); + ::std::string interface_api_cpp{data()->interface_api}; + v->Visit("interface_api", &interface_api_cpp); + bool use_unpacked_api_cpp{use_unpacked_api()}; + v->Visit("use_unpacked_api", &use_unpacked_api_cpp); + } +}; + +class InMemoryMetadataNode : public ::tvm::target::metadata::VisitableMetadataNode { + public: + InMemoryMetadataNode() + : InMemoryMetadataNode(0 /* version */, {} /* inputs */, {} /* outputs */, {} /* devices */, + "" /* executor */, "" /* mod_name */, "" /* interface_api */, + false /* use_unpacked_api */ + ) {} + InMemoryMetadataNode(int64_t version, + const ::std::vector<::tvm::runtime::metadata::TensorInfo>& inputs, + const ::std::vector<::tvm::runtime::metadata::TensorInfo>& outputs, + const ::std::vector<::std::string>& devices, + const ::tvm::runtime::String executor, const ::tvm::runtime::String mod_name, + const ::tvm::runtime::String interface_api, bool use_unpacked_api) + : VisitableMetadataNode{&storage_}, + inputs_{new struct TVMTensorInfo[inputs.size()]()}, + inputs_objs_{inputs}, + outputs_{new struct TVMTensorInfo[outputs.size()]()}, + outputs_objs_{outputs}, + devices_{new const char*[devices.size()]()}, + devices_objs_{devices}, + executor_{executor}, + mod_name_{mod_name}, + interface_api_{interface_api}, + storage_{version, + nullptr, + 0, + nullptr, + 0, + nullptr, + 0, + executor_.c_str(), + mod_name_.c_str(), + interface_api_.c_str(), + use_unpacked_api} { + storage_.inputs = inputs_.get(); + storage_.num_inputs = inputs.size(); + for (unsigned int i = 0; i < inputs.size(); ++i) { + inputs_.get()[i] = *inputs[i]->data(); + } + storage_.outputs = outputs_.get(); + storage_.num_outputs = outputs.size(); + for (unsigned int i = 0; i < outputs.size(); ++i) { + outputs_.get()[i] = *outputs[i]->data(); + } + storage_.devices = devices_.get(); + storage_.num_devices = devices.size(); + for (unsigned int i = 0; i < devices.size(); ++i) { + devices_.get()[i] = devices_objs_[i].c_str(); + } + } + + private: + ::std::unique_ptr inputs_; + std::vector<::tvm::runtime::metadata::TensorInfo> inputs_objs_; + ::std::unique_ptr outputs_; + std::vector<::tvm::runtime::metadata::TensorInfo> outputs_objs_; + ::std::unique_ptr devices_; + std::vector<::std::string> devices_objs_; + ::std::string executor_; + ::std::string mod_name_; + ::std::string interface_api_; + struct ::TVMMetadata storage_; +}; + +class VisitableTensorInfoNode : public ::tvm::runtime::metadata::TensorInfoNode { + public: + explicit VisitableTensorInfoNode(const struct ::TVMTensorInfo* data) : TensorInfoNode{data} {} + VisitableTensorInfoNode() : TensorInfoNode{nullptr} {} + + void VisitAttrs(AttrVisitor* v) { + ::std::string name_cpp{data()->name}; + v->Visit("name", &name_cpp); + auto shape_array = Array(); + auto shape_accessor = shape(); + shape_array.reserve(num_shape()); + for (int64_t i = 0; i < num_shape(); ++i) { + shape_array.push_back(::tvm::Integer{static_cast(shape_accessor[i])}); + } + ::tvm::runtime::metadata::MetadataArray shape_metadata_array{shape_array, runtime::metadata::MetadataTypeIndex::kInt64, "int64_t"}; + v->Visit("shape", &shape_metadata_array); + ::tvm::runtime::DataType dtype_cpp{dtype()}; + v->Visit("dtype", &dtype_cpp); + } +}; + +class InMemoryTensorInfoNode : public ::tvm::target::metadata::VisitableTensorInfoNode { + public: + InMemoryTensorInfoNode() : InMemoryTensorInfoNode("", {}, ::tvm::runtime::DataType(0, 0, 0)) {} + InMemoryTensorInfoNode(const ::tvm::runtime::String& name, const ::std::vector& shape, + ::tvm::runtime::DataType dtype) + : VisitableTensorInfoNode{&storage_}, + name_{name}, + shape_{new int64_t[shape.size()]()}, + storage_{name_.c_str(), nullptr, 0, dtype} { + storage_.shape = shape_.get(); + storage_.num_shape = shape.size(); + for (unsigned int i = 0; i < shape.size(); ++i) { + shape_.get()[i] = shape[i]; + } + } + + private: + ::std::string name_; + ::std::unique_ptr shape_; + struct ::TVMTensorInfo storage_; +}; + +} // namespace metadata +} // namespace target +} // namespace tvm + +#endif // TVM_TARGET_METADATA_H_ diff --git a/src/target/metadata_module.cc b/src/target/metadata_module.cc index 2b190e5d66ed..9e1a8dcf69e1 100644 --- a/src/target/metadata_module.cc +++ b/src/target/metadata_module.cc @@ -27,6 +27,7 @@ #include +#include "../runtime/const_loader_module.h" #include "../runtime/meta_data.h" #include "llvm/llvm_module.h" #include "source/source_module.h" @@ -34,10 +35,99 @@ namespace tvm { namespace codegen { +static runtime::Module CreateCrtMetadataModule( + runtime::Module target_module, Target target, relay::Runtime runtime, + runtime::metadata::Metadata metadata, Array non_crt_exportable_modules, + Array crt_exportable_modules, + const std::unordered_map& const_var_ndarray) { + if (!non_crt_exportable_modules.empty()) { + std::string non_exportable_modules; + for (unsigned int i = 0; i < non_crt_exportable_modules.size(); i++) { + if (i > 0) { + non_exportable_modules += ", "; + } + auto mod = non_crt_exportable_modules[i]; + auto pf_sym = mod.GetFunction("get_symbol"); + if (pf_sym != nullptr) { + non_exportable_modules += pf_sym().operator std::string(); + } else { + non_exportable_modules += + std::string{"(module type_key="} + mod->type_key() + std::string{")"}; + } + } + CHECK(false) << "These " << non_crt_exportable_modules.size() + << " modules are not exportable to C-runtime: " << non_exportable_modules; + } + + if (target->kind->name == "c") { + crt_exportable_modules.push_back(target_module); + target_module = + CreateCSourceCrtMetadataModule(crt_exportable_modules, target, runtime, metadata); + } else if (target->kind->name == "llvm") { +#ifdef TVM_LLVM_VERSION + crt_exportable_modules.push_back(target_module); + target_module = CreateLLVMCrtMetadataModule(crt_exportable_modules, target, runtime); +#else // TVM_LLVM_VERSION + LOG(FATAL) << "TVM was not built with LLVM enabled."; +#endif // TVM_LLVM_VERSION + } + + return target_module; +} + +static runtime::Module CreateCppMetadataModule( + runtime::Module target_module, Target target, relay::Runtime runtime, + runtime::metadata::Metadata metadata, + const std::unordered_map>& const_vars_by_symbol, + Array non_crt_exportable_modules, + Array crt_exportable_modules, + const std::unordered_map& const_var_ndarray) { + if (!non_crt_exportable_modules.empty()) { + runtime::Module const_loader_mod = + runtime::ConstLoaderModuleCreate(const_var_ndarray, const_vars_by_symbol); + const_loader_mod.Import(target_module); + for (const auto& it : non_crt_exportable_modules) { + const_loader_mod.Import(it); + } + target_module = const_loader_mod; + } + + if (metadata->executor() == runtime::kTvmExecutorAot && runtime->name == relay::kTvmRuntimeCpp) { + if (target->kind->name == "c") { + auto metadata_module = CreateCSourceCppMetadataModule(metadata); + metadata_module->Import(target_module); + target_module = metadata_module; + } +#ifdef TVM_LLVM_VERSION + else if (target->kind->name == "llvm") { + auto metadata_module = CreateLLVMCppMetadataModule(metadata, target, runtime); + metadata_module->Import(target_module); + target_module = metadata_module; + } +#endif // TVM_LLVM_VERSION + else { + CHECK(false) << "Don't know how to create MetadataModule for target type " << target->str(); + } + } + + return target_module; +} + +/*! + * \brief Create a metadata module wrapper. The helper is used by different + * codegens, such as graph executor codegen and the vm compiler. + * + * \param params The metadata for initialization of all modules. + * \param target_module the internal module that is compiled by tvm. + * \param ext_modules The external modules that needs to be imported inside the metadata + * module(s). + * \param target The target that all the modules are compiled for + * \return The created metadata module that manages initialization of metadata. + */ runtime::Module CreateMetadataModule( - const std::unordered_map& params, + const std::unordered_map& const_var_ndarray, tvm::runtime::Module target_module, const Array& ext_modules, Target target, - tvm::relay::Runtime runtime, runtime::Metadata metadata) { + tvm::relay::Runtime runtime, runtime::metadata::Metadata metadata) { // Here we split modules into two groups: // 1. Those modules which can be exported to C-runtime. These are DSO-exportable // (i.e. llvm or c) modules which return nothing from get_const_vars(). @@ -52,19 +142,19 @@ runtime::Module CreateMetadataModule( bool is_targeting_crt = runtime->name == "crt"; // Wrap all submodules in the initialization wrapper. - std::unordered_map> sym_metadata; + std::unordered_map> const_vars_by_symbol; for (tvm::runtime::Module mod : ext_modules) { auto pf_sym = mod.GetFunction("get_symbol"); auto pf_var = mod.GetFunction("get_const_vars"); - std::vector arrays; + std::vector symbol_const_vars; if (pf_sym != nullptr && pf_var != nullptr) { String symbol = pf_sym(); Array variables = pf_var(); for (size_t i = 0; i < variables.size(); i++) { - arrays.push_back(variables[i].operator std::string()); + symbol_const_vars.push_back(variables[i].operator std::string()); } - ICHECK_EQ(sym_metadata.count(symbol), 0U) << "Found duplicated symbol: " << symbol; - sym_metadata[symbol] = arrays; + ICHECK_EQ(const_vars_by_symbol.count(symbol), 0U) << "Found duplicated symbol: " << symbol; + const_vars_by_symbol[symbol] = symbol_const_vars; } // We only need loading of serialized constant data // if there are constants present and required by the @@ -74,7 +164,7 @@ runtime::Module CreateMetadataModule( // TODO(@manupa-arm) : we should be able to use csource_metadata // if the variables are empty when all the runtime modules implement get_func_names - if (arrays.empty() && is_targeting_crt && DSOExportable(mod) && + if (symbol_const_vars.empty() && is_targeting_crt && DSOExportable(mod) && (target->kind->name == "c" || target->kind->name == "llvm")) { crt_exportable_modules.push_back(mod); } else { @@ -83,49 +173,17 @@ runtime::Module CreateMetadataModule( } if (is_targeting_crt) { - if (!non_crt_exportable_modules.empty()) { - std::string non_exportable_modules; - for (unsigned int i = 0; i < non_crt_exportable_modules.size(); i++) { - if (i > 0) { - non_exportable_modules += ", "; - } - auto mod = non_crt_exportable_modules[i]; - auto pf_sym = mod.GetFunction("get_symbol"); - if (pf_sym != nullptr) { - non_exportable_modules += pf_sym().operator std::string(); - } else { - non_exportable_modules += - std::string{"(module type_key="} + mod->type_key() + std::string{")"}; - } - } - CHECK(false) << "These " << non_crt_exportable_modules.size() - << " modules are not exportable to C-runtime: " << non_exportable_modules; - } - - if (target->kind->name == "c") { - crt_exportable_modules.push_back(target_module); - target_module = - CreateCSourceCrtMetadataModule(crt_exportable_modules, target, runtime, metadata); - } else if (target->kind->name == "llvm") { -#ifdef TVM_LLVM_VERSION - crt_exportable_modules.push_back(target_module); - target_module = CreateLLVMCrtMetadataModule(crt_exportable_modules, target, runtime); -#else // TVM_LLVM_VERSION - LOG(FATAL) << "TVM was not built with LLVM enabled."; -#endif // TVM_LLVM_VERSION - } + LOG(INFO) << "Create CRT metadata: " << metadata.defined(); + return CreateCrtMetadataModule(target_module, target, runtime, metadata, + non_crt_exportable_modules, crt_exportable_modules, + const_var_ndarray); } else { - if (!non_crt_exportable_modules.empty()) { - runtime::Module binary_meta_mod = runtime::MetadataModuleCreate(params, sym_metadata); - binary_meta_mod.Import(target_module); - for (const auto& it : non_crt_exportable_modules) { - binary_meta_mod.Import(it); - } - return binary_meta_mod; - } + return CreateCppMetadataModule(target_module, target, runtime, metadata, const_vars_by_symbol, + non_crt_exportable_modules, crt_exportable_modules, + const_var_ndarray); } - return target_module; } } // namespace codegen + } // namespace tvm diff --git a/src/target/metadata_module.h b/src/target/metadata_module.h index ee6f7231b3a1..9e0a25bb2421 100644 --- a/src/target/metadata_module.h +++ b/src/target/metadata_module.h @@ -26,6 +26,7 @@ #define TVM_TARGET_METADATA_MODULE_H_ #include +#include #include #include #include @@ -54,7 +55,7 @@ namespace codegen { runtime::Module CreateMetadataModule( const std::unordered_map& params, runtime::Module target_module, const Array& ext_modules, Target target, tvm::relay::Runtime runtime, - runtime::Metadata metadata); + runtime::metadata::Metadata metadata); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 515cdccb88fb..4e21775a4a20 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -23,6 +23,7 @@ #include "codegen_c_host.h" #include +#include #include #include #include @@ -51,6 +52,10 @@ void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, std::string target_s CodeGenC::Init(output_ssa); } +void CodeGenCHost::InitGlobalContext() { + decl_stream << "void* " << tvm::runtime::symbol::tvm_module_ctx << " = NULL;\n"; +} + void CodeGenCHost::DefineModuleName() { decl_stream << "void* " << module_name_ << " = NULL;\n"; } void CodeGenCHost::AddFunction(const PrimFunc& f) { @@ -392,6 +397,7 @@ runtime::Module BuildCHost(IRModule mod, Target target) { bool emit_asserts = false; CodeGenCHost cg; cg.Init(output_ssa, emit_asserts, target->str()); + LOG(INFO) << "CodegenCHost: " << mod; Map linked_params; bool found_linked_params = false; @@ -437,6 +443,11 @@ runtime::Module BuildCHost(IRModule mod, Target target) { cg.AddFunction(aot_executor_fn); } + relay::Runtime runtime = mod->GetAttr(tvm::attr::kRuntime).value(); + if (aot_executor_fn.defined() && runtime->name == relay::kTvmRuntimeCpp) { + cg.InitGlobalContext(); + } + if (target->GetAttr("system-lib").value_or(Bool(false))) { ICHECK_EQ(target->GetAttr("runtime").value_or(""), "c") << "c target only supports generating C runtime SystemLibs"; diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index c94612cfeac3..8bd83444717d 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -40,6 +40,7 @@ class CodeGenCHost : public CodeGenC { CodeGenCHost(); void Init(bool output_ssa, bool emit_asserts, std::string target_str); + void InitGlobalContext(); void AddFunction(const PrimFunc& f); void DefineModuleName(); diff --git a/src/target/source/codegen_source_base.h b/src/target/source/codegen_source_base.h index d938469b8969..7d55273376be 100644 --- a/src/target/source/codegen_source_base.h +++ b/src/target/source/codegen_source_base.h @@ -25,6 +25,7 @@ #ifndef TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_ #define TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_ +#include #include #include #include @@ -145,6 +146,19 @@ runtime::Module CSourceModuleCreate(const String& code, const String& fmt, const Array& func_names, const Array& const_vars = {}); +/*! + * \brief Wrap the submodules in a metadata module. + * \param params The variable to constant mapping that is collected by the host + * module. + * \param target_module The main TIR-lowered internal runtime module + * \param modules All the external modules that needs to be imported inside the metadata module(s). + * \param target The target that all the modules are compiled for + * \return The wrapped module. + */ +runtime::Module CreateMetadataModule( + const std::unordered_map& params, runtime::Module target_module, + const Array& ext_modules, Target target, runtime::metadata::Metadata metadata); + /*! * \brief Create a source module for viewing and limited saving for device. * \param data The code data to be viewed. @@ -157,6 +171,16 @@ runtime::Module DeviceSourceModuleCreate( std::string data, std::string fmt, std::unordered_map fmap, std::string type_key, std::function fget_source = nullptr); +/*! + * \brief Wrap the submodules that are to be wrapped in a c-source metadata module for C runtime. + * \param modules The modules to be wrapped. + * \param target the target the modules are compiled for. + * \param metadata the metadata needed for code generation. + * \return The wrapped module. + */ +runtime::Module CreateCSourceCrtMetadataModule(const Array& modules, Target target, + runtime::metadata::Metadata metadata); + } // namespace codegen } // namespace tvm #endif // TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_ diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index e01a3d93d087..2b4780899fb4 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -23,17 +23,22 @@ */ #include "source_module.h" -#include -#include -#include - #include +#include #include +#include #include +#include + +// TODO(areusch): idk what's up here... +#include // NOLINT(build/include_order) +#include // NOLINT(build/include_order) +#include // NOLINT(build/include_order) #include "../../runtime/file_utils.h" #include "../../support/str_escape.h" #include "../func_registry_generator.h" +#include "../metadata.h" #include "codegen_source_base.h" namespace tvm { @@ -130,7 +135,8 @@ runtime::Module CSourceModuleCreate(const String& code, const String& fmt, class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { public: CSourceCrtMetadataModuleNode(const Array& func_names, const std::string& fmt, - Target target, relay::Runtime runtime, runtime::Metadata metadata) + Target target, relay::Runtime runtime, + runtime::metadata::Metadata metadata) : fmt_(fmt), func_names_(func_names), target_(target), @@ -164,7 +170,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { Array func_names_; Target target_; relay::Runtime runtime_; - runtime::Metadata metadata_; + runtime::metadata::Metadata metadata_; void CreateFuncRegistry() { code_ << "#include \n"; @@ -200,7 +206,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { void GenerateEntrypointForUnpackedAPI(const std::string& entrypoint_name, const std::string& run_func) { code_ << "TVM_DLL int32_t " << run_func << "("; - unsigned int total_args = (metadata_->inputs.size() + metadata_->num_outputs); + unsigned int total_args = (metadata_->num_inputs() + metadata_->num_outputs()); for (unsigned int i = 0; i < total_args; ++i) { code_ << "void* arg" << i; if (i + 1 != total_args) { @@ -212,13 +218,13 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { code_ << "(void* args, void* type_code, int num_args, void* out_value, void* " "out_type_code, void* resource_handle) {\n"; code_ << "return " << run_func << "("; - for (unsigned int i = 0; i < metadata_->inputs.size(); ++i) { + for (unsigned int i = 0; i < metadata_->num_inputs(); ++i) { code_ << "((DLTensor*)(((TVMValue*)args)[" << i << "].v_handle))[0].data,"; } - for (int i = 0; i < metadata_->num_outputs; ++i) { - int j = metadata_->inputs.size() + i; + for (int i = 0; i < metadata_->num_outputs(); ++i) { + int j = metadata_->num_inputs() + i; code_ << "((DLTensor*)(((TVMValue*)args)[" << j << "].v_handle))[0].data"; - if (i + 1 != metadata_->num_outputs) { + if (i + 1 != metadata_->num_outputs()) { code_ << ","; } } @@ -246,7 +252,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { code_ << "#include <" << mod_name << ".h>\n"; code_ << "TVM_DLL int32_t " << run_func << "("; unsigned int total_args = - (metadata_->inputs.size() + metadata_->devices.size() + metadata_->num_outputs); + (metadata_->num_inputs() + metadata_->num_devices() + metadata_->num_outputs()); for (unsigned int i = 0; i < total_args; ++i) { code_ << "void* arg" << i; if (i + 1 != total_args) { @@ -256,7 +262,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { code_ << ");\n"; code_ << "int32_t " << entrypoint_name << "("; code_ << "struct " << runtime::get_name_mangled(mod_name, "inputs") << "* inputs,"; - if (!metadata_->devices.empty()) { + if (metadata_->num_devices() > 0) { code_ << "struct " << runtime::get_name_mangled(mod_name, "outputs") << "* outputs,"; code_ << "struct " << runtime::get_name_mangled(mod_name, "devices") << "* devices"; } else { @@ -265,27 +271,28 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { code_ << ") {" << "return " << run_func << "("; - for (const auto& input : metadata_->inputs) { - std::string sanitised_input = input; + for (const auto& input : metadata_->inputs()) { + std::string sanitised_input = input->name(); std::replace_if(sanitised_input.begin(), sanitised_input.end(), isNotAlnum, '_'); code_ << "inputs->" << sanitised_input << ","; } - if (metadata_->num_outputs == 1) { + if (metadata_->num_outputs() == 1) { code_ << "outputs->output"; } else { - for (int i = 0; i < metadata_->num_outputs; ++i) { + for (int i = 0; i < metadata_->num_outputs(); ++i) { code_ << "outputs->output" << i; - if (i + 1 != metadata_->num_outputs) { + if (i + 1 != metadata_->num_outputs()) { code_ << ","; } } } - if (!metadata_->devices.empty()) { + if (metadata_->num_devices() > 0) { code_ << ","; - for (const String& device : metadata_->devices) { + auto devices = metadata_->devices(); + for (const String& device : devices) { code_ << "devices->" << device; - if (device != metadata_->devices.back()) { + if (device != devices[devices.size() - 1]) { code_ << ","; } } @@ -299,24 +306,25 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { const std::string run_func_suffix = ::tvm::runtime::symbol::tvm_run_func_suffix; const std::string tvm_entrypoint_suffix = ::tvm::runtime::symbol::tvm_entrypoint_suffix; const std::string run_func_mangled = - runtime::get_name_mangled(metadata_->mod_name, run_func_suffix); + runtime::get_name_mangled(metadata_->mod_name(), run_func_suffix); const std::string entrypoint_mangled = - runtime::get_name_mangled(metadata_->mod_name, tvm_entrypoint_suffix); - const std::string network_mangled = runtime::get_name_mangled(metadata_->mod_name, "network"); + runtime::get_name_mangled(metadata_->mod_name(), tvm_entrypoint_suffix); + const std::string network_mangled = runtime::get_name_mangled(metadata_->mod_name(), "network"); code_ << "#include \"tvm/runtime/c_runtime_api.h\"\n"; code_ << "#ifdef __cplusplus\n"; code_ << "extern \"C\" {\n"; code_ << "#endif\n"; - if (metadata_->unpacked_api) { - if (metadata_->interface_api == "c") { - GenerateCInterfaceEntrypoint(entrypoint_mangled, run_func_mangled, metadata_->mod_name); + if (metadata_->use_unpacked_api()) { + LOG(INFO) << "Generate AOT Descriptor: " << metadata_->interface_api(); + if (metadata_->interface_api() == "c") { + GenerateCInterfaceEntrypoint(entrypoint_mangled, run_func_mangled, metadata_->mod_name()); } else { GenerateEntrypointForUnpackedAPI(entrypoint_mangled, run_func_mangled); } } else { - ICHECK_EQ(metadata_->interface_api, "packed") + ICHECK_EQ(metadata_->interface_api(), "packed") << "Packed interface required for packed operators"; GenerateEntrypointForPackedAPI(entrypoint_mangled, run_func_mangled); } @@ -331,15 +339,460 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { CreateFuncRegistry(); GenerateCrtSystemLib(); } - if (metadata_.defined() && metadata_->executor == runtime::kTvmExecutorAot) { + LOG(INFO) << "Metadata " << metadata_.defined() << " exec " << metadata_->executor(); + if (metadata_.defined() && metadata_->executor() == runtime::kTvmExecutorAot) { GenerateAOTDescriptor(); } code_ << ";"; } }; +class CMetadataWriterVisitor : public ::tvm::AttrVisitor { + private: + std::stringstream struct_defs_; + + std::vector streams_; + std::stringstream* current_stream_; + + void Visit(const char* key, double* value) override { (*current_stream_) << *value; } + + void Visit(const char* key, int64_t* value) override { (*current_stream_) << *value << "L"; } + + void Visit(const char* key, uint64_t* value) override { (*current_stream_) << *value << "UL"; } + + void Visit(const char* key, int* value) override { (*current_stream_) << *value; } + + void Visit(const char* key, bool* value) override { + (*current_stream_) << (value ? "true" : "false"); + } + + void Visit(const char* key, std::string* value) override { + (*current_stream_) << "\"" << value + << "\""; // todo: ->replace('\\', "\\\\").replace('\"', "\\\"") << "\""; + } + + void Visit(const char* key, void** value) override { (*current_stream_) << *value; } + + void Visit(const char* key, DataType* value) override { + (*current_stream_) << "DLDataType{" << value->code() << ", " << value->bits() << ", " + << value->lanes() << "}"; + } + + void Visit(const char* key, runtime::NDArray* value) override { + ICHECK(false) << "at key " << key << ": cannot emit metadata of type NDArray"; + } + + void Visit(const char* key, runtime::ObjectRef* value) override { + // if (value->as< + // todo + } +}; + +class MetadataStructDefiner : public AttrVisitor { + using MetadataTypeIndex = runtime::metadata::MetadataTypeIndex; + + public: + void Visit(const char* key, double* value) final { + // dns: mangle name + code_ << " double " << key << ";" << std::endl; + } + + void Visit(const char* key, int64_t* value) final { + // dns: mangle name + code_ << " int64_t " << key << ";" << std::endl; + } + + void Visit(const char* key, uint64_t* value) final { + // dns: mangle name + code_ << " uint64_t " << key << ";" << std::endl; + } + void Visit(const char* key, int* value) final { + // dns: mangle name + code_ << " int " << key << ";" << std::endl; + } + void Visit(const char* key, bool* value) final { + // dns: mangle name + code_ << " uint8_t " << key << ";" << std::endl; + } + void Visit(const char* key, std::string* value) final { + // dns: mangle name + code_ << " const char* " << key << ";" << std::endl; + } + void Visit(const char* key, void** value) final { + // dns: mangle name + code_ << " void* " << key << ";" << std::endl; + } + void Visit(const char* key, DataType* value) final { + // dns: mangle name + code_ << " DLDataType " << key << ";" << std::endl; + } + + void Visit(const char* key, runtime::NDArray* value) final { + // TODO(areusch): probably we could consolidate --link-params here, tho... + ICHECK(false) << "do not support serializing NDArray as metadata"; + } + + void WriteComma() { + if (!is_first_item_) { + code_ << ", "; + } + } + + void VisitArray(const char* key, const runtime::metadata::MetadataArrayNode* array) { + switch (array->type_index) { + case MetadataTypeIndex::kUint64: + code_ << " uint64_t** " << key << ";" << std::endl; + case MetadataTypeIndex::kInt64: + code_ << " int64_t** " << key << ";" << std::endl; + case MetadataTypeIndex::kBool: + code_ << " bool** " << key << ";" << std::endl; + case MetadataTypeIndex::kString: + code_ << " const char** " << key << ";" << std::endl; + default: + CHECK(false) << "Field " << key << ": unknown MetadataTypeIndex: " << array->type_index; + } + } + + // const ArrayNode* arr = value->as(); + // if (arr != nullptr) { + // // dns: mangle name + + // code_ << " " << "" << key << ";" << std::endl; + // WriteComma(); + // code_ << "{"; + // if (arr->size() > 0) { + // is_first_item_ = true; + // for (ObjectRef o : *arr) { + // // todo might have to switch on object type. + // WriteComma(); + // ReflectionVTable::Global()->VisitAttrs(o.get(), this); + // } + // } + // code_ << "}"; + // return; + // } + // } + + void Visit(const char* key, ObjectRef* value) final { + auto metadata = Downcast(*value); + auto arr = metadata.as(); + if (arr != nullptr) { + VisitArray(key, arr); + return; + } + + auto old_is_first_item = is_first_item_; + is_first_item_ = true; + code_ << "{"; + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + code_ << "}"; + is_first_item_ = old_is_first_item; + } + + std::string GetOutput() { return code_.str(); } + + private: + ::std::stringstream code_; + bool is_first_item_; +}; + +static std::string address_from_parts(const std::vector& parts) { + std::stringstream ss; + for (unsigned int i = 0; i < parts.size(); ++i) { + if (i > 0) { + ss << "_"; + } + ss << parts[i]; + } + return ss.str(); +} + +class MetadataQueuer : public AttrVisitor { + public: + using QueueItem = std::tuple; + explicit MetadataQueuer(std::vector* queue) : queue_{queue} {} + + void Visit(const char* key, double* value) final {} + void Visit(const char* key, int64_t* value) final {} + void Visit(const char* key, uint64_t* value) final {} + void Visit(const char* key, int* value) final {} + void Visit(const char* key, bool* value) final {} + void Visit(const char* key, std::string* value) final {} + void Visit(const char* key, DataType* value) final {} + void Visit(const char* key, runtime::NDArray* value) final {} + void Visit(const char* key, void** value) final {} + + void Visit(const char* key, ObjectRef* value) final { + address_parts_.push_back(key); + if (value->as() != nullptr) { + auto metadata = Downcast(*value); + const runtime::metadata::MetadataArrayNode* arr = + value->as(); + std::cout << "Is array? " << arr << std::endl; + if (arr != nullptr) { + for (unsigned int i = 0; i < arr->array.size(); i++) { + ObjectRef o = arr->array[i]; + std::cout << "queue-visiting array element " << i << ": " << o->type_index() << " (" + << o.operator->() << ")" << std::endl; + if (o.as() != nullptr) { + std::stringstream ss; + ss << i; + address_parts_.push_back(ss.str()); + runtime::metadata::MetadataBase metadata = Downcast(o); + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + address_parts_.pop_back(); + } + } + } else { + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + } + + queue_->push_back(std::make_tuple(address_from_parts(address_parts_), + Downcast(*value))); + } + address_parts_.pop_back(); + } + + private: + std::vector* queue_; + std::vector address_parts_; +}; + +std::string MetadataArrayTypeToCType(const runtime::metadata::MetadataArrayNode* array) { + using MetadataTypeIndex = runtime::metadata::MetadataTypeIndex; + + switch (array->type_index) { + case MetadataTypeIndex::kInt64: + return "int64_t"; + break; + case MetadataTypeIndex::kUint64: + return "uint64_t"; + break; + case MetadataTypeIndex::kBool: + return "int8_t"; + break; + case MetadataTypeIndex::kString: + return "const char*"; + break; + case MetadataTypeIndex::kMetadata: + return ::std::string{"struct "} + array->struct_name; + break; + default: + ICHECK(false) << "Unexpected MetadataTypeIndex " << array->type_index; + return ""; + }; +} + +class MetadataSerializer : public AttrVisitor { + public: + static constexpr const char* kGlobalSymbol = "kTvmgenMetadata"; + + MetadataSerializer() : is_first_item_{true} {} + + void WriteComma() { + if (is_first_item_) { + is_first_item_ = false; + } else { + code_ << ", " << std::endl; + } + } + + void WriteKey(const char* key) { + if (key != nullptr) { + code_ << " /* " << key << "*/"; + } + } + + void Visit(const char* key, double* value) final { + WriteComma(); + code_.setf(std::ios::hex | std::ios::showbase | std::ios::fixed | std::ios::scientific, + std::ios::basefield | std::ios::showbase | std::ios::floatfield); + code_ << *value; + WriteKey(key); + } + + void Visit(const char* key, int64_t* value) final { + WriteComma(); + code_ << *value << "L"; + WriteKey(key); + } + + void Visit(const char* key, uint64_t* value) final { + WriteComma(); + code_ << *value << "UL"; + WriteKey(key); + } + void Visit(const char* key, int* value) final { + WriteComma(); + code_ << *value; + WriteKey(key); + } + void Visit(const char* key, bool* value) final { + WriteComma(); + code_ << *value; + WriteKey(key); + } + void Visit(const char* key, std::string* value) final { + WriteComma(); + code_ << "\"" << *value << "\""; + WriteKey(key); + } + void Visit(const char* key, void** value) final { + WriteComma(); + code_ << *value; + WriteKey(key); + } + void Visit(const char* key, DataType* value) final { + WriteComma(); + code_ << "DLDataType{" << value->code() << ", " << value->bits() << ", " << value->lanes() + << "}"; + WriteKey(key); + } + + void Visit(const char* key, runtime::NDArray* value) final { + // TODO(areusch): probably we could consolidate --link-params here, tho... + ICHECK(false) << "do not support serializing NDArray as metadata"; + } + + void VisitArray(const runtime::metadata::MetadataArrayNode* array) { + std::cout << "visit array " << array << ": " << array->type_index << " " << array->array.size() + << std::endl; + auto old_is_first_item = is_first_item_; + is_first_item_ = true; + for (unsigned int i = 0; i < array->array.size(); ++i) { // ObjectRef o : *(array->array)) { + ObjectRef o = array->array[i]; + std::cout << "visiting array element " << i << ": " << o->type_index() << " (" + << o.operator->() << ")" << std::endl; + if (o->IsInstance()) { + int64_t i = Downcast(o); + Visit(nullptr, &i); + continue; + } + + if (o->IsInstance()) { + std::string s = Downcast(o); + Visit(nullptr, &s); + continue; + } + + runtime::metadata::MetadataBase metadata = Downcast(o); + std::cout << "visit member " << metadata->get_name() << std::endl; + std::stringstream i_str; + i_str << i; + address_.push_back(i_str.str()); + Visit(nullptr, &metadata); + address_.pop_back(); + // ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + } + is_first_item_ = old_is_first_item; + } + + void Visit(const char* key, ObjectRef* value) final { + const runtime::metadata::MetadataArrayNode* arr = + value->as(); + std::cout << "Is array? " << arr << std::endl; + if (arr != nullptr) { + WriteComma(); + if (key != nullptr) { + address_.push_back(key); + } + code_ << address_from_parts(address_) << " , " << arr->array.size() << " /* " << key + << "_size */"; + if (key != nullptr) { + address_.pop_back(); + } + // VisitArray(key, Downcast(*value).operator->()); + // WriteComma(); + // code_ << "{"; + // if (arr->size() > 0) { + // is_first_item_ = true; + // for (ObjectRef* o : *arr) { + // // todo might have to switch on object type. + // WriteComma(); + // ReflectionVTable::Global()->VisitAttrs(o.get(), this); + // } + // } + // code_ << "}"; + return; + } + + std::cout << "downcast..." << std::endl; + runtime::metadata::MetadataBase metadata = Downcast(*value); + std::cout << "downcast ok: " << metadata->get_name() << std::endl; + + if (key != nullptr) { // NOTE: outermost call passes nullptr key + address_.push_back(key); + } + ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); + if (key != nullptr) { // NOTE: outermost call passes nullptr key + address_.pop_back(); + } + } + + // void EnterStruct(::tvm::runtime::metadata::MetadataBase metadata) { + // const char* type_key = metadata->GetTypeKey(); + // is_defining_struct_.emplace_back( + // !generated_struct_decls_.contains(type_key)); + // if (is_defining_struct()) { + // decl_ << "struct " << get_struct_name(metadata) << "{"; + // } + // is_first_item_.emplace_back(true); + // } + + // void ExitStruct(::tvm::runtime::metadata::MetadataBase metadata) { + // decl_ << "}; // struct " << get_struct_name(metadata); + // is_first_item_.pop_back(); + // } + + void CodegenMetadata(::tvm::runtime::metadata::Metadata metadata) { + decl_ << "#include " << std::endl + << "#include " << std::endl + << "#include " << std::endl; + std::vector queue; + MetadataQueuer queuer{&queue}; + queuer.Visit(kGlobalSymbol, &metadata); + + for (MetadataQueuer::QueueItem item : queue) { + auto struct_name = std::get<0>(item); + auto obj = std::get<1>(item); + auto arr = obj.as(); + std::cout << "codegen: " << struct_name; + is_first_item_ = true; + address_.push_back(struct_name); + if (arr != nullptr) { + std::string c_type{"const "}; + if (arr->type_index == runtime::metadata::MetadataTypeIndex::kString) { + // note drop const + c_type = MetadataArrayTypeToCType(arr); + } else { + c_type += MetadataArrayTypeToCType(arr); + } + code_ << c_type << "[" << arr->array.size() + << "] = {" << std::endl; + VisitArray(arr); + } else { + code_ << "const struct TVMMetadata " << struct_name << " = {" << std::endl; + Visit(nullptr, &obj); + } + address_.pop_back(); + code_ << "};" << std::endl; + } + } + + std::string GetOutput() { return decl_.str() + code_.str(); } + + private: + std::vector address_; + std::stringstream decl_; + std::stringstream code_; + bool is_first_item_; + std::unordered_set generated_struct_decls_; + std::vector is_defining_struct_; +}; + runtime::Module CreateCSourceCrtMetadataModule(const Array& modules, Target target, - relay::Runtime runtime, runtime::Metadata metadata) { + relay::Runtime runtime, + runtime::metadata::Metadata metadata) { Array func_names; for (runtime::Module mod : modules) { auto pf_funcs = mod.GetFunction("get_func_names"); @@ -358,6 +811,34 @@ runtime::Module CreateCSourceCrtMetadataModule(const Array& mod return std::move(csrc_metadata_module); } +runtime::Module CreateCSourceCppMetadataModule(runtime::metadata::Metadata metadata) { + // MetadataStructDefiner definer; + // ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), &definer); + MetadataSerializer serializer; + serializer.CodegenMetadata(metadata); + std::stringstream lookup_func; + lookup_func << "#ifdef __cplusplus\n" + << "extern \"C\"\n" + << "#endif\n"; + + lookup_func << "TVM_DLL int32_t get_c_metadata(TVMValue* arg_values, int* arg_tcodes, int " + "num_args, TVMValue* ret_values, int* ret_tcodes, void* resource_handle) {" + << std::endl; + lookup_func << " ret_values[0].v_handle = (void*) &" << MetadataSerializer::kGlobalSymbol + << ";" << std::endl; + lookup_func << " ret_tcodes[0] = kTVMOpaqueHandle;" << std::endl; + lookup_func << " return 0;" << std::endl; + lookup_func << "};" << std::endl; + + auto mod = MetadataModuleCreate(metadata); + std::vector func_names{"get_c_metadata"}; + // definer.GetOutput() + + auto c = CSourceModuleCreate(serializer.GetOutput() + lookup_func.str(), "c", func_names, + Array()); + mod->Import(c); + return mod; +} + // supports limited save without cross compile class DeviceSourceModuleNode final : public runtime::ModuleNode { public: @@ -423,7 +904,8 @@ TVM_REGISTER_GLOBAL("runtime.CreateCSourceCrtMetadataModule") .set_body_typed([](const Array& modules, Target target, relay::Runtime runtime) { // Note that we don't need metadata when we compile a single operator - return CreateCSourceCrtMetadataModule(modules, target, runtime, runtime::Metadata()); + return CreateCSourceCrtMetadataModule(modules, target, runtime, + runtime::metadata::Metadata()); }); } // namespace codegen diff --git a/src/target/source/source_module.h b/src/target/source/source_module.h index fde363c1198a..9028f6dc410a 100644 --- a/src/target/source/source_module.h +++ b/src/target/source/source_module.h @@ -26,24 +26,32 @@ #define TVM_TARGET_SOURCE_SOURCE_MODULE_H_ #include +#include #include #include -#include "../../runtime/meta_data.h" namespace tvm { namespace codegen { /*! + * \brief Wrap the submodules that are to be wrapped in a c-source metadata module for C runtime. * \param modules The modules to be wrapped. * \param target the target the modules are compiled for. * \param runtime the runtime to code generate against - * \param metadata the metadata needed for code generation. + * \param metadata Compiler-generated metadata exported to runtime. * \return The wrapped module. */ runtime::Module CreateCSourceCrtMetadataModule(const Array& modules, Target target, - relay::Runtime runtime, runtime::Metadata metadata); + relay::Runtime runtime, + runtime::metadata::Metadata metadata); + +/*! + * \brief Create C++-runtime targeted metadata module for "c" backend. + * \param metadata Compiler-generated metadata. + */ +runtime::Module CreateCSourceCppMetadataModule(runtime::metadata::Metadata metadata); } // namespace codegen } // namespace tvm diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index a5ecf4ba8296..f8dca1a4f7c6 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -328,8 +328,8 @@ class BuiltinLower : public StmtExprMutator { // cpacked call resource_handle if (!use_string_lookup) { - tir::Var resource_handle = Downcast(op->args[arg_count]); - packed_args.push_back(StringImm(resource_handle->name_hint)); +// tir::Var resource_handle = Downcast(op->args[arg_count]); +// packed_args.push_back(StringImm(resource_handle->name_hint)); } auto builtin_call = use_string_lookup ? builtin::tvm_call_packed_lowered() diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc index d5a4c91a3c43..ff3641cd6982 100644 --- a/tests/cpp/build_module_test.cc +++ b/tests/cpp/build_module_test.cc @@ -107,18 +107,22 @@ TEST(BuildModule, Heterogeneous) { auto elemwise_sub = compute( C->shape, [©, &C](PrimExpr i) { return copy[i] - C[i]; }, "elemwise_sub"); - With cuda_scope(target_cuda); - auto s1 = topi::cuda::schedule_injective(target_cuda, {elemwise_add}); + auto fcreate_s1 = [=]() { + With cuda_scope(target_cuda); + return topi::cuda::schedule_injective(target_cuda, {elemwise_add}); + }; - With llvm_scope(target_llvm); - auto s2 = create_schedule({elemwise_sub->op}); + auto fcreate_s2 = [=]() { + With llvm_scope(target_llvm); + return create_schedule({elemwise_sub->op}); + }; auto args1 = Array({A, B, elemwise_add}); auto args2 = Array({copy, C, elemwise_sub}); std::unordered_map binds; - auto lowered_s1 = LowerSchedule(s1, args1, "elemwise_add", binds); - auto lowered_s2 = LowerSchedule(s2, args2, "elemwise_sub", binds); + auto lowered_s1 = LowerSchedule(fcreate_s1(), args1, "elemwise_add", binds); + auto lowered_s2 = LowerSchedule(fcreate_s2(), args2, "elemwise_sub", binds); Map inputs = {{target_cuda, lowered_s1}, {target_llvm, lowered_s2}}; auto module = build(inputs, Target()); diff --git a/tests/cpp/test_metadata.cc b/tests/cpp/test_metadata.cc new file mode 100644 index 000000000000..a4afdc40e2e5 --- /dev/null +++ b/tests/cpp/test_metadata.cc @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +namespace { +const int64_t kNormalInput1Shape[4] = {1, 5, 5, 3}; +const struct TVMTensorInfo kNormalInputs[1] = {{"input1", kNormalInput1Shape, 4, DLDataType{1, 2, 3}}}; + +const int64_t kNormalOutput1Shape[3] = {3, 8, 8}; +const struct TVMTensorInfo kNormalOutputs[1] = {{"output1", kNormalOutput1Shape, 3, DLDataType{3, 4, 5}}}; + +const char* kNormalDevices[2] = {"device1", "device2"}; + +const struct TVMMetadata kNormal = { + TVM_METADATA_VERSION, + kNormalInputs, + 1, + kNormalOutputs, + 1, + kNormalDevices, + 2, + "aot", + "default", + "c", + true, + }; +} + +using ::tvm::runtime::Downcast; +using ::testing::Eq; +using ::testing::ElementsAre; + +TEST(Metadata, ParseStruct) { + tvm::runtime::metadata::Metadata md = tvm::runtime::metadata::Metadata(&kNormal); + EXPECT_THAT(md->version(), Eq(TVM_METADATA_VERSION)); + EXPECT_THAT(md->num_inputs(), Eq(1)); + + auto input1 = md->inputs()[0]; + EXPECT_THAT(input1->name(), Eq("input1")); + EXPECT_THAT(input1->shape(), ElementsAre(1, 5, 5, 3)); + EXPECT_THAT(input1->dtype(), Eq(tvm::runtime::DataType(DLDataType{1, 2, 3}))); + + EXPECT_THAT(md->num_outputs(), Eq(1)); + auto output1 = md->outputs()[0]; + EXPECT_THAT(output1->name(), Eq("output1")); + EXPECT_THAT(::std::vector(output1->shape()), ElementsAre(3, 8, 8)); + EXPECT_THAT(output1->dtype(), Eq(tvm::runtime::DataType(DLDataType{3, 4, 5}))); + + auto devices = md->devices(); + EXPECT_THAT(devices, ElementsAre(::tvm::runtime::String("device1"), + ::tvm::runtime::String("device2"))); + + EXPECT_THAT(md->executor(), Eq("aot")); + EXPECT_THAT(md->mod_name(), Eq("default")); + EXPECT_THAT(md->interface_api(), Eq("c")); + EXPECT_THAT(md->use_unpacked_api(), Eq(true)); +} + +class TestVisitor : public tvm::AttrVisitor { + public: + using Element = ::std::tuple<::std::string, ::tvm::runtime::ObjectRef>; + void Visit(const char* key, double* value) final { + keys.push_back(key); + values.push_back(::tvm::FloatImm(::tvm::runtime::DataType(kDLFloat, 64, 1), *value)); + } + void Visit(const char* key, int64_t* value) final { + keys.push_back(key); + values.push_back(::tvm::IntImm(::tvm::runtime::DataType(kDLInt, 64, 1), *value)); + } + void Visit(const char* key, uint64_t* value) final { + keys.push_back(key); + int64_t v; + *(reinterpret_cast(&v)) = *value; + values.push_back(::tvm::IntImm(::tvm::runtime::DataType(kDLUInt, 64, 1), v)); + } + void Visit(const char* key, int* value) final { + keys.push_back(key); + values.push_back(::tvm::IntImm(::tvm::runtime::DataType(kDLInt, 64, 1), *value)); + } + void Visit(const char* key, bool* value) final { + keys.push_back(key); + values.push_back(::tvm::Bool(*value)); + } + void Visit(const char* key, std::string* value) final { + keys.push_back(key); + values.push_back(::tvm::runtime::String(*value)); + } + void Visit(const char* key, tvm::runtime::DataType* value) final { + keys.push_back(key); + values.push_back(::tvm::PrimType(*value)); + } + void Visit(const char* key, tvm::runtime::NDArray* value) final { + keys.push_back(key); + values.push_back(*value); + } + void Visit(const char* key, void** value) final { + CHECK(false) << "Do not expect this type"; + } + + void Visit(const char* key, ::tvm::runtime::ObjectRef* value) final { + keys.push_back(key); + values.push_back(*value); + } + + std::vector keys; + std::vector<::tvm::runtime::ObjectRef> values; +}; + +TEST(Metadata, Visitor) { + tvm::runtime::metadata::Metadata md = tvm::runtime::metadata::Metadata(&kNormal); + TestVisitor v; + ::tvm::ReflectionVTable::Global()->VisitAttrs(md.operator->(), &v); + + EXPECT_THAT(v.keys, ElementsAre( + Eq("version"), + Eq("inputs"), + Eq("outputs"), + Eq("devices"), + Eq("executor"), + Eq("mod_name"), + Eq("interface_api"), + Eq("use_unpacked_api") + )); + + EXPECT_THAT(Downcast(v.values[0])->value, Eq(TVM_METADATA_VERSION)); + + // Just identify the tensor. + auto input_array = Downcast(v.values[1]); + EXPECT_THAT(input_array->type_index, Eq(tvm::runtime::metadata::MetadataTypeIndex::kMetadata)); + EXPECT_THAT(input_array->struct_name, Eq(std::string("TVMTensorInfo"))); + EXPECT_THAT(input_array->array.size(), Eq(1)); + auto array0 = input_array->array[0]; + + auto input1 = Downcast(array0); + EXPECT_THAT(input1->name(), Eq("input1")); + + auto output_array = Downcast(v.values[2]); + EXPECT_THAT(output_array->type_index, Eq(tvm::runtime::metadata::MetadataTypeIndex::kMetadata)); + EXPECT_THAT(output_array->struct_name, Eq("TVMTensorInfo")); + auto output1 = Downcast(output_array->array[0]); + + EXPECT_THAT(output1->name(), Eq("output1")); + + auto devices = Downcast(v.values[3]); + EXPECT_THAT(devices->type_index, Eq(tvm::runtime::metadata::MetadataTypeIndex::kString)); + EXPECT_THAT(Downcast(devices->array[0]), Eq("device1")); + EXPECT_THAT(Downcast(devices->array[1]), Eq("device1")); + +// EXPECT_THAT(Downcast(v.values[0])->value, Eq(TVM_METADATA_VERSION)); +} diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py index d335528914b0..7dd65ca20ef4 100644 --- a/tests/python/relay/aot/aot_test_utils.py +++ b/tests/python/relay/aot/aot_test_utils.py @@ -206,6 +206,7 @@ def parametrize_aot_options(test): interface_api = ["packed", "c"] use_unpacked_api = [True, False] test_runner = [AOT_DEFAULT_RUNNER, AOT_CORSTONE300_RUNNER] + print("TEST RUNNERS", test_runner) all_combinations = itertools.product(interface_api, use_unpacked_api, test_runner) diff --git a/tests/python/relay/aot/test_cpp_aot.py b/tests/python/relay/aot/test_cpp_aot.py new file mode 100644 index 000000000000..1a290caade7e --- /dev/null +++ b/tests/python/relay/aot/test_cpp_aot.py @@ -0,0 +1,118 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import sys +import textwrap + +import numpy as np +import pytest + +import tvm +from tvm import relay, TVMError +from tvm.ir.module import IRModule +from tvm.relay import backend, testing, transform +from tvm.relay.testing import byoc +from tvm.relay.op.annotation import compiler_begin, compiler_end +from aot_test_utils import ( + AOTTestModel, + AOT_DEFAULT_RUNNER, + generate_ref_data, + convert_to_relay, + compile_and_run, + compile_models, + parametrize_aot_options, +) + + +def print_mod_tree(m, indent=0): + print(f"{' ' * indent} - {m!r}") + for i in m.imported_modules: + print_mod_tree(i, indent + 2) + + +unpacked_api = tvm.testing.parameter(True, False) + + +target_kind = tvm.testing.parameter("c", "llvm") + + +def test_conv2d(target_kind, unpacked_api): + RELAY_MODEL = textwrap.dedent( + """\ + #[version = "0.0.5"] + def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(3, 3, 5, 5), int8]) { + %1 = nn.conv2d( + %data, + %weight, + padding=[2, 2], + channels=3, + kernel_size=[5, 5], + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32"); + %2 = cast(nn.max_pool2d(%1, pool_size=[3, 3]), dtype="int8"); + %3 = nn.conv2d( + %2, + %weight, + padding=[2, 2], + channels=3, + kernel_size=[5, 5], + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32"); + %4 = nn.max_pool2d(%3, pool_size=[3, 3]); + %4 + } + """ + ) + ir_mod = tvm.parser.fromtext(RELAY_MODEL) + + main_func = ir_mod["main"] + shape_dict = {p.name_hint: p.checked_type.concrete_shape for p in main_func.params} + type_dict = {p.name_hint: p.checked_type.dtype for p in main_func.params} + + weight_data = np.ones(shape_dict["weight"]).astype(type_dict["weight"]) + input_data = np.ones(shape_dict["data"]).astype(type_dict["data"]) + + params = {"weight": weight_data} + inputs = {"data": input_data} + output_list = generate_ref_data(ir_mod, inputs, params) + + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + mod = tvm.relay.build( + ir_mod, + params=params, + target=target_kind, + executor=backend.Executor("aot", {"unpacked-api": unpacked_api, "interface-api": "packed"}), + ) + + print_mod_tree(mod.module) + + with tvm.contrib.utils.TempDirectory.set_keep_for_debug(): + mod.export_library("test.so", options=["-fpermissive"]) + mod.export_library("test.tar") + runner = tvm.runtime.load_module("test.so") + print_mod_tree(runner) + runner = tvm.runtime.executor.AotModule(runner["default"](tvm.cpu(0))) + runner.set_input(**inputs) + runner.run() + assert (runner.get_output(0).asnumpy() == output_list[0]).all() + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 8a2b1f1bb84d..fdcded355065 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -47,7 +47,7 @@ def test_error_c_interface_with_packed_api(): two = relay.add(relay.const(1), relay.const(1)) func = relay.Function([], two) - with pytest.raises(tvm.TVMError, match="Packed interface required for packed operators"): + with pytest.raises(tvm.TVMError, match='When interface_api == "c", need unpacked-api == true'): compile_and_run( AOTTestModel( module=IRModule.from_expr(func), inputs={}, outputs=generate_ref_data(func, {}) diff --git a/tests/python/unittest/test_metadata.py b/tests/python/unittest/test_metadata.py new file mode 100644 index 000000000000..98e912cbfad4 --- /dev/null +++ b/tests/python/unittest/test_metadata.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm + + +def test_