diff --git a/CMakeLists.txt b/CMakeLists.txt index 769a35318d9d..12c8e88044be 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -74,6 +74,7 @@ tvm_option(USE_CPP_RPC "Build CPP RPC" OFF) tvm_option(USE_TFLITE "Build with tflite support" OFF) tvm_option(USE_TENSORFLOW_PATH "TensorFlow root path when use TFLite" none) tvm_option(USE_COREML "Build with coreml support" OFF) +tvm_option(USE_BNNS "Build with BNNS support" OFF) tvm_option(USE_TARGET_ONNX "Build with ONNX Codegen support" OFF) tvm_option(USE_ARM_COMPUTE_LIB "Build with Arm Compute Library" OFF) tvm_option(USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME "Build with Arm Compute Library graph runtime" OFF) @@ -348,6 +349,7 @@ include(cmake/modules/contrib/HybridDump.cmake) include(cmake/modules/contrib/TFLite.cmake) include(cmake/modules/contrib/TF_TVMDSOOP.cmake) include(cmake/modules/contrib/CoreML.cmake) +include(cmake/modules/contrib/BNNS.cmake) include(cmake/modules/contrib/ONNX.cmake) include(cmake/modules/contrib/ArmComputeLib.cmake) include(cmake/modules/contrib/TensorRT.cmake) diff --git a/cmake/config.cmake b/cmake/config.cmake index 872feb918a4f..5faeb9325dba 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -269,3 +269,6 @@ set(USE_HEXAGON_SDK /path/to/sdk) # Whether to use ONNX codegen set(USE_TARGET_ONNX OFF) + +# Whether enable BNNS runtime +set(USE_BNNS OFF) diff --git a/cmake/modules/contrib/BNNS.cmake b/cmake/modules/contrib/BNNS.cmake new file mode 100644 index 000000000000..e14aa2857ebc --- /dev/null +++ b/cmake/modules/contrib/BNNS.cmake @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(USE_BNNS STREQUAL "ON") + add_definitions(-DUSE_JSON_RUNTIME=1) + file(GLOB BNNS_RELAY_CONTRIB_SRC src/relay/backend/contrib/bnns/*.cc) + list(APPEND COMPILER_SRCS ${BNNS_RELAY_CONTRIB_SRC}) + list(APPEND COMPILER_SRCS ${JSON_RELAY_CONTRIB_SRC}) + + list(APPEND TVM_RUNTIME_LINKER_LIBS "-framework Accelerate") + + file(GLOB BNNS_CONTRIB_SRC src/runtime/contrib/bnns/*.cc) + list(APPEND RUNTIME_SRCS ${BNNS_CONTRIB_SRC}) + message(STATUS "Build with BNNS JSON runtime: " ${EXTERN_LIBRARY_BNNS}) +endif() + diff --git a/docs/deploy/bnns.rst b/docs/deploy/bnns.rst new file mode 100644 index 000000000000..cb15a4f3bd54 --- /dev/null +++ b/docs/deploy/bnns.rst @@ -0,0 +1,183 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Relay BNNS Integration +====================== +**Author**: `Egor Churaev `_ + +Introduction +------------ + +Apple BNNS library is a collection of functions that can be used to construct neural networks +for inference (and train). It’s supported in macOS, iOS, tvOS, and watchOS. BNNS provides +primitives executed on all CPU supported on those platforms and optimized for high performance +and low-energy consumption. This integration will offload as many operators as possible from Relay to BNNS. + +BNNS runtime is a part of platform API and available on all modern Apple operating systems. +Application using BNNS will not depends on any additional external dependencies. + +BNNS functions uses Apple private hardware capabilities which are not exposed yet by Apple. Example +of such capabilities can be AMX Apple cpu extension. + +This guide will demonstrate how to build TVM with BNNS codegen and runtime enabled. It will also provide example +code to compile and run models using BNNS runtime. Finally, we document the supported operators. + +Building TVM with BNNS support +------------------------------ + +To turn on TVM BNNS codegen and TVM BNNS runtime you need to turn on the only USE_BNNS flag + +* USE_BNNS=ON/OFF - This flag will enable compiling a network with offloading subgraphs to BNNS primitives + and will link tvm library to the BNNS runtime module. + +Enabling of this flag will cause to search the default Accelerate Frameworks on current target SDK. +The minimal versions of required SDK is macOS 11.0, iOS 14.0, tvOS 14.0 and watchOS 7.0. + +Example setting in config.cmake file: + +.. code:: cmake + + set(USE_BNNS ON) + +BNNS partitioning of Relay graph +-------------------------------- + +Operations to be offloaded on BNNS execution must be annotated before passing of module for compilation. +All ops annotated by `partition_for_bnns` will be offloaded for BNNS execution. The rest of the ops +will go through the LLVM compilation and code generation. + +Important note: BNNS support primitives only with constant weights. To satisfy this requirements we have +to map constants to related tensor abstraction in relay representation. To freeze tensors and operate +with them as constants you may need to call ONNX importer with special flag "freeze_params=True" +or performer binding manually. In general cases all relay importers don't do that by default. +For your convenience "partition_for_bnns" can do this for you if params dictionary is passed as the argument. + +.. code:: python + + from tvm.relay.op.contrib.bnns import partition_for_bnns + model = partition_for_bnns(model, params=params) + + +Input data layout for operations to be offloaded to BNNS execution +------------------------------------------------------------------ + +BNNS kernels support only planar format of input data. The partitioner will require to have NCHW input +layout for conv2d input. + +To use BNNS integration for models with interleave input layout, they should be converted before +passing of module to `partition_for_bnns`. The layout conversion will happen only for explicitly +enumerated types of ops. It might happen that depending on topology there might be regular data reorder +around conv2d to interleave and planar layout. This will be reflected in performance penalties and affect +execution time. It is recommended to analyze the whole topology and extend below list to convert all +intermediate tensors to NCHW data layout. + +Example of input layouts change: + +.. code:: python + + # For models with NHWC input layout + with tvm.transform.PassContext(opt_level=3): + mod = relay.transform.InferType()(mod) + mod = relay.transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"], + "nn.bias_add": ["NCHW", "default"], + "nn.relu": ["NCHW"]})(mod) + + +Example: Build and Deploy Mobilenet v2 1.0 with BNNS +---------------------------------------------------- + +Create a Relay graph from a MXNet Mobilenet v2 1.0 model. + +.. code:: python + + import tvm + from tvm import relay + import mxnet + from mxnet.gluon.model_zoo.vision import get_model + + dtype = "float32" + input_shape = (1, 3, 224, 224) + block = get_model('mobilenetv2_1.0', pretrained=True) + module, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) + + +Markup the parts of graphs to be offloaded to BNNS primitives. All ops which are supported by the BNNS +integration will be handled by BNNS invocations, the rest of the ops will go through the +regular TVM llvm compilation and code generation. + +After that you need to compile new module with target corresponding to required Apple platform + +.. code:: python + + from tvm.relay.op.contrib.bnns import partition_for_bnns + + # target for macOS Big Sur 11.1: + target = "llvm -mtriple=x86_64-apple-darwin20.2.0" + + model = partition_for_bnns(model, params=params) # to markup operations to be offloaded to BNNS + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(model, target=target, target_host=target, params=params) + +Export the module. + +.. code:: python + + lib.export_library('compiled.dylib') + + +Load module and run inference on the target machine with TVM built with ``USE_BNNS`` enabled + +.. code:: python + + import tvm + import numpy as np + from tvm.contrib import graph_runtime + + ctx = tvm.cpu(0) + loaded_lib = tvm.runtime.load_module('compiled.dylib') + gen_module = tvm.contrib.graph_runtime.GraphModule(loaded_lib['default'](ctx)) + + dtype = "float32" + input_shape = (1, 3, 224, 224) + input_data = np.random.uniform(0, 1, input_shape).astype(dtype) + gen_module.run(data=input_data) + + + +Operator support +---------------- + ++------------------------+------------------------------------------------------------------------------+ +| Relay Node | Remarks | ++========================+==============================================================================+ +| nn.conv2d | | ++------------------------+------------------------------------------------------------------------------+ +| nn.batch_norm | Supported by BNNS integration only in nn.conv2d-batch_norm pattern | ++------------------------+------------------------------------------------------------------------------+ +| nn.dense | | ++------------------------+------------------------------------------------------------------------------+ +| nn.batch_matmul | | ++------------------------+------------------------------------------------------------------------------+ +| nn.bias_add | Supported by BNNS integration only as a bias part of nn.conv2d or nn.dense | +| | fusion | ++------------------------+------------------------------------------------------------------------------+ +| add | Supported by BNNS integration only as a bias part of nn.conv2d or nn.dense fusion | ++------------------------+------------------------------------------------------------------------------+ +| nn.relu | Supported by BNNS integration only as a part of nn.conv2d or nn.dense fusion | ++------------------------+------------------------------------------------------------------------------+ +| nn.gelu | Supported by BNNS integration only as a part of nn.conv2d or nn.dense fusion | ++------------------------+------------------------------------------------------------------------------+ diff --git a/docs/deploy/index.rst b/docs/deploy/index.rst index 2b37f734c3c3..3cbbb10bd74b 100644 --- a/docs/deploy/index.rst +++ b/docs/deploy/index.rst @@ -71,3 +71,4 @@ target device without relying on RPC. see the following resources on how to do s arm_compute_lib tensorrt vitis_ai + bnns diff --git a/python/tvm/driver/tvmc/composite_target.py b/python/tvm/driver/tvmc/composite_target.py index 0a2592685646..886160ad000c 100644 --- a/python/tvm/driver/tvmc/composite_target.py +++ b/python/tvm/driver/tvmc/composite_target.py @@ -21,6 +21,7 @@ from tvm.relay.op.contrib.arm_compute_lib import partition_for_arm_compute_lib from tvm.relay.op.contrib.ethosn import partition_for_ethosn +from tvm.relay.op.contrib.bnns import partition_for_bnns from .common import TVMCException @@ -40,6 +41,10 @@ "config_key": "relay.ext.ethos-n.options", "pass_pipeline": partition_for_ethosn, }, + "bnns": { + "config_key": None, + "pass_pipeline": partition_for_bnns, + }, } diff --git a/python/tvm/relay/op/contrib/__init__.py b/python/tvm/relay/op/contrib/__init__.py index 49abf36134b4..30c2db0ddf0b 100644 --- a/python/tvm/relay/op/contrib/__init__.py +++ b/python/tvm/relay/op/contrib/__init__.py @@ -20,6 +20,7 @@ from .arm_compute_lib import * from .dnnl import * +from .bnns import * from .coreml import * from .ethosn import * from .tensorrt import * diff --git a/python/tvm/relay/op/contrib/bnns.py b/python/tvm/relay/op/contrib/bnns.py new file mode 100644 index 000000000000..2ace502e6528 --- /dev/null +++ b/python/tvm/relay/op/contrib/bnns.py @@ -0,0 +1,327 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""BNNS library supported operators. +Is a part of Accelerate framework on macOS/iOS platforms. Apple provide several APIs +to handle tensor processing. Particularly: + * BNNS (basic neural ) + * vDSP (1D and 2D tensor processing) +""" +import math +import tvm.ir + +from tvm.relay import transform +from tvm.relay.expr import const +from tvm.relay.build_module import bind_params_by_name + +from .register import register_pattern_table, get_pattern_table +from ...dataflow_pattern import wildcard, is_op, is_expr + + +def partition_for_bnns(mod, params=None): + """Partition the graph greedily offloading supported + operators to BNNS. + + Parameters + ---------- + mod : Module + The module to run passes on. + params : Optional[Dict[str, NDArray]] + Constant input parameters. + + Returns + ------- + ret : annotated and partitioned module. + """ + if params: + mod["main"] = bind_params_by_name(mod["main"], params) + + seq = tvm.transform.Sequential( + [ + transform.InferType(), + transform.FoldConstant(), + transform.FoldScaleAxis(), + transform.DynamicToStatic(), + transform.AlterOpLayout(), + # TODO(apeskov): WA. AlterOpLayout call lead to constants shape transformation + # Some expand_dims op may appears after constants. It breaks BNNS fusing. + # So we have to call FoldConstant right before bnns composite passes. + transform.FoldConstant(), + transform.MergeComposite(get_pattern_table("bnns")), + transform.AnnotateTarget("bnns"), + # If you no need in per layer performance statistic you can + # uncomment next line + # transform.MergeCompilerRegions(), + transform.PartitionGraph(), + ] + ) + + return seq(mod) + + +def _register_external_op_helper(op_name, supported=True): + """The helper function to indicate that a given operator can be supported + by BNNS. + + Parameters + ---------- + op_name : Str + The name of supported operator that will be registered. + + Returns + ------- + f : callable + A function that returns if the operator is supported by BNNS. + """ + + @tvm.ir.register_op_attr(op_name, "target.bnns") + def _func_wrapper(expr): + return supported + + return _func_wrapper + + +_register_external_op_helper("nn.batch_matmul") + + +@tvm.ir.register_op_attr("nn.max_pool2d", "target.bnns") +def max_pool2d_check(expr): + """Check if the nn.max_pool2d can be executed in BNNS""" + attrs, args = expr.attrs, expr.args + data_typ = args[0].checked_type + rank = len(data_typ.shape) + if rank < 3 or rank > 4 or data_typ.dtype != "float32": + return False + if attrs.layout != "NCHW": + return False + return True + + +@tvm.ir.register_op_attr("nn.avg_pool2d", "target.bnns") +def avg_pool2d_check(expr): + """Check if the nn.avg_pool2d can be executed in BNNS""" + attrs, args = expr.attrs, expr.args + data_typ = args[0].checked_type + rank = len(data_typ.shape) + if rank < 3 or rank > 4 or data_typ.dtype != "float32": + return False + if attrs.layout != "NCHW": + return False + return True + + +@tvm.ir.register_op_attr("nn.global_max_pool2d", "target.bnns") +def global_max_pool2d_check(expr): + """Check if the nn.global_max_pool2d can be executed in BNNS""" + attrs, args = expr.attrs, expr.args + data_typ = args[0].checked_type + rank = len(data_typ.shape) + if rank < 3 or rank > 4 or data_typ.dtype != "float32": + return False + if attrs.layout != "NCHW": + return False + return True + + +@tvm.ir.register_op_attr("nn.global_avg_pool2d", "target.bnns") +def global_avg_pool2d_check(expr): + """Check if the nn.global_avg_pool2d can be executed in BNNS""" + attrs, args = expr.attrs, expr.args + data_typ = args[0].checked_type + rank = len(data_typ.shape) + if rank < 3 or rank > 4 or data_typ.dtype != "float32": + return False + if attrs.layout != "NCHW": + return False + return True + + +def dtype_is_supported(dtype): + """Check if data type is supported by BNNS backend""" + return dtype in ("", "float32") + + +@tvm.ir.register_op_attr("nn.conv2d", "target.bnns") +def conv2d_check(expr): + """Check if the conv2d can be executed in BNNS""" + attrs, args = expr.attrs, expr.args + data_typ = args[0].checked_type + if len(data_typ.shape) != 4 or data_typ.dtype != "float32": + return False + if not isinstance(args[1], tvm.relay.expr.Constant): + return False + kernel_typ = args[1].checked_type + if len(kernel_typ.shape) != 4 or kernel_typ.dtype != "float32": + return False + if attrs.data_layout != "NCHW": + return False + if not dtype_is_supported(attrs.out_dtype): + return False + return True + + +def bias_check(expr): + """Check is bias added through the correct dimension""" + attrs, args = expr.attrs, expr.args + if not isinstance(args[1], tvm.relay.expr.Constant): + return False + if expr.op.name == "nn.bias_add": + return attrs.axis == 1 + if expr.op.name == "add": + b_shape = args[1].checked_type.shape + if len(b_shape) == 4: + return bool(b_shape[0] == 1 and b_shape[2] == 1 and b_shape[3] == 1) + if len(b_shape) == 3: + return bool(b_shape[1] == 1 and b_shape[2] == 1) + + return False + + +@tvm.ir.register_op_attr("nn.dense", "target.bnns") +def dense(expr): + """Check if the dense can be used in BNNS.""" + attrs, args = expr.attrs, expr.args + data_typ = args[0].checked_type + if data_typ.dtype != "float32": + return False + if not isinstance(args[1], tvm.relay.expr.Constant): + return False + kernel_typ = args[1].checked_type + if len(kernel_typ.shape) != 2 or kernel_typ.dtype != "float32": + return False + if attrs.out_dtype != "float32" and attrs.out_dtype != "": + return False + return True + + +def make_conv_pattern(with_bias=True, activation="none"): + """Make pattern for bnns.conv2d primitive""" + data = wildcard() + weight = wildcard() + bias = wildcard() + pat = is_op("nn.conv2d")(data, weight) + if with_bias: + pat = is_op("add")(pat, bias) | is_op("nn.bias_add")(pat, bias) + if activation == "relu": + pat = is_op("nn.relu")(pat) + elif activation == "sigmoid": + pat = is_op("sigmoid")(pat) + return pat + + +def check_conv(extract): + """Check conv pattern is supported by BNNS.""" + bias_is_ok = True + call = extract + while call.op.name != "nn.conv2d": + if call.op.name in ("nn.bias_add", "add"): + bias_is_ok &= bias_check(call) + call = call.args[0] + return conv2d_check(call) and bias_is_ok + + +def make_dense_bias_pattern(): + """Make pattern for bnns.dense primitive""" + data = wildcard() + weight = wildcard() + bias = wildcard() + d = is_op("nn.dense")(data, weight) + return is_op("add")(d, bias) + + +def make_dense_bias_gelu_pattern(): + """Make pattern for bnns.dense primitive with fused bias and gelu activation""" + dense_bias = make_dense_bias_pattern() + const1 = is_expr(const(0.044715)) + const2 = is_expr(const(math.sqrt(2 / math.pi))) + + gelu = is_op("power")(dense_bias, is_expr(const(3, dtype="float32"))) + gelu = is_op("multiply")(gelu, const1) + gelu = is_op("add")(gelu, dense_bias) + gelu = is_op("multiply")(gelu, const2) + gelu = is_op("tanh")(gelu) + gelu = is_op("add")(gelu, is_expr(const(1, dtype="float32"))) + gelu = is_op("multiply")(gelu, is_expr(const(0.5))) + gelu = is_op("multiply")(gelu, dense_bias) + return gelu + + +def check_dense(extract): + """Check dense pattern is supported by BNNS.""" + call = extract + while call.op.name != "nn.dense": + call = call.args[0] + return dense(call) + + +@tvm.ir.register_op_attr("nn.instance_norm", "target.bnns") +def instance_norm_check(expr): + """Check if the nn.instance_norm can be executed in BNNS""" + attrs, args = expr.attrs, expr.args + data_typ = args[0].checked_type + rank = len(data_typ.shape) + if rank < 3 or rank > 4 or data_typ.dtype != "float32": + return False + if not isinstance(args[1], tvm.relay.expr.Constant) or not isinstance( + args[2], tvm.relay.expr.Constant + ): + return False + if attrs.axis == 0 and rank == 3 or attrs.axis == 1 and rank == 4: + return True + return False + + +@register_pattern_table("bnns") +def pattern_table(): + """Get BNNS specific fusing patterns collection""" + conv2d_bias_pat = ( + "bnns.conv2d_bias", + make_conv_pattern(with_bias=True), + check_conv, + ) + conv2d_bias_relu_pat = ( + "bnns.conv2d_bias_relu", + make_conv_pattern(with_bias=True, activation="relu"), + check_conv, + ) + conv2d_relu_pat = ( + "bnns.conv2d_relu", + make_conv_pattern(with_bias=False, activation="relu"), + check_conv, + ) + conv2d_bias_sigmoid_pat = ( + "bnns.conv2d_bias_sigmoid", + make_conv_pattern(with_bias=True, activation="sigmoid"), + check_conv, + ) + conv2d_sigmoid_pat = ( + "bnns.conv2d_sigmoid", + make_conv_pattern(with_bias=False, activation="sigmoid"), + check_conv, + ) + dense_bias_gelu = ("bnns.dense_bias_gelu", make_dense_bias_gelu_pattern(), check_dense) + dense_bias = ("bnns.dense_bias", make_dense_bias_pattern(), check_dense) + bnns_patterns = [ + conv2d_bias_relu_pat, + conv2d_relu_pat, + conv2d_bias_sigmoid_pat, + conv2d_sigmoid_pat, + conv2d_bias_pat, + dense_bias_gelu, + dense_bias, + ] + return bnns_patterns diff --git a/src/relay/backend/contrib/bnns/codegen.cc b/src/relay/backend/contrib/bnns/codegen.cc new file mode 100644 index 000000000000..72c32fb5b19e --- /dev/null +++ b/src/relay/backend/contrib/bnns/codegen.cc @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file + * \brief Implementation of BNNS codegen APIs. + */ + +#include +#include +#include +#include + +#include +#include + +#include "../../../../runtime/contrib/json/json_node.h" +#include "../../utils.h" +#include "../codegen_json/codegen_json.h" + +namespace tvm { +namespace relay { +namespace contrib { + +using namespace backend; + +/*! + * \brief Retrieve the expected "root" op nested inside a fused call, such as conv2d in + * relu(add(conv2d)) + * \param call A Relay call node. Typically nn.relu when called the first time. + * \param max_depth The maximum number of calls before the root op, counting from current_call. + * \param root_name The name of expected "root" op in this fused call. + * \return A CallNode corresponding to the root op + */ +inline const CallNode* FindCallWithName(const CallNode* current_call, int max_depth, + const std::string& root_name) { + ICHECK(current_call && max_depth >= 0); + + if (max_depth == 0) { + ICHECK(current_call && IsOp(current_call, root_name)); + return current_call; + } + if (IsOp(current_call, root_name)) { + return current_call; + } + + ICHECK_GT(current_call->args.size(), 0); + + const auto* next_call = current_call->args[0].as(); + return FindCallWithName(next_call, max_depth - 1, root_name); +} + +class BNNSJSONSerializer : public backend::contrib::JSONSerializer { + using JSONGraphNode = tvm::runtime::json::JSONGraphNode; + using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; + + public: + BNNSJSONSerializer(const std::string& symbol, const Expr& expr) : JSONSerializer(symbol, expr) {} + + std::vector VisitExpr_(const CallNode* cn) override { + Expr expr = GetRef(cn); + std::string name; + const CallNode* call = cn; + if (const auto* op_node = cn->op.as()) { + name = op_node->name; + } else if (const auto* fn = cn->op.as()) { + auto comp = fn->GetAttr(attr::kComposite); + ICHECK(comp.defined()) << "BNNS JSON runtime only supports composite functions."; + name = comp.value(); + + auto body = fn->body.as(); + if (name == "bnns.conv2d_bias_relu") { + auto add_op_type = IsOp(body->args[0].as(), "add") ? "add" : "nn.bias_add"; + call = GetRootCall(body, 2, {"nn.conv2d", add_op_type, "nn.relu"}); + } else if (name == "bnns.conv2d_bias") { + auto add_op_type = IsOp(body, "add") ? "add" : "nn.bias_add"; + call = GetRootCall(body, 1, {"nn.conv2d", add_op_type}); + } else if (name == "bnns.conv2d_relu") { + call = GetRootCall(body, 1, {"nn.conv2d", "nn.relu"}); + ICHECK(call->op.as()) << "Not op node"; + } else if (name == "bnns.conv2d_bias_sigmoid") { + auto add_op_type = IsOp(body->args[0].as(), "add") ? "add" : "nn.bias_add"; + call = GetRootCall(body, 2, {"nn.conv2d", add_op_type, "sigmoid"}); + ICHECK(call->op.as()) << "Not op node"; + } else if (name == "bnns.conv2d_sigmoid") { + call = GetRootCall(body, 1, {"nn.conv2d", "sigmoid"}); + ICHECK(call->op.as()) << "Not op node"; + } else if (name == "bnns.dense_bias") { + call = GetRootCall(fn->body.as(), 1, {"nn.dense", "add"}); + } else if (name == "bnns.dense_bias_gelu") { + call = FindCallWithName(fn->body.as(), 10, "nn.dense"); + } else { + LOG(FATAL) << "Unrecognized BNNS pattern: " << name; + } + } else { + LOG(FATAL) << "BNNS JSON runtime does not support calls to " << cn->op->GetTypeKey(); + } + + std::vector inputs; + for (const auto& arg : cn->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + auto node = std::make_shared(name, /* name_ */ + "kernel", /* op_type_ */ + inputs, 1 /* num_outputs_ */); + SetCallNodeAttribute(node, call); + return AddNode(node, GetRef(cn)); + } +}; + +/*! + * \brief The external compiler/codegen tool. It takes a Relay expression/module and + * compile it into a runtime module. + */ +runtime::Module BNNSCompiler(const ObjectRef& ref) { + ICHECK(ref->IsInstance()); + auto func = Downcast(ref); + auto func_name = GetExtSymbol(func); + BNNSJSONSerializer serializer(func_name, func); + serializer.serialize(); + std::string graph_json = serializer.GetJSON(); + auto params = serializer.GetParams(); + + const auto* pf = runtime::Registry::Get("runtime.BNNSJSONRuntimeCreate"); + ICHECK(pf != nullptr) << "Cannot find JSON runtime module to create"; + auto mod = (*pf)(func_name, graph_json, params); + return mod; +} + +TVM_REGISTER_GLOBAL("relay.ext.bnns").set_body_typed(BNNSCompiler); + +/** + * \brief A helper to expand the params by adding ones which used by BNNS runtime + * for a given expression. Same as default ConstantUpdater but skip constant from + * essential BNNS composed function ops. + */ +struct BNNSConstantUpdater : public ConstantUpdater { + public: + BNNSConstantUpdater(const std::string& symbol, + std::unordered_map* params, + const std::vector& skip_mask) + : ConstantUpdater(symbol, params), skip_mask_(skip_mask) {} + using ConstantUpdater::VisitExpr_; + + /**! + * Like an original implementation but avoid visiting of body nodes + * for BNNS specific composite primitives. + */ + void VisitExpr_(const FunctionNode* op) final { + this->VisitSpan(op->span); + for (auto param : op->params) { + this->VisitExpr(param); + } + + if (!isBNNSSpecificCompositeFunc(op)) { + this->VisitExpr(op->body); + } + } + + private: + bool isBNNSSpecificCompositeFunc(const FunctionNode* op) { + auto comp = op->GetAttr(attr::kComposite); + if (!comp) return false; + + auto comp_name = comp.value(); + + bool is_match = false; + for (const auto& mask : skip_mask_) { + if (std::string(comp_name).substr(0, mask.size()) == mask) { + is_match = true; + break; + } + } + return is_match; + } + + std::vector skip_mask_; +}; + +Map BNNSConstantUpdaterFunc(Expr expr, std::string symbol) { + std::vector bnns_composite_filter = {"bnns."}; + + // Visit all suitable constant nodes + std::unordered_map res; + BNNSConstantUpdater const_updater(symbol, &res, bnns_composite_filter); + const_updater(expr); + + // Convert to tvm::Map + Map ret; + for (const auto& kvp : res) ret.Set(kvp.first, kvp.second); + return ret; +} + +TVM_REGISTER_GLOBAL("relay.ext.bnns.constant_updater").set_body_typed(BNNSConstantUpdaterFunc); + +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/runtime/contrib/bnns/bnns_json_runtime.cc b/src/runtime/contrib/bnns/bnns_json_runtime.cc new file mode 100644 index 000000000000..87b01567cd30 --- /dev/null +++ b/src/runtime/contrib/bnns/bnns_json_runtime.cc @@ -0,0 +1,573 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * \file + * \brief Simple JSON runtime for Apple BNNS primitives + */ + +#include +#include +#include + +#include +#include +#include +#include + +#include "../json/json_node.h" +#include "../json/json_runtime.h" +#include "bnns_wrp.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +using namespace ::tvm::runtime; +using namespace ::tvm::runtime::json; +using namespace ::tvm::runtime::contrib::BNNS; + +struct ThreadingConfig { + /** + * Internal parallelism level ov BNNS primitive specified via BNNSFilterParameters + * struct. BNNS doesn't provide real control of internal threading, so it may be + * ignored by BNNS implementation. + * + * Valid values: + * 0 use default num of threads suggested by BNNS implementation + * >0 suggests to use this num of internal BNNS threads + */ + size_t internalConcurrency = 0; + + /** + * TVM level parallelism for BNNS runtime. + * BNNS runtime will split primitive into set of independent sub primitives which + * can be executed in parallel. As a rule the splitting are performed through output + * channels, so the effective shape of executed primitive is changed. + * + * Valid values: + * 0 do not use graph level treading + * >0 split into this num of primitives + */ + size_t externalConcurrency = 0; +}; + +/** + * Depends on platform hardware the optimal ThreadingConfig may differ. + * This function contains a priori knowledge about some Apple platforms + * and their specific. + * + * @return default ThreadingConfig suggested for this platform + */ +ThreadingConfig getDefaultThreadingConfig() { + // TODO(apeskov): have to implement CPU/iOS version check. + // meanwhile will use {0, 2} stub to utilize big cores of A13/A14 CPU. + return {0, 2}; +} + +/** + * Main entry point to BNNS runtime + */ +class BNNSJSONRuntime : public JSONRuntimeBase { + public: + BNNSJSONRuntime(const std::string& symbol_name, const std::string& graph_json, + const Array const_names) + : JSONRuntimeBase(symbol_name, graph_json, const_names) {} + + const char* type_key() const override { return "bnns_json"; } + + void Init(const Array& consts) override { + ICHECK_EQ(consts.size(), const_idx_.size()) + << "The number of input constants must match the number of required."; + + SetupConstants(consts); + BindInputsAndOutputs(); + AllocateIntermediateTensors(); + BuildEngine(); + } + + void Run() override { + // Wrap external handler into BNNS tensor representation + auto bind_ext_hdl_to_tensor = [this](uint32_t eid) { + const auto& ext_dlt = *data_entry_[eid]; + auto& bnns_tensor = tensors_eid_[eid]; + bnns_tensor->set_data_hdl(ext_dlt.data); + }; + + // Bind all input/output external data object into internal abstractions + for (const auto& eid : input_var_eid_) bind_ext_hdl_to_tensor(eid); + for (const auto& out_entity : outputs_) bind_ext_hdl_to_tensor(EntryID(out_entity)); + + // Invoke primitives in topological order + for (const auto& prim : primitives_) prim->execute(); + } + + private: + /** Make corresponding input/output tensor stubs */ + void BindInputsAndOutputs() { + tensors_eid_.resize(data_entry_.size()); + auto createTensor = [&](JSONGraphNodeEntry entry) { + auto node = nodes_[entry.id_]; + auto dlshape = node.GetOpShape()[entry.index_]; + auto dltype = node.GetOpDataType()[entry.index_]; + void* data = nullptr; + if (data_entry_[entry.id_] != nullptr) data = data_entry_[entry.id_]->data; + tensors_eid_[entry.id_] = std::make_shared( + BNNS::Shape{dlshape.begin(), dlshape.end()}, convertToBNNS(dltype), data); + }; + + for (auto& id : input_nodes_) { + auto eid = JSONGraphNodeEntry(id, 0); + createTensor(eid); + } + + for (auto entry : outputs_) { + createTensor(entry); + } + } + + /** Allocate intermediate tensors */ + void AllocateIntermediateTensors() { + for (int i = 0; i < nodes_.size(); ++i) { + auto eid = JSONGraphNodeEntry(i, 0); + if (tensors_eid_[eid.id_] != nullptr) continue; + auto node = nodes_[i]; + auto dlshape = node.GetOpShape()[0]; + auto dltype = node.GetOpDataType()[0]; + tensors_eid_[eid.id_] = std::make_shared( + BNNS::Shape{dlshape.begin(), dlshape.end()}, convertToBNNS(dltype), nullptr); + tensors_eid_[eid.id_]->allocate_memory(); + } + } + + // Build up the engine based on the input graph. + void BuildEngine() { + // Build subgraph engine. + for (size_t nid = 0; nid < nodes_.size(); ++nid) { + const auto& node = nodes_[nid]; + if (node.GetOpType() == "kernel") { + ICHECK_EQ(node.GetOpType(), "kernel"); + auto op_name = node.GetOpName(); + if ("nn.conv2d" == op_name) { + Conv2d(nid); + } else if ("bnns.conv2d_relu" == op_name) { + Conv2d(nid, false, "relu"); + } else if ("bnns.conv2d_bias_relu" == op_name) { + Conv2d(nid, true, "relu"); + } else if ("bnns.conv2d_sigmoid" == op_name) { + Conv2d(nid, false, "sigmoid"); + } else if ("bnns.conv2d_bias_sigmoid" == op_name) { + Conv2d(nid, true, "sigmoid"); + } else if ("bnns.conv2d_bias" == op_name) { + Conv2d(nid, true); + } else if ("nn.dense" == op_name) { + Dense(nid); + } else if ("bnns.dense_bias" == op_name) { + Dense(nid, true); + } else if ("bnns.dense_bias_gelu" == op_name) { + Dense(nid, true, true); + } else if ("nn.batch_matmul" == op_name) { + MatMul(nid); + } else if ("nn.instance_norm" == op_name) { + InstanceNormalization(nid); + } else if ("nn.max_pool2d" == op_name) { + Pooling(nid, false); + } else if ("nn.avg_pool2d" == op_name) { + Pooling(nid, true); + } else if ("nn.global_max_pool2d" == op_name) { + Pooling(nid, false, true); + } else if ("nn.global_avg_pool2d" == op_name) { + Pooling(nid, true, true); + } else { + LOG(FATAL) << "Unsupported op: " << op_name; + } + } + } + } + + // Get BNNS tensor. + std::shared_ptr GetBNNSTensor(const JSONGraphNodeEntry& entry) { + auto eid = EntryID(entry); + ICHECK(eid < tensors_eid_.size()); + return tensors_eid_[eid]; + } + + void Conv2d(const size_t& nid, const bool has_bias = false, + const std::string activation_type = "none") { + auto node = nodes_[nid]; + + // Setup attributes. + auto src_entry = node.GetInputs()[0]; + auto wgh_entry = node.GetInputs()[1]; + auto dst_entry = JSONGraphNodeEntry(nid, 0); + + auto dl_input_shape = nodes_[src_entry.id_].GetOpShape()[src_entry.index_]; + auto dl_weight_shape = nodes_[wgh_entry.id_].GetOpShape()[wgh_entry.index_]; + BNNS::Shape input_shape{dl_input_shape.begin(), dl_input_shape.end()}; + BNNS::Shape weight_shape{dl_weight_shape.begin(), dl_weight_shape.end()}; + std::vector str_strides = node.GetAttr>("strides"); + std::vector str_dilation = node.GetAttr>("dilation"); + std::vector str_padding = node.GetAttr>("padding"); + BNNS::Dim groups = std::stoi(node.GetAttr>("groups")[0]); + + BNNS::Dim PH_L = std::stoi(str_padding[0]), // height padding: left + PH_R = std::stoi(str_padding[2]), // height padding: right + PW_L = std::stoi(str_padding[1]), // width padding: left + PW_R = std::stoi(str_padding[3]), // width padding: right + SH = std::stoi(str_strides[0]), // height-wise stride + SW = std::stoi(str_strides[1]), // weight-wise stride + DH = std::stoi(str_dilation[0]), // height kernel dilation + DW = std::stoi(str_dilation[1]); // width kernel dilation + + // Memory descriptions. + const auto& src_t = GetBNNSTensor(src_entry); + const auto& wgh_t = GetBNNSTensor(wgh_entry); + const auto& dst_t = GetBNNSTensor(dst_entry); + + auto src_view = TView::as_is(src_t).extract_outer_dim().with_layout(BNNSDataLayoutImageCHW); + auto wgh_view = TView::as_is(wgh_t).with_layout(BNNSDataLayoutConvolutionWeightsOIHW); + auto dst_view = TView::as_is(dst_t).extract_outer_dim().with_layout(BNNSDataLayoutImageCHW); + TView bias_view; + + if (has_bias) { + auto bias_entry = node.GetInputs()[2]; + + auto bias_t = GetBNNSTensor(bias_entry); + bias_view = TView::as_is(bias_t).squeeze().with_layout(BNNSDataLayoutVector); + } + + BNNSActivation activation = {BNNSActivationFunctionIdentity}; + if (activation_type == "relu") + activation = {BNNSActivationFunctionRectifiedLinear}; + else if (activation_type == "sigmoid") + activation = {BNNSActivationFunctionSigmoid}; + + BNNSLayerParametersConvolution conv_param = { + src_view.get_bnns_view(), + wgh_view.get_bnns_view(), + dst_view.get_bnns_view(), + bias_view.get_bnns_view(), + activation, + SW, /* x_stride */ + SH, /* y_stride */ + DW, /* x_dilation_stride */ + DH, /* y_dilation_stride */ + 0, /* x_padding, explicit pads will be used */ + 0, /* y_padding, explicit pads will be used */ + groups, /* groups */ + {PW_L, PW_R, PH_L, PH_R} /* explicit pad values */ + }; + + size_t num_sub_prim = default_thread_config.externalConcurrency; + std::vector params; + std::tie(params, src_view, dst_view) = + split_to_n(num_sub_prim, conv_param, src_view, wgh_view, bias_view, dst_view); + + std::vector filters(params.size(), nullptr); + for (int i = 0; i < params.size(); i++) { + auto common_filter_param = getCommonFilterParams(); + filters[i] = BNNSFilterCreateLayerConvolution(¶ms[i], &common_filter_param); + ICHECK(filters[i]) << "BNNS primitive was not created. Unsupported attributes configuration"; + } + + primitives_.emplace_back(std::make_shared(filters, src_view, dst_view)); + } + + void Dense(const size_t& nid, const bool has_bias = false, const bool has_gelu = false) { + auto node = nodes_[nid]; + + // Setup attributes. + auto src_entry = node.GetInputs()[0]; + auto weight_entry = node.GetInputs()[1]; + auto dst_entry = JSONGraphNodeEntry(nid, 0); + + // Memory descriptions. + auto src_t = GetBNNSTensor(src_entry); + auto wgh_t = GetBNNSTensor(weight_entry); + auto dst_t = GetBNNSTensor(dst_entry); + + auto src_view = TView::as_is(src_t).extract_outer_dim().with_layout(BNNSDataLayoutVector); + auto wgh_view = TView::as_is(wgh_t).with_layout(BNNSDataLayoutRowMajorMatrix); + auto dst_view = TView::as_is(dst_t).extract_outer_dim().with_layout(BNNSDataLayoutVector); + + TView bias_view; + if (has_bias) { + auto bias_entry = node.GetInputs()[2]; + auto bias_md = GetBNNSTensor(bias_entry); + bias_view = TView::as_is(bias_md).with_layout(BNNSDataLayoutVector); + } + + BNNSActivation activation = {BNNSActivationFunctionIdentity}; + if (has_gelu) { + activation = {BNNSActivationFunctionGELUApproximation}; + activation.alpha = std::sqrt(2.0 / M_PI); + activation.beta = 0.044715; + } + + BNNSLayerParametersFullyConnected layerParameters = { + src_view.get_bnns_view(), + wgh_view.get_bnns_view(), + dst_view.get_bnns_view(), + bias_view.get_bnns_view(), + activation, + }; + + auto common_filter_param = getCommonFilterParams(); + auto filter = BNNSFilterCreateLayerFullyConnected(&layerParameters, &common_filter_param); + ICHECK(filter) << "BNNS primitive was not created. Unsupported attributes configuration"; + std::vector filters = {filter}; + primitives_.emplace_back(std::make_shared(filters, src_view, dst_view)); + } + + void MatMul(const size_t& nid) { + auto node = nodes_[nid]; + + // Setup attributes. + auto a_entry = node.GetInputs()[0]; + auto b_entry = node.GetInputs()[1]; + auto dst_entry = JSONGraphNodeEntry(nid, 0); + bool a_is_weighted = data_entry_[EntryID(a_entry)] != nullptr; + bool b_is_weighted = data_entry_[EntryID(b_entry)] != nullptr; + + // Memory descriptions. + auto a_t = GetBNNSTensor(a_entry); + auto b_t = GetBNNSTensor(b_entry); + auto dst_t = GetBNNSTensor(dst_entry); + + auto a_view = TView::as_is(a_t); + auto b_view = TView::as_is(b_t); + auto dst_view = TView::as_is(dst_t); + + BNNSLayerParametersBroadcastMatMul layerParameters = {1, // alpha + 0, // beta + false, // transA + true, // transB + false, // quadratic + a_is_weighted, + b_is_weighted, + a_view.get_bnns_view(), + b_view.get_bnns_view(), + dst_view.get_bnns_view()}; + + // BNNS limitation: MatMul use reverse dims values. However strides are calculated correctly + // based on BNNSNDArrayDescriptor::layout value. + std::reverse(layerParameters.iA_desc.size, layerParameters.iA_desc.size + 3); + std::reverse(layerParameters.iB_desc.size, layerParameters.iB_desc.size + 3); + std::reverse(layerParameters.o_desc.size, layerParameters.o_desc.size + 3); + + auto common_filter_param = getCommonFilterParams(); + auto filter = BNNSFilterCreateLayerBroadcastMatMul(&layerParameters, &common_filter_param); + ICHECK(filter) << "BNNS primitive was not created. Unsupported attributes configuration"; + + std::vector filters{filter}; + if (a_is_weighted || b_is_weighted) { + auto src_view = a_is_weighted ? b_view : a_view; + primitives_.emplace_back(std::make_shared(filters, src_view, dst_view)); + } else { + primitives_.emplace_back( + std::make_shared(filters, a_view, b_view, dst_view)); + } + } + + void InstanceNormalization(const size_t& nid) { + auto node = nodes_[nid]; + size_t axis = std::stoi(node.GetAttr>("axis")[0]); + float epsilon = std::stof(node.GetAttr>("epsilon")[0]); + bool center = std::stoi(node.GetAttr>("center")[0]); + bool scale = std::stoi(node.GetAttr>("scale")[0]); + + // Setup attributes. + auto src_entry = node.GetInputs()[0]; + auto scale_entry = node.GetInputs()[1]; + auto bias_entry = node.GetInputs()[2]; + auto dst_entry = JSONGraphNodeEntry(nid, 0); + + // Memory descriptions. + auto src_t = GetBNNSTensor(src_entry); + auto scale_t = GetBNNSTensor(scale_entry); + auto bias_t = GetBNNSTensor(bias_entry); + auto dst_t = GetBNNSTensor(dst_entry); + + auto src_view = TView::as_is(src_t); + auto dst_view = TView::as_is(dst_t); + size_t src_rank = Tensor::getRank(src_view.get_bnns_view()); + size_t dst_rank = Tensor::getRank(dst_view.get_bnns_view()); + ICHECK_EQ(src_rank, dst_rank); + ICHECK_LE(src_rank, 4); + if (src_rank < 4) { + src_view = src_view.unsqueeze(4); + dst_view = dst_view.unsqueeze(4); + } + src_view = src_view.extract_outer_dim().with_layout(BNNSDataLayoutImageCHW); + dst_view = dst_view.extract_outer_dim().with_layout(BNNSDataLayoutImageCHW); + auto scale_view = TView::as_is(scale_t).with_layout(BNNSDataLayoutVector); + auto bias_view = TView::as_is(bias_t).with_layout(BNNSDataLayoutVector); + BNNSActivation activation = {BNNSActivationFunctionIdentity}; + + auto b_desc = bias_view.get_bnns_view(); + if (!center) b_desc = {}; + auto s_desc = scale_view.get_bnns_view(); + if (!scale) s_desc = {}; + + // NOTE: Axis option is ignored in BNNS. The result doesn't depends on value of axis. + BNNSLayerParametersNormalization layerParameters = {src_view.get_bnns_view(), // i_desc + dst_view.get_bnns_view(), // o_desc + b_desc, // beta_desc + s_desc, // gamma_desc + {}, // moving_mean_desc + {}, // moving_variance_desc + 1.f, // momentum + epsilon, // epsilon + activation, // activation + 1, // num_groups + axis}; // normalization_axis + + BNNSFilterType filter_type = BNNSInstanceNorm; + auto common_filter_param = getCommonFilterParams(); + auto filter = + BNNSFilterCreateLayerNormalization(filter_type, &layerParameters, &common_filter_param); + ICHECK(filter) << "BNNS primitive was not created. Unsupported attributes configuration"; + + std::vector filters{filter}; + primitives_.emplace_back(std::make_shared(filters, src_view, dst_view)); + } + + void Pooling(const size_t& nid, bool avg_pooling, bool global = false) { + auto node = nodes_[nid]; + + auto src_entry = node.GetInputs()[0]; + auto dst_entry = JSONGraphNodeEntry(nid, 0); + + // Memory descriptions. + auto src_t = GetBNNSTensor(src_entry); + auto dst_t = GetBNNSTensor(dst_entry); + + auto src_view = TView::as_is(src_t); + auto dst_view = TView::as_is(dst_t); + size_t src_rank = Tensor::getRank(src_view.get_bnns_view()); + size_t dst_rank = Tensor::getRank(dst_view.get_bnns_view()); + ICHECK_EQ(src_rank, dst_rank); + ICHECK_LE(src_rank, 4); + if (src_rank < 4) { + src_view = src_view.unsqueeze(4); + dst_view = dst_view.unsqueeze(4); + } + src_view = src_view.extract_outer_dim().with_layout(BNNSDataLayoutImageCHW); + dst_view = dst_view.extract_outer_dim().with_layout(BNNSDataLayoutImageCHW); + BNNSActivation activation = {BNNSActivationFunctionIdentity}; + BNNSPoolingFunction pf = {BNNSPoolingFunctionMax}; + if (avg_pooling) pf = {BNNSPoolingFunctionAverageCountExcludePadding}; + + // Setup attributes. + size_t k_height = 0; + size_t k_width = 0; + size_t y_padding = 0; + size_t x_padding = 0; + size_t y_stride = 1; + size_t x_stride = 1; + if (!global) { + std::vector pool_size = node.GetAttr>("pool_size"); + std::vector padding = node.GetAttr>("padding"); + std::vector strides = node.GetAttr>("strides"); + k_height = std::stoi(pool_size[0]); + k_width = std::stoi(pool_size[1]); + y_padding = std::stoi(padding[0]); + x_padding = std::stoi(padding[1]); + y_stride = std::stoi(strides[0]); + x_stride = std::stoi(strides[1]); + } else { + auto sv = src_view.get_bnns_view(); + k_height = sv.size[1]; + k_width = sv.size[0]; + } + + BNNSLayerParametersPooling layerParameters = {src_view.get_bnns_view(), // i_desc + dst_view.get_bnns_view(), // o_desc + {}, // bias + activation, // activation + pf, // pooling_function + k_width, // k_width + k_height, // k_height + x_stride, // x_stride + y_stride, // y_stride + 0, // x_dilation_stride + 0, // y_dilation_stride + x_padding, // x_padding + y_padding, // y_padding + {}}; // pad left, right, up, down padding + + auto common_filter_param = getCommonFilterParams(); + auto filter = BNNSFilterCreateLayerPooling(&layerParameters, &common_filter_param); + ICHECK(filter) << "BNNS primitive was not created. Unsupported attributes configuration"; + + std::vector filters{filter}; + primitives_.emplace_back(std::make_shared(filters, src_view, dst_view)); + } + + BNNS::Dtype convertToBNNS(const DLDataType& dl_dtype) { + if (dl_dtype.code == DLDataTypeCode::kDLFloat) { + if (dl_dtype.bits == 32) return BNNSDataTypeFloat32; + if (dl_dtype.bits == 16) return BNNSDataTypeFloat16; + } + if (dl_dtype.code == DLDataTypeCode::kDLInt) { + if (dl_dtype.bits == 32) return BNNSDataTypeInt32; + if (dl_dtype.bits == 16) return BNNSDataTypeInt16; + if (dl_dtype.bits == 8) return BNNSDataTypeInt8; + } + if (dl_dtype.code == DLDataTypeCode::kDLUInt) { + if (dl_dtype.bits == 32) return BNNSDataTypeUInt32; + if (dl_dtype.bits == 16) return BNNSDataTypeUInt16; + if (dl_dtype.bits == 8) return BNNSDataTypeUInt8; + } + LOG(FATAL) << "Unsupported data type for BNNS runtime"; + return BNNS::Dtype(0); + } + + BNNSFilterParameters getCommonFilterParams() { + // NOTE: To force weights tensor copy on stage of filter create + // just change : BNNSFlagsUseClientPtr -> 0 + return {BNNSFlagsUseClientPtr, default_thread_config.internalConcurrency}; + } + + /** Default threading config. Should be used if there are + * no other threading specificator. */ + const ThreadingConfig default_thread_config = getDefaultThreadingConfig(); + + /** Collection of all primitives in topological order */ + std::vector> primitives_; + + /** Vector with BNNS tensors. Index of tensor matched with + * corresponding EntryID from base JSONRuntimeBase. */ + std::vector tensors_eid_; +}; + +runtime::Module BNNSJSONRuntimeCreate(String symbol_name, String graph_json, + const Array& const_names) { + auto n = make_object(symbol_name, graph_json, const_names); + return runtime::Module(n); +} + +TVM_REGISTER_GLOBAL("runtime.BNNSJSONRuntimeCreate").set_body_typed(BNNSJSONRuntimeCreate); + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_bnns_json") + .set_body_typed(BNNSJSONRuntime::LoadFromBinary); + +} // namespace contrib +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/bnns/bnns_wrp.h b/src/runtime/contrib/bnns/bnns_wrp.h new file mode 100644 index 000000000000..b31e97e554da --- /dev/null +++ b/src/runtime/contrib/bnns/bnns_wrp.h @@ -0,0 +1,495 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * \file + * \brief C++ wrappers and helpers to handle BNNS objects + */ + +#ifndef TVM_RUNTIME_CONTRIB_BNNS_BNNS_WRP_H_ +#define TVM_RUNTIME_CONTRIB_BNNS_BNNS_WRP_H_ + +#include + +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace contrib { +namespace BNNS { + +using Dim = size_t; +using Shape = std::vector; +using Dtype = BNNSDataType; +using HDL = void*; + +void* default_alloc(size_t size) { return malloc(size); } + +void default_free(void* ptr) { free(ptr); } + +/** + * Main abstraction for tensor representation + * + * Contains buffer handler and common attributes like shape and dtype. + */ +class Tensor { + public: + Tensor() = delete; + Tensor(Tensor&) = delete; + + Tensor(Shape shape, Dtype dtype, void* hdl) { + auto rank = shape.size(); + ICHECK(rank < BNNS_MAX_TENSOR_DIMENSION); + + desc_ = {BNNSNDArrayFlags(0), + getPlainLayout(rank), + {}, // shape + {}, // strides + hdl, // data handler + dtype, // data type + nullptr, // table_data (clustering case), is not used + dtype, + 1.f, + 0.f}; + std::copy(shape.rbegin(), shape.rend(), std::begin(desc_.size)); + + desc_.data = hdl; + is_external_data = true; + } + + ~Tensor() { + if (desc_.data && !is_external_data) { + default_free(desc_.data); + desc_.data = nullptr; + } + } + + void allocate_memory() { + if (desc_.data && !is_external_data) { + default_free(desc_.data); + } + const size_t buff_size = getSize(desc_) * getElementSize(desc_); + desc_.data = default_alloc(buff_size); + ICHECK(desc_.data); + is_external_data = false; + } + + void* get_data_hdl() const { return desc_.data; } + + void set_data_hdl(void* hdl) { + if (desc_.data && !is_external_data) { + default_free(desc_.data); + desc_.data = nullptr; + } + + desc_.data = hdl; + is_external_data = true; + } + + const BNNSNDArrayDescriptor& get_desc() const { return desc_; } + + static BNNSDataLayout getPlainLayout(size_t rank) { + ICHECK(rank <= BNNS_MAX_TENSOR_DIMENSION); + return static_cast((rank << 16) | 0x8001); + } + + static size_t getRank(BNNSDataLayout layout) { return (layout & 0xF0000) >> 16; } + + static size_t getRank(BNNSNDArrayDescriptor desc) { return getRank(desc.layout); } + + static size_t getSize(BNNSNDArrayDescriptor desc) { + auto rank = getRank(desc); + return std::accumulate(desc.size, desc.size + rank, 1, std::multiplies()); + } + + /** return size of element in bytes */ + static size_t getElementSize(Dtype dtype) { return (dtype & 0xFFFF) / 8; } + + /** return size of element in bytes */ + static size_t getElementSize(const BNNSNDArrayDescriptor& desc) { + return getElementSize(desc.data_type); + } + + private: + bool is_external_data = false; + BNNSNDArrayDescriptor desc_; +}; + +using TensorPtr = std::shared_ptr; + +/** + * Tensor View object which represent how provided BNNS::Tensor will be considered + * + * The single BNNS::Tensor can be treated in different form depend on particular primitive + * expectation. More other some primitive supports only external form of batching. So we have + * some abstraction to describe how primitive will handle provided tensor. + * + * Batched View + * View with extracted dimension as external batch value + * example: Tensor [2, 3, 224, 224] -> View [3, 224, 224] with ext batch 2 + * + * Party View + * The collection of view on the same tensor, can be the same view or with some stride + * example: Tensor [6, 5, 3, 3] -> 3 x View [2, 5, 3, 3] with stride 45 + */ +class TView { + public: + /** Make view on provided tensor as is */ + static TView as_is(const TensorPtr& origin) { + TView res; + res.origin_ = origin; + res.view_desc_ = origin->get_desc(); + return res; + } + + /** Extract outer dimension to separate batch field. TView will became batched view */ + TView extract_outer_dim() const { + auto rank = Tensor::getRank(view_desc_); + TView res = *this; + res.batch_size_ = view_desc_.size[rank - 1]; + res.batch_stride_ = + std::accumulate(view_desc_.size, view_desc_.size + rank - 1, 1, std::multiplies<>()); + res.view_desc_.size[rank - 1] = 0; + res.view_desc_.layout = Tensor::getPlainLayout(rank - 1); + return res; + } + + /** Squeeze all dims equal 1 */ + TView squeeze(size_t min_rank = 1) const { + auto rank = Tensor::getRank(view_desc_); + size_t squeezed_shape[BNNS_MAX_TENSOR_DIMENSION] = {}; + size_t squeezed_rank = 0; + for (int i = 0; i < rank; i++) + if (view_desc_.size[i] != 1) squeezed_shape[squeezed_rank++] = view_desc_.size[i]; + + if (min_rank > squeezed_rank) { + std::fill(squeezed_shape + squeezed_rank, squeezed_shape + min_rank, 1); + squeezed_rank = min_rank; + } + + TView res = *this; + std::copy(squeezed_shape, squeezed_shape + squeezed_rank, res.view_desc_.size); + std::fill(res.view_desc_.size + squeezed_rank, res.view_desc_.size + rank, 0); + res.view_desc_.layout = Tensor::getPlainLayout(squeezed_rank); + return res; + } + + /** Expand the shape of an array */ + TView expand_dims(std::vector axes) const { + auto rank = Tensor::getRank(view_desc_); + TView res = *this; + size_t unsqueezed_shape[BNNS_MAX_TENSOR_DIMENSION] = {}; + size_t unsqueezed_rank = axes.size() + rank; + ICHECK_LE(unsqueezed_rank, BNNS_MAX_TENSOR_DIMENSION); + for (const auto& axis : axes) { + ICHECK_LT(axis, unsqueezed_rank); + unsqueezed_shape[axis] = 1; + } + for (int i = 0, orig_idx = 0; i < unsqueezed_rank; ++i) { + if (unsqueezed_shape[i] == 1) continue; + unsqueezed_shape[i] = view_desc_.size[orig_idx++]; + } + std::copy(unsqueezed_shape, unsqueezed_shape + unsqueezed_rank, res.view_desc_.size); + res.view_desc_.layout = Tensor::getPlainLayout(unsqueezed_rank); + return res; + } + + /** Unsqueeze tensor to a new rank */ + TView unsqueeze(size_t new_rank) const { + ICHECK_LE(new_rank, BNNS_MAX_TENSOR_DIMENSION); + auto rank = Tensor::getRank(view_desc_); + ICHECK_GT(new_rank, rank); + std::vector axes(new_rank - rank); + std::iota(axes.begin(), axes.end(), rank); + return expand_dims(axes); + } + + /** Construct new TView with specified layout if it applicable */ + TView with_layout(BNNSDataLayout layout) const { + ICHECK_EQ(Tensor::getRank(view_desc_), Tensor::getRank(layout)); + + TView res = *this; + res.view_desc_.layout = layout; + return res; + } + + /** Construct party TView by splitting original TView into num parts */ + TView party_split_n(size_t num) const { + ICHECK_EQ(party_size_, 1); + + TView res = *this; + size_t rank = Tensor::getRank(view_desc_); + size_t size = Tensor::getSize(view_desc_); + res.party_size_ = num; + res.party_stride_ = size / num; + + if (res.batch_size_ != 1) { + res.batch_size_ /= num; + } else { + res.view_desc_.size[rank - 1] /= num; + res.batch_stride_ /= num; + } + return res; + } + + /** Construct party TView by duplicating original TView num times */ + TView party_duplicate_n(size_t num) const { + ICHECK_EQ(party_size_, 1); + + TView res = *this; + res.party_size_ = num; + res.party_stride_ = 0; + + return res; + } + + /** Return data buffer handler */ + HDL get_data_hdl() const { return view_desc_.data; } + + /** Return external batch dimension value */ + size_t get_batch_size() const { return batch_size_; } + + /** Return external batch dimension stride */ + size_t get_stride() const { return batch_stride_; } + + /** Return party element by index */ + TView operator[](size_t i) const { + ICHECK_LT(i, party_size_); + + TView res = *this; + res.party_size_ = 1; + if (origin_) { + auto hdl = reinterpret_cast(origin_->get_data_hdl()); + hdl += i * party_stride_ * Tensor::getElementSize(view_desc_.data_type); + res.view_desc_.data = hdl; + } + return res; + } + + /** Check if view is empty and doesn't relay to any tensor */ + operator bool() const { return origin_ != nullptr; } + + /** Get BNNS descriptor for particular View. Batch and Party attributed are ignored. */ + const BNNSNDArrayDescriptor& get_bnns_view() const { return view_desc_; } + + private: + /** Original tensor object to view on */ + TensorPtr origin_; + + /** Batched view parameters */ + BNNSNDArrayDescriptor view_desc_ = {}; + size_t batch_size_ = 1; + size_t batch_stride_ = 0; + + /** Party representation parameters */ + size_t party_size_ = 1; + size_t party_stride_ = 0; +}; + +/** + * Wrapper on top of BNNSFilter and src/dst TensorView. + * + * Support decomposed representation of filter and can execute sub primitives in parallel. + */ +class Primitive { + public: + Primitive(const std::vector fs, const TView& src, const TView& dst) + : filters(fs), src_view(src), dst_view(dst) {} + + virtual ~Primitive() { + for (auto& filter : filters) + if (filter) { + BNNSFilterDestroy(filter); + filter = nullptr; + } + } + + /** Execute primitive with using specified src/dst */ + void execute() { + auto res = TVMBackendParallelLaunch(run_task, this, filters.size()); + ICHECK_EQ(res, 0) << "BNNS runtime. Primitive was not executed properly"; + } + + private: + virtual int execute_impl(int part_idx) { + const auto filter = this->filters[part_idx]; + const auto src_view = this->src_view[part_idx]; + const auto dst_view = this->dst_view[part_idx]; + + size_t mb = src_view.get_batch_size(); + + // NB! BNNS limitations + // * Do not use simple BNNSFilterApply. There is a bug inside BNNS, + // BNNSFilterApply doesn't work for grouped convolution. + // * Group convolution doesn't support arbitrary stride for Batch dim. + // The tensor should be dense. + return BNNSFilterApplyBatch(filter, mb, src_view.get_data_hdl(), src_view.get_stride(), + dst_view.get_data_hdl(), dst_view.get_stride()); + } + + static int run_task(int task_id, TVMParallelGroupEnv* penv, void* cdata) { + auto prim = reinterpret_cast(cdata); + return prim->execute_impl(task_id); + } + + protected: + /** BNNS kernels/filters collect which will execute primitive */ + std::vector filters = {}; + const TView src_view; + const TView dst_view; +}; + +/** + * Wrapper on top of BNNS::Primitive + * + * This primitive should be used for executing primitive with two inputs. + */ +class TwoInputPrimitive : public Primitive { + public: + TwoInputPrimitive(const std::vector fs, const TView& src, const TView& src2, + const TView& dst) + : Primitive(fs, src, dst), src2_view(src2) {} + + private: + int execute_impl(int task_id) override { + const auto filter = this->filters[task_id]; + const auto src_view = this->src_view[task_id]; + const auto src2_view = this->src2_view[task_id]; + const auto dst_view = this->dst_view[task_id]; + + size_t mb = src_view.get_batch_size(); + + return BNNSFilterApplyTwoInputBatch(filter, mb, src_view.get_data_hdl(), src_view.get_stride(), + src2_view.get_data_hdl(), src2_view.get_stride(), + dst_view.get_data_hdl(), dst_view.get_stride()); + } + + protected: + const TView src2_view; +}; + +/** + * Wrapper on top of BNNS::Primitive + * + * This primitive should be used for executing normalization filter + */ +class NormPrimitive : public Primitive { + public: + using Primitive::Primitive; + + private: + int execute_impl(int task_id) override { + const auto filter = this->filters[task_id]; + const auto src_view = this->src_view[task_id]; + const auto dst_view = this->dst_view[task_id]; + + size_t mb = src_view.get_batch_size(); + return BNNSNormalizationFilterApplyBatch(filter, mb, src_view.get_data_hdl(), + src_view.get_stride(), dst_view.get_data_hdl(), + dst_view.get_stride(), false); + } +}; + +/** + * Wrapper on top of BNNS::Primitive + * + * This primitive should be used for executing pooling filter + */ +class PoolingPrimitive : public Primitive { + public: + using Primitive::Primitive; + + private: + int execute_impl(int task_id) override { + const auto filter = this->filters[task_id]; + const auto src_view = this->src_view[task_id]; + const auto dst_view = this->dst_view[task_id]; + + size_t mb = src_view.get_batch_size(); + return BNNSPoolingFilterApplyBatch(filter, mb, src_view.get_data_hdl(), src_view.get_stride(), + dst_view.get_data_hdl(), dst_view.get_stride(), nullptr, 0); + } +}; + +/** + * Function which split primitive into sub primitives to parallel execution + * + * @param num requested num of sub primitives + * @param orig_conv_param original convolution descriptor + * @param src_view source tensor view + * @param wgh_view weight tensor view + * @param b_view bias tensor view + * @param dst_view destination tensor view + * @param num number of part to split into + * @return collection of Convolution descriptors plus corresponding src/dst tensors view + */ +static std::tuple, TView, TView> split_to_n( + size_t num, const BNNSLayerParametersConvolution& orig_conv_param, const TView& src_view, + const TView& wgh_view, const TView& b_view, const TView& dst_view) { + size_t batch = src_view.get_batch_size(); + size_t oc = dst_view.get_bnns_view().size[2]; + size_t groups = orig_conv_param.groups; + + BNNS::TView src_view_new; + BNNS::TView wgh_view_new; + BNNS::TView b_view_new; + BNNS::TView dst_view_new; + + // TODO(apeskov): Add split by batch dim. Meanwhile we just disable it... + if (batch > 1 || oc % num != 0 || (groups > 1 && groups % num != 0)) { + return {{orig_conv_param}, src_view, dst_view}; + } + + // if groups > 1 split only by groups + // otherwise split inside one convolution by output channels + if (groups > 1) { + src_view_new = src_view.party_split_n(num); + groups = groups / num; + } else { + src_view_new = src_view.party_duplicate_n(num); + } + + wgh_view_new = wgh_view.party_split_n(num); + b_view_new = b_view.party_split_n(num); + dst_view_new = dst_view.party_split_n(num); + + std::vector res(num); + for (size_t i = 0; i < num; i++) { + auto& cur = res[i]; + cur = orig_conv_param; + + cur.i_desc = src_view_new[i].get_bnns_view(); + cur.o_desc = dst_view_new[i].get_bnns_view(); + cur.w_desc = wgh_view_new[i].get_bnns_view(); + cur.bias = b_view_new[i].get_bnns_view(); + cur.groups = groups; + } + return {res, src_view_new, dst_view_new}; +} + +} // namespace BNNS +} // namespace contrib +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_CONTRIB_BNNS_BNNS_WRP_H_ diff --git a/src/runtime/contrib/json/json_runtime.h b/src/runtime/contrib/json/json_runtime.h index 3ae652ccaf24..55f16635b9e6 100644 --- a/src/runtime/contrib/json/json_runtime.h +++ b/src/runtime/contrib/json/json_runtime.h @@ -55,7 +55,7 @@ class JSONRuntimeBase : public ModuleNode { LoadGraph(graph_json_); } - const char* type_key() const { return "json"; } + const char* type_key() const override { return "json"; } /*! \brief Initialize a specific json runtime. */ virtual void Init(const Array& consts) = 0; @@ -69,7 +69,7 @@ class JSONRuntimeBase : public ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The packed function. */ - virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) override { if (name == "get_symbol") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->symbol_name_; }); @@ -98,7 +98,7 @@ class JSONRuntimeBase : public ModuleNode { } } - virtual void SaveToBinary(dmlc::Stream* stream) { + void SaveToBinary(dmlc::Stream* stream) override { // Save the symbol stream->Write(symbol_name_); // Save the graph diff --git a/tests/cpp/contrib/bnns.cc b/tests/cpp/contrib/bnns.cc new file mode 100644 index 000000000000..1efd487caff9 --- /dev/null +++ b/tests/cpp/contrib/bnns.cc @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +TEST(PackedFunc, Basic) { + using namespace tvm; + using namespace tvm::tir; + using namespace tvm::runtime; + int x = 0; + void* handle = &x; + DLTensor a; + + Var v = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { + ICHECK(args.num_args == 3); + ICHECK(args.values[0].v_float64 == 1.0); + ICHECK(args.type_codes[0] == kDLFloat); + ICHECK(args.values[1].v_handle == &a); + ICHECK(args.type_codes[1] == kTVMDLTensorHandle); + ICHECK(args.values[2].v_handle == &x); + ICHECK(args.type_codes[2] == kTVMOpaqueHandle); + *rv = Var("a"); + })(1.0, &a, handle); + ICHECK(v->name_hint == "a"); +} + +TEST(PackedFunc, Node) { + using namespace tvm; + using namespace tvm::tir; + using namespace tvm::runtime; + Var x; + Var t = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { + ICHECK(args.num_args == 1); + ICHECK(args[0].IsObjectRef()); + Var b = args[0]; + ICHECK(x.same_as(b)); + *rv = b; + })(x); + ICHECK(t.same_as(x)); +} + +TEST(PackedFunc, NDArray) { + using namespace tvm; + using namespace tvm::runtime; + auto x = NDArray::Empty({}, String2DLDataType("float32"), TVMContext{kDLCPU, 0}); + reinterpret_cast(x->data)[0] = 10.0f; + ICHECK(x.use_count() == 1); + + PackedFunc forward([&](TVMArgs args, TVMRetValue* rv) { *rv = args[0]; }); + + NDArray ret = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { + NDArray y = args[0]; + DLTensor* ptr = args[0]; + ICHECK(ptr == x.operator->()); + ICHECK(x.same_as(y)); + ICHECK(x.use_count() == 2); + *rv = forward(y); + })(x); + ICHECK(ret.use_count() == 2); + ICHECK(ret.same_as(x)); +} + +TEST(PackedFunc, str) { + using namespace tvm; + using namespace tvm::runtime; + PackedFunc([&](TVMArgs args, TVMRetValue* rv) { + ICHECK(args.num_args == 1); + std::string x = args[0]; + ICHECK(x == "hello"); + String y = args[0]; + ICHECK(y == "hello"); + *rv = x; + })("hello"); + + PackedFunc([&](TVMArgs args, TVMRetValue* rv) { + ICHECK(args.num_args == 1); + runtime::String s = args[0]; + ICHECK(s == "hello"); + })(runtime::String("hello")); +} + +TEST(PackedFunc, func) { + using namespace tvm; + using namespace tvm::runtime; + PackedFunc addone([&](TVMArgs args, TVMRetValue* rv) { *rv = args[0].operator int() + 1; }); + // function as arguments + int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + PackedFunc f = args[0]; + // TVMArgValue -> Arguments as function + *rv = f(args[1]).operator int(); + })(addone, 1); + ICHECK_EQ(r0, 2); + + int r1 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + // TVMArgValue -> TVMRetValue + *rv = args[1]; + })(2, 100); + ICHECK_EQ(r1, 100); + + int r2 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { + // re-assignment + *rv = args[0]; + // TVMRetValue -> Function argument + *rv = addone(args[0].operator PackedFunc()(args[1], 1)); + })(addone, 100); + ICHECK_EQ(r2, 102); +} + +TEST(PackedFunc, Expr) { + using namespace tvm; + using namespace tvm::runtime; + // automatic conversion of int to expr + PackedFunc addone([](TVMArgs args, TVMRetValue* rv) { + PrimExpr x = args[0]; + *rv = x.as()->value + 1; + }); + int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + PackedFunc f = args[0]; + // TVMArgValue -> Arguments as function + *rv = f(args[1]).operator int(); + })(addone, 1); + ICHECK_EQ(r0, 2); +} + +TEST(PackedFunc, Type) { + using namespace tvm; + using namespace tvm::runtime; + auto get_type = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + DataType x = args[0]; + *rv = x; + }); + auto get_type2 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { *rv = args[0]; }); + ICHECK(get_type("int32").operator DataType() == DataType::Int(32)); + ICHECK(get_type("float").operator DataType() == DataType::Float(32)); + ICHECK(get_type2("float32x2").operator DataType() == DataType::Float(32, 2)); +} + +TEST(TypedPackedFunc, HighOrder) { + using namespace tvm; + using namespace tvm::runtime; + using Int1Func = TypedPackedFunc; + using Int2Func = TypedPackedFunc; + using BindFunc = TypedPackedFunc; + BindFunc ftyped; + ftyped = [](Int2Func f1, int value) -> Int1Func { + auto binded = [f1, value](int x) { return f1(value, x); }; + Int1Func x(binded); + return x; + }; + auto add = [](int x, int y) { return x + y; }; + ICHECK_EQ(ftyped(Int2Func(add), 1)(2), 3); + PackedFunc f = ftyped(Int2Func(add), 1); + ICHECK_EQ(f(3).operator int(), 4); + // call the type erased version. + Int1Func f1 = ftyped.packed()(Int2Func(add), 1); + ICHECK_EQ(f1(3), 4); +} + +TEST(TypedPackedFunc, Deduce) { + using namespace tvm::runtime; + using tvm::runtime::detail::function_signature; + + TypedPackedFunc x; + auto f = [](int x) -> int { return x + 1; }; + std::function y; + + static_assert(std::is_same::FType, int(float)>::value, + "invariant1"); + static_assert(std::is_same::FType, int(int)>::value, + "invariant2"); + static_assert(std::is_same::FType, void(float)>::value, + "invariant3"); +} + +TEST(PackedFunc, ObjectConversion) { + using namespace tvm; + using namespace tvm::tir; + using namespace tvm::runtime; + TVMRetValue rv; + auto x = NDArray::Empty({}, String2DLDataType("float32"), TVMContext{kDLCPU, 0}); + // assign null + rv = ObjectRef(); + ICHECK_EQ(rv.type_code(), kTVMNullptr); + + // Can assign NDArray to ret type + rv = x; + ICHECK_EQ(rv.type_code(), kTVMNDArrayHandle); + // Even if we assign base type it still shows as NDArray + rv = ObjectRef(x); + ICHECK_EQ(rv.type_code(), kTVMNDArrayHandle); + // Check convert back + ICHECK(rv.operator NDArray().same_as(x)); + ICHECK(rv.operator ObjectRef().same_as(x)); + ICHECK(!rv.IsObjectRef()); + + auto pf1 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { + ICHECK_EQ(args[0].type_code(), kTVMNDArrayHandle); + ICHECK(args[0].operator NDArray().same_as(x)); + ICHECK(args[0].operator ObjectRef().same_as(x)); + ICHECK(args[1].operator ObjectRef().get() == nullptr); + ICHECK(args[1].operator NDArray().get() == nullptr); + ICHECK(args[1].operator Module().get() == nullptr); + ICHECK(args[1].operator Array().get() == nullptr); + ICHECK(!args[0].IsObjectRef()); + }); + pf1(x, ObjectRef()); + pf1(ObjectRef(x), NDArray()); + + // testcases for modules + auto* pf = tvm::runtime::Registry::Get("runtime.SourceModuleCreate"); + ICHECK(pf != nullptr); + Module m = (*pf)("", "xyz"); + rv = m; + ICHECK_EQ(rv.type_code(), kTVMModuleHandle); + // Even if we assign base type it still shows as NDArray + rv = ObjectRef(m); + ICHECK_EQ(rv.type_code(), kTVMModuleHandle); + // Check convert back + ICHECK(rv.operator Module().same_as(m)); + ICHECK(rv.operator ObjectRef().same_as(m)); + ICHECK(!rv.IsObjectRef()); + + auto pf2 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { + ICHECK_EQ(args[0].type_code(), kTVMModuleHandle); + ICHECK(args[0].operator Module().same_as(m)); + ICHECK(args[0].operator ObjectRef().same_as(m)); + ICHECK(args[1].operator ObjectRef().get() == nullptr); + ICHECK(args[1].operator NDArray().get() == nullptr); + ICHECK(args[1].operator Module().get() == nullptr); + ICHECK(!args[0].IsObjectRef()); + }); + pf2(m, ObjectRef()); + pf2(ObjectRef(m), Module()); +} + +TEST(TypedPackedFunc, RValue) { + using namespace tvm; + using namespace tvm::runtime; + { + auto inspect = [](TVMArgs args, TVMRetValue* rv) { + for (int i = 0; i < args.size(); ++i) { + ICHECK_EQ(args[0].type_code(), kTVMObjectRValueRefArg); + } + }; + PackedFunc finspect(inspect); + finspect(tir::Var("x")); + } + { + auto f = [](tir::Var x, bool move) { + if (move) { + ICHECK(x.unique()); + } else { + ICHECK(!x.unique()); + } + ICHECK(x->name_hint == "x"); + return x; + }; + TypedPackedFunc tf(f); + + tir::Var var("x"); + ICHECK(var.unique()); + tf(var, false); + // move the result to the function. + tir::Var ret = tf(std::move(var), true); + ICHECK(!var.defined()); + } + + { + // pass child class. + auto f = [](PrimExpr x, bool move) { + if (move) { + ICHECK(x.unique()); + } else { + ICHECK(!x.unique()); + } + return x; + }; + TypedPackedFunc tf(f); + + tir::Var var("x"); + ICHECK(var.unique()); + tf(var, false); + tf(std::move(var), true); + // auto conversion. + tf(1, true); + } +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} diff --git a/tests/python/contrib/test_bnns/__init__.py b/tests/python/contrib/test_bnns/__init__.py new file mode 100644 index 000000000000..724b23f1378b --- /dev/null +++ b/tests/python/contrib/test_bnns/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Infrastructure and tests for BNNS""" diff --git a/tests/python/contrib/test_bnns/infrastructure.py b/tests/python/contrib/test_bnns/infrastructure.py new file mode 100644 index 000000000000..0107de54a04f --- /dev/null +++ b/tests/python/contrib/test_bnns/infrastructure.py @@ -0,0 +1,330 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from itertools import zip_longest, combinations +import json +import os +import warnings + +import numpy as np + +import tvm +from tvm import relay +from tvm import rpc +from tvm.contrib import graph_runtime +from tvm.relay.op.contrib.bnns import partition_for_bnns +from tvm.contrib import utils +from tvm.autotvm.measure import request_remote +from tvm.relay.analysis import analysis + + +class Device: + """ + Common device configuration for python tests. + + Check tests/python/contrib/arm_compute_lib/ for the presence of an test_config.json file. + This file can be used to override the default configuration here which will attempt to run the BNNS + runtime tests locally if the runtime is available. Changing the configuration will allow these + runtime tests to be offloaded to a remote device with BNNS via a tracker for example. + + Notes + ----- + The test configuration will be loaded once when the the class is created. If the configuration + changes between tests, any changes will not be picked up. + + + Attributes + ---------- + connection_type : str + Details the type of RPC connection to use. Options: + local - Use the local device, + tracker - Connect to a tracker to request a remote device, + remote - Connect to a remote device directly. + host : str + Specify IP address or hostname of remote target. + port : int + Specify port number of remote target. + target : str + The compilation target. + device_key : str + The device key of the remote target. Use when connecting to a remote device via a tracker. + cross_compile : str + Specify path to cross compiler to use when connecting a remote device from a non-arm platform. + """ + + connection_type = "local" + host = "localhost" + port = 9090 + target = "llvm" + device_key = "" + cross_compile = "" + + def __init__(self): + """Keep remote device for lifetime of object.""" + self.device = self._get_remote() + + @classmethod + def _get_remote(cls): + """Get a remote (or local) device to use for testing.""" + if cls.connection_type == "tracker": + device = request_remote(cls.device_key, cls.host, cls.port, timeout=1000) + elif cls.connection_type == "remote": + device = rpc.connect(cls.host, cls.port) + elif cls.connection_type == "local": + device = rpc.LocalSession() + else: + raise ValueError( + "connection_type in test_config.json should be one of: " "local, tracker, remote." + ) + + return device + + @classmethod + def load(cls, file_name): + """Load test config + + Load the test configuration by looking for file_name relative + to the test_bnns directory. + """ + location = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) + config_file = os.path.join(location, file_name) + if not os.path.exists(config_file): + warnings.warn("Config file doesn't exist, resuming tests with default config.") + return + with open(config_file, mode="r") as config: + test_config = json.load(config) + + cls.connection_type = test_config["connection_type"] + cls.host = test_config["host"] + cls.port = test_config["port"] + cls.target = test_config["target"] + cls.device_key = test_config.get("device_key") or "" + cls.cross_compile = test_config.get("cross_compile") or "" + + +Device.target = "llvm" + + +def skip_runtime_test(): + """Skip test if it requires the runtime and it's not present.""" + # BNNS codegen not present. + if not tvm.get_global_func("relay.ext.bnns", True): + print("Skip because BNNS codegen is not available.") + return True + return False + + +def skip_codegen_test(): + """Skip test if it requires the BNNS codegen and it's not present.""" + if not tvm.get_global_func("relay.ext.bnns", True): + print("Skip because BNNS codegen is not available.") + return True + + +def build_module(mod, target, params=None, enable_bnns=True, tvm_ops=0): + """Build module with option to build for BNNS.""" + if isinstance(mod, tvm.relay.expr.Call): + mod = tvm.IRModule.from_expr(mod) + with tvm.transform.PassContext(opt_level=3): + if enable_bnns: + mod = partition_for_bnns(mod) + relay.backend.compile_engine.get().clear() + return relay.build(mod, target=target, target_host=target, params=params) + + +def build_and_run( + mod, + inputs, + outputs, + params, + device, + enable_bnns=True, + no_runs=1, + tvm_ops=0, + config=None, +): + """Build and run the relay module.""" + if config is None: + config = {} + + try: + lib = build_module(mod, device.target, params, enable_bnns, tvm_ops) + except Exception as e: + err_msg = "The module could not be built.\n" + if config: + err_msg += f"The test failed with the following parameters: {config}\n" + err_msg += str(e) + raise Exception(err_msg) + + lib = update_lib(lib, device.device, device.cross_compile) + gen_module = graph_runtime.GraphModule(lib["default"](device.device.cpu(0))) + gen_module.set_input(**inputs) + out = [] + for _ in range(no_runs): + gen_module.run() + out.append([gen_module.get_output(i) for i in range(outputs)]) + return out + + +def update_lib(lib, device, cross_compile): + """Export the library to the remote/local device.""" + lib_name = "mod.so" + temp = utils.tempdir() + lib_path = temp.relpath(lib_name) + if cross_compile: + lib.export_library(lib_path, cc=cross_compile) + else: + lib.export_library(lib_path) + device.upload(lib_path) + lib = device.load_module(lib_name) + return lib + + +def extract_bnns_modules(module): + """Get the BNNS module(s) from llvm module.""" + return list(filter(lambda mod: mod.type_key == "bnns_json", module.get_lib().imported_modules)) + + +def verify(answers, atol, rtol, verify_saturation=False, config=None): + """Compare the array of answers. Each entry is a list of outputs.""" + if config is None: + config = {} + + if len(answers) < 2: + raise RuntimeError(f"No results to compare: expected at least two, found {len(answers)}") + for answer in zip_longest(*answers): + for outs in combinations(answer, 2): + try: + if verify_saturation: + assert ( + np.count_nonzero(outs[0].asnumpy() == 255) < 0.25 * outs[0].asnumpy().size + ), "Output is saturated: {}".format(outs[0]) + assert ( + np.count_nonzero(outs[0].asnumpy() == 0) < 0.25 * outs[0].asnumpy().size + ), "Output is saturated: {}".format(outs[0]) + tvm.testing.assert_allclose( + outs[0].asnumpy(), outs[1].asnumpy(), rtol=rtol, atol=atol + ) + except AssertionError as e: + err_msg = "Results not within the acceptable tolerance.\n" + if config: + err_msg += f"The test failed with the following parameters: {config}\n" + err_msg += str(e) + raise AssertionError(err_msg) + + +def verify_codegen( + module, + known_good_codegen, + num_bnns_modules, + tvm_ops=0, + target=Device.target, +): + """Check BNNS codegen against a known good output.""" + module = build_module(module, target, tvm_ops=tvm_ops) + bnns_modules = extract_bnns_modules(module) + + assert len(bnns_modules) == num_bnns_modules, ( + f"The number of BNNS modules produced ({len(bnns_modules)}) does not " + f"match the expected value ({num_bnns_modules})." + ) + + for mod in bnns_modules: + source = mod.get_source("json") + codegen = json.loads(source)["nodes"] + # remove input and const names as these cannot be predetermined + for node in range(len(codegen)): + if codegen[node]["op"] == "input" or codegen[node]["op"] == "const": + codegen[node]["name"] = "" + codegen_str = json.dumps(codegen, sort_keys=True, indent=2) + known_good_codegen_str = json.dumps(known_good_codegen, sort_keys=True, indent=2) + + assert codegen_str == known_good_codegen_str, ( + f"The JSON produced by codegen does not match the expected result. \n" + f"Actual={codegen_str} \n" + f"Expected={known_good_codegen_str}" + ) + + +def compare_inference_with_ref(func, params, atol=0.002, rtol=0.007): + """Compare scoring results for compilation with and without BNNS. + + Provided function will be compiled two times with and without BNNS. + The scoring results for both type of compilation will be compared + with provided atol and rtol. The input data will be automatically + generated based of shape and dtype info provided for var nodes. + + """ + # Generate input tensor values + inputs = {} + for free_param in analysis.free_vars(func): + name = free_param.name_hint + dtype = free_param.type_annotation.dtype + shape = [s.value for s in free_param.type_annotation.shape] + inputs[name] = tvm.nd.array(np.random.uniform(0, 127, shape).astype(dtype)) + + # Run for both type of compilation + device = Device() + outputs = [] + for bnns in [False, True]: + outputs.append(build_and_run(func, inputs, 1, params, device, enable_bnns=bnns)[0]) + + # Compare result tensors + verify(outputs, atol=atol, rtol=rtol) + + +def generate_trials(space, r_factor=3): + """Generates a series of trials. + + This algorithm generates a series of non-deterministic trials given a + space of options to test. A trial is generated by pulling a value from + each option in the space. On some occasions the values are shuffled to + ensure a different trial on each r_factor iteration. The algorithm ensures + that each value from an option is used at least once. The total number of + trials is determined by the r_factor * the option with the largest number + of values. + + Parameters + ---------- + space: List[List[Any]] + A list of different options with varying values to test. + r_factor: Optional[int] + The repeat factor. + + Returns + ------- + result: List[Tuple] + A list of trials specifying values for each option. + + """ + np.random.seed(0) + max_len = 1 + for option in space: + max_len = max(max_len, len(option)) + + num_trials = r_factor * max_len + trials = [] + for i in range(num_trials): + trial = [] + for option in space: + if i % len(option) == 0: + np.random.shuffle(option) + trial.append(option[i % len(option)]) + + trials.append(trial) + + return trials diff --git a/tests/python/contrib/test_bnns/test_conv2d.py b/tests/python/contrib/test_bnns/test_conv2d.py new file mode 100644 index 000000000000..886958cf3076 --- /dev/null +++ b/tests/python/contrib/test_bnns/test_conv2d.py @@ -0,0 +1,177 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""BNNS integration conv2d tests.""" + +import numpy as np +import pytest +import tvm +from tvm import relay + +from .infrastructure import skip_runtime_test, compare_inference_with_ref, generate_trials + +# TODO: Missed cases +# 1. Bias as add with 3d const tensor. Lead to additional unsqueeze op between +# 2. Check unsupported cases of fusion. Like bias add with axis != 1, add with broadcast by spatial dims +# 3. Check if bias/weights is not constants. Should fallback into LLVM or decompose it +# 4. Check if bias/weights is constants expr. Should works somehow. + + +def _get_model( + shape, + kernel=(3, 3), + padding=(1, 1), + strides=(1, 1), + dilation=(1, 1), + groups=1, + dtype="float32", + channels=-1, # -1 means same as input channels + bias_type="none", + activation_type="none", +): + """Return a model and any parameters it may have""" + if channels == -1: + channels = shape[1] + + a = relay.var("a", shape=shape, dtype=dtype) + weight_shape = (channels, shape[1] // groups, *kernel) + w = tvm.nd.array(np.random.uniform(-128, 127, weight_shape).astype(dtype)) + weights = relay.const(w, dtype) + out = relay.nn.conv2d( + a, + weights, + kernel_size=kernel, + dilation=dilation, + strides=strides, + padding=padding, + groups=groups, + channels=channels, + out_dtype=dtype, + ) + params = {"w": w} + if bias_type == "bias_add": + b = tvm.nd.array(np.random.uniform(-10, 10, weight_shape[0]).astype(dtype)) + biasc = relay.const(b, dtype) + out = relay.nn.bias_add(out, biasc, axis=1) + params["b"] = b + elif bias_type == "add_3d" or bias_type == "add_4d": + bias_shape = ( + (weight_shape[0], 1, 1) if bias_type == "add_3d" else (1, weight_shape[0], 1, 1) + ) + b = tvm.nd.array(np.random.uniform(-10, 10, bias_shape).astype(dtype)) + biasc = relay.const(b, dtype) + out = relay.add(out, biasc) + params["b"] = b + + if activation_type == "relu": + out = relay.nn.relu(out) + elif activation_type == "sigmoid": + out = relay.op.sigmoid(out) + return out, params + + +@pytest.mark.skipif(skip_runtime_test(), reason="Skip because BNNS codegen is not available") +def test_conv2d(): + np.random.seed(0) + + kernel_hs = [1, 2, 3, 5] + kernel_ws = [1, 2, 3, 5] + pad = [(1, 1), (2, 2), (2, 1)] + strides = [(1, 1), (2, 2)] + dilation = [(1, 1)] + out_channels = [1, 4, 8, 16] + input_shapes = [(10, 10, 14), (12, 15, 16), (20, 20, 20)] + batches = [1, 2] + groups = [1, 2] + bias_kind = ["none", "add_3d", "add_4d", "bias.add"] + activation_kind = ["none", "relu", "sigmoid"] + trials = generate_trials( + [ + kernel_hs, + kernel_ws, + pad, + strides, + dilation, + out_channels, + input_shapes, + groups, + batches, + bias_kind, + activation_kind, + ], + 3, + ) + + for ( + kernel_h, + kernel_w, + pad, + stride, + dilation, + out_channels, + input_shapes, + group, + batch, + bias, + activation, + ) in trials: + if out_channels % group != 0: + continue + func, params = _get_model( + shape=(batch, *input_shapes), + kernel=(kernel_h, kernel_w), + padding=pad, + strides=stride, + dilation=dilation, + groups=group, + channels=out_channels, + bias_type=bias, + activation_type=activation, + ) + compare_inference_with_ref(func, params) + + +@pytest.mark.skipif(skip_runtime_test(), reason="Skip because BNNS codegen is not available") +def test_conv2d_dw(): + if skip_runtime_test(): + return + + np.random.seed(0) + shape = [4, 5, 5] + + for batch in [1, 2]: + mod, params = _get_model(shape=(batch, *shape), groups=shape[0]) + compare_inference_with_ref(mod, params) + + +@pytest.mark.skipif(skip_runtime_test(), reason="Skip because BNNS codegen is not available") +def test_conv2d_with_oc1(): + if skip_runtime_test(): + return + + np.random.seed(0) + shape = [3, 5, 5] + + for batch in [1, 2]: + for bias in ["none", "add_4d"]: + mod, params = _get_model(shape=(batch, *shape), channels=1, bias_type=bias) + compare_inference_with_ref(mod, params) + + +if __name__ == "__main__": + test_conv2d() + test_conv2d_dw() + test_conv2d_with_oc1() diff --git a/tests/python/contrib/test_bnns/test_conv2d_patterns.py b/tests/python/contrib/test_bnns/test_conv2d_patterns.py new file mode 100644 index 000000000000..b10504bbc961 --- /dev/null +++ b/tests/python/contrib/test_bnns/test_conv2d_patterns.py @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""BNNS pattern detection check""" + +import tvm +from tvm import relay +import numpy as np + +from tvm.relay.op.contrib.bnns import partition_for_bnns + +fp32 = "float32" + + +def partition(exp): + """Apply BNNS specific partitioning transformation""" + mod = tvm.IRModule.from_expr(exp) + with tvm.transform.PassContext(opt_level=3): + mod = partition_for_bnns(mod) + return mod + + +def is_op_fused(func, op_name): + is_fused = False + + def visit(op): + if ( + isinstance(op, tvm.relay.function.Function) + and op_name in op.attrs["PartitionedFromPattern"] + ): + nonlocal is_fused + is_fused = True + + tvm.relay.analysis.post_order_visit(func.body, visit) + return is_fused + + +def test_pattern_conv2d_with_bias_add(): + for axis in (1, 2): + a = relay.var("a", shape=(2, 7, 8, 8), dtype=fp32) + w = relay.const(np.random.uniform(-10, 10, (8, 7, 3, 3)).astype(fp32)) + res = relay.nn.conv2d(a, w, kernel_size=(3, 3), padding=(1, 1), channels=8, out_dtype=fp32) + b = relay.const(np.random.uniform(-10, 10, 8).astype(fp32)) + res = relay.nn.bias_add(res, b, axis=axis) + + mod = partition(res) + bias_is_fused = is_op_fused(mod["bnns_0"], "nn.bias_add") + + assert bias_is_fused if axis == 1 else not bias_is_fused + + +def test_pattern_conv2d_with_add(): + workloads = {8: False, (8, 1): False, (8, 1, 1): True, (1, 8, 1, 1): True} + + for b_shape, should_be_fused in workloads.items(): + a = relay.var("a", shape=(2, 7, 8, 8), dtype=fp32) + w = relay.const(np.random.uniform(-10, 10, (8, 7, 3, 3)).astype(fp32)) + res = relay.nn.conv2d(a, w, kernel_size=(3, 3), padding=(1, 1), channels=8, out_dtype=fp32) + b = relay.const(np.random.uniform(-10, 10, b_shape).astype(fp32)) + res = relay.add(res, b) + + mod = partition(res) + bias_is_fused = is_op_fused(mod["bnns_0"], "add") + + assert bias_is_fused == should_be_fused + + +def test_pattern_conv2d_with_non_cons_weights(): + for const_weights in (True, False): + a = relay.var("a", shape=(2, 7, 8, 8), dtype=fp32) + if const_weights: + w = relay.const(np.random.uniform(-10, 10, (8, 7, 3, 3)).astype(fp32)) + else: + w = relay.var("w", shape=(8, 7, 3, 3), dtype=fp32) + + res = relay.nn.conv2d(a, w, kernel_size=(3, 3), padding=(1, 1), channels=8, out_dtype=fp32) + + mod = partition(res) + use_bnns = len(mod.get_global_vars()) == 2 # GlobalVar: "main" and "bnns_0" + + assert use_bnns == const_weights + + +def test_pattern_conv2d_with_non_cons_bias(): + a = relay.var("a", shape=[2, 7, 8, 8], dtype=fp32) + w = relay.const(np.random.uniform(-10, 10, (8, 7, 3, 3)).astype(fp32)) + res = relay.nn.conv2d(a, w, kernel_size=(3, 3), padding=(1, 1), channels=8, out_dtype=fp32) + b = relay.var("b", shape=[8], dtype=fp32) + res = relay.nn.bias_add(res, b, axis=1) + + mod = partition(res) + bias_is_fused = is_op_fused(mod["bnns_0"], "nn.bias_add") + + assert not bias_is_fused diff --git a/tests/python/contrib/test_bnns/test_dense.py b/tests/python/contrib/test_bnns/test_dense.py new file mode 100644 index 000000000000..c2cf9bf71373 --- /dev/null +++ b/tests/python/contrib/test_bnns/test_dense.py @@ -0,0 +1,190 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""BNNS integration dense tests.""" + +import numpy as np +import math +import pytest +import tvm +from tvm import relay +from .infrastructure import ( + Device, + skip_runtime_test, + skip_codegen_test, + build_and_run, + verify, + verify_codegen, + generate_trials, +) + + +def _get_model(shape, weight_shape, units, dtype, var_names, has_bias=False, has_gelu=False): + """Return a model and any parameters it may have""" + a = relay.var(next(var_names), shape=shape, dtype=dtype) + w = tvm.nd.array(np.random.uniform(-128, 127, weight_shape).astype(dtype)) + weights = relay.const(w, dtype) + out = relay.nn.dense(a, weights, units=units, out_dtype=dtype) + params = {"w": w} + if has_bias: + b = tvm.nd.array(np.random.randint(-128, 127, weight_shape[0]).astype(dtype)) + biasc = relay.const(b, dtype) + out = relay.op.add(out, biasc) + params["b"] = b + if has_gelu: + const1 = relay.const(0.044715) + const2 = relay.const(math.sqrt(2 / math.pi)) + bias = out + out = relay.op.power(bias, relay.const(3.0, "float32")) + out = relay.op.multiply(out, const1) + out = relay.op.add(out, bias) + out = relay.op.multiply(out, const2) + out = relay.op.tanh(out) + out = relay.op.add(out, relay.const(1, "float32")) + out = relay.op.multiply(out, relay.const(0.5)) + out = relay.op.multiply(out, bias) + return out, params + + +def _get_expected_codegen(shape, weight_shape, units, dtype, has_bias=False, has_gelu=False): + output_shape = (shape[0], units) + name = "nn.dense" + if has_bias is True: + name = "bnns.dense_bias" + if has_bias is True and has_gelu is True: + name = "bnns.dense_bias_gelu" + + node = { + "op": "kernel", + "name": name, + "inputs": [], + "attrs": { + "num_outputs": "1", + "out_dtype": [["float32"]], + "shape": [[list(output_shape)]], + "dtype": [[dtype]], + "units": [[str(units)]], + }, + } + + inputs = [ + {"op": "input", "name": "", "attrs": {"shape": [[list(shape)]], "dtype": [[str(dtype)]]}}, + { + "op": "const", + "name": "", + "attrs": {"shape": [[list(weight_shape)]], "dtype": [[str(dtype)]]}, + }, + ] + + if has_bias: + inputs.append( + { + "op": "const", + "name": "", + "attrs": {"shape": [[[weight_shape[0]]]], "dtype": [["float32"]]}, + } + ) + + input_idx = 0 + for _ in range(len(inputs)): + node["inputs"].append([input_idx, 0, 0]) + input_idx += 1 + node["attrs"]["num_inputs"] = str(len(inputs)) + inputs.append(node) + return inputs + + +@pytest.mark.skipif(skip_runtime_test(), reason="Skip because BNNS codegen is not available") +def test_dense(): + device = Device() + np.random.seed(0) + + dtype = ["float32"] + shape = [ + ((1, 128), (16, 128), 16), + ((32, 32), (32, 32), 32), + ((1, 64), (1, 64), 1), + ((11, 2), (2, 2), 2), + ((2, 2), (1, 2), 1), + ] + composite = [False, True] + trials = generate_trials([dtype, shape, composite, composite], 3) + + for dtype, (shape, weight_shape, units), with_bias, with_gelu in trials: + outputs = [] + inputs = {"a": tvm.nd.array(np.random.uniform(-128, 127, shape).astype(dtype))} + func, params = _get_model( + shape, + weight_shape, + units, + dtype, + var_names=iter(inputs), + has_bias=with_bias, + has_gelu=with_gelu, + ) + for bnns in [False, True]: + outputs.append( + build_and_run( + func, + inputs, + 1, + params, + device, + enable_bnns=bnns, + )[0] + ) + + config = { + "shape": shape, + "weight_shape": weight_shape, + "units": units, + "dtype": dtype, + "with_bias": with_bias, + "with_gelu": with_gelu, + } + verify(outputs, atol=0.001, rtol=0.01, config=config) + + +@pytest.mark.skipif(skip_codegen_test(), reason="Skip because BNNS codegen is not available") +def test_codegen_dense(): + np.random.seed(0) + + dtype = ["float32"] + shape = [ + ((1, 128), (16, 128), 16), + ((32, 32), (32, 32), 32), + ((1, 64), (1, 64), 1), + ((11, 2), (2, 2), 2), + ((2, 2), (1, 2), 1), + ] + composite = [False, True] + trials = generate_trials([dtype, shape, composite, composite], 3) + + for dtype, (shape, weight_shape, units), with_bias, with_gelu in trials: + inputs = {"a"} + + args = (shape, weight_shape, units, dtype) + + func, params = _get_model( + *args, var_names=iter(inputs), has_bias=with_bias, has_gelu=with_gelu + ) + exp_codegen = _get_expected_codegen(*args, has_bias=with_bias, has_gelu=with_gelu) + verify_codegen(func, exp_codegen, 1) + + +if __name__ == "__main__": + test_dense() + test_codegen_dense() diff --git a/tests/python/contrib/test_bnns/test_matmul.py b/tests/python/contrib/test_bnns/test_matmul.py new file mode 100644 index 000000000000..7bf4d48f8e88 --- /dev/null +++ b/tests/python/contrib/test_bnns/test_matmul.py @@ -0,0 +1,113 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""BNNS integration dense tests.""" + +import numpy as np +import math +import pytest +import tvm +from tvm import relay +from tvm import testing +from .infrastructure import ( + Device, + skip_runtime_test, + skip_codegen_test, + verify_codegen, + build_and_run, + verify, + generate_trials, +) + + +def _get_model(a_shape, b_shape, dtype, var_names, is_a_constant=False, is_b_constant=False): + """Return a model and any parameters it may have""" + a = relay.var(next(var_names), shape=a_shape, dtype=dtype) + b = relay.var(next(var_names), shape=b_shape, dtype=dtype) + params = {} + if is_b_constant is True: + b = tvm.nd.array(np.random.uniform(-128, 127, b_shape).astype(dtype)) + params["b"] = b + b = relay.const(b, dtype) + if is_a_constant is True: + a = tvm.nd.array(np.random.uniform(-128, 127, a_shape).astype(dtype)) + params["a"] = a + a = relay.const(a, dtype) + out = relay.nn.batch_matmul(a, b) + return out, params + + +@pytest.mark.skipif(skip_runtime_test(), reason="Skip because BNNS codegen is not available") +def test_matmul(): + device = Device() + np.random.seed(0) + dtype = "float32" + + # C[N, I, J] = A[N, I, K] * B[N, J, K] + shapes_config = [ + # B, I, J, K + [1, 4, 4, 3], + [1, 16, 32, 32], + [2, 1, 1, 3], + [2, 16, 32, 32], + [5, 1, 1, 3], + ] + data_config = [ + # A_is_constant, B_is_constant + [False, True], + [True, False], + [False, False], + ] + + for N, I, J, K in shapes_config: + a_shape = [N, I, K] + b_shape = [N, J, K] + for is_a_constant, is_b_constant in data_config: + outputs = [] + inputs = { + "a": tvm.nd.array(np.random.uniform(-128, 127, a_shape).astype(dtype)), + "b": tvm.nd.array(np.random.uniform(-128, 127, b_shape).astype(dtype)), + } + func, params = _get_model( + a_shape, + b_shape, + dtype, + var_names=iter(inputs), + is_a_constant=is_a_constant, + is_b_constant=is_b_constant, + ) + for enable_bnns in [False, True]: + outputs.append( + build_and_run( + func, + inputs, + 1, + params, + device, + enable_bnns=enable_bnns, + )[0] + ) + + config = { + "a_shape": a_shape, + "b_shape": b_shape, + "dtype": dtype, + } + verify(outputs, atol=0.001, rtol=0.01, config=config) + + +if __name__ == "__main__": + test_matmul() diff --git a/tests/python/contrib/test_bnns/test_normalization.py b/tests/python/contrib/test_bnns/test_normalization.py new file mode 100644 index 000000000000..094cfb041c3c --- /dev/null +++ b/tests/python/contrib/test_bnns/test_normalization.py @@ -0,0 +1,201 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""BNNS integration normalization tests.""" + +import numpy as np +import math +import pytest +import tvm +from tvm import relay +from tvm import testing +from .infrastructure import ( + Device, + skip_runtime_test, + skip_codegen_test, + verify_codegen, + build_and_run, + verify, + generate_trials, +) + + +def _get_model( + shape, b_shape, s_shape, dtype, var_names, axis=1, epsilon=1e-5, center=True, scale=True +): + """Return a model and any parameters it may have""" + src = relay.var(next(var_names), shape=shape, dtype=dtype) + params = {} + b = tvm.nd.array(np.random.uniform(-128, 127, b_shape).astype(dtype)) + params["b"] = b + b = relay.const(b, dtype) + s = tvm.nd.array(np.random.uniform(-128, 127, b_shape).astype(dtype)) + params["b"] = s + s = relay.const(s, dtype) + out = relay.nn.instance_norm(src, s, b, axis, epsilon, center, scale) + + return out, params + + +def _get_expected_codegen(shape, axis, center, scale, dtype, offload_on_bnns): + output_shape = shape + name = "nn.instance_norm" + + node = { + "op": "kernel", + "name": name, + "inputs": [], + "attrs": { + "num_outputs": "1", + "axis": [[str(axis)]], + "center": [[str(int(center))]], + "scale": [[str(int(scale))]], + "shape": [[list(output_shape)]], + "dtype": [[dtype]], + "epsilon": [["1.0000000000000001e-05"]], + }, + } + + inputs = [ + {"op": "input", "name": "", "attrs": {"shape": [[list(shape)]], "dtype": [[str(dtype)]]}}, + { + "op": "const", + "name": "", + "attrs": {"shape": [[[shape[axis]]]], "dtype": [[str(dtype)]]}, + }, + { + "op": "const", + "name": "", + "attrs": {"shape": [[[shape[axis]]]], "dtype": [[str(dtype)]]}, + }, + ] + + input_idx = 0 + for _ in range(len(inputs)): + node["inputs"].append([input_idx, 0, 0]) + input_idx += 1 + node["attrs"]["num_inputs"] = str(len(inputs)) + inputs.append(node) + return inputs + + +@pytest.mark.skipif(skip_runtime_test(), reason="Skip because BNNS codegen is not available") +def test_normalization(): + device = Device() + np.random.seed(0) + dtype = "float32" + + shapes_config = [ + [1, 2, 3, 4], + [3, 2, 3, 4], + [2, 2, 3], + [16, 32, 32], + [5, 3], + ] + axes = [-1, 0, 1, 2] + + for shape in shapes_config: + for axis in axes: + if len(shape) == 2 and axis != 0: + continue + for center in [False, True]: + for scale in [False, True]: + outputs = [] + inputs = { + "src": tvm.nd.array(np.random.uniform(-128, 127, shape).astype(dtype)), + } + func, params = _get_model( + shape, + [shape[axis]], + [shape[axis]], + dtype, + var_names=iter(inputs), + axis=axis, + center=center, + scale=scale, + ) + for enable_bnns in [False, True]: + outputs.append( + build_and_run( + func, + inputs, + 1, + params, + device, + enable_bnns=enable_bnns, + )[0] + ) + + config = { + "dtype": dtype, + } + verify(outputs, atol=0.001, rtol=0.01, config=config) + + +@pytest.mark.skipif(skip_codegen_test(), reason="Skip because BNNS codegen is not available") +def test_codegen_normalization(): + np.random.seed(0) + + dtype = "float32" + shapes_config = [ + [1, 2, 3, 4], + [3, 2, 3, 4], + [2, 2, 3], + [16, 32, 32], + [5, 3], + ] + axes = [-1, 0, 1, 2] + + def check_normalization(rank, axis): + if rank < 3 or rank > 4: + return False + if axis == 0 and rank == 3 or axis == 1 and rank == 4: + return True + return False + + for shape in shapes_config: + for axis in axes: + if len(shape) == 2 and axis != 0: + continue + for center in [False, True]: + for scale in [False, True]: + inputs = {"src"} + + args = (shape, axis, center, scale, dtype) + + func, params = _get_model( + shape, + [shape[axis]], + [shape[axis]], + dtype, + var_names=iter(inputs), + axis=axis, + center=center, + scale=scale, + ) + + offload_on_bnns = check_normalization(len(shape), axis) + if offload_on_bnns is True: + bnns_blocks = 1 + else: + bnns_blocks = 0 + exp_codegen = _get_expected_codegen(*args, offload_on_bnns) + verify_codegen(func, exp_codegen, bnns_blocks) + + +if __name__ == "__main__": + test_normalization() + test_codegen_normalization() diff --git a/tests/python/contrib/test_bnns/test_onnx_topologies.py b/tests/python/contrib/test_bnns/test_onnx_topologies.py new file mode 100644 index 000000000000..86f98eb6e8de --- /dev/null +++ b/tests/python/contrib/test_bnns/test_onnx_topologies.py @@ -0,0 +1,140 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""BNNS pattern detection check""" + +import pytest + +import tvm +from tvm import relay +from tvm.relay import transform +from tvm.contrib import utils, graph_runtime +from tvm.contrib.download import download_testdata +from tvm.relay.op.contrib.bnns import partition_for_bnns + +import numpy as np + +pytest.importorskip("onnx") + +bnns_is_absent = tvm.get_global_func("relay.ext.bnns", True) is None + +TARGET = "llvm" +INPUT_SHAPE = [1, 3, 224, 224] + +BASE_MODEL_URL = "https://github.com/onnx/models/raw/master/" +MODEL_URL_COLLECTION = { + "BERT": "text/machine_comprehension/bert-squad/model/bertsquad-10.onnx", + "MobileNet-v2": "vision/classification/mobilenet/model/mobilenetv2-7.onnx", + "ResNet50-v1": "vision/classification/resnet/model/resnet50-v1-7.onnx", + "ResNet50-v2": "vision/classification/resnet/model/resnet50-v2-7.onnx", + "SqueezeNet-v1.1": "vision/classification/squeezenet/model/squeezenet1.1-7.onnx", + "SqueezeNet-v1.0": "vision/classification/squeezenet/model/squeezenet1.0-7.onnx", + "Inception-v1": "vision/classification/inception_and_googlenet/inception_v1/model/inception-v1-7.onnx", + "Inception-v2": "vision/classification/inception_and_googlenet/inception_v2/model/inception-v2-7.onnx", +} + + +def get_onnx_input_name(model): + inputs = [node.name for node in model.graph.input] + initializer = [node.name for node in model.graph.initializer] + + inputs = list(set(inputs) - set(initializer)) + return inputs + + +def get_model_url(model_name): + return BASE_MODEL_URL + MODEL_URL_COLLECTION[model_name] + + +def get_name_from_url(url): + return url[url.rfind("/") + 1 :].strip() + + +def find_of_download(model_name): + model_url = get_model_url(model_name) + model_file_name = get_name_from_url(model_url) + return download_testdata(model_url, model_file_name, module="models") + + +def get_model(model_name): + model_path = find_of_download(model_name) + onnx_model = onnx.load(model_path) + input_names = get_onnx_input_name(onnx_model) + input_dict = {} + for name in input_names: + input_dict[name] = INPUT_SHAPE # TODO: hardcode + mod, params = relay.frontend.from_onnx(onnx_model, input_dict, freeze_params=True) + return mod, params, input_dict + + +def simplify_model(mod): + """ + Simplify execution graph + + At least merge BatchNorm into convolution. For this purpose decompose BN primitive + into simple operation which can be calculated as const expr and after that merged + into nearest conv/dense primitive. + """ + seq = tvm.transform.Sequential( + [ + transform.InferType(), + transform.FoldConstant(), + transform.SimplifyInference(), + transform.FoldScaleAxis(), + ] + ) + return seq(mod) + + +def process(model_name): + temp = utils.tempdir() + model, params, input_dict = get_model(model_name) + + def run(mod, target, simplify=True, with_bnns=False): + with tvm.transform.PassContext(opt_level=3): + if simplify: + mod = simplify_model(mod) + if with_bnns: + mod = partition_for_bnns(mod) + graph_module = relay.build(mod, target=target, target_host=target, params=params) + + lib_name = "deploy.tar" + path_dso = temp.relpath(lib_name) + graph_module.export_library(path_dso) + + ctx = tvm.cpu(0) + loaded_lib = tvm.runtime.load_module(path_dso) + + module = graph_runtime.GraphModule(loaded_lib["default"](ctx)) + module.run() + return module.get_output(0).asnumpy() + + res_llvm = run(model, TARGET, simplify=True, with_bnns=False) + res_bnns = run(model, TARGET, simplify=True, with_bnns=True) + + tvm.testing.assert_allclose( + res_llvm, + res_bnns, + atol=0.002, + rtol=0.007, + ) + + +@pytest.mark.skip(reason="Manually disabled because of huge complexity") +@pytest.mark.skipif(bnns_is_absent, reason="BNNS runtime is absent") +@pytest.mark.parametrize("model_name", MODEL_URL_COLLECTION.keys()) +def test_topology(model_name): + process(model_name) diff --git a/tests/python/contrib/test_bnns/test_pooling.py b/tests/python/contrib/test_bnns/test_pooling.py new file mode 100644 index 000000000000..77a78d4bf7e1 --- /dev/null +++ b/tests/python/contrib/test_bnns/test_pooling.py @@ -0,0 +1,289 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""BNNS integration pooling tests.""" + +import numpy as np +import pytest +import tvm +from tvm import relay +from tvm import testing +from .infrastructure import ( + skip_runtime_test, + skip_codegen_test, + build_and_run, + verify, + verify_codegen, +) +from .infrastructure import Device + + +def _calculate_output_shape(shape, sizes, padding, strides): + """Calculate pooling output shape.""" + output_height = ((shape[2] - sizes[0] + padding[0] + padding[2]) / strides[0]) + 1 + output_width = ((shape[3] - sizes[1] + padding[1] + padding[3]) / strides[1]) + 1 + return 1, shape[1], int(output_height), int(output_width) + + +def _get_pooling_model( + shape, dtype, typef, sizes, strides, padding, ceil_mode, count_include_pad, var_names +): + """Return a model and any parameters it may have.""" + if len(padding) == 2: + padding = (padding[0], padding[1], padding[0], padding[1]) + out = relay.var(next(var_names), shape=shape, dtype=dtype) + + if typef == "nn.max_pool2d": + out = relay.nn.max_pool2d( + out, + pool_size=sizes, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + ) + elif typef == "nn.avg_pool2d": + out = relay.nn.avg_pool2d( + out, + pool_size=sizes, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + ) + else: + raise ValueError("Function not supported") + + return out + + +def _get_global_pooling_model(shape, dtype, typef, var_names): + """Return a model and any parameters it may have.""" + out = relay.var(next(var_names), shape=shape, dtype=dtype) + + if typef == "nn.global_max_pool2d": + out = relay.nn.global_max_pool2d(out) + elif typef == "nn.global_avg_pool2d": + out = relay.nn.global_avg_pool2d(out) + else: + raise ValueError("Function not supported") + + return out + + +def _get_expected_pooling_codegen( + shape, dtype, typef, sizes, strides, padding, ceil_mode, count_include_pad +): + if len(padding) == 2: + padding = (padding[0], padding[1], padding[0], padding[1]) + output_shape = _calculate_output_shape(shape, sizes, padding, strides) + + node = { + "op": "kernel", + "name": typef, + "inputs": [[0, 0, 0]], + "attrs": { + "num_inputs": "1", + "num_outputs": "1", + "layout": [["NCHW"]], + "shape": [[list(output_shape)]], + "dtype": [[dtype]], + "padding": [[str(p) for p in padding]], + "strides": [[str(s) for s in strides]], + "pool_size": [[str(s) for s in sizes]], + "ceil_mode": [[str(1 if ceil_mode else 0)]], + }, + } + + if typef == "nn.avg_pool2d" or typef == "nn.l2_pool2d": + node["attrs"]["count_include_pad"] = [["1" if count_include_pad else "0"]] + + input = {"op": "input", "name": "", "attrs": {"shape": [[list(shape)]], "dtype": [[dtype]]}} + return [input, node] + + +def _get_expected_global_pooling_codegen(shape, dtype, typef): + node = { + "op": "kernel", + "name": typef, + "inputs": [[0, 0, 0]], + "attrs": { + "num_inputs": "1", + "num_outputs": "1", + "layout": [["NCHW"]], + "shape": [[[1, shape[1], 1, 1]]], + "dtype": [[dtype]], + }, + } + + input = {"op": "input", "name": "", "attrs": {"shape": [[list(shape)]], "dtype": [[dtype]]}} + return [input, node] + + +@pytest.mark.skipif(skip_runtime_test(), reason="Skip because BNNS codegen is not available") +def test_pooling(): + device = Device() + np.random.seed(0) + + dtype = "float32" + trials = [ + ["nn.max_pool2d", (3, 3), (2, 2), (0, 0), False, False, (27, 27, 512)], + ["nn.max_pool2d", (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)], + ["nn.max_pool2d", (3, 3), (2, 2), (1, 1), True, True, (15, 15, 16)], + ["nn.max_pool2d", (2, 2), (2, 2), (0, 1), False, False, (16, 16, 16)], + ["nn.avg_pool2d", (2, 2), (2, 2), (1, 1), False, False, (16, 16, 16)], + ["nn.avg_pool2d", (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)], + ["nn.avg_pool2d", (3, 3), (2, 2), (0, 1), True, False, (15, 15, 16)], + ] + + for ( + typef, + size, + stride, + pad, + ceil_mode, + count_include_pad, + input_shape, + ) in trials: + shape = (1, *input_shape) + outputs = [] + inputs = { + "a": tvm.nd.array(np.random.uniform(-127, 128, shape).astype(dtype)), + } + + func = _get_pooling_model( + shape, dtype, typef, size, stride, pad, ceil_mode, count_include_pad, iter(inputs) + ) + + config = { + "size": size, + "stride": stride, + "shape": shape, + "pooling type": typef, + "dtype": dtype, + "padding": pad, + "ceil_mode": ceil_mode, + "count_include_pad": count_include_pad, + "inputs": inputs, + } + + params = None + for enable_bnns in [False, True]: + outputs.append( + build_and_run( + func, inputs, 1, params, device, enable_bnns=enable_bnns, config=config + )[0] + ) + + verify(outputs, atol=0.001, rtol=0.001, config=config) + + +@pytest.mark.skipif(skip_runtime_test(), reason="Skip because BNNS codegen is not available") +def test_global_pooling(): + device = Device() + np.random.seed(0) + + dtype = "float32" + + trials = [ + ["nn.global_max_pool2d", (8, 8, 16)], + ["nn.global_max_pool2d", (9, 9, 16)], + ["nn.global_max_pool2d", (8, 8, 16)], + ["nn.global_avg_pool2d", (8, 8, 16)], + ["nn.global_avg_pool2d", (8, 8, 16)], + ["nn.global_avg_pool2d", (9, 9, 16)], + ] + + for typef, input_shape in trials: + shape = (1, *input_shape) + outputs = [] + inputs = { + "a": tvm.nd.array(np.random.uniform(-127, 128, shape).astype(dtype)), + } + + func = _get_global_pooling_model(shape, dtype, typef, iter(inputs)) + config = { + "shape": shape, + "pooling type": typef, + "dtype": dtype, + } + + for enable_bnns in [False, True]: + outputs.append( + build_and_run( + func, inputs, 1, None, device, enable_bnns=enable_bnns, config=config + )[0] + ) + + verify(outputs, atol=0.001, rtol=0.001, config=config) + + +@pytest.mark.skipif(skip_codegen_test(), reason="Skip because BNNS codegen is not available") +def test_codegen_pooling(): + dtype = "float32" + + trials = [ + ["nn.max_pool2d", (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)], + ["nn.max_pool2d", (3, 3), (2, 2), (1, 1), True, True, (15, 15, 16)], + ["nn.max_pool2d", (2, 2), (2, 2), (0, 1), False, False, (16, 16, 16)], + ["nn.avg_pool2d", (2, 2), (2, 2), (1, 1), False, False, (16, 16, 16)], + ["nn.avg_pool2d", (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)], + ["nn.avg_pool2d", (3, 3), (2, 2), (0, 1), True, False, (15, 15, 16)], + ] + + for ( + typef, + size, + stride, + pad, + ceil_mode, + count_include_pad, + input_shape, + ) in trials: + shape = (1, *input_shape) + inputs = {"a"} + args = (shape, dtype, typef, size, stride, pad, False, False) + func = _get_pooling_model(*args, iter(inputs)) + exp_codegen = _get_expected_pooling_codegen(*args) + verify_codegen(func, exp_codegen, 1) + + +@pytest.mark.skipif(skip_codegen_test(), reason="Skip because BNNS codegen is not available") +def test_codegen_global_pooling(): + dtype = "float32" + + trials = [ + ["nn.global_max_pool2d", (8, 8, 16)], + ["nn.global_max_pool2d", (9, 9, 16)], + ["nn.global_max_pool2d", (8, 8, 16)], + ["nn.global_avg_pool2d", (8, 8, 16)], + ["nn.global_avg_pool2d", (8, 8, 16)], + ["nn.global_avg_pool2d", (9, 9, 16)], + ] + + for typef, input_shape in trials: + shape = (1, *input_shape) + inputs = {"a"} + args = (shape, dtype, typef) + func = _get_global_pooling_model(*args, iter(inputs)) + exp_codegen = _get_expected_global_pooling_codegen(*args) + verify_codegen(func, exp_codegen, 1) + + +if __name__ == "__main__": + test_pooling() + test_global_pooling() + test_codegen_pooling() + test_codegen_global_pooling()