diff --git a/CMakeLists.txt b/CMakeLists.txt index bf18ffc9e856..19e026f95d82 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -252,6 +252,7 @@ include(cmake/modules/LLVM.cmake) include(cmake/modules/Micro.cmake) include(cmake/modules/ANTLR.cmake) include(cmake/modules/contrib/BLAS.cmake) +include(cmake/modules/contrib/Extern.cmake) include(cmake/modules/contrib/Random.cmake) include(cmake/modules/contrib/MicroStandaloneRuntime.cmake) include(cmake/modules/contrib/Sort.cmake) diff --git a/cmake/config.cmake b/cmake/config.cmake index 1ef956c7ee18..dabe1e930a60 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -163,6 +163,11 @@ set(USE_ROCBLAS OFF) # Whether use contrib sort set(USE_SORT ON) +# Whether use contrib extern (use ";" to separate multiple externs) +# Available externs: +# dnnl +set(USE_EXTERN none) + # Build ANTLR parser for Relay text format # Possible values: # - ON: enable ANTLR by searching default locations (cmake find_program for antlr4 and /usr/local for jar) diff --git a/cmake/modules/contrib/Extern.cmake b/cmake/modules/contrib/Extern.cmake new file mode 100644 index 000000000000..cf381a080b88 --- /dev/null +++ b/cmake/modules/contrib/Extern.cmake @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +message(STATUS "Build with relay.backend.contrib") + +file(GLOB GCC_RELAY_CONTRIB_SRC src/relay/backend/contrib/gcc/codegen.cc) +list(APPEND COMPILER_SRCS ${GCC_RELAY_CONTRIB_SRC}) + +list(FIND USE_EXTERN "dnnl" DNNL_IDX) +if(DNNL_IDX GREATER -1) + file(GLOB DNNL_RELAY_CONTRIB_SRC src/relay/backend/contrib/dnnl/codegen.cc) + list(APPEND COMPILER_SRCS ${DNNL_RELAY_CONTRIB_SRC}) + + find_library(EXTERN_LIBRARY_DNNL dnnl) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${EXTERN_LIBRARY_DNNL}) + file(GLOB DNNL_CONTRIB_SRC src/runtime/contrib/dnnl/*) + list(APPEND RUNTIME_SRCS ${DNNL_CONTRIB_SRC}) + message(STATUS "Use extern library: MKLDNN" ${EXTERN_LIBRARY_DNNL}) +endif() + diff --git a/include/tvm/relay/attrs/annotation.h b/include/tvm/relay/attrs/annotation.h index fd21db5a9c14..cc7803ecde6f 100644 --- a/include/tvm/relay/attrs/annotation.h +++ b/include/tvm/relay/attrs/annotation.h @@ -57,6 +57,19 @@ struct CastHintAttrs : public tvm::AttrsNode { } }; +/*! + * \brief Options for the subgraph operators. + */ +struct SubgraphAttrs : public tvm::AttrsNode { + /*! \brief The 3rd party compiler for subgraph code generation. */ + std::string compiler; + + TVM_DECLARE_ATTRS(SubgraphAttrs, "relay.attrs.SubgraphAttrs") { + TVM_ATTR_FIELD(compiler) + .describe("The 3rd compiler used for subgraph code generation."); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_ANNOTATION_H_ diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 2aa88099a69c..ee6db9342bb7 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -268,6 +268,14 @@ class FunctionNode : public ExprNode { */ bool IsPrimitive() const; + /*! + * \brief Check whether the function is an external function. + * External functions are subgraphes that supported by external libraries. + * + * \return Whether the function is external or not. + */ + bool IsExternal() const; + TVM_DLL static Function make(tvm::Array params, Expr body, Type ret_type, @@ -588,6 +596,25 @@ std::string AsText(const NodeRef& node, bool show_meta_data = true, runtime::TypedPackedFunc annotate = nullptr); +/*! \brief namespace of the attributes that are attached to a function. */ +namespace attr { +/*! \brief Mark the function as a primitive function. */ +constexpr const char* kPrimitive = "Primitive"; +/*! + * \brief Mark the function as an external function that needs to be handled by + * the external codegen tool/backend. + */ +constexpr const char* kExternal = "External"; +/*! \brief Indicate if the function is a closure. */ +constexpr const char* kClosure = "Closure"; +/*! \brief Store a Var to parameter/Constant mapping on a Function. */ +constexpr const char* kParams = "__params__"; +/*! \brief Store the function name. */ +constexpr const char* kFuncName = "FuncName"; +/*! \brief Mark if the function should be avoided being optimized. */ +constexpr const char* kSkipOptimization = "SkipOptimization"; +} // namespace attr + } // namespace relay } // namespace tvm #endif // TVM_RELAY_EXPR_H_ diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index 741e8b478828..d86745916fb6 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -29,6 +29,7 @@ #include #include #include +#include namespace tvm { namespace relay { @@ -122,7 +123,7 @@ using FTVMSchedule = runtime::TypedPackedFunc< * operator with other expressions. This function will be invoked * in AlterOpLayout pass. * \param attrs The attribute of the original node. - * \param inputs The input symbols of the original node. + * \param args The input symbols of the original node. * \param tinfos An array of placeholders, use for getting the inferred shape * and dtype of the inputs. * \return new_expr The modified expression. @@ -136,8 +137,8 @@ using FTVMAlterOpLayout = runtime::TypedPackedFunc< * \brief Legalizes an expression with another expression. This function will be * invoked in Legalize pass. It is a target-dependent pass. * \param attrs The attribute of the original node. - * \param inputs The input symbols of the original node. - * \param tinfos An array of placeholders, use for getting the inferred shape + * \param args The input symbols of the original node. + * \param arg_types An array of placeholders, use for getting the inferred shape * and dtype of the inputs. * \return new_expr The modified expression. */ @@ -146,6 +147,22 @@ using FTVMLegalize = runtime::TypedPackedFunc< const Array& args, const Array& arg_types)>; +/*! + * \brief Annotates an expression to indicate which external codegen tool an op + * should be scheduled to. It is a hardware dependent pass. + * + * \param attrs The attribute of the original expr. + * \param args The arguments of the original expr. + * \param compiler The external compiler that is used for external ops. + * + * \return true if this op should be registered with external codegen tool, + * otherwise, false. + */ +using FTVMExternOp = runtime::TypedPackedFunc< + bool(const Attrs& attrs, // NOLINT(*) + const Array& args, + const std::string& compiler)>; + /*! * \brief Forward rewriting rule for a specific op. * diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index ddadbe4fc31d..92eb99f2cd94 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -576,6 +576,14 @@ TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var); */ TVM_DLL Pass PrintIR(bool show_meta_data = true); +/*! + * \brief Partition a Relay program into regions that can be executed on + * different backends. + * + * \return The pass. + */ +TVM_DLL Pass PartitionGraph(); + } // namespace transform /*! diff --git a/include/tvm/runtime/contrib/dnnl/dnnl_kernel.h b/include/tvm/runtime/contrib/dnnl/dnnl_kernel.h new file mode 100644 index 000000000000..39ebcc2aa55a --- /dev/null +++ b/include/tvm/runtime/contrib/dnnl/dnnl_kernel.h @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file include/tvm/runtime/contrib/dnnl/dnnl_kernel.h + * \brief Use external dnnl library kernels. + */ + +#ifndef TVM_RUNTIME_CONTRIB_DNNL_DNNL_KERNEL_H_ +#define TVM_RUNTIME_CONTRIB_DNNL_DNNL_KERNEL_H_ + +#include "dnnl.hpp" + +namespace tvm { +namespace runtime { +namespace contrib { + +using namespace dnnl; + +extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, + int p_C_, int p_H_, int p_W_, int p_O_, int p_G_, + int p_Ph_, int p_Pw_, int p_Kh_, int p_Kw_, + int p_Sh_, int p_Sw_); + +extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, + int p_I_, int p_O_); + +extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, + int p_W_); + +extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, + float* variance, float* out, int p_n_, int p_c_, + int p_h_, int p_w_, int p_e_); + +extern "C" void dnnl_add(float* data, float* weight, float* out, int p_n_, + int p_c_, int p_h_, int p_w_); + +} // namespace contrib +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_CONTRIB_DNNL_DNNL_KERNEL_H_ diff --git a/python/tvm/module.py b/python/tvm/module.py index 2790227f32c7..d9676169cc5a 100644 --- a/python/tvm/module.py +++ b/python/tvm/module.py @@ -133,7 +133,16 @@ def export_library(self, self.save(path_obj) files = [path_obj] is_system_lib = self.type_key == "llvm" and self.get_function("__tvm_is_system_module")() + has_imported_c_file = False if self.imported_modules: + for i, m in enumerate(self.imported_modules): + if m.type_key == "c": + has_imported_c_file = True + c_file_name = "tmp_" + str(i) + ".cc" + path_cc = temp.relpath(c_file_name) + with open(path_cc, "w") as f: + f.write(m.get_source()) + files.append(path_cc) path_cc = temp.relpath("devc.cc") with open(path_cc, "w") as f: f.write(_PackImportsToC(self, is_system_lib)) @@ -143,7 +152,7 @@ def export_library(self, fcompile = _tar.tar else: fcompile = _cc.create_shared - if self.type_key == "c": + if self.type_key == "c" or has_imported_c_file: options = [] if "options" in kwargs: opts = kwargs["options"] diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index c7cbcf096a6c..60057e3387b4 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -29,7 +29,7 @@ from . import adt from . import analysis from . import transform -from .build_module import build, create_executor, optimize +from .build_module import build, create_executor, optimize, build_extern from .transform import build_config from . import prelude from . import parser diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 28ce16b9b452..c9280552216b 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -30,6 +30,7 @@ from .module import Module as _Module from .backend import interpreter as _interpreter from .backend.vm import VMExecutor +from . import transform as _transform def _update_target(target): target = target if target else _target.current_target() @@ -296,6 +297,33 @@ def optimize(mod, target=None, params=None): return mod, params +def build_extern(mod, target): + """Helper function that builds a Relay function to run on external codegen + tools. + + Parameters + ---------- + mod : relay.Module + The module to build. Using relay.Function is deprecated. + + target : str + The name of the external compilation target. + + Returns + ------- + mod : relay.Module + The relay module contains partitioned subgraphes for external codegen + tools. + """ + if isinstance(mod, _expr.Function): + mod = _Module.from_expr(mod) + + seq = _transform.Sequential([_transform.ExternOp(target), + _transform.PartitionGraph()]) + mod = seq(mod) + return mod + + class GraphExecutor(_interpreter.Executor): """Wrapper around Executor interface. diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index a089cab669c9..f246750e5cd9 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -19,7 +19,7 @@ # operator defs from .op import get, register, register_schedule, register_compute, register_gradient, \ register_pattern, register_alter_op_layout, register_legalize, \ - schedule_injective, Op, OpPattern, debug + register_extern_op, schedule_injective, Op, OpPattern, debug # Operators from .reduce import * diff --git a/python/tvm/relay/op/annotation/annotation.py b/python/tvm/relay/op/annotation/annotation.py index 2b9d4bcd81bc..835a04c5bec9 100644 --- a/python/tvm/relay/op/annotation/annotation.py +++ b/python/tvm/relay/op/annotation/annotation.py @@ -62,6 +62,7 @@ def stop_fusion(data): """ return _make.stop_fusion(data) + def checkpoint(data): """Annotate an expression to be a checkpoint for the checkpointing memory optimization. @@ -78,3 +79,42 @@ def checkpoint(data): return _make.checkpoint(data) register_schedule("annotation.checkpoint", schedule_injective) + + +def subgraph_begin(data, compiler): + """Annotate an expression to indicate that it is the beginning of + a subgraph. + + Parameters + ---------- + data : tvm.relay.Expr + The expression to be annotated. + + compiler : Str + The compiler used to generate code of a subgraph. + + Returns + ------- + result : tvm.relay.Expr + The annotated expression. + """ + return _make.subgraph_begin(data, compiler) + + +def subgraph_end(data, compiler): + """Annotate an expression to indicate that it is the end of a subgraph. + + Parameters + ---------- + data : tvm.relay.Expr + The expression to be annotated. + + compiler : Str + The compiler used to generate code of a subgraph. + + Returns + ------- + result : tvm.relay.Expr + The annotated expression. + """ + return _make.subgraph_end(data, compiler) diff --git a/python/tvm/relay/op/contrib/__init__.py b/python/tvm/relay/op/contrib/__init__.py index 3159006486b3..a369f143d4c0 100644 --- a/python/tvm/relay/op/contrib/__init__.py +++ b/python/tvm/relay/op/contrib/__init__.py @@ -18,4 +18,5 @@ """Neural network related operators.""" from __future__ import absolute_import as _abs from .contrib import * +from .extern_op import * from . import _contrib diff --git a/python/tvm/relay/op/contrib/dnnl/__init__.py b/python/tvm/relay/op/contrib/dnnl/__init__.py new file mode 100644 index 000000000000..0da426ab4741 --- /dev/null +++ b/python/tvm/relay/op/contrib/dnnl/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import +"""Neural network related operators.""" +from __future__ import absolute_import as _abs +from .extern_op import * diff --git a/python/tvm/relay/op/contrib/dnnl/extern_op.py b/python/tvm/relay/op/contrib/dnnl/extern_op.py new file mode 100644 index 000000000000..fb967872a588 --- /dev/null +++ b/python/tvm/relay/op/contrib/dnnl/extern_op.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""CBLAS library supported operators.""" +from __future__ import absolute_import + + +def conv2d(attrs, args): + """Check if the external codegen should be used. + """ + return True + + +def dense(attrs, args): + """Check if the external codegen should be used. + """ + return True + + +def relu(attrs, args): + """Check if the external codegen should be used. + """ + return True + + +def batch_norm(attrs, args): + """Check if the external codegen should be used. + FIXME: Turn off due to not support of multiple outputs. + """ + return False + + +def add(attrs, args): + """Check if the external codegen should be used. + """ + return True diff --git a/python/tvm/relay/op/contrib/extern_op.py b/python/tvm/relay/op/contrib/extern_op.py new file mode 100644 index 000000000000..e1310f7a25bd --- /dev/null +++ b/python/tvm/relay/op/contrib/extern_op.py @@ -0,0 +1,106 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +""" +External compiler related feature registration. + +It implements dispatchers that check if an operator should use the external +codegen tool. + +Each compiler can customize the support of the operator. For example, they can +check the attribute of an operator and/or the features of the input arguments +to decide if we should use the external compiler. +""" +from __future__ import absolute_import + +import logging +import pkgutil +from pathlib import Path +from importlib import import_module + +from .. import op as reg + +logger = logging.getLogger('ExternOp') + +# Load available contrib compilers +compilers = {} +for _, name, _ in pkgutil.iter_modules([Path(__file__).parent]): + compilers[name] = import_module( + '.%s' % name, package='.'.join(__name__.split('.')[:-1])) + + +def get_extern_op(compiler, op_name): + """Get the extern op function from the registered compiler + """ + if compiler in compilers: + if hasattr(compilers[compiler], 'extern_op'): + extern_op = getattr(compilers[compiler], 'extern_op') + if hasattr(extern_op, op_name): + return getattr(extern_op, op_name) + + logger.warning("%s in %s is not registered. Fallback to CPU", op_name, + compiler) + return lambda x, y: False + + +@reg.register_extern_op("nn.conv2d") +def external_conv2d(attrs, args, compiler): + """Check if the external compiler should be used. + """ + return get_extern_op(compiler, 'conv2d')(attrs, args) + + +@reg.register_extern_op("nn.dense") +def external_dense(attrs, args, compiler): + """Check if the external compiler should be used. + """ + return get_extern_op(compiler, 'dense')(attrs, args) + + +@reg.register_extern_op("nn.relu") +def external_relu(attrs, args, compiler): + """Check if the external compiler should be used. + """ + return get_extern_op(compiler, 'relu')(attrs, args) + + +@reg.register_extern_op("nn.batch_norm") +def external_batch_norm(attrs, args, compiler): + """Check if the external compiler should be used. + """ + return get_extern_op(compiler, 'batch_norm')(attrs, args) + + +@reg.register_extern_op("subtract") +def external_subtract(attrs, args, compiler): + """Check if the external compiler should be used. + """ + return get_extern_op(compiler, 'subtract')(attrs, args) + + +@reg.register_extern_op("add") +def external_add(attrs, args, compiler): + """Check if the external compiler should be used. + """ + return get_extern_op(compiler, 'add')(attrs, args) + + +@reg.register_extern_op("multiply") +def external_multiply(attrs, args, compiler): + """Check if the external compiler should be used. + """ + return get_extern_op(compiler, 'multiply')(attrs, args) diff --git a/python/tvm/relay/op/contrib/gcc/__init__.py b/python/tvm/relay/op/contrib/gcc/__init__.py new file mode 100644 index 000000000000..0da426ab4741 --- /dev/null +++ b/python/tvm/relay/op/contrib/gcc/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import +"""Neural network related operators.""" +from __future__ import absolute_import as _abs +from .extern_op import * diff --git a/python/tvm/relay/op/contrib/gcc/extern_op.py b/python/tvm/relay/op/contrib/gcc/extern_op.py new file mode 100644 index 000000000000..1d85f1916992 --- /dev/null +++ b/python/tvm/relay/op/contrib/gcc/extern_op.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""GCC compiler supported operators.""" +from __future__ import absolute_import + +def conv2d(attrs, args): + """Check if the external codegen should be used. + """ + return False + +def subtract(attrs, args): + """Check if the external codegen should be used. + """ + return True + +def add(attrs, args): + """Check if the external codegen should be used. + """ + return True + +def multiply(attrs, args): + """Check if the external codegen should be used. + """ + return True diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 355496e42b48..a30688c9fafc 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -229,6 +229,7 @@ def register_pattern(op_name, pattern, level=10): """ return register(op_name, "TOpPattern", pattern, level) + def register_gradient(op_name, fgradient=None, level=10): """Register operator pattern for an op. @@ -266,6 +267,25 @@ def register_shape_func(op_name, data_dependant, shape_func=None, level=10): get(op_name).set_attr("TShapeDataDependant", data_dependant, level) return register(op_name, "FShapeFunc", shape_func, level) +def register_extern_op(op_name, fextern=None, level=10): + """Register the external codegen tool for an op. + + Parameters + ---------- + op_name : str + The name of the operator. + + fextern : function (attrs: Attrs, args: List[Expr], compiler: str) + -> new_expr: Expr + The function for wrapping a call expr with subgraph_start and + subgraph_end. + + level : int + The priority level + """ + return register(op_name, "FTVMExternOp", fextern, level) + + _init_api("relay.op", __name__) @register_func("relay.op.compiler._lower") diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 540c1f5b79cd..81474e207233 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -480,6 +480,24 @@ def Legalize(legalize_map_attr_name="FTVMLegalize"): return _transform.Legalize(legalize_map_attr_name) +def ExternOp(compiler): + """Set ops in an experession as external ops so that it will use the + external codegen tool. + + Parameters + ---------- + compiler : str + The compiler used for external codegen. + + Returns + ------- + ret : tvm.relay.Pass + The annotated pass that wrapps ops with subgraph_start and + subgraph_end. + """ + return _transform.ExternOp(compiler) + + def RewriteAnnotatedOps(fallback_device): """Rewrite the annotated program where annotation operators, e.g. `on_deivce`, mark which device an expression should be scheduled to. @@ -635,6 +653,18 @@ def PrintIR(show_meta_data=True): return _transform.PrintIR(show_meta_data) +def PartitionGraph(): + """Partition a Relay program into regions that can be executed on different + backends. + + Returns + ------- + ret: tvm.relay.Pass + The registered pass that partitions the Relay program. + """ + return _transform.PartitionGraph() + + def gradient(expr, mod=None, mode='higher_order'): """ Transform the input function, diff --git a/src/codegen/codegen.cc b/src/codegen/codegen.cc index 8464e3dbbb2a..921b5a0bca5c 100644 --- a/src/codegen/codegen.cc +++ b/src/codegen/codegen.cc @@ -69,6 +69,7 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) { << "Only support simply one-level hierarchy"; std::string tkey = im->type_key(); stream->Write(tkey); + if (tkey == "c") continue; im->SaveToBinary(stream); } // translate to C program diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 9254c7e3e7b9..6d0fe581f9d2 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -73,6 +73,14 @@ struct GraphCodegen { return CallFunc("get_graph_json", nullptr); } + Array GetExternalFuncs() { + return CallFunc >("get_external_funcs", nullptr); + } + + runtime::Module GetExternalModule() { + return CallFunc("get_external_module", nullptr); + } + Map > GetLoweredFunc() { return CallFunc > >("get_lowered_funcs", nullptr); } @@ -148,6 +156,14 @@ class RelayBuildModule : public runtime::ModuleNode { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->graph_codegen_->GetLoweredFunc(); }); + } else if (name == "get_external_funcs") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->graph_codegen_->GetExternalFuncs(); + }); + } else if (name == "get_external_module") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->graph_codegen_->GetExternalModule(); + }); } else if (name == "optimize") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { CHECK_EQ(args.num_args, 2); @@ -474,6 +490,16 @@ class RelayBuildModule : public runtime::ModuleNode { target_host_, BuildConfig::Current()); } + Array external_funcs = graph_codegen_->GetExternalFuncs(); + if (!external_funcs.empty()) { + auto ext_rt_mod = graph_codegen_->GetExternalModule(); + // Execute the whole module using external runtime. + if (lowered_funcs.size() == 0) { + ret_.mod = ext_rt_mod; + } else { + ret_.mod.Import(ext_rt_mod); + } + } } protected: diff --git a/src/relay/backend/contrib/contrib_codegen.h b/src/relay/backend/contrib/contrib_codegen.h new file mode 100644 index 000000000000..f7e651251b97 --- /dev/null +++ b/src/relay/backend/contrib/contrib_codegen.h @@ -0,0 +1,284 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/contrib/contrib_codegen.h + * \brief The base class for external codegen tools. + */ +#ifndef TVM_RELAY_BACKEND_CONTRIB_CONTRIB_CODEGEN_H_ +#define TVM_RELAY_BACKEND_CONTRIB_CONTRIB_CODEGEN_H_ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace contrib { + +class ExternCodegenBase { + public: + ExternCodegenBase() = default; + + /*! + * \brief Create a runtime module for the external library. For example, it + * could be a CSourceModule that can be directly compiled and linked together + * with a DSOModule, or a json style module that emitts a json artifact that + * is able to be executed by a customized json runtime. + * + * \param ref The subgraph Relay expression/module to be executed using extern ops. + * + * \return A runtime module. + */ + virtual runtime::Module CreateExternModule(const NodeRef& ref) = 0; + + /*! + * \brief Split the Relay function name to tokens. + * + * \param func The provided function. + * \param prefix The prefix of the function name, i.e. dnnl. + * + * \return A vector of tokenized function name splitted by "_". + */ + std::string GetSubgraphID(const Function& func, const std::string& prefix) const { + const auto name_node = + FunctionGetAttr(func, attr::kFuncName).as(); + CHECK(name_node != nullptr) << "Fail to retrieve subgraph name."; + std::string name = name_node->value; + return GetSubgraphID(name, prefix); + } + + /*! + * \brief Split the encoded function name to tokens. + * + * \param the function name string. + * + * \return a vector of tokenized function name splitted by "_". + */ + std::string GetSubgraphID(const std::string& name, const std::string& prefix) const { + std::string temp = name; + std::vector tokens; + std::string delimiter = "_"; + size_t pos = 0; + std::string token; + while ((pos = temp.find(delimiter)) != std::string::npos) { + token = temp.substr(0, pos); + tokens.push_back(token); + temp.erase(0, pos + delimiter.length()); + } + tokens.push_back(temp); + + CHECK(tokens.size() >= 2) << "Invalid subgraph name: " << name; + CHECK(tokens[0] == prefix) + << "Function name: " << name + << " does not start with: " << prefix; + return tokens[1]; + } +}; + +// A helper class to write the declaration of external functions. +class ExternSourcePrinter { + protected: + /*! \brief Print indents using spaces. */ + void PrintIndents() { + for (int i = 0; i < indent_; i++) { + code_stream_ << ' '; + } + } + + /*! + * \brief Enter a new scope. + */ + void EnterScope() { indent_ += 2; } + + /*! + * \brief Exit a scope. + */ + void ExitScope() { + CHECK_GE(indent_, 2U) << "Wrong ident found."; + indent_ -= 2; + } + + /*! + * \brief Gerenate a wrapper for the subgraph that will use external codegen. + * + * \param func_name The name of wrapper function. + * \param arg_cnt The expected number of arguments for the wrapper. + * + * \code + * + * // An example code for the wrapper. + * extern "C" void foo(TVMValue* value, int* type_code, int nargs) { + * if (nargs != 3) { + * printf("foo expects 3 args, but received %d\n", nargs); + * return 1; + * } + * + * DLTensor* arg0 = static_cast(value[0].v_handle); + * DLTensor* arg1 = static_cast(value[1].v_handle); + * DLTensor* out = static_cast(value[2].v_handle); + * + * foo_(static_cast(arg0->data), + * static_cast(arg1->data), + * static_cast(out->data)); + * return 0; + * } + * + * \endcode + */ + void GenerateSubgraphWrapper(const std::string& func_name, int arg_cnt) { + // Print signature + code_stream_ << "\n"; + code_stream_ << "extern \"C\" int " << func_name; + code_stream_ << "(TVMValue* value, int* type_code, int nargs) {\n"; + EnterScope(); + // Print guard + PrintIndents(); + code_stream_ << "if (nargs != " << arg_cnt << "){\n"; + EnterScope(); + PrintIndents(); + code_stream_ << "printf(\"" << func_name << " expects " << arg_cnt + << "arguments, but received %d\\n\", nargs);\n"; + PrintIndents(); + code_stream_ << "return 1;\n"; + ExitScope(); + PrintIndents(); + code_stream_ << "}\n"; + + // According to TVM's calling convention, the last one is output. + for (int i = 0; i < arg_cnt; i++) { + PrintIndents(); + code_stream_ << "DLTensor* arg" << i << " = " + << "static_cast(value[" << i << "].v_handle);\n"; + } + // Generate the call. + PrintIndents(); + code_stream_ << func_name << "_("; + for (int i = 0; i < arg_cnt - 1; i++) { + code_stream_ << "static_cast(arg" << i << "->data), "; + } + if (arg_cnt > 0) { + code_stream_ << "static_cast(arg" << arg_cnt - 1 << "->data)"; + } + code_stream_ << ");\n\n"; + PrintIndents(); + code_stream_ << "return 0;\n"; + ExitScope(); + code_stream_ << "}"; + } + + /*! + * \brief Emit the code for external runtime. + * + * \return The code string. + */ + virtual std::string JIT() = 0; + + /*! + * \brief Extract the shape from a Relay tensor type. + * + * \param type The provided type. + * + * \return The extracted shape in a list. + */ + std::vector GetShape(const Type& type) const { + const auto* ttype = type.as(); + CHECK(ttype) << "Expect TensorTypeNode"; + std::vector shape; + for (size_t i = 0; i < ttype->shape.size(); ++i) { + auto* val = ttype->shape[i].as(); + CHECK(val); + shape.push_back(val->value); + } + return shape; + } + + /*! + * \briefa A common interface that that used by various external runtime to + * generate the wrapper to invoke external kernels. + * + * \param subgraph_id The unique id of an external function. It will be used + * during runtime to pick the correct external function. + * \param args The arguments used by the external function. + * \param buf_decl The declaration of temporary buffers that used to store the + * intermeidate of each external kernel. + * \param body The statements of the external function. + * \param out The name and id pairs for output. + * + * \return The emitted code string. + */ + std::string JitImpl(std::string subgraph_id, + std::vector args, + std::vector buf_decl, + std::vector body, + std::vector> out) { + // Create the signature. For example, it could be: + // extern "C" void dnnl_0_(float* input0, float* input1, float* out, int M, int N) {} + code_stream_ << "extern \"C\" void " << subgraph_id << "_("; + + for (const auto& arg : args) { + code_stream_ << "float* " << arg << ", "; + } + code_stream_ << "float* out) {\n"; + this->EnterScope(); + + // Function body + for (auto decl : buf_decl) { + this->PrintIndents(); + code_stream_ << decl << "\n"; + } + code_stream_ << "\n"; + for (auto stmt : body) { + this->PrintIndents(); + code_stream_ << stmt << "\n"; + } + + // Copy output + CHECK_EQ(out.size(), 1U) << "Internal error: only single output is support."; + this->PrintIndents(); + code_stream_ << "std::memcpy(out, " << out[0].first << ", 4 * " << out[0].second << ");\n"; + + // Free buffers + for (size_t i = 0; i < buf_decl.size(); i++) { + this->PrintIndents(); + code_stream_ << "std::free(buf_" << i << ");\n"; + } + + this->ExitScope(); + code_stream_ << "}\n"; + + // Create the wrapper to call the subgraph + this->GenerateSubgraphWrapper(subgraph_id, args.size() + 1 /* output */); + return code_stream_.str(); + } + + /*! \brief The external function source code stream. */ + std::ostringstream code_stream_; + + private: + /*! \brief Indent of the source code. */ + int indent_{0}; +}; + +} // namespace contrib +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_BACKEND_CONTRIB_CONTRIB_CODEGEN_H_ diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc new file mode 100644 index 000000000000..1057af560665 --- /dev/null +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -0,0 +1,286 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/contrib/dnnl/codegen.cc + * \brief Implementation of DNNL codegen APIs. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "../contrib_codegen.h" + +namespace tvm { +namespace relay { +namespace contrib { + +// TODO(@zhiics, @comaniac): This is basic implementation. We should implement +// all utilities and make a base class for users to implement. +class DnnlBuilder : public ExprVisitor, public ExternSourcePrinter { + public: + explicit DnnlBuilder(const std::string& id) { this->subgraph_id_ = id; } + + void VisitExpr_(const VarNode* node) final { + subgraph_args_.push_back(node->name_hint()); + out_.clear(); + out_.push_back({node->name_hint(), 0}); + } + + void VisitExpr_(const TupleGetItemNode* op) final { + // Do nothing + } + + void VisitExpr_(const CallNode* call) final { + std::ostringstream decl_stream; + std::ostringstream buf_stream; + // Args: ID + std::vector args; + + if (IsOp(call, "nn.conv2d")) { + decl_stream << "dnnl_conv2d"; + const auto* conv2d_attr = call->attrs.as(); + + auto ishape = GetShape(call->args[0]->checked_type()); + auto wshape = GetShape(call->args[1]->checked_type()); + + // Args: N, C, H, W + for (auto s : ishape) { + args.push_back(std::to_string(s)); + } + + // Args: O, G, Ph, Pw, Kh, Kw, Sh, Sw + args.push_back(std::to_string(wshape[0])); + args.push_back(std::to_string(conv2d_attr->groups)); + args.push_back(std::to_string(conv2d_attr->padding[0].as()->value)); + args.push_back(std::to_string(conv2d_attr->padding[1].as()->value)); + args.push_back(std::to_string(wshape[2])); + args.push_back(std::to_string(wshape[3])); + args.push_back(std::to_string(conv2d_attr->strides[0].as()->value)); + args.push_back(std::to_string(conv2d_attr->strides[1].as()->value)); + } else if (IsOp(call, "nn.dense")) { + decl_stream << "dnnl_dense"; + auto ishape = GetShape(call->args[0]->checked_type()); + auto wshape = GetShape(call->args[1]->checked_type()); + + // Args: N, C, O + args.push_back(std::to_string(ishape[0])); + args.push_back(std::to_string(ishape[1])); + args.push_back(std::to_string(wshape[0])); + + } else if (IsOp(call, "nn.relu")) { + decl_stream << "dnnl_relu"; + auto ishape = GetShape(call->args[0]->checked_type()); + + // Args: N, C, H, W + for (auto s : ishape) { + args.push_back(std::to_string(s)); + } + } else if (IsOp(call, "nn.batch_norm")) { + decl_stream << "dnnl_bn"; + const auto* bn_attr = call->attrs.as(); + auto ishape = GetShape(call->args[0]->checked_type()); + + // Args: N, C, H, W + for (auto s : ishape) { + args.push_back(std::to_string(s)); + } + + // Args: epilson + args.push_back(std::to_string(bn_attr->epsilon)); + } else if (IsOp(call, "add")) { + decl_stream << "dnnl_add"; + auto ishape = GetShape(call->args[0]->checked_type()); + + // Args: H, W + for (auto s : ishape) { + args.push_back(std::to_string(s)); + } + } else { + LOG(FATAL) << "Unsupported op: " << AsText(call->op, false); + } + + // Make function call with input buffers when visiting arguments + bool first = true; + decl_stream << "("; + for (size_t i = 0; i < call->args.size(); ++i) { + VisitExpr(call->args[i]); + for (auto out : out_) { + if (!first) { + decl_stream << ", "; + } + first = false; + decl_stream << out.first; + } + } + + // Analyze the output buffer + auto type_node = call->checked_type().as(); + CHECK(type_node != nullptr && runtime::TypeMatch(type_node->dtype, kDLFloat, 32)) + << "Only support single output tensor with float type"; + std::string out = "buf_" + std::to_string(buf_idx_++); + auto out_shape = GetShape(call->checked_type()); + int out_size = 1; + for (size_t i = 0; i < out_shape.size(); ++i) { + out_size *= out_shape[i]; + } + this->PrintIndents(); + buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");"; + buf_decl_.push_back(buf_stream.str()); + decl_stream << ", " << out; + + // Attach attribute arguments + for (size_t i = 0; i < args.size(); ++i) { + decl_stream << ", " << args[i]; + } + decl_stream << ");"; + subgraph_body.push_back(decl_stream.str()); + + // Update output buffer + out_.clear(); + out_.push_back({out, out_size}); + } + + std::string JIT(void) { + return JitImpl(subgraph_id_, subgraph_args_, buf_decl_, subgraph_body, out_); + } + + private: + /*! \brief The id of the external dnnl subgraph. */ + std::string subgraph_id_{""}; + /*! + * \brief The index to track the output buffer. Each kernel will redirect the + * output to a buffer that may be consumed by other kernels. + */ + int buf_idx_{0}; + /*! \brief The arguments used by a wrapped external function. */ + std::vector subgraph_args_; + /*! \brief statement of the external function. */ + std::vector subgraph_body; + /*! \brief The declaration of intermeidate buffers. */ + std::vector buf_decl_; + /*! \brief The name of the the outputs. */ + std::vector> out_; + + /*! + * \brief Check if a call has the provided name. + * + * \param call A Relay call node. + * \param op_name The name of the expected call. + * + * \return true if the call's name is equivalent to the given name. Otherwise, + * false. + */ + bool IsOp(const CallNode* call, std::string op_name) const { + const auto* op_node = call->op.as(); + CHECK(op_node) << "Expects a single op."; + Op op = GetRef(op_node); + return op == Op::Get(op_name); + } +}; + +/*! + * \brief The DNNL codegen helper to generate wrapepr function calls of DNNL + * libraries. The code is a CSourceModule that can be compiled separately and + * linked together with a DSOModule. + */ +class DNNLCodegen : public ExternCodegenBase { + public: + // Create a corresponding external function for the given relay Function. + void CreateExternFunction(const Function& func) { + CHECK(func.defined()) + << "Input error: external codegen expects a Relay function."; + const auto* call = func->body.as(); + CHECK(call) << "DNNL expects a single convolution or dense op"; + + // Record subgraph ID for runtime invoke. + auto sid = GetSubgraphID(func, "dnnl"); + + auto builder = DnnlBuilder("dnnl_" + sid); + builder.VisitExpr(func->body); + code_stream_ << builder.JIT(); + } + + /*! + * \brief The overridden function that will create a CSourceModule. In order + * to compile the generated C source code, users need to specify the paths to + * some libraries, including some TVM required and dnnl specific ones. To make + * linking simpiler, the DNNL kernels are wrapped in a TVM compatible manner + * and are live under include/tvm/runtime/contrib/dnnl folder. + * + * \param ref A object ref that could be either a Relay function or module. + * + * \return The runtime module that contains C source code. + */ + runtime::Module CreateExternModule(const NodeRef& ref) override { + // Create headers + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "using namespace tvm::runtime::contrib;\n"; + code_stream_ << "\n"; + + if (ref->IsInstance()) { + CreateExternFunction(Downcast(ref)); + } else if (ref->IsInstance()) { + relay::Module mod = Downcast(ref); + for (const auto& it : mod->functions) { + CreateExternFunction(Downcast(it.second)); + } + } else { + LOG(FATAL) << "The input ref is expected to be a Relay function or module" + << "\n"; + } + + // Create a CSourceModule + const auto* pf = runtime::Registry::Get("module.csource_module_create"); + CHECK(pf != nullptr) << "Cannot find csource module to create the external function"; + return (*pf)(code_stream_.str(), "cc"); + } + + private: + /*! \brief The code stream that prints the external functions. */ + std::ostringstream code_stream_; +}; + +/*! + * \brief The external compiler/codegen tool. It takes a Relay expression/module and + * compile it into a runtime module. + */ +runtime::Module DNNLCompiler(const NodeRef& ref) { + DNNLCodegen dnnl; + return dnnl.CreateExternModule(ref); +} + +TVM_REGISTER_API("relay.ext.dnnl") +.set_body_typed(DNNLCompiler); + +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/contrib/gcc/codegen.cc b/src/relay/backend/contrib/gcc/codegen.cc new file mode 100644 index 000000000000..0530dec9ae79 --- /dev/null +++ b/src/relay/backend/contrib/gcc/codegen.cc @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include + +#include +#include + +#include "../contrib_codegen.h" + +namespace tvm { +namespace relay { +namespace contrib { + +/*! + * \brief An example codegen that is only used for quick prototyping and testing + * purpose. Only several binary options are covered in the GCC builder. Users + * may need to extend them to cover more operators. + */ +class GccBuilder : public ExprVisitor, public ExternSourcePrinter { + public: + explicit GccBuilder(const std::string& id) { this->subgraph_id_ = id; } + + void VisitExpr_(const VarNode* node) { + subgraph_args_.push_back(node->name_hint()); + out_.clear(); + out_.push_back({node->name_hint(), 0}); + } + + void VisitExpr_(const CallNode* call) final { + std::ostringstream macro_stream; + std::ostringstream decl_stream; + std::ostringstream buf_stream; + + auto op_node = call->op.as(); + std::string func_name = subgraph_id_ + "_" + std::to_string(func_idx++); + + // Make function declaration + macro_stream << "GCC_BINARY_OP_" << call->args.size() << "D(" << func_name << ", "; + + if (GetRef(op_node) == Op::Get("add")) { + macro_stream << "+"; + } else if (GetRef(op_node) == Op::Get("subtract")) { + macro_stream << "-"; + } else if (GetRef(op_node) == Op::Get("multiply")) { + macro_stream << "*"; + } else { + LOG(FATAL) << "Unrecognized op"; + } + + auto in_shape = GetShape(call->args[0]->checked_type()); + for (size_t i = 0; i < in_shape.size(); ++i) { + macro_stream << ", " << in_shape[i]; + } + macro_stream << ");"; + func_decl_.push_back(macro_stream.str()); + + // Make function call when visiting arguments + bool first = true; + decl_stream << func_name << "("; + for (size_t i = 0; i < call->args.size(); ++i) { + VisitExpr(call->args[i]); + for (auto out : out_) { + if (!first) { + decl_stream << ", "; + } + first = false; + decl_stream << out.first; + } + } + + auto type_node = call->checked_type().as(); + CHECK(type_node != nullptr && runtime::TypeMatch(type_node->dtype, kDLFloat, 32)) + << "Only support single output tensor with float type"; + std::string out = "buf_" + std::to_string(buf_idx_++); + auto out_shape = GetShape(call->checked_type()); + int out_size = 1; + for (size_t i = 0; i < out_shape.size(); ++i) { + out_size *= out_shape[i]; + } + buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");"; + buf_decl_.push_back(buf_stream.str()); + + decl_stream << ", " << out << ");"; + subgraph_body.push_back(decl_stream.str()); + + // Update output buffer + out_.clear(); + out_.push_back({out, out_size}); + } + + /*! + * \brief Emit the source code that invokes gcc compatible wrappers. + * + * \return The emitted code. + */ + std::string JIT() { + // Write function macros + for (auto decl : func_decl_) { + code_stream_ << decl << "\n"; + } + return JitImpl(subgraph_id_, subgraph_args_, buf_decl_, subgraph_body, out_); + } + + private: + /*! \brief The subgraph id that represents an GCC external function. */ + std::string subgraph_id_ = ""; + /*! \brief The index of an external function. */ + int func_idx = 0; + /*! \brief The index of allocated buffers. */ + int buf_idx_ = 0; + /*! \brief The arguments of a GCC compatible external function. */ + std::vector subgraph_args_; + /*! \brief The statements of a GCC compatible external function. */ + std::vector subgraph_body; + /*! \brief The declaration statements of a GCC compatible external function. */ + std::vector func_decl_; + /*! \brief The declaration statements of buffers. */ + std::vector buf_decl_; + /*! \brief The name and index pairs for output. */ + std::vector> out_; +}; + +class GccCodegen : public ExternCodegenBase { + public: + void CreateExternFunction(const Function& func) { + CHECK(func.defined()) + << "Input error: external codegen expects a Relay function."; + + // Record subgraph ID for runtime invoke. + auto sid = GetSubgraphID(func, "gcc"); + + auto builder = GccBuilder("gcc_" + sid); + builder.VisitExpr(func->body); + code_stream_ << builder.JIT(); + } + + runtime::Module CreateExternModule(const NodeRef& ref) override { + // Create headers + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + + // Append some common macro for operator definition. + const char* operator_macro = R"op_marco( + #define GCC_BINARY_OP_1D(p_ID_, p_OP_, p_DIM1_) \ + extern "C" void p_ID_(float* a, float* b, float* out) { \ + for (int64_t i = 0; i < p_DIM1_; ++i) { \ + out[i] = a[i] p_OP_ b[i]; \ + } \ + } + + #define GCC_BINARY_OP_2D(p_ID_, p_OP_, p_DIM1_, p_DIM2_) \ + extern "C" void p_ID_(float* a, float* b, float* out) { \ + for (int64_t i = 0; i < p_DIM1_; ++i) { \ + for (int64_t j = 0; j < p_DIM2_; ++j) { \ + int64_t k = i * p_DIM2_ + j; \ + out[k] = a[k] p_OP_ b[k]; \ + } \ + } \ + } + )op_marco"; + + code_stream_ << operator_macro << "\n\n"; + + if (ref->IsInstance()) { + CreateExternFunction(Downcast(ref)); + } else if (ref->IsInstance()) { + relay::Module mod = Downcast(ref); + for (const auto& it : mod->functions) { + CreateExternFunction(Downcast(it.second)); + } + } else { + LOG(FATAL) << "The input ref is expected to be a Relay function or module" + << "\n"; + } + + // Create a CSourceModule + const auto* pf = runtime::Registry::Get("module.csource_module_create"); + CHECK(pf != nullptr) << "Cannot find csource module to create the external function"; + return (*pf)(code_stream_.str(), "cc"); + } + + private: + std::ostringstream code_stream_; +}; + +/*! + * \brief The external compiler/codegen tool. It takes a Relay expression/module and + * compile it into a runtime module. + * + * The external codegen tool should have been registered similiarly to LLVM, + * CUDA, etc, under TVM, so the generated code could be packed in a runtime + * module. This module simplifies code serialization and invocation. + */ +runtime::Module GccCompiler(const NodeRef& ref) { + GccCodegen gcc; + return gcc.CreateExternModule(ref); +} + +TVM_REGISTER_API("relay.ext.gcc") +.set_body_typed(GccCompiler); + +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index e2881785766c..cf5f26fedfa7 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include @@ -55,6 +56,7 @@ using TargetsMap = std::unordered_map; struct LoweredOutput { std::string graph_json; Map > lowered_funcs; + Array external_funcs; std::unordered_map params; }; @@ -212,6 +214,7 @@ class GraphRuntimeCodegen LoweredOutput ret; ret.graph_json = os.str(); ret.params = params_; + ret.external_funcs = external_funcs_; for (auto& kv : lowered_funcs_) { if (ret.lowered_funcs.count(kv.first) == 0) { ret.lowered_funcs.Set(kv.first, Array()); @@ -380,6 +383,28 @@ class GraphRuntimeCodegen } return fields; } + + std::vector InvokeExternalCodegen(const CallNode* op, const Function& func) { + CHECK(func->IsExternal()); + std::vector inputs; + for (auto arg : op->args) { + auto res = VisitExpr(arg); + for (auto nr : res) { + inputs.push_back(nr); + } + } + external_funcs_.push_back(func); + const auto name_node = FunctionGetAttr(func, attr::kFuncName).as(); + CHECK(name_node != nullptr) << "External function has not been attached a name yet."; + std::string op_name = name_node->value; + auto node = GraphOpNode::make_node_ptr(_GetUniqueName(op_name), + GraphAttrs(), + op_name, + inputs, + GraphAttrs()); + return AddNode(node, GetRef(op)); + } + std::vector VisitExpr_(const CallNode* op) override { Expr expr = GetRef(op); Function func; @@ -390,6 +415,9 @@ class GraphRuntimeCodegen LOG(FATAL) << "Not implemented"; } else if (op->op.as()) { func = GetRef(op->op.as()); + if (func->IsExternal()) { + return InvokeExternalCodegen(op, func); + } } else { LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey(); } @@ -470,7 +498,7 @@ class GraphRuntimeCodegen return {}; } std::vector VisitExpr_(const FunctionNode* op) override { - throw std::invalid_argument("function not supported"); + CHECK(op->IsExternal()) << "Only external function is supported"; return {}; } std::vector VisitExpr_(const RefCreateNode* op) override { @@ -587,6 +615,8 @@ class GraphRuntimeCodegen std::unordered_map name_map_; /*! \brief compile engine */ CompileEngine compile_engine_; + /*! \brief external functions */ + Array external_funcs_; }; class GraphRuntimeCodegenModule : public runtime::ModuleNode { @@ -628,7 +658,6 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode { } *rv = ret; }); - } else if (name == "get_param_by_name") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { std::string key = args[0]; @@ -639,6 +668,35 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->output_.lowered_funcs; }); + } else if (name == "get_external_funcs") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->output_.external_funcs; + }); + } else if (name == "get_external_module") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK(!this->output_.external_funcs.empty()) << "No external function is annotated."; + // Invoke the external codegen to generate a external runtime module. + auto compiler = FunctionGetAttr(output_.external_funcs[0], attr::kExternal); + const tvm::ir::StringImm* code_gen = compiler.as(); + CHECK(code_gen) << "No external codegen is set"; + std::string ext_name = "relay.ext." + code_gen->value; + auto pf = tvm::runtime::Registry::Get(ext_name); + CHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n"; + + // Invoke the 3rd party codegen to generate a library for the external + // functions. + relay::Module rly_mod = relay::ModuleNode::make({}, {}); + for (const auto& func : output_.external_funcs) { + auto ext_func_name = FunctionGetAttr(func, attr::kFuncName); + const tvm::ir::StringImm* func_name = ext_func_name.as(); + CHECK(func_name) << "No external function name is set for:\n" << AsText(func, false); + auto gv = GlobalVarNode::make(func_name->value); + rly_mod->Add(gv, func); + } + runtime::Module ext_mod = (*pf)(rly_mod); + CHECK(ext_mod.defined()) << "No external runtime is generated."; + *rv = ext_mod; + }); } else { return PackedFunc([](TVMArgs args, TVMRetValue* rv) {}); } diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index 7f21defc9d12..c841f87dd836 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -37,21 +37,19 @@ namespace tvm { namespace relay { namespace vm { -static const char* kIsClosure = "IsClosure"; - inline std::string GenerateName(const Function& func) { size_t hash = StructuralHash()(func); return std::string("lifted_name") + std::to_string(hash); } bool IsClosure(const Function& func) { - NodeRef res = FunctionGetAttr(func, kIsClosure); + NodeRef res = FunctionGetAttr(func, attr::kClosure); const ir::IntImm* pval = res.as(); return pval && pval->value != 0; } Function MarkClosure(const Function& func) { - return FunctionSetAttr(func, kIsClosure, tvm::Integer(1)); + return FunctionSetAttr(func, attr::kClosure, tvm::Integer(1)); } /* The goal of this class is to lift out any nested functions into top-level diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 47e735f20fc8..3673e9d4449b 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -157,13 +157,13 @@ FuncType FunctionNode::func_type_annotation() const { } bool FunctionNode::IsPrimitive() const { - NodeRef res = FunctionGetAttr(GetRef(this), "Primitive"); + NodeRef res = FunctionGetAttr(GetRef(this), attr::kPrimitive); const ir::IntImm* pval = res.as(); return pval && pval->value != 0; } Function FunctionNode::SetParams(const tvm::Map& parameters) const { - return FunctionSetAttr(GetRef(this), "__params__", parameters); + return FunctionSetAttr(GetRef(this), attr::kParams, parameters); } TVM_REGISTER_API("relay._expr.FunctionSetParams") @@ -173,7 +173,7 @@ TVM_REGISTER_API("relay._expr.FunctionSetParams") }); tvm::Map FunctionNode::GetParams() const { - auto node_ref = FunctionGetAttr(GetRef(this), "__params__"); + auto node_ref = FunctionGetAttr(GetRef(this), attr::kParams); return Downcast>(node_ref); } @@ -182,6 +182,12 @@ TVM_REGISTER_API("relay._expr.FunctionGetParams") return func->GetParams(); }); +bool FunctionNode::IsExternal() const { + NodeRef res = FunctionGetAttr(GetRef(this), attr::kExternal); + const ir::StringImm* pval = res.as(); + return pval; +} + NodeRef FunctionGetAttr(const Function& func, const std::string& key) { if (!func->attrs.defined()) { return NodeRef(); } diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index f5674fa06adb..76525071006b 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -171,5 +171,51 @@ Mark a checkpoint for checkpointing memory optimization. return outputs; }); +RELAY_REGISTER_OP("annotation.subgraph_begin") +.describe(R"code(Begin region of a subgraph.)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.set_support_level(10) +.add_type_rel("Identity", IdentityRel) +.set_attr("TOpPattern", kOpaque) +.set_attr("TOpIsStateful", false) +.set_attr("FInferCorrectLayout", + ElemwiseArbitraryLayout) +.set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype, const Target& target) -> Array { + return {topi::identity(inputs[0])}; + }); + +TVM_REGISTER_API("relay.op.annotation._make.subgraph_begin") +.set_body_typed([](Expr expr, std::string compiler) { + auto attrs = make_node(); + attrs->compiler = compiler; + static const Op& op = Op::Get("annotation.subgraph_begin"); + return CallNode::make(op, {expr}, Attrs(attrs), {}); +}); + +RELAY_REGISTER_OP("annotation.subgraph_end") +.describe(R"code(End region of a subgraph.)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.set_support_level(10) +.add_type_rel("Identity", IdentityRel) +.set_attr("TOpPattern", kOpaque) +.set_attr("TOpIsStateful", false) +.set_attr("FInferCorrectLayout", + ElemwiseArbitraryLayout) +.set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype, const Target& target) -> Array { + return {topi::identity(inputs[0])}; + }); + +TVM_REGISTER_API("relay.op.annotation._make.subgraph_end") +.set_body_typed([](Expr expr, std::string compiler) { + auto attrs = make_node(); + attrs->compiler = compiler; + static const Op& op = Op::Get("annotation.subgraph_end"); + return CallNode::make(op, {expr}, Attrs(attrs), {}); +}); + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/extern_op.cc b/src/relay/pass/extern_op.cc new file mode 100644 index 000000000000..e63b506f41be --- /dev/null +++ b/src/relay/pass/extern_op.cc @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/pass/extern_op.cc + * \brief Wraps a call with subgraph_begin and subgraph_end to indicate that + * the op of this call node will use external compiler. + */ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace extern_op { + +// A helper class to insert annotation boundaries for subgraphs. +class ExternOpWrapper : public ExprMutator { + public: + explicit ExternOpWrapper(const std::string& compiler) : compiler_(compiler) {} + + Expr VisitExpr_(const CallNode* cn) { + auto new_e = ExprMutator::VisitExpr_(cn); + + Call call = Downcast(new_e); + static auto fextern = Op::GetAttr("FTVMExternOp"); + Op op = Downcast(call->op); + CHECK(op.operator->()); + + if (fextern.count(op)) { + bool external = fextern[op](call->attrs, call->args, compiler_); + if (external) { + tvm::Array subgraph_begins; + for (const auto& it : call->args) { + const auto* begin_op = + runtime::Registry::Get("relay.op.annotation._make.subgraph_begin"); + CHECK(begin_op); + Expr begin = (*begin_op)(it, compiler_); + subgraph_begins.push_back(begin); + } + Expr update_call = CallNode::make(call->op, subgraph_begins, call->attrs); + const auto* end_op = + runtime::Registry::Get("relay.op.annotation._make.subgraph_end"); + CHECK(end_op); + Expr end = (*end_op)(update_call, compiler_); + return end; + } + } else { + LOG(WARNING) << op.operator->()->name << " in " << compiler_ << " is not registered"; + } + return new_e; + } + + private: + std::string compiler_; +}; + +Expr ExternOp(const Expr& expr, const std::string& compiler) { + return ExternOpWrapper(compiler).Mutate(expr); +} + +} // namespace extern_op + +namespace transform { + +Pass ExternOp(const std::string& compiler) { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(relay::extern_op::ExternOp(f, compiler)); + }; + auto func_pass = CreateFunctionPass(pass_func, 1, "ExternOpFunc", + {ir::StringImm::make("InferType")}); + return transform::Sequential({func_pass, InferType()}, "ExternOp"); +} + +TVM_REGISTER_API("relay._transform.ExternOp") +.set_body_typed(ExternOp); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 904d24657cad..9aba1aca9a5b 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -239,7 +239,8 @@ class IndexedForwardGraph::Creator : private ExprVisitor { // Finally if the operator position is not a call node we will // need to call Update, as it may be an arbitrary expression. OpPatternKind op_pattern = kOpaque; - if (const OpNode* opnode = call->op.as()) { + const OpNode* opnode = call->op.as(); + if (opnode != nullptr && call->op != Op::Get("nn.batch_norm")) { op_pattern = static_cast(fpattern[GetRef(opnode)]); } else { this->Update(call->op, node, kOpaque); @@ -932,7 +933,7 @@ class FuseMutator : private ExprMutator { visitor(body); const GroupInfo& ginfo = ginfo_[group]; auto func = FunctionNode::make(ginfo.params, body, ret_type, {}); - func = FunctionSetAttr(func, "Primitive", tvm::Integer(visitor.has_call)); + func = FunctionSetAttr(func, attr::kPrimitive, tvm::Integer(visitor.has_call)); return CallNode::make(func, ginfo.arguments, Attrs()); } diff --git a/src/relay/pass/partition_graph.cc b/src/relay/pass/partition_graph.cc new file mode 100644 index 000000000000..f873c322c288 --- /dev/null +++ b/src/relay/pass/partition_graph.cc @@ -0,0 +1,398 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file src/relay/pass/partition_graph.cc + * + * \brief Partition an input function into multiple Functions according based + * on the inserted annotation nodes (i.e. begin and end). These nodes are used + * as boundaries to partition the Relay function into multiple regions that can + * be offloaded to different accelerators. + * + * Each of these paritioned functions, a.k.a subgraphs, will be viewed as + * external functions, and they will use external tools for codegen. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace graph_partitioning { + +/*! + * \brief The subgraph properties for partition. + */ +struct Subgraph { + /*! \brief The subgraph ID. */ + int id; + + /*! \brief The input arguments of this subgraph. */ + std::vector> args; + + /*! \brief Nodes in this subgraph. */ + std::unordered_set nodes; +}; + +/*! + * \brief The checker that verifies if a Relay program is annotated correctly + * for graph partitioning. + */ +class AnnotationChecker : public ExprVisitor { + public: + bool Check() { + if (!this->found_start && !this->found_end) { + LOG(WARNING) << "No subgraph annotation found"; + } else if (!this->found_start) { + LOG(ERROR) << "Subgraph start annotation is missing"; + return false; + } else if (!this->found_end) { + LOG(ERROR) << "Subgraph end annotation is missing"; + return false; + } + return true; + } + + void VisitExpr_(const CallNode* call) final { + auto op_node = call->op.as(); + if (op_node == nullptr || call->attrs.as() == nullptr) { + return; + } else if (GetRef(op_node) == Op::Get("annotation.subgraph_begin")) { + this->found_start = true; + } else if (GetRef(op_node) == Op::Get("annotation.subgraph_end")) { + this->found_end = true; + } + } + + private: + bool found_start = false; + bool found_end = false; +}; + +/*! \brief This class partitions the graph labeled with begin and end annoations + * into function containing multiple subgraphs. Each subgraph is labeled as + * external. + * + * TODO(@zhiics) This following algorithm is not adequate to handle all cases, + * i.e. multiple `end` nodes. + */ +class Partitioner : public ExprMutator { + public: + Subgraph* GetSubgraph(const Expr node) { + for (auto candidate : this->subgraphs_) { + if (candidate->nodes.find(node) != candidate->nodes.end()) { + return candidate; + } + } + return nullptr; + } + + void MergeSubgraph(Subgraph* subgraph1, Subgraph* subgraph2) { + if (subgraph1 == subgraph2) { + return; + } + + // Merge subgraph 2 to subgraph 1 and erase subgraph 2. + subgraph1->nodes.insert(subgraph2->nodes.begin(), subgraph2->nodes.end()); + for (auto arg : subgraph2->args) { + subgraph1->args.push_back(arg); + } + this->subgraphs_.erase(subgraph2); + } + + void AddToSubgraph(Subgraph* subgraph, const Expr expr) { + auto subgraph2 = GetSubgraph(expr); + if (subgraph2) { + MergeSubgraph(subgraph, subgraph2); + } else { + subgraph->nodes.insert(expr); + } + } + + Expr VisitExpr_(const CallNode* call) final { + auto op_node = call->op.as(); + + if (op_node == nullptr || call->attrs.as() == nullptr) { + // Propogate subgraph to arguments + auto subgraph = GetSubgraph(GetRef(call)); + if (subgraph) { + for (auto arg : call->args) { + AddToSubgraph(subgraph, arg); + } + } + return ExprMutator::VisitExpr_(call); + } else if (GetRef(op_node) == Op::Get("annotation.subgraph_begin")) { + // The annotation node is inserted on edge so it must have only one argument. + CHECK_EQ(call->args.size(), 1U); + + // Traverse the rest graph. + auto input_expr = VisitExpr(call->args[0]); + + // Replace the begin annotation with an external call input variable. + auto subgraph_attrs = call->attrs.as(); + auto var = VarNode::make(subgraph_attrs->compiler + "_input" + std::to_string(var_id_++), + input_expr->checked_type_); + + // Find the corresponding subgraph and add the argument. + auto subgraph = GetSubgraph(GetRef(call)); + if (!subgraph) { + throw Error(RELAY_ERROR("Cannot find the corresponding subgraph for start annotation:\n" + << AsText(GetRef(call), false))); + } + subgraph->args.push_back({var, input_expr}); + return std::move(var); + } else { + CHECK(GetRef(op_node) == Op::Get("annotation.subgraph_end")); + // The annotation node is inserted on edge so it must have only one argument. + CHECK_EQ(call->args.size(), 1U); + + auto subgraph_attrs = call->attrs.as(); + + // Check if the argument is already belonged to an exist subgraph + auto subgraph = GetSubgraph(call->args[0]); + if (!subgraph) { + auto ret = this->subgraphs_.emplace(new Subgraph()); + subgraph = *ret.first; + subgraph->nodes.insert(call->args[0]); + subgraph->id = this->subgraph_id_++; + } + subgraph->nodes.insert(GetRef(call)); + + // Traverse towarding to subgraph inputs. + auto input = VisitExpr(call->args[0]); + Array params; + Array args; + + // The subgraph may be merged so we need to update it again. + subgraph = GetSubgraph(GetRef(call)); + for (auto pair : subgraph->args) { + params.push_back(pair.first); + args.push_back(pair.second); + } + + auto subgraph_func = + FunctionNode::make(params, input, call->args[0]->checked_type_, {}, Attrs()); + + Expr arg0 = call->args[0]; + std::string name = subgraph_attrs->compiler + "_" + std::to_string(subgraph->id); + subgraph_func = + FunctionSetAttr(subgraph_func, attr::kFuncName, tvm::ir::StringImm::make(name)); + subgraph_func = FunctionSetAttr(subgraph_func, attr::kPrimitive, tvm::Integer(1)); + subgraph_func = FunctionSetAttr(subgraph_func, attr::kExternal, + tvm::ir::StringImm::make(subgraph_attrs->compiler)); + return CallNode::make(subgraph_func, args); + } + } + + Expr VisitExpr_(const TupleNode* op) final { + auto subgraph = GetSubgraph(GetRef(op)); + if (!subgraph) { + return ExprMutator::VisitExpr_(op); + } else { + for (auto field : op->fields) { + AddToSubgraph(subgraph, field); + } + Array fields; + for (auto field : op->fields) { + fields.push_back(VisitExpr(field)); + } + return TupleNode::make(fields); + } + } + + Expr VisitExpr_(const TupleGetItemNode* g) final { + auto subgraph = GetSubgraph(GetRef(g)); + if (!subgraph) { + return ExprMutator::VisitExpr_(g); + } else { + AddToSubgraph(subgraph, g->tuple); + auto t = VisitExpr(g->tuple); + return TupleGetItemNode::make(t, g->index); + } + } + + Expr VisitExpr_(const FunctionNode* op) final { + auto subgraph = GetSubgraph(GetRef(op)); + if (!subgraph) { + return ExprMutator::VisitExpr_(op); + } else { + Array params; + for (auto param : op->params) { + AddToSubgraph(subgraph, param); + } + for (auto param : op->params) { + Var new_param = Downcast(VisitExpr(param)); + params.push_back(new_param); + } + auto body = VisitExpr(op->body); + return FunctionNode::make(params, body, op->ret_type, op->type_params, op->attrs); + } + } + + Expr VisitExpr_(const LetNode* op) final { + auto subgraph = GetSubgraph(GetRef(op)); + if (!subgraph) { + return ExprMutator::VisitExpr_(op); + } else { + AddToSubgraph(subgraph, op->var); + AddToSubgraph(subgraph, op->value); + AddToSubgraph(subgraph, op->body); + Var var = Downcast(VisitExpr(op->var)); + auto value = VisitExpr(op->value); + auto body = VisitExpr(op->body); + + return LetNode::make(var, value, body); + } + } + + Expr VisitExpr_(const IfNode* op) final { + auto subgraph = GetSubgraph(GetRef(op)); + if (!subgraph) { + return ExprMutator::VisitExpr_(op); + } else { + AddToSubgraph(subgraph, op->cond); + AddToSubgraph(subgraph, op->true_branch); + AddToSubgraph(subgraph, op->false_branch); + auto guard = VisitExpr(op->cond); + auto true_b = VisitExpr(op->true_branch); + auto false_b = VisitExpr(op->false_branch); + return IfNode::make(guard, true_b, false_b); + } + } + + Expr VisitExpr_(const RefCreateNode* op) final { + auto subgraph = GetSubgraph(GetRef(op)); + if (!subgraph) { + return ExprMutator::VisitExpr_(op); + } else { + AddToSubgraph(subgraph, op->value); + Expr value = VisitExpr(op->value); + return RefCreateNode::make(value); + } + } + + Expr VisitExpr_(const RefReadNode* op) final { + auto subgraph = GetSubgraph(GetRef(op)); + if (!subgraph) { + return ExprMutator::VisitExpr_(op); + } else { + AddToSubgraph(subgraph, op->ref); + Expr ref = VisitExpr(op->ref); + return RefReadNode::make(ref); + } + } + + Expr VisitExpr_(const RefWriteNode* op) final { + auto subgraph = GetSubgraph(GetRef(op)); + if (!subgraph) { + return ExprMutator::VisitExpr_(op); + } else { + AddToSubgraph(subgraph, op->ref); + Expr ref = VisitExpr(op->ref); + Expr value = VisitExpr(op->value); + return RefWriteNode::make(ref, value); + } + } + + private: + int var_id_{0}; + int subgraph_id_{0}; + std::unordered_set subgraphs_; +}; + +/*! + * \brief TODO(@zhiics, @comaniac) Combine parallel subgraphs that belong to + * the same codegen backend. This reduces rounds trips between TVM and external + * backends. + * + * For example, sg1 and sg2 should be combined if they belong to the same + * codegen tool in the following case. + * + * op1 + * / \ + * sg1 sg2 + * + * | + * \|/ + * + * op1 + * | + * sg1_sg2 + * + * where the return type of the new subgraph sg1_sg2 is a tuple, and op1 has two + * inputs that obtained from the tuple. + */ +class ParallelSubgraphCombiner : public ExprMutator { + using ParallelGroup = std::vector>; + + public: + Expr Combine(const Expr& expr) { + ParallelGroup groups = GroupFinder().FindGroups(expr); + return expr; + } + + private: + class GroupFinder : public ExprVisitor { + public: + ParallelGroup FindGroups(const Expr& expr) { + this->VisitExpr(expr); + return groups_; + } + + void VisitExpr_(const CallNode* call) final { ExprVisitor::VisitExpr_(call); } + + private: + ParallelGroup groups_; + }; +}; + +Expr PartitionGraph(const Expr& expr) { + Partitioner part; + return part.Mutate(expr); +} + +} // namespace graph_partitioning + +namespace transform { + +Pass PartitionGraph() { + runtime::TypedPackedFunc part_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(graph_partitioning::PartitionGraph(f)); + }; + auto partitioned = CreateFunctionPass(part_func, 1, "PartitionGraph", {}); + return Sequential({partitioned, InferType()}); +} + +TVM_REGISTER_API("relay._transform.PartitionGraph") +.set_body_typed(transform::PartitionGraph); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index b025d3787f9e..fd834a679a93 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -329,12 +329,12 @@ Module FunctionPassNode::operator()(const Module& mod, return updated_mod; } -// TODO(zhiics) Create an enum attribute for FunctionNode -// enum Attribute {kPrimitive, kSkipOptimization} bool FunctionPassNode::SkipFunction(const Function& func) const { - NodeRef res = FunctionGetAttr(func, "SkipOptimization"); - const ir::IntImm* pval = res.as(); - return pval && pval->value != 0; + NodeRef skip_opt = FunctionGetAttr(func, attr::kSkipOptimization); + NodeRef ext = FunctionGetAttr(func, attr::kExternal); + const ir::IntImm* pval = skip_opt.as(); + const ir::StringImm* sval = ext.as(); + return (pval && pval->value != 0) || (sval && sval->value.size() > 0); } Sequential::Sequential(tvm::Array passes, PassInfo pass_info) { diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc new file mode 100644 index 000000000000..1412a9bc0dae --- /dev/null +++ b/src/runtime/contrib/dnnl/dnnl.cc @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/dnnl/dnnl.cc + * \brief TVM compatible wrappers for dnnl kernels. + */ + +#include + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace contrib { + +using namespace dnnl; + +typedef struct { + void** data; +} DnnlPackedArgs; + +// Read from memory, write to handle +inline void read_from_dnnl_memory(void* handle, const memory& mem) { + size_t bytes = mem.get_desc().get_size(); + + uint8_t* src = static_cast(mem.get_data_handle()); + std::copy(src, src + bytes, reinterpret_cast(handle)); +} + +extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, + int p_C_, int p_H_, int p_W_, int p_O_, int p_G_, + int p_Ph_, int p_Pw_, int p_Kh_, int p_Kw_, + int p_Sh_, int p_Sw_) { + using tag = memory::format_tag; + using dt = memory::data_type; + engine eng(engine::kind::cpu, 0); + stream s(eng); + + memory::dims conv2d_src_tz = {p_N_, p_C_, p_H_, p_W_}; + memory::dims conv2d_weights_tz = {p_O_, p_C_, p_Kh_, p_Kw_}; + if (p_G_ > 1) conv2d_weights_tz = {p_G_, 1, p_C_ / p_G_, p_Kh_, p_Kw_}; + memory::dims conv2d_bias_tz = {p_O_}; + memory::dims conv2d_dst_tz = {p_N_, p_O_, + (p_H_ - p_Kh_ + 2 * p_Ph_ + p_Sh_) / p_Sh_, + (p_W_ - p_Kw_ + 2 * p_Pw_ + p_Sw_) / p_Sw_}; + memory::dims conv2d_strides = {p_Sh_, p_Sw_}; + memory::dims conv2d_padding = {p_Ph_, p_Pw_}; + + std::vector conv2d_bias(p_O_, 0); + + auto user_src_memory = + memory({{conv2d_src_tz}, dt::f32, tag::nchw}, eng, data); + auto user_weights_memory = memory( + {{conv2d_weights_tz}, dt::f32, (p_G_ > 1) ? tag::goihw : tag::oihw}, eng, + weights); + auto conv2d_user_bias_memory = + memory({{conv2d_bias_tz}, dt::f32, tag::x}, eng, conv2d_bias.data()); + + auto conv2d_src_md = memory::desc({conv2d_src_tz}, dt::f32, tag::any); + auto conv2d_bias_md = memory::desc({conv2d_bias_tz}, dt::f32, tag::any); + auto conv2d_weights_md = memory::desc({conv2d_weights_tz}, dt::f32, tag::any); + auto conv2d_dst_md = memory::desc({conv2d_dst_tz}, dt::f32, tag::nchw); + + auto conv2d_desc = convolution_forward::desc( + prop_kind::forward_inference, algorithm::convolution_direct, + conv2d_src_md, conv2d_weights_md, conv2d_bias_md, conv2d_dst_md, + conv2d_strides, conv2d_padding, conv2d_padding); + auto conv2d_prim_desc = convolution_forward::primitive_desc(conv2d_desc, eng); + + auto conv2d_src_memory = user_src_memory; + auto conv2d_weights_memory = user_weights_memory; + auto conv2d_dst_memory = memory(conv2d_prim_desc.dst_desc(), eng); + + auto conv = convolution_forward(conv2d_prim_desc); + conv.execute(s, {{DNNL_ARG_SRC, conv2d_src_memory}, + {DNNL_ARG_WEIGHTS, conv2d_weights_memory}, + {DNNL_ARG_BIAS, conv2d_user_bias_memory}, + {DNNL_ARG_DST, conv2d_dst_memory}}); + s.wait(); + read_from_dnnl_memory(out, conv2d_dst_memory); +} + +extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, + int p_I_, int p_O_) { + using tag = memory::format_tag; + using dt = memory::data_type; + + engine eng(engine::kind::cpu, 0); + stream s(eng); + + memory::dims data_tz = {p_B_, p_I_}; + memory::dims weight_tz = {p_O_, p_I_}; + memory::dims bias_tz = {p_O_}; + memory::dims dst_tz = {p_B_, p_O_}; + + auto data_md = memory::desc{{data_tz}, dt::f32, tag::nc}; + auto weight_md = memory::desc({{weight_tz}, dt::f32, tag::nc}); + auto bias_md = memory::desc({{bias_tz}, dt::f32, tag::x}); + auto dst_md = memory::desc({{dst_tz}, dt::f32, tag::nc}); + + std::vector bias(p_O_, 0); + auto data_memory = memory(data_md, eng, data); + auto weight_memory = memory(weight_md, eng, weight); + auto bias_memory = memory(bias_md, eng, bias.data()); + auto dst_memory = memory(dst_md, eng); + + auto dense_desc = inner_product_forward::desc( + prop_kind::forward_inference, data_md, weight_md, bias_md, dst_md); + auto dense_prim_desc = inner_product_forward::primitive_desc(dense_desc, eng); + assert(dst_md == dense_prim_desc.dst_desc()); + + auto dense = inner_product_forward(dense_prim_desc); + dense.execute(s, {{DNNL_ARG_SRC, data_memory}, + {DNNL_ARG_WEIGHTS, weight_memory}, + {DNNL_ARG_BIAS, bias_memory}, + {DNNL_ARG_DST, dst_memory}}); + s.wait(); + read_from_dnnl_memory(out, dst_memory); +} + +extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, + int p_W_) { + using tag = memory::format_tag; + using dt = memory::data_type; + + engine eng(engine::kind::cpu, 0); + stream s(eng); + + memory::dims data_tz = {p_N_, p_C_, p_H_, p_W_}; + + auto data_md = memory::desc{{data_tz}, dt::f32, tag::nchw}; + + auto data_memory = memory(data_md, eng, data); + auto dst_memory = memory(data_md, eng); + + auto relu_desc = eltwise_forward::desc(prop_kind::forward_inference, + algorithm::eltwise_relu, data_md, 0); + auto relu_prim_desc = eltwise_forward::primitive_desc(relu_desc, eng); + assert(data_md == relu_prim_desc.dst_desc()); + + auto relu = eltwise_forward(relu_prim_desc); + relu.execute(s, {{DNNL_ARG_SRC, data_memory}, {DNNL_ARG_DST, dst_memory}}); + s.wait(); + read_from_dnnl_memory(out, dst_memory); +} + +extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, + float* variance, float* out, int p_N_, int p_C_, + int p_H_, int p_W_, int p_E_) { + using tag = memory::format_tag; + using dt = memory::data_type; + + engine eng(engine::kind::cpu, 0); + stream s(eng); + + memory::dims data_tz = {p_N_, p_C_, p_H_, p_W_}; + + auto data_md = memory::desc{{data_tz}, dt::f32, tag::nchw}; + + auto data_memory = memory(data_md, eng, data); + auto dst_memory = memory(data_md, eng); + + auto bn_desc = batch_normalization_forward::desc( + prop_kind::forward_inference, data_md, p_E_, + normalization_flags::use_global_stats | + normalization_flags::use_scale_shift); + auto bn_prim_desc = batch_normalization_forward::primitive_desc(bn_desc, eng); + assert(data_md == bn_prim_desc.dst_desc()); + + float* weight = reinterpret_cast(malloc(sizeof(float) * 2 * p_C_)); + memcpy(weight, gamma, sizeof(float) * p_C_); + memcpy(weight + p_C_, beta, sizeof(float) * p_C_); + + auto weight_memory = memory(bn_prim_desc.weights_desc(), eng, weight); + auto mean_memory = memory(bn_prim_desc.mean_desc(), eng, mean); + auto variance_memory = memory(bn_prim_desc.variance_desc(), eng, variance); + + auto bn = batch_normalization_forward(bn_prim_desc); + bn.execute(s, {{DNNL_ARG_SRC, data_memory}, + {DNNL_ARG_DST, dst_memory}, + {DNNL_ARG_SCALE_SHIFT, weight_memory}, + {DNNL_ARG_MEAN, mean_memory}, + {DNNL_ARG_VARIANCE, variance_memory}}); + s.wait(); + read_from_dnnl_memory(out, dst_memory); + free(weight); +} + +extern "C" void dnnl_add(float* data, float* weight, float* out, int p_N_, + int p_C_, int p_H_, int p_W_) { + using tag = memory::format_tag; + using dt = memory::data_type; + + engine eng(engine::kind::cpu, 0); + stream s(eng); + + memory::dims data_tz = {p_N_, p_C_, p_H_, p_W_}; + + auto data_md = memory::desc{{data_tz}, dt::f32, tag::nchw}; + auto weight_md = memory::desc({{data_tz}, dt::f32, tag::nchw}); + auto dst_md = memory::desc({{data_tz}, dt::f32, tag::nchw}); + + auto data_memory = memory(data_md, eng, data); + auto weight_memory = memory(weight_md, eng, weight); + auto dst_memory = memory(dst_md, eng); + + auto add_desc = + binary::desc(algorithm::binary_add, data_md, weight_md, dst_md); + auto add_prim_desc = binary::primitive_desc(add_desc, eng); + assert(dst_md == add_prim_desc.dst_desc()); + + auto add = binary(add_prim_desc); + add.execute(s, {{DNNL_ARG_SRC_0, data_memory}, + {DNNL_ARG_SRC_1, weight_memory}, + {DNNL_ARG_DST, dst_memory}}); + s.wait(); + read_from_dnnl_memory(out, dst_memory); +} + +} // namespace contrib +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/module_util.cc b/src/runtime/module_util.cc index 445bfd343653..c20e52414cc3 100644 --- a/src/runtime/module_util.cc +++ b/src/runtime/module_util.cc @@ -49,6 +49,7 @@ void ImportModuleBlob(const char* mblob, std::vector* mlist) { for (uint64_t i = 0; i < size; ++i) { std::string tkey; CHECK(stream->Read(&tkey)); + if (tkey == "c") continue; std::string fkey = "module.loadbinary_" + tkey; const PackedFunc* f = Registry::Get(fkey); CHECK(f != nullptr) diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py new file mode 100644 index 000000000000..13ed3580e09b --- /dev/null +++ b/tests/python/relay/test_pass_partition_graph.py @@ -0,0 +1,306 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for graph partitioning.""" +import numpy as np +import pytest + +import tvm +import tvm.relay.testing +import tvm.relay.transform +from tvm import relay +from tvm.contrib import util +from tvm.relay.annotation import subgraph_begin, subgraph_end +from tvm.relay.expr_functor import ExprMutator + + +class GCCAnnotator(ExprMutator): + """ + A simple annotator that creates the following subgraph: + | + -- begin -- + | + add + | + subtract + | + multiply + | + -- end -- + | + """ + + def __init__(self): + super(GCCAnnotator, self).__init__() + self.in_subgraph = 0 + + def visit_call(self, call): + if call.op.name == "add": # Annotate begin at args + if self.in_subgraph == 1: + lhs = subgraph_begin(super().visit(call.args[0]), "gcc") + rhs = subgraph_begin(super().visit(call.args[1]), "gcc") + op = relay.add(lhs, rhs) + self.in_subgraph = 2 + return op + elif call.op.name == "subtract": + if self.in_subgraph == 1: + lhs = super().visit(call.args[0]) + rhs = super().visit(call.args[1]) + if isinstance(lhs, relay.expr.Var): + lhs = subgraph_begin(lhs, "gcc") + if isinstance(rhs, relay.expr.Var): + rhs = subgraph_begin(rhs, "gcc") + return relay.subtract(lhs, rhs) + elif call.op.name == "multiply": # Annotate end at output + self.in_subgraph = 1 + lhs = super().visit(call.args[0]) + rhs = super().visit(call.args[1]) + if isinstance(lhs, relay.expr.Var): + lhs = subgraph_begin(lhs, "gcc") + if isinstance(rhs, relay.expr.Var): + rhs = subgraph_begin(rhs, "gcc") + op = relay.multiply(lhs, rhs) + if self.in_subgraph == 2: + op = subgraph_end(op, "gcc") + self.in_subgraph = 0 + return op + return super().visit_call(call) + + +class WholeGraphAnnotator(ExprMutator): + """ + An annotator that creates a subgraph for an entire graph. + """ + + def __init__(self, compiler): + super(WholeGraphAnnotator, self).__init__() + self.compiler = compiler + self.last_call = True + + def visit_call(self, call): + curr_last = self.last_call + self.last_call = False + + params = [] + for arg in call.args: + param = super().visit(arg) + if isinstance(param, relay.expr.Var): + param = subgraph_begin(param, self.compiler) + params.append(param) + + new_call = relay.Call(call.op, params, call.attrs) + if curr_last: + new_call = subgraph_end(new_call, self.compiler) + return new_call + + +class MobileNetAnnotator(ExprMutator): + """ + Annotate mobilenet until global_avg_pool. + """ + + def __init__(self, compiler): + super(MobileNetAnnotator, self).__init__() + self.compiler = compiler + self.subgraph_open = False + + def visit_call(self, call): + + if call.op.name == 'nn.global_avg_pool2d': + self.subgraph_open = True + subgraph_open = self.subgraph_open + + params = [] + for arg in call.args: + param = super().visit(arg) + if call.op.name == 'nn.global_avg_pool2d': + param = subgraph_end(param, self.compiler) + if subgraph_open and isinstance(param, relay.expr.Var): + param = subgraph_begin(param, self.compiler) + params.append(param) + + new_call = relay.Call(call.op, params, call.attrs) + return new_call + +def check_result(mod, map_inputs, out_shape, result, tol=1e-7): + with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): + json, lib, _ = relay.build(mod, "llvm") + kwargs = {} + kwargs["options"] = ["-O2", "-std=c++11"] + tmp_path = util.tempdir() + lib_name = 'lib.so' + lib_path = tmp_path.relpath(lib_name) + lib.export_library(lib_path, fcompile=False, **kwargs) + lib = tvm.module.load(lib_path) + + ctx = tvm.cpu() + rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx) + + for name, data in map_inputs.items(): + rt_mod.set_input(name, data) + rt_mod.run() + out = tvm.nd.empty(out_shape, ctx=ctx) + out = rt_mod.get_output(0, out) + + tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol) + + +def test_multi_node_subgraph(): + x = relay.var('x', shape=(10, 10)) + w0 = relay.var('w0', shape=(10, 10)) + w1 = relay.var('w1', shape=(10, 10)) + w2 = relay.var('w2', shape=(10, 10)) + w3 = relay.var('w3', shape=(10, 10)) + w4 = relay.var('w4', shape=(10, 10)) + w5 = relay.var('w5', shape=(10, 10)) + w6 = relay.var('w6', shape=(10, 10)) + w7 = relay.var('w7', shape=(10, 10)) + + # Subgraph on GCC + # FIXME: We generate two subgraphs for this case but they should be merged to one + # due to the common input (x). + z0 = relay.add(x, w0) + p0 = relay.subtract(z0, w1) + q0 = relay.multiply(p0, w2) + + z1 = relay.add(x, w3) + p1 = relay.subtract(z1, w4) + q1 = relay.multiply(p1, w5) + + # Other parts on TVM + z2 = relay.add(x, w6) + q2 = relay.subtract(z2, w7) + + r = relay.concatenate((q0, q1, q2), axis=0) + f = relay.Function([x, w0, w1, w2, w3, w4, w5, w6, w7], r) + mod = relay.Module() + ann = GCCAnnotator() + mod["main"] = ann.visit(f) + mod = relay.transform.PartitionGraph()(mod) + mod = relay.transform.InferType()(mod) + + x_data = np.random.rand(10, 10).astype('float32') + w_data = [] + for _ in range(8): + w_data.append(np.random.rand(10, 10).astype('float32')) + + map_inputs = {"w{}".format(i): w_data[i] for i in range(8)} + map_inputs["x"] = x_data + check_result( + mod, map_inputs, (30, 10), + np.concatenate((((x_data + w_data[0]) - w_data[1]) * w_data[2], + ((x_data + w_data[3]) - w_data[4]) * w_data[5], + x_data + w_data[6] - w_data[7]), + axis=0)) + + +def test_extern_gcc_single_op(): + x = relay.var('x', shape=(8, 8)) + y = relay.var('y', shape=(8, 8)) + z = x + y + f = relay.Function([x, y], z) + x_data = np.random.rand(8, 8).astype('float32') + y_data = np.random.rand(8, 8).astype('float32') + mod = relay.Module() + mod["main"] = f + mod = relay.build_extern(mod, "gcc") + + check_result(mod, {"x": x_data, "y": y_data}, (8, 8), x_data + y_data) + + +def test_extern_gcc(): + x = relay.var('x', shape=(2, 2)) + y = relay.var('y', shape=(2, 2)) + z = x + x + p = y * y + f = relay.Function([x, y], p - z) + x_data = np.random.rand(2, 2).astype('float32') + y_data = np.random.rand(2, 2).astype('float32') + mod = relay.Module() + mod["main"] = f + mod = relay.build_extern(mod, "gcc") + + check_result(mod, {"x": x_data, "y": y_data}, (2, 2), (y_data * y_data) - (x_data + x_data)) + + +@pytest.mark.skip(reason="Only for DEMO purpose") +def test_extern_dnnl(): + dtype = 'float32' + ishape = (1, 32, 14, 14) + w1shape = (32, 1, 3, 3) + data = relay.var('data', shape=(ishape), dtype=dtype) + weight1 = relay.var('weight1', shape=(w1shape), dtype=dtype) + depthwise_conv2d_1 = relay.nn.conv2d(data, + weight1, + kernel_size=(3, 3), + padding=(1, 1), + groups=32) + depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1, + weight1, + kernel_size=(3, 3), + padding=(1, 1), + groups=32) + out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2) + + f = relay.Function([data, weight1], out) + + mod = relay.Module() + mod['main'] = WholeGraphAnnotator('dnnl').visit(f) + mod = relay.transform.PartitionGraph()(mod) + + ref_mod = relay.Module() + ref_mod['main'] = f + + i_data = np.random.uniform(0, 1, ishape).astype(dtype) + w1_data = np.random.uniform(0, 1, w1shape).astype(dtype) + + ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu()) + ref_res = ref_ex.evaluate()(i_data, w1_data) + check_result(mod, {"data": i_data, "weight1": w1_data}, + (1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5) + + +@pytest.mark.skip(reason="Only for DEMO purpose") +def test_extern_dnnl_mobilenet(): + # FIXME: This test is only for demo purpose and supposed to be removed. + dtype = 'float32' + ishape = (1, 3, 224, 224) + mod, params = relay.testing.mobilenet.get_workload( + batch_size=1, dtype='float32') + + mod = relay.build_extern(mod, "dnnl") + + i_data = np.random.uniform(0, 1, ishape).astype(dtype) + + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(0)) + res = ex.evaluate()(i_data, **params) + + # FIXME: When subgraph has only one op, Relay executor will use the cache value instead + # of re-computing, so the following checking logic does not work. + #ref_mod, params = relay.testing.mobilenet.get_workload(batch_size=1, dtype='float32') + #ref_ex = relay.create_executor("debug", mod=ref_mod, ctx=tvm.cpu(0)) + #ref_res = ref_ex.evaluate()(i_data, **params) + + #tvm.testing.assert_allclose(res.asnumpy(), ref_res.asnumpy(), rtol=1e-5) + + +if __name__ == "__main__": + test_multi_node_subgraph() + test_extern_gcc_single_op() + test_extern_gcc() + # test_extern_dnnl() + # test_extern_dnnl_mobilenet() diff --git a/tutorials/dev/custom_relay_backend.py b/tutorials/dev/custom_relay_backend.py new file mode 100644 index 000000000000..ba8ce514d27d --- /dev/null +++ b/tutorials/dev/custom_relay_backend.py @@ -0,0 +1,291 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" + +.. _tutorial-custom-relay-backend: + +Bring Your Own Codegen To TVM +============================= +**Author**: `Zhi Chen `_, `Cody Hao Yu `_ + +As the number of hardware devices targeted by deep learning workloads keeps increasing, the required knowledge +for users to achieve high performance on various devices keeps increasing as well. To free data +scientists from worrying about the performance when developing a new model, hardware vendors either +provide libraries such as MKLDNN or cuDNN with many commonly used deep learning operators, +or provide frameworks such as TensorRT to let users describe their models in a certain way to +achieve high performance. However, users have to learn a new programming interface when they +attempt to work on a new library or device. As a result, the demand of a unified programming +interface becomes more and more important to 1) let all users and hardware vendors stand on the +same page, and 2) provide a feasible solution to allow a specialized hardware or library to only +support widely used operators with extremely high performance, but fallback unsupported operators +to general devices like CPU/GPU. + +In this tutorial, we demonstrate how a hardware vendor can easily implement +a Relay backend to support a specialized hardware device/library. It mainly +takes three steps: 1) define whether an operator is supported under a given +template, 2) specify how to compile and serialize the supported operators so +that it can ingest TVM specific data format, e.g. NDArray, and 3) specify how +to execute the compiled operators on a certain device. We will demonstrate how +to add a new backend that uses open source compilers (e.g. GCC, LLVM, etc) or any +proprietary compilers to execute a subgraph of a model without the exposure of +the IP of customer's codegen tool chain. Note that you will need to add the +specialized Relay backend to the TVM codebase and rebuild TVM for enabling. + +""" + +###################################################################### +# Define The Supported Operators +# ------------------------------ +# The first step is to define which operators are supported by your backend. +# A template is provided to ease vendor's effort to add the supported +# operators. +# +# For example, We create a new Python file at python/relay/backend/op/contrib/gcc/extern_op.py, +# and implement a set of boolean functions with corresponding operator names. A boolean +# function should return `True` if we allow it to be executed by the given backend; `False` +# otherwise. + +from __future__ import absolute_import + +def conv2d(attrs, args): + """Check if the external codegen should be used. + """ + return False + +def subtract(attrs, args): + """Check if the external codegen should be used. + """ + return True + +def add(attrs, args): + """Check if the external codegen should be used. + """ + return True + +def multiply(attrs, args): + """Check if the external codegen should be used. + """ + return True + +###################################################################### +# Note that since we include `attrs` and `args` into the function signature, we +# can define more complicated rules. For example, we can only support conv2d +# with float32 data type or with kernel size 1x1. In addition, the vendors can +# also check the attributes associated with a given operator to decide if it is +# supported by checking the fields in `attrs`. In an even more complicated but +# interesting scenario, we also allow developers to check the sequence of +# operators through iterating on the `agrs`. However, this is only +# unidirectional as only the inputs are visible. +# +# After annotating whether an operator can be executed on the given backend. +# Users can directly invoke the partitioning pass to separate the graph into +# multiple segments. The C++ backend implements a partitioning pass to fulfill +# the task and creates subgraphs/sub-functions with *External* attribute, +# indicating that this function will be handled by external codegen tool. +# Therefore, Relay passes should skip optimizations on them. + +###################################################################### +# Customize Subgraph Annotations +# ------------------------------ +# In addition to specifying a set of rules for supported operators, we can also implement +# a Relay IR mutator to find the supported subgraphs, which may include multiple operators, +# for the target backend. Here we implement an annotator that includes an entire Relay graph +# to be offloaded. Specifically, we are going to do two tasks: +# - insert `subgraph_begin` after all input variables +# - insert `subgraph_end` before the primary output. For example, given a Relay graph as follows: +# input_a +# | +# add --- input_b +# | +# subtract --- input_c +# | +# multiply --- input_d +# | +# out +# +# Our goal is to mutate the graph to the following: +# +# input_a +# | +# subgraph_begin +# | +# add --- subgraph_begin --- input_b +# | +# subtract --- subgraph_begin --- input_c +# | +# multiply --- subgraph_begin --- input_d +# | +# subgraph_end +# | +# out +# +# The implementation is shown as follows. As can be seen, the annotator is derived from +# `ExprMutator` that traverses a Relay graph and allows us to mutate it. We know that all ops +# are `call` nodes in Relay graph, so we override the call node mutator `visit_call` in +# `ExprMutator` and insert annotations. + +import tvm +from tvm import relay +from tvm.relay.expr_functor import ExprMutator +from tvm.relay.annotation import subgraph_begin, subgraph_end + +class WholeGraphAnnotator(ExprMutator): + """ + An annotator that creates a subgraph for an entire graph. + """ + def __init__(self, compiler): + super(WholeGraphAnnotator, self).__init__() + self.compiler = compiler + self.last_call = True + + def visit_call(self, call): + curr_last = self.last_call + self.last_call = False + + params = [] + for arg in call.args: + param = super().visit(arg) + if isinstance(param, relay.expr.Var): + param = subgraph_begin(param, self.compiler) + params.append(param) + + new_call = relay.Call(call.op, params, call.attrs) + if curr_last: + new_call = subgraph_end(new_call, self.compiler) + return new_call + +###################################################################### +# Finally, we apply the annotator to our workload. Let's first build a Relay +# function: + +input_a = relay.var('a', shape=(10, 10)) +input_b = relay.var('b', shape=(10, 10)) +input_c = relay.var('c', shape=(10, 10)) +input_d = relay.var('d', shape=(10, 10)) + +temp_1 = relay.add(input_a, input_b) +temp_2 = relay.subtract(temp_1, input_c) +out = relay.multiply(temp_2, input_d) +func = relay.Function([input_a, input_b, input_c, input_d], out) + +###################################################################### +# The above Relay function results in the following IR: + +print(func) + +###################################################################### +# Then we apply the annotator to the IR and partition the graph: + +mod = relay.Module() +mod['main'] = WholeGraphAnnotator('gcc').visit(func) +mod = relay.transform.PartitionGraph()(mod) + +###################################################################### +# Accordingly, the IR is transformed to the following. We can see that the +# entire Relay graph is enclosed in a function with `External="gcc"` attribute. +# It indicates that this function will be offloaded to an external backend +# during the runtime. + +print(mod['main']) + +###################################################################### +# Implement The Codegen +# --------------------- +# The second and the third step are implemented in C++ instead of Python. +# Specifically, we create src/relay/backend/contrib/gcc/codegen.cc and +# implement the codegen and runtime dispatcher here. For the codegen, +# we need to implement two functions: `CompileExternalLib()` and `Build()`. +# `Build()` accepts a Relay module or subgraph and generate the library or device +# code accordingly. In the GCC example, we implement a Relay IR visitor to generate +# C++ code for subgraphs. + +###################################################################### +# In addition `CompileExternalLib()` is used for specifying how to generate and +# serialize an external library for the generated device code (C++ in this +# example). The generated library/executable binary can either be materialized +# to disk and load back during runtime, or stored in memory directly for +# later usage using whatever user defined mechanism. In the GCC case, the +# stand system calls e.g. dlopen/dlsym or LoadLibrary/GetProcAddress are used +# for Linux and Windows, respectively. + +###################################################################### +# Implement The Runtime Dispather +# ------------------------------- +# The last step is invoking the generated external library in runtime. +# We create a runtime module `GccModule` derived from `ExternModuleBase` +# in src/runtime/contrib/gcc/gcc.h for Relay runtime to dispatch the +# generated library/executable. Then, we implement the dispatcher in +# src/runtime/contrib/gcc/gcc.cc. Note that altough the `GccModule` constructor +# accepts the path of generated library/executable for runtime initialization, +# it can be customized by each external backend to accept any types of required +# artifacts. + +###################################################################### +# In addition, we implement tvm runtime `Module` compatible +# `GetFunction()`. The function takes a subgraph name and returns +# a `PackedFunc` that executes the subgraph with runtime input data. Note that +# the runtime data in TVM is provided in the tvm `NDArray` format. It's +# vendors' repsonsiblity to deserialize it into the format that they library +# can ingest. For example, we unpack it and extract the raw pointers for +# MKL-DNN. If the subgraph is compiled by `Build` in advance and the shared +# library or executable binary is available, then we can invoke it here. +# +# `GetFunction()` will be invoked by Relay runtime, including interpreter, +# graph runtime, and VM, meaning that this one implemtation works for all +# kinds of Relay runtimes. + +###################################################################### +# Add Codegen to TVM Building Process +# ----------------------------------- +# Finally, we include the implemented codegen to the cmake config so that +# it will be built along with the TVM. In cmake/modules/contrib/Extern.cmake: +# +# list(FIND USE_EXTERN "gcc" _gcc_idx) +# if(_gcc_idx GREATER -1) +# file(GLOB GCC_RELAY_CONTRIB_SRC src/relay/backend/contrib/gcc/codegen.cc) +# list(APPEND COMPILER_SRCS ${GCC_RELAY_CONTRIB_SRC}) +# file(GLOB GCC_CONTRIB_SRC src/runtime/contrib/gcc/*.cc) +# list(APPEND RUNTIME_SRCS ${GCC_CONTRIB_SRC}) +# message(STATUS "Use extern library: GCC") +# endif() + + +###################################################################### +# We can now build TVM with the external GCC backend and test the correctness: +# 1. cd build +# 2. set(USE_EXTERN gcc) in config.cmake +# 3. cmake ..; make -j +# +# .. note:: +# The complete GCC backend implementation is in the TVM codebase +# so we can directly use it in this tutorial for demonstration. +# +# Multiple external backends can be enabled simultaneously by ";". +# For example: set(USE_EXTERN gcc;dnnl) + +import numpy as np + +a_data = np.random.rand(10, 10).astype('float32') +b_data = np.random.rand(10, 10).astype('float32') +c_data = np.random.rand(10, 10).astype('float32') +d_data = np.random.rand(10, 10).astype('float32') + +ex = relay.create_executor('debug', mod=mod, ctx=tvm.cpu(0)) +result = ex.evaluate()(a_data, b_data, c_data, d_data) +tvm.testing.assert_allclose(result.asnumpy(), (a_data + b_data - c_data) * d_data) + +print('Results are correct!')