diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 849e4606834e..59825d69d0d4 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1 +1 @@ -Thanks for contributing to TVM! Please refer to guideline https://docs.tvm.ai/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from [Reviewers](https://github.com/dmlc/tvm/blob/master/CONTRIBUTORS.md#reviewers). +Thanks for contributing to TVM! Please refer to guideline https://docs.tvm.ai/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from [Reviewers](https://github.com/dmlc/tvm/blob/master/CONTRIBUTORS.md#reviewers) by @ them in the pull request thread. diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core index 808f485387f9..3943914eed66 160000 --- a/3rdparty/dmlc-core +++ b/3rdparty/dmlc-core @@ -1 +1 @@ -Subproject commit 808f485387f9a03f78fa9f1159f387d0d91b7a28 +Subproject commit 3943914eed66470bd010df581e29e4dca4f7df6f diff --git a/CMakeLists.txt b/CMakeLists.txt index 10730ac718b4..a5f5f1428859 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,6 +7,7 @@ include(cmake/util/FindCUDA.cmake) include(cmake/util/FindVulkan.cmake) include(cmake/util/FindLLVM.cmake) include(cmake/util/FindROCM.cmake) +include(cmake/util/FindANTLR.cmake) if(EXISTS ${CMAKE_CURRENT_BINARY_DIR}/config.cmake) include(${CMAKE_CURRENT_BINARY_DIR}/config.cmake) @@ -33,6 +34,7 @@ tvm_option(USE_LLVM "Build with LLVM, can be set to specific llvm-config path" O tvm_option(USE_STACKVM_RUNTIME "Include stackvm into the runtime" OFF) tvm_option(USE_GRAPH_RUNTIME "Build with tiny graph runtime" ON) tvm_option(USE_GRAPH_RUNTIME_DEBUG "Build with tiny graph runtime debug mode" OFF) +tvm_option(USE_OPENMP "Build with OpenMP thread pool implementation" OFF) tvm_option(USE_RELAY_DEBUG "Building Relay in debug mode..." OFF) tvm_option(USE_SGX "Build with SGX" OFF) tvm_option(USE_RTTI "Build with RTTI" ON) @@ -154,6 +156,7 @@ list(APPEND COMPILER_SRCS ${RELAY_BACKEND_SRCS}) list(APPEND COMPILER_SRCS ${RELAY_IR_SRCS}) list(APPEND COMPILER_SRCS ${RELAY_QNN_SRCS}) + if(USE_VM_PROFILER) message(STATUS "Build compiler with Relay VM profiler support...") file(GLOB BACKEND_VM_PROFILER_SRCS src/relay/backend/vm/profiler/*.cc) @@ -233,6 +236,7 @@ include(cmake/modules/VTA.cmake) include(cmake/modules/CUDA.cmake) include(cmake/modules/OpenCL.cmake) include(cmake/modules/OpenGL.cmake) +include(cmake/modules/OpenMP.cmake) include(cmake/modules/Vulkan.cmake) include(cmake/modules/Metal.cmake) include(cmake/modules/ROCM.cmake) @@ -264,6 +268,7 @@ add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS}) add_library(tvm_topi SHARED ${TOPI_SRCS}) add_library(tvm_runtime SHARED ${RUNTIME_SRCS}) + if(USE_RELAY_DEBUG) message(STATUS "Building Relay in debug mode...") set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "USE_RELAY_DEBUG") diff --git a/Jenkinsfile b/Jenkinsfile index 6134023f9c21..a66c96f3396e 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -38,9 +38,15 @@ // - Tag the new version as the lates // - Periodically cleanup the old versions on local workers // + +// Hashtag in the source to build current CI docker builds +// +// - ci-cpu:v0.54: e7c88a99f830de30814df14eaa980547ecbd61c1 +// + ci_lint = "tvmai/ci-lint:v0.51" ci_gpu = "tvmai/ci-gpu:v0.54" -ci_cpu = "tvmai/ci-cpu:v0.52" +ci_cpu = "tvmai/ci-cpu:v0.54" ci_i386 = "tvmai/ci-i386:v0.52" // tvm libraries @@ -196,10 +202,10 @@ stage('Build') { make(ci_cpu, 'build', '-j2') pack_lib('cpu', tvm_lib) timeout(time: max_time, unit: 'MINUTES') { - sh "${docker_run} ${ci_cpu} ./tests/scripts/task_golang.sh" sh "${docker_run} ${ci_cpu} ./tests/scripts/task_python_unittest.sh" sh "${docker_run} ${ci_cpu} ./tests/scripts/task_python_integration.sh" sh "${docker_run} ${ci_cpu} ./tests/scripts/task_python_vta.sh" + sh "${docker_run} ${ci_cpu} ./tests/scripts/task_golang.sh" } } } diff --git a/cmake/config.cmake b/cmake/config.cmake index d92c2151d9c8..dbc2e80812fd 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -115,6 +115,10 @@ set(USE_BLAS none) # set(USE_MKL_PATH ) if using `pip install mkl` set(USE_MKL_PATH none) +# Whether use OpenMP thread pool, choices: gnu, intel +# Note: "gnu" uses gomp library, "intel" uses iomp5 library +set(USE_OPENMP none) + # Whether use contrib.random in runtime set(USE_RANDOM OFF) @@ -140,6 +144,10 @@ set(USE_ROCBLAS OFF) set(USE_SORT ON) # Build ANTLR parser for Relay text format +# Possible values: +# - ON: enable ANTLR by searching default locations (cmake find_program for antlr4 and /usr/local for jar) +# - OFF: disable ANTLR +# - /path/to/antlr-*-complete.jar: path to specific ANTLR jar file set(USE_ANTLR OFF) # Whether use Relay debug mode diff --git a/cmake/modules/ANTLR.cmake b/cmake/modules/ANTLR.cmake index 5842c819099d..d3c1b4218253 100644 --- a/cmake/modules/ANTLR.cmake +++ b/cmake/modules/ANTLR.cmake @@ -15,29 +15,7 @@ # specific language governing permissions and limitations # under the License. if(USE_ANTLR) - find_program(ANTLR4 antlr4) - - if (NOT ANTLR4) - file(GLOB_RECURSE ANTLR4JAR - /usr/local/lib/antlr-*-complete.jar - /usr/local/Cellar/*antlr-*-complete.jar) - - # Get the first element of the list of antlr jars. - # Sort and reverse the list so the item selected is the highest - # version in lib or else in Cellar if no lib installation exists. - list(SORT ANTLR4JAR) - list(REVERSE ANTLR4JAR) - list(GET ANTLR4JAR 0 ANTLR4JAR) - - set(JAVA_HOME $ENV{JAVA_HOME}) - if (NOT DEFINED JAVA_HOME) - # Hack to get system to search for Java itself. - set(JAVA_HOME "/usr") - endif() - - set(ANTLR4 ${JAVA_HOME}/bin/java -jar ${ANTLR4JAR}) - endif() - + find_antlr(${USE_ANTLR}) if(ANTLR4) set(RELAY_PARSER_DIR diff --git a/cmake/modules/OpenMP.cmake b/cmake/modules/OpenMP.cmake new file mode 100644 index 000000000000..5dd9be508342 --- /dev/null +++ b/cmake/modules/OpenMP.cmake @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# OpenMP Module +if(USE_OPENMP STREQUAL "gnu") + find_package(OpenMP) + if(OPENMP_FOUND) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") + list(APPEND TVM_RUNTIME_LINKER_LIBS ${OpenMP_CXX_LIBRARIES}) + add_definitions(-DTVM_THREADPOOL_USE_OPENMP=1) + message(STATUS "Build with OpenMP ${OpenMP_CXX_LIBRARIES}") + else() + add_definitions(-DTVM_THREADPOOL_USE_OPENMP=0) + message(WARNING "OpenMP cannot be found, use TVM threadpool instead.") + endif() +elseif(USE_OPENMP STREQUAL "intel") + find_package(OpenMP) + if(OPENMP_FOUND) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") + if (MSVC) + find_library(OMP_LIBRARY NAMES libiomp5md) + else() + find_library(OMP_LIBRARY NAMES iomp5) + endif() + list(APPEND TVM_RUNTIME_LINKER_LIBS ${OMP_LIBRARY}) + add_definitions(-DTVM_THREADPOOL_USE_OPENMP=1) + message(STATUS "Build with OpenMP " ${OMP_LIBRARY}) + else() + add_definitions(-DTVM_THREADPOOL_USE_OPENMP=0) + message(WARNING "OpenMP cannot be found, use TVM threadpool instead.") + endif() +else() + add_definitions(-DTVM_THREADPOOL_USE_OPENMP=0) +endif() diff --git a/cmake/util/FindANTLR.cmake b/cmake/util/FindANTLR.cmake new file mode 100644 index 000000000000..b68f90ead131 --- /dev/null +++ b/cmake/util/FindANTLR.cmake @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +####################################################### +# Enhanced version of find ANTLR. +# +# Usage: +# find_antlr(${USE_ANTLR}) +# +# - When USE_ANTLR=ON, use auto search by first trying to find antlr4 program, +# then trying to find antlr-*-complete.jar +# - When USE_ANTLR=/path/to/antlr-*-complete.jar, use provided jar +# +# Provide variables: +# - ANTLR4 +# +macro(find_antlr use_antlr) + set(JAVA_HOME $ENV{JAVA_HOME}) + if (NOT DEFINED JAVA_HOME) + # Hack to get system to search for Java itself. + message(STATUS "JAVA_HOME is not defined. Set it to ensure proper use") + set(JAVA_HOME "/usr") + endif() + if(MSVC) + set(JAVA_PROGRAM ${JAVA_HOME}/java.exe) + else() + set(JAVA_PROGRAM ${JAVA_HOME}/bin/java) + endif() + message(STATUS "Using Java at " ${JAVA_PROGRAM}) + + if (${use_antlr} STREQUAL "ON") + find_program(ANTLR4 antlr4) + if (NOT ANTLR4) + file(GLOB_RECURSE ANTLR4JAR + /usr/local/lib/antlr-*-complete.jar + /usr/local/Cellar/*antlr-*-complete.jar) + + # Get the first element of the list of antlr jars. + # Sort and reverse the list so the item selected is the highest + # version in lib or else in Cellar if no lib installation exists. + list(SORT ANTLR4JAR) + list(REVERSE ANTLR4JAR) + list(GET ANTLR4JAR 0 ANTLR4JAR) + + set(ANTLR4 ${JAVA_PROGRAM} -jar ${ANTLR4JAR}) + endif() + elseif(NOT ${use_antlr} STREQUAL "OFF") + set(ANTLR4 ${JAVA_PROGRAM} -jar ${use_antlr}) + endif() + message(STATUS "ANTLR4="${ANTLR4}) +endmacro(find_antlr) diff --git a/docker/install/ubuntu_install_nnpack.sh b/docker/install/ubuntu_install_nnpack.sh index 4f45f130e2e5..dc51fc28d492 100755 --- a/docker/install/ubuntu_install_nnpack.sh +++ b/docker/install/ubuntu_install_nnpack.sh @@ -6,9 +6,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,11 +22,14 @@ set -o pipefail apt-get update && apt-get install -y --no-install-recommends git cmake -# TODO: specific tag? git clone https://github.com/Maratyszcza/NNPACK NNPACK +git clone https://github.com/Maratyszcza/pthreadpool NNPACK/pthreadpool + +# Use specific versioning tag. (cd NNPACK && git checkout 1e005b0c2) +(cd NNPACK/pthreadpool && git checkout 13da0b4c) mkdir -p NNPACK/build cd NNPACK/build -cmake -DCMAKE_INSTALL_PREFIX:PATH=. -DNNPACK_INFERENCE_ONLY=OFF -DNNPACK_CONVOLUTION_ONLY=OFF -DNNPACK_BUILD_TESTS=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON .. && make -j4 && make install +cmake -DCMAKE_INSTALL_PREFIX:PATH=. -DNNPACK_INFERENCE_ONLY=OFF -DNNPACK_CONVOLUTION_ONLY=OFF -DNNPACK_BUILD_TESTS=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DPTHREADPOOL_SOURCE_DIR=pthreadpool .. && make -j4 && make install cd - diff --git a/docker/install/ubuntu_install_onnx.sh b/docker/install/ubuntu_install_onnx.sh index f3e8d8e8f540..54210b83f4d6 100755 --- a/docker/install/ubuntu_install_onnx.sh +++ b/docker/install/ubuntu_install_onnx.sh @@ -27,5 +27,4 @@ pip3 install onnx==1.5.0 # not expose that in the wheel!!! pip3 install future -pip3 install https://download.pytorch.org/whl/cu80/torch-1.0.1.post2-cp36-cp36m-linux_x86_64.whl -pip3 install torchvision +pip3 install torch==1.2.0 torchvision==0.4.0 diff --git a/docs/dev/virtual_machine.rst b/docs/dev/virtual_machine.rst index 2791ee71177e..cb08cc14e56e 100644 --- a/docs/dev/virtual_machine.rst +++ b/docs/dev/virtual_machine.rst @@ -121,7 +121,7 @@ AllocTensor Allocate a tensor value of the appropriate shape (stored in `shape_register`) and `dtype`. The result is saved to register `dst`. -AllocDatatype +AllocADT ^^^^^^^^^^^^^ **Arguments**: :: @@ -176,7 +176,7 @@ GetTagi RegName object RegName dst -Get the object tag for Datatype object in register `object`. And saves the reult to register `dst`. +Get the object tag for ADT object in register `object`. And saves the reult to register `dst`. Fatal ^^^^^ @@ -251,9 +251,9 @@ Currently, we support 3 types of objects: tensors, data types, and closures. :: - VMObject VMTensor(const tvm::runtime::NDArray& data); - VMObject VMDatatype(size_t tag, const std::vector& fields); - VMObject VMClosure(size_t func_index, std::vector free_vars); + Object Tensor(const tvm::runtime::NDArray& data); + Object ADT(size_t tag, const std::vector& fields); + Object Closure(size_t func_index, std::vector free_vars); Stack and State diff --git a/docs/frontend/tensorflow.md b/docs/frontend/tensorflow.md deleted file mode 100644 index 06a6fcc32b4f..000000000000 --- a/docs/frontend/tensorflow.md +++ /dev/null @@ -1,53 +0,0 @@ - - - - - - - - - - - - - - - - - -# Tensorflow Frontend -Tensorflow frontend helps in importing tensorflow released model into TVM. - -This document helps few steps while importing various different models from -[tensorflow research/slim](https://github.com/tensorflow/models/tree/master/research/slim). - -Current frontend is tested with all versions of below models -- Inception (V1/V2/V3/V4) -- Resnet (All) -- Mobilenet (V1/V2 All) -- Vgg (16/19) - -Tensorflow frontend expects a freezed protobuf format as input. - -Not all models are released as freezed protobuf. Some of them are checkpoints (.ckpt). -Please refer to [export](https://github.com/tensorflow/models/tree/master/research/slim#exporting-the-inference-graph) -and [freeze](https://github.com/tensorflow/models/tree/master/research/slim#freezing-the-exported-graph) -instructions to generate protobuf from checkpoint. - -## General Instructions - -### Add Shapes: -While freezing of protobuf add additional option ```add_shapes=True``` to embed output shapes of each node into graph. -You may use ```tvm.relay.testing.tf.AddShapesToGraphDef``` from nnvm for the same. -Please refer to [tensorflow tutorial](https://github.com/dmlc/tvm/blob/master/tutorials/nnvm/from_tensorflow.py). - -### Explicit Shape: -There might be situations where the add_shapes=True may not provide sufficient information about shape. -You may pass explicit dictionary of input shapes argument for ```from_tensorflow```. -Please refer to [test cases](https://github.com/dmlc/tvm/blob/master/nnvm/tests/python/frontend/tensorflow/test_forward.py#L36). - -### GPU: -Most of these tensorflow models are released for CPU with NHWC layout. -To compile for GPU we need to pass extra argument ```layout='NCHW'``` for from_tensorflow. -This option will do a layout conversion before and after for neural network ops. -Remaining nnvm build options for GPU compilation remain as it is. diff --git a/docs/frontend/tensorflow.rst b/docs/frontend/tensorflow.rst new file mode 100644 index 000000000000..827f5d637988 --- /dev/null +++ b/docs/frontend/tensorflow.rst @@ -0,0 +1,241 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +TensorFlow Frontend +=================== + +The TensorFlow frontend helps in importing TensorFlow models into TVM. + +Supported versions: + +- 1.12 and below + +Tested models: + +- Inception (V1/V2/V3/V4) +- Resnet (All) +- Mobilenet (V1/V2 All) +- Vgg (16/19) +- BERT (Base/3-layer) + +Preparing a Model for Inference +------------------------------- + +Remove Unneeded Nodes +~~~~~~~~~~~~~~~~~~~~~ + +The export process will remove many nodes that are not needed for inference, but unfortunately will leave some remaining. The nodes that should be manually removed are: + +- Dropout, including `Dropout`_ and `DropoutWrapper`_ +- `Assert`_ + +.. _Dropout: https://www.tensorflow.org/api_docs/python/tf/nn/dropout +.. _DropoutWrapper: https://www.tensorflow.org/versions/r1.12/api_docs/python/tf/nn/rnn_cell/DropoutWrapper?hl=hr +.. _Assert: https://www.tensorflow.org/api_docs/python/tf/debugging/Assert + +Convert None Dimensions to Constants +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +TVM has minimal support for dynamic tensor shapes. Dimensions that are ``None`` should be replaced with constants. For example, a model may accept an input with shape ``(None,20)``. This should be converted to a shape like ``(1,20)``. The model should be modified accordingly to ensure that these shapes match throughout the graph. + +Export +~~~~~~ + +TensorFlow frontend expects a frozen protobuf (.pb) or saved model as input. It currently does not support checkpoint (.ckpt). The graphdef needed by the TensorFlow frontend can be extracted from the active session, or by using the `TFParser`_ helper class. + +.. _TFParser: https://github.com/dmlc/tvm/blob/master/python/tvm/relay/frontend/tensorflow_parser.py + +The model should be exported with a number of transformations to prepare the model for inference. It is also important to set ```add_shapes=True```, as this will embed the output shapes of each node into the graph. Here is one function to export a model as a protobuf given a session: + +.. code:: python + + import tensorflow as tf + from tensorflow.tools.graph_transforms import TransformGraph + + def export_pb(session): + with tf.gfile.GFile("myexportedmodel.pb", "wb") as f: + inputs = ["myinput1", "myinput2"] # replace with your input names + outputs = ["myoutput1"] # replace with your output names + graph_def = session.graph.as_graph_def(add_shapes=True) + graph_def = tf.graph.util.convert_variables_to_constants(session, graph_def, outputs) + graph_def = TransformGraph( + graph_def, + inputs, + outputs, + [ + "remove_nodes(op=Identity, op=CheckNumerics, op=StopGradient)", + "sort_by_execution_order", # sort by execution order after each transform to ensure correct node ordering + "remove_device", + "sort_by_execution_order", + "fold_batch_norms", + "sort_by_execution_order", + "fold_old_batch_norms", + "sort_by_execution_order" + ] + ) + f.write(graph_def.SerializeToString()) + +Another method is to `export and freeze the graph `_. + +Import the Model +---------------- + +Explicit Shape: +~~~~~~~~~~~~~~~ + +To ensure shapes can be known throughout the entire graph, pass the ```shape``` argument to ```from_tensorflow```. This dictionary maps input names to input shapes. Please refer to these `test cases `_ for examples. + +Data Layout +~~~~~~~~~~~ + +Most TensorFlow models are released with NHWC layout. NCHW layout often provides better performance, especially on GPU. The TensorFlow frontend can automatically convert the model's data layout by passing the argument ```layout='NCHW'``` to ```from_tensorflow```. + +Best Practices +-------------- + +- Use static tensor shapes instead of dynamic shapes (remove ```None``` dimensions). +- Use static RNN instead of dynamic RNN, as ```TensorArray``` isn't supported yet. + +Supported Ops +------------- + +- Abs +- Add +- All +- ArgMax +- ArgMin +- AvgPool +- BatchMatMul +- BatchMatMulV2 +- BatchNormWithGlobalNormalization +- BatchToSpaceND +- BiasAdd +- BroadcastTo +- Cast +- Ceil +- CheckNumerics +- ClipByValue +- Concat +- ConcatV2 +- Conv2D +- Cos +- CropAndResize +- DecodeJpeg +- DepthwiseConv2dNative +- DepthToSpace +- Equal +- Elu +- Enter +- Erf +- Exit +- Exp +- ExpandDims +- Fill +- Floor +- FloorDiv +- FusedBatchNorm +- FusedBatchNormV2 +- Gather +- GatherNd +- GatherV2 +- Greater +- GreaterEqual +- Identity +- LeakyRelu +- LeftShift +- Less +- LessEqual +- Log +- Log1p +- LoopCond +- LogicalAnd +- LogicalOr +- LogicalNot +- LogSoftmax +- LRN +- LSTMBlockCell +- MatMul +- Max +- MaxPool +- Maximum +- Mean +- Merge +- Min +- Minimum +- MirrorPad +- Mod +- Mul +- Neg +- NextIteration +- NotEqual +- OneHot +- Pack +- Pad +- PadV2 +- Pow +- Prod +- Range +- Rank +- RealDiv +- Relu +- Relu6 +- Reshape +- ResizeBilinear +- ResizeBicubic +- ResizeNearestNeighbor +- ReverseV2 +- RightShift +- Round +- Rsqrt +- Select +- Selu +- Shape +- Sigmoid +- Sign +- Sin +- Size +- Slice +- Softmax +- Softplus +- SpaceToBatchND +- SpaceToDepth, +- Split +- SplitV +- Sqrt +- Square +- SquareDifference +- Squeeze +- StridedSlice +- Sub +- Sum +- Switch +- Tanh +- TensorArrayV3 +- TensorArrayScatterV3 +- TensorArrayGatherV3 +- TensorArraySizeV3 +- TensorArrayWriteV3 +- TensorArrayReadV3 +- TensorArraySplitV3 +- TensorArrayConcatV3 +- Tile +- TopKV2 +- Transpose +- TruncateMod +- Unpack +- Where +- ZerosLike diff --git a/docs/index.rst b/docs/index.rst index 9666fff0c5d3..f02dcc7c91e2 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -47,6 +47,12 @@ Developer Guide dev/index nnvm_top +Frontends +---------------- +.. toctree:: + :maxdepth: 1 + + frontend/tensorflow Index ----- diff --git a/docs/vta/install.md b/docs/vta/install.md index c43a167292b4..02c50fbba481 100644 --- a/docs/vta/install.md +++ b/docs/vta/install.md @@ -229,7 +229,7 @@ Now you can connect the power cable and serial port to boot the Angstrom Linux. > In this case, you might need to build the `zImage` file of your own from [socfpga-4.9.78-ltsi](https://github.com/altera-opensource/linux-socfpga/tree/socfpga-4.9.78-ltsi) branch of the [linux-socfpga](https://github.com/altera-opensource/linux-socfpga) repository. > For a quick fix, you can also download a prebuilt version of the `zImage` file [here](https://raw.githubusercontent.com/liangfu/de10-nano-supplement/master/zImage). -After connecting he usb cables to the DE10-Nano board, power on the board by connecting the power cable. You may then connect to the serial port of the device by using `minicom` on your host PC: +After connecting the usb cables to the DE10-Nano board, power on the board by connecting the power cable. You may then connect to the serial port of the device by using `minicom` on your host PC: ``` bash # NOTE: root privilege is typically required to run the following command. diff --git a/golang/src/value.go b/golang/src/value.go index 576331a8cfa0..5e0f78270eaa 100644 --- a/golang/src/value.go +++ b/golang/src/value.go @@ -44,8 +44,8 @@ var KTVMType = int32(C.kTVMType) var KTVMContext = int32(C.kTVMContext) // KArrayHandle is golang type code for TVM kArrayHandle. var KArrayHandle = int32(C.kArrayHandle) -// KNodeHandle is golang type code for TVM kNodeHandle. -var KNodeHandle = int32(C.kNodeHandle) +// KObjectHandle is golang type code for TVM kObjectHandle. +var KObjectHandle = int32(C.kObjectHandle) // KModuleHandle is gonag type code for TVM kModuleHandle. var KModuleHandle = int32(C.kModuleHandle) // KFuncHandle is gonalg type code for TVM kFuncHandle. diff --git a/include/tvm/api_registry.h b/include/tvm/api_registry.h index e12d841519ca..c41c3087f4ac 100644 --- a/include/tvm/api_registry.h +++ b/include/tvm/api_registry.h @@ -58,7 +58,7 @@ class EnvFuncNode : public Node { /*! \brief constructor */ EnvFuncNode() {} - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } @@ -79,7 +79,7 @@ class EnvFunc : public NodeRef { explicit EnvFunc(NodePtr n) : NodeRef(n) {} /*! \return The internal global function pointer */ const EnvFuncNode* operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } /*! * \brief Invoke the function. @@ -124,19 +124,19 @@ class TypedEnvFunc : public NodeRef { /*! \brief short hand for this function type */ using TSelf = TypedEnvFunc; TypedEnvFunc() {} - explicit TypedEnvFunc(NodePtr n) : NodeRef(n) {} + explicit TypedEnvFunc(ObjectPtr n) : NodeRef(n) {} /*! * \brief Assign global function to a TypedEnvFunc * \param other Another global function. * \return reference to self. */ TSelf& operator=(const EnvFunc& other) { - this->node_ = other.node_; + ObjectRef::operator=(other); return *this; } /*! \return The internal global function pointer */ const EnvFuncNode* operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } /*! * \brief Invoke the function. diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index 8be1c3604813..bda6ac647f55 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -60,7 +60,7 @@ class ConstIntBoundNode : public Node { int64_t min_value; int64_t max_value; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("min_value", &min_value); v->Visit("max_value", &max_value); } @@ -162,7 +162,7 @@ class ModularSetNode : public Node { /*! \brief The base */ int64_t base; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("coeff", &coeff); v->Visit("base", &base); } @@ -351,7 +351,7 @@ enum SignType { */ struct IntSetNode : public Node { static constexpr const char* _type_key = "IntSet"; - TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node); + TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Object); }; /*! @@ -362,7 +362,7 @@ class IntSet : public NodeRef { /*! \brief constructor */ IntSet() {} // constructor from not container. - explicit IntSet(NodePtr n) : NodeRef(n) {} + explicit IntSet(ObjectPtr n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -692,7 +692,7 @@ Array DetectClipBound(const Expr& e, // implementation inline const IntSetNode* IntSet::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } } // namespace arith } // namespace tvm diff --git a/include/tvm/attrs.h b/include/tvm/attrs.h index 3b64d1f961e2..2fbb9e6a866e 100644 --- a/include/tvm/attrs.h +++ b/include/tvm/attrs.h @@ -115,7 +115,7 @@ class AttrFieldInfoNode : public Node { /*! \brief detailed description of the type */ std::string description; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); v->Visit("type_info", &type_info); v->Visit("description", &description); @@ -163,7 +163,7 @@ class AttrsEqual { return lhs == rhs; } // node comparator - TVM_DLL bool operator()(const NodeRef& lhs, const NodeRef& rhs) const; + TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const; protected: friend class AttrsEqualHandler; @@ -197,13 +197,13 @@ class AttrsHash { size_t operator()(const std::string& value) const { return std::hash()(value); } - size_t operator()(const Type& value) const { + size_t operator()(const DataType& value) const { return std::hash()( static_cast(value.code()) | (static_cast(value.bits()) << 8) | (static_cast(value.lanes()) << 16)); } - TVM_DLL size_t operator()(const NodeRef& value) const; + TVM_DLL size_t operator()(const ObjectRef& value) const; private: friend class AttrsHashHandler; @@ -221,6 +221,8 @@ class BaseAttrsNode : public Node { public: using TVMArgs = runtime::TVMArgs; using TVMRetValue = runtime::TVMRetValue; + // visit function + virtual void VisitAttrs(AttrVisitor* v) {} /*! * \brief Initialize the attributes by sequence of arguments * \param args The postional arguments in the form @@ -260,7 +262,7 @@ class BaseAttrsNode : public Node { * \return The comparison result. */ TVM_DLL virtual bool ContentEqual( - const Node* other, AttrsEqual equal) const = 0; + const Object* other, AttrsEqual equal) const = 0; /*! * \brief Content aware hash. * \param hasher The hasher to run the hash. @@ -290,7 +292,7 @@ class Attrs : public NodeRef { private: /*! \return the internal attribute node */ const BaseAttrsNode* ptr() const { - return static_cast(node_.get()); + return static_cast(get()); } }; @@ -315,7 +317,7 @@ class DictAttrsNode : public BaseAttrsNode { void VisitNonDefaultAttrs(AttrVisitor* v) final; void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final; Array ListFieldInfo() const final; - bool ContentEqual(const Node* other, AttrsEqual equal) const final; + bool ContentEqual(const Object* other, AttrsEqual equal) const final; size_t ContentHash(AttrsHash hasher) const final; // type info static constexpr const char* _type_key = "DictAttrs"; @@ -369,7 +371,7 @@ class AttrsEqualVisitor { public: bool result_{true}; // constructor - AttrsEqualVisitor(const Node* lhs, const Node* rhs, const AttrsEqual& equal) + AttrsEqualVisitor(const Object* lhs, const Object* rhs, const AttrsEqual& equal) : lhs_(lhs), rhs_(rhs), equal_(equal) { } template @@ -387,8 +389,8 @@ class AttrsEqualVisitor { } private: - const Node* lhs_; - const Node* rhs_; + const Object* lhs_; + const Object* rhs_; const AttrsEqual& equal_; }; @@ -488,7 +490,7 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) { } else if (const ir::UIntImm* op = expr.as()) { *ptr = static_cast(op->value); } else { - LOG(FATAL) << "Expect int value, but get " << expr->type_key(); + LOG(FATAL) << "Expect int value, but get " << expr->GetTypeKey(); } } } @@ -521,7 +523,7 @@ inline void SetValue(double* ptr, const TVMArgValue& val) { } else if (const ir::UIntImm* op = expr.as()) { *ptr = static_cast(op->value); } else { - LOG(FATAL) << "Expect float value, but get " << expr->type_key(); + LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey(); } } } @@ -753,12 +755,12 @@ class AttrNonDefaultVisitor { template class AttrsNode : public BaseAttrsNode { public: - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { ::tvm::detail::AttrNormalVisitor vis(v); self()->__VisitAttrs__(vis); } - void VisitNonDefaultAttrs(AttrVisitor* v) final { + void VisitNonDefaultAttrs(AttrVisitor* v) { ::tvm::detail::AttrNonDefaultVisitor vis(v); self()->__VisitAttrs__(vis); } @@ -827,7 +829,7 @@ class AttrsNode : public BaseAttrsNode { return visitor.fields_; } - bool ContentEqual(const Node* other, AttrsEqual equal) const final { + bool ContentEqual(const Object* other, AttrsEqual equal) const final { DerivedType* pself = self(); if (pself == other) return true; if (other == nullptr) return false; @@ -839,7 +841,7 @@ class AttrsNode : public BaseAttrsNode { size_t ContentHash(AttrsHash hasher) const final { ::tvm::detail::AttrsHashVisitor visitor(hasher); - visitor.result_ = std::hash()(this->type_key()); + visitor.result_ = this->GetTypeKeyHash(); self()->__VisitAttrs__(visitor); return visitor.result_; } diff --git a/include/tvm/base.h b/include/tvm/base.h index f358f7f5d447..9b3b4cd3e8df 100644 --- a/include/tvm/base.h +++ b/include/tvm/base.h @@ -19,88 +19,16 @@ /*! * \file tvm/base.h - * \brief Defines the base data structure + * \brief Base utilities */ #ifndef TVM_BASE_H_ #define TVM_BASE_H_ #include -#include -#include -#include -#include -#include #include -#include "runtime/registry.h" namespace tvm { -using ::tvm::Node; -using ::tvm::NodeRef; -using ::tvm::AttrVisitor; - -/*! - * \brief Macro to define common node ref methods. - * \param TypeName The name of the NodeRef. - * \param BaseTypeName The Base type. - * \param NodeName The node container type. - */ -#define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \ - TypeName() {} \ - explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseTypeName(n) {} \ - const NodeName* operator->() const { \ - return static_cast(node_.get()); \ - } \ - operator bool() const { return this->defined(); } \ - using ContainerType = NodeName; - -/*! - * \brief Macro to define CopyOnWrite function in a NodeRef. - * \param NodeName The Type of the Node. - * - * CopyOnWrite will generate a unique copy of the internal node. - * The node will be copied if it is referenced by multiple places. - * The function returns the raw pointer to the node to allow modification - * of the content. - * - * \code - * - * MyCOWNodeRef ref, ref2; - * ref2 = ref; - * ref.CopyOnWrite()->value = new_value; - * assert(ref2->value == old_value); - * assert(ref->value == new_value); - * - * \endcode - */ -#define TVM_DEFINE_NODE_REF_COW(NodeName) \ - NodeName* CopyOnWrite() { \ - CHECK(node_ != nullptr); \ - if (!node_.unique()) { \ - NodePtr n = make_node(*(operator->())); \ - NodePtr(std::move(n)).swap(node_); \ - } \ - return static_cast(node_.get()); \ - } - -/*! \brief Macro to make it easy to define node ref type given node */ -#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \ - class TypeName : public ::tvm::NodeRef { \ - public: \ - TVM_DEFINE_NODE_REF_METHODS(TypeName, ::tvm::NodeRef, NodeName); \ - }; \ - -/*! - * \brief Macro to make it easy to define node ref type that - * has a CopyOnWrite member function. - */ -#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \ - class TypeName : public BaseType { \ - public: \ - TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseType, NodeName); \ - TVM_DEFINE_NODE_REF_COW(NodeName); \ - }; - /*! * \brief RAII wrapper function to enter and exit a context object * similar to python's with syntax. @@ -145,99 +73,6 @@ class With { ContextType ctx_; }; -/*! - * \brief save the node as well as all the node it depends on as json. - * This can be used to serialize any TVM object - * - * \return the string representation of the node. - */ -std::string SaveJSON(const NodeRef& node); - -/*! - * \brief Internal implementation of LoadJSON - * Load tvm Node object from json and return a shared_ptr of Node. - * \param json_str The json string to load from. - * - * \return The shared_ptr of the Node. - */ -NodePtr LoadJSON_(std::string json_str); - -/*! - * \brief Load the node from json string. - * This can be used to deserialize any TVM object. - * - * \param json_str The json string to load from. - * - * \tparam NodeType the nodetype - * - * \code - * Expr e = LoadJSON(json_str); - * \endcode - */ -template::value>::type > -inline NodeType LoadJSON(const std::string& json_str) { - return NodeType(LoadJSON_(json_str)); -} - -/*! - * \brief Registry entry for NodeFactory. - * - * There are two types of Nodes that can be serialized. - * The normal node requires a registration a creator function that - * constructs an empty Node of the corresponding type. - * - * The global singleton(e.g. global operator) where only global_key need to be serialized, - * in this case, FGlobalKey need to be defined. - */ -struct NodeFactoryReg { - /*! - * \brief creator function. - * \param global_key Key that identifies a global single object. - * If this is not empty then FGlobalKey - * \return The created function. - */ - using FCreate = std::function(const std::string& global_key)>; - /*! - * \brief Global key function, only needed by global objects. - * \param node The node pointer. - * \return node The global key to the node. - */ - using FGlobalKey = std::function; - /*! \brief registered name */ - std::string name; - /*! - * \brief The creator function - */ - FCreate fcreator = nullptr; - /*! - * \brief The global key function. - */ - FGlobalKey fglobal_key = nullptr; - // setter of creator - NodeFactoryReg& set_creator(FCreate f) { // NOLINT(*) - this->fcreator = f; - return *this; - } - // setter of creator - NodeFactoryReg& set_global_key(FGlobalKey f) { // NOLINT(*) - this->fglobal_key = f; - return *this; - } - // global registry singleton - TVM_DLL static ::dmlc::Registry<::tvm::NodeFactoryReg> *Registry(); -}; - -/*! - * \brief Register a Node type - * \note This is necessary to enable serialization of the Node. - */ -#define TVM_REGISTER_NODE_TYPE(TypeName) \ - static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \ - ::tvm::NodeFactoryReg::Registry()->__REGISTER__(TypeName::_type_key) \ - .set_creator([](const std::string&) { return ::tvm::make_node(); }) - - #define TVM_STRINGIZE_DETAIL(x) #x #define TVM_STRINGIZE(x) TVM_STRINGIZE_DETAIL(x) #define TVM_DESCRIBE(...) describe(__VA_ARGS__ "\n\nFrom:" __FILE__ ":" TVM_STRINGIZE(__LINE__)) diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h index 1233e9b0b89b..d2c2b40661e2 100644 --- a/include/tvm/buffer.h +++ b/include/tvm/buffer.h @@ -51,7 +51,7 @@ enum BufferType : int { class Buffer : public NodeRef { public: Buffer() {} - explicit Buffer(NodePtr n) : NodeRef(n) {} + explicit Buffer(ObjectPtr n) : NodeRef(n) {} /*! * \brief Return a new buffer that is equivalent with current one * but always add stride field. @@ -135,7 +135,7 @@ class BufferNode : public Node { /*! \brief constructor */ BufferNode() {} - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("data", &data); v->Visit("dtype", &dtype); v->Visit("shape", &shape); @@ -171,7 +171,7 @@ class BufferNode : public Node { }; inline const BufferNode* Buffer::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } /*! diff --git a/include/tvm/build_module.h b/include/tvm/build_module.h index 1d57d82e66c6..7114a4550331 100644 --- a/include/tvm/build_module.h +++ b/include/tvm/build_module.h @@ -61,7 +61,7 @@ class TargetNode : public Node { /*! \return the full device string to pass to codegen::Build */ TVM_DLL const std::string& str() const; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("target_name", &target_name); v->Visit("device_name", &device_name); v->Visit("device_type", &device_type); @@ -93,7 +93,7 @@ class TargetNode : public Node { class Target : public NodeRef { public: Target() {} - explicit Target(NodePtr n) : NodeRef(n) {} + explicit Target(ObjectPtr n) : NodeRef(n) {} /*! * \brief Create a Target given a string * \param target_str the string to parse @@ -110,7 +110,7 @@ class Target : public NodeRef { TVM_DLL static tvm::Target Current(bool allow_not_defined = true); const TargetNode* operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } using ContainerType = TargetNode; @@ -229,7 +229,7 @@ class BuildConfigNode : public Node { /*! \brief Whether to disable loop vectorization. */ bool disable_vectorize = false; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("data_alignment", &data_alignment); v->Visit("offset_factor", &offset_factor); v->Visit("double_buffer_split_loop", &double_buffer_split_loop); @@ -256,12 +256,12 @@ class BuildConfigNode : public Node { class BuildConfig : public ::tvm::NodeRef { public: BuildConfig() {} - explicit BuildConfig(NodePtr<::tvm::Node> n) : NodeRef(n) {} + explicit BuildConfig(ObjectPtr n) : NodeRef(n) {} const BuildConfigNode* operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } BuildConfigNode* operator->() { - return static_cast(node_.get()); + return static_cast(get_mutable()); } /*! * \brief Construct a BuildConfig containing a empty build config node. @@ -371,7 +371,7 @@ class GenericFuncNode; class GenericFunc : public NodeRef { public: GenericFunc() {} - explicit GenericFunc(NodePtr n) : NodeRef(n) {} + explicit GenericFunc(ObjectPtr n) : NodeRef(n) {} /*! * \brief Set the default function implementaiton. @@ -473,15 +473,17 @@ class GenericFuncNode : public Node { /* \brief map from keys to registered functions */ std::unordered_map dispatch_dict_; + void VisitAttrs(AttrVisitor* v) {} + static constexpr const char* _type_key = "GenericFunc"; TVM_DECLARE_NODE_TYPE_INFO(GenericFuncNode, Node); }; inline GenericFuncNode* GenericFunc::operator->() { - return static_cast(node_.get()); + return static_cast(get_mutable()); } -#define TVM_GENERIC_FUNC_REG_VAR_DEF \ +#define TVM_GENERIC_FUNC_REG_VAR_DEF \ static TVM_ATTRIBUTE_UNUSED ::tvm::GenericFunc& __mk_ ## TVM /*! diff --git a/include/tvm/c_dsl_api.h b/include/tvm/c_dsl_api.h deleted file mode 100644 index bbbb84926e8e..000000000000 --- a/include/tvm/c_dsl_api.h +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/c_dsl_api.h - * - * \brief TVM DSL Node C API, used to interact to DSL compilation. - * - * These are only a few functions needed for DSL construction time. - * These function are only available when link libtvm. - * If only TVM runtime is linked, calling these function will trigger error. - * - * \note Most API functions are registerd as PackedFunc and - * can be grabbed via TVMFuncGetGlobal - */ -#ifndef TVM_C_DSL_API_H_ -#define TVM_C_DSL_API_H_ - -#include "runtime/c_runtime_api.h" - -#ifdef __cplusplus -extern "C" { -#endif - -/*! \brief handle to node */ -typedef void* NodeHandle; - -/*! - * \brief free the node handle - * \param handle The node handle to be freed. - * \return 0 when success, -1 when failure happens - */ -TVM_DLL int TVMNodeFree(NodeHandle handle); - -/*! - * \brief Convert type key to type index. - * \param type_key The key of the type. - * \param out_index the corresponding type index. - * \return 0 when success, -1 when failure happens - */ -TVM_DLL int TVMNodeTypeKey2Index(const char* type_key, - int* out_index); - -/*! - * \brief Get runtime type index of the node. - * \param handle the node handle. - * \param out_index the corresponding type index. - * \return 0 when success, -1 when failure happens - */ -TVM_DLL int TVMNodeGetTypeIndex(NodeHandle handle, - int* out_index); - -/*! - * \brief get attributes given key - * \param handle The node handle - * \param key The attribute name - * \param out_value The attribute value - * \param out_type_code The type code of the attribute. - * \param out_success Whether get is successful. - * \return 0 when success, -1 when failure happens - * \note API calls always exchanges with type bits=64, lanes=1 - */ -TVM_DLL int TVMNodeGetAttr(NodeHandle handle, - const char* key, - TVMValue* out_value, - int* out_type_code, - int* out_success); - -/*! - * \brief get attributes names in the node. - * \param handle The node handle - * \param out_size The number of functions - * \param out_array The array of function names. - * \return 0 when success, -1 when failure happens - */ -TVM_DLL int TVMNodeListAttrNames(NodeHandle handle, - int *out_size, - const char*** out_array); -#ifdef __cplusplus -} // TVM_EXTERN_C -#endif -#endif // TVM_C_DSL_API_H_ diff --git a/include/tvm/channel.h b/include/tvm/channel.h index 143d4295f3e3..3a40a787d891 100644 --- a/include/tvm/channel.h +++ b/include/tvm/channel.h @@ -35,7 +35,7 @@ class Channel : public NodeRef { public: /*! \brief default constructor */ Channel() {} - explicit Channel(NodePtr n) : NodeRef(n) {} + explicit Channel(ObjectPtr n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -54,7 +54,7 @@ struct ChannelNode : public Node { /*! \brief default data type in read/write */ Type dtype; // visit all attributes - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("handle_var", &handle_var); v->Visit("dtype", &dtype); } @@ -67,7 +67,7 @@ struct ChannelNode : public Node { // Inline implementations inline const ChannelNode* Channel::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } } // namespace tvm #endif // TVM_CHANNEL_H_ diff --git a/include/tvm/data_layout.h b/include/tvm/data_layout.h index c2ae572de818..5e2cc08660db 100644 --- a/include/tvm/data_layout.h +++ b/include/tvm/data_layout.h @@ -104,7 +104,7 @@ class LayoutNode : public Node { */ Array axes; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); v->Visit("axes", &axes); } @@ -127,7 +127,7 @@ class LayoutNode : public Node { */ class Layout : public NodeRef { public: - explicit Layout(NodePtr n) : NodeRef(n) {} + explicit Layout(ObjectPtr n) : NodeRef(n) {} /*! \brief default constructor */ Layout() = default; @@ -152,7 +152,7 @@ class Layout : public NodeRef { * \return the pointer to the internal node container */ const LayoutNode* operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } /*! @@ -160,7 +160,7 @@ class Layout : public NodeRef { * \return the pointer to the internal node container */ LayoutNode* operator->() { - return static_cast(node_.get()); + return static_cast(get_mutable()); } /*! @@ -325,7 +325,7 @@ class BijectiveLayoutNode : public Node { /*! \brief The destination layout */ Layout dst_layout; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("src_layout", &src_layout); v->Visit("dst_layout", &dst_layout); v->Visit("forward_rule", &forward_rule); @@ -369,7 +369,7 @@ class BijectiveLayout : public NodeRef { }; inline const BijectiveLayoutNode* BijectiveLayout::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } } // namespace tvm diff --git a/include/tvm/expr.h b/include/tvm/expr.h index 201a2b485aa6..ea578152899d 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -27,8 +27,10 @@ #include #include #include +#include #include "base.h" #include "dtype.h" +#include "node/node.h" #include "node/container.h" #include "node/ir_functor.h" #include "runtime/c_runtime_api.h" @@ -49,7 +51,7 @@ class ExprNode : public Node { class Expr : public NodeRef { public: Expr() {} - explicit Expr(NodePtr ptr) : NodeRef(ptr) {} + explicit Expr(ObjectPtr ptr) : NodeRef(ptr) {} /*! * \brief construct from integer. * \param value The value to be constructed. @@ -110,7 +112,7 @@ class Variable : public ExprNode { static Var make(DataType dtype, std::string name_hint); - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("name", &name_hint); } @@ -122,7 +124,7 @@ class Variable : public ExprNode { /*! \brief a named variable in TVM */ class Var : public Expr { public: - explicit Var(NodePtr n) : Expr(n) {} + explicit Var(ObjectPtr n) : Expr(n) {} TVM_DLL explicit Var(std::string name_hint = "v", Type t = Int(32)); /*! @@ -145,7 +147,7 @@ class Var : public Expr { * \return the corresponding Variable. */ const Variable* get() const { - return static_cast(node_.get()); + return static_cast(data_.get()); } /*! \brief type indicate the container type */ using ContainerType = Variable; @@ -164,7 +166,7 @@ class IntImm : public ExprNode { /*! \brief the Internal value. */ int64_t value; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("value", &value); } @@ -187,7 +189,7 @@ class Integer : public Expr { /*! * \brief constructor from node. */ - explicit Integer(NodePtr node) : Expr(node) {} + explicit Integer(ObjectPtr node) : Expr(node) {} /*! * \brief Construct integer from int value. */ @@ -197,7 +199,7 @@ class Integer : public Expr { * \param other another expression. */ Integer& operator=(const Integer& other) { - node_ = other.node_; + data_ = other.data_; return *this; } /*! @@ -205,13 +207,13 @@ class Integer : public Expr { * \return the content of the integer. */ const IntImm* operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } /*! * \brief convert to int64_t */ operator int64_t() const { - CHECK(node_ != nullptr) + CHECK(data_ != nullptr) << " Trying to reference a null Integer"; return (*this)->value; } @@ -230,7 +232,7 @@ class RangeNode : public Node { RangeNode() {} RangeNode(Expr min, Expr extent) : min(min), extent(extent) {} - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("min", &min); v->Visit("extent", &extent); } @@ -346,7 +348,7 @@ class IterVar : public NodeRef { // construct a new iter var without a domain IterVar() {} // construct from shared ptr. - explicit IterVar(NodePtr n) : NodeRef(n) {} + explicit IterVar(ObjectPtr n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -406,7 +408,7 @@ class IterVarNode : public Node { */ std::string thread_tag; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dom", &dom); v->Visit("var", &var); v->Visit("iter_type", &iter_type); @@ -423,7 +425,7 @@ class IterVarNode : public Node { // inline implementations inline const IterVarNode* IterVar::operator->() const { - return static_cast(node_.get()); + return static_cast(data_.get()); } inline IterVar::operator Expr() const { @@ -481,16 +483,16 @@ class IRPrinter { : stream(stream) {} /*! \brief The node to be printed. */ - TVM_DLL void Print(const NodeRef& node); + TVM_DLL void Print(const ObjectRef& node); /*! \brief Print indent to the stream */ TVM_DLL void PrintIndent(); // Allow registration to be printer. - using FType = IRFunctor; + using FType = IRFunctor; TVM_DLL static FType& vtable(); }; // default print function for all nodes -inline std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*) +inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*) IRPrinter(os).Print(n); return os; } @@ -498,10 +500,7 @@ inline std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT namespace std { template <> -struct hash<::tvm::IterVar> { - std::size_t operator()(const ::tvm::IterVar& k) const { - return k.hash(); - } +struct hash<::tvm::IterVar> : public ::tvm::NodeHash { }; } #endif // TVM_EXPR_H_ diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 079f05f5a7f2..b6c3028d892f 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -45,7 +45,7 @@ class UIntImm : public ExprNode { /*! \brief The constant value content. */ uint64_t value; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("value", &value); } @@ -62,7 +62,7 @@ class FloatImm : public ExprNode { /*! \brief The constant value content. */ double value; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("value", &value); } @@ -79,7 +79,7 @@ class StringImm : public ExprNode { /*! \brief The constant value content. */ std::string value; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("value", &value); } @@ -99,7 +99,7 @@ class Cast : public ExprNode { /*! \brief Original data type. */ Expr value; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("value", &value); } @@ -122,7 +122,7 @@ class BinaryOpNode : public ExprNode { /*! \brief The right operand. */ Expr b; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &(this->type)); v->Visit("a", &a); v->Visit("b", &b); @@ -214,7 +214,7 @@ class CmpOpNode : public ExprNode { /*! \brief The right operand. */ Expr b; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &(this->type)); v->Visit("a", &a); v->Visit("b", &b); @@ -278,7 +278,7 @@ class And : public ExprNode { /*! \brief The right operand. */ Expr b; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &(this->type)); v->Visit("a", &a); v->Visit("b", &b); @@ -298,7 +298,7 @@ class Or : public ExprNode { /*! \brief The right operand. */ Expr b; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("a", &a); v->Visit("b", &b); @@ -316,7 +316,7 @@ class Not : public ExprNode { /*! \brief The input operand. */ Expr a; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("a", &a); } @@ -343,7 +343,7 @@ class Select : public ExprNode { /*! \brief value to be returned when condition is false. */ Expr false_value; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("condition", &condition); v->Visit("true_value", &true_value); @@ -380,7 +380,7 @@ class Load : public ExprNode { /*! \brief The predicate to mask which lanes would be loaded. */ Expr predicate; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("buffer_var", &buffer_var); v->Visit("index", &index); @@ -411,7 +411,7 @@ class Ramp : public ExprNode { /*! \brief Total number of lanes. */ int lanes; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("base", &base); v->Visit("stride", &stride); @@ -432,7 +432,7 @@ class Broadcast : public ExprNode { /*! \brief The numerb of lanes. */ int lanes; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("value", &value); v->Visit("lanes", &lanes); @@ -456,7 +456,7 @@ class Let : public ExprNode { /*! \brief The result expression. */ Expr body; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("var", &var); v->Visit("value", &value); @@ -522,7 +522,7 @@ class Call : public ExprNode { /*! \brief The output value index if func's value is a tuple. */ int value_index{0}; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("name", &name); v->Visit("args", &args); @@ -592,7 +592,7 @@ class Shuffle : public ExprNode { /*! \brief The indices of each element. */ Array indices; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("vectors", &vectors); v->Visit("indices", &indices); } @@ -652,7 +652,7 @@ class CommReducerNode : public Node { Array result, Array identity_element); - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("lhs", &lhs); v->Visit("rhs", &rhs); v->Visit("result", &result); @@ -664,10 +664,10 @@ class CommReducerNode : public Node { }; inline const CommReducerNode* CommReducer::get() const { - return static_cast(node_.get()); + return static_cast(data_.get()); } inline const CommReducerNode* CommReducer::operator->() const { - return static_cast(node_.get()); + return get(); } /*! \brief Reduction operator operator */ @@ -694,7 +694,7 @@ class Reduce : public ExprNode { Expr condition, int value_index); - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("combiner", &combiner); v->Visit("source", &source); @@ -710,7 +710,7 @@ class Reduce : public ExprNode { /*! \brief Any shape. */ class Any : public ExprNode { public: - void VisitAttrs(AttrVisitor* v) final {} + void VisitAttrs(AttrVisitor* v) {} /*! \brief Convert to var. */ Var ToVar() const { return Variable::make(Int(32), "any_dim"); @@ -735,7 +735,7 @@ class LetStmt : public StmtNode { /*! \brief The body block. */ Stmt body; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("var", &var); v->Visit("value", &value); v->Visit("body", &body); @@ -768,7 +768,7 @@ class AttrStmt : public StmtNode { /*! \brief The body statement to be executed */ Stmt body; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("node", &node); v->Visit("attr_key", &attr_key); v->Visit("value", &value); @@ -799,7 +799,7 @@ class AssertStmt : public StmtNode { */ Stmt body; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("condition", &condition); v->Visit("message", &message); v->Visit("body", &body); @@ -822,7 +822,7 @@ class ProducerConsumer : public StmtNode { /*! \brief Body to be executed. */ Stmt body; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("func", &func); v->Visit("is_producer", &is_producer); v->Visit("body", &body); @@ -863,7 +863,7 @@ class Store : public StmtNode { /*! \brief The predicate to mask which lanes would be stored. */ Expr predicate; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("buffer_var", &buffer_var); v->Visit("value", &value); v->Visit("index", &index); @@ -893,7 +893,7 @@ class Provide : public StmtNode { /*! \brief The index arguments of the function. */ Array args; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("func", &func); v->Visit("value_index", &value_index); v->Visit("value", &value); @@ -929,7 +929,7 @@ class Allocate : public StmtNode { Expr new_expr; std::string free_function; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("buffer_var", &buffer_var); v->Visit("dtype", &type); v->Visit("extents", &extents); @@ -972,7 +972,7 @@ class Free : public StmtNode { /*! \brief The buffer variable. */ Var buffer_var; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("buffer_var", &buffer_var); } @@ -1001,7 +1001,7 @@ class Realize : public StmtNode { /*! \brief The body of realization. */ Stmt body; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("func", &func); v->Visit("value_index", &value_index); v->Visit("dtype", &type); @@ -1031,7 +1031,7 @@ class Block : public StmtNode { /*! \brief The restof statments. */ Stmt rest; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("first", &first); v->Visit("rest", &rest); } @@ -1055,7 +1055,7 @@ class IfThenElse : public StmtNode { /*! \brief The branch to be executed when condition is false, can be null. */ Stmt else_case; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("condition", &condition); v->Visit("then_case", &then_case); v->Visit("else_case", &else_case); @@ -1078,7 +1078,7 @@ class Evaluate : public StmtNode { /*! \brief The expression to be evaluated. */ Expr value; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("value", &value); } @@ -1142,7 +1142,7 @@ class For : public StmtNode { DeviceAPI device_api, Stmt body); - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("loop_var", &loop_var); v->Visit("min", &min); v->Visit("extent", &extent); @@ -1169,7 +1169,7 @@ class Prefetch : public StmtNode { /*! \brief Bounds to be prefetched. */ Region bounds; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("func", &func); v->Visit("value_index", &value_index); v->Visit("type", &type); @@ -1310,6 +1310,16 @@ constexpr const char* opengl_stage_scope = "opengl_stage_scope"; */ constexpr const char* device_scope = "device_scope"; +/*! + * \brief Mark that the shape of TensorCore fragment + */ +constexpr const char* fragment_shape = "fragment_shape"; + +/*! + * \brief Mark that the layout of TensorCore fragment + */ +constexpr const char* fragment_layout = "fragment_layout"; + /*! * \brief Check if attr_key is a pragma key extension * \param attr_key The attr key to be compared @@ -1552,6 +1562,54 @@ constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit"; * } */ constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce"; +/*! + * \brief tvm intrinsic for tensor core load operators. + * + * void tvm_load_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k, + * Expr index, Expr buffer_ptr, Expr stride, + * StringImm layout) { + * // m, n, k are the shape of wmma fragment. + * // Determine fragment layout(column-major or row major) by layout. + * // fragments must be in 'wmma.matrix_a' or 'wmma.matrix_b' scope. + * nvcuda::wmma::load_matrix_sync(fragment[index], buffer_ptr, stride); + * } + */ +constexpr const char* tvm_load_matrix_sync = "tvm_load_matrix_sync"; +/*! + * \brief tvm intrinsic for tensor core mma_sync operators. + * + * void tvm_mma_sync(Var fragment_d, Expr index_d, + * Var fragment_a, Expr index_a, + * Var fragment_b, Expr index_b, + * Var fragment_c, Expr index_c) { + * nvcuda::wmma::mma_sync(fragment_d[index_d], fragment_a[index_a], + * fragment_b[index_b], fragment_c[index_c]); + * } + */ +constexpr const char* tvm_mma_sync = "tvm_mma_sync"; +/*! + * \brief tvm intrinsic for tensor core fill_fragment operators. + * + * void tvm_fill_fragment(Var fragment, UIntImm m, UIntImm, n, UIntImm k, + * Expr index, Expr value) { + * // m, n, k are the shape of wmma fragment + * // fragments must be in 'wmma.accumulator' scope. + * nvcuda::wmma::fill_fragment(fragment[index], value); + * } + */ +constexpr const char* tvm_fill_fragment = "tvm_fill_fragment"; +/*! + * \brief tvm intrinsic for tensor core store operators. + * + * void tvm_store_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k, + * Expr index, Expr buffer_ptr, Expr stride, + * StringImm layout) { + * // m, n, k are the shape of wmma fragment + * // fragments must be in 'wmma.accumulator' scope. + * nvcuda::wmma::store_matrix_sync(fragment[index], buffer_ptr, stride, layout); + * } + */ +constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync"; } // namespace intrinsic @@ -1576,7 +1634,7 @@ namespace std { template <> struct hash<::tvm::ir::TensorKey> { std::size_t operator()(const ::tvm::ir::TensorKey& k) const { - size_t lhs = k.f.hash(); + size_t lhs = ::tvm::NodeHash()(k.f); size_t rhs = static_cast(k.value_index); lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); return lhs; diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h index a7d91eacf851..54a5eff6846b 100644 --- a/include/tvm/ir_functor_ext.h +++ b/include/tvm/ir_functor_ext.h @@ -84,19 +84,19 @@ class StmtFunctor; } #define STMT_FUNCTOR_DEFAULT { \ return VisitStmtDefault_(op, std::forward(args)...); \ -} + } #define IR_EXPR_FUNCTOR_DISPATCH(OP) \ vtable.template set_dispatch( \ - [](const NodeRef& n, TSelf* self, Args... args) { \ - return self->VisitExpr_(static_cast(n.node_.get()), \ + [](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitExpr_(static_cast(n.get()), \ std::forward(args)...); \ }); \ #define IR_STMT_FUNCTOR_DISPATCH(OP) \ vtable.template set_dispatch( \ - [](const NodeRef& n, TSelf* self, Args... args) { \ - return self->VisitStmt_(static_cast(n.node_.get()), \ + [](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitStmt_(static_cast(n.get()), \ std::forward(args)...); \ }); \ @@ -104,7 +104,7 @@ template class ExprFunctor { private: using TSelf = ExprFunctor; - using FType = IRFunctor; + using FType = IRFunctor; public: /*! \brief the result type of this functor */ @@ -164,7 +164,7 @@ class ExprFunctor { virtual R VisitExpr_(const FloatImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const StringImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Node* op, Args ...) { - LOG(FATAL) << "Do not have a default for " << op->type_key(); + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); return R(); } @@ -213,7 +213,7 @@ template class StmtFunctor { private: using TSelf = StmtFunctor; - using FType = IRFunctor; + using FType = IRFunctor; public: /*! \brief the result type of this functor */ @@ -255,7 +255,7 @@ class StmtFunctor { virtual R VisitStmt_(const Block* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmtDefault_(const Node* op, Args ...) { - LOG(FATAL) << "Do not have a default for " << op->type_key(); + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); return R(); } diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h index b82a19d4689c..c910a48620c8 100644 --- a/include/tvm/ir_mutator.h +++ b/include/tvm/ir_mutator.h @@ -65,9 +65,9 @@ class TVM_DLL IRMutator { /*! \brief destructor */ virtual ~IRMutator() {} /*! \brief functor type of expr mutation */ - using FMutateExpr = IRFunctor; + using FMutateExpr = IRFunctor; /*! \brief functor type of stmt mutation */ - using FMutateStmt = IRFunctor; + using FMutateStmt = IRFunctor; /*! \return internal vtable of expr */ static FMutateExpr& vtable_expr(); // NOLINT(*) /*! \return internal stmt of expr */ diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 5ac71fdce47b..842c6af8cf5d 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -377,6 +377,13 @@ Stmt LowerStorageAccessInfo(Stmt stmt); */ Stmt DecorateDeviceScope(Stmt stmt); +/*! + * \brief Loop invariant code motion which locates and hoists if statements. + * \param stmt The stmt to do if statement hoisting. + * \return Transformed stmt. + */ +Stmt HoistIfThenElse(Stmt stmt); + /*! * \brief Make an user callable API LoweredFunc. * @@ -506,6 +513,15 @@ LoweredFunc CombineContextCall(LoweredFunc f); */ LoweredFunc PointerValueTypeRewrite(LoweredFunc f); +/*! + * \brief Lower attached storage access information on device. + * Do this pass after all storage access analysis finish. + * + * \param func The device function to be lowered. + * \return Transformed function. + */ +LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc func); + /*! * \brief Lower intrinsic function calls. * \param f The device function to be lowered. @@ -525,6 +541,14 @@ LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target); */ LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target); +/*! + * \brief Infer the TensorCore fragment infomation using tensor intrinsics + * + * \param f The device function to be lowered. + * \return Transformed function. + */ +LoweredFunc InferFragment(LoweredFunc f); + /*! * \brief Verify if memory accesses are legal for a specific target device type. * diff --git a/include/tvm/ir_visitor.h b/include/tvm/ir_visitor.h index f20b91368587..bebf94585ed6 100644 --- a/include/tvm/ir_visitor.h +++ b/include/tvm/ir_visitor.h @@ -49,7 +49,7 @@ namespace ir { * // The use case is to count number of Variables in the ir tree. * class MyCounter : public IRVisitor { * public: - * int Count(const NodeRef& n) { + * int Count(const ObjectRef& n) { * ret_ = 0; * this->Visit(n); * return ret_; @@ -94,7 +94,7 @@ class TVM_DLL IRVisitor { /*! \brief destructor */ virtual ~IRVisitor() {} /*! \brief functor type of visitor */ - using FVisit = IRFunctor; + using FVisit = IRFunctor; /*! \return internal vtable*/ static FVisit& vtable(); // overloadable visit function. diff --git a/include/tvm/lowered_func.h b/include/tvm/lowered_func.h index 4da93b80c2ab..6709f545cb39 100644 --- a/include/tvm/lowered_func.h +++ b/include/tvm/lowered_func.h @@ -44,7 +44,7 @@ class LoweredFuncNode; class LoweredFunc : public ir::FunctionRef { public: LoweredFunc() {} - explicit LoweredFunc(NodePtr n) : FunctionRef(n) {} + explicit LoweredFunc(ObjectPtr n) : FunctionRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -119,7 +119,7 @@ class LoweredFuncNode : public ir::FunctionBaseNode { int num_outputs() const final { return 1; } - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); v->Visit("args", &args); v->Visit("thread_axis", &thread_axis); @@ -136,17 +136,14 @@ class LoweredFuncNode : public ir::FunctionBaseNode { // Implementations of inline functions inline const LoweredFuncNode* LoweredFunc::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } } // namespace tvm namespace std { template <> -struct hash<::tvm::LoweredFunc> { - std::size_t operator()(const ::tvm::LoweredFunc& k) const { - return k.hash(); - } +struct hash<::tvm::LoweredFunc> : public tvm::NodeHash { }; } diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h index c2c639e374f5..c36c6c141451 100644 --- a/include/tvm/node/container.h +++ b/include/tvm/node/container.h @@ -38,67 +38,49 @@ namespace tvm { class ArrayNode : public Node { public: /*! \brief the data content */ - std::vector > data; + std::vector data; - void VisitAttrs(AttrVisitor* visitor) final { - // Visitor to array have no effect. + void VisitAttrs(AttrVisitor* visitor) { } static constexpr const char* _type_key = "Array"; - TVM_DECLARE_NODE_TYPE_INFO(ArrayNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Node); }; /*! \brief map node content */ class MapNode : public Node { public: - void VisitAttrs(AttrVisitor* visitor) final { - // Visitor to map have no effect. + void VisitAttrs(AttrVisitor* visitor) { } - // hash function - struct Hash { - size_t operator()(const NodePtr& n) const { - return std::hash()(n.get()); - } - }; - // comparator - struct Equal { - bool operator()( - const NodePtr& a, - const NodePtr& b) const { - return a.get() == b.get(); - } - }; /*! \brief The corresponding conatiner type */ using ContainerType = std::unordered_map< - NodePtr, - NodePtr, - Hash, Equal>; + ObjectRef, + ObjectRef, + ObjectHash, ObjectEqual>; /*! \brief the data content */ ContainerType data; static constexpr const char* _type_key = "Map"; - TVM_DECLARE_NODE_TYPE_INFO(MapNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Node); }; /*! \brief specialized map node with string as key */ class StrMapNode : public Node { public: - void VisitAttrs(AttrVisitor* visitor) final { - // Visitor to map have no effect. - } /*! \brief The corresponding conatiner type */ - using ContainerType = std::unordered_map< - std::string, - NodePtr >; + using ContainerType = std::unordered_map; + + void VisitAttrs(AttrVisitor* visitor) { + } /*! \brief the data content */ ContainerType data; static constexpr const char* _type_key = "StrMap"; - TVM_DECLARE_NODE_TYPE_INFO(StrMapNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(StrMapNode, Node); }; /*! @@ -111,9 +93,9 @@ template::difference_type; - using value_type = typename std::iterator_traits::value_type; - using pointer = typename std::iterator_traits::pointer; - using reference = typename std::iterator_traits::reference; + using value_type = typename Converter::ResultType; + using pointer = typename Converter::ResultType*; + using reference = typename Converter::ResultType&; // NOLINT(*) using iterator_category = typename std::iterator_traits::iterator_category; explicit IterAdapter(TIter iter) : iter_(iter) {} @@ -138,7 +120,7 @@ class IterAdapter { inline bool operator!=(IterAdapter other) const { return !(*this == other); } - inline const typename Converter::ResultType operator*() const { + inline const value_type operator*() const { return Converter::convert(*iter_); } @@ -162,26 +144,27 @@ class Array : public NodeRef { * \brief default constructor */ Array() { - node_ = make_node(); + data_ = make_node(); } /*! * \brief move constructor * \param other source */ Array(Array && other) { // NOLINT(*) - node_ = std::move(other.node_); + data_ = std::move(other.data_); } /*! * \brief copy constructor * \param other source */ - Array(const Array &other) : NodeRef(other.node_) { // NOLINT(*) + Array(const Array &other) { // NOLINT(*) + data_ = std::move(other.data_); } /*! * \brief constructor from pointer * \param n the container pointer */ - explicit Array(NodePtr n) : NodeRef(n) {} + explicit Array(ObjectPtr n) : NodeRef(n) {} /*! * \brief constructor from iterator * \param begin begin of iterator @@ -214,9 +197,9 @@ class Array : public NodeRef { explicit Array(size_t n, const T& val) { auto tmp_node = make_node(); for (size_t i = 0; i < n; ++i) { - tmp_node->data.push_back(val.node_); + tmp_node->data.push_back(val); } - node_ = std::move(tmp_node); + data_ = std::move(tmp_node); } /*! * \brief move assign operator @@ -224,7 +207,7 @@ class Array : public NodeRef { * \return reference to self. */ Array& operator=(Array && other) { - node_ = std::move(other.node_); + data_ = std::move(other.data_); return *this; } /*! @@ -233,7 +216,7 @@ class Array : public NodeRef { * \return reference to self. */ Array& operator=(const Array & other) { - node_ = other.node_; + data_ = other.data_; return *this; } /*! @@ -246,9 +229,9 @@ class Array : public NodeRef { void assign(IterType begin, IterType end) { auto n = make_node(); for (IterType it = begin; it != end; ++it) { - n->data.push_back((*it).node_); + n->data.push_back(T(*it)); } - node_ = std::move(n); + data_ = std::move(n); } /*! * \brief Read i-th element from array. @@ -256,12 +239,13 @@ class Array : public NodeRef { * \return the i-th element. */ inline const T operator[](size_t i) const { - return T(static_cast(node_.get())->data[i]); + return DowncastNoCheck( + static_cast(data_.get())->data[i]); } /*! \return The size of the array */ inline size_t size() const { - if (node_.get() == nullptr) return 0; - return static_cast(node_.get())->data.size(); + if (data_.get() == nullptr) return 0; + return static_cast(data_.get())->data.size(); } /*! * \brief copy on write semantics @@ -272,12 +256,12 @@ class Array : public NodeRef { * \return Handle to the internal node container(which ganrantees to be unique) */ inline ArrayNode* CopyOnWrite() { - if (node_.get() == nullptr || !node_.unique()) { + if (data_.get() == nullptr || !data_.unique()) { NodePtr n = make_node(); - n->data = static_cast(node_.get())->data; - NodePtr(std::move(n)).swap(node_); + n->data = static_cast(data_.get())->data; + ObjectPtr(std::move(n)).swap(data_); } - return static_cast(node_.get()); + return static_cast(data_.get()); } /*! * \brief push a new item to the back of the list @@ -285,7 +269,7 @@ class Array : public NodeRef { */ inline void push_back(const T& item) { ArrayNode* n = this->CopyOnWrite(); - n->data.push_back(item.node_); + n->data.push_back(item); } /*! * \brief set i-th element of the array. @@ -294,7 +278,7 @@ class Array : public NodeRef { */ inline void Set(size_t i, const T& value) { ArrayNode* n = this->CopyOnWrite(); - n->data[i] = value.node_; + n->data[i] = value; } /*! \return whether array is empty */ inline bool empty() const { @@ -303,34 +287,34 @@ class Array : public NodeRef { /*! \brief specify container node */ using ContainerType = ArrayNode; - struct Ptr2NodeRef { + struct ValueConverter { using ResultType = T; - static inline T convert(const NodePtr& n) { - return T(n); + static inline T convert(const ObjectRef& n) { + return DowncastNoCheck(n); } }; - using iterator = IterAdapter >::const_iterator>; + using iterator = IterAdapter::const_iterator>; using reverse_iterator = IterAdapter< - Ptr2NodeRef, - std::vector >::const_reverse_iterator>; + ValueConverter, + std::vector::const_reverse_iterator>; /*! \return begin iterator */ inline iterator begin() const { - return iterator(static_cast(node_.get())->data.begin()); + return iterator(static_cast(data_.get())->data.begin()); } /*! \return end iterator */ inline iterator end() const { - return iterator(static_cast(node_.get())->data.end()); + return iterator(static_cast(data_.get())->data.end()); } /*! \return rbegin iterator */ inline reverse_iterator rbegin() const { - return reverse_iterator(static_cast(node_.get())->data.rbegin()); + return reverse_iterator(static_cast(data_.get())->data.rbegin()); } /*! \return rend iterator */ inline reverse_iterator rend() const { - return reverse_iterator(static_cast(node_.get())->data.rend()); + return reverse_iterator(static_cast(data_.get())->data.rend()); } }; @@ -355,26 +339,26 @@ class Map : public NodeRef { * \brief default constructor */ Map() { - node_ = make_node(); + data_ = make_node(); } /*! * \brief move constructor * \param other source */ Map(Map && other) { // NOLINT(*) - node_ = std::move(other.node_); + data_ = std::move(other.data_); } /*! * \brief copy constructor * \param other source */ - Map(const Map &other) : NodeRef(other.node_) { // NOLINT(*) + Map(const Map &other) : NodeRef(other.data_) { // NOLINT(*) } /*! * \brief constructor from pointer * \param n the container pointer */ - explicit Map(NodePtr n) : NodeRef(n) {} + explicit Map(ObjectPtr n) : NodeRef(n) {} /*! * \brief constructor from iterator * \param begin begin of iterator @@ -406,7 +390,7 @@ class Map : public NodeRef { * \return reference to self. */ Map& operator=(Map && other) { - node_ = std::move(other.node_); + data_ = std::move(other.data_); return *this; } /*! @@ -415,7 +399,7 @@ class Map : public NodeRef { * \return reference to self. */ Map& operator=(const Map & other) { - node_ = other.node_; + data_ = other.data_; return *this; } /*! @@ -428,10 +412,9 @@ class Map : public NodeRef { void assign(IterType begin, IterType end) { NodePtr n = make_node(); for (IterType i = begin; i != end; ++i) { - n->data.emplace(std::make_pair(i->first.node_, - i->second.node_)); + n->data.emplace(std::make_pair(i->first, i->second)); } - node_ = std::move(n); + data_ = std::move(n); } /*! * \brief Read element from map. @@ -439,7 +422,8 @@ class Map : public NodeRef { * \return the corresonding element. */ inline const V operator[](const K& key) const { - return V(static_cast(node_.get())->data.at(key.node_)); + return DowncastNoCheck( + static_cast(data_.get())->data.at(key)); } /*! * \brief Read element from map. @@ -447,17 +431,18 @@ class Map : public NodeRef { * \return the corresonding element. */ inline const V at(const K& key) const { - return V(static_cast(node_.get())->data.at(key.node_)); + return DowncastNoCheck( + static_cast(data_.get())->data.at(key)); } /*! \return The size of the array */ inline size_t size() const { - if (node_.get() == nullptr) return 0; - return static_cast(node_.get())->data.size(); + if (data_.get() == nullptr) return 0; + return static_cast(data_.get())->data.size(); } /*! \return The number of elements of the key */ inline size_t count(const K& key) const { - if (node_.get() == nullptr) return 0; - return static_cast(node_.get())->data.count(key.node_); + if (data_.get() == nullptr) return 0; + return static_cast(data_.get())->data.count(key); } /*! * \brief copy on write semantics @@ -468,12 +453,12 @@ class Map : public NodeRef { * \return Handle to the internal node container(which ganrantees to be unique) */ inline MapNode* CopyOnWrite() { - if (node_.get() == nullptr || !node_.unique()) { + if (data_.get() == nullptr || !data_.unique()) { NodePtr n = make_node(); - n->data = static_cast(node_.get())->data; - NodePtr(std::move(n)).swap(node_); + n->data = static_cast(data_.get())->data; + ObjectPtr(std::move(n)).swap(data_); } - return static_cast(node_.get()); + return static_cast(data_.get()); } /*! * \brief set the Map. @@ -482,7 +467,7 @@ class Map : public NodeRef { */ inline void Set(const K& key, const V& value) { MapNode* n = this->CopyOnWrite(); - n->data[key.node_] = value.node_; + n->data[key] = value; } /*! \return whether array is empty */ @@ -492,29 +477,31 @@ class Map : public NodeRef { /*! \brief specify container node */ using ContainerType = MapNode; - struct Ptr2NodeRef { + struct ValueConverter { using ResultType = std::pair; static inline ResultType convert(const std::pair< - NodePtr, - NodePtr >& n) { - return std::make_pair(K(n.first), V(n.second)); + ObjectRef, + ObjectRef>& n) { + return std::make_pair(DowncastNoCheck(n.first), + DowncastNoCheck(n.second)); } }; using iterator = IterAdapter< - Ptr2NodeRef, MapNode::ContainerType::const_iterator>; + ValueConverter, MapNode::ContainerType::const_iterator>; /*! \return begin iterator */ inline iterator begin() const { - return iterator(static_cast(node_.get())->data.begin()); + return iterator(static_cast(data_.get())->data.begin()); } /*! \return end iterator */ inline iterator end() const { - return iterator(static_cast(node_.get())->data.end()); + return iterator(static_cast(data_.get())->data.end()); } /*! \return begin iterator */ inline iterator find(const K& key) const { - return iterator(static_cast(node_.get())->data.find(key.node_)); + return iterator( + static_cast(data_.get())->data.find(key)); } }; @@ -524,14 +511,14 @@ class Map : public NodeRef { public: // for code reuse Map() { - node_ = make_node(); + data_ = make_node(); } Map(Map && other) { // NOLINT(*) - node_ = std::move(other.node_); + data_ = std::move(other.data_); } - Map(const Map &other) : NodeRef(other.node_) { // NOLINT(*) + Map(const Map &other) : NodeRef(other.data_) { // NOLINT(*) } - explicit Map(NodePtr n) : NodeRef(n) {} + explicit Map(ObjectPtr n) : NodeRef(n) {} template Map(IterType begin, IterType end) { assign(begin, end); @@ -545,76 +532,77 @@ class Map : public NodeRef { assign(init.begin(), init.end()); } Map& operator=(Map && other) { - node_ = std::move(other.node_); + data_ = std::move(other.data_); return *this; } Map& operator=(const Map & other) { - node_ = other.node_; + data_ = other.data_; return *this; } template void assign(IterType begin, IterType end) { auto n = make_node(); for (IterType i = begin; i != end; ++i) { - n->data.emplace(std::make_pair(i->first, - i->second.node_)); + n->data.emplace(std::make_pair(i->first, i->second)); } - node_ = std::move(n); + data_ = std::move(n); } inline const V operator[](const std::string& key) const { - return V(static_cast(node_.get())->data.at(key)); + return DowncastNoCheck( + static_cast(data_.get())->data.at(key)); } inline const V at(const std::string& key) const { - return V(static_cast(node_.get())->data.at(key)); + return DowncastNoCheck( + static_cast(data_.get())->data.at(key)); } inline size_t size() const { - if (node_.get() == nullptr) return 0; - return static_cast(node_.get())->data.size(); + if (data_.get() == nullptr) return 0; + return static_cast(data_.get())->data.size(); } inline size_t count(const std::string& key) const { - if (node_.get() == nullptr) return 0; - return static_cast(node_.get())->data.count(key); + if (data_.get() == nullptr) return 0; + return static_cast(data_.get())->data.count(key); } inline StrMapNode* CopyOnWrite() { - if (node_.get() == nullptr || !node_.unique()) { + if (data_.get() == nullptr || !data_.unique()) { NodePtr n = make_node(); - n->data = static_cast(node_.get())->data; - NodePtr(std::move(n)).swap(node_); + n->data = static_cast(data_.get())->data; + ObjectPtr(std::move(n)).swap(data_); } - return static_cast(node_.get()); + return static_cast(data_.get()); } inline void Set(const std::string& key, const V& value) { StrMapNode* n = this->CopyOnWrite(); - n->data[key] = value.node_; + n->data[key] = value; } inline bool empty() const { return size() == 0; } using ContainerType = StrMapNode; - struct Ptr2NodeRef { + struct ValueConverter { using ResultType = std::pair; static inline ResultType convert(const std::pair< - std::string, - NodePtr >& n) { - return std::make_pair(n.first, V(n.second)); + std::string, + ObjectRef>& n) { + return std::make_pair(n.first, DowncastNoCheck(n.second)); } }; using iterator = IterAdapter< - Ptr2NodeRef, StrMapNode::ContainerType::const_iterator>; + ValueConverter, StrMapNode::ContainerType::const_iterator>; /*! \return begin iterator */ inline iterator begin() const { - return iterator(static_cast(node_.get())->data.begin()); + return iterator(static_cast(data_.get())->data.begin()); } /*! \return end iterator */ inline iterator end() const { - return iterator(static_cast(node_.get())->data.end()); + return iterator(static_cast(data_.get())->data.end()); } /*! \return begin iterator */ inline iterator find(const std::string& key) const { - return iterator(static_cast(node_.get())->data.find(key)); + return iterator(static_cast(data_.get())->data.find(key)); } }; diff --git a/include/tvm/node/ir_functor.h b/include/tvm/node/ir_functor.h index 23c5a3fafdab..e902e8fb6d44 100644 --- a/include/tvm/node/ir_functor.h +++ b/include/tvm/node/ir_functor.h @@ -34,10 +34,10 @@ namespace tvm { /*! - * \brief A dynamically dispatched functor on NodeRef in the first argument. + * \brief A dynamically dispatched functor on ObjectRef in the first argument. * * \code - * IRFunctor tostr; + * IRFunctor tostr; * tostr.set_dispatch([](const Add* op, std::string prefix) { * return prefix + "Add"; * }); @@ -60,10 +60,10 @@ template class IRFunctor; template -class IRFunctor { +class IRFunctor { private: - using Function = std::function; - using TSelf = IRFunctor; + using Function = std::function; + using TSelf = IRFunctor; /*! \brief internal function table */ std::vector func_; @@ -75,8 +75,8 @@ class IRFunctor { * \param n The node to be dispatched * \return Whether dispatching function is registered for n's type. */ - inline bool can_dispatch(const NodeRef& n) const { - uint32_t type_index = n.type_index(); + inline bool can_dispatch(const ObjectRef& n) const { + uint32_t type_index = n->type_index(); return type_index < func_.size() && func_[type_index] != nullptr; } /*! @@ -85,12 +85,12 @@ class IRFunctor { * \param args The additional arguments * \return The result. */ - inline R operator()(const NodeRef& n, Args... args) const { - uint32_t type_index = n.type_index(); + inline R operator()(const ObjectRef& n, Args... args) const { + uint32_t type_index = n->type_index(); CHECK(type_index < func_.size() && func_[type_index] != nullptr) << "IRFunctor calls un-registered function on type " - << Node::TypeIndex2Key(type_index); + << n->GetTypeKey(); return func_[type_index](n, std::forward(args)...); } /*! @@ -101,19 +101,19 @@ class IRFunctor { */ template inline TSelf& set_dispatch(Function f) { // NOLINT(*) - uint32_t tindex = Node::TypeKey2Index(TNode::_type_key); + uint32_t tindex = TNode::RuntimeTypeIndex(); if (func_.size() <= tindex) { func_.resize(tindex + 1, nullptr); } CHECK(func_[tindex] == nullptr) - << "Dispatch for " << Node::TypeIndex2Key(tindex) + << "Dispatch for " << TNode::_type_key << " is already set"; func_[tindex] = f; return *this; } /*! * \brief set the dispacher for type TNode - * This allows f to used detailed const Node pointer to replace NodeRef + * This allows f to used detailed const Node pointer to replace ObjectRef * * \param f The function to be set. * \tparam TNode the type of Node to be dispatched. @@ -121,8 +121,8 @@ class IRFunctor { */ template inline TSelf& set_dispatch(std::function f) { // NOLINT(*) - Function fun = [f](const NodeRef& n, Args... args) { - return f(static_cast(n.node_.get()), + Function fun = [f](const ObjectRef& n, Args... args) { + return f(static_cast(n.get()), std::forward(args)...); }; return this->set_dispatch(fun); @@ -135,7 +135,7 @@ class IRFunctor { */ template inline TSelf& clear_dispatch() { // NOLINT(*) - uint32_t tindex = Node::TypeKey2Index(TNode::_type_key); + uint32_t tindex = TNode::RuntimeTypeIndex(); CHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range"; func_[tindex] = nullptr; return *this; @@ -172,7 +172,7 @@ class IRFunctor { * f(e, this); * } * - * using FType = IRFunctor; + * using FType = IRFunctor; * // function to return global function table * static FType& vtable(); * }; @@ -232,15 +232,15 @@ template class IRFunctorStaticRegistry; template -class IRFunctorStaticRegistry { +class IRFunctorStaticRegistry { private: - IRFunctor *irf_; + IRFunctor *irf_; std::shared_ptr free_list; - using TSelf = IRFunctorStaticRegistry; + using TSelf = IRFunctorStaticRegistry; public: - IRFunctorStaticRegistry(IRFunctor *irf) { + IRFunctorStaticRegistry(IRFunctor *irf) { irf_ = irf; free_list = std::make_shared(); } @@ -261,12 +261,12 @@ class IRFunctorStaticRegistry { * the compiler to deduce the template types. */ template -IRFunctorStaticRegistry MakeIRFunctorStaticRegistry( - IRFunctor *irf) { - return IRFunctorStaticRegistry(irf); +IRFunctorStaticRegistry MakeIRFunctorStaticRegistry( + IRFunctor *irf) { + return IRFunctorStaticRegistry(irf); } -#define TVM_AUTO_REGISTER_VAR_DEF(ClsName) \ +#define TVM_AUTO_REGISTER_VAR_DEF(ClsName) \ static TVM_ATTRIBUTE_UNUSED auto __make_functor ## _ ## ClsName /*! diff --git a/include/tvm/node/memory.h b/include/tvm/node/memory.h deleted file mode 100644 index 1bba57144e19..000000000000 --- a/include/tvm/node/memory.h +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/node/memory.h - * \brief Node memory management. - */ -#ifndef TVM_NODE_MEMORY_H_ -#define TVM_NODE_MEMORY_H_ - -#include -#include "node.h" - -namespace tvm { -/*! - * \brief Allocate a node object. - * \param args arguments to the constructor. - * \tparam T the node type. - * \return The NodePtr to the allocated object. - */ -template -inline NodePtr make_node(Args&&... args); - -// Detail implementations after this -// -// The current design allows swapping the -// allocator pattern when necessary. -// -// Possible future allocator optimizations: -// - Arena allocator that gives ownership of memory to arena (deleter_= nullptr) -// - Thread-local object pools: one pool per size and alignment requirement. -// - Can specialize by type of object to give the specific allocator to each object. -// -template -class SimpleNodeAllocator { - public: - template - static T* New(Args&&... args) { - return new T(std::forward(args)...); - } - static NodeBase::FDeleter Deleter() { - return Deleter_; - } - - private: - static void Deleter_(NodeBase* ptr) { - delete static_cast(ptr); - } -}; - -template -inline NodePtr make_node(Args&&... args) { - using Allocator = SimpleNodeAllocator; - static_assert(std::is_base_of::value, - "make_node can only be used to create NodeBase"); - T* node = Allocator::New(std::forward(args)...); - node->deleter_ = Allocator::Deleter(); - return NodePtr(node); -} - -} // namespace tvm -#endif // TVM_NODE_MEMORY_H_ diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index cb18e46e9a5c..4014c3700596 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -18,344 +18,143 @@ */ /*! * \file tvm/node/node.h - * \brief Node system data structure. + * \brief Definitions and helper macros for IR/AST nodes. + * + * The node folder contains base utilities for IR/AST nodes, + * invariant of which specific language dialect. + * + * We implement AST/IR nodes as sub-classes of runtime::Object. + * The base class Node is just an alias of runtime::Object. + * + * Besides the runtime type checking provided by Object, + * node folder contains additional functionalities such as + * reflection and serialization, which are important features + * for building a compiler infra. */ #ifndef TVM_NODE_NODE_H_ #define TVM_NODE_NODE_H_ -#include #include -#include +#include +#include +#include + #include #include #include #include - namespace tvm { -// forward declaration -class DataType; -class Node; -class NodeRef; -namespace runtime { -// forward declaration -class NDArray; -// forward declaration -class ObjectRef; -} // namespace runtime +using runtime::TypeIndex; +using runtime::Object; +using runtime::ObjectPtr; +using runtime::ObjectRef; +using runtime::GetRef; +using runtime::Downcast; +using runtime::ObjectHash; +using runtime::ObjectEqual; +using runtime::make_object; -/*! - * \brief Visitor class to each node content. - * The content is going to be called for each field. - */ -class TVM_DLL AttrVisitor { - public: -//! \cond Doxygen_Suppress - virtual ~AttrVisitor() = default; - virtual void Visit(const char* key, double* value) = 0; - virtual void Visit(const char* key, int64_t* value) = 0; - virtual void Visit(const char* key, uint64_t* value) = 0; - virtual void Visit(const char* key, int* value) = 0; - virtual void Visit(const char* key, bool* value) = 0; - virtual void Visit(const char* key, std::string* value) = 0; - virtual void Visit(const char* key, void** value) = 0; - virtual void Visit(const char* key, DataType* value) = 0; - virtual void Visit(const char* key, NodeRef* value) = 0; - virtual void Visit(const char* key, runtime::NDArray* value) = 0; - virtual void Visit(const char* key, runtime::ObjectRef* value) = 0; - template::value>::type> - void Visit(const char* key, ENum* ptr) { - static_assert(std::is_same::type>::value, - "declare enum to be enum int to use visitor"); - this->Visit(key, reinterpret_cast(ptr)); - } -//! \endcond -}; +using NodeHash = ObjectHash; +using NodeEqual = ObjectEqual; +using Node = Object; /*! - * \brief base class of node container in DSL AST. + * \brief Base class of all references to AST/IR nodes. */ -class TVM_DLL Node : public NodeBase { +class NodeRef : public ObjectRef { public: - /*! \brief virtual destructor */ - virtual ~Node() {} - /*! \return The unique type key of the node */ - virtual const char* type_key() const = 0; - /*! - * \brief Apply visitor to each field of the Node - * Visitor could mutate the content of the node. - * override if Node contains attribute fields. - * \param visitor The visitor - */ - virtual void VisitAttrs(AttrVisitor* visitor) {} - /*! \return the type index of the node */ - virtual uint32_t type_index() const = 0; - /*! - * \brief Whether this node derives from node with type_index=tid. - * Implemented by TVM_DECLARE_NODE_TYPE_INFO - * - * \param tid The type index. - * \return the check result. - */ - virtual bool _DerivedFrom(uint32_t tid) const; - /*! - * \brief get a runtime unique type index given a type key - * \param type_key Type key of a type. - * \return the corresponding type index. - */ - static uint32_t TypeKey2Index(const char* type_key); - /*! - * \brief get type key from type index. - * \param index The type index - * \return the corresponding type key. - */ - static const char* TypeIndex2Key(uint32_t index); - /*! - * \return whether the type is derived from - */ - template - inline bool derived_from() const; - /*! - * \return whether the node is of type T - * \tparam The type to be checked. - */ - template - inline bool is_type() const; - /*! - * \brief Get a NodePtr that holds reference to this Node. - * \return the NodePtr - */ - inline NodePtr GetNodePtr() const; - // node ref can see this - friend class NodeRef; - static constexpr const char* _type_key = "Node"; -}; - -/*! \brief Base class of all node reference object */ -class NodeRef { - public: - /*! \brief type indicate the container type */ - using ContainerType = Node; - /*! - * \brief Comparator - * \param other Another node ref. - * \return the compare result. - */ - inline bool operator==(const NodeRef& other) const; - /*! - * \brief Comparator - * \param other Another node ref. - * \return the compare result. - */ - inline bool same_as(const NodeRef& other) const; - /*! - * \brief Comparator - * \param other Another node ref. - * \return the compare result. - */ - inline bool operator<(const NodeRef& other) const; - /*! - * \brief Comparator - * \param other Another node ref. - * \return the compare result. - */ - inline bool operator!=(const NodeRef& other) const; - /*! \return the hash function for NodeRef */ - inline size_t hash() const; - /*! \return whether the expression is null */ - inline bool defined() const; - /*! \return the internal type index of IRNode */ - inline uint32_t type_index() const; - /*! \return the internal node pointer */ - inline const Node* get() const; - /*! \return the internal node pointer */ - inline const Node* operator->() const; - /*! - * \brief Downcast this ir node to its actual type (e.g. Add, or - * Select). This returns nullptr if the node is not of the requested - * type. Example usage: - * - * if (const Add *add = node->as()) { - * // This is an add node - * } - * \tparam T the target type, must be subtype of IRNode - */ - template - inline const T *as() const; - /*! - * \brief A more powerful version of as that also works with - * intermediate base types. - * \tparam T the target type, must be subtype of IRNode - */ - template - inline const T *as_derived() const; - /*! \brief default constructor */ - NodeRef() = default; - explicit NodeRef(NodePtr node) : node_(node) {} - /*! \brief the internal node object, do not touch */ - NodePtr node_; + NodeRef() {} + explicit NodeRef(ObjectPtr n) : ObjectRef(n) {} }; /*! - * \brief Get a reference type from a Node ptr type - * - * It is always important to get a reference type - * if we want to return a value as reference or keep - * the node alive beyond the scope of the function. - * - * \param ptr The node pointer - * \tparam RefType The reference type - * \tparam NodeType The node type - * \return The corresponding RefType + * \brief Allocate a node object. + * \param args arguments to the constructor. + * \tparam T the node type. + * \return The NodePtr to the allocated object. + * \note This function is an alias of make_object. */ -template -inline RefType GetRef(const NodeType* ptr); - -/*! - * \brief Downcast a base reference type to a more specific type. - * - * \param ref The inptut reference - * \return The corresponding SubRef. - * \tparam SubRef The target specific reference type. - * \tparam BaseRef the current reference type. - */ -template -inline SubRef Downcast(BaseRef ref); +template +inline NodePtr make_node(Args&&... args) { + return runtime::make_object(std::forward(args)...); +} /*! * \brief helper macro to declare type information in a base node. */ -#define TVM_DECLARE_BASE_NODE_INFO(TypeName, Parent) \ - bool _DerivedFrom(uint32_t tid) const override { \ - static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ - if (tidx == tid) return true; \ - return Parent::_DerivedFrom(tid); \ - } +#define TVM_DECLARE_BASE_NODE_INFO(TypeName, Parent) \ + TVM_DECLARE_BASE_OBJECT_INFO(TypeName, Parent) /*! * \brief helper macro to declare type information in a terminal node */ -#define TVM_DECLARE_NODE_TYPE_INFO(TypeName, Parent) \ - const char* type_key() const final { \ - return TypeName::_type_key; \ - } \ - uint32_t type_index() const final { \ - static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ - return tidx; \ - } \ - bool _DerivedFrom(uint32_t tid) const final { \ - static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ - if (tidx == tid) return true; \ - return Parent::_DerivedFrom(tid); \ - } - -// implementations of inline functions after this -template -inline bool Node::derived_from() const { - // use static field so query only happens once. - static uint32_t type_id = Node::TypeKey2Index(T::_type_key); - return this->_DerivedFrom(type_id); -} - - -template -inline bool Node::is_type() const { - // use static field so query only happens once. - static uint32_t type_id = Node::TypeKey2Index(T::_type_key); - return type_id == this->type_index(); -} - - -inline NodePtr Node::GetNodePtr() const { - return NodePtr(const_cast(this)); -} - -template -inline RefType GetRef(const NodeType* ptr) { - static_assert(std::is_base_of::value, - "Can only cast to the ref of same container type"); - return RefType(ptr->GetNodePtr()); -} - -template -inline SubRef Downcast(BaseRef ref) { - CHECK(ref->template is_type() || - ref->template derived_from()) - << "Downcast from " << ref->type_key() << " to " - << SubRef::ContainerType::_type_key << " failed."; - return SubRef(std::move(ref.node_)); -} +#define TVM_DECLARE_NODE_TYPE_INFO(TypeName, Parent) \ + TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, Parent); -inline const Node* NodeRef::get() const { - return node_.get(); -} - -inline const Node* NodeRef::operator->() const { - return node_.get(); -} - -inline bool NodeRef::defined() const { - return node_.get() != nullptr; -} - -inline bool NodeRef::operator==(const NodeRef& other) const { - return node_.get() == other.node_.get(); -} - -inline bool NodeRef::same_as(const NodeRef& other) const { - return node_.get() == other.node_.get(); -} - -inline bool NodeRef::operator<(const NodeRef& other) const { - return node_.get() < other.node_.get(); -} - -inline bool NodeRef::operator!=(const NodeRef& other) const { - return node_.get() != other.node_.get(); -} - -inline size_t NodeRef::hash() const { - return std::hash()(node_.get()); -} - -inline uint32_t NodeRef::type_index() const { - CHECK(node_.get() != nullptr) - << "null type"; - return get()->type_index(); -} -template -inline const T* NodeRef::as() const { - const Node* ptr = static_cast(get()); - if (ptr && ptr->is_type()) { - return static_cast(ptr); - } - return nullptr; -} - -template -inline const T* NodeRef::as_derived() const { - const Node* ptr = static_cast(get()); - if (ptr && (ptr->is_type() || ptr->derived_from())) { - return static_cast(ptr); - } - return nullptr; -} +/*! + * \brief Macro to define common node ref methods. + * \param TypeName The name of the NodeRef. + * \param BaseTypeName The Base type. + * \param NodeName The node container type. + */ +#define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \ + TypeName() {} \ + explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \ + : BaseTypeName(n) {} \ + const NodeName* operator->() const { \ + return static_cast(data_.get()); \ + } \ + operator bool() const { return this->defined(); } \ + using ContainerType = NodeName; -/*! \brief The hash function for nodes */ -struct NodeHash { - size_t operator()(const NodeRef& a) const { - return a.hash(); - } -}; +/*! + * \brief Macro to define CopyOnWrite function in a NodeRef. + * \param NodeName The Type of the Node. + * + * CopyOnWrite will generate a unique copy of the internal node. + * The node will be copied if it is referenced by multiple places. + * The function returns the raw pointer to the node to allow modification + * of the content. + * + * \code + * + * MyCOWNodeRef ref, ref2; + * ref2 = ref; + * ref.CopyOnWrite()->value = new_value; + * assert(ref2->value == old_value); + * assert(ref->value == new_value); + * + * \endcode + */ +#define TVM_DEFINE_NODE_REF_COW(NodeName) \ + NodeName* CopyOnWrite() { \ + CHECK(data_ != nullptr); \ + if (!data_.unique()) { \ + NodePtr n = make_node(*(operator->())); \ + ObjectPtr(std::move(n)).swap(data_); \ + } \ + return static_cast(data_.get()); \ + } + +/*! \brief Macro to make it easy to define node ref type given node */ +#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \ + class TypeName : public ::tvm::NodeRef { \ + public: \ + TVM_DEFINE_NODE_REF_METHODS(TypeName, ::tvm::NodeRef, NodeName); \ + }; \ -/*! \brief The equal comparator for nodes */ -struct NodeEqual { - bool operator()(const NodeRef& a, const NodeRef& b) const { - return a.get() == b.get(); - } -}; +/*! + * \brief Macro to make it easy to define node ref type that + * has a CopyOnWrite member function. + */ +#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \ + class TypeName : public BaseType { \ + public: \ + TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseType, NodeName); \ + TVM_DEFINE_NODE_REF_COW(NodeName); \ + }; } // namespace tvm #endif // TVM_NODE_NODE_H_ diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h new file mode 100644 index 000000000000..e6caa443ab9c --- /dev/null +++ b/include/tvm/node/reflection.h @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/node/reflection.h + * \brief Reflection and serialization of compiler IR/AST nodes. + */ +#ifndef TVM_NODE_REFLECTION_H_ +#define TVM_NODE_REFLECTION_H_ + +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { + +// forward declaration +class DataType; + +using runtime::Object; +using runtime::ObjectPtr; +using runtime::ObjectRef; + +/*! + * \brief Visitor class for to get the attributesof a AST/IR node. + * The content is going to be called for each field. + * + * Each objects that wants reflection will need to implement + * a VisitAttrs function and call visitor->Visit on each of its field. + */ +class TVM_DLL AttrVisitor { + public: +//! \cond Doxygen_Suppress + virtual ~AttrVisitor() = default; + virtual void Visit(const char* key, double* value) = 0; + virtual void Visit(const char* key, int64_t* value) = 0; + virtual void Visit(const char* key, uint64_t* value) = 0; + virtual void Visit(const char* key, int* value) = 0; + virtual void Visit(const char* key, bool* value) = 0; + virtual void Visit(const char* key, std::string* value) = 0; + virtual void Visit(const char* key, void** value) = 0; + virtual void Visit(const char* key, DataType* value) = 0; + virtual void Visit(const char* key, runtime::NDArray* value) = 0; + virtual void Visit(const char* key, runtime::ObjectRef* value) = 0; + template::value>::type> + void Visit(const char* key, ENum* ptr) { + static_assert(std::is_same::type>::value, + "declare enum to be enum int to use visitor"); + this->Visit(key, reinterpret_cast(ptr)); + } +//! \endcond +}; + +/*! + * \brief Virtual function table to support IR/AST node reflection. + * + * Functions are stored in columar manner. + * Each column is a vector indexed by Object's type_index. + */ +class ReflectionVTable { + public: + /*! + * \brief Visitor function. + * \note We use function pointer, instead of std::function + * to reduce the dispatch overhead as field visit + * does not need as much customization. + */ + typedef void (*FVisitAttrs)(Object* self, AttrVisitor* visitor); + /*! + * \brief creator function. + * \param global_key Key that identifies a global single object. + * If this is not empty then FGlobalKey must be defined for the object. + * \return The created function. + */ + using FCreate = std::function(const std::string& global_key)>; + /*! + * \brief Global key function, only needed by global objects. + * \param node The node pointer. + * \return node The global key to the node. + */ + using FGlobalKey = std::function; + /*! + * \brief Dispatch the VisitAttrs function. + * \param self The pointer to the object. + * \param visitor The attribute visitor. + */ + inline void VisitAttrs(Object* self, AttrVisitor* visitor) const; + /*! + * \brief Get global key of the object, if any. + * \param self The pointer to the object. + * \return the global key if object has one, otherwise return empty string. + */ + inline std::string GetGlobalKey(Object* self) const; + /*! + * \brief Create an initial object using default constructor + * by type_key and global key. + * + * \param type_key The type key of the object. + * \param global_key A global key that can be used to uniquely identify the object if any. + */ + TVM_DLL ObjectPtr CreateInitObject(const std::string& type_key, + const std::string& global_key = "") const; + /*! + * \brief Get an field object by the attr name. + * \param self The pointer to the object. + * \param attr_name The name of the field. + * \return The corresponding attribute value. + * \note This function will throw an exception if the object does not contain the field. + */ + TVM_DLL runtime::TVMRetValue GetAttr(Object* self, const std::string& attr_name) const; + + /*! + * \brief List all the fields in the object. + * \return All the fields. + */ + TVM_DLL std::vector ListAttrNames(Object* self) const; + + /*! \return The global singleton. */ + TVM_DLL static ReflectionVTable* Global(); + + class Registry; + template + inline Registry Register(); + + private: + /*! \brief Attribute visitor. */ + std::vector fvisit_attrs_; + /*! \brief Creation function. */ + std::vector fcreate_; + /*! \brief Global key function. */ + std::vector fglobal_key_; +}; + +/*! \brief Registry of a reflection table. */ +class ReflectionVTable::Registry { + public: + Registry(ReflectionVTable* parent, uint32_t type_index) + : parent_(parent), type_index_(type_index) { } + /*! + * \brief Set fcreate function. + * \param f The creator function. + * \return rference to self. + */ + Registry& set_creator(FCreate f) { // NOLINT(*) + CHECK_LT(type_index_, parent_->fcreate_.size()); + parent_->fcreate_[type_index_] = f; + return *this; + } + /*! + * \brief Set global_key function. + * \param f The creator function. + * \return rference to self. + */ + Registry& set_global_key(FGlobalKey f) { // NOLINT(*) + CHECK_LT(type_index_, parent_->fglobal_key_.size()); + parent_->fglobal_key_[type_index_] = f; + return *this; + } + + private: + ReflectionVTable* parent_; + uint32_t type_index_; +}; + +/*! + * \brief Register a node type to object registry and reflection registry. + * \param TypeName The name of the type. + * \note This macro will call TVM_REGISTER_OBJECT_TYPE for the type as well. + */ +#define TVM_REGISTER_NODE_TYPE(TypeName) \ + TVM_REGISTER_OBJECT_TYPE(TypeName); \ + static DMLC_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry & \ + __make_Node ## _ ## TypeName ## __ = \ + ::tvm::ReflectionVTable::Global()->Register() \ + .set_creator([](const std::string&) { \ + return ::tvm::runtime::make_object(); \ + }) + +// Implementation details +template +inline ReflectionVTable::Registry +ReflectionVTable::Register() { + uint32_t tindex = T::RuntimeTypeIndex(); + if (tindex >= fvisit_attrs_.size()) { + fvisit_attrs_.resize(tindex + 1, nullptr); + fcreate_.resize(tindex + 1, nullptr); + fglobal_key_.resize(tindex + 1, nullptr); + } + // functor that implemnts the redirection. + struct Functor { + static void VisitAttrs(Object* self, AttrVisitor* v) { + static_cast(self)->VisitAttrs(v); + } + }; + + fvisit_attrs_[tindex] = Functor::VisitAttrs; + return Registry(this, tindex); +} + +inline void ReflectionVTable:: +VisitAttrs(Object* self, AttrVisitor* visitor) const { + uint32_t tindex = self->type_index(); + if (tindex >= fvisit_attrs_.size() || fvisit_attrs_[tindex] == nullptr) { + LOG(FATAL) << "TypeError: " << self->GetTypeKey() + << " is not registered via TVM_REGISTER_NODE_TYPE"; + } + fvisit_attrs_[tindex](self, visitor); +} + +inline std::string ReflectionVTable::GetGlobalKey(Object* self) const { + uint32_t tindex = self->type_index(); + if (tindex < fglobal_key_.size() && fglobal_key_[tindex] != nullptr) { + return fglobal_key_[tindex](self); + } else { + return std::string(); + } +} + +} // namespace tvm +#endif // TVM_NODE_REFLECTION_H_ diff --git a/include/tvm/node/serialization.h b/include/tvm/node/serialization.h new file mode 100644 index 000000000000..ac675946e0eb --- /dev/null +++ b/include/tvm/node/serialization.h @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Utility functions for serialization. + * \file tvm/node/serialization.h + */ +#ifndef TVM_NODE_SERIALIZATION_H_ +#define TVM_NODE_SERIALIZATION_H_ + +#include +#include + +#include + +namespace tvm { +/*! + * \brief save the node as well as all the node it depends on as json. + * This can be used to serialize any TVM object + * + * \return the string representation of the node. + */ +TVM_DLL std::string SaveJSON(const runtime::ObjectRef& node); + +/*! + * \brief Internal implementation of LoadJSON + * Load tvm Node object from json and return a shared_ptr of Node. + * \param json_str The json string to load from. + * + * \return The shared_ptr of the Node. + */ +TVM_DLL runtime::ObjectRef LoadJSON(std::string json_str); + +} // namespace tvm +#endif // TVM_NODE_SERIALIZATION_H_ diff --git a/include/tvm/operation.h b/include/tvm/operation.h index b950aa952f04..f53c1ce56a93 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -188,7 +188,7 @@ class PlaceholderOpNode : public OperationNode { const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); v->Visit("tag", &tag); v->Visit("attrs", &attrs); @@ -259,7 +259,7 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode { bool debug_keep_trivial_loop) const final; size_t num_schedulable_dims() const final; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); v->Visit("tag", &tag); v->Visit("attrs", &attrs); @@ -312,7 +312,7 @@ class TensorComputeOpNode : public BaseComputeOpNode { bool debug_keep_trivial_loop) const final; size_t num_schedulable_dims() const final; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); v->Visit("tag", &tag); v->Visit("axis", &axis); @@ -394,7 +394,7 @@ class ScanOpNode : public OperationNode { const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); v->Visit("tag", &tag); v->Visit("attrs", &attrs); @@ -461,7 +461,7 @@ class ExternOpNode : public OperationNode { const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); v->Visit("tag", &tag); v->Visit("attrs", &attrs); @@ -529,7 +529,7 @@ class HybridOpNode : public OperationNode { const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); v->Visit("tag", &tag); v->Visit("attrs", &attrs); @@ -651,7 +651,7 @@ inline Tensor compute(Array shape, // inline function. inline const OperationNode* Operation::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } } // namespace tvm #endif // TVM_OPERATION_H_ diff --git a/include/tvm/packed_func_ext.h b/include/tvm/packed_func_ext.h index 5951594b873c..71f8f55b2655 100644 --- a/include/tvm/packed_func_ext.h +++ b/include/tvm/packed_func_ext.h @@ -20,7 +20,7 @@ /*! * \file tvm/packed_func_ext.h * \brief Extension package to PackedFunc - * This enales pass NodeRef types into/from PackedFunc. + * This enales pass ObjectRef types into/from PackedFunc. */ #ifndef TVM_PACKED_FUNC_EXT_H_ #define TVM_PACKED_FUNC_EXT_H_ @@ -37,6 +37,7 @@ #include "runtime/packed_func.h" namespace tvm { + using runtime::TVMArgs; using runtime::TVMRetValue; using runtime::PackedFunc; @@ -47,103 +48,99 @@ namespace runtime { * \tparam T the type to be checked. */ template -struct NodeTypeChecker { - static inline bool Check(Node* sptr) { - // This is the only place in the project where RTTI is used - // It can be turned off, but will make non strict checking. - // TODO(tqchen) possibly find alternative to turn of RTTI +struct ObjectTypeChecker { + static bool Check(const Object* ptr) { using ContainerType = typename T::ContainerType; - // always allow nullptr. - if (sptr == nullptr) return true; - return sptr->derived_from(); + if (ptr == nullptr) return true; + return ptr->IsInstance(); } - static inline void PrintName(std::ostringstream& os) { // NOLINT(*) + static void PrintName(std::ostream& os) { // NOLINT(*) using ContainerType = typename T::ContainerType; os << ContainerType::_type_key; } }; template -struct NodeTypeChecker > { - static inline bool Check(Node* sptr) { - if (sptr == nullptr) return true; - if (!sptr->is_type()) return false; - ArrayNode* n = static_cast(sptr); +struct ObjectTypeChecker > { + static bool Check(const Object* ptr) { + if (ptr == nullptr) return true; + if (!ptr->IsInstance()) return false; + const ArrayNode* n = static_cast(ptr); for (const auto& p : n->data) { - if (!NodeTypeChecker::Check(p.get())) { + if (!ObjectTypeChecker::Check(p.get())) { return false; } } return true; } - static inline void PrintName(std::ostringstream& os) { // NOLINT(*) - os << "array<"; - NodeTypeChecker::PrintName(os); - os << ">"; + static void PrintName(std::ostream& os) { // NOLINT(*) + os << "List["; + ObjectTypeChecker::PrintName(os); + os << "]"; } }; template -struct NodeTypeChecker > { - static inline bool Check(Node* sptr) { - if (sptr == nullptr) return true; - if (!sptr->is_type()) return false; - StrMapNode* n = static_cast(sptr); +struct ObjectTypeChecker > { + static bool Check(const Object* ptr) { + if (ptr == nullptr) return true; + if (!ptr->IsInstance()) return false; + const StrMapNode* n = static_cast(ptr); for (const auto& kv : n->data) { - if (!NodeTypeChecker::Check(kv.second.get())) return false; + if (!ObjectTypeChecker::Check(kv.second.get())) return false; } return true; } - static inline void PrintName(std::ostringstream& os) { // NOLINT(*) - os << "map::PrintName(os); - os << '>'; + ObjectTypeChecker::PrintName(os); + os << ']'; } }; template -struct NodeTypeChecker > { - static inline bool Check(Node* sptr) { - if (sptr == nullptr) return true; - if (!sptr->is_type()) return false; - MapNode* n = static_cast(sptr); +struct ObjectTypeChecker > { + static bool Check(const Object* ptr) { + if (ptr == nullptr) return true; + if (!ptr->IsInstance()) return false; + const MapNode* n = static_cast(ptr); for (const auto& kv : n->data) { - if (!NodeTypeChecker::Check(kv.first.get())) return false; - if (!NodeTypeChecker::Check(kv.second.get())) return false; + if (!ObjectTypeChecker::Check(kv.first.get())) return false; + if (!ObjectTypeChecker::Check(kv.second.get())) return false; } return true; } - static inline void PrintName(std::ostringstream& os) { // NOLINT(*) - os << "map<"; - NodeTypeChecker::PrintName(os); + static void PrintName(std::ostringstream& os) { // NOLINT(*) + os << "Map["; + ObjectTypeChecker::PrintName(os); os << ','; - NodeTypeChecker::PrintName(os); - os << '>'; + ObjectTypeChecker::PrintName(os); + os << ']'; } }; template -inline std::string NodeTypeName() { +inline std::string ObjectTypeName() { std::ostringstream os; - NodeTypeChecker::PrintName(os); + ObjectTypeChecker::PrintName(os); return os.str(); } // extensions for tvm arg value -template -inline TNodeRef TVMArgValue::AsNodeRef() const { +template +inline TObjectRef TVMArgValue::AsObjectRef() const { static_assert( - std::is_base_of::value, - "Conversion only works for NodeRef"); - if (type_code_ == kNull) return TNodeRef(NodePtr(nullptr)); - TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); - NodePtr& sptr = *ptr >(); - CHECK(NodeTypeChecker::Check(sptr.get())) - << "Expected type " << NodeTypeName() - << " but get " << sptr->type_key(); - return TNodeRef(sptr); + std::is_base_of::value, + "Conversion only works for ObjectRef"); + if (type_code_ == kNull) return TObjectRef(NodePtr(nullptr)); + TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); + Object* ptr = static_cast(value_.v_handle); + CHECK(ObjectTypeChecker::Check(ptr)) + << "Expected type " << ObjectTypeName() + << " but get " << ptr->GetTypeKey(); + return TObjectRef(ObjectPtr(ptr)); } inline TVMArgValue::operator tvm::Expr() const { @@ -156,18 +153,20 @@ inline TVMArgValue::operator tvm::Expr() const { if (type_code_ == kDLFloat) { return Expr(static_cast(value_.v_float64)); } - TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); - NodePtr& sptr = *ptr >(); - if (sptr->is_type()) { - return IterVar(sptr)->var; + + TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); + Object* ptr = static_cast(value_.v_handle); + + if (ptr->IsInstance()) { + return IterVar(ObjectPtr(ptr))->var; } - if (sptr->is_type()) { - return Tensor(sptr)(); + if (ptr->IsInstance()) { + return Tensor(ObjectPtr(ptr))(); } - CHECK(NodeTypeChecker::Check(sptr.get())) - << "Expected type " << NodeTypeName() - << " but get " << sptr->type_key(); - return Expr(sptr); + CHECK(ObjectTypeChecker::Check(ptr)) + << "Expected type " << ObjectTypeName() + << " but get " << ptr->GetTypeKey(); + return Expr(ObjectPtr(ptr)); } inline TVMArgValue::operator tvm::Integer() const { @@ -177,68 +176,36 @@ inline TVMArgValue::operator tvm::Integer() const { CHECK_GE(value_.v_int64, std::numeric_limits::min()); return Integer(static_cast(value_.v_int64)); } - NodePtr& sptr = *ptr >(); - CHECK(NodeTypeChecker::Check(sptr.get())) - << "Expected type " << NodeTypeName() - << " but get " << sptr->type_key(); - return Integer(sptr); -} - -inline NodePtr& TVMArgValue::node_sptr() { - TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); - return *ptr >(); + TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); + Object* ptr = static_cast(value_.v_handle); + CHECK(ObjectTypeChecker::Check(ptr)) + << "Expected type " << ObjectTypeName() + << " but get " << ptr->GetTypeKey(); + return Integer(ObjectPtr(ptr)); } - -template -inline bool TVMArgValue::IsNodeType() const { - TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); - NodePtr& sptr = - *ptr >(); - return NodeTypeChecker::Check(sptr.get()); +template +inline bool TVMPODValue_::IsObjectRef() const { + TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); + Object* ptr = static_cast(value_.v_handle); + return ObjectTypeChecker::Check(ptr); } // extensions for TVMRetValue -inline TVMRetValue& TVMRetValue::operator=( - const NodePtr& other) { - if (other.get() == nullptr) { - SwitchToPOD(kNull); - } else { - SwitchToClass >(kNodeHandle, other); - } - return *this; -} - -inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) { - if (!other.defined()) { - SwitchToPOD(kNull); - } else { - SwitchToClass >(kNodeHandle, other.node_); - } - return *this; -} - -template -inline TNodeRef TVMRetValue::AsNodeRef() const { +template +inline TObjectRef TVMRetValue::AsObjectRef() const { static_assert( - std::is_base_of::value, - "Conversion only works for NodeRef"); - if (type_code_ == kNull) return TNodeRef(); - TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); - NodePtr& sptr = *ptr >(); - CHECK(NodeTypeChecker::Check(sptr.get())) - << "Expected type " << NodeTypeName() - << " but get " << sptr->type_key(); - return TNodeRef(sptr); -} + std::is_base_of::value, + "Conversion only works for ObjectRef"); + if (type_code_ == kNull) return TObjectRef(); + TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); -inline void TVMArgsSetter::operator()(size_t i, const NodeRef& other) const { // NOLINT(*) - if (other.defined()) { - values_[i].v_handle = const_cast*>(&(other.node_)); - type_codes_[i] = kNodeHandle; - } else { - type_codes_[i] = kNull; - } + Object* ptr = static_cast(value_.v_handle); + + CHECK(ObjectTypeChecker::Check(ptr)) + << "Expected type " << ObjectTypeName() + << " but get " << ptr->GetTypeKey(); + return TObjectRef(ObjectPtr(ptr)); } // type related stuffs diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index 4329c438e8a0..a74353239a00 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -52,7 +52,7 @@ class PatternNode : public RelayNode { class Pattern : public NodeRef { public: Pattern() {} - explicit Pattern(NodePtr p) : NodeRef(p) {} + explicit Pattern(ObjectPtr p) : NodeRef(p) {} using ContainerType = PatternNode; }; @@ -66,7 +66,7 @@ class PatternWildcardNode : public PatternNode { TVM_DLL static PatternWildcard make(); - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); } @@ -88,7 +88,7 @@ class PatternVarNode : public PatternNode { TVM_DLL static PatternVar make(tvm::relay::Var var); - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("var", &var); v->Visit("span", &span); } @@ -122,7 +122,7 @@ class ConstructorNode : public ExprNode { tvm::Array inputs, GlobalTypeVar belong_to); - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name_hint", &name_hint); v->Visit("inputs", &inputs); v->Visit("belong_to", &belong_to); @@ -151,7 +151,7 @@ class PatternConstructorNode : public PatternNode { TVM_DLL static PatternConstructor make(Constructor constructor, tvm::Array var); - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("constructor", &constructor); v->Visit("patterns", &patterns); v->Visit("span", &span); @@ -175,7 +175,7 @@ class PatternTupleNode : public PatternNode { TVM_DLL static PatternTuple make(tvm::Array var); - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("patterns", &patterns); v->Visit("span", &span); } @@ -213,7 +213,7 @@ class TypeDataNode : public TypeNode { /*! \brief The constructors. */ tvm::Array constructors; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("header", &header); v->Visit("type_vars", &type_vars); v->Visit("constructors", &constructors); @@ -240,7 +240,7 @@ class ClauseNode : public Node { /*! \brief The resulting value. */ Expr rhs; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("lhs", &lhs); v->Visit("rhs", &rhs); } @@ -269,7 +269,7 @@ class MatchNode : public ExprNode { */ bool complete; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("data", &data); v->Visit("clauses", &clauses); v->Visit("complete", &complete); diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index f94ba5e26068..5a2326ece05d 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -83,10 +83,12 @@ using NodeEqual = ::tvm::NodeEqual; #define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefBase) \ class TypeName : public NodeRefBase { \ public: \ - TypeName() {} \ - explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : NodeRefBase(n) {} \ + TypeName() {} \ + explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \ + : NodeRefBase(n) { \ + } \ const NodeName* operator->() const { \ - return static_cast(node_.get()); \ + return static_cast(get()); \ } \ operator bool() { return this->defined(); } \ using ContainerType = NodeName; \ @@ -105,7 +107,7 @@ class SourceNameNode : public Node { /*! \brief The source name. */ std::string name; // override attr visitor - void VisitAttrs(AttrVisitor* v) final { v->Visit("name", &name); } + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } static constexpr const char* _type_key = "relay.SourceName"; TVM_DECLARE_NODE_TYPE_INFO(SourceNameNode, Node); @@ -127,7 +129,7 @@ class SourceName : public NodeRef { * \return the pointer to the internal node container */ inline const SourceNameNode* operator->() const { - return static_cast(this->node_.get()); + return static_cast(get()); } /*! @@ -158,7 +160,7 @@ class SpanNode : public Node { /*! \brief column offset */ int col_offset; // override attr visitor - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("source", &source); v->Visit("lineno", &lineno); v->Visit("col_offset", &col_offset); @@ -202,7 +204,7 @@ class IdNode : public Node { */ std::string name_hint; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name_hint", &name_hint); } diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index b1b8d6a7154e..ff075e3a8970 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -95,7 +95,7 @@ class ConstantNode : public ExprNode { return data->ndim == 0; } - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("data", &data); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); @@ -117,7 +117,7 @@ class TupleNode : public ExprNode { /*! \brief the fields of the tuple */ tvm::Array fields; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); @@ -165,7 +165,7 @@ class VarNode : public ExprNode { return vid->name_hint; } - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("vid", &vid); v->Visit("type_annotation", &type_annotation); v->Visit("span", &span); @@ -197,7 +197,7 @@ class GlobalVarNode : public ExprNode { /*! \brief The name of the variable, this only acts as a hint. */ std::string name_hint; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name_hint", &name_hint); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); @@ -243,7 +243,7 @@ class FunctionNode : public ExprNode { */ tvm::Attrs attrs; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("params", ¶ms); v->Visit("body", &body); v->Visit("ret_type", &ret_type); @@ -274,6 +274,19 @@ class FunctionNode : public ExprNode { tvm::Array ty_params, tvm::Attrs attrs = Attrs()); + /*! + * \brief Attach the function's parameters to its attributes for use in analysis. + * \return The function with its parameters attached. + */ + Function SetParams(const tvm::Map& parameters) const; + + /*! + * \brief Retrieve the function's parameters. + * + * \return The function's parameter. + */ + tvm::Map GetParams() const; + static constexpr const char* _type_key = "relay.Function"; TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode); }; @@ -284,7 +297,6 @@ RELAY_DEFINE_NODE_REF(Function, FunctionNode, Expr); TVM_DLL NodeRef FunctionGetAttr(const Function& func, const std::string& key); TVM_DLL Function FunctionSetAttr(const Function& func, const std::string& key, const NodeRef& data); - /*! * \brief Call corresponds to operator invocation. * Corresponds to the operator in computational graph terminology. @@ -327,7 +339,7 @@ class CallNode : public ExprNode { */ tvm::Array type_args; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("op", &op); v->Visit("args", &args); v->Visit("attrs", &attrs); @@ -369,7 +381,7 @@ class LetNode : public ExprNode { /*! \brief The body of the let binding */ Expr body; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("var", &var); v->Visit("value", &value); v->Visit("body", &body); @@ -407,7 +419,7 @@ class IfNode : public ExprNode { /*! \brief The expression evaluated when condition is false */ Expr false_branch; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("cond", &cond); v->Visit("true_branch", &true_branch); v->Visit("false_branch", &false_branch); @@ -432,7 +444,7 @@ class TupleGetItemNode : public ExprNode { /*! \brief which value to get */ int index; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("tuple_value", &tuple); v->Visit("index", &index); v->Visit("span", &span); @@ -454,7 +466,7 @@ class RefCreateNode : public ExprNode { /*! \brief The initial value of the Reference. */ Expr value; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("value", &value); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); @@ -475,7 +487,7 @@ class RefReadNode : public ExprNode { /*! \brief The Reference Expression. */ Expr ref; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("ref", &ref); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); @@ -498,7 +510,7 @@ class RefWriteNode : public ExprNode { /*! \brief The value to write into. */ Expr value; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("ref", &ref); v->Visit("value", &value); v->Visit("span", &span); @@ -541,10 +553,11 @@ RELAY_DEFINE_NODE_REF(TempExpr, TempExprNode, Expr); // implementataions inline const Type& ExprNode::checked_type() const { - CHECK(checked_type_.defined()) << "internal error: the type checker has " - "not populated the checked_type " - "field for " - << GetRef(this); + CHECK(checked_type_.defined()) + << "internal error: the type checker has " + << "not populated the checked_type " + << "field for " + << GetRef(this); return this->checked_type_; } @@ -557,7 +570,7 @@ inline const TTypeNode* ExprNode::type_as() const { const TTypeNode* node = checked_type_.as(); CHECK(node != nullptr) << "Expected type to be " << TTypeNode::_type_key - << ", but get " << checked_type_->type_key(); + << ", but get " << checked_type_->GetTypeKey(); return node; } diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index e0d940c5d1a5..8bc87a27f66f 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -57,8 +57,8 @@ class ExprFunctor; #define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \ vtable.template set_dispatch( \ - [](const NodeRef& n, TSelf* self, Args... args) { \ - return self->VisitExpr_(static_cast(n.node_.get()), \ + [](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitExpr_(static_cast(n.get()), \ std::forward(args)...); \ }); @@ -66,7 +66,7 @@ template class ExprFunctor { private: using TSelf = ExprFunctor; - using FType = tvm::IRFunctor; + using FType = tvm::IRFunctor; public: /*! \brief the result type of this functor */ @@ -117,7 +117,7 @@ class ExprFunctor { virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Node* op, Args...) { - LOG(FATAL) << "Do not have a default for " << op->type_key(); + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); throw; } diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index d05099f781ac..d5d783d4804a 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -78,9 +78,9 @@ class ValueNode : public RelayNode { class Value : public NodeRef { public: Value() {} - explicit Value(NodePtr n) : NodeRef(n) {} + explicit Value(ObjectPtr n) : NodeRef(n) {} const ValueNode* operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } using ContainerType = ValueNode; @@ -106,7 +106,7 @@ class ClosureNode : public ValueNode { ClosureNode() {} - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("env", &env); v->Visit("func", &func); } @@ -119,6 +119,32 @@ class ClosureNode : public ValueNode { RELAY_DEFINE_NODE_REF(Closure, ClosureNode, Value); +/*! \brief A Relay Recursive Closure. A closure that has a name. */ +class RecClosure; + +/*! \brief The container type of RecClosure. */ +class RecClosureNode : public ValueNode { + public: + /*! \brief The closure. */ + Closure clos; + /*! \brief variable the closure bind to. */ + Var bind; + + RecClosureNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("clos", &clos); + v->Visit("bind", &bind); + } + + TVM_DLL static RecClosure make(Closure clos, Var bind); + + static constexpr const char* _type_key = "relay.RecClosure"; + TVM_DECLARE_NODE_TYPE_INFO(RecClosureNode, ValueNode); +}; + +RELAY_DEFINE_NODE_REF(RecClosure, RecClosureNode, Value); + /*! \brief A tuple value. */ class TupleValue; @@ -128,7 +154,7 @@ struct TupleValueNode : ValueNode { TupleValueNode() {} - void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("fields", &fields); } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); } TVM_DLL static TupleValue make(tvm::Array value); @@ -147,7 +173,7 @@ struct TensorValueNode : ValueNode { TensorValueNode() {} - void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("data", &data); } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("data", &data); } /*! \brief Build a value from an NDArray. */ TVM_DLL static TensorValue make(runtime::NDArray data); @@ -166,7 +192,7 @@ struct RefValueNode : ValueNode { RefValueNode() {} - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("value", &value); } @@ -189,7 +215,7 @@ struct ConstructorValueNode : ValueNode { /*! \brief Optional field tracking ADT constructor. */ Constructor constructor; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("tag", &tag); v->Visit("fields", &fields); v->Visit("constructor", &constructor); diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 8b17020a1132..160ae5db8265 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -68,7 +68,7 @@ class ModuleNode : public RelayNode { ModuleNode() {} - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("functions", &functions); v->Visit("type_definitions", &type_definitions); v->Visit("global_var_map_", &global_var_map_); @@ -281,10 +281,10 @@ class ModuleNode : public RelayNode { struct Module : public NodeRef { Module() {} - explicit Module(NodePtr p) : NodeRef(p) {} + explicit Module(ObjectPtr<::tvm::Object> p) : NodeRef(p) {} - inline ModuleNode* operator->() const { - return static_cast(node_.get()); + ModuleNode* operator->() const { + return static_cast(get_mutable()); } using ContainerType = ModuleNode; diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 0a6d3725655f..7d2a1f653a93 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -24,6 +24,8 @@ #ifndef TVM_RELAY_OP_H_ #define TVM_RELAY_OP_H_ +#include + #include #include #include @@ -82,7 +84,7 @@ class OpNode : public relay::ExprNode { */ int32_t support_level = 10; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); v->Visit("op_type", &op_type); v->Visit("description", &description); @@ -138,7 +140,7 @@ class Op : public relay::Expr { /*! \brief default constructor */ Op() {} /*! \brief constructor from node pointer */ - explicit Op(NodePtr n) : Expr(n) {} + explicit Op(ObjectPtr n) : Expr(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -221,11 +223,12 @@ class OpRegistry { const Attrs&, const TypeReporter&)> type_rel_func); /*! - * \brief Set the type key of attributes. - * \param type_key The type of of the attrs field. + * \brief Set the the attrs type key and index to be AttrsType. + * \tparam AttrsType the attribute type to b set. * \return reference to self. */ - inline OpRegistry& set_attrs_type_key(const std::string& type_key); + template + inline OpRegistry& set_attrs_type(); /*! * \brief Set the num_inputs * \param n The number of inputs to be set. @@ -397,7 +400,7 @@ class OpMap { // implementations inline const OpNode* Op::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } template @@ -496,10 +499,10 @@ inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*) return *this; } -inline OpRegistry& OpRegistry::set_attrs_type_key( // NOLINT(*) - const std::string& type_key) { - get()->attrs_type_key = type_key; - get()->attrs_type_index = Node::TypeKey2Index(type_key.c_str()); +template +inline OpRegistry& OpRegistry::set_attrs_type() { // NOLINT(*) + get()->attrs_type_key = AttrsType::_type_key; + get()->attrs_type_index = AttrsType::RuntimeTypeIndex(); return *this; } diff --git a/include/tvm/relay/pattern_functor.h b/include/tvm/relay/pattern_functor.h index 7f1c47e03592..c15523cb25de 100644 --- a/include/tvm/relay/pattern_functor.h +++ b/include/tvm/relay/pattern_functor.h @@ -57,8 +57,8 @@ class PatternFunctor; #define RELAY_PATTERN_FUNCTOR_DISPATCH(OP) \ vtable.template set_dispatch( \ - [](const NodeRef& n, TSelf* self, Args... args) { \ - return self->VisitPattern_(static_cast(n.node_.get()), \ + [](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitPattern_(static_cast(n.get()), \ std::forward(args)...); \ }); @@ -66,7 +66,7 @@ template class PatternFunctor { private: using TSelf = PatternFunctor; - using FType = tvm::IRFunctor; + using FType = tvm::IRFunctor; public: /*! \brief the result type of this functor */ @@ -103,7 +103,7 @@ class PatternFunctor { virtual R VisitPattern_(const PatternTupleNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT; virtual R VisitPatternDefault_(const Node* op, Args...) { - LOG(FATAL) << "Do not have a default for " << op->type_key(); + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); throw; } diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h index 83b55b04222a..e5f4ba94e12e 100644 --- a/include/tvm/relay/qnn/attrs.h +++ b/include/tvm/relay/qnn/attrs.h @@ -49,7 +49,7 @@ struct RequantizeAttrs : public tvm::AttrsNode { .describe("The scale of the output tensor."); TVM_ATTR_FIELD(output_zero_point) .describe("The zero point of the output tensor."); - TVM_ATTR_FIELD(rounding).set_default("TONEAREST") + TVM_ATTR_FIELD(rounding).set_default("UPWARD") .describe("Defines the rounding direction when the value is midway between" "two representable values. There are two supported modes - UPWARD" "or TONEAREST. Both modes behave exactly same except at the" diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index a2119c90f750..82144d76e565 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -101,7 +101,7 @@ class PassContextNode : public RelayNode { PassContextNode() = default; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("opt_level", &opt_level); v->Visit("fallback_device", &fallback_device); v->Visit("required_pass", &required_pass); @@ -134,16 +134,16 @@ class PassContext : public NodeRef { * \return const access pointer. */ const PassContextNode* operator->() const { - CHECK(node_.get() != nullptr); - return static_cast(node_.get()); + CHECK(get() != nullptr); + return static_cast(get()); } /*! * \brief mutable accessor. * \return mutable access pointer. */ PassContextNode* operator->() { - CHECK(node_.get() != nullptr); - return static_cast(node_.get()); + CHECK(get() != nullptr); + return static_cast(get_mutable()); } /*! * \brief Construct a PassContext containing the default configurations. @@ -196,7 +196,7 @@ class PassInfoNode : public RelayNode { PassInfoNode() = default; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("opt_level", &opt_level); v->Visit("name", &name); v->Visit("required", &required); @@ -221,6 +221,7 @@ class Pass; */ class PassNode : public RelayNode { public: + virtual ~PassNode() {} /*! * \brief Get the pass information/meta data. */ virtual PassInfo Info() const = 0; @@ -247,7 +248,7 @@ class PassNode : public RelayNode { virtual Module operator()(const Module& mod, const PassContext& pass_ctx) const = 0; - void VisitAttrs(tvm::AttrVisitor* v) override {} + void VisitAttrs(tvm::AttrVisitor* v) {} static constexpr const char* _type_key = "relay.Pass"; TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode); diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 16e36785c533..e0c056c1216b 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -58,7 +58,7 @@ class TypeNode : public RelayNode { class Type : public NodeRef { public: Type() {} - explicit Type(NodePtr p) : NodeRef(p) {} + explicit Type(ObjectPtr p) : NodeRef(p) {} using ContainerType = TypeNode; }; @@ -96,7 +96,7 @@ class TensorTypeNode : public BaseTensorTypeNode { /*! \brief The content data type */ DataType dtype; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("shape", &shape); v->Visit("dtype", &dtype); v->Visit("span", &span); @@ -159,7 +159,7 @@ class TypeVarNode : public TypeNode { /*! \brief The kind of type parameter */ Kind kind; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("var", &var); v->Visit("kind", &kind); v->Visit("span", &span); @@ -188,7 +188,7 @@ class GlobalTypeVarNode : public TypeNode { /*! \brief The kind of type parameter */ Kind kind; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("var", &var); v->Visit("kind", &kind); v->Visit("span", &span); @@ -216,7 +216,7 @@ class TypeCallNode : public TypeNode { /*! \brief The arguments. */ tvm::Array args; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("func", &func); v->Visit("args", &args); v->Visit("span", &span); @@ -245,7 +245,7 @@ class IncompleteTypeNode : public TypeNode { public: Kind kind; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("kind", &kind); v->Visit("span", &span); } @@ -297,7 +297,7 @@ class FuncTypeNode : public TypeNode { */ tvm::Array type_constraints; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("arg_types", &arg_types); v->Visit("ret_type", &ret_type); v->Visit("type_params", &type_params); @@ -330,7 +330,7 @@ class TupleTypeNode : public TypeNode { TupleTypeNode() {} - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); v->Visit("span", &span); } @@ -357,7 +357,7 @@ class RefTypeNode : public TypeNode { RefTypeNode() {} - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("value", &value); v->Visit("span", &span); } @@ -417,7 +417,7 @@ class TypeReporterNode : public Node { TVM_DLL virtual Module GetModule() = 0; // solver is not serializable. - void VisitAttrs(tvm::AttrVisitor* v) final {} + void VisitAttrs(tvm::AttrVisitor* v) {} static constexpr const char* _type_key = "relay.TypeReporter"; TVM_DECLARE_NODE_TYPE_INFO(TypeReporterNode, Node); @@ -430,10 +430,11 @@ class TypeReporterNode : public Node { class TypeReporter : public NodeRef { public: TypeReporter() {} - explicit TypeReporter(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) { + explicit TypeReporter(::tvm::ObjectPtr<::tvm::Object> n) : NodeRef(n) { } TypeReporterNode* operator->() const { - return static_cast(node_.get()); + return const_cast( + static_cast(get())); } using ContainerType = TypeReporterNode; }; @@ -487,7 +488,7 @@ class TypeRelationNode : public TypeConstraintNode { /*! \brief Attributes to the relation function */ Attrs attrs; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("func", &func); v->Visit("args", &args); v->Visit("num_inputs", &num_inputs); diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 54e6f98e8ee5..267504beb11a 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -98,13 +98,12 @@ typedef enum { kTVMType = 5U, kTVMContext = 6U, kArrayHandle = 7U, - kNodeHandle = 8U, + kObjectHandle = 8U, kModuleHandle = 9U, kFuncHandle = 10U, kStr = 11U, kBytes = 12U, kNDArrayContainer = 13U, - kObjectCell = 14U, // Extension codes for other frameworks to integrate TVM PackedFunc. // To make sure each framework's id do not conflict, use first and // last sections to mark ranges. @@ -549,13 +548,31 @@ TVM_DLL int TVMStreamStreamSynchronize(int device_type, TVMStreamHandle dst); /*! - * \brief Get the tag from an object. + * \brief Get the type_index from an object. * * \param obj The object handle. - * \param tag The tag of object. + * \param out_tindex the output type index. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMGetObjectTag(TVMObjectHandle obj, int* tag); +TVM_DLL int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex); + +/*! + * \brief Convert type key to type index. + * \param type_key The key of the type. + * \param out_tindex the corresponding type index. + * \return 0 when success, -1 when failure happens + */ +TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex); + +/*! + * \brief Free the object. + * + * \param obj The object handle. + * \note Internally we decrease the reference counter of the object. + * The object will be freed when every reference to the object are removed. + * \return 0 when success, -1 when failure happens + */ +TVM_DLL int TVMObjectFree(TVMObjectHandle obj); #ifdef __cplusplus } // TVM_EXTERN_C diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index 68029c13cb93..bb362dcdec66 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -230,6 +230,7 @@ inline std::ostream& operator<<(std::ostream& os, DLContext ctx) { // NOLINT(*) os << runtime::DeviceName(device_type) << "(" << ctx.device_id << ")"; return os; } + #endif } // namespace runtime } // namespace tvm diff --git a/include/tvm/runtime/memory.h b/include/tvm/runtime/memory.h index 6b4f01e4ac9b..d28552eaf7fd 100644 --- a/include/tvm/runtime/memory.h +++ b/include/tvm/runtime/memory.h @@ -69,7 +69,7 @@ class ObjAllocatorBase { "make_node can only be used to create NodeBase"); T* ptr = Handler::New(static_cast(this), std::forward(args)...); - ptr->type_index_ = T::type_index(); + ptr->type_index_ = T::RuntimeTypeIndex(); ptr->deleter_ = Handler::Deleter(); return ObjectPtr(ptr); } @@ -82,6 +82,8 @@ class SimpleObjAllocator : template class Handler { public: + using StorageType = typename std::aligned_storage::type; + template static T* New(SimpleObjAllocator*, Args&&... args) { // NOTE: the first argument is not needed for SimpleObjAllocator @@ -91,7 +93,15 @@ class SimpleObjAllocator : // In the case of an object pool, an allocator needs to create // a special chunk memory that hides reference to the allocator // and call allocator's release function in the deleter. - return new T(std::forward(args)...); + + // NOTE2: Use inplace new to allocate + // This is used to get rid of warning when deleting a virtual + // class with non-virtual destructor. + // We are fine here as we captured the right deleter during construction. + // This is also the right way to get storage type for an object pool. + StorageType* data = new StorageType(); + new (data) T(std::forward(args)...); + return reinterpret_cast(data); } static Object::FDeleter Deleter() { @@ -99,8 +109,17 @@ class SimpleObjAllocator : } private: - static void Deleter_(Object* ptr) { - delete static_cast(ptr); + static void Deleter_(Object* objptr) { + // NOTE: this is important to cast back to T* + // because objptr and tptr may not be the same + // depending on how sub-class allocates the space. + T* tptr = static_cast(objptr); + // It is important to do tptr->T::~T(), + // so that we explicitly call the specific destructor + // instead of tptr->~T(), which could mean the intention + // call a virtual destructor(which may not be available and is not required). + tptr->T::~T(); + delete reinterpret_cast(tptr); } }; }; diff --git a/include/tvm/runtime/node_base.h b/include/tvm/runtime/node_base.h deleted file mode 100644 index 8b47c18a09a7..000000000000 --- a/include/tvm/runtime/node_base.h +++ /dev/null @@ -1,259 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/runtime/node_base.h - * \brief Base data structure for Node. - * - * \note Node is not a runtime feature. - * This file only exposes the signature of NodePtr for PackedFunc. - */ -#ifndef TVM_RUNTIME_NODE_BASE_H_ -#define TVM_RUNTIME_NODE_BASE_H_ - -#include -#include - -namespace tvm { - -// forward declarations -template -class NodePtr; -class Node; -class NodeRef; - -/*! - * \brief Base class of Node for runtime destructor purposes. - * - * Node is a reference counted object which is used to construct AST. - * Each node is backed by a custom deleter, which deletes the object. - * Do not call create raw Node pointer, always use tvm::make_node. - * - * \note In most cases, please inheritate tvm::Node. - * \sa Node, NodePtr, make_node - */ -class NodeBase { - public: - /*! - * \brief type of NodeBase deleter - * \param self pointer to the NodeBase. - */ - typedef void (*FDeleter)(NodeBase* self); - - protected: - // default constructor and copy constructor - NodeBase() {} - // override the copy and assign constructors to do nothing. - // This is to make sure only contents, but not deleter and ref_counter - // are copied when a child class copies itself. - NodeBase(const NodeBase& other) { // NOLINT(*) - } - NodeBase(NodeBase&& other) { // NOLINT(*) - } - NodeBase& operator=(const NodeBase& other) { //NOLINT(*) - return *this; - } - NodeBase& operator=(NodeBase&& other) { //NOLINT(*) - return *this; - } - - private: - /*! \brief Internal reference counter */ - std::atomic ref_counter_{0}; - /*! - * \brief deleter of this object to enable customized allocation. - * If the deleter is nullptr, no deletion will be performed. - * The creator of the Node must always set the deleter field properly. - */ - FDeleter deleter_ = nullptr; - // reference counting functions - void IncRef() { - ref_counter_.fetch_add(1, std::memory_order_relaxed); - } - void DecRef() { - if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) { - std::atomic_thread_fence(std::memory_order_acquire); - if (this->deleter_ != nullptr) { - (*this->deleter_)(this); - } - } - } - int use_count() const { - return ref_counter_.load(std::memory_order_relaxed); - } - // friend declaration - template - friend class NodePtr; - template - friend NodePtr make_node(Args&&...); -}; - -/*! - * \brief Smart pointer for Node containers, - * must be subclass of NodeBase - * \tparam T the content data type. - */ -template -class NodePtr { - public: - /*! \brief default constructor */ - NodePtr() {} - /*! \brief default constructor */ - NodePtr(std::nullptr_t) {} // NOLINT(*) - /*! - * \brief copy constructor - * \param other The value to be moved - */ - NodePtr(const NodePtr& other) // NOLINT(*) - : NodePtr(other.data_) { - } - /*! - * \brief copy constructor - * \param other The value to be moved - */ - template - NodePtr(const NodePtr& other) // NOLINT(*) - : NodePtr(other.data_) { - static_assert(std::is_base_of::value, - "can only assign of child class NodePtr to parent"); - } - /*! - * \brief move constructor - * \param other The value to be moved - */ - NodePtr(NodePtr&& other) // NOLINT(*) - : data_(other.data_) { - other.data_ = nullptr; - } - /*! - * \brief move constructor - * \param other The value to be moved - */ - template - NodePtr(NodePtr&& other) // NOLINT(*) - : data_(other.data_) { - static_assert(std::is_base_of::value, - "can only assign of child class NodePtr to parent"); - other.data_ = nullptr; - } - /*! \brief destructor */ - ~NodePtr() { - this->reset(); - } - /*! - * \brief Swap this array with another NDArray - * \param other The other NDArray - */ - void swap(NodePtr& other) { // NOLINT(*) - std::swap(data_, other.data_); - } - /*! - * \return Get the content of the pointer - */ - T* get() const { - return static_cast(data_); - } - /*! - * \return The pointer - */ - T* operator->() const { - return get(); - } - /*! - * \return The reference - */ - T& operator*() const { // NOLINT(*) - return *get(); - } - /*! - * \brief copy assignmemt - * \param other The value to be assigned. - * \return reference to self. - */ - NodePtr& operator=(const NodePtr& other) { // NOLINT(*) - // takes in plane operator to enable copy elison. - // copy-and-swap idiom - NodePtr(other).swap(*this); // NOLINT(*) - return *this; - } - /*! - * \brief move assignmemt - * \param other The value to be assigned. - * \return reference to self. - */ - NodePtr& operator=(NodePtr&& other) { // NOLINT(*) - // copy-and-swap idiom - NodePtr(std::move(other)).swap(*this); // NOLINT(*) - return *this; - } - /*! \brief reset the content of ptr to be nullptr */ - void reset() { - if (data_ != nullptr) { - data_->DecRef(); - data_ = nullptr; - } - } - /*! \return The use count of the ptr, for debug purposes */ - int use_count() const { - return data_ != nullptr ? data_->use_count() : 0; - } - /*! \return whether the reference is unique */ - bool unique() const { - return data_ != nullptr && data_->use_count() == 1; - } - /*! \return Whether two NodePtr do not equals each other */ - bool operator==(const NodePtr& other) const { - return data_ == other.data_; - } - /*! \return Whether two NodePtr equals each other */ - bool operator!=(const NodePtr& other) const { - return data_ != other.data_; - } - /*! \return Whether the pointer is nullptr */ - bool operator==(std::nullptr_t null) const { - return data_ == nullptr; - } - /*! \return Whether the pointer is not nullptr */ - bool operator!=(std::nullptr_t null) const { - return data_ != nullptr; - } - - private: - /*! \brief internal pointer field */ - NodeBase* data_{nullptr}; - /*! - * \brief constructor from NodeBase - * \param data The node base pointer - */ - explicit NodePtr(NodeBase* data) - : data_(data) { - if (data != nullptr) { - data_->IncRef(); - } - } - // friend declaration - friend class Node; - template - friend class NodePtr; - template - friend NodePtr make_node(Args&&...); -}; -} // namespace tvm - -#endif // TVM_RUNTIME_NODE_BASE_H_ diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 7b0653ae5485..cc4a295cc5d4 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -23,6 +23,7 @@ #ifndef TVM_RUNTIME_OBJECT_H_ #define TVM_RUNTIME_OBJECT_H_ +#include #include #include #include @@ -51,7 +52,7 @@ enum TypeIndex { kRoot = 0, kVMTensor = 1, kVMClosure = 2, - kVMDatatype = 3, + kVMADT = 3, kStaticIndexEnd, /*! \brief Type index is allocated during runtime. */ kDynamic = kStaticIndexEnd @@ -65,7 +66,7 @@ enum TypeIndex { * - _type_index: * Static type index of the object, if assigned to TypeIndex::kDynamic * the type index will be assigned during runtime. - * Runtime type index can be accessed by ObjectType::type_index(); + * Runtime type index can be accessed by ObjectType::TypeIndex(); * - _type_key: * The unique string identifier of tyep type. * - _type_final: @@ -147,10 +148,23 @@ class Object { * \param self pointer to the Object. */ typedef void (*FDeleter)(Object* self); - /*! \return The internal type index of the object. */ + /*! \return The internal runtime type index of the object. */ uint32_t type_index() const { return type_index_; } + /*! + * \return the type key of the object. + * \note this operation is expensive, can be used for error reporting. + */ + std::string GetTypeKey() const { + return TypeIndex2Key(type_index_); + } + /*! + * \return A hash value of the return of GetTypeKey. + */ + size_t GetTypeKeyHash() const { + return TypeIndex2KeyHash(type_index_); + } /*! * Check if the object is an instance of TargetType. * \tparam TargetType The target type to be checked. @@ -159,19 +173,65 @@ class Object { template inline bool IsInstance() const; + /*! + * \brief Get the type key of the corresponding index from runtime. + * \param tindex The type index. + * \return the result. + */ + TVM_DLL static std::string TypeIndex2Key(uint32_t tindex); + /*! + * \brief Get the type key hash of the corresponding index from runtime. + * \param tindex The type index. + * \return the related key-hash. + */ + TVM_DLL static size_t TypeIndex2KeyHash(uint32_t tindex); + /*! + * \brief Get the type index of the corresponding key from runtime. + * \param key The type key. + * \return the result. + */ + TVM_DLL static uint32_t TypeKey2Index(const std::string& key); + #if TVM_OBJECT_ATOMIC_REF_COUNTER using RefCounterType = std::atomic; #else using RefCounterType = int32_t; #endif - // Object type properties static constexpr const char* _type_key = "Object"; + + static uint32_t _GetOrAllocRuntimeTypeIndex() { + return TypeIndex::kRoot; + } + static uint32_t RuntimeTypeIndex() { + return TypeIndex::kRoot; + } + + // Default object type properties for sub-classes static constexpr bool _type_final = false; static constexpr uint32_t _type_child_slots = 0; static constexpr bool _type_child_slots_can_overflow = true; - static const uint32_t _GetOrAllocRuntimeTypeIndex() { - return 0; + // NOTE: the following field is not type index of Object + // but was intended to be used by sub-classes as default value. + // The type index of Object is TypeIndex::kRoot + static constexpr uint32_t _type_index = TypeIndex::kDynamic; + + // Default constructor and copy constructor + Object() {} + // Override the copy and assign constructors to do nothing. + // This is to make sure only contents, but not deleter and ref_counter + // are copied when a child class copies itself. + // This will enable us to use make_object(*obj_ptr) + // to copy an existing object. + Object(const Object& other) { // NOLINT(*) + } + Object(Object&& other) { // NOLINT(*) + } + Object& operator=(const Object& other) { //NOLINT(*) + return *this; + } + Object& operator=(Object&& other) { //NOLINT(*) + return *this; } protected: @@ -209,25 +269,12 @@ class Object { * \return The allocated type index. */ TVM_DLL static uint32_t GetOrAllocRuntimeTypeIndex( - const char* key, + const std::string& key, uint32_t static_tindex, uint32_t parent_tindex, uint32_t type_child_slots, bool type_child_slots_can_overflow); - /*! - * \brief Get the type key of the corresponding index from runtime. - * \param tindex The type index. - */ - TVM_DLL static std::string TypeIndex2Key(uint32_t tindex); - - /*! - * \brief Get the type index of the corresponding key from runtime. - * \param key The type key. - */ - TVM_DLL static uint32_t TypeKey2Index(const char* key); - - private: // reference counter related operations /*! \brief developer function, increases reference counter. */ inline void IncRef(); @@ -253,8 +300,35 @@ class Object { template friend class ObjectPtr; friend class TVMRetValue; + friend class TVMObjectCAPI; }; +/*! + * \brief Get a reference type from a raw object ptr type + * + * It is always important to get a reference type + * if we want to return a value as reference or keep + * the node alive beyond the scope of the function. + * + * \param ptr The node pointer + * \tparam RefType The reference type + * \tparam ObjectType The node type + * \return The corresponding RefType + */ +template +inline RefType GetRef(const ObjectType* ptr); + +/*! + * \brief Downcast a base reference type to a more specific type. + * + * \param ref The inptut reference + * \return The corresponding SubRef. + * \tparam SubRef The target specific reference type. + * \tparam BaseRef the current reference type. + */ +template +inline SubRef Downcast(BaseRef ref); + /*! * \brief A custom smart pointer for Object. * \tparam T the content data type. @@ -388,7 +462,7 @@ class ObjectPtr { /*! \brief internal pointer field */ Object* data_{nullptr}; /*! - * \brief constructor from NodeBase + * \brief constructor from Object * \param data The data pointer */ explicit ObjectPtr(Object* data) : data_(data) { @@ -399,6 +473,7 @@ class ObjectPtr { // friend classes friend class Object; friend class ObjectRef; + friend struct ObjectHash; template friend class ObjectPtr; template @@ -406,6 +481,9 @@ class ObjectPtr { friend class TVMPODValue_; friend class TVMArgsSetter; friend class TVMRetValue; + friend class TVMArgValue; + template + friend RefType GetRef(const ObjType* ptr); }; /*! \brief Base class of all object reference */ @@ -415,10 +493,54 @@ class ObjectRef { ObjectRef() = default; /*! \brief Constructor from existing object ptr */ explicit ObjectRef(ObjectPtr data) : data_(data) {} + /*! + * \brief Comparator + * \param other Another object ref. + * \return the compare result. + */ + bool same_as(const ObjectRef& other) const { + return data_ == other.data_; + } + /*! + * \brief Comparator + * \param other Another object ref. + * \return the compare result. + */ + bool operator==(const ObjectRef& other) const { + return data_ == other.data_; + } + /*! + * \brief Comparator + * \param other Another node ref. + * \return the compare result. + */ + bool operator!=(const ObjectRef& other) const { + return data_ != other.data_; + } + /*! + * \brief Comparator + * \param other Another object ref by address. + * \return the compare result. + */ + bool operator<(const ObjectRef& other) const { + return data_.get() < other.data_.get(); + } + /*! \return whether the expression is null */ + bool defined() const { + return data_ != nullptr; + } /*! \return the internal object pointer */ - inline const Object* get() const; + const Object* get() const { + return data_.get(); + } /*! \return the internal node pointer */ - inline const Object* operator->() const; + const Object* operator->() const { + return get(); + } + /*! \return whether the reference is unique */ + bool unique() const { + return data_.unique(); + } /*! * \brief Try to downcast the internal Object to a * raw pointer of a corresponding type. @@ -433,25 +555,81 @@ class ObjectRef { template inline const ObjectType* as() const; - /*! \brief type indicate the container type */ + /*! \brief type indicate the container type. */ using ContainerType = Object; protected: /*! \brief Internal pointer that backs the reference. */ ObjectPtr data_; + /*! \return return a mutable internal ptr, can be used by sub-classes. */ + Object* get_mutable() const { + return data_.get(); + } + /*! + * \brief Internal helper function downcast a ref without check. + * \note Only used for internal dev purposes. + * \tparam T The target reference type. + * \return The casted result. + */ + template + static T DowncastNoCheck(ObjectRef ref) { + return T(std::move(ref.data_)); + } + /*! + * \brief Internal helper function get data_ as ObjectPtr of ObjectType. + * \note only used for internal dev purpose. + * \tparam ObjectType The corresponding object type. + * \return the corresponding type. + */ + template + static ObjectPtr GetDataPtr(const ObjectRef& ref) { + return ObjectPtr(ref.data_.data_); + } // friend classes. + friend struct ObjectHash; friend class TVMRetValue; friend class TVMArgsSetter; + template + friend SubRef Downcast(BaseRef ref); }; + +/*! \brief ObjectRef hash functor */ +struct ObjectHash { + size_t operator()(const ObjectRef& a) const { + return operator()(a.data_); + } + + template + size_t operator()(const ObjectPtr& a) const { + return std::hash()(a.get()); + } +}; + + +/*! \brief ObjectRef equal functor */ +struct ObjectEqual { + bool operator()(const ObjectRef& a, const ObjectRef& b) const { + return a.same_as(b); + } + + template + size_t operator()(const ObjectPtr& a, const ObjectPtr& b) const { + return a == b; + } +}; + + /*! * \brief helper macro to declare a base object type that can be inheritated. * \param TypeName The name of the current type. * \param ParentType The name of the ParentType */ #define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ - static const uint32_t type_index() { \ - if (_type_index != TypeIndex::kDynamic) return _type_index; \ + static const uint32_t RuntimeTypeIndex() { \ + if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \ + return TypeName::_type_index; \ + } \ return _GetOrAllocRuntimeTypeIndex(); \ } \ static const uint32_t _GetOrAllocRuntimeTypeIndex() { \ @@ -550,11 +728,11 @@ inline bool Object::IsInstance() const { if (TargetType::_type_final) { // if the target type is a final type // then we only need to check the equivalence. - return self->type_index_ == TargetType::type_index(); + return self->type_index_ == TargetType::RuntimeTypeIndex(); } else { // if target type is a non-leaf type // Check if type index falls into the range of reserved slots. - uint32_t begin = TargetType::type_index(); + uint32_t begin = TargetType::RuntimeTypeIndex(); // The condition will be optimized by constant-folding. if (TargetType::_type_child_slots != 0) { uint32_t end = begin + TargetType::_type_child_slots; @@ -564,22 +742,15 @@ inline bool Object::IsInstance() const { } if (!TargetType::_type_child_slots_can_overflow) return false; // Invariance: parent index is always smaller than the child. - if (self->type_index_ < TargetType::type_index()) return false; + if (self->type_index_ < TargetType::RuntimeTypeIndex()) return false; // The rare slower-path, check type hierachy. - return self->DerivedFrom(TargetType::type_index()); + return self->DerivedFrom(TargetType::RuntimeTypeIndex()); } } else { return false; } } -inline const Object* ObjectRef::get() const { - return data_.data_; -} - -inline const Object* ObjectRef::operator->() const { - return get(); -} template inline const ObjectType* ObjectRef::as() const { @@ -590,7 +761,27 @@ inline const ObjectType* ObjectRef::as() const { return nullptr; } } + +template +inline RefType GetRef(const ObjType* ptr) { + static_assert(std::is_base_of::value, + "Can only cast to the ref of same container type"); + return RefType(ObjectPtr(const_cast(static_cast(ptr)))); +} + +template +inline SubRef Downcast(BaseRef ref) { + CHECK(ref->template IsInstance()) + << "Downcast from " << ref->GetTypeKey() << " to " + << SubRef::ContainerType::_type_key << " failed."; + return SubRef(std::move(ref.data_)); +} + } // namespace runtime + +template +using NodePtr = runtime::ObjectPtr; + } // namespace tvm #endif // TVM_RUNTIME_OBJECT_H_ diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 5b71bbc66142..a42946ac2d2c 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -40,7 +40,6 @@ #include "module.h" #include "ndarray.h" #include "object.h" -#include "node_base.h" // Whether use TVM runtime in header only mode. #ifndef TVM_RUNTIME_HEADER_ONLY @@ -490,9 +489,12 @@ class TVMPODValue_ { return NDArray(static_cast(value_.v_handle)); } operator ObjectRef() const { - if (type_code_ == kNull) return ObjectRef(ObjectPtr(nullptr)); - TVM_CHECK_TYPE_CODE(type_code_, kObjectCell); - return ObjectRef(ObjectPtr(static_cast(value_.v_handle))); + if (type_code_ == kNull) { + return ObjectRef(ObjectPtr(nullptr)); + } + TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); + return ObjectRef( + ObjectPtr(static_cast(value_.v_handle))); } operator TVMContext() const { TVM_CHECK_TYPE_CODE(type_code_, kTVMContext); @@ -512,9 +514,14 @@ class TVMPODValue_ { CHECK_LT(type_code_, kExtEnd); return static_cast(value_.v_handle)[0]; } + template::value>::type> + inline bool IsObjectRef() const; int type_code() const { return type_code_; } + /*! * \brief return handle as specific pointer type. * \tparam T the data type. @@ -567,6 +574,7 @@ class TVMArgValue : public TVMPODValue_ { using TVMPODValue_::operator NDArray; using TVMPODValue_::operator TVMContext; using TVMPODValue_::operator ObjectRef; + using TVMPODValue_::IsObjectRef; // conversion operator. operator std::string() const { @@ -610,21 +618,15 @@ class TVMArgValue : public TVMPODValue_ { return value_; } // Deferred extension handler. - template - inline TNodeRef AsNodeRef() const; + template + inline TObjectRef AsObjectRef() const; template::value>::type> inline operator T() const; - template::value>::type> - inline bool IsNodeType() const; inline operator tvm::DataType() const; inline operator tvm::Expr() const; inline operator tvm::Integer() const; - // get internal node ptr, if it is node - inline NodePtr& node_sptr(); }; /*! @@ -663,6 +665,8 @@ class TVMRetValue : public TVMPODValue_ { using TVMPODValue_::operator TVMContext; using TVMPODValue_::operator NDArray; using TVMPODValue_::operator ObjectRef; + using TVMPODValue_::IsObjectRef; + TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { this->Assign(other); } @@ -760,11 +764,19 @@ class TVMRetValue : public TVMPODValue_ { return *this; } TVMRetValue& operator=(ObjectRef other) { - this->Clear(); - type_code_ = kObjectCell; - // move the handle out - value_.v_handle = other.data_.data_; - other.data_.data_ = nullptr; + return operator=(std::move(other.data_)); + } + template + TVMRetValue& operator=(ObjectPtr other) { + if (other.data_ != nullptr) { + this->Clear(); + type_code_ = kObjectHandle; + // move the handle out + value_.v_handle = other.data_; + other.data_ = nullptr; + } else { + SwitchToPOD(kNull); + } return *this; } TVMRetValue& operator=(PackedFunc f) { @@ -814,21 +826,19 @@ class TVMRetValue : public TVMPODValue_ { } /*! \return The value field, if the data is POD */ const TVMValue& value() const { - CHECK(type_code_ != kNodeHandle && + CHECK(type_code_ != kObjectHandle && type_code_ != kFuncHandle && type_code_ != kModuleHandle && type_code_ != kStr) << "TVMRetValue.value can only be used for POD data"; return value_; } - // NodeRef related extenstions: in tvm/packed_func_ext.h + // ObjectRef related extenstions: in tvm/packed_func_ext.h template::value>::type> inline operator T() const; - template - inline TNodeRef AsNodeRef() const; - inline TVMRetValue& operator=(const NodeRef& other); - inline TVMRetValue& operator=(const NodePtr& other); + template + inline TObjectRef AsObjectRef() const; // type related inline operator tvm::DataType() const; inline TVMRetValue& operator=(const tvm::DataType& other); @@ -857,12 +867,7 @@ class TVMRetValue : public TVMPODValue_ { *this = other.operator NDArray(); break; } - case kNodeHandle: { - SwitchToClass >( - kNodeHandle, *other.template ptr >()); - break; - } - case kObjectCell: { + case kObjectHandle: { *this = other.operator ObjectRef(); break; } @@ -908,12 +913,11 @@ class TVMRetValue : public TVMPODValue_ { case kStr: delete ptr(); break; case kFuncHandle: delete ptr(); break; case kModuleHandle: delete ptr(); break; - case kNodeHandle: delete ptr >(); break; case kNDArrayContainer: { static_cast(value_.v_handle)->DecRef(); break; } - case kObjectCell: { + case kObjectHandle: { static_cast(value_.v_handle)->DecRef(); break; } @@ -939,14 +943,13 @@ inline const char* TypeCode2Str(int type_code) { case kBytes: return "bytes"; case kHandle: return "handle"; case kNull: return "NULL"; - case kNodeHandle: return "NodeHandle"; case kArrayHandle: return "ArrayHandle"; case kTVMType: return "TVMType"; case kTVMContext: return "TVMContext"; case kFuncHandle: return "FunctionHandle"; case kModuleHandle: return "ModuleHandle"; case kNDArrayContainer: return "NDArrayContainer"; - case kObjectCell: return "ObjectCell"; + case kObjectHandle: return "ObjectCell"; default: LOG(FATAL) << "unknown type_code=" << static_cast(type_code); return ""; } @@ -1057,8 +1060,6 @@ inline PackedFunc::FType PackedFunc::body() const { return body_; } - - // internal namespace namespace detail { @@ -1163,8 +1164,12 @@ class TVMArgsSetter { type_codes_[i] = kNDArrayContainer; } void operator()(size_t i, const ObjectRef& value) const { // NOLINT(*) - values_[i].v_handle = value.data_.data_; - type_codes_[i] = kObjectCell; + if (value.defined()) { + values_[i].v_handle = value.data_.data_; + type_codes_[i] = kObjectHandle; + } else { + type_codes_[i] = kNull; + } } void operator()(size_t i, const TVMRetValue& value) const { // NOLINT(*) if (value.type_code() == kStr) { @@ -1181,8 +1186,6 @@ class TVMArgsSetter { typename = typename std::enable_if< extension_type_info::code != 0>::type> inline void operator()(size_t i, const T& value) const; - // NodeRef related extenstions: in tvm/packed_func_ext.h - inline void operator()(size_t i, const NodeRef& other) const; // NOLINT(*) inline void operator()(size_t i, const tvm::DataType& t) const; private: @@ -1301,7 +1304,7 @@ template struct TVMValueCast { static T Apply(const TSrc* self) { static_assert(!is_ext && !is_nd, "The default case accepts only non-extensions"); - return self->template AsNodeRef(); + return self->template AsObjectRef(); } }; diff --git a/include/tvm/runtime/registry.h b/include/tvm/runtime/registry.h index 40e1a520cb67..d668984f50e2 100644 --- a/include/tvm/runtime/registry.h +++ b/include/tvm/runtime/registry.h @@ -91,7 +91,7 @@ class Registry { * Note that this will ignore default arg values and always require all arguments to be provided. * * \code - * + * * int multiply(int x, int y) { * return x * y; * } @@ -115,7 +115,7 @@ class Registry { * Note that this will ignore default arg values and always require all arguments to be provided. * * \code - * + * * // node subclass: * struct Example { * int doThing(int x); @@ -143,7 +143,7 @@ class Registry { * Note that this will ignore default arg values and always require all arguments to be provided. * * \code - * + * * // node subclass: * struct Example { * int doThing(int x); @@ -168,22 +168,22 @@ class Registry { /*! * \brief set the body of the function to be the passed method pointer. - * Used when calling a method on a Node subclass through a NodeRef subclass. + * Used when calling a method on a Node subclass through a ObjectRef subclass. * Note that this will ignore default arg values and always require all arguments to be provided. * * \code - * + * * // node subclass: * struct ExampleNode: BaseNode { * int doThing(int x); * } - * + * * // noderef subclass - * struct Example; + * struct Example; * * TVM_REGISTER_API("Example_doThing") * .set_body_method(&ExampleNode::doThing); // will have type int(Example, int) - * + * * // note that just doing: * // .set_body_method(&ExampleNode::doThing); * // wouldn't work, because ExampleNode can't be taken from a TVMArgValue. @@ -191,15 +191,15 @@ class Registry { * \endcode * * \param f the method pointer to forward to. - * \tparam TNodeRef the node reference type to call the method on + * \tparam TObjectRef the node reference type to call the method on * \tparam TNode the node type containing the method (inferred). * \tparam R the return type of the function (inferred). * \tparam Args the argument types of the function (inferred). */ - template::value>::type> + template::value>::type> Registry& set_body_method(R (TNode::*f)(Args...)) { - return set_body_typed([f](TNodeRef ref, Args... params) { + return set_body_typed([f](TObjectRef ref, Args... params) { TNode* target = ref.operator->(); // call method pointer return (target->*f)(params...); @@ -208,22 +208,22 @@ class Registry { /*! * \brief set the body of the function to be the passed method pointer. - * Used when calling a method on a Node subclass through a NodeRef subclass. + * Used when calling a method on a Node subclass through a ObjectRef subclass. * Note that this will ignore default arg values and always require all arguments to be provided. * * \code - * + * * // node subclass: * struct ExampleNode: BaseNode { * int doThing(int x); * } - * + * * // noderef subclass - * struct Example; + * struct Example; * * TVM_REGISTER_API("Example_doThing") * .set_body_method(&ExampleNode::doThing); // will have type int(Example, int) - * + * * // note that just doing: * // .set_body_method(&ExampleNode::doThing); * // wouldn't work, because ExampleNode can't be taken from a TVMArgValue. @@ -231,15 +231,15 @@ class Registry { * \endcode * * \param f the method pointer to forward to. - * \tparam TNodeRef the node reference type to call the method on + * \tparam TObjectRef the node reference type to call the method on * \tparam TNode the node type containing the method (inferred). * \tparam R the return type of the function (inferred). * \tparam Args the argument types of the function (inferred). */ - template::value>::type> + template::value>::type> Registry& set_body_method(R (TNode::*f)(Args...) const) { - return set_body_typed([f](TNodeRef ref, Args... params) { + return set_body_typed([f](TObjectRef ref, Args... params) { const TNode* target = ref.operator->(); // call method pointer return (target->*f)(params...); diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index aa8543d569af..ee973cb62092 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -56,31 +57,31 @@ class Tensor : public ObjectRef { /*! \brief An object representing a structure or enumeration. */ -class DatatypeObj : public Object { +class ADTObj : public Object { public: /*! \brief The tag representing the constructor used. */ size_t tag; /*! \brief The fields of the structure. */ std::vector fields; - static constexpr const uint32_t _type_index = TypeIndex::kVMDatatype; - static constexpr const char* _type_key = "vm.Datatype"; - TVM_DECLARE_FINAL_OBJECT_INFO(DatatypeObj, Object); + static constexpr const uint32_t _type_index = TypeIndex::kVMADT; + static constexpr const char* _type_key = "vm.ADT"; + TVM_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object); }; -/*! \brief reference to data type. */ -class Datatype : public ObjectRef { +/*! \brief reference to algebraic data type objects. */ +class ADT : public ObjectRef { public: - Datatype(size_t tag, std::vector fields); + ADT(size_t tag, std::vector fields); /*! * \brief construct a tuple object. * \param fields The fields of the tuple. * \return The constructed tuple type. */ - static Datatype Tuple(std::vector fields); + static ADT Tuple(std::vector fields); - TVM_DEFINE_OBJECT_REF_METHODS(Datatype, ObjectRef, DatatypeObj); + TVM_DEFINE_OBJECT_REF_METHODS(ADT, ObjectRef, ADTObj); }; /*! \brief An object representing a closure. */ @@ -128,7 +129,7 @@ enum class Opcode { InvokePacked = 4U, AllocTensor = 5U, AllocTensorReg = 6U, - AllocDatatype = 7U, + AllocADT = 7U, AllocClosure = 8U, GetField = 9U, If = 10U, @@ -236,7 +237,7 @@ struct Instruction { /*! \brief The register to project from. */ RegName object; } get_tag; - struct /* AllocDatatype Operands */ { + struct /* AllocADT Operands */ { /*! \brief The datatype's constructor tag. */ Index constructor_tag; /*! \brief The number of fields to store in the datatype. */ @@ -293,7 +294,7 @@ struct Instruction { * \param dst The register name of the destination. * \return The allocate instruction tensor. */ - static Instruction AllocDatatype(Index tag, Index num_fields, const std::vector& fields, + static Instruction AllocADT(Index tag, Index num_fields, const std::vector& fields, RegName dst); /*! \brief Construct an allocate closure instruction. * \param func_index The index of the function table. @@ -430,15 +431,184 @@ struct VMFrame { caller_return_register(0) {} }; +/*! \brief The executable emitted by the VM compiler. + * + * The executable contains information (e.g. data in different memory regions) + * to run in a virtual machine. + * + * - Global section, containing all globals. + * - Constant section, storing the constant pool. + * - Primitive name section, containing the function name of the primitive ops + * used by the virtual machine. + * - Code section, handling the VM functions and bytecode. + */ +class Executable : public ModuleNode { + public: + /*! + * \brief Get a PackedFunc from an executable module. + * + * \param name the name of the function. + * \param sptr_to_self The shared_ptr that points to this module node. + * + * \return PackedFunc or nullptr when it is not available. + */ + PackedFunc GetFunction(const std::string& name, + const std::shared_ptr& sptr_to_self) final; + + /*! + * \brief Serialize the executable into global section, constant section, and + * code section. + * + * \return The binary representation of the VM. + */ + TVMByteArray Save(); + + /*! + * \brief Load the saved VM executable. + * + * \param code The bytecode in string. + * \param lib The compiled runtime library. + * + * \return exe The constructed executable. + */ + static runtime::Module Load(const std::string& code, const runtime::Module lib); + + /*! + * \brief Get the serialized form of the `functions`. This is + * essentially bytecode serialization. + * + * \return The serialized vm bytecode. + * + * \note The bytecode is in the following format: + * func_name reg_file_size num_instructions + * param1 param2 ... paramM + * instruction1 + * instruction2 + * ... + * instructionN + * + * Each instruction is printed in the following format: + * opcode num_fields field1 ... fieldX # The text format. + * + * Serializing an `Instruction` requires us to deal with the bytecode. Each line + * of the instructions could be serialized as the following format: + * hash, opcode, f1, f2, ..., fX, field with variable length + * 1. hash: the hash of the instruction. This number will be used to help us + * validate if an instruction is well-formed during deserialization. + * 2. opcode: the opcode code of the instruction. + * 3. f1, f2, ..., fX. These fields together represent the fixed fields in + * an instruction, e.g., `from` and `dst` fields of a `Move` instruction. For + * example, `DLDataType` will be unpacked into three fields (code, bits, lanes). + * 4. The rest of the line indicates the field with variable length, e.g., + * the shape of a tensor, the args used by an `InvokPacked` instruction, etc. + + * The field starting from # is only used for debugging. The serialized code + * doesn't contain it, therefore the deserializer doens't need to handle it. + */ + std::string GetBytecode() const; + +/*! + * \brief Print the detailed statistics of the given code, i.e. number of + * globls and constants, etc. + */ + std::string Stats() const; + + /*! \brief Get the `lib` module in an executable. Users have the flexibility to call + * `export_library` from the frontend to save the library to disk. + * + * \return The runtime module that contains the hardwre dependent code. + */ + runtime::Module GetLib() const { return lib; } + + virtual ~Executable() {} + + const char* type_key() const final { + return "VMExecutable"; + } + + /*! \brief The runtime module/library that contains both the host and also the device + * code when executing on non-CPU devices. */ + runtime::Module lib; + /*! \brief The global constant pool. */ + std::vector constants; + /*! \brief A map from globals (as strings) to their index in the function map. */ + std::unordered_map global_map; + /*! \brief A mapping from the packed function (as string) to the index that + * corresponds to the position of the `packed_funcs` list in a `VirtualMachine` object. + */ + std::unordered_map primitive_map; + /*! \brief The virtual machine's function table. */ + std::vector functions; + + private: + /*! + * \brief Save the globals. + * + * \param strm The input stream. + */ + void SaveGlobalSection(dmlc::Stream* strm); + + /*! + * \brief Save the constant pool. + * + * \param strm The input stream. + */ + void SaveConstantSection(dmlc::Stream* strm); + + /*! + * \brief Save primitive op names. + * + * \param strm The input stream. + */ + void SavePrimitiveOpNames(dmlc::Stream* strm); + + /*! + * \brief Save the vm functions. + * + * \param strm The input stream. + */ + void SaveCodeSection(dmlc::Stream* strm); + + /*! + * \brief Load the globals. + * + * \param strm The input stream. + */ + void LoadGlobalSection(dmlc::Stream* strm); + + /*! + * \brief Load the constant pool. + * + * \param strm The input stream. + */ + void LoadConstantSection(dmlc::Stream* strm); + + /*! + * \brief Load primitive op names. + * + * \param strm The input stream. + */ + void LoadPrimitiveOpNames(dmlc::Stream* strm); + + /*! + * \brief Load the vm functions. + * + * \param strm The input stream. + */ + void LoadCodeSection(dmlc::Stream* strm); + + /*! \brief The serialized bytecode. */ + std::string code_; +}; + /*! \brief The virtual machine. * * The virtual machine contains all the current execution state, - * as well as the global view of functions, the global constant - * table, the compiled operators. + * as well as the executable. * * The goal is to have a single self-contained object, * enabling one to easily pass around VMs, execute them on - * multiple threads, or serialized them to disk or over the + * multiple threads, or serialize them to disk or over the * wire. */ class VirtualMachine : public runtime::ModuleNode { @@ -486,16 +656,18 @@ class VirtualMachine : public runtime::ModuleNode { return "VirtualMachine"; } - /*! \brief The runtime module/library that contains generated code. */ - runtime::Module lib; + VirtualMachine() : frames(), func_index(0), code(nullptr), pc(0), exec(nullptr) {} + + /*! \brief load the executable for the virtual machine. + * \param exec The executable. + */ + void LoadExecutable(const Executable* exec); + + protected: /*! \brief The virtual machine's packed function table. */ std::vector packed_funcs; - /*! \brief The virtual machine's function table. */ - std::vector functions; /*! \brief The current stack of call frames. */ std::vector frames; - /*! \brief The global constant pool. */ - std::vector constants; /*! \brief The fuction table index of the current function. */ Index func_index; /*! \brief The current pointer to the code section. */ @@ -506,6 +678,9 @@ class VirtualMachine : public runtime::ModuleNode { /*! \brief The special return register. */ ObjectRef return_register; + /*! \brief The executable the VM will operate on. */ + const Executable* exec; + /*! \brief The set of TVM contexts the VM is currently executing on. */ std::vector ctxs; @@ -550,8 +725,6 @@ class VirtualMachine : public runtime::ModuleNode { */ ObjectRef Invoke(const std::string& name, const std::vector& args); - VirtualMachine() : functions(), frames(), func_index(0), code(nullptr), pc(0) {} - /*! \brief Initialize the virtual machine for a set of contexts. * \param contexts The set of TVM contexts. */ @@ -565,21 +738,6 @@ class VirtualMachine : public runtime::ModuleNode { */ TVMContext GetParamsContext() const; - /*! - * \brief Load parameters from the parameter bytearray. - * \param params The binary file that contains parameters. - */ - void LoadParams(const std::string& params); - - /*! \brief A map from globals (as strings) to their index in the function map. - */ - std::unordered_map global_map; - - /*! \brief A mapping from the packed function (as string) to the index that - * corresponds to the position of the `packed_funcs` list. - */ - std::unordered_map primitive_map; - private: /*! \brief Invoke a global setting up the VM state to execute. * @@ -589,6 +747,12 @@ class VirtualMachine : public runtime::ModuleNode { /*! \brief The parameter name to data mapping. */ std::unordered_map params_; + + /*! + * \brief The constant pool for runtime. It caches the device dependent + * object to avoid rellocation of constants during inference. + */ + std::vector const_pool_; }; } // namespace vm diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index af3e929ac3fa..3f4ee38a7695 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -56,7 +56,7 @@ enum AttachType : int { class Stage : public NodeRef { public: Stage() {} - explicit Stage(NodePtr n) : NodeRef(n) {} + explicit Stage(ObjectPtr n) : NodeRef(n) {} /*! * \brief create a new schedule for op. * \param op The operator in the schedule @@ -280,7 +280,7 @@ class Stage : public NodeRef { class Schedule : public NodeRef { public: Schedule() {} - explicit Schedule(NodePtr n) : NodeRef(n) {} + explicit Schedule(ObjectPtr n) : NodeRef(n) {} /*! * \brief Get a copy of current schedule. * \return The copied schedule. @@ -403,7 +403,7 @@ class Schedule : public NodeRef { class IterVarRelation : public NodeRef { public: IterVarRelation() {} - explicit IterVarRelation(NodePtr n) : NodeRef(n) {} + explicit IterVarRelation(ObjectPtr n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -417,7 +417,7 @@ class IterVarRelation : public NodeRef { class IterVarAttr : public NodeRef { public: IterVarAttr() {} - explicit IterVarAttr(NodePtr n) : NodeRef(n) {} + explicit IterVarAttr(ObjectPtr n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -495,7 +495,7 @@ class StageNode : public Node { /*! \brief Number of direct child stages, only used for group stage.*/ int num_child_stages{0}; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("op", &op); v->Visit("origin_op", &origin_op); v->Visit("all_iter_vars", &all_iter_vars); @@ -540,7 +540,7 @@ class ScheduleNode : public Node { */ std::unordered_map op2stage_cache_; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("outputs", &outputs); v->Visit("stages", &stages); v->Visit("groups", &groups); @@ -617,7 +617,7 @@ class IterVarAttrNode : public Node { */ Array pragma_values; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("iter_type", &iter_type); v->Visit("bind_thread", &bind_thread); v->Visit("prefetch_data", &prefetch_data); @@ -657,7 +657,7 @@ class SplitNode : public IterVarRelationNode { /*! \brief Number of parts, only factor or nparts can be given */ Expr nparts; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("parent", &parent); v->Visit("outer", &outer); v->Visit("inner", &inner); @@ -687,7 +687,7 @@ class FuseNode : public IterVarRelationNode { /*! \brief The target domain */ IterVar fused; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("outer", &outer); v->Visit("inner", &inner); v->Visit("fused", &fused); @@ -712,7 +712,7 @@ class RebaseNode : public IterVarRelationNode { /*! \brief The inner domain */ IterVar rebased; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("parent", &parent); v->Visit("rebased", &rebased); } @@ -732,7 +732,7 @@ class SingletonNode : public IterVarRelationNode { /*! \brief The singleton iterator */ IterVar iter; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("iter", &iter); } @@ -745,25 +745,25 @@ class SingletonNode : public IterVarRelationNode { // implementations inline const StageNode* Stage::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } inline StageNode* Stage::operator->() { - return static_cast(node_.get()); + return static_cast(get_mutable()); } inline const ScheduleNode* Schedule::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } inline ScheduleNode* Schedule::operator->() { - return static_cast(node_.get()); + return static_cast(get_mutable()); } inline const IterVarRelationNode* IterVarRelation::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } inline const IterVarAttrNode* IterVarAttr::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } } // namespace tvm #endif // TVM_SCHEDULE_H_ diff --git a/include/tvm/target_info.h b/include/tvm/target_info.h index 1e3a7686ca00..86cb0e275609 100644 --- a/include/tvm/target_info.h +++ b/include/tvm/target_info.h @@ -47,7 +47,7 @@ struct MemoryInfoNode : public Node { */ Expr head_address; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("unit_bits", &unit_bits); v->Visit("max_num_bits", &max_num_bits); v->Visit("max_simd_bits", &max_simd_bits); diff --git a/include/tvm/tensor.h b/include/tvm/tensor.h index f37cc7bed7d1..599d6ff657d1 100644 --- a/include/tvm/tensor.h +++ b/include/tvm/tensor.h @@ -50,7 +50,7 @@ class Tensor : public NodeRef { public: /*! \brief default constructor, used internally */ Tensor() {} - explicit Tensor(NodePtr n) : NodeRef(n) {} + explicit Tensor(ObjectPtr n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -141,7 +141,7 @@ class Operation : public ir::FunctionRef { public: /*! \brief default constructor */ Operation() {} - explicit Operation(NodePtr n) : FunctionRef(n) {} + explicit Operation(ObjectPtr n) : FunctionRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -171,7 +171,7 @@ class TensorNode : public Node { /*! \brief constructor */ TensorNode() {} - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("shape", &shape); v->Visit("dtype", &dtype); v->Visit("op", &op); @@ -189,7 +189,7 @@ class TensorNode : public Node { // Implementations of inline functions inline const TensorNode* Tensor::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } inline size_t Tensor::ndim() const { @@ -250,19 +250,17 @@ DEFINE_OVERLOAD_SLICE_BINARY_OP(<); // NOLINT(*) namespace std { template <> -struct hash<::tvm::Operation> { - std::size_t operator()(const ::tvm::Operation& k) const { - return k.hash(); - } +struct hash<::tvm::Operation> : public ::tvm::NodeHash { }; template <> struct hash<::tvm::Tensor> { std::size_t operator()(const ::tvm::Tensor& k) const { + ::tvm::NodeHash hasher; if (k.defined() && k->op.defined()) { - return k->op.hash(); + return hasher(k->op); } else{ - return k.hash(); + return hasher(k); } } }; diff --git a/include/tvm/tensor_intrin.h b/include/tvm/tensor_intrin.h index b5ca6eb4358b..0d4795ad5440 100644 --- a/include/tvm/tensor_intrin.h +++ b/include/tvm/tensor_intrin.h @@ -87,7 +87,7 @@ class TensorIntrinNode : public Node { /*! \brief constructor */ TensorIntrinNode() {} - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); v->Visit("op", &op); v->Visit("inputs", &inputs); @@ -112,7 +112,7 @@ class TensorIntrinNode : public Node { }; inline const TensorIntrinNode* TensorIntrin::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } // Internal node container of tensor intrinsic calling. @@ -152,7 +152,7 @@ class TensorIntrinCallNode : public Node { /*! \brief scalar expression inputs */ Array scalar_inputs; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("intrin", &intrin); v->Visit("tensors", &tensors); v->Visit("regions", ®ions); @@ -170,7 +170,7 @@ class TensorIntrinCallNode : public Node { }; inline const TensorIntrinCallNode* TensorIntrinCall::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } } // namespace tvm diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/ConnectProxyServerProcessor.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/ConnectProxyServerProcessor.java index 2fc97f65aca4..04888f568be3 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/ConnectProxyServerProcessor.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/ConnectProxyServerProcessor.java @@ -30,7 +30,6 @@ public class ConnectProxyServerProcessor implements ServerProcessor { private final String host; private final int port; private final String key; - private final SocketFileDescriptorGetter socketFileDescriptorGetter; private volatile Socket currSocket = new Socket(); private Runnable callback; @@ -40,14 +39,11 @@ public class ConnectProxyServerProcessor implements ServerProcessor { * @param host Proxy server host. * @param port Proxy server port. * @param key Proxy server key. - * @param sockFdGetter Method to get file descriptor from Java socket. */ - public ConnectProxyServerProcessor(String host, int port, String key, - SocketFileDescriptorGetter sockFdGetter) { + public ConnectProxyServerProcessor(String host, int port, String key) { this.host = host; this.port = port; this.key = "server:" + key; - socketFileDescriptorGetter = sockFdGetter; } /** @@ -70,8 +66,8 @@ public void setStartTimeCallback(Runnable callback) { try { SocketAddress address = new InetSocketAddress(host, port); currSocket.connect(address, 6000); - InputStream in = currSocket.getInputStream(); - OutputStream out = currSocket.getOutputStream(); + final InputStream in = currSocket.getInputStream(); + final OutputStream out = currSocket.getOutputStream(); out.write(Utils.toBytes(RPC.RPC_MAGIC)); out.write(Utils.toBytes(key.length())); out.write(Utils.toBytes(key)); @@ -91,11 +87,10 @@ public void setStartTimeCallback(Runnable callback) { if (callback != null) { callback.run(); } - final int sockFd = socketFileDescriptorGetter.get(currSocket); - if (sockFd != -1) { - new NativeServerLoop(sockFd).run(); - System.err.println("Finish serving " + address); - } + + SocketChannel sockChannel = new SocketChannel(currSocket); + new NativeServerLoop(sockChannel.getFsend(), sockChannel.getFrecv()).run(); + System.err.println("Finish serving " + address); } catch (Throwable e) { e.printStackTrace(); throw new RuntimeException(e); diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/ConnectTrackerServerProcessor.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/ConnectTrackerServerProcessor.java index 47881eb350c3..c449bb18a565 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/ConnectTrackerServerProcessor.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/ConnectTrackerServerProcessor.java @@ -37,7 +37,6 @@ */ public class ConnectTrackerServerProcessor implements ServerProcessor { private ServerSocket server; - private final SocketFileDescriptorGetter socketFileDescriptorGetter; private final String trackerHost; private final int trackerPort; // device key @@ -62,10 +61,11 @@ public class ConnectTrackerServerProcessor implements ServerProcessor { * @param trackerHost Tracker host. * @param trackerPort Tracker port. * @param key Device key. - * @param sockFdGetter Method to get file descriptor from Java socket. + * @param watchdog watch for timeout, etc. + * @throws java.io.IOException when socket fails to open. */ public ConnectTrackerServerProcessor(String trackerHost, int trackerPort, String key, - SocketFileDescriptorGetter sockFdGetter, RPCWatchdog watchdog) throws IOException { + RPCWatchdog watchdog) throws IOException { while (true) { try { this.server = new ServerSocket(serverPort); @@ -81,7 +81,6 @@ public ConnectTrackerServerProcessor(String trackerHost, int trackerPort, String } } System.err.println("using port: " + serverPort); - this.socketFileDescriptorGetter = sockFdGetter; this.trackerHost = trackerHost; this.trackerPort = trackerPort; this.key = key; @@ -163,11 +162,9 @@ public String getMatchKey() { System.err.println("Connection from " + socket.getRemoteSocketAddress().toString()); // received timeout in seconds watchdog.startTimeout(timeout * 1000); - final int sockFd = socketFileDescriptorGetter.get(socket); - if (sockFd != -1) { - new NativeServerLoop(sockFd).run(); - System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString()); - } + SocketChannel sockChannel = new SocketChannel(socket); + new NativeServerLoop(sockChannel.getFsend(), sockChannel.getFrecv()).run(); + System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString()); Utils.closeQuietly(socket); } catch (ConnectException e) { // if the tracker connection failed, wait a bit before retrying diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/NativeServerLoop.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/NativeServerLoop.java index 255dabb438d5..697ce45fa04f 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/NativeServerLoop.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/NativeServerLoop.java @@ -28,14 +28,17 @@ * Call native ServerLoop on socket file descriptor. */ public class NativeServerLoop implements Runnable { - private final int sockFd; + private final Function fsend; + private final Function frecv; /** * Constructor for NativeServerLoop. - * @param nativeSockFd native socket file descriptor. + * @param fsend socket.send function. + * @param frecv socket.recv function. */ - public NativeServerLoop(final int nativeSockFd) { - sockFd = nativeSockFd; + public NativeServerLoop(final Function fsend, final Function frecv) { + this.fsend = fsend; + this.frecv = frecv; } @Override public void run() { @@ -43,7 +46,7 @@ public NativeServerLoop(final int nativeSockFd) { try { tempDir = serverEnv(); System.err.println("starting server loop..."); - RPC.getApi("_ServerLoop").pushArg(sockFd).invoke(); + RPC.getApi("_ServerLoop").pushArg(fsend).pushArg(frecv).invoke(); System.err.println("done server loop..."); } catch (IOException e) { e.printStackTrace(); diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java index 8ebf188b0667..278ef9fe8eef 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java @@ -200,6 +200,7 @@ public void upload(byte[] data, String target) { * Upload file to remote runtime temp folder. * @param data The file in local to upload. * @param target The path in remote. + * @throws java.io.IOException for network failure. */ public void upload(File data, String target) throws IOException { byte[] blob = getBytesFromFile(data); @@ -209,6 +210,7 @@ public void upload(File data, String target) throws IOException { /** * Upload file to remote runtime temp folder. * @param data The file in local to upload. + * @throws java.io.IOException for network failure. */ public void upload(File data) throws IOException { upload(data, data.getName()); diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java index c81faa0ca999..a9ea2d89a62c 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java @@ -17,31 +17,12 @@ package ml.dmlc.tvm.rpc; -import sun.misc.SharedSecrets; - -import java.io.FileDescriptor; -import java.io.FileInputStream; import java.io.IOException; -import java.io.InputStream; -import java.net.Socket; /** * RPC Server. */ public class Server { - private static SocketFileDescriptorGetter defaultSocketFdGetter - = new SocketFileDescriptorGetter() { - @Override public int get(Socket socket) { - try { - InputStream is = socket.getInputStream(); - FileDescriptor fd = ((FileInputStream) is).getFD(); - return SharedSecrets.getJavaIOFileDescriptorAccess().get(fd); - } catch (IOException e) { - e.printStackTrace(); - return -1; - } - } - }; private final WorkerThread worker; private static class WorkerThread extends Thread { @@ -72,35 +53,10 @@ public void terminate() { /** * Start a standalone server. * @param serverPort Port. - * @param socketFdGetter Method to get system file descriptor of the server socket. - * @throws IOException if failed to bind localhost:port. - */ - public Server(int serverPort, SocketFileDescriptorGetter socketFdGetter) throws IOException { - worker = new WorkerThread(new StandaloneServerProcessor(serverPort, socketFdGetter)); - } - - /** - * Start a standalone server. - * Use sun.misc.SharedSecrets.getJavaIOFileDescriptorAccess - * to get file descriptor for the socket. - * @param serverPort Port. * @throws IOException if failed to bind localhost:port. */ public Server(int serverPort) throws IOException { - this(serverPort, defaultSocketFdGetter); - } - - /** - * Start a server connected to proxy. - * @param proxyHost The proxy server host. - * @param proxyPort The proxy server port. - * @param key The key to identify the server. - * @param socketFdGetter Method to get system file descriptor of the server socket. - */ - public Server(String proxyHost, int proxyPort, String key, - SocketFileDescriptorGetter socketFdGetter) { - worker = new WorkerThread( - new ConnectProxyServerProcessor(proxyHost, proxyPort, key, socketFdGetter)); + worker = new WorkerThread(new StandaloneServerProcessor(serverPort)); } /** @@ -112,7 +68,8 @@ public Server(String proxyHost, int proxyPort, String key, * @param key The key to identify the server. */ public Server(String proxyHost, int proxyPort, String key) { - this(proxyHost, proxyPort, key, defaultSocketFdGetter); + worker = new WorkerThread( + new ConnectProxyServerProcessor(proxyHost, proxyPort, key)); } /** diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/SocketChannel.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/SocketChannel.java new file mode 100644 index 000000000000..e72581b2358f --- /dev/null +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/SocketChannel.java @@ -0,0 +1,49 @@ +package ml.dmlc.tvm.rpc; + +import ml.dmlc.tvm.Function; +import ml.dmlc.tvm.TVMValue; +import ml.dmlc.tvm.TVMValueBytes; + +import java.io.IOException; +import java.net.Socket; + +public class SocketChannel { + private final Socket socket; + + SocketChannel(Socket sock) { + socket = sock; + } + + private Function fsend = Function.convertFunc(new Function.Callback() { + @Override public Object invoke(TVMValue... args) { + byte[] data = args[0].asBytes(); + try { + socket.getOutputStream().write(data); + } catch (IOException e) { + e.printStackTrace(); + return -1; + } + return data.length; + } + }); + + private Function frecv = Function.convertFunc(new Function.Callback() { + @Override public Object invoke(TVMValue... args) { + long size = args[0].asLong(); + try { + return new TVMValueBytes(Utils.recvAll(socket.getInputStream(), (int) size)); + } catch (IOException e) { + e.printStackTrace(); + return -1; + } + } + }); + + public Function getFsend() { + return fsend; + } + + public Function getFrecv() { + return frecv; + } +} diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/SocketFileDescriptorGetter.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/SocketFileDescriptorGetter.java deleted file mode 100644 index 4c35f720009d..000000000000 --- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/SocketFileDescriptorGetter.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ml.dmlc.tvm.rpc; - -import java.net.Socket; - -/** - * Interface for defining different socket fd getter. - */ -public interface SocketFileDescriptorGetter { - /** - * Get native socket file descriptor. - * @param socket Java socket. - * @return native socket fd. - */ - public int get(Socket socket); -} diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/StandaloneServerProcessor.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/StandaloneServerProcessor.java index 06e3303d1523..2d2303d3fe8a 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/StandaloneServerProcessor.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/StandaloneServerProcessor.java @@ -28,12 +28,9 @@ */ public class StandaloneServerProcessor implements ServerProcessor { private final ServerSocket server; - private final SocketFileDescriptorGetter socketFileDescriptorGetter; - public StandaloneServerProcessor(int serverPort, - SocketFileDescriptorGetter sockFdGetter) throws IOException { + public StandaloneServerProcessor(int serverPort) throws IOException { this.server = new ServerSocket(serverPort); - this.socketFileDescriptorGetter = sockFdGetter; } @Override public void terminate() { @@ -46,9 +43,9 @@ public StandaloneServerProcessor(int serverPort, @Override public void run() { try { - Socket socket = server.accept(); - InputStream in = socket.getInputStream(); - OutputStream out = socket.getOutputStream(); + final Socket socket = server.accept(); + final InputStream in = socket.getInputStream(); + final OutputStream out = socket.getOutputStream(); int magic = Utils.wrapBytes(Utils.recvAll(in, 4)).getInt(); if (magic != RPC.RPC_MAGIC) { Utils.closeQuietly(socket); @@ -66,12 +63,10 @@ public StandaloneServerProcessor(int serverPort, out.write(Utils.toBytes(serverKey)); } + SocketChannel sockChannel = new SocketChannel(socket); System.err.println("Connection from " + socket.getRemoteSocketAddress().toString()); - final int sockFd = socketFileDescriptorGetter.get(socket); - if (sockFd != -1) { - new NativeServerLoop(sockFd).run(); - System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString()); - } + new NativeServerLoop(sockChannel.getFsend(), sockChannel.getFrecv()).run(); + System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString()); Utils.closeQuietly(socket); } catch (Throwable e) { e.printStackTrace(); diff --git a/jvm/core/src/test/java/ml/dmlc/tvm/contrib/GraphRuntimeTest.java b/jvm/core/src/test/java/ml/dmlc/tvm/contrib/GraphRuntimeTest.java index d719eb6f61e7..a29402867381 100644 --- a/jvm/core/src/test/java/ml/dmlc/tvm/contrib/GraphRuntimeTest.java +++ b/jvm/core/src/test/java/ml/dmlc/tvm/contrib/GraphRuntimeTest.java @@ -17,7 +17,10 @@ package ml.dmlc.tvm.contrib; -import ml.dmlc.tvm.*; +import ml.dmlc.tvm.Module; +import ml.dmlc.tvm.NDArray; +import ml.dmlc.tvm.TVMContext; +import ml.dmlc.tvm.TestUtils; import ml.dmlc.tvm.rpc.Client; import ml.dmlc.tvm.rpc.RPCSession; import ml.dmlc.tvm.rpc.Server; diff --git a/jvm/native/src/main/native/ml_dmlc_tvm_native_c_api.cc b/jvm/native/src/main/native/ml_dmlc_tvm_native_c_api.cc index 1eff6c45e1fc..b4bfd4270775 100644 --- a/jvm/native/src/main/native/ml_dmlc_tvm_native_c_api.cc +++ b/jvm/native/src/main/native/ml_dmlc_tvm_native_c_api.cc @@ -242,7 +242,7 @@ extern "C" int funcInvokeCallback(TVMValue *args, for (int i = 0; i < numArgs; ++i) { TVMValue arg = args[i]; int tcode = typeCodes[i]; - if (tcode == kNodeHandle || tcode == kFuncHandle || tcode == kModuleHandle) { + if (tcode == kObjectHandle || tcode == kFuncHandle || tcode == kModuleHandle) { TVMCbArgToReturn(&arg, tcode); } jobject jarg = tvmRetValueToJava(env, arg, tcode); @@ -259,8 +259,8 @@ extern "C" int funcInvokeCallback(TVMValue *args, reinterpret_cast(resourceHandle), jargs); TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get(); - const int prevNumStrArg = e->tvmFuncArgPushedStrs.size(); - const int prevNumBytesArg = e->tvmFuncArgPushedBytes.size(); + const size_t prevNumStrArg = e->tvmFuncArgPushedStrs.size(); + const size_t prevNumBytesArg = e->tvmFuncArgPushedBytes.size(); // convert returned (java) TVMValue to (C) TVMValue env->CallStaticVoidMethod(clsFunc, pushArgToStack, jretValue); diff --git a/jvm/pom.xml b/jvm/pom.xml index 99cfe0d7b5ec..150c3a00a894 100644 --- a/jvm/pom.xml +++ b/jvm/pom.xml @@ -164,8 +164,8 @@ maven-compiler-plugin 3.3 - 1.6 - 1.6 + 1.7 + 1.7 UTF-8 diff --git a/nnvm/include/nnvm/compiler/util.h b/nnvm/include/nnvm/compiler/util.h index fa8b69f9b70a..9555c0e7b3ea 100644 --- a/nnvm/include/nnvm/compiler/util.h +++ b/nnvm/include/nnvm/compiler/util.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -56,7 +56,7 @@ inline tvm::Array ShapeToArray(TShape shape) { * \return An Array of Expr, where each element is a constant int32 */ inline tvm::Array ShapeToIntArray(TShape shape) { - return tvm::Array(ShapeToArray(shape).node_); + return tvm::Downcast >(ShapeToArray(shape)); } } // namespace compiler } // namespace nnvm diff --git a/nnvm/src/compiler/compile_engine.cc b/nnvm/src/compiler/compile_engine.cc index 542455969b8b..5ce78d1d58d6 100644 --- a/nnvm/src/compiler/compile_engine.cc +++ b/nnvm/src/compiler/compile_engine.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -388,6 +388,9 @@ TVM_REGISTER_GLOBAL("nnvm.compiler.CacheItem2ScheduleArgs") *rv = ret; }); +TVM_REGISTER_NODE_TYPE(GraphFuncNode); +TVM_REGISTER_NODE_TYPE(GraphCacheEntryNode); + TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const GraphFuncNode *op, IRPrinter *p) { p->stream << "GraphFunc(name=" << op->func_name diff --git a/nnvm/src/compiler/compile_engine.h b/nnvm/src/compiler/compile_engine.h index 35287f5a9358..ec9a13b13b17 100644 --- a/nnvm/src/compiler/compile_engine.h +++ b/nnvm/src/compiler/compile_engine.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -55,7 +55,7 @@ struct GraphFuncNode : public tvm::Node { /*! \brief The lowered functions */ tvm::Array funcs; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("target", &target); v->Visit("func_name", &func_name); v->Visit("inputs", &inputs); @@ -78,7 +78,7 @@ struct GraphCacheEntryNode : public tvm::Node { /*! \brief Index of the master node for calling schedule*/ int master_idx; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("graph_func", &graph_func); v->Visit("use_count", &use_count); v->Visit("master_idx", &master_idx); @@ -92,7 +92,7 @@ class GraphCacheEntry : public ::tvm::NodeRef { GraphCacheEntry() {} explicit GraphCacheEntry(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) {} GraphCacheEntryNode* operator->() { - return static_cast(node_.get()); + return static_cast(get_mutable()); } using ContainerType = GraphCacheEntryNode; }; diff --git a/nnvm/src/compiler/graph_hash.h b/nnvm/src/compiler/graph_hash.h index aed3462cf128..6966a152224b 100644 --- a/nnvm/src/compiler/graph_hash.h +++ b/nnvm/src/compiler/graph_hash.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -48,7 +48,7 @@ struct GraphKeyNode : public tvm::Node { // The graph hash key is ensured always not to be 0 mutable size_t cache_hash_key_{0}; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("inputs", &inputs); v->Visit("target", &target); } diff --git a/nnvm/src/compiler/graph_runtime.cc b/nnvm/src/compiler/graph_runtime.cc index 3bfebe3ba4e8..d8ff3bf34bf8 100644 --- a/nnvm/src/compiler/graph_runtime.cc +++ b/nnvm/src/compiler/graph_runtime.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,11 +18,12 @@ */ /*! - * Copyright (c) 2017 by Contributors * \file graph_runtime.cc * \brief Interface code with TVM graph runtime. */ #include +#include + #include #include "graph_runtime.h" diff --git a/nnvm/src/compiler/graph_runtime.h b/nnvm/src/compiler/graph_runtime.h index 3a847de83d9f..770c98e83261 100644 --- a/nnvm/src/compiler/graph_runtime.h +++ b/nnvm/src/compiler/graph_runtime.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,7 +28,6 @@ #include #include #include -#include #include #include #include @@ -62,13 +61,13 @@ struct NDArrayWrapperNode : public ::tvm::Node { std::string name; tvm::runtime::NDArray array; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); v->Visit("array", &array); } static constexpr const char* _type_key = "NDArrayWrapper"; - TVM_DECLARE_NODE_TYPE_INFO(NDArrayWrapperNode, Node); + TVM_DECLARE_NODE_TYPE_INFO(NDArrayWrapperNode, tvm::Node); }; TVM_DEFINE_NODE_REF(NDArrayWrapper, NDArrayWrapperNode); diff --git a/nnvm/src/compiler/packed_func_ext.cc b/nnvm/src/compiler/packed_func_ext.cc index bbcc62a99ad8..45f1451663e6 100644 --- a/nnvm/src/compiler/packed_func_ext.cc +++ b/nnvm/src/compiler/packed_func_ext.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -115,7 +115,7 @@ TVM_REGISTER_GLOBAL("nnvm._register_compute") const Array& out_info) -> Array { TVMRetValue ret = (*f)(GetAttrDict(attrs), inputs, out_info); - if ((*ret.ptr<::tvm::NodePtr >())->derived_from()) { + if (ret.IsObjectRef()) { return {ret.operator Tensor()}; } else { return ret; diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc index 5496a4c674f6..c48ae0061f9e 100644 --- a/nnvm/src/top/tensor/transform.cc +++ b/nnvm/src/top/tensor/transform.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -1237,7 +1237,7 @@ Array GetIntArray(Array arr) { CHECK(!arr[i].defined() || arr[i].as()) << "Expect an int array"; } - return Array(arr.node_); + return Downcast >(arr); } NNVM_REGISTER_OP(slice_like) diff --git a/python/tvm/_ffi/_ctypes/function.py b/python/tvm/_ffi/_ctypes/function.py index 895c72d28d01..2f0b5babda4d 100644 --- a/python/tvm/_ffi/_ctypes/function.py +++ b/python/tvm/_ffi/_ctypes/function.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # coding: utf-8 -# pylint: disable=invalid-name, protected-access, too-many-branches, global-statement +# pylint: disable=invalid-name, protected-access, too-many-branches, global-statement, unused-import """Function configuration API.""" from __future__ import absolute_import @@ -32,8 +32,8 @@ from .types import TVMValue, TypeCode from .types import TVMPackedCFunc, TVMCFuncFinalizer from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64 -from .node import NodeBase -from . import node as _node +from .object import ObjectBase, _set_class_node +from . import object as _object FunctionHandle = ctypes.c_void_p ModuleHandle = ctypes.c_void_p @@ -107,9 +107,9 @@ def _make_tvm_args(args, temp_args): values = (TVMValue * num_args)() type_codes = (ctypes.c_int * num_args)() for i, arg in enumerate(args): - if isinstance(arg, NodeBase): + if isinstance(arg, ObjectBase): values[i].v_handle = arg.handle - type_codes[i] = TypeCode.NODE_HANDLE + type_codes[i] = TypeCode.OBJECT_HANDLE elif arg is None: values[i].v_handle = None type_codes[i] = TypeCode.NULL @@ -147,7 +147,7 @@ def _make_tvm_args(args, temp_args): elif isinstance(arg, (list, tuple, dict, NodeGeneric)): arg = convert_to_node(arg) values[i].v_handle = arg.handle - type_codes[i] = TypeCode.NODE_HANDLE + type_codes[i] = TypeCode.OBJECT_HANDLE temp_args.append(arg) elif isinstance(arg, _CLASS_MODULE): values[i].v_handle = arg.handle @@ -163,9 +163,6 @@ def _make_tvm_args(args, temp_args): values[i].v_handle = arg.handle type_codes[i] = TypeCode.FUNC_HANDLE temp_args.append(arg) - elif isinstance(arg, _CLASS_OBJECT): - values[i].v_handle = arg.handle - type_codes[i] = TypeCode.OBJECT_CELL else: raise TypeError("Don't know how to handle type %s" % type(arg)) return values, type_codes, num_args @@ -225,7 +222,7 @@ def __init_handle_by_constructor__(fconstructor, args): raise get_last_ffi_error() _ = temp_args _ = args - assert ret_tcode.value == TypeCode.NODE_HANDLE + assert ret_tcode.value == TypeCode.OBJECT_HANDLE handle = ret_val.v_handle return handle @@ -246,7 +243,7 @@ def _handle_return_func(x): return _CLASS_FUNCTION(handle, False) # setup return handle for function type -_node.__init_by_constructor__ = __init_handle_by_constructor__ +_object.__init_by_constructor__ = __init_handle_by_constructor__ RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True) diff --git a/python/tvm/_ffi/_ctypes/node.py b/python/tvm/_ffi/_ctypes/object.py similarity index 52% rename from python/tvm/_ffi/_ctypes/node.py rename to python/tvm/_ffi/_ctypes/object.py index 39fe0ef35525..c3ae56822198 100644 --- a/python/tvm/_ffi/_ctypes/node.py +++ b/python/tvm/_ffi/_ctypes/object.py @@ -14,66 +14,59 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, protected-access -# pylint: disable=no-member, missing-docstring, not-callable +# pylint: disable=invalid-name +"""Runtime Object api""" from __future__ import absolute_import import ctypes -from ..base import _LIB, check_call, c_str +from ..base import _LIB, check_call +from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func from ..node_generic import _set_class_node_base -from .types import TVMValue, TypeCode -from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func -NodeHandle = ctypes.c_void_p + +ObjectHandle = ctypes.c_void_p __init_by_constructor__ = None -"""Maps node type to its constructor""" -NODE_TYPE = {} +"""Maps object type to its constructor""" +OBJECT_TYPE = {} + +_CLASS_NODE = None + +def _set_class_node(node_class): + global _CLASS_NODE + _CLASS_NODE = node_class + -def _register_node(index, cls): - """register node class""" - NODE_TYPE[index] = cls +def _register_object(index, cls): + """register object class""" + OBJECT_TYPE[index] = cls -def _return_node(x): - """Return node function""" + +def _return_object(x): handle = x.v_handle - if not isinstance(handle, NodeHandle): - handle = NodeHandle(handle) - tindex = ctypes.c_int() - check_call(_LIB.TVMNodeGetTypeIndex(handle, ctypes.byref(tindex))) - cls = NODE_TYPE.get(tindex.value, NodeBase) + if not isinstance(handle, ObjectHandle): + handle = ObjectHandle(handle) + tindex = ctypes.c_uint() + check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex))) + cls = OBJECT_TYPE.get(tindex.value, _CLASS_NODE) # Avoid calling __init__ of cls, instead directly call __new__ # This allows child class to implement their own __init__ - node = cls.__new__(cls) - node.handle = handle - return node - + obj = cls.__new__(cls) + obj.handle = handle + return obj -RETURN_SWITCH[TypeCode.NODE_HANDLE] = _return_node -C_TO_PY_ARG_SWITCH[TypeCode.NODE_HANDLE] = _wrap_arg_func( - _return_node, TypeCode.NODE_HANDLE) +RETURN_SWITCH[TypeCode.OBJECT_HANDLE] = _return_object +C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_HANDLE] = _wrap_arg_func( + _return_object, TypeCode.OBJECT_HANDLE) -class NodeBase(object): +class ObjectBase(object): + """Base object for all object types""" __slots__ = ["handle"] - # pylint: disable=no-member + def __del__(self): if _LIB is not None: - check_call(_LIB.TVMNodeFree(self.handle)) - - def __getattr__(self, name): - ret_val = TVMValue() - ret_type_code = ctypes.c_int() - ret_success = ctypes.c_int() - check_call(_LIB.TVMNodeGetAttr( - self.handle, c_str(name), - ctypes.byref(ret_val), - ctypes.byref(ret_type_code), - ctypes.byref(ret_success))) - if not ret_success.value: - raise AttributeError( - "'%s' object has no attribute '%s'" % (str(type(self)), name)) - return RETURN_SWITCH[ret_type_code.value](ret_val) + check_call(_LIB.TVMObjectFree(self.handle)) def __init_handle_by_constructor__(self, fconstructor, *args): """Initialize the handle by calling constructor function. @@ -95,8 +88,9 @@ def __init_handle_by_constructor__(self, fconstructor, *args): # assign handle first to avoid error raising self.handle = None handle = __init_by_constructor__(fconstructor, args) - if not isinstance(handle, NodeHandle): - handle = NodeHandle(handle) + if not isinstance(handle, ObjectHandle): + handle = ObjectHandle(handle) self.handle = handle -_set_class_node_base(NodeBase) + +_set_class_node_base(ObjectBase) diff --git a/python/tvm/_ffi/_ctypes/vmobj.py b/python/tvm/_ffi/_ctypes/vmobj.py deleted file mode 100644 index 59930e55c382..000000000000 --- a/python/tvm/_ffi/_ctypes/vmobj.py +++ /dev/null @@ -1,52 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name -"""Runtime Object api""" -from __future__ import absolute_import - -import ctypes -from ..base import _LIB, check_call -from .types import TypeCode, RETURN_SWITCH - -ObjectHandle = ctypes.c_void_p - -"""Maps object type to its constructor""" -OBJECT_TYPE = {} - -def _register_object(index, cls): - """register object class""" - OBJECT_TYPE[index] = cls - - -def _return_object(x): - handle = x.v_handle - if not isinstance(handle, ObjectHandle): - handle = ObjectHandle(handle) - tag = ctypes.c_int() - check_call(_LIB.TVMGetObjectTag(handle, ctypes.byref(tag))) - cls = OBJECT_TYPE.get(tag.value, ObjectBase) - obj = cls(handle) - return obj - -RETURN_SWITCH[TypeCode.OBJECT_CELL] = _return_object - - -class ObjectBase(object): - __slots__ = ["handle"] - - def __init__(self, handle): - self.handle = handle diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index 63130ef67d38..4b7b2c88ffa5 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -31,13 +31,12 @@ cdef enum TVMTypeCode: kTVMType = 5 kTVMContext = 6 kArrayHandle = 7 - kNodeHandle = 8 + kObjectHandle = 8 kModuleHandle = 9 kFuncHandle = 10 kStr = 11 kBytes = 12 kNDArrayContainer = 13 - kObjectCell = 14 kExtBegin = 15 cdef extern from "tvm/runtime/c_runtime_api.h": @@ -78,7 +77,7 @@ ctypedef void* TVMStreamHandle ctypedef void* TVMRetValueHandle ctypedef void* TVMFunctionHandle ctypedef void* ObjectHandle -ctypedef void* NodeHandle + ctypedef struct TVMNDArrayContainer: DLTensor dl_tensor @@ -130,19 +129,9 @@ cdef extern from "tvm/runtime/c_runtime_api.h": int TVMArrayToDLPack(DLTensorHandle arr_from, DLManagedTensor** out) void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor) - int TVMGetObjectTag(ObjectHandle obj, int* tag) - -cdef extern from "tvm/c_dsl_api.h": - int TVMNodeFree(NodeHandle handle) - int TVMNodeTypeKey2Index(const char* type_key, - int* out_index) - int TVMNodeGetTypeIndex(NodeHandle handle, - int* out_index) - int TVMNodeGetAttr(NodeHandle handle, - const char* key, - TVMValue* out_value, - int* out_type_code, - int* out_success) + int TVMObjectFree(ObjectHandle obj) + int TVMObjectGetTypeIndex(ObjectHandle obj, unsigned* out_index) + cdef inline py_str(const char* x): if PY_MAJOR_VERSION < 3: diff --git a/python/tvm/_ffi/_cython/core.pyx b/python/tvm/_ffi/_cython/core.pyx index 4b8536c726aa..cbf9d5859046 100644 --- a/python/tvm/_ffi/_cython/core.pyx +++ b/python/tvm/_ffi/_cython/core.pyx @@ -16,7 +16,8 @@ # under the License. include "./base.pxi" -include "./node.pxi" +include "./object.pxi" +# include "./node.pxi" include "./function.pxi" include "./ndarray.pxi" -include "./vmobj.pxi" + diff --git a/python/tvm/_ffi/_cython/function.pxi b/python/tvm/_ffi/_cython/function.pxi index cf1884c32486..a2360427b6c7 100644 --- a/python/tvm/_ffi/_cython/function.pxi +++ b/python/tvm/_ffi/_cython/function.pxi @@ -41,10 +41,9 @@ cdef int tvm_callback(TVMValue* args, for i in range(num_args): value = args[i] tcode = type_codes[i] - if (tcode == kNodeHandle or + if (tcode == kObjectHandle or tcode == kFuncHandle or tcode == kModuleHandle or - tcode == kObjectCell or tcode > kExtBegin): CALL(TVMCbArgToReturn(&value, tcode)) @@ -98,9 +97,9 @@ cdef inline int make_arg(object arg, list temp_args) except -1: """Pack arguments into c args tvm call accept""" cdef unsigned long long ptr - if isinstance(arg, NodeBase): - value[0].v_handle = (arg).chandle - tcode[0] = kNodeHandle + if isinstance(arg, ObjectBase): + value[0].v_handle = (arg).chandle + tcode[0] = kObjectHandle elif isinstance(arg, NDArrayBase): value[0].v_handle = (arg).chandle tcode[0] = (kNDArrayContainer if @@ -152,15 +151,12 @@ cdef inline int make_arg(object arg, temp_args.append(tstr) elif isinstance(arg, (list, tuple, dict, NodeGeneric)): arg = convert_to_node(arg) - value[0].v_handle = (arg).chandle - tcode[0] = kNodeHandle + value[0].v_handle = (arg).chandle + tcode[0] = kObjectHandle temp_args.append(arg) elif isinstance(arg, _CLASS_MODULE): value[0].v_handle = c_handle(arg.handle) tcode[0] = kModuleHandle - elif isinstance(arg, _CLASS_OBJECT): - value[0].v_handle = c_handle(arg.handle) - tcode[0] = kObjectCell elif isinstance(arg, FunctionBase): value[0].v_handle = (arg).chandle tcode[0] = kFuncHandle @@ -188,8 +184,8 @@ cdef inline bytearray make_ret_bytes(void* chandle): cdef inline object make_ret(TVMValue value, int tcode): """convert result to return value.""" - if tcode == kNodeHandle: - return make_ret_node(value.v_handle) + if tcode == kObjectHandle: + return make_ret_object(value.v_handle) elif tcode == kNull: return None elif tcode == kInt: @@ -212,8 +208,6 @@ cdef inline object make_ret(TVMValue value, int tcode): fobj = _CLASS_FUNCTION(None, False) (fobj).chandle = value.v_handle return fobj - elif tcode == kObjectCell: - return make_ret_object(value.v_handle) elif tcode in _TVM_EXT_RET: return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle)) @@ -314,6 +308,7 @@ cdef class FunctionBase: _CLASS_FUNCTION = None _CLASS_MODULE = None _CLASS_OBJECT = None +_CLASS_NODE = None def _set_class_module(module_class): """Initialize the module.""" @@ -327,3 +322,7 @@ def _set_class_function(func_class): def _set_class_object(obj_class): global _CLASS_OBJECT _CLASS_OBJECT = obj_class + +def _set_class_node(node_class): + global _CLASS_NODE + _CLASS_NODE = node_class diff --git a/python/tvm/_ffi/_cython/node.pxi b/python/tvm/_ffi/_cython/object.pxi similarity index 64% rename from python/tvm/_ffi/_cython/node.pxi rename to python/tvm/_ffi/_cython/object.pxi index 5e0c366e5600..9561eab94ea2 100644 --- a/python/tvm/_ffi/_cython/node.pxi +++ b/python/tvm/_ffi/_cython/object.pxi @@ -15,43 +15,46 @@ # specific language governing permissions and limitations # under the License. -from ... import _api_internal -from ..base import string_types +"""Maps object type to its constructor""" from ..node_generic import _set_class_node_base -"""Maps node type to its constructor""" -NODE_TYPE = [] +OBJECT_TYPE = [] -def _register_node(int index, object cls): - """register node class""" - while len(NODE_TYPE) <= index: - NODE_TYPE.append(None) - NODE_TYPE[index] = cls +def _register_object(int index, object cls): + """register object class""" + while len(OBJECT_TYPE) <= index: + OBJECT_TYPE.append(None) + OBJECT_TYPE[index] = cls -cdef inline object make_ret_node(void* chandle): - global NODE_TYPE - cdef int tindex - cdef list node_type +cdef inline object make_ret_object(void* chandle): + global OBJECT_TYPE + global _CLASS_NODE + cdef unsigned tindex + cdef list object_type cdef object cls - node_type = NODE_TYPE - CALL(TVMNodeGetTypeIndex(chandle, &tindex)) - if tindex < len(node_type): - cls = node_type[tindex] + cdef object handle + object_type = OBJECT_TYPE + handle = ctypes_handle(chandle) + CALL(TVMObjectGetTypeIndex(chandle, &tindex)) + if tindex < len(object_type): + cls = object_type[tindex] if cls is not None: obj = cls.__new__(cls) else: - obj = NodeBase.__new__(NodeBase) + # default use node base class + # TODO(tqchen) change to object after Node unifies with Object + obj = _CLASS_NODE.__new__(_CLASS_NODE) else: - obj = NodeBase.__new__(NodeBase) - (obj).chandle = chandle + obj = _CLASS_NODE.__new__(_CLASS_NODE) + (obj).chandle = chandle return obj -cdef class NodeBase: +cdef class ObjectBase: cdef void* chandle - cdef _set_handle(self, handle): + cdef inline _set_handle(self, handle): cdef unsigned long long ptr if handle is None: self.chandle = NULL @@ -70,17 +73,7 @@ cdef class NodeBase: self._set_handle(value) def __dealloc__(self): - CALL(TVMNodeFree(self.chandle)) - - def __getattr__(self, name): - cdef TVMValue ret_val - cdef int ret_type_code, ret_succ - CALL(TVMNodeGetAttr(self.chandle, c_str(name), - &ret_val, &ret_type_code, &ret_succ)) - if ret_succ == 0: - raise AttributeError( - "'%s' object has no attribute '%s'" % (type(self), name)) - return make_ret(ret_val, ret_type_code) + CALL(TVMObjectFree(self.chandle)) def __init_handle_by_constructor__(self, fconstructor, *args): """Initialize the handle by calling constructor function. @@ -104,7 +97,8 @@ cdef class NodeBase: cdef void* chandle ConstructorCall( (fconstructor).chandle, - kNodeHandle, args, &chandle) + kObjectHandle, args, &chandle) self.chandle = chandle -_set_class_node_base(NodeBase) + +_set_class_node_base(ObjectBase) diff --git a/python/tvm/_ffi/_cython/vmobj.pxi b/python/tvm/_ffi/_cython/vmobj.pxi deleted file mode 100644 index 9b487566a6a6..000000000000 --- a/python/tvm/_ffi/_cython/vmobj.pxi +++ /dev/null @@ -1,67 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Maps object type to its constructor""" -OBJECT_TYPE = [] - -def _register_object(int index, object cls): - """register node class""" - while len(OBJECT_TYPE) <= index: - OBJECT_TYPE.append(None) - OBJECT_TYPE[index] = cls - - -cdef inline object make_ret_object(void* chandle): - global OBJECT_TYPE - cdef int tag - cdef list object_type - cdef object cls - cdef object handle - object_type = OBJECT_TYPE - handle = ctypes_handle(chandle) - CALL(TVMGetObjectTag(chandle, &tag)) - if tag < len(object_type): - cls = object_type[tag] - if cls is not None: - obj = cls(handle) - else: - obj = ObjectBase(handle) - else: - obj = ObjectBase(handle) - return obj - - -cdef class ObjectBase: - cdef ObjectHandle chandle - - cdef inline _set_handle(self, handle): - if handle is None: - self.chandle = NULL - else: - self.chandle = c_handle(handle) - - property handle: - def __get__(self): - if self.chandle == NULL: - return None - else: - return ctypes.cast(self.chandle, ctypes.c_void_p) - def __set__(self, value): - self._set_handle(value) - - def __init__(self, handle): - self._set_handle(handle) diff --git a/python/tvm/_ffi/function.py b/python/tvm/_ffi/function.py index 4bb31820548f..60e7aeb9aec5 100644 --- a/python/tvm/_ffi/function.py +++ b/python/tvm/_ffi/function.py @@ -22,7 +22,6 @@ import sys import ctypes from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE -from . import vmobj as _vmobj IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError diff --git a/python/tvm/_ffi/node.py b/python/tvm/_ffi/node.py index baca89d628b8..c6c151af9053 100644 --- a/python/tvm/_ffi/node.py +++ b/python/tvm/_ffi/node.py @@ -21,21 +21,8 @@ import ctypes import sys from .. import _api_internal +from .object import Object, register_object, _set_class_node from .node_generic import NodeGeneric, convert_to_node, const -from .base import _LIB, check_call, c_str, py_str, _FFI_MODE - -IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError -try: - # pylint: disable=wrong-import-position - if _FFI_MODE == "ctypes": - raise ImportError() - if sys.version_info >= (3, 0): - from ._cy3.core import _register_node, NodeBase as _NodeBase - else: - from ._cy2.core import _register_node, NodeBase as _NodeBase -except IMPORT_EXCEPT: - # pylint: disable=wrong-import-position - from ._ctypes.node import _register_node, NodeBase as _NodeBase def _new_object(cls): @@ -43,20 +30,22 @@ def _new_object(cls): return cls.__new__(cls) -class NodeBase(_NodeBase): +class NodeBase(Object): """NodeBase is the base class of all TVM language AST object.""" def __repr__(self): return _api_internal._format_str(self) def __dir__(self): - plist = ctypes.POINTER(ctypes.c_char_p)() - size = ctypes.c_uint() - check_call(_LIB.TVMNodeListAttrNames( - self.handle, ctypes.byref(size), ctypes.byref(plist))) - names = [] - for i in range(size.value): - names.append(py_str(plist[i])) - return names + fnames = _api_internal._NodeListAttrNames(self) + size = fnames(-1) + return [fnames(i) for i in range(size)] + + def __getattr__(self, name): + try: + return _api_internal._NodeGetAttr(self, name) + except AttributeError: + raise AttributeError( + "%s has no attribute %s" % (str(type(self)), name)) def __hash__(self): return _api_internal._raw_ptr(self) @@ -95,24 +84,6 @@ def same_as(self, other): return self.__hash__() == other.__hash__() -def register_node(type_key=None): - """register node type - - Parameters - ---------- - type_key : str or cls - The type key of the node - """ - node_name = type_key if isinstance(type_key, str) else type_key.__name__ - - def register(cls): - """internal register function""" - tindex = ctypes.c_int() - ret = _LIB.TVMNodeTypeKey2Index(c_str(node_name), ctypes.byref(tindex)) - if ret == 0: - _register_node(tindex.value, cls) - return cls - - if isinstance(type_key, str): - return register - return register(type_key) +# pylint: disable=invalid-name +register_node = register_object +_set_class_node(NodeBase) diff --git a/python/tvm/_ffi/object.py b/python/tvm/_ffi/object.py new file mode 100644 index 000000000000..002fd27af0fd --- /dev/null +++ b/python/tvm/_ffi/object.py @@ -0,0 +1,137 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Runtime Object API""" +from __future__ import absolute_import + +import sys +import ctypes +from .base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str + +IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError + +try: + # pylint: disable=wrong-import-position,unused-import + if _FFI_MODE == "ctypes": + raise ImportError() + if sys.version_info >= (3, 0): + from ._cy3.core import _set_class_object, _set_class_node + from ._cy3.core import ObjectBase as _ObjectBase + from ._cy3.core import _register_object + else: + from ._cy2.core import _set_class_object, _set_class_node + from ._cy2.core import ObjectBase as _ObjectBase + from ._cy2.core import _register_object +except IMPORT_EXCEPT: + # pylint: disable=wrong-import-position,unused-import + from ._ctypes.function import _set_class_object, _set_class_node + from ._ctypes.object import ObjectBase as _ObjectBase + from ._ctypes.object import _register_object + + +class Object(_ObjectBase): + """Base class for all tvm's runtime objects.""" + pass + + +def register_object(type_key=None): + """register object type. + + Parameters + ---------- + type_key : str or cls + The type key of the node + + Examples + -------- + The following code registers MyObject + using type key "test.MyObject" + + .. code-block:: python + + @tvm.register_object("test.MyObject") + class MyObject(Object): + pass + """ + object_name = type_key if isinstance(type_key, str) else type_key.__name__ + + def register(cls): + """internal register function""" + if hasattr(cls, "_type_index"): + tindex = cls._type_index + else: + tidx = ctypes.c_uint() + if not _RUNTIME_ONLY: + check_call(_LIB.TVMObjectTypeKey2Index( + c_str(object_name), ctypes.byref(tidx))) + else: + # directly skip unknown objects during runtime. + ret = _LIB.TVMObjectTypeKey2Index( + c_str(object_name), ctypes.byref(tidx)) + if ret != 0: + return cls + tindex = tidx.value + _register_object(tindex, cls) + return cls + + if isinstance(type_key, str): + return register + + return register(type_key) + + +def getitem_helper(obj, elem_getter, length, idx): + """Helper function to implement a pythonic getitem function. + + Parameters + ---------- + obj: object + The original object + + elem_getter : function + A simple function that takes index and return a single element. + + length : int + The size of the array + + idx : int or slice + The argument passed to getitem + + Returns + ------- + result : object + The result of getitem + """ + if isinstance(idx, slice): + start = idx.start if idx.start is not None else 0 + stop = idx.stop if idx.stop is not None else length + step = idx.step if idx.step is not None else 1 + if start < 0: + start += length + if stop < 0: + stop += length + return [elem_getter(obj, i) for i in range(start, stop, step)] + + if idx < -length or idx >= length: + raise IndexError("Index out of range. size: {}, got index {}" + .format(length, idx)) + if idx < 0: + idx += length + return elem_getter(obj, idx) + + +_set_class_object(Object) diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 0d28abd46cb2..2dbb67dfbf73 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -36,13 +36,12 @@ class TypeCode(object): TVM_TYPE = 5 TVM_CONTEXT = 6 ARRAY_HANDLE = 7 - NODE_HANDLE = 8 + OBJECT_HANDLE = 8 MODULE_HANDLE = 9 FUNC_HANDLE = 10 STR = 11 BYTES = 12 NDARRAY_CONTAINER = 13 - OBJECT_CELL = 14 EXT_BEGIN = 15 diff --git a/python/tvm/_ffi/vmobj.py b/python/tvm/_ffi/vmobj.py deleted file mode 100644 index ea3431aa973c..000000000000 --- a/python/tvm/_ffi/vmobj.py +++ /dev/null @@ -1,61 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name -"""Runtime Object api""" -from __future__ import absolute_import - -import sys -from .base import _FFI_MODE - -IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError - -try: - # pylint: disable=wrong-import-position - if _FFI_MODE == "ctypes": - raise ImportError() - if sys.version_info >= (3, 0): - from ._cy3.core import _set_class_object - from ._cy3.core import ObjectBase as _ObjectBase - from ._cy3.core import _register_object - else: - from ._cy2.core import _set_class_object - from ._cy2.core import ObjectBase as _ObjectBase - from ._cy2.core import _register_object -except IMPORT_EXCEPT: - # pylint: disable=wrong-import-position - from ._ctypes.function import _set_class_object - from ._ctypes.vmobj import ObjectBase as _ObjectBase - from ._ctypes.vmobj import _register_object - - -class ObjectTag(object): - """Type code used in API calls""" - TENSOR = 1 - CLOSURE = 2 - DATATYPE = 3 - - -class Object(_ObjectBase): - """The VM Object used in Relay virtual machine.""" - - -def register_object(cls): - _register_object(cls.tag, cls) - return cls - - -_set_class_object(Object) diff --git a/python/tvm/api.py b/python/tvm/api.py index e7523bd733f9..f0261be37e41 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -21,6 +21,7 @@ from numbers import Integral as _Integral from ._ffi.base import string_types +from ._ffi.object import register_object, Object from ._ffi.node import register_node, NodeBase from ._ffi.node import convert_to_node as _convert_to_node from ._ffi.node_generic import _scalar_type_inference diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index 5b0294ef2d07..55be05f4b88f 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -52,8 +52,7 @@ def _build(func, def extract_from_program(func, params, ops, target, target_host=None): """ Extract tuning tasks from a relay program. - This function collects tuning tasks by building the program - with a "tracing" target and tracing all the calls to topi. + This function is the single program version of extract_from_multiple_program. Parameters ---------- @@ -73,66 +72,14 @@ def extract_from_program(func, params, ops, target, target_host=None): task: Array of autotvm.task.Task collected tasks """ - import tvm.relay.op - from tvm import relay - import topi - - env = TaskExtractEnv.get() - - # NOTE: To add more ops, you only need to change the following lists - # relay op -> topi compute - OP2TOPI = { - tvm.relay.op.nn.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw, - topi.nn.group_conv2d_nchw, topi.nn.conv2d_NCHWc], - tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw], - tvm.relay.op.nn.dense: [topi.nn.dense], - tvm.relay.op.nn.deformable_conv2d: [topi.nn.deformable_conv2d_nchw], - } - - topi_funcs = [] - for op_name in ops: - if op_name in OP2TOPI: - topi_funcs.extend(OP2TOPI[op_name]) - else: - warnings.warn("Op %s is not tunable, ignored" % op_name) - - # run compiler to collect all TOPI calls during compilation - env.reset(topi_funcs) - with env: - # disable logger temporarily - old_state = logger.disabled - logger.disabled = True - - relay.backend.compile_engine.get().clear() - # wrap build call in thread to avoid multiprocessing problems - mod = relay.Module.from_expr(func) - build_thread = threading.Thread(target=_build, - args=(mod, - target, - target_host, - params)) - build_thread.start() - build_thread.join() - - logger.disabled = old_state - - # create tasks for target - tasks = [] - for task_name, args in env.get_tasks(): - try: - tsk = create(task_name, args, - target=target, target_host=target_host, - template_key='direct') - tasks.append(tsk) - except topi.InvalidShapeError: - warnings.warn("Invalid shape during AutoTVM task creation") - return tasks + return extract_from_multiple_program([func], [params], ops, target, target_host) def extract_from_multiple_program(funcs, params, ops, target, target_host=None): """ Extract tuning tasks from multiple relay programs. - This function is the multiple program version of extract_from_program + This function collects tuning tasks by building a list of programs + with a "tracing" target and tracing all the calls to topi. Parameters ---------- @@ -152,19 +99,20 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None): task: Array of autotvm.task.Task collected tasks """ - env = TaskExtractEnv.get() import tvm.relay.op from tvm import relay import topi + env = TaskExtractEnv.get() + # NOTE: To add more ops, you only need to change the following lists # relay op -> topi compute OP2TOPI = { tvm.relay.op.nn.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw, - topi.nn.group_conv2d_nchw], + topi.nn.group_conv2d_nchw, topi.nn.conv2d_NCHWc], tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw], tvm.relay.op.nn.dense: [topi.nn.dense], - tvm.relay.op.nn.contrib_deformable_conv2d: [topi.nn.deformable_conv2d_nchw], + tvm.relay.op.nn.deformable_conv2d: [topi.nn.deformable_conv2d_nchw], } topi_funcs = [] @@ -185,11 +133,8 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None): relay.backend.compile_engine.get().clear() # wrap build call in thread to avoid multiprocessing problems mod = relay.Module.from_expr(func) - build_thread = threading.Thread(target=my_build, - args=(mod, - target, - target_host, - params)) + build_thread = threading.Thread(target=_build, + args=(mod, target, target_host, param)) build_thread.start() build_thread.join() diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index e0db27574898..4f3cc90b474e 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -226,7 +226,7 @@ def args_to_workload(x, topi_compute_func=None): elif x is None: workload = 0 else: - raise RuntimeError('Do not support type "%s" in argument. Consider to use' + raise RuntimeError('Do not support type "%s" in argument. Consider to use ' 'primitive types only' % type(x)) return (get_func_name(topi_compute_func), ) + workload if topi_compute_func else workload diff --git a/python/tvm/autotvm/task/topi_integration.py b/python/tvm/autotvm/task/topi_integration.py index 09f08ad8b4ae..ac4683d4ae0b 100644 --- a/python/tvm/autotvm/task/topi_integration.py +++ b/python/tvm/autotvm/task/topi_integration.py @@ -176,9 +176,12 @@ def _topi_nn_conv2d(*args, **kwargs): args = deserialize_args(args) A, W = args[:2] layout = args[-2] - assert layout == 'NCHW', "only support NCHW currently" + assert layout == 'NCHW' or layout == 'HWCN', "only support NCHW/HWCN currently" C = topi.nn.conv2d(*args, **kwargs) - s = topi.generic.schedule_conv2d_nchw([C]) + if layout == 'NCHW': + s = topi.generic.schedule_conv2d_nchw([C]) + else: + s = topi.generic.schedule_conv2d_hwcn([C]) return s, [A, W, C] @register("topi_nn_depthwise_conv2d_nchw") diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index 4cb09931616e..fe2f64142c56 100644 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -413,7 +413,6 @@ def lower(sch, # Phase 3 stmt = ir_pass.Simplify(stmt) - stmt = ir_pass.LowerStorageAccessInfo(stmt) stmt = ir_pass.RemoveNoOp(stmt) if not cfg.disable_select_rewriting: stmt = ir_pass.RewriteUnsafeSelect(stmt) @@ -465,6 +464,7 @@ def _build_for_device(flist, target, target_host): func = ir_pass.ThreadSync(func, "global") func = ir_pass.ThreadSync(func, "shared") func = ir_pass.ThreadSync(func, "warp") + func = ir_pass.InferFragment(func) warp_size = target.thread_warp_size func = ir_pass.LowerThreadAllreduce(func, warp_size) fsplits = [s for s in ir_pass.SplitHostDevice(func)] @@ -494,6 +494,8 @@ def _build_for_device(flist, target, target_host): assert not fdevice target_host = _target.create(target_host) + fdevice = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fdevice] + fhost = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fhost] fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice] fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost] fhost = [ir_pass.CombineContextCall(x) for x in fhost] @@ -568,10 +570,11 @@ def build(inputs, B = tvm.placeholder((n,), name='B') C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') s1 = tvm.create_schedule(C.op) - s2 = topi.cpp.cuda.schedule_injective("cuda", [C]) - f1 = tvm.lower(s1, [A, B, C], name="test_add1") - f2 = tvm.lower(s2, [A, B, C], name="test_add2") - m = tvm.build({"llvm": [f1], "cuda": [f2]}, target_host="llvm") + with tvm.target.cuda() as cuda_tgt: + s2 = topi.cuda.schedule_injective(cuda_tgt, [C]) + f1 = tvm.lower(s1, [A, B, C], name="test_add1") + f2 = tvm.lower(s2, [A, B, C], name="test_add2") + m = tvm.build({"llvm": [f1], "cuda": [f2]}, target_host="llvm") Note ---- diff --git a/python/tvm/error.py b/python/tvm/error.py index b5a7ed2374b7..a6d4f701d2a6 100644 --- a/python/tvm/error.py +++ b/python/tvm/error.py @@ -49,6 +49,7 @@ def __init__(self, msg): register_error("ValueError", ValueError) register_error("TypeError", TypeError) +register_error("AttributeError", AttributeError) @register_error diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index ceb98c4d251e..fff9c99e5007 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -37,8 +37,6 @@ from . import feature from .backend import vm from .backend import profiler_vm -from .backend import serializer -from .backend import deserializer from .backend import vmobj # Root operators diff --git a/python/tvm/relay/backend/deserializer.py b/python/tvm/relay/backend/deserializer.py deleted file mode 100644 index fde702b1cd04..000000000000 --- a/python/tvm/relay/backend/deserializer.py +++ /dev/null @@ -1,81 +0,0 @@ -# License .to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name -""" -The Relay Virtual Machine deserializer. - -Python interface for deserializing a Relay VM. -""" -from tvm import module -from tvm._ffi.runtime_ctypes import TVMByteArray -from . import _vm -from . import vm as rly_vm - -def _create_deserializer(code, lib): - """Create a deserializer object. - - Parameters - ---------- - code : bytearray - The serialized virtual machine code. - - lib : :py:class:`~tvm.module.Module` - The serialized runtime module/library that contains the hardware - dependent binary code. - - Returns - ------- - ret : Deserializer - The created virtual machine deserializer. - """ - if isinstance(code, (bytes, str)): - code = bytearray(code) - elif not isinstance(code, (bytearray, TVMByteArray)): - raise TypeError("vm is expected to be the type of bytearray or " + - "TVMByteArray, but received {}".format(type(code))) - - if not isinstance(lib, module.Module): - raise TypeError("lib is expected to be the type of tvm.module.Module" + - ", but received {}".format(type(lib))) - return _vm._Deserializer(code, lib) - - -class Deserializer: - """Relay VM deserializer. - - Parameters - ---------- - code : bytearray - The serialized virtual machine code. - - lib : :py:class:`~tvm.module.Module` - The serialized runtime module/library that contains the hardware - dependent binary code. - """ - def __init__(self, code, lib): - self.mod = _create_deserializer(code, lib) - self._deserialize = self.mod["deserialize"] - - def deserialize(self): - """Deserialize the serialized bytecode into a Relay VM. - - Returns - ------- - ret : VirtualMachine - The deserialized Relay VM. - """ - return rly_vm.VirtualMachine(self._deserialize()) diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index ae60b7a89b2f..1d53f6a92b07 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -72,6 +72,11 @@ class Closure(Value): """A closure produced by the interpreter.""" +@register_relay_node +class RecClosure(Value): + """A recursive closure produced by the interpreter.""" + + @register_relay_node class ConstructorValue(Value): def __init__(self, tag, fields, constructor): diff --git a/python/tvm/relay/backend/profiler_vm.py b/python/tvm/relay/backend/profiler_vm.py index 8ae3161e0b83..ded5d0d13bd7 100644 --- a/python/tvm/relay/backend/profiler_vm.py +++ b/python/tvm/relay/backend/profiler_vm.py @@ -49,8 +49,8 @@ def compile(mod, target=None, target_host=None, params=None): Returns ------- - vm : VirtualMachineProfiler - The profile VM runtime. + exec : Executable + The executable with profiling code. """ compiler = VMCompilerProfiler() target = compiler.update_target(target) @@ -60,7 +60,11 @@ def compile(mod, target=None, target_host=None, params=None): tophub_context = compiler.tophub_context(target) with tophub_context: compiler._compile(mod, target, target_host) - return VirtualMachineProfiler(compiler._get_vm()) + return vm.Executable(compiler._get_exec()) + +def enabled(): + """Whether vm profiler is enabled.""" + return hasattr(_vm, "_VMCompilerProfiler") class VMCompilerProfiler(vm.VMCompiler): """Build Relay module to run on VM runtime.""" @@ -68,13 +72,17 @@ def __init__(self): super().__init__() self.mod = _vm._VMCompilerProfiler() self._compile = self.mod["compile"] - self._get_vm = self.mod["get_vm"] + self._get_exec = self.mod["get_executable"] self._set_params_func = self.mod["set_params"] class VirtualMachineProfiler(vm.VirtualMachine): """Relay profile VM runtime.""" def __init__(self, mod): super().__init__(mod) + m = mod.module if isinstance(mod, vm.Executable) else mod + self.mod = _vm._VirtualMachineDebug(m) + self._init = self.mod["init"] + self._invoke = self.mod["invoke"] self._get_stat = self.mod["get_stat"] def get_stat(self): diff --git a/python/tvm/relay/backend/serializer.py b/python/tvm/relay/backend/serializer.py deleted file mode 100644 index b45ba9116a15..000000000000 --- a/python/tvm/relay/backend/serializer.py +++ /dev/null @@ -1,191 +0,0 @@ -# License .to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name -""" -The Relay Virtual Machine serializer. - -Python interface for serializing a Relay VM. -""" -import tvm -from . import _vm -from . import vm as rly_vm - -def _create_serializer(vm): - """Create a VM serializer. - - Parameters - ---------- - vm : Union[VirtualMachine, :py:class:`~tvm.module.Module`] - The virtual machine to be serialized. - - Returns - ------- - ret : Serializer - The created virtual machine serializer. - """ - if isinstance(vm, rly_vm.VirtualMachine): - vm = vm.module - elif not isinstance(vm, tvm.module.Module): - raise TypeError("vm is expected to be the type of VirtualMachine or " + - "tvm.Module, but received {}".format(type(vm))) - - return _vm._Serializer(vm) - - -class Serializer: - """Relay VM serializer.""" - def __init__(self, vm): - self.mod = _create_serializer(vm) - self._get_lib = self.mod["get_lib"] - self._get_bytecode = self.mod["get_bytecode"] - self._get_globals = self.mod["get_globals"] - self._get_stats = self.mod["get_stats"] - self._get_primitive_ops = self.mod["get_primitive_ops"] - self._serialize = self.mod["serialize"] - - @property - def stats(self): - """Get the statistics of the Relay VM. - - Returns - ------- - ret : String - The serialized statistic information. - """ - return self._get_stats() - - @property - def primitive_ops(self): - """Get the name of the primitive ops that are executed in the VM. - - Returns - ------- - ret : List[:py:class:`~tvm.expr.StringImm`] - The list of primitive ops. - """ - return [prim_op.value for prim_op in self._get_primitive_ops()] - - @property - def bytecode(self): - """Get the bytecode of the Relay VM. - - Returns - ------- - ret : String - The serialized bytecode. - - Notes - ----- - The bytecode is in the following format: - func_name reg_file_size num_instructions - param1 param2 ... paramM - instruction1 - instruction2 - ... - instructionN - - Each instruction is printed in the following format: - hash opcode field1 ... fieldX # The text format. - - The part starting from # is only used for visualization and debugging. - The real serialized code doesn't contain it, therefore the deserializer - doesn't need to deal with it as well. - """ - return self._get_bytecode() - - @property - def globals(self): - """Get the globals used by the Relay VM. - - Returns - ------- - ret : List[:py:class:`~tvm.expr.StringImm`] - The serialized globals. - """ - return [glb.value for glb in self._get_globals()] - - def serialize(self): - """Serialize the Relay VM. - - Returns - ------- - code : bytearray - The binary blob representing a serialized Relay VM. It can then be - saved to disk and later deserialized into a new VM. - - lib : :py:class:`~tvm.module.Module` - The runtime module that contains the generated code. It is - basically a library that is composed of hardware dependent code. - - Notes - ----- - The returned code is organized with the following sections in order. - - Global section. This section contains the globals used by the - virtual machine. - - Constant section. This section is used to store the constant pool of - a virtual machine. - - Primitive name section. This section is introduced to accommodate - the list of primitive operator names that will be invoked by the - virtual machine. - - Code section. The VM functions, including bytecode, are sitting in - this section. - - Examples - -------- - .. code-block:: python - - import numpy as np - import tvm - from tvm import relay - - # define a simple network. - x = relay.var('x', shape=(10, 10)) - f = relay.Function([x], x + x) - mod = relay.Module({"main": f}) - - # create a Relay VM. - ctx = tvm.cpu() - target = "llvm" - compiler = relay.vm.VMCompiler() - vm = compiler.compile(mod, target) - vm.init(ctx) - - # serialize. - ser = relay.serializer.Serializer(vm) - code, lib = ser.serialize() - - # save and load the code and lib file. - tmp = tvm.contrib.util.tempdir() - path_lib = tmp.relpath("lib.so") - lib.export_library(path_lib) - with open(tmp.relpath("code.bc"), "wb") as fo: - fo.write(code) - - loaded_lib = tvm.module.load(path_lib) - loaded_code = bytearray(open(tmp.relpath("code.bc"), "rb").read()) - - # deserialize. - deser = relay.deserializer.Deserializer(loaded_code, loaded_lib) - des_vm = deser.deserialize() - - # execute the deserialized vm. - des_vm.init(ctx) - x_data = np.random.rand(10, 10).astype('float32') - res = des_vm.run(x_data) - print(res.asnumpy()) - """ - return self._serialize(), self._get_lib() diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index e54629dd1344..e190e3f1eb41 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -24,15 +24,18 @@ import tvm from tvm import autotvm -from tvm._ffi.runtime_ctypes import TVMByteArray from tvm.relay import expr as _expr +from tvm._ffi.runtime_ctypes import TVMByteArray from . import _vm from . import vmobj as _obj from .interpreter import Executor +Tensor = _obj.Tensor +ADT = _obj.ADT + def _convert(arg, cargs): if isinstance(arg, (np.ndarray, tvm.nd.NDArray)): - cargs.append(_obj.tensor_object(arg)) + cargs.append(_obj.Tensor(arg)) elif isinstance(arg, (tuple, list)): field_args = [] for field in arg: @@ -41,6 +44,7 @@ def _convert(arg, cargs): else: raise "unsupported type" + def convert(args): cargs = [] for arg in args: @@ -49,12 +53,202 @@ def convert(args): return cargs +class Executable(object): + """Relay VM executable""" + def __init__(self, mod): + self.mod = mod + self._save = self.mod["save"] + self._get_lib = self.mod["get_lib"] + self._get_bytecode = self.mod["get_bytecode"] + self._get_stats = self.mod["get_stats"] + + def save(self): + """Save the Relay VM Executable. + + Returns + ------- + code : bytearray + The binary blob representing a serialized Relay VM executable. It + can then be saved to disk and later deserialized into a new + Executable. + + lib : :py:class:`~tvm.module.Module` + The runtime module that contains the generated code. It is + basically a library that is composed of hardware dependent code. + + Notes + ----- + The returned code is organized with the following sections in order. + - Global section. This section contains the globals used by the + virtual machine. + - Constant section. This section is used to store the constant pool of + a virtual machine. + - Primitive name section. This section is introduced to accommodate + the list of primitive operator names that will be invoked by the + virtual machine. + - Code section. The VM functions, including bytecode, are sitting in + this section. + + Examples + -------- + + .. code-block:: python + + import numpy as np + import tvm + from tvm import relay + # define a simple network. + x = relay.var('x', shape=(10, 10)) + f = relay.Function([x], x + x) + mod = relay.Module({"main": f}) + # create a Relay VM. + ctx = tvm.cpu() + target = "llvm" + executable = relay.vm.compile(mod, target) + code, lib = executable.save() + # save and load the code and lib file. + tmp = tvm.contrib.util.tempdir() + path_lib = tmp.relpath("lib.so") + lib.export_library(path_lib) + with open(tmp.relpath("code.ro"), "wb") as fo: + fo.write(code) + loaded_lib = tvm.module.load(path_lib) + loaded_code = bytearray(open(tmp.relpath("code.ro"), "rb").read()) + # deserialize. + des_exec = relay.vm.Executable.load_exec(loaded_code, loaded_code) + # execute the deserialized executable. + x_data = np.random.rand(10, 10).astype('float32') + des_vm = relay.vm.VirtualMachine(des_exec) + des_vm.init(ctx) + res = des_vm.run(x_data) + print(res.asnumpy()) + """ + return self._save(), self._get_lib() + + @staticmethod + def load_exec(bytecode, lib): + """Construct an executable from saved artifacts. + + Parameters + ---------- + bytecode : bytearray + The binary blob representing a the Relay VM bytecode. + + lib : :py:class:`~tvm.module.Module` + The runtime module that contains the generated code. + + Returns + ------- + exec: Executable + An executable constructed using the provided artifacts. + """ + if isinstance(bytecode, (bytes, str)): + code = bytearray(bytecode) + elif not isinstance(bytecode, (bytearray, TVMByteArray)): + raise TypeError("bytecode is expected to be the type of bytearray " + + "or TVMByteArray, but received {}".format(type(code))) + + if not isinstance(lib, tvm.module.Module): + raise TypeError("lib is expected to be the type of tvm.module.Module" + + ", but received {}".format(type(lib))) + + return Executable(_vm.Load_Executable(bytecode, lib)) + + @property + def lib(self): + """Get the library that contains hardware dependent code. + + Returns + ------- + ret : :py:class:`~tvm.Module` + The runtime module that contains hardware dependent code. + """ + return self._get_lib() + + @property + def stats(self): + """Get the statistics of the Relay VM executable. + + Returns + ------- + ret : String + The statistic information of the VM executable. + """ + return self._get_stats() + + @property + def primitive_ops(self): + """Get the name of the primitive ops contained in the executable. + + Returns + ------- + ret : List[String] + The list of primitive ops. + """ + ret = [] + num_primitives = _vm.GetNumOfPrimitives(self.module) + for i in range(num_primitives): + ret.append(_vm.GetPrimitiveFields(self.module, i)) + return ret + + @property + def bytecode(self): + """Get the bytecode of the Relay VM executable. + + Returns + ------- + ret : String + The bytecode of the executable. + + Notes + ----- + The bytecode is in the following format: + func_name reg_file_size num_instructions + param1 param2 ... paramM + instruction1 + instruction2 + ... + instructionN + + Each instruction is printed in the following format: + hash opcode field1 ... fieldX # The text format. + + The part starting from # is only used for visualization and debugging. + The real serialized code doesn't contain it, therefore the deserializer + doesn't need to deal with it as well. + """ + return self._get_bytecode() + + @property + def globals(self): + """Get the globals used by the Relay VM executable. + + Returns + ------- + ret : List[String] + The globals contained in the executable. + """ + ret = [] + num_globals = _vm.GetNumOfGlobals(self.module) + for i in range(num_globals): + ret.append(_vm.GetGlobalFields(self.module, i)) + return ret + + @property + def module(self): + """Return the runtime module contained in a virtual machine executable.""" + return self.mod + + class VirtualMachine(object): """Relay VM runtime.""" def __init__(self, mod): - self.mod = mod + if not isinstance(mod, (Executable, tvm.module.Module)): + raise TypeError("mod is expected to be the type of Executable or " + + "tvm.Module, but received {}".format(type(mod))) + m = mod.module if isinstance(mod, Executable) else mod + self.mod = _vm._VirtualMachine(m) self._init = self.mod["init"] - self._load_params = self.mod["load_params"] self._invoke = self.mod["invoke"] def init(self, ctx): @@ -68,23 +262,6 @@ def init(self, ctx): args = [ctx.device_type, ctx.device_id] self._init(*args) - def load_params(self, params): - """Load parameters for the VM. - - Parameters - ---------- - params : Union[bytearray, Dict] - The dictionary that contains serialized parameters. - """ - if isinstance(params, dict): - params = tvm.relay.save_param_dict(params) - elif isinstance(params, (bytes, str)): - params = bytearray(params) - if not isinstance(params, (bytearray, TVMByteArray)): - raise TypeError("params must be a bytearray") - - self._load_params(bytearray(params)) - def invoke(self, func_name, *args): """Invoke a function. @@ -119,11 +296,6 @@ def run(self, *args): """ return self.invoke("main", *args) - @property - def module(self): - """Return the runtime module contained in a virtual machine.""" - return self.mod - def compile(mod, target=None, target_host=None, params=None): """ @@ -152,8 +324,8 @@ def compile(mod, target=None, target_host=None, params=None): Returns ------- - vm : VirtualMachine - The VM runtime. + exec : Executable + The VM executable that contains both library code and bytecode. """ compiler = VMCompiler() @@ -164,14 +336,14 @@ def compile(mod, target=None, target_host=None, params=None): tophub_context = compiler.tophub_context(target) with tophub_context: compiler._compile(mod, target, target_host) - return VirtualMachine(compiler._get_vm()) + return Executable(compiler._get_exec()) class VMCompiler(object): """Build Relay module to run on VM runtime.""" def __init__(self): self.mod = _vm._VMCompiler() self._compile = self.mod["compile"] - self._get_vm = self.mod["get_vm"] + self._get_exec = self.mod["get_executable"] self._set_params_func = self.mod["set_params"] def set_params(self, params): @@ -237,7 +409,7 @@ class VMExecutor(Executor): mod : :py:class:`~tvm.relay.module.Module` The module to support the execution. - ctx : :py:class:`TVMContext` + ctx : :py:class:`~tvm.TVMContext` The runtime context to run the code on. target : :py:class:`Target` @@ -249,7 +421,8 @@ def __init__(self, mod, ctx, target): self.mod = mod self.ctx = ctx self.target = target - self.vm = compile(mod, target) + self.executable = compile(mod, target) + self.vm = VirtualMachine(self.executable) self.vm.init(ctx) def _make_executor(self, expr=None): diff --git a/python/tvm/relay/backend/vmobj.py b/python/tvm/relay/backend/vmobj.py index 4c92e9bf38a6..f3fdb763209d 100644 --- a/python/tvm/relay/backend/vmobj.py +++ b/python/tvm/relay/backend/vmobj.py @@ -18,32 +18,37 @@ from __future__ import absolute_import as _abs import numpy as _np -from tvm._ffi.vmobj import Object, ObjectTag, register_object +from tvm._ffi.object import Object, register_object, getitem_helper from tvm import ndarray as _nd from . import _vmobj -# TODO(@icemelon9): Add ClosureObject -@register_object -class TensorObject(Object): - """Tensor object.""" - tag = ObjectTag.TENSOR +@register_object("vm.Tensor") +class Tensor(Object): + """Tensor object. - def __init__(self, handle): - """Constructs a Tensor object - - Parameters - ---------- - handle : object - Object handle + Parameters + ---------- + arr : numpy.ndarray or tvm.nd.NDArray + The source array. - Returns - ------- - obj : TensorObject - A tensor object. - """ - super(TensorObject, self).__init__(handle) - self.data = _vmobj.GetTensorData(self) + ctx : TVMContext, optional + The device context to create the array + """ + def __init__(self, arr, ctx=None): + if isinstance(arr, _np.ndarray): + ctx = ctx if ctx else _nd.cpu(0) + self.__init_handle_by_constructor__( + _vmobj.Tensor, _nd.array(arr, ctx=ctx)) + elif isinstance(arr, _nd.NDArray): + self.__init_handle_by_constructor__( + _vmobj.Tensor, arr) + else: + raise RuntimeError("Unsupported type for tensor object.") + + @property + def data(self): + return _vmobj.GetTensorData(self) def asnumpy(self): """Convert data to numpy array @@ -56,69 +61,38 @@ def asnumpy(self): return self.data.asnumpy() -@register_object -class DatatypeObject(Object): - """Datatype object.""" - tag = ObjectTag.DATATYPE +@register_object("vm.ADT") +class ADT(Object): + """Algebatic data type(ADT) object. - def __init__(self, handle): - """Constructs a Datatype object + Parameters + ---------- + tag : int + The tag of ADT. - Parameters - ---------- - handle : object - Object handle + fields : list[Object] or tuple[Object] + The source tuple. + """ + def __init__(self, tag, fields): + for f in fields: + assert isinstance(f, Object) + self.__init_handle_by_constructor__( + _vmobj.ADT, tag, *fields) - Returns - ------- - obj : DatatypeObject - A Datatype object. - """ - super(DatatypeObject, self).__init__(handle) - self.tag = _vmobj.GetDatatypeTag(self) - num_fields = _vmobj.GetDatatypeNumberOfFields(self) - self.fields = [] - for i in range(num_fields): - self.fields.append(_vmobj.GetDatatypeFields(self, i)) + @property + def tag(self): + return _vmobj.GetADTTag(self) def __getitem__(self, idx): - return self.fields[idx] + return getitem_helper( + self, _vmobj.GetADTFields, len(self), idx) def __len__(self): - return len(self.fields) - - def __iter__(self): - return iter(self.fields) - -# TODO(icemelon9): Add closure object - -def tensor_object(arr, ctx=_nd.cpu(0)): - """Create a tensor object from source arr. - - Parameters - ---------- - arr : numpy.ndarray or tvm.nd.NDArray - The source array. - - ctx : TVMContext, optional - The device context to create the array - - Returns - ------- - ret : TensorObject - The created object. - """ - if isinstance(arr, _np.ndarray): - tensor = _vmobj.Tensor(_nd.array(arr, ctx)) - elif isinstance(arr, _nd.NDArray): - tensor = _vmobj.Tensor(arr) - else: - raise RuntimeError("Unsupported type for tensor object.") - return tensor + return _vmobj.GetADTNumberOfFields(self) def tuple_object(fields): - """Create a datatype object from source tuple. + """Create a ADT object from source tuple. Parameters ---------- @@ -127,30 +101,9 @@ def tuple_object(fields): Returns ------- - ret : DatatypeObject + ret : ADT The created object. """ for f in fields: assert isinstance(f, Object) return _vmobj.Tuple(*fields) - - -def datatype_object(tag, fields): - """Create a datatype object from tag and source fields. - - Parameters - ---------- - tag : int - The tag of datatype. - - fields : list[Object] or tuple[Object] - The source tuple. - - Returns - ------- - ret : DatatypeObject - The created object. - """ - for f in fields: - assert isinstance(f, Object) - return _vmobj.Datatype(tag, *fields) diff --git a/python/tvm/relay/debug.py b/python/tvm/relay/debug.py index ee30f25d88c1..8887a7eb3c7c 100644 --- a/python/tvm/relay/debug.py +++ b/python/tvm/relay/debug.py @@ -17,12 +17,8 @@ # pylint: disable=wildcard-import, redefined-builtin, invalid-name """The Relay IR namespace containing the IR definition and compiler.""" from __future__ import absolute_import -from .base import NodeBase, register_relay_node from ..api import register_func -@register_relay_node -class InterpreterState(NodeBase): - pass # pylint: disable=unused-argument def _debugger_init(expr, stack): diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 88779dfd76e0..8d59e99d8388 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -27,6 +27,7 @@ from .._ffi import base as _base from .. import nd as _nd from .. import convert +from ..ndarray import NDArray # will be registered afterwards _op_make = None @@ -305,6 +306,17 @@ def __call__(self, *args): """ return Call(self, args, None, None) + def get_params(self): + return _expr.FunctionGetParams(self) + + def set_params(self, params): + for key in params: + value = params[key] + if isinstance(value, NDArray): + params[key] = Constant(value) + + return _expr.FunctionSetParams(self, params) + @register_relay_node class Call(Expr): diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 637e1f0860da..d4b9162d6f3d 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -259,7 +259,7 @@ def get_relay_op(op_name): op = None else: # try search op in various modules - for candidate in (_op, _op.nn, _op.image, _op.vision): + for candidate in (_op, _op.nn, _op.image, _op.vision, _op.contrib): op = getattr(candidate, op_name, None) if op is not None: break diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index a7f787484b2c..a1f51ad41fb0 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -915,6 +915,7 @@ def _impl_v1(cls, inputs, attr, params): reps = attr.pop('repeats') # The number of times repeating the tensor data. return _op.tile(inputs[0], reps) + class Erf(OnnxOpConverter): """Operator converter for Erf """ @@ -922,6 +923,40 @@ class Erf(OnnxOpConverter): def _impl_v1(cls, inputs, attr, params): return _op.erf(inputs[0]) +class Where(OnnxOpConverter): + """Operator converter for Where + """ + @classmethod + def _impl_v9(cls, inputs, attr, params): + return _op.where(inputs[0], inputs[1], inputs[2]) + + +class ConstantOfShape(Elemwise): + """Operator converter for ConstantOfShape + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + if not isinstance(inputs, list) or len(inputs) < 2: + raise ValueError("Expect minimum 2 inputs") + # reps: The number of times repeating the tensor data. + shape = tuple(params[inputs[1].name_hint].asnumpy().astype('int').tolist()) + return _op.tile(inputs[0], reps=shape) + + +class ConstantOfShape(Elemwise): + """Operator converter for ConstantOfShape + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + if not isinstance(inputs, list) or len(inputs) < 2: + raise ValueError("Expect minimum 2 inputs") + # reps: The number of times repeating the tensor data. + try: + shape = tuple(params[inputs[1].name_hint].asnumpy().astype('int').tolist()) + except Exception as e: + raise ValueError(e) + return _op.tile(inputs[0], reps=shape) + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -1042,7 +1077,9 @@ def _get_convert_map(opset): 'Not': Not.get_converter(opset), 'And': And.get_converter(opset), 'Tile': Tile.get_converter(opset), - 'Erf': Erf.get_converter(opset) + 'Erf': Erf.get_converter(opset), + 'Where': Where.get_converter(opset), + 'ConstantOfShape': ConstantOfShape.get_converter(opset) } @@ -1162,7 +1199,14 @@ def from_onnx(self, graph, opset): self._params[i_name] = fill_value self._nodes[i_name] = new_var(node.output[0], shape=(), dtype=dtype) inputs.append(self._nodes[i_name]) - + if op_name == "ConstantOfShape": + t_proto = self._parse_attr(node.attribute)["value"] + i_name = node.output[0] + self._params[i_name] = self._parse_array(t_proto) + self._nodes[i_name] = new_var(i_name, + shape=list(t_proto.dims), + dtype=self._params[i_name].dtype) + inputs.append(self._nodes[i_name]) i_name = self._parse_value_proto(node) attr['tvm_custom'] = {} attr['tvm_custom']['name'] = i_name diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 38f9c523e0b1..bfa3431ba29e 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -22,10 +22,14 @@ import warnings from collections import defaultdict + # Numpy support import numpy as np import tvm + +from tvm.relay.prelude import Prelude + from .. import analysis from .. import expr as _expr from .. import op as _op @@ -432,6 +436,24 @@ def _impl(inputs, attr, params): return AttrCvt(op_name="copy", ignores=['message'])(inputs, attr) return _impl +def _assert(): + # ToDo: In general people want asserts to be gone from TensorFlow graphs + # when they are optimizing them, so converting it to a no-op is + # reasonable. However, it would be nice to have the option to keep them + # once Relay gets a Halt or Assert op. + return _no_op() + +def _no_op(): + def _impl(inputs, attr, params): + # ToDo: This should really be an op that returns nothing, which could + # be represented as an empty tuple. It turns out that TVM + # infrastructure doesn't like running functions that return None and + # also don't like running functions that return an empty tuple. So it + # doesn't work, but it should be made to work and then this could be + # improved. In the mean time, it is hard to imagine a case where it + # matters in any real way that a no-op is converted to a constant 0. + return tvm.relay.const(0) + return _impl def _matmul(): def _impl(inputs, attr, params): @@ -508,6 +530,69 @@ def _impl(inputs, attr, params): return _op.concatenate(inputs_reshaped, axis) return _impl +def _tensor_array(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr.get('dtype').name + tensor_array_constructor = prelude.get_var('tensor_array', dtype_str) + return tensor_array_constructor(_op.take(inputs[0], tvm.relay.const(0))) + return _impl + +def _tensor_array_scatter(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr.get('T').name + values_rank = len(inputs[2].type_annotation.shape) + unstack_name = "tensor_array_unstack_tensor{}".format(values_rank) + unstack_function = prelude.get_var(unstack_name, dtype_str) + values = unstack_function(inputs[2]) + tensor_array_scatter_func = prelude.get_var('tensor_array_scatter', dtype_str) + return tensor_array_scatter_func(inputs[0], inputs[1], values) + return _impl + +def _tensor_array_gather(): + def _impl(inputs, attr, params, prelude): + return prelude.tensor_array_gather(inputs[2], inputs[1]) + return _impl + +def _tensor_array_size(): + def _impl(inputs, attr, params, prelude): + return prelude.length(inputs[0]) + return _impl + +def _tensor_array_write(): + def _impl(inputs, attr, params, prelude): + input_rank = len(inputs[2].type_annotation.shape) + dtype = attr.get('T').name + + tensor_name = 'tensor{}'.format(input_rank) + tensor_func = prelude.get_var(tensor_name, dtype) + v = tensor_func(inputs[2]) + write_func = prelude.get_var('tensor_array_write', dtype) + + return write_func(inputs[3], _op.take(inputs[1], tvm.relay.const(0)), v) + return _impl + +def _tensor_array_read(): + def _impl(inputs, attr, params, prelude): + read_func = prelude.get_var('tensor_array_read', attr.get('dtype').name) + return read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0))) + return _impl + +def _tensor_array_split(): + def _impl(inputs, attr, params, prelude): + input_rank = len(inputs[1].type_annotation.shape) + dtype_str = attr.get('T').name + v = prelude.get_var("tensor{}".format(input_rank), dtype_str)(inputs[1]) + lengths = _op.cast(inputs[2], 'int32') + split_var = prelude.get_var('tensor_array_split', dtype_str) + return split_var(inputs[0], v, lengths) + return _impl + +def _tensor_array_concat(): + def _impl(inputs, attr, params, prelude): + concat_func = prelude.get_var('tensor_array_concat', attr['dtype'].name) + return concat_func(inputs[1]) + return _impl + def _tile(): def _impl(inputs, attr, params): reps = _get_list_param(params, inputs.pop()) @@ -1238,6 +1323,13 @@ def _impl(inputs, attr, params): return _op.multiply(difference, difference) return _impl +def _size(): + def _impl(inputs, attr, params): + new_attr = attr + new_attr['out_type'] = attr['out_type'].name + return AttrCvt('ndarray_size', transforms={'out_type' : 'dtype'})(inputs, new_attr) + return _impl + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -1252,6 +1344,7 @@ def _impl(inputs, attr, params): 'All' : _reduce('all'), 'ArgMax' : _argx(_op.argmax, 'argmax'), 'ArgMin' : _argx(_op.argmin, 'argmin'), + 'Assert' : _assert(), 'AvgPool' : _pooling('avg_pool'), 'BatchMatMul' : _batch_matmul(), 'BatchMatMulV2' : _batch_matmul(), @@ -1310,9 +1403,18 @@ def _impl(inputs, attr, params): 'Mod' : _elemwise('mod'), 'Mul' : _elemwise('multiply'), 'Neg' : AttrCvt('negative'), + 'NoOp' : _no_op(), 'NotEqual' : _broadcast('not_equal'), 'OneHot' : _one_hot(), 'Pack' : _pack(), + 'TensorArrayV3' : _tensor_array(), + 'TensorArrayScatterV3' : _tensor_array_scatter(), + 'TensorArrayGatherV3' : _tensor_array_gather(), + 'TensorArraySizeV3' : _tensor_array_size(), + 'TensorArrayWriteV3' : _tensor_array_write(), + 'TensorArrayReadV3' : _tensor_array_read(), + 'TensorArraySplitV3' : _tensor_array_split(), + 'TensorArrayConcatV3' : _tensor_array_concat(), 'Pad' : _pad('Pad'), 'PadV2' : _pad('PadV2'), 'Pow' : _elemwise('power'), @@ -1335,7 +1437,7 @@ def _impl(inputs, attr, params): 'Shape' : _shape(), 'Sigmoid' : AttrCvt('sigmoid'), 'Sign' : AttrCvt('sign'), - 'Size' : AttrCvt('ndarray_size'), + 'Size' : _size(), 'Slice' : _slice(), 'Softmax' : _softmax(), 'Softplus' : _softplus(), @@ -1860,6 +1962,7 @@ def __init__(self): self._loops = {} self._branches = {} self._mod = _module.Module({}) + self._prelude = Prelude(self._mod) def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): """Construct relay nodes from tensorflow graph definition - GraphDef. @@ -2113,8 +2216,11 @@ def _parse_param(self, key, value, name, shape): if np_array.dtype == np.dtype(object): # Object types are generally tensorflow DT_STRING (DecodeJpeg op). # Just leave it as placeholder. - self._nodes[name] = [_expr.var(name, shape=shape[name], dtype='uint8')] - + if shape: + var_shape = shape[name] + else: + var_shape = tensor_util.TensorShapeProtoToList(value.tensor.tensor_shape) + self._nodes[name] = [_expr.var(name, shape=var_shape, dtype='uint8')] return array_ndim = len(np_array.shape) @@ -2335,7 +2441,11 @@ def _convert_operator(self, op_name, inputs, attrs, if op_name in identity_list: sym = get_relay_op(op_name)(*inputs, **attrs) elif op_name in convert_map: - sym = convert_map[op_name](inputs, attrs, self._params) + if 'TensorArray' in op_name: + sym = convert_map[op_name](inputs, attrs, self._params, self._prelude) + else: + sym = convert_map[op_name](inputs, attrs, self._params) + elif op_name in convert_map_rnn: sym = self._convert_rnn_operator(op_name, inputs, attrs, self._params, graph, diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 35bc85e09fdd..b042af9fbe65 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -82,6 +82,7 @@ def __init__(self, model, subgraph, exp_tab): 'REDUCE_MAX': self._convert_reduce_max, 'MEAN': self._convert_reduce_mean, 'REDUCE_PROD': self._convert_reduce_prod, + 'SUM': self._convert_reduce_sum, 'FULLY_CONNECTED': self.convert_fully_connected, 'PAD': self.convert_pad, 'PACK': self.convert_pack, @@ -224,6 +225,18 @@ def has_same_qnn_params(self, lhs_tensor, rhs_tensor): return lhs_tensor.qnn_params['scale'] == rhs_tensor.qnn_params['scale'] and \ lhs_tensor.qnn_params['zero_point'] == rhs_tensor.qnn_params['zero_point'] + def is_quantized(self, op): + """Check if an input tensor is quantized.""" + try: + from tflite.Operator import Operator + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + first_tensor = input_tensors[0] + return first_tensor.qnn_params is not None + def convert_conv2d(self, op): """Convert TFLite conv2d""" return self.convert_conv(op, "conv2d") @@ -498,7 +511,25 @@ def _convert_elemwise(self, relay_op, op): rhs_type_str = self.get_tensor_type_str(rhs_tensor.tensor.Type()) rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor), dtype=rhs_type_str) - out = relay_op(lhs_expr, rhs_expr) + + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1, "output tensors length should be 1" + output_tensor = output_tensors[0] + + # If quantized, extracts qnn params and call QNN add operator. + if lhs_tensor.qnn_params: + assert rhs_tensor.qnn_params, "Both tensors should be quantized." + assert output_tensor.qnn_params, "Output tensor should be quantized." + out = relay_op(lhs=lhs_expr, + rhs=rhs_expr, + lhs_scale=lhs_tensor.qnn_params['scale'], + lhs_zero_point=lhs_tensor.qnn_params['zero_point'], + rhs_scale=rhs_tensor.qnn_params['scale'], + rhs_zero_point=rhs_tensor.qnn_params['zero_point'], + output_scale=output_tensor.qnn_params['scale'], + output_zero_point=output_tensor.qnn_params['zero_point']) + else: + out = relay_op(lhs_expr, rhs_expr) # Options (fused_activation_function) options = None @@ -517,36 +548,70 @@ def _convert_elemwise(self, relay_op, op): fused_activation_fn = options.FusedActivationFunction() # if we have activation fn if fused_activation_fn != ActivationFunctionType.NONE: + if output_tensor.qnn_params: + raise tvm.error.OpNotImplemented( + 'Elemwise operators with fused activation are not supported yet.') out = self.convert_fused_activation_function(out, fused_activation_fn) return out def convert_add(self, op): """Convert TFLite ADD""" + # Check if the input tensor is quantized, call QNN op + if self.is_quantized(op): + return self._convert_elemwise(_qnn.op.add, op) return self._convert_elemwise(_op.add, op) def convert_sub(self, op): """Convert TFLite SUB""" + # Check if the input tensor is quantized, call QNN op + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized sub operator is not supported yet.') return self._convert_elemwise(_op.subtract, op) def convert_mul(self, op): """Convert TFLite MUL""" + # Check if the input tensor is quantized, call QNN op + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized mul operator is not supported yet.') return self._convert_elemwise(_op.multiply, op) def convert_div(self, op): """Convert TFLite DIV""" + # Check if the input tensor is quantized, call QNN op + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized div operator is not supported yet.') return self._convert_elemwise(_op.divide, op) def convert_pow(self, op): + # Check if the input tensor is quantized, call QNN op + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized pow operator is not supported yet.') return self._convert_elemwise(_op.power, op) def convert_maximum(self, op): + # Check if the input tensor is quantized, call QNN op + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized maximum operator is not supported yet.') return self._convert_elemwise(_op.maximum, op) def convert_minimum(self, op): + # Check if the input tensor is quantized, call QNN op + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized minimum operator is not supported yet.') return self._convert_elemwise(_op.minimum, op) def convert_greater(self, op): + # Check if the input tensor is quantized, call QNN op + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized greater operator is not supported yet.') return self._convert_elemwise(_op.greater, op) def convert_zeros_like(self, op): @@ -608,6 +673,9 @@ def _convert_reduce_mean(self, op): def _convert_reduce_prod(self, op): return self._convert_reduce(_op.reduce.prod, op) + def _convert_reduce_sum(self, op): + return self._convert_reduce(_op.reduce.sum, op) + def convert_fully_connected(self, op): """Convert TFLite fully connected""" try: diff --git a/python/tvm/relay/op/_reduce.py b/python/tvm/relay/op/_reduce.py index f6b699f1e9cc..845ec4b9ba87 100644 --- a/python/tvm/relay/op/_reduce.py +++ b/python/tvm/relay/op/_reduce.py @@ -37,3 +37,4 @@ def _schedule_reduce(_, outs, target): _reg.register_schedule("mean", _schedule_reduce) _reg.register_schedule("variance", _schedule_reduce) _reg.register_schedule("nn.cross_entropy", _schedule_reduce) +_reg.register_schedule("nn.cross_entropy_with_logits", _schedule_reduce) diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index da5804906269..188b3bb15956 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -108,6 +108,29 @@ def clip_compute(attrs, inputs, output_type, target): register_schedule("clip", schedule_elemwise) +@script +def _cast_shape_function(x): + out_ndim = len(x) + out = output_tensor((out_ndim,), "int64") + for i in const_range(out_ndim): + out[i] = x[i] + return out + +def cast_shape_func(attrs, inputs, out_ndims): + return [_cast_shape_function(*inputs)] + +@script +def _expand_dims_shape_func(x): + ndim = len(x.shape) + out = output_tensor((ndim+1,), "int64") + out[0] = int64(1) + for i in const_range(0, ndim): + out[i+1] = int64(x.shape[i]) + return out + +def expand_dims_shape_func(attrs, inputs, out_ndims): + return [_expand_dims_shape_func(*inputs)] + # shape func @script def _broadcast_shape_func(x, y, ndim): @@ -140,6 +163,9 @@ def _broadcast_shape_func(x, y, ndim): def broadcast_shape_func(attrs, inputs, out_ndims): return [_broadcast_shape_func(*inputs, out_ndims[0])] +register_shape_func("expand_dims", False, expand_dims_shape_func) +register_shape_func("cast", False, cast_shape_func) + register_shape_func("add", False, broadcast_shape_func) register_shape_func("subtract", False, broadcast_shape_func) register_shape_func("multiply", False, broadcast_shape_func) diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 3a82e46e6a7d..d55cad7c7a2d 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -48,6 +48,9 @@ tile, transpose, where, + repeat, + expand_dims, + full_like ) @@ -198,6 +201,7 @@ def clip_grad(orig, grad): @register_gradient("nn.max_pool2d") def max_pool2d_grad(orig, grad): + """Returns the gradient of max_pool2d.""" attrs = orig.attrs pool_grad = _nn.max_pool2d_grad(grad, orig.args[0], pool_size=attrs.pool_size, strides=attrs.strides, padding=attrs.padding, @@ -207,6 +211,7 @@ def max_pool2d_grad(orig, grad): @register_gradient("nn.avg_pool2d") def avg_pool2d_grad(orig, grad): + """Returns the gradient of avg_pool2d.""" attrs = orig.attrs pool_grad = _nn.avg_pool2d_grad(grad, orig.args[0], pool_size=attrs.pool_size, strides=attrs.strides, padding=attrs.padding, @@ -215,6 +220,26 @@ def avg_pool2d_grad(orig, grad): return [pool_grad] +@register_gradient("nn.global_avg_pool2d") +def global_avg_pool2d_grad(orig, grad): + """Returns the gradient of global_avg_pool2d.""" + data = orig.args[0] + shape = data.checked_type.shape + layout = orig.attrs.layout + + # we assume NCHW or NHWC layout for now, but easy to add more + assert layout in ["NCHW", "NHWC"] + if layout == "NCHW": + pool_size = shape[2], shape[3] + elif layout == "NHWC": + pool_size = shape[1], shape[2] + + pool_grad = _nn.avg_pool2d_grad(grad, data, pool_size=pool_size, + strides=(1, 1), padding=(0, 0), + layout=layout) + return [pool_grad] + + # not implemented, this is only for testing. @register_gradient("concatenate") def concatenate_grad(orig, grad): @@ -287,16 +312,53 @@ def conv2d_grad(orig, grad): return [backward_data, backward_weight] +def _get_reduce_axis(call): + """Helper function that returns the reduce axis of the call as plain python ints.""" + x, axis = call.args[0], call.attrs.axis + shape = x.checked_type.concrete_shape + + # should never exclude when axis is None + assert not (axis is None and call.attrs.exclude) + + if axis is None: + return None + + # convert to nonnegative integers and sort + axis = sorted([ax if ax >= 0 else len(shape) + ax for ax in map(int, axis)]) + if call.attrs.exclude: + axis = [ax for ax in range(len(shape)) if ax not in axis] + return axis + + +def _unreduce_expand(x, axis): + """Helper function that returns x expanded on the reduced dimensions in axis.""" + # assume axis is sorted nonnegative ints + for ax in axis: + x = expand_dims(x, ax) + return x + + @register_gradient("max") def max_grad(orig, grad): """Returns the gradient of max""" - # Only support axis=0, since broadcasting orig to x behaves incorrectly - x, axis = orig.args[0], orig.attrs.axis - assert(axis is not None and len(axis) == 1 and int(axis[0]) == 0) - orig = broadcast_to_like(orig, x) - grad = broadcast_to_like(grad, x) - indicators = cast_like(equal(orig, x), grad) - return [indicators * grad] + x, axis = orig.args[0], _get_reduce_axis(orig) + shape = x.checked_type.concrete_shape + + repeated = orig + if axis is None: + repeated = full_like(x, repeated) + else: + # expand dims (if necessary) and repeat along each axis + if not orig.attrs.keepdims: + repeated = _unreduce_expand(repeated, axis) + grad = _unreduce_expand(grad, axis) + for ax in axis: + repeated = repeat(repeated, shape[ax], ax) + + indicators = cast_like(equal(repeated, x), grad) + num_selected = _sum(indicators, axis, keepdims=True) + # spread error across all max weights + return [indicators * grad / num_selected] @register_gradient("nn.softmax") @@ -372,7 +434,11 @@ def negative_grad(orig, grad): @register_gradient("sum") def sum_grad(orig, grad): """Returns grad broadcasted to data dims""" - data = orig.args[0] + data, axis = orig.args[0], _get_reduce_axis(orig) + if not orig.attrs.keepdims: + if axis is None: + axis = list(range(len(data.checked_type.concrete_shape))) + grad = _unreduce_expand(grad, axis) return [broadcast_to_like(grad, data)] @@ -383,3 +449,12 @@ def cross_entropy_grad(orig, grad): batch_size = take(shape, const(0, dtype='int32'), axis=0) grad = grad / batch_size.astype('float32') return [-grad * y / x, -grad * log(x)] + + +@register_gradient("nn.cross_entropy_with_logits") +def cross_entropy_with_logits_grad(orig, grad): + x, y = orig.args + shape = shape_of(x) + batch_size = take(shape, const(0, dtype='int32'), axis=0) + grad = grad / batch_size.astype('float32') + return [-grad * y, -grad * x] diff --git a/python/tvm/relay/op/annotation/annotation.py b/python/tvm/relay/op/annotation/annotation.py index 10c898538596..2b9d4bcd81bc 100644 --- a/python/tvm/relay/op/annotation/annotation.py +++ b/python/tvm/relay/op/annotation/annotation.py @@ -17,10 +17,10 @@ """Annotation operations.""" from __future__ import absolute_import as _abs from . import _make +from ..op import register_schedule, schedule_injective from .... import nd as _nd from .... import TVMContext as _TVMContext - def on_device(data, device): """Annotate an expression with a certain device type. @@ -61,3 +61,20 @@ def stop_fusion(data): The annotated expression. """ return _make.stop_fusion(data) + +def checkpoint(data): + """Annotate an expression to be a checkpoint for the checkpointing memory optimization. + + Parameters + ---------- + data : tvm.relay.Expr + The expression to be annotated. + + Returns + ------- + result : tvm.relay.Expr + The annotated expression. + """ + return _make.checkpoint(data) + +register_schedule("annotation.checkpoint", schedule_injective) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index b8572349fb9d..5786c228abc0 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -153,14 +153,14 @@ def compute_conv2d(attrs, inputs, out_type, target): out_dtype = (inputs[0].dtype if out_dtype in ("same", "") else out_dtype) - assert layout in ["NCHW", "NHWC", "NCHW4c"] + assert layout in ["NCHW", "NHWC", "NCHW4c", "HWCN"] (dilation_h, dilation_w) = dilation if dilation_h < 1 or dilation_w < 1: raise ValueError("dilation should be positive value") def _get_out_depth(): weight_shape = get_const_tuple(inputs[1].shape) - if kernel_layout == "HWOI": + if kernel_layout.startswith("HW"): return weight_shape[2] * weight_shape[3] return weight_shape[0] * weight_shape[1] @@ -192,11 +192,13 @@ def schedule_conv2d(attrs, outs, target): with target: if groups == 1 and layout == "NCHW": return topi.generic.schedule_conv2d_nchw(outs) - if groups == 1 and layout == "NCHW4c": + elif groups == 1 and layout == "NCHW4c": return topi.generic.schedule_conv2d_nchw(outs) - if groups == 1 and layout == "NHWC": + elif groups == 1 and layout == "NHWC": return topi.generic.schedule_conv2d_nhwc(outs) - if groups != 1: + elif groups == 1 and layout == "HWCN": + return topi.generic.schedule_conv2d_hwcn(outs) + elif groups != 1: # collect in_channels to distinguish depthwise and group conv2d op = _find_conv2d_op(outs[0].op) assert op is not None @@ -768,3 +770,12 @@ def schedule_bitserial_dense(attrs, outputs, target): def compute_cross_entropy(attrs, inputs, out_dtype, target): x, y = inputs return [-topi.sum(topi.log(x) * y) / x.shape[0]] + + +reg.register_pattern("nn.cross_entropy_with_logits", OpPattern.OPAQUE) + + +@reg.register_compute("nn.cross_entropy_with_logits") +def compute_cross_entropy_with_logits(attrs, inputs, out_dtype, target): + x, y = inputs + return [-topi.sum(x * y) / x.shape[0]] diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 9ddb3ece4ce2..1f289d1bd27a 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -1807,3 +1807,22 @@ def cross_entropy(predictions, targets): The computed result. """ return _make.cross_entropy(predictions, targets) + + +def cross_entropy_with_logits(predictions, targets): + """CrossEntropy with logits. + + Parameters + ---------- + predictions : tvm.relay.Expr + The predictions. + + targets : tvm.relay.Expr + The targets. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + return _make.cross_entropy_with_logits(predictions, targets) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 803d8ef50db5..d27ffe512617 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -16,8 +16,513 @@ # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """A prelude containing useful global functions and ADT definitions.""" +from .ty import GlobalTypeVar, TensorType, Any, scalar_type +from .expr import Var, Function, GlobalVar, If, const +from .op.tensor import add, subtract, equal +from .adt import Constructor, TypeData, Clause, Match +from .adt import PatternConstructor, PatternVar, PatternWildcard +from . import op from .module import Module +class TensorArrayOps(object): + """Contains tensor array related ops""" + + def __init__(self, prelude, dtype): + """Create tensor array ops registry""" + self.prelude = prelude + self.dtype = dtype + + def get_name(self, canonical): + """Get name corresponding to the caninical name""" + return self.prelude.get_name(canonical, self.dtype) + + def get_var(self, canonical): + """Get var corresponding to the caninical name""" + return self.prelude.get_var(canonical, self.dtype) + + def define_tensor_adt(self): + """Defines the dynamic tensor ADT, which is the container for tensors + with variable shapes.""" + tensor_type_name = self.get_name('tensor_t') + tensor_type_var = GlobalTypeVar(tensor_type_name) + setattr(self.prelude, tensor_type_name, tensor_type_var) + tensor0_type = TensorType([], self.dtype) + tensor1_type = TensorType([Any()], self.dtype) + tensor2_type = TensorType([Any(), Any()], self.dtype) + tensor3_type = TensorType([Any(), Any(), Any()], self.dtype) + tensor4_type = TensorType([Any(), Any(), Any(), Any()], self.dtype) + tensor5_type = TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype) + tensor6_type = TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype) + tensor_nil_name = self.get_name('tensor_nil') + tensor0_name = self.get_name('tensor0') + tensor1_name = self.get_name('tensor1') + tensor2_name = self.get_name('tensor2') + tensor3_name = self.get_name('tensor3') + tensor4_name = self.get_name('tensor4') + tensor5_name = self.get_name('tensor5') + tensor6_name = self.get_name('tensor6') + tensor_nil_case = Constructor(tensor_nil_name, [], tensor_type_var) + tensor0_case = Constructor(tensor0_name, [tensor0_type], tensor_type_var) + tensor1_case = Constructor(tensor1_name, [tensor1_type], tensor_type_var) + tensor2_case = Constructor(tensor2_name, [tensor2_type], tensor_type_var) + tensor3_case = Constructor(tensor3_name, [tensor3_type], tensor_type_var) + tensor4_case = Constructor(tensor4_name, [tensor4_type], tensor_type_var) + tensor5_case = Constructor(tensor5_name, [tensor5_type], tensor_type_var) + tensor6_case = Constructor(tensor6_name, [tensor6_type], tensor_type_var) + setattr(self.prelude, tensor_nil_name, tensor_nil_case) + setattr(self.prelude, tensor0_name, tensor0_case) + setattr(self.prelude, tensor1_name, tensor1_case) + setattr(self.prelude, tensor2_name, tensor2_case) + setattr(self.prelude, tensor3_name, tensor3_case) + setattr(self.prelude, tensor4_name, tensor4_case) + setattr(self.prelude, tensor5_name, tensor5_case) + setattr(self.prelude, tensor6_name, tensor6_case) + self.prelude.mod[tensor_type_var] = TypeData(tensor_type_var, [], [tensor_nil_case, + tensor0_case, + tensor1_case, + tensor2_case, + tensor3_case, + tensor4_case, + tensor5_case, + tensor6_case]) + + def define_tensor_take(self): + """Defines a function to return a range of tensor_t on axis 0. + tensor_take(t, lower, upper) : + tensor_t -> Tensor[(), int32] -> Tensor[(), int32] -> tensor_t + """ + take_name = self.get_name("tensor_take") + take_var = GlobalVar(take_name) + setattr(self.prelude, take_name, take_var) + tensor_t = self.get_var('tensor_t') + tensor1_var = self.get_var('tensor1') + tensor2_var = self.get_var('tensor2') + tensor3_var = self.get_var('tensor3') + tensor4_var = self.get_var('tensor4') + tensor5_var = self.get_var('tensor5') + tensor6_var = self.get_var('tensor6') + t = Var('tensor', tensor_t()) + lower = Var('lower', scalar_type('int32')) + upper = Var('upper', scalar_type('int32')) + t1 = Var('t1') + t2 = Var('t2') + t3 = Var('t3') + t4 = Var('t4') + t5 = Var('t5') + t6 = Var('t6') + tensor1_case =\ + Clause(PatternConstructor(tensor1_var, [PatternVar(t1)]), + tensor1_var(op.take(t1, op.arange(lower, upper, dtype='int32')))) + tensor2_case =\ + Clause(PatternConstructor(tensor2_var, [PatternVar(t2)]), + tensor2_var(op.take(t2, op.arange(lower, upper, dtype='int32'), axis=0))) + tensor3_case =\ + Clause(PatternConstructor(tensor3_var, [PatternVar(t3)]), + tensor3_var(op.take(t3, op.arange(lower, upper, dtype='int32'), axis=0))) + tensor4_case =\ + Clause(PatternConstructor(tensor4_var, [PatternVar(t4)]), + tensor4_var(op.take(t4, op.arange(lower, upper, dtype='int32'), axis=0))) + tensor5_case =\ + Clause(PatternConstructor(tensor5_var, [PatternVar(t5)]), + tensor5_var(op.take(t5, op.arange(lower, upper, dtype='int32'), axis=0))) + tensor6_case =\ + Clause(PatternConstructor(tensor6_var, [PatternVar(t6)]), + tensor6_var(op.take(t6, op.arange(lower, upper, dtype='int32'), axis=0))) + self.prelude.mod[take_var] =\ + Function([t, lower, upper], + Match(t, [tensor1_case, + tensor2_case, + tensor3_case, + tensor4_case, + tensor5_case, + tensor6_case], False), + tensor_t(), []) + + def define_tensor_expand_dims(self): + """Defines a function to grow a tensor_t's rank by adding one dimension in front + of the original tensor_t. + tensor_expand_dims(t) : tensor_t -> tensor_t + """ + expand_dims_name = self.get_name("tensor_expand_dims") + expand_dims_var = GlobalVar(expand_dims_name) + setattr(self.prelude, expand_dims_name, expand_dims_var) + tensor_type_var = self.get_var('tensor_t') + x = Var("x", tensor_type_var()) + t0 = Var("t0") + t1 = Var("t1") + t2 = Var("t2") + t3 = Var("t3") + t4 = Var("t4") + t5 = Var("t5") + tensor0_var = self.get_var('tensor0') + tensor1_var = self.get_var('tensor1') + tensor2_var = self.get_var('tensor2') + tensor3_var = self.get_var('tensor3') + tensor4_var = self.get_var('tensor4') + tensor5_var = self.get_var('tensor5') + tensor6_var = self.get_var('tensor6') + tensor0_case = Clause(PatternConstructor(tensor0_var, [PatternVar(t0)]), + tensor1_var(op.expand_dims(t0, 0, 1))) + tensor1_case = Clause(PatternConstructor(tensor1_var, [PatternVar(t1)]), + tensor2_var(op.expand_dims(t1, 0, 1))) + tensor2_case = Clause(PatternConstructor(tensor2_var, [PatternVar(t2)]), + tensor3_var(op.expand_dims(t2, 0, 1))) + tensor3_case = Clause(PatternConstructor(tensor3_var, [PatternVar(t3)]), + tensor4_var(op.expand_dims(t3, 0, 1))) + tensor4_case = Clause(PatternConstructor(tensor4_var, [PatternVar(t4)]), + tensor5_var(op.expand_dims(t4, 0, 1))) + tensor5_case = Clause(PatternConstructor(tensor5_var, [PatternVar(t5)]), + tensor6_var(op.expand_dims(t5, 0, 1))) + self.prelude.mod[expand_dims_var] =\ + Function([x], + Match(x, [tensor0_case, + tensor1_case, + tensor2_case, + tensor3_case, + tensor4_case, + tensor5_case], False)) + + def define_tensor_concat(self): + """Defines a function to concatenate two tensor_t on the first axis + + tensor_concatenate(t) : tensor_t -> tensor_t -> tensor_t + """ + concat_name = self.get_name("tensor_concatenate") + concat_var = GlobalVar(concat_name) + setattr(self.prelude, concat_name, concat_var) + tensor_type_var = self.get_var('tensor_t') + x = Var("x", tensor_type_var()) + y = Var("y", tensor_type_var()) + + tensor1_var = self.get_var('tensor1') + tensor2_var = self.get_var('tensor2') + tensor3_var = self.get_var('tensor3') + tensor4_var = self.get_var('tensor4') + t11 = Var("t11") + t12 = Var("t12") + t21 = Var("t21") + t22 = Var("t22") + t31 = Var("t31") + t32 = Var("t32") + t41 = Var("t41") + t42 = Var("t42") + tensor1_case = Clause(PatternConstructor(tensor1_var, [PatternVar(t11)]), + Match(y, [Clause(PatternConstructor(tensor1_var, [PatternVar(t12)]), + tensor1_var(op.concatenate([t11, t12], axis=0)))], + False)) + tensor2_case = Clause(PatternConstructor(tensor2_var, [PatternVar(t21)]), + Match(y, [Clause(PatternConstructor(tensor2_var, [PatternVar(t22)]), + tensor2_var(op.concatenate([t21, t22], axis=0)))], + False)) + tensor3_case = Clause(PatternConstructor(tensor3_var, [PatternVar(t31)]), + Match(y, [Clause(PatternConstructor(tensor3_var, [PatternVar(t32)]), + tensor3_var(op.concatenate([t31, t32], axis=0)))], + False)) + tensor4_case = Clause(PatternConstructor(tensor4_var, [PatternVar(t41)]), + Match(y, [Clause(PatternConstructor(tensor4_var, [PatternVar(t42)]), + tensor4_var(op.concatenate([t41, t42], axis=0)))], + False)) + # op.concatenate does not support tensor with rank higher than 4 + self.prelude.mod[concat_var] =\ + Function([x, y], Match(x, [tensor1_case, + tensor2_case, + tensor3_case, + tensor4_case], False)) + + def define_tensor_array(self): + """Defines a function to create a tensor array with size n. + tensor_array(n) : Tensor[(), int32] -> list[tensor_t] + """ + tensor_array_constructor_name = self.get_name("tensor_array") + tensor_array_constructor_var = GlobalVar(tensor_array_constructor_name) + setattr(self.prelude, tensor_array_constructor_name, tensor_array_constructor_var) + tensor_nil_var = self.get_var('tensor_nil') + tensor_type_var = self.get_var('tensor_t') + n = Var("x", scalar_type('int32')) + body = If(equal(n, const(0)), + self.prelude.nil(), + self.prelude.cons(tensor_nil_var(), + tensor_array_constructor_var(subtract(n, const(1))))) + self.prelude.mod[tensor_array_constructor_var] = \ + Function([n], body, self.prelude.l(tensor_type_var()), []) + + def define_tensor_array_read(self): + """Defines a function to get the head of a list. Assume the list has at least one + element. + + tensor_array_read(ta, n) : list[tensor_t] -> Tensor[(), int32] -> tensor_t + """ + read_name = self.get_name("tensor_array_read") + read_var = GlobalVar(read_name) + setattr(self.prelude, read_name, read_var) + tensor_type_var = self.get_var('tensor_t') + + tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + n = Var("x", scalar_type('int32')) + self.prelude.mod[read_var] =\ + Function([tensor_array, n], self.prelude.nth(tensor_array, n), tensor_type_var(), []) + + def define_tensor_array_write(self): + """Defines a function to update a tensor array at index n with value v. + tensor_array_write(ta, n, v) : + list[tensor_t] -> Tensor[(), int32] -> tensor_t -> list[tensor_t] + """ + write_name = self.get_name("tensor_array_write") + write_var = GlobalVar(write_name) + setattr(self.prelude, write_name, write_var) + tensor_type_var = self.get_var('tensor_t') + tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + n = Var("x", scalar_type('int32')) + v = Var("v", tensor_type_var()) + self.prelude.mod[write_var] =\ + Function([tensor_array, n, v], self.prelude.update(tensor_array, n, v), + self.prelude.l(tensor_type_var()), []) + + def define_tensor_array_unstack_tensor1(self): + """Defines a function to unstack the values of a tensor_t with rank 1 in a tensor array. + tensor_array_unstack_tensor1(t) : tensor_t -> list[tensor_t] + """ + helper_name = self.get_name("tensor_array_unstack_tensor1_helper") + helper_var = GlobalVar(helper_name) + setattr(self.prelude, helper_name, helper_var) + tensor = Var("t", TensorType([Any()], self.dtype)) + up = Var("up", scalar_type('int32')) + i = Var("i", scalar_type('int32')) + tensor_type_var = self.get_var('tensor_t') + tensor0_var = self.get_var('tensor0') + helper_body =\ + If(equal(i, up), + self.prelude.nil(), + self.prelude.cons(tensor0_var(op.take(tensor, i)), + helper_var(add(i, const(1)), up, tensor))) + self.prelude.mod[helper_var] =\ + Function([i, up, tensor], helper_body, self.prelude.l(tensor_type_var()), []) + unstack_name = self.get_name("tensor_array_unstack_tensor1") + unstack_var = GlobalVar(unstack_name) + setattr(self.prelude, unstack_name, unstack_var) + tensor1 = Var("tensor", TensorType([Any()], self.dtype)) + shape = op.shape_of(tensor1) + ndim = op.take(shape, const(0)) + self.prelude.mod[unstack_var] =\ + Function([tensor1], helper_var(const(0), ndim, tensor1), + self.prelude.l(tensor_type_var()), []) + + def define_tensor_array_unstack_tensor2(self): + """Defines a function to unstack the values of a tensor_t with rank 2 in a tensor array. + + tensor_array_unstack_tensor2(t) : tensor_t -> list[tensor_t] + """ + helper_name = self.get_name("tensor_array_unstack_tensor2_helper") + helper_var = GlobalVar(helper_name) + setattr(self.prelude, helper_name, helper_var) + tensor = Var("t", TensorType([Any(), Any()], self.dtype)) + up = Var("up", scalar_type('int32')) + i = Var("i", scalar_type('int32')) + + helper_body = If(equal(i, up), + self.prelude.nil(), + self.prelude.cons(self.get_var('tensor1')(op.take(tensor, i, axis=0)), + helper_var(add(i, const(1)), up, tensor))) + self.prelude.mod[helper_var] =\ + Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), []) + + tensor_array_unstack_tensor2_name = self.get_name("tensor_array_unstack_tensor2") + tensor_array_unstack_tensor2_var = GlobalVar(tensor_array_unstack_tensor2_name) + setattr(self.prelude, tensor_array_unstack_tensor2_name, tensor_array_unstack_tensor2_var) + tensor2 = Var("tensor", TensorType([Any(), Any()], self.dtype)) + shape = op.shape_of(tensor2) + ndim = op.take(shape, const(0)) + self.prelude.mod[tensor_array_unstack_tensor2_var] =\ + Function([tensor2], helper_var(const(0), ndim, tensor2), + self.prelude.l(self.get_var('tensor_t')()), []) + + def define_tensor_array_scatter(self): + """Defines a function to scatter the values of a tensor_t in indices of a tensor array. + tensor_array_scatter(ta, indices, value) : + list[tensor_t] -> Tensor[(Any), int32] -> tensor_t -> list[tensor_t] + """ + tensor_array_scatter_helper_name = self.get_name("tensor_array_scatter_helper") + tensor_array_scatter_helper_var = GlobalVar(tensor_array_scatter_helper_name) + tensor_t = self.get_var('tensor_t') + ta = Var("ta", self.prelude.l(tensor_t())) + current = Var("current", scalar_type('int32')) + limit = Var("limit", scalar_type('int32')) + indices_ = Var('indices_', TensorType([Any()], 'int32')) + values_ = Var('values_', self.prelude.l(tensor_t())) + write_var = self.get_var('tensor_array_write') + read_var = self.get_var('tensor_array_read') + helper_body = If(equal(current, limit), + ta, + tensor_array_scatter_helper_var( + write_var(ta, op.take(indices_, current), + read_var(values_, current)), + add(current, const(1)), + limit, indices_, values_)) + self.prelude.mod[tensor_array_scatter_helper_var] =\ + Function([ta, current, limit, indices_, values_], + helper_body, self.prelude.l(tensor_t()), []) + tensor_array_scatter_name = self.get_name("tensor_array_scatter") + tensor_array_scatter_var = GlobalVar(tensor_array_scatter_name) + setattr(self.prelude, tensor_array_scatter_name, tensor_array_scatter_var) + tensor_array = Var("tensor_array", self.prelude.l(tensor_t())) + indices = Var('indices', TensorType([Any()], 'int32')) + values = Var('values', self.prelude.l(tensor_t())) + indices_shape = op.shape_of(indices) + limit = op.take(indices_shape, const(0)) + body = tensor_array_scatter_helper_var(tensor_array, const(0), limit, indices, values) + self.prelude.mod[tensor_array_scatter_var] =\ + Function([tensor_array, indices, values], body, self.prelude.l(tensor_t()), []) + + def define_tensor_array_split(self): + """Defines a function to split the values of a tensor_t into a tensor array. + tensor_array_split(ta, value, lengths) : + list[tensor_t] -> tensor_t -> Tensor[(Any), int32] -> list[tensor_t] + """ + tensor_t = self.get_var('tensor_t') + tensor_array_split_helper_name = self.get_name("ta_split_helper") + tensor_array_split_helper_var = GlobalVar(tensor_array_split_helper_name) + setattr(self.prelude, tensor_array_split_helper_name, tensor_array_split_helper_var) + ta1 = Var("tensor_array", self.prelude.l(tensor_t())) + value1 = Var('value1', tensor_t()) + offset1 = Var('offset1', scalar_type('int32')) + current1 = Var('current1', scalar_type('int32')) + limit1 = Var('limit1', scalar_type('int32')) + lengths1 = Var('lengths', TensorType([Any()], 'int32')) + write_var = self.get_var('tensor_array_write') + take_var = self.get_var('tensor_take') + helper1_body = If(equal(current1, limit1), + ta1, + write_var( + tensor_array_split_helper_var( + ta1, + value1, + add(offset1, op.take(lengths1, current1)), + add(current1, const(1)), + limit1, + lengths1 + ), + current1, + take_var(value1, + offset1, + add(op.take(lengths1, current1), offset1)))) + self.prelude.mod[tensor_array_split_helper_var] = \ + Function([ta1, value1, offset1, current1, limit1, lengths1], + helper1_body, self.prelude.l(tensor_t()), []) + split_name = self.get_name("tensor_array_split") + split_var = GlobalVar(split_name) + setattr(self.prelude, split_name, split_var) + tensor_array = Var("tensor_array", self.prelude.l(tensor_t())) + value = Var('value', tensor_t()) + lengths = Var('lengths', TensorType([Any()], 'int32')) + lengths_shape = op.shape_of(lengths) + lengths_limit = op.take(lengths_shape, const(0)) + body = tensor_array_split_helper_var( + tensor_array, + value, + const(0), + const(0), + lengths_limit, + lengths) + self.prelude.mod[split_var] =\ + Function([tensor_array, value, lengths], body, self.prelude.l(tensor_t()), []) + + def define_tensor_array_concat(self): + """Defines a function to return the values in the tensor array as concatenated tensor_t. + tensor_array_concat(ta) : list[tensor_t] -> tensor_t + """ + concat_name = self.get_name("tensor_array_concat") + concat_var = GlobalVar(concat_name) + setattr(self.prelude, concat_name, concat_var) + tensor_concat_var = self.get_var('tensor_concatenate') + tensor_t = self.get_var('tensor_t') + tensor_nil_var = self.get_var('tensor_nil') + tensor_array = Var("tensor_array", self.prelude.l(tensor_t())) + hd = Var("hd") + tl = Var("tl") + nil_case = Clause(PatternConstructor(self.prelude.nil), tensor_nil_var()) + cons_case = Clause(PatternConstructor(self.prelude.cons, [PatternVar(hd), PatternVar(tl)]), + Match(tl, [ + Clause(PatternConstructor(self.prelude.nil), hd), + Clause(PatternWildcard(), + tensor_concat_var(hd, concat_var(tl))) + ], False)) + self.prelude.mod[concat_var] =\ + Function([tensor_array], + Match(tensor_array, [nil_case, cons_case], False), tensor_t(), []) + + def define_tensor_array_gather(self): + """Defines a function to return the selected values in a tensor array as tensor_t. + tensor_array_gather(ta, indices) : list[tensor_t] -> Tensor[(Any), int32] -> tensor_t + """ + helper_name = self.get_name("tensor_array_gather_helper") + helper_var = GlobalVar(helper_name) + setattr(self.prelude, helper_name, helper_var) + tensor_type_var = self.get_var('tensor_t') + stack_var = self.get_var('tensor_array_stack') + read_var = self.get_var('tensor_array_read') + ta = Var("ta", self.prelude.l(tensor_type_var())) + accu = Var("accu", self.prelude.l(tensor_type_var())) + current = Var("current", scalar_type('int32')) + limit = Var("limit", scalar_type('int32')) + indices_ = Var('indices_', TensorType([Any()], 'int32')) + helper_body =\ + If(equal(current, const(0)), + stack_var(accu), + helper_var( + ta, + self.prelude.cons( + read_var( + ta, op.take(indices_, subtract(current, const(1)))), accu), + subtract(current, const(1)), + limit, indices_)) + self.prelude.mod[helper_var] = \ + Function([ta, accu, current, limit, indices_], helper_body, tensor_type_var(), []) + gather_name = self.get_name("tensor_array_gather") + gather_var = GlobalVar(gather_name) + setattr(self.prelude, gather_name, gather_var) + tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + indices = Var('indices', TensorType([Any()], 'int32')) + indices_shape = op.shape_of(indices) + limit = op.take(indices_shape, const(0)) + body = helper_var(tensor_array, self.prelude.nil(), limit, limit, indices) + self.prelude.mod[gather_var] =\ + Function([tensor_array, indices], body, tensor_type_var(), []) + + def define_tensor_array_stack(self): + """Defines a function to get the values in the tensor array as a stack tensor_t. + tensor_array_stack(l) : list[tensor_t] -> tensor_t + """ + stack_name = self.get_name("tensor_array_stack") + stack_var = GlobalVar(stack_name) + setattr(self.prelude, stack_name, stack_var) + tensor_type_var = self.get_var('tensor_t') + tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + expand_dims_var = self.get_var('tensor_expand_dims') + concat_var = self.get_var('tensor_concatenate') + tensor_array_expand_dims = self.prelude.map(expand_dims_var, tensor_array) + tensors = self.prelude.foldl(concat_var, + self.prelude.hd(tensor_array_expand_dims), + self.prelude.tl(tensor_array_expand_dims)) + self.prelude.mod[stack_var] = Function([tensor_array], tensors, tensor_type_var(), []) + + def register(self): + """Register all tensor array ops in Prelude""" + self.define_tensor_adt() + self.define_tensor_take() + self.define_tensor_expand_dims() + self.define_tensor_concat() + self.define_tensor_array() + self.define_tensor_array_read() + self.define_tensor_array_write() + self.define_tensor_array_unstack_tensor1() + self.define_tensor_array_unstack_tensor2() + self.define_tensor_array_scatter() + self.define_tensor_array_split() + self.define_tensor_array_concat() + self.define_tensor_array_stack() + # TODO(wweic): Gather fails in PartialEvaluate + # self.define_tensor_array_gather() + class Prelude: """Contains standard definitions.""" @@ -27,6 +532,17 @@ def __init__(self, mod=None): self.mod = mod self.load_prelude() + def get_name(self, canonical, dtype): + """Get name corresponding to the canonical name""" + if canonical == 'tensor_t': + return 'tensor_{}_t'.format(dtype) + return "{}_{}".format(canonical, dtype) + + def get_var(self, canonical, dtype): + """Get var corresponding to the canonical name""" + name = self.get_name(canonical, dtype) + return getattr(self, name) + def load_prelude(self): """Parses the Prelude from Relay's text format into a module.""" # TODO(@jroesch): we should remove this helper when we port over prelude @@ -74,3 +590,7 @@ def load_prelude(self): ] for global_def in GLOBAL_DEFS: setattr(self, global_def, self.mod.get_global_var(global_def)) + + for dtype in ['float32', 'int32']: + tensor_array_ops = TensorArrayOps(self, dtype) + tensor_array_ops.register() diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 0fdc0f3a3231..6b2e073822f1 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -100,7 +100,7 @@ def _is_int8_hw_support(target): Checks to ensure that we can use Intel DLBoost instructions - Check if the target is skylake and above. """ - supported_arches = {'-mcpu=skylake-avx512',} + supported_arches = {'-mcpu=skylake-avx512', '-mcpu=cascadelake'} return supported_arches.intersection(set(target.options)) # Collect the dtypes. diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index ed443abb5293..7faf62b4be14 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -27,7 +27,7 @@ def requantize(data, input_zero_point, output_scale, output_zero_point, - rounding="TONEAREST", + rounding="UPWARD", out_dtype="int8"): r"""Requantized operator. @@ -349,3 +349,45 @@ def dense(data, input_zero_point, kernel_zero_point, out_dtype) + + +def mul(lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, + output_scale, output_zero_point): + """Quantized multiplication with numpy-style broadcasting. + + Parameters + ---------- + lhs : relay.Expr + The left hand side quantized input data. + + rhs : relay.Expr + The right hand side quantized input data. + + lhs_scale: float + The scale of the lhs quantized expr. + + lhs_zero_point: int + The zero point of lhs quantized expr. + + rhs_scale: float + The scale of the rhs quantized expr. + + rhs_zero_point: int + The zero point of rhs quantized expr. + + output_scale: float + The scale of the output quantized expr. + + output_zero_point: int + The zero point of output quantized expr. + + Returns + ------- + result : relay.Expr + The computed result. + + """ + return _make.mul(lhs, rhs, + lhs_scale, lhs_zero_point, + rhs_scale, rhs_zero_point, + output_scale, output_zero_point) diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index d661be73ad02..d7b59922b89d 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -203,8 +203,12 @@ def convert_module(self): for var, func in self.mod.functions.items(): # optimize the definition so any operators used are lowered opt_func = self.optimize(func) - converted_func, _ = self.convert_func_node(opt_func, var) - defs.append(converted_func) + try: + converted_func, _ = self.convert_func_node(opt_func, var) + defs.append(converted_func) + except TypeError: + # TODO(wweic): fix conversion for Any + pass return defs diff --git a/python/tvm/rpc/tracker.py b/python/tvm/rpc/tracker.py index f31625cd34ed..b9b29a7fe4a1 100644 --- a/python/tvm/rpc/tracker.py +++ b/python/tvm/rpc/tracker.py @@ -230,7 +230,7 @@ def call_handler(self, args): port, matchkey = args[2] self.pending_matchkeys.add(matchkey) # got custom address (from rpc server) - if args[3] is not None: + if len(args) >= 4 and args[3] is not None: value = (self, args[3], port, matchkey) else: value = (self, self._addr[0], port, matchkey) diff --git a/python/tvm/target.py b/python/tvm/target.py index 4548ffac4c88..087c9b47fd7a 100644 --- a/python/tvm/target.py +++ b/python/tvm/target.py @@ -128,6 +128,16 @@ def model(self): return opt.value[7:] return 'unknown' + @property + def mcpu(self): + """Returns the mcpu from the target if it exists.""" + mcpu = '' + if self.options is not None: + for opt in self.options: + if 'mcpu' in opt: + mcpu = opt.split('=')[1] + return mcpu + def __enter__(self): _api_internal._EnterTargetScope(self) return self @@ -496,6 +506,19 @@ def vta(model='unknown', options=None): return ret +def bifrost(model='unknown', options=None): + """Return an ARM Mali GPU target (Bifrost architecture). + + Parameters + ---------- + options : str or list of str + Additional options + """ + opts = ["-device=bifrost", '-model=%s' % model] + opts = _merge_opts(opts, options) + return _api_internal._TargetCreate("opencl", *opts) + + def create(target_str): """Get a target given target string. diff --git a/rust/common/src/packed_func.rs b/rust/common/src/packed_func.rs index d9399492264b..848d5c00ab3f 100644 --- a/rust/common/src/packed_func.rs +++ b/rust/common/src/packed_func.rs @@ -71,7 +71,7 @@ macro_rules! TVMPODValue { Context(TVMContext), Handle(*mut c_void), ArrayHandle(TVMArrayHandle), - NodeHandle(*mut c_void), + ObjectHandle(*mut c_void), ModuleHandle(TVMModuleHandle), FuncHandle(TVMFunctionHandle), NDArrayContainer(*mut c_void), @@ -92,7 +92,7 @@ macro_rules! TVMPODValue { TVMTypeCode_kTVMContext => Context($value.v_ctx), TVMTypeCode_kHandle => Handle($value.v_handle), TVMTypeCode_kArrayHandle => ArrayHandle($value.v_handle as TVMArrayHandle), - TVMTypeCode_kNodeHandle => NodeHandle($value.v_handle), + TVMTypeCode_kObjectHandle => ObjectHandle($value.v_handle), TVMTypeCode_kModuleHandle => ModuleHandle($value.v_handle), TVMTypeCode_kFuncHandle => FuncHandle($value.v_handle), TVMTypeCode_kNDArrayContainer => NDArrayContainer($value.v_handle), @@ -124,7 +124,7 @@ macro_rules! TVMPODValue { TVMTypeCode_kArrayHandle, ) }, - NodeHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kNodeHandle), + ObjectHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kObjectHandle), ModuleHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kModuleHandle), FuncHandle(val) => ( diff --git a/rust/frontend/src/function.rs b/rust/frontend/src/function.rs index 948711276304..01d0c58cfc5d 100644 --- a/rust/frontend/src/function.rs +++ b/rust/frontend/src/function.rs @@ -264,7 +264,7 @@ unsafe extern "C" fn tvm_callback( for i in 0..len { value = args_list[i]; tcode = type_codes_list[i]; - if tcode == ffi::TVMTypeCode_kNodeHandle as c_int + if tcode == ffi::TVMTypeCode_kObjectHandle as c_int || tcode == ffi::TVMTypeCode_kFuncHandle as c_int || tcode == ffi::TVMTypeCode_kModuleHandle as c_int { diff --git a/src/README.md b/src/README.md index 0c6f30a881b8..599f41dfdc5f 100644 --- a/src/README.md +++ b/src/README.md @@ -22,6 +22,8 @@ There can be internal header files within each module that sit in src. ## Modules - common: Internal common utilities. +- runtime: Minimum runtime related codes. +- node: base infra for IR/AST nodes that is dialect independent. - api: API function registration. - lang: The definition of DSL related data structure. - arithmetic: Arithmetic expression and set simplification. @@ -29,7 +31,6 @@ There can be internal header files within each module that sit in src. - schedule: The operations on the schedule graph before converting to IR. - pass: The optimization pass on the IR structure. - codegen: The code generator. -- runtime: Minimum runtime related codes. - autotvm: The auto-tuning module. - relay: Implementation of Relay. The second generation of NNVM, a new IR for deep learning frameworks. - contrib: Contrib extension libraries. diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index f31f02b1eaf4..c57e2afaa8eb 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -117,8 +117,7 @@ TVM_REGISTER_API("arith._CreateAnalyzer") }); } else if (name == "bind") { return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - auto& sptr = args[1].node_sptr(); - if (sptr->is_type()) { + if (args[1].IsObjectRef()) { self->Bind(args[0], args[1].operator Range()); } else { self->Bind(args[0], args[1].operator Expr()); diff --git a/src/api/api_base.cc b/src/api/api_base.cc index 28ebb4d65005..42367efb15bb 100644 --- a/src/api/api_base.cc +++ b/src/api/api_base.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -26,11 +26,12 @@ #include #include #include +#include namespace tvm { TVM_REGISTER_API("_format_str") .set_body([](TVMArgs args, TVMRetValue *ret) { - CHECK(args[0].type_code() == kNodeHandle); + CHECK(args[0].type_code() == kObjectHandle); std::ostringstream os; os << args[0].operator NodeRef(); *ret = os.str(); @@ -38,16 +39,15 @@ TVM_REGISTER_API("_format_str") TVM_REGISTER_API("_raw_ptr") .set_body([](TVMArgs args, TVMRetValue *ret) { - CHECK(args[0].type_code() == kNodeHandle); - *ret = reinterpret_cast( - args[0].node_sptr().get()); + CHECK(args[0].type_code() == kObjectHandle); + *ret = reinterpret_cast(args[0].value().v_handle); }); TVM_REGISTER_API("_save_json") -.set_body_typed(SaveJSON); +.set_body_typed(SaveJSON); TVM_REGISTER_API("_load_json") -.set_body_typed(LoadJSON); +.set_body_typed(LoadJSON); TVM_REGISTER_API("_TVMSetStream") .set_body_typed(TVMSetStream); diff --git a/src/api/api_codegen.cc b/src/api/api_codegen.cc index 73e26719cf15..f2ca67e6e2f9 100644 --- a/src/api/api_codegen.cc +++ b/src/api/api_codegen.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -33,7 +33,7 @@ namespace codegen { TVM_REGISTER_API("codegen._Build") .set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsNodeType()) { + if (args[0].IsObjectRef()) { *ret = Build({args[0]}, args[1]); } else { *ret = Build(args[0], args[1]); diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index b8ee1441fe12..9312c5532302 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2016 by Contributors * Implementation of API functions related to IR build * \file api_ir.cc */ diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index aa0ce47b4a37..f3d6c5f6ab62 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -57,25 +57,26 @@ TVM_REGISTER_API("_str") TVM_REGISTER_API("_Array") .set_body([](TVMArgs args, TVMRetValue* ret) { - std::vector > data; + std::vector data; for (int i = 0; i < args.size(); ++i) { if (args[i].type_code() != kNull) { - data.push_back(args[i].node_sptr()); + data.push_back(args[i].operator ObjectRef()); } else { - data.push_back(NodePtr(nullptr)); + data.push_back(ObjectRef(nullptr)); } } auto node = make_node(); node->data = std::move(data); - *ret = node; + *ret = runtime::ObjectRef(node); }); TVM_REGISTER_API("_ArrayGetItem") .set_body([](TVMArgs args, TVMRetValue* ret) { int64_t i = args[1]; - auto& sptr = args[0].node_sptr(); - CHECK(sptr->is_type()); - auto* n = static_cast(sptr.get()); + CHECK_EQ(args[0].type_code(), kObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + CHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); CHECK_LT(static_cast(i), n->data.size()) << "out of bound of array"; *ret = n->data[static_cast(i)]; @@ -83,10 +84,11 @@ TVM_REGISTER_API("_ArrayGetItem") TVM_REGISTER_API("_ArraySize") .set_body([](TVMArgs args, TVMRetValue* ret) { - auto& sptr = args[0].node_sptr(); - CHECK(sptr->is_type()); + CHECK_EQ(args[0].type_code(), kObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + CHECK(ptr->IsInstance()); *ret = static_cast( - static_cast(sptr.get())->data.size()); + static_cast(ptr)->data.size()); }); TVM_REGISTER_API("_Map") @@ -98,10 +100,10 @@ TVM_REGISTER_API("_Map") for (int i = 0; i < args.num_args; i += 2) { CHECK(args[i].type_code() == kStr) << "key of str map need to be str"; - CHECK(args[i + 1].type_code() == kNodeHandle) + CHECK(args[i + 1].type_code() == kObjectHandle) << "value of the map to be NodeRef"; data.emplace(std::make_pair(args[i].operator std::string(), - args[i + 1].node_sptr())); + args[i + 1].operator ObjectRef())); } auto node = make_node(); node->data = std::move(data); @@ -110,12 +112,12 @@ TVM_REGISTER_API("_Map") // Container node. MapNode::ContainerType data; for (int i = 0; i < args.num_args; i += 2) { - CHECK(args[i].type_code() == kNodeHandle) + CHECK(args[i].type_code() == kObjectHandle) << "key of str map need to be str"; - CHECK(args[i + 1].type_code() == kNodeHandle) + CHECK(args[i + 1].type_code() == kObjectHandle) << "value of map to be NodeRef"; - data.emplace(std::make_pair(args[i].node_sptr(), - args[i + 1].node_sptr())); + data.emplace(std::make_pair(args[i].operator ObjectRef(), + args[i + 1].operator ObjectRef())); } auto node = make_node(); node->data = std::move(data); @@ -125,31 +127,33 @@ TVM_REGISTER_API("_Map") TVM_REGISTER_API("_MapSize") .set_body([](TVMArgs args, TVMRetValue* ret) { - auto& sptr = args[0].node_sptr(); - if (sptr->is_type()) { - auto* n = static_cast(sptr.get()); + CHECK_EQ(args[0].type_code(), kObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + if (ptr->IsInstance()) { + auto* n = static_cast(ptr); *ret = static_cast(n->data.size()); } else { - CHECK(sptr->is_type()); - auto* n = static_cast(sptr.get()); + CHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); *ret = static_cast(n->data.size()); } }); TVM_REGISTER_API("_MapGetItem") .set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK(args[0].type_code() == kNodeHandle); - auto& sptr = args[0].node_sptr(); - if (sptr->is_type()) { - CHECK(args[1].type_code() == kNodeHandle); - auto* n = static_cast(sptr.get()); - auto it = n->data.find(args[1].node_sptr()); + CHECK_EQ(args[0].type_code(), kObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + + if (ptr->IsInstance()) { + CHECK(args[1].type_code() == kObjectHandle); + auto* n = static_cast(ptr); + auto it = n->data.find(args[1].operator ObjectRef()); CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map"; *ret = (*it).second; } else { - CHECK(sptr->is_type()); - auto* n = static_cast(sptr.get()); + CHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); auto it = n->data.find(args[1].operator std::string()); CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map"; @@ -159,16 +163,17 @@ TVM_REGISTER_API("_MapGetItem") TVM_REGISTER_API("_MapCount") .set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK(args[0].type_code() == kNodeHandle); - auto& sptr = args[0].node_sptr(); - if (sptr->is_type()) { - auto* n = static_cast(sptr.get()); - CHECK(args[1].type_code() == kNodeHandle); + CHECK_EQ(args[0].type_code(), kObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + + if (ptr->IsInstance()) { + auto* n = static_cast(ptr); + CHECK_EQ(args[0].type_code(), kObjectHandle); *ret = static_cast( - n->data.count(args[1].node_sptr())); + n->data.count(args[1].operator ObjectRef())); } else { - CHECK(sptr->is_type()); - auto* n = static_cast(sptr.get()); + CHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); *ret = static_cast( n->data.count(args[1].operator std::string())); } @@ -176,9 +181,11 @@ TVM_REGISTER_API("_MapCount") TVM_REGISTER_API("_MapItems") .set_body([](TVMArgs args, TVMRetValue* ret) { - auto& sptr = args[0].node_sptr(); - if (sptr->is_type()) { - auto* n = static_cast(sptr.get()); + CHECK_EQ(args[0].type_code(), kObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + + if (ptr->IsInstance()) { + auto* n = static_cast(ptr); auto rkvs = make_node(); for (const auto& kv : n->data) { rkvs->data.push_back(kv.first); @@ -186,10 +193,10 @@ TVM_REGISTER_API("_MapItems") } *ret = rkvs; } else { - auto* n = static_cast(sptr.get()); + auto* n = static_cast(ptr); auto rkvs = make_node(); for (const auto& kv : n->data) { - rkvs->data.push_back(ir::StringImm::make(kv.first).node_); + rkvs->data.push_back(ir::StringImm::make(kv.first)); rkvs->data.push_back(kv.second); } *ret = rkvs; @@ -426,7 +433,7 @@ TVM_REGISTER_API("_ScheduleCacheRead") TVM_REGISTER_API("_ScheduleCacheWrite") .set_body([](TVMArgs args, TVMRetValue* ret) { - if (args[1].IsNodeType()) { + if (args[1].IsObjectRef()) { *ret = args[0].operator Schedule() .cache_write(args[1].operator Tensor(), args[2]); } else { diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 25cd5838385f..d7f621f3ade1 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -35,7 +35,7 @@ namespace ir { TVM_REGISTER_API("ir_pass.Simplify") .set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsNodeType()) { + if (args[0].IsObjectRef()) { if (args.size() > 1) { *ret = Simplify(args[0].operator Stmt(), args[1]); } else { @@ -52,7 +52,7 @@ TVM_REGISTER_API("ir_pass.Simplify") TVM_REGISTER_API("ir_pass.CanonicalSimplify") .set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsNodeType()) { + if (args[0].IsObjectRef()) { if (args.size() > 1) { *ret = CanonicalSimplify(args[0].operator Stmt(), args[1]); } else { @@ -69,7 +69,7 @@ TVM_REGISTER_API("ir_pass.CanonicalSimplify") TVM_REGISTER_API("ir_pass.Substitute") .set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsNodeType()) { + if (args[0].IsObjectRef()) { *ret = Substitute(args[0].operator Stmt(), args[1].operator Map()); } else { *ret = Substitute(args[0].operator Expr(), args[1].operator Map()); @@ -78,7 +78,7 @@ TVM_REGISTER_API("ir_pass.Substitute") TVM_REGISTER_API("ir_pass.Equal") .set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsNodeType()) { + if (args[0].IsObjectRef()) { *ret = Equal(args[0].operator Stmt(), args[1].operator Stmt()); } else { *ret = Equal(args[0].operator Expr(), args[1].operator Expr()); @@ -118,6 +118,14 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit") }); }); +TVM_REGISTER_API("ir_pass.LowerStorageAccess") +.set_body([](TVMArgs args, TVMRetValue *ret) { + LoweredFunc f = args[0]; + auto n = make_node(*f.operator->()); + n->body = LowerStorageAccessInfo(f->body); + *ret = LoweredFunc(n); +}); + // make from two arguments #define REGISTER_PASS(PassName) \ TVM_REGISTER_API("ir_pass."#PassName) \ @@ -140,6 +148,7 @@ REGISTER_PASS(SplitHostDevice); REGISTER_PASS(StorageRewrite); REGISTER_PASS(CoProcSync); REGISTER_PASS(LowerStorageAccessInfo); +REGISTER_PASS(LowerDeviceStorageAccessInfo) REGISTER_PASS(InjectVirtualThread); REGISTER_PASS(InjectPrefetch); REGISTER_PASS(InjectDoubleBuffer); @@ -160,5 +169,7 @@ REGISTER_PASS(VerifyGPUCode); REGISTER_PASS(DecorateDeviceScope); REGISTER_PASS(InstrumentBoundCheckers); REGISTER_PASS(VerifyCompactBuffer); +REGISTER_PASS(HoistIfThenElse); +REGISTER_PASS(InferFragment) } // namespace ir } // namespace tvm diff --git a/src/api/api_schedule.cc b/src/api/api_schedule.cc index 177360bf2ebb..cf0e0f3c6b7a 100644 --- a/src/api/api_schedule.cc +++ b/src/api/api_schedule.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2017 by Contributors * Implementation of API functions related to schedule pass. * \file api_schedule.cc */ diff --git a/src/api/dsl_api.cc b/src/api/dsl_api.cc deleted file mode 100644 index 89e999f73edb..000000000000 --- a/src/api/dsl_api.cc +++ /dev/null @@ -1,230 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2016 by Contributors - * Implementation of DSL API - * \file dsl_api.cc - */ -#include -#include -#include -#include -#include -#include -#include -#include -#include "../runtime/dsl_api.h" - -namespace tvm { -namespace runtime { -/*! \brief entry to to easily hold returning information */ -struct TVMAPIThreadLocalEntry { - /*! \brief result holder for returning strings */ - std::vector ret_vec_str; - /*! \brief result holder for returning string pointers */ - std::vector ret_vec_charp; - /*! \brief result holder for retruning string */ - std::string ret_str; -}; - -/*! \brief Thread local store that can be used to hold return values. */ -typedef dmlc::ThreadLocalStore TVMAPIThreadLocalStore; - -using TVMAPINode = NodePtr; - -struct APIAttrGetter : public AttrVisitor { - std::string skey; - TVMRetValue* ret; - bool found_ref_object{false}; - - void Visit(const char* key, double* value) final { - if (skey == key) *ret = value[0]; - } - void Visit(const char* key, int64_t* value) final { - if (skey == key) *ret = value[0]; - } - void Visit(const char* key, uint64_t* value) final { - CHECK_LE(value[0], static_cast(std::numeric_limits::max())) - << "cannot return too big constant"; - if (skey == key) *ret = static_cast(value[0]); - } - void Visit(const char* key, int* value) final { - if (skey == key) *ret = static_cast(value[0]); - } - void Visit(const char* key, bool* value) final { - if (skey == key) *ret = static_cast(value[0]); - } - void Visit(const char* key, void** value) final { - if (skey == key) *ret = static_cast(value[0]); - } - void Visit(const char* key, Type* value) final { - if (skey == key) *ret = value[0]; - } - void Visit(const char* key, std::string* value) final { - if (skey == key) *ret = value[0]; - } - void Visit(const char* key, NodeRef* value) final { - if (skey == key) { - *ret = value[0]; - found_ref_object = true; - } - } - void Visit(const char* key, runtime::NDArray* value) final { - if (skey == key) { - *ret = value[0]; - found_ref_object = true; - } - } - void Visit(const char* key, runtime::ObjectRef* value) final { - if (skey == key) { - *ret = value[0]; - found_ref_object = true; - } - } -}; - -struct APIAttrDir : public AttrVisitor { - std::vector* names; - - void Visit(const char* key, double* value) final { - names->push_back(key); - } - void Visit(const char* key, int64_t* value) final { - names->push_back(key); - } - void Visit(const char* key, uint64_t* value) final { - names->push_back(key); - } - void Visit(const char* key, bool* value) final { - names->push_back(key); - } - void Visit(const char* key, int* value) final { - names->push_back(key); - } - void Visit(const char* key, void** value) final { - names->push_back(key); - } - void Visit(const char* key, Type* value) final { - names->push_back(key); - } - void Visit(const char* key, std::string* value) final { - names->push_back(key); - } - void Visit(const char* key, NodeRef* value) final { - names->push_back(key); - } - void Visit(const char* key, runtime::NDArray* value) final { - names->push_back(key); - } - void Visit(const char* key, runtime::ObjectRef* value) final { - names->push_back(key); - } -}; - -class DSLAPIImpl : public DSLAPI { - public: - void NodeFree(NodeHandle handle) const final { - delete static_cast(handle); - } - void NodeTypeKey2Index(const char* type_key, - int* out_index) const final { - *out_index = static_cast(Node::TypeKey2Index(type_key)); - } - void NodeGetTypeIndex(NodeHandle handle, - int* out_index) const final { - *out_index = static_cast( - (*static_cast(handle))->type_index()); - } - void NodeGetAttr(NodeHandle handle, - const char* key, - TVMValue* ret_val, - int* ret_type_code, - int* ret_success) const final { - TVMRetValue rv; - APIAttrGetter getter; - TVMAPINode* tnode = static_cast(handle); - getter.skey = key; - getter.ret = &rv; - if (getter.skey == "type_key") { - ret_val->v_str = (*tnode)->type_key(); - *ret_type_code = kStr; - *ret_success = 1; - return; - } else if (!(*tnode)->is_type()) { - (*tnode)->VisitAttrs(&getter); - *ret_success = getter.found_ref_object || rv.type_code() != kNull; - } else { - // specially handle dict attr - DictAttrsNode* dnode = static_cast(tnode->get()); - auto it = dnode->dict.find(key); - if (it != dnode->dict.end()) { - *ret_success = 1; - rv = (*it).second; - } else { - *ret_success = 0; - } - } - if (*ret_success) { - if (rv.type_code() == kStr || - rv.type_code() == kTVMType) { - TVMAPIThreadLocalEntry *e = TVMAPIThreadLocalStore::Get(); - e->ret_str = rv.operator std::string(); - *ret_type_code = kStr; - ret_val->v_str = e->ret_str.c_str(); - } else { - rv.MoveToCHost(ret_val, ret_type_code); - } - } - } - void NodeListAttrNames(NodeHandle handle, - int *out_size, - const char*** out_array) const final { - TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); - ret->ret_vec_str.clear(); - TVMAPINode* tnode = static_cast(handle); - APIAttrDir dir; - dir.names = &(ret->ret_vec_str); - - if (!(*tnode)->is_type()) { - (*tnode)->VisitAttrs(&dir); - } else { - // specially handle dict attr - DictAttrsNode* dnode = static_cast(tnode->get()); - for (const auto& kv : dnode->dict) { - ret->ret_vec_str.push_back(kv.first); - } - } - ret->ret_vec_charp.clear(); - for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { - ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); - } - *out_array = dmlc::BeginPtr(ret->ret_vec_charp); - *out_size = static_cast(ret->ret_vec_str.size()); - } -}; - -TVM_REGISTER_GLOBAL("dsl_api.singleton") -.set_body([](TVMArgs args, TVMRetValue* rv) { - static DSLAPIImpl impl; - void* ptr = &impl; - *rv = ptr; - }); -} // namespace runtime -} // namespace tvm diff --git a/src/arithmetic/analyzer.cc b/src/arithmetic/analyzer.cc index acd964935c25..98e25742592d 100644 --- a/src/arithmetic/analyzer.cc +++ b/src/arithmetic/analyzer.cc @@ -36,9 +36,7 @@ Analyzer::Analyzer() int_set(this) { } -void Analyzer::Bind(const VarExpr& v, const Expr& expr) { - Var var(v.node_); - +void Analyzer::Bind(const VarExpr& var, const Expr& expr) { Expr new_expr = expr; new_expr = this->canonical_simplify(new_expr); new_expr = this->rewrite_simplify(new_expr); @@ -49,9 +47,8 @@ void Analyzer::Bind(const VarExpr& v, const Expr& expr) { this->canonical_simplify.Update(var, new_expr); } -void Analyzer::Bind(const VarExpr& v, const Range& range) { +void Analyzer::Bind(const VarExpr& var, const Range& range) { CHECK(range.defined()); - Var var(v.node_); if (is_one(range->extent)) { this->Bind(var, range->min); } else { diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 6f7b4d78da05..9c3a706e2ad0 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -53,17 +53,17 @@ class VariablePathFinder: public IRVisitor { if (!found_) path_.pop_back(); } - std::vector path_; + std::vector path_; private: bool found_{false}; Expr target_; - std::unordered_set visited_; + std::unordered_set visited_; }; // get the path to the variable, // return empty vector to represent failure -std::vector GetPath(Expr target, Expr expr) { +std::vector GetPath(Expr target, Expr expr) { VariablePathFinder v(target); v.Visit(expr); return v.path_; @@ -189,7 +189,7 @@ class BoundDeducer: public IRVisitor { const std::unordered_map& hint_map_; const std::unordered_map& relax_map_; ExprIntSetMap expr_map_; - std::vector path_; + std::vector path_; size_t iter_{0}; // internal analzyer Analyzer analyzer_; diff --git a/src/arithmetic/canonical_simplify.cc b/src/arithmetic/canonical_simplify.cc index d80e4969d5c2..1b576a645824 100644 --- a/src/arithmetic/canonical_simplify.cc +++ b/src/arithmetic/canonical_simplify.cc @@ -43,6 +43,7 @@ class SplitExpr; */ class CanonicalExprNode : public BaseExprNode { public: + virtual ~CanonicalExprNode() {} /*! * \brief Return the normal Expr that is equivalent to self. * \note Can mutate the internal data structure. @@ -51,7 +52,7 @@ class CanonicalExprNode : public BaseExprNode { virtual Expr Normalize() const = 0; // overrides - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { } static constexpr const char* _type_key = "arith.CanonicalExpr"; @@ -485,7 +486,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { * \return Normalized expr. */ Expr Normalize(Expr expr) { - if (const auto* op = expr.as_derived()) { + if (const auto* op = expr.as()) { return op->Normalize(); } else { return expr; @@ -503,7 +504,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { if (const auto* op = expr.as()) { if (op->base == 0 && op->args.size() == 1) return op->args[0]; } - if (const auto* op = expr.as_derived()) { + if (const auto* op = expr.as()) { expr = op->Normalize(); } NodePtr n = make_node(); @@ -629,7 +630,7 @@ Mutate_(const Mul* op, const Expr& self) { } if (const auto* bconst = b.as()) { if (a.as()) { - SumExpr ret(std::move(a.node_)); + SumExpr ret = Downcast(std::move(a)); ret.CopyOnWrite()->MulToSelf(bconst->value); return std::move(ret); } else { @@ -931,7 +932,7 @@ Mutate_(const Mod* op, const Expr& self) { int64_t new_base = psum->base % cval; if (cbound->min_value >= 0 && cbound->min_value - psum->base + new_base >= 0) { - SumExpr sum_expr(std::move(a.node_)); + SumExpr sum_expr = Downcast(a); sum_expr.CopyOnWrite()->base = new_base; return SplitModConst(ToSplitExpr(std::move(sum_expr)), cval, kTruncDiv); } @@ -992,7 +993,7 @@ Mutate_(const FloorMod* op, const Expr& self) { // Simplify the offset constant if necessary. // floormod(x - 5, 3) => floormod(x + 1, 3) int64_t new_base = floormod(psum->base, cval); - SumExpr sum_expr(std::move(a.node_)); + SumExpr sum_expr = Downcast(std::move(a)); sum_expr.CopyOnWrite()->base = new_base; return SplitModConst(ToSplitExpr(std::move(sum_expr)), cval, kFloorDiv); } else { diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h index 57f90534fbb4..86f1927f2abe 100644 --- a/src/arithmetic/const_fold.h +++ b/src/arithmetic/const_fold.h @@ -28,6 +28,7 @@ #include #include #include +#include #include "int_operator.h" namespace tvm { diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index d5c012d302dc..168486ee0018 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -39,7 +39,7 @@ ConstIntBound::ConstIntBound( auto node = make_node(); node->min_value = min_value; node->max_value = max_value; - node_ = std::move(node); + data_ = std::move(node); } inline void PrintBoundValue(std::ostream& os, int64_t val) { diff --git a/src/arithmetic/detect_linear_equation.cc b/src/arithmetic/detect_linear_equation.cc index 3c5f12a7379e..7da020efc42a 100644 --- a/src/arithmetic/detect_linear_equation.cc +++ b/src/arithmetic/detect_linear_equation.cc @@ -176,7 +176,7 @@ bool DetectClipBound( if (const Variable* v = n.as()) { if (bmap->count(v)) { if (flag == 0) { - var = Var(n.node_); + var = Downcast(n); flag = 1; } else if (flag == 1) { if (!var.same_as(n)) { diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 0e24714daf1f..409477578758 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -40,7 +40,7 @@ IntervalSet::IntervalSet(Expr min_value, Expr max_value) { auto node = make_node(); node->min_value = std::move(min_value); node->max_value = std::move(max_value); - node_ = std::move(node); + data_ = std::move(node); } IntervalSet MakeIntervalSet(Expr min_value, Expr max_value) { @@ -506,7 +506,7 @@ class IntervalSetEvaluator : } IntervalSet VisitExprDefault_(const Node* op) final { - DLOG(WARNING) << "cannot evaluate set type " << op->type_key(); + DLOG(WARNING) << "cannot evaluate set type " << op->GetTypeKey(); return IntervalSet::Everything(); } @@ -807,6 +807,8 @@ IntSet EvalSet(Range r, return EvalSet(r, ConvertDomMap(dom_map)); } +TVM_REGISTER_NODE_TYPE(IntervalSetNode); + TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch([](const IntervalSetNode *op, IRPrinter *p) { p->stream << "IntervalSet" diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h index 306361868759..831b44409030 100644 --- a/src/arithmetic/int_set.h +++ b/src/arithmetic/int_set.h @@ -47,7 +47,7 @@ class IntervalSetNode : public IntSetNode { Expr max_value; // visitor overload. - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("min_value", &min_value); v->Visit("max_value", &max_value); } diff --git a/src/arithmetic/ir_mutator_with_analyzer.cc b/src/arithmetic/ir_mutator_with_analyzer.cc index 04e166ae52c0..cda9d585ace1 100644 --- a/src/arithmetic/ir_mutator_with_analyzer.cc +++ b/src/arithmetic/ir_mutator_with_analyzer.cc @@ -87,7 +87,7 @@ Stmt IRMutatorWithAnalyzer:: Mutate_(const AttrStmt* op, const Stmt& s) { if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { - IterVar iv(op->node.node_); + IterVar iv = Downcast(op->node); CHECK_NE(iv->thread_tag.length(), 0U); analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value)); diff --git a/src/arithmetic/ir_visitor_with_analyzer.h b/src/arithmetic/ir_visitor_with_analyzer.h index 71eea50e4c72..918f2e89501f 100644 --- a/src/arithmetic/ir_visitor_with_analyzer.h +++ b/src/arithmetic/ir_visitor_with_analyzer.h @@ -47,7 +47,7 @@ class IRVisitorWithAnalyzer final : public IRVisitor { void Visit_(const AttrStmt* op) { if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { - IterVar iv(op->node.node_); + IterVar iv = Downcast(op->node); CHECK_NE(iv->thread_tag.length(), 0U); analyzer_.Bind(iv->var, Range::make_by_min_extent(0, op->value)); diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index 08454dd0ef5a..9e363e7cf99a 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -41,7 +41,7 @@ ModularSet::ModularSet(int64_t coeff, int64_t base) { node->coeff = coeff; node->base = base; // finish construction. - node_ = std::move(node); + data_ = std::move(node); } TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 3f1c32243a23..cfcb0607858f 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -34,6 +34,7 @@ namespace tvm { TVM_REGISTER_NODE_TYPE(TargetNode); +TVM_REGISTER_NODE_TYPE(GenericFuncNode); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch([](const TargetNode *op, IRPrinter *p) { @@ -51,9 +52,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) */ Target CreateTarget(const std::string& target_name, const std::vector& options) { - auto target = Target(make_node()); - auto t = static_cast(target.node_.get()); - + auto t = make_node(); t->target_name = target_name; std::string libs_flag = "-libs="; @@ -137,7 +136,7 @@ Target CreateTarget(const std::string& target_name, return target::stackvm(); } - return target; + return Target(t); } TVM_REGISTER_API("_TargetCreate") @@ -423,7 +422,6 @@ Stmt BuildStmt(Schedule sch, // Phase 2 stmt = ir::Simplify(stmt); - stmt = ir::LowerStorageAccessInfo(stmt); stmt = ir::RemoveNoOp(stmt); if (!(config->disable_select_rewriting)) @@ -518,6 +516,7 @@ Array > split_dev_host_funcs(const Array& funcs, for (size_t i = 0; i < fhost.size(); ++i) { auto func = fhost[i]; func = ir::BindDeviceType(func, target->device_type); + func = ir::LowerDeviceStorageAccessInfo(func); func = ir::LowerTVMBuiltin(func); fhost.Set(i, func); } @@ -525,6 +524,7 @@ Array > split_dev_host_funcs(const Array& funcs, for (size_t i = 0; i < fhost.size(); ++i) { auto func = fhost[i]; func = ir::LowerIntrin(func, target_host->target_name); + func = ir::LowerDeviceStorageAccessInfo(func); func = ir::CombineContextCall(func); fhost.Set(i, func); } @@ -674,7 +674,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); struct GenericFunc::Manager { - std::unordered_map > fmap; + std::unordered_map fmap; // mutex std::mutex mutex; @@ -694,10 +694,11 @@ GenericFunc GenericFunc::Get(const std::string& name) { if (it == m->fmap.end()) { auto f = make_node(); f->name_ = name; - m->fmap[name] = f; - return GenericFunc(f); + auto gf = GenericFunc(f); + m->fmap[name] = gf; + return gf; } else { - return GenericFunc(it->second); + return it->second; } } @@ -707,12 +708,12 @@ void GenericFunc::RegisterGenericFunc(GenericFunc func, const std::string& name) auto it = m->fmap.find(name); CHECK(it == m->fmap.end()) << "GenericFunc already registered " << name; func->name_ = name; - m->fmap[name] = func.node_; + m->fmap[name] = func; } GenericFunc& GenericFunc::set_default(const PackedFunc value, - bool allow_override) { - auto node = static_cast(node_.get()); + bool allow_override) { + auto node = static_cast(operator->()); if (!allow_override) { CHECK(node->generic_func_ == nullptr) << "Generic function already registered for " << node->name_; @@ -736,7 +737,7 @@ GenericFunc& GenericFunc::register_func(const std::vector& tags, } void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const { - auto node = static_cast(node_.get()); + auto node = static_cast(get()); auto target = Target::Current(true); PackedFunc func; diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index ecf62ab0cfac..ab203f2aa28a 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -806,7 +806,7 @@ void CodeGenC::VisitStmt_(const Allocate* op) { void CodeGenC::VisitStmt_(const AttrStmt* op) { if (op->attr_key == ir::attr::thread_extent) { - IterVar iv(op->node.node_); + IterVar iv = Downcast(op->node); if (iv->thread_tag.length() != 0) { if (!var_idmap_.count(iv->var.get())) { BindThreadIndex(iv); diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index 241310fd00d4..39a3ab7df0cc 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include "codegen_cuda.h" @@ -74,6 +75,10 @@ std::string CodeGenCUDA::Finish() { decl_stream << "#include \n"; } + if (need_mma_h_) { + decl_stream << "#include \n"; + } + return CodeGenC::Finish(); } @@ -102,14 +107,22 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*) bool fail = false; if (t.is_float()) { switch (t.bits()) { - case 16: os << "half"; + case 16: enable_fp16_ = true; + if (lanes == 1) { + os << "half"; + } else if (lanes <= 8) { + CHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; + os << "float" << lanes / 2; + } else { + fail = true; + } break; case 32: os << "float"; break; case 64: os << "double"; break; default: fail = true; break; } - if (!fail && lanes == 1) return; + if (!fail && (lanes == 1 || t.bits() == 16)) return; if (!fail && (lanes >= 2 && lanes <= 4)) { os << lanes; return; } @@ -290,6 +303,113 @@ void CodeGenCUDA::PrintStorageScope( } } +void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) { + if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) { + need_mma_h_ = true; + CHECK_EQ(op->args.size(), 6U); + os << "nvcuda::wmma::fill_fragment("; + this->PrintExpr(op->args[0], os); + os << "["; + this->PrintExpr(op->args[4], os); + os << "], "; + this->PrintExpr(op->args[5], os); + os << ")"; + } else if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync)) { + need_mma_h_ = true; + CHECK_EQ(op->args.size(), 8U); + os << "nvcuda::wmma::load_matrix_sync("; + this->PrintExpr(op->args[0], os); + os << "["; + this->PrintExpr(op->args[4], os); + os << "], "; + this->PrintExpr(op->args[5], os); + os << ", "; + this->PrintExpr(op->args[6], os); + os << ")"; + } else if (op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) { + need_mma_h_ = true; + CHECK_EQ(op->args.size(), 8U); + os << "nvcuda::wmma::store_matrix_sync("; + this->PrintExpr(op->args[5], os); + os << ", "; + this->PrintExpr(op->args[0], os); + os << "["; + this->PrintExpr(op->args[4], os); + os << "], "; + this->PrintExpr(op->args[6], os); + if (const StringImm *str = op->args[7].as()) { + os << ", nvcuda::wmma::mem_" << str->value; + } else { + LOG(FATAL) << "Invalid parameters"; + } + os << ")"; + } else if (op->is_intrinsic(intrinsic::tvm_mma_sync)) { + need_mma_h_ = true; + CHECK_EQ(op->args.size(), 8U); + os << "nvcuda::wmma::mma_sync("; + for (int i = 0; i < 4; ++i) { + this->PrintExpr(op->args[i * 2], os); + os << "["; + this->PrintExpr(op->args[i * 2 + 1], os); + os << "]" << ((i < 3) ? ", ": ")"); + } + } else { + CodeGenC::VisitExpr_(op, os); + } +} + +void CodeGenCUDA::VisitStmt_(const AttrStmt* op) { + if (op->attr_key == attr::fragment_shape) { + const Variable* buffer = op->node.as(); + const StringImm* shape_str = op->value.as(); + fragment_shapes[buffer] = shape_str->value; + } else if (op->attr_key == attr::fragment_layout) { + const Variable* buffer = op->node.as(); + const StringImm* layout_str = op->value.as(); + fragment_layouts[buffer] = layout_str->value; + } + CodeGenC::VisitStmt_(op); +} + +void CodeGenCUDA::VisitStmt_(const Allocate* op) { + CHECK(!is_zero(op->condition)); + std::string vid = AllocVarID(op->buffer_var.get()); + if (op->new_expr.defined()) { + // Prefer global static allocation for the program + CHECK_EQ(op->free_function, "nop"); + std::string new_data = PrintExpr(op->new_expr); + this->PrintIndent(); + PrintType(op->type, stream); + stream << "* "<< vid << '=' << new_data << ";\n"; + } else { + this->PrintIndent(); + int32_t constant_size = op->constant_allocation_size(); + CHECK_GT(constant_size, 0) + << "Can only handle constant size stack allocation for now"; + const Variable* buffer = op->buffer_var.as(); + std::string scope = alloc_storage_scope_.at(buffer); + if (scope.find("wmma.") == 0) { + if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { + CHECK(op->type == Float(16) || op->type == Int(8) || op->type == UInt(8)) + << "Matrix_a and matrix_b only support half or char or unsigned char type for now"; + } else { + CHECK(op->type == Float(16) || op->type == Float(32) || op->type == Int(32)) + << "Accumulator only support half, float and int type for now"; + } + constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); + PrintWmmaScope(scope, op->type, buffer, stream); + } else { + PrintStorageScope(scope, stream); + stream << ' '; + PrintType(op->type, stream); + } + stream << ' '<< vid << '[' + << constant_size << "];\n"; + } + RegisterHandleType(op->buffer_var.get(), op->type); + this->PrintStmt(op->body); +} + void CodeGenCUDA::VisitStmt_(const Evaluate *op) { if (is_const(op->value)) return; const Call* call = op->value.as(); @@ -392,5 +512,49 @@ void CodeGenCUDA::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(* PrintConst(op, os, this); } +void CodeGenCUDA::PrintWmmaScope(const std::string &scope, Type t, + const Variable* variable, std::ostream &os) { + std::stringstream type; + PrintType(t, type); + std::string shape_str = fragment_shapes[variable]; + if (scope == "wmma.matrix_a") { + need_mma_h_ = true; + std::string layout_str = fragment_layouts[variable]; + os << "nvcuda::wmma::fragment"; + } else if (scope == "wmma.matrix_b") { + need_mma_h_ = true; + std::string layout_str = fragment_layouts[variable]; + os << "nvcuda::wmma::fragment"; + } else if (scope == "wmma.accumulator") { + need_mma_h_ = true; + os << "nvcuda::wmma::fragment"; + } +} + +int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string &scope, + const Variable* variable, int32_t size) { + std::string shape_str = fragment_shapes[variable]; + size_t m, n, k; + size_t last_pos = 0, pos = 0; + pos = shape_str.find(", ", last_pos); + m = std::stoi(shape_str.substr(last_pos, pos - last_pos)); + last_pos = pos + 2; + pos = shape_str.find(", ", last_pos); + n = std::stoi(shape_str.substr(last_pos, pos - last_pos)); + last_pos = pos + 2; + k = std::stoi(shape_str.substr(last_pos, shape_str.length() - last_pos)); + if (scope == "wmma.matrix_a") { + return size / m / k; + } else if (scope == "wmma.matrix_b") { + return size / n / k; + } else if (scope == "wmma.accumulator") { + return size / m / n; + } + return 0; +} + } // namespace codegen } // namespace tvm diff --git a/src/codegen/codegen_cuda.h b/src/codegen/codegen_cuda.h index 61c6fa3a5170..53e7db45efc6 100644 --- a/src/codegen/codegen_cuda.h +++ b/src/codegen/codegen_cuda.h @@ -28,6 +28,7 @@ #include #include #include +#include #include "codegen_c.h" namespace tvm { @@ -40,7 +41,7 @@ class CodeGenCUDA final : public CodeGenC { void AddFunction(LoweredFunc f); std::string Finish(); bool need_include_path() { - return (enable_fp16_ || enable_int8_ || need_math_constants_h_); + return (enable_fp16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_); } // override behavior void VisitStmt_(const ir::For* op) final; @@ -60,7 +61,10 @@ class CodeGenCUDA final : public CodeGenC { void VisitExpr_(const Shuffle* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImm *op, std::ostream& os) final; + void VisitExpr_(const Call *op, std::ostream& os) final; void VisitStmt_(const Evaluate *op) final; + void VisitStmt_(const Allocate *op) final; + void VisitStmt_(const AttrStmt *op) final; private: // Whether global barrier is needed. @@ -75,7 +79,14 @@ class CodeGenCUDA final : public CodeGenC { bool enable_int8_{false}; // whether need math_constants.h bool need_math_constants_h_{false}; + // whether need mma.h + bool need_mma_h_{false}; + + std::unordered_map fragment_shapes; + std::unordered_map fragment_layouts; friend void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p); + void PrintWmmaScope(const std::string& scope, Type t, const Variable* variable, std::ostream& os); + int32_t GetWmmaFragmentSize(const std::string &scope, const Variable* variable, int32_t size); }; } // namespace codegen diff --git a/src/codegen/codegen_opencl.cc b/src/codegen/codegen_opencl.cc index 0b33bf43c151..3120bb543aea 100644 --- a/src/codegen/codegen_opencl.cc +++ b/src/codegen/codegen_opencl.cc @@ -22,6 +22,7 @@ * \file codegen_opencl.cc */ #include +#include #include #include #include "codegen_opencl.h" diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index d009290bb2fe..de54e242ff40 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -1173,7 +1173,7 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) { void CodeGenLLVM::VisitStmt_(const AttrStmt* op) { if (op->attr_key == attr::thread_extent) { - IterVar iv(op->node.node_); + IterVar iv = Downcast(op->node); if (iv->thread_tag.length() != 0) { if (!var_map_.count(iv->var.get())) { var_map_[iv->var.get()] = GetThreadIndex(iv); diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc index 7caf3a258b6f..6a3b0571c9ab 100644 --- a/src/codegen/spirv/codegen_spirv.cc +++ b/src/codegen/spirv/codegen_spirv.cc @@ -606,7 +606,7 @@ void CodeGenSPIRV::VisitStmt_(const Allocate* op) { void CodeGenSPIRV::VisitStmt_(const AttrStmt* op) { if (op->attr_key == attr::thread_extent) { - IterVar iv(op->node.node_); + IterVar iv = Downcast(op->node); if (iv->thread_tag.length() != 0) { if (!var_map_.count(iv->var.get())) { var_map_[iv->var.get()] = GetThreadIndex(iv, op->value); diff --git a/src/codegen/spirv/intrin_rule_spirv.cc b/src/codegen/spirv/intrin_rule_spirv.cc index a046cc4f458c..fca9aa203f80 100644 --- a/src/codegen/spirv/intrin_rule_spirv.cc +++ b/src/codegen/spirv/intrin_rule_spirv.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,9 +18,9 @@ */ /*! - * Copyright (c) 2017 by Contributors * \file intrin_rule_spirv.cc */ +#include #include #include #include diff --git a/src/common/socket.h b/src/common/socket.h index 2a2d9166a134..39bcff863c10 100644 --- a/src/common/socket.h +++ b/src/common/socket.h @@ -27,8 +27,10 @@ #define TVM_COMMON_SOCKET_H_ #if defined(_WIN32) +#define NOMINMAX #include #include +#undef NOMINMAX using ssize_t = int; #ifdef _MSC_VER #pragma comment(lib, "Ws2_32.lib") diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 54616adc214e..778b6b1a7811 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -300,7 +300,7 @@ void CodeGenHybrid::VisitStmt_(const AttrStmt* op) { PrintStmt(op->body); indent_ -= tab_; } else if (op->attr_key == ir::attr::realize_scope) { - auto v = FunctionRef(op->node.node_); + auto v = Downcast(op->node); alloc_storage_scope_[v] = op->value.as()->value; PrintStmt(op->body); } else { @@ -408,7 +408,7 @@ void CodeGenHybrid::PrintIndent() { std::string CodeGenHybrid::GetVarID(const Variable *v) { if (binds_.count(v)) return binds_[v]; - auto key = std::make_pair(v->GetNodePtr().get(), 0); + auto key = std::make_pair(static_cast(v), 0); if (id_map_.count(key)) { return id_map_[key]; } diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index 498838fc908f..866756996f8d 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file codegen_hybrid.h * \brief Common utilities to generated C style code. */ diff --git a/src/contrib/sort/sort.cc b/src/contrib/sort/sort.cc index a87ce07cb602..0ccaee515acb 100644 --- a/src/contrib/sort/sort.cc +++ b/src/contrib/sort/sort.cc @@ -75,7 +75,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") // Currently only supports input dtype to be float32. CHECK_EQ(dtype.code, 2) << "Currently only supports input dtype " "to be float."; -#if (__ARM_FP16_FORMAT_IEEE != 1) +#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC != 1) CHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype " "to be float32."; #endif @@ -100,23 +100,23 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx))); } if (is_ascend) { -#if (__ARM_FP16_FORMAT_IEEE == 1) +#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) if (dtype.bits == 16) { std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<__fp16>); } else { #endif std::stable_sort(sorter.begin(), sorter.end(), CompareAscend); -#if (__ARM_FP16_FORMAT_IEEE == 1) +#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) } #endif } else { -#if (__ARM_FP16_FORMAT_IEEE == 1) +#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) if (dtype.bits == 16) { std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<__fp16>); } else { #endif std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); -#if (__ARM_FP16_FORMAT_IEEE == 1) +#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) } #endif } @@ -210,7 +210,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } -#if (__ARM_FP16_FORMAT_IEEE == 1) +#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) } else if (data_dtype == "float16") { if (out_dtype == "float16") { argsort<__fp16, __fp16>(input, output, axis, is_ascend); diff --git a/src/lang/api_registry.cc b/src/lang/api_registry.cc index e041f3a2dd2d..cd3d43b7dcf3 100644 --- a/src/lang/api_registry.cc +++ b/src/lang/api_registry.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -62,7 +62,7 @@ TVM_REGISTER_API("_EnvFuncGetPackedFunc") TVM_REGISTER_NODE_TYPE(EnvFuncNode) .set_creator(CreateEnvNode) -.set_global_key([](const Node* n) { +.set_global_key([](const Object* n) { return static_cast(n)->name; }); diff --git a/src/lang/attr_functor.h b/src/lang/attr_functor.h index 995dfb392e87..b9391e4895b9 100644 --- a/src/lang/attr_functor.h +++ b/src/lang/attr_functor.h @@ -44,17 +44,17 @@ class AttrFunctor; #define ATTR_FUNCTOR_DISPATCH(OP) \ vtable.template set_dispatch( \ - [](const NodeRef& n, TSelf* self, Args... args) { \ - return self->VisitAttr_(static_cast(n.node_.get()), \ + [](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitAttr_(static_cast(n.get()), \ std::forward(args)...); \ }); \ // A functor for common attribute information. template -class AttrFunctor { +class AttrFunctor { private: - using TSelf = AttrFunctor; - using FType = tvm::IRFunctor; + using TSelf = AttrFunctor; + using FType = tvm::IRFunctor; public: /*! \brief the result type of this functor */ @@ -65,7 +65,7 @@ class AttrFunctor { * \param args Additional arguments. * \return The result of the call */ - virtual R VisitAttr(const NodeRef& n, Args... args) { + virtual R VisitAttr(const ObjectRef& n, Args... args) { static FType vtable = InitVTable(); if (vtable.can_dispatch(n)) { return vtable(n, this, std::forward(args)...); @@ -73,7 +73,7 @@ class AttrFunctor { return VisitAttrDefault_(n.get(), std::forward(args)...); } } - virtual R VisitAttrDefault_(const Node* node, Args... args) = 0; + virtual R VisitAttrDefault_(const Object* node, Args... args) = 0; virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::IntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT; @@ -143,60 +143,60 @@ class AttrFunctor { }; class AttrsEqualHandler : - protected AttrFunctor { + protected AttrFunctor { public: /*! * \brief Check if lhs equals rhs * \param lhs The left operand. * \param rhs The right operand. */ - bool Equal(const NodeRef& lhs, const NodeRef& rhs); + bool Equal(const ObjectRef& lhs, const ObjectRef& rhs); protected: - bool VisitAttrDefault_(const Node* lhs, const NodeRef& other) final; - bool VisitAttr_(const ArrayNode* lhs, const NodeRef& other) final; - bool VisitAttr_(const StrMapNode* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::IntImm* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::UIntImm* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::FloatImm* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::StringImm* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Add* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Sub* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Mul* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Div* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Mod* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::FloorDiv* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::FloorMod* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Min* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Max* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::GE* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::GT* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::LT* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::LE* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::EQ* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::NE* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::And* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Or* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Not* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Cast* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Call* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Select* lhs, const NodeRef& other) final; + bool VisitAttrDefault_(const Object* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::IntImm* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::UIntImm* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::FloatImm* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::StringImm* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Add* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Sub* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Mul* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Div* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Mod* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::FloorDiv* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::FloorMod* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Min* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Max* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::GE* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::GT* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::LT* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::LE* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::EQ* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::NE* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::And* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Or* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Not* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Cast* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Call* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Select* lhs, const ObjectRef& other) final; }; class AttrsHashHandler : - protected AttrFunctor { + protected AttrFunctor { public: /*! * \brief Get hash value of node * \param node The node to be hashed. */ - size_t Hash(const NodeRef& node) { + size_t Hash(const ObjectRef& node) { if (!node.defined()) return 0; return this->VisitAttr(node); } protected: - size_t VisitAttrDefault_(const Node* lhs) final; + size_t VisitAttrDefault_(const Object* lhs) final; size_t VisitAttr_(const ir::IntImm* lhs) final; size_t VisitAttr_(const ir::UIntImm* lhs) final; size_t VisitAttr_(const ir::FloatImm* lhs) final; diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc index c5b14ac577ec..a299e17996e0 100644 --- a/src/lang/attrs.cc +++ b/src/lang/attrs.cc @@ -40,7 +40,7 @@ void DictAttrsNode::InitByPackedArgs( for (int i = 0; i < args.size(); i += 2) { std::string key = args[i]; runtime::TVMArgValue val = args[i + 1]; - if (val.type_code() == kNodeHandle) { + if (val.type_code() == kObjectHandle) { dict.Set(key, val.operator NodeRef()); } else if (val.type_code() == kStr) { dict.Set(key, Expr(val.operator std::string())); @@ -72,14 +72,14 @@ TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode); using namespace ir; // Equal handler. -bool AttrsEqualHandler::Equal(const NodeRef& lhs, const NodeRef& rhs) { +bool AttrsEqualHandler::Equal(const ObjectRef& lhs, const ObjectRef& rhs) { if (lhs.same_as(rhs)) return true; if (!lhs.defined() || !rhs.defined()) return false; return this->VisitAttr(lhs, rhs); } -bool AttrsEqualHandler::VisitAttrDefault_(const Node* lhs, const NodeRef& other) { - if (lhs->derived_from()) { +bool AttrsEqualHandler::VisitAttrDefault_(const Object* lhs, const ObjectRef& other) { + if (lhs->IsInstance()) { AttrsEqual equal; equal.handler_ = this; return static_cast(lhs)->ContentEqual( @@ -88,58 +88,58 @@ bool AttrsEqualHandler::VisitAttrDefault_(const Node* lhs, const NodeRef& other) return lhs == other.get(); } -bool AttrsEqualHandler::VisitAttr_(const IntImm* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const IntImm* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->value == rhs->value; } return false; } -bool AttrsEqualHandler::VisitAttr_(const UIntImm* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const UIntImm* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->value == rhs->value; } return false; } -bool AttrsEqualHandler::VisitAttr_(const FloatImm* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const FloatImm* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->value == rhs->value; } return false; } -bool AttrsEqualHandler::VisitAttr_(const StringImm* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const StringImm* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->value == rhs->value; } return false; } -bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { if (rhs->data.size() != lhs->data.size()) return false; for (size_t i = 0; i < lhs->data.size(); ++i) { - if (!Equal(NodeRef(lhs->data[i]), NodeRef(rhs->data[i]))) return false; + if (!Equal(lhs->data[i], rhs->data[i])) return false; } } return true; } -bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { if (rhs->data.size() != lhs->data.size()) return false; for (const auto& kv : lhs->data) { auto it = rhs->data.find(kv.first); if (it == rhs->data.end()) return false; - if (!Equal(NodeRef(kv.second), NodeRef(it->second))) return false; + if (!Equal(kv.second, it->second)) return false; } } return true; } #define TVM_DEFINE_ATTRS_BINOP_EQUAL(NodeName) \ - bool AttrsEqualHandler::VisitAttr_(const NodeName* lhs, const NodeRef& other) { \ + bool AttrsEqualHandler::VisitAttr_(const NodeName* lhs, const ObjectRef& other) { \ if (const auto* rhs = other.as()) { \ if (!Equal(lhs->a, rhs->a)) return false; \ if (!Equal(lhs->b, rhs->b)) return false; \ @@ -167,7 +167,7 @@ TVM_DEFINE_ATTRS_BINOP_EQUAL(NE); TVM_DEFINE_ATTRS_BINOP_EQUAL(And); TVM_DEFINE_ATTRS_BINOP_EQUAL(Or); -bool AttrsEqualHandler::VisitAttr_(const Not* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const Not* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return Equal(lhs->a, rhs->a); } else { @@ -175,7 +175,7 @@ bool AttrsEqualHandler::VisitAttr_(const Not* lhs, const NodeRef& other) { } } -bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { if (lhs->type != rhs->type) return false; return Equal(lhs->value, rhs->value); @@ -184,7 +184,7 @@ bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const NodeRef& other) { } } -bool AttrsEqualHandler::VisitAttr_(const Call* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const Call* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->name == rhs->name && @@ -196,7 +196,7 @@ bool AttrsEqualHandler::VisitAttr_(const Call* lhs, const NodeRef& other) { } } -bool AttrsEqualHandler::VisitAttr_(const Select* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const Select* lhs, const ObjectRef& other) { if (const auto* rhs = other.as