diff --git a/ci/jenkins/Jenkins_steps.groovy b/ci/jenkins/Jenkins_steps.groovy index 0e8bc36bfe89..3f5fb2503b56 100644 --- a/ci/jenkins/Jenkins_steps.groovy +++ b/ci/jenkins/Jenkins_steps.groovy @@ -24,7 +24,7 @@ utils = load('ci/Jenkinsfile_utils.groovy') // mxnet libraries mx_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' -mx_lib_cython = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, python/mxnet/_cy3/*.so' +mx_lib_cython = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, python/mxnet/_cy3/*.so, python/mxnet/_ffi/_cy3/*.so' // Python wheels mx_pip = 'build/*.whl' @@ -32,15 +32,15 @@ mx_pip = 'build/*.whl' // mxnet cmake libraries, in cmake builds we do not produce a libnvvm static library by default. mx_cmake_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so' mx_cmake_lib_no_tvm_op = 'build/libmxnet.so, build/libmxnet.a, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so' -mx_cmake_lib_cython = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, python/mxnet/_cy3/*.so' +mx_cmake_lib_cython = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, python/mxnet/_cy3/*.so, python/mxnet/_ffi/_cy3/*.so' // mxnet cmake libraries, in cmake builds we do not produce a libnvvm static library by default. mx_cmake_lib_debug = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests' mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so' mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' mx_tensorrt_lib = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so' -mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy3/*.so' -mx_lib_cpp_capi = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, lib/libmkldnn.so.1, lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy3/*.so, build/tests/cpp/mxnet_unit_tests' -mx_lib_cpp_examples_no_tvm_op = 'lib/libmxnet.so, lib/libmxnet.a, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy3/*.so' +mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy3/*.so, python/mxnet/_ffi/_cy3/*.so' +mx_lib_cpp_capi = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, lib/libmkldnn.so.1, lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy3/*.so, python/mxnet/_ffi/_cy3/*.so, build/tests/cpp/mxnet_unit_tests' +mx_lib_cpp_examples_no_tvm_op = 'lib/libmxnet.so, lib/libmxnet.a, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy3/*.so, python/mxnet/_ffi/_cy3/*.so' mx_lib_cpp_examples_cpu = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/cpp-package/example/*' // Python unittest for CPU diff --git a/include/mxnet/api_registry.h b/include/mxnet/api_registry.h new file mode 100644 index 000000000000..168887a42f10 --- /dev/null +++ b/include/mxnet/api_registry.h @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file api_registry.h + * \brief This file contains utilities related to + * the MXNet's global function registry. + */ +// Acknowledgement: This file originates from incubator-tvm +#ifndef MXNET_API_REGISTRY_H_ +#define MXNET_API_REGISTRY_H_ + +#include +#include +#include "runtime/registry.h" + +namespace mxnet { +/*! + * \brief Register an API function globally. + * It simply redirects to MXNET_REGISTER_GLOBAL + * + * \code + * MXNET_REGISTER_API(MyPrint) + * .set_body([](MXNetArgs args, MXNetRetValue* rv) { + * // my code. + * }); + * \endcode + */ +#define MXNET_REGISTER_API(OpName) MXNET_REGISTER_GLOBAL(OpName) + +} // namespace mxnet +#endif // MXNET_API_REGISTRY_H_ diff --git a/include/mxnet/expr_operator.h b/include/mxnet/expr_operator.h new file mode 100644 index 000000000000..c28761c0d1b9 --- /dev/null +++ b/include/mxnet/expr_operator.h @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file expr_operator.h + * \brief Common operators defined for Expr. + * + * \note Most of the operator defined here perform simple constant folding + * when the type is int32 or int64 for simplifying the index expressions. + */ +// Acknowledgement: This file originates from incubator-tvm +// Acknowledgement: Most operator APIs originate from Halide. +#ifndef MXNET_EXPR_OPERATOR_H_ +#define MXNET_EXPR_OPERATOR_H_ + +#include + +namespace mxnet { + +template +inline PrimExpr MakeConstScalar(MXNetDataType t, ValueType value) { + if (t.is_int()) return IntImm(t, static_cast(value)); + if (t.is_float()) return FloatImm(t, static_cast(value)); + // customized type and uint is not supported for MXNet for now + LOG(FATAL) << "cannot make const for type " << t; + return PrimExpr(); +} + + +template +inline PrimExpr make_const(MXNetDataType t, ValueType value) { + if (t.lanes() == 1) { + return MakeConstScalar(t, value); + } else { + LOG(FATAL) << "MXNetDataType::lanes() != 1 is not supported "; + } + return PrimExpr(); +} + +} // namespace mxnet + +#endif // MXNET_EXPR_OPERATOR_H_ diff --git a/include/mxnet/ir/expr.h b/include/mxnet/ir/expr.h new file mode 100644 index 000000000000..b9483c74320a --- /dev/null +++ b/include/mxnet/ir/expr.h @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file expr.h + * \brief Base expr nodes in MXNet. + */ +// Acknowledgement: This file originates from incubator-tvm +#ifndef MXNET_IR_EXPR_H_ +#define MXNET_IR_EXPR_H_ + +#include +#include +#include +#include +#include + +namespace mxnet { + +/*! + * \brief Base type of all the expressions. + * \sa Expr + */ +class BaseExprNode : public Object { + public: + static constexpr const char* _type_key = "Expr"; + MXNET_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object); +}; + +/*! + * \brief Managed reference to BaseExprNode. + * \sa BaseExprNode + */ +class BaseExpr : public ObjectRef { + public: + /*! \brief Cosntructor */ + BaseExpr() {} + /*! + * \brief Cosntructor from object ptr. + * \param ptr The object pointer. + */ + explicit BaseExpr(runtime::ObjectPtr ptr) : ObjectRef(ptr) {} + /*! \brief The container type. */ + using ContainerType = BaseExprNode; +}; + +/*! + * \brief Base node of all primitive expressions. + * + * A primitive expression deals with low-level + * POD data types and handles without + * doing life-cycle management for objects. + * + * PrimExpr is used in the low-level code + * optimizations and integer analysis. + * + * \sa PrimExpr + */ +class PrimExprNode : public BaseExprNode { + public: + /*! + * \brief The runtime data type of the primitive expression. + * + * MXNetDataType(dtype) provides coarse grained type information + * during compile time and runtime. It is eagerly built in + * PrimExpr expression construction and can be used for + * quick type checking. + * + * dtype is sufficient to decide the Type of the PrimExpr + * when it corresponds to POD value types such as i32. + * + * When dtype is MXNetDataType::Handle(), the expression could corresponds to + * a more fine-grained Type, and we can get the type by running lazy type inference. + */ + MXNetDataType dtype; + + static constexpr const char* _type_key = "PrimExpr"; + MXNET_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode); +}; + +/*! + * \brief Reference to PrimExprNode. + * \sa PrimExprNode + */ +class PrimExpr : public BaseExpr { + public: + /*! \brief Cosntructor */ + PrimExpr() {} + /*! + * \brief Cosntructor from object ptr. + * \param ptr The object pointer. + */ + explicit PrimExpr(runtime::ObjectPtr ptr) : BaseExpr(ptr) {} + /*! + * \brief construct from integer. + * \param value The value to be constructed. + */ + MXNET_DLL PrimExpr(int32_t value); // NOLINT(*) + /*! + * \brief construct from float. + * \param value The value to be constructed. + */ + MXNET_DLL PrimExpr(float value); // NOLINT(*) + /*! + * \brief construct from string. + * \param str The value to be constructed. + */ + MXNET_DLL PrimExpr(std::string str); // NOLINT(*) + + /*! \return the data type of this expression. */ + MXNetDataType dtype() const { + return static_cast(get())->dtype; + } + /*! \brief The container type. */ + using ContainerType = PrimExprNode; +}; + +/*! + * \brief Constant integer literals in the program. + * \sa IntImm + */ +class IntImmNode : public PrimExprNode { + public: + /*! \brief the Internal value. */ + int64_t value; + + static constexpr const char* _type_key = "IntImm"; + MXNET_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode); +}; + +/*! + * \brief Managed reference class to IntImmNode. + * + * \sa IntImmNode + */ +class IntImm : public PrimExpr { + public: + /*! + * \brief Constructor + */ + IntImm() {} + /*! + * \brief constructor from node. + */ + explicit IntImm(runtime::ObjectPtr node) : PrimExpr(node) {} + /*! + * \brief Constructor. + * \param dtype The data type of the value. + * \param value The internal value. + */ + MXNET_DLL IntImm(MXNetDataType dtype, int64_t value); + /*! + * \brief Get pointer to the internal value. + * \return the content of the integer. + */ + const IntImmNode* operator->() const { + return static_cast(get()); + } + /*! \brief type indicate the container type */ + using ContainerType = IntImmNode; +}; + +/*! + * \brief Constant floating point literals in the program. + * \sa FloatImm + */ +class FloatImmNode : public PrimExprNode { + public: + /*! \brief The constant value content. */ + double value; + + static constexpr const char* _type_key = "FloatImm"; + MXNET_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode); +}; + +/*! + * \brief Managed reference class to FloatImmNode. + * + * \sa FloatImmNode + */ +class FloatImm : public PrimExpr { + public: + /*! + * \brief Constructor + */ + FloatImm() {} + /*! + * \brief constructor from node. + */ + explicit FloatImm(runtime::ObjectPtr node) : PrimExpr(node) {} + /*! + * \brief Constructor. + * \param dtype The data type of the value. + * \param value The internal value. + */ + MXNET_DLL FloatImm(MXNetDataType dtype, double value); + /*! + * \brief Get pointer to the container. + * \return The pointer. + */ + const FloatImmNode* operator->() const { + return static_cast(get()); + } + /*! \brief type indicate the container type */ + using ContainerType = FloatImmNode; +}; + +} // namespace mxnet +#endif // MXNET_IR_EXPR_H_ diff --git a/include/mxnet/node/container.h b/include/mxnet/node/container.h new file mode 100644 index 000000000000..27b9853a74b7 --- /dev/null +++ b/include/mxnet/node/container.h @@ -0,0 +1,334 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file container.h + * \brief Array container + */ +// Acknowledgement: This file originates from incubator-tvm +#ifndef MXNET_NODE_CONTAINER_H_ +#define MXNET_NODE_CONTAINER_H_ + +#include + +#include +#include +#include +#include +#include +#include + +namespace mxnet { + +/*! \brief array node content in array */ +class ArrayNode : public Object { + public: + /*! \brief the data content */ + std::vector data; + + static constexpr const char* _type_key = "Array"; + MXNET_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object); +}; + +/*! + * \brief iterator adapter that adapts TIter to return another type. + * \tparam Converter a struct that contains converting function + * \tparam TIter the content iterator type. + */ +template +class IterAdapter { + public: + using difference_type = typename std::iterator_traits::difference_type; + using value_type = typename Converter::ResultType; + using pointer = typename Converter::ResultType*; + using reference = typename Converter::ResultType&; // NOLINT(*) + using iterator_category = typename std::iterator_traits::iterator_category; + + explicit IterAdapter(TIter iter) : iter_(iter) {} + inline IterAdapter& operator++() { + ++iter_; + return *this; + } + inline IterAdapter operator+(difference_type offset) const { + return IterAdapter(iter_ + offset); + } + + template + typename std::enable_if::value, + typename T::difference_type>::type + inline operator-(const IterAdapter& rhs) const { + return iter_ - rhs.iter_; + } + + inline bool operator==(IterAdapter other) const { + return iter_ == other.iter_; + } + inline bool operator!=(IterAdapter other) const { + return !(*this == other); + } + inline const value_type operator*() const { + return Converter::convert(*iter_); + } + + private: + TIter iter_; +}; + +/*! + * \brief Array container of NodeRef in DSL graph. + * Array implements copy on write semantics, which means array is mutable + * but copy will happen when array is referenced in more than two places. + * + * operator[] only provide const acces, use Set to mutate the content. + * \tparam T The content NodeRef type. + */ +template::value>::type > +class Array : public ObjectRef { + public: + /*! + * \brief default constructor + */ + Array() { + data_ = make_object(); + } + /*! + * \brief move constructor + * \param other source + */ + Array(Array && other) { // NOLINT(*) + data_ = std::move(other.data_); + } + /*! + * \brief copy constructor + * \param other source + */ + Array(const Array &other) { // NOLINT(*) + data_ = std::move(other.data_); + } + /*! + * \brief constructor from pointer + * \param n the container pointer + */ + explicit Array(runtime::ObjectPtr n) : ObjectRef(n) {} + /*! + * \brief constructor from iterator + * \param begin begin of iterator + * \param end end of iterator + * \tparam IterType The type of iterator + */ + template + Array(IterType begin, IterType end) { + assign(begin, end); + } + /*! + * \brief constructor from initializer list + * \param init The initalizer list + */ + Array(std::initializer_list init) { // NOLINT(*) + assign(init.begin(), init.end()); + } + /*! + * \brief constructor from vector + * \param init The vector + */ + Array(const std::vector& init) { // NOLINT(*) + assign(init.begin(), init.end()); + } + /*! + * \brief Constructs a container with n elements. Each element is a copy of val + * \param n The size of the container + * \param val The init value + */ + explicit Array(size_t n, const T& val) { + auto tmp_node = make_object(); + for (size_t i = 0; i < n; ++i) { + tmp_node->data.push_back(val); + } + data_ = std::move(tmp_node); + } + /*! + * \brief move assign operator + * \param other The source of assignment + * \return reference to self. + */ + Array& operator=(Array && other) { + data_ = std::move(other.data_); + return *this; + } + /*! + * \brief copy assign operator + * \param other The source of assignment + * \return reference to self. + */ + Array& operator=(const Array & other) { + data_ = other.data_; + return *this; + } + /*! + * \brief reset the array to content from iterator. + * \param begin begin of iterator + * \param end end of iterator + * \tparam IterType The type of iterator + */ + template + void assign(IterType begin, IterType end) { + auto n = make_object(); + for (IterType it = begin; it != end; ++it) { + n->data.push_back(T(*it)); + } + data_ = std::move(n); + } + /*! + * \brief Read i-th element from array. + * \param i The index + * \return the i-th element. + */ + inline const T operator[](size_t i) const { + return DowncastNoCheck( + static_cast(data_.get())->data[i]); + } + /*! \return The size of the array */ + inline size_t size() const { + if (data_.get() == nullptr) return 0; + return static_cast(data_.get())->data.size(); + } + /*! + * \brief copy on write semantics + * Do nothing if current handle is the unique copy of the array. + * Otherwise make a new copy of the array to ensure the current handle + * hold a unique copy. + * + * \return Handle to the internal node container(which ganrantees to be unique) + */ + inline ArrayNode* CopyOnWrite() { + if (data_.get() == nullptr || !data_.unique()) { + runtime::ObjectPtr n = make_object(); + n->data = static_cast(data_.get())->data; + runtime::ObjectPtr(std::move(n)).swap(data_); + } + return static_cast(data_.get()); + } + /*! + * \brief push a new item to the back of the list + * \param item The item to be pushed. + */ + inline void push_back(const T& item) { + ArrayNode* n = this->CopyOnWrite(); + n->data.push_back(item); + } + /*! + * \brief Resize the array. + * \param size The new size. + */ + inline void resize(size_t size) { + ArrayNode* n = this->CopyOnWrite(); + n->data.resize(size); + } + /*! + * \brief set i-th element of the array. + * \param i The index + * \param value The value to be setted. + */ + inline void Set(size_t i, const T& value) { + ArrayNode* n = this->CopyOnWrite(); + n->data[i] = value; + } + /*! \return whether array is empty */ + inline bool empty() const { + return size() == 0; + } + /*! + * \brief Helper function to apply fmutate to mutate an array. + * \param fmutate The transformation function T -> T. + * \tparam F the type of the mutation function. + * \note This function performs copy on write optimization. + */ + template + inline void MutateByApply(F fmutate) { + ArrayNode* ptr = static_cast(data_.get()); + if (ptr == nullptr) return; + if (data_.unique()) { + // Copy on write optimization. + // Perform inplace update because this is an unique copy. + for (size_t i = 0; i < ptr->data.size(); ++i) { + // It is important to use move here + // to make prevent the element's ref count from increasing + // so fmutate itself can perform copy-on-write optimization + T old_elem = DowncastNoCheck(std::move(ptr->data[i])); + T new_elem = fmutate(std::move(old_elem)); + ptr->data[i] = std::move(new_elem); + } + } else { + // lazily trigger copy if there is element change. + runtime::ObjectPtr copy; + for (size_t i = 0; i < ptr->data.size(); ++i) { + T old_elem = DowncastNoCheck(ptr->data[i]); + T new_elem = fmutate(old_elem); + if (!new_elem.same_as(ptr->data[i])) { + // copy the old array + if (copy == nullptr) { + copy = runtime::make_object(*ptr); + } + copy->data[i] = std::move(new_elem); + } + } + // replace the data with the new copy. + if (copy != nullptr) { + data_ = std::move(copy); + } + } + } + + /*! \brief specify container node */ + using ContainerType = ArrayNode; + + struct ValueConverter { + using ResultType = T; + static inline T convert(const ObjectRef& n) { + return DowncastNoCheck(n); + } + }; + using iterator = IterAdapter::const_iterator>; + + using reverse_iterator = IterAdapter< + ValueConverter, + std::vector::const_reverse_iterator>; + + /*! \return begin iterator */ + inline iterator begin() const { + return iterator(static_cast(data_.get())->data.begin()); + } + /*! \return end iterator */ + inline iterator end() const { + return iterator(static_cast(data_.get())->data.end()); + } + /*! \return rbegin iterator */ + inline reverse_iterator rbegin() const { + return reverse_iterator(static_cast(data_.get())->data.rbegin()); + } + /*! \return rend iterator */ + inline reverse_iterator rend() const { + return reverse_iterator(static_cast(data_.get())->data.rend()); + } +}; + +} // namespace mxnet +#endif // MXNET_NODE_CONTAINER_H_ diff --git a/include/mxnet/node/node.h b/include/mxnet/node/node.h new file mode 100644 index 000000000000..76bf0e67fad0 --- /dev/null +++ b/include/mxnet/node/node.h @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file node.h + * \brief Definitions and helper macros for IR/AST nodes. + * + * The node folder contains base utilities for IR/AST nodes, + * invariant of which specific language dialect. + * + * We implement AST/IR nodes as sub-classes of runtime::Object. + * The base class Node is just an alias of runtime::Object. + * + * Besides the runtime type checking provided by Object, + * node folder contains additional functionalities such as + * reflection and serialization, which are important features + * for building a compiler infra. + */ +// Acknowledgement: This file originates from incubator-tvm +#ifndef MXNET_NODE_NODE_H_ +#define MXNET_NODE_NODE_H_ + +#include +#include +#include + +#include +#include +#include +#include + +namespace mxnet { + +using runtime::TypeIndex; +using runtime::Object; +// We strictly restrict ObjectPtr to ::mxnet::runtime +// as it may conflict with ::nnvm::ObjectPtr +// using runtime::ObjectPtr; +using runtime::ObjectRef; +using runtime::GetRef; +using runtime::Downcast; +using runtime::ObjectHash; +using runtime::ObjectEqual; +using runtime::make_object; + +} // namespace mxnet + +#endif // MXNET_NODE_NODE_H_ diff --git a/include/mxnet/runtime/c_runtime_api.h b/include/mxnet/runtime/c_runtime_api.h new file mode 100644 index 000000000000..208a64326ac4 --- /dev/null +++ b/include/mxnet/runtime/c_runtime_api.h @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file c_runtime_api.h + * \brief MXNet runtime library. + */ +// Acknowledgement: This file originates from incubator-tvm +#ifndef MXNET_RUNTIME_C_RUNTIME_API_H_ +#define MXNET_RUNTIME_C_RUNTIME_API_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif +#include +#include +#include + + +/*! + * \brief The type code in MXNetType + * \note MXNetType is used in two places. + */ +typedef enum { + // The type code of other types are compatible with DLPack. + // The next few fields are extension types + // that is used by MXNet API calls. + kHandle = 3U, + kNull = 4U, + kMXNetType = 5U, + kMXNetContext = 6U, + kArrayHandle = 7U, + kObjectHandle = 8U, + kModuleHandle = 9U, + kFuncHandle = 10U, + kStr = 11U, + kBytes = 12U, + kNDArrayContainer = 13U, + kNDArrayHandle = 14U, + // Extension codes for other frameworks to integrate MXNet PackedFunc. + // To make sure each framework's id do not conflict, use first and + // last sections to mark ranges. + // Open an issue at the repo if you need a section of code. + kExtBegin = 15U, + kNNVMFirst = 16U, + kNNVMLast = 20U, + // The following section of code is used for non-reserved types. + kExtReserveEnd = 64U, + kExtEnd = 128U, + // The rest of the space is used for custom, user-supplied datatypes + kCustomBegin = 129U, +} MXNetTypeCode; + +/*! + * \brief Union type of values + * being passed through API and function calls. + */ +typedef union { + int64_t v_int64; + double v_float64; + void* v_handle; + const char* v_str; + DLDataType v_type; +} MXNetValue; + +/*! + * \brief Byte array type used to pass in byte array + * When kBytes is used as data type. + */ +typedef struct { + const char* data; + size_t size; +} MXNetByteArray; + +/*! \brief Handle to packed function handle. */ +typedef void* MXNetFunctionHandle; +/*! \brief Handle to Object. */ +typedef void* MXNetObjectHandle; + +/*! + * \brief Free the function when it is no longer needed. + * \param func The function handle + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXNetFuncFree(MXNetFunctionHandle func); + +/*! + * \brief Call a Packed MXNet Function. + * + * \param func node handle of the function. + * \param arg_values The arguments + * \param type_codes The type codes of the arguments + * \param num_args Number of arguments. + * + * \param ret_val The return value. + * \param ret_type_code the type code of return value. + * + * \return 0 when success, -1 when failure happens + * \note MXNet calls always exchanges with type bits=64, lanes=1 + * + * \note API calls always exchanges with type bits=64, lanes=1 + * If API call returns container handles (e.g. FunctionHandle) + * these handles should be managed by the front-end. + * The front-end need to call free function (e.g. MXNetFuncFree) + * to free these handles. + */ +MXNET_DLL int MXNetFuncCall(MXNetFunctionHandle func, + MXNetValue* arg_values, + int* type_codes, + int num_args, + MXNetValue* ret_val, + int* ret_type_code); + +/*! + * \brief Get a global function. + * + * \param name The name of the function. + * \param out the result function pointer, NULL if it does not exist. + * + * \note The function handle of global function is managed by MXNet runtime, + * So MXNetFuncFree is should not be called when it get deleted. + */ +MXNET_DLL int MXNetFuncGetGlobal(const char* name, MXNetFunctionHandle* out); + +/*! + * \brief List all the globally registered function name + * \param out_size The number of functions + * \param out_array The array of function names. + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXNetFuncListGlobalNames(int* out_size, + const char*** out_array); + +/*! + * \brief Free the object. + * + * \param obj The object handle. + * \note Internally we decrease the reference counter of the object. + * The object will be freed when every reference to the object are removed. + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXNetObjectFree(MXNetObjectHandle obj); + +#ifdef __cplusplus +} // extern "C" +#endif +#endif // MXNET_RUNTIME_C_RUNTIME_API_H_ diff --git a/include/mxnet/runtime/container.h b/include/mxnet/runtime/container.h new file mode 100644 index 000000000000..3dd7e0fc9c79 --- /dev/null +++ b/include/mxnet/runtime/container.h @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file container.h + * \brief Common POD(plain old data) container types. + */ +// Acknowledgement: This file originates from incubator-tvm +#ifndef MXNET_RUNTIME_CONTAINER_H_ +#define MXNET_RUNTIME_CONTAINER_H_ +#include +#include +#include + +#include +#include +#include +#include + +namespace mxnet { +namespace runtime { + +class ADTBuilder; +/*! + * \brief Base template for classes with array like memory layout. + * + * It provides general methods to access the memory. The memory + * layout is ArrayType + [ElemType]. The alignment of ArrayType + * and ElemType is handled by the memory allocator. + * + * \tparam ArrayType The array header type, contains object specific metadata. + * \tparam ElemType The type of objects stored in the array right after + * ArrayType. + * + * \code + * // Example usage of the template to define a simple array wrapper + * class ArrayObj : public InplaceArrayBase { + * public: + * // Wrap EmplaceInit to initialize the elements + * template + * void Init(Iterator begin, Iterator end) { + * size_t num_elems = std::distance(begin, end); + * auto it = begin; + * this->size = 0; + * for (size_t i = 0; i < num_elems; ++i) { + * InplaceArrayBase::EmplaceInit(i, *it++); + * this->size++; + * } + * } + * } + * + * void test_function() { + * vector fields; + * auto ptr = make_inplace_array_object(fields.size()); + * ptr->Init(fields.begin(), fields.end()); + * + * // Access the 0th element in the array. + * assert(ptr->operator[](0) == fields[0]); + * } + * + * \endcode + */ +template +class InplaceArrayBase { + public: + /*! + * \brief Access element at index + * \param idx The index of the element. + * \return Const reference to ElemType at the index. + */ + const ElemType& operator[](size_t idx) const { + size_t size = Self()->GetSize(); + CHECK_LT(idx, size) << "Index " << idx << " out of bounds " << size << "\n"; + return *(reinterpret_cast(AddressOf(idx))); + } + + /*! + * \brief Access element at index + * \param idx The index of the element. + * \return Reference to ElemType at the index. + */ + ElemType& operator[](size_t idx) { + size_t size = Self()->GetSize(); + CHECK_LT(idx, size) << "Index " << idx << " out of bounds " << size << "\n"; + return *(reinterpret_cast(AddressOf(idx))); + } + + /*! + * \brief Destroy the Inplace Array Base object + */ + ~InplaceArrayBase() { + if (!(std::is_standard_layout::value && + std::is_trivial::value)) { + size_t size = Self()->GetSize(); + for (size_t i = 0; i < size; ++i) { + ElemType* fp = reinterpret_cast(AddressOf(i)); + fp->ElemType::~ElemType(); + } + } + } + + protected: + friend class ADTBuilder; + /*! + * \brief Construct a value in place with the arguments. + * + * \tparam Args Type parameters of the arguments. + * \param idx Index of the element. + * \param args Arguments to construct the new value. + * + * \note Please make sure ArrayType::GetSize returns 0 before first call of + * EmplaceInit, and increment GetSize by 1 each time EmplaceInit succeeds. + */ + template + void EmplaceInit(size_t idx, Args&&... args) { + void* field_ptr = AddressOf(idx); + new (field_ptr) ElemType(std::forward(args)...); + } + + private: + /*! + * \brief Return the self object for the array. + * + * \return Pointer to ArrayType. + */ + inline ArrayType* Self() const { + return static_cast(const_cast(this)); + } + + /*! + * \brief Return the raw pointer to the element at idx. + * + * \param idx The index of the element. + * \return Raw pointer to the element. + */ + void* AddressOf(size_t idx) const { + static_assert(alignof(ArrayType) % alignof(ElemType) == 0 && + sizeof(ArrayType) % alignof(ElemType) == 0, + "The size and alignment of ArrayType should respect " + "ElemType's alignment."); + + size_t kDataStart = sizeof(ArrayType); + ArrayType* self = Self(); + char* data_start = reinterpret_cast(self) + kDataStart; + return data_start + idx * sizeof(ElemType); + } +}; + +/*! \brief An object representing a structure or enumeration. */ +class ADTObj : public Object, public InplaceArrayBase { + public: + /*! \brief The tag representing the constructor used. */ + uint32_t tag; + /*! \brief Number of fields in the ADT object. */ + uint32_t size{0}; + // The fields of the structure follows directly in memory. + + static constexpr const uint32_t _type_index = TypeIndex::kMXNetADT; + static constexpr const char* _type_key = "MXNet.ADT"; + MXNET_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object); + + private: + /*! + * \return The number of elements in the array. + */ + size_t GetSize() const { return size; } + + /*! + * \brief Initialize the elements in the array. + * + * \tparam Iterator Iterator type of the array. + * \param begin The begin iterator. + * \param end The end iterator. + */ + template + void Init(Iterator begin, Iterator end) { + size_t num_elems = std::distance(begin, end); + this->size = 0; + auto it = begin; + for (size_t i = 0; i < num_elems; ++i) { + InplaceArrayBase::EmplaceInit(i, *it++); + // Only increment size after the initialization succeeds + this->size++; + } + } + + friend class ADT; + friend InplaceArrayBase; +}; + +/*! \brief reference to algebraic data type objects. */ +class ADT : public ObjectRef { + public: + /*! + * \brief construct an ADT object reference. + * \param tag The tag of the ADT object. + * \param fields The fields of the ADT object. + * \return The constructed ADT object reference. + */ + ADT(uint32_t tag, std::vector fields) + : ADT(tag, fields.begin(), fields.end()){}; + + /*! + * \brief construct an ADT object reference. + * \param tag The tag of the ADT object. + * \param begin The begin iterator to the start of the fields array. + * \param end The end iterator to the end of the fields array. + * \return The constructed ADT object reference. + */ + template + ADT(uint32_t tag, Iterator begin, Iterator end) { + size_t num_elems = std::distance(begin, end); + auto ptr = make_inplace_array_object(num_elems); + ptr->tag = tag; + ptr->Init(begin, end); + data_ = std::move(ptr); + } + + /*! + * \brief construct an ADT object reference. + * \param tag The tag of the ADT object. + * \param init The initializer list of fields. + * \return The constructed ADT object reference. + */ + ADT(uint32_t tag, std::initializer_list init) + : ADT(tag, init.begin(), init.end()){}; + + /*! + * \brief Access element at index. + * + * \param idx The array index + * \return const ObjectRef + */ + const ObjectRef& operator[](size_t idx) const { + return operator->()->operator[](idx); + } + + /*! + * \brief Return the ADT tag. + */ + size_t tag() const { return operator->()->tag; } + + /*! + * \brief Return the number of fields. + */ + size_t size() const { return operator->()->size; } + + /*! + * \brief Construct a tuple object. + * + * \tparam Args Type params of tuple feilds. + * \param args Tuple fields. + * \return ADT The tuple object reference. + */ + template + static ADT Tuple(Args&&... args) { + return ADT(0, std::forward(args)...); + } + + MXNET_DEFINE_OBJECT_REF_METHODS(ADT, ObjectRef, ADTObj); +}; + +} // namespace runtime +} // namespace mxnet + +#endif // MXNET_RUNTIME_CONTAINER_H_ diff --git a/include/mxnet/runtime/data_type.h b/include/mxnet/runtime/data_type.h new file mode 100644 index 000000000000..01d776322e68 --- /dev/null +++ b/include/mxnet/runtime/data_type.h @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file data_type.h + * \brief Primitive runtime data type. + */ +// Acknowledgement: This file originates from incubator-tvm +// Acknowledgement: MXNetDataType structure design originates from Halide. +#ifndef MXNET_RUNTIME_DATA_TYPE_H_ +#define MXNET_RUNTIME_DATA_TYPE_H_ + +#include +#include +#include + + +namespace mxnet { +namespace runtime { +/*! + * \brief Runtime primitive data type. + * + * This class is a thin wrapper of DLDataType. + * We also make use of MXNetDataType in compiler to store quick hint + */ +class MXNetDataType { + public: + /*! \brief Type code for the MXNetDataType. */ + enum TypeCode { + kInt = kDLInt, + kUInt = kDLUInt, + kFloat = kDLFloat, + kHandle = MXNetTypeCode::kHandle, + }; + /*! \brief default constructor */ + MXNetDataType() {} + /*! + * \brief Constructor + * \param dtype The DLDataType + */ + explicit MXNetDataType(DLDataType dtype) + : data_(dtype) {} + /*! + * \brief Constructor + * \param code The type code. + * \param bits The number of bits in the type. + * \param lanes The number of lanes. + */ + MXNetDataType(int code, int bits, int lanes) { + data_.code = static_cast(code); + data_.bits = static_cast(bits); + data_.lanes = static_cast(lanes); + } + /*! \return The type code. */ + int code() const { + return static_cast(data_.code); + } + /*! \return number of bits in the data. */ + int bits() const { + return static_cast(data_.bits); + } + /*! \return number of bytes to store each scalar. */ + int bytes() const { + return (bits() + 7) / 8; + } + /*! \return number of lanes in the data. */ + int lanes() const { + return static_cast(data_.lanes); + } + /*! \return whether type is a scalar type. */ + bool is_scalar() const { + return lanes() == 1; + } + /*! \return whether type is a scalar type. */ + bool is_bool() const { + return code() == MXNetDataType::kUInt && bits() == 1; + } + /*! \return whether type is a float type. */ + bool is_float() const { + return code() == MXNetDataType::kFloat; + } + /*! \return whether type is an int type. */ + bool is_int() const { + return code() == MXNetDataType::kInt; + } + /*! \return whether type is an uint type. */ + bool is_uint() const { + return code() == MXNetDataType::kUInt; + } + /*! \return whether type is a handle type. */ + bool is_handle() const { + return code() == MXNetDataType::kHandle; + } + /*! \return whether type is a vector type. */ + bool is_vector() const { + return lanes() > 1; + } + /*! + * \brief Create a new data type by change lanes to a specified value. + * \param lanes The target number of lanes. + * \return the result type. + */ + MXNetDataType with_lanes(int lanes) const { + return MXNetDataType(data_.code, data_.bits, lanes); + } + /*! + * \brief Create a new data type by change bits to a specified value. + * \param bits The target number of bits. + * \return the result type. + */ + MXNetDataType with_bits(int bits) const { + return MXNetDataType(data_.code, bits, data_.lanes); + } + /*! + * \brief Get the scalar version of the type. + * \return the result type. + */ + MXNetDataType element_of() const { + return with_lanes(1); + } + /*! + * \brief Equal comparator. + * \param other The data type to compre against. + * \return The comparison resilt. + */ + bool operator==(const MXNetDataType& other) const { + return + data_.code == other.data_.code && + data_.bits == other.data_.bits && + data_.lanes == other.data_.lanes; + } + /*! + * \brief NotEqual comparator. + * \param other The data type to compre against. + * \return The comparison resilt. + */ + bool operator!=(const MXNetDataType& other) const { + return !operator==(other); + } + /*! + * \brief Converter to DLDataType + * \return the result. + */ + operator DLDataType () const { + return data_; + } + + /*! + * \brief Construct an int type. + * \param bits The number of bits in the type. + * \param lanes The number of lanes. + * \return The constructed data type. + */ + static MXNetDataType Int(int bits, int lanes = 1) { + return MXNetDataType(kDLInt, bits, lanes); + } + /*! + * \brief Construct an uint type. + * \param bits The number of bits in the type. + * \param lanes The number of lanes + * \return The constructed data type. + */ + static MXNetDataType UInt(int bits, int lanes = 1) { + return MXNetDataType(kDLUInt, bits, lanes); + } + /*! + * \brief Construct an uint type. + * \param bits The number of bits in the type. + * \param lanes The number of lanes + * \return The constructed data type. + */ + static MXNetDataType Float(int bits, int lanes = 1) { + return MXNetDataType(kDLFloat, bits, lanes); + } + /*! + * \brief Construct a bool type. + * \param lanes The number of lanes + * \return The constructed data type. + */ + static MXNetDataType Bool(int lanes = 1) { + return MXNetDataType::UInt(1, lanes); + } + /*! + * \brief Construct a handle type. + * \param bits The number of bits in the type. + * \param lanes The number of lanes + * \return The constructed data type. + */ + static MXNetDataType Handle(int bits = 64, int lanes = 1) { + return MXNetDataType(kHandle, bits, lanes); + } + + private: + DLDataType data_; +}; + +} // namespace runtime + +using MXNetDataType = runtime::MXNetDataType; + +} // namespace mxnet +#endif // MXNET_RUNTIME_DATA_TYPE_H_ diff --git a/include/mxnet/runtime/ffi_helper.h b/include/mxnet/runtime/ffi_helper.h new file mode 100644 index 000000000000..b539524dfd05 --- /dev/null +++ b/include/mxnet/runtime/ffi_helper.h @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ffi_helper + * \brief Helper class to support additional objects in FFI. + */ +// Acknowledgement: This file originates from incubator-tvm +#ifndef MXNET_RUNTIME_FFI_HELPER_H_ +#define MXNET_RUNTIME_FFI_HELPER_H_ + +#include +#include +#include +#include + +namespace mxnet { +namespace runtime { + +/*! \brief Ellipsis. */ +class EllipsisObj : public Object { + public: + static constexpr const uint32_t _type_index = TypeIndex::kEllipsis; + static constexpr const char* _type_key = "MXNet.Ellipsis"; + MXNET_DECLARE_FINAL_OBJECT_INFO(EllipsisObj, Object); +}; + +inline ObjectRef CreateEllipsis() { + return ObjectRef(make_object()); +} + +/*! \brief Slice. */ +class SliceObj : public Object { + public: + int64_t start; + int64_t stop; + int64_t step; + + static constexpr const uint32_t _type_index = TypeIndex::kSlice; + static constexpr const char* _type_key = "MXNet.Slice"; + MXNET_DECLARE_FINAL_OBJECT_INFO(SliceObj, Object); +}; + +class Slice : public ObjectRef { + public: + explicit inline Slice(int64_t start, int64_t stop, int64_t step, + ObjectPtr&& data = make_object()) { + data->start = start; + data->stop = stop; + data->step = step; + data_ = std::move(data); + } + + explicit inline Slice(int64_t stop) + : Slice(kNoneValue, stop, kNoneValue) { + } + + // constant to represent None. + static constexpr int64_t kNoneValue = std::numeric_limits::min(); + + MXNET_DEFINE_OBJECT_REF_METHODS(Slice, ObjectRef, SliceObj); +}; + +int64_t inline SliceNoneValue() { + return Slice::kNoneValue; +} + +class IntegerObj: public Object { + public: + int64_t value; + static constexpr const uint32_t _type_index = TypeIndex::kInteger; + static constexpr const char* _type_key = "MXNet.Integer"; + MXNET_DECLARE_FINAL_OBJECT_INFO(IntegerObj, Object); +}; + +class Integer: public ObjectRef { + public: + explicit Integer(int64_t value, + ObjectPtr&& data = make_object()) { + data->value = value; + data_ = std::move(data); + } + MXNET_DEFINE_OBJECT_REF_METHODS(Integer, ObjectRef, IntegerObj); +}; + +// Helper functions for fast FFI implementations +/*! + * \brief A builder class that helps to incrementally build ADT. + */ +class ADTBuilder { + public: + /*! \brief default constructor */ + ADTBuilder() = default; + + explicit inline ADTBuilder(uint32_t tag, uint32_t size) + : data_(make_inplace_array_object(size)) { + data_->size = size; + } + + template + void inline EmplaceInit(size_t idx, Args&&... args) { + data_->EmplaceInit(idx, std::forward(args)...); + } + + ADT inline Get() { + return ADT(std::move(data_)); + } + + private: + friend class ADT; + ObjectPtr data_; +}; +} // namespace runtime +} // namespace mxnet +#endif // MXNET_RUNTIME_FFI_HELPER_H_ diff --git a/include/mxnet/runtime/memory.h b/include/mxnet/runtime/memory.h new file mode 100644 index 000000000000..ea4b5a409d1e --- /dev/null +++ b/include/mxnet/runtime/memory.h @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file runtime/memory.h + * \brief Runtime memory management. + */ +// Acknowledgement: This file originates from incubator-tvm +#ifndef MXNET_RUNTIME_MEMORY_H_ +#define MXNET_RUNTIME_MEMORY_H_ + +#include +#include +#include +#include "object.h" + +namespace mxnet { +namespace runtime { +/*! + * \brief Allocate an object using default allocator. + * \param args arguments to the constructor. + * \tparam T the node type. + * \return The ObjectPtr to the allocated object. + */ +template +inline ObjectPtr make_object(Args&&... args); + +// Detail implementations after this +// +// The current design allows swapping the +// allocator pattern when necessary. +// +// Possible future allocator optimizations: +// - Arena allocator that gives ownership of memory to arena (deleter_= nullptr) +// - Thread-local object pools: one pool per size and alignment requirement. +// - Can specialize by type of object to give the specific allocator to each object. + +/*! + * \brief Base class of object allocators that implements make. + * Use curiously recurring template pattern. + * + * \tparam Derived The derived class. + */ +template +class ObjAllocatorBase { + public: + /*! + * \brief Make a new object using the allocator. + * \tparam T The type to be allocated. + * \tparam Args The constructor signature. + * \param args The arguments. + */ + template + inline ObjectPtr make_object(Args&&... args) { + using Handler = typename Derived::template Handler; + static_assert(std::is_base_of::value, + "make can only be used to create Object"); + T* ptr = Handler::New(static_cast(this), + std::forward(args)...); + ptr->type_index_ = T::RuntimeTypeIndex(); + ptr->deleter_ = Handler::Deleter(); + return ObjectPtr(ptr); + } + + /*! + * \tparam ArrayType The type to be allocated. + * \tparam ElemType The type of array element. + * \tparam Args The constructor signature. + * \param num_elems The number of array elements. + * \param args The arguments. + */ + template + inline ObjectPtr make_inplace_array(size_t num_elems, Args&&... args) { + using Handler = typename Derived::template ArrayHandler; + static_assert(std::is_base_of::value, + "make_inplace_array can only be used to create Object"); + ArrayType* ptr = Handler::New(static_cast(this), + num_elems, + std::forward(args)...); + ptr->type_index_ = ArrayType::RuntimeTypeIndex(); + ptr->deleter_ = Handler::Deleter(); + return ObjectPtr(ptr); + } +}; + +// Simple allocator that uses new/delete. +class SimpleObjAllocator : + public ObjAllocatorBase { + public: + template + class Handler { + public: + using StorageType = typename std::aligned_storage::type; + + template + static T* New(SimpleObjAllocator*, Args&&... args) { + // NOTE: the first argument is not needed for SimpleObjAllocator + // It is reserved for special allocators that needs to recycle + // the object to itself (e.g. in the case of object pool). + // + // In the case of an object pool, an allocator needs to create + // a special chunk memory that hides reference to the allocator + // and call allocator's release function in the deleter. + + // NOTE2: Use inplace new to allocate + // This is used to get rid of warning when deleting a virtual + // class with non-virtual destructor. + // We are fine here as we captured the right deleter during construction. + // This is also the right way to get storage type for an object pool. + StorageType* data = new StorageType(); + new (data) T(std::forward(args)...); + return reinterpret_cast(data); + } + + static Object::FDeleter Deleter() { + return Deleter_; + } + + private: + static void Deleter_(Object* objptr) { + // NOTE: this is important to cast back to T* + // because objptr and tptr may not be the same + // depending on how sub-class allocates the space. + T* tptr = static_cast(objptr); + // It is important to do tptr->T::~T(), + // so that we explicitly call the specific destructor + // instead of tptr->~T(), which could mean the intention + // call a virtual destructor(which may not be available and is not required). + tptr->T::~T(); + delete reinterpret_cast(tptr); + } + }; + + // Array handler that uses new/delete. + template + class ArrayHandler { + public: + using StorageType = typename std::aligned_storage::type; + // for now only support elements that aligns with array header. + static_assert(alignof(ArrayType) % alignof(ElemType) == 0 && + sizeof(ArrayType) % alignof(ElemType) == 0, + "element alignment constraint"); + + template + static ArrayType* New(SimpleObjAllocator*, size_t num_elems, Args&&... args) { + // NOTE: the first argument is not needed for ArrayObjAllocator + // It is reserved for special allocators that needs to recycle + // the object to itself (e.g. in the case of object pool). + // + // In the case of an object pool, an allocator needs to create + // a special chunk memory that hides reference to the allocator + // and call allocator's release function in the deleter. + // NOTE2: Use inplace new to allocate + // This is used to get rid of warning when deleting a virtual + // class with non-virtual destructor. + // We are fine here as we captured the right deleter during construction. + // This is also the right way to get storage type for an object pool. + size_t unit = sizeof(StorageType); + size_t requested_size = num_elems * sizeof(ElemType) + sizeof(ArrayType); + size_t num_storage_slots = (requested_size + unit - 1) / unit; + StorageType* data = new StorageType[num_storage_slots]; + new (data) ArrayType(std::forward(args)...); + return reinterpret_cast(data); + } + + static Object::FDeleter Deleter() { + return Deleter_; + } + + private: + static void Deleter_(Object* objptr) { + // NOTE: this is important to cast back to ArrayType* + // because objptr and tptr may not be the same + // depending on how sub-class allocates the space. + ArrayType* tptr = static_cast(objptr); + // It is important to do tptr->ArrayType::~ArrayType(), + // so that we explicitly call the specific destructor + // instead of tptr->~ArrayType(), which could mean the intention + // call a virtual destructor(which may not be available and is not required). + tptr->ArrayType::~ArrayType(); + StorageType* p = reinterpret_cast(tptr); + delete []p; + } + }; +}; + +template +inline ObjectPtr make_object(Args&&... args) { + return SimpleObjAllocator().make_object(std::forward(args)...); +} + +template +inline ObjectPtr make_inplace_array_object(size_t num_elems, Args&&... args) { + return SimpleObjAllocator().make_inplace_array( + num_elems, std::forward(args)...); +} + +} // namespace runtime +} // namespace mxnet +#endif // MXNET_RUNTIME_MEMORY_H_ diff --git a/include/mxnet/runtime/ndarray.h b/include/mxnet/runtime/ndarray.h new file mode 100644 index 000000000000..317c3239092d --- /dev/null +++ b/include/mxnet/runtime/ndarray.h @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file runtime/ndarray.h + * \brief A device-independent managed NDArray abstraction. + */ +// Acknowledgement: This file originates from incubator-tvm +#ifndef MXNET_RUNTIME_NDARRAY_H_ +#define MXNET_RUNTIME_NDARRAY_H_ + +namespace mxnet { +namespace runtime { + +/*! + * \brief The type trait indicates subclass of TVM's NDArray. + * For irrelavant classes, code = -1. + * For TVM NDArray itself, code = 0. + * All subclasses of NDArray should override code > 0. + */ +template +struct array_type_info { + /*! \brief the value of the traits */ + static const int code = -1; +}; + +} // namespace runtime +} // namespace mxnet +#endif // MXNET_RUNTIME_NDARRAY_H_ diff --git a/include/mxnet/runtime/object.h b/include/mxnet/runtime/object.h new file mode 100644 index 000000000000..e2fb067f1067 --- /dev/null +++ b/include/mxnet/runtime/object.h @@ -0,0 +1,823 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file object.h + * \brief A managed object in MXNet runtime. + */ +// Acknowledgement: This file originates from incubator-tvm +#ifndef MXNET_RUNTIME_OBJECT_H_ +#define MXNET_RUNTIME_OBJECT_H_ + +#include +#include +#include +#include +#include "c_runtime_api.h" + +/*! + * \brief Whether or not use atomic reference counter. + * If the reference counter is not atomic, + * an object cannot be owned by multiple threads. + * We can, however, move an object across threads + */ +#ifndef MXNET_OBJECT_ATOMIC_REF_COUNTER +#define MXNET_OBJECT_ATOMIC_REF_COUNTER 1 +#endif + +#if MXNET_OBJECT_ATOMIC_REF_COUNTER +#include +#endif // MXNET_OBJECT_ATOMIC_REF_COUNTER + +namespace mxnet { +namespace runtime { + +/*! \brief list of the type index. */ +enum TypeIndex { + /*! \brief Root object type. */ + kRoot = 0, + kMXNetTensor = 1, + kMXNetClosure = 2, + kMXNetADT = 3, + kRuntimeModule = 4, + kEllipsis = 5, + kSlice = 6, + kInteger = 7, + kStaticIndexEnd, + /*! \brief Type index is allocated during runtime. */ + kDynamic = kStaticIndexEnd +}; + +/*! + * \brief base class of all object containers. + * + * Sub-class of objects should declare the following static constexpr fields: + * + * - _type_index: + * Static type index of the object, if assigned to TypeIndex::kDynamic + * the type index will be assigned during runtime. + * Runtime type index can be accessed by ObjectType::TypeIndex(); + * - _type_key: + * The unique string identifier of tyep type. + * - _type_final: + * Whether the type is terminal type(there is no subclass of the type in the object system). + * This field is automatically set by marco MXNET_DECLARE_FINAL_OBJECT_INFO + * It is still OK to sub-class a terminal object type T and construct it using make_object. + * But IsInstance check will only show that the object type is T(instead of the sub-class). + * + * The following two fields are necessary for base classes that can be sub-classed. + * + * - _type_child_slots: + * Number of reserved type index slots for child classes. + * Used for runtime optimization for type checking in IsInstance. + * If an object's type_index is within range of [type_index, type_index + _type_child_slots] + * Then the object can be quickly decided as sub-class of the current object class. + * If not, a fallback mechanism is used to check the global type table. + * Recommendation: set to estimate number of children needed. + * - _type_child_slots_can_overflow: + * Whether we can add additional child classes even if the number of child classes + * exceeds the _type_child_slots. A fallback mechanism to check global type table will be used. + * Recommendation: set to false for optimal runtime speed if we know exact number of children. + * + * Two macros are used to declare helper functions in the object: + * - Use MXNET_DECLARE_BASE_OBJECT_INFO for object classes that can be sub-classed. + * - Use MXNET_DECLARE_FINAL_OBJECT_INFO for object classes that cannot be sub-classed. + * + * New objects can be created using make_object function. + * Which will automatically populate the type_index and deleter of the object. + * + * \sa make_object + * \sa ObjectPtr + * \sa ObjectRef + * + * \code + * + * // Create a base object + * class BaseObj : public Object { + * public: + * // object fields + * int field0; + * + * // object properties + * static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + * static constexpr const char* _type_key = "test.BaseObj"; + * MXNET_DECLARE_BASE_OBJECT_INFO(BaseObj, Object); + * }; + * + * class ObjLeaf : public ObjBase { + * public: + * // fields + * int child_field0; + * // object properties + * static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + * static constexpr const char* _type_key = "test.LeafObj"; + * MXNET_DECLARE_BASE_OBJECT_INFO(LeaffObj, Object); + * }; + * + * // The following code should be put into a cc file. + * MXNET_REGISTER_OBJECT_TYPE(ObjBase); + * MXNET_REGISTER_OBJECT_TYPE(ObjLeaf); + * + * // Usage example. + * void TestObjects() { + * // create an object + * ObjectRef leaf_ref(make_object()); + * // cast to a specific instance + * const LeafObj* leaf_ptr = leaf_ref.as(); + * CHECK(leaf_ptr != nullptr); + * // can also cast to the base class. + * CHECK(leaf_ref.as() != nullptr); + * } + * + * \endcode + */ +class Object { + public: + /*! + * \brief Object deleter + * \param self pointer to the Object. + */ + typedef void (*FDeleter)(Object* self); + /*! \return The internal runtime type index of the object. */ + uint32_t type_index() const { + return type_index_; + } + /*! + * \return the type key of the object. + * \note this operation is expensive, can be used for error reporting. + */ + std::string GetTypeKey() const { + return TypeIndex2Key(type_index_); + } + /*! + * \return A hash value of the return of GetTypeKey. + */ + size_t GetTypeKeyHash() const { + return TypeIndex2KeyHash(type_index_); + } + /*! + * Check if the object is an instance of TargetType. + * \tparam TargetType The target type to be checked. + * \return Whether the target type is true. + */ + template + inline bool IsInstance() const; + + /*! + * \brief Get the type key of the corresponding index from runtime. + * \param tindex The type index. + * \return the result. + */ + MXNET_DLL static std::string TypeIndex2Key(uint32_t tindex); + /*! + * \brief Get the type key hash of the corresponding index from runtime. + * \param tindex The type index. + * \return the related key-hash. + */ + MXNET_DLL static size_t TypeIndex2KeyHash(uint32_t tindex); + /*! + * \brief Get the type index of the corresponding key from runtime. + * \param key The type key. + * \return the result. + */ + MXNET_DLL static uint32_t TypeKey2Index(const std::string& key); + +#if MXNET_OBJECT_ATOMIC_REF_COUNTER + using RefCounterType = std::atomic; +#else + using RefCounterType = int32_t; +#endif + + static constexpr const char* _type_key = "Object"; + + static uint32_t _GetOrAllocRuntimeTypeIndex() { + return TypeIndex::kRoot; + } + static uint32_t RuntimeTypeIndex() { + return TypeIndex::kRoot; + } + + // Default object type properties for sub-classes + static constexpr bool _type_final = false; + static constexpr uint32_t _type_child_slots = 0; + static constexpr bool _type_child_slots_can_overflow = true; + // NOTE: the following field is not type index of Object + // but was intended to be used by sub-classes as default value. + // The type index of Object is TypeIndex::kRoot + static constexpr uint32_t _type_index = TypeIndex::kDynamic; + + // Default constructor and copy constructor + Object() {} + // Override the copy and assign constructors to do nothing. + // This is to make sure only contents, but not deleter and ref_counter + // are copied when a child class copies itself. + // This will enable us to use make_object(*obj_ptr) + // to copy an existing object. + Object(const Object& other) { // NOLINT(*) + } + Object(Object&& other) { // NOLINT(*) + } + Object& operator=(const Object& other) { //NOLINT(*) + return *this; + } + Object& operator=(Object&& other) { //NOLINT(*) + return *this; + } + + protected: + // The fields of the base object cell. + /*! \brief Type index(tag) that indicates the type of the object. */ + uint32_t type_index_{0}; + /*! \brief The internal reference counter */ + RefCounterType ref_counter_{0}; + /*! + * \brief deleter of this object to enable customized allocation. + * If the deleter is nullptr, no deletion will be performed. + * The creator of the object must always set the deleter field properly. + */ + FDeleter deleter_ = nullptr; + // Invariant checks. + static_assert(sizeof(int32_t) == sizeof(RefCounterType) && + alignof(int32_t) == sizeof(RefCounterType), + "RefCounter ABI check."); + + /*! + * \brief Get the type index using type key. + * + * When the function is first time called for a type, + * it will register the type to the type table in the runtime. + * If the static_tindex is TypeIndex::kDynamic, the function will + * allocate a runtime type index. + * Otherwise, we will populate the type table and return the static index. + * + * \param key the type key. + * \param static_tindex The current _type_index field. + * can be TypeIndex::kDynamic. + * \param parent_tindex The index of the parent. + * \param type_child_slots Number of slots reserved for its children. + * \param type_child_slots_can_overflow Whether to allow child to overflow the slots. + * \return The allocated type index. + */ + MXNET_DLL static uint32_t GetOrAllocRuntimeTypeIndex( + const std::string& key, + uint32_t static_tindex, + uint32_t parent_tindex, + uint32_t type_child_slots, + bool type_child_slots_can_overflow); + + // reference counter related operations + /*! \brief developer function, increases reference counter. */ + inline void IncRef(); + /*! + * \brief developer function, decrease reference counter. + * \note The deleter will be called when ref_counter_ becomes zero. + */ + inline void DecRef(); + + private: + /*! + * \return The usage count of the cell. + * \note We use stl style naming to be consistent with known API in shared_ptr. + */ + inline int use_count() const; + /*! + * \brief Check of this object is derived from the parent. + * \param parent_tindex The parent type index. + * \return The derivation results. + */ + MXNET_DLL bool DerivedFrom(uint32_t parent_tindex) const; + // friend classes + template + friend class ObjAllocatorBase; + template + friend class ObjectPtr; + friend class MXNetRetValue; + friend class ObjectInternal; +}; + +/*! + * \brief Get a reference type from a raw object ptr type + * + * It is always important to get a reference type + * if we want to return a value as reference or keep + * the object alive beyond the scope of the function. + * + * \param ptr The object pointer + * \tparam RefType The reference type + * \tparam ObjectType The object type + * \return The corresponding RefType + */ +template +inline RefType GetRef(const ObjectType* ptr); + +/*! + * \brief Downcast a base reference type to a more specific type. + * + * \param ref The inptut reference + * \return The corresponding SubRef. + * \tparam SubRef The target specific reference type. + * \tparam BaseRef the current reference type. + */ +template +inline SubRef Downcast(BaseRef ref); + +/*! + * \brief A custom smart pointer for Object. + * \tparam T the content data type. + * \sa make_object + */ +template +class ObjectPtr { + public: + /*! \brief default constructor */ + ObjectPtr() {} + /*! \brief default constructor */ + ObjectPtr(std::nullptr_t) {} // NOLINT(*) + /*! + * \brief copy constructor + * \param other The value to be moved + */ + ObjectPtr(const ObjectPtr& other) // NOLINT(*) + : ObjectPtr(other.data_) {} + /*! + * \brief copy constructor + * \param other The value to be moved + */ + template + ObjectPtr(const ObjectPtr& other) // NOLINT(*) + : ObjectPtr(other.data_) { + static_assert(std::is_base_of::value, + "can only assign of child class ObjectPtr to parent"); + } + /*! + * \brief move constructor + * \param other The value to be moved + */ + ObjectPtr(ObjectPtr&& other) // NOLINT(*) + : data_(other.data_) { + other.data_ = nullptr; + } + /*! + * \brief move constructor + * \param other The value to be moved + */ + template + ObjectPtr(ObjectPtr&& other) // NOLINT(*) + : data_(other.data_) { + static_assert(std::is_base_of::value, + "can only assign of child class ObjectPtr to parent"); + other.data_ = nullptr; + } + /*! \brief destructor */ + ~ObjectPtr() { + this->reset(); + } + /*! + * \brief Swap this array with another Object + * \param other The other Object + */ + void swap(ObjectPtr& other) { // NOLINT(*) + std::swap(data_, other.data_); + } + /*! + * \return Get the content of the pointer + */ + T* get() const { + return static_cast(data_); + } + /*! + * \return The pointer + */ + T* operator->() const { + return get(); + } + /*! + * \return The reference + */ + T& operator*() const { // NOLINT(*) + return *get(); + } + /*! + * \brief copy assignmemt + * \param other The value to be assigned. + * \return reference to self. + */ + ObjectPtr& operator=(const ObjectPtr& other) { // NOLINT(*) + // takes in plane operator to enable copy elison. + // copy-and-swap idiom + ObjectPtr(other).swap(*this); // NOLINT(*) + return *this; + } + /*! + * \brief move assignmemt + * \param other The value to be assigned. + * \return reference to self. + */ + ObjectPtr& operator=(ObjectPtr&& other) { // NOLINT(*) + // copy-and-swap idiom + ObjectPtr(std::move(other)).swap(*this); // NOLINT(*) + return *this; + } + /*! \brief reset the content of ptr to be nullptr */ + void reset() { + if (data_ != nullptr) { + data_->DecRef(); + data_ = nullptr; + } + } + /*! \return The use count of the ptr, for debug purposes */ + int use_count() const { + return data_ != nullptr ? data_->use_count() : 0; + } + /*! \return whether the reference is unique */ + bool unique() const { + return data_ != nullptr && data_->use_count() == 1; + } + /*! \return Whether two ObjectPtr do not equal each other */ + bool operator==(const ObjectPtr& other) const { + return data_ == other.data_; + } + /*! \return Whether two ObjectPtr equals each other */ + bool operator!=(const ObjectPtr& other) const { + return data_ != other.data_; + } + /*! \return Whether the pointer is nullptr */ + bool operator==(std::nullptr_t null) const { + return data_ == nullptr; + } + /*! \return Whether the pointer is not nullptr */ + bool operator!=(std::nullptr_t null) const { + return data_ != nullptr; + } + + private: + /*! \brief internal pointer field */ + Object* data_{nullptr}; + /*! + * \brief constructor from Object + * \param data The data pointer + */ + explicit ObjectPtr(Object* data) : data_(data) { + if (data != nullptr) { + data_->IncRef(); + } + } + // friend classes + friend class Object; + friend class ObjectRef; + friend struct ObjectHash; + template + friend class ObjectPtr; + template + friend class ObjAllocatorBase; + friend class MXNetPODValue_; + friend class MXNetArgsSetter; + friend class MXNetRetValue; + friend class MXNetArgValue; + template + friend RefType GetRef(const ObjType* ptr); + template + friend ObjectPtr GetObjectPtr(ObjType* ptr); +}; + +/*! \brief Base class of all object reference */ +class ObjectRef { + public: + /*! \brief default constructor */ + ObjectRef() = default; + /*! \brief Constructor from existing object ptr */ + explicit ObjectRef(ObjectPtr data) : data_(data) {} + /*! + * \brief Comparator + * \param other Another object ref. + * \return the compare result. + */ + bool same_as(const ObjectRef& other) const { + return data_ == other.data_; + } + /*! + * \brief Comparator + * \param other Another object ref. + * \return the compare result. + */ + bool operator==(const ObjectRef& other) const { + return data_ == other.data_; + } + /*! + * \brief Comparator + * \param other Another object ref. + * \return the compare result. + */ + bool operator!=(const ObjectRef& other) const { + return data_ != other.data_; + } + /*! + * \brief Comparator + * \param other Another object ref by address. + * \return the compare result. + */ + bool operator<(const ObjectRef& other) const { + return data_.get() < other.data_.get(); + } + /*! \return whether the expression is null */ + bool defined() const { + return data_ != nullptr; + } + /*! \return the internal object pointer */ + const Object* get() const { + return data_.get(); + } + /*! \return the internal object pointer */ + const Object* operator->() const { + return get(); + } + /*! \return whether the reference is unique */ + bool unique() const { + return data_.unique(); + } + /*! + * \brief Try to downcast the internal Object to a + * raw pointer of a corresponding type. + * + * The function will return a nullptr if the cast failed. + * + * if (const Add *add = node_ref.As()) { + * // This is an add node + * } + * \tparam ObjectType the target type, must be a subtype of Object/ + */ + template + inline const ObjectType* as() const; + + /*! \brief type indicate the container type. */ + using ContainerType = Object; + + protected: + /*! \brief Internal pointer that backs the reference. */ + ObjectPtr data_; + /*! \return return a mutable internal ptr, can be used by sub-classes. */ + Object* get_mutable() const { + return data_.get(); + } + /*! + * \brief Internal helper function downcast a ref without check. + * \note Only used for internal dev purposes. + * \tparam T The target reference type. + * \return The casted result. + */ + template + static T DowncastNoCheck(ObjectRef ref) { + return T(std::move(ref.data_)); + } + /*! + * \brief Internal helper function get data_ as ObjectPtr of ObjectType. + * \note only used for internal dev purpose. + * \tparam ObjectType The corresponding object type. + * \return the corresponding type. + */ + template + static ObjectPtr GetDataPtr(const ObjectRef& ref) { + return ObjectPtr(ref.data_.data_); + } + // friend classes. + friend struct ObjectHash; + friend class MXNetRetValue; + friend class MXNetArgsSetter; + template + friend SubRef Downcast(BaseRef ref); +}; + +/*! + * \brief Get an object ptr type from a raw object ptr. + * + * \param ptr The object pointer + * \tparam BaseType The reference type + * \tparam ObjectType The object type + * \return The corresponding RefType + */ +template +inline ObjectPtr GetObjectPtr(ObjectType* ptr); + +/*! \brief ObjectRef hash functor */ +struct ObjectHash { + size_t operator()(const ObjectRef& a) const { + return operator()(a.data_); + } + + template + size_t operator()(const ObjectPtr& a) const { + return std::hash()(a.get()); + } +}; + + +/*! \brief ObjectRef equal functor */ +struct ObjectEqual { + bool operator()(const ObjectRef& a, const ObjectRef& b) const { + return a.same_as(b); + } + + template + size_t operator()(const ObjectPtr& a, const ObjectPtr& b) const { + return a == b; + } +}; + + +/*! + * \brief helper macro to declare a base object type that can be inheritated. + * \param TypeName The name of the current type. + * \param ParentType The name of the ParentType + */ +#define MXNET_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ + static const uint32_t RuntimeTypeIndex() { \ + if (TypeName::_type_index != ::mxnet::runtime::TypeIndex::kDynamic) { \ + return TypeName::_type_index; \ + } \ + return _GetOrAllocRuntimeTypeIndex(); \ + } \ + static const uint32_t _GetOrAllocRuntimeTypeIndex() { \ + static uint32_t tidx = GetOrAllocRuntimeTypeIndex( \ + TypeName::_type_key, \ + TypeName::_type_index, \ + ParentType::_GetOrAllocRuntimeTypeIndex(), \ + TypeName::_type_child_slots, \ + TypeName::_type_child_slots_can_overflow); \ + return tidx; \ + } \ + +/*! + * \brief helper macro to declare type information in a final class. + * \param TypeName The name of the current type. + * \param ParentType The name of the ParentType + */ +#define MXNET_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \ + static const constexpr bool _type_final = true; \ + static const constexpr int _type_child_slots = 0; \ + MXNET_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ + + +/*! + * \brief Helper macro to register the object type to runtime. + * Makes sure that the runtime type table is correctly populated. + * + * Use this macro in the cc file for each terminal class. + */ +#define MXNET_REGISTER_OBJECT_TYPE(TypeName) \ + static DMLC_ATTRIBUTE_UNUSED uint32_t __make_Object_tidx ## _ ## TypeName ## __ = \ + TypeName::_GetOrAllocRuntimeTypeIndex() + + +#define MXNET_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + TypeName() {} \ + explicit TypeName( \ + ::mxnet::runtime::ObjectPtr<::mxnet::runtime::Object> n) \ + : ParentType(n) {} \ + const ObjectName* operator->() const { \ + return static_cast(data_.get()); \ + } \ + operator bool() const { return data_ != nullptr; } \ + using ContainerType = ObjectName; + +#define MXNET_DEFINE_OBJECT_REF_METHODS_MUT(TypeName, ParentType, ObjectName) \ + TypeName() {} \ + explicit TypeName( \ + ::mxnet::runtime::ObjectPtr<::mxnet::runtime::Object> n) \ + : ParentType(n) {} \ + ObjectName* operator->() { \ + return static_cast(data_.get()); \ + } \ + operator bool() const { return data_ != nullptr; } \ + using ContainerType = ObjectName; + +// Implementations details below +// Object reference counting. +#if MXNET_OBJECT_ATOMIC_REF_COUNTER + +inline void Object::IncRef() { + ref_counter_.fetch_add(1, std::memory_order_relaxed); +} + +inline void Object::DecRef() { + if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) { + std::atomic_thread_fence(std::memory_order_acquire); + if (this->deleter_ != nullptr) { + (*this->deleter_)(this); + } + } +} + +inline int Object::use_count() const { + return ref_counter_.load(std::memory_order_relaxed); +} + +#else + +inline void Object::IncRef() { + ++ref_counter_; +} + +inline void Object::DecRef() { + if (--ref_counter == 0) { + if (this->deleter_ != nullptr) { + (*this->deleter_)(this); + } + } +} + +inline int Object::use_count() const { + return ref_counter_; +} + +#endif // MXNET_OBJECT_ATOMIC_REF_COUNTER + +template +inline bool Object::IsInstance() const { + const Object* self = this; + // NOTE: the following code can be optimized by + // compiler dead-code elimination for already known constants. + if (self != nullptr) { + // Everything is a subclass of object. + if (std::is_same::value) return true; + if (TargetType::_type_final) { + // if the target type is a final type + // then we only need to check the equivalence. + return self->type_index_ == TargetType::RuntimeTypeIndex(); + } else { + // if target type is a non-leaf type + // Check if type index falls into the range of reserved slots. + uint32_t begin = TargetType::RuntimeTypeIndex(); + // The condition will be optimized by constant-folding. + if (TargetType::_type_child_slots != 0) { + uint32_t end = begin + TargetType::_type_child_slots; + if (self->type_index_ >= begin && self->type_index_ < end) return true; + } else { + if (self->type_index_ == begin) return true; + } + if (!TargetType::_type_child_slots_can_overflow) return false; + // Invariance: parent index is always smaller than the child. + if (self->type_index_ < TargetType::RuntimeTypeIndex()) return false; + // The rare slower-path, check type hierachy. + return self->DerivedFrom(TargetType::RuntimeTypeIndex()); + } + } else { + return false; + } +} + + +template +inline const ObjectType* ObjectRef::as() const { + if (data_ != nullptr && + data_->IsInstance()) { + return static_cast(data_.get()); + } else { + return nullptr; + } +} + +template +inline RefType GetRef(const ObjType* ptr) { + static_assert(std::is_base_of::value, + "Can only cast to the ref of same container type"); + return RefType(ObjectPtr(const_cast(static_cast(ptr)))); +} + +template +inline ObjectPtr GetObjectPtr(ObjType* ptr) { + static_assert(std::is_base_of::value, + "Can only cast to the ref of same container type"); + return ObjectPtr(static_cast(ptr)); +} + +template +inline SubRef Downcast(BaseRef ref) { + CHECK(ref->template IsInstance()) + << "Downcast from " << ref->GetTypeKey() << " to " + << SubRef::ContainerType::_type_key << " failed."; + return SubRef(std::move(ref.data_)); +} + +} // namespace runtime + +template +using NodePtr = runtime::ObjectPtr; + +} // namespace mxnet + +#endif // MXNET_RUNTIME_OBJECT_H_ diff --git a/include/mxnet/runtime/packed_func.h b/include/mxnet/runtime/packed_func.h new file mode 100644 index 000000000000..16351a7604dc --- /dev/null +++ b/include/mxnet/runtime/packed_func.h @@ -0,0 +1,1201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file runtime/packed_func.h + * \brief Type-erased function used across MXNET API. + */ +// Acknowledgement: This file originates from incubator-tvm +#ifndef MXNET_RUNTIME_PACKED_FUNC_H_ +#define MXNET_RUNTIME_PACKED_FUNC_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mxnet { +// forward declarations +// class Integer; +// class Expr; + +namespace runtime { + +/*! + * \brief convert a string to TVM type. + * \param s The string to be converted. + * \return The corresponding tvm type. + */ +inline DLDataType String2DLDataType(std::string s); + +// forward declarations +class MXNetArgs; +class MXNetArgValue; +class MXNetRetValue; +class MXNetArgsSetter; + +/*! + * \brief Packed function is a type-erased function. + * The arguments are passed by packed format. + * + * This is an useful unified interface to call generated functions, + * It is the unified function function type of TVM. + * It corresponds to TVMFunctionHandle in C runtime API. + */ +class PackedFunc { + public: + /*! + * \brief The internal std::function + * \param args The arguments to the function. + * \param rv The return value. + * + * \code + * // Example code on how to implemented FType + * void MyPackedFunc(MXNetArgs args, MXNetRetValue* rv) { + * // automatically convert arguments to desired type. + * int a0 = args[0]; + * float a1 = args[1]; + * ... + * // automatically assign values to rv + * std::string my_return_value = "x"; + * *rv = my_return_value; + * } + * \endcode + */ + using FType = std::function; + /*! \brief default constructor */ + PackedFunc() {} + /*! \brief constructor from null */ + PackedFunc(std::nullptr_t null) {} // NOLINT(*) + /*! + * \brief constructing a packed function from a std::function. + * \param body the internal container of packed function. + */ + explicit PackedFunc(FType body) : body_(body) {} + /*! + * \brief Call packed function by directly passing in unpacked format. + * \param args Arguments to be passed. + * \tparam Args arguments to be passed. + * + * \code + * // Example code on how to call packed function + * void CallPacked(PackedFunc f) { + * // call like normal functions by pass in arguments + * // return value is automatically converted back + * int rvalue = f(1, 2.0); + * } + * \endcode + */ + template + inline MXNetRetValue operator()(Args&& ...args) const; + /*! + * \brief Call the function in packed format. + * \param args The arguments + * \param rv The return value. + */ + inline void CallPacked(MXNetArgs args, MXNetRetValue* rv) const; + /*! \return the internal body function */ + inline FType body() const; + /*! \return Whether the packed function is nullptr */ + bool operator==(std::nullptr_t null) const { + return body_ == nullptr; + } + /*! \return Whether the packed function is not nullptr */ + bool operator!=(std::nullptr_t null) const { + return body_ != nullptr; + } + + private: + /*! \brief internal container of packed function */ + FType body_; +}; + +/*! + * \brief Please refer to \ref TypedPackedFuncAnchor "TypedPackedFunc" + */ +template +class TypedPackedFunc; + +/*! + * \anchor TypedPackedFuncAnchor + * \brief A PackedFunc wrapper to provide typed function signature. + * It is backed by a PackedFunc internally. + * + * TypedPackedFunc enables compile time type checking. + * TypedPackedFunc works with the runtime system: + * - It can be passed as an argument of PackedFunc. + * - It can be assigned to MXNetRetValue. + * - It can be directly converted to a type-erased PackedFunc. + * + * Developers should prefer TypedPackedFunc over PackedFunc in C++ code + * as it enables compile time checking. + * We can construct a TypedPackedFunc from a lambda function + * with the same signature. + * + * \code + * // user defined lambda function. + * auto addone = [](int x)->int { + * return x + 1; + * }; + * // We can directly convert + * // lambda function to TypedPackedFunc + * TypedPackedFunc ftyped(addone); + * // invoke the function. + * int y = ftyped(1); + * // Can be directly converted to PackedFunc + * PackedFunc packed = ftype; + * \endcode + * \tparam R The return value of the function. + * \tparam Args The argument signature of the function. + */ +template +class TypedPackedFunc { + public: + /*! \brief short hand for this function type */ + using TSelf = TypedPackedFunc; + /*! \brief default constructor */ + TypedPackedFunc() {} + /*! \brief constructor from null */ + TypedPackedFunc(std::nullptr_t null) {} // NOLINT(*) + /*! + * \brief construct by wrap a PackedFunc + * + * Example usage: + * \code + * PackedFunc packed([](MXNetArgs args, MXNetRetValue *rv) { + * int x = args[0]; + * *rv = x + 1; + * }); + * // construct from packed function + * TypedPackedFunc ftyped(packed); + * // call the typed version. + * CHECK_EQ(ftyped(1), 2); + * \endcode + * + * \param packed The packed function + */ + inline TypedPackedFunc(PackedFunc packed); // NOLINT(*) + /*! + * \brief constructor from MXNetRetValue + * \param value The MXNetRetValue + */ + inline TypedPackedFunc(const MXNetRetValue& value); // NOLINT(*) + /*! + * \brief constructor from MXNetArgValue + * \param value The MXNetArgValue + */ + inline TypedPackedFunc(const MXNetArgValue& value); // NOLINT(*) + /*! + * \brief construct from a lambda function with the same signature. + * + * Example usage: + * \code + * auto typed_lambda = [](int x)->int { return x + 1; } + * // construct from packed function + * TypedPackedFunc ftyped(typed_lambda); + * // call the typed version. + * CHECK_EQ(ftyped(1), 2); + * \endcode + * + * \param typed_lambda typed lambda function. + * \tparam FLambda the type of the lambda function. + */ + template + >::value>::type> + TypedPackedFunc(const FLambda& typed_lambda) { // NOLINT(*) + this->AssignTypedLambda(typed_lambda); + } + /*! + * \brief copy assignment operator from typed lambda + * + * Example usage: + * \code + * // construct from packed function + * TypedPackedFunc ftyped; + * ftyped = [](int x) { return x + 1; } + * // call the typed version. + * CHECK_EQ(ftyped(1), 2); + * \endcode + * + * \param typed_lambda typed lambda function. + * \tparam FLambda the type of the lambda function. + * \returns reference to self. + */ + template + >::value>::type> + TSelf& operator=(FLambda typed_lambda) { // NOLINT(*) + this->AssignTypedLambda(typed_lambda); + return *this; + } + /*! + * \brief copy assignment operator from PackedFunc. + * \param packed The packed function. + * \returns reference to self. + */ + TSelf& operator=(PackedFunc packed) { + packed_ = packed; + return *this; + } + /*! + * \brief Invoke the operator. + * \param args The arguments + * \returns The return value. + */ + inline R operator()(Args ...args) const; + /*! + * \brief convert to PackedFunc + * \return the internal PackedFunc + */ + operator PackedFunc() const { + return packed(); + } + /*! + * \return reference the internal PackedFunc + */ + const PackedFunc& packed() const { + return packed_; + } + /*! \return Whether the packed function is nullptr */ + bool operator==(std::nullptr_t null) const { + return packed_ == nullptr; + } + /*! \return Whether the packed function is not nullptr */ + bool operator!=(std::nullptr_t null) const { + return packed_ != nullptr; + } + + private: + friend class MXNetRetValue; + /*! \brief The internal packed function */ + PackedFunc packed_; + /*! + * \brief Assign the packed field using a typed lambda function. + * + * \param flambda The lambda function. + * \tparam FLambda The lambda function type. + * \note We capture the lambda when possible for maximum efficiency. + */ + template + inline void AssignTypedLambda(FLambda flambda); +}; + +/*! \brief Arguments into TVM functions. */ +class MXNetArgs { + public: + const MXNetValue* values; + const int* type_codes; + int num_args; + /*! + * \brief constructor + * \param values The argument values + * \param type_codes The argument type codes + * \param num_args number of arguments. + */ + MXNetArgs(const MXNetValue* values, + const int* type_codes, + int num_args) + : values(values), + type_codes(type_codes), + num_args(num_args) { } + /*! \return size of the arguments */ + inline int size() const; + /*! + * \brief Get i-th argument + * \param i the index. + * \return the ith argument. + */ + inline MXNetArgValue operator[](int i) const; +}; + +/*! + * \brief Convert type code to its name + * \param type_code The type code . + * \return The name of type code. + */ +inline const char* TypeCode2Str(int type_code); + +/*! + * \brief convert a string to TVM type. + * \param s The string to be converted. + * \return The corresponding tvm type. + */ +// inline TVMType String2TVMType(std::string s); + +// macro to check type code. +#define MXNET_CHECK_TYPE_CODE(CODE, T) \ + CHECK_EQ(CODE, T) << " expected " \ + << TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \ + +/*! + * \brief Type traits to mark if a class is tvm extension type. + * + * To enable extension type in C++ must be registered via marco. + * TVM_REGISTER_EXT_TYPE(TypeName) after defining this with this traits. + * + * Extension class can be passed and returned via PackedFunc in all tvm runtime. + * Internally extension class is stored as T*. + * + * \tparam T the typename + */ +template +struct extension_type_info { + static const int code = 0; +}; + +/*! + * \brief Internal base class to + * handle conversion to POD values. + */ +class MXNetPODValue_ { + public: + operator double() const { + // Allow automatic conversion from int to float + // This avoids errors when user pass in int from + // the frontend while the API expects a float. + if (type_code_ == kDLInt) { + return static_cast(value_.v_int64); + } + MXNET_CHECK_TYPE_CODE(type_code_, kDLFloat); + return value_.v_float64; + } + operator int64_t() const { + MXNET_CHECK_TYPE_CODE(type_code_, kDLInt); + return value_.v_int64; + } + operator uint64_t() const { + MXNET_CHECK_TYPE_CODE(type_code_, kDLInt); + return value_.v_int64; + } + operator int() const { + MXNET_CHECK_TYPE_CODE(type_code_, kDLInt); + CHECK_LE(value_.v_int64, + std::numeric_limits::max()); + return static_cast(value_.v_int64); + } + operator bool() const { + MXNET_CHECK_TYPE_CODE(type_code_, kDLInt); + return value_.v_int64 != 0; + } + operator void*() const { + if (type_code_ == kNull) return nullptr; + if (type_code_ == kArrayHandle) return value_.v_handle; + MXNET_CHECK_TYPE_CODE(type_code_, kHandle); + return value_.v_handle; + } + operator ObjectRef() const { + if (type_code_ == kNull) { + return ObjectRef(ObjectPtr(nullptr)); + } + MXNET_CHECK_TYPE_CODE(type_code_, kObjectHandle); + return ObjectRef( + ObjectPtr(static_cast(value_.v_handle))); + } + template::value>::type> + inline bool IsObjectRef() const; + int type_code() const { + return type_code_; + } + + /*! + * \brief return handle as specific pointer type. + * \tparam T the data type. + * \return The pointer type. + */ + template + T* ptr() const { + return static_cast(value_.v_handle); + } + + protected: + friend class MXNetArgsSetter; + friend class MXNetRetValue; + MXNetPODValue_() : type_code_(kNull) {} + MXNetPODValue_(MXNetValue value, int type_code) + : value_(value), type_code_(type_code) {} + + /*! \brief The value */ + MXNetValue value_; + /*! \brief the type code */ + int type_code_; +}; + +/*! + * \brief A single argument value to PackedFunc. + * Containing both type_code and MXNetValue + * + * Provides utilities to do type cast into other types. + */ +class MXNetArgValue : public MXNetPODValue_ { + public: + /*! \brief default constructor */ + MXNetArgValue() {} + /*! + * \brief constructor + * \param value of the function + * \param type_code The type code. + */ + MXNetArgValue(MXNetValue value, int type_code) + : MXNetPODValue_(value, type_code) { + } + // reuse converter from parent + using MXNetPODValue_::operator double; + using MXNetPODValue_::operator int64_t; + using MXNetPODValue_::operator uint64_t; + using MXNetPODValue_::operator int; + using MXNetPODValue_::operator bool; + using MXNetPODValue_::operator void*; + using MXNetPODValue_::operator ObjectRef; + using MXNetPODValue_::IsObjectRef; + + // conversion operator. + operator std::string() const { + if (type_code_ == kBytes) { + MXNetByteArray* arr = static_cast(value_.v_handle); + return std::string(arr->data, arr->size); + } else { + MXNET_CHECK_TYPE_CODE(type_code_, kStr); + return std::string(value_.v_str); + } + } + operator DLDataType() const { + if (type_code_ == kStr) { + return String2DLDataType(operator std::string()); + } + // None type + if (type_code_ == kNull) { + DLDataType t; + t.code = kHandle; t.bits = 0; t.lanes = 0; + return t; + } + MXNET_CHECK_TYPE_CODE(type_code_, kMXNetType); + return value_.v_type; + } + operator MXNetDataType() const { + return MXNetDataType(operator DLDataType()); + } + operator ::mxnet::NDArray*() const { + if (type_code_ == kNull) { + return nullptr; + } + MXNET_CHECK_TYPE_CODE(type_code_, kNDArrayHandle); + return reinterpret_cast<::mxnet::NDArray*>(value_.v_handle); + } + operator PackedFunc() const { + if (type_code_ == kNull) return PackedFunc(); + MXNET_CHECK_TYPE_CODE(type_code_, kFuncHandle); + return *ptr(); + } + template + operator TypedPackedFunc() const { + return TypedPackedFunc(operator PackedFunc()); + } + const MXNetValue& value() const { + return value_; + } + // Deferred extension handler. + template + inline TObjectRef AsObjectRef() const; + template::value>::type> + inline operator T() const; +}; + +/*! + * \brief Return Value container, + * Unlike MXNetArgValue, which only holds reference and do not delete + * the underlying container during destruction. + * + * MXNetRetValue holds value and will manage the underlying containers + * when it stores a complicated data type. + */ +class MXNetRetValue : public MXNetPODValue_ { + public: + /*! \brief default constructor */ + MXNetRetValue() {} + /*! + * \brief move constructor from anoter return value. + * \param other The other return value. + */ + MXNetRetValue(MXNetRetValue&& other) + : MXNetPODValue_(other.value_, other.type_code_) { + other.value_.v_handle = nullptr; + other.type_code_ = kNull; + } + /*! \brief destructor */ + ~MXNetRetValue() { + this->Clear(); + } + // reuse converter from parent + using MXNetPODValue_::operator double; + using MXNetPODValue_::operator int64_t; + using MXNetPODValue_::operator uint64_t; + using MXNetPODValue_::operator int; + using MXNetPODValue_::operator bool; + using MXNetPODValue_::operator void*; + using MXNetPODValue_::operator ObjectRef; + using MXNetPODValue_::IsObjectRef; + + MXNetRetValue(const MXNetRetValue& other) : MXNetPODValue_() { + this->Assign(other); + } + // conversion operators + operator std::string() const { + if (type_code_ == kBytes) { + return *ptr(); + } + MXNET_CHECK_TYPE_CODE(type_code_, kStr); + return *ptr(); + } + operator DLDataType() const { + if (type_code_ == kStr) { + return String2DLDataType(operator std::string()); + } + MXNET_CHECK_TYPE_CODE(type_code_, kMXNetType); + return value_.v_type; + } + operator MXNetDataType() const { + return MXNetDataType(operator DLDataType()); + } + operator PackedFunc() const { + if (type_code_ == kNull) return PackedFunc(); + MXNET_CHECK_TYPE_CODE(type_code_, kFuncHandle); + return *ptr(); + } + template + operator TypedPackedFunc() const { + return TypedPackedFunc(operator PackedFunc()); + } + // Assign operators + MXNetRetValue& operator=(MXNetRetValue&& other) { + this->Clear(); + value_ = other.value_; + type_code_ = other.type_code_; + other.type_code_ = kNull; + return *this; + } + MXNetRetValue& operator=(double value) { + this->SwitchToPOD(kDLFloat); + value_.v_float64 = value; + return *this; + } + MXNetRetValue& operator=(std::nullptr_t value) { + this->SwitchToPOD(kNull); + value_.v_handle = value; + return *this; + } + MXNetRetValue& operator=(void* value) { + this->SwitchToPOD(kHandle); + value_.v_handle = value; + return *this; + } + MXNetRetValue& operator=(int64_t value) { + this->SwitchToPOD(kDLInt); + value_.v_int64 = value; + return *this; + } + MXNetRetValue& operator=(int value) { + this->SwitchToPOD(kDLInt); + value_.v_int64 = value; + return *this; + } + MXNetRetValue& operator=(bool value) { + this->SwitchToPOD(kDLInt); + value_.v_int64 = value; + return *this; + } + MXNetRetValue& operator=(std::string value) { + this->SwitchToClass(kStr, value); + return *this; + } + MXNetRetValue& operator=(DLDataType t) { + this->SwitchToPOD(kMXNetType); + value_.v_type = t; + return *this; + } + MXNetRetValue& operator=(const MXNetDataType& other) { + return operator=(other.operator DLDataType()); + } + MXNetRetValue& operator=(MXNetByteArray value) { + this->SwitchToClass(kBytes, std::string(value.data, value.size)); + return *this; + } + MXNetRetValue& operator=(ObjectRef other) { + return operator=(std::move(other.data_)); + } + template + MXNetRetValue& operator=(ObjectPtr other) { + SwitchToObject(kObjectHandle, std::move(other)); + return *this; + } + MXNetRetValue& operator=(PackedFunc f) { + this->SwitchToClass(kFuncHandle, f); + return *this; + } + template + MXNetRetValue& operator=(const TypedPackedFunc& f) { + return operator=(f.packed()); + } + MXNetRetValue& operator=(const MXNetRetValue& other) { // NOLINT(*0 + this->Assign(other); + return *this; + } + MXNetRetValue& operator=(const MXNetArgValue& other) { + this->Assign(other); + return *this; + } + MXNetRetValue& operator=(::mxnet::NDArray* value) { + this->SwitchToPOD(kNDArrayHandle); + value_.v_handle = reinterpret_cast(value); + return *this; + } + template::code != 0>::type> + MXNetRetValue& operator=(const T& other) { + this->SwitchToClass( + extension_type_info::code, other); + return *this; + } + /*! + * \brief Move the value back to front-end via C API. + * This marks the current container as null. + * The managed resources is moved to front-end and + * the front end should take charge in managing them. + * + * \param ret_value The return value. + * \param ret_type_code The return type code. + */ + void MoveToCHost(MXNetValue* ret_value, + int* ret_type_code) { + // cannot move str; need specially handle. + CHECK(type_code_ != kStr && type_code_ != kBytes); + *ret_value = value_; + *ret_type_code = type_code_; + type_code_ = kNull; + } + /*! \return The value field, if the data is POD */ + const MXNetValue& value() const { + CHECK(type_code_ != kObjectHandle && + type_code_ != kFuncHandle && + type_code_ != kStr) << "MXNetRetValue.value can only be used for POD data"; + return value_; + } + // ObjectRef related extenstions: in tvm/packed_func_ext.h + template::value>::type> + inline operator T() const; + template + inline TObjectRef AsObjectRef() const; + + private: + template + void Assign(const T& other) { + switch (other.type_code()) { + case kStr: { + SwitchToClass(kStr, other); + break; + } + case kBytes: { + SwitchToClass(kBytes, other); + break; + } + case kFuncHandle: { + SwitchToClass(kFuncHandle, other); + break; + } + case kObjectHandle: { + *this = other.operator ObjectRef(); + break; + } + default: { + if (other.type_code() < kExtBegin) { + SwitchToPOD(other.type_code()); + value_ = other.value_; + } else { + LOG(FATAL) << "Does not support ext type"; + } + break; + } + } + } + // get the internal container. + void SwitchToPOD(int type_code) { + if (type_code_ != type_code) { + this->Clear(); + type_code_ = type_code; + } + } + template + void SwitchToClass(int type_code, T v) { + if (type_code_ != type_code) { + this->Clear(); + type_code_ = type_code; + value_.v_handle = new T(v); + } else { + *static_cast(value_.v_handle) = v; + } + } + void SwitchToObject(int type_code, ObjectPtr other) { + if (other.data_ != nullptr) { + this->Clear(); + type_code_ = type_code; + // move the handle out + value_.v_handle = other.data_; + other.data_ = nullptr; + } else { + SwitchToPOD(kNull); + } + } + void Clear() { + if (type_code_ == kNull) return; + switch (type_code_) { + case kStr: delete ptr(); break; + case kFuncHandle: delete ptr(); break; + case kObjectHandle: { + static_cast(value_.v_handle)->DecRef(); + break; + } + } + if (type_code_ > kExtBegin) { + LOG(FATAL) << "Does not support ext type"; + } + type_code_ = kNull; + } +}; + +inline DLDataType String2DLDataType(std::string s) { + DLDataType t; + // handle None type + if (s.length() == 0) { + t.bits = 0; t.lanes = 0; t.code = kHandle; + return t; + } + t.bits = 32; t.lanes = 1; + const char* scan = nullptr; + if (s.substr(0, 3) == "int") { + t.code = kDLInt; scan = s.c_str() + 3; + } else if (s.substr(0, 4) == "uint") { + t.code = kDLUInt; scan = s.c_str() + 4; + } else if (s.substr(0, 5) == "float") { + t.code = kDLFloat; scan = s.c_str() + 5; + } else if (s.substr(0, 6) == "handle") { + t.code = kHandle; + t.bits = 64; // handle uses 64 bit by default. + scan = s.c_str() + 6; + } else if (s == "bool") { + t.code = kDLUInt; + t.bits = 1; + t.lanes = 1; + return t; + } else if (s.substr(0, 6) == "custom") { + LOG(FATAL) << "custom MXNetDataType is not supported"; + // t.code = ParseCustomDatatype(s, &scan); + } else { + scan = s.c_str(); + LOG(FATAL) << "unknown type " << s; + } + char* xdelim; // emulate sscanf("%ux%u", bits, lanes) + uint8_t bits = static_cast(strtoul(scan, &xdelim, 10)); + if (bits != 0) t.bits = bits; + char* endpt = xdelim; + if (*xdelim == 'x') { + t.lanes = static_cast(strtoul(xdelim + 1, &endpt, 10)); + } + CHECK(endpt == s.c_str() + s.length()) << "unknown type " << s; + return t; +} + +// implementation details +inline const char* TypeCode2Str(int type_code) { + switch (type_code) { + case kDLInt: return "int"; + case kDLUInt: return "uint"; + case kDLFloat: return "float"; + case kStr: return "str"; + case kBytes: return "bytes"; + case kHandle: return "handle"; + case kNull: return "NULL"; + case kFuncHandle: return "FunctionHandle"; + case kObjectHandle: return "ObjectCell"; + default: LOG(FATAL) << "unknown type_code=" + << static_cast(type_code); return ""; + } +} + +inline int String2MXNetTypeWithBool(const std::string& s) { + if (s == "float32") { + return mshadow::kFloat32; + } else if (s == "float64") { + return mshadow::kFloat64; + } else if (s == "float16") { + return mshadow::kFloat16; + } else if (s == "uint8") { + return mshadow::kUint8; + } else if (s == "int8") { + return mshadow::kInt8; + } else if (s == "int32") { + return mshadow::kInt32; + } else if (s == "int64") { + return mshadow::kInt64; + } else if (s == "bool") { + return mshadow::kBool; + } else { + LOG(FATAL) << "unknown type " << s; + } + LOG(FATAL) << "should not reach here "; + return 0; +} + +inline int String2MXNetType(const std::string& s) { + if (s == "float32") { + return mshadow::kFloat32; + } else if (s == "float64") { + return mshadow::kFloat64; + } else if (s == "float16") { + return mshadow::kFloat16; + } else if (s == "uint8") { + return mshadow::kUint8; + } else if (s == "int8") { + return mshadow::kInt8; + } else if (s == "int32") { + return mshadow::kInt32; + } else if (s == "int64") { + return mshadow::kInt64; + } else { + LOG(FATAL) << "unknown type " << s; + } + LOG(FATAL) << "should not reach here "; + return 0; +} + +inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*) + if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) { + os << "bool"; return os; + } + if (t.code < kCustomBegin) { + os << TypeCode2Str(t.code); + } else { + LOG(FATAL) << "custom MXNetDataType is not supported"; + // os << "custom[" << GetCustomTypeName(t.code) << "]"; + } + if (t.code == kHandle) return os; + os << static_cast(t.bits); + if (t.lanes != 1) { + os << 'x' << static_cast(t.lanes); + } + return os; +} + +inline std::ostream& operator<<(std::ostream& os, const MXNetDataType& dtype) { // NOLINT(*) + return os << dtype.operator DLDataType(); +} + +inline MXNetArgValue MXNetArgs::operator[](int i) const { + CHECK_LT(i, num_args) + << "not enough argument passed, " + << num_args << " passed" + << " but request arg[" << i << "]."; + return MXNetArgValue(values[i], type_codes[i]); +} + +inline int MXNetArgs::size() const { + return num_args; +} + +inline void PackedFunc::CallPacked(MXNetArgs args, MXNetRetValue* rv) const { + body_(args, rv); +} + +inline PackedFunc::FType PackedFunc::body() const { + return body_; +} + +// internal namespace +namespace detail { + +template +struct for_each_dispatcher { + template + static void run(const F& f, T&& value, Args&&... args) { // NOLINT(*) + f(I, std::forward(value)); + for_each_dispatcher + ::run(f, std::forward(args)...); + } +}; + +template +struct for_each_dispatcher { + static void run(const F& f) {} // NOLINT(*) +}; + +template +inline void for_each(const F& f, Args&&... args) { // NOLINT(*) + for_each_dispatcher + ::run(f, std::forward(args)...); +} +} // namespace detail + +/* \brief argument settter to PackedFunc */ +class MXNetArgsSetter { + public: + MXNetArgsSetter(MXNetValue* values, int* type_codes) + : values_(values), type_codes_(type_codes) {} + // setters for POD types + template::value>::type> + void operator()(size_t i, T value) const { + values_[i].v_int64 = static_cast(value); + type_codes_[i] = kDLInt; + } + void operator()(size_t i, uint64_t value) const { + values_[i].v_int64 = static_cast(value); + CHECK_LE(value, + static_cast(std::numeric_limits::max())); + type_codes_[i] = kDLInt; + } + void operator()(size_t i, double value) const { + values_[i].v_float64 = value; + type_codes_[i] = kDLFloat; + } + void operator()(size_t i, std::nullptr_t value) const { + values_[i].v_handle = value; + type_codes_[i] = kNull; + } + void operator()(size_t i, const MXNetArgValue& value) const { + values_[i] = value.value_; + type_codes_[i] = value.type_code_; + } + void operator()(size_t i, void* value) const { + values_[i].v_handle = value; + type_codes_[i] = kHandle; + } + void operator()(size_t i, DLTensor* value) const { + values_[i].v_handle = value; + type_codes_[i] = kArrayHandle; + } + void operator()(size_t i, const char* value) const { + values_[i].v_str = value; + type_codes_[i] = kStr; + } + // setters for container type + // They must be reference(instead of const ref) + // to make sure they are alive in the tuple(instead of getting converted) + void operator()(size_t i, const std::string& value) const { // NOLINT(*) + values_[i].v_str = value.c_str(); + type_codes_[i] = kStr; + } + void operator()(size_t i, DLDataType value) const { + values_[i].v_type = value; + type_codes_[i] = kMXNetType; + } + void operator()(size_t i, MXNetDataType dtype) const { + operator()(i, dtype.operator DLDataType()); + } + void operator()(size_t i, const MXNetByteArray& value) const { // NOLINT(*) + values_[i].v_handle = const_cast(&value); + type_codes_[i] = kBytes; + } + void operator()(size_t i, const PackedFunc& value) const { // NOLINT(*) + values_[i].v_handle = const_cast(&value); + type_codes_[i] = kFuncHandle; + } + template + void operator()(size_t i, const TypedPackedFunc& value) const { // NOLINT(*) + operator()(i, value.packed()); + } + void operator()(size_t i, const ObjectRef& value) const { // NOLINT(*) + if (value.defined()) { + values_[i].v_handle = value.data_.data_; + type_codes_[i] = kObjectHandle; + } else { + type_codes_[i] = kNull; + } + } + void operator()(size_t i, const MXNetRetValue& value) const { // NOLINT(*) + if (value.type_code() == kStr) { + values_[i].v_str = value.ptr()->c_str(); + type_codes_[i] = kStr; + } else { + CHECK_NE(value.type_code(), kBytes) << "not handled."; + values_[i] = value.value_; + type_codes_[i] = value.type_code(); + } + } + + private: + /*! \brief The values fields */ + MXNetValue* values_; + /*! \brief The type code fields */ + int* type_codes_; +}; + +template +inline MXNetRetValue PackedFunc::operator()(Args&& ...args) const { + const int kNumArgs = sizeof...(Args); + const int kArraySize = kNumArgs > 0 ? kNumArgs : 1; + MXNetValue values[kArraySize]; + int type_codes[kArraySize]; + detail::for_each(MXNetArgsSetter(values, type_codes), + std::forward(args)...); + MXNetRetValue rv; + body_(MXNetArgs(values, type_codes, kNumArgs), &rv); + return rv; +} + +namespace detail { +template +struct unpack_call_dispatcher { + template + static void run(const F& f, + const MXNetArgs& args_pack, + MXNetRetValue* rv, + Args&&... unpacked_args) { + unpack_call_dispatcher + ::run(f, args_pack, rv, + std::forward(unpacked_args)..., + args_pack[index]); + } +}; + +template +struct unpack_call_dispatcher { + template + static void run(const F& f, + const MXNetArgs& args_pack, + MXNetRetValue* rv, + Args&&... unpacked_args) { + *rv = R(f(std::forward(unpacked_args)...)); + } +}; + +template +struct unpack_call_dispatcher { + template + static void run(const F& f, + const MXNetArgs& args_pack, + MXNetRetValue* rv, + Args&&... unpacked_args) { + f(std::forward(unpacked_args)...); + } +}; + +template +inline void unpack_call(const F& f, const MXNetArgs& args, MXNetRetValue* rv) { + unpack_call_dispatcher::run(f, args, rv); +} + +template +inline R call_packed(const PackedFunc& pf, Args&& ...args) { + return R(pf(std::forward(args)...)); +} + +template +struct typed_packed_call_dispatcher { + template + static inline R run(const PackedFunc& pf, Args&& ...args) { + return pf(std::forward(args)...); + } +}; + +template<> +struct typed_packed_call_dispatcher { + template + static inline void run(const PackedFunc& pf, Args&& ...args) { + pf(std::forward(args)...); + } +}; +} // namespace detail + +template +TypedPackedFunc::TypedPackedFunc(PackedFunc packed) + : packed_(packed) {} + +template +TypedPackedFunc::TypedPackedFunc(const MXNetRetValue& value) + : packed_(value.operator PackedFunc()) {} + +template +TypedPackedFunc::TypedPackedFunc(const MXNetArgValue& value) + : packed_(value.operator PackedFunc()) {} + +template +template +inline void TypedPackedFunc::AssignTypedLambda(FType flambda) { + packed_ = PackedFunc([flambda](const MXNetArgs& args, MXNetRetValue* rv) { + detail::unpack_call(flambda, args, rv); + }); +} + +template +inline R TypedPackedFunc::operator()(Args... args) const { + return detail::typed_packed_call_dispatcher + ::run(packed_, std::forward(args)...); +} + +// extension and node type handling +namespace detail { +template +struct MXNetValueCast { + static T Apply(const TSrc* self) { + static_assert(!is_ext && !is_nd, "The default case accepts only non-extensions"); + return self->template AsObjectRef(); + } +}; + +} // namespace detail + +template +inline MXNetRetValue::operator T() const { + return detail:: + MXNetValueCast::code != 0), + (array_type_info::code > 0)> + ::Apply(this); +} + +} // namespace runtime +} // namespace mxnet +#endif // MXNET_RUNTIME_PACKED_FUNC_H_ diff --git a/include/mxnet/runtime/registry.h b/include/mxnet/runtime/registry.h new file mode 100644 index 000000000000..70782b47254d --- /dev/null +++ b/include/mxnet/runtime/registry.h @@ -0,0 +1,314 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file registry.h + * \brief This file defines the TVM global function registry. + * + * The registered functions will be made available to front-end + * as well as backend users. + * + * The registry stores type-erased functions. + * Each registered function is automatically exposed + * to front-end language(e.g. python). + * + * Front-end can also pass callbacks as PackedFunc, or register + * then into the same global registry in C++. + * The goal is to mix the front-end language and the TVM back-end. + * + * \code + * // register the function as MyAPIFuncName + * TVM_REGISTER_GLOBAL(MyAPIFuncName) + * .set_body([](TVMArgs args, TVMRetValue* rv) { + * // my code. + * }); + * \endcode + */ +// Acknowledgement: This file originates from incubator-tvm +#ifndef MXNET_RUNTIME_REGISTRY_H_ +#define MXNET_RUNTIME_REGISTRY_H_ + +#include +#include +#include "packed_func.h" + +namespace mxnet { +namespace runtime { + +/*! \brief Registry for global function */ +class Registry { + public: + /*! + * \brief set the body of the function to be f + * \param f The body of the function. + */ + MXNET_DLL Registry& set_body(PackedFunc f); // NOLINT(*) + /*! + * \brief set the body of the function to be f + * \param f The body of the function. + */ + Registry& set_body(PackedFunc::FType f) { // NOLINT(*) + return set_body(PackedFunc(f)); + } + /*! + * \brief set the body of the function to be TypedPackedFunc. + * + * \code + * + * TVM_REGISTER_API("addone") + * .set_body_typed([](int x) { return x + 1; }); + * + * \endcode + * + * \param f The body of the function. + * \tparam FType the signature of the function. + * \tparam FLambda The type of f. + */ + template + Registry& set_body_typed(FLambda f) { + return set_body(TypedPackedFunc(f).packed()); + } + + /*! + * \brief set the body of the function to the given function pointer. + * Note that this doesn't work with lambdas, you need to + * explicitly give a type for those. + * Note that this will ignore default arg values and always require all arguments to be provided. + * + * \code + * + * int multiply(int x, int y) { + * return x * y; + * } + * + * TVM_REGISTER_API("multiply") + * .set_body_typed(multiply); // will have type int(int, int) + * + * \endcode + * + * \param f The function to forward to. + * \tparam R the return type of the function (inferred). + * \tparam Args the argument types of the function (inferred). + */ + template + Registry& set_body_typed(R (*f)(Args...)) { + return set_body(TypedPackedFunc(f)); + } + + /*! + * \brief set the body of the function to be the passed method pointer. + * Note that this will ignore default arg values and always require all arguments to be provided. + * + * \code + * + * // node subclass: + * struct Example { + * int doThing(int x); + * } + * TVM_REGISTER_API("Example_doThing") + * .set_body_method(&Example::doThing); // will have type int(Example, int) + * + * \endcode + * + * \param f the method pointer to forward to. + * \tparam T the type containing the method (inferred). + * \tparam R the return type of the function (inferred). + * \tparam Args the argument types of the function (inferred). + */ + template + Registry& set_body_method(R (T::*f)(Args...)) { + return set_body_typed([f](T target, Args... params) -> R { + // call method pointer + return (target.*f)(params...); + }); + } + + /*! + * \brief set the body of the function to be the passed method pointer. + * Note that this will ignore default arg values and always require all arguments to be provided. + * + * \code + * + * // node subclass: + * struct Example { + * int doThing(int x); + * } + * TVM_REGISTER_API("Example_doThing") + * .set_body_method(&Example::doThing); // will have type int(Example, int) + * + * \endcode + * + * \param f the method pointer to forward to. + * \tparam T the type containing the method (inferred). + * \tparam R the return type of the function (inferred). + * \tparam Args the argument types of the function (inferred). + */ + template + Registry& set_body_method(R (T::*f)(Args...) const) { + return set_body_typed([f](const T target, Args... params) -> R { + // call method pointer + return (target.*f)(params...); + }); + } + + /*! + * \brief set the body of the function to be the passed method pointer. + * Used when calling a method on a Node subclass through a ObjectRef subclass. + * Note that this will ignore default arg values and always require all arguments to be provided. + * + * \code + * + * // node subclass: + * struct ExampleNode: BaseNode { + * int doThing(int x); + * } + * + * // noderef subclass + * struct Example; + * + * TVM_REGISTER_API("Example_doThing") + * .set_body_method(&ExampleNode::doThing); // will have type int(Example, int) + * + * // note that just doing: + * // .set_body_method(&ExampleNode::doThing); + * // wouldn't work, because ExampleNode can't be taken from a TVMArgValue. + * + * \endcode + * + * \param f the method pointer to forward to. + * \tparam TObjectRef the node reference type to call the method on + * \tparam TNode the node type containing the method (inferred). + * \tparam R the return type of the function (inferred). + * \tparam Args the argument types of the function (inferred). + */ + template::value>::type> + Registry& set_body_method(R (TNode::*f)(Args...)) { + return set_body_typed([f](TObjectRef ref, Args... params) { + TNode* target = ref.operator->(); + // call method pointer + return (target->*f)(params...); + }); + } + + /*! + * \brief set the body of the function to be the passed method pointer. + * Used when calling a method on a Node subclass through a ObjectRef subclass. + * Note that this will ignore default arg values and always require all arguments to be provided. + * + * \code + * + * // node subclass: + * struct ExampleNode: BaseNode { + * int doThing(int x); + * } + * + * // noderef subclass + * struct Example; + * + * TVM_REGISTER_API("Example_doThing") + * .set_body_method(&ExampleNode::doThing); // will have type int(Example, int) + * + * // note that just doing: + * // .set_body_method(&ExampleNode::doThing); + * // wouldn't work, because ExampleNode can't be taken from a TVMArgValue. + * + * \endcode + * + * \param f the method pointer to forward to. + * \tparam TObjectRef the node reference type to call the method on + * \tparam TNode the node type containing the method (inferred). + * \tparam R the return type of the function (inferred). + * \tparam Args the argument types of the function (inferred). + */ + template::value>::type> + Registry& set_body_method(R (TNode::*f)(Args...) const) { + return set_body_typed([f](TObjectRef ref, Args... params) { + const TNode* target = ref.operator->(); + // call method pointer + return (target->*f)(params...); + }); + } + + /*! + * \brief Register a function with given name + * \param name The name of the function. + * \param override Whether allow oveeride existing function. + * \return Reference to theregistry. + */ + MXNET_DLL static Registry& Register(const std::string& name, bool override = false); // NOLINT(*) + /*! + * \brief Erase global function from registry, if exist. + * \param name The name of the function. + * \return Whether function exist. + */ + MXNET_DLL static bool Remove(const std::string& name); + /*! + * \brief Get the global function by name. + * \param name The name of the function. + * \return pointer to the registered function, + * nullptr if it does not exist. + */ + MXNET_DLL static const PackedFunc* Get(const std::string& name); // NOLINT(*) + /*! + * \brief Get the names of currently registered global function. + * \return The names + */ + MXNET_DLL static std::vector ListNames(); + + // Internal class. + struct Manager; + + protected: + /*! \brief name of the function */ + std::string name_; + /*! \brief internal packed function */ + PackedFunc func_; + friend struct Manager; +}; + +/*! \brief helper macro to supress unused warning */ +#if defined(__GNUC__) +#define MXNET_ATTRIBUTE_UNUSED __attribute__((unused)) +#else +#define MXNET_ATTRIBUTE_UNUSED +#endif + +#define MXNET_STR_CONCAT_(__x, __y) __x##__y +#define MXNET_STR_CONCAT(__x, __y) MXNET_STR_CONCAT_(__x, __y) + +#define MXNET_FUNC_REG_VAR_DEF \ + static MXNET_ATTRIBUTE_UNUSED ::mxnet::runtime::Registry& __mk_ ## MXNET + +/*! + * \brief Register a function globally. + * \code + * TVM_REGISTER_GLOBAL("MyPrint") + * .set_body([](TVMArgs args, TVMRetValue* rv) { + * }); + * \endcode + */ +#define MXNET_REGISTER_GLOBAL(OpName) \ + MXNET_STR_CONCAT(MXNET_FUNC_REG_VAR_DEF, __COUNTER__) = \ + ::mxnet::runtime::Registry::Register(OpName) + +} // namespace runtime +} // namespace mxnet +#endif // MXNET_RUNTIME_REGISTRY_H_ diff --git a/include/mxnet/tuple.h b/include/mxnet/tuple.h index e4ca75ffa6d3..e1d9b66754d1 100644 --- a/include/mxnet/tuple.h +++ b/include/mxnet/tuple.h @@ -34,6 +34,10 @@ #include "nnvm/graph_attr_types.h" #include "nnvm/graph.h" #include "nnvm/pass.h" +#include "runtime/object.h" +#include "runtime/ffi_helper.h" +#include "node/container.h" +#include "ir/expr.h" namespace mxnet { @@ -59,6 +63,17 @@ class Tuple { inline ~Tuple() { delete [] data_heap_; } + /*! + * constructor to construct a tuple with all `value`. + * \param ndim the number of dimension + * \param value the dimension size for all dims + */ + inline Tuple(const int ndim, const dim_t value) { // NOLINT(*) + this->SetDim(ndim); + if (ndim > 0) { + std::fill_n(begin(), ndim, value); + } + } /*! * \brief copy constructor from another tuple * \param s the source tuple @@ -103,6 +118,16 @@ class Tuple { RandomAccessIterator end) { this->assign(begin, end); } + + inline explicit Tuple(const runtime::ObjectRef& src) { + using namespace runtime; + ADT adt = Downcast(src); + this->SetDim(adt.size()); + for (int i = 0; i < ndim_; ++i) { + this->begin()[i] = Downcast(adt[i])->value; + } + } + /*! * \brief Assign content to tuple from iterator. * \param begin the beginning of iterator @@ -468,6 +493,8 @@ class TShape : public Tuple { RandomAccessIterator end) { this->assign(begin, end); } + + inline explicit TShape(const ObjectRef& src): Tuple(src) {} /*! * \brief assignment function from tshape * \param src source shape. diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index 95c5bb8521ba..d6d6a1f49e8e 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -103,3 +103,8 @@ from . import numpy_op_signature from . import numpy_dispatch_protocol from . import numpy_op_fallback + +from . import _global_var + +from . import _api_internal +from . import api diff --git a/python/mxnet/_api_internal.py b/python/mxnet/_api_internal.py new file mode 100644 index 000000000000..1a60d435b731 --- /dev/null +++ b/python/mxnet/_api_internal.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Namespace of internal API + +The functions in this namespace are automatically exported from C++ side via PackedFunc +that is registered by "MXNET_REGISTER_*" macro. This way makes calling Python functions from C++ +side very easily. + +Each string starts with "_" in the "MXNET_REGISTER_*" macro is an internal API. + +Acknowledgement: This file originates from incubator-tvm +""" diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py index 8dce5d869254..56f62f3d5377 100644 --- a/python/mxnet/_ctypes/ndarray.py +++ b/python/mxnet/_ctypes/ndarray.py @@ -26,6 +26,7 @@ from ..base import c_str_array, c_handle_array from ..base import NDArrayHandle, CachedOpHandle from ..base import check_call +from .. import _global_var def _monitor_callback_wrapper(callback): @@ -57,23 +58,7 @@ def __del__(self): check_call(_LIB.MXNDArrayFree(self.handle)) def __reduce__(self): - return (_ndarray_cls, (None,), self.__getstate__()) - - -_ndarray_cls = None -_np_ndarray_cls = None - - -def _set_ndarray_class(cls): - """Set the symbolic class to be cls""" - global _ndarray_cls - _ndarray_cls = cls - - -def _set_np_ndarray_class(cls): - """Set the symbolic class to be cls""" - global _np_ndarray_cls - _np_ndarray_cls = cls + return (_global_var._ndarray_cls, (None,), self.__getstate__()) def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op, output_is_list): @@ -105,7 +90,7 @@ def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op, output_is_list c_str_array([str(s) for s in vals]), ctypes.byref(out_stypes))) - create_ndarray_fn = _np_ndarray_cls if is_np_op else _ndarray_cls + create_ndarray_fn = _global_var._np_ndarray_cls if is_np_op else _global_var._ndarray_cls if original_output is not None: return original_output if num_output.value == 1 and not output_is_list: @@ -170,7 +155,7 @@ def __call__(self, *args, **kwargs): if original_output is not None: return original_output - create_ndarray_fn = _np_ndarray_cls if self.is_np_sym else _ndarray_cls + create_ndarray_fn = _global_var._np_ndarray_cls if self.is_np_sym else _global_var._ndarray_cls if num_output.value == 1: return create_ndarray_fn(ctypes.cast(output_vars[0], NDArrayHandle), stype=out_stypes[0]) diff --git a/python/mxnet/_ffi/__init__.py b/python/mxnet/_ffi/__init__.py new file mode 100644 index 000000000000..9fcc4e1f0de7 --- /dev/null +++ b/python/mxnet/_ffi/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Acknowledgement: This file originates from incubator-tvm +""" diff --git a/python/mxnet/_ffi/_ctypes/__init__.py b/python/mxnet/_ffi/_ctypes/__init__.py new file mode 100644 index 000000000000..5072d118d413 --- /dev/null +++ b/python/mxnet/_ffi/_ctypes/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +ctypes specific implementation of FFI +Acknowledgement: This file originates from incubator-tvm +""" diff --git a/python/mxnet/_ffi/_ctypes/function.py b/python/mxnet/_ffi/_ctypes/function.py new file mode 100644 index 000000000000..5b126913b998 --- /dev/null +++ b/python/mxnet/_ffi/_ctypes/function.py @@ -0,0 +1,120 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# coding: utf-8 +# pylint: disable=invalid-name, protected-access, too-many-branches, global-statement, unused-import +""" +Function configuration API. +Acknowledgement: This file originates from incubator-tvm +""" +import ctypes +from numbers import Number, Integral + +from ...base import get_last_ffi_error, _LIB +from ..base import c_str +from .types import MXNetValue, TypeCode +from .types import RETURN_SWITCH +from .object import ObjectBase +from ..node_generic import convert_to_node +from ..._ctypes.ndarray import NDArrayBase + +ObjectHandle = ctypes.c_void_p + + +def _make_mxnet_args(args, temp_args): + """Pack arguments into c args mxnet call accept""" + num_args = len(args) + values = (MXNetValue * num_args)() + type_codes = (ctypes.c_int * num_args)() + for i, arg in enumerate(args): + if isinstance(arg, ObjectBase): + values[i].v_handle = arg.handle + type_codes[i] = TypeCode.OBJECT_HANDLE + elif arg is None: + values[i].v_handle = None + type_codes[i] = TypeCode.NULL + elif isinstance(arg, Integral): + values[i].v_int64 = arg + type_codes[i] = TypeCode.INT + elif isinstance(arg, Number): + values[i].v_float64 = arg + type_codes[i] = TypeCode.FLOAT + elif isinstance(arg, str): + values[i].v_str = c_str(arg) + type_codes[i] = TypeCode.STR + elif isinstance(arg, (list, tuple)): + arg = convert_to_node(arg) + values[i].v_handle = arg.handle + type_codes[i] = TypeCode.OBJECT_HANDLE + temp_args.append(arg) + elif isinstance(arg, NDArrayBase): + values[i].v_handle = arg.handle + type_codes[i] = TypeCode.NDARRAYHANDLE + elif isinstance(arg, ctypes.c_void_p): + values[i].v_handle = arg + type_codes[i] = TypeCode.HANDLE + else: + raise TypeError("Don't know how to handle type %s" % type(arg)) + return values, type_codes, num_args + + +class FunctionBase(object): + """Function base.""" + __slots__ = ["handle", "is_global"] + # pylint: disable=no-member + def __init__(self, handle, is_global): + """Initialize the function with handle + + Parameters + ---------- + handle : FunctionHandle + the handle to the underlying function. + + is_global : bool + Whether this is a global function in python + """ + self.handle = handle + self.is_global = is_global + + def __del__(self): + if not self.is_global and _LIB is not None: + if _LIB.MXNetFuncFree(self.handle) != 0: + raise get_last_ffi_error() + + def __call__(self, *args): + """Call the function with positional arguments + + args : list + The positional arguments to the function call. + """ + temp_args = [] + values, tcodes, num_args = _make_mxnet_args(args, temp_args) + ret_val = MXNetValue() + ret_tcode = ctypes.c_int() + if _LIB.MXNetFuncCall( + self.handle, values, tcodes, ctypes.c_int(num_args), + ctypes.byref(ret_val), ctypes.byref(ret_tcode)) != 0: + raise get_last_ffi_error() + _ = temp_args + _ = args + return RETURN_SWITCH[ret_tcode.value](ret_val) + + +_CLASS_OBJECT = None + +def _set_class_object(obj_class): + global _CLASS_OBJECT + _CLASS_OBJECT = obj_class diff --git a/python/mxnet/_ffi/_ctypes/object.py b/python/mxnet/_ffi/_ctypes/object.py new file mode 100644 index 000000000000..85ab415692f6 --- /dev/null +++ b/python/mxnet/_ffi/_ctypes/object.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +""" +Runtime Object api +Acknowledgement: This file originates from incubator-tvm +""" +import ctypes +from ...base import _LIB, check_call +from . import function +from .types import RETURN_SWITCH, TypeCode + +ObjectHandle = ctypes.c_void_p + + +def _return_object(x): + handle = x.v_handle + if not isinstance(handle, ObjectHandle): + handle = ObjectHandle(handle) + # Does not support specific cpp node class for now + cls = function._CLASS_OBJECT + # Avoid calling __init__ of cls, instead directly call __new__ + # This allows child class to implement their own __init__ + obj = cls.__new__(cls) + obj.handle = handle + return obj + +RETURN_SWITCH[TypeCode.OBJECT_HANDLE] = _return_object + + +class ObjectBase(object): + """Base object for all object types""" + __slots__ = ["handle"] + + def __del__(self): + if _LIB is not None: + check_call(_LIB.MXNetObjectFree(self.handle)) + + # Does not support creation of cpp node class via python class diff --git a/python/mxnet/_ffi/_ctypes/types.py b/python/mxnet/_ffi/_ctypes/types.py new file mode 100644 index 000000000000..265408e5ba93 --- /dev/null +++ b/python/mxnet/_ffi/_ctypes/types.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The C Types used in API. +Acknowledgement: This file originates from incubator-tvm +""" +# pylint: disable=invalid-name +import ctypes +from ...base import NDArrayHandle +from ... import _global_var + + +class TypeCode(object): + """Type code used in API calls""" + INT = 0 + UINT = 1 + FLOAT = 2 + HANDLE = 3 + NULL = 4 + MXNET_TYPE = 5 + MXNET_CONTEXT = 6 + ARRAY_HANDLE = 7 + OBJECT_HANDLE = 8 + MODULE_HANDLE = 9 + FUNC_HANDLE = 10 + STR = 11 + BYTES = 12 + NDARRAY_CONTAINER = 13 + NDARRAYHANDLE = 14 + EXT_BEGIN = 15 + + +class MXNetValue(ctypes.Union): + """MXNetValue in C API""" + _fields_ = [("v_int64", ctypes.c_int64), + ("v_float64", ctypes.c_double), + ("v_handle", ctypes.c_void_p), + ("v_str", ctypes.c_char_p)] + +RETURN_SWITCH = { + TypeCode.INT: lambda x: x.v_int64, + TypeCode.FLOAT: lambda x: x.v_float64, + TypeCode.NULL: lambda x: None, + TypeCode.NDARRAYHANDLE: lambda x: _global_var._np_ndarray_cls(handle=NDArrayHandle(x.v_handle)) +} diff --git a/python/mxnet/_ffi/_cy3/__init__.py b/python/mxnet/_ffi/_cy3/__init__.py new file mode 100644 index 000000000000..bb5ed444987b --- /dev/null +++ b/python/mxnet/_ffi/_cy3/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""cython3 namespace +Acknowledgement: This file originates from incubator-tvm +""" diff --git a/python/mxnet/_ffi/_cython/base.pxi b/python/mxnet/_ffi/_cython/base.pxi new file mode 100644 index 000000000000..1c393e8e241c --- /dev/null +++ b/python/mxnet/_ffi/_cython/base.pxi @@ -0,0 +1,103 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Acknowledgement: This file originates from incubator-tvm""" + +from libcpp.vector cimport vector +from cpython.version cimport PY_MAJOR_VERSION +from cpython cimport pycapsule +from libc.stdint cimport int32_t, int64_t, uint64_t, uint8_t, uint16_t, uint32_t +import ctypes +from ...base import get_last_ffi_error + +cdef enum MXNetTypeCode: + kInt = 0 + kUInt = 1 + kFloat = 2 + kHandle = 3 + kNull = 4 + kMXNetType = 5 + kMXNetContext = 6 + kArrayHandle = 7 + kObjectHandle = 8 + kModuleHandle = 9 + kFuncHandle = 10 + kStr = 11 + kBytes = 12 + kNDArrayContainer = 13 + kNDArrayHandle = 14 + kExtBegin = 15 + +cdef extern from "mxnet/runtime/c_runtime_api.h": + ctypedef struct MXNetValue: + int64_t v_int64 + double v_float64 + void* v_handle + const char* v_str + +ctypedef void* MXNetRetValueHandle +ctypedef void* MXNetFunctionHandle +ctypedef void* ObjectHandle + + +cdef extern from "mxnet/runtime/c_runtime_api.h": + int MXNetFuncCall(MXNetFunctionHandle func, + MXNetValue* arg_values, + int* type_codes, + int num_args, + MXNetValue* ret_val, + int* ret_type_code) + int MXNetFuncFree(MXNetFunctionHandle func) + + +cdef inline py_str(const char* x): + if PY_MAJOR_VERSION < 3: + return x + else: + return x.decode("utf-8") + + +cdef inline c_str(pystr): + """Create ctypes char * from a python string + Parameters + ---------- + string : string type + python string + + Returns + ------- + str : c_char_p + A char pointer that can be passed to C API + """ + return pystr.encode("utf-8") + + +cdef inline CALL(int ret): + if ret != 0: + raise get_last_ffi_error() + + +cdef inline object ctypes_handle(void* chandle): + """Cast C handle to ctypes handle.""" + return ctypes.cast(chandle, ctypes.c_void_p) + + +cdef inline void* c_handle(object handle): + """Cast C types handle to c handle.""" + cdef unsigned long long v_ptr + v_ptr = handle.value + return (v_ptr) diff --git a/python/mxnet/_ffi/_cython/convert.pxi b/python/mxnet/_ffi/_cython/convert.pxi new file mode 100644 index 000000000000..2cbdc48b49a8 --- /dev/null +++ b/python/mxnet/_ffi/_cython/convert.pxi @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Acknowledgement: This file originates from incubator-tvm""" + +from libc.stdint cimport * +from numbers import Integral + +cdef extern from "mxnet/runtime/ffi_helper.h" namespace "mxnet::runtime": + cdef cppclass Object: + pass + + cdef cppclass ObjectPtr[T]: + pass + + cdef cppclass ObjectRef: + const Object* get() const + + cdef cppclass ADT(ObjectRef): + ADT() + + cdef cppclass ADTBuilder: + ADTBuilder() + ADTBuilder(uint32_t tag, uint32_t size) + void EmplaceInit(size_t idx, ObjectRef) + ADT Get() + + cdef cppclass Integer(ObjectRef): + Integer() + Integer(int64_t) + + +cdef inline ADT convert_tuple(tuple src_tuple) except *: + cdef uint32_t size = len(src_tuple) + cdef ADTBuilder builder = ADTBuilder(0, size) + for i in range(size): + builder.EmplaceInit(i, convert_object(src_tuple[i])) + return builder.Get() + + +cdef inline ADT convert_list(list src) except *: + cdef uint32_t size = len(src) + cdef ADTBuilder builder = ADTBuilder(0, size) + for i in range(size): + builder.EmplaceInit(i, convert_object(src[i])) + return builder.Get() + + +cdef inline ObjectRef convert_object(object src_obj) except *: + # We use this branch as a fast check for int. + # The Integral branch is slow, and it only captures numpy.int64, etc. + if isinstance(src_obj, int): + return Integer(src_obj) + elif isinstance(src_obj, tuple): + return convert_tuple(src_obj) + elif isinstance(src_obj, list): + return convert_list(src_obj) + elif isinstance(src_obj, Integral): + return Integer(src_obj) + else: + raise TypeError("Don't know how to convert type %s" % type(src_obj)) diff --git a/python/mxnet/_ffi/_cython/core.pyx b/python/mxnet/_ffi/_cython/core.pyx new file mode 100644 index 000000000000..482f494b6e5e --- /dev/null +++ b/python/mxnet/_ffi/_cython/core.pyx @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Acknowledgement: This file originates from incubator-tvm""" + +include "./base.pxi" +include "./ndarray.pxi" +include "./convert.pxi" +include "./function.pxi" diff --git a/python/mxnet/_ffi/_cython/function.pxi b/python/mxnet/_ffi/_cython/function.pxi new file mode 100644 index 000000000000..2683868cba03 --- /dev/null +++ b/python/mxnet/_ffi/_cython/function.pxi @@ -0,0 +1,163 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Acknowledgement: This file originates from incubator-tvm""" + +import ctypes +import traceback +from ...ndarray._internal import NDArrayBase +from numbers import Number, Integral + + +cdef inline int make_arg(object arg, + MXNetValue* value, + int* tcode, + ObjectRef* temp_objs, + list temp_args) except -1: + """Pack arguments into c args mxnet call accept""" + cdef unsigned long long ptr + + if isinstance(arg, (list, tuple)): + temp_objs[0] = convert_object(arg) + value[0].v_handle = ((temp_objs[0].get())) + tcode[0] = kObjectHandle + elif isinstance(arg, NDArrayBase): + value[0].v_handle = (arg._get_handle()) + tcode[0] = kNDArrayHandle + elif isinstance(arg, (int, long)): + value[0].v_int64 = arg + tcode[0] = kInt + elif isinstance(arg, float): + value[0].v_float64 = arg + tcode[0] = kFloat + elif isinstance(arg, str): + tstr = c_str(arg) + value[0].v_str = tstr + tcode[0] = kStr + temp_args.append(tstr) + elif arg is None: + value[0].v_handle = NULL + tcode[0] = kNull + elif isinstance(arg, Number): + value[0].v_float64 = arg + tcode[0] = kFloat + elif isinstance(arg, ctypes.c_void_p): + value[0].v_handle = c_handle(arg) + tcode[0] = kHandle + else: + raise TypeError("Don't know how to handle type %s" % type(arg)) + return 0 + + +cdef inline object make_ret(MXNetValue value, int tcode): + """convert result to return value.""" + if tcode == kNull: + return None + elif tcode == kInt: + return value.v_int64 + elif tcode == kFloat: + return value.v_float64 + elif tcode == kStr: + return py_str(value.v_str) + elif tcode == kHandle: + return ctypes_handle(value.v_handle) + elif tcode == kNDArrayHandle: + return c_make_array(value.v_handle) + raise ValueError("Unhandled type code %d" % tcode) + + +cdef inline int FuncCall3(void* chandle, + tuple args, + int nargs, + MXNetValue* ret_val, + int* ret_tcode) except -1: + cdef MXNetValue[3] values + cdef int[3] tcodes + cdef ObjectRef[3] temp_objs + nargs = len(args) + temp_args = [] + for i in range(nargs): + make_arg(args[i], &values[i], &tcodes[i], &temp_objs[i], temp_args) + CALL(MXNetFuncCall(chandle, &values[0], &tcodes[0], + nargs, ret_val, ret_tcode)) + return 0 + + +cdef inline int FuncCall(void* chandle, + tuple args, + MXNetValue* ret_val, + int* ret_tcode) except -1: + cdef int nargs + nargs = len(args) + if nargs <= 3: + FuncCall3(chandle, args, nargs, ret_val, ret_tcode) + return 0 + + cdef vector[MXNetValue] values + cdef vector[int] tcodes + cdef vector[ObjectRef] temp_objs + values.resize(nargs) + tcodes.resize(nargs) + temp_objs.resize(nargs) + + temp_args = [] + for i in range(nargs): + make_arg(args[i], &values[i], &tcodes[i], &temp_objs[i], temp_args) + CALL(MXNetFuncCall(chandle, &values[0], &tcodes[0], + nargs, ret_val, ret_tcode)) + return 0 + + +cdef class FunctionBase: + cdef MXNetFunctionHandle chandle + cdef int is_global + + cdef inline _set_handle(self, handle): + if handle is None: + self.chandle = NULL + else: + self.chandle = c_handle(handle) + + property is_global: + def __get__(self): + return self.c_is_global != 0 + + def __set__(self, value): + self.c_is_global = value + + property handle: + def __get__(self): + if self.chandle == NULL: + return None + else: + return ctypes.cast(self.chandle, ctypes.c_void_p) + def __set__(self, value): + self._set_handle(value) + + def __init__(self, handle, is_global): + self._set_handle(handle) + self.c_is_global = is_global + + def __dealloc__(self): + if self.is_global == 0: + CALL(MXNetFuncFree(self.chandle)) + + def __call__(self, *args): + cdef MXNetValue ret_val + cdef int ret_tcode + FuncCall(self.chandle, args, &ret_val, &ret_tcode) + return make_ret(ret_val, ret_tcode) diff --git a/python/mxnet/_ffi/_cython/ndarray.pxi b/python/mxnet/_ffi/_cython/ndarray.pxi new file mode 100644 index 000000000000..9bca64d8dd1a --- /dev/null +++ b/python/mxnet/_ffi/_cython/ndarray.pxi @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Acknowledgement: This file originates from incubator-tvm""" + +import ctypes +from ... import _global_var + +cdef c_make_array(void* handle): + return _global_var._np_ndarray_cls(handle=handle) diff --git a/python/mxnet/_ffi/base.py b/python/mxnet/_ffi/base.py new file mode 100644 index 000000000000..be68d20c53e2 --- /dev/null +++ b/python/mxnet/_ffi/base.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# coding: utf-8 +# pylint: disable=invalid-name +"""Base library for MXNet FFI. +Acknowledgement: This file originates from incubator-tvm +""" +import sys +import ctypes +import numpy as np + +string_types = (str,) +integer_types = (int, np.int32) +numeric_types = integer_types + (float, np.float32) +# this function is needed for python3 +# to convert ctypes.char_p .value back to python str +if sys.platform == "win32": + encoding = 'cp' + str(ctypes.cdll.kernel32.GetACP()) + py_str = lambda x: x.decode(encoding) +else: + py_str = lambda x: x.decode('utf-8') + +#---------------------------- +# helper function in ctypes. +#---------------------------- +def c_str(string): + """Create ctypes char * from a python string + Parameters + ---------- + string : string type + python string + + Returns + ------- + str : c_char_p + A char pointer that can be passed to C API + """ + return ctypes.c_char_p(string.encode('utf-8')) + + +def c_array(ctype, values): + """Create ctypes array from a python array + + Parameters + ---------- + ctype : ctypes data type + data type of the array we want to convert to + + values : tuple or list + data content + + Returns + ------- + out : ctypes array + Created ctypes array + """ + return (ctype * len(values))(*values) + + +def decorate(func, fwrapped): + """A wrapper call of decorator package, differs to call time + + Parameters + ---------- + func : function + The original function + + fwrapped : function + The wrapped function + """ + import decorator + return decorator.decorate(func, fwrapped) diff --git a/python/mxnet/_ffi/function.py b/python/mxnet/_ffi/function.py new file mode 100644 index 000000000000..b1f888fff030 --- /dev/null +++ b/python/mxnet/_ffi/function.py @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=invalid-name, unused-import +""" +Function namespace. +Acknowledgement: This file originates from incubator-tvm +""" +import os +import sys +import ctypes +from ..base import _LIB, check_call +from .base import py_str, c_str + +try: + if int(os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0: + from ._ctypes.function import FunctionBase as _FunctionBase + # To set RETURN_SWITCH for OBJECT_HANDLE + from . import object + else: + from ._cy3.core import FunctionBase as _FunctionBase +except ImportError: + if int(os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0: + raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1") + from ._ctypes.function import FunctionBase as _FunctionBase + # To set RETURN_SWITCH for OBJECT_HANDLE + from . import object + +FunctionHandle = ctypes.c_void_p + + +class Function(_FunctionBase): + """The PackedFunc object used in TVM. + + Function plays an key role to bridge front and backend in TVM. + Function provide a type-erased interface, you can call function with positional arguments. + + The compiled module returns Function. + TVM backend also registers and exposes its API as Functions. + For example, the developer function exposed in tvm.ir_pass are actually + C++ functions that are registered as PackedFunc + + The following are list of common usage scenario of tvm.Function. + + - Automatic exposure of C++ API into python + - To call PackedFunc from python side + - To call python callbacks to inspect results in generated code + - Bring python hook into C++ backend + + See Also + -------- + tvm.register_func: How to register global function. + tvm.get_global_func: How to get global function. + """ + + +def get_global_func(name, allow_missing=False): + """Get a global function by name + + Parameters + ---------- + name : str + The name of the global function + + allow_missing : bool + Whether allow missing function or raise an error. + + Returns + ------- + func : tvm.Function + The function to be returned, None if function is missing. + """ + handle = FunctionHandle() + check_call(_LIB.MXNetFuncGetGlobal(c_str(name), ctypes.byref(handle))) + if handle.value: + return Function(handle, False) + + if allow_missing: + return None + + raise ValueError("Cannot find global function %s" % name) + + +def list_global_func_names(): + """Get list of global functions registered. + + Returns + ------- + names : list + List of global functions names. + """ + plist = ctypes.POINTER(ctypes.c_char_p)() + size = ctypes.c_uint() + + check_call(_LIB.MXNetFuncListGlobalNames(ctypes.byref(size), + ctypes.byref(plist))) + fnames = [] + for i in range(size.value): + fnames.append(py_str(plist[i])) + return fnames + + +def _get_api(f): + flocal = f + flocal.is_global = True + return flocal + + +def _init_api(namespace, target_module_name=None): + """Initialize api for a given module name + + namespace : str + The namespace of the source registry + + target_module_name : str + The target module name if different from namespace + """ + target_module_name = ( + target_module_name if target_module_name else namespace) + if namespace.startswith("mxnet."): + _init_api_prefix(target_module_name, namespace[6:]) + else: + _init_api_prefix(target_module_name, namespace) + + +def _init_api_prefix(module_name, prefix): + module = sys.modules[module_name] + + for name in list_global_func_names(): + if prefix == "api": + fname = name + if name.startswith("_"): + target_module = sys.modules["mxnet._api_internal"] + else: + target_module = module + else: + if not name.startswith(prefix): + continue + fname = name[len(prefix)+1:] + target_module = module + + if fname.find(".") != -1: + continue + f = get_global_func(name) + ff = _get_api(f) + ff.__name__ = fname + ff.__doc__ = ("MXNet PackedFunc %s. " % fname) + setattr(target_module, ff.__name__, ff) diff --git a/python/mxnet/_ffi/node_generic.py b/python/mxnet/_ffi/node_generic.py new file mode 100644 index 000000000000..c7f332390ce7 --- /dev/null +++ b/python/mxnet/_ffi/node_generic.py @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Common implementation of Node generic related logic +Acknowledgement: This file originates from incubator-tvm""" +# pylint: disable=unused-import +from numbers import Number, Integral +from .. import _api_internal + +def _scalar_type_inference(value): + if hasattr(value, 'dtype'): + dtype = str(value.dtype) + elif isinstance(value, bool): + dtype = 'bool' + elif isinstance(value, float): + # We intentionally convert the float to float32 since it's more common in DL. + dtype = 'float32' + elif isinstance(value, int): + # We intentionally convert the python int to int32 since it's more common in DL. + dtype = 'int32' + else: + raise NotImplementedError('Cannot automatically inference the type.' + ' value={}'.format(value)) + return dtype + + +def convert_to_node(value): + """Convert a python value to corresponding node type. + + Parameters + ---------- + value : str + The value to be inspected. + + Returns + ------- + node : Node + The corresponding node value. + """ + if isinstance(value, Integral): + return _api_internal._Integer(value) + elif isinstance(value, (list, tuple)): + value = [convert_to_node(x) for x in value] + return _api_internal._ADT(*value) + raise ValueError("don't know how to convert type %s to node" % type(value)) + + +def const(value, dtype=None): + """Construct a constant value for a given type. + + Parameters + ---------- + value : int or float + The input value + + dtype : str or None, optional + The data type. + + Returns + ------- + expr : Expr + Constant expression corresponds to the value. + """ + if dtype is None: + dtype = _scalar_type_inference(value) + return _api_internal._const(value, dtype) diff --git a/python/mxnet/_ffi/object.py b/python/mxnet/_ffi/object.py new file mode 100644 index 000000000000..e0a4aa600f25 --- /dev/null +++ b/python/mxnet/_ffi/object.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Runtime Object API +Acknowledgement: This file originates from incubator-tvm""" +from ._ctypes.function import _set_class_object +from ._ctypes.object import ObjectBase as _ObjectBase + +class Object(_ObjectBase): + """Base class for all mxnet's runtime objects.""" + +_set_class_object(Object) diff --git a/python/mxnet/_ffi/runtime_ctypes.py b/python/mxnet/_ffi/runtime_ctypes.py new file mode 100644 index 000000000000..05e3cf3f3152 --- /dev/null +++ b/python/mxnet/_ffi/runtime_ctypes.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Common runtime ctypes. +Acknowledgement: This file originates from incubator-tvm +""" +# pylint: disable=invalid-name +import ctypes + + +class TVMByteArray(ctypes.Structure): + """Temp data structure for byte array.""" + _fields_ = [("data", ctypes.POINTER(ctypes.c_byte)), + ("size", ctypes.c_size_t)] diff --git a/python/mxnet/_global_var.py b/python/mxnet/_global_var.py new file mode 100644 index 000000000000..8d2de25b735d --- /dev/null +++ b/python/mxnet/_global_var.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""global variables for ffi""" + +_ndarray_cls = None +_np_ndarray_cls = None + + +def _set_ndarray_class(cls): + global _ndarray_cls + _ndarray_cls = cls + + +def _set_np_ndarray_class(cls): + global _np_ndarray_cls + _np_ndarray_cls = cls diff --git a/python/mxnet/api.py b/python/mxnet/api.py new file mode 100644 index 000000000000..56eec942e820 --- /dev/null +++ b/python/mxnet/api.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Functions defined in MXNet. +Acknowledgement: This file originates from incubator-tvm""" + +from ._ffi.function import _init_api + +_init_api("mxnet.api") diff --git a/python/mxnet/cython/__init__.py b/python/mxnet/cython/__init__.py new file mode 100644 index 000000000000..135c600fca77 --- /dev/null +++ b/python/mxnet/cython/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""cython""" diff --git a/python/mxnet/cython/ndarray.pyx b/python/mxnet/cython/ndarray.pyx index ade6cf42585d..9e0504d306de 100644 --- a/python/mxnet/cython/ndarray.pyx +++ b/python/mxnet/cython/ndarray.pyx @@ -22,6 +22,7 @@ import numpy as np from ..ndarray_doc import _build_doc from libc.stdint cimport uint32_t, int64_t from ..base import _LIB +from .. import _global_var include "./base.pyi" @@ -36,7 +37,10 @@ cdef class NDArrayBase: if handle is None: self.chandle = NULL else: - ptr = handle.value + if isinstance(handle, (int, long)): + ptr = handle + else: + ptr = handle.value self.chandle = (ptr) property handle: @@ -59,20 +63,11 @@ cdef class NDArrayBase: CALL(MXNDArrayFree(self.chandle)) def __reduce__(self): - return (_ndarray_cls, (None,), self.__getstate__()) - - -_ndarray_cls = None -_np_ndarray_cls = None - -def _set_ndarray_class(cls): - global _ndarray_cls - _ndarray_cls = cls + return (_global_var._ndarray_cls, (None,), self.__getstate__()) + def _get_handle(self): + return self.chandle -def _set_np_ndarray_class(cls): - global _np_ndarray_cls - _np_ndarray_cls = cls def _monitor_callback_wrapper(callback): def callback_handle(name, opr_name, arr, _): @@ -81,7 +76,7 @@ def _monitor_callback_wrapper(callback): cdef NewArray(NDArrayHandle handle, int stype=-1, int is_np_array=0): """Create a new array given handle""" - create_array_fn = _np_ndarray_cls if is_np_array else _ndarray_cls + create_array_fn = _global_var._np_ndarray_cls if is_np_array else _global_var._ndarray_cls return create_array_fn(_ctypes.cast(handle, _ctypes.c_void_p), stype=stype) diff --git a/python/mxnet/ndarray/_internal.py b/python/mxnet/ndarray/_internal.py index 14142c407fde..716ef4dedaef 100644 --- a/python/mxnet/ndarray/_internal.py +++ b/python/mxnet/ndarray/_internal.py @@ -23,15 +23,18 @@ try: if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0: from .._ctypes.ndarray import NDArrayBase, CachedOp - from .._ctypes.ndarray import _set_ndarray_class, _imperative_invoke, _set_np_ndarray_class + from .._ctypes.ndarray import _imperative_invoke + from .._global_var import _set_ndarray_class, _set_np_ndarray_class else: from .._cy3.ndarray import NDArrayBase, CachedOp - from .._cy3.ndarray import _set_ndarray_class, _imperative_invoke, _set_np_ndarray_class + from .._cy3.ndarray import _imperative_invoke + from .._global_var import _set_ndarray_class, _set_np_ndarray_class except ImportError: if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0: raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1") from .._ctypes.ndarray import NDArrayBase, CachedOp - from .._ctypes.ndarray import _set_ndarray_class, _imperative_invoke, _set_np_ndarray_class + from .._ctypes.ndarray import _imperative_invoke + from .._global_var import _set_ndarray_class, _set_np_ndarray_class from ..base import _Null try: diff --git a/python/mxnet/ndarray/numpy/_api_internal.py b/python/mxnet/ndarray/numpy/_api_internal.py new file mode 100644 index 000000000000..c56b85b8322f --- /dev/null +++ b/python/mxnet/ndarray/numpy/_api_internal.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Namespace for numpy internal api.""" + +from ..._ffi.function import _init_api + +__all__ = [] + +_init_api("_npi", "mxnet.ndarray.numpy._api_internal") diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index feb2caa67b2f..242931862988 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -25,6 +25,7 @@ from ...util import wrap_np_unary_func, wrap_np_binary_func from ...context import current_context from . import _internal as _npi +from . import _api_internal from ..ndarray import NDArray @@ -83,7 +84,7 @@ def shape(a): @set_module('mxnet.ndarray.numpy') -def zeros(shape, dtype=_np.float32, order='C', ctx=None): # pylint: disable=redefined-outer-name +def zeros(shape, dtype=None, order='C', ctx=None): # pylint: disable=redefined-outer-name """Return a new array of given shape and type, filled with zeros. This function currently only supports storing multi-dimensional data in row-major (C-style). @@ -110,10 +111,15 @@ def zeros(shape, dtype=_np.float32, order='C', ctx=None): # pylint: disable=red """ if order != 'C': raise NotImplementedError + # If the following code (4 lines) regarding ctx is removed + # np.zeros((3, 4)) can be as fast as 4.96 us if ctx is None: - ctx = current_context() - dtype = _np.float32 if dtype is None else dtype - return _npi.zeros(shape=shape, ctx=ctx, dtype=dtype) + ctx = str(current_context()) + else: + ctx = str(ctx) + if dtype is not None and not isinstance(dtype, str): + dtype = _np.dtype(dtype).name + return _api_internal.zeros(shape, dtype, ctx) @set_module('mxnet.ndarray.numpy') @@ -1541,21 +1547,7 @@ def tensordot(a, b, axes=2): [ 4796., 5162.], [ 4928., 5306.]]) """ - if _np.isscalar(axes): - return _npi.tensordot_int_axes(a, b, axes) - - if len(axes) != 2: - raise ValueError('Axes must consist of two arrays.') - a_axes_summed, b_axes_summed = axes - if _np.isscalar(a_axes_summed): - a_axes_summed = (a_axes_summed,) - if _np.isscalar(b_axes_summed): - b_axes_summed = (b_axes_summed,) - - if len(a_axes_summed) != len(b_axes_summed): - raise ValueError('Axes length mismatch') - - return _npi.tensordot(a, b, a_axes_summed, b_axes_summed) + return _api_internal.tensordot(a, b, axes) @set_module('mxnet.ndarray.numpy') diff --git a/python/mxnet/numpy/_register.py b/python/mxnet/numpy/_register.py index fbf53988eb0b..394ec5c75dd4 100644 --- a/python/mxnet/numpy/_register.py +++ b/python/mxnet/numpy/_register.py @@ -21,6 +21,5 @@ from ..base import _init_np_op_module from ..ndarray.register import _make_ndarray_function - _init_np_op_module(root_module_name='mxnet', np_module_name='numpy', mx_module_name=None, make_op_func=_make_ndarray_function) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 382dbc0ea472..8e129ce35429 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -2312,7 +2312,7 @@ def shape(a): @set_module('mxnet.numpy') -def zeros(shape, dtype=_np.float32, order='C', ctx=None): # pylint: disable=redefined-outer-name +def zeros(shape, dtype=None, order='C', ctx=None): # pylint: disable=redefined-outer-name """Return a new array of given shape and type, filled with zeros. This function currently only supports storing multi-dimensional data in row-major (C-style). diff --git a/python/setup.py b/python/setup.py index ab1ff6950277..dcd84cef1ea1 100644 --- a/python/setup.py +++ b/python/setup.py @@ -94,6 +94,20 @@ def config_cython(): libraries=libraries, extra_link_args=extra_link_args, language="c++")) + + path = "mxnet/_ffi/_cython" + for fn in os.listdir(path): + if not fn.endswith(".pyx"): + continue + ret.append(Extension( + "mxnet._ffi.%s.%s" % (subdir, fn[:-4]), + ["mxnet/_ffi/_cython/%s" % fn], + include_dirs=["../include/", "../3rdparty/tvm/nnvm/include"], + library_dirs=library_dirs, + libraries=libraries, + extra_compile_args=["-std=c++11"], + extra_link_args=extra_link_args, + language="c++")) # If `force=True` is not used and you cythonize the modules for python2 and python3 # successively, you need to delete `mxnet/cython/ndarray.cpp` after the first cythonize. diff --git a/src/api/_api_internal/_api_internal.cc b/src/api/_api_internal/_api_internal.cc new file mode 100644 index 000000000000..586dce82f383 --- /dev/null +++ b/src/api/_api_internal/_api_internal.cc @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file _api_internal.cc + * \brief Internal functions exposed to python for FFI use only + */ +// Acknowledgement: This file originates from incubator-tvm +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mxnet { + +MXNET_REGISTER_GLOBAL("_Integer") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + if (args[0].type_code() == kDLInt) { + *ret = Integer(args[0].operator int64_t()); + } else { + LOG(FATAL) << "only accept int"; + } +}); + +MXNET_REGISTER_GLOBAL("_ADT") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + std::vector data; + for (int i = 0; i < args.size(); ++i) { + if (args[i].type_code() != kNull) { + data.push_back(args[i].operator ObjectRef()); + } else { + data.emplace_back(nullptr); + } + } + *ret = ADT(0, data.begin(), data.end()); +}); + +MXNET_REGISTER_API("_nop") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +}); + +} // namespace mxnet diff --git a/src/api/operator/numpy/np_init_op.cc b/src/api/operator/numpy/np_init_op.cc new file mode 100644 index 000000000000..746985c6e9f3 --- /dev/null +++ b/src/api/operator/numpy/np_init_op.cc @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_init_op.cc + * \brief Implementation of the API of functions in src/operator/numpy/np_init_op.cc + */ +#include "../utils.h" +#include "../../../operator/tensor/init_op.h" + +namespace mxnet { + +MXNET_REGISTER_API("_npi.zeros") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_zeros"); + nnvm::NodeAttrs attrs; + op::InitOpParam param; + if (args[0].type_code() == kDLInt) { + param.shape = TShape(1, args[0].operator int64_t()); + } else { + param.shape = TShape(args[0].operator ObjectRef()); + } + if (args[1].type_code() == kNull) { + param.dtype = mshadow::kFloat32; + } else { + param.dtype = String2MXNetTypeWithBool(args[1].operator std::string()); + } + attrs.parsed = std::move(param); + attrs.op = op; + if (args[2].type_code() != kNull) { + attrs.dict["ctx"] = args[2].operator std::string(); + } + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, 0, nullptr, &num_outputs, nullptr); + *ret = ndoutputs[0]; +}); + +} // namespace mxnet diff --git a/src/api/operator/numpy/np_tensordot_op.cc b/src/api/operator/numpy/np_tensordot_op.cc new file mode 100644 index 000000000000..ade2a0314d01 --- /dev/null +++ b/src/api/operator/numpy/np_tensordot_op.cc @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_tensordot_op.cc + * \brief Implementation of the API of functions in src/operator/numpy/np_tensordot_op.cc + */ +#include "../utils.h" +#include "../../../operator/numpy/np_tensordot_op-inl.h" + +namespace mxnet { + +inline static void _npi_tensordot_int_axes(runtime::MXNetArgs args, + runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_tensordot_int_axes"); + op::TensordotIntAxesParam param; + nnvm::NodeAttrs attrs; + attrs.op = op; + param.axes = args[2].operator int(); + // we directly copy TensordotIntAxesParam, which is trivially-copyable + attrs.parsed = param; + int num_outputs = 0; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; + auto ndoutputs = Invoke(op, &attrs, 2, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); +} + +inline static void _npi_tensordot(runtime::MXNetArgs args, + runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_tensordot"); + op::TensordotParam param; + nnvm::NodeAttrs attrs; + attrs.op = op; + ADT adt = Downcast(args[2].operator ObjectRef()); + if (const IntegerObj* lop = adt[0].as()) { + param.a_axes_summed = Tuple(1, lop->value); + param.b_axes_summed = Tuple(1, Downcast(adt[1])->value); + } else { + param.a_axes_summed = Tuple(adt[0]); + param.b_axes_summed = Tuple(adt[1]); + } + attrs.parsed = std::move(param); + int num_outputs = 0; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; + auto ndoutputs = Invoke(op, &attrs, 2, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); +} + +MXNET_REGISTER_API("_npi.tensordot") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + if (args[2].type_code() == kDLInt) { + _npi_tensordot_int_axes(args, ret); + } else { + _npi_tensordot(args, ret); + } +}); + +} // namespace mxnet diff --git a/src/api/operator/utils.cc b/src/api/operator/utils.cc new file mode 100644 index 000000000000..d8cd4c922603 --- /dev/null +++ b/src/api/operator/utils.cc @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file utils.cc + * \brief Utility functions for operator invoke + */ +#include "utils.h" + +namespace mxnet { + +void SetInOut(std::vector* ndinputs, + std::vector* ndoutputs, + int num_inputs, + NDArray** inputs, + int *num_outputs, + int infered_num_outputs, + int num_visible_outputs, + NDArray** out_array) { + ndinputs->clear(); + ndinputs->reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + NDArray* inp = reinterpret_cast(inputs[i]); + if (!features::is_enabled(features::INT64_TENSOR_SIZE)) { + CHECK_LT(inp->shape().Size(), (int64_t{1} << 31) - 1) << + "[SetNDInputsOutputs] Size of tensor you are trying to allocate is larger than " + "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1"; + } + ndinputs->emplace_back(inp); + } + + ndoutputs->clear(); + ndoutputs->reserve(infered_num_outputs); + if (out_array == nullptr) { + for (int i = 0; i < infered_num_outputs; ++i) { + ndoutputs->emplace_back(new NDArray()); + } + *num_outputs = num_visible_outputs; + } else { + CHECK(*num_outputs == infered_num_outputs || *num_outputs == num_visible_outputs) + << "Operator expects " << infered_num_outputs << " (all) or " + << num_visible_outputs << " (visible only) outputs, but got " + << *num_outputs << " instead."; + for (int i = 0; i < *num_outputs; ++i) { + ndoutputs->emplace_back(out_array[i]); + } + for (int i = *num_outputs; i < infered_num_outputs; ++i) { + ndoutputs->emplace_back(new NDArray()); + } + } +} + +} // namespace mxnet diff --git a/src/api/operator/utils.h b/src/api/operator/utils.h new file mode 100644 index 000000000000..7a31e4537780 --- /dev/null +++ b/src/api/operator/utils.h @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file utils.h + * \brief Utility functions for operator invoke + */ +#ifndef MXNET_API_OPERATOR_UTILS_H_ +#define MXNET_API_OPERATOR_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "../../imperative/imperative_utils.h" + +namespace mxnet { + +void SetInOut(std::vector* ndinputs, + std::vector* ndoutputs, + int num_inputs, + NDArray** inputs, + int *num_outputs, + int infered_num_outputs, + int num_visible_outputs, + NDArray** out_array); + +template +std::vector Invoke(const nnvm::Op* op, + nnvm::NodeAttrs* attrs, + int num_inputs, + NDArray** inputs, + int* num_outputs, + NDArray** outputs) { + int infered_num_outputs; + int num_visible_outputs; + imperative::SetNumOutputs(op, *attrs, num_inputs, &infered_num_outputs, &num_visible_outputs); + + std::vector ndinputs, ndoutputs; + SetInOut(&ndinputs, &ndoutputs, num_inputs, inputs, + num_outputs, infered_num_outputs, num_visible_outputs, outputs); + + auto state = Imperative::Get()->Invoke(Context::CPU(), *attrs, ndinputs, ndoutputs); + if (Imperative::Get()->is_recording()) { + ::dmlc::get(attrs->parsed).SetAttrDict(&(attrs->dict)); + Imperative::Get()->RecordOp(std::move(*attrs), ndinputs, ndoutputs, state); + } + for (int i = *num_outputs; i < infered_num_outputs; ++i) delete ndoutputs[i]; + return ndoutputs; +} + +} // namespace mxnet + +#endif // MXNET_API_OPERATOR_UTILS_H_ diff --git a/src/ir/expr.cc b/src/ir/expr.cc new file mode 100644 index 000000000000..75d76edfff5d --- /dev/null +++ b/src/ir/expr.cc @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file expr.cc + * \brief The expression AST nodes for the common IR infra. + */ +// Acknowledgement: This file originates from incubator-tvm + +#include + +namespace mxnet { + +IntImm::IntImm(MXNetDataType dtype, int64_t value) { + CHECK(dtype.is_scalar()) + << "ValueError: IntImm can only take scalar."; + CHECK(dtype.is_int() || dtype.is_uint()) + << "ValueError: IntImm can only take scalar."; + if (dtype.is_uint()) { + CHECK_GE(value, 0U); + } + runtime::ObjectPtr node = make_object(); + node->dtype = dtype; + node->value = value; + data_ = std::move(node); +} + +FloatImm::FloatImm(MXNetDataType dtype, double value) { + CHECK_EQ(dtype.lanes(), 1) + << "ValueError: FloatImm can only take scalar."; + runtime::ObjectPtr node = make_object(); + node->dtype = dtype; + node->value = value; + data_ = std::move(node); +} + +} // namespace mxnet diff --git a/src/lang/expr.cc b/src/lang/expr.cc new file mode 100644 index 000000000000..a2e44cedc574 --- /dev/null +++ b/src/lang/expr.cc @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file expr.cc + */ +// Acknowledgement: This file originates from incubator-tvm + +#include +#include + +namespace mxnet { + +MXNET_REGISTER_OBJECT_TYPE(ArrayNode); + +} // namespace mxnet diff --git a/src/lang/ir.cc b/src/lang/ir.cc new file mode 100644 index 000000000000..28a6af18734a --- /dev/null +++ b/src/lang/ir.cc @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ir.cc + */ +// Acknowledgement: This file originates from incubator-tvm + +#include +#include + +namespace mxnet { + +MXNET_REGISTER_OBJECT_TYPE(IntImmNode); +MXNET_REGISTER_OBJECT_TYPE(FloatImmNode); + +} // namespace mxnet diff --git a/src/operator/numpy/np_tensordot_op-inl.h b/src/operator/numpy/np_tensordot_op-inl.h index 1c62527db43f..d025f1558535 100644 --- a/src/operator/numpy/np_tensordot_op-inl.h +++ b/src/operator/numpy/np_tensordot_op-inl.h @@ -25,6 +25,7 @@ #define MXNET_OPERATOR_NUMPY_NP_TENSORDOT_OP_INL_H_ #include +#include #include "../tensor/matrix_op-inl.h" namespace mxnet { @@ -38,6 +39,13 @@ struct TensordotParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(a_axes_summed); DMLC_DECLARE_FIELD(b_axes_summed); } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream a_axes_summed_s, b_axes_summed_s; + a_axes_summed_s << a_axes_summed; + b_axes_summed_s << b_axes_summed; + (*dict)["a_axes_summed"] = a_axes_summed_s.str(); + (*dict)["b_axes_summed"] = b_axes_summed_s.str(); + } }; /** @@ -553,6 +561,11 @@ struct TensordotIntAxesParam : public dmlc::Parameter { DMLC_DECLARE_PARAMETER(TensordotIntAxesParam) { DMLC_DECLARE_FIELD(axes); } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream axes_s; + axes_s << axes; + (*dict)["axes"] = axes_s.str(); + } }; /** diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index 3ba4f2c0f5df..8e22bb743fcd 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -60,6 +60,15 @@ struct InitOpParam : public dmlc::Parameter { MXNET_ADD_ALL_TYPES_WITH_BOOL .describe("Target data type."); } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream shape_s, dtype_s; + shape_s << shape; + dtype_s << dtype; + (*dict)["shape"] = shape_s.str(); + (*dict)["dtype"] = dtype_s.str(); + // We do not set ctx, because ctx has been set in dict instead of InitOpParam. + // Setting ctx here results in an error. + } }; struct InitOpWithoutDTypeParam : public dmlc::Parameter { diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc new file mode 100644 index 000000000000..a78a5cdf5971 --- /dev/null +++ b/src/runtime/c_runtime_api.cc @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file c_runtime_api.cc + * \brief Device specific implementations + */ +// Acknowledgement: This file originates from incubator-tvm + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../c_api/c_api_common.h" + +using namespace mxnet::runtime; + +struct MXNetRuntimeEntry { + std::string ret_str; + std::string last_error; + MXNetByteArray ret_bytes; +}; + +typedef dmlc::ThreadLocalStore MXNetAPIRuntimeStore; + +int MXNetFuncFree(MXNetFunctionHandle func) { + API_BEGIN(); + delete static_cast(func); + API_END(); +} + +int MXNetFuncCall(MXNetFunctionHandle func, + MXNetValue* args, + int* arg_type_codes, + int num_args, + MXNetValue* ret_val, + int* ret_type_code) { + API_BEGIN(); + MXNetRetValue rv; + (*static_cast(func)).CallPacked( + MXNetArgs(args, arg_type_codes, num_args), &rv); + // handle return string. + if (rv.type_code() == kStr || + rv.type_code() == kBytes) { + MXNetRuntimeEntry* e = MXNetAPIRuntimeStore::Get(); + e->ret_str = *rv.ptr(); + if (rv.type_code() == kBytes) { + e->ret_bytes.data = e->ret_str.c_str(); + e->ret_bytes.size = e->ret_str.length(); + *ret_type_code = kBytes; + ret_val->v_handle = &(e->ret_bytes); + } else { + *ret_type_code = kStr; + ret_val->v_str = e->ret_str.c_str(); + } + } else { + rv.MoveToCHost(ret_val, ret_type_code); + } + API_END(); +} diff --git a/src/runtime/object.cc b/src/runtime/object.cc new file mode 100644 index 000000000000..76d8b9776d8a --- /dev/null +++ b/src/runtime/object.cc @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file object.cc + * \brief Object type management system. + */ +// Acknowledgement: This file originates from incubator-tvm + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../c_api/c_api_common.h" +#include "./object_internal.h" + +namespace mxnet { +namespace runtime { + +/*! \brief Type information */ +struct TypeInfo { + /*! \brief The current index. */ + uint32_t index{0}; + /*! \brief Index of the parent in the type hierachy */ + uint32_t parent_index{0}; + // NOTE: the indices in [index, index + num_reserved_slots) are + // reserved for the child-class of this type. + /*! \brief Total number of slots reserved for the type and its children. */ + uint32_t num_slots{0}; + /*! \brief number of allocated child slots. */ + uint32_t allocated_slots{0}; + /*! \brief Whether child can overflow. */ + bool child_slots_can_overflow{true}; + /*! \brief name of the type. */ + std::string name; + /*! \brief hash of the name */ + size_t name_hash{0}; +}; + +/*! + * \brief Type context that manages the type hierachy information. + */ +class TypeContext { + public: + // NOTE: this is a relatively slow path for child checking + // Most types are already checked by the fast-path via reserved slot checking. + bool DerivedFrom(uint32_t child_tindex, uint32_t parent_tindex) { + // invariance: child's type index is always bigger than its parent. + if (child_tindex < parent_tindex) return false; + if (child_tindex == parent_tindex) return true; + { + std::lock_guard lock(mutex_); + CHECK_LT(child_tindex, type_table_.size()); + while (child_tindex > parent_tindex) { + child_tindex = type_table_[child_tindex].parent_index; + } + } + return child_tindex == parent_tindex; + } + + uint32_t GetOrAllocRuntimeTypeIndex(const std::string& skey, + uint32_t static_tindex, + uint32_t parent_tindex, + uint32_t num_child_slots, + bool child_slots_can_overflow) { + std::lock_guard lock(mutex_); + auto it = type_key2index_.find(skey); + if (it != type_key2index_.end()) { + return it->second; + } + // try to allocate from parent's type table. + CHECK_LT(parent_tindex, type_table_.size()); + TypeInfo& pinfo = type_table_[parent_tindex]; + CHECK_EQ(pinfo.index, parent_tindex); + + // if parent cannot overflow, then this class cannot. + if (!pinfo.child_slots_can_overflow) { + child_slots_can_overflow = false; + } + + // total number of slots include the type itself. + uint32_t num_slots = num_child_slots + 1; + uint32_t allocated_tindex; + + if (static_tindex != TypeIndex::kDynamic) { + // statically assigned type + allocated_tindex = static_tindex; + CHECK_LT(static_tindex, type_table_.size()); + CHECK_EQ(type_table_[allocated_tindex].allocated_slots, 0U) + << "Conflicting static index " << static_tindex + << " between " << type_table_[allocated_tindex].name + << " and " + << skey; + } else if (pinfo.allocated_slots + num_slots < pinfo.num_slots) { + // allocate the slot from parent's reserved pool + allocated_tindex = parent_tindex + pinfo.allocated_slots; + // update parent's state + pinfo.allocated_slots += num_slots; + } else { + CHECK(pinfo.child_slots_can_overflow) + << "Reach maximum number of sub-classes for " << pinfo.name; + // allocate new entries. + allocated_tindex = type_counter_; + type_counter_ += num_slots; + CHECK_LE(type_table_.size(), allocated_tindex); + type_table_.resize(allocated_tindex + 1, TypeInfo()); + } + CHECK_GT(allocated_tindex, parent_tindex); + // initialize the slot. + type_table_[allocated_tindex].index = allocated_tindex; + type_table_[allocated_tindex].parent_index = parent_tindex; + type_table_[allocated_tindex].num_slots = num_slots; + type_table_[allocated_tindex].allocated_slots = 1; + type_table_[allocated_tindex].child_slots_can_overflow = + child_slots_can_overflow; + type_table_[allocated_tindex].name = skey; + type_table_[allocated_tindex].name_hash = std::hash()(skey); + // update the key2index mapping. + type_key2index_[skey] = allocated_tindex; + return allocated_tindex; + } + + std::string TypeIndex2Key(uint32_t tindex) { + std::lock_guard lock(mutex_); + CHECK(tindex < type_table_.size() && + type_table_[tindex].allocated_slots != 0) + << "Unknown type index " << tindex; + return type_table_[tindex].name; + } + + size_t TypeIndex2KeyHash(uint32_t tindex) { + std::lock_guard lock(mutex_); + CHECK(tindex < type_table_.size() && + type_table_[tindex].allocated_slots != 0) + << "Unknown type index " << tindex; + return type_table_[tindex].name_hash; + } + + uint32_t TypeKey2Index(const std::string& skey) { + auto it = type_key2index_.find(skey); + CHECK(it != type_key2index_.end()) + << "Cannot find type " << skey; + return it->second; + } + + static TypeContext* Global() { + static TypeContext inst; + return &inst; + } + + private: + TypeContext() { + type_table_.resize(TypeIndex::kStaticIndexEnd, TypeInfo()); + } + // mutex to avoid registration from multiple threads. + std::mutex mutex_; + std::atomic type_counter_{TypeIndex::kStaticIndexEnd}; + std::vector type_table_; + std::unordered_map type_key2index_; +}; + +uint32_t Object::GetOrAllocRuntimeTypeIndex(const std::string& key, + uint32_t static_tindex, + uint32_t parent_tindex, + uint32_t num_child_slots, + bool child_slots_can_overflow) { + return TypeContext::Global()->GetOrAllocRuntimeTypeIndex( + key, static_tindex, parent_tindex, num_child_slots, child_slots_can_overflow); +} + +bool Object::DerivedFrom(uint32_t parent_tindex) const { + return TypeContext::Global()->DerivedFrom( + this->type_index_, parent_tindex); +} + +std::string Object::TypeIndex2Key(uint32_t tindex) { + return TypeContext::Global()->TypeIndex2Key(tindex); +} + +size_t Object::TypeIndex2KeyHash(uint32_t tindex) { + return TypeContext::Global()->TypeIndex2KeyHash(tindex); +} + +uint32_t Object::TypeKey2Index(const std::string& key) { + return TypeContext::Global()->TypeKey2Index(key); +} + +} // namespace runtime +} // namespace mxnet + +int MXNetObjectFree(MXNetObjectHandle obj) { + API_BEGIN(); + mxnet::runtime::ObjectInternal::ObjectFree(obj); + API_END(); +} diff --git a/src/runtime/object_internal.h b/src/runtime/object_internal.h new file mode 100644 index 000000000000..002468456741 --- /dev/null +++ b/src/runtime/object_internal.h @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file object_internal.h + * \brief Expose a few functions for CFFI purposes. + * This file is not intended to be used + */ +// Acknowledgement: This file originates from incubator-tvm +#ifndef MXNET_RUNTIME_OBJECT_INTERNAL_H_ +#define MXNET_RUNTIME_OBJECT_INTERNAL_H_ + +#include +#include +#include + +namespace mxnet { +namespace runtime { + +/*! + * \brief Internal object namespace to expose + * certain util functions for FFI. + */ +class ObjectInternal { + public: + /*! + * \brief Free an object handle. + */ + static void ObjectFree(MXNetObjectHandle obj) { + if (obj != nullptr) { + static_cast(obj)->DecRef(); + } + } +}; + +} // namespace runtime +} // namespace mxnet +#endif // MXNET_RUNTIME_OBJECT_INTERNAL_H_ diff --git a/src/runtime/registry.cc b/src/runtime/registry.cc new file mode 100644 index 000000000000..276c1ba73d18 --- /dev/null +++ b/src/runtime/registry.cc @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file registry.cc + * \brief The global registry of packed function. + */ +// Acknowledgement: This file originates from incubator-tvm +#include +#include +#include +#include +#include +#include +#include +#include "../c_api/c_api_common.h" + +namespace mxnet { +namespace runtime { + +struct Registry::Manager { + // map storing the functions. + // We delibrately used raw pointer + // This is because PackedFunc can contain callbacks into the host languge(python) + // and the resource can become invalid because of indeterminstic order of destruction. + // The resources will only be recycled during program exit. + std::unordered_map fmap; + std::mutex mutex; + + // vtable for extension type is not suported for now + Manager() {} + + static Manager* Global() { + // We deliberately leak the Manager instance, to avoid leak sanitizers + // complaining about the entries in Manager::fmap being leaked at program + // exit. + static Manager* inst = new Manager(); + return inst; + } +}; + +Registry& Registry::set_body(PackedFunc f) { // NOLINT(*) + func_ = f; + return *this; +} + +Registry& Registry::Register(const std::string& name, bool override) { // NOLINT(*) + Manager* m = Manager::Global(); + std::lock_guard lock(m->mutex); + auto it = m->fmap.find(name); + if (it == m->fmap.end()) { + Registry* r = new Registry(); + r->name_ = name; + m->fmap[name] = r; + return *r; + } else { + CHECK(override) + << "Global PackedFunc " << name << " is already registered"; + return *it->second; + } +} + +bool Registry::Remove(const std::string& name) { + Manager* m = Manager::Global(); + std::lock_guard lock(m->mutex); + auto it = m->fmap.find(name); + if (it == m->fmap.end()) return false; + m->fmap.erase(it); + return true; +} + +const PackedFunc* Registry::Get(const std::string& name) { + Manager* m = Manager::Global(); + std::lock_guard lock(m->mutex); + auto it = m->fmap.find(name); + if (it == m->fmap.end()) return nullptr; + return &(it->second->func_); +} + +std::vector Registry::ListNames() { + Manager* m = Manager::Global(); + std::lock_guard lock(m->mutex); + std::vector keys; + keys.reserve(m->fmap.size()); + for (const auto &kv : m->fmap) { + keys.push_back(kv.first); + } + return keys; +} + +} // namespace runtime +} // namespace mxnet + +/*! \brief entry to to easily hold returning information */ +struct MXNetFuncThreadLocalEntry { + /*! \brief result holder for returning strings */ + std::vector ret_vec_str; + /*! \brief result holder for returning string pointers */ + std::vector ret_vec_charp; +}; + +/*! \brief Thread local store that can be used to hold return values. */ +typedef dmlc::ThreadLocalStore MXNetFuncThreadLocalStore; + +int MXNetFuncGetGlobal(const char* name, MXNetFunctionHandle* out) { + API_BEGIN(); + const mxnet::runtime::PackedFunc* fp = + mxnet::runtime::Registry::Get(name); + if (fp != nullptr) { + *out = new mxnet::runtime::PackedFunc(*fp); // NOLINT(*) + } else { + *out = nullptr; + } + API_END(); +} + +int MXNetFuncListGlobalNames(int *out_size, + const char*** out_array) { + API_BEGIN(); + MXNetFuncThreadLocalEntry *ret = MXNetFuncThreadLocalStore::Get(); + ret->ret_vec_str = mxnet::runtime::Registry::ListNames(); + ret->ret_vec_charp.clear(); + for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { + ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); + } + *out_array = dmlc::BeginPtr(ret->ret_vec_charp); + *out_size = static_cast(ret->ret_vec_str.size()); + API_END(); +}