diff --git a/3rdparty/vta-hw b/3rdparty/vta-hw index 57db5a718c74..87ce9acfae55 160000 --- a/3rdparty/vta-hw +++ b/3rdparty/vta-hw @@ -1 +1 @@ -Subproject commit 57db5a718c74a788c98120ebbe1230797be698c8 +Subproject commit 87ce9acfae550d1a487746e9d06c2e250076e54c diff --git a/CMakeLists.txt b/CMakeLists.txt index 56170c693e3c..39035bba152e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,6 +36,7 @@ tvm_option(USE_LLVM "Build with LLVM, can be set to specific llvm-config path" O tvm_option(USE_STACKVM_RUNTIME "Include stackvm into the runtime" OFF) tvm_option(USE_GRAPH_RUNTIME "Build with tiny graph runtime" ON) tvm_option(USE_GRAPH_RUNTIME_DEBUG "Build with tiny graph runtime debug mode" OFF) +tvm_option(USE_GRAPH_RUNTIME_CUGRAPH "Build with tiny graph runtime cuGraph launch mode" OFF) tvm_option(USE_OPENMP "Build with OpenMP thread pool implementation" OFF) tvm_option(USE_RELAY_DEBUG "Building Relay in debug mode..." OFF) tvm_option(USE_RTTI "Build with RTTI" ON) @@ -321,6 +322,15 @@ if(USE_GRAPH_RUNTIME) set_source_files_properties(${RUNTIME_GRAPH_SRCS} PROPERTIES COMPILE_DEFINITIONS "TVM_GRAPH_RUNTIME_DEBUG") endif(USE_GRAPH_RUNTIME_DEBUG) + + if(USE_GRAPH_RUNTIME_CUGRAPH) + message(STATUS "Build with Graph runtime cuGraph support...") + file(GLOB RUNTIME_CUGRAPH_SRCS src/runtime/graph/cugraph/*.cc) + list(APPEND RUNTIME_SRCS ${RUNTIME_CUGRAPH_SRCS}) + set_source_files_properties(${RUNTIME_GRAPH_SRCS} + PROPERTIES COMPILE_DEFINITIONS "TVM_GRAPH_RUNTIME_CUGRAPH") + endif(USE_GRAPH_RUNTIME_CUGRAPH) + endif(USE_GRAPH_RUNTIME) if(USE_VM_PROFILER) diff --git a/Jenkinsfile b/Jenkinsfile index bba3950aea87..506dcab4e306 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -50,7 +50,7 @@ ci_cpu = "tlcpack/ci-cpu:v0.72-t0" ci_wasm = "tlcpack/ci-wasm:v0.70" ci_i386 = "tlcpack/ci-i386:v0.72-t0" ci_qemu = "tlcpack/ci-qemu:v0.01" -ci_arm = "tlcpack/ci-arm:v0.01" +ci_arm = "tlcpack/ci-arm:v0.02" // <--- End of regex-scanned config. // tvm libraries diff --git a/apps/android_camera/models/prepare_model.py b/apps/android_camera/models/prepare_model.py index ab20e028c2ad..f155d46c31a4 100644 --- a/apps/android_camera/models/prepare_model.py +++ b/apps/android_camera/models/prepare_model.py @@ -106,7 +106,7 @@ def main(model_str, output_path): f.write(graph) print("dumping params...") with open(output_path_str + "/" + "deploy_param.params", "wb") as f: - f.write(relay.save_param_dict(params)) + f.write(runtime.save_param_dict(params)) print("dumping labels...") synset_url = "".join( [ diff --git a/apps/bundle_deploy/build_model.py b/apps/bundle_deploy/build_model.py index 0991ac9ad94b..8fbc01bcf4a6 100644 --- a/apps/bundle_deploy/build_model.py +++ b/apps/bundle_deploy/build_model.py @@ -20,7 +20,7 @@ import os from tvm import relay import tvm -from tvm import te +from tvm import te, runtime import logging import json from tvm.contrib import cc as _cc @@ -70,7 +70,7 @@ def build_module(opts): with open( os.path.join(build_dir, file_format_str.format(name="params", ext="bin")), "wb" ) as f_params: - f_params.write(relay.save_param_dict(params)) + f_params.write(runtime.save_param_dict(params)) def build_test_module(opts): @@ -113,7 +113,7 @@ def build_test_module(opts): with open( os.path.join(build_dir, file_format_str.format(name="test_params", ext="bin")), "wb" ) as f_params: - f_params.write(relay.save_param_dict(lowered_params)) + f_params.write(runtime.save_param_dict(lowered_params)) with open( os.path.join(build_dir, file_format_str.format(name="test_data", ext="bin")), "wb" ) as fp: diff --git a/apps/bundle_deploy/runtime.cc b/apps/bundle_deploy/runtime.cc index 3224028b60a1..2f7e3848b4bf 100644 --- a/apps/bundle_deploy/runtime.cc +++ b/apps/bundle_deploy/runtime.cc @@ -23,6 +23,7 @@ #include #include "../../src/runtime/c_runtime_api.cc" +#include "../../src/runtime/container.cc" #include "../../src/runtime/cpu_device_api.cc" #include "../../src/runtime/file_utils.cc" #include "../../src/runtime/graph/graph_runtime.cc" diff --git a/apps/sgx/src/build_model.py b/apps/sgx/src/build_model.py index 868d3bcb9fc4..1fc297d8a094 100755 --- a/apps/sgx/src/build_model.py +++ b/apps/sgx/src/build_model.py @@ -23,7 +23,7 @@ from os import path as osp import sys -from tvm import relay +from tvm import relay, runtime from tvm.relay import testing import tvm from tvm import te @@ -49,7 +49,7 @@ def main(): with open(osp.join(build_dir, "graph.json"), "w") as f_graph_json: f_graph_json.write(graph) with open(osp.join(build_dir, "params.bin"), "wb") as f_params: - f_params.write(relay.save_param_dict(params)) + f_params.write(runtime.save_param_dict(params)) if __name__ == "__main__": diff --git a/apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py b/apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py index 42695d28fadb..3d8a349b8744 100644 --- a/apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py +++ b/apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py @@ -24,7 +24,7 @@ import onnx import tvm -from tvm import relay +from tvm import relay, runtime def _get_mod_and_params(model_file): @@ -60,7 +60,7 @@ def build_graph_lib(model_file, opt_level): f_graph.write(graph_json) with open(os.path.join(out_dir, "graph.params"), "wb") as f_params: - f_params.write(relay.save_param_dict(params)) + f_params.write(runtime.save_param_dict(params)) if __name__ == "__main__": diff --git a/cmake/config.cmake b/cmake/config.cmake index 872feb918a4f..257e62bd9b7d 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -102,6 +102,9 @@ set(USE_GRAPH_RUNTIME ON) # Whether enable additional graph debug functions set(USE_GRAPH_RUNTIME_DEBUG OFF) +# Whether enable tiny graph runtime for cudaGraph Launch +set(USE_GRAPH_RUNTIME_CUGRAPH OFF) + # Whether enable additional vm profiler functions set(USE_VM_PROFILER OFF) diff --git a/cmake/utils/FindEthosN.cmake b/cmake/utils/FindEthosN.cmake index d33b55f0c7a9..26d00a462b39 100644 --- a/cmake/utils/FindEthosN.cmake +++ b/cmake/utils/FindEthosN.cmake @@ -59,6 +59,7 @@ macro(find_ethosn use_ethosn) find_library(ETHOSN_COMPILER_LIBRARY NAMES EthosNSupport) set(ETHOSN_PACKAGE_VERSION "0.1.1") + set(ETHOSN_DEFINITIONS -DETHOSN_API_VERSION=${USE_ETHOSN_API_VERSION}) if(${USE_ETHOSN_HW} MATCHES ${IS_TRUE_PATTERN}) # Runtime hardware support @@ -70,7 +71,7 @@ macro(find_ethosn use_ethosn) find_library(ETHOSN_RUNTIME_LIBRARY NAMES EthosNDriver PATHS ${__ethosn_stack}/lib) find_library(ETHOSN_RUNTIME_LIBRARY NAMES EthosNDriver) - set(ETHOSN_DEFINITIONS -DETHOSN_HW) + set(ETHOSN_DEFINITIONS -DETHOSN_HW -DETHOSN_API_VERSION=${USE_ETHOSN_API_VERSION}) endif () if(ETHOSN_COMPILER_LIBRARY) diff --git a/docker/Dockerfile.ci_arm b/docker/Dockerfile.ci_arm index 020792700ee9..671ce04e8c1d 100644 --- a/docker/Dockerfile.ci_arm +++ b/docker/Dockerfile.ci_arm @@ -16,7 +16,7 @@ # under the License. # CI docker arm env -# tag: v0.10 +# tag: v0.02 FROM ubuntu:18.04 diff --git a/docker/install/ubuntu_install_python.sh b/docker/install/ubuntu_install_python.sh index 58d72f327aa6..d3af336491cc 100755 --- a/docker/install/ubuntu_install_python.sh +++ b/docker/install/ubuntu_install_python.sh @@ -34,7 +34,7 @@ apt-get install -y python-pip python-dev python3.6 python3.6-dev rm -f /usr/bin/python3 && ln -s /usr/bin/python3.6 /usr/bin/python3 # Install pip -cd /tmp && wget -q https://bootstrap.pypa.io/get-pip.py && python2 get-pip.py && python3.6 get-pip.py +cd /tmp && wget -q https://bootstrap.pypa.io/get-pip.py && python3.6 get-pip.py # Pin pip version pip3 install pip==19.3.1 diff --git a/docker/install/ubuntu_install_vitis_ai_packages_ci.sh b/docker/install/ubuntu_install_vitis_ai_packages_ci.sh index c34ed3addce2..774d85dcf68a 100644 --- a/docker/install/ubuntu_install_vitis_ai_packages_ci.sh +++ b/docker/install/ubuntu_install_vitis_ai_packages_ci.sh @@ -23,7 +23,7 @@ set -o pipefail export PYXIR_HOME=/opt/pyxir mkdir "$PYXIR_HOME" -pip3 install progressbar +pip3 install progressbar h5py==2.10.0 -git clone --recursive --branch v0.1.3 https://github.com/Xilinx/pyxir.git "${PYXIR_HOME}" +git clone --recursive --branch v0.1.6 --depth 1 https://github.com/Xilinx/pyxir.git "${PYXIR_HOME}" cd "${PYXIR_HOME}" && python3 setup.py install diff --git a/docs/deploy/android.rst b/docs/deploy/android.rst index 8c8fcfb49679..256978d00607 100644 --- a/docs/deploy/android.rst +++ b/docs/deploy/android.rst @@ -31,7 +31,7 @@ The code below will save the compilation output which is required on android tar with open("deploy_graph.json", "w") as fo: fo.write(graph.json()) with open("deploy_param.params", "wb") as fo: - fo.write(relay.save_param_dict(params)) + fo.write(runtime.save_param_dict(params)) deploy_lib.so, deploy_graph.json, deploy_param.params will go to android target. diff --git a/golang/sample/gen_mobilenet_lib.py b/golang/sample/gen_mobilenet_lib.py index b82e0c476b9f..12f215b4fd9c 100644 --- a/golang/sample/gen_mobilenet_lib.py +++ b/golang/sample/gen_mobilenet_lib.py @@ -16,7 +16,7 @@ # under the License. import os -from tvm import relay, transform +from tvm import relay, transform, runtime from tvm.contrib.download import download_testdata @@ -94,4 +94,4 @@ def extract(path): fo.write(graph) with open("./mobilenet.params", "wb") as fo: - fo.write(relay.save_param_dict(params)) + fo.write(runtime.save_param_dict(params)) diff --git a/include/tvm/arith/bound.h b/include/tvm/arith/bound.h index 12b91cc033e5..f8e63ed5857a 100644 --- a/include/tvm/arith/bound.h +++ b/include/tvm/arith/bound.h @@ -25,7 +25,7 @@ #include #include -#include +#include #include #include diff --git a/include/tvm/arith/pattern.h b/include/tvm/arith/pattern.h index 301d95636ca4..3f1096b10a8b 100644 --- a/include/tvm/arith/pattern.h +++ b/include/tvm/arith/pattern.h @@ -25,7 +25,7 @@ #define TVM_ARITH_PATTERN_H_ #include -#include +#include #include namespace tvm { diff --git a/include/tvm/auto_scheduler/measure_record.h b/include/tvm/auto_scheduler/measure_record.h index ec40611d49b4..c82ed076eca7 100755 --- a/include/tvm/auto_scheduler/measure_record.h +++ b/include/tvm/auto_scheduler/measure_record.h @@ -34,7 +34,7 @@ namespace tvm { namespace auto_scheduler { -const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.5"; // NOLINT(*) +const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.6"; // NOLINT(*) /*! \brief Callback for logging the input and results of measurements to file */ class RecordToFileNode : public MeasureCallbackNode { diff --git a/include/tvm/auto_scheduler/search_task.h b/include/tvm/auto_scheduler/search_task.h index 9e7d3aa2cd32..14bf55abb447 100755 --- a/include/tvm/auto_scheduler/search_task.h +++ b/include/tvm/auto_scheduler/search_task.h @@ -26,6 +26,7 @@ #define TVM_AUTO_SCHEDULER_SEARCH_TASK_H_ #include +#include #include namespace tvm { @@ -120,6 +121,8 @@ class SearchTaskNode : public Object { HardwareParams hardware_params; /*! \brief The layout rewrite option used for measuring programs. */ LayoutRewriteOption layout_rewrite_option; + /*! \brief Names of some user defined input data used in program measuring. */ + Array task_input_names; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("compute_dag", &compute_dag); @@ -128,6 +131,7 @@ class SearchTaskNode : public Object { v->Visit("target_host", &target_host); v->Visit("hardware_params", &hardware_params); v->Visit("layout_rewrite_option", &layout_rewrite_option); + v->Visit("task_input_names", &task_input_names); } static constexpr const char* _type_key = "auto_scheduler.SearchTask"; @@ -148,9 +152,11 @@ class SearchTask : public ObjectRef { * \param target_host The target host device of this search task. * \param hardware_params Hardware parameters used in this search task. * \param layout_rewrite_option The layout rewrite option used for measuring programs. + * \param task_input_names Names of some user defined input data used in program measuring. */ SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host, - Optional hardware_params, LayoutRewriteOption layout_rewrite_option); + Optional hardware_params, LayoutRewriteOption layout_rewrite_option, + Array task_input_names); TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode); }; diff --git a/include/tvm/ir/adt.h b/include/tvm/ir/adt.h index 466a4f00fd5f..231c04e69821 100644 --- a/include/tvm/ir/adt.h +++ b/include/tvm/ir/adt.h @@ -29,8 +29,8 @@ #include #include -#include #include +#include #include #include diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 5302a55bfff3..2295baa0297b 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -26,8 +26,8 @@ #include #include -#include #include +#include #include #include diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index d6fb6a20b58a..07d582a298e4 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -28,8 +28,8 @@ #include #include #include -#include #include +#include #include #include diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 6557bbe31b8e..50c6f8dd8c3a 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -59,7 +59,6 @@ #include #include #include -#include #include #include diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 19b1ad0a0d83..b93a41e0c098 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -50,8 +50,8 @@ #define TVM_IR_TYPE_H_ #include -#include #include +#include #include #include diff --git a/include/tvm/node/attr_registry_map.h b/include/tvm/node/attr_registry_map.h index 552aa7114657..6acd2e7dbdd8 100644 --- a/include/tvm/node/attr_registry_map.h +++ b/include/tvm/node/attr_registry_map.h @@ -23,7 +23,7 @@ #ifndef TVM_NODE_ATTR_REGISTRY_MAP_H_ #define TVM_NODE_ATTR_REGISTRY_MAP_H_ -#include +#include #include #include diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h deleted file mode 100644 index 10b47a92bdcf..000000000000 --- a/include/tvm/node/container.h +++ /dev/null @@ -1,1486 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/node/container.h - * \brief Array/Map container in the DSL graph. - */ -#ifndef TVM_NODE_CONTAINER_H_ -#define TVM_NODE_CONTAINER_H_ - -#ifndef USE_FALLBACK_STL_MAP -#define USE_FALLBACK_STL_MAP 0 -#endif - -#include -#include -#include -#include - -#include -#include -#include - -namespace tvm { - -using runtime::Array; -using runtime::ArrayNode; -using runtime::Downcast; -using runtime::IterAdapter; -using runtime::make_object; -using runtime::Object; -using runtime::ObjectEqual; -using runtime::ObjectHash; -using runtime::ObjectPtr; -using runtime::ObjectPtrEqual; -using runtime::ObjectPtrHash; -using runtime::ObjectRef; -using runtime::String; -using runtime::StringObj; - -#if (USE_FALLBACK_STL_MAP != 0) - -/*! \brief Shared content of all specializations of hash map */ -class MapNode : public Object { - public: - /*! \brief Type of the keys in the hash map */ - using key_type = ObjectRef; - /*! \brief Type of the values in the hash map */ - using mapped_type = ObjectRef; - /*! \brief Type of the actual underlying container */ - using ContainerType = std::unordered_map; - /*! \brief Iterator class */ - using iterator = ContainerType::iterator; - /*! \brief Iterator class */ - using const_iterator = ContainerType::const_iterator; - /*! \brief Type of value stored in the hash map */ - using KVType = ContainerType::value_type; - - static_assert(std::is_standard_layout::value, "KVType is not standard layout"); - static_assert(sizeof(KVType) == 16 || sizeof(KVType) == 8, "sizeof(KVType) incorrect"); - - static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeMap; - static constexpr const char* _type_key = "Map"; - TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object); - - /*! - * \brief Number of elements in the SmallMapNode - * \return The result - */ - size_t size() const { return data_.size(); } - /*! - * \brief Count the number of times a key exists in the hash map - * \param key The indexing key - * \return The result, 0 or 1 - */ - size_t count(const key_type& key) const { return data_.count(key); } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The const reference to the value - */ - const mapped_type& at(const key_type& key) const { return data_.at(key); } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The mutable reference to the value - */ - mapped_type& at(const key_type& key) { return data_.at(key); } - /*! \return begin iterator */ - iterator begin() { return data_.begin(); } - /*! \return const begin iterator */ - const_iterator begin() const { return data_.begin(); } - /*! \return end iterator */ - iterator end() { return data_.end(); } - /*! \return end iterator */ - const_iterator end() const { return data_.end(); } - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - const_iterator find(const key_type& key) const { return data_.find(key); } - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - iterator find(const key_type& key) { return data_.find(key); } - /*! - * \brief Erase the entry associated with the iterator - * \param position The iterator - */ - void erase(const iterator& position) { data_.erase(position); } - /*! - * \brief Erase the entry associated with the key, do nothing if not exists - * \param key The indexing key - */ - void erase(const key_type& key) { data_.erase(key); } - /*! - * \brief Create an empty container - * \return The object created - */ - static ObjectPtr Empty() { return make_object(); } - - protected: - /*! - * \brief Create the map using contents from the given iterators. - * \param first Begin of iterator - * \param last End of iterator - * \tparam IterType The type of iterator - * \return ObjectPtr to the map created - */ - template - static ObjectPtr CreateFromRange(IterType first, IterType last) { - ObjectPtr p = make_object(); - p->data_ = ContainerType(first, last); - return p; - } - /*! - * \brief InsertMaybeReHash an entry into the given hash map - * \param kv The entry to be inserted - * \param map The pointer to the map, can be changed if re-hashing happens - */ - static void InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { - MapNode* map_node = static_cast(map->get()); - map_node->data_[kv.first] = kv.second; - } - /*! - * \brief Create an empty container with elements copying from another MapNode - * \param from The source container - * \return The object created - */ - static ObjectPtr CopyFrom(MapNode* from) { - ObjectPtr p = make_object(); - p->data_ = ContainerType(from->data_.begin(), from->data_.end()); - return p; - } - /*! \brief The real container storing data */ - ContainerType data_; - template - friend class Map; -}; - -#else - -/*! \brief Shared content of all specializations of hash map */ -class MapNode : public Object { - public: - /*! \brief Type of the keys in the hash map */ - using key_type = ObjectRef; - /*! \brief Type of the values in the hash map */ - using mapped_type = ObjectRef; - /*! \brief Type of value stored in the hash map */ - using KVType = std::pair; - /*! \brief Iterator class */ - class iterator; - - static_assert(std::is_standard_layout::value, "KVType is not standard layout"); - static_assert(sizeof(KVType) == 16 || sizeof(KVType) == 8, "sizeof(KVType) incorrect"); - - static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeMap; - static constexpr const char* _type_key = "Map"; - TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object); - - /*! - * \brief Number of elements in the SmallMapNode - * \return The result - */ - size_t size() const { return size_; } - /*! - * \brief Count the number of times a key exists in the hash map - * \param key The indexing key - * \return The result, 0 or 1 - */ - size_t count(const key_type& key) const; - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The const reference to the value - */ - const mapped_type& at(const key_type& key) const; - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The mutable reference to the value - */ - mapped_type& at(const key_type& key); - /*! \return begin iterator */ - iterator begin() const; - /*! \return end iterator */ - iterator end() const; - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - iterator find(const key_type& key) const; - /*! - * \brief Erase the entry associated with the iterator - * \param position The iterator - */ - void erase(const iterator& position); - /*! - * \brief Erase the entry associated with the key, do nothing if not exists - * \param key The indexing key - */ - void erase(const key_type& key) { erase(find(key)); } - - class iterator { - public: - using iterator_category = std::forward_iterator_tag; - using difference_type = int64_t; - using value_type = KVType; - using pointer = KVType*; - using reference = KVType&; - /*! \brief Default constructor */ - iterator() : index(0), self(nullptr) {} - /*! \brief Compare iterators */ - bool operator==(const iterator& other) const { - return index == other.index && self == other.self; - } - /*! \brief Compare iterators */ - bool operator!=(const iterator& other) const { return !(*this == other); } - /*! \brief De-reference iterators */ - pointer operator->() const; - /*! \brief De-reference iterators */ - reference operator*() const { return *((*this).operator->()); } - /*! \brief Prefix self increment, e.g. ++iter */ - iterator& operator++(); - /*! \brief Prefix self decrement, e.g. --iter */ - iterator& operator--(); - /*! \brief Suffix self increment */ - iterator operator++(int) { - iterator copy = *this; - ++(*this); - return copy; - } - /*! \brief Suffix self decrement */ - iterator operator--(int) { - iterator copy = *this; - --(*this); - return copy; - } - - protected: - /*! \brief Construct by value */ - iterator(uint64_t index, const MapNode* self) : index(index), self(self) {} - /*! \brief The position on the array */ - uint64_t index; - /*! \brief The container it points to */ - const MapNode* self; - - friend class DenseMapNode; - friend class SmallMapNode; - }; - /*! - * \brief Create an empty container - * \return The object created - */ - static inline ObjectPtr Empty(); - - protected: - /*! - * \brief Create the map using contents from the given iterators. - * \param first Begin of iterator - * \param last End of iterator - * \tparam IterType The type of iterator - * \return ObjectPtr to the map created - */ - template - static inline ObjectPtr CreateFromRange(IterType first, IterType last); - /*! - * \brief InsertMaybeReHash an entry into the given hash map - * \param kv The entry to be inserted - * \param map The pointer to the map, can be changed if re-hashing happens - */ - static inline void InsertMaybeReHash(const KVType& kv, ObjectPtr* map); - /*! - * \brief Create an empty container with elements copying from another SmallMapNode - * \param from The source container - * \return The object created - */ - static inline ObjectPtr CopyFrom(MapNode* from); - /*! \brief number of slots minus 1 */ - uint64_t slots_; - /*! \brief number of entries in the container */ - uint64_t size_; - // Reference class - template - friend class Map; -}; - -/*! \brief A specialization of small-sized hash map */ -class SmallMapNode : public MapNode, - public runtime::InplaceArrayBase { - private: - static constexpr uint64_t kInitSize = 2; - static constexpr uint64_t kMaxSize = 4; - - public: - using MapNode::iterator; - using MapNode::KVType; - - /*! \brief Defaults to the destructor of InplaceArrayBase */ - ~SmallMapNode() = default; - /*! - * \brief Count the number of times a key exists in the SmallMapNode - * \param key The indexing key - * \return The result, 0 or 1 - */ - size_t count(const key_type& key) const { return find(key).index < size_; } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The const reference to the value - */ - const mapped_type& at(const key_type& key) const { - iterator itr = find(key); - ICHECK(itr.index < size_) << "IndexError: key is not in Map"; - return itr->second; - } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The mutable reference to the value - */ - mapped_type& at(const key_type& key) { - iterator itr = find(key); - ICHECK(itr.index < size_) << "IndexError: key is not in Map"; - return itr->second; - } - /*! \return begin iterator */ - iterator begin() const { return iterator(0, this); } - /*! \return end iterator */ - iterator end() const { return iterator(size_, this); } - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - iterator find(const key_type& key) const { - KVType* ptr = static_cast(AddressOf(0)); - for (uint64_t i = 0; i < size_; ++i, ++ptr) { - if (ObjectEqual()(ptr->first, key)) { - return iterator(i, this); - } - } - return iterator(size_, this); - } - /*! - * \brief Erase the entry associated with the iterator - * \param position The iterator - */ - void erase(const iterator& position) { Erase(position.index); } - - private: - /*! - * \brief Remove a position in SmallMapNode - * \param index The position to be removed - */ - void Erase(const uint64_t index) { - if (index >= size_) { - return; - } - KVType* begin = static_cast(AddressOf(0)); - KVType* last = begin + (size_ - 1); - if (index + 1 == size_) { - last->first.ObjectRef::~ObjectRef(); - last->second.ObjectRef::~ObjectRef(); - } else { - *(begin + index) = std::move(*last); - } - size_ -= 1; - } - /*! - * \brief Create an empty container - * \param n Number of empty slots - * \return The object created - */ - static ObjectPtr Empty(uint64_t n = kInitSize) { - using ::tvm::runtime::make_inplace_array_object; - ObjectPtr p = make_inplace_array_object(n); - p->size_ = 0; - p->slots_ = n; - return p; - } - /*! - * \brief Create an empty container initialized with a given range - * \param n Number of empty slots - * \param first begin of iterator - * \param last end of iterator - * \tparam IterType The type of iterator - * \return The object created - */ - template - static ObjectPtr CreateFromRange(uint64_t n, IterType first, IterType last) { - ObjectPtr p = Empty(n); - KVType* ptr = static_cast(p->AddressOf(0)); - for (; first != last; ++first, ++p->size_) { - new (ptr++) KVType(*first); - } - return p; - } - /*! - * \brief Create an empty container with elements copying from another SmallMapNode - * \param from The source container - * \return The object created - */ - static ObjectPtr CopyFrom(SmallMapNode* from) { - KVType* first = static_cast(from->AddressOf(0)); - KVType* last = first + from->size_; - return CreateFromRange(from->size_, first, last); - } - /*! - * \brief InsertMaybeReHash an entry into the given hash map - * \param kv The entry to be inserted - * \param map The pointer to the map, can be changed if re-hashing happens - */ - static void InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { - SmallMapNode* map_node = static_cast(map->get()); - iterator itr = map_node->find(kv.first); - if (itr.index < map_node->size_) { - itr->second = kv.second; - return; - } - if (map_node->size_ < map_node->slots_) { - KVType* ptr = static_cast(map_node->AddressOf(map_node->size_)); - new (ptr) KVType(kv); - ++map_node->size_; - return; - } - uint64_t next_size = std::max(map_node->slots_ * 2, uint64_t(kInitSize)); - next_size = std::min(next_size, uint64_t(kMaxSize)); - ICHECK_GT(next_size, map_node->slots_); - ObjectPtr new_map = CreateFromRange(next_size, map_node->begin(), map_node->end()); - InsertMaybeReHash(kv, &new_map); - *map = std::move(new_map); - } - /*! - * \brief Increment the pointer - * \param index The pointer to be incremented - * \return The increased pointer - */ - uint64_t IncItr(uint64_t index) const { return index + 1 < size_ ? index + 1 : size_; } - /*! - * \brief Decrement the pointer - * \param index The pointer to be decremented - * \return The decreased pointer - */ - uint64_t DecItr(uint64_t index) const { return index > 0 ? index - 1 : size_; } - /*! - * \brief De-reference the pointer - * \param index The pointer to be dereferenced - * \return The result - */ - KVType* DeRefItr(uint64_t index) const { return static_cast(AddressOf(index)); } - /*! \brief A size function used by InplaceArrayBase */ - uint64_t GetSize() const { return size_; } - - protected: - friend class MapNode; - friend class DenseMapNode; - friend class runtime::InplaceArrayBase; -}; - -/*! \brief A specialization of hash map that implements the idea of array-based hash map. - * Another reference implementation can be found [1]. - * - * A. Overview - * - * DenseMapNode did several improvements over traditional separate chaining hash, - * in terms of cache locality, memory footprints and data organization. - * - * A1. Implicit linked list. For better cache locality, instead of using linked list - * explicitly for each bucket, we store list data into a single array that spans contiguously - * in memory, and then carefully design access patterns to make sure most of them fall into - * a single cache line. - * - * A2. 1-byte metadata. There is only 1 byte overhead for each slot in the array to indexing and - * traversal. This can be divided in 3 parts. - * 1) Reserved code: (0b11111111)_2 indicates a slot is empty; (0b11111110)_2 indicates protected, - * which means the slot is empty but not allowed to be written. - * 2) If not empty or protected, the highest bit is used to indicate whether data in the slot is - * head of a linked list. - * 3) The rest 7 bits are used as the "next pointer" (i.e. pointer to the next element). On 64-bit - * architecture, an ordinary pointer can take up to 8 bytes, which is not acceptable overhead when - * dealing with 16-byte ObjectRef pairs. Based on a commonly noticed fact that the lists are - * relatively short (length <= 3) in hash maps, we follow [1]'s idea that only allows the pointer to - * be one of the 126 possible values, i.e. if the next element of i-th slot is (i + x)-th element, - * then x must be one of the 126 pre-defined values. - * - * A3. Data blocking. We organize the array in the way that every 16 elements forms a data block. - * The 16-byte metadata of those 16 elements are stored together, followed by the real data, i.e. - * 16 key-value pairs. - * - * B. Implementation details - * - * B1. Power-of-2 table size and Fibonacci Hashing. We use power-of-two as table size to avoid - * modulo for more efficient arithmetics. To make the hash-to-slot mapping distribute more evenly, - * we use the Fibonacci Hashing [2] trick. - * - * B2. Traverse a linked list in the array. - * 1) List head. Assume Fibonacci Hashing maps a given key to slot i, if metadata at slot i - * indicates that it is list head, then we found the head; otherwise the list is empty. No probing - * is done in this procedure. 2) Next element. To find the next element of a non-empty slot i, we - * look at the last 7 bits of the metadata at slot i. If they are all zeros, then it is the end of - * list; otherwise, we know that the next element is (i + candidates[the-last-7-bits]). - * - * B3. InsertMaybeReHash an element. Following B2, we first traverse the linked list to see if this - * element is in the linked list, and if not, we put it at the end by probing the next empty - * position in one of the 126 candidate positions. If the linked list does not even exist, but the - * slot for list head has been occupied by another linked list, we should find this intruder another - * place. - * - * B4. Quadratic probing with triangle numbers. In open address hashing, it is provable that probing - * with triangle numbers can traverse power-of-2-sized table [3]. In our algorithm, we follow the - * suggestion in [1] that also use triangle numbers for "next pointer" as well as sparing for list - * head. - * - * [1] https://github.com/skarupke/flat_hash_map - * [2] https://programmingpraxis.com/2018/06/19/fibonacci-hash/ - * [3] https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/ - */ -class DenseMapNode : public MapNode { - private: - /*! \brief The number of elements in a memory block */ - static constexpr int kBlockCap = 16; - /*! \brief Maximum load factor of the hash map */ - static constexpr double kMaxLoadFactor = 0.99; - /*! \brief Binary representation of the metadata of an empty slot */ - static constexpr uint8_t kEmptySlot = uint8_t(0b11111111); - /*! \brief Binary representation of the metadata of a protected slot */ - static constexpr uint8_t kProtectedSlot = uint8_t(0b11111110); - /*! \brief Number of probing choices available */ - static constexpr int kNumJumpDists = 126; - /*! \brief Head of the implicit linked list */ - struct ListNode; - /*! \brief POD type of a block of memory */ - struct Block { - uint8_t bytes[kBlockCap + kBlockCap * sizeof(KVType)]; - }; - static_assert(sizeof(Block) == kBlockCap * (sizeof(KVType) + 1), "sizeof(Block) incorrect"); - static_assert(std::is_standard_layout::value, "Block is not standard layout"); - - public: - using MapNode::iterator; - - /*! - * \brief Destroy the DenseMapNode - */ - ~DenseMapNode() { this->Reset(); } - /*! \return The number of elements of the key */ - size_t count(const key_type& key) const { return !Search(key).IsNone(); } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The const reference to the value - */ - const mapped_type& at(const key_type& key) const { return At(key); } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The mutable reference to the value - */ - mapped_type& at(const key_type& key) { return At(key); } - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - iterator find(const key_type& key) const { - ListNode node = Search(key); - return node.IsNone() ? end() : iterator(node.index, this); - } - /*! - * \brief Erase the entry associated with the iterator - * \param position The iterator - */ - void erase(const iterator& position) { - uint64_t index = position.index; - if (position.self != nullptr && index <= this->slots_) { - Erase(ListNode(index, this)); - } - } - /*! \return begin iterator */ - iterator begin() const { - if (slots_ == 0) { - return iterator(0, this); - } - for (uint64_t index = 0; index <= slots_; ++index) { - if (!ListNode(index, this).IsEmpty()) { - return iterator(index, this); - } - } - return iterator(slots_ + 1, this); - } - /*! \return end iterator */ - iterator end() const { return slots_ == 0 ? iterator(0, this) : iterator(slots_ + 1, this); } - - private: - /*! - * \brief Search for the given key - * \param key The key - * \return ListNode that associated with the key - */ - ListNode Search(const key_type& key) const { - if (this->size_ == 0) { - return ListNode(); - } - for (ListNode iter = GetListHead(ObjectHash()(key)); !iter.IsNone(); iter.MoveToNext(this)) { - if (ObjectEqual()(key, iter.Key())) { - return iter; - } - } - return ListNode(); - } - /*! - * \brief Search for the given key, throw exception if not exists - * \param key The key - * \return ListNode that associated with the key - */ - mapped_type& At(const key_type& key) const { - ListNode iter = Search(key); - ICHECK(!iter.IsNone()) << "IndexError: key is not in Map"; - return iter.Val(); - } - /*! - * \brief Try to insert a key, or do nothing if already exists - * \param key The indexing key - * \param result The linked-list entry found or just constructed - * \return A boolean, indicating if actual insertion happens - */ - bool TryInsert(const key_type& key, ListNode* result) { - if (slots_ == 0) { - return false; - } - // required that `iter` to be the head of a linked list through which we can iterator - ListNode iter = IndexFromHash(ObjectHash()(key)); - // `iter` can be: 1) empty; 2) body of an irrelevant list; 3) head of the relevant list - // Case 1: empty - if (iter.IsEmpty()) { - iter.NewHead(KVType(key, ObjectRef(nullptr))); - this->size_ += 1; - *result = iter; - return true; - } - // Case 2: body of an irrelevant list - if (!iter.IsHead()) { - // we move the elements around and construct the single-element linked list - return IsFull() ? false : TrySpareListHead(iter, key, result); - } - // Case 3: head of the relevant list - // we iterate through the linked list until the end - // make sure `iter` is the previous element of `next` - ListNode next = iter; - do { - // find equal item, do not insert - if (ObjectEqual()(key, next.Key())) { - *result = next; - return true; - } - // make sure `iter` is the previous element of `next` - iter = next; - } while (next.MoveToNext(this)); - // `iter` is the tail of the linked list - // always check capacity before insertion - if (IsFull()) { - return false; - } - // find the next empty slot - uint8_t jump; - if (!iter.GetNextEmpty(this, &jump, result)) { - return false; - } - result->NewTail(KVType(key, ObjectRef(nullptr))); - // link `iter` to `empty`, and move forward - iter.SetJump(jump); - this->size_ += 1; - return true; - } - /*! - * \brief Spare an entry to be the head of a linked list. - * As described in B3, during insertion, it is possible that the entire linked list does not - * exist, but the slot of its head has been occupied by other linked lists. In this case, we need - * to spare the slot by moving away the elements to another valid empty one to make insertion - * possible. - * \param target The given entry to be spared - * \param key The indexing key - * \param result The linked-list entry constructed as the head - * \return A boolean, if actual insertion happens - */ - bool TrySpareListHead(ListNode target, const key_type& key, ListNode* result) { - // `target` is not the head of the linked list - // move the original item of `target` (if any) - // and construct new item on the position `target` - // To make `target` empty, we - // 1) find `w` the previous element of `target` in the linked list - // 2) copy the linked list starting from `r = target` - // 3) paste them after `w` - // read from the linked list after `r` - ListNode r = target; - // write to the tail of `w` - ListNode w = target.FindPrev(this); - // after `target` is moved, we disallow writing to the slot - bool is_first = true; - uint8_t r_meta, jump; - ListNode empty; - do { - // `jump` describes how `w` is jumped to `empty` - // rehash if there is no empty space after `w` - if (!w.GetNextEmpty(this, &jump, &empty)) { - return false; - } - // move `r` to `empty` - empty.NewTail(std::move(r.Data())); - // clear the metadata of `r` - r_meta = r.Meta(); - if (is_first) { - is_first = false; - r.SetProtected(); - } else { - r.SetEmpty(); - } - // link `w` to `empty`, and move forward - w.SetJump(jump); - w = empty; - // move `r` forward as well - } while (r.MoveToNext(this, r_meta)); - // finally we have done moving the linked list - // fill data_ into `target` - target.NewHead(KVType(key, ObjectRef(nullptr))); - this->size_ += 1; - *result = target; - return true; - } - /*! - * \brief Remove a ListNode - * \param iter The node to be removed - */ - void Erase(const ListNode& iter) { - this->size_ -= 1; - if (!iter.HasNext()) { - // `iter` is the last - if (!iter.IsHead()) { - // cut the link if there is any - iter.FindPrev(this).SetJump(0); - } - iter.Data().KVType::~KVType(); - iter.SetEmpty(); - } else { - ListNode last = iter, prev = iter; - for (last.MoveToNext(this); last.HasNext(); prev = last, last.MoveToNext(this)) { - } - iter.Data() = std::move(last.Data()); - last.SetEmpty(); - prev.SetJump(0); - } - } - /*! \brief Clear the container to empty, release all entries and memory acquired */ - void Reset() { - uint64_t n_blocks = CalcNumBlocks(this->slots_); - for (uint64_t bi = 0; bi < n_blocks; ++bi) { - uint8_t* meta_ptr = data_[bi].bytes; - KVType* data_ptr = reinterpret_cast(data_[bi].bytes + kBlockCap); - for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) { - uint8_t& meta = *meta_ptr; - if (meta != uint8_t(kProtectedSlot) && meta != uint8_t(kEmptySlot)) { - meta = uint8_t(kEmptySlot); - data_ptr->KVType::~KVType(); - } - } - } - ReleaseMemory(); - } - /*! \brief Release the memory acquired by the container without deleting its entries stored inside - */ - void ReleaseMemory() { - delete[] data_; - data_ = nullptr; - slots_ = 0; - size_ = 0; - fib_shift_ = 63; - } - /*! - * \brief Create an empty container - * \param fib_shift The fib shift provided - * \param n_slots Number of slots required, should be power-of-two - * \return The object created - */ - static ObjectPtr Empty(uint32_t fib_shift, uint64_t n_slots) { - ICHECK_GT(n_slots, uint64_t(SmallMapNode::kMaxSize)); - ObjectPtr p = make_object(); - uint64_t n_blocks = CalcNumBlocks(n_slots - 1); - Block* block = p->data_ = new Block[n_blocks]; - p->slots_ = n_slots - 1; - p->size_ = 0; - p->fib_shift_ = fib_shift; - for (uint64_t i = 0; i < n_blocks; ++i, ++block) { - std::fill(block->bytes, block->bytes + kBlockCap, uint8_t(kEmptySlot)); - } - return p; - } - /*! - * \brief Create an empty container with elements copying from another DenseMapNode - * \param from The source container - * \return The object created - */ - static ObjectPtr CopyFrom(DenseMapNode* from) { - ObjectPtr p = make_object(); - uint64_t n_blocks = CalcNumBlocks(from->slots_); - p->data_ = new Block[n_blocks]; - p->slots_ = from->slots_; - p->size_ = from->size_; - p->fib_shift_ = from->fib_shift_; - for (uint64_t bi = 0; bi < n_blocks; ++bi) { - uint8_t* meta_ptr_from = from->data_[bi].bytes; - KVType* data_ptr_from = reinterpret_cast(from->data_[bi].bytes + kBlockCap); - uint8_t* meta_ptr_to = p->data_[bi].bytes; - KVType* data_ptr_to = reinterpret_cast(p->data_[bi].bytes + kBlockCap); - for (int j = 0; j < kBlockCap; - ++j, ++meta_ptr_from, ++data_ptr_from, ++meta_ptr_to, ++data_ptr_to) { - uint8_t& meta = *meta_ptr_to = *meta_ptr_from; - ICHECK(meta != kProtectedSlot); - if (meta != uint8_t(kEmptySlot)) { - new (data_ptr_to) KVType(*data_ptr_from); - } - } - } - return p; - } - /*! - * \brief InsertMaybeReHash an entry into the given hash map - * \param kv The entry to be inserted - * \param map The pointer to the map, can be changed if re-hashing happens - */ - static void InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { - DenseMapNode* map_node = static_cast(map->get()); - ListNode iter; - // Try to insert. If succeed, we simply return - if (map_node->TryInsert(kv.first, &iter)) { - iter.Val() = kv.second; - return; - } - ICHECK_GT(map_node->slots_, uint64_t(SmallMapNode::kMaxSize)); - // Otherwise, start rehash - ObjectPtr p = Empty(map_node->fib_shift_ - 1, map_node->slots_ * 2 + 2); - // Insert the given `kv` into the new hash map - InsertMaybeReHash(kv, &p); - uint64_t n_blocks = CalcNumBlocks(map_node->slots_); - // Then Insert data from the original block. - for (uint64_t bi = 0; bi < n_blocks; ++bi) { - uint8_t* meta_ptr = map_node->data_[bi].bytes; - KVType* data_ptr = reinterpret_cast(map_node->data_[bi].bytes + kBlockCap); - for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) { - uint8_t& meta = *meta_ptr; - if (meta != uint8_t(kProtectedSlot) && meta != uint8_t(kEmptySlot)) { - meta = uint8_t(kEmptySlot); - KVType kv = std::move(*data_ptr); - InsertMaybeReHash(kv, &p); - } - } - } - map_node->ReleaseMemory(); - *map = p; - } - /*! - * \brief Check whether the hash table is full - * \return A boolean indicating whether hash table is full - */ - bool IsFull() const { return size_ + 1 > (slots_ + 1) * kMaxLoadFactor; } - /*! - * \brief Increment the pointer - * \param index The pointer to be incremented - * \return The increased pointer - */ - uint64_t IncItr(uint64_t index) const { - for (++index; index <= slots_; ++index) { - if (!ListNode(index, this).IsEmpty()) { - return index; - } - } - return slots_ + 1; - } - /*! - * \brief Decrement the pointer - * \param index The pointer to be decremented - * \return The decreased pointer - */ - uint64_t DecItr(uint64_t index) const { - while (index != 0) { - index -= 1; - if (!ListNode(index, this).IsEmpty()) { - return index; - } - } - return slots_ + 1; - } - /*! - * \brief De-reference the pointer - * \param index The pointer to be dereferenced - * \return The result - */ - KVType* DeRefItr(uint64_t index) const { return &ListNode(index, this).Data(); } - /*! \brief Construct from hash code */ - ListNode IndexFromHash(uint64_t hash_value) const { - return ListNode(FibHash(hash_value, fib_shift_), this); - } - /*! \brief Construct from hash code if the position is head of list */ - ListNode GetListHead(uint64_t hash_value) const { - ListNode node = IndexFromHash(hash_value); - return node.IsHead() ? node : ListNode(); - } - /*! \brief Construct the number of blocks in the hash table */ - static uint64_t CalcNumBlocks(uint64_t n_slots_m1) { - uint64_t n_slots = n_slots_m1 > 0 ? n_slots_m1 + 1 : 0; - return (n_slots + kBlockCap - 1) / kBlockCap; - } - /*! - * \brief Calculate the power-of-2 table size given the lower-bound of required capacity. - * \param cap The lower-bound of the required capacity - * \param fib_shift The result shift for Fibonacci Hashing - * \param n_slots The result number of slots - */ - static void CalcTableSize(uint64_t cap, uint32_t* fib_shift, uint64_t* n_slots) { - uint32_t shift = 64; - uint64_t slots = 1; - for (uint64_t c = cap; c; c >>= 1) { - shift -= 1; - slots <<= 1; - } - ICHECK_GT(slots, cap); - if (slots < cap * 2) { - *fib_shift = shift - 1; - *n_slots = slots << 1; - } else { - *fib_shift = shift; - *n_slots = slots; - } - } - /*! - * \brief Fibonacci Hashing, maps a hash code to an index in a power-of-2-sized table. - * See also: https://programmingpraxis.com/2018/06/19/fibonacci-hash/. - * \param hash_value The raw hash value - * \param fib_shift The shift in Fibonacci Hashing - * \return An index calculated using Fibonacci Hashing - */ - static uint64_t FibHash(uint64_t hash_value, uint32_t fib_shift) { - constexpr uint64_t coeff = 11400714819323198485ull; - return (coeff * hash_value) >> fib_shift; - } - /*! \brief The implicit in-place linked list used to index a chain */ - struct ListNode { - /*! \brief Construct None */ - ListNode() : index(0), block(nullptr) {} - /*! \brief Construct from position */ - ListNode(uint64_t index, const DenseMapNode* self) - : index(index), block(self->data_ + (index / kBlockCap)) {} - /*! \brief Metadata on the entry */ - uint8_t& Meta() const { return *(block->bytes + index % kBlockCap); } - /*! \brief Data on the entry */ - KVType& Data() const { - return *(reinterpret_cast(block->bytes + kBlockCap + - (index % kBlockCap) * sizeof(KVType))); - } - /*! \brief Key on the entry */ - key_type& Key() const { return Data().first; } - /*! \brief Value on the entry */ - mapped_type& Val() const { return Data().second; } - /*! \brief If the entry is head of linked list */ - bool IsHead() const { return (Meta() & 0b10000000) == 0b00000000; } - /*! \brief If the entry is none */ - bool IsNone() const { return block == nullptr; } - /*! \brief If the entry is empty slot */ - bool IsEmpty() const { return Meta() == uint8_t(kEmptySlot); } - /*! \brief If the entry is protected slot */ - bool IsProtected() const { return Meta() == uint8_t(kProtectedSlot); } - /*! \brief Set the entry to be empty */ - void SetEmpty() const { Meta() = uint8_t(kEmptySlot); } - /*! \brief Set the entry to be protected */ - void SetProtected() const { Meta() = uint8_t(kProtectedSlot); } - /*! \brief Set the entry's jump to its next entry */ - void SetJump(uint8_t jump) const { (Meta() &= 0b10000000) |= jump; } - /*! \brief Construct a head of linked list in-place */ - void NewHead(KVType v) const { - Meta() = 0b00000000; - new (&Data()) KVType(std::move(v)); - } - /*! \brief Construct a tail of linked list in-place */ - void NewTail(KVType v) const { - Meta() = 0b10000000; - new (&Data()) KVType(std::move(v)); - } - /*! \brief If the entry has next entry on the linked list */ - bool HasNext() const { return kNextProbeLocation[Meta() & 0b01111111] != 0; } - /*! \brief Move the entry to the next entry on the linked list */ - bool MoveToNext(const DenseMapNode* self, uint8_t meta) { - uint64_t offset = kNextProbeLocation[meta & 0b01111111]; - if (offset == 0) { - index = 0; - block = nullptr; - return false; - } - index = (index + offset) & (self->slots_); - block = self->data_ + (index / kBlockCap); - return true; - } - /*! \brief Move the entry to the next entry on the linked list */ - bool MoveToNext(const DenseMapNode* self) { return MoveToNext(self, Meta()); } - /*! \brief Get the previous entry on the linked list */ - ListNode FindPrev(const DenseMapNode* self) const { - // start from the head of the linked list, which must exist - ListNode next = self->IndexFromHash(ObjectHash()(Key())); - // `prev` is always the previous item of `next` - ListNode prev = next; - for (next.MoveToNext(self); index != next.index; prev = next, next.MoveToNext(self)) { - } - return prev; - } - /*! \brief Get the next empty jump */ - bool GetNextEmpty(const DenseMapNode* self, uint8_t* jump, ListNode* result) const { - for (uint8_t idx = 1; idx < kNumJumpDists; ++idx) { - ListNode candidate((index + kNextProbeLocation[idx]) & (self->slots_), self); - if (candidate.IsEmpty()) { - *jump = idx; - *result = candidate; - return true; - } - } - return false; - } - /*! \brief Index on the real array */ - uint64_t index; - /*! \brief Pointer to the actual block */ - Block* block; - }; - - protected: - /*! \brief fib shift in Fibonacci Hashing */ - uint32_t fib_shift_; - /*! \brief array of data blocks */ - Block* data_; - /* clang-format off */ - /*! \brief Candidates of probing distance */ - TVM_DLL static constexpr uint64_t kNextProbeLocation[kNumJumpDists] { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - // Quadratic probing with triangle numbers. See also: - // 1) https://en.wikipedia.org/wiki/Quadratic_probing - // 2) https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/ - // 3) https://github.com/skarupke/flat_hash_map - 21, 28, 36, 45, 55, 66, 78, 91, 105, 120, - 136, 153, 171, 190, 210, 231, 253, 276, 300, 325, - 351, 378, 406, 435, 465, 496, 528, 561, 595, 630, - 666, 703, 741, 780, 820, 861, 903, 946, 990, 1035, - 1081, 1128, 1176, 1225, 1275, 1326, 1378, 1431, 1485, 1540, - 1596, 1653, 1711, 1770, 1830, 1891, 1953, 2016, 2080, 2145, - 2211, 2278, 2346, 2415, 2485, 2556, 2628, - // larger triangle numbers - 8515, 19110, 42778, 96141, 216153, - 486591, 1092981, 2458653, 5532801, 12442566, - 27993903, 62983476, 141717030, 318844378, 717352503, - 1614057336, 3631522476, 8170957530, 18384510628, 41364789378, - 93070452520, 209408356380, 471168559170, 1060128894105, 2385289465695, - 5366898840628, 12075518705635, 27169915244790, 61132312065111, 137547689707000, - 309482283181501, 696335127828753, 1566753995631385, 3525196511162271, 7931691992677701, - 17846306936293605, 40154190677507445, 90346928918121501, 203280589587557251, 457381325854679626, - 1029107982097042876, 2315492959180353330, 5209859154120846435, - }; - /* clang-format on */ - friend class MapNode; -}; - -#define TVM_DISPATCH_MAP(base, var, body) \ - { \ - using TSmall = SmallMapNode*; \ - using TDense = DenseMapNode*; \ - uint64_t slots = base->slots_; \ - if (slots <= SmallMapNode::kMaxSize) { \ - TSmall var = static_cast(base); \ - body; \ - } else { \ - TDense var = static_cast(base); \ - body; \ - } \ - } - -#define TVM_DISPATCH_MAP_CONST(base, var, body) \ - { \ - using TSmall = const SmallMapNode*; \ - using TDense = const DenseMapNode*; \ - uint64_t slots = base->slots_; \ - if (slots <= SmallMapNode::kMaxSize) { \ - TSmall var = static_cast(base); \ - body; \ - } else { \ - TDense var = static_cast(base); \ - body; \ - } \ - } - -inline MapNode::iterator::pointer MapNode::iterator::operator->() const { - TVM_DISPATCH_MAP_CONST(self, p, { return p->DeRefItr(index); }); -} - -inline MapNode::iterator& MapNode::iterator::operator++() { - TVM_DISPATCH_MAP_CONST(self, p, { - index = p->IncItr(index); - return *this; - }); -} - -inline MapNode::iterator& MapNode::iterator::operator--() { - TVM_DISPATCH_MAP_CONST(self, p, { - index = p->DecItr(index); - return *this; - }); -} - -inline size_t MapNode::count(const key_type& key) const { - TVM_DISPATCH_MAP_CONST(this, p, { return p->count(key); }); -} - -inline const MapNode::mapped_type& MapNode::at(const MapNode::key_type& key) const { - TVM_DISPATCH_MAP_CONST(this, p, { return p->at(key); }); -} - -inline MapNode::mapped_type& MapNode::at(const MapNode::key_type& key) { - TVM_DISPATCH_MAP(this, p, { return p->at(key); }); -} - -inline MapNode::iterator MapNode::begin() const { - TVM_DISPATCH_MAP_CONST(this, p, { return p->begin(); }); -} - -inline MapNode::iterator MapNode::end() const { - TVM_DISPATCH_MAP_CONST(this, p, { return p->end(); }); -} - -inline MapNode::iterator MapNode::find(const MapNode::key_type& key) const { - TVM_DISPATCH_MAP_CONST(this, p, { return p->find(key); }); -} - -inline void MapNode::erase(const MapNode::iterator& position) { - TVM_DISPATCH_MAP(this, p, { return p->erase(position); }); -} - -#undef TVM_DISPATCH_MAP -#undef TVM_DISPATCH_MAP_CONST - -inline ObjectPtr MapNode::Empty() { return SmallMapNode::Empty(); } - -inline ObjectPtr MapNode::CopyFrom(MapNode* from) { - if (from->slots_ <= SmallMapNode::kMaxSize) { - return SmallMapNode::CopyFrom(static_cast(from)); - } else { - return DenseMapNode::CopyFrom(static_cast(from)); - } -} - -template -inline ObjectPtr MapNode::CreateFromRange(IterType first, IterType last) { - int64_t _cap = std::distance(first, last); - if (_cap < 0) { - return SmallMapNode::Empty(); - } - uint64_t cap = static_cast(_cap); - if (cap < SmallMapNode::kMaxSize) { - return SmallMapNode::CreateFromRange(cap, first, last); - } - uint32_t fib_shift; - uint64_t n_slots; - DenseMapNode::CalcTableSize(cap, &fib_shift, &n_slots); - ObjectPtr obj = DenseMapNode::Empty(fib_shift, n_slots); - for (; first != last; ++first) { - KVType kv(*first); - DenseMapNode::InsertMaybeReHash(kv, &obj); - } - return obj; -} - -inline void MapNode::InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { - constexpr uint64_t kSmallMapMaxSize = SmallMapNode::kMaxSize; - MapNode* base = static_cast(map->get()); - if (base->slots_ < kSmallMapMaxSize) { - SmallMapNode::InsertMaybeReHash(kv, map); - } else if (base->slots_ == kSmallMapMaxSize) { - if (base->size_ < base->slots_) { - SmallMapNode::InsertMaybeReHash(kv, map); - } else { - ObjectPtr new_map = MapNode::CreateFromRange(base->begin(), base->end()); - DenseMapNode::InsertMaybeReHash(kv, &new_map); - *map = std::move(new_map); - } - } else { - DenseMapNode::InsertMaybeReHash(kv, map); - } -} - -namespace runtime { -template <> -inline ObjectPtr make_object<>() = delete; -} // namespace runtime - -#endif - -/*! - * \brief Map container of NodeRef->NodeRef in DSL graph. - * Map implements copy on write semantics, which means map is mutable - * but copy will happen when array is referenced in more than two places. - * - * operator[] only provide const acces, use Set to mutate the content. - * \tparam K The key NodeRef type. - * \tparam V The value NodeRef type. - */ -template ::value>::type, - typename = typename std::enable_if::value>::type> -class Map : public ObjectRef { - public: - using key_type = K; - using mapped_type = V; - class iterator; - /*! - * \brief default constructor - */ - Map() { data_ = MapNode::Empty(); } - /*! - * \brief move constructor - * \param other source - */ - Map(Map&& other) { data_ = std::move(other.data_); } - /*! - * \brief copy constructor - * \param other source - */ - Map(const Map& other) : ObjectRef(other.data_) {} - /*! - * \brief copy assign operator - * \param other The source of assignment - * \return reference to self. - */ - Map& operator=(Map&& other) { - data_ = std::move(other.data_); - return *this; - } - /*! - * \brief move assign operator - * \param other The source of assignment - * \return reference to self. - */ - Map& operator=(const Map& other) { - data_ = other.data_; - return *this; - } - /*! - * \brief constructor from pointer - * \param n the container pointer - */ - explicit Map(ObjectPtr n) : ObjectRef(n) {} - /*! - * \brief constructor from iterator - * \param begin begin of iterator - * \param end end of iterator - * \tparam IterType The type of iterator - */ - template - Map(IterType begin, IterType end) { - data_ = MapNode::CreateFromRange(begin, end); - } - /*! - * \brief constructor from initializer list - * \param init The initalizer list - */ - Map(std::initializer_list> init) { - data_ = MapNode::CreateFromRange(init.begin(), init.end()); - } - /*! - * \brief constructor from unordered_map - * \param init The unordered_map - */ - template - Map(const std::unordered_map& init) { // NOLINT(*) - data_ = MapNode::CreateFromRange(init.begin(), init.end()); - } - /*! - * \brief Read element from map. - * \param key The key - * \return the corresonding element. - */ - const V at(const K& key) const { return DowncastNoCheck(GetMapNode()->at(key)); } - /*! - * \brief Read element from map. - * \param key The key - * \return the corresonding element. - */ - const V operator[](const K& key) const { return this->at(key); } - /*! \return The size of the array */ - size_t size() const { - MapNode* n = GetMapNode(); - return n == nullptr ? 0 : n->size(); - } - /*! \return The number of elements of the key */ - size_t count(const K& key) const { - MapNode* n = GetMapNode(); - return n == nullptr ? 0 : GetMapNode()->count(key); - } - /*! \return whether array is empty */ - bool empty() const { return size() == 0; } - /*! - * \brief set the Map. - * \param key The index key. - * \param value The value to be setted. - */ - void Set(const K& key, const V& value) { - CopyOnWrite(); - MapNode::InsertMaybeReHash(MapNode::KVType(key, value), &data_); - } - /*! \return begin iterator */ - iterator begin() const { return iterator(GetMapNode()->begin()); } - /*! \return end iterator */ - iterator end() const { return iterator(GetMapNode()->end()); } - /*! \return find the key and returns the associated iterator */ - iterator find(const K& key) const { return iterator(GetMapNode()->find(key)); } - - void erase(const K& key) { CopyOnWrite()->erase(key); } - - /*! - * \brief copy on write semantics - * Do nothing if current handle is the unique copy of the array. - * Otherwise make a new copy of the array to ensure the current handle - * hold a unique copy. - * - * \return Handle to the internal node container(which ganrantees to be unique) - */ - MapNode* CopyOnWrite() { - if (data_.get() == nullptr) { - data_ = MapNode::Empty(); - } else if (!data_.unique()) { - data_ = MapNode::CopyFrom(GetMapNode()); - } - return GetMapNode(); - } - /*! \brief specify container node */ - using ContainerType = MapNode; - - /*! \brief Iterator of the hash map */ - class iterator { - public: - using iterator_category = std::bidirectional_iterator_tag; - using difference_type = int64_t; - using value_type = const std::pair; - using pointer = value_type*; - using reference = value_type; - - iterator() : itr() {} - - /*! \brief Compare iterators */ - bool operator==(const iterator& other) const { return itr == other.itr; } - /*! \brief Compare iterators */ - bool operator!=(const iterator& other) const { return itr != other.itr; } - /*! \brief De-reference iterators is not allowed */ - pointer operator->() const = delete; - /*! \brief De-reference iterators */ - reference operator*() const { - auto& kv = *itr; - return std::make_pair(DowncastNoCheck(kv.first), DowncastNoCheck(kv.second)); - } - /*! \brief Prefix self increment, e.g. ++iter */ - iterator& operator++() { - ++itr; - return *this; - } - /*! \brief Suffix self increment */ - iterator operator++(int) { - iterator copy = *this; - ++(*this); - return copy; - } - - private: - iterator(const MapNode::iterator& itr) // NOLINT(*) - : itr(itr) {} - - template - friend class Map; - - MapNode::iterator itr; - }; - - private: - /*! \brief Return data_ as type of pointer of MapNode */ - MapNode* GetMapNode() const { return static_cast(data_.get()); } -}; - -/*! - * \brief Merge two Maps. - * \param lhs the first Map to merge. - * \param rhs the second Map to merge. - * @return The merged Array. Original Maps are kept unchanged. - */ -template ::value>::type, - typename = typename std::enable_if::value>::type> -inline Map Merge(Map lhs, const Map& rhs) { - for (const auto& p : rhs) { - lhs.Set(p.first, p.second); - } - return std::move(lhs); -} - -} // namespace tvm - -namespace tvm { -namespace runtime { -// Additional overloads for PackedFunc checking. -template -struct ObjectTypeChecker> { - static Optional CheckAndGetMismatch(const Object* ptr) { - if (ptr == nullptr) return NullOpt; - if (!ptr->IsInstance()) return String(ptr->GetTypeKey()); - const MapNode* n = static_cast(ptr); - for (const auto& kv : *n) { - Optional key_type = ObjectTypeChecker::CheckAndGetMismatch(kv.first.get()); - Optional value_type = ObjectTypeChecker::CheckAndGetMismatch(kv.first.get()); - if (key_type.defined() || value_type.defined()) { - std::string key_name = - key_type.defined() ? std::string(key_type.value()) : ObjectTypeChecker::TypeName(); - std::string value_name = value_type.defined() ? std::string(value_type.value()) - : ObjectTypeChecker::TypeName(); - return String("Map[" + key_name + ", " + value_name + "]"); - } - } - return NullOpt; - } - static bool Check(const Object* ptr) { - if (ptr == nullptr) return true; - if (!ptr->IsInstance()) return false; - const MapNode* n = static_cast(ptr); - for (const auto& kv : *n) { - if (!ObjectTypeChecker::Check(kv.first.get())) return false; - if (!ObjectTypeChecker::Check(kv.second.get())) return false; - } - return true; - } - static std::string TypeName() { - return "Map[" + ObjectTypeChecker::TypeName() + ", " + ObjectTypeChecker::TypeName() + - ']'; - } -}; -} // namespace runtime -} // namespace tvm -#endif // TVM_NODE_CONTAINER_H_ diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index 59295c2ce427..7b2a9f8061b4 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -34,7 +34,6 @@ #ifndef TVM_NODE_NODE_H_ #define TVM_NODE_NODE_H_ -#include #include #include #include diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h index 9424f6dc30f2..d5309bca894d 100644 --- a/include/tvm/node/structural_equal.h +++ b/include/tvm/node/structural_equal.h @@ -23,8 +23,8 @@ #ifndef TVM_NODE_STRUCTURAL_EQUAL_H_ #define TVM_NODE_STRUCTURAL_EQUAL_H_ -#include #include +#include #include #include diff --git a/include/tvm/node/structural_hash.h b/include/tvm/node/structural_hash.h index ed89d841cd65..a661a852780d 100644 --- a/include/tvm/node/structural_hash.h +++ b/include/tvm/node/structural_hash.h @@ -23,8 +23,8 @@ #ifndef TVM_NODE_STRUCTURAL_HASH_H_ #define TVM_NODE_STRUCTURAL_HASH_H_ -#include #include +#include #include #include diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index d53658f87f40..e6eec61a7e9d 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -88,7 +88,8 @@ class ExprFunctor { * \return The result of the call */ virtual R VisitExpr(const Expr& n, Args... args) { - ICHECK(n.defined()); + ICHECK(n.defined()) << "Found null pointer node while traversing AST. The previous pass may " + "have generated invalid data."; static FType vtable = InitVTable(); return vtable(n, this, std::forward(args)...); } diff --git a/include/tvm/relay/feature.h b/include/tvm/relay/feature.h index 7df881938f50..4a5de33af4b9 100644 --- a/include/tvm/relay/feature.h +++ b/include/tvm/relay/feature.h @@ -25,8 +25,8 @@ #define TVM_RELAY_FEATURE_H_ #include -#include #include +#include #include #include diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index 796ab7b113c1..336fef21ab88 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -24,7 +24,12 @@ #ifndef TVM_RUNTIME_CONTAINER_H_ #define TVM_RUNTIME_CONTAINER_H_ +#ifndef USE_FALLBACK_STL_MAP +#define USE_FALLBACK_STL_MAP 0 +#endif + #include +#include #include #include @@ -34,6 +39,7 @@ #include #include #include +#include // We use c++14 std::experimental::string_view for optimizing hash computation // only right now, its usage is limited in this file. Any broader usage of // std::experiment in our core codebase is discouraged and needs community @@ -1688,11 +1694,1413 @@ class Closure : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj); }; +#if (USE_FALLBACK_STL_MAP != 0) + +/*! \brief Shared content of all specializations of hash map */ +class MapNode : public Object { + public: + /*! \brief Type of the keys in the hash map */ + using key_type = ObjectRef; + /*! \brief Type of the values in the hash map */ + using mapped_type = ObjectRef; + /*! \brief Type of the actual underlying container */ + using ContainerType = std::unordered_map; + /*! \brief Iterator class */ + using iterator = ContainerType::iterator; + /*! \brief Iterator class */ + using const_iterator = ContainerType::const_iterator; + /*! \brief Type of value stored in the hash map */ + using KVType = ContainerType::value_type; + + static_assert(std::is_standard_layout::value, "KVType is not standard layout"); + static_assert(sizeof(KVType) == 16 || sizeof(KVType) == 8, "sizeof(KVType) incorrect"); + + static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeMap; + static constexpr const char* _type_key = "Map"; + TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object); + + /*! + * \brief Number of elements in the SmallMapNode + * \return The result + */ + size_t size() const { return data_.size(); } + /*! + * \brief Count the number of times a key exists in the hash map + * \param key The indexing key + * \return The result, 0 or 1 + */ + size_t count(const key_type& key) const { return data_.count(key); } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The const reference to the value + */ + const mapped_type& at(const key_type& key) const { return data_.at(key); } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The mutable reference to the value + */ + mapped_type& at(const key_type& key) { return data_.at(key); } + /*! \return begin iterator */ + iterator begin() { return data_.begin(); } + /*! \return const begin iterator */ + const_iterator begin() const { return data_.begin(); } + /*! \return end iterator */ + iterator end() { return data_.end(); } + /*! \return end iterator */ + const_iterator end() const { return data_.end(); } + /*! + * \brief Index value associated with a key + * \param key The indexing key + * \return The iterator of the entry associated with the key, end iterator if not exists + */ + const_iterator find(const key_type& key) const { return data_.find(key); } + /*! + * \brief Index value associated with a key + * \param key The indexing key + * \return The iterator of the entry associated with the key, end iterator if not exists + */ + iterator find(const key_type& key) { return data_.find(key); } + /*! + * \brief Erase the entry associated with the iterator + * \param position The iterator + */ + void erase(const iterator& position) { data_.erase(position); } + /*! + * \brief Erase the entry associated with the key, do nothing if not exists + * \param key The indexing key + */ + void erase(const key_type& key) { data_.erase(key); } + /*! + * \brief Create an empty container + * \return The object created + */ + static ObjectPtr Empty() { return make_object(); } + + protected: + /*! + * \brief Create the map using contents from the given iterators. + * \param first Begin of iterator + * \param last End of iterator + * \tparam IterType The type of iterator + * \return ObjectPtr to the map created + */ + template + static ObjectPtr CreateFromRange(IterType first, IterType last) { + ObjectPtr p = make_object(); + p->data_ = ContainerType(first, last); + return p; + } + /*! + * \brief InsertMaybeReHash an entry into the given hash map + * \param kv The entry to be inserted + * \param map The pointer to the map, can be changed if re-hashing happens + */ + static void InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { + MapNode* map_node = static_cast(map->get()); + map_node->data_[kv.first] = kv.second; + } + /*! + * \brief Create an empty container with elements copying from another MapNode + * \param from The source container + * \return The object created + */ + static ObjectPtr CopyFrom(MapNode* from) { + ObjectPtr p = make_object(); + p->data_ = ContainerType(from->data_.begin(), from->data_.end()); + return p; + } + /*! \brief The real container storing data */ + ContainerType data_; + template + friend class Map; +}; + +#else + +/*! \brief Shared content of all specializations of hash map */ +class MapNode : public Object { + public: + /*! \brief Type of the keys in the hash map */ + using key_type = ObjectRef; + /*! \brief Type of the values in the hash map */ + using mapped_type = ObjectRef; + /*! \brief Type of value stored in the hash map */ + using KVType = std::pair; + /*! \brief Iterator class */ + class iterator; + + static_assert(std::is_standard_layout::value, "KVType is not standard layout"); + static_assert(sizeof(KVType) == 16 || sizeof(KVType) == 8, "sizeof(KVType) incorrect"); + + static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeMap; + static constexpr const char* _type_key = "Map"; + TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object); + + /*! + * \brief Number of elements in the SmallMapNode + * \return The result + */ + size_t size() const { return size_; } + /*! + * \brief Count the number of times a key exists in the hash map + * \param key The indexing key + * \return The result, 0 or 1 + */ + size_t count(const key_type& key) const; + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The const reference to the value + */ + const mapped_type& at(const key_type& key) const; + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The mutable reference to the value + */ + mapped_type& at(const key_type& key); + /*! \return begin iterator */ + iterator begin() const; + /*! \return end iterator */ + iterator end() const; + /*! + * \brief Index value associated with a key + * \param key The indexing key + * \return The iterator of the entry associated with the key, end iterator if not exists + */ + iterator find(const key_type& key) const; + /*! + * \brief Erase the entry associated with the iterator + * \param position The iterator + */ + void erase(const iterator& position); + /*! + * \brief Erase the entry associated with the key, do nothing if not exists + * \param key The indexing key + */ + void erase(const key_type& key) { erase(find(key)); } + + class iterator { + public: + using iterator_category = std::forward_iterator_tag; + using difference_type = int64_t; + using value_type = KVType; + using pointer = KVType*; + using reference = KVType&; + /*! \brief Default constructor */ + iterator() : index(0), self(nullptr) {} + /*! \brief Compare iterators */ + bool operator==(const iterator& other) const { + return index == other.index && self == other.self; + } + /*! \brief Compare iterators */ + bool operator!=(const iterator& other) const { return !(*this == other); } + /*! \brief De-reference iterators */ + pointer operator->() const; + /*! \brief De-reference iterators */ + reference operator*() const { return *((*this).operator->()); } + /*! \brief Prefix self increment, e.g. ++iter */ + iterator& operator++(); + /*! \brief Prefix self decrement, e.g. --iter */ + iterator& operator--(); + /*! \brief Suffix self increment */ + iterator operator++(int) { + iterator copy = *this; + ++(*this); + return copy; + } + /*! \brief Suffix self decrement */ + iterator operator--(int) { + iterator copy = *this; + --(*this); + return copy; + } + + protected: + /*! \brief Construct by value */ + iterator(uint64_t index, const MapNode* self) : index(index), self(self) {} + /*! \brief The position on the array */ + uint64_t index; + /*! \brief The container it points to */ + const MapNode* self; + + friend class DenseMapNode; + friend class SmallMapNode; + }; + /*! + * \brief Create an empty container + * \return The object created + */ + static inline ObjectPtr Empty(); + + protected: + /*! + * \brief Create the map using contents from the given iterators. + * \param first Begin of iterator + * \param last End of iterator + * \tparam IterType The type of iterator + * \return ObjectPtr to the map created + */ + template + static inline ObjectPtr CreateFromRange(IterType first, IterType last); + /*! + * \brief InsertMaybeReHash an entry into the given hash map + * \param kv The entry to be inserted + * \param map The pointer to the map, can be changed if re-hashing happens + */ + static inline void InsertMaybeReHash(const KVType& kv, ObjectPtr* map); + /*! + * \brief Create an empty container with elements copying from another SmallMapNode + * \param from The source container + * \return The object created + */ + static inline ObjectPtr CopyFrom(MapNode* from); + /*! \brief number of slots minus 1 */ + uint64_t slots_; + /*! \brief number of entries in the container */ + uint64_t size_; + // Reference class + template + friend class Map; +}; + +/*! \brief A specialization of small-sized hash map */ +class SmallMapNode : public MapNode, + public runtime::InplaceArrayBase { + private: + static constexpr uint64_t kInitSize = 2; + static constexpr uint64_t kMaxSize = 4; + + public: + using MapNode::iterator; + using MapNode::KVType; + + /*! \brief Defaults to the destructor of InplaceArrayBase */ + ~SmallMapNode() = default; + /*! + * \brief Count the number of times a key exists in the SmallMapNode + * \param key The indexing key + * \return The result, 0 or 1 + */ + size_t count(const key_type& key) const { return find(key).index < size_; } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The const reference to the value + */ + const mapped_type& at(const key_type& key) const { + iterator itr = find(key); + ICHECK(itr.index < size_) << "IndexError: key is not in Map"; + return itr->second; + } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The mutable reference to the value + */ + mapped_type& at(const key_type& key) { + iterator itr = find(key); + ICHECK(itr.index < size_) << "IndexError: key is not in Map"; + return itr->second; + } + /*! \return begin iterator */ + iterator begin() const { return iterator(0, this); } + /*! \return end iterator */ + iterator end() const { return iterator(size_, this); } + /*! + * \brief Index value associated with a key + * \param key The indexing key + * \return The iterator of the entry associated with the key, end iterator if not exists + */ + iterator find(const key_type& key) const { + KVType* ptr = static_cast(AddressOf(0)); + for (uint64_t i = 0; i < size_; ++i, ++ptr) { + if (ObjectEqual()(ptr->first, key)) { + return iterator(i, this); + } + } + return iterator(size_, this); + } + /*! + * \brief Erase the entry associated with the iterator + * \param position The iterator + */ + void erase(const iterator& position) { Erase(position.index); } + + private: + /*! + * \brief Remove a position in SmallMapNode + * \param index The position to be removed + */ + void Erase(const uint64_t index) { + if (index >= size_) { + return; + } + KVType* begin = static_cast(AddressOf(0)); + KVType* last = begin + (size_ - 1); + if (index + 1 == size_) { + last->first.ObjectRef::~ObjectRef(); + last->second.ObjectRef::~ObjectRef(); + } else { + *(begin + index) = std::move(*last); + } + size_ -= 1; + } + /*! + * \brief Create an empty container + * \param n Number of empty slots + * \return The object created + */ + static ObjectPtr Empty(uint64_t n = kInitSize) { + using ::tvm::runtime::make_inplace_array_object; + ObjectPtr p = make_inplace_array_object(n); + p->size_ = 0; + p->slots_ = n; + return p; + } + /*! + * \brief Create an empty container initialized with a given range + * \param n Number of empty slots + * \param first begin of iterator + * \param last end of iterator + * \tparam IterType The type of iterator + * \return The object created + */ + template + static ObjectPtr CreateFromRange(uint64_t n, IterType first, IterType last) { + ObjectPtr p = Empty(n); + KVType* ptr = static_cast(p->AddressOf(0)); + for (; first != last; ++first, ++p->size_) { + new (ptr++) KVType(*first); + } + return p; + } + /*! + * \brief Create an empty container with elements copying from another SmallMapNode + * \param from The source container + * \return The object created + */ + static ObjectPtr CopyFrom(SmallMapNode* from) { + KVType* first = static_cast(from->AddressOf(0)); + KVType* last = first + from->size_; + return CreateFromRange(from->size_, first, last); + } + /*! + * \brief InsertMaybeReHash an entry into the given hash map + * \param kv The entry to be inserted + * \param map The pointer to the map, can be changed if re-hashing happens + */ + static void InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { + SmallMapNode* map_node = static_cast(map->get()); + iterator itr = map_node->find(kv.first); + if (itr.index < map_node->size_) { + itr->second = kv.second; + return; + } + if (map_node->size_ < map_node->slots_) { + KVType* ptr = static_cast(map_node->AddressOf(map_node->size_)); + new (ptr) KVType(kv); + ++map_node->size_; + return; + } + uint64_t next_size = std::max(map_node->slots_ * 2, uint64_t(kInitSize)); + next_size = std::min(next_size, uint64_t(kMaxSize)); + ICHECK_GT(next_size, map_node->slots_); + ObjectPtr new_map = CreateFromRange(next_size, map_node->begin(), map_node->end()); + InsertMaybeReHash(kv, &new_map); + *map = std::move(new_map); + } + /*! + * \brief Increment the pointer + * \param index The pointer to be incremented + * \return The increased pointer + */ + uint64_t IncItr(uint64_t index) const { return index + 1 < size_ ? index + 1 : size_; } + /*! + * \brief Decrement the pointer + * \param index The pointer to be decremented + * \return The decreased pointer + */ + uint64_t DecItr(uint64_t index) const { return index > 0 ? index - 1 : size_; } + /*! + * \brief De-reference the pointer + * \param index The pointer to be dereferenced + * \return The result + */ + KVType* DeRefItr(uint64_t index) const { return static_cast(AddressOf(index)); } + /*! \brief A size function used by InplaceArrayBase */ + uint64_t GetSize() const { return size_; } + + protected: + friend class MapNode; + friend class DenseMapNode; + friend class runtime::InplaceArrayBase; +}; + +/*! \brief A specialization of hash map that implements the idea of array-based hash map. + * Another reference implementation can be found [1]. + * + * A. Overview + * + * DenseMapNode did several improvements over traditional separate chaining hash, + * in terms of cache locality, memory footprints and data organization. + * + * A1. Implicit linked list. For better cache locality, instead of using linked list + * explicitly for each bucket, we store list data into a single array that spans contiguously + * in memory, and then carefully design access patterns to make sure most of them fall into + * a single cache line. + * + * A2. 1-byte metadata. There is only 1 byte overhead for each slot in the array to indexing and + * traversal. This can be divided in 3 parts. + * 1) Reserved code: (0b11111111)_2 indicates a slot is empty; (0b11111110)_2 indicates protected, + * which means the slot is empty but not allowed to be written. + * 2) If not empty or protected, the highest bit is used to indicate whether data in the slot is + * head of a linked list. + * 3) The rest 7 bits are used as the "next pointer" (i.e. pointer to the next element). On 64-bit + * architecture, an ordinary pointer can take up to 8 bytes, which is not acceptable overhead when + * dealing with 16-byte ObjectRef pairs. Based on a commonly noticed fact that the lists are + * relatively short (length <= 3) in hash maps, we follow [1]'s idea that only allows the pointer to + * be one of the 126 possible values, i.e. if the next element of i-th slot is (i + x)-th element, + * then x must be one of the 126 pre-defined values. + * + * A3. Data blocking. We organize the array in the way that every 16 elements forms a data block. + * The 16-byte metadata of those 16 elements are stored together, followed by the real data, i.e. + * 16 key-value pairs. + * + * B. Implementation details + * + * B1. Power-of-2 table size and Fibonacci Hashing. We use power-of-two as table size to avoid + * modulo for more efficient arithmetics. To make the hash-to-slot mapping distribute more evenly, + * we use the Fibonacci Hashing [2] trick. + * + * B2. Traverse a linked list in the array. + * 1) List head. Assume Fibonacci Hashing maps a given key to slot i, if metadata at slot i + * indicates that it is list head, then we found the head; otherwise the list is empty. No probing + * is done in this procedure. 2) Next element. To find the next element of a non-empty slot i, we + * look at the last 7 bits of the metadata at slot i. If they are all zeros, then it is the end of + * list; otherwise, we know that the next element is (i + candidates[the-last-7-bits]). + * + * B3. InsertMaybeReHash an element. Following B2, we first traverse the linked list to see if this + * element is in the linked list, and if not, we put it at the end by probing the next empty + * position in one of the 126 candidate positions. If the linked list does not even exist, but the + * slot for list head has been occupied by another linked list, we should find this intruder another + * place. + * + * B4. Quadratic probing with triangle numbers. In open address hashing, it is provable that probing + * with triangle numbers can traverse power-of-2-sized table [3]. In our algorithm, we follow the + * suggestion in [1] that also use triangle numbers for "next pointer" as well as sparing for list + * head. + * + * [1] https://github.com/skarupke/flat_hash_map + * [2] https://programmingpraxis.com/2018/06/19/fibonacci-hash/ + * [3] https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/ + */ +class DenseMapNode : public MapNode { + private: + /*! \brief The number of elements in a memory block */ + static constexpr int kBlockCap = 16; + /*! \brief Maximum load factor of the hash map */ + static constexpr double kMaxLoadFactor = 0.99; + /*! \brief Binary representation of the metadata of an empty slot */ + static constexpr uint8_t kEmptySlot = uint8_t(0b11111111); + /*! \brief Binary representation of the metadata of a protected slot */ + static constexpr uint8_t kProtectedSlot = uint8_t(0b11111110); + /*! \brief Number of probing choices available */ + static constexpr int kNumJumpDists = 126; + /*! \brief Head of the implicit linked list */ + struct ListNode; + /*! \brief POD type of a block of memory */ + struct Block { + uint8_t bytes[kBlockCap + kBlockCap * sizeof(KVType)]; + }; + static_assert(sizeof(Block) == kBlockCap * (sizeof(KVType) + 1), "sizeof(Block) incorrect"); + static_assert(std::is_standard_layout::value, "Block is not standard layout"); + + public: + using MapNode::iterator; + + /*! + * \brief Destroy the DenseMapNode + */ + ~DenseMapNode() { this->Reset(); } + /*! \return The number of elements of the key */ + size_t count(const key_type& key) const { return !Search(key).IsNone(); } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The const reference to the value + */ + const mapped_type& at(const key_type& key) const { return At(key); } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The mutable reference to the value + */ + mapped_type& at(const key_type& key) { return At(key); } + /*! + * \brief Index value associated with a key + * \param key The indexing key + * \return The iterator of the entry associated with the key, end iterator if not exists + */ + iterator find(const key_type& key) const { + ListNode node = Search(key); + return node.IsNone() ? end() : iterator(node.index, this); + } + /*! + * \brief Erase the entry associated with the iterator + * \param position The iterator + */ + void erase(const iterator& position) { + uint64_t index = position.index; + if (position.self != nullptr && index <= this->slots_) { + Erase(ListNode(index, this)); + } + } + /*! \return begin iterator */ + iterator begin() const { + if (slots_ == 0) { + return iterator(0, this); + } + for (uint64_t index = 0; index <= slots_; ++index) { + if (!ListNode(index, this).IsEmpty()) { + return iterator(index, this); + } + } + return iterator(slots_ + 1, this); + } + /*! \return end iterator */ + iterator end() const { return slots_ == 0 ? iterator(0, this) : iterator(slots_ + 1, this); } + + private: + /*! + * \brief Search for the given key + * \param key The key + * \return ListNode that associated with the key + */ + ListNode Search(const key_type& key) const { + if (this->size_ == 0) { + return ListNode(); + } + for (ListNode iter = GetListHead(ObjectHash()(key)); !iter.IsNone(); iter.MoveToNext(this)) { + if (ObjectEqual()(key, iter.Key())) { + return iter; + } + } + return ListNode(); + } + /*! + * \brief Search for the given key, throw exception if not exists + * \param key The key + * \return ListNode that associated with the key + */ + mapped_type& At(const key_type& key) const { + ListNode iter = Search(key); + ICHECK(!iter.IsNone()) << "IndexError: key is not in Map"; + return iter.Val(); + } + /*! + * \brief Try to insert a key, or do nothing if already exists + * \param key The indexing key + * \param result The linked-list entry found or just constructed + * \return A boolean, indicating if actual insertion happens + */ + bool TryInsert(const key_type& key, ListNode* result) { + if (slots_ == 0) { + return false; + } + // required that `iter` to be the head of a linked list through which we can iterator + ListNode iter = IndexFromHash(ObjectHash()(key)); + // `iter` can be: 1) empty; 2) body of an irrelevant list; 3) head of the relevant list + // Case 1: empty + if (iter.IsEmpty()) { + iter.NewHead(KVType(key, ObjectRef(nullptr))); + this->size_ += 1; + *result = iter; + return true; + } + // Case 2: body of an irrelevant list + if (!iter.IsHead()) { + // we move the elements around and construct the single-element linked list + return IsFull() ? false : TrySpareListHead(iter, key, result); + } + // Case 3: head of the relevant list + // we iterate through the linked list until the end + // make sure `iter` is the previous element of `next` + ListNode next = iter; + do { + // find equal item, do not insert + if (ObjectEqual()(key, next.Key())) { + *result = next; + return true; + } + // make sure `iter` is the previous element of `next` + iter = next; + } while (next.MoveToNext(this)); + // `iter` is the tail of the linked list + // always check capacity before insertion + if (IsFull()) { + return false; + } + // find the next empty slot + uint8_t jump; + if (!iter.GetNextEmpty(this, &jump, result)) { + return false; + } + result->NewTail(KVType(key, ObjectRef(nullptr))); + // link `iter` to `empty`, and move forward + iter.SetJump(jump); + this->size_ += 1; + return true; + } + /*! + * \brief Spare an entry to be the head of a linked list. + * As described in B3, during insertion, it is possible that the entire linked list does not + * exist, but the slot of its head has been occupied by other linked lists. In this case, we need + * to spare the slot by moving away the elements to another valid empty one to make insertion + * possible. + * \param target The given entry to be spared + * \param key The indexing key + * \param result The linked-list entry constructed as the head + * \return A boolean, if actual insertion happens + */ + bool TrySpareListHead(ListNode target, const key_type& key, ListNode* result) { + // `target` is not the head of the linked list + // move the original item of `target` (if any) + // and construct new item on the position `target` + // To make `target` empty, we + // 1) find `w` the previous element of `target` in the linked list + // 2) copy the linked list starting from `r = target` + // 3) paste them after `w` + // read from the linked list after `r` + ListNode r = target; + // write to the tail of `w` + ListNode w = target.FindPrev(this); + // after `target` is moved, we disallow writing to the slot + bool is_first = true; + uint8_t r_meta, jump; + ListNode empty; + do { + // `jump` describes how `w` is jumped to `empty` + // rehash if there is no empty space after `w` + if (!w.GetNextEmpty(this, &jump, &empty)) { + return false; + } + // move `r` to `empty` + empty.NewTail(std::move(r.Data())); + // clear the metadata of `r` + r_meta = r.Meta(); + if (is_first) { + is_first = false; + r.SetProtected(); + } else { + r.SetEmpty(); + } + // link `w` to `empty`, and move forward + w.SetJump(jump); + w = empty; + // move `r` forward as well + } while (r.MoveToNext(this, r_meta)); + // finally we have done moving the linked list + // fill data_ into `target` + target.NewHead(KVType(key, ObjectRef(nullptr))); + this->size_ += 1; + *result = target; + return true; + } + /*! + * \brief Remove a ListNode + * \param iter The node to be removed + */ + void Erase(const ListNode& iter) { + this->size_ -= 1; + if (!iter.HasNext()) { + // `iter` is the last + if (!iter.IsHead()) { + // cut the link if there is any + iter.FindPrev(this).SetJump(0); + } + iter.Data().KVType::~KVType(); + iter.SetEmpty(); + } else { + ListNode last = iter, prev = iter; + for (last.MoveToNext(this); last.HasNext(); prev = last, last.MoveToNext(this)) { + } + iter.Data() = std::move(last.Data()); + last.SetEmpty(); + prev.SetJump(0); + } + } + /*! \brief Clear the container to empty, release all entries and memory acquired */ + void Reset() { + uint64_t n_blocks = CalcNumBlocks(this->slots_); + for (uint64_t bi = 0; bi < n_blocks; ++bi) { + uint8_t* meta_ptr = data_[bi].bytes; + KVType* data_ptr = reinterpret_cast(data_[bi].bytes + kBlockCap); + for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) { + uint8_t& meta = *meta_ptr; + if (meta != uint8_t(kProtectedSlot) && meta != uint8_t(kEmptySlot)) { + meta = uint8_t(kEmptySlot); + data_ptr->KVType::~KVType(); + } + } + } + ReleaseMemory(); + } + /*! \brief Release the memory acquired by the container without deleting its entries stored inside + */ + void ReleaseMemory() { + delete[] data_; + data_ = nullptr; + slots_ = 0; + size_ = 0; + fib_shift_ = 63; + } + /*! + * \brief Create an empty container + * \param fib_shift The fib shift provided + * \param n_slots Number of slots required, should be power-of-two + * \return The object created + */ + static ObjectPtr Empty(uint32_t fib_shift, uint64_t n_slots) { + ICHECK_GT(n_slots, uint64_t(SmallMapNode::kMaxSize)); + ObjectPtr p = make_object(); + uint64_t n_blocks = CalcNumBlocks(n_slots - 1); + Block* block = p->data_ = new Block[n_blocks]; + p->slots_ = n_slots - 1; + p->size_ = 0; + p->fib_shift_ = fib_shift; + for (uint64_t i = 0; i < n_blocks; ++i, ++block) { + std::fill(block->bytes, block->bytes + kBlockCap, uint8_t(kEmptySlot)); + } + return p; + } + /*! + * \brief Create an empty container with elements copying from another DenseMapNode + * \param from The source container + * \return The object created + */ + static ObjectPtr CopyFrom(DenseMapNode* from) { + ObjectPtr p = make_object(); + uint64_t n_blocks = CalcNumBlocks(from->slots_); + p->data_ = new Block[n_blocks]; + p->slots_ = from->slots_; + p->size_ = from->size_; + p->fib_shift_ = from->fib_shift_; + for (uint64_t bi = 0; bi < n_blocks; ++bi) { + uint8_t* meta_ptr_from = from->data_[bi].bytes; + KVType* data_ptr_from = reinterpret_cast(from->data_[bi].bytes + kBlockCap); + uint8_t* meta_ptr_to = p->data_[bi].bytes; + KVType* data_ptr_to = reinterpret_cast(p->data_[bi].bytes + kBlockCap); + for (int j = 0; j < kBlockCap; + ++j, ++meta_ptr_from, ++data_ptr_from, ++meta_ptr_to, ++data_ptr_to) { + uint8_t& meta = *meta_ptr_to = *meta_ptr_from; + ICHECK(meta != kProtectedSlot); + if (meta != uint8_t(kEmptySlot)) { + new (data_ptr_to) KVType(*data_ptr_from); + } + } + } + return p; + } + /*! + * \brief InsertMaybeReHash an entry into the given hash map + * \param kv The entry to be inserted + * \param map The pointer to the map, can be changed if re-hashing happens + */ + static void InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { + DenseMapNode* map_node = static_cast(map->get()); + ListNode iter; + // Try to insert. If succeed, we simply return + if (map_node->TryInsert(kv.first, &iter)) { + iter.Val() = kv.second; + return; + } + ICHECK_GT(map_node->slots_, uint64_t(SmallMapNode::kMaxSize)); + // Otherwise, start rehash + ObjectPtr p = Empty(map_node->fib_shift_ - 1, map_node->slots_ * 2 + 2); + // Insert the given `kv` into the new hash map + InsertMaybeReHash(kv, &p); + uint64_t n_blocks = CalcNumBlocks(map_node->slots_); + // Then Insert data from the original block. + for (uint64_t bi = 0; bi < n_blocks; ++bi) { + uint8_t* meta_ptr = map_node->data_[bi].bytes; + KVType* data_ptr = reinterpret_cast(map_node->data_[bi].bytes + kBlockCap); + for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) { + uint8_t& meta = *meta_ptr; + if (meta != uint8_t(kProtectedSlot) && meta != uint8_t(kEmptySlot)) { + meta = uint8_t(kEmptySlot); + KVType kv = std::move(*data_ptr); + InsertMaybeReHash(kv, &p); + } + } + } + map_node->ReleaseMemory(); + *map = p; + } + /*! + * \brief Check whether the hash table is full + * \return A boolean indicating whether hash table is full + */ + bool IsFull() const { return size_ + 1 > (slots_ + 1) * kMaxLoadFactor; } + /*! + * \brief Increment the pointer + * \param index The pointer to be incremented + * \return The increased pointer + */ + uint64_t IncItr(uint64_t index) const { + for (++index; index <= slots_; ++index) { + if (!ListNode(index, this).IsEmpty()) { + return index; + } + } + return slots_ + 1; + } + /*! + * \brief Decrement the pointer + * \param index The pointer to be decremented + * \return The decreased pointer + */ + uint64_t DecItr(uint64_t index) const { + while (index != 0) { + index -= 1; + if (!ListNode(index, this).IsEmpty()) { + return index; + } + } + return slots_ + 1; + } + /*! + * \brief De-reference the pointer + * \param index The pointer to be dereferenced + * \return The result + */ + KVType* DeRefItr(uint64_t index) const { return &ListNode(index, this).Data(); } + /*! \brief Construct from hash code */ + ListNode IndexFromHash(uint64_t hash_value) const { + return ListNode(FibHash(hash_value, fib_shift_), this); + } + /*! \brief Construct from hash code if the position is head of list */ + ListNode GetListHead(uint64_t hash_value) const { + ListNode node = IndexFromHash(hash_value); + return node.IsHead() ? node : ListNode(); + } + /*! \brief Construct the number of blocks in the hash table */ + static uint64_t CalcNumBlocks(uint64_t n_slots_m1) { + uint64_t n_slots = n_slots_m1 > 0 ? n_slots_m1 + 1 : 0; + return (n_slots + kBlockCap - 1) / kBlockCap; + } + /*! + * \brief Calculate the power-of-2 table size given the lower-bound of required capacity. + * \param cap The lower-bound of the required capacity + * \param fib_shift The result shift for Fibonacci Hashing + * \param n_slots The result number of slots + */ + static void CalcTableSize(uint64_t cap, uint32_t* fib_shift, uint64_t* n_slots) { + uint32_t shift = 64; + uint64_t slots = 1; + for (uint64_t c = cap; c; c >>= 1) { + shift -= 1; + slots <<= 1; + } + ICHECK_GT(slots, cap); + if (slots < cap * 2) { + *fib_shift = shift - 1; + *n_slots = slots << 1; + } else { + *fib_shift = shift; + *n_slots = slots; + } + } + /*! + * \brief Fibonacci Hashing, maps a hash code to an index in a power-of-2-sized table. + * See also: https://programmingpraxis.com/2018/06/19/fibonacci-hash/. + * \param hash_value The raw hash value + * \param fib_shift The shift in Fibonacci Hashing + * \return An index calculated using Fibonacci Hashing + */ + static uint64_t FibHash(uint64_t hash_value, uint32_t fib_shift) { + constexpr uint64_t coeff = 11400714819323198485ull; + return (coeff * hash_value) >> fib_shift; + } + /*! \brief The implicit in-place linked list used to index a chain */ + struct ListNode { + /*! \brief Construct None */ + ListNode() : index(0), block(nullptr) {} + /*! \brief Construct from position */ + ListNode(uint64_t index, const DenseMapNode* self) + : index(index), block(self->data_ + (index / kBlockCap)) {} + /*! \brief Metadata on the entry */ + uint8_t& Meta() const { return *(block->bytes + index % kBlockCap); } + /*! \brief Data on the entry */ + KVType& Data() const { + return *(reinterpret_cast(block->bytes + kBlockCap + + (index % kBlockCap) * sizeof(KVType))); + } + /*! \brief Key on the entry */ + key_type& Key() const { return Data().first; } + /*! \brief Value on the entry */ + mapped_type& Val() const { return Data().second; } + /*! \brief If the entry is head of linked list */ + bool IsHead() const { return (Meta() & 0b10000000) == 0b00000000; } + /*! \brief If the entry is none */ + bool IsNone() const { return block == nullptr; } + /*! \brief If the entry is empty slot */ + bool IsEmpty() const { return Meta() == uint8_t(kEmptySlot); } + /*! \brief If the entry is protected slot */ + bool IsProtected() const { return Meta() == uint8_t(kProtectedSlot); } + /*! \brief Set the entry to be empty */ + void SetEmpty() const { Meta() = uint8_t(kEmptySlot); } + /*! \brief Set the entry to be protected */ + void SetProtected() const { Meta() = uint8_t(kProtectedSlot); } + /*! \brief Set the entry's jump to its next entry */ + void SetJump(uint8_t jump) const { (Meta() &= 0b10000000) |= jump; } + /*! \brief Construct a head of linked list in-place */ + void NewHead(KVType v) const { + Meta() = 0b00000000; + new (&Data()) KVType(std::move(v)); + } + /*! \brief Construct a tail of linked list in-place */ + void NewTail(KVType v) const { + Meta() = 0b10000000; + new (&Data()) KVType(std::move(v)); + } + /*! \brief If the entry has next entry on the linked list */ + bool HasNext() const { return kNextProbeLocation[Meta() & 0b01111111] != 0; } + /*! \brief Move the entry to the next entry on the linked list */ + bool MoveToNext(const DenseMapNode* self, uint8_t meta) { + uint64_t offset = kNextProbeLocation[meta & 0b01111111]; + if (offset == 0) { + index = 0; + block = nullptr; + return false; + } + index = (index + offset) & (self->slots_); + block = self->data_ + (index / kBlockCap); + return true; + } + /*! \brief Move the entry to the next entry on the linked list */ + bool MoveToNext(const DenseMapNode* self) { return MoveToNext(self, Meta()); } + /*! \brief Get the previous entry on the linked list */ + ListNode FindPrev(const DenseMapNode* self) const { + // start from the head of the linked list, which must exist + ListNode next = self->IndexFromHash(ObjectHash()(Key())); + // `prev` is always the previous item of `next` + ListNode prev = next; + for (next.MoveToNext(self); index != next.index; prev = next, next.MoveToNext(self)) { + } + return prev; + } + /*! \brief Get the next empty jump */ + bool GetNextEmpty(const DenseMapNode* self, uint8_t* jump, ListNode* result) const { + for (uint8_t idx = 1; idx < kNumJumpDists; ++idx) { + ListNode candidate((index + kNextProbeLocation[idx]) & (self->slots_), self); + if (candidate.IsEmpty()) { + *jump = idx; + *result = candidate; + return true; + } + } + return false; + } + /*! \brief Index on the real array */ + uint64_t index; + /*! \brief Pointer to the actual block */ + Block* block; + }; + + protected: + /*! \brief fib shift in Fibonacci Hashing */ + uint32_t fib_shift_; + /*! \brief array of data blocks */ + Block* data_; + /* clang-format off */ + /*! \brief Candidates of probing distance */ + TVM_DLL static constexpr uint64_t kNextProbeLocation[kNumJumpDists] { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + // Quadratic probing with triangle numbers. See also: + // 1) https://en.wikipedia.org/wiki/Quadratic_probing + // 2) https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/ + // 3) https://github.com/skarupke/flat_hash_map + 21, 28, 36, 45, 55, 66, 78, 91, 105, 120, + 136, 153, 171, 190, 210, 231, 253, 276, 300, 325, + 351, 378, 406, 435, 465, 496, 528, 561, 595, 630, + 666, 703, 741, 780, 820, 861, 903, 946, 990, 1035, + 1081, 1128, 1176, 1225, 1275, 1326, 1378, 1431, 1485, 1540, + 1596, 1653, 1711, 1770, 1830, 1891, 1953, 2016, 2080, 2145, + 2211, 2278, 2346, 2415, 2485, 2556, 2628, + // larger triangle numbers + 8515, 19110, 42778, 96141, 216153, + 486591, 1092981, 2458653, 5532801, 12442566, + 27993903, 62983476, 141717030, 318844378, 717352503, + 1614057336, 3631522476, 8170957530, 18384510628, 41364789378, + 93070452520, 209408356380, 471168559170, 1060128894105, 2385289465695, + 5366898840628, 12075518705635, 27169915244790, 61132312065111, 137547689707000, + 309482283181501, 696335127828753, 1566753995631385, 3525196511162271, 7931691992677701, + 17846306936293605, 40154190677507445, 90346928918121501, 203280589587557251, 457381325854679626, + 1029107982097042876, 2315492959180353330, 5209859154120846435, + }; + /* clang-format on */ + friend class MapNode; +}; + +#define TVM_DISPATCH_MAP(base, var, body) \ + { \ + using TSmall = SmallMapNode*; \ + using TDense = DenseMapNode*; \ + uint64_t slots = base->slots_; \ + if (slots <= SmallMapNode::kMaxSize) { \ + TSmall var = static_cast(base); \ + body; \ + } else { \ + TDense var = static_cast(base); \ + body; \ + } \ + } + +#define TVM_DISPATCH_MAP_CONST(base, var, body) \ + { \ + using TSmall = const SmallMapNode*; \ + using TDense = const DenseMapNode*; \ + uint64_t slots = base->slots_; \ + if (slots <= SmallMapNode::kMaxSize) { \ + TSmall var = static_cast(base); \ + body; \ + } else { \ + TDense var = static_cast(base); \ + body; \ + } \ + } + +inline MapNode::iterator::pointer MapNode::iterator::operator->() const { + TVM_DISPATCH_MAP_CONST(self, p, { return p->DeRefItr(index); }); +} + +inline MapNode::iterator& MapNode::iterator::operator++() { + TVM_DISPATCH_MAP_CONST(self, p, { + index = p->IncItr(index); + return *this; + }); +} + +inline MapNode::iterator& MapNode::iterator::operator--() { + TVM_DISPATCH_MAP_CONST(self, p, { + index = p->DecItr(index); + return *this; + }); +} + +inline size_t MapNode::count(const key_type& key) const { + TVM_DISPATCH_MAP_CONST(this, p, { return p->count(key); }); +} + +inline const MapNode::mapped_type& MapNode::at(const MapNode::key_type& key) const { + TVM_DISPATCH_MAP_CONST(this, p, { return p->at(key); }); +} + +inline MapNode::mapped_type& MapNode::at(const MapNode::key_type& key) { + TVM_DISPATCH_MAP(this, p, { return p->at(key); }); +} + +inline MapNode::iterator MapNode::begin() const { + TVM_DISPATCH_MAP_CONST(this, p, { return p->begin(); }); +} + +inline MapNode::iterator MapNode::end() const { + TVM_DISPATCH_MAP_CONST(this, p, { return p->end(); }); +} + +inline MapNode::iterator MapNode::find(const MapNode::key_type& key) const { + TVM_DISPATCH_MAP_CONST(this, p, { return p->find(key); }); +} + +inline void MapNode::erase(const MapNode::iterator& position) { + TVM_DISPATCH_MAP(this, p, { return p->erase(position); }); +} + +#undef TVM_DISPATCH_MAP +#undef TVM_DISPATCH_MAP_CONST + +inline ObjectPtr MapNode::Empty() { return SmallMapNode::Empty(); } + +inline ObjectPtr MapNode::CopyFrom(MapNode* from) { + if (from->slots_ <= SmallMapNode::kMaxSize) { + return SmallMapNode::CopyFrom(static_cast(from)); + } else { + return DenseMapNode::CopyFrom(static_cast(from)); + } +} + +template +inline ObjectPtr MapNode::CreateFromRange(IterType first, IterType last) { + int64_t _cap = std::distance(first, last); + if (_cap < 0) { + return SmallMapNode::Empty(); + } + uint64_t cap = static_cast(_cap); + if (cap < SmallMapNode::kMaxSize) { + return SmallMapNode::CreateFromRange(cap, first, last); + } + uint32_t fib_shift; + uint64_t n_slots; + DenseMapNode::CalcTableSize(cap, &fib_shift, &n_slots); + ObjectPtr obj = DenseMapNode::Empty(fib_shift, n_slots); + for (; first != last; ++first) { + KVType kv(*first); + DenseMapNode::InsertMaybeReHash(kv, &obj); + } + return obj; +} + +inline void MapNode::InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { + constexpr uint64_t kSmallMapMaxSize = SmallMapNode::kMaxSize; + MapNode* base = static_cast(map->get()); + if (base->slots_ < kSmallMapMaxSize) { + SmallMapNode::InsertMaybeReHash(kv, map); + } else if (base->slots_ == kSmallMapMaxSize) { + if (base->size_ < base->slots_) { + SmallMapNode::InsertMaybeReHash(kv, map); + } else { + ObjectPtr new_map = MapNode::CreateFromRange(base->begin(), base->end()); + DenseMapNode::InsertMaybeReHash(kv, &new_map); + *map = std::move(new_map); + } + } else { + DenseMapNode::InsertMaybeReHash(kv, map); + } +} + +template <> +inline ObjectPtr make_object<>() = delete; + +#endif + +/*! + * \brief Map container of NodeRef->NodeRef in DSL graph. + * Map implements copy on write semantics, which means map is mutable + * but copy will happen when array is referenced in more than two places. + * + * operator[] only provide const acces, use Set to mutate the content. + * \tparam K The key NodeRef type. + * \tparam V The value NodeRef type. + */ +template ::value>::type, + typename = typename std::enable_if::value>::type> +class Map : public ObjectRef { + public: + using key_type = K; + using mapped_type = V; + class iterator; + /*! + * \brief default constructor + */ + Map() { data_ = MapNode::Empty(); } + /*! + * \brief move constructor + * \param other source + */ + Map(Map&& other) { data_ = std::move(other.data_); } + /*! + * \brief copy constructor + * \param other source + */ + Map(const Map& other) : ObjectRef(other.data_) {} + /*! + * \brief copy assign operator + * \param other The source of assignment + * \return reference to self. + */ + Map& operator=(Map&& other) { + data_ = std::move(other.data_); + return *this; + } + /*! + * \brief move assign operator + * \param other The source of assignment + * \return reference to self. + */ + Map& operator=(const Map& other) { + data_ = other.data_; + return *this; + } + /*! + * \brief constructor from pointer + * \param n the container pointer + */ + explicit Map(ObjectPtr n) : ObjectRef(n) {} + /*! + * \brief constructor from iterator + * \param begin begin of iterator + * \param end end of iterator + * \tparam IterType The type of iterator + */ + template + Map(IterType begin, IterType end) { + data_ = MapNode::CreateFromRange(begin, end); + } + /*! + * \brief constructor from initializer list + * \param init The initalizer list + */ + Map(std::initializer_list> init) { + data_ = MapNode::CreateFromRange(init.begin(), init.end()); + } + /*! + * \brief constructor from unordered_map + * \param init The unordered_map + */ + template + Map(const std::unordered_map& init) { // NOLINT(*) + data_ = MapNode::CreateFromRange(init.begin(), init.end()); + } + /*! + * \brief Read element from map. + * \param key The key + * \return the corresonding element. + */ + const V at(const K& key) const { return DowncastNoCheck(GetMapNode()->at(key)); } + /*! + * \brief Read element from map. + * \param key The key + * \return the corresonding element. + */ + const V operator[](const K& key) const { return this->at(key); } + /*! \return The size of the array */ + size_t size() const { + MapNode* n = GetMapNode(); + return n == nullptr ? 0 : n->size(); + } + /*! \return The number of elements of the key */ + size_t count(const K& key) const { + MapNode* n = GetMapNode(); + return n == nullptr ? 0 : GetMapNode()->count(key); + } + /*! \return whether array is empty */ + bool empty() const { return size() == 0; } + /*! + * \brief set the Map. + * \param key The index key. + * \param value The value to be setted. + */ + void Set(const K& key, const V& value) { + CopyOnWrite(); + MapNode::InsertMaybeReHash(MapNode::KVType(key, value), &data_); + } + /*! \return begin iterator */ + iterator begin() const { return iterator(GetMapNode()->begin()); } + /*! \return end iterator */ + iterator end() const { return iterator(GetMapNode()->end()); } + /*! \return find the key and returns the associated iterator */ + iterator find(const K& key) const { return iterator(GetMapNode()->find(key)); } + + void erase(const K& key) { CopyOnWrite()->erase(key); } + + /*! + * \brief copy on write semantics + * Do nothing if current handle is the unique copy of the array. + * Otherwise make a new copy of the array to ensure the current handle + * hold a unique copy. + * + * \return Handle to the internal node container(which ganrantees to be unique) + */ + MapNode* CopyOnWrite() { + if (data_.get() == nullptr) { + data_ = MapNode::Empty(); + } else if (!data_.unique()) { + data_ = MapNode::CopyFrom(GetMapNode()); + } + return GetMapNode(); + } + /*! \brief specify container node */ + using ContainerType = MapNode; + + /*! \brief Iterator of the hash map */ + class iterator { + public: + using iterator_category = std::bidirectional_iterator_tag; + using difference_type = int64_t; + using value_type = const std::pair; + using pointer = value_type*; + using reference = value_type; + + iterator() : itr() {} + + /*! \brief Compare iterators */ + bool operator==(const iterator& other) const { return itr == other.itr; } + /*! \brief Compare iterators */ + bool operator!=(const iterator& other) const { return itr != other.itr; } + /*! \brief De-reference iterators is not allowed */ + pointer operator->() const = delete; + /*! \brief De-reference iterators */ + reference operator*() const { + auto& kv = *itr; + return std::make_pair(DowncastNoCheck(kv.first), DowncastNoCheck(kv.second)); + } + /*! \brief Prefix self increment, e.g. ++iter */ + iterator& operator++() { + ++itr; + return *this; + } + /*! \brief Suffix self increment */ + iterator operator++(int) { + iterator copy = *this; + ++(*this); + return copy; + } + + private: + iterator(const MapNode::iterator& itr) // NOLINT(*) + : itr(itr) {} + + template + friend class Map; + + MapNode::iterator itr; + }; + + private: + /*! \brief Return data_ as type of pointer of MapNode */ + MapNode* GetMapNode() const { return static_cast(data_.get()); } +}; + +/*! + * \brief Merge two Maps. + * \param lhs the first Map to merge. + * \param rhs the second Map to merge. + * @return The merged Array. Original Maps are kept unchanged. + */ +template ::value>::type, + typename = typename std::enable_if::value>::type> +inline Map Merge(Map lhs, const Map& rhs) { + for (const auto& p : rhs) { + lhs.Set(p.first, p.second); + } + return std::move(lhs); +} + } // namespace runtime // expose the functions to the root namespace. +using runtime::Array; +using runtime::ArrayNode; +using runtime::Downcast; +using runtime::IterAdapter; +using runtime::make_object; +using runtime::Map; +using runtime::MapNode; +using runtime::Object; +using runtime::ObjectEqual; +using runtime::ObjectHash; +using runtime::ObjectPtr; +using runtime::ObjectPtrEqual; +using runtime::ObjectPtrHash; +using runtime::ObjectRef; using runtime::Optional; using runtime::String; +using runtime::StringObj; constexpr runtime::NullOptType NullOpt{}; } // namespace tvm diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index d705be6c4deb..7d914ce6bff9 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -160,12 +160,19 @@ class DataType { */ static DataType UInt(int bits, int lanes = 1) { return DataType(kDLUInt, bits, lanes); } /*! - * \brief Construct an uint type. + * \brief Construct an float type. * \param bits The number of bits in the type. * \param lanes The number of lanes * \return The constructed data type. */ static DataType Float(int bits, int lanes = 1) { return DataType(kDLFloat, bits, lanes); } + /*! + * \brief Construct an bfloat type. + * \param bits The number of bits in the type. + * \param lanes The number of lanes + * \return The constructed data type. + */ + static DataType BFloat(int bits, int lanes = 1) { return DataType(kDLBfloat, bits, lanes); } /*! * \brief Construct a bool type. * \param lanes The number of lanes diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index cf30923aacb0..751a435c734a 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -450,6 +450,40 @@ struct ObjectTypeChecker> { } static std::string TypeName() { return "Array[" + ObjectTypeChecker::TypeName() + "]"; } }; +template +struct ObjectTypeChecker> { + static Optional CheckAndGetMismatch(const Object* ptr) { + if (ptr == nullptr) return NullOpt; + if (!ptr->IsInstance()) return String(ptr->GetTypeKey()); + const MapNode* n = static_cast(ptr); + for (const auto& kv : *n) { + Optional key_type = ObjectTypeChecker::CheckAndGetMismatch(kv.first.get()); + Optional value_type = ObjectTypeChecker::CheckAndGetMismatch(kv.first.get()); + if (key_type.defined() || value_type.defined()) { + std::string key_name = + key_type.defined() ? std::string(key_type.value()) : ObjectTypeChecker::TypeName(); + std::string value_name = value_type.defined() ? std::string(value_type.value()) + : ObjectTypeChecker::TypeName(); + return String("Map[" + key_name + ", " + value_name + "]"); + } + } + return NullOpt; + } + static bool Check(const Object* ptr) { + if (ptr == nullptr) return true; + if (!ptr->IsInstance()) return false; + const MapNode* n = static_cast(ptr); + for (const auto& kv : *n) { + if (!ObjectTypeChecker::Check(kv.first.get())) return false; + if (!ObjectTypeChecker::Check(kv.second.get())) return false; + } + return true; + } + static std::string TypeName() { + return "Map[" + ObjectTypeChecker::TypeName() + ", " + ObjectTypeChecker::TypeName() + + ']'; + } +}; /*! * \brief Internal base class to diff --git a/include/tvm/runtime/profiling.h b/include/tvm/runtime/profiling.h new file mode 100644 index 000000000000..45b60ea18acc --- /dev/null +++ b/include/tvm/runtime/profiling.h @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file include/tvm/runtime/profiling.h + * \brief Runtime profiling including timers. + */ +#ifndef TVM_RUNTIME_PROFILING_H_ +#define TVM_RUNTIME_PROFILING_H_ + +#include +#include +#include +#include + +#include + +namespace tvm { +namespace runtime { + +/*! \brief Base class for all implementations. + * + * New implementations of this interface should make sure that `Start` and `Stop` + * are as lightweight as possible. Expensive state synchronization should be + * done in `SyncAndGetElapsedNanos`. + */ +class TimerNode : public Object { + public: + /*! \brief Start the timer. + * + * Note: this function should only be called once per object. + */ + virtual void Start() = 0; + /*! \brief Stop the timer. + * + * Note: this function should only be called once per object. + */ + virtual void Stop() = 0; + /*! \brief Synchronize timer state and return elapsed time between `Start` and `Stop`. + * \return The time in nanoseconds between `Start` and `Stop`. + * + * This function is necessary because we want to avoid timing the overhead of + * doing timing. When using multiple timers, it is recommended to stop all of + * them before calling `SyncAndGetElapsedNanos` on any of them. + * + * Note: this function should be only called once per object. It may incur + * a large synchronization overhead (for example, with GPUs). + */ + virtual int64_t SyncAndGetElapsedNanos() = 0; + + virtual ~TimerNode() {} + + static constexpr const char* _type_key = "TimerNode"; + TVM_DECLARE_BASE_OBJECT_INFO(TimerNode, Object); +}; + +/*! \brief Timer for a specific device. + * + * This is a managed reference to a TimerNode. + * + * \sa TimerNode + */ +class Timer : public ObjectRef { + public: + /*! + * \brief Get a device specific timer. + * \param ctx The device context to time. + * \return A `Timer` that has already been started. + * + * Use this function to time runtime of arbitrary regions of code on a specific + * device. The code that you want to time should be running on the device + * otherwise the timer will not return correct results. This is a lower level + * interface than TimeEvaluator and only runs the timed code once + * (TimeEvaluator runs the code multiple times). + * + * A default timer is used if a device specific one does not exist. This + * timer performs synchronization between the device and CPU, which can lead + * to overhead in the reported results. + * + * Example usage: + * \code{.cpp} + * Timer t = Timer::Start(TVMContext::cpu()); + * my_long_running_function(); + * t->Stop(); + * ... // some more computation + * int64_t nanosecs = t->SyncAndGetElapsedNanos() // elapsed time in nanoseconds + * \endcode + * + * To add a new device-specific timer, register a new function + * "profiler.timer.my_device" (where `my_device` is the `DeviceName` of your + * device). This function should accept a `TVMContext` and return a new `Timer` + * that has already been started. + * + * For example, this is how the CPU timer is implemented: + * \code{.cpp} + * class CPUTimerNode : public TimerNode { + * public: + * virtual void Start() { start_ = std::chrono::high_resolution_clock::now(); } + * virtual void Stop() { duration_ = std::chrono::high_resolution_clock::now() - start_; } + * virtual int64_t SyncAndGetElapsedNanos() { return duration_.count(); } + * virtual ~CPUTimerNode() {} + * + * static constexpr const char* _type_key = "CPUTimerNode"; + * TVM_DECLARE_FINAL_OBJECT_INFO(CPUTimerNode, TimerNode); + * + * private: + * std::chrono::high_resolution_clock::time_point start_; + * std::chrono::duration duration_; + * }; + * TVM_REGISTER_OBJECT_TYPE(CPUTimerNode); + * + * TVM_REGISTER_GLOBAL("profiling.timer.cpu").set_body_typed([](TVMContext ctx) { + * return Timer(make_object()); + * }); + * \endcode + */ + static TVM_DLL Timer Start(TVMContext ctx); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Timer, ObjectRef, TimerNode); +}; + +/*! + * \brief Default timer if one does not exist for the context. + * \param ctx The context to time on. + * + * Note that this timer performs synchronization between the device and CPU, + * which can lead to overhead in the reported results. + */ +Timer DefaultTimer(TVMContext ctx); + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_PROFILING_H_ diff --git a/include/tvm/te/schedule_pass.h b/include/tvm/te/schedule_pass.h index a4efa7a94990..32e74f6ef9d5 100644 --- a/include/tvm/te/schedule_pass.h +++ b/include/tvm/te/schedule_pass.h @@ -41,6 +41,13 @@ namespace te { */ void AutoInlineElemWise(Schedule sch); +/*! + * \brief To automatically inline the broadcast operations. + * + * \param sch The schedule to be inlined. + */ +void AutoInlineBroarcast(Schedule sch); + /*! * \brief To automatically inline operations with injective writes * (i.e. writes without reduction or sequential loops). Note diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index 2f9fa2f534c5..401ba102c2f4 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -25,7 +25,7 @@ #define TVM_TE_TENSOR_H_ #include -#include +#include #include #include diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 839e7c1b7c1c..83f228da9475 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -25,7 +25,7 @@ #define TVM_TIR_BUFFER_H_ #include -#include +#include #include #include diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index c7ff9e19014c..7cab1970f478 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -26,10 +26,10 @@ #define TVM_TIR_EXPR_H_ #include -#include #include #include #include +#include #include #include #include diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index ceebbbb305ce..d6303ae266e1 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -26,8 +26,8 @@ #ifndef TVM_TIR_STMT_FUNCTOR_H_ #define TVM_TIR_STMT_FUNCTOR_H_ -#include #include +#include #include #include #include diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 261fdf9970a3..3ad230560f3a 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1134,6 +1134,9 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, size_t ndim_i = indices->shape.size(); ICHECK_GE(ndim_d, 1) << "Cannot gather from a scalar."; ICHECK_EQ(ndim_d, ndim_i); + if (axis < 0) { + axis += ndim_d; + } ICHECK_GE(axis, 0); ICHECK_LT(axis, ndim_d); size_t indices_dim_i = static_cast(GetConstInt(indices->shape[axis])); diff --git a/python/tvm/auto_scheduler/__init__.py b/python/tvm/auto_scheduler/__init__.py index 06ca44d997e5..ff6d82a0242c 100644 --- a/python/tvm/auto_scheduler/__init__.py +++ b/python/tvm/auto_scheduler/__init__.py @@ -41,6 +41,7 @@ LocalRunner, RPCRunner, LocalRPCMeasureContext, + register_task_input_check_func, ) from .measure_record import RecordToFile, RecordReader, load_best_record, load_records, save_records from .relay_integration import ( diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 47ffde4327c4..959a9c5da82a 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -36,6 +36,7 @@ import shutil import tempfile import multiprocessing +import logging import tvm._ffi from tvm.runtime import Object, module, ndarray @@ -50,6 +51,7 @@ call_func_with_timeout, check_remote, get_const_tuple, + get_func_name, make_traceback_info, request_remote, ) @@ -58,6 +60,8 @@ deserialize_workload_registry_entry, ) +# pylint: disable=invalid-name +logger = logging.getLogger("auto_scheduler") # The time cost for measurements with errors # We use 1e10 instead of sys.float_info.max for better readability in log @@ -223,6 +227,7 @@ def recover_measure_input(inp, rebuild_state=False): target_host=task.target_host, hardware_params=task.hardware_params, layout_rewrite_option=task.layout_rewrite_option, + task_inputs=list(task.task_input_names), ) if rebuild_state: @@ -719,6 +724,97 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo return results +TASK_INPUT_CHECK_FUNC_REGISTRY = {} + + +def register_task_input_check_func(func_name, f=None, override=False): + """Register a function that checks the input buffer map. + + The input function should take a list of Tensor wich indicate the Input/output Tensor of a TVM + subgraph and return a Map from the input Tensor to its buffer name. + + Parameters + ---------- + func_name : Union[Function, str] + The check function that returns the compute declaration Tensors or its function name. + f : Optional[Function] + The check function to be registered. + override : boolean = False + Whether to override existing entry. + + Examples + -------- + .. code-block:: python + + @auto_scheduler.register_task_input_check_func + def check_task_input_by_placeholder_name(args : List[Tensor]): + tensor_input_map = {} + for arg in args: + if isinstance(arg.op, tvm.te.PlaceholderOp): + if arg.op.name != "placeholder": + tensor_input_map[arg] = arg.op.name + return tensor_input_map + """ + global TASK_INPUT_CHECK_FUNC_REGISTRY + + if callable(func_name): + f = func_name + func_name = get_func_name(f) + if not isinstance(func_name, str): + raise ValueError("expect string function name") + + def register(myf): + """internal register function""" + if func_name in TASK_INPUT_CHECK_FUNC_REGISTRY and not override: + raise RuntimeError("%s has been registered already" % func_name) + TASK_INPUT_CHECK_FUNC_REGISTRY[func_name] = myf + return myf + + if f: + return register(f) + return register + + +def _prepare_input_map(args): + """This function deals with special task inputs. Map the input Tensor of a TVM subgraph + to a specific buffer name in the global buffer map. + + Parameters + ---------- + args : List[Tensor] + Input/output Tensor of a TVM subgraph. + + Returns + ------- + Dict[Tensor, str] : + Map from the input Tensor to its buffer name. + + Notes + ----- + The buffer name is specially designed, and these buffer should be provided in + `SearchTask(..., task_inputs={...})`. + """ + # pylint: disable=import-outside-toplevel + + global TASK_INPUT_CHECK_FUNC_REGISTRY + + # A dict that maps the input tensor arg to a buffer name + tensor_input_map = {} + + # Case 0: Check placeholder name + for arg in args: + if isinstance(arg.op, tvm.te.PlaceholderOp): + if arg.op.name != "placeholder": + tensor_input_map[arg] = arg.op.name + + # Case 1: Check specific tensor inputs + for func_name in TASK_INPUT_CHECK_FUNC_REGISTRY: + func = TASK_INPUT_CHECK_FUNC_REGISTRY[func_name] + tensor_input_map.update(func(args)) + + return tensor_input_map + + def _timed_eval_func( inp_serialized, build_res, @@ -729,7 +825,11 @@ def _timed_eval_func( enable_cpu_cache_flush, verbose, ): + # pylint: disable=import-outside-toplevel + from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency + inp = MeasureInput.deserialize(inp_serialized) + task_input_names = inp.task.task_input_names tic = time.time() error_no = 0 error_msg = None @@ -758,11 +858,31 @@ def _timed_eval_func( if error_no == 0: try: - args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args] random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True) assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake" - for arg in args: - random_fill(arg) + + tensor_input_map = _prepare_input_map(build_res.args) if task_input_names else {} + args = [] + task_inputs_count = 0 + for arg in build_res.args: + if arg in tensor_input_map: + tensor_name = tensor_input_map[arg] + if tensor_name in task_input_names: + args.append(get_task_input_buffer(inp.task.workload_key, tensor_name)) + task_inputs_count += 1 + else: + raise ValueError( + "%s not found in task_inputs, " % (tensor_name) + + "should provide with `SearchTask(..., task_inputs={...})`" + ) + else: + empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx) + random_fill(empty_array) + args.append(empty_array) + if task_inputs_count != len(task_input_names): + logger.warning( + "task_inputs not fully matched, check if there's any unexpected error" + ) ctx.sync() costs = time_f(*args).results # pylint: disable=broad-except @@ -911,7 +1031,11 @@ def _timed_rpc_run( enable_cpu_cache_flush, verbose, ): + # pylint: disable=import-outside-toplevel + from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency + inp = MeasureInput.deserialize(inp_serialized) + task_input_names = inp.task.task_input_names tic = time.time() error_no = 0 error_msg = None @@ -943,18 +1067,36 @@ def _timed_rpc_run( if error_no == 0: try: - args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args] - try: - random_fill = remote.get_function("tvm.contrib.random.random_fill") - except AttributeError: - raise AttributeError( - "Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices" + random_fill = remote.get_function("tvm.contrib.random.random_fill") + assert ( + random_fill + ), "Please make sure USE_RANDOM is ON in the config.cmake on the remote devices" + + tensor_input_map = _prepare_input_map(build_res.args) if task_input_names else {} + args = [] + task_inputs_count = 0 + for arg in build_res.args: + if arg in tensor_input_map: + tensor_name = tensor_input_map[arg] + if tensor_name in task_input_names: + args.append(get_task_input_buffer(inp.task.workload_key, tensor_name)) + task_inputs_count += 1 + else: + raise ValueError( + "%s not found in task_inputs, " % (tensor_name) + + "should provide with `SearchTask(..., task_inputs={...})`" + ) + else: + empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx) + random_fill(empty_array) + args.append(empty_array) + if task_inputs_count != len(task_input_names): + logger.warning( + "task_inputs not fully matched, check if there's any unexpected error" ) - for arg in args: - random_fill(arg) ctx.sync() - costs = time_f(*args).results + # clean up remote files remote.remove(build_res.filename) remote.remove(os.path.splitext(build_res.filename)[0] + ".so") diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index b39aba227a88..68f53125c7ae 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -283,10 +283,13 @@ def auto_schedule_topi(outs): key = register_workload_tensors(dag.workload_key(), io_tensors) target = tvm.target.Target.current() + dispatch_ctx = DispatchContext.current + state = dispatch_ctx.query(target, key, has_complex_op, dag) + schedule = None + env = TracingEnvironment.current if env is None: # in the final build mode - state = DispatchContext.current.query(target, key, has_complex_op, dag) if state is None: return None @@ -303,8 +306,6 @@ def auto_schedule_topi(outs): LayoutRewriteOption.get_target_default(target, True) != LayoutRewriteOption.NO_REWRITE and has_layout_free ): - dispatch_ctx = DispatchContext.current - state = dispatch_ctx.query(target, key, has_complex_op, dag) if state is None: return None @@ -316,7 +317,7 @@ def auto_schedule_topi(outs): else: raise ValueError("Invalid tracing mode: " + env.tracing_mode) - return None + return schedule def tensor_no_check_call(self, *indices): diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index 175c2fa06c39..57e239cf79e8 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -19,8 +19,12 @@ import json +import os +import logging +import numpy as np + import tvm._ffi -from tvm.runtime import Object +from tvm.runtime import Object, ndarray from tvm.driver.build_module import build from tvm.target import Target @@ -33,6 +37,9 @@ from .workload_registry import WORKLOAD_FUNC_REGISTRY, register_workload_tensors from . import _ffi_api +# pylint: disable=invalid-name +logger = logging.getLogger("auto_scheduler") + @tvm._ffi.register_object("auto_scheduler.HardwareParams") class HardwareParams(Object): @@ -157,6 +164,156 @@ def __init__( ) +# The map stores special registered buffer for measurement. +# This can be used for sparse workloads when we cannot use random tensors for measurment. +# { +# "workload_key_0": { +# "task_input_0": Tensor(...), +# "task_input_1": Tensor(...) +# }, +# "workload_key_1": { +# "task_input_2": Tensor(...), +# "task_input_3": Tensor(...) +# }, +# ... +# } +TASK_INPUT_BUFFER_TABLE = {} + + +def _save_buffer_to_file(buffer_name, buffer_data): + """Save the current Tensor buffer to a numpy file. + + File name will be: {buffer_name}.{buffer_shape}_{buffer_data_type}.npy + """ + np_data = buffer_data.asnumpy() + + buffer_name += "." + for i in np_data.shape: + buffer_name += "%d_" % (i) + buffer_name += "%s" % (np_data.dtype) + buffer_name += ".npy" + + np_data.tofile(buffer_name, " ") + + +def _try_load_buffer_from_file(buffer_name): + """Try to load buffer from a numpy file, if not found, return None. + + File name has a same format as `_save_buffer_to_file`. + """ + filelist = os.listdir() + + for file in filelist: + if file.startswith(buffer_name + "."): + meta_info = file.split(".")[-2].split("_") + shape = [int(i) for i in meta_info[:-1]] + dtype = meta_info[-1] + buffer_data = np.fromfile(file, dtype=dtype, sep=" ") + buffer_data = buffer_data.reshape(shape) + return ndarray.array(buffer_data) + + return None + + +def register_task_input_buffer( + workload_key, + input_name, + input_data, + overwrite=False, + save_to_file=False, +): + """Register special buffer for measurement. + + Parameters + ---------- + workload_key : str + The workload key of the SearchTask. + + input_name : str + The name of input buffer. + + input_data : tvm.nd.NDArray + The input Tensor data. + + overwrite : bool = False + Whether to overwrite the data if a name has already registered. + + save_to_file : bool = False + Whether to save the data to a local file as well. This can be reused to resume the last + tuning process. + + Returns + ------- + tvm.nd.NDArray + The actual registered Tensor data of this input_name. With `overwrite` set to False, will + return the original one if the name has already registered before. + """ + global TASK_INPUT_BUFFER_TABLE + + if workload_key not in TASK_INPUT_BUFFER_TABLE: + TASK_INPUT_BUFFER_TABLE[workload_key] = {} + input_table = TASK_INPUT_BUFFER_TABLE[workload_key] + + if not overwrite: + if input_name not in input_table.keys(): + # Try to load buffer data from local file + tensor_from_file = _try_load_buffer_from_file(input_name) + if tensor_from_file: + input_table[input_name] = tensor_from_file + + if input_name in input_table.keys(): + logger.warning( + "Tensor %s exists in TASK_INPUT_BUFFER_TABLE, %s", + input_name, + "set overwrite to True or this Tensor will not be registered", + ) + return input_table[input_name] + + input_table[input_name] = input_data + if save_to_file: + _save_buffer_to_file(input_name, input_data) + return input_data + + +def get_task_input_buffer(workload_key, input_name): + """Get special buffer for measurement. + + The buffers are registered by `register_task_input_buffer`. + + Parameters + ---------- + workload_key : str + The workload key of the SearchTask. + + input_name : str + The name of input buffer. + + Returns + ------- + tvm.nd.NDArray + The registered input buffer. + """ + global TASK_INPUT_BUFFER_TABLE + + if workload_key not in TASK_INPUT_BUFFER_TABLE: + TASK_INPUT_BUFFER_TABLE[workload_key] = {} + input_table = TASK_INPUT_BUFFER_TABLE[workload_key] + + if input_name not in input_table.keys(): + # Try to load buffer data from local file + tensor_from_file = _try_load_buffer_from_file(input_name) + if tensor_from_file: + input_table[input_name] = tensor_from_file + + if input_name in input_table.keys(): + return input_table[input_name] + + raise ValueError( + "%s not found in TASK_INPUT_BUFFER_TABLE, " % (input_name) + + "should provide with `SearchTask(..., task_inputs={...})`" + ) + + @tvm._ffi.register_object("auto_scheduler.SearchTask") class SearchTask(Object): """The computation information and hardware parameters for a schedule search task. @@ -185,6 +342,16 @@ class SearchTask(Object): The NO_REWRITE and INSERT_TRANSFORM_STAGE are expected to be used when tuning a standalone op, and the REWRITE_FOR_PRE_TRANSFORMED is expected to be used when tuning ops inside a network. + task_inputs : Union[Dict[str, tvm.nd.NDArray], List[str]] + A dict maps the input names to input tensors or a list of input names. + Some special Tensor used as inputs in program measuring. Usually we do not need to care + about it, but for special workloads like Sparse computation the Sparse Tensor input are + meaningful that we cannot use random input directly. + task_inputs_overwrite : bool = False + Whether to overwrite the data if a name has already in the global table. + task_inputs_save_to_file : bool = False + Whether to save the data to a local file as well. This can be reused to resume the last + tuning process. Examples -------- @@ -212,6 +379,9 @@ def __init__( target_host=None, hardware_params=None, layout_rewrite_option=None, + task_inputs=None, + task_inputs_overwrite=False, + task_inputs_save_to_file=False, ): assert ( func is not None or workload_key is not None @@ -231,6 +401,22 @@ def __init__( if layout_rewrite_option is None: layout_rewrite_option = LayoutRewriteOption.get_target_default(target) + task_input_names = [] + if isinstance(task_inputs, list): + task_input_names = task_inputs + elif isinstance(task_inputs, dict): + for input_name in task_inputs: + register_task_input_buffer( + workload_key, + input_name, + task_inputs[input_name], + task_inputs_overwrite, + task_inputs_save_to_file, + ) + task_input_names.append(input_name) + elif task_inputs is not None: + raise ValueError("task_inputs should be a dict or a list.") + self.__init_handle_by_constructor__( _ffi_api.SearchTask, compute_dag, @@ -239,6 +425,7 @@ def __init__( target_host, hardware_params, layout_rewrite_option, + task_input_names, ) def tune(self, tuning_options, search_policy=None): @@ -326,6 +513,7 @@ def __getstate__(self): "target_host": self.target_host, "hardware_params": self.hardware_params, "layout_rewrite_option": self.layout_rewrite_option, + "task_input_names": self.task_input_names, } def __setstate__(self, state): @@ -350,6 +538,7 @@ def __setstate__(self, state): state["target_host"], state["hardware_params"], state["layout_rewrite_option"], + state["task_input_names"], ) diff --git a/python/tvm/auto_scheduler/utils.py b/python/tvm/auto_scheduler/utils.py index 8aa33e6775f8..14dc5b8984c3 100644 --- a/python/tvm/auto_scheduler/utils.py +++ b/python/tvm/auto_scheduler/utils.py @@ -201,6 +201,9 @@ def serialize_args(args): Currently this is mainly used for tvm.tensor.Tensor """ ret = [] + if args is None: + return tuple(ret) + for t in args: if isinstance(t, Tensor): t = ("TENSOR", get_const_tuple(t.shape), t.dtype) diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index 62fd811dc1ec..b68767bd0528 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -280,7 +280,13 @@ def get_build_kwargs(self): def run(self, measure_inputs, build_results): results = [] - remote_args = (self.key, self.host, self.port, self.priority, self.timeout) + remote_kwargs = dict( + device_key=self.key, + host=self.host, + port=self.port, + priority=self.priority, + timeout=self.timeout, + ) for i in range(0, len(measure_inputs), self.n_parallel): futures = [] @@ -300,7 +306,7 @@ def run(self, measure_inputs, build_results): self.repeat, self.min_repeat_ms, self.cooldown_interval, - remote_args, + remote_kwargs, self.enable_cpu_cache_flush, module_loader, ) diff --git a/python/tvm/contrib/cu_graph/__init__.py b/python/tvm/contrib/cu_graph/__init__.py new file mode 100644 index 000000000000..245692337bc3 --- /dev/null +++ b/python/tvm/contrib/cu_graph/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + diff --git a/python/tvm/contrib/cu_graph/cugraph_runtime.py b/python/tvm/contrib/cu_graph/cugraph_runtime.py new file mode 100644 index 000000000000..e82d69056ac4 --- /dev/null +++ b/python/tvm/contrib/cu_graph/cugraph_runtime.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Graph runtime test cuGraph""" +import tvm._ffi + +from tvm._ffi.base import string_types +from tvm.contrib import graph_runtime + + +def create(graph_json_str, libmod, ctx): + assert isinstance(graph_json_str, string_types) + try: + ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx) + if num_rpc_ctx == len(ctx): + pass + else: + fcreate = tvm._ffi.get_global_func("tvm.graph_runtime_cugraph.create") + except ValueError: + raise ValueError( + "Please set '(USE_GRAPH_RUNTIME_CUGRAPH ON)' in " + "config.cmake and rebuild TVM to enable cu_graph test mode" + ) + + func_obj = fcreate(graph_json_str, libmod, *device_type_id) + return GraphModuleCuGraph(func_obj, ctx, graph_json_str) + + +class GraphModuleCuGraph(graph_runtime.GraphModule): + def __init__(self, module, ctx, graph_json_str): + + self._start_capture = module["start_capture"] + self._end_capture = module["end_capture"] + self._run_cuda_graph = module["run_cuda_graph"] + + graph_runtime.GraphModule.__init__(self, module) + + def capture_cuda_graph(self): + # call cuModuleLoadData before cudaStream API + self._run() + + print("====== Start Stream Capture ======") + self._start_capture() + print("====== Start Run Ops On Stream ======") + self._run() + print("====== End Stream Capture ======") + self._end_capture() + + + def run_cuda_graph(self): + self._run_cuda_graph() + diff --git a/python/tvm/contrib/debugger/debug_result.py b/python/tvm/contrib/debugger/debug_result.py index 3159ab34397a..f58947f0766f 100644 --- a/python/tvm/contrib/debugger/debug_result.py +++ b/python/tvm/contrib/debugger/debug_result.py @@ -264,8 +264,4 @@ def save_tensors(params): """ _save_tensors = tvm.get_global_func("tvm.relay._save_param_dict") - args = [] - for k, v in params.items(): - args.append(k) - args.append(tvm.nd.array(v)) - return _save_tensors(*args) + return _save_tensors(params) diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 2a97b0b31d1e..f33603b923a5 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -302,8 +302,22 @@ def have_tensorcore(compute_version=None, target=None): major, minor = compute_version.split("_")[1] compute_version = major + "." + minor major, _ = parse_compute_version(compute_version) + if major >= 7: + return True + + return False + + +def have_bf16(compute_version): + """Either bf16 support is provided in the compute capability or not - if major == 7: + Parameters + ---------- + compute_version : str + compute capability of a GPU (e.g. "8.0") + """ + major, _ = parse_compute_version(compute_version) + if major >= 8: return True return False diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index fc1805ee0ab4..83791e50f6d5 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -24,7 +24,7 @@ import tvm from tvm import autotvm, auto_scheduler -from tvm import relay +from tvm import relay, runtime from tvm.contrib import cc from tvm.contrib import utils @@ -282,7 +282,7 @@ def save_module(module_path, graph, lib, params, cross=None): with open(temp.relpath(param_name), "wb") as params_file: logger.debug("writing params to file to %s", params_file.name) - params_file.write(relay.save_param_dict(params)) + params_file.write(runtime.save_param_dict(params)) logger.debug("saving module as tar file to %s", module_path) with tarfile.open(module_path, "w") as tar: diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py index 87ea3be1436a..1d23ccfb0c00 100644 --- a/python/tvm/driver/tvmc/runner.py +++ b/python/tvm/driver/tvmc/runner.py @@ -24,11 +24,11 @@ import tempfile import numpy as np -import tvm from tvm import rpc from tvm.autotvm.measure import request_remote from tvm.contrib import graph_runtime as runtime from tvm.contrib.debugger import debug_runtime +from tvm.relay import load_param_dict from . import common from .common import TVMCException @@ -163,9 +163,8 @@ def get_input_info(graph_str, params): shape_dict = {} dtype_dict = {} - # Use a special function to load the binary params back into a dict - load_arr = tvm.get_global_func("tvm.relay._load_param_dict")(params) - param_names = [v.name for v in load_arr] + params_dict = load_param_dict(params) + param_names = [k for (k, v) in params_dict.items()] graph = json.loads(graph_str) for node_id in graph["arg_nodes"]: node = graph["nodes"][node_id] diff --git a/python/tvm/ir/container.py b/python/tvm/ir/container.py index a87d67992953..5222f7a97a7c 100644 --- a/python/tvm/ir/container.py +++ b/python/tvm/ir/container.py @@ -19,7 +19,7 @@ from tvm.runtime import Object from tvm.runtime.container import getitem_helper -from tvm.runtime import _ffi_node_api +from tvm.runtime import _ffi_api @tvm._ffi.register_object("Array") @@ -33,10 +33,10 @@ class Array(Object): """ def __getitem__(self, idx): - return getitem_helper(self, _ffi_node_api.ArrayGetItem, len(self), idx) + return getitem_helper(self, _ffi_api.ArrayGetItem, len(self), idx) def __len__(self): - return _ffi_node_api.ArraySize(self) + return _ffi_api.ArraySize(self) @tvm._ffi.register_object @@ -49,18 +49,18 @@ class Map(Object): """ def __getitem__(self, k): - return _ffi_node_api.MapGetItem(self, k) + return _ffi_api.MapGetItem(self, k) def __contains__(self, k): - return _ffi_node_api.MapCount(self, k) != 0 + return _ffi_api.MapCount(self, k) != 0 def items(self): """Get the items from the map""" - akvs = _ffi_node_api.MapItems(self) + akvs = _ffi_api.MapItems(self) return [(akvs[i], akvs[i + 1]) for i in range(0, len(akvs), 2)] def __len__(self): - return _ffi_node_api.MapSize(self) + return _ffi_api.MapSize(self) def get(self, key, default=None): """Get an element with a default value. diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index 7e49461dff52..48e9ce0643a9 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -20,9 +20,9 @@ This file contains the set of passes for Relay, which exposes an interface for configuring the passes and scripting them in Python. """ -from tvm.ir import IRModule -from tvm.relay import transform, build_module -from tvm.runtime.ndarray import cpu +from ...ir import IRModule +from ...relay import transform, build_module +from ...runtime.ndarray import cpu from . import _ffi_api from .feature import Feature diff --git a/python/tvm/relay/analysis/annotated_regions.py b/python/tvm/relay/analysis/annotated_regions.py index 437b97b0fa16..a18ccb97836b 100644 --- a/python/tvm/relay/analysis/annotated_regions.py +++ b/python/tvm/relay/analysis/annotated_regions.py @@ -17,7 +17,7 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import """Regions used in Relay.""" -from tvm.runtime import Object +from ...runtime import Object from . import _ffi_api diff --git a/python/tvm/relay/analysis/call_graph.py b/python/tvm/relay/analysis/call_graph.py index 966659aac494..fd9704d0af1f 100644 --- a/python/tvm/relay/analysis/call_graph.py +++ b/python/tvm/relay/analysis/call_graph.py @@ -17,8 +17,8 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import """Call graph used in Relay.""" -from tvm.ir import IRModule -from tvm.runtime import Object +from ...ir import IRModule +from ...runtime import Object from ..expr import GlobalVar from . import _ffi_api diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index a39f72e2e61f..68397cc0cef6 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -386,6 +386,18 @@ def items(self): assert len(res) % 2 == 0 return [(res[2 * i], res[2 * i + 1]) for i in range(len(res) // 2)] + def shape_func_items(self): + """List items in the shape_func_cache. + + Returns + ------- + item_list : List[Tuple[CCacheKey, CCacheValue]] + The list of shape_func_items. + """ + res = _backend._CompileEngineListShapeFuncItems(self) + assert len(res) % 2 == 0 + return [(res[2 * i], res[2 * i + 1]) for i in range(len(res) // 2)] + def get_current_ccache_key(self): return _backend._CompileEngineGetCurrentCCacheKey(self) @@ -405,7 +417,28 @@ def dump(self): res += "target={}\n".format(k.target) res += "use_count={}\n".format(v.use_count) res += "func_name={}\n".format(v.cached_func.func_name) + res += "----relay function----\n" + res += k.source_func.astext() + "\n" + res += "----tir function----- \n" + res += "inputs={}\n".format(v.cached_func.inputs) + res += "outputs={}\n".format(v.cached_func.outputs) + res += "function: \n" + res += v.cached_func.funcs.astext() + "\n" + res += "===================================\n" + shape_func_items = self.shape_func_items() + res += "%d shape_func_items cached\n" % len(shape_func_items) + for k, v in shape_func_items: + res += "------------------------------------\n" + res += "target={}\n".format(k.target) + res += "use_count={}\n".format(v.use_count) + res += "func_name={}\n".format(v.cached_func.func_name) + res += "----relay function----\n" res += k.source_func.astext() + "\n" + res += "----tir function----- \n" + res += "inputs={}\n".format(v.cached_func.inputs) + res += "outputs={}\n".format(v.cached_func.outputs) + res += "function: \n" + res += v.cached_func.funcs.astext() + "\n" res += "===================================\n" return res diff --git a/python/tvm/relay/backend/graph_runtime_factory.py b/python/tvm/relay/backend/graph_runtime_factory.py index 4c6ac47b71b4..3427a62cd491 100644 --- a/python/tvm/relay/backend/graph_runtime_factory.py +++ b/python/tvm/relay/backend/graph_runtime_factory.py @@ -21,7 +21,7 @@ from tvm.runtime import ndarray -class GraphRuntimeFactoryModule(object): +class GraphRuntimeFactoryModule: """Graph runtime factory module. This is a module of graph runtime factory diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index f05e105ed2a2..4c9a898f2374 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -25,7 +25,7 @@ from tvm.ir.transform import PassContext from tvm.tir import expr as tvm_expr -from .. import nd as _nd, autotvm +from .. import nd as _nd, autotvm, register_func from ..target import Target from ..contrib import graph_runtime as _graph_rt from . import _build_module @@ -194,6 +194,20 @@ def get_params(self): return ret +@register_func("tvm.relay.module_export_library") +def _module_export(module, file_name): # fcompile, addons, kwargs? + return module.export_library(file_name) + + +@register_func("tvm.relay.build") +def _build_module_no_factory(mod, target=None, target_host=None, params=None, mod_name="default"): + """A wrapper around build which discards the Python GraphFactoryRuntime. + This wrapper is suitable to be used from other programming languages as + the runtime::Module can be freely passed between language boundaries. + """ + return build(mod, target, target_host, params, mod_name).module + + def build(mod, target=None, target_host=None, params=None, mod_name="default"): # fmt: off # pylint: disable=line-too-long @@ -377,10 +391,20 @@ def _make_executor(self, expr=None): ret_type = self.mod["main"].checked_type.ret_type if _ty.is_dynamic(ret_type): raise ValueError("Graph Runtime only supports static graphs, got output type", ret_type) - num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1 mod = build(self.mod, target=self.target) gmodule = _graph_rt.GraphModule(mod["default"](self.ctx)) + def _unflatten(flat_iter, cur_type): + if isinstance(cur_type, _ty.TensorType): + return next(flat_iter) + if isinstance(cur_type, _ty.TupleType): + fields = [] + for field_type in cur_type.fields: + field = _unflatten(flat_iter, field_type) + fields.append(field) + return fields + raise ValueError("Return type", ret_type, "contains unsupported type", cur_type) + def _graph_wrapper(*args, **kwargs): args = self._convert_args(self.mod["main"], args, kwargs) # Create map of inputs. @@ -388,13 +412,11 @@ def _graph_wrapper(*args, **kwargs): gmodule.set_input(i, arg) # Run the module, and fetch the output. gmodule.run() - # make a copy so multiple invocation won't hurt perf. - if num_outputs == 1: - return gmodule.get_output(0).copyto(_nd.cpu(0)) - outputs = [] - for i in range(num_outputs): - outputs.append(gmodule.get_output(i).copyto(_nd.cpu(0))) - return outputs + flattened = [] + for i in range(gmodule.get_num_outputs()): + flattened.append(gmodule.get_output(i).copyto(_nd.cpu(0))) + unflattened = _unflatten(iter(flattened), ret_type) + return unflattened return _graph_wrapper diff --git a/python/tvm/relay/frontend/__init__.py b/python/tvm/relay/frontend/__init__.py index 7e16499ccc44..aa8ac4fc7434 100644 --- a/python/tvm/relay/frontend/__init__.py +++ b/python/tvm/relay/frontend/__init__.py @@ -20,9 +20,6 @@ Contains the model importers currently defined for Relay. """ - -from __future__ import absolute_import - from .mxnet import from_mxnet from .mxnet_qnn_op_utils import quantize_conv_bias_mkldnn_from_var from .keras import from_keras diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 0c9d2c4381ac..5415c77097a2 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -1234,7 +1234,7 @@ def _mx_topk(inputs, attrs): new_attrs = {} new_attrs["k"] = attrs.get_int("k", 1) new_attrs["axis"] = attrs.get_int("axis", -1) - new_attrs["is_ascend"] = attrs.get_bool("is_ascend", True) + new_attrs["is_ascend"] = attrs.get_bool("is_ascend", False) ret_type = attrs.get_str("ret_typ", "indices") if ret_type == "mask": raise tvm.error.OpAttributeUnimplemented( diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 3c61749fc203..c709e2b4e7bd 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -34,6 +34,7 @@ from .. import expr as _expr from .. import function as _function from .. import op as _op +from .. import qnn from ..ty import TupleType, TensorType, Any from ..loops import while_loop from .. import transform @@ -805,14 +806,35 @@ def log_sigmoid(self, inputs, input_types): data = inputs[0] return _op.log(_op.tensor.sigmoid(data)) - def hard_swish(self, inputs, input_types): - data = inputs[0] - dtype = input_types[0] + def hard_sigmoid(self, inputs, input_types): + def _relu6(x): + return _op.tensor.clip(x, 0.0, 6.0) - def _relu6(input_tensor): - return _op.tensor.clip(input_tensor, 0.0, 6.0) + def func(x): + return _relu6(x + _expr.const(3.0)) / _expr.const(6.0) + + if self.is_quantized_tensor(inputs[0]): + input_scale = _expr.const(inputs[1]) + input_zero_point = _expr.const(inputs[2]) + # PyTorch seems to use the following output qparams, but accuracy + # is broken if we use this. + # TODO(masahi): Revisit this parameter choice + # + # Taken from src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp + # output_scale = _expr.const(0.00390625) # 1.0 / 2^8 + # output_zero_point = _expr.const(-128) + output_scale = input_scale + output_zero_point = input_zero_point + + data = qnn.op.dequantize(inputs[0], input_scale, input_zero_point, axis=1) + out = func(data) + return qnn.op.quantize(out, output_scale, output_zero_point, out_dtype="uint8") + + return func(inputs[0]) - return data * _relu6(data + _expr.const(3.0, dtype=dtype)) / _expr.const(6.0, dtype=dtype) + def hard_swish(self, inputs, input_types): + data = inputs[0] + return data * self.hard_sigmoid(inputs, input_types) def adaptive_avg_pool_2d(self, inputs, input_types): data = inputs[0] @@ -1374,6 +1396,20 @@ def avg_pool3d(self, inputs, input_types): count_include_pad=count_include_pad, ) + def linear(self, inputs, input_types): + # https://pytorch.org/docs/stable/nn.functional.html#linear + # 0 - input + # 1 - weight + bias = inputs[2] + mm_out = self.matmul(inputs[:2], input_types[:2]) + if isinstance(bias, _expr.Expr): + bias_ndims = len(self.infer_shape_with_prelude(bias)) + if bias_ndims == 1: + return _op.nn.bias_add(mm_out, bias) + mm_dtype = self.infer_type_with_prelude(mm_out).dtype + return self.add([mm_out, bias], [mm_dtype, input_types[2]]) + return mm_out + def dropout(self, inputs, input_types): data = inputs[0] rate = float(inputs[1]) @@ -2289,6 +2325,7 @@ def create_convert_map(self): "aten::softplus": self.softplus, "aten::avg_pool2d": self.avg_pool2d, "aten::avg_pool3d": self.avg_pool3d, + "aten::linear": self.linear, "aten::dropout": self.dropout, "aten::dropout_": self.dropout, "aten::feature_dropout": self.dropout, @@ -2403,6 +2440,8 @@ def create_convert_map(self): "aten::__not__": self.logical_not, "aten::hardswish_": self.hard_swish, "aten::hardswish": self.hard_swish, + "aten::hardsigmoid_": self.hard_sigmoid, + "aten::hardsigmoid": self.hard_sigmoid, "aten::cumsum": self.cumsum, "aten::masked_fill": self.masked_fill, "aten::masked_select": self.masked_select, @@ -3201,5 +3240,16 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt # ListConstruct kept original python list. Convert to tuple. ret = _expr.Tuple(ret) - mod["main"] = tvm.relay.Function(_analysis.free_vars(ret), ret) + # Separate data inputs and parameters to make sure data inputs are always in the beginning. + func_args = [] + data_inputs = [] + for arg in _analysis.free_vars(ret): + if arg.name_hint not in tvm_params.keys(): + data_inputs.append(arg) + else: + func_args.append(arg) + func_args = data_inputs + func_args + + mod["main"] = tvm.relay.Function(func_args, ret) + return transform.RemoveUnusedFunctions()(mod), tvm_params diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index e3431043bc86..2b85a1f3a1be 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -191,6 +191,7 @@ def _get_quant_param_for_input(input_value): "quantized::cat": (2, 3), "quantized::mul_scalar": (2, 3), "quantized::add_scalar": (2, 3), + "quantized::hardswish": (1, 2), } def dfs(current_node): @@ -358,6 +359,8 @@ def add_input_quant_params_to_op_inputs(graph): "quantized::add_scalar": 1, "quantized::mul_scalar": 1, "quantized::relu6": 1, + "quantized::hardswish": 1, + "aten::hardsigmoid": 1, } need_input_quant_param = set(num_quantized_inputs.keys()) @@ -765,6 +768,7 @@ def _impl(inputs, _): out_zp = _expr.const(inputs[3]) if q_min > z - c_q or q_max < z - c_q: + # TODO(masahi): Replace this with integer only compute dequant = relay.qnn.op.dequantize(inputs[0], _expr.const(s), _expr.const(z)) dequantized_add = _op.tensor.add(dequant, _expr.const(c_q * s)) return relay.qnn.op.quantize( @@ -820,6 +824,35 @@ def _impl(inputs, _): return _impl +def _hswish(): + # refer to src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp + # They fallback to fp32 + def _impl(inputs, _): + assert len(inputs) == 5, "Input quant params not found in op inputs" + # TODO(masahi): Replace this with integer only compute. + # We do not have to strictly follow how PyTorch does it. + + def relu6(x): + return _op.tensor.clip(x, 0.0, 6.0) + + def hardsigmoid(x): + dtype = "float32" + return relu6(x + _expr.const(3.0, dtype=dtype)) / _expr.const(6.0, dtype=dtype) + + output_scale = _expr.const(inputs[1]) + output_zero_point = _expr.const(inputs[2]) + input_scale = _expr.const(inputs[3]) + input_zero_point = _expr.const(inputs[4]) + + dequant = relay.qnn.op.dequantize(inputs[0], input_scale, input_zero_point, axis=1) + dequantized_hswish = dequant * hardsigmoid(dequant) + return relay.qnn.op.quantize( + dequantized_hswish, output_scale, output_zero_point, out_dtype="uint8" + ) + + return _impl + + def _linear_dynamic(): def _calculate_qparam(inp): # reference ATen/native/quantized/cpu/qlinear_dynamic.cpp @@ -906,4 +939,5 @@ def _impl(inputs, _): "quantized::mul_scalar": _mul_scalar(), "quantized::relu6": _relu6(), "quantized::linear_dynamic": _linear_dynamic(), + "quantized::hardswish": _hswish(), } diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 20eb95ba7c00..c79c495b0360 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1051,10 +1051,11 @@ def _impl(inputs, attr, params, mod): def _sparse_tensor_dense_matmul(): - # Sparse utility from scipy - from scipy.sparse import csr_matrix - def _impl(inputs, attr, params, mod): + # Loading this by default causes TVM to not be loadable from other languages. + # Sparse utility from scipy + from scipy.sparse import csr_matrix + assert len(inputs) == 4, "There should be 4 input tensors" indices_tensor = _infer_value(inputs[0], params, mod).asnumpy() @@ -1166,6 +1167,125 @@ def _impl(inputs, attr, params, mod): return _impl +def _math_segment_sum(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 2, "There should be 2 input tensors" + return get_relay_op("segment_sum")(inputs[0], inputs[1]) + + return _impl + + +def _sparse_segment_sum(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 3, "There should be 3 input tensors" + data = _op.take(inputs[0], inputs[1], axis=0) + return _op.segment_sum(data, inputs[2]) + + return _impl + + +def _sparse_segment_sum_with_num_segments(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 4, "There should be 4 input tensors" + data = _op.take(inputs[0], inputs[1], axis=0) + num_segments = int(inputs[3].data.asnumpy().item()) + return _op.segment_sum(data, inputs[2], num_segments) + + return _impl + + +def row_wise_divide(multi_dim_tensor, one_dim_vector): + """ + This function enables row-wise division of multi_dim_tensor and one_dim_vector. + To achieve this, it is first tiled to the appropriate shape and then elemwise_division + """ + multi_dim_tensor_offrow_shape = _op.strided_slice( + _op.shape_of(multi_dim_tensor, "int32"), [1], [-1], slice_mode="size" + ) + one_dim_vector_tiled_shape = _op.concatenate( + [_op.reverse(multi_dim_tensor_offrow_shape, 0), _expr.const([1])], axis=0 + ) + one_dim_vector_tiled = _op.transpose(_op.tile(one_dim_vector, one_dim_vector_tiled_shape)) + return _op.divide(multi_dim_tensor, one_dim_vector_tiled) + + +def count_all_indices(segment_ids, counts_dtype, num_segments=None): + """ + This snippet calculates the sqrt count of each index among all valid indices + Valid indices are from 0 to max of [segment ids, num_segments] + """ + + max_segments = _op.reshape(_op.max(segment_ids), -1) + _expr.const([1]) + if num_segments: + max_segments = _op.maximum(max_segments, _expr.const([num_segments])) + max_ones = _op.maximum(max_segments, _op.shape_of(segment_ids)) + counts = _op.segment_sum( + _op.ones(max_ones, counts_dtype), segment_ids, num_segments=num_segments + ) + real_counts = _op.clip(counts, 1, 2147483647) # Clip max doesn't work over int32 + return real_counts + + +def _sparse_segment_sum_sqrtn(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 3, "There should be 3 input tensors" + data = _op.take(inputs[0], inputs[1], axis=0) + real_counts = count_all_indices(inputs[2], attr["T"].name) + real_sqrt_counts = _op.sqrt(_op.cast_like(real_counts, data)) + + # Calculate regular segment sum + segment_sum = _op.segment_sum(data, inputs[2]) + + return row_wise_divide(segment_sum, real_sqrt_counts) + + return _impl + + +def _sparse_segment_sum_sqrtn_with_num_segments(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 4, "There should be 4 input tensors" + data = _op.take(inputs[0], inputs[1], axis=0) + num_segments = int(inputs[3].data.asnumpy().item()) + real_counts = count_all_indices(inputs[2], attr["T"].name, num_segments=num_segments) + real_sqrt_counts = _op.sqrt(_op.cast_like(real_counts, data)) + + # Calculate regular segment sum + segment_sum = _op.segment_sum(data, inputs[2], num_segments=num_segments) + + return row_wise_divide(segment_sum, real_sqrt_counts) + + return _impl + + +def _sparse_segment_mean(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 3, "There should be 3 input tensors" + data = _op.take(inputs[0], inputs[1], axis=0) + real_counts = count_all_indices(inputs[2], attr["T"].name) + + # Calculate regular segment sum + segment_sum = _op.segment_sum(data, inputs[2]) + + return row_wise_divide(segment_sum, real_counts) + + return _impl + + +def _sparse_segment_mean_with_num_segments(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 4, "There should be 4 input tensors" + data = _op.take(inputs[0], inputs[1], axis=0) + num_segments = int(inputs[3].data.asnumpy().item()) + real_counts = count_all_indices(inputs[2], attr["T"].name, num_segments=num_segments) + + # Calculate regular segment sum + segment_sum = _op.segment_sum(data, inputs[2], num_segments=num_segments) + + return row_wise_divide(segment_sum, real_counts) + + return _impl + + def _identity(): def _impl(inputs, attr, params, mod): return inputs[0] @@ -2660,6 +2780,13 @@ def _impl(inputs, attr, params, mod): "SparseTensorDenseMatMul": _sparse_tensor_dense_matmul(), "SparseFillEmptyRows": _sparse_fill_empty_rows(), "SparseReshape": _sparse_reshape(), + "SegmentSum": _math_segment_sum(), + "SparseSegmentSum": _sparse_segment_sum(), + "SparseSegmentSumWithNumSegments": _sparse_segment_sum_with_num_segments(), + "SparseSegmentSqrtN": _sparse_segment_sum_sqrtn(), + "SparseSegmentSqrtNWithNumSegments": _sparse_segment_sum_sqrtn_with_num_segments(), + "SparseSegmentMean": _sparse_segment_mean(), + "SparseSegmentMeanWithNumSegments": _sparse_segment_mean_with_num_segments(), "Split": _split(False), "SplitV": _split(True), "Sqrt": AttrCvt("sqrt"), diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 7728d6e3efa4..5f68be84d46a 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -281,3 +281,4 @@ def elemwise_shape_func(attrs, inputs, _): register_shape_func("clip", False, elemwise_shape_func) register_shape_func("log2", False, elemwise_shape_func) register_shape_func("sigmoid", False, elemwise_shape_func) +register_shape_func("tanh", False, elemwise_shape_func) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 97f45278f073..e90263d794bc 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -247,6 +247,31 @@ def strided_slice_shape_func(attrs, inputs, _): ] +@script +def _one_hot_shape_func(indices_shape, depth, axis): + in_ndim = indices_shape.shape[0] + out_ndim = in_ndim + 1 + true_axis = in_ndim if axis == -1 else axis + indices_i = 0 + out = output_tensor((out_ndim,), "int64") + for i in range(out_ndim): + if i == true_axis: + out[i] = int64(depth) + else: + out[i] = int64(indices_shape[indices_i]) + indices_i += 1 + return out + + +@_reg.register_shape_func("one_hot", False) +def one_hot_shape_func(attrs, inputs, _): + """ + Shape func for one_hot + """ + shape_func = [_one_hot_shape_func(inputs[0], convert(attrs.depth), convert(attrs.axis))] + return shape_func + + @script def _concatenate_shape_func(inputs, axis): ndim = inputs[0].shape[0] diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 85bbab692574..e0d0f165219e 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -354,6 +354,8 @@ def judge_winograd( OH = (H + pt + pb - KH) // stride_h + 1 OW = (W + pl + pr - KW) // stride_w + 1 nH, nW = (OH + tile_size - 1) // tile_size, (OW + tile_size - 1) // tile_size + if not isinstance(N, int): + return False, False, False P = N * nH * nW judge_winograd_tensorcore = ( @@ -655,7 +657,7 @@ def dense_strategy_cuda(attrs, inputs, out_type, target): data, weights = inputs b, i = get_const_tuple(data.shape) o, _ = get_const_tuple(weights.shape) - if out_type.dtype == "int8": + if data.dtype == "int8" and weights.dtype == "int8" and out_type.dtype == "int32": strategy.add_implementation( wrap_compute_dense(topi.cuda.dense_int8), wrap_topi_schedule(topi.cuda.schedule_dense_int8), diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 73508ddd2603..4129b610cb7c 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1047,7 +1047,7 @@ def gather(data, axis, indices): The input data to the operator. axis: int - The axis along which to index. + The axis along which to index. negative axis is supported. indices: relay.Expr The indices of values to gather. @@ -1450,6 +1450,75 @@ def sparse_reshape(sparse_indices, prev_shape, new_shape): return TupleWrapper(_make.sparse_reshape(sparse_indices, prev_shape, new_shape), 2) +def segment_sum(data, segment_ids, num_segments=None): + """ + Computes the sum along segment_ids along axis 0. If multiple segment_ids reference the same + location their contributions add up. + result[index, j, k, ...] = Σi... data[i, j, k,..] where index = segment_ids[i] + This op is much better understood with visualization articulated in the following links and + examples at the end of this docstring. + + https://www.tensorflow.org/api_docs/python/tf/math/unsorted_segment_sum + https://caffe2.ai/docs/sparse-operations.html#null__unsorted-segment-reduction-ops + + Parameters + ---------- + data : relay.Expr + Input Tensor. It can be of any type and multi-dimensional + segment_ids : relay.Expr + A 1-D int32/int64 tensor containing the segment_ids of the rows to calculate the output + sum upon. It defines a mapping from the zeroth dimension of data onto segment_ids. The + segment_ids tensor should be the size of the first dimension, d0, with consecutive IDs + in the range 0 to k, where k 64 else [1, 64]) + A, weights = C.op.input_tensors + _, in_dim_weights = get_const_tuple(weights.shape) + _, in_dim_A = get_const_tuple(A.shape) + + if isinstance(in_dim_A, int): + in_dim = in_dim_A + elif isinstance(in_dim_weights, int): + in_dim = in_dim_weights + else: + in_dim = None + + if in_dim is not None: + cfg.define_split("tile_k", in_dim, num_outputs=2) + if cfg.is_fallback: + cfg["tile_k"] = SplitEntity([-1, 64] if in_dim > 64 else [1, 64]) + _, kf = cfg["tile_k"].apply(s, C, C.op.reduce_axis[0]) + else: + tile_k = 64 + _, kf = s[C].split(C.op.reduce_axis[0], tile_k) - _, kf = cfg["tile_k"].apply(s, C, C.op.reduce_axis[0]) CF = s.rfactor(C, kf) if C.op in s.outputs: diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index 8145ed80af47..1bf18df09da3 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -18,7 +18,7 @@ """Sparse operators""" from __future__ import absolute_import import tvm -from tvm import te +from tvm import te, auto_scheduler from ..utils import get_const_tuple @@ -197,7 +197,7 @@ def _compute_block(nb_j, j, i): def _sparse_dense_sp_rhs_bsrmm(data, weight_data, weight_indices, weight_indptr): - (m, _) = get_const_tuple(data.shape) + (m, k) = get_const_tuple(data.shape) (_, bs_r, bs_c) = get_const_tuple(weight_data.shape) (num_blocks_plus_1,) = get_const_tuple(weight_indptr.shape) num_blocks = num_blocks_plus_1 - 1 @@ -218,7 +218,10 @@ def _compute_block(i, nb_j, j): idxm = tvm.tir.indexmod bsrmm_block = te.compute( - (m, num_blocks, bs_r), _compute_block, tag="sparse_dense_sp_rhs_bsrmm_block" + (m, num_blocks, bs_r), + _compute_block, + tag="sparse_dense_sp_rhs_bsrmm_block", + attrs={"FLOP": 2 * m * num_blocks * bs_r * k}, ) return te.compute( (m, num_blocks * bs_r), @@ -356,3 +359,112 @@ def sparse_dense_alter_layout(_attrs, _inputs, _tinfos, _out_type): Unlike other TOPI functions, this function operates on both graph level and operator level. """ return None + + +@auto_scheduler.register_task_input_check_func +def try_get_sparse_input(args): + """Analyze the input data from the given args. + + Parameters + ---------- + args : List[Tensor] + Input/output Tensor of a TVM subgraph. + + Returns + ------- + Dict[Tensor, str] : + Map from the input Tensor to its buffer name. + + Notes + ----- + The buffer name is specially designed, and these buffer should be provided in + `SearchTask(..., task_inputs={...})`. + """ + sparse_prefix = sparse_data = sparse_indices = sparse_indptr = None + + def _process_inputs(input_tensors, m, n, prefix_init): + nonlocal sparse_prefix + nonlocal sparse_data + nonlocal sparse_indices + nonlocal sparse_indptr + + assert len(input_tensors) == 4 + unsure_tensors = list(input_tensors) + # Get the Dense data + dense_data = None + for tensor in unsure_tensors: + if len(tensor.shape) == 2: + assert dense_data is None + dense_data = tensor + assert m == dense_data.shape[0] + k = dense_data.shape[1] + unsure_tensors.remove(dense_data) + + # Get the Sparse data + sparse_data = None + for tensor in unsure_tensors: + if len(tensor.shape) == 3: + assert sparse_data is None + sparse_data = tensor + block_size, bs_r, bs_c = sparse_data.shape + unsure_tensors.remove(sparse_data) + + # Get the Sparse indptr & indices + sparse_indices = None + for tensor in unsure_tensors: + assert len(tensor.shape) == 1 + if tensor.shape[0] == block_size: + assert sparse_indices is None + sparse_indices = tensor + unsure_tensors.remove(sparse_indices) + assert len(unsure_tensors) == 1 + sparse_indptr = unsure_tensors[0] + + # Generate the sparse_prefix + density = 1.0 + for i in sparse_data.shape: + density *= i + density /= k * n + density = density.value + sparse_prefix = "%s_%d_%d_%d_%d_%d_%.2f_" % (prefix_init, m, n, k, bs_r, bs_c, density) + + visited = set() + + def _traverse(t): + # We cannot directly add tensors to the set, because the comparison of + # two tensors with ndim=0 is ambiguous. + assert t.handle is not None + if t.handle.value in visited: + return + + if isinstance(t.op, te.ComputeOp): + # TODO(jcf94): Currently only support to one sparse op, add more support here + if t.op.tag == "sparse_dense_sp_rhs_bsrmm": + m, n = t.shape + assert len(t.op.input_tensors) == 1 + block_tensor = t.op.input_tensors[0] + _process_inputs(block_tensor.op.input_tensors, m, n, "sparse_dense_bsr") + if sparse_prefix is not None: + # Early stop if we find a sparse_prefix + # Notice: If any workload has more than one sparse input, this may get problem + return + for x in t.op.input_tensors: + _traverse(x) + visited.add(t.handle.value) + + try: + for arg in args: + _traverse(arg) + # pylint: disable=broad-except + except Exception: + return {} + + if sparse_data is None or sparse_indices is None or sparse_indptr is None: + return {} + + sparse_input_map = {} + sparse_input_map[sparse_data] = sparse_prefix + "W_data" + sparse_input_map[sparse_indices] = sparse_prefix + "W_indices" + sparse_input_map[sparse_indptr] = sparse_prefix + "W_indptr" + + return sparse_input_map diff --git a/python/tvm/topi/scatter_add.py b/python/tvm/topi/scatter_add.py index 4c77a0767785..6b04837b7766 100644 --- a/python/tvm/topi/scatter_add.py +++ b/python/tvm/topi/scatter_add.py @@ -32,8 +32,8 @@ def _scatter_add_1d(data, indices, updates): @hybrid.script def _scatter_add_2d(data, indices, updates, axis): out = output_tensor(data.shape, data.dtype) - for i in const_range(data.shape[0]): - for j in const_range(data.shape[1]): + for i in range(data.shape[0]): + for j in range(data.shape[1]): out[i, j] = data[i, j] if axis == 0: for i in range(indices.shape[0]): @@ -54,14 +54,14 @@ def _scatter_add_2d(data, indices, updates, axis): @hybrid.script def _scatter_add_3d(data, indices, updates, axis): out = output_tensor(data.shape, data.dtype) - for i in const_range(data.shape[0]): - for j in const_range(data.shape[1]): - for k in const_range(data.shape[2]): + for i in range(data.shape[0]): + for j in range(data.shape[1]): + for k in range(data.shape[2]): out[i, j, k] = data[i, j, k] if axis == 0: for i in range(indices.shape[0]): for j in range(indices.shape[1]): - for k in const_range(indices.shape[2]): + for k in range(indices.shape[2]): out[ indices[i, j, k] if indices[i, j, k] >= 0 @@ -72,7 +72,7 @@ def _scatter_add_3d(data, indices, updates, axis): elif axis == 1: for i in range(indices.shape[0]): for j in range(indices.shape[1]): - for k in const_range(indices.shape[2]): + for k in range(indices.shape[2]): out[ i, indices[i, j, k] @@ -83,7 +83,7 @@ def _scatter_add_3d(data, indices, updates, axis): else: for i in range(indices.shape[0]): for j in range(indices.shape[1]): - for k in const_range(indices.shape[2]): + for k in range(indices.shape[2]): out[ i, j, @@ -98,17 +98,17 @@ def _scatter_add_3d(data, indices, updates, axis): @hybrid.script def _scatter_add_4d(data, indices, updates, axis): out = output_tensor(data.shape, data.dtype) - for i in const_range(data.shape[0]): - for j in const_range(data.shape[1]): - for k in const_range(data.shape[2]): - for l in const_range(data.shape[3]): + for i in range(data.shape[0]): + for j in range(data.shape[1]): + for k in range(data.shape[2]): + for l in range(data.shape[3]): out[i, j, k, l] = data[i, j, k, l] if axis == 0: for i in range(indices.shape[0]): for j in range(indices.shape[1]): - for k in const_range(indices.shape[2]): - for l in const_range(indices.shape[3]): + for k in range(indices.shape[2]): + for l in range(indices.shape[3]): out[ indices[i, j, k, l] if indices[i, j, k, l] >= 0 @@ -120,8 +120,8 @@ def _scatter_add_4d(data, indices, updates, axis): elif axis == 1: for i in range(indices.shape[0]): for j in range(indices.shape[1]): - for k in const_range(indices.shape[2]): - for l in const_range(indices.shape[3]): + for k in range(indices.shape[2]): + for l in range(indices.shape[3]): out[ i, indices[i, j, k, l] @@ -133,8 +133,8 @@ def _scatter_add_4d(data, indices, updates, axis): elif axis == 2: for i in range(indices.shape[0]): for j in range(indices.shape[1]): - for k in const_range(indices.shape[2]): - for l in const_range(indices.shape[3]): + for k in range(indices.shape[2]): + for l in range(indices.shape[3]): out[ i, j, @@ -146,8 +146,8 @@ def _scatter_add_4d(data, indices, updates, axis): else: for i in range(indices.shape[0]): for j in range(indices.shape[1]): - for k in const_range(indices.shape[2]): - for l in const_range(indices.shape[3]): + for k in range(indices.shape[2]): + for l in range(indices.shape[3]): out[ i, j, diff --git a/rust/tvm-graph-rt/src/graph.rs b/rust/tvm-graph-rt/src/graph.rs index 646a20daaf5b..83fe37ea7970 100644 --- a/rust/tvm-graph-rt/src/graph.rs +++ b/rust/tvm-graph-rt/src/graph.rs @@ -483,7 +483,7 @@ named! { ) } -/// Loads a param dict saved using `relay.save_param_dict`. +/// Loads a param dict saved using `runtime.save_param_dict`. pub fn load_param_dict(bytes: &[u8]) -> Result, GraphFormatError> { match parse_param_dict(bytes) { Ok((remaining_bytes, param_dict)) => { diff --git a/rust/tvm-graph-rt/tests/build_model.py b/rust/tvm-graph-rt/tests/build_model.py index d34b4403c936..969075929a42 100755 --- a/rust/tvm-graph-rt/tests/build_model.py +++ b/rust/tvm-graph-rt/tests/build_model.py @@ -23,7 +23,7 @@ import numpy as np import tvm from tvm import te -from tvm import relay +from tvm import relay, runtime from tvm.relay import testing CWD = osp.dirname(osp.abspath(osp.expanduser(__file__))) @@ -47,7 +47,7 @@ def main(): with open(osp.join(CWD, "graph.json"), "w") as f_resnet: f_resnet.write(graph) with open(osp.join(CWD, "graph.params"), "wb") as f_params: - f_params.write(relay.save_param_dict(params)) + f_params.write(runtime.save_param_dict(params)) if __name__ == "__main__": diff --git a/rust/tvm-graph-rt/tests/test_nn/src/build_test_graph.py b/rust/tvm-graph-rt/tests/test_nn/src/build_test_graph.py index e743e48b01f8..0045b3b0557d 100755 --- a/rust/tvm-graph-rt/tests/test_nn/src/build_test_graph.py +++ b/rust/tvm-graph-rt/tests/test_nn/src/build_test_graph.py @@ -23,7 +23,7 @@ import numpy as np import tvm -from tvm import te +from tvm import te, runtime from tvm import relay from tvm.relay import testing @@ -49,7 +49,7 @@ def main(): f_resnet.write(graph) with open(osp.join(out_dir, "graph.params"), "wb") as f_params: - f_params.write(relay.save_param_dict(params)) + f_params.write(runtime.save_param_dict(params)) if __name__ == "__main__": diff --git a/rust/tvm-rt/README.md b/rust/tvm-rt/README.md index a99eeaa578dd..58b1f8a30a39 100644 --- a/rust/tvm-rt/README.md +++ b/rust/tvm-rt/README.md @@ -17,8 +17,8 @@ # TVM Runtime Support -This crate provides an idiomatic Rust API for [TVM](https://github.com/apache/tvm) runtime. -Currently this is tested on `1.42.0` and above. +This crate provides an idiomatic Rust API for [TVM](https://github.com/apache/tvm) runtime, +see [here](https://github.com/apache/tvm/blob/main/rust/tvm/README.md) for more details. ## What Does This Crate Offer? diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs index 5abf66708f45..e8902b54f6ef 100644 --- a/rust/tvm-rt/src/array.rs +++ b/rust/tvm-rt/src/array.rs @@ -39,9 +39,9 @@ pub struct Array { // TODO(@jroesch): convert to use generics instead of casting inside // the implementation. external! { - #[name("node.ArrayGetItem")] + #[name("runtime.ArrayGetItem")] fn array_get_item(array: ObjectRef, index: isize) -> ObjectRef; - #[name("node.ArraySize")] + #[name("runtime.ArraySize")] fn array_size(array: ObjectRef) -> i64; } @@ -69,8 +69,8 @@ impl Array { pub fn from_vec(data: Vec) -> Result> { let iter = data.into_iter().map(T::into_arg_value).collect(); - let func = Function::get("node.Array").expect( - "node.Array function is not registered, this is most likely a build or linking error", + let func = Function::get("runtime.Array").expect( + "runtime.Array function is not registered, this is most likely a build or linking error", ); // let array_data = func.invoke(iter)?; diff --git a/rust/tvm-rt/src/lib.rs b/rust/tvm-rt/src/lib.rs index 4b163eff9c8f..5f9ab1617378 100644 --- a/rust/tvm-rt/src/lib.rs +++ b/rust/tvm-rt/src/lib.rs @@ -99,7 +99,6 @@ pub mod map; pub mod module; pub mod ndarray; mod to_function; -pub mod value; /// Outputs the current TVM version. pub fn version() -> &'static str { @@ -112,6 +111,8 @@ pub fn version() -> &'static str { #[cfg(test)] mod tests { use super::*; + use crate::{ByteArray, Context, DataType}; + use std::{convert::TryInto, str::FromStr}; #[test] fn print_version() { @@ -127,4 +128,29 @@ mod tests { errors::NDArrayError::EmptyArray.to_string() ); } + + #[test] + fn bytearray() { + let w = vec![1u8, 2, 3, 4, 5]; + let v = ByteArray::from(w.as_slice()); + let tvm: ByteArray = RetValue::from(v).try_into().unwrap(); + assert_eq!( + tvm.data(), + w.iter().copied().collect::>().as_slice() + ); + } + + #[test] + fn ty() { + let t = DataType::from_str("int32").unwrap(); + let tvm: DataType = RetValue::from(t).try_into().unwrap(); + assert_eq!(tvm, t); + } + + #[test] + fn ctx() { + let c = Context::from_str("gpu").unwrap(); + let tvm: Context = RetValue::from(c).try_into().unwrap(); + assert_eq!(tvm, c); + } } diff --git a/rust/tvm-rt/src/map.rs b/rust/tvm-rt/src/map.rs index b8bfb4e5e644..d6dfaf3641b8 100644 --- a/rust/tvm-rt/src/map.rs +++ b/rust/tvm-rt/src/map.rs @@ -48,13 +48,13 @@ where // TODO(@jroesch): convert to use generics instead of casting inside // the implementation. external! { - #[name("node.MapSize")] + #[name("runtime.MapSize")] fn map_size(map: ObjectRef) -> i64; - #[name("node.MapGetItem")] + #[name("runtime.MapGetItem")] fn map_get_item(map_object: ObjectRef, key: ObjectRef) -> ObjectRef; - #[name("node.MapCount")] + #[name("runtime.MapCount")] fn map_count(map: ObjectRef, key: ObjectRef) -> ObjectRef; - #[name("node.MapItems")] + #[name("runtime.MapItems")] fn map_items(map: ObjectRef) -> Array; } @@ -81,8 +81,8 @@ where V: IsObjectRef, { pub fn from_data(data: Vec) -> Result> { - let func = Function::get("node.Map").expect( - "node.Map function is not registered, this is most likely a build or linking error", + let func = Function::get("runtime.Map").expect( + "runtime.Map function is not registered, this is most likely a build or linking error", ); let map_data: ObjectPtr = func.invoke(data)?.try_into()?; @@ -107,6 +107,18 @@ where let oref: ObjectRef = map_get_item(self.object.clone(), key.upcast())?; oref.downcast() } + + pub fn empty() -> Self { + Self::from_iter(vec![].into_iter()) + } + + //(@jroesch): I don't think this is a correct implementation. + pub fn null() -> Self { + Map { + object: ObjectRef::null(), + _data: PhantomData, + } + } } pub struct IntoIter { diff --git a/rust/tvm-rt/src/module.rs b/rust/tvm-rt/src/module.rs index c0822a5045e6..6109819939af 100644 --- a/rust/tvm-rt/src/module.rs +++ b/rust/tvm-rt/src/module.rs @@ -26,21 +26,24 @@ use std::{ ptr, }; +use crate::object::Object; +use tvm_macros::Object; use tvm_sys::ffi; use crate::errors::Error; +use crate::String as TString; use crate::{errors, function::Function}; -const ENTRY_FUNC: &str = "__tvm_main__"; - /// Wrapper around TVM module handle which contains an entry function. /// The entry function can be applied to an imported module through [`entry_func`]. /// /// [`entry_func`]:struct.Module.html#method.entry_func -#[derive(Debug, Clone)] -pub struct Module { - pub(crate) handle: ffi::TVMModuleHandle, - entry_func: Option, +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "Module"] +#[type_key = "runtime.Module"] +pub struct ModuleNode { + base: Object, } crate::external! { @@ -49,21 +52,18 @@ crate::external! { #[name("runtime.ModuleLoadFromFile")] fn load_from_file(file_name: CString, format: CString) -> Module; + + #[name("runtime.ModuleSaveToFile")] + fn save_to_file(module: Module, name: TString, fmt: TString); + + // TODO(@jroesch): we need to refactor this + #[name("tvm.relay.module_export_library")] + fn export_library(module: Module, file_name: TString); } impl Module { - pub(crate) fn new(handle: ffi::TVMModuleHandle) -> Self { - Self { - handle, - entry_func: None, - } - } - - pub fn entry(&mut self) -> Option { - if self.entry_func.is_none() { - self.entry_func = self.get_function(ENTRY_FUNC, false).ok(); - } - self.entry_func.clone() + pub fn default_fn(&mut self) -> Result { + self.get_function("default", true) } /// Gets a function by name from a registered module. @@ -72,7 +72,7 @@ impl Module { let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle; check_call!(ffi::TVMModGetFunction( - self.handle, + self.handle(), name.as_ptr() as *const c_char, query_import as c_int, &mut fhandle as *mut _ @@ -87,7 +87,7 @@ impl Module { /// Imports a dependent module such as `.ptx` for gpu. pub fn import_module(&self, dependent_module: Module) { - check_call!(ffi::TVMModImport(self.handle, dependent_module.handle)) + check_call!(ffi::TVMModImport(self.handle(), dependent_module.handle())) } /// Loads a module shared library from path. @@ -110,6 +110,14 @@ impl Module { Ok(module) } + pub fn save_to_file(&self, name: String, fmt: String) -> Result<(), Error> { + save_to_file(self.clone(), name.into(), fmt.into()) + } + + pub fn export_library(&self, name: String) -> Result<(), Error> { + export_library(self.clone(), name.into()) + } + /// Checks if a target device is enabled for a module. pub fn enabled(&self, target: &str) -> bool { let target = CString::new(target).unwrap(); @@ -118,13 +126,7 @@ impl Module { } /// Returns the underlying module handle. - pub fn handle(&self) -> ffi::TVMModuleHandle { - self.handle - } -} - -impl Drop for Module { - fn drop(&mut self) { - check_call!(ffi::TVMModFree(self.handle)); + pub unsafe fn handle(&self) -> ffi::TVMModuleHandle { + self.0.clone().unwrap().into_raw() as *mut _ } } diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 8df6041956b8..264d5febd103 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -267,6 +267,10 @@ impl ObjectPtr { Err(Error::downcast("TODOget_type_key".into(), U::TYPE_KEY)) } } + + pub unsafe fn into_raw(self) -> *mut T { + self.ptr.as_ptr() + } } impl std::ops::Deref for ObjectPtr { @@ -300,7 +304,7 @@ impl<'a, T: IsObject> TryFrom for ObjectPtr { use crate::ndarray::NDArrayContainer; match ret_value { - RetValue::ObjectHandle(handle) => { + RetValue::ObjectHandle(handle) | RetValue::ModuleHandle(handle) => { let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; debug_assert!(optr.count() >= 1); optr.downcast() @@ -329,6 +333,11 @@ impl<'a, T: IsObject> From> for ArgValue<'a> { assert!(!raw_ptr.is_null()); ArgValue::NDArrayHandle(raw_ptr) } + "runtime.Module" => { + let raw_ptr = ObjectPtr::leak(object_ptr) as *mut Object as *mut std::ffi::c_void; + assert!(!raw_ptr.is_null()); + ArgValue::ModuleHandle(raw_ptr) + } _ => { let raw_ptr = ObjectPtr::leak(object_ptr) as *mut Object as *mut std::ffi::c_void; assert!(!raw_ptr.is_null()); @@ -346,7 +355,7 @@ impl<'a, T: IsObject> TryFrom> for ObjectPtr { use crate::ndarray::NDArrayContainer; match arg_value { - ArgValue::ObjectHandle(handle) => { + ArgValue::ObjectHandle(handle) | ArgValue::ModuleHandle(handle) => { let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; debug_assert!(optr.count() >= 1); optr.downcast() diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs index affd81b0e7ed..c5ede7d224ce 100644 --- a/rust/tvm-rt/src/to_function.rs +++ b/rust/tvm-rt/src/to_function.rs @@ -255,6 +255,7 @@ impl_typed_and_to_function!(2; A, B); impl_typed_and_to_function!(3; A, B, C); impl_typed_and_to_function!(4; A, B, C, D); impl_typed_and_to_function!(5; A, B, C, D, E); +impl_typed_and_to_function!(6; A, B, C, D, E, G); #[cfg(test)] mod tests { diff --git a/rust/tvm-rt/src/value.rs b/rust/tvm-rt/src/value.rs deleted file mode 100644 index b8cd190176c4..000000000000 --- a/rust/tvm-rt/src/value.rs +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -//! This module implements [`ArgValue`] and [`RetValue`] types -//! and their conversions needed for the types used in frontend crate. -//! `RetValue` is the owned version of `TVMPODValue`. - -use std::convert::TryFrom; - -use crate::{ArgValue, Module, RetValue}; -use tvm_sys::{errors::ValueDowncastError, ffi::TVMModuleHandle, try_downcast}; - -macro_rules! impl_handle_val { - ($type:ty, $variant:ident, $inner_type:ty, $ctor:path) => { - impl<'a> From<&'a $type> for ArgValue<'a> { - fn from(arg: &'a $type) -> Self { - ArgValue::$variant(arg.handle() as $inner_type) - } - } - - impl<'a> From<&'a mut $type> for ArgValue<'a> { - fn from(arg: &'a mut $type) -> Self { - ArgValue::$variant(arg.handle() as $inner_type) - } - } - - impl<'a> TryFrom> for $type { - type Error = ValueDowncastError; - fn try_from(val: ArgValue<'a>) -> Result<$type, Self::Error> { - try_downcast!(val -> $type, |ArgValue::$variant(val)| { $ctor(val) }) - } - } - - impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for $type { - type Error = ValueDowncastError; - fn try_from(val: &'a ArgValue<'v>) -> Result<$type, Self::Error> { - try_downcast!(val -> $type, |ArgValue::$variant(val)| { $ctor(*val) }) - } - } - - impl From<$type> for RetValue { - fn from(val: $type) -> RetValue { - RetValue::$variant(val.handle() as $inner_type) - } - } - - impl TryFrom for $type { - type Error = ValueDowncastError; - fn try_from(val: RetValue) -> Result<$type, Self::Error> { - try_downcast!(val -> $type, |RetValue::$variant(val)| { $ctor(val) }) - } - } - }; -} - -impl_handle_val!(Module, ModuleHandle, TVMModuleHandle, Module::new); - -#[cfg(test)] -mod tests { - use std::{convert::TryInto, str::FromStr}; - - use crate::{ByteArray, Context, DataType}; - - use super::*; - - #[test] - fn bytearray() { - let w = vec![1u8, 2, 3, 4, 5]; - let v = ByteArray::from(w.as_slice()); - let tvm: ByteArray = RetValue::from(v).try_into().unwrap(); - assert_eq!( - tvm.data(), - w.iter().copied().collect::>().as_slice() - ); - } - - #[test] - fn ty() { - let t = DataType::from_str("int32").unwrap(); - let tvm: DataType = RetValue::from(t).try_into().unwrap(); - assert_eq!(tvm, t); - } - - #[test] - fn ctx() { - let c = Context::from_str("gpu").unwrap(); - let tvm: Context = RetValue::from(c).try_into().unwrap(); - assert_eq!(tvm, c); - } -} diff --git a/rust/tvm/Cargo.toml b/rust/tvm/Cargo.toml index 29d2003b5089..9438f340f78f 100644 --- a/rust/tvm/Cargo.toml +++ b/rust/tvm/Cargo.toml @@ -50,9 +50,10 @@ tvm-macros = { version = "*", path = "../tvm-macros/" } paste = "0.1" mashup = "0.1" once_cell = "^1.3.1" -pyo3 = { version = "0.11.1", optional = true } +pyo3 = { version = "^0.13", optional = true } codespan-reporting = "0.9.5" structopt = { version = "0.3" } +tracing = "^0.1" [[bin]] name = "tyck" diff --git a/rust/tvm/README.md b/rust/tvm/README.md index 26f9f1fbedfd..75fabe7d9a1b 100644 --- a/rust/tvm/README.md +++ b/rust/tvm/README.md @@ -15,221 +15,40 @@ -# TVM Runtime Frontend Support +# TVM -This crate provides an idiomatic Rust API for [TVM](https://github.com/apache/tvm) runtime frontend. Currently this requires **Nightly Rust** and tested on `rustc 1.32.0-nightly` +This crate provides an idiomatic Rust API for [TVM](https://github.com/apache/tvm). +The code works on **Stable Rust** and is tested against `rustc 1.47`. -## What Does This Crate Offer? - -Here is a major workflow - -1. Train your **Deep Learning** model using any major framework such as [PyTorch](https://pytorch.org/), [Apache MXNet](https://mxnet.apache.org/) or [TensorFlow](https://www.tensorflow.org/) -2. Use **TVM** to build optimized model artifacts on a supported context such as CPU, GPU, OpenCL and specialized accelerators. -3. Deploy your models using **Rust** :heart: - -### Example: Deploy Image Classification from Pretrained Resnet18 on ImageNet1k - -Please checkout [examples/resnet](examples/resnet) for the complete end-to-end example. - -Here's a Python snippet for downloading and building a pretrained Resnet18 via Apache MXNet and TVM - -```python -block = get_model('resnet18_v1', pretrained=True) - -sym, params = relay.frontend.from_mxnet(block, shape_dict) -# compile the model -with relay.build_config(opt_level=opt_level): - graph, lib, params = relay.build( - net, target, params=params) -# same the model artifacts -lib.save(os.path.join(target_dir, "deploy_lib.o")) -cc.create_shared(os.path.join(target_dir, "deploy_lib.so"), - [os.path.join(target_dir, "deploy_lib.o")]) - -with open(os.path.join(target_dir, "deploy_graph.json"), "w") as fo: - fo.write(graph.json()) -with open(os.path.join(target_dir,"deploy_param.params"), "wb") as fo: - fo.write(relay.save_param_dict(params)) -``` +You can find the API Documentation [here](https://tvm.apache.org/docs/api/rust/tvm/index.html). -Now, we need to input the artifacts to create and run the *Graph Runtime* to detect our input cat image - -![cat](https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true) +## What Does This Crate Offer? -as demostrated in the following Rust snippet +The goal of this crate is to provide bindings to both the TVM compiler and runtime +APIs. First train your **Deep Learning** model using any major framework such as +[PyTorch](https://pytorch.org/), [Apache MXNet](https://mxnet.apache.org/) or [TensorFlow](https://www.tensorflow.org/). +Then use **TVM** to build and deploy optimized model artifacts on a supported devices such as CPU, GPU, OpenCL and specialized accelerators. -```rust - let graph = fs::read_to_string("deploy_graph.json")?; - // load the built module - let lib = Module::load(&Path::new("deploy_lib.so"))?; - // get the global TVM graph runtime function - let runtime_create_fn = Function::get("tvm.graph_runtime.create", true).unwrap(); - let runtime_create_fn_ret = call_packed!( - runtime_create_fn, - &graph, - &lib, - &ctx.device_type, - &ctx.device_id - )?; - // get graph runtime module - let graph_runtime_module: Module = runtime_create_fn_ret.try_into()?; - // get the registered `load_params` from runtime module - let ref load_param_fn = graph_runtime_module - .get_function("load_params", false) - .unwrap(); - // parse parameters and convert to TVMByteArray - let params: Vec = fs::read("deploy_param.params")?; - let barr = TVMByteArray::from(¶ms); - // load the parameters - call_packed!(load_param_fn, &barr)?; - // get the set_input function - let ref set_input_fn = graph_runtime_module - .get_function("set_input", false) - .unwrap(); +The Rust bindings are composed of a few crates: +- The [tvm](https://tvm.apache.org/docs/api/rust/tvm/index.html) crate which exposes Rust bindings to + both the compiler and runtime. +- The [tvm_macros](https://tvm.apache.org/docs/api/rust/tvm/index.html) crate which provides macros + which generate unsafe boilerplate for TVM's data structures. +- The [tvm_rt](https://tvm.apache.org/docs/api/rust/tvm_rt/index.html) crate which exposes Rust + bindings to the TVM runtime APIs. +- The [tvm_sys] crate which provides raw bindings and linkage to the TVM C++ library. +- The [tvm_graph_rt] crate which implements a version of the TVM graph runtime in Rust vs. C++. - call_packed!(set_input_fn, "data", &input)?; - // get `run` function from runtime module - let ref run_fn = graph_runtime_module.get_function("run", false).unwrap(); - // execute the run function. Note that it has no argument - call_packed!(run_fn,)?; - // prepare to get the output - let output_shape = &mut [1, 1000]; - let output = empty(output_shape, TVMContext::cpu(0), TVMType::from("float32")); - // get the `get_output` function from runtime module - let ref get_output_fn = graph_runtime_module - .get_function("get_output", false) - .unwrap(); - // execute the get output function - call_packed!(get_output_fn, &0, &output)?; - // flatten the output as Vec - let output = output.to_vec::()?; -``` +These crates have been recently refactored and reflect a much different philosophy than +previous bindings, as well as much increased support for more of the TVM API including +exposing all of the compiler internals. -and the model correctly predicts the input image as **tiger cat**. +These are still very much in development and should not be considered stable, but contributions +and usage is welcome and encouraged. If you want to discuss design issues check our Discourse +[forum](https://discuss.tvm.ai) and for bug reports check our GitHub [repository](https://github.com/apache/tvm). -## Installations +## Install -Please follow TVM [installations](https://tvm.apache.org/docs/install/index.html), `export TVM_HOME=/path/to/tvm` and add `libtvm_runtime` to your `LD_LIBRARY_PATH`. +Please follow the TVM [install](https://tvm.apache.org/docs/install/index.html) instructions, `export TVM_HOME=/path/to/tvm` and add `libtvm_runtime` to your `LD_LIBRARY_PATH`. *Note:* To run the end-to-end examples and tests, `tvm` and `topi` need to be added to your `PYTHONPATH` or it's automatic via an Anaconda environment when it is installed individually. - -## Supported TVM Functionalities - -### Use TVM to Generate Shared Library - -One can use the following Python snippet to generate `add_gpu.so` which add two vectors on GPU. - -```python -import os -import tvm -from tvm import te -from tvm.contrib import cc - -def test_add(target_dir): - if not tvm.runtime.enabled("cuda"): - print("skip {__file__} because cuda is not enabled...".format(__file__=__file__)) - return - n = te.var("n") - A = te.placeholder((n,), name='A') - B = te.placeholder((n,), name='B') - C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") - s = te.create_schedule(C.op) - bx, tx = s[C].split(C.op.axis[0], factor=64) - s[C].bind(bx, tvm.thread_axis("blockIdx.x")) - s[C].bind(tx, tvm.thread_axis("threadIdx.x")) - fadd_cuda = tvm.build(s, [A, B, C], "cuda", target_host="llvm", name="myadd") - - fadd_cuda.save(os.path.join(target_dir, "add_gpu.o")) - fadd_cuda.imported_modules[0].save(os.path.join(target_dir, "add_gpu.ptx")) - cc.create_shared(os.path.join(target_dir, "add_gpu.so"), - [os.path.join(target_dir, "add_gpu.o")]) - - -if __name__ == "__main__": - import sys - if len(sys.argv) != 2: - sys.exit(-1) - test_add(sys.argv[1]) -``` - -### Run the Generated Shared Library - -The following code snippet demonstrates how to load and test the generated shared library (`add_gpu.so`) in Rust. - -```rust -extern crate tvm_frontend as tvm; - -use tvm::*; - -fn main() { - let shape = &mut [2]; - let mut data = vec![3f32, 4.0]; - let mut arr = empty(shape, TVMContext::gpu(0), TVMType::from("float32")); - arr.copy_from_buffer(data.as_mut_slice()); - let mut ret = empty(shape, TVMContext::gpu(0), TVMType::from("float32")); - let mut fadd = Module::load(&Path::new("add_gpu.so")).unwrap(); - let fadd_dep = Module::load(&Path::new("add_gpu.ptx")).unwrap(); - assert!(fadd.enabled("gpu")); - fadd.import_module(fadd_dep); - fadd.entry(); - function::Builder::from(&mut fadd) - .arg(&arr) - .arg(&arr) - .set_output(&mut ret)? - .invoke() - .unwrap(); - - assert_eq!(ret.to_vec::().unwrap(), vec![6f32, 8.0]); -} -``` - -**Note:** it is required to instruct the `rustc` to link to the generated `add_gpu.so` in runtime, for example by -`cargo:rustc-link-search=native=add_gpu`. - -See the tests and examples custom `build.rs` for more details. - -### Convert and Register a Rust Function as a TVM Packed Function - -One can use `register_global_func!` macro to convert and register a Rust -function of type `fn(&[TVMArgValue]) -> Result` to a global TVM **packed function** as follows - -```rust -#[macro_use] -extern crate tvm_frontend as tvm; -use std::convert::TryInto; -use tvm::*; - -fn main() { - register_global_func! { - fn sum(args: &[TVMArgValue]) -> Result { - let mut ret = 0f32; - let shape = &mut [2]; - for arg in args.iter() { - let e = empty(shape, TVMContext::cpu(0), TVMType::from("float32")); - let arg: NDArray = arg.try_into()?; - let arr = arg.copy_to_ndarray(e).unwrap(); - let rnd: ArrayD = ArrayD::try_from(&arr).unwrap(); - ret += rnd.scalar_sum(); - } - let ret_val = TVMRetValue::from(&ret); - Ok(ret_val) - } - } - - let shape = &mut [2]; - let mut data = vec![3f32, 4.0]; - let mut arr = empty(shape, TVMContext::cpu(0), TVMType::from("float32")); - arr.copy_from_buffer(data.as_mut_slice()); - let mut registered = function::Builder::default(); - let ret: f64 = registered - .get_function("sum", true) - .arg(&arr) - .arg(&arr) - .invoke() - .unwrap() - .try_into() - .unwrap(); - - assert_eq!(ret, 14f64); -} -``` diff --git a/rust/tvm/examples/resnet/src/build_resnet.py b/rust/tvm/examples/resnet/src/build_resnet.py index 03ac611a191a..fdacb5bb1fca 100644 --- a/rust/tvm/examples/resnet/src/build_resnet.py +++ b/rust/tvm/examples/resnet/src/build_resnet.py @@ -27,7 +27,7 @@ import tvm from tvm import te -from tvm import relay +from tvm import relay, runtime from tvm.relay import testing from tvm.contrib import graph_runtime, cc from PIL import Image @@ -88,7 +88,7 @@ def build(target_dir): fo.write(graph) with open(osp.join(target_dir, "deploy_param.params"), "wb") as fo: - fo.write(relay.save_param_dict(params)) + fo.write(runtime.save_param_dict(params)) def download_img_labels(): diff --git a/rust/tvm/src/compiler/graph_rt.rs b/rust/tvm/src/compiler/graph_rt.rs new file mode 100644 index 000000000000..6b5873398cab --- /dev/null +++ b/rust/tvm/src/compiler/graph_rt.rs @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::convert::TryInto; +use std::io::Read; +use std::path::Path; + +use once_cell::sync::Lazy; +use thiserror::Error; + +use crate::ir::IRModule; +use crate::python; +use crate::runtime::{map::Map, Function, Module as RtModule, NDArray, String}; + +#[derive(Error, Debug)] +pub enum Error { + #[error("{0}")] + IO(#[from] std::io::Error), + #[error("{0}")] + TVM(#[from] crate::errors::Error), +} + +static TVM_BUILD: Lazy = Lazy::new(|| { + python::import("tvm").unwrap(); + python::import("tvm.relay").unwrap(); + Function::get("tvm.relay.build").unwrap() +}); + +fn _compile_module( + module: IRModule, + target: String, + target_host: String, + params: Map, + module_name: String, +) -> Result { + // The RAW API is Fn(IRModule, String, String, Map, String); + let module = TVM_BUILD.invoke(vec![ + module.into(), + target.into(), + target_host.into(), + params.into(), + module_name.into(), + ])?; + let module: RtModule = module.try_into().unwrap(); + Ok(module) +} + +#[derive(Debug)] +pub struct CompilerConfig { + target: Option, + target_host: Option, + params: Map, + module_name: Option, +} + +impl Default for CompilerConfig { + fn default() -> Self { + CompilerConfig { + target: None, + target_host: None, + params: Map::empty(), + module_name: None, + } + } +} + +/// Compile a module from a configuration and IRModule. +/// +/// # Arguments +/// +/// * `config` - The configuration for the compiler. +/// * `module` - The IRModule to compile. +pub fn compile_module(config: CompilerConfig, module: IRModule) -> Result { + let target = config.target.unwrap_or("llvm".into()); + _compile_module( + module, + target, + "llvm".into(), + Map::::empty(), + "default".into(), + ) +} + +/// Compile an IRModule on disk and output a runtime module to disk. +/// +/// # Arguments +/// * `config` - The configuration for the compiler. +/// * `ir_mod_path` - The path the serialized IRModule. +// +/// * `output_rt_mod_path` - The path to the output runtime module. +pub fn compile_from_disk( + config: CompilerConfig, + ir_mod_path: P1, + output_rt_mod_path: P2, +) -> Result<(), Error> +where + P1: AsRef, + P2: AsRef, +{ + let mut input_file = std::fs::File::open(ir_mod_path.as_ref())?; + let mut input_module_text = std::string::String::new(); + input_file.read_to_string(&mut input_module_text)?; + let input_module = IRModule::parse("name", input_module_text)?; + let rt_module = compile_module(config, input_module)?; + let output_path_str = output_rt_mod_path.as_ref().display().to_string(); + rt_module.export_library(output_path_str)?; + Ok(()) +} diff --git a/rust/tvm/src/compiler/mod.rs b/rust/tvm/src/compiler/mod.rs new file mode 100644 index 000000000000..ed8b47edbad4 --- /dev/null +++ b/rust/tvm/src/compiler/mod.rs @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +pub mod graph_rt; diff --git a/rust/tvm/src/ir/diagnostics/mod.rs b/rust/tvm/src/ir/diagnostics/mod.rs index 8bcdf8f51e60..182ffd4d9081 100644 --- a/rust/tvm/src/ir/diagnostics/mod.rs +++ b/rust/tvm/src/ir/diagnostics/mod.rs @@ -35,7 +35,7 @@ use tvm_macros::{external, Object}; pub mod codespan; external! { - #[name("node.ArrayGetItem")] + #[name("runtime.ArrayGetItem")] fn get_renderer() -> DiagnosticRenderer; #[name("diagnostics.DiagnosticRenderer")] diff --git a/rust/tvm/src/ir/expr.rs b/rust/tvm/src/ir/expr.rs index 653169def3a4..03d8a4920718 100644 --- a/rust/tvm/src/ir/expr.rs +++ b/rust/tvm/src/ir/expr.rs @@ -32,12 +32,14 @@ use super::span::Span; #[type_key = "Expr"] pub struct BaseExprNode { pub base: Object, + pub span: Span, } impl BaseExprNode { - pub fn base() -> BaseExprNode { + pub fn base(span: Span) -> BaseExprNode { BaseExprNode { base: Object::base::(), + span, } } } @@ -52,9 +54,9 @@ pub struct PrimExprNode { } impl PrimExprNode { - pub fn base(datatype: DataType) -> PrimExprNode { + pub fn base(datatype: DataType, span: Span) -> PrimExprNode { PrimExprNode { - base: BaseExprNode::base::(), + base: BaseExprNode::base::(span), datatype, } } @@ -70,9 +72,9 @@ pub struct GlobalVarNode { } impl GlobalVar { - pub fn new(name_hint: String, _span: Span) -> GlobalVar { + pub fn new(name_hint: String, span: Span) -> GlobalVar { let node = GlobalVarNode { - base: relay::ExprNode::base::(), + base: relay::ExprNode::base::(span), name_hint: name_hint.into(), }; GlobalVar(Some(ObjectPtr::new(node))) diff --git a/rust/tvm/src/ir/function.rs b/rust/tvm/src/ir/function.rs index 14c00ea02bf6..43aca869f385 100644 --- a/rust/tvm/src/ir/function.rs +++ b/rust/tvm/src/ir/function.rs @@ -17,12 +17,12 @@ * under the License. */ -use crate::ir::relay::ExprNode; -use crate::runtime::{IsObject, IsObjectRef, ObjectRef}; - use tvm_macros::Object; -// Define Calling Convention. +use super::span::Span; + +use crate::ir::relay::ExprNode; +use crate::runtime::{IsObject, IsObjectRef, ObjectRef}; // TODO(@jroesch): define DictAttrs pub type DictAttrs = ObjectRef; @@ -39,7 +39,7 @@ pub struct BaseFuncNode { impl BaseFuncNode { pub fn base() -> BaseFuncNode { BaseFuncNode { - base: ExprNode::base::(), + base: ExprNode::base::(Span::null()), attrs: ::null(), } } diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index a09f70dc25b9..513a906f6db4 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -279,8 +279,8 @@ mod tests { let name = GlobalTypeVar::new("my_type", TypeKind::Type, Span::null()); let type_data = TypeData::new(name.clone(), vec![], vec![], Span::null()); module.add_def(name.clone(), type_data, true)?; - let by_gtv = module.lookup_def(name)?; - let by_gv = module.lookup_def_str("my_type")?; + let _by_gtv = module.lookup_def(name)?; + let _by_gv = module.lookup_def_str("my_type")?; Ok(()) } diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index 9d2983237acb..f43967f28d60 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -23,7 +23,7 @@ use super::attrs::Attrs; use super::expr::BaseExprNode; use super::function::BaseFuncNode; use super::span::Span; -use super::ty::{Type, TypeNode}; +use super::ty::Type; use tvm_macros::Object; use tvm_rt::NDArray; @@ -39,19 +39,14 @@ pub mod attrs; #[type_key = "RelayExpr"] pub struct ExprNode { pub base: BaseExprNode, - pub span: ObjectRef, pub checked_type: Type, } impl ExprNode { - pub fn base() -> ExprNode { + pub fn base(span: Span) -> ExprNode { ExprNode { - base: BaseExprNode::base::(), - span: ObjectRef::null(), - checked_type: Type::from(TypeNode { - base: Object::base::(), - span: Span::null(), - }), + base: BaseExprNode::base::(span.clone()), + checked_type: Type::null(), } } } @@ -85,9 +80,9 @@ pub struct ConstantNode { } impl Constant { - pub fn new(data: NDArray, _span: ObjectRef) -> Constant { + pub fn new(data: NDArray, span: Span) -> Constant { let node = ConstantNode { - base: ExprNode::base::(), + base: ExprNode::base::(span), data: data, }; Constant(Some(ObjectPtr::new(node))) @@ -104,9 +99,9 @@ pub struct TupleNode { } impl Tuple { - pub fn new(fields: Array, _span: ObjectRef) -> Tuple { + pub fn new(fields: Array, span: Span) -> Tuple { let node = TupleNode { - base: ExprNode::base::(), + base: ExprNode::base::(span), fields, }; Tuple(Some(ObjectPtr::new(node))) @@ -124,9 +119,9 @@ pub struct VarNode { } impl Var { - pub fn new(name_hint: String, type_annotation: Type, _span: Span) -> Var { + pub fn new(name_hint: String, type_annotation: Type, span: Span) -> Var { let node = VarNode { - base: ExprNode::base::(), + base: ExprNode::base::(span), vid: Id::new(name_hint.into()), type_annotation: type_annotation, }; @@ -165,10 +160,10 @@ impl Call { args: Array, attrs: Attrs, type_args: Array, - _span: ObjectRef, + span: Span, ) -> Call { let node = CallNode { - base: ExprNode::base::(), + base: ExprNode::base::(span), op: op, args: args, attrs: attrs, @@ -190,9 +185,9 @@ pub struct LetNode { } impl Let { - pub fn new(var: Var, value: Expr, body: Expr, _span: ObjectRef) -> Let { + pub fn new(var: Var, value: Expr, body: Expr, span: Span) -> Let { let node = LetNode { - base: ExprNode::base::(), + base: ExprNode::base::(span), var, value, body, @@ -213,9 +208,9 @@ pub struct IfNode { } impl If { - pub fn new(cond: Expr, true_branch: Expr, false_branch: Expr, _span: ObjectRef) -> If { + pub fn new(cond: Expr, true_branch: Expr, false_branch: Expr, span: Span) -> If { let node = IfNode { - base: ExprNode::base::(), + base: ExprNode::base::(span), cond, true_branch, false_branch, @@ -235,9 +230,9 @@ pub struct TupleGetItemNode { } impl TupleGetItem { - pub fn new(tuple: Expr, index: i32, _span: ObjectRef) -> TupleGetItem { + pub fn new(tuple: Expr, index: i32, span: Span) -> TupleGetItem { let node = TupleGetItemNode { - base: ExprNode::base::(), + base: ExprNode::base::(span), tuple, index, }; @@ -255,9 +250,9 @@ pub struct RefCreateNode { } impl RefCreate { - pub fn new(value: Expr, _span: ObjectRef) -> RefCreate { + pub fn new(value: Expr, span: Span) -> RefCreate { let node = RefCreateNode { - base: ExprNode::base::(), + base: ExprNode::base::(span), value, }; RefCreate(Some(ObjectPtr::new(node))) @@ -274,9 +269,9 @@ pub struct RefReadNode { } impl RefRead { - pub fn new(ref_value: Expr, _span: ObjectRef) -> RefRead { + pub fn new(ref_value: Expr, span: Span) -> RefRead { let node = RefReadNode { - base: ExprNode::base::(), + base: ExprNode::base::(span), ref_value, }; RefRead(Some(ObjectPtr::new(node))) @@ -294,9 +289,9 @@ pub struct RefWriteNode { } impl RefWrite { - pub fn new(ref_value: Expr, value: Expr, _span: ObjectRef) -> RefWrite { + pub fn new(ref_value: Expr, value: Expr, span: Span) -> RefWrite { let node = RefWriteNode { - base: ExprNode::base::(), + base: ExprNode::base::(span), ref_value, value, }; @@ -316,9 +311,9 @@ pub struct ConstructorNode { } impl Constructor { - pub fn new(name_hint: String, inputs: Array, tag: i32, _span: ObjectRef) -> Constructor { + pub fn new(name_hint: String, inputs: Array, tag: i32, span: Span) -> Constructor { let node = ConstructorNode { - base: ExprNode::base::(), + base: ExprNode::base::(span), name_hint, inputs, tag, @@ -335,14 +330,14 @@ impl Constructor { #[type_key = "relay.Pattern"] pub struct PatternNode { pub base: Object, - pub span: ObjectRef, + pub span: Span, } impl PatternNode { - pub fn base() -> PatternNode { + pub fn base(span: Span) -> PatternNode { PatternNode { base: Object::base::(), - span: ObjectRef::null(), + span: span, } } } @@ -356,9 +351,9 @@ pub struct PatternWildcardNode { } impl PatternWildcard { - pub fn new(_span: ObjectRef) -> PatternWildcard { + pub fn new(span: Span) -> PatternWildcard { let node = PatternWildcardNode { - base: PatternNode::base::(), + base: PatternNode::base::(span), }; PatternWildcard(Some(ObjectPtr::new(node))) } @@ -374,9 +369,9 @@ pub struct PatternVarNode { } impl PatternVar { - pub fn new(var: Var, _span: ObjectRef) -> PatternVar { + pub fn new(var: Var, span: Span) -> PatternVar { let node = PatternVarNode { - base: PatternNode::base::(), + base: PatternNode::base::(span), var: var, }; PatternVar(Some(ObjectPtr::new(node))) @@ -397,10 +392,10 @@ impl PatternConstructor { pub fn new( constructor: Constructor, patterns: Array, - _span: ObjectRef, + span: Span, ) -> PatternConstructor { let node = PatternConstructorNode { - base: PatternNode::base::(), + base: PatternNode::base::(span), constructor, patterns, }; @@ -418,9 +413,9 @@ pub struct PatternTupleNode { } impl PatternTuple { - pub fn new(patterns: Array, _span: ObjectRef) -> PatternTuple { + pub fn new(patterns: Array, span: Span) -> PatternTuple { let node = PatternTupleNode { - base: PatternNode::base::(), + base: PatternNode::base::(span), patterns, }; PatternTuple(Some(ObjectPtr::new(node))) @@ -438,7 +433,7 @@ pub struct ClauseNode { } impl Clause { - pub fn new(lhs: Pattern, rhs: Expr, _span: ObjectRef) -> Clause { + pub fn new(lhs: Pattern, rhs: Expr, _span: Span) -> Clause { let node = ClauseNode { base: Object::base::(), lhs, @@ -460,9 +455,9 @@ pub struct MatchNode { } impl Match { - pub fn new(data: Expr, clauses: Array, complete: bool, _span: ObjectRef) -> Match { + pub fn new(data: Expr, clauses: Array, complete: bool, span: Span) -> Match { let node = MatchNode { - base: ExprNode::base::(), + base: ExprNode::base::(span), data, clauses, complete, diff --git a/rust/tvm/src/ir/tir.rs b/rust/tvm/src/ir/tir.rs index ccbe30c95820..dcbec520d3b6 100644 --- a/rust/tvm/src/ir/tir.rs +++ b/rust/tvm/src/ir/tir.rs @@ -18,7 +18,9 @@ */ use super::{PrimExpr, PrimExprNode}; -use crate::runtime::String as TVMString; + +use crate::ir::span::Span; +use crate::runtime::{IsObjectRef, String as TVMString}; use crate::DataType; use tvm_macros::Object; @@ -36,7 +38,7 @@ macro_rules! define_node { impl $name { pub fn new(datatype: DataType, $($id : $t,)*) -> $name { - let base = PrimExprNode::base::<$node>(datatype); + let base = PrimExprNode::base::<$node>(datatype, Span::null()); let node = $node { base, $($id),* }; node.into() } @@ -56,7 +58,6 @@ impl From for IntImm { impl From for PrimExpr { fn from(i: i32) -> PrimExpr { - use crate::runtime::IsObjectRef; IntImm::from(i).upcast() } } diff --git a/rust/tvm/src/ir/ty.rs b/rust/tvm/src/ir/ty.rs index f7c52b51f332..83fdbfeb66aa 100644 --- a/rust/tvm/src/ir/ty.rs +++ b/rust/tvm/src/ir/ty.rs @@ -23,7 +23,7 @@ use tvm_rt::{array::Array, DataType}; use crate::ir::relay::Constructor; use crate::ir::span::Span; use crate::ir::PrimExpr; -use crate::runtime::{string::String as TString, IsObject, Object, ObjectPtr}; +use crate::runtime::{string::String as TString, IsObject, IsObjectRef, Object, ObjectPtr}; #[repr(C)] #[derive(Object, Debug)] @@ -147,8 +147,17 @@ pub struct TupleTypeNode { } impl TupleType { + // todo add coercion + pub fn new(fields: Vec, span: Span) -> Self { + let node = TupleTypeNode { + base: TypeNode::base::(span), + fields: Array::from_vec(fields).unwrap(), + }; + ObjectPtr::new(node).into() + } + pub fn empty() -> TupleType { - todo!() + TupleType::new(vec![], Span::null()) } } @@ -236,7 +245,13 @@ impl TensorType { }; ObjectPtr::new(node).into() } + + pub fn static_sh(shape: Vec, dtype: DataType, span: Span) -> TensorType { + let sh = Array::from_vec(shape.into_iter().map(Into::into).collect()).unwrap(); + Self::new(sh, dtype, span) + } } + // TODO(@jroesch): implement these in future. // // using TypeCall = tvm::TypeCall; diff --git a/rust/tvm/src/lib.rs b/rust/tvm/src/lib.rs index e86420eb70c9..caae07775d21 100644 --- a/rust/tvm/src/lib.rs +++ b/rust/tvm/src/lib.rs @@ -39,7 +39,9 @@ pub use tvm_rt::errors; pub use tvm_rt::function; pub use tvm_rt::module; pub use tvm_rt::ndarray; -pub use tvm_rt::value; + +#[cfg(feature = "python")] +pub mod compiler; pub mod ir; #[cfg(feature = "python")] pub mod python; diff --git a/rust/tvm/src/python.rs b/rust/tvm/src/python.rs index 89558af733b3..c224fb4db372 100644 --- a/rust/tvm/src/python.rs +++ b/rust/tvm/src/python.rs @@ -29,6 +29,8 @@ use pyo3::prelude::*; pub fn load() -> Result { let gil = Python::acquire_gil(); let py = gil.python(); + // let main_mod = initialize(); + //let main_mod = main_mod.as_ref(py); load_python_tvm_(py).map_err(|e| { // We can't display Python exceptions via std::fmt::Display, // so print the error here manually. @@ -36,25 +38,33 @@ pub fn load() -> Result { }) } -// const TVMC_CODE: &'static str = include_str!("tvmc.py"); +pub fn import(mod_to_import: &str) -> PyResult<()> { + let gil = Python::acquire_gil(); + let py = gil.python(); + import_python(py, mod_to_import)?; + Ok(()) +} + +fn import_python<'p, 'b: 'p>(py: Python<'p>, to_import: &'b str) -> PyResult<&'p PyModule> { + let imported_mod = py.import(to_import)?; + Ok(imported_mod) +} fn load_python_tvm_(py: Python) -> PyResult { - let sys = py.import("tvm")?; - let version: String = sys.get("__version__")?.extract()?; - // py.run(TVMC_CODE, None, None)?; + let imported_mod = import_python(py, "tvm")?; + let version: String = imported_mod.get("__version__")?.extract()?; Ok(version) } #[cfg(test)] mod tests { - use super::load_python_tvm_; + use super::*; use anyhow::Result; - use pyo3::prelude::*; #[ignore] #[test] fn test_run() -> Result<()> { - load_python_tvm_(Python::acquire_gil().python()).unwrap(); + load().unwrap(); Ok(()) } } diff --git a/rust/tvm/src/runtime/graph_rt.rs b/rust/tvm/src/runtime/graph_rt.rs index 8b26ebb4ca22..fcc41aca560f 100644 --- a/rust/tvm/src/runtime/graph_rt.rs +++ b/rust/tvm/src/runtime/graph_rt.rs @@ -34,13 +34,23 @@ pub struct GraphRt { } impl GraphRt { + /// Create a graph runtime directly from a runtime module. + pub fn from_module(module: Module, ctx: Context) -> Result { + let default: Box Result> = + module.get_function("default", false)?.into(); + + Ok(Self { + module: default(ctx)?, + }) + } + /// Create a graph runtime from the deprecated graph, lib, ctx triple. pub fn create_from_parts(graph: &str, lib: Module, ctx: Context) -> Result { let runtime_create_fn = Function::get("tvm.graph_runtime.create").unwrap(); let runtime_create_fn_ret = runtime_create_fn.invoke(vec![ graph.into(), - (&lib).into(), + lib.into(), (&ctx.device_type).into(), // NOTE you must pass the device id in as i32 because that's what TVM expects (ctx.device_id as i32).into(), diff --git a/rust/tvm/tests/basics/src/main.rs b/rust/tvm/tests/basics/src/main.rs index e4249a491746..450ab48dc1b2 100644 --- a/rust/tvm/tests/basics/src/main.rs +++ b/rust/tvm/tests/basics/src/main.rs @@ -30,6 +30,7 @@ fn main() { } else { (Context::gpu(0), "gpu") }; + let dtype = DataType::from_str("float32").unwrap(); let mut arr = NDArray::empty(shape, ctx, dtype); arr.copy_from_buffer(data.as_mut_slice()); @@ -38,11 +39,13 @@ fn main() { if !fadd.enabled(ctx_name) { return; } + if cfg!(feature = "gpu") { fadd.import_module(Module::load(&concat!(env!("OUT_DIR"), "/test_add.ptx")).unwrap()); } - fadd.entry() + // todo(@jroesch): fix the entry_name + fadd.get_function("__tvm_main__", false) .expect("module must have entry point") .invoke(vec![(&arr).into(), (&arr).into(), (&ret).into()]) .unwrap(); diff --git a/rust/tvm/tests/basics/src/tvm_add.py b/rust/tvm/tests/basics/src/tvm_add.py index b9672fbf4aaf..3c1fc64d3e36 100755 --- a/rust/tvm/tests/basics/src/tvm_add.py +++ b/rust/tvm/tests/basics/src/tvm_add.py @@ -37,7 +37,6 @@ def main(target, out_dir): s[C].bind(tx, te.thread_axis("threadIdx.x")) fadd = tvm.build(s, [A, B, C], target, target_host="llvm", name="myadd") - fadd.save(osp.join(out_dir, "test_add.o")) if target == "cuda": fadd.imported_modules[0].save(osp.join(out_dir, "test_add.ptx")) diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index cf516d8452e2..d93218c0208c 100755 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -1399,7 +1399,7 @@ void GetPerStoreFeaturesFromFile(const std::string& filename, int max_lines, int Array tensors = (*workload_key_to_tensors)(workload_key); task = SearchTask(ComputeDAG(tensors), workload_key, cur_inp->task->target, cur_inp->task->target_host, cur_inp->task->hardware_params, - cur_inp->task->layout_rewrite_option); + cur_inp->task->layout_rewrite_option, cur_inp->task->task_input_names); task_id = task_cache.size(); // compute min cost for each task @@ -1466,9 +1466,10 @@ void GetPerStoreFeaturesFromMeasurePairs(const Array& inputs, // The measure input is incomplete, rebuild task for incomplete measure pairs read from file try { Array tensors = (*workload_key_to_tensors)(workload_key); - task = SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target, - inputs[i]->task->target_host, inputs[i]->task->hardware_params, - inputs[i]->task->layout_rewrite_option); + task = + SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target, + inputs[i]->task->target_host, inputs[i]->task->hardware_params, + inputs[i]->task->layout_rewrite_option, inputs[i]->task->task_input_names); } catch (std::exception& e) { // Cannot build ComputeDAG from workload key, the task may have not been registered in // this search round diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc index 1120f437b176..5dafa8d98702 100644 --- a/src/auto_scheduler/measure_record.cc +++ b/src/auto_scheduler/measure_record.cc @@ -169,6 +169,12 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { writer->WriteArrayItem(std::string("")); } writer->WriteArrayItem(static_cast(data.layout_rewrite_option)); + writer->WriteArraySeperator(); + writer->BeginArray(false); + for (const auto& i : data.task_input_names) { + writer->WriteArrayItem(std::string(i)); + } + writer->EndArray(); writer->EndArray(); } inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::SearchTaskNode* data) { @@ -200,6 +206,17 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { reader->Read(&int_value); data->layout_rewrite_option = ::tvm::auto_scheduler::LayoutRewriteOption(int_value); s = reader->NextArrayItem(); + if (s) { + reader->BeginArray(); + s = reader->NextArrayItem(); + while (s) { + reader->Read(&str_value); + data->task_input_names.push_back(str_value); + s = reader->NextArrayItem(); + } + // Process the end of array + s = reader->NextArrayItem(); + } ICHECK(!s); } } @@ -444,5 +461,22 @@ TVM_REGISTER_GLOBAL("auto_scheduler.DeserializeMeasureInput").set_body_typed([]( reader.Read(inp.get()); return ObjectRef(inp); }); + +TVM_REGISTER_GLOBAL("auto_scheduler.SerializeSearchTask") + .set_body_typed([](const SearchTask& search_task) { + std::ostringstream os; + dmlc::JSONWriter writer(&os); + writer.Write(*search_task.get()); + return os.str(); + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.DeserializeSearchTask").set_body_typed([](String json) { + std::istringstream ss(json); + dmlc::JSONReader reader(&ss); + auto search_task = make_object(); + reader.Read(search_task.get()); + return ObjectRef(search_task); +}); + } // namespace auto_scheduler } // namespace tvm diff --git a/src/auto_scheduler/search_policy/utils.cc b/src/auto_scheduler/search_policy/utils.cc index d59df6965776..ce8dc39922e0 100644 --- a/src/auto_scheduler/search_policy/utils.cc +++ b/src/auto_scheduler/search_policy/utils.cc @@ -465,6 +465,22 @@ const std::vector& SplitFactorizationMemo::GetFactors(int n) { /********** Utils interface API for ffi **********/ +TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsGetConsumers") + .set_body_typed([](const SearchTask& task, const State& state, int stage_id) { + const std::set& consumers = GetConsumers(task, state, stage_id); + tvm::Map ret; + for (const auto& i : consumers) { + ret.Set(Integer(i), Integer(i)); + } + return ret; + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsIsElementwiseMatch") + .set_body_typed([](const SearchTask& task, const State& state, int stage_id, + int target_stage_id) { + return ElementwiseMatch(task, state, stage_id, target_stage_id); + }); + TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsIsTiled") .set_body_typed([](const Stage& stage) { return IsTiled(stage); }); diff --git a/src/auto_scheduler/search_task.cc b/src/auto_scheduler/search_task.cc index 0abee16fceab..22c2893141cf 100755 --- a/src/auto_scheduler/search_task.cc +++ b/src/auto_scheduler/search_task.cc @@ -114,7 +114,7 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host, Optional hardware_params, - LayoutRewriteOption layout_rewrite_option) { + LayoutRewriteOption layout_rewrite_option, Array task_input_names) { auto node = make_object(); node->compute_dag = std::move(compute_dag); node->workload_key = std::move(workload_key); @@ -127,6 +127,7 @@ SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target targe HardwareParamsNode::GetDefaultHardwareParams(node->target, node->target_host); } node->layout_rewrite_option = layout_rewrite_option; + node->task_input_names = std::move(task_input_names); data_ = std::move(node); } @@ -142,9 +143,9 @@ TVM_REGISTER_GLOBAL("auto_scheduler.HardwareParams") TVM_REGISTER_GLOBAL("auto_scheduler.SearchTask") .set_body_typed([](ComputeDAG compute_dag, String workload_key, Target target, Target target_host, Optional hardware_params, - int layout_rewrite_option) { + int layout_rewrite_option, Array task_input_names) { return SearchTask(compute_dag, workload_key, target, target_host, hardware_params, - LayoutRewriteOption(layout_rewrite_option)); + LayoutRewriteOption(layout_rewrite_option), task_input_names); }); } // namespace auto_scheduler diff --git a/src/node/container.cc b/src/node/container.cc deleted file mode 100644 index b72d5a4cd736..000000000000 --- a/src/node/container.cc +++ /dev/null @@ -1,363 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * Expose container API to frontend. - * \file src/node/container.cc - */ -#include -#include -#include -#include - -#include "../support/str_escape.h" - -namespace tvm { - -// SEQualReduce traits for runtime containers. -struct StringObjTrait { - static constexpr const std::nullptr_t VisitAttrs = nullptr; - - static void SHashReduce(const runtime::StringObj* key, SHashReducer hash_reduce) { - hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes(key->data, key->size)); - } - - static bool SEqualReduce(const runtime::StringObj* lhs, const runtime::StringObj* rhs, - SEqualReducer equal) { - if (lhs == rhs) return true; - if (lhs->size != rhs->size) return false; - if (lhs->data == rhs->data) return true; - return std::memcmp(lhs->data, rhs->data, lhs->size) == 0; - } -}; - -struct RefToObjectPtr : public ObjectRef { - static ObjectPtr Get(const ObjectRef& ref) { return GetDataPtr(ref); } -}; - -TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait) - .set_creator([](const std::string& bytes) { - return RefToObjectPtr::Get(runtime::String(bytes)); - }) - .set_repr_bytes([](const Object* n) -> std::string { - return GetRef(static_cast(n)) - . - operator std::string(); - }); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '"' << support::StrEscape(op->data, op->size) << '"'; - }); - -struct ADTObjTrait { - static constexpr const std::nullptr_t VisitAttrs = nullptr; - - static void SHashReduce(const runtime::ADTObj* key, SHashReducer hash_reduce) { - hash_reduce(key->tag); - hash_reduce(static_cast(key->size)); - for (uint32_t i = 0; i < key->size; ++i) { - hash_reduce((*key)[i]); - } - } - - static bool SEqualReduce(const runtime::ADTObj* lhs, const runtime::ADTObj* rhs, - SEqualReducer equal) { - if (lhs == rhs) return true; - if (lhs->tag != rhs->tag) return false; - if (lhs->size != rhs->size) return false; - - for (uint32_t i = 0; i < lhs->size; ++i) { - if (!equal((*lhs)[i], (*rhs)[i])) return false; - } - return true; - } -}; - -TVM_REGISTER_REFLECTION_VTABLE(runtime::ADTObj, ADTObjTrait); - -struct NDArrayContainerTrait { - static constexpr const std::nullptr_t VisitAttrs = nullptr; - - static void SHashReduce(const runtime::NDArray::Container* key, SHashReducer hash_reduce) { - ICHECK_EQ(key->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor"; - ICHECK(runtime::IsContiguous(key->dl_tensor)) << "Can only hash contiguous tensor"; - hash_reduce(runtime::DataType(key->dl_tensor.dtype)); - hash_reduce(key->dl_tensor.ndim); - for (int i = 0; i < key->dl_tensor.ndim; ++i) { - hash_reduce(key->dl_tensor.shape[i]); - } - hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes( - static_cast(key->dl_tensor.data), runtime::GetDataSize(key->dl_tensor))); - } - - static bool SEqualReduce(const runtime::NDArray::Container* lhs, - const runtime::NDArray::Container* rhs, SEqualReducer equal) { - if (lhs == rhs) return true; - - auto ldt = lhs->dl_tensor.dtype; - auto rdt = rhs->dl_tensor.dtype; - ICHECK_EQ(lhs->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor"; - ICHECK_EQ(rhs->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor"; - ICHECK(runtime::IsContiguous(lhs->dl_tensor)) << "Can only compare contiguous tensor"; - ICHECK(runtime::IsContiguous(rhs->dl_tensor)) << "Can only compare contiguous tensor"; - - if (lhs->dl_tensor.ndim != rhs->dl_tensor.ndim) return false; - for (int i = 0; i < lhs->dl_tensor.ndim; ++i) { - if (!equal(lhs->dl_tensor.shape[i], rhs->dl_tensor.shape[i])) return false; - } - if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) { - size_t data_size = runtime::GetDataSize(lhs->dl_tensor); - return std::memcmp(lhs->dl_tensor.data, rhs->dl_tensor.data, data_size) == 0; - } else { - return false; - } - } -}; - -TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrait); - -struct ArrayNodeTrait { - static constexpr const std::nullptr_t VisitAttrs = nullptr; - - static void SHashReduce(const ArrayNode* key, SHashReducer hash_reduce) { - hash_reduce(static_cast(key->size())); - for (size_t i = 0; i < key->size(); ++i) { - hash_reduce(key->at(i)); - } - } - - static bool SEqualReduce(const ArrayNode* lhs, const ArrayNode* rhs, SEqualReducer equal) { - if (lhs->size() != rhs->size()) return false; - for (size_t i = 0; i < lhs->size(); ++i) { - if (!equal(lhs->at(i), rhs->at(i))) return false; - } - return true; - } -}; - -TVM_REGISTER_OBJECT_TYPE(ArrayNode); -TVM_REGISTER_REFLECTION_VTABLE(ArrayNode, ArrayNodeTrait) - .set_creator([](const std::string&) -> ObjectPtr { - return ::tvm::runtime::make_object(); - }); - -TVM_REGISTER_GLOBAL("node.Array").set_body([](TVMArgs args, TVMRetValue* ret) { - std::vector data; - for (int i = 0; i < args.size(); ++i) { - if (args[i].type_code() != kTVMNullptr) { - data.push_back(args[i].operator ObjectRef()); - } else { - data.push_back(ObjectRef(nullptr)); - } - } - *ret = Array(data); -}); - -TVM_REGISTER_GLOBAL("node.ArrayGetItem").set_body([](TVMArgs args, TVMRetValue* ret) { - int64_t i = args[1]; - ICHECK_EQ(args[0].type_code(), kTVMObjectHandle); - Object* ptr = static_cast(args[0].value().v_handle); - ICHECK(ptr->IsInstance()); - auto* n = static_cast(ptr); - ICHECK_LT(static_cast(i), n->size()) << "out of bound of array"; - *ret = n->at(i); -}); - -TVM_REGISTER_GLOBAL("node.ArraySize").set_body([](TVMArgs args, TVMRetValue* ret) { - ICHECK_EQ(args[0].type_code(), kTVMObjectHandle); - Object* ptr = static_cast(args[0].value().v_handle); - ICHECK(ptr->IsInstance()); - *ret = static_cast(static_cast(ptr)->size()); -}); - -struct MapNodeTrait { - static constexpr const std::nullptr_t VisitAttrs = nullptr; - - static void SHashReduceForOMap(const MapNode* key, SHashReducer hash_reduce) { - // SHash's var handling depends on the determinism of traversal. - // NOTE: only book-keep the mapped hash keys. - // This resolves common use cases where we want to store - // Map where Var is defined in the function - // parameters. - using KV = std::pair; - std::vector temp; - for (const auto& kv : *key) { - size_t hashed_value; - if (hash_reduce->LookupHashedValue(kv.first, &hashed_value)) { - temp.emplace_back(hashed_value, kv.second); - } - } - // sort by the hash key of the keys. - std::sort(temp.begin(), temp.end(), - [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; }); - // add size to the hash - hash_reduce(static_cast(key->size())); - // hash the content - for (size_t i = 0; i < temp.size();) { - size_t k = i + 1; - for (; k < temp.size() && temp[k].first == temp[i].first; ++k) { - } - // ties are rare, but we need to skip them to make the hash determinsitic - if (k == i + 1) { - hash_reduce->SHashReduceHashedValue(temp[i].first); - hash_reduce(temp[i].second); - } - i = k; - } - } - - static void SHashReduceForSMap(const MapNode* key, SHashReducer hash_reduce) { - // NOTE: only book-keep the mapped hash keys. - // This resolves common use cases where we want to store - // Map where Var is defined in the function - // parameters. - using KV = std::pair; - std::vector temp; - for (const auto& kv : *key) { - temp.push_back(std::make_pair(Downcast(kv.first), kv.second)); - } - // sort by the hash key of the keys. - std::sort(temp.begin(), temp.end(), - [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; }); - // NOTE: we won't have ties - // add size to the hash after sorting. - hash_reduce(static_cast(key->size())); - // hash the content - for (size_t i = 0; i < temp.size(); ++i) { - hash_reduce(temp[i].first); - hash_reduce(temp[i].second); - } - } - - static void SHashReduce(const MapNode* key, SHashReducer hash_reduce) { - bool is_str_map = std::all_of(key->begin(), key->end(), [](const auto& v) { - return v.first->template IsInstance(); - }); - if (is_str_map) { - SHashReduceForSMap(key, hash_reduce); - } else { - SHashReduceForOMap(key, hash_reduce); - } - } - - static bool SEqualReduceForOMap(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) { - for (const auto& kv : *lhs) { - // Only allow equal checking if the keys are already mapped - // This resolves common use cases where we want to store - // Map where Var is defined in the function - // parameters. - ObjectRef rhs_key = equal->MapLhsToRhs(kv.first); - if (!rhs_key.defined()) return false; - auto it = rhs->find(rhs_key); - if (it == rhs->end()) return false; - if (!equal(kv.second, it->second)) return false; - } - return true; - } - - static bool SEqualReduceForSMap(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) { - for (const auto& kv : *lhs) { - auto it = rhs->find(kv.first); - if (it == rhs->end()) return false; - if (!equal(kv.second, it->second)) return false; - } - return true; - } - - static bool SEqualReduce(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) { - if (rhs->size() != lhs->size()) return false; - if (rhs->size() == 0) return true; - bool ls = std::all_of(lhs->begin(), lhs->end(), - [](const auto& v) { return v.first->template IsInstance(); }); - bool rs = std::all_of(rhs->begin(), rhs->end(), - [](const auto& v) { return v.first->template IsInstance(); }); - if (ls != rs) { - return false; - } - return (ls && rs) ? SEqualReduceForSMap(lhs, rhs, equal) : SEqualReduceForOMap(lhs, rhs, equal); - } -}; - -TVM_REGISTER_OBJECT_TYPE(MapNode); -TVM_REGISTER_REFLECTION_VTABLE(MapNode, MapNodeTrait) - .set_creator([](const std::string&) -> ObjectPtr { return MapNode::Empty(); }); - -TVM_REGISTER_GLOBAL("node.Map").set_body([](TVMArgs args, TVMRetValue* ret) { - ICHECK_EQ(args.size() % 2, 0); - std::unordered_map data; - for (int i = 0; i < args.num_args; i += 2) { - ObjectRef k = - String::CanConvertFrom(args[i]) ? args[i].operator String() : args[i].operator ObjectRef(); - ObjectRef v = args[i + 1]; - data.emplace(std::move(k), std::move(v)); - } - *ret = Map(std::move(data)); -}); - -TVM_REGISTER_GLOBAL("node.MapSize").set_body([](TVMArgs args, TVMRetValue* ret) { - ICHECK_EQ(args[0].type_code(), kTVMObjectHandle); - Object* ptr = static_cast(args[0].value().v_handle); - ICHECK(ptr->IsInstance()); - auto* n = static_cast(ptr); - *ret = static_cast(n->size()); -}); - -TVM_REGISTER_GLOBAL("node.MapGetItem").set_body([](TVMArgs args, TVMRetValue* ret) { - ICHECK_EQ(args[0].type_code(), kTVMObjectHandle); - Object* ptr = static_cast(args[0].value().v_handle); - ICHECK(ptr->IsInstance()); - - auto* n = static_cast(ptr); - auto it = n->find(String::CanConvertFrom(args[1]) ? args[1].operator String() - : args[1].operator ObjectRef()); - ICHECK(it != n->end()) << "cannot find the corresponding key in the Map"; - *ret = (*it).second; -}); - -TVM_REGISTER_GLOBAL("node.MapCount").set_body([](TVMArgs args, TVMRetValue* ret) { - ICHECK_EQ(args[0].type_code(), kTVMObjectHandle); - Object* ptr = static_cast(args[0].value().v_handle); - ICHECK(ptr->IsInstance()); - const MapNode* n = static_cast(ptr); - int64_t cnt = n->count(String::CanConvertFrom(args[1]) ? args[1].operator String() - : args[1].operator ObjectRef()); - *ret = cnt; -}); - -TVM_REGISTER_GLOBAL("node.MapItems").set_body([](TVMArgs args, TVMRetValue* ret) { - ICHECK_EQ(args[0].type_code(), kTVMObjectHandle); - Object* ptr = static_cast(args[0].value().v_handle); - auto* n = static_cast(ptr); - Array rkvs; - for (const auto& kv : *n) { - if (kv.first->IsInstance()) { - rkvs.push_back(Downcast(kv.first)); - } else { - rkvs.push_back(kv.first); - } - rkvs.push_back(kv.second); - } - *ret = std::move(rkvs); -}); - -#if (USE_FALLBACK_STL_MAP == 0) -TVM_DLL constexpr uint64_t DenseMapNode::kNextProbeLocation[]; -#endif -} // namespace tvm diff --git a/src/node/reflection.cc b/src/node/reflection.cc index 9dc9d330bb77..79a53aa26440 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -22,9 +22,9 @@ * \file node/reflection.cc */ #include -#include #include #include +#include #include namespace tvm { diff --git a/src/node/serialization.cc b/src/node/serialization.cc index c7e4d27c8b2c..ad42799b55e5 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -24,9 +24,9 @@ #include #include #include -#include #include #include +#include #include #include #include diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index e0b729d3f103..efedd1b99d6d 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -28,6 +28,7 @@ #include #include +#include "../support/str_escape.h" #include "../support/utils.h" namespace tvm { @@ -260,4 +261,241 @@ size_t StructuralHash::operator()(const ObjectRef& object) const { return VarCountingSHashHandler().Hash(object, false); } +// SEQualReduce traits for runtime containers. +struct StringObjTrait { + static constexpr const std::nullptr_t VisitAttrs = nullptr; + + static void SHashReduce(const runtime::StringObj* key, SHashReducer hash_reduce) { + hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes(key->data, key->size)); + } + + static bool SEqualReduce(const runtime::StringObj* lhs, const runtime::StringObj* rhs, + SEqualReducer equal) { + if (lhs == rhs) return true; + if (lhs->size != rhs->size) return false; + if (lhs->data == rhs->data) return true; + return std::memcmp(lhs->data, rhs->data, lhs->size) == 0; + } +}; + +struct RefToObjectPtr : public ObjectRef { + static ObjectPtr Get(const ObjectRef& ref) { return GetDataPtr(ref); } +}; + +TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait) + .set_creator([](const std::string& bytes) { + return RefToObjectPtr::Get(runtime::String(bytes)); + }) + .set_repr_bytes([](const Object* n) -> std::string { + return GetRef(static_cast(n)) + . + operator std::string(); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '"' << support::StrEscape(op->data, op->size) << '"'; + }); + +struct ADTObjTrait { + static constexpr const std::nullptr_t VisitAttrs = nullptr; + + static void SHashReduce(const runtime::ADTObj* key, SHashReducer hash_reduce) { + hash_reduce(key->tag); + hash_reduce(static_cast(key->size)); + for (uint32_t i = 0; i < key->size; ++i) { + hash_reduce((*key)[i]); + } + } + + static bool SEqualReduce(const runtime::ADTObj* lhs, const runtime::ADTObj* rhs, + SEqualReducer equal) { + if (lhs == rhs) return true; + if (lhs->tag != rhs->tag) return false; + if (lhs->size != rhs->size) return false; + + for (uint32_t i = 0; i < lhs->size; ++i) { + if (!equal((*lhs)[i], (*rhs)[i])) return false; + } + return true; + } +}; + +TVM_REGISTER_REFLECTION_VTABLE(runtime::ADTObj, ADTObjTrait); + +struct NDArrayContainerTrait { + static constexpr const std::nullptr_t VisitAttrs = nullptr; + + static void SHashReduce(const runtime::NDArray::Container* key, SHashReducer hash_reduce) { + ICHECK_EQ(key->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor"; + ICHECK(runtime::IsContiguous(key->dl_tensor)) << "Can only hash contiguous tensor"; + hash_reduce(runtime::DataType(key->dl_tensor.dtype)); + hash_reduce(key->dl_tensor.ndim); + for (int i = 0; i < key->dl_tensor.ndim; ++i) { + hash_reduce(key->dl_tensor.shape[i]); + } + hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes( + static_cast(key->dl_tensor.data), runtime::GetDataSize(key->dl_tensor))); + } + + static bool SEqualReduce(const runtime::NDArray::Container* lhs, + const runtime::NDArray::Container* rhs, SEqualReducer equal) { + if (lhs == rhs) return true; + + auto ldt = lhs->dl_tensor.dtype; + auto rdt = rhs->dl_tensor.dtype; + ICHECK_EQ(lhs->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor"; + ICHECK_EQ(rhs->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor"; + ICHECK(runtime::IsContiguous(lhs->dl_tensor)) << "Can only compare contiguous tensor"; + ICHECK(runtime::IsContiguous(rhs->dl_tensor)) << "Can only compare contiguous tensor"; + + if (lhs->dl_tensor.ndim != rhs->dl_tensor.ndim) return false; + for (int i = 0; i < lhs->dl_tensor.ndim; ++i) { + if (!equal(lhs->dl_tensor.shape[i], rhs->dl_tensor.shape[i])) return false; + } + if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) { + size_t data_size = runtime::GetDataSize(lhs->dl_tensor); + return std::memcmp(lhs->dl_tensor.data, rhs->dl_tensor.data, data_size) == 0; + } else { + return false; + } + } +}; + +TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrait); + +struct ArrayNodeTrait { + static constexpr const std::nullptr_t VisitAttrs = nullptr; + + static void SHashReduce(const ArrayNode* key, SHashReducer hash_reduce) { + hash_reduce(static_cast(key->size())); + for (size_t i = 0; i < key->size(); ++i) { + hash_reduce(key->at(i)); + } + } + + static bool SEqualReduce(const ArrayNode* lhs, const ArrayNode* rhs, SEqualReducer equal) { + if (lhs->size() != rhs->size()) return false; + for (size_t i = 0; i < lhs->size(); ++i) { + if (!equal(lhs->at(i), rhs->at(i))) return false; + } + return true; + } +}; +TVM_REGISTER_REFLECTION_VTABLE(ArrayNode, ArrayNodeTrait) + .set_creator([](const std::string&) -> ObjectPtr { + return ::tvm::runtime::make_object(); + }); + +struct MapNodeTrait { + static constexpr const std::nullptr_t VisitAttrs = nullptr; + + static void SHashReduceForOMap(const MapNode* key, SHashReducer hash_reduce) { + // SHash's var handling depends on the determinism of traversal. + // NOTE: only book-keep the mapped hash keys. + // This resolves common use cases where we want to store + // Map where Var is defined in the function + // parameters. + using KV = std::pair; + std::vector temp; + for (const auto& kv : *key) { + size_t hashed_value; + if (hash_reduce->LookupHashedValue(kv.first, &hashed_value)) { + temp.emplace_back(hashed_value, kv.second); + } + } + // sort by the hash key of the keys. + std::sort(temp.begin(), temp.end(), + [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; }); + // add size to the hash + hash_reduce(static_cast(key->size())); + // hash the content + for (size_t i = 0; i < temp.size();) { + size_t k = i + 1; + for (; k < temp.size() && temp[k].first == temp[i].first; ++k) { + } + // ties are rare, but we need to skip them to make the hash determinsitic + if (k == i + 1) { + hash_reduce->SHashReduceHashedValue(temp[i].first); + hash_reduce(temp[i].second); + } + i = k; + } + } + + static void SHashReduceForSMap(const MapNode* key, SHashReducer hash_reduce) { + // NOTE: only book-keep the mapped hash keys. + // This resolves common use cases where we want to store + // Map where Var is defined in the function + // parameters. + using KV = std::pair; + std::vector temp; + for (const auto& kv : *key) { + temp.push_back(std::make_pair(Downcast(kv.first), kv.second)); + } + // sort by the hash key of the keys. + std::sort(temp.begin(), temp.end(), + [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; }); + // NOTE: we won't have ties + // add size to the hash after sorting. + hash_reduce(static_cast(key->size())); + // hash the content + for (size_t i = 0; i < temp.size(); ++i) { + hash_reduce(temp[i].first); + hash_reduce(temp[i].second); + } + } + + static void SHashReduce(const MapNode* key, SHashReducer hash_reduce) { + bool is_str_map = std::all_of(key->begin(), key->end(), [](const auto& v) { + return v.first->template IsInstance(); + }); + if (is_str_map) { + SHashReduceForSMap(key, hash_reduce); + } else { + SHashReduceForOMap(key, hash_reduce); + } + } + + static bool SEqualReduceForOMap(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) { + for (const auto& kv : *lhs) { + // Only allow equal checking if the keys are already mapped + // This resolves common use cases where we want to store + // Map where Var is defined in the function + // parameters. + ObjectRef rhs_key = equal->MapLhsToRhs(kv.first); + if (!rhs_key.defined()) return false; + auto it = rhs->find(rhs_key); + if (it == rhs->end()) return false; + if (!equal(kv.second, it->second)) return false; + } + return true; + } + + static bool SEqualReduceForSMap(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) { + for (const auto& kv : *lhs) { + auto it = rhs->find(kv.first); + if (it == rhs->end()) return false; + if (!equal(kv.second, it->second)) return false; + } + return true; + } + + static bool SEqualReduce(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) { + if (rhs->size() != lhs->size()) return false; + if (rhs->size() == 0) return true; + bool ls = std::all_of(lhs->begin(), lhs->end(), + [](const auto& v) { return v.first->template IsInstance(); }); + bool rs = std::all_of(rhs->begin(), rhs->end(), + [](const auto& v) { return v.first->template IsInstance(); }); + if (ls != rs) { + return false; + } + return (ls && rs) ? SEqualReduceForSMap(lhs, rhs, equal) : SEqualReduceForOMap(lhs, rhs, equal); + } +}; +TVM_REGISTER_REFLECTION_VTABLE(MapNode, MapNodeTrait) + .set_creator([](const std::string&) -> ObjectPtr { return MapNode::Empty(); }); + } // namespace tvm diff --git a/src/printer/meta_data.h b/src/printer/meta_data.h index 233da1baffd8..f76c32d353cf 100644 --- a/src/printer/meta_data.h +++ b/src/printer/meta_data.h @@ -24,8 +24,8 @@ #ifndef TVM_PRINTER_META_DATA_H_ #define TVM_PRINTER_META_DATA_H_ -#include #include +#include #include #include diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index ed09e4f6eb32..ae975a5f3240 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -692,6 +692,17 @@ class CompileEngineImpl : public CompileEngineNode { return items; } + // List all items in the shape_func_cache. + Array ListShapeFuncItems() { + std::lock_guard lock(mutex_); + Array items; + for (auto& kv : shape_func_cache_) { + items.push_back(kv.first); + items.push_back(kv.second); + } + return items; + } + /*! * \brief Get the cache key of the function that is being lowered currently * \return the cache key @@ -882,6 +893,13 @@ TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems").set_body_typed([](C return ptr->ListItems(); }); +TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListShapeFuncItems") + .set_body_typed([](CompileEngine self) { + CompileEngineImpl* ptr = dynamic_cast(self.operator->()); + ICHECK(ptr != nullptr); + return ptr->ListShapeFuncItems(); + }); + TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGetCurrentCCacheKey") .set_body_typed([](CompileEngine self) { CompileEngineImpl* ptr = dynamic_cast(self.operator->()); diff --git a/src/relay/backend/contrib/codegen_json/codegen_json.h b/src/relay/backend/contrib/codegen_json/codegen_json.h index 859ef8c9bdb2..192e09140375 100644 --- a/src/relay/backend/contrib/codegen_json/codegen_json.h +++ b/src/relay/backend/contrib/codegen_json/codegen_json.h @@ -26,7 +26,6 @@ #include #include -#include #include #include #include diff --git a/src/relay/backend/contrib/ethosn/capabilities.h b/src/relay/backend/contrib/ethosn/capabilities.h index 8c7ee6a0d009..cc14ca101da6 100644 --- a/src/relay/backend/contrib/ethosn/capabilities.h +++ b/src/relay/backend/contrib/ethosn/capabilities.h @@ -45,7 +45,7 @@ namespace ethosn { * variant[2] - Ethos-N37 * variant[3] - Ethos-N78 */ -#if _ETHOSN_API_VERSION_ == 2008 +#if _ETHOSN_API_VERSION_ == 2011 static std::vector variants[4] = { { 0x03, 0x00, 0x00, 0x00, 0x78, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, @@ -84,40 +84,50 @@ static std::vector variants[4] = { 0x10, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, - 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, }}; #else -static std::vector variants[3] = { +static std::vector variants[4] = { { - 0x02, 0x00, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x78, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x01, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, }, { - 0x02, 0x00, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x78, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x01, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, }, { - 0x02, 0x00, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x78, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x01, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + { + 0x03, 0x00, 0x00, 0x00, 0x78, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x07, 0x00, 0x02, 0x00, + 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, }}; #endif } // namespace ethosn diff --git a/src/relay/backend/contrib/ethosn/codegen.cc b/src/relay/backend/contrib/ethosn/codegen.cc index 3097a300a0d9..5e052b3e4fd6 100644 --- a/src/relay/backend/contrib/ethosn/codegen.cc +++ b/src/relay/backend/contrib/ethosn/codegen.cc @@ -198,8 +198,19 @@ sl::TensorsAndId MakeOps(const sl::TensorAndId& op) { NetworkWithIDs ConstructNetworkVisitor::Construct(const Function& func) { // Initialise everything +#if _ETHOSN_API_VERSION_ == 2011 + auto ctx = transform::PassContext::Current(); + auto cfg = ctx->GetConfig("relay.ext.ethos-n.options"); + if (!cfg.defined()) { + cfg = AttrsWithDefaultValues(); + } +#endif NetworkWithIDs network_with_ids; +#if _ETHOSN_API_VERSION_ == 2011 + network_ = sl::CreateNetwork(variants[cfg.value()->variant]); +#else network_ = sl::CreateNetwork(); +#endif network_with_ids.network = network_; operand_table_.clear(); @@ -561,7 +572,11 @@ sl::CompilationOptions EthosnCompiler::CreateOptions() { cfg = AttrsWithDefaultValues(); } +#if _ETHOSN_API_VERSION_ == 2011 + sl::CompilationOptions options; +#else sl::CompilationOptions options(variants[cfg.value()->variant]); +#endif options.m_Strategy0 = cfg.value()->strategy0; options.m_Strategy1 = cfg.value()->strategy1; options.m_Strategy3 = cfg.value()->strategy3; @@ -575,15 +590,13 @@ sl::CompilationOptions EthosnCompiler::CreateOptions() { options.m_BlockConfig8x32 = cfg.value()->block_config_8x32; options.m_BlockConfig8x8 = cfg.value()->block_config_8x8; options.m_EnableIntermediateCompression = cfg.value()->enable_intermediate_compression; - options.m_DisableWinograd = cfg.value()->disable_winograd; +#if _ETHOSN_API_VERSION_ == 2008 options.m_DebugInfo.m_DumpDebugFiles = cfg.value()->dump_debug_files; +#endif + options.m_DisableWinograd = cfg.value()->disable_winograd; options.m_DebugInfo.m_DebugDir = cfg.value()->debug_dir; -#if _ETHOSN_API_VERSION_ == 2008 options.m_CompilerAlgorithm = sl::EthosNCompilerAlgorithmFromString(cfg.value()->compiler_algorithm.c_str()); -#else - options.m_EnableCascading = cfg.value()->enable_cascading; -#endif return options; } @@ -606,6 +619,175 @@ std::pair, std::vector> EthosnCompiler::GetInput return std::make_pair(input_order, output_order); } +#if _ETHOSN_API_VERSION_ == 2011 +auto ctx = transform::PassContext::Current(); +auto cfg = ctx -> GetConfig("relay.ext.ethos-n.options").defined() + ? ctx -> GetConfig("relay.ext.ethos-n.options") + : AttrsWithDefaultValues(); +auto m_Queries = sl::SupportQueries(variants[cfg.value()->variant]); +#endif + +TVM_REGISTER_GLOBAL("relay.ethos-n.support.conv2d") + .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { + Call call = args[0]; + ConvolutionParams params; + auto err = EthosnAPI::QnnConv2d(call, ¶ms); +#if _ETHOSN_API_VERSION_ == 2011 + if (params.is_depthwise) { + *rv = !err && + m_Queries.IsDepthwiseConvolutionSupported(params.bias_info, params.weights_info, + params.conv_info, params.activation_info); + } else { + *rv = !err && m_Queries.IsConvolutionSupported(params.bias_info, params.weights_info, + params.conv_info, params.activation_info); + } +#else + if (params.is_depthwise) { + *rv = !err && sl::IsDepthwiseConvolutionSupported(params.bias_info, params.weights_info, + params.conv_info, params.activation_info); + } else { + *rv = !err && sl::IsConvolutionSupported(params.bias_info, params.weights_info, + params.conv_info, params.activation_info); + } +#endif + }); + +TVM_REGISTER_GLOBAL("relay.ethos-n.support.fc") + .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { + Call call = args[0]; + FullyConnectedParams params; + auto err = EthosnAPI::QnnFullyConnected(call, ¶ms); +#if _ETHOSN_API_VERSION_ == 2011 + *rv = !err && m_Queries.IsFullyConnectedSupported(params.bias_info, params.weights_info, + params.fc_info, params.input_info); +#else + *rv = !err && sl::IsFullyConnectedSupported(params.bias_info, params.weights_info, + params.fc_info, params.input_info); +#endif + }); + +TVM_REGISTER_GLOBAL("relay.ethos-n.support.max_pool2d") + .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { + Call call = args[0]; + MaxPool2DParams params; + auto err = EthosnAPI::MaxPool2D(call, ¶ms); +#if _ETHOSN_API_VERSION_ == 2011 + *rv = !err && m_Queries.IsPoolingSupported(params.pool_info, params.input_info); +#else + *rv = !err && sl::IsPoolingSupported(params.pool_info, params.input_info); +#endif + }); + +TVM_REGISTER_GLOBAL("relay.ethos-n.support.avg_pool2d") + .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { + Call call = args[0]; + AvgPool2DParams params; + auto err = EthosnAPI::AvgPool2D(call, ¶ms); +#if _ETHOSN_API_VERSION_ == 2011 + *rv = !err && m_Queries.IsPoolingSupported(params.pool_info, params.input_info); +#else + *rv = !err && sl::IsPoolingSupported(params.pool_info, params.input_info); +#endif + }); + +TVM_REGISTER_GLOBAL("relay.ethos-n.support.reshape") + .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { + Call call = args[0]; + ReshapeParams params; + auto err = EthosnAPI::Reshape(call, ¶ms); +#if _ETHOSN_API_VERSION_ == 2011 + *rv = !err && m_Queries.IsReshapeSupported(params.new_shape, params.input_info); +#else + *rv = !err && sl::IsReshapeSupported(params.new_shape, params.input_info); +#endif + }); + +TVM_REGISTER_GLOBAL("relay.ethos-n.support.addition") + .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { + Call call = args[0]; + AdditionParams params; + auto err = EthosnAPI::Addition(call, ¶ms); +#if _ETHOSN_API_VERSION_ == 2011 + *rv = !err && m_Queries.IsAdditionSupported(params.lhs_info, params.rhs_info, + params.output_quantization_info); +#else + *rv = !err && sl::IsAdditionSupported(params.lhs_info, params.rhs_info, + params.output_quantization_info); +#endif + }); + +TVM_REGISTER_GLOBAL("relay.ethos-n.support.sigmoid") + .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { + Call call = args[0]; + SigmoidParams params; + auto err = EthosnAPI::Sigmoid(call, ¶ms); +#if _ETHOSN_API_VERSION_ == 2011 + *rv = !err && m_Queries.IsSigmoidSupported(params.input_info); +#else + *rv = !err && sl::IsSigmoidSupported(params.input_info); +#endif + }); + +TVM_REGISTER_GLOBAL("relay.ethos-n.support.concatenate") + .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { + Call call = args[0]; + ConcatenateParams params; + auto err = EthosnAPI::Concatenate(call, ¶ms); +#if _ETHOSN_API_VERSION_ == 2011 + *rv = !err && m_Queries.IsConcatenationSupported(params.input_infos, params.concat_info); +#else + *rv = !err && sl::IsConcatenationSupported(params.input_infos, params.concat_info); +#endif + }); + +TVM_REGISTER_GLOBAL("relay.ethos-n.support.split") + .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { + Call call = args[0]; + SplitParams params; + auto err = EthosnAPI::Split(call, ¶ms); +#if _ETHOSN_API_VERSION_ == 2011 + *rv = !err && m_Queries.IsSplitSupported(params.input_info, params.split_info); +#else + *rv = !err && sl::IsSplitSupported(params.input_info, params.split_info); +#endif + }); + +TVM_REGISTER_GLOBAL("relay.ethos-n.support.depth_to_space") + .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { + Call call = args[0]; + DepthToSpaceParams params; + auto err = EthosnAPI::DepthToSpace(call, ¶ms); +#if _ETHOSN_API_VERSION_ == 2011 + *rv = !err && m_Queries.IsDepthToSpaceSupported(params.input_info, params.depth_info); +#else + *rv = !err && sl::IsDepthToSpaceSupported(params.input_info, params.depth_info); +#endif + }); + +TVM_REGISTER_GLOBAL("relay.ethos-n.support.relu") + .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { + Call call = args[0]; + ReluParams params; + auto err = EthosnAPI::Relu(call, ¶ms); +#if _ETHOSN_API_VERSION_ == 2011 + *rv = !err && m_Queries.IsReluSupported(params.relu_info, params.input_info); +#else + *rv = !err && sl::IsReluSupported(params.relu_info, params.input_info); +#endif + }); + +TVM_REGISTER_GLOBAL("relay.ethos-n.query").set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { +#if defined ETHOSN_HW + *rv = true; +#else + *rv = false; +#endif +}); + +TVM_REGISTER_GLOBAL("relay.ethos-n.api.version").set_body_typed([]() -> int { + return _ETHOSN_API_VERSION_; +}); + } // namespace ethosn } // namespace contrib } // namespace relay diff --git a/src/relay/backend/contrib/ethosn/codegen_ethosn.h b/src/relay/backend/contrib/ethosn/codegen_ethosn.h index 9887a2b3ad78..e44aa31d6b13 100644 --- a/src/relay/backend/contrib/ethosn/codegen_ethosn.h +++ b/src/relay/backend/contrib/ethosn/codegen_ethosn.h @@ -240,14 +240,12 @@ struct EthosnCompilerConfigNode : public tvm::AttrsNode int { - return _ETHOSN_API_VERSION_; -}); - } // namespace ethosn } // namespace contrib } // namespace relay diff --git a/src/relay/backend/contrib/ethosn/ethosn_api_version.h b/src/relay/backend/contrib/ethosn/ethosn_api_version.h index 618b702da333..78f08950bb48 100644 --- a/src/relay/backend/contrib/ethosn/ethosn_api_version.h +++ b/src/relay/backend/contrib/ethosn/ethosn_api_version.h @@ -29,10 +29,12 @@ * along with associated compatibility measures when no * longer necessary. */ +#ifndef ETHOSN_API_VERSION #define _ETHOSN_API_VERSION_ 2008 -#ifndef COMPILER_ALGORITHM_MODE -#undef _ETHOSN_API_VERSION_ -#define _ETHOSN_API_VERSION_ 2005 +#elif ~(~ETHOSN_API_VERSION + 0) == 0 && ~(~ETHOSN_API_VERSION + 1) == 1 +#define _ETHOSN_API_VERSION_ 2008 +#else +#define _ETHOSN_API_VERSION_ ETHOSN_API_VERSION #endif #endif // TVM_RELAY_BACKEND_CONTRIB_ETHOSN_ETHOSN_API_VERSION_H_ diff --git a/src/relay/backend/contrib/tensorrt/codegen.cc b/src/relay/backend/contrib/tensorrt/codegen.cc index cb648333df8d..059dbc192a04 100644 --- a/src/relay/backend/contrib/tensorrt/codegen.cc +++ b/src/relay/backend/contrib/tensorrt/codegen.cc @@ -156,6 +156,9 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer { // with slice_mode = "size", attrs->end_value mean the size of the slice int end_value = attrs->end.value()[i].as()->value; size_value = (end_value == -1) ? ishape[i] - begin_value : end_value; + } else { + LOG(FATAL) << "Unexpected slice_mode " << attrs->slice_mode << ", expected end or size"; + throw; } ICHECK_GT(size_value, 0); size.push_back(std::to_string(size_value)); diff --git a/src/relay/backend/param_dict.cc b/src/relay/backend/param_dict.cc index 1d7e08abcdde..bb0fad9142c1 100644 --- a/src/relay/backend/param_dict.cc +++ b/src/relay/backend/param_dict.cc @@ -31,70 +31,24 @@ #include #include +#include "../../runtime/file_utils.h" + namespace tvm { namespace relay { using namespace runtime; -TVM_REGISTER_GLOBAL("tvm.relay._save_param_dict").set_body([](TVMArgs args, TVMRetValue* rv) { - ICHECK_EQ(args.size() % 2, 0u); - // `args` is in the form "key, value, key, value, ..." - size_t num_params = args.size() / 2; - std::vector names; - names.reserve(num_params); - std::vector arrays; - arrays.reserve(num_params); - for (size_t i = 0; i < num_params * 2; i += 2) { - names.emplace_back(args[i].operator String()); - arrays.emplace_back(args[i + 1].operator DLTensor*()); - } - std::string bytes; - dmlc::MemoryStringStream strm(&bytes); - dmlc::Stream* fo = &strm; - uint64_t header = kTVMNDArrayListMagic, reserved = 0; - fo->Write(header); - fo->Write(reserved); - fo->Write(names); - { - uint64_t sz = static_cast(arrays.size()); - fo->Write(sz); - for (size_t i = 0; i < sz; ++i) { - tvm::runtime::SaveDLTensor(fo, arrays[i]); - } - } - TVMByteArray arr; - arr.data = bytes.c_str(); - arr.size = bytes.length(); - *rv = arr; -}); - -TVM_REGISTER_GLOBAL("tvm.relay._load_param_dict").set_body([](TVMArgs args, TVMRetValue* rv) { - std::string bytes = args[0]; - std::vector names; - dmlc::MemoryStringStream memstrm(&bytes); - dmlc::Stream* strm = &memstrm; - uint64_t header, reserved; - ICHECK(strm->Read(&header)) << "Invalid parameters file format"; - ICHECK(header == kTVMNDArrayListMagic) << "Invalid parameters file format"; - ICHECK(strm->Read(&reserved)) << "Invalid parameters file format"; - ICHECK(strm->Read(&names)) << "Invalid parameters file format"; - uint64_t sz; - strm->Read(&sz, sizeof(sz)); - size_t size = static_cast(sz); - ICHECK(size == names.size()) << "Invalid parameters file format"; - tvm::Array ret; - for (size_t i = 0; i < size; ++i) { - tvm::runtime::NDArray temp; - temp.Load(strm); - auto n = tvm::make_object(); - n->name = std::move(names[i]); - n->array = temp; - ret.push_back(NamedNDArray(n)); - } - *rv = ret; +TVM_REGISTER_GLOBAL("tvm.relay._save_param_dict") + .set_body_typed([](const Map& params) { + std::string s = ::tvm::runtime::SaveParams(params); + // copy return array so it is owned by the ret value + TVMRetValue rv; + rv = TVMByteArray{s.data(), s.size()}; + return rv; + }); +TVM_REGISTER_GLOBAL("tvm.relay._load_param_dict").set_body_typed([](const String& s) { + return ::tvm::runtime::LoadParams(s); }); -TVM_REGISTER_NODE_TYPE(NamedNDArrayNode); - } // namespace relay } // namespace tvm diff --git a/src/relay/backend/param_dict.h b/src/relay/backend/param_dict.h index 384201f94648..96e17a9da07b 100644 --- a/src/relay/backend/param_dict.h +++ b/src/relay/backend/param_dict.h @@ -32,32 +32,7 @@ #include namespace tvm { -namespace relay { - -/*! \brief Magic number for NDArray list file */ -constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7; - -/*! - * \brief Wrapper node for naming `NDArray`s. - */ -struct NamedNDArrayNode : public ::tvm::Object { - std::string name; - tvm::runtime::NDArray array; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("array", &array); - } - - static constexpr const char* _type_key = "NamedNDArray"; - TVM_DECLARE_FINAL_OBJECT_INFO(NamedNDArrayNode, Object); -}; - -class NamedNDArray : public ObjectRef { - public: - TVM_DEFINE_OBJECT_REF_METHODS(NamedNDArray, ObjectRef, NamedNDArrayNode); -}; -} // namespace relay +namespace relay {} // namespace relay } // namespace tvm #endif // TVM_RELAY_BACKEND_PARAM_DICT_H_ diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 0718191a2ff6..251a55f10b72 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -376,11 +376,16 @@ class VMFunctionCompiler : ExprFunctor { CompileMatch(match); } - void VisitExpr_(const LetNode* let_node) { - DLOG(INFO) << PrettyPrint(let_node->value); - this->VisitExpr(let_node->value); - var_register_map_.insert({let_node->var, this->last_register_}); - this->VisitExpr(let_node->body); + void VisitExpr_(const LetNode* l) final { + Expr let_binding = GetRef(l); + const LetNode* let; + while ((let = let_binding.as())) { + VisitExpr(let->value); + var_register_map_.insert({let->var, this->last_register_}); + let_binding = let->body; + } + + VisitExpr(let_binding); } void VisitExpr_(const TupleGetItemNode* get_node) { diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc index 650df99645e7..eb848eb7a828 100644 --- a/src/relay/backend/vm/inline_primitives.cc +++ b/src/relay/backend/vm/inline_primitives.cc @@ -58,8 +58,19 @@ struct PrimitiveInliner : ExprMutator { explicit PrimitiveInliner(const IRModule& module) : module_(module) {} Expr VisitExpr_(const LetNode* let_node) { - var_map.insert({let_node->var, VisitExpr(let_node->value)}); - return ExprMutator::VisitExpr_(let_node); + auto pre_visit = [this](const LetNode* op) { + var_map.insert({op->var, this->VisitExpr(op->value)}); + }; + auto post_visit = [this](const LetNode* op) { + // Rely on the Memoizer to cache pre-visit values + Expr value = this->VisitExpr(op->value); + // Visit body and cache the op + Expr body = this->VisitExpr(op->body); + auto expr = GetRef(op); + this->memo_[expr] = Let(op->var, value, body); + }; + ExpandANormalForm(let_node, pre_visit, post_visit); + return memo_[GetRef(let_node)]; } Expr VisitExpr_(const CallNode* call) { diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index fe9a544a719e..cc530a10188e 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -61,19 +61,30 @@ class LambdaLifter : public ExprMutator { explicit LambdaLifter(const IRModule& module) : module_(module) {} Expr VisitExpr_(const LetNode* let_node) final { - bool is_lambda = false; - if (auto func = let_node->value.as()) { - if (!func->HasNonzeroAttr(attr::kPrimitive)) { - is_lambda = true; - letrec_.push_back(let_node->var); + auto pre_visit = [this](const LetNode* op) { + bool is_lambda = false; + if (auto func = op->value.as()) { + if (!func->HasNonzeroAttr(attr::kPrimitive)) { + is_lambda = true; + this->letrec_.push_back(op->var); + } } - } - auto value = VisitExpr(let_node->value); - if (is_lambda) { - letrec_.pop_back(); - } - auto body = VisitExpr(let_node->body); - return Let(let_node->var, value, body); + Expr value = this->VisitExpr(op->value); + + if (is_lambda) { + this->letrec_.pop_back(); + } + }; + auto post_visit = [this](const LetNode* op) { + // Rely on the Memoizer to cache pre-visit values + Expr value = this->VisitExpr(op->value); + // Visit body and cache the op + Expr body = this->VisitExpr(op->body); + auto expr = GetRef(op); + this->memo_[expr] = Let(op->var, value, body); + }; + ExpandANormalForm(let_node, pre_visit, post_visit); + return memo_[GetRef(let_node)]; } Expr VisitExpr_(const CallNode* call_node) final { diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index d70c6fe2dd1f..5984a208efe0 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -103,11 +103,41 @@ Expr MixedModeMutator::VisitExpr(const Expr& expr) { class PostOrderRewriter : public MixedModeMutator { public: explicit PostOrderRewriter(ExprRewriter* rewriter) : rewriter_(rewriter) {} + Expr DispatchVisitExpr(const Expr& expr) final { auto post = ExprFunctor::VisitExpr(expr); return rewriter_->Rewrite(expr, post); } + using MixedModeMutator::VisitExpr_; + + Expr VisitExpr_(const LetNode* node) final { + auto pre_visit = [this](const LetNode* op) { + Expr var = this->Mutate(op->var); + Expr value = this->Mutate(op->value); + }; + auto post_visit = [this, node](const LetNode* op) { + Var var = Downcast(this->Mutate(op->var)); + Expr value = this->Mutate(op->value); + Expr body = this->Mutate(op->body); + Expr expr = GetRef(op); + Expr post; + if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { + post = expr; + } else { + post = Let(var, value, body); + } + // avoid rewriting the first LetNode twice + if (op == node) { + this->memo_[expr] = post; + } else { + this->memo_[expr] = this->rewriter_->Rewrite(expr, post); + } + }; + ExpandANormalForm(node, pre_visit, post_visit); + return memo_[GetRef(node)]; + } + protected: ExprRewriter* rewriter_; }; diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 38c33b45936e..0ea71de367fa 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -61,10 +61,10 @@ bool BiasAddRel(const Array& types, int num_inputs, const Attrs& attrs, if (axis < 0) { axis = data->shape.size() + axis; } - if (axis >= static_cast(data->shape.size())) { + if (axis >= static_cast(data->shape.size()) || axis < 0) { reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) << "The axis in bias_add must be in range for the shape; " - << "attempted to access index " << axis << " of " + << "attempted to access index " << param->axis << " of " << PrettyPrint(data->shape)); return false; } diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index c00e2e02b369..8802cd903b01 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -26,8 +26,8 @@ #include #include -#include #include +#include #include diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 941f43a5a2c4..e3929bf8b77e 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3179,6 +3179,9 @@ bool GatherRel(const Array& types, int num_inputs, const Attrs& attrs, const auto ndim_indices = indices->shape.size(); int axis = param->axis->value; ICHECK_EQ(ndim_data, ndim_indices); + if (axis < 0) { + axis += ndim_data; + } ICHECK_GE(axis, 0); ICHECK_LT(axis, ndim_data); diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc index 2716c6e65f65..d77ede3acbf9 100644 --- a/src/relay/quantize/realize.cc +++ b/src/relay/quantize/realize.cc @@ -165,7 +165,7 @@ Expr QuantizeRealize(const Call& ref_call, const Array& new_args, const Ob MakeConstantScalar(cfg->dtype_activation, static_cast(shift_nbit))); } else { data = LeftShift(data, - MakeConstantScalar(cfg->dtype_activation, static_cast(shift_nbit))); + MakeConstantScalar(cfg->dtype_activation, static_cast(-shift_nbit))); } data = Clip(data, clip_min_imm, clip_max_imm); return QRealizeIntExpr(data, dom_scale, n->dtype); diff --git a/src/relay/transforms/dead_code.cc b/src/relay/transforms/dead_code.cc index 2e7c08a684dc..26624e438b8a 100644 --- a/src/relay/transforms/dead_code.cc +++ b/src/relay/transforms/dead_code.cc @@ -46,10 +46,16 @@ class FindDef : private ExprVisitor { VarMap expr_map_; void VisitExpr_(const LetNode* l) final { - ICHECK_EQ(expr_map_.count(l->var), 0); - expr_map_[l->var] = l->value; - VisitExpr(l->value); - VisitExpr(l->body); + auto pre_visit = [this](const LetNode* op) { + ICHECK_EQ(expr_map_.count(op->var), 0); + expr_map_[op->var] = op->value; + this->VisitExpr(op->value); + }; + auto post_visit = [this](const LetNode* op) { + this->VisitExpr(op->body); + this->visit_counter_[op] += 1; + }; + ExpandANormalForm(l, pre_visit, post_visit); } friend CalcDep; @@ -81,12 +87,24 @@ class Eliminator : private ExprMutator { } Expr VisitExpr_(const LetNode* op) final { - Var v = op->var; - if (HasLet(v)) { - return Let(v, VisitExpr(op->value), VisitExpr(op->body)); - } else { - return VisitExpr(op->body); - } + auto pre_visit = [this](const LetNode* op) { + if (HasLet(op->var)) { + Expr value = this->VisitExpr(op->value); + } + }; + auto post_visit = [this](const LetNode* op) { + Expr body = this->VisitExpr(op->body); + auto expr = GetRef(op); + Var v = op->var; + if (HasLet(v)) { + Expr value = this->VisitExpr(op->value); + this->memo_[expr] = Let(v, value, body); + } else { + this->memo_[expr] = body; + } + }; + ExpandANormalForm(op, pre_visit, post_visit); + return memo_[GetRef(op)]; } }; @@ -121,7 +139,15 @@ class CalcDep : protected MixedModeVisitor { } } - void VisitExpr_(const LetNode* l) final { VisitExpr(l->body); } + void VisitExpr_(const LetNode* l) final { + Expr let_binding = GetRef(l); + const LetNode* let; + while ((let = let_binding.as())) { + let_binding = let->body; + visit_counter_[l] += 1; + } + VisitExpr(let_binding); + } void VisitExpr_(const VarNode* v) final { Var var = GetRef(v); diff --git a/src/relay/transforms/fold_explicit_padding.cc b/src/relay/transforms/fold_explicit_padding.cc new file mode 100644 index 000000000000..bab8b814df05 --- /dev/null +++ b/src/relay/transforms/fold_explicit_padding.cc @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/transforms/fold_explicit_padding.cc + * \brief A pass for folding explicit pads into other ops. + */ + +#include +#include +#include +#include +#include + +#include "../op/tensor/transform.h" +#include "pattern_utils.h" + +namespace tvm { +namespace relay { + +/*! + * \brief SimplifyConvPad matches a pad followed by a conv/convtranspose/pool/etc + * with a pad attribute and merges the padding into the kernel. + */ +class SimplifyConvPad { + public: + DFPattern pattern() const { return pattern_; } + + SimplifyConvPad() { + x_ = IsWildcard(); + w_ = IsWildcard(); + pad_ = IsOp("nn.pad")({x_}); + conv1d_ = IsOp("nn.conv1d"); + conv2d_ = IsOp("nn.conv2d"); + conv3d_ = IsOp("nn.conv3d"); + conv_ = (conv1d_ || conv2d_ || conv3d_)({pad_, w_}); + pattern_ = conv_; + } + + template + Attrs MakeConvAttrs(const T* old_attrs, const Array padding) const { + ICHECK(old_attrs); + ICHECK(padding.size() == old_attrs->padding.size()) + << "Number of dimensions to pad and convolution padding attributes should have the same " + "extent"; + + auto new_attrs = make_object(); + Array combined_padding; + for (size_t i = 0; i < padding.size(); ++i) { + combined_padding.push_back(padding[i] + old_attrs->padding[i]); + } + new_attrs->strides = old_attrs->strides; + new_attrs->padding = combined_padding; + new_attrs->dilation = old_attrs->dilation; + new_attrs->groups = old_attrs->groups; + new_attrs->channels = old_attrs->channels; + new_attrs->kernel_size = old_attrs->kernel_size; + new_attrs->data_layout = old_attrs->data_layout; + new_attrs->kernel_layout = old_attrs->kernel_layout; + new_attrs->out_layout = old_attrs->out_layout; + new_attrs->out_dtype = old_attrs->out_dtype; + return Attrs(new_attrs); + } + + template + Attrs GetAttrs(const PadAttrs* param, const T* attrs) const { + ICHECK(param); + ICHECK(attrs); + ICHECK(attrs->data_layout.size() == param->pad_width.size()) + << "Data Layout and padding attributes should have the same extent"; + + std::string data_layout = attrs->data_layout; + std::set image_dims({'H', 'W', 'D'}); + Array padding; + // If we're padding a non-spatial dimension, don't simplify + // Convolution can only pad on spatial axes + for (size_t i = 0; i < param->pad_width.size(); ++i) { + if (!image_dims.count(data_layout[i])) { + for (size_t j = 0; j < param->pad_width[i].size(); ++j) { + if (param->pad_width[i][j] != 0) { + return Attrs(); + } + } + } + } + for (size_t j = 0; j < param->pad_width[0].size(); ++j) { + for (size_t i = 0; i < param->pad_width.size(); ++i) { + if (image_dims.count(data_layout[i])) { + padding.push_back(param->pad_width[i][j]); + } + } + } + + return MakeConvAttrs(attrs, padding); + } + + Expr callback(const Expr& pre, const Expr& post, + const Map>& node_map) const { + const CallNode* call_node = post.as(); + ICHECK(call_node); + auto pad = node_map[pad_][0]; + const CallNode* pad_node = pad.as(); + ICHECK(pad_node); + const PadAttrs* param = pad_node->attrs.as(); + ICHECK(param); + if (param->pad_mode == "constant" && param->pad_value == 0.0) { + Attrs attrs; + if (node_map.count(conv1d_)) { + attrs = GetAttrs(param, call_node->attrs.as()); + } else if (node_map.count(conv2d_)) { + attrs = GetAttrs(param, call_node->attrs.as()); + } else if (node_map.count(conv3d_)) { + attrs = GetAttrs(param, call_node->attrs.as()); + } else { + return post; + } + if (!attrs.defined()) { + return post; + } + auto x = node_map[x_][0]; + auto w = node_map[w_][0]; + return Call(call_node->op, {x, w}, attrs, call_node->type_args, call_node->span); + } + return post; + } + + private: + /*! \brief Pattern for rewriting */ + DFPattern pattern_; + /*! \brief Pattern input */ + DFPattern x_; + /*! \brief Pattern input weight */ + DFPattern w_; + /*! \brief Pattern pad */ + DFPattern pad_; + /*! \brief Pattern conv */ + DFPattern conv_; + DFPattern conv1d_; + DFPattern conv2d_; + DFPattern conv3d_; +}; + +class SimplifyExplicitPadding { + public: + explicit SimplifyExplicitPadding(IRModule mod) : mod_(mod) { + CreateCallback(SimplifyConvPad()); + // TODO(mbrookhart): ConvTranspose(Pad(x)), Pool(Pad(x)) + } + template + void CreateCallback(const T& pattern) { + auto func = [pattern](TVMArgs args, TVMRetValue* rv) { + Expr pre = args[0]; + Expr post = args[1]; + Map> node_map = args[2]; + *rv = pattern.callback(pre, post, node_map); + }; + callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true)); + } + + Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); } + + private: + IRModule mod_; + /*! \brief Callbacks for expr simplification */ + Array callbacks_; +}; + +/*! + * \brief FoldExplicitPadding finds explict padding before an op that can + * support implicit padding and fuses them. + */ +Expr FoldExplicitPadding(const Expr& expr, const IRModule& mod) { + return SimplifyExplicitPadding(mod).Simplify(expr); +} + +namespace transform { + +Pass FoldExplicitPadding() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(FoldExplicitPadding(f, m)); + }; + return CreateFunctionPass(pass_func, 0, " FoldExplicitPadding", {"InferType"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.FoldExplicitPadding").set_body_typed(FoldExplicitPadding); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index bfe04e10a9d0..74e48dc4bc54 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -82,121 +82,6 @@ class SimplifyReshape : public SimplifyPattern { DFPattern x_; }; -/*! - * \brief SimplifyConvPad matches a pad followed by a conv/convtranspose/pool/etc - * with a pad attribute and merges the padding into the kernel. - */ -class SimplifyConvPad : public SimplifyPattern { - public: - SimplifyConvPad() { - x_ = IsWildcard(); - w_ = IsWildcard(); - pad_ = IsOp("nn.pad")({x_}); - conv1d_ = IsOp("nn.conv1d"); - conv2d_ = IsOp("nn.conv2d"); - conv3d_ = IsOp("nn.conv3d"); - conv_ = (conv1d_ || conv2d_ || conv3d_)({pad_, w_}); - pattern_ = conv_; - } - template - Attrs MakeConvAttrs(const T* old_attrs, const Array padding) const { - ICHECK(old_attrs); - ICHECK(padding.size() == old_attrs->padding.size()) - << "Number of dimensions to pad and convolution padding attributes should have the same " - "extent"; - - auto new_attrs = make_object(); - Array combined_padding; - for (size_t i = 0; i < padding.size(); ++i) { - combined_padding.push_back(padding[i] + old_attrs->padding[i]); - } - new_attrs->strides = old_attrs->strides; - new_attrs->padding = combined_padding; - new_attrs->dilation = old_attrs->dilation; - new_attrs->groups = old_attrs->groups; - new_attrs->channels = old_attrs->channels; - new_attrs->kernel_size = old_attrs->kernel_size; - new_attrs->data_layout = old_attrs->data_layout; - new_attrs->kernel_layout = old_attrs->kernel_layout; - new_attrs->out_layout = old_attrs->out_layout; - new_attrs->out_dtype = old_attrs->out_dtype; - return Attrs(new_attrs); - } - template - Attrs GetAttrs(const PadAttrs* param, const T* attrs) const { - ICHECK(param); - ICHECK(attrs); - ICHECK(attrs->data_layout.size() == param->pad_width.size()) - << "Data Layout and padding attributes should have the same extent"; - - std::string data_layout = attrs->data_layout; - std::set image_dims({'H', 'W', 'D'}); - Array padding; - // If we're padding a non-spatial dimension, don't simplify - // Convolution can only pad on spatial axes - for (size_t i = 0; i < param->pad_width.size(); ++i) { - if (!image_dims.count(data_layout[i])) { - for (size_t j = 0; j < param->pad_width[i].size(); ++j) { - if (param->pad_width[i][j] != 0) { - return Attrs(); - } - } - } - } - for (size_t j = 0; j < param->pad_width[0].size(); ++j) { - for (size_t i = 0; i < param->pad_width.size(); ++i) { - if (image_dims.count(data_layout[i])) { - padding.push_back(param->pad_width[i][j]); - } - } - } - - return MakeConvAttrs(attrs, padding); - } - Expr callback(const Expr& pre, const Expr& post, - const Map>& node_map) const override { - const CallNode* call_node = post.as(); - ICHECK(call_node); - auto pad = node_map[pad_][0]; - const CallNode* pad_node = pad.as(); - ICHECK(pad_node); - const PadAttrs* param = pad_node->attrs.as(); - ICHECK(param); - if (param->pad_mode == "constant" && param->pad_value == 0.0) { - Attrs attrs; - if (node_map.count(conv1d_)) { - attrs = GetAttrs(param, call_node->attrs.as()); - } else if (node_map.count(conv2d_)) { - attrs = GetAttrs(param, call_node->attrs.as()); - } else if (node_map.count(conv3d_)) { - attrs = GetAttrs(param, call_node->attrs.as()); - } else { - return post; - } - if (!attrs.defined()) { - return post; - } - auto x = node_map[x_][0]; - auto w = node_map[w_][0]; - return Call(call_node->op, {x, w}, attrs, call_node->type_args, call_node->span); - } - return post; - } - - private: - /*! \brief Pattern input */ - DFPattern x_; - /*! \brief Pattern input weight */ - DFPattern w_; - /*! \brief Pattern pad */ - DFPattern pad_; - /*! \brief Pattern conv */ - DFPattern conv_; - DFPattern conv1d_; - DFPattern conv2d_; - DFPattern conv3d_; -}; - /*! * \brief FullArgwhere finds full followed by argwhere and turns it into an Arange op */ @@ -278,7 +163,6 @@ class ExprSimplifier { explicit ExprSimplifier(IRModule mod) : mod_(mod) { CreateCallback(SimplifyReshape()); CreateCallback(FullElementwise()); - CreateCallback(SimplifyConvPad()); } template void CreateCallback(const T& pattern) { diff --git a/src/runtime/container.cc b/src/runtime/container.cc index 916a912b3c5e..3d9b1481f6e6 100644 --- a/src/runtime/container.cc +++ b/src/runtime/container.cc @@ -79,5 +79,100 @@ TVM_REGISTER_OBJECT_TYPE(ADTObj); TVM_REGISTER_OBJECT_TYPE(StringObj); TVM_REGISTER_OBJECT_TYPE(ClosureObj); +TVM_REGISTER_OBJECT_TYPE(ArrayNode); + +TVM_REGISTER_GLOBAL("runtime.Array").set_body([](TVMArgs args, TVMRetValue* ret) { + std::vector data; + for (int i = 0; i < args.size(); ++i) { + if (args[i].type_code() != kTVMNullptr) { + data.push_back(args[i].operator ObjectRef()); + } else { + data.push_back(ObjectRef(nullptr)); + } + } + *ret = Array(data); +}); + +TVM_REGISTER_GLOBAL("runtime.ArrayGetItem").set_body([](TVMArgs args, TVMRetValue* ret) { + int64_t i = args[1]; + ICHECK_EQ(args[0].type_code(), kTVMObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + ICHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); + ICHECK_LT(static_cast(i), n->size()) << "out of bound of array"; + *ret = n->at(i); +}); + +TVM_REGISTER_GLOBAL("runtime.ArraySize").set_body([](TVMArgs args, TVMRetValue* ret) { + ICHECK_EQ(args[0].type_code(), kTVMObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + ICHECK(ptr->IsInstance()); + *ret = static_cast(static_cast(ptr)->size()); +}); + +TVM_REGISTER_OBJECT_TYPE(MapNode); + +TVM_REGISTER_GLOBAL("runtime.Map").set_body([](TVMArgs args, TVMRetValue* ret) { + ICHECK_EQ(args.size() % 2, 0); + std::unordered_map data; + for (int i = 0; i < args.num_args; i += 2) { + ObjectRef k = + String::CanConvertFrom(args[i]) ? args[i].operator String() : args[i].operator ObjectRef(); + ObjectRef v = args[i + 1]; + data.emplace(std::move(k), std::move(v)); + } + *ret = Map(std::move(data)); +}); + +TVM_REGISTER_GLOBAL("runtime.MapSize").set_body([](TVMArgs args, TVMRetValue* ret) { + ICHECK_EQ(args[0].type_code(), kTVMObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + ICHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); + *ret = static_cast(n->size()); +}); + +TVM_REGISTER_GLOBAL("runtime.MapGetItem").set_body([](TVMArgs args, TVMRetValue* ret) { + ICHECK_EQ(args[0].type_code(), kTVMObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + ICHECK(ptr->IsInstance()); + + auto* n = static_cast(ptr); + auto it = n->find(String::CanConvertFrom(args[1]) ? args[1].operator String() + : args[1].operator ObjectRef()); + ICHECK(it != n->end()) << "cannot find the corresponding key in the Map"; + *ret = (*it).second; +}); + +TVM_REGISTER_GLOBAL("runtime.MapCount").set_body([](TVMArgs args, TVMRetValue* ret) { + ICHECK_EQ(args[0].type_code(), kTVMObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + ICHECK(ptr->IsInstance()); + const MapNode* n = static_cast(ptr); + int64_t cnt = n->count(String::CanConvertFrom(args[1]) ? args[1].operator String() + : args[1].operator ObjectRef()); + *ret = cnt; +}); + +TVM_REGISTER_GLOBAL("runtime.MapItems").set_body([](TVMArgs args, TVMRetValue* ret) { + ICHECK_EQ(args[0].type_code(), kTVMObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + auto* n = static_cast(ptr); + Array rkvs; + for (const auto& kv : *n) { + if (kv.first->IsInstance()) { + rkvs.push_back(Downcast(kv.first)); + } else { + rkvs.push_back(kv.first); + } + rkvs.push_back(kv.second); + } + *ret = std::move(rkvs); +}); + +#if (USE_FALLBACK_STL_MAP == 0) +TVM_DLL constexpr uint64_t DenseMapNode::kNextProbeLocation[]; +#endif + } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index ce69d4ca7bde..b12992f57159 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -167,7 +167,7 @@ inline void CallLtIgemm(TVMArgs args, TVMRetValue* ret, cublasLtHandle_t hdl) { ICHECK(CheckMixPrecisionType(A->dtype, C->dtype)) << "Unsupported data type"; int32_t alpha = args.size() > 5 ? args[5] : 1; int32_t beta = args.size() > 6 ? args[6] : 0; - cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; + cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr; auto A_data = reinterpret_cast(static_cast(A->data) + A->byte_offset); auto B_data = reinterpret_cast(static_cast(B->data) + B->byte_offset); auto C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); @@ -204,7 +204,7 @@ inline void CallLtIgemm(TVMArgs args, TVMRetValue* ret, cublasLtHandle_t hdl) { &order_COL32, sizeof(order_COL32))); CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl, operationDesc, &alpha, B_data, Adesc, A_data, Bdesc, &beta, - C_data, Cdesc, C_data, Cdesc, NULL, NULL, 0, 0)); + C_data, Cdesc, C_data, Cdesc, nullptr, nullptr, 0, nullptr)); } #endif diff --git a/src/runtime/contrib/cublas/cublas_utils.cc b/src/runtime/contrib/cublas/cublas_utils.cc index d4ec08770723..4b4a1b755e66 100644 --- a/src/runtime/contrib/cublas/cublas_utils.cc +++ b/src/runtime/contrib/cublas/cublas_utils.cc @@ -35,7 +35,7 @@ CuBlasThreadEntry::CuBlasThreadEntry() { CHECK_CUBLAS_ERROR(cublasCreate(&handle CuBlasThreadEntry::~CuBlasThreadEntry() { if (handle) { cublasDestroy(handle); - handle = 0; + handle = nullptr; } } diff --git a/src/runtime/contrib/json/json_runtime.h b/src/runtime/contrib/json/json_runtime.h index 3ae652ccaf24..55f16635b9e6 100644 --- a/src/runtime/contrib/json/json_runtime.h +++ b/src/runtime/contrib/json/json_runtime.h @@ -55,7 +55,7 @@ class JSONRuntimeBase : public ModuleNode { LoadGraph(graph_json_); } - const char* type_key() const { return "json"; } + const char* type_key() const override { return "json"; } /*! \brief Initialize a specific json runtime. */ virtual void Init(const Array& consts) = 0; @@ -69,7 +69,7 @@ class JSONRuntimeBase : public ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The packed function. */ - virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) override { if (name == "get_symbol") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->symbol_name_; }); @@ -98,7 +98,7 @@ class JSONRuntimeBase : public ModuleNode { } } - virtual void SaveToBinary(dmlc::Stream* stream) { + void SaveToBinary(dmlc::Stream* stream) override { // Save the symbol stream->Write(symbol_name_); // Save the graph diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.cc b/src/runtime/contrib/tensorrt/tensorrt_builder.cc index ee47e67001f3..09b36d720877 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_builder.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.cc @@ -99,6 +99,14 @@ void TensorRTBuilder::AddOutput(const JSONGraphNodeEntry& node, uint32_t entry_i ICHECK(it != node_output_map_.end()) << "Output was not found."; auto out_tensor = it->second[node.index_].tensor; std::string name = "tensorrt_output_" + std::to_string(network_output_names_.size()); + // If the network is already marked as an input or output, make a copy to avoid TRT crash. + if (out_tensor->isNetworkOutput()) { + LOG(WARNING) << name << " is a duplicate output."; + out_tensor = network_->addIdentity(*out_tensor)->getOutput(0); + } else if (out_tensor->isNetworkInput()) { + LOG(WARNING) << name << " is both an input and an output."; + out_tensor = network_->addIdentity(*out_tensor)->getOutput(0); + } out_tensor->setName(name.c_str()); network_->markOutput(*out_tensor); network_output_names_.push_back(name); diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.cc b/src/runtime/contrib/tensorrt/tensorrt_ops.cc index 824178eaa619..04b1e838ee8e 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_ops.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_ops.cc @@ -309,8 +309,8 @@ class Conv3DOpConverter : public TensorRTOpConverter { bool use_asymmetric_padding; GetPadding3D(str_padding, &use_asymmetric_padding, &prepadding, &postpadding); - // Could use attrs->channels.as()->value - const int num_outputs = weight_shape[0]; + const int num_outputs = + std::stoi(params->node.GetAttr>("channels")[0]); const auto kernel_size = nvinfer1::Dims3(weight_shape[2], weight_shape[3], weight_shape[4]); nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; auto conv_layer = params->network->addConvolutionNd(*input_tensor, num_outputs, kernel_size, @@ -788,8 +788,8 @@ class Conv2DTransposeOpConverter : public TensorRTOpConverter { } #endif - // Could use conv2d_attr->channels.as()->value - const int num_outputs = weight_shape[1]; + const int num_outputs = + std::stoi(params->node.GetAttr>("channels")[0]); const auto kernel_size = nvinfer1::DimsHW(weight_shape[2], weight_shape[3]); nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; auto deconv_layer = params->network->addDeconvolution(*input_tensor, num_outputs, kernel_size, @@ -846,8 +846,8 @@ class Conv3DTransposeOpConverter : public TensorRTOpConverter { bool use_asymmetric_padding; GetPadding3D(str_padding, &use_asymmetric_padding, &prepadding, &postpadding); - // Could use attrs->channels.as()->value - const int num_outputs = weight_shape[1]; + const int num_outputs = + std::stoi(params->node.GetAttr>("channels")[0]); const auto kernel_size = nvinfer1::Dims3(weight_shape[2], weight_shape[3], weight_shape[4]); nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; auto deconv_layer = params->network->addDeconvolutionNd(*input_tensor, num_outputs, kernel_size, diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index c77395422e87..f156d68d283e 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -243,5 +244,40 @@ TVM_REGISTER_GLOBAL("device_api.cpu_pinned").set_body([](TVMArgs args, TVMRetVal *rv = static_cast(ptr); }); +class GPUTimerNode : public TimerNode { + public: + virtual void Start() { + CUDA_CALL(cudaEventRecord(start_, CUDAThreadEntry::ThreadLocal()->stream)); + } + virtual void Stop() { CUDA_CALL(cudaEventRecord(stop_, CUDAThreadEntry::ThreadLocal()->stream)); } + virtual int64_t SyncAndGetElapsedNanos() { + CUDA_CALL(cudaEventSynchronize(stop_)); + float milliseconds = 0; + CUDA_CALL(cudaEventElapsedTime(&milliseconds, start_, stop_)); + return milliseconds * 1e6; + } + virtual ~GPUTimerNode() { + CUDA_CALL(cudaEventDestroy(start_)); + CUDA_CALL(cudaEventDestroy(stop_)); + } + GPUTimerNode() { + CUDA_CALL(cudaEventCreate(&start_)); + CUDA_CALL(cudaEventCreate(&stop_)); + } + + static constexpr const char* _type_key = "GPUTimerNode"; + TVM_DECLARE_FINAL_OBJECT_INFO(GPUTimerNode, TimerNode); + + private: + cudaEvent_t start_; + cudaEvent_t stop_; +}; + +TVM_REGISTER_OBJECT_TYPE(GPUTimerNode); + +TVM_REGISTER_GLOBAL("profiling.timer.gpu").set_body_typed([](TVMContext ctx) { + return Timer(make_object()); +}); + } // namespace runtime } // namespace tvm diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc index 42cbfdc3b1ed..92c398b559d2 100644 --- a/src/runtime/file_utils.cc +++ b/src/runtime/file_utils.cc @@ -23,6 +23,8 @@ #include "file_utils.h" #include +#include +#include #include #include @@ -157,5 +159,71 @@ void LoadMetaDataFromFile(const std::string& file_name, void RemoveFile(const std::string& file_name) { std::remove(file_name.c_str()); } +Map LoadParams(const std::string& param_blob) { + dmlc::MemoryStringStream strm(const_cast(¶m_blob)); + return LoadParams(&strm); +} +Map LoadParams(dmlc::Stream* strm) { + Map params; + uint64_t header, reserved; + ICHECK(strm->Read(&header)) << "Invalid parameters file format"; + ICHECK(header == kTVMNDArrayListMagic) << "Invalid parameters file format"; + ICHECK(strm->Read(&reserved)) << "Invalid parameters file format"; + + std::vector names; + ICHECK(strm->Read(&names)) << "Invalid parameters file format"; + uint64_t sz; + strm->Read(&sz); + size_t size = static_cast(sz); + ICHECK(size == names.size()) << "Invalid parameters file format"; + for (size_t i = 0; i < size; ++i) { + // The data_entry is allocated on device, NDArray.load always load the array into CPU. + NDArray temp; + temp.Load(strm); + params.Set(names[i], temp); + } + return params; +} + +void SaveParams(dmlc::Stream* strm, const Map& params) { + std::vector names; + std::vector arrays; + for (auto& p : params) { + names.push_back(p.first); + arrays.push_back(p.second.operator->()); + } + + uint64_t header = kTVMNDArrayListMagic, reserved = 0; + strm->Write(header); + strm->Write(reserved); + strm->Write(names); + { + uint64_t sz = static_cast(arrays.size()); + strm->Write(sz); + for (size_t i = 0; i < sz; ++i) { + tvm::runtime::SaveDLTensor(strm, arrays[i]); + } + } +} + +std::string SaveParams(const Map& params) { + std::string bytes; + dmlc::MemoryStringStream strm(&bytes); + dmlc::Stream* fo = &strm; + SaveParams(fo, params); + return bytes; +} + +TVM_REGISTER_GLOBAL("runtime.SaveParams").set_body_typed([](const Map& params) { + std::string s = ::tvm::runtime::SaveParams(params); + // copy return array so it is owned by the ret value + TVMRetValue rv; + rv = TVMByteArray{s.data(), s.size()}; + return rv; +}); +TVM_REGISTER_GLOBAL("runtime.LoadParams").set_body_typed([](const String& s) { + return ::tvm::runtime::LoadParams(s); +}); + } // namespace runtime } // namespace tvm diff --git a/src/runtime/file_utils.h b/src/runtime/file_utils.h index 696a9760c2e1..718d10d5df70 100644 --- a/src/runtime/file_utils.h +++ b/src/runtime/file_utils.h @@ -24,6 +24,8 @@ #ifndef TVM_RUNTIME_FILE_UTILS_H_ #define TVM_RUNTIME_FILE_UTILS_H_ +#include + #include #include @@ -92,6 +94,32 @@ void LoadMetaDataFromFile(const std::string& file_name, * \param file_name The file name. */ void RemoveFile(const std::string& file_name); + +constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7; +/*! + * \brief Load parameters from a string. + * \param param_blob Serialized string of parameters. + * \return Map of parameter name to parameter value. + */ +Map LoadParams(const std::string& param_blob); +/*! + * \brief Load parameters from a stream. + * \param strm Stream to load parameters from. + * \return Map of parameter name to parameter value. + */ +Map LoadParams(dmlc::Stream* strm); +/*! + * \brief Serialize parameters to a byte array. + * \param params Parameters to save. + * \return String containing binary parameter data. + */ +std::string SaveParams(const Map& params); +/*! + * \brief Serialize parameters to a stream. + * \param strm Stream to write to. + * \param params Parameters to save. + */ +void SaveParams(dmlc::Stream* strm, const Map& params); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_FILE_UTILS_H_ diff --git a/src/runtime/graph/cugraph/graph_runtime_cugraph.cc b/src/runtime/graph/cugraph/graph_runtime_cugraph.cc new file mode 100644 index 000000000000..eccd69130c17 --- /dev/null +++ b/src/runtime/graph/cugraph/graph_runtime_cugraph.cc @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file graph_runtime_cugraph.cc + */ + +#include + +#include "../../cuda/cuda_common.h" +#include "../graph_runtime.h" + +namespace tvm { +namespace runtime { + +class GraphRuntimeCuGraph : public GraphRuntime { + public: + int StartCapture() { + const TVMContext& ctx = data_entry_[entry_id(0, 0)]->ctx; + + TVMStreamCreate(ctx.device_type, ctx.device_id, &capture_stream_); + TVMSetStream(ctx.device_type, ctx.device_id, capture_stream_); + + CUDA_CALL(cudaStreamBeginCapture(static_cast(capture_stream_), + cudaStreamCaptureModeGlobal)); + return 0; + } + + int RunCudaGraph() { + cudaStream_t cuStream = static_cast(capture_stream_); + CUDA_CALL(cudaGraphLaunch(cu_graph_exec_, cuStream)); + CUDA_CALL(cudaStreamSynchronize(cuStream)); + return 0; + } + + int EndCapture() { + cudaGraph_t graph; + CUDA_CALL(cudaStreamEndCapture(static_cast(capture_stream_), &graph)); + + cudaGraphNode_t* nodes = NULL; + size_t numNodes = 0; + CUDA_CALL(cudaGraphGetNodes(graph, nodes, &numNodes)); + LOG(INFO) << "Num of nodes in the cuda graph created using stream capture API = " << numNodes; + + CUDA_CALL(cudaGraphInstantiate(&cu_graph_exec_, graph, NULL, NULL, 0)); + return 0; + } + + /*! + * \brief GetFunction Get the function based on input. + * \param name The function which needs to be invoked. + * \param sptr_to_self Packed function pointer. + */ + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); + + private: + TVMStreamHandle capture_stream_; + cudaGraphExec_t cu_graph_exec_; +}; + +PackedFunc GraphRuntimeCuGraph::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { + if (name == "run_cuda_graph") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->RunCudaGraph(); }); + } else if (name == "start_capture") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->StartCapture(); }); + } else if (name == "end_capture") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->EndCapture(); }); + } else { + return GraphRuntime::GetFunction(name, sptr_to_self); + } +} + +Module GraphRuntimeCuGraphCreate(const std::string& sym_json, const tvm::runtime::Module& m, + const std::vector& ctxs, + PackedFunc lookup_linked_param_func) { + auto exec = make_object(); + exec->Init(sym_json, m, ctxs, lookup_linked_param_func); + return Module(exec); +} + +TVM_REGISTER_GLOBAL("tvm.graph_runtime_cugraph.create").set_body([](TVMArgs args, TVMRetValue* rv) { + ICHECK_GE(args.num_args, 4) << "The expected number of arguments for graph_runtime.create is " + "at least 4, but it has " + << args.num_args; + PackedFunc lookup_linked_param_func; + int ctx_start_arg = 2; + if (args[2].type_code() == kTVMPackedFuncHandle) { + lookup_linked_param_func = args[2]; + ctx_start_arg++; + } + + *rv = GraphRuntimeCuGraphCreate(args[0], args[1], GetAllContext(args, ctx_start_arg), + lookup_linked_param_func); +}); + +} // namespace runtime +} // namespace tvm + diff --git a/src/runtime/graph/debug/graph_runtime_debug.cc b/src/runtime/graph/debug/graph_runtime_debug.cc index 93bdd065c9d9..0e3003aa42c3 100644 --- a/src/runtime/graph/debug/graph_runtime_debug.cc +++ b/src/runtime/graph/debug/graph_runtime_debug.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -77,16 +78,25 @@ class GraphRuntimeDebug : public GraphRuntime { number * 1.618)); // 1.618 is chosen by random } tbegin = std::chrono::high_resolution_clock::now(); + std::vector> op_timers; + for (size_t index = 0; index < op_execs_.size(); index++) { + op_timers.push_back({}); + } for (int k = 0; k < number; k++) { for (size_t index = 0; index < op_execs_.size(); ++index) { if (op_execs_[index]) { - time_sec_per_op[index] += RunOpHost(index); + op_timers[index].push_back(RunOpHost(index)); } } } + for (size_t index = 0; index < op_execs_.size(); ++index) { + for (auto t : op_timers[index]) { + time_sec_per_op[index] += t->SyncAndGetElapsedNanos() / 1e9; + } + } tend = std::chrono::high_resolution_clock::now(); duration_ms = - std::chrono::duration_cast >(tend - tbegin).count() * + std::chrono::duration_cast>(tend - tbegin).count() * 1000; } while (duration_ms < min_repeat_ms); @@ -160,15 +170,12 @@ class GraphRuntimeDebug : public GraphRuntime { return results_arr[0]; } - double RunOpHost(int index) { - auto op_tbegin = std::chrono::high_resolution_clock::now(); - op_execs_[index](); + Timer RunOpHost(int index) { const TVMContext& ctx = data_entry_[entry_id(index, 0)]->ctx; - TVMSynchronize(ctx.device_type, ctx.device_id, nullptr); - auto op_tend = std::chrono::high_resolution_clock::now(); - double op_duration = - std::chrono::duration_cast >(op_tend - op_tbegin).count(); - return op_duration; + Timer t = Timer::Start(ctx); + op_execs_[index](); + t->Stop(); + return t; } /*! diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index 6d586cfdd042..6c51e711aef1 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -38,6 +38,8 @@ #include #include +#include "../file_utils.h" + namespace tvm { namespace runtime { namespace details { @@ -196,31 +198,10 @@ void GraphRuntime::LoadParams(const std::string& param_blob) { } void GraphRuntime::LoadParams(dmlc::Stream* strm) { - uint64_t header, reserved; - ICHECK(strm->Read(&header)) << "Invalid parameters file format"; - ICHECK(header == kTVMNDArrayListMagic) << "Invalid parameters file format"; - ICHECK(strm->Read(&reserved)) << "Invalid parameters file format"; - - std::vector names; - ICHECK(strm->Read(&names)) << "Invalid parameters file format"; - uint64_t sz; - strm->Read(&sz); - size_t size = static_cast(sz); - ICHECK(size == names.size()) << "Invalid parameters file format"; - for (size_t i = 0; i < size; ++i) { - int in_idx = GetInputIndex(names[i]); - if (in_idx < 0) { - NDArray temp; - temp.Load(strm); - continue; - } - uint32_t eid = this->entry_id(input_nodes_[in_idx], 0); - ICHECK_LT(eid, data_entry_.size()); - - // The data_entry is allocated on device, NDArray.load always load the array into CPU. - NDArray temp; - temp.Load(strm); - data_entry_[eid].CopyFrom(temp); + Map params = ::tvm::runtime::LoadParams(strm); + for (auto& p : params) { + uint32_t eid = this->entry_id(input_nodes_[GetInputIndex(p.first)], 0); + data_entry_[eid].CopyFrom(p.second); } } diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h index 627911883dfb..a1e2ee3b5d74 100644 --- a/src/runtime/graph/graph_runtime.h +++ b/src/runtime/graph/graph_runtime.h @@ -47,9 +47,6 @@ namespace runtime { ICHECK_EQ(ret, 0) << TVMGetLastError(); \ } -/*! \brief Magic number for NDArray list file */ -constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7; - /*! \brief operator attributes about tvm op */ struct TVMOpParam { std::string func_name; diff --git a/src/runtime/graph/graph_runtime_factory.cc b/src/runtime/graph/graph_runtime_factory.cc index 2c055e16cc9f..4d3993a9a36f 100644 --- a/src/runtime/graph/graph_runtime_factory.cc +++ b/src/runtime/graph/graph_runtime_factory.cc @@ -24,7 +24,7 @@ #include "./graph_runtime_factory.h" -#include +#include #include #include diff --git a/src/runtime/metadata_module.cc b/src/runtime/metadata_module.cc index acef9d4736fd..665c72cc5e0d 100644 --- a/src/runtime/metadata_module.cc +++ b/src/runtime/metadata_module.cc @@ -27,7 +27,7 @@ * code and metadata significantly reduces the efforts for handling external * codegen and runtimes. */ -#include +#include #include #include #include diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 981dd6129f9e..8f1fde86f074 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -180,7 +180,7 @@ void Init(MetalModuleNode* m, ObjectPtr sptr, const std::string& func_na scache_[dev_id] = m->GetPipelineState(dev_id, func_name); } // invoke the function with void arguments - void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const { + void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const { metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal(); int device_id = t->context.device_id; if (scache_[device_id] == nil) { @@ -197,7 +197,7 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const } if (num_pack_args_ != 0) { [encoder setBytes:pack_args - length:num_pack_args_ * sizeof(ArgUnion) + length:num_pack_args_ * sizeof(ArgUnion64) atIndex:num_buffer_args_]; } // launch diff --git a/src/runtime/micro/micro_session.cc b/src/runtime/micro/micro_session.cc index f26a717dae33..6c0d0c4c40fe 100644 --- a/src/runtime/micro/micro_session.cc +++ b/src/runtime/micro/micro_session.cc @@ -172,7 +172,7 @@ class MicroTransportChannel : public RPCChannel { // confusion. unsigned int seed = random_seed.load(); if (seed == 0) { - seed = (unsigned int)time(NULL); + seed = (unsigned int)time(nullptr); } uint8_t initial_nonce = 0; for (int i = 0; i < kNumRandRetries && initial_nonce == 0; ++i) { diff --git a/src/runtime/module.cc b/src/runtime/module.cc index 4cec5e3643c1..d84a8215421f 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -178,7 +178,7 @@ TVM_REGISTER_GLOBAL("runtime.ModuleGetTypeKey").set_body_typed([](Module mod) { TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile").set_body_typed(Module::LoadFromFile); TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile") - .set_body_typed([](Module mod, std::string name, std::string fmt) { + .set_body_typed([](Module mod, tvm::String name, tvm::String fmt) { mod->SaveToFile(name, fmt); }); diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h index 45cde22bda08..7c852da77df6 100644 --- a/src/runtime/pack_args.h +++ b/src/runtime/pack_args.h @@ -40,13 +40,24 @@ namespace tvm { namespace runtime { /*! * \brief argument union type of 32bit. - * Choose 32 bit because most GPU API do not work well with 64 bit. */ -union ArgUnion { +union ArgUnion32 { int32_t v_int32; uint32_t v_uint32; float v_float32; }; + +/*! + * \brief argument union type of 64 bit, for use by Vulkan and Metal runtime. + */ +union ArgUnion64 { + int32_t v_int32[2]; + uint32_t v_uint32[2]; + float v_float32[2]; + int64_t v_int64; + uint64_t v_uint64; + double v_float64; +}; /*! * \brief Create a packed function from void addr types. * @@ -140,9 +151,9 @@ inline PackedFunc PackFuncVoidAddr_(F f, const std::vector& code int num_args = static_cast(codes.size()); auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) { TempArray addr_(num_args); - TempArray holder_(num_args); + TempArray holder_(num_args); void** addr = addr_.data(); - ArgUnion* holder = holder_.data(); + ArgUnion32* holder = holder_.data(); for (int i = 0; i < num_args; ++i) { switch (codes[i]) { case INT64_TO_INT64: @@ -177,25 +188,28 @@ template inline PackedFunc PackFuncNonBufferArg_(F f, int base, const std::vector& codes) { int num_args = static_cast(codes.size()); auto ret = [f, codes, base, num_args](TVMArgs args, TVMRetValue* ret) { - TempArray holder_(num_args); - ArgUnion* holder = holder_.data(); + TempArray holder_(num_args); + ArgUnion64* holder = holder_.data(); for (int i = 0; i < num_args; ++i) { switch (codes[i]) { - case INT64_TO_INT64: + case INT64_TO_INT64: { + holder[i].v_int64 = args.values[base + i].v_int64; + break; + } case FLOAT64_TO_FLOAT64: { - LOG(FATAL) << "Do not support 64bit argument to device function"; + holder[i].v_float64 = args.values[base + i].v_float64; break; } case INT64_TO_INT32: { - holder[i].v_int32 = static_cast(args.values[base + i].v_int64); + holder[i].v_int32[0] = static_cast(args.values[base + i].v_int64); break; } case INT64_TO_UINT32: { - holder[i].v_uint32 = static_cast(args.values[base + i].v_int64); + holder[i].v_uint32[0] = static_cast(args.values[base + i].v_int64); break; } case FLOAT64_TO_FLOAT32: { - holder[i].v_float32 = static_cast(args.values[base + i].v_float64); + holder[i].v_float32[0] = static_cast(args.values[base + i].v_float64); break; } case HANDLE_TO_HANDLE: { diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc new file mode 100644 index 000000000000..3d204166986d --- /dev/null +++ b/src/runtime/profiling.cc @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/profiling.cc + * \brief Runtime profiling including timers. + */ + +#include +#include + +#include +#include + +namespace tvm { +namespace runtime { + +class DefaultTimerNode : public TimerNode { + public: + virtual void Start() { + TVMSynchronize(ctx_.device_type, ctx_.device_id, nullptr); + start_ = std::chrono::high_resolution_clock::now(); + } + virtual void Stop() { + TVMSynchronize(ctx_.device_type, ctx_.device_id, nullptr); + duration_ = std::chrono::high_resolution_clock::now() - start_; + } + virtual int64_t SyncAndGetElapsedNanos() { return duration_.count(); } + virtual ~DefaultTimerNode() {} + + explicit DefaultTimerNode(TVMContext ctx) : ctx_(ctx) {} + static constexpr const char* _type_key = "DefaultTimerNode"; + TVM_DECLARE_FINAL_OBJECT_INFO(DefaultTimerNode, TimerNode); + + private: + std::chrono::high_resolution_clock::time_point start_; + std::chrono::duration duration_; + TVMContext ctx_; +}; + +TVM_REGISTER_OBJECT_TYPE(DefaultTimerNode); +TVM_REGISTER_OBJECT_TYPE(TimerNode); + +Timer DefaultTimer(TVMContext ctx) { return Timer(make_object(ctx)); } + +class CPUTimerNode : public TimerNode { + public: + virtual void Start() { start_ = std::chrono::high_resolution_clock::now(); } + virtual void Stop() { duration_ = std::chrono::high_resolution_clock::now() - start_; } + virtual int64_t SyncAndGetElapsedNanos() { return duration_.count(); } + virtual ~CPUTimerNode() {} + + static constexpr const char* _type_key = "CPUTimerNode"; + TVM_DECLARE_FINAL_OBJECT_INFO(CPUTimerNode, TimerNode); + + private: + std::chrono::high_resolution_clock::time_point start_; + std::chrono::duration duration_; +}; +TVM_REGISTER_OBJECT_TYPE(CPUTimerNode); + +TVM_REGISTER_GLOBAL("profiling.timer.cpu").set_body_typed([](TVMContext ctx) { + return Timer(make_object()); +}); + +Timer Timer::Start(TVMContext ctx) { + auto f = Registry::Get(std::string("profiling.timer.") + DeviceName(ctx.device_type)); + if (f == nullptr) { + Timer t = DefaultTimer(ctx); + t->Start(); + return t; + } else { + Timer t = f->operator()(ctx); + t->Start(); + return t; + } +} + +TVM_REGISTER_GLOBAL("profiling.start_timer").set_body_typed(Timer::Start); +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index 26e44eca0d12..5f24ce0eec48 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -200,5 +201,41 @@ TVM_REGISTER_GLOBAL("device_api.rocm").set_body([](TVMArgs args, TVMRetValue* rv DeviceAPI* ptr = ROCMDeviceAPI::Global(); *rv = static_cast(ptr); }); + +class ROCMTimerNode : public TimerNode { + public: + virtual void Start() { + ROCM_CALL(hipEventRecord(start_, ROCMThreadEntry::ThreadLocal()->stream)); + } + virtual void Stop() { ROCM_CALL(hipEventRecord(stop_, ROCMThreadEntry::ThreadLocal()->stream)); } + virtual int64_t SyncAndGetElapsedNanos() { + ROCM_CALL(hipEventSynchronize(stop_)); + float milliseconds = 0; + ROCM_CALL(hipEventElapsedTime(&milliseconds, start_, stop_)); + return milliseconds * 1e6; + } + virtual ~ROCMTimerNode() { + ROCM_CALL(hipEventDestroy(start_)); + ROCM_CALL(hipEventDestroy(stop_)); + } + ROCMTimerNode() { + ROCM_CALL(hipEventCreate(&start_)); + ROCM_CALL(hipEventCreate(&stop_)); + } + + static constexpr const char* _type_key = "ROCMTimerNode"; + TVM_DECLARE_FINAL_OBJECT_INFO(ROCMTimerNode, TimerNode); + + private: + hipEvent_t start_; + hipEvent_t stop_; +}; + +TVM_REGISTER_OBJECT_TYPE(ROCMTimerNode); + +TVM_REGISTER_GLOBAL("profiling.timer.rocm").set_body_typed([](TVMContext ctx) { + return Timer(make_object()); +}); + } // namespace runtime } // namespace tvm diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc index 94d827893b92..fc01a754ca50 100644 --- a/src/runtime/vm/profiler/vm.cc +++ b/src/runtime/vm/profiler/vm.cc @@ -45,7 +45,15 @@ PackedFunc VirtualMachineDebug::GetFunction(const std::string& name, return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { ICHECK_EQ(args.size(), 1U); std::vector> op_acc_time; - for (auto kv : op_durations_) { + std::unordered_map> op_durations; + for (auto kv : op_timers_) { + std::vector durations_us; + for (auto t : kv.second) { + durations_us.push_back(t->SyncAndGetElapsedNanos() / 1e3); + } + op_durations[kv.first] = durations_us; + } + for (auto kv : op_durations) { auto val = std::make_pair(kv.first, std::accumulate(kv.second.begin(), kv.second.end(), 0.0)); op_acc_time.push_back(val); @@ -66,7 +74,7 @@ PackedFunc VirtualMachineDebug::GetFunction(const std::string& name, << "#Duration(us): Sum/Mean/Min/Max" << std::endl; for (auto kv : op_acc_time) { - auto vals = op_durations_[kv.first]; + auto vals = op_durations[kv.first]; auto sum = kv.second; auto mean = sum / static_cast(vals.size()); auto min_value = *std::min_element(vals.begin(), vals.end()); @@ -85,7 +93,7 @@ PackedFunc VirtualMachineDebug::GetFunction(const std::string& name, }); } else if (name == "reset") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - op_durations_.clear(); + op_timers_.clear(); op_invokes_.clear(); }); } else { @@ -118,16 +126,11 @@ void VirtualMachineDebug::InvokePacked(Index packed_index, const PackedFunc& fun auto nd_array = Downcast(arg); auto ctx = nd_array->ctx; - TVMSynchronize(ctx.device_type, ctx.device_id, nullptr); - - auto op_begin = std::chrono::high_resolution_clock::now(); + Timer t = Timer::Start(ctx); VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size, args); - TVMSynchronize(ctx.device_type, ctx.device_id, nullptr); - auto op_end = std::chrono::high_resolution_clock::now(); - double op_duration = - std::chrono::duration_cast>(op_end - op_begin).count(); + t->Stop(); - op_durations_[packed_index].push_back(op_duration * 1e6); + op_timers_[packed_index].push_back(t); op_invokes_[packed_index] += 1; } diff --git a/src/runtime/vm/profiler/vm.h b/src/runtime/vm/profiler/vm.h index 797d414fe8f3..9f5ce87bcf47 100644 --- a/src/runtime/vm/profiler/vm.h +++ b/src/runtime/vm/profiler/vm.h @@ -25,6 +25,7 @@ #ifndef TVM_RUNTIME_VM_PROFILER_VM_H_ #define TVM_RUNTIME_VM_PROFILER_VM_H_ +#include #include #include @@ -51,7 +52,7 @@ class VirtualMachineDebug : public VirtualMachine { const std::vector& args) final; std::unordered_map packed_index_map_; - std::unordered_map> op_durations_; + std::unordered_map> op_timers_; std::unordered_map op_invokes_; }; diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 3f890baf52c0..6d121aa67733 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -35,6 +35,8 @@ #include #include +#include "../file_utils.h" + using namespace tvm::runtime; namespace tvm { diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc index f40fd80f38b5..794f3c570f96 100644 --- a/src/runtime/vulkan/vulkan.cc +++ b/src/runtime/vulkan/vulkan.cc @@ -711,7 +711,7 @@ class VulkanWrappedFunc { thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags); } - void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const; + void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const; private: // internal module @@ -875,7 +875,7 @@ class VulkanModuleNode final : public runtime::ModuleNode { VkPushConstantRange crange; crange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; crange.offset = 0; - crange.size = sizeof(ArgUnion) * num_pack_args; + crange.size = sizeof(ArgUnion64) * num_pack_args; VkPipelineLayoutCreateInfo playout_cinfo; playout_cinfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO; @@ -1046,7 +1046,8 @@ VulkanStream* VulkanThreadEntry::Stream(size_t device_id) { return streams_[device_id].get(); } -void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const { +void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, + const ArgUnion64* pack_args) const { int device_id = VulkanThreadEntry::ThreadLocal()->ctx.device_id; ICHECK_LT(device_id, kVulkanMaxNumDevice); const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); @@ -1075,7 +1076,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion descriptor_buffers.data()); if (num_pack_args_ != 0) { vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, - VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion), + VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion64), pack_args); } vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2)); @@ -1093,7 +1094,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion } // Otherwise, the more expensive deferred path. - std::vector pack_args_storage(pack_args, pack_args + num_pack_args_); + std::vector pack_args_storage(pack_args, pack_args + num_pack_args_); const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers]() { std::vector write_descriptor_sets; write_descriptor_sets.resize(descriptor_buffers.size()); @@ -1119,7 +1120,8 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion nullptr); if (pack_args_storage.size() != 0) { vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT, - 0, pack_args_storage.size() * sizeof(ArgUnion), pack_args_storage.data()); + 0, pack_args_storage.size() * sizeof(ArgUnion64), + pack_args_storage.data()); } vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2)); VkMemoryBarrier barrier_info; diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index 0f394f50fe71..d6c8f1799596 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -#include +#include #include #include diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 2e9babacc441..e54acd2221d1 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -61,6 +61,18 @@ std::string CodeGenCUDA::Finish() { decl_stream << _cuda_half_util; } + if (enable_bf16_) { + decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)\n"; + decl_stream << "#include \n"; + decl_stream << "__device__ nv_bfloat16 max" + << "(nv_bfloat16 a, nv_bfloat16 b)\n" + << "{\n return __hgt(a, b) ? a : b;\n}\n"; + decl_stream << "__device__ nv_bfloat16 min(nv_bfloat16 a, nv_bfloat16 b)\n" + << "{\n return __hlt(a, b) ? a : b;\n}\n"; + decl_stream << "#endif\n\n"; + decl_stream << _cuda_bfloat16_util; + } + if (enable_warp_shuffle_) { decl_stream << _cuda_warp_intrinsic_util; } @@ -170,6 +182,17 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) os << lanes; return; } + } else if (t.is_bfloat16()) { + enable_bf16_ = true; + if (t.is_scalar()) { + os << "nv_bfloat16"; + } else if (lanes <= 8) { + ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; + os << "uint" << lanes / 2; + } else { + fail = true; + } + if (!fail) return; } else if (t == DataType::Bool()) { os << "bool"; return; @@ -382,6 +405,8 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i, } } else if (t.is_float16()) { os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; + } else if (t.is_bfloat16()) { + os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; } else if (t.lanes() > 4 && t.lanes() <= 8) { std::string type_name; if (t.bits() == 16) { @@ -427,6 +452,9 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i, } else if (t.is_float16()) { stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = " << value << ";\n"; + } else if (t.is_bfloat16()) { + stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] + << " = " << value << ";\n"; } else if (t.lanes() > 4 && t.lanes() <= 8) { std::string type_name; if (t.bits() == 16) { @@ -687,7 +715,8 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) || op->dtype == DataType::UInt(8) || op->dtype == DataType::Int(4) || - op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1)) + op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1) || + op->dtype == DataType::BFloat(16)) << "Matrix_a and matrix_b only support half or char or unsigned char " << "or uint4 or int4 or int1 type for now"; } else { @@ -767,6 +796,19 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO return; } + if (op->dtype.is_bfloat16()) { + std::string v = PrintExpr(op->value); + os << "make_"; + PrintType(op->dtype, os); + os << '('; + for (int i = 0; i < op->lanes / 2; ++i) { + if (i != 0) os << ", "; + os << "__pack_nv_bfloat162(" << v << ", " << v << ")"; + } + os << ')'; + return; + } + std::string v = PrintExpr(op->value); os << "make_"; PrintType(op->dtype, os); @@ -836,6 +878,13 @@ void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream& os) { } inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*) + // Type code is kBFloat + if (op->dtype.is_bfloat16()) { + os << "__float2bfloat16_rn"; + os << '(' << std::scientific << op->value << 'f' << ')'; + return; + } + // Type code is kFloat switch (op->dtype.bits()) { case 64: case 32: { @@ -938,7 +987,7 @@ void CodeGenCUDA::HandleVolatileLoads(const std::string& value, const LoadNode* // Cast away volatile qualifier for fp16 types. That is, only loads and // stores are volatile. The loaded objects are not marked as volatile. // - if (op->dtype.is_float16() && IsVolatile(op->buffer_var.get())) { + if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) && IsVolatile(op->buffer_var.get())) { os << "("; PrintType(op->dtype, os); os << ")(" << value << ")"; @@ -979,6 +1028,25 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val return; } + if (t.is_bfloat16()) { + if (i == 0) { + os << "make_"; + PrintType(t, os); + os << '('; + } + if (i % 2 == 0) { + os << "__pack_bfloat162(" << value; + } else { + os << "," << value << ")"; + if (i != t.lanes() - 1) { + os << ","; + } else { + os << ")"; + } + } + return; + } + if (i == 0) { os << "make_"; PrintType(t, os); diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index 3cde8e379eb4..2098b8ac8344 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -42,7 +42,7 @@ class CodeGenCUDA final : public CodeGenC { void Init(bool output_ssa); std::string Finish(); bool need_include_path() { - return (enable_fp16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_); + return (enable_fp16_ || enable_bf16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_); } // override behavior void PrintFuncPrefix() final; @@ -88,6 +88,8 @@ class CodeGenCUDA final : public CodeGenC { std::string vid_global_barrier_expect_; // whether enable fp16 bool enable_fp16_{false}; + // whether enable bf16 + bool enable_bf16_{false}; // whether enable int8 bool enable_int8_{false}; // whether enable warp shuffle intrinsics diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index baa30065a7f9..c95d578df686 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -47,7 +47,7 @@ CodeGenMetal::CodeGenMetal() { decl_stream << "#include \n"; decl_stream << "using namespace metal;\n\n"; decl_stream << "union __TVMArgUnion {\n" - << " int v_int;\n" + << " int v_int[2];\n" << "};\n\n"; } @@ -102,6 +102,11 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { std::string vid = AllocVarID(v.get()); std::ostringstream vref; if (v.dtype().bits() == 32) { + decl_stream << " "; + PrintType(v.dtype(), decl_stream); + decl_stream << " " << vid << "[2];\n"; + vref << varg << "." << vid << "[0]"; + } else if (v.dtype().bits() == 64) { decl_stream << " "; PrintType(v.dtype(), decl_stream); decl_stream << " " << vid << ";\n"; diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index 5c562f7b1643..965b86c24d9e 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -43,6 +43,8 @@ struct CUDAMath { default: return ""; } + } else if (t.is_bfloat16()) { + return 'h' + name; } return ""; } diff --git a/src/target/source/literal/cuda_half_t.h b/src/target/source/literal/cuda_half_t.h index f8e92d508d88..3888f3a4fb07 100644 --- a/src/target/source/literal/cuda_half_t.h +++ b/src/target/source/literal/cuda_half_t.h @@ -311,6 +311,30 @@ static inline __device__ __host__ half htanh(half x) { #endif )"; +static constexpr const char* _cuda_bfloat16_util = R"( +// Pack two bfloat16 values. +static inline __device__ __host__ unsigned +__pack_nv_bfloat162(const nv_bfloat16 x, const nv_bfloat16 y) { + unsigned v0 = *((unsigned short *)&x); + unsigned v1 = *((unsigned short *)&y); + return (v1 << 16) | v0; +} + +// fix undefined fp16 match function +static inline __device__ __host__ nv_bfloat16 hpow(nv_bfloat16 x, nv_bfloat16 y) { + float tmp_x = __bfloat162float(x); + float tmp_y = __bfloat162float(y); + float result = powf(tmp_x, tmp_y); + return __float2bfloat16(result); +} + +static inline __device__ __host__ nv_bfloat16 htanh(nv_bfloat16 x) { + float tmp_x = __bfloat162float(x); + float result = tanhf(tmp_x); + return __float2bfloat16(result); +} +)"; + static constexpr const char* _cuda_warp_intrinsic_util = R"( #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700) #define __shfl_sync(mask, var, lane, width) \ diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 51d136d5510e..24608ebc93f4 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -45,10 +45,15 @@ std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std:: if (auto* ptr = arg->type_annotation.as()) { auto* prim = ptr->element_type.as(); ICHECK(prim); - DataType value_type = prim->dtype; + DataType value_storage_type = prim->dtype; + if (value_storage_type == DataType::UInt(1)) { + // We need a physically addressable buffer type to support boolean tensors. + // The loaded byte is cast to bool inside the LoadNode visitor below. + value_storage_type = DataType::UInt(8); + } spirv::Value arg_value = - builder_->BufferArgument(builder_->GetSType(value_type), 0, num_buffer); - storage_info_[arg.get()].UpdateContentType(value_type); + builder_->BufferArgument(builder_->GetSType(value_storage_type), 0, num_buffer); + storage_info_[arg.get()].UpdateContentType(value_storage_type); var_map_[arg.get()] = arg_value; } else { LOG(FATAL) << "require all handles to be typed"; @@ -369,11 +374,18 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { mask |= spv::MemoryAccessVolatileMask; } if (op->dtype.lanes() == 1) { - ICHECK_EQ(info.content_type, op->dtype) - << "Vulkan only allow one type access to the same buffer"; spirv::Value index = MakeValue(op->index); spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); - return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask); + spirv::Value loaded = builder_->MakeValue(spv::OpLoad, content_type, ptr, mask); + if (op->dtype == DataType::UInt(1)) { + // A bool tensor is backed by a byte buffer, we cast to bool here. + auto bool_ty = builder_->GetSType(DataType::UInt(1)); + return builder_->Cast(bool_ty, loaded); + } else { + ICHECK_EQ(info.content_type, op->dtype) + << "Vulkan only allow one type access to the same buffer"; + return loaded; + } } else { if (op->dtype.element_of() == info.content_type) { // because content type is element type, we can only do scalarize load. @@ -514,6 +526,34 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { builder_->StartLabel(merge_label); } +void CodeGenSPIRV::VisitStmt_(const WhileNode* op) { + spirv::Label head_label = builder_->NewLabel(); + spirv::Label body_label = builder_->NewLabel(); + spirv::Label continue_label = builder_->NewLabel(); + spirv::Label merge_label = builder_->NewLabel(); + builder_->MakeInst(spv::OpBranch, head_label); + + // Loop head + builder_->StartLabel(head_label); + spirv::Value loop_cond = MakeValue(op->condition); + uint32_t control = spv::LoopControlMaskNone; + builder_->MakeInst(spv::OpLoopMerge, merge_label, continue_label, control); + builder_->MakeInst(spv::OpBranchConditional, loop_cond, body_label, merge_label, + weight_likely_branch_, 1); + + // loop body + builder_->StartLabel(body_label); + this->VisitStmt(op->body); + builder_->MakeInst(spv::OpBranch, continue_label); + + // loop continue + builder_->StartLabel(continue_label); + builder_->MakeInst(spv::OpBranch, head_label); + + // loop merge + builder_->StartLabel(merge_label); +} + void CodeGenSPIRV::VisitStmt_(const IfThenElseNode* op) { spirv::Value cond = MakeValue(op->condition); spirv::Label then_label = builder_->NewLabel(); diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index be755641c8a5..1e80fcc4a931 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -93,6 +93,7 @@ class CodeGenSPIRV : public ExprFunctor, // stmt void VisitStmt_(const StoreNode* op) override; void VisitStmt_(const ForNode* op) override; + void VisitStmt_(const WhileNode* op) override; void VisitStmt_(const IfThenElseNode* op) override; void VisitStmt_(const AllocateNode* op) override; void VisitStmt_(const AttrStmtNode* op) override; diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc index 90b2eb2a671f..b75fb53b150d 100644 --- a/src/target/spirv/intrin_rule_spirv.cc +++ b/src/target/spirv/intrin_rule_spirv.cc @@ -62,8 +62,14 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.fabs").set_body(DispatchGLSLPureIntr TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp").set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sin").set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.cos").set_body(DispatchGLSLPureIntrin); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log").set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log2").set_body(DispatchGLSLPureIntrin); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sqrt").set_body(DispatchGLSLPureIntrin); TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow").set_body(DispatchGLSLPureIntrin); diff --git a/src/te/operation/op_utils.cc b/src/te/operation/op_utils.cc index 32ffccbbec1f..b3897e142545 100644 --- a/src/te/operation/op_utils.cc +++ b/src/te/operation/op_utils.cc @@ -243,6 +243,14 @@ Stmt Substitute(Stmt s, const std::unordered_map& value_map) return tir::Substitute(s, init); } +PrimExpr Substitute(PrimExpr s, const std::unordered_map& value_map) { + std::unordered_map init; + for (const auto& kv : value_map) { + init[kv.first->var.get()] = kv.second; + } + return tir::Substitute(s, init); +} + IterVarType ForKindToIterVarType(tir::ForKind kind) { switch (kind) { case ForKind::kSerial: diff --git a/src/te/operation/op_utils.h b/src/te/operation/op_utils.h index e6bf2caae6e0..02f4a860a01d 100644 --- a/src/te/operation/op_utils.h +++ b/src/te/operation/op_utils.h @@ -73,7 +73,7 @@ std::vector MakeIfNest(const std::vector& predicates); */ Stmt ReplaceTensor(Stmt stmt, const std::unordered_map& replace); /*! - * \brief Replace the tensor reference (especially in Call's) in stmt by the replace map. + * \brief Replace the tensor reference (especially in Call's) in primExpr by the replace map. * \param expr The expression to be processed. * \param replace The replacement rule. */ @@ -87,6 +87,14 @@ PrimExpr ReplaceTensor(PrimExpr expr, const std::unordered_map& */ Stmt Substitute(Stmt stmt, const std::unordered_map& value_map); +/*! + * \brief Substitute the variables of primExpr by value map. + * \param expr the expression to be processed. + * \param value_map The value map. + * \return Substituted result. + */ +PrimExpr Substitute(PrimExpr expr, const std::unordered_map& value_map); + /*! * \brief Converts Halide ForKind to its corresponding IterVarType * \param kind The ForKind to be converted diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index bfd1ec579818..ea713220eddd 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -311,6 +311,7 @@ Array MatchTensorizeBody(const ComputeOpNode* self, const Stage& stage } void VerifyTensorizeBody(const ComputeOpNode* self, const Stage& stage, + const std::unordered_map& value_map, const std::unordered_map& dom_map, const std::unordered_map& out_dom, const std::unordered_map >& in_region, @@ -327,7 +328,8 @@ void VerifyTensorizeBody(const ComputeOpNode* self, const Stage& stage, for (size_t i = 0; i < body.size(); ++i) { PrimExpr lhs = ana.Simplify(body[i]); - PrimExpr rhs = ana.Simplify(intrin_compute->body[i]); + // run substitution because the intrin body could depend on outer loop vars. + PrimExpr rhs = ana.Simplify(Substitute(intrin_compute->body[i], value_map)); if (lhs.dtype() != rhs.dtype()) { LOG(FATAL) << "Failed to match the data type with TensorIntrin " << intrin->name << "'s declaration " @@ -349,7 +351,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, ICHECK(intrin.defined()); ComputeLoopNest n = ComputeLoopNest::Create(self, stage, dom_map, debug_keep_trivial_loop); VerifyTensorizeLoopNest(self, stage, n, tloc); - VerifyTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin); + VerifyTensorizeBody(self, stage, n.main_vmap, dom_map, out_dom, in_region, intrin); // Start bind data. Stmt nop = Evaluate(0); std::vector input_bind_nest, output_bind_nest; diff --git a/src/te/schedule/auto_inline_elem_wise.cc b/src/te/schedule/auto_inline_elem_wise.cc index e2b7215158b2..bf584df25825 100644 --- a/src/te/schedule/auto_inline_elem_wise.cc +++ b/src/te/schedule/auto_inline_elem_wise.cc @@ -39,15 +39,15 @@ class ElemWiseDetector : public tir::ExprVisitor { ExprVisitor::VisitExpr(e); } - void VisitExpr_(const CallNode* op) final { - Array axis = op->args; - if (axis_.size() != axis.size()) { + void VisitExpr_(const ProducerLoadNode* op) final { + Array indices = op->indices; + if (axis_.size() != indices.size()) { is_elem_wise_ = false; return; } for (size_t i = 0; i < axis_.size(); ++i) { - if (!axis[i].same_as(axis_[i]->var)) { + if (!indices[i].same_as(axis_[i]->var)) { is_elem_wise_ = false; return; } @@ -83,7 +83,11 @@ bool IsBroadcast(const Operation& op) { if (compute->reduce_axis.size()) { return false; } - // TODO(nicolasvasilache): Implement Me + constexpr auto kBroadcast = "broadcast"; + // broadcast op in topi has tag `broadcast` + if (op->tag == kBroadcast) { + return true; + } } return false; } @@ -113,6 +117,8 @@ void AutoInlineInjective(Schedule sch) { TVM_REGISTER_GLOBAL("schedule.AutoInlineElemWise").set_body_typed(AutoInlineElemWise); +TVM_REGISTER_GLOBAL("schedule.AutoInlineBroadcast").set_body_typed(AutoInlineBroadcast); + TVM_REGISTER_GLOBAL("schedule.AutoInlineInjective").set_body_typed(AutoInlineInjective); } // namespace te diff --git a/tests/cpp/profiling.cc b/tests/cpp/profiling.cc new file mode 100644 index 000000000000..6ec2fc060f9f --- /dev/null +++ b/tests/cpp/profiling.cc @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +#include +#include + +namespace tvm { +namespace runtime { +TEST(DefaultTimer, Basic) { + using namespace tvm::runtime; + DLContext ctx; + ctx.device_type = kDLCPU; + ctx.device_id = 0; + + Timer t = Timer::Start(ctx); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + t->Stop(); + int64_t elapsed = t->SyncAndGetElapsedNanos(); + CHECK_GT(elapsed, 9 * 1e6); +} +} // namespace runtime +} // namespace tvm + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} diff --git a/tests/python/contrib/test_ethosn/test_networks.py b/tests/python/contrib/test_ethosn/test_networks.py index e0eccdfb30f5..06ce93b2aba5 100644 --- a/tests/python/contrib/test_ethosn/test_networks.py +++ b/tests/python/contrib/test_ethosn/test_networks.py @@ -127,6 +127,10 @@ def test_mobilenet_v1(): _compile_hash = {"47e216d8ab2bf491708ccf5620bc0d02"} if tei.get_ethosn_variant() == 3: _compile_hash = {"2436f523e263f66a063cef902f2f43d7"} + if tei.get_ethosn_api_version() == 2011: + _compile_hash = {"9298b6c51e2a82f70e91dd11dd6af412"} + if tei.get_ethosn_variant() == 3: + _compile_hash = {"407eb47346c8afea2d15e8f0d1c079f2"} _test_image_network( model_url="https://storage.googleapis.com/download.tensorflow.org/" "models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz", @@ -151,6 +155,10 @@ def test_inception_v3(): _compile_hash = {"8c9d75659cd7bc9ff6dd6d490d28f9b2"} if tei.get_ethosn_variant() == 3: _compile_hash = {"cdd4d7f6453d722ea73224ff9d6a115a"} + if tei.get_ethosn_api_version() == 2011: + _compile_hash = {"d44eece5027ff56e5e7fcf014367378d"} + if tei.get_ethosn_variant() == 3: + _compile_hash = {"1ba555b4bc60c428018a0f2de9d90532"} _test_image_network( model_url="https://storage.googleapis.com/download.tensorflow.org/" "models/tflite_11_05_08/inception_v3_quant.tgz", @@ -169,11 +177,17 @@ def test_inception_v4(): # codegen, which could come about from either a change in Support Library # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. - if not tei.get_ethosn_variant() == 0: - pytest.skip("Ethos-N78 20.08 does not support inception_v4 in the default configuration.") _compile_hash = {"06bf6cb56344f3904bcb108e54edfe87"} if tei.get_ethosn_api_version() == 2008: + if not tei.get_ethosn_variant() == 0: + pytest.skip( + "Ethos-N78 20.08 does not support inception_v4 in the default configuration." + ) _compile_hash = {"798292bfa596ca7c32086396b494b46c"} + if tei.get_ethosn_api_version() == 2011: + _compile_hash = {"53f126cf654d4cf61ebb23c767f6740b"} + if tei.get_ethosn_variant() == 3: + _compile_hash = {"851665c060cf4719248919d17325ae02"} _test_image_network( model_url="https://storage.googleapis.com/download.tensorflow.org/" "models/inception_v4_299_quant_20181026.tgz", @@ -197,6 +211,10 @@ def test_ssd_mobilenet_v1(): _compile_hash = {"5999f26e140dee0d7866491997ef78c5", "24e3a690a7e95780052792d5626c85be"} if tei.get_ethosn_variant() == 3: _compile_hash = {"da871b3f03a93df69d704ed44584d6cd", "9f52411d301f3cba3f6e4c0f1c558e87"} + if tei.get_ethosn_api_version() == 2011: + _compile_hash = {"6e8c4586bdd26527c642a4f016f52284", "057c5efb094c79fbe4483b561147f1d2"} + if tei.get_ethosn_variant() == 3: + _compile_hash = {"dc687e60a4b6750fe740853f22aeb2dc", "1949d86100004eca41099c8e6fa919ab"} _test_image_network( model_url="https://storage.googleapis.com/download.tensorflow.org/" "models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip", diff --git a/tests/python/contrib/test_ethosn/test_reshape.py b/tests/python/contrib/test_ethosn/test_reshape.py index 4afec557e569..20df5f9bd288 100644 --- a/tests/python/contrib/test_ethosn/test_reshape.py +++ b/tests/python/contrib/test_ethosn/test_reshape.py @@ -37,8 +37,8 @@ def test_reshape(): return trials = [ - ((1, 15, 4, 1), (60,)), - ((1, 15, 4, 1), (30, 2)), + ((1, 15, 4, 1), (1, 60)), + ((1, 15, 4, 1), (1, 30, 2)), ((1, 15, 4, 1), (1, 4, 15, 1)), ((1, 15, 4, 1), (1, 12, 5, 1)), ((1, 15, 4, 1), (1, -1, 2, 1)), diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index 7ddc4e762cfd..ae8214d6463c 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -22,7 +22,7 @@ import tvm import tvm.relay.testing -from tvm import relay +from tvm import relay, runtime from tvm.relay.op.contrib import tensorrt from tvm.contrib import graph_runtime, utils from tvm.runtime.vm import VirtualMachine @@ -71,6 +71,14 @@ def assert_result_dict_holds(result_dict): tvm.testing.assert_allclose(r1, r2, rtol=1e-3, atol=1e-3) +def set_func_attr(func, compile_name, symbol_name): + func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + func = func.with_attr("Compiler", compile_name) + func = func.with_attr("global_symbol", symbol_name) + return func + + def run_and_verify_func(config, target="cuda"): """Test a Relay func by compiling, running, and comparing TVM and TRT outputs. @@ -257,7 +265,7 @@ def test_tensorrt_serialize_graph_runtime(): def compile_graph(mod, params): with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}): graph, lib, params = relay.build(mod, params=params, target="cuda") - params = relay.save_param_dict(params) + params = runtime.save_param_dict(params) return graph, lib, params def run_graph(graph, lib, params): @@ -1109,13 +1117,6 @@ def test_dynamic_offload(): kernel = relay.var("kernel", shape=(k_shape), dtype="float32") def get_expected(): - def set_func_attr(func, compile_name, symbol_name): - func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) - func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - func = func.with_attr("Compiler", compile_name) - func = func.with_attr("global_symbol", symbol_name) - return func - # Create a nested TRT function that matches the expected output mod = tvm.IRModule() var1 = relay.var("tensorrt_0_i0", shape=(data_shape), dtype="float32") @@ -1331,5 +1332,32 @@ def get_maskrcnn_input(in_size: int) -> np.ndarray: ) +def test_empty_subgraph(): + if skip_codegen_test(): + return + x_shape = (1, 3, 5) + mod = tvm.IRModule() + # Empty tensorrt subgraph. + var1 = relay.var("tensorrt_0_i0", shape=(x_shape), dtype="float32") + f1 = GlobalVar("tensorrt_0") + func = relay.Function([var1], var1) + func = set_func_attr(func, "tensorrt", "tensorrt_0") + mod[f1] = func + mod = relay.transform.InferType()(mod) + + # Create the main function + x = relay.var("x", shape=x_shape, dtype="float32") + out = f1(relay.nn.relu(x)) + f = relay.Function([x], out) + mod["main"] = f + + x_data = np.random.uniform(-1, 1, x_shape).astype("float32") + for mode in ["graph", "vm"]: + with tvm.transform.PassContext(opt_level=3): + exec = relay.create_executor(mode, mod=mod, ctx=tvm.gpu(0), target="cuda") + if not skip_runtime_test(): + results = exec.evaluate()(x_data) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 3e652cfc69e3..4eb7f6139e8f 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -1064,14 +1064,23 @@ def verify(shape, axis, is_ascend, dtype="float32"): @tvm.testing.uses_gpu def test_forward_topk(): - def verify(shape, k, axis, ret_type, is_ascend=False, dtype="float32"): + def verify(shape, k, axis, ret_type, is_ascend=None, dtype="float32"): x_np = np.random.uniform(size=shape).astype("float32") - ref_res = mx.nd.topk( - mx.nd.array(x_np), k=k, axis=axis, ret_typ=ret_type, is_ascend=is_ascend, dtype=dtype - ) - mx_sym = mx.sym.topk( - mx.sym.var("x"), k=k, axis=axis, ret_typ=ret_type, is_ascend=is_ascend, dtype=dtype - ) + if is_ascend is None: + ref_res = mx.nd.topk(mx.nd.array(x_np), k=k, axis=axis, ret_typ=ret_type, dtype=dtype) + mx_sym = mx.sym.topk(mx.sym.var("x"), k=k, axis=axis, ret_typ=ret_type, dtype=dtype) + else: + ref_res = mx.nd.topk( + mx.nd.array(x_np), + k=k, + axis=axis, + ret_typ=ret_type, + is_ascend=is_ascend, + dtype=dtype, + ) + mx_sym = mx.sym.topk( + mx.sym.var("x"), k=k, axis=axis, ret_typ=ret_type, is_ascend=is_ascend, dtype=dtype + ) mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: @@ -1086,7 +1095,7 @@ def verify(shape, k, axis, ret_type, is_ascend=False, dtype="float32"): verify((3, 4), k=1, axis=0, ret_type="both") verify((3, 4), k=1, axis=-1, ret_type="indices") - verify((3, 5, 6), k=2, axis=2, ret_type="value") + verify((3, 5, 6), k=2, axis=2, ret_type="value", is_ascend=False) verify((3, 5, 6), k=2, axis=1, ret_type="value", is_ascend=True) verify((3, 5, 6), k=0, axis=2, ret_type="both", dtype="int32") diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 07e52b7079e8..29c69abba542 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -41,7 +41,6 @@ def torch_version_check(): def get_tvm_runtime(script_module, input_name, ishape): - input_shapes = [(input_name, ishape)] mod, params = relay.frontend.from_pytorch(script_module, input_shapes) @@ -125,43 +124,40 @@ def fuse_model(self): # Mobilenet V3 related modules class Hsigmoid(nn.Module): - def __init__(self, inplace=True, add_stub=False): + def __init__(self, add_stub=False): super().__init__() - self.float_op = nn.quantized.FloatFunctional() - self.relu6 = nn.ReLU6(inplace=inplace) self.quant = QuantStub() self.dequant = DeQuantStub() self.add_stub = add_stub + self.hsigmoid = nn.Hardsigmoid() def forward(self, x): if self.add_stub: x = self.quant(x) - relu6 = self.relu6(self.float_op.add_scalar(x, 3.0)) - mul = self.float_op.mul_scalar(relu6, 1 / 6.0) + x = self.hsigmoid(x) if self.add_stub: - mul = self.dequant(mul) - return mul + x = self.dequant(x) + return x def fuse_model(self): pass class Hswish(nn.Module): - def __init__(self, inplace=True, add_stub=False): - super(Hswish, self).__init__() - self.float_op = nn.quantized.FloatFunctional() - self.hsigmoid = Hsigmoid(inplace, add_stub=False) + def __init__(self, add_stub=False): + super().__init__() self.quant = QuantStub() self.dequant = DeQuantStub() self.add_stub = add_stub + self.hswish = nn.Hardswish() def forward(self, x): if self.add_stub: x = self.quant(x) - mul = self.float_op.mul(x, self.hsigmoid(x)) + x = self.hswish(x) if self.add_stub: - mul = self.dequant(mul) - return mul + x = self.dequant(x) + return x def fuse_model(self): pass @@ -274,18 +270,12 @@ def test_quantized_modules(): ("conv_bn_relu" + postfix, imagenet_ishape, ConvBn(with_relu=True), per_channel), ("linear" + postfix, (16, 16), Linear(), per_channel), ("linear_relu" + postfix, (16, 16), Linear(with_relu=True), per_channel), - ] - - if torch_version_check(): - qmodules += [ ("hsigmoid", imagenet_ishape, Hsigmoid(add_stub=True), False), ("hswish", imagenet_ishape, Hswish(add_stub=True), False), ("semodule", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), False), ("semodule, per_channel", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), True), ("mul_scalar negative", imagenet_ishape, MulScalarNegative(), False), ] - else: - print("Skipping tests that require torch > 1.4") for (module_name, ishape, raw_module, per_channel) in qmodules: raw_module.eval() @@ -372,6 +362,13 @@ def get_imagenet_input(): # ("googlenet", qgooglenet(pretrained=True), per_channel), ] + if is_version_greater_than("1.7.1"): + from torchvision.models.quantization import mobilenet_v3_large as qmobilenet_v3_large + + qmodels.append( + ("mobilenet_v3_large", qmobilenet_v3_large(pretrained=True, quantize=True).eval(), True) + ) + results = [] for (model_name, raw_model, per_channel) in qmodels: @@ -385,7 +382,10 @@ def get_imagenet_input(): inp = get_imagenet_input() pt_inp = torch.from_numpy(inp) - quantize_model(raw_model, pt_inp, per_channel=per_channel) + if "mobilenet_v3_large" not in model_name: + # mv3 was qat-ed, quantize=True option above makes it already quantized + quantize_model(raw_model, pt_inp, per_channel=per_channel) + script_module = torch.jit.trace(raw_model, pt_inp).eval() with torch.no_grad(): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 9f035ade7a21..83c1698799c7 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -24,6 +24,7 @@ import torch import torchvision from torch.nn import Module +from torch.nn import functional as F import tvm from tvm import relay from tvm.contrib import graph_runtime @@ -200,6 +201,8 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)] input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input])) mod, params = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map) + for arg in mod["main"].params[: len(input_names)]: + assert arg.name_hint in input_names compiled_input = dict(zip(input_names, [inp.clone().cpu().numpy() for inp in baseline_input])) with tvm.transform.PassContext(opt_level=3): @@ -1459,6 +1462,39 @@ def forward(self, *args): assert not any([op.name == "multiply" for op in list_ops(mod["main"])]) +@tvm.testing.uses_gpu +def test_forward_linear(): + torch.set_grad_enabled(False) + + class Linear(Module): + def forward(self, input, weight, bias): + return F.linear(input, weight, bias) + + class LinearNoBias(Module): + def forward(self, input, weight): + return F.linear(input, weight) + + input2d = torch.rand([2, 2]).float() + weight1d = torch.rand([2]).float() + weight2d = torch.rand([2, 2]).float() + bias1d = torch.rand([2]).float() + bias2d = torch.rand([2, 2]).float() + # 2D input, 2D weight, 1D bias + verify_model(Linear(), input_data=[input2d, weight2d, bias1d]) + # 2D input, 2D weight, 2D bias + verify_model(Linear(), input_data=[input2d, weight2d, bias2d]) + # 2D input, 2D weight, no bias + verify_model(LinearNoBias(), input_data=[input2d, weight2d]) + # 2D input, 1D weight, 1D bias is not supported by torch.linear() + # 2D input, 1D weight, no bias + verify_model(LinearNoBias(), input_data=[input2d, weight1d]) + # TODO: Add the following cases when matmul(1D, _) is supported by TVM + # 1D input, 2D weight, 1D bias + # 1D input, 2D weight, no bias + # 1D input, 1D weight, scalar bias + # 1D input, 1D weight, no bias + + @tvm.testing.uses_gpu def test_forward_dropout(): torch.set_grad_enabled(False) @@ -3615,6 +3651,13 @@ def test_hard_swish(): verify_model(torch.nn.Hardswish(inplace=True).eval(), input_data=input) +def test_hard_sigmoid(): + examples = [torch.rand(8).float(), torch.rand(8, 10).float(), torch.rand(1, 1, 10).float()] + for input in examples: + verify_model(torch.nn.Hardsigmoid().eval(), input_data=input) + verify_model(torch.nn.Hardsigmoid(inplace=True).eval(), input_data=input) + + def test_cumsum(): def test_fn(dim, dtype=None): return lambda x: torch.cumsum(x, dim=dim, dtype=dtype) @@ -3670,13 +3713,13 @@ def test_fn(dim, descending): inp = torch.randn(100) verify_model(test_fn(0, True), [inp]) - verify_model(test_fn(0, False), [inp]) + verify_model(test_fn(-1, False), [inp]) inp = torch.randn(100, 100) verify_model(test_fn(0, True), [inp]) - verify_model(test_fn(0, False), [inp]) + verify_model(test_fn(-2, False), [inp]) verify_model(test_fn(1, True), [inp]) - verify_model(test_fn(1, False), [inp]) + verify_model(test_fn(-1, False), [inp]) def test_logical_and(): @@ -3857,6 +3900,8 @@ def test_fn(is_sorted, return_inverse, return_counts): test_logical_and() test_masked_select() test_unique() + test_hard_swish() + test_hard_sigmoid() # Model tests test_resnet18() @@ -3895,4 +3940,3 @@ def test_fn(is_sorted, return_inverse, return_counts): # Test convert torch script(jit) with specific inputs' types test_convert_torch_script_with_input_types() - test_hard_swish() diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 41145bf77218..81aeb5ef886c 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -2080,6 +2080,181 @@ def test_forward_sparse_reshape( _test_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, use_dyn) +####################################################################### +# Sparse Segment Variants +# ------------ + + +def _test_sparse_segment_variant( + tf_op, data_np, indices_np, segment_ids_np, num_segments, use_dyn=False +): + with tf.Graph().as_default(): + if use_dyn: + data = tf.placeholder( + shape=[None for _ in data_np.shape], dtype=data_np.dtype, name="data" + ) + indices = tf.placeholder(shape=[None], dtype=indices_np.dtype, name="indices") + segment_ids = tf.placeholder( + shape=(None), dtype=segment_ids_np.dtype, name="segment_ids" + ) + else: + data = tf.placeholder(shape=data_np.shape, dtype=data_np.dtype, name="data") + indices = tf.placeholder(shape=indices_np.shape, dtype=indices_np.dtype, name="indices") + segment_ids = tf.placeholder( + shape=segment_ids_np.shape, dtype=segment_ids_np.dtype, name="segment_ids" + ) + + _ = tf_op( + data, indices, segment_ids, num_segments=num_segments, name="sparse_segment_variant" + ) + compare_tf_with_tvm( + [data_np, indices_np, segment_ids_np], + [data.name, indices.name, segment_ids.name], + ["sparse_segment_variant:0"], + mode="vm", + ) + + +@pytest.mark.parametrize( + "data_np, indices_np, segment_ids_np, num_segments", + [ + ( + np.array([5, 1, 7, 2, 3, 4], dtype=np.float32), + np.array([0, 3, 4], dtype=np.int32), + np.array([0, 1, 1], dtype=np.int32), + None, + ), + ( + np.array([[1, 2, 3, 4], [-1, -2, -3, -4], [5, 6, 7, 8]], dtype=np.float64), + np.array([0, 1], dtype=np.int32), + np.array([0, 2], dtype=np.int32), + 4, + ), + ( + np.random.random((6, 4, 5)), + np.array([0, 2, 4, 3, 1], dtype=np.int32), + np.array([0, 0, 1, 5, 5], dtype=np.int32), + 100, + ), + ( + np.random.random((6, 4, 5)), + np.array([0, 2, 4, 3, 1], dtype=np.int32), + np.array([0, 0, 1, 5, 5], dtype=np.int32), + None, + ), + ( + np.array([[[1, 7]], [[3, 8]], [[2, 9]]], dtype=np.float64), + np.array([0, 1, 2], dtype=np.int32), + np.array([0, 0, 1], dtype=np.int32), + None, + ), + ( + np.random.random((9, 4, 5, 7)), + np.array([0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=np.int32), + np.array([0, 0, 1, 3, 5, 6, 7, 7, 8], dtype=np.int32), + 9, + ), + ( + np.random.random((9, 4, 5, 7)), + np.array([0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=np.int32), + np.array([0, 0, 1, 3, 5, 6, 7, 7, 8], dtype=np.int32), + None, + ), + ( + np.array([[1, 2, 3, 4], [-1, -2, -3, -4], [5, 6, 7, 8]], dtype=np.float64), + np.array([0, 1], dtype=np.int32), + np.array([0, 2], dtype=np.int32), + None, + ), + ( + np.random.random((9, 4, 5, 7)), + np.array([0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=np.int32), + np.array([0, 0, 1, 3, 5, 5, 5, 5, 5], dtype=np.int32), + 6, + ), + ], +) +@pytest.mark.parametrize("use_dyn", [True, False]) +@pytest.mark.parametrize( + "tf_op", + [ + tf.sparse.segment_sum, + tf.sparse.segment_sqrt_n, + tf.sparse.segment_mean, + ], +) +def test_forward_sparse_segment_sum_variants( + tf_op, + data_np, + indices_np, + segment_ids_np, + num_segments, + use_dyn, +): + """sparse segment sum variants tests""" + _test_sparse_segment_variant(tf_op, data_np, indices_np, segment_ids_np, num_segments, use_dyn) + + +####################################################################### +# Math SegmentSum +# ------------ + + +def _test_math_segment_sum(data_np, segment_ids_np, use_dyn=False): + with tf.Graph().as_default(): + if use_dyn: + data = tf.placeholder( + shape=[None for _ in data_np.shape], dtype=data_np.dtype, name="data" + ) + segment_ids = tf.placeholder( + shape=(None), dtype=segment_ids_np.dtype, name="segment_ids" + ) + else: + data = tf.placeholder(shape=data_np.shape, dtype=data_np.dtype, name="data") + segment_ids = tf.placeholder( + shape=segment_ids_np.shape, dtype=segment_ids_np.dtype, name="segment_ids" + ) + + _ = tf.math.segment_sum(data, segment_ids, name="segment_sum") + compare_tf_with_tvm( + [data_np, segment_ids_np], + [data.name, segment_ids.name], + ["segment_sum:0"], + mode="vm", + ) + + +@pytest.mark.parametrize( + "data_np, segment_ids_np", + [ + ( + np.array([5, 1, 7, 2, 3, 4], dtype=np.float32), + np.array([0, 0, 0, 1, 1, 1], dtype=np.int32), + ), + ( + np.array([[1, 2, 3, 4], [-1, -2, -3, -4], [5, 6, 7, 8]], dtype=np.float64), + np.array([0, 0, 1], dtype=np.int32), + ), + ( + np.random.random((6, 4, 5)), + np.array([0, 0, 1, 2, 2, 3], dtype=np.int64), + ), + ( + np.array([[[1, 7]], [[3, 8]], [[2, 9]]], dtype=np.float32), + np.array([0, 0, 1], dtype=np.int32), + ), + ( + np.random.random((9, 4, 5, 7)), + np.array([0, 0, 0, 1, 2, 3, 4, 4, 5], dtype=np.int64), + ), + ], +) +@pytest.mark.parametrize("use_dyn", [True, False]) +def test_forward_math_segment_sum(data_np, segment_ids_np, use_dyn): + """math segment sum test""" + _test_math_segment_sum(data_np, segment_ids_np, use_dyn) + + # tensorflow.compat.v1.sparse_to_dense # --------------- def _test_sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape): diff --git a/tests/python/integration/test_tuning.py b/tests/python/integration/test_tuning.py index 64b2c16e155e..813352c52096 100644 --- a/tests/python/integration/test_tuning.py +++ b/tests/python/integration/test_tuning.py @@ -18,9 +18,14 @@ Test the tuner """ import logging +import sys +import textwrap import time +import pytest + import tvm +import tvm.relay from tvm import te from tvm import autotvm @@ -29,94 +34,100 @@ import tvm.testing -@autotvm.template("testing/conv2d_no_batching") -def conv2d_no_batching(N, H, W, CI, CO, KH, KW): - """An example template for testing""" - assert N == 1, "Only consider batch_size = 1 in this template" - - data = te.placeholder((N, CI, H, W), name="data") - kernel = te.placeholder((CO, CI, KH, KW), name="kernel") - - rc = te.reduce_axis((0, CI), name="rc") - ry = te.reduce_axis((0, KH), name="ry") - rx = te.reduce_axis((0, KW), name="rx") - - conv = te.compute( - (N, CO, H - KH + 1, W - KW + 1), - lambda nn, ff, yy, xx: te.sum( - data[nn, rc, yy + ry, xx + rx] * kernel[ff, rc, ry, rx], axis=[rc, ry, rx] - ), - tag="conv2d_nchw", - ) - - s = te.create_schedule([conv.op]) - - output = conv - OL = s.cache_write(conv, "local") - - # create cache stage - AA = s.cache_read(data, "shared", [OL]) - WW = s.cache_read(kernel, "shared", [OL]) - AL = s.cache_read(AA, "local", [OL]) - WL = s.cache_read(WW, "local", [OL]) - - # tile and bind spatial axes - n, f, y, x = s[output].op.axis - cfg = autotvm.get_config() - cfg.define_split("tile_f", cfg.axis(f), num_outputs=4) - cfg.define_split("tile_y", cfg.axis(y), num_outputs=4) - cfg.define_split("tile_x", cfg.axis(x), num_outputs=4) - bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f) - by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) - bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) - kernel_scope = n # this is the scope to attach global config inside this kernel - - s[output].bind(bf, te.thread_axis("blockIdx.z")) - s[output].bind(by, te.thread_axis("blockIdx.y")) - s[output].bind(bx, te.thread_axis("blockIdx.x")) - s[output].bind(vf, te.thread_axis("vthread")) - s[output].bind(vy, te.thread_axis("vthread")) - s[output].bind(vx, te.thread_axis("vthread")) - s[output].bind(tf, te.thread_axis("threadIdx.z")) - s[output].bind(ty, te.thread_axis("threadIdx.y")) - s[output].bind(tx, te.thread_axis("threadIdx.x")) - s[output].reorder(n, bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi) - s[OL].compute_at(s[output], tx) - - # tile and bind reduction axes - n, f, y, x = s[OL].op.axis - rc, ry, rx = s[OL].op.reduce_axis - cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3) - cfg.define_split("tile_ry", cfg.axis(ry), num_outputs=3) - cfg.define_split("tile_rx", cfg.axis(rx), num_outputs=3) - rco, rcm, rci = cfg["tile_rc"].apply(s, OL, rc) - ryo, rym, ryi = cfg["tile_rx"].apply(s, OL, ry) - rxo, rxm, rxi = cfg["tile_ry"].apply(s, OL, rx) - s[OL].reorder(rco, ryo, rxo, rcm, rym, rxm, rci, ryi, rxi, n, f, y, x) - - s[AA].compute_at(s[OL], rxo) - s[WW].compute_at(s[OL], rxo) - s[AL].compute_at(s[OL], rxm) - s[WL].compute_at(s[OL], rxm) - - # cooperative fetching - for load in [AA, WW]: - n, f, y, x = s[load].op.axis - fused = s[load].fuse(n, f, y, x) - tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2]) - ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2]) - tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2]) - s[load].bind(tz, te.thread_axis("threadIdx.z")) - s[load].bind(ty, te.thread_axis("threadIdx.y")) - s[load].bind(tx, te.thread_axis("threadIdx.x")) - - # tune unroll - cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) - cfg.define_knob("unroll_explicit", [0, 1]) - s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) - s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val) - - return s, [data, kernel, conv] +def setup_module(): + @autotvm.template("testing/conv2d_no_batching") + def conv2d_no_batching(N, H, W, CI, CO, KH, KW): + """An example template for testing""" + assert N == 1, "Only consider batch_size = 1 in this template" + + data = te.placeholder((N, CI, H, W), name="data") + kernel = te.placeholder((CO, CI, KH, KW), name="kernel") + + rc = te.reduce_axis((0, CI), name="rc") + ry = te.reduce_axis((0, KH), name="ry") + rx = te.reduce_axis((0, KW), name="rx") + + conv = te.compute( + (N, CO, H - KH + 1, W - KW + 1), + lambda nn, ff, yy, xx: te.sum( + data[nn, rc, yy + ry, xx + rx] * kernel[ff, rc, ry, rx], axis=[rc, ry, rx] + ), + tag="conv2d_nchw", + ) + + s = te.create_schedule([conv.op]) + + output = conv + OL = s.cache_write(conv, "local") + + # create cache stage + AA = s.cache_read(data, "shared", [OL]) + WW = s.cache_read(kernel, "shared", [OL]) + AL = s.cache_read(AA, "local", [OL]) + WL = s.cache_read(WW, "local", [OL]) + + # tile and bind spatial axes + n, f, y, x = s[output].op.axis + cfg = autotvm.get_config() + cfg.define_split("tile_f", cfg.axis(f), num_outputs=4) + cfg.define_split("tile_y", cfg.axis(y), num_outputs=4) + cfg.define_split("tile_x", cfg.axis(x), num_outputs=4) + bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f) + by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) + bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) + kernel_scope = n # this is the scope to attach global config inside this kernel + + s[output].bind(bf, te.thread_axis("blockIdx.z")) + s[output].bind(by, te.thread_axis("blockIdx.y")) + s[output].bind(bx, te.thread_axis("blockIdx.x")) + s[output].bind(vf, te.thread_axis("vthread")) + s[output].bind(vy, te.thread_axis("vthread")) + s[output].bind(vx, te.thread_axis("vthread")) + s[output].bind(tf, te.thread_axis("threadIdx.z")) + s[output].bind(ty, te.thread_axis("threadIdx.y")) + s[output].bind(tx, te.thread_axis("threadIdx.x")) + s[output].reorder(n, bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi) + s[OL].compute_at(s[output], tx) + + # tile and bind reduction axes + n, f, y, x = s[OL].op.axis + rc, ry, rx = s[OL].op.reduce_axis + cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3) + cfg.define_split("tile_ry", cfg.axis(ry), num_outputs=3) + cfg.define_split("tile_rx", cfg.axis(rx), num_outputs=3) + rco, rcm, rci = cfg["tile_rc"].apply(s, OL, rc) + ryo, rym, ryi = cfg["tile_rx"].apply(s, OL, ry) + rxo, rxm, rxi = cfg["tile_ry"].apply(s, OL, rx) + s[OL].reorder(rco, ryo, rxo, rcm, rym, rxm, rci, ryi, rxi, n, f, y, x) + + s[AA].compute_at(s[OL], rxo) + s[WW].compute_at(s[OL], rxo) + s[AL].compute_at(s[OL], rxm) + s[WL].compute_at(s[OL], rxm) + + # cooperative fetching + for load in [AA, WW]: + n, f, y, x = s[load].op.axis + fused = s[load].fuse(n, f, y, x) + tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2]) + ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2]) + tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2]) + s[load].bind(tz, te.thread_axis("threadIdx.z")) + s[load].bind(ty, te.thread_axis("threadIdx.y")) + s[load].bind(tx, te.thread_axis("threadIdx.x")) + + # tune unroll + cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) + cfg.define_knob("unroll_explicit", [0, 1]) + s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) + s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val) + + return s, [data, kernel, conv] + + +def teardown_module(): + # TODO(areusch): Tasks should not be registered into a global. + del autotvm.task.task.TASK_TABLE["testing/conv2d_no_batching"] def get_sample_task(target=tvm.target.cuda(), target_host=None): @@ -131,19 +142,62 @@ def get_sample_task(target=tvm.target.cuda(), target_host=None): @tvm.testing.parametrize_targets("cuda", "opencl") -def test_tuning(target, ctx): +def test_tuning_gpu(target, ctx): # init task task, target = get_sample_task(target, None) - logging.info("%s", task.config_space) + logging.info("task config space: %s", task.config_space) measure_option = autotvm.measure_option(autotvm.LocalBuilder(), autotvm.LocalRunner()) + results = [] + tuner = RandomTuner(task) - tuner.tune(n_trial=20, measure_option=measure_option) + tuner.tune( + n_trial=20, + measure_option=measure_option, + callbacks=(lambda _tuner, _inputs, rs: results.extend(rs),), + ) + assert len(results) == 20 -if __name__ == "__main__": - # only print log when invoked from main - logging.basicConfig(level=logging.DEBUG) + successful_results = [r for r in results if r.error_no == autotvm.MeasureErrorNo.NO_ERROR] + assert len(successful_results) > 0, f"No successful tuning runs: {results!r}" + + +def test_tuning_cpu(): + ir_mod = tvm.parser.fromtext( + textwrap.dedent( + """ + #[version = "0.0.5"] + def @main(%a : Tensor[(1, 3, 32, 32), float32], %b : Tensor[(3, 3, 5, 5), float32]) { + nn.conv2d(%a, %b, data_layout="NCHW", kernel_layout="OIHW") + } + """ + ) + ) + tasks = autotvm.task.relay_integration.extract_from_program( + ir_mod, {}, tvm.target.create("llvm") + ) + assert len(tasks) == 1, f"Extracted != 1 task from program: {tasks!r}" + + task = tasks[0] + + measure_option = autotvm.measure_option(autotvm.LocalBuilder(), autotvm.LocalRunner()) + + results = [] + + tuner = RandomTuner(task) + tuner.tune( + n_trial=20, + measure_option=measure_option, + callbacks=(lambda _tuner, _inputs, rs: results.extend(rs),), + ) + + assert len(results) == 20 - test_tuning() + successful_results = [r for r in results if r.error_no == autotvm.MeasureErrorNo.NO_ERROR] + assert len(successful_results) > 0, f"No successful tuning runs: {results!r}" + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index b75cc5f5e750..32292de4c8ea 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -260,6 +260,28 @@ def test_any_reshape(): verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 4)) +def verify_any_one_hot(indices_shape, indices_np_shape, depth, on_value, off_value, axis, dtype): + indices = relay.var("indices", shape=indices_shape, dtype="int32") + on_value_const = relay.const(on_value, dtype) + off_value_const = relay.const(off_value, dtype) + y = relay.one_hot(indices, on_value_const, off_value_const, depth, axis=axis, dtype=dtype) + params = [indices] + mod = tvm.IRModule() + mod["main"] = relay.Function(params, y) + + indices_npy = np.random.randint(0, depth, size=indices_np_shape).astype("int32") + out_npy = tvm.topi.testing.one_hot(indices_npy, on_value, off_value, depth, axis, dtype) + args = [indices_npy] + check_result(args, mod, out_npy) + + +@tvm.testing.uses_gpu +def test_any_one_hot(): + verify_any_one_hot(any_dims(1), (3,), 3, 1, 0, -1, "int32") + verify_any_one_hot(any_dims(2), (2, 2), 5, 0.5, -0.5, 1, "float32") + verify_any_one_hot(any_dims(4), (3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32") + + def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"): x = relay.var("x", shape=x_shape, dtype=dtype) y = relay.argwhere(x) diff --git a/tests/python/relay/test_autotvm_task_extraction.py b/tests/python/relay/test_autotvm_task_extraction.py index d6bfd8d0ec11..b3f1868969cc 100644 --- a/tests/python/relay/test_autotvm_task_extraction.py +++ b/tests/python/relay/test_autotvm_task_extraction.py @@ -102,5 +102,26 @@ def test_task_extraction(): assert len(tasks) == 31 +def test_task_extraction_for_dense_int8_cuda(): + target = "cuda" + dense = relay.op.get("nn.dense") + + def get_net(batch, in_dim, out_dim, dtype, out_dtype): + data = tvm.relay.var("data", shape=[batch, in_dim], dtype=dtype) + weight = tvm.relay.var("weight", shape=[out_dim, in_dim], dtype=dtype) + out = relay.nn.dense(data, weight, out_dtype=out_dtype) + mod, params = relay.testing.create_workload(out) + return mod, params + + mod, params = get_net(1, 16, 32, "float32", "float32") + tasks = autotvm.task.extract_from_program(mod, target=target, params=params, ops=(dense,)) + assert len(tasks) == 1 and tasks[0].name == "dense_small_batch.cuda" + + mod, params = get_net(1, 16, 32, "int8", "int32") + tasks = autotvm.task.extract_from_program(mod, target=target, params=params, ops=(dense,)) + assert len(tasks) == 1 and tasks[0].name == "dense_int8.cuda" + + if __name__ == "__main__": test_task_extraction() + test_task_extraction_for_dense_int8_cuda() diff --git a/tests/python/relay/test_backend_graph_runtime.py b/tests/python/relay/test_backend_graph_runtime.py index 3c42b7b4196f..68708aaeb413 100644 --- a/tests/python/relay/test_backend_graph_runtime.py +++ b/tests/python/relay/test_backend_graph_runtime.py @@ -209,6 +209,27 @@ def test_compile_nested_tuples(): ref = ref + 1 +def test_graph_executor_nested_tuples(): + x, y, z, w = [relay.var(c, shape=(2, 3), dtype="float32") for c in "xyzw"] + out = relay.Tuple([x, relay.Tuple([y, relay.Tuple([z, w])])]) + func = relay.Function([x, y, z, w], out) + + exe = relay.create_executor( + kind="graph", mod=tvm.IRModule.from_expr(func), ctx=tvm.cpu(0), target="llvm" + ) + f = exe.evaluate() + + data = [np.random.uniform(size=(2, 3)).astype("float32") for _ in "xyzw"] + out = f(*data) + assert len(out) == 2 + tvm.testing.assert_allclose(out[0].asnumpy(), data[0]) + assert len(out[1]) == 2 + tvm.testing.assert_allclose(out[1][0].asnumpy(), data[1]) + assert len(out[1][1]) == 2 + tvm.testing.assert_allclose(out[1][1][0].asnumpy(), data[2]) + tvm.testing.assert_allclose(out[1][1][1].asnumpy(), data[3]) + + if __name__ == "__main__": test_plan_memory() test_with_params() diff --git a/tests/python/relay/test_cpp_build_module.py b/tests/python/relay/test_cpp_build_module.py index 67f0621ef273..60f3dfa76e38 100644 --- a/tests/python/relay/test_cpp_build_module.py +++ b/tests/python/relay/test_cpp_build_module.py @@ -18,7 +18,7 @@ import tvm from tvm import te -from tvm import relay +from tvm import relay, runtime from tvm.contrib.nvcc import have_fp16 import tvm.testing @@ -86,7 +86,7 @@ def test_fp16_build(): # test rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx) - rt.load_params(relay.save_param_dict(params)) + rt.load_params(runtime.save_param_dict(params)) rt.run() out = rt.get_output(0) diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index ea5dd6948b11..dfd350486c3b 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -202,14 +202,16 @@ def test_bias_add(): def test_bias_add_type_failure(): - # the axis is out of range - try: - b_add = relay.nn.bias_add(relay.const(1), relay.const(2), axis=0) - run_infer_type(b_add) - except tvm._ffi.base.TVMError: - pass - else: - assert False + def assert_failure(expr): + try: + run_infer_type(expr) + except tvm._ffi.base.TVMError: + return + else: + assert False + + for axis in (0, -1, -3, 1): + assert_failure(relay.nn.bias_add(relay.const(1), relay.const(2), axis=axis)) def test_expand_dims_infer_type(): diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index c9ed975c3b9b..d2a5090943c3 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -24,6 +24,7 @@ from tvm.error import TVMError from tvm.relay import create_executor, transform from tvm.relay.testing import check_grad, run_infer_type +from typing import Optional import tvm.testing @@ -1023,7 +1024,25 @@ def verify_dynamic_scatter(dshape, ishape, axis=0): @tvm.testing.uses_gpu -def test_scatter_add(): +@pytest.mark.parametrize( + "dshape, ishape, axis, dtype", + [ + ((10,), (10,), 0, "int32"), + ((1000,), (1000,), 0, "int32"), + ((10, 5), (10, 5), -2, "float32"), + ((10, 5), (10, 5), -1, "float32"), + ((10, 5), (3, 5), 0, "float32"), + ((12, 4), (7, 2), 1, "float32"), + ((2, 3, 4), (1, 3, 4), 0, "float32"), + ((2, 3, 4), (2, 1, 4), 1, "float32"), + ((2, 3, 4), (2, 3, 1), 2, "float32"), + ((2, 3, 4, 5), (1, 3, 4, 5), 0, "float32"), + ((6, 3, 4, 5), (2, 3, 4, 5), 1, "float32"), + ((2, 3, 8, 5), (2, 3, 1, 1), 2, "float32"), + ((16, 16, 4, 5), (16, 16, 4, 5), 3, "float32"), + ], +) +def test_scatter_add(dshape, ishape, axis, dtype): def ref_scatter_add(data, indices, updates, axis=0): output = np.copy(data) for index in np.ndindex(*indices.shape): @@ -1033,9 +1052,9 @@ def ref_scatter_add(data, indices, updates, axis=0): return output def verify_scatter_add(dshape, ishape, axis=0, dtype="float32"): - d = relay.var("d", relay.TensorType(dshape, dtype)) - i = relay.var("i", relay.TensorType(ishape, "int64")) - u = relay.var("u", relay.TensorType(ishape, dtype)) + d = relay.var("d", relay.TensorType(shape=[relay.Any() for _ in dshape], dtype=dtype)) + i = relay.var("i", relay.TensorType(shape=[relay.Any() for _ in ishape], dtype="int64")) + u = relay.var("u", relay.TensorType(shape=[relay.Any() for _ in ishape], dtype=dtype)) z = relay.op.scatter_add(d, i, u, axis) func = relay.Function([d, i, u], z) @@ -1045,40 +1064,177 @@ def verify_scatter_add(dshape, ishape, axis=0, dtype="float32"): indices_np = np.random.randint(-dshape[axis], dshape[axis] - 1, ishape).astype("int64") ref_res = ref_scatter_add(data_np, indices_np, updates_np, axis) - for target, ctx in tvm.testing.enabled_targets(): - for kind in ["graph", "debug"]: - if target == "nvptx" and dtype == "float32" and len(dshape) == 1: - # scatter_add 1D on GPU is implemented via atomic. - # Floating point atomic requires LLVM 9 or newer for nvptx backend. - # But LLVM on CI is LLVM 8. - continue - intrp = relay.create_executor(kind, ctx=ctx, target=target) - op_res = intrp.evaluate(func)(data_np, indices_np, updates_np) - tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) - verify_scatter_add((10,), (10,), 0, dtype="int32") - verify_scatter_add((1000,), (1000,)) - verify_scatter_add((1000,), (1000,), 0, dtype="int32") - verify_scatter_add((10, 5), (10, 5), -2) - verify_scatter_add((10, 5), (10, 5), -1) - verify_scatter_add((10, 5), (3, 5), 0) - verify_scatter_add((12, 4), (7, 2), 1) - verify_scatter_add((2, 3, 4), (1, 3, 4), 0) - verify_scatter_add((2, 3, 4), (2, 1, 4), 1) - verify_scatter_add((2, 3, 4), (2, 3, 1), 2) - verify_scatter_add((2, 3, 4, 5), (1, 3, 4, 5), 0) - verify_scatter_add((6, 3, 4, 5), (2, 3, 4, 5), 1) - verify_scatter_add((2, 3, 8, 5), (2, 3, 1, 1), 2) - verify_scatter_add((16, 16, 4, 5), (16, 16, 4, 5), 3) + verify_func( + func, + [data_np, indices_np, updates_np], + ref_res, + ) + + verify_scatter_add(dshape, ishape, axis, dtype) @tvm.testing.uses_gpu -def test_gather(): +@pytest.mark.parametrize( + "data, axis, indices, ref_res", + [ + ([[1, 2], [3, 4]], 1, [[0, 0], [1, 0]], [[1, 1], [4, 3]]), + ([[1, 2], [3, 4]], -1, [[0, 0], [1, 0]], [[1, 1], [4, 3]]), + ( + [[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]], + 0, + [[[1, 0, 1], [1, 1, 0]]], + [[[6, 1, 8], [9, 10, 5]]], + ), + ( + [[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]], + -3, + [[[1, 0, 1], [1, 1, 0]]], + [[[6, 1, 8], [9, 10, 5]]], + ), + ( + [ + [ + [-0.2321, -0.2024, -1.7624], + [-0.3829, -0.4246, 0.2448], + [0.1822, 0.2360, -0.8965], + [0.4497, -0.2224, 0.6103], + ], + [ + [0.0408, -0.7667, -0.4303], + [-0.3216, 0.7489, -0.1502], + [0.0144, -0.4699, -0.0064], + [-0.0768, -1.6064, 1.3390], + ], + ], + 1, + [[[2, 2, 0], [1, 0, 3]], [[3, 2, 0], [1, 0, 0]]], + [ + [[0.1822, 0.2360, -1.7624], [-0.3829, -0.2024, 0.6103]], + [[-0.0768, -0.4699, -0.4303], [-0.3216, -0.7667, -0.4303]], + ], + ), + ( + [ + [ + [-0.2321, -0.2024, -1.7624], + [-0.3829, -0.4246, 0.2448], + [0.1822, 0.2360, -0.8965], + [0.4497, -0.2224, 0.6103], + ], + [ + [0.0408, -0.7667, -0.4303], + [-0.3216, 0.7489, -0.1502], + [0.0144, -0.4699, -0.0064], + [-0.0768, -1.6064, 1.3390], + ], + ], + -2, + [[[2, 2, 0], [1, 0, 3]], [[3, 2, 0], [1, 0, 0]]], + [ + [[0.1822, 0.2360, -1.7624], [-0.3829, -0.2024, 0.6103]], + [[-0.0768, -0.4699, -0.4303], [-0.3216, -0.7667, -0.4303]], + ], + ), + ( + [ + [ + [-0.2321, -0.2024, -1.7624], + [-0.3829, -0.4246, 0.2448], + [0.1822, 0.2360, -0.8965], + [0.4497, -0.2224, 0.6103], + ], + [ + [0.0408, -0.7667, -0.4303], + [-0.3216, 0.7489, -0.1502], + [0.0144, -0.4699, -0.0064], + [-0.0768, -1.6064, 1.3390], + ], + ], + -2, + [[[2, 2, 0], [1, 0, 3]], [[3, 2, 0], [1, 0, 0]]], + [ + [[0.1822, 0.2360, -1.7624], [-0.3829, -0.2024, 0.6103]], + [[-0.0768, -0.4699, -0.4303], [-0.3216, -0.7667, -0.4303]], + ], + ), + ( + [ + [ + [0.3050, 1.6986, 1.1034], + [0.7020, -0.6960, -2.1818], + [0.3116, -0.5773, -0.9912], + [0.0835, -1.3915, -1.0720], + ], + [ + [0.1694, -0.6091, -0.6539], + [-0.5234, -0.1218, 0.5084], + [0.2374, -1.9537, -2.0078], + [-0.5700, -1.0302, 0.1558], + ], + ], + 2, + [ + [[1, 1, 0, 1], [0, 0, 2, 2], [1, 2, 1, 2], [2, 2, 1, 0]], + [[0, 0, 1, 2], [2, 2, 1, 0], [1, 2, 0, 0], [0, 2, 0, 2]], + ], + [ + [ + [1.6986, 1.6986, 0.3050, 1.6986], + [0.7020, 0.7020, -2.1818, -2.1818], + [-0.5773, -0.9912, -0.5773, -0.9912], + [-1.0720, -1.0720, -1.3915, 0.0835], + ], + [ + [0.1694, 0.1694, -0.6091, -0.6539], + [0.5084, 0.5084, -0.1218, -0.5234], + [-1.9537, -2.0078, 0.2374, 0.2374], + [-0.5700, 0.1558, -0.5700, 0.1558], + ], + ], + ), + ( + [ + [ + [0.3050, 1.6986, 1.1034], + [0.7020, -0.6960, -2.1818], + [0.3116, -0.5773, -0.9912], + [0.0835, -1.3915, -1.0720], + ], + [ + [0.1694, -0.6091, -0.6539], + [-0.5234, -0.1218, 0.5084], + [0.2374, -1.9537, -2.0078], + [-0.5700, -1.0302, 0.1558], + ], + ], + -1, + [ + [[1, 1, 0, 1], [0, 0, 2, 2], [1, 2, 1, 2], [2, 2, 1, 0]], + [[0, 0, 1, 2], [2, 2, 1, 0], [1, 2, 0, 0], [0, 2, 0, 2]], + ], + [ + [ + [1.6986, 1.6986, 0.3050, 1.6986], + [0.7020, 0.7020, -2.1818, -2.1818], + [-0.5773, -0.9912, -0.5773, -0.9912], + [-1.0720, -1.0720, -1.3915, 0.0835], + ], + [ + [0.1694, 0.1694, -0.6091, -0.6539], + [0.5084, 0.5084, -0.1218, -0.5234], + [-1.9537, -2.0078, 0.2374, 0.2374], + [-0.5700, 0.1558, -0.5700, 0.1558], + ], + ], + ), + ], +) +def test_gather(data, axis, indices, ref_res): def verify_gather(data, axis, indices, ref_res): data = np.asarray(data, dtype="float32") indices = np.asarray(indices, dtype="int32") ref_res = np.asarray(ref_res) - d = relay.var("x", relay.TensorType(data.shape, "float32")) i = relay.var("y", relay.TensorType(indices.shape, "int32")) z = relay.gather(d, axis, i) @@ -1091,70 +1247,7 @@ def verify_gather(data, axis, indices, ref_res): op_res = intrp.evaluate(func)(data, indices) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) - verify_gather([[1, 2], [3, 4]], 1, [[0, 0], [1, 0]], [[1, 1], [4, 3]]) - verify_gather( - [[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]], - 0, - [[[1, 0, 1], [1, 1, 0]]], - [[[6, 1, 8], [9, 10, 5]]], - ) - verify_gather( - [ - [ - [-0.2321, -0.2024, -1.7624], - [-0.3829, -0.4246, 0.2448], - [0.1822, 0.2360, -0.8965], - [0.4497, -0.2224, 0.6103], - ], - [ - [0.0408, -0.7667, -0.4303], - [-0.3216, 0.7489, -0.1502], - [0.0144, -0.4699, -0.0064], - [-0.0768, -1.6064, 1.3390], - ], - ], - 1, - [[[2, 2, 0], [1, 0, 3]], [[3, 2, 0], [1, 0, 0]]], - [ - [[0.1822, 0.2360, -1.7624], [-0.3829, -0.2024, 0.6103]], - [[-0.0768, -0.4699, -0.4303], [-0.3216, -0.7667, -0.4303]], - ], - ) - verify_gather( - [ - [ - [0.3050, 1.6986, 1.1034], - [0.7020, -0.6960, -2.1818], - [0.3116, -0.5773, -0.9912], - [0.0835, -1.3915, -1.0720], - ], - [ - [0.1694, -0.6091, -0.6539], - [-0.5234, -0.1218, 0.5084], - [0.2374, -1.9537, -2.0078], - [-0.5700, -1.0302, 0.1558], - ], - ], - 2, - [ - [[1, 1, 0, 1], [0, 0, 2, 2], [1, 2, 1, 2], [2, 2, 1, 0]], - [[0, 0, 1, 2], [2, 2, 1, 0], [1, 2, 0, 0], [0, 2, 0, 2]], - ], - [ - [ - [1.6986, 1.6986, 0.3050, 1.6986], - [0.7020, 0.7020, -2.1818, -2.1818], - [-0.5773, -0.9912, -0.5773, -0.9912], - [-1.0720, -1.0720, -1.3915, 0.0835], - ], - [ - [0.1694, 0.1694, -0.6091, -0.6539], - [0.5084, 0.5084, -0.1218, -0.5234], - [-1.9537, -2.0078, 0.2374, 0.2374], - [-0.5700, 0.1558, -0.5700, 0.1558], - ], - ], - ) + verify_gather(data, axis, indices, ref_res) @tvm.testing.uses_gpu @@ -1515,6 +1608,105 @@ def verify_sparse_reshape( ) +@tvm.testing.uses_gpu +@pytest.mark.parametrize( + "data_np, segment_ids_np, num_segments", + [ + ( + np.array([5, 1, 7, 2, 3, 4], dtype=np.float32), + np.array([0, 0, 1, 1, 0, 1], dtype=np.int32), + None, + ), + ( + np.array([[1, 2, 3, 4], [-1, -2, -3, -4], [5, 6, 7, 8]], dtype=np.float64), + np.array([0, 0, 1], dtype=np.int32), + None, + ), + ( + np.random.random((6, 4, 5)), + np.array([2, 0, 1, 0, 3, 2], dtype=np.int64), + None, + ), + ( + np.array([[[1, 7]], [[3, 8]], [[2, 9]]], dtype=np.float32), + np.array([0, 0, 1], dtype=np.int32), + None, + ), + ( + np.random.random((9, 4, 5, 7)), + np.array([5, 0, 1, 0, 3, 6, 8, 7, 7], dtype=np.int64), + 9, + ), + ( + np.array([[1, 2, 3, 4], [-1, -2, -3, -4], [5, 6, 7, 8]], dtype=np.float64), + np.array([0, 2], dtype=np.int32), + 4, + ), + ( + np.random.random((6, 4, 5)), + np.array([0, 0, 1, 5, 5], dtype=np.int32), + 100, + ), + ], +) +@pytest.mark.parametrize("use_dyn", [True, False]) +def test_segment_sum(data_np, segment_ids_np, num_segments, use_dyn): + def ref_segment_sum( + data: np.ndarray, + segment_ids: np.ndarray, + num_segments: Optional[int] = None, + ): + """ + This function calculates the expected output of segment_sum operator given the inputs. + """ + if not num_segments: + num_segments = np.unique(segment_ids).shape[0] + + result = np.zeros((num_segments,) + data.shape[1:], data.dtype) + for i, index in enumerate(segment_ids): + result[index] += data[i] + return result + + def verify_segment_sum( + data_np: np.ndarray, segment_ids_np: np.ndarray, num_segments: Optional[int] + ): + """ + This function verifies the relay output of segment_sum with its expected output. + """ + if use_dyn: + data = relay.var( + "data", + shape=[relay.Any() for _ in data_np.shape], + dtype=str(data_np.dtype), + ) + segment_ids = relay.var( + "segment_ids", + shape=[relay.Any()], + dtype=str(segment_ids_np.dtype), + ) + else: + data = relay.var( + "data", + relay.TensorType(data_np.shape, str(data_np.dtype)), + ) + segment_ids = relay.var( + "segment_ids", relay.TensorType(segment_ids_np.shape, str(segment_ids_np.dtype)) + ) + z = relay.op.segment_sum(data, segment_ids, num_segments) + + func = relay.Function([data, segment_ids], z) + ref_res = ref_segment_sum(data_np, segment_ids_np, num_segments=num_segments) + segment_sum_result = run_infer_type(z) + assert segment_sum_result.checked_type.dtype == data_np.dtype + verify_func( + func, + [data_np, segment_ids_np], + ref_res, + ) + + verify_segment_sum(data_np, segment_ids_np, num_segments) + + def verify_func(func, data, ref_res, target_ctx=tvm.testing.enabled_targets()): assert isinstance(data, list) for target, ctx in target_ctx: diff --git a/tests/python/relay/test_param_dict.py b/tests/python/relay/test_param_dict.py index 74c9ebcaa355..29e0b5c0463b 100644 --- a/tests/python/relay/test_param_dict.py +++ b/tests/python/relay/test_param_dict.py @@ -17,7 +17,7 @@ import os import numpy as np import tvm -from tvm import te +from tvm import te, runtime import json import base64 from tvm._ffi.base import py_str @@ -31,7 +31,7 @@ def test_save_load(): x = np.ones((10, 2)).astype("float32") y = np.ones((1, 2, 3)).astype("float32") params = {"x": x, "y": y} - param_bytes = relay.save_param_dict(params) + param_bytes = runtime.save_param_dict(params) assert isinstance(param_bytes, bytearray) param2 = relay.load_param_dict(param_bytes) assert len(param2) == 2 @@ -46,7 +46,7 @@ def test_ndarray_reflection(): param_dict = {"x": tvm_array, "y": tvm_array} assert param_dict["x"].same_as(param_dict["y"]) # Serialize then deserialize `param_dict`. - deser_param_dict = relay.load_param_dict(relay.save_param_dict(param_dict)) + deser_param_dict = relay.load_param_dict(runtime.save_param_dict(param_dict)) # Make sure the data matches the original data and `x` and `y` contain the same data. np.testing.assert_equal(deser_param_dict["x"].asnumpy(), tvm_array.asnumpy()) # Make sure `x` and `y` contain the same data. @@ -77,7 +77,7 @@ def verify_graph_runtime(remote, target, shape, dtype): lib = remote.load_module("dev_lib.o") ctx = remote.cpu(0) mod = graph_runtime.create(graph, lib, ctx) - mod.load_params(relay.save_param_dict(params)) + mod.load_params(runtime.save_param_dict(params)) mod.run() out = mod.get_output(0, tvm.nd.empty(shape, dtype=dtype, ctx=ctx)) tvm.testing.assert_allclose(x_in + 1, out.asnumpy()) diff --git a/tests/python/relay/test_pass_fold_explicit_padding.py b/tests/python/relay/test_pass_fold_explicit_padding.py new file mode 100644 index 000000000000..302a2b91bb8f --- /dev/null +++ b/tests/python/relay/test_pass_fold_explicit_padding.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import relay +from tvm.relay import transform +from tvm.relay.testing import run_opt_pass + +import numpy as np + + +def test_simplify_conv_pad(): + convs = [relay.nn.conv1d, relay.nn.conv2d, relay.nn.conv3d] + + def validate(ndim, pad_width, pad_value, pad_mode, orig_padding, layout): + if layout[1] == "C": + shape = [1, 3] + [10] * ndim + wshape = [8, 3] + [3] * ndim + elif layout[-1] == "C": + shape = [1] + [10] * ndim + [3] + wshape = [8] + [3] * ndim + [3] + else: + raise ValueError("This test only supports NC* and N*C") + + x = relay.var("x", shape=shape, dtype="float32") + w = relay.var("w", shape=wshape, dtype="float32") + pad = relay.nn.pad(x, pad_width, pad_value, pad_mode) + if layout[1] == "C": + conv = convs[ndim - 1](pad, w, padding=orig_padding) + else: + conv = convs[ndim - 1]( + pad, w, padding=orig_padding, data_layout=layout, kernel_layout="DHWIO"[3 - ndim :] + ) + + if pad_mode == "constant" and pad_value == 0: + new_padding = [] + for j in range(2): + for i in range(len(pad_width)): + if layout[i] in ["D", "H", "W"]: + new_padding.append(pad_width[i][j]) + for i in range(len(new_padding)): + new_padding[i] += orig_padding[i] + if layout[1] == "C": + after = convs[ndim - 1](x, w, padding=new_padding) + else: + after = convs[ndim - 1]( + x, w, padding=new_padding, data_layout=layout, kernel_layout="DHWIO"[3 - ndim :] + ) + else: + after = conv + + zz = run_opt_pass(conv, transform.FoldExplicitPadding()) + expected = run_opt_pass(after, transform.InferType()) + assert tvm.ir.structural_equal(zz, expected) + + mod1 = tvm.IRModule.from_expr(conv) + mod2 = tvm.IRModule.from_expr(zz) + + with tvm.transform.PassContext(): + ex1 = relay.create_executor("vm", mod=mod1, ctx=tvm.cpu(), target="llvm") + ex2 = relay.create_executor("vm", mod=mod2, ctx=tvm.cpu(), target="llvm") + x_np = np.random.rand(*shape).astype("float32") + w_np = np.random.rand(*wshape).astype("float32") + result1 = ex1.evaluate()(x_np, w_np) + result2 = ex2.evaluate()(x_np, w_np) + + tvm.testing.assert_allclose(result1.asnumpy(), result2.asnumpy(), rtol=1e-5, atol=1e-5) + + for orig_pad in [[0, 0], [2, 0], [0, 2]]: + for i_pad in [[0, 0], [1, 1], [1, 0]]: + for ndim in [1, 2, 3]: + for channels_last in [0, 1]: + if channels_last: + layout = "NDHWC" + layout = layout[0:1] + layout[4 - ndim : 4] + layout[-1:] + padding = [[0, 0]] + [i_pad] * ndim + [[0, 0]] + else: + layout = "NCDHW" + layout = layout[0:2] + layout[5 - ndim :] + padding = [[0, 0]] * 2 + [i_pad] * ndim + + validate(ndim, padding, 0, "constant", orig_pad * ndim, layout) + ndim = 2 + validate(ndim, [[0, 0]] * 2 + [i_pad] * ndim, 1, "constant", orig_pad * ndim, "NCHW") + validate(ndim, [[0, 0]] * 2 + [i_pad] * ndim, 0, "edge", orig_pad * ndim, "NCHW") + + +if __name__ == "__main__": + test_simplify_conv_pad() diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index e3e497e930f9..9531d896b2ed 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -124,82 +124,6 @@ def after_right(x, elem_op, value): validate(shape, value, dtype) -def test_simplify_conv_pad(): - convs = [relay.nn.conv1d, relay.nn.conv2d, relay.nn.conv3d] - - def validate(ndim, pad_width, pad_value, pad_mode, orig_padding, layout): - if layout[1] == "C": - shape = [1, 3] + [10] * ndim - wshape = [8, 3] + [3] * ndim - elif layout[-1] == "C": - shape = [1] + [10] * ndim + [3] - wshape = [8] + [3] * ndim + [3] - else: - raise ValueError("This test only supports NC* and N*C") - - x = relay.var("x", shape=shape, dtype="float32") - w = relay.var("w", shape=wshape, dtype="float32") - pad = relay.nn.pad(x, pad_width, pad_value, pad_mode) - if layout[1] == "C": - conv = convs[ndim - 1](pad, w, padding=orig_padding) - else: - conv = convs[ndim - 1]( - pad, w, padding=orig_padding, data_layout=layout, kernel_layout="DHWIO"[3 - ndim :] - ) - - if pad_mode == "constant" and pad_value == 0: - new_padding = [] - for j in range(2): - for i in range(len(pad_width)): - if layout[i] in ["D", "H", "W"]: - new_padding.append(pad_width[i][j]) - for i in range(len(new_padding)): - new_padding[i] += orig_padding[i] - if layout[1] == "C": - after = convs[ndim - 1](x, w, padding=new_padding) - else: - after = convs[ndim - 1]( - x, w, padding=new_padding, data_layout=layout, kernel_layout="DHWIO"[3 - ndim :] - ) - else: - after = conv - - zz = run_opt_pass(conv, transform.SimplifyExpr()) - expected = run_opt_pass(after, transform.InferType()) - assert tvm.ir.structural_equal(zz, expected) - - mod1 = tvm.IRModule.from_expr(conv) - mod2 = tvm.IRModule.from_expr(zz) - - with tvm.transform.PassContext(disabled_pass="SimplifyExpr"): - ex1 = relay.create_executor("vm", mod=mod1, ctx=tvm.cpu(), target="llvm") - ex2 = relay.create_executor("vm", mod=mod2, ctx=tvm.cpu(), target="llvm") - x_np = np.random.rand(*shape).astype("float32") - w_np = np.random.rand(*wshape).astype("float32") - result1 = ex1.evaluate()(x_np, w_np) - result2 = ex2.evaluate()(x_np, w_np) - - tvm.testing.assert_allclose(result1.asnumpy(), result2.asnumpy()) - - for orig_pad in [[0, 0], [2, 0], [0, 2]]: - for i_pad in [[0, 0], [1, 1], [1, 0]]: - for ndim in [1, 2, 3]: - for channels_last in [0, 1]: - if channels_last: - layout = "NDHWC" - layout = layout[0:1] + layout[4 - ndim : 4] + layout[-1:] - padding = [[0, 0]] + [i_pad] * ndim + [[0, 0]] - else: - layout = "NCDHW" - layout = layout[0:2] + layout[5 - ndim :] - padding = [[0, 0]] * 2 + [i_pad] * ndim - - validate(ndim, padding, 0, "constant", orig_pad * ndim, layout) - ndim = 2 - validate(ndim, [[0, 0]] * 2 + [i_pad] * ndim, 1, "constant", orig_pad * ndim, "NCHW") - validate(ndim, [[0, 0]] * 2 + [i_pad] * ndim, 0, "edge", orig_pad * ndim, "NCHW") - - if __name__ == "__main__": test_simplify_reshape() test_simplify_full_elementwise() diff --git a/tests/python/topi/python/test_topi_broadcast.py b/tests/python/topi/python/test_topi_broadcast.py index 44be28c318e4..ada03ea5377b 100644 --- a/tests/python/topi/python/test_topi_broadcast.py +++ b/tests/python/topi/python/test_topi_broadcast.py @@ -284,7 +284,7 @@ def test_shift(): ) verify_broadcast_binary_ele( - (1, 2, 2), (2,), topi.left_shift, np.left_shift, dtype="int8", rhs_min=0, rhs_max=32 + (1, 2, 2), (2,), topi.left_shift, np.left_shift, dtype="int32", rhs_min=0, rhs_max=32 ) diff --git a/tests/python/topi/python/test_topi_cumsum.py b/tests/python/topi/python/test_topi_cumsum.py index a01a496f92e9..cfe5130643c5 100644 --- a/tests/python/topi/python/test_topi_cumsum.py +++ b/tests/python/topi/python/test_topi_cumsum.py @@ -28,6 +28,8 @@ def check_cumsum(np_ref, data, axis=None, dtype=None): "generic": (lambda x: topi.cumsum(x, axis, dtype), topi.generic.schedule_extern), "cuda": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan), "nvptx": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan), + "vulkan": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan), + "metal": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan), } fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) tvm.topi.testing.compare_numpy_tvm([data], np_ref, target, ctx, fcompute, fschedule) @@ -44,6 +46,9 @@ def check_cumsum(np_ref, data, axis=None, dtype=None): check_cumsum(np.cumsum(data, dtype=np.int32), data, dtype="int32") for in_dtype in ["float32", "float64"]: + if target == "metal" and in_dtype == "float64": + # float64 is not supported in metal + continue data = np.random.randn(10, 10).astype(in_dtype) check_cumsum(np.cumsum(data), data) check_cumsum(np.cumsum(data, axis=0), data, axis=0) @@ -70,3 +75,5 @@ def check_cumsum(np_ref, data, axis=None, dtype=None): test_cumsum(tvm.context("cpu"), tvm.target.Target("llvm")) test_cumsum(tvm.context("cuda"), tvm.target.Target("cuda")) test_cumsum(tvm.context("nvptx"), tvm.target.Target("nvptx")) + test_cumsum(tvm.context("vulkan"), tvm.target.Target("vulkan")) + test_cumsum(tvm.context("metal"), tvm.target.Target("metal")) diff --git a/tests/python/topi/python/test_topi_vision.py b/tests/python/topi/python/test_topi_vision.py index 839356892ab1..2fdf3cf4b170 100644 --- a/tests/python/topi/python/test_topi_vision.py +++ b/tests/python/topi/python/test_topi_vision.py @@ -112,7 +112,7 @@ def check_device(device): tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3) tvm.testing.assert_allclose(tvm_out3.asnumpy(), np_out3, rtol=1e-3) - for device in ["llvm", "cuda", "opencl"]: + for device in ["llvm", "cuda", "opencl", "vulkan"]: check_device(device) diff --git a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py index 2fae7b838143..795c3cb3b0a2 100644 --- a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py +++ b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py @@ -66,7 +66,7 @@ def test_apply_steps_with_layout_rewrite_corner_case(): @tvm.testing.requires_llvm def test_correctness_layout_rewrite_rewrite_for_preTransformed(): - N = 128 + N = 16 target = tvm.target.Target("llvm") task = auto_scheduler.SearchTask(func=matmul_auto_scheduler_test, args=(N, N, N), target=target) dag = task.compute_dag @@ -78,9 +78,10 @@ def test_correctness_layout_rewrite_rewrite_for_preTransformed(): measure_ctx = auto_scheduler.LocalRPCMeasureContext() tuning_options = auto_scheduler.TuningOptions( - num_measure_trials=2, + num_measure_trials=100, runner=measure_ctx.runner, verbose=2, + early_stopping=1, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) task.tune(tuning_options, search_policy=search_policy) diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index cc9d7a41548d..116981028cc9 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -19,6 +19,7 @@ import json import multiprocessing +import numpy as np import tvm from tvm import topi from tvm import te, auto_scheduler @@ -26,7 +27,7 @@ import tvm.testing import pickle -from test_auto_scheduler_common import matmul_auto_scheduler_test, get_tiled_matmul +from test_auto_scheduler_common import matmul_auto_scheduler_test from tvm.auto_scheduler import workload_registry @@ -355,6 +356,34 @@ def test_measure_target_host(): assert str(recovered_inp.task.target_host) == str(inp.task.target_host) +@tvm.testing.requires_llvm +def test_measure_special_inputs_map_by_name(): + @auto_scheduler.register_workload + def foo(): + X = te.placeholder(shape=[10], dtype="int32") + Index = te.placeholder(shape=[1], dtype="int32", name="Index") + Y = te.compute((1,), lambda i: X[Index[i]]) + return [X, Index, Y] + + # This workload cannot use random input for the `Index` input + task = auto_scheduler.SearchTask( + func=foo, + target="llvm", + task_inputs={ + "Index": tvm.nd.array(np.array([5], dtype="int32")), + }, + ) + + minp = auto_scheduler.MeasureInput(task, task.compute_dag.init_state) + local_builder = auto_scheduler.LocalBuilder() + local_runner = auto_scheduler.LocalRunner(timeout=10) + + bress = local_builder.build([minp]) + assert bress[0].error_no == 0 + mress = local_runner.run([minp], bress) + assert mress[0].error_no == 0 + + if __name__ == "__main__": test_record_split_reorder_fuse_annotation() test_record_compute_at_root_inline_cache_read_write() @@ -366,3 +395,4 @@ def test_measure_target_host(): test_dag_measure_local_builder_runner() test_measure_local_builder_rpc_runner() test_measure_target_host() + test_measure_special_inputs_map_by_name() diff --git a/tests/python/unittest/test_auto_scheduler_search_task.py b/tests/python/unittest/test_auto_scheduler_search_task.py new file mode 100644 index 000000000000..78e85dc213e0 --- /dev/null +++ b/tests/python/unittest/test_auto_scheduler_search_task.py @@ -0,0 +1,207 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test search policy""" + +import numpy as np +import tempfile + +import tvm +import tvm.testing +from tvm import auto_scheduler +from tvm.auto_scheduler.utils import get_const_tuple +from test_auto_scheduler_common import ( + matmul_auto_scheduler_test, + zero_rank_compute_auto_scheduler_test, + zero_rank_reduce_auto_scheduler_test, +) + + +def test_search_task_add_task_input(): + auto_scheduler.search_task.TASK_INPUT_BUFFER_TABLE.clear() + N = 64 + target = "llvm" + test_input_0 = tvm.runtime.ndarray.empty((64, 64)) + test_input_1 = tvm.runtime.ndarray.empty((10, 20)) + test_input_2 = tvm.runtime.ndarray.empty((30, 40, 50)) + task = auto_scheduler.SearchTask( + func="matmul_auto_scheduler_test", + args=(N, N, N), + target=target, + task_inputs={ + "test_input_0": test_input_0, + "test_input_1": test_input_1, + "test_input_2": test_input_2, + }, + task_inputs_overwrite=True, + ) + + assert len(task.task_input_names) == 3 + assert task.task_input_names[0] == "test_input_0" + assert task.task_input_names[1] == "test_input_1" + assert task.task_input_names[2] == "test_input_2" + + +def test_search_task_record(): + auto_scheduler.search_task.TASK_INPUT_BUFFER_TABLE.clear() + N = 64 + target = "llvm" + + # Log with no task input + task = auto_scheduler.SearchTask( + func="matmul_auto_scheduler_test", args=(N, N, N), target=target + ) + task_record = auto_scheduler._ffi_api.SerializeSearchTask(task) + new_task = auto_scheduler._ffi_api.DeserializeSearchTask(task_record) + # TODO(jcf94): Check the compute dag & hardware parameter + assert task.workload_key == new_task.workload_key + assert str(task.target) == str(new_task.target) + assert str(task.target_host) == str(new_task.target_host) + assert task.layout_rewrite_option == new_task.layout_rewrite_option + + # Log with 1 task input + test_input_0 = tvm.runtime.ndarray.empty((64, 64)) + task = auto_scheduler.SearchTask( + func="matmul_auto_scheduler_test", + args=(N, N, N), + target=target, + task_inputs={"test_input_0": test_input_0}, + task_inputs_overwrite=True, + ) + task_record = auto_scheduler._ffi_api.SerializeSearchTask(task) + new_task = auto_scheduler._ffi_api.DeserializeSearchTask(task_record) + assert task.workload_key == new_task.workload_key + assert str(task.target) == str(new_task.target) + assert str(task.target_host) == str(new_task.target_host) + assert task.layout_rewrite_option == new_task.layout_rewrite_option + assert len(new_task.task_input_names) == 1 + assert new_task.task_input_names[0] == "test_input_0" + + # Log with multiple task inputs + test_input_1 = tvm.runtime.ndarray.empty((64, 64)) + task = auto_scheduler.SearchTask( + func="matmul_auto_scheduler_test", + args=(N, N, N), + target=target, + task_inputs={ + "test_input_0": test_input_0, + "test_input_1": test_input_1, + }, + task_inputs_overwrite=True, + ) + task_record = auto_scheduler._ffi_api.SerializeSearchTask(task) + new_task = auto_scheduler._ffi_api.DeserializeSearchTask(task_record) + assert task.workload_key == new_task.workload_key + assert str(task.target) == str(new_task.target) + assert str(task.target_host) == str(new_task.target_host) + assert task.layout_rewrite_option == new_task.layout_rewrite_option + assert len(new_task.task_input_names) == 2 + assert new_task.task_input_names[0] == "test_input_0" + assert new_task.task_input_names[1] == "test_input_1" + + # Log with version 0.5 + v5_log = """["[\\\"matmul_auto_scheduler_test\\\", 64, 64, 64]", "llvm -keys=cpu -link-params=0", [6, 64, 64, 0, 0, 0, 0, 0], "", 1]""" + new_task = auto_scheduler._ffi_api.DeserializeSearchTask(v5_log) + assert task.workload_key == new_task.workload_key + assert str(task.target) == str(new_task.target) + assert str(task.target_host) == str(new_task.target_host) + assert task.layout_rewrite_option == new_task.layout_rewrite_option + assert len(new_task.task_input_names) == 0 + + +def test_recover_measure_input_with_task_input(): + auto_scheduler.search_task.TASK_INPUT_BUFFER_TABLE.clear() + + # Since this file is tests for search_task, we only check the search_task here + + # Log with no task input + task = auto_scheduler.SearchTask( + func=matmul_auto_scheduler_test, args=(512, 512, 512), target="llvm" + ) + inp = auto_scheduler.measure.MeasureInput(task, task.compute_dag.init_state) + res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1) + measure_record = auto_scheduler.measure_record.dump_record_to_string(inp, res) + measure_log = auto_scheduler.measure_record.load_record_from_string(measure_record) + new_task = measure_log[0].task + assert task.workload_key == new_task.workload_key + assert str(task.target) == str(new_task.target) + assert str(task.target_host) == str(new_task.target_host) + assert task.layout_rewrite_option == new_task.layout_rewrite_option + + # Log with 1 task input + test_input_0 = tvm.runtime.ndarray.empty((64, 64)) + task = auto_scheduler.SearchTask( + func=matmul_auto_scheduler_test, + args=(512, 512, 512), + target="llvm", + task_inputs={ + "test_input_0": test_input_0, + }, + task_inputs_overwrite=True, + ) + inp = auto_scheduler.measure.MeasureInput(task, task.compute_dag.init_state) + res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1) + measure_record = auto_scheduler.measure_record.dump_record_to_string(inp, res) + measure_log = auto_scheduler.measure_record.load_record_from_string(measure_record) + new_task = measure_log[0].task + assert task.workload_key == new_task.workload_key + assert str(task.target) == str(new_task.target) + assert str(task.target_host) == str(new_task.target_host) + assert task.layout_rewrite_option == new_task.layout_rewrite_option + assert len(new_task.task_input_names) == 1 + assert new_task.task_input_names[0] == "test_input_0" + + # Log with multiple task inputs + test_input_1 = tvm.runtime.ndarray.empty((64, 64)) + task = auto_scheduler.SearchTask( + func=matmul_auto_scheduler_test, + args=(512, 512, 512), + target="llvm", + task_inputs={ + "test_input_0": test_input_0, + "test_input_1": test_input_1, + }, + task_inputs_overwrite=True, + ) + inp = auto_scheduler.measure.MeasureInput(task, task.compute_dag.init_state) + res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1) + measure_record = auto_scheduler.measure_record.dump_record_to_string(inp, res) + measure_log = auto_scheduler.measure_record.load_record_from_string(measure_record) + new_task = measure_log[0].task + assert task.workload_key == new_task.workload_key + assert str(task.target) == str(new_task.target) + assert str(task.target_host) == str(new_task.target_host) + assert task.layout_rewrite_option == new_task.layout_rewrite_option + assert len(new_task.task_input_names) == 2 + assert new_task.task_input_names[0] == "test_input_0" + assert new_task.task_input_names[1] == "test_input_1" + + # Log with version 0.5 + v5_log = """{"i": [["[\\\"matmul_auto_scheduler_test\\\", 512, 512, 512]", "llvm -keys=cpu -link-params=0", [6, 64, 64, 0, 0, 0, 0, 0], "", 1], [[], []]], "r": [[0.1], 0, 0.2, 1], "v": "v0.6"}""" + measure_log = auto_scheduler.measure_record.load_record_from_string(v5_log) + new_task = measure_log[0].task + assert task.workload_key == new_task.workload_key + assert str(task.target) == str(new_task.target) + assert str(task.target_host) == str(new_task.target_host) + assert task.layout_rewrite_option == new_task.layout_rewrite_option + assert len(new_task.task_input_names) == 0 + + +if __name__ == "__main__": + test_search_task_add_task_input() + test_search_task_record() + test_recover_measure_input_with_task_input() diff --git a/tests/python/unittest/test_runtime_graph.py b/tests/python/unittest/test_runtime_graph.py index c43a35924420..16e9db42cba3 100644 --- a/tests/python/unittest/test_runtime_graph.py +++ b/tests/python/unittest/test_runtime_graph.py @@ -16,7 +16,7 @@ # under the License. import tvm import tvm.testing -from tvm import te +from tvm import te, runtime import numpy as np import json from tvm import rpc @@ -94,12 +94,12 @@ def check_sharing(): graph, lib, params = relay.build(func, target="llvm", params=params) mod_shared = graph_runtime.create(graph, lib, tvm.cpu(0)) - mod_shared.load_params(relay.save_param_dict(params)) + mod_shared.load_params(runtime.save_param_dict(params)) num_mods = 10 mods = [graph_runtime.create(graph, lib, tvm.cpu(0)) for _ in range(num_mods)] for mod in mods: - mod.share_params(mod_shared, relay.save_param_dict(params)) + mod.share_params(mod_shared, runtime.save_param_dict(params)) a = np.random.uniform(size=(1, 10)).astype("float32") for mod in mods: diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index 51a587242ae3..a34fe4a062cb 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import numpy as np -from tvm import relay +from tvm import relay, runtime from tvm.relay import testing import tvm from tvm.contrib import graph_runtime @@ -314,7 +314,7 @@ def verify_cpu_remove_package_params(obj_format): complied_graph_lib_no_params = complied_graph_lib["remove_params"]() complied_graph_lib_no_params.export_library(path_lib) with open(temp.relpath("deploy_param.params"), "wb") as fo: - fo.write(relay.save_param_dict(complied_graph_lib.get_params())) + fo.write(runtime.save_param_dict(complied_graph_lib.get_params())) loaded_lib = tvm.runtime.load_module(path_lib) data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32") ctx = tvm.cpu(0) @@ -361,7 +361,7 @@ def verify_gpu_remove_package_params(obj_format): complied_graph_lib_no_params = complied_graph_lib["remove_params"]() complied_graph_lib_no_params.export_library(path_lib) with open(temp.relpath("deploy_param.params"), "wb") as fo: - fo.write(relay.save_param_dict(complied_graph_lib.get_params())) + fo.write(runtime.save_param_dict(complied_graph_lib.get_params())) loaded_lib = tvm.runtime.load_module(path_lib) data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32") ctx = tvm.gpu(0) @@ -409,7 +409,7 @@ def verify_rpc_cpu_remove_package_params(obj_format): complied_graph_lib_no_params.export_library(path_lib) path_params = temp.relpath("deploy_param.params") with open(path_params, "wb") as fo: - fo.write(relay.save_param_dict(complied_graph_lib.get_params())) + fo.write(runtime.save_param_dict(complied_graph_lib.get_params())) from tvm import rpc @@ -462,7 +462,7 @@ def verify_rpc_gpu_remove_package_params(obj_format): complied_graph_lib_no_params.export_library(path_lib) path_params = temp.relpath("deploy_param.params") with open(path_params, "wb") as fo: - fo.write(relay.save_param_dict(complied_graph_lib.get_params())) + fo.write(runtime.save_param_dict(complied_graph_lib.get_params())) from tvm import rpc diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index a228a640f108..06d7cb4bb7bb 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -19,7 +19,7 @@ import numpy as np from tvm import topi import unittest -from tvm.contrib.nvcc import have_fp16, have_int8 +from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16 from tvm.contrib import nvcc import tvm.testing @@ -67,6 +67,53 @@ def check_cuda(dtype, n, lanes): check_cuda("float16", 64, 8) +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda +def test_cuda_bf16_vectorize_add(): + if not have_bf16(tvm.gpu(0).compute_version): + print("skip because gpu does not support bf16") + return + num_thread = 8 + + def np_float2np_bf16(arr): + """Convert a numpy array of float to a numpy array + of bf16 in uint16""" + orig = arr.view(" 0.5 + b_np = np.zeros((n,), dtype="int32") + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + func(a, b) + ref = a_np.astype(np.int32) + tvm.testing.assert_allclose(b.asnumpy(), ref) + + +if __name__ == "__main__": + test_bool_load() diff --git a/tests/python/unittest/test_te_schedule_ops.py b/tests/python/unittest/test_te_schedule_ops.py index 1555974169fc..255e0cdb1f21 100644 --- a/tests/python/unittest/test_te_schedule_ops.py +++ b/tests/python/unittest/test_te_schedule_ops.py @@ -110,19 +110,53 @@ def argmax_init(idx_typ, val_typ): def test_auto_inline(): - m = te.var("m") - n = te.var("n") - A = te.placeholder((m, n), name="A") - B = te.placeholder((m, n), name="B") - C = te.placeholder((m, n), name="C") - T1 = te.compute((m, n), lambda i, j: A(i, j) * B(i, j), name="T1") - T2 = te.compute((m, n), lambda i, j: T1(i, j) + C(i, j), name="T2") - - s = te.create_schedule(T2.op) - tvm.te.schedule.AutoInlineElemWise(s) - s = s.normalize() - bounds = tvm.te.schedule.InferBound(s) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) + def elemwise(): + m = te.var("m") + n = te.var("n") + A = te.placeholder((m, n), name="A") + B = te.placeholder((m, n), name="B") + C = te.placeholder((m, n), name="C") + T1 = te.compute((m, n), lambda i, j: A(i, j) * B(i, j), name="T1") + T2 = te.compute((m, n), lambda i, j: T1(i, j) + C(i, j), name="T2") + + return te.create_schedule(T2.op), T1 + + def broadcast(): + m = te.var("m") + n = te.var("n") + A = te.placeholder((1,), name="A") + B = te.placeholder((m, n), name="B") + C = te.placeholder((m, n), name="C") + T1 = te.compute((m, n), lambda i, j: A(0) * B(i, j), name="T1", tag="broadcast") + T2 = te.compute((m, n), lambda i, j: T1(i, j) + C(i, j), name="T2") + + return te.create_schedule(T2.op), T1 + + def injective(): + m = te.var("m") + n = te.var("n") + A = te.placeholder((m,), name="A") + B = te.placeholder((m, n), name="B") + C = te.placeholder((m, n), name="C") + T1 = te.compute((m, n), lambda i, j: A(i) * B(i, j), name="T1") + T2 = te.compute((m, n), lambda i, j: T1(i, j) + C(i, j), name="T2") + + return te.create_schedule(T2.op), T1 + + def check_auto_inline(schedule_func, auto_inline_func): + s, T1 = schedule_func() + # before auto inline the attach type is AttachType.kGroupRoot + assert s[T1].attach_type == 1 + auto_inline_func(s) + # after auto inline the attach type is AttachType.kInline + assert s[T1].attach_type == 2 + s = s.normalize() + bounds = tvm.te.schedule.InferBound(s) + stmt = tvm.te.schedule.ScheduleOps(s, bounds) + + check_auto_inline(elemwise, tvm.te.schedule.AutoInlineElemWise) + check_auto_inline(broadcast, tvm.te.schedule.AutoInlineBroadcast) + check_auto_inline(injective, tvm.te.schedule.AutoInlineInjective) def test_schedule_const_bound(): diff --git a/tests/python/unittest/test_te_schedule_tensorize.py b/tests/python/unittest/test_te_schedule_tensorize.py index 83a5d30bb90d..fdafdb74fc0b 100644 --- a/tests/python/unittest/test_te_schedule_tensorize.py +++ b/tests/python/unittest/test_te_schedule_tensorize.py @@ -18,14 +18,22 @@ from tvm import te -def intrin_vadd(n): +def intrin_vadd(xo, m, n): x = te.placeholder((n,), name="vx") y = te.placeholder((n,), name="vy") - z = te.compute(x.shape, lambda i: x[i] + y[i], name="z") + if m % n == 0: + body = lambda i: x[i] + y[i] + else: + body = lambda i: tvm.tir.Select( + xo * n + i < m, x[i] + y[i], tvm.tir.const(0, dtype=x.dtype) + ) + z = te.compute(x.shape, body, name="z") def intrin_func(ins, outs): xx, yy = ins zz = outs[0] + # special handle needed to tackle tail loop part when m % n != 0 + # here is tvm.min(n, m - xo * n) return tvm.tir.call_packed("vadd", xx, yy, zz) buffer_params = {"offset_factor": 16} @@ -84,15 +92,17 @@ def intrin_func(ins, outs): def test_tensorize_vadd(): - m = 128 - x = te.placeholder((m,), name="x") - y = te.placeholder((m,), name="y") - z = te.compute(x.shape, lambda i: x[i] + y[i], name="z") + def add(m): + x = te.placeholder((m,), name="x") + y = te.placeholder((m,), name="y") + z = te.compute(x.shape, lambda i: x[i] + y[i], name="z") + return x, y, z - def check(factor): + def check(m, factor): + x, y, z = add(m) s = te.create_schedule(z.op) xo, xi = s[z].split(z.op.axis[0], factor=factor) - vadd = intrin_vadd(factor) + vadd = intrin_vadd(xo, m, factor) s[z].tensorize(xi, vadd) s = s.normalize() dom_map = tvm.te.schedule.InferBound(s) @@ -108,7 +118,36 @@ def check(factor): stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [x, y, z]) - check(16) + def check_cache_write(m, factor): + x, y, z = add(m) + s = te.create_schedule(z.op) + _, _ = s[z].split(z.op.axis[0], factor=factor) + + z_global = s.cache_write(z, "global") + xo, xi = z_global.op.axis + + vadd = intrin_vadd(xo, m, factor) + s[z_global].tensorize(xi, vadd) + s = s.normalize() + dom_map = tvm.te.schedule.InferBound(s) + finfer = tvm.get_global_func("test.op.InferTensorizeRegion") + out_dom, in_dom = finfer(s[z_global], dom_map) + # outer loop var will be rebased, so min value is the new loop var and extent is 1 + assert tvm.ir.structural_equal(out_dom[xo].extent, 1) + assert isinstance(out_dom[xo].min, tvm.tir.Var) + assert xo.var.name == out_dom[xo].min.name + + fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") + body = fmatch(s[z_global], out_dom, in_dom, vadd)[0] + ana = tvm.arith.Analyzer() + vars = tvm.runtime.convert({xo.var: out_dom[xo].min}) + vadd_body = tvm.tir.stmt_functor.substitute(vadd.op.body[0], vars) + assert tvm.ir.structural_equal(ana.simplify(body), ana.simplify(vadd_body)) + stmt = tvm.te.schedule.ScheduleOps(s, dom_map) + tvm.lower(s, [x, y, z]) + + check(128, 16) + check_cache_write(129, 16) def test_tensorize_matmul(): diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 46bc500fc503..8ad5cb63924e 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -405,6 +405,7 @@ def check_target(target, ir): check_target("llvm", mandel_ir_cpu) check_target("npvtx", mandel_ir_gpu) check_target("cuda", mandel_ir_gpu) + check_target("vulkan", mandel_ir_gpu) def test_while_binary_search(): @@ -493,6 +494,7 @@ def check_target(target, ir): check_target("llvm", searchsorted_ir_cpu) check_target("cuda", searchsorted_ir_gpu) check_target("nvptx", searchsorted_ir_gpu) + check_target("vulkan", searchsorted_ir_gpu) if __name__ == "__main__": diff --git a/tests/scripts/task_python_microtvm.sh b/tests/scripts/task_python_microtvm.sh index ba8018667895..2e06932ba536 100755 --- a/tests/scripts/task_python_microtvm.sh +++ b/tests/scripts/task_python_microtvm.sh @@ -18,6 +18,7 @@ set -e set -u +set -x # NOTE(areusch): Adding to diagnose flaky timeouts source tests/scripts/setup-pytest-env.sh diff --git a/tutorials/auto_scheduler/ci_logs/sparse_dense.json b/tutorials/auto_scheduler/ci_logs/sparse_dense.json new file mode 100644 index 000000000000..7c1c100124dc --- /dev/null +++ b/tutorials/auto_scheduler/ci_logs/sparse_dense.json @@ -0,0 +1,2 @@ +# Keep a valid schedule for demonstraction. This is used to prevent flasky errors in CI. +{"i": [["[\"sparse_dense\", 512, 512, 512, [9831, 16, 1], [9831], [33], \"float32\"]", "llvm -keys=cpu -link-params=0", [6, 64, 64, 0, 0, 0, 0, 0], "", 1, ["sparse_dense_bsr_512_512_512_16_1_0.60_W_data", "sparse_dense_bsr_512_512_512_16_1_0.60_W_indices", "sparse_dense_bsr_512_512_512_16_1_0.60_W_indptr"]], [[], [["CI", 8], ["CI", 6], ["SP", 5, 0, 512, [1, 8], 1], ["FSP", 9, 0, 2, 1], ["SP", 5, 3, 32, [32], 1], ["FSP", 9, 2, 4, 1], ["RE", 5, [0, 3, 1, 4, 6, 2, 5, 7]], ["RE", 9, [0, 2, 1, 3]], ["CA", 5, 9, 1], ["CI", 4], ["FU", 9, [0, 1]], ["AN", 9, 0, 3], ["PR", 5, 0, "auto_unroll_max_step$0"], ["AN", 9, 2, 2]]]], "r": [[0.000957008], 0, 0.605709, 1614689820], "v": "v0.6"} diff --git a/tutorials/auto_scheduler/tune_sparse_x86.py b/tutorials/auto_scheduler/tune_sparse_x86.py new file mode 100644 index 000000000000..ced416f6c500 --- /dev/null +++ b/tutorials/auto_scheduler/tune_sparse_x86.py @@ -0,0 +1,339 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Auto-scheduling Sparse Matrix Multiplication on CPU with Custom Sketch Rule +=========================================================================== +**Author**: `Chengfan Jia `_ + +This is a tutorial on how to use the auto-scheduler to tune a sparse matrix multiplication for +CPUs. + +Auto-scheduler is designed to explore the schedule with best performance for a given computation +declaration automatically. While sometimes, we may have a demand to try some special ops which may +not been well-supported by auto-scheduler's default sketch rules and result in poor performance. +Fortunately, auto-scheduler currently allows user to provide a CustomSketch to cover these cases. + +We use sparse matrix multiplication as an example in this tutorial to demonstrate how to implement +and plug a custom sketch rule to the auto-scheduler's search policy. + +Note that this tutorial will not run on Windows or recent versions of macOS. To +get it to run, you will need to wrap the body of this tutorial in a :code:`if +__name__ == "__main__":` block. +""" + +import os +import itertools + +import numpy as np +import tvm +from tvm import te, auto_scheduler, runtime, topi +from tvm.auto_scheduler import _ffi_api +from tvm.topi.utils import get_const_tuple + +import scipy.sparse as sp + +###################################################################### +# Define the computation +# ^^^^^^^^^^^^^^^^^^^^^^ +# To begin with, let us define the computation of a sparse matmul with several relu and bias add. +# The function should return the list of input/output tensors. +# From these tensors, the auto-scheduler can get the whole computational graph. + +# We use this function to generate a random bsr matrix +def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype): + import itertools + + Y = np.zeros((M, N), dtype=dtype) + assert M % BS_R == 0 + assert N % BS_C == 0 + nnz = int(density * M * N) + num_blocks = int(nnz / (BS_R * BS_C)) + 1 + candidate_blocks = np.asarray(list(itertools.product(range(0, M, BS_R), range(0, N, BS_C)))) + assert candidate_blocks.shape[0] == M // BS_R * N // BS_C + chosen_blocks = candidate_blocks[ + np.random.choice(candidate_blocks.shape[0], size=num_blocks, replace=False) + ] + for i in range(len(chosen_blocks)): + r, c = chosen_blocks[i] + Y[r : r + BS_R, c : c + BS_C] = np.random.randn(BS_R, BS_C) + s = sp.bsr_matrix(Y, blocksize=(BS_R, BS_C)) + assert s.data.shape == (num_blocks, BS_R, BS_C) + assert s.indices.shape == (num_blocks,) + assert s.indptr.shape == (M // BS_R + 1,) + return s + + +@auto_scheduler.register_workload +def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype): + X = te.placeholder(shape=(M, K), dtype=dtype) + W_data = te.placeholder(shape=w_data_shape, dtype=dtype) + W_indices = te.placeholder(shape=w_indices_shape, dtype="int32") + W_indptr = te.placeholder(shape=w_indptr_shape, dtype="int32") + B = te.placeholder(shape=(M, N), dtype=dtype) + + out = topi.nn.sparse_dense(topi.nn.relu(X), W_data, W_indices, W_indptr) + out = te.compute((M, N), lambda i, j: out[i, j] + B[i, j], name="BiasAdd") + out = topi.nn.relu(out) + + return [X, W_data, W_indices, W_indptr, B, out] + + +###################################################################### +# Special step for sparse workload +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# During schedule tuning, auto-scheduler will use random inputs to measure the performance of a +# generated schedule. While we cannot directly use a random array as the input of a sparse op, for +# the "indices" and "indptr" array are meaningful for the computation. +# +# To solve this problem, we register these as special buffers, and load them when process program +# measuring. +# See the `tvm.auto_scheduler.measure.py` for more details. + +# Define the basic shapes of this sparse computation +M = K = N = 512 +BS_R = 16 +BS_C = 1 +density = 0.6 + +# Generate the test data with numpy +X_np = np.random.randn(M, K).astype("float32") +X_np = np.maximum(np.zeros((M, K), dtype="float32"), X_np) # Relu +W_sp_np = random_bsr_matrix(N, K, BS_R, BS_C, density=density, dtype="float32") +W_np = W_sp_np.todense() +Y_np = X_np @ W_np.T # Process the matrix multiplication +B_np = np.random.randn(M, N).astype("float32") +Y_np = Y_np + B_np # Bias add +Y_np = np.maximum(np.zeros((M, N), dtype="float32"), Y_np) # Relu + +###################################################################### +# Create the search task +# ^^^^^^^^^^^^^^^^^^^^^^ +# We then create a search task with M=N=K=512 and dtype="float32" +# If your machine supports avx instructions, you can +# +# - replace "llvm" below with "llvm -mcpu=core-avx2" to enable AVX2 +# - replace "llvm" below with "llvm -mcpu=skylake-avx512" to enable AVX-512 + +target = tvm.target.Target("llvm") + +# Register the sparse data to task inputs +prefix = "sparse_dense_bsr_%d_%d_%d_%d_%d_%.2f_" % (M, N, K, BS_R, BS_C, density) +task = tvm.auto_scheduler.SearchTask( + func=sparse_dense, + args=(M, N, K, W_sp_np.data.shape, W_sp_np.indices.shape, W_sp_np.indptr.shape, "float32"), + target=target, + task_inputs={ + prefix + "W_data": runtime.ndarray.array(W_sp_np.data), + prefix + "W_indices": runtime.ndarray.array(W_sp_np.indices), + prefix + "W_indptr": runtime.ndarray.array(W_sp_np.indptr), + }, + task_inputs_save_to_file=True, +) + +# Inspect the computational graph +print("Computational DAG:") +print(task.compute_dag) + +###################################################################### +# Write the custom sketch for sparse dense op +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# Before tuning, we will need to define the CustomSketchRule for the sparse dense op. +# +# CustomSketchRule consists of two parts: the condition function and the apply function. +# +# - condition function: describe when to apply this sketch rule. For example, we can only apply +# the rule to the sparse ops by matching their name and tag. +# - apply function: describe how to generate the initial sketch. You can implement it using +# auto-scheduler provided loop state APIs. + + +def meet_condition_func(search_policy, state, stage_id): + state = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag) + if state.stages[stage_id].op.tag in [ + "sparse_dense_sp_rhs_bsrmm", + "sparse_dense_sp_rhs_bsrmm_block", + ]: + return auto_scheduler.PreloadCustomSketchRule.APPLY_AND_SKIP_REST + else: + return auto_scheduler.PreloadCustomSketchRule.PASS + + +def apply_func(search_policy, state, stage_id): + ret = [] + s0 = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag) + if s0.stages[stage_id].op.tag == "sparse_dense_sp_rhs_bsrmm_block": + return [s0.state_object, stage_id - 1] + + sparse_dense = s0.stages[stage_id].op + sparse_dense_block = s0.stages[stage_id - 1].op + assert sparse_dense.tag == "sparse_dense_sp_rhs_bsrmm" + assert sparse_dense_block.tag == "sparse_dense_sp_rhs_bsrmm_block" + + # Set the default consumer of compute block + consumer = sparse_dense + + # If sparse dense has a single elementwise consumer + # We can compute inline the sparse_dense output stage + consumers = _ffi_api.SearchPolicyUtilsGetConsumers( + search_policy.search_task, s0.state_object, stage_id + ) + if len(consumers) == 1: + consumer_id = int(consumers.items()[0][0]) + if _ffi_api.SearchPolicyUtilsIsElementwiseMatch( + search_policy.search_task, s0.state_object, stage_id, consumer_id + ): + consumer = s0.stages[consumer_id].op + s0.compute_inline(sparse_dense) + + i, nb_j, j, row_offset, c = s0[sparse_dense_block].iters + m, n = s0[consumer].iters + i0, i1, i2 = s0.split(sparse_dense_block, i, [None, None]) + m0, m1 = s0.follow_split(consumer, m, len(s0.transform_steps) - 1, 1) + j0, j1 = s0.split(sparse_dense_block, nb_j, [None]) + n0, n1 = s0.follow_split(consumer, n, len(s0.transform_steps) - 1, 1) + s0.reorder(sparse_dense_block, [i0, j0, i1, j1, row_offset, i2, j, c]) + s0.reorder(consumer, [m0, n0, m1, n1]) + s0.compute_at(sparse_dense_block, consumer, n0) + + ret.append([s0.state_object, stage_id - 2]) + + return ret + + +###################################################################### +# Next, we set parameters for the auto-scheduler with the custom sketch plugged in. +# +# * :code:`num_measure_trials` is the number of measurement trials we can use during the search. +# We only make 10 trials in this tutorial for a fast demonstration. In practice, 1000 is a +# good value for the search to converge. You can do more trials according to your time budget. +# * In addition, we use :code:`RecordToFile` to dump measurement records into a file +# `sparse_dense.json`. +# The measurement records can be used to query the history best, resume the search, +# and do more analyses later. +# * see :any:`auto_scheduler.TuningOptions` for more parameters +# * Here, we need to create a :code:`auto_scheduler.SketchPolicy` object, and add the custom sketch +# rule as a `init_search_callbacks`. + +log_file = "sparse_dense.json" +tune_option = auto_scheduler.TuningOptions( + num_measure_trials=10, + measure_callbacks=[auto_scheduler.RecordToFile(log_file)], + verbose=2, +) + +search_policy = auto_scheduler.SketchPolicy( + task, + program_cost_model=auto_scheduler.XGBModel(), + init_search_callbacks=[ + auto_scheduler.PreloadCustomSketchRule(meet_condition_func, apply_func, "SparseDense") + ], +) + +###################################################################### +# Run the search +# ^^^^^^^^^^^^^^ +# Now we get all inputs ready. +# We can kick off the search and let the auto-scheduler do its magic. +# After some measurement trials, we can load the best schedule from the log +# file and apply it. + +# Run auto-tuning (search) +# Notice: We do not run the tuning in our webpage server since it takes too long. +# Uncomment the following line to run it by yourself. +task.tune(tune_option, search_policy) + +# Apply the best schedule +sch, args = task.apply_best(log_file) + +###################################################################### +# We can lower the schedule to see the IR after auto-scheduling. +# The auto-scheduler correctly performs optimizations including multi-level tiling, +# layout transformation, parallelization, vectorization, unrolling, and operator fusion. + +print("Lowered TIR:") +print(tvm.lower(sch, args, simple_mode=True)) + +###################################################################### +# Check correctness and evaluate performance +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# We build the binary and check its correctness and performance. + +func = tvm.build(sch, args, target) + +ctx = tvm.cpu() + +X_tvm = tvm.nd.array(X_np, ctx=ctx) +W_data_tvm = tvm.nd.array(W_sp_np.data, ctx=ctx) +W_indices_tvm = tvm.nd.array(W_sp_np.indices, ctx=ctx) +W_indptr_tvm = tvm.nd.array(W_sp_np.indptr, ctx=ctx) +B_tvm = tvm.nd.array(B_np, ctx=ctx) +Y_tvm = tvm.nd.empty(Y_np.shape, ctx=ctx) + +func(X_tvm, W_data_tvm, W_indices_tvm, W_indptr_tvm, B_tvm, Y_tvm) + +# Check results +tvm.testing.assert_allclose(Y_np, Y_tvm.asnumpy(), atol=1e-4, rtol=1e-4) + +# Evaluate execution time. +evaluator = func.time_evaluator(func.entry_name, ctx, min_repeat_ms=500) +print( + "Execution time of this operator: %.3f ms" + % ( + np.median(evaluator(X_tvm, W_data_tvm, W_indices_tvm, W_indptr_tvm, B_tvm, Y_tvm).results) + * 1000 + ) +) + +###################################################################### +# .. note:: Tuning result example +# +# .. code-block:: c +# +# ---------------------------------------------------------------------- +# Lowered TIR: +# primfn(placeholder_5: handle, placeholder_6: handle, placeholder_7: handle, placeholder_8: handle, placeholder_9: handle, compute_1: handle) -> () +# attr = {"global_symbol": "main", "tir.noalias": True} +# buffers = {placeholder_2: Buffer(placeholder_10: Pointer(float32), float32, [9831, 16, 1], []), +# placeholder_4: Buffer(placeholder_11: Pointer(int32), int32, [33], []), +# placeholder_3: Buffer(placeholder_12: Pointer(float32), float32, [512, 512], []), +# compute: Buffer(compute_2: Pointer(float32), float32, [512, 512], []), +# placeholder_1: Buffer(placeholder_13: Pointer(float32), float32, [512, 512], []), +# placeholder: Buffer(placeholder_14: Pointer(int32), int32, [9831], [])} +# buffer_map = {placeholder_7: placeholder, placeholder_9: placeholder_1, placeholder_6: placeholder_2, compute_1: compute, placeholder_5: placeholder_3, placeholder_8: placeholder_4} { +# for (i0.outer.i1.outer.fused: int32, 0, 1024) "parallel" { +# attr [compute_3: Pointer(float32)] "storage_scope" = "global"; +# allocate(compute_3, float32, [256]) { +# for (nb_j.inner: int32, 0, 2) { +# for (i.inner.init: int32, 0, 8) { +# for (j.init: int32, 0, 16) { +# compute_3[(((i.inner.init*32) + (nb_j.inner*16)) + j.init)] = 0f32 +# } +# } +# for (elem_idx: int32, 0, ((int32*)placeholder_11[(((floormod(i0.outer.i1.outer.fused, 16)*2) + nb_j.inner) + 1)] - (int32*)placeholder_11[((floormod(i0.outer.i1.outer.fused, 16)*2) + nb_j.inner)])) { +# for (i.inner: int32, 0, 8) { +# for (j: int32, 0, 16) { +# compute_3[(((i.inner*32) + (nb_j.inner*16)) + j)] = ((float32*)compute_3[(((i.inner*32) + (nb_j.inner*16)) + j)] + ((float32*)placeholder_10[((((int32*)placeholder_11[((floormod(i0.outer.i1.outer.fused, 16)*2) + nb_j.inner)]*16) + (elem_idx*16)) + j)]*max((float32*)placeholder_12[(((floordiv(i0.outer.i1.outer.fused, 16)*4096) + (i.inner*512)) + (int32*)placeholder_14[((int32*)placeholder_11[((floormod(i0.outer.i1.outer.fused, 16)*2) + nb_j.inner)] + elem_idx)])], 0f32))) +# } +# } +# } +# } +# for (i0.inner: int32, 0, 8) { +# compute_2[ramp((((floordiv(i0.outer.i1.outer.fused, 16)*4096) + (i0.inner*512)) + (floormod(i0.outer.i1.outer.fused, 16)*32)), 1, 32)] = max(((float32x32*)compute_3[ramp((i0.inner*32), 1, 32)] + (float32x32*)placeholder_13[ramp((((floordiv(i0.outer.i1.outer.fused, 16)*4096) + (i0.inner*512)) + (floormod(i0.outer.i1.outer.fused, 16)*32)), 1, 32)]), broadcast(0f32, 32)) +# } +# } +# } +# } diff --git a/tutorials/frontend/deploy_sparse.py b/tutorials/frontend/deploy_sparse.py index 9641fb8fd14c..98004a93c74f 100644 --- a/tutorials/frontend/deploy_sparse.py +++ b/tutorials/frontend/deploy_sparse.py @@ -81,7 +81,7 @@ import itertools import numpy as np import tensorflow as tf -from tvm import relay +from tvm import relay, runtime from tvm.contrib import graph_runtime from tvm.relay import data_dep_optimization as ddo from tensorflow.python.framework.convert_to_constants import ( @@ -196,7 +196,7 @@ def import_graphdef( with open(os.path.join(abs_path, relay_file), "w") as fo: fo.write(tvm.ir.save_json(mod)) with open(os.path.join(abs_path, relay_params), "wb") as fo: - fo.write(relay.save_param_dict(params)) + fo.write(runtime.save_param_dict(params)) return mod, params, shape_dict