This repository was archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.7k
MXNet FFI for Operator Imperative Invocation #17510
Merged
Merged
Changes from all commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
b9c6ea1
Init
meta-project-ci 6aafabd
Add nop
meta-project-ci 78b885d
Add utility function SetInOut and Invoke
meta-project-ci 939863a
Init ctypes
meta-project-ci 06e75f6
Dispatch for default/CSR array
meta-project-ci 08ba4c7
Refactor, register the funcs where they are used, except for _api_int…
meta-project-ci ee75e2f
Seperate tvm ffi api and legacy api
meta-project-ci f3922ff
Replace legacy zeros with new
meta-project-ci 5004c9b
Fix numpy.int64 in shape
meta-project-ci 5dd4799
Fix sanity
meta-project-ci 8dd68bf
Fix
meta-project-ci 490ce05
Remove python2 support
meta-project-ci b748349
Cleanup
meta-project-ci e3acc91
Fix ci
meta-project-ci e424a1f
Fix lint
meta-project-ci 9848108
Revert rand_shape_nd
meta-project-ci 6f9906d
Fix clang-tidy
meta-project-ci 3aaad87
Support NDArray in ctypes
meta-project-ci a65380e
Using runtime
meta-project-ci e84b494
Conversion ctor
meta-project-ci 167a825
Tensordot
meta-project-ci 976fbd3
Tensordot backward
meta-project-ci 1817c2a
Fix nop regression
meta-project-ci 92f27b3
Deprecate Array
meta-project-ci f4fa30d
Fix comments
meta-project-ci ae194db
Fix comments
meta-project-ci be7db15
Add acknowledgement to incubator-tvm
meta-project-ci d29bb7d
Refactor according to comments
meta-project-ci File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one | ||
| * or more contributor license agreements. See the NOTICE file | ||
| * distributed with this work for additional information | ||
| * regarding copyright ownership. The ASF licenses this file | ||
| * to you under the Apache License, Version 2.0 (the | ||
| * "License"); you may not use this file except in compliance | ||
| * with the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, | ||
| * software distributed under the License is distributed on an | ||
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| * KIND, either express or implied. See the License for the | ||
| * specific language governing permissions and limitations | ||
| * under the License. | ||
| */ | ||
|
|
||
| /*! | ||
| * \file api_registry.h | ||
| * \brief This file contains utilities related to | ||
| * the MXNet's global function registry. | ||
| */ | ||
| // Acknowledgement: This file originates from incubator-tvm | ||
| #ifndef MXNET_API_REGISTRY_H_ | ||
| #define MXNET_API_REGISTRY_H_ | ||
|
|
||
| #include <string> | ||
| #include <utility> | ||
| #include "runtime/registry.h" | ||
|
|
||
| namespace mxnet { | ||
| /*! | ||
| * \brief Register an API function globally. | ||
| * It simply redirects to MXNET_REGISTER_GLOBAL | ||
| * | ||
| * \code | ||
| * MXNET_REGISTER_API(MyPrint) | ||
| * .set_body([](MXNetArgs args, MXNetRetValue* rv) { | ||
| * // my code. | ||
| * }); | ||
| * \endcode | ||
| */ | ||
| #define MXNET_REGISTER_API(OpName) MXNET_REGISTER_GLOBAL(OpName) | ||
|
|
||
| } // namespace mxnet | ||
| #endif // MXNET_API_REGISTRY_H_ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one | ||
| * or more contributor license agreements. See the NOTICE file | ||
| * distributed with this work for additional information | ||
| * regarding copyright ownership. The ASF licenses this file | ||
| * to you under the Apache License, Version 2.0 (the | ||
| * "License"); you may not use this file except in compliance | ||
| * with the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, | ||
| * software distributed under the License is distributed on an | ||
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| * KIND, either express or implied. See the License for the | ||
| * specific language governing permissions and limitations | ||
| * under the License. | ||
| */ | ||
|
|
||
| /*! | ||
| * \file expr_operator.h | ||
| * \brief Common operators defined for Expr. | ||
| * | ||
| * \note Most of the operator defined here perform simple constant folding | ||
| * when the type is int32 or int64 for simplifying the index expressions. | ||
| */ | ||
| // Acknowledgement: This file originates from incubator-tvm | ||
| // Acknowledgement: Most operator APIs originate from Halide. | ||
| #ifndef MXNET_EXPR_OPERATOR_H_ | ||
| #define MXNET_EXPR_OPERATOR_H_ | ||
|
|
||
| #include <mxnet/ir/expr.h> | ||
|
|
||
| namespace mxnet { | ||
|
|
||
| template<typename ValueType> | ||
| inline PrimExpr MakeConstScalar(MXNetDataType t, ValueType value) { | ||
| if (t.is_int()) return IntImm(t, static_cast<int64_t>(value)); | ||
| if (t.is_float()) return FloatImm(t, static_cast<double>(value)); | ||
| // customized type and uint is not supported for MXNet for now | ||
| LOG(FATAL) << "cannot make const for type " << t; | ||
| return PrimExpr(); | ||
| } | ||
|
|
||
|
|
||
| template<typename ValueType> | ||
| inline PrimExpr make_const(MXNetDataType t, ValueType value) { | ||
| if (t.lanes() == 1) { | ||
| return MakeConstScalar(t, value); | ||
| } else { | ||
| LOG(FATAL) << "MXNetDataType::lanes() != 1 is not supported "; | ||
| } | ||
| return PrimExpr(); | ||
| } | ||
|
|
||
| } // namespace mxnet | ||
|
|
||
| #endif // MXNET_EXPR_OPERATOR_H_ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,225 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one | ||
| * or more contributor license agreements. See the NOTICE file | ||
| * distributed with this work for additional information | ||
| * regarding copyright ownership. The ASF licenses this file | ||
| * to you under the Apache License, Version 2.0 (the | ||
| * "License"); you may not use this file except in compliance | ||
| * with the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, | ||
| * software distributed under the License is distributed on an | ||
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| * KIND, either express or implied. See the License for the | ||
| * specific language governing permissions and limitations | ||
| * under the License. | ||
| */ | ||
|
|
||
| /*! | ||
| * \file expr.h | ||
| * \brief Base expr nodes in MXNet. | ||
| */ | ||
| // Acknowledgement: This file originates from incubator-tvm | ||
| #ifndef MXNET_IR_EXPR_H_ | ||
| #define MXNET_IR_EXPR_H_ | ||
|
|
||
| #include <mxnet/runtime/object.h> | ||
| #include <mxnet/node/node.h> | ||
| #include <mxnet/node/container.h> | ||
| #include <mxnet/runtime/data_type.h> | ||
| #include <string> | ||
|
|
||
| namespace mxnet { | ||
|
|
||
| /*! | ||
| * \brief Base type of all the expressions. | ||
| * \sa Expr | ||
| */ | ||
| class BaseExprNode : public Object { | ||
| public: | ||
| static constexpr const char* _type_key = "Expr"; | ||
| MXNET_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object); | ||
| }; | ||
|
|
||
| /*! | ||
| * \brief Managed reference to BaseExprNode. | ||
| * \sa BaseExprNode | ||
| */ | ||
| class BaseExpr : public ObjectRef { | ||
| public: | ||
| /*! \brief Cosntructor */ | ||
| BaseExpr() {} | ||
| /*! | ||
| * \brief Cosntructor from object ptr. | ||
| * \param ptr The object pointer. | ||
| */ | ||
| explicit BaseExpr(runtime::ObjectPtr<Object> ptr) : ObjectRef(ptr) {} | ||
| /*! \brief The container type. */ | ||
| using ContainerType = BaseExprNode; | ||
| }; | ||
|
|
||
| /*! | ||
| * \brief Base node of all primitive expressions. | ||
| * | ||
| * A primitive expression deals with low-level | ||
| * POD data types and handles without | ||
| * doing life-cycle management for objects. | ||
| * | ||
| * PrimExpr is used in the low-level code | ||
| * optimizations and integer analysis. | ||
| * | ||
| * \sa PrimExpr | ||
| */ | ||
| class PrimExprNode : public BaseExprNode { | ||
| public: | ||
| /*! | ||
| * \brief The runtime data type of the primitive expression. | ||
| * | ||
| * MXNetDataType(dtype) provides coarse grained type information | ||
| * during compile time and runtime. It is eagerly built in | ||
| * PrimExpr expression construction and can be used for | ||
| * quick type checking. | ||
| * | ||
| * dtype is sufficient to decide the Type of the PrimExpr | ||
| * when it corresponds to POD value types such as i32. | ||
| * | ||
| * When dtype is MXNetDataType::Handle(), the expression could corresponds to | ||
| * a more fine-grained Type, and we can get the type by running lazy type inference. | ||
| */ | ||
| MXNetDataType dtype; | ||
|
|
||
| static constexpr const char* _type_key = "PrimExpr"; | ||
| MXNET_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode); | ||
| }; | ||
|
|
||
| /*! | ||
| * \brief Reference to PrimExprNode. | ||
| * \sa PrimExprNode | ||
| */ | ||
| class PrimExpr : public BaseExpr { | ||
| public: | ||
| /*! \brief Cosntructor */ | ||
| PrimExpr() {} | ||
| /*! | ||
| * \brief Cosntructor from object ptr. | ||
| * \param ptr The object pointer. | ||
| */ | ||
| explicit PrimExpr(runtime::ObjectPtr<Object> ptr) : BaseExpr(ptr) {} | ||
| /*! | ||
| * \brief construct from integer. | ||
| * \param value The value to be constructed. | ||
| */ | ||
| MXNET_DLL PrimExpr(int32_t value); // NOLINT(*) | ||
| /*! | ||
| * \brief construct from float. | ||
| * \param value The value to be constructed. | ||
| */ | ||
| MXNET_DLL PrimExpr(float value); // NOLINT(*) | ||
| /*! | ||
| * \brief construct from string. | ||
| * \param str The value to be constructed. | ||
| */ | ||
| MXNET_DLL PrimExpr(std::string str); // NOLINT(*) | ||
|
|
||
| /*! \return the data type of this expression. */ | ||
| MXNetDataType dtype() const { | ||
| return static_cast<const PrimExprNode*>(get())->dtype; | ||
| } | ||
| /*! \brief The container type. */ | ||
| using ContainerType = PrimExprNode; | ||
| }; | ||
|
|
||
| /*! | ||
| * \brief Constant integer literals in the program. | ||
| * \sa IntImm | ||
| */ | ||
| class IntImmNode : public PrimExprNode { | ||
| public: | ||
| /*! \brief the Internal value. */ | ||
| int64_t value; | ||
|
|
||
| static constexpr const char* _type_key = "IntImm"; | ||
| MXNET_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode); | ||
| }; | ||
|
|
||
| /*! | ||
| * \brief Managed reference class to IntImmNode. | ||
| * | ||
| * \sa IntImmNode | ||
| */ | ||
| class IntImm : public PrimExpr { | ||
| public: | ||
| /*! | ||
| * \brief Constructor | ||
| */ | ||
| IntImm() {} | ||
| /*! | ||
| * \brief constructor from node. | ||
| */ | ||
| explicit IntImm(runtime::ObjectPtr<Object> node) : PrimExpr(node) {} | ||
| /*! | ||
| * \brief Constructor. | ||
| * \param dtype The data type of the value. | ||
| * \param value The internal value. | ||
| */ | ||
| MXNET_DLL IntImm(MXNetDataType dtype, int64_t value); | ||
| /*! | ||
| * \brief Get pointer to the internal value. | ||
| * \return the content of the integer. | ||
| */ | ||
| const IntImmNode* operator->() const { | ||
| return static_cast<const IntImmNode*>(get()); | ||
| } | ||
| /*! \brief type indicate the container type */ | ||
| using ContainerType = IntImmNode; | ||
| }; | ||
|
|
||
| /*! | ||
| * \brief Constant floating point literals in the program. | ||
| * \sa FloatImm | ||
| */ | ||
| class FloatImmNode : public PrimExprNode { | ||
| public: | ||
| /*! \brief The constant value content. */ | ||
| double value; | ||
|
|
||
| static constexpr const char* _type_key = "FloatImm"; | ||
| MXNET_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode); | ||
| }; | ||
|
|
||
| /*! | ||
| * \brief Managed reference class to FloatImmNode. | ||
| * | ||
| * \sa FloatImmNode | ||
| */ | ||
| class FloatImm : public PrimExpr { | ||
| public: | ||
| /*! | ||
| * \brief Constructor | ||
| */ | ||
| FloatImm() {} | ||
| /*! | ||
| * \brief constructor from node. | ||
| */ | ||
| explicit FloatImm(runtime::ObjectPtr<Object> node) : PrimExpr(node) {} | ||
| /*! | ||
| * \brief Constructor. | ||
| * \param dtype The data type of the value. | ||
| * \param value The internal value. | ||
| */ | ||
| MXNET_DLL FloatImm(MXNetDataType dtype, double value); | ||
| /*! | ||
| * \brief Get pointer to the container. | ||
| * \return The pointer. | ||
| */ | ||
| const FloatImmNode* operator->() const { | ||
| return static_cast<const FloatImmNode*>(get()); | ||
| } | ||
| /*! \brief type indicate the container type */ | ||
| using ContainerType = FloatImmNode; | ||
| }; | ||
|
|
||
| } // namespace mxnet | ||
| #endif // MXNET_IR_EXPR_H_ |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.