From 757c82cd6c535162a62562d65e819ccc92dc3932 Mon Sep 17 00:00:00 2001 From: Alexander Peskov Date: Sat, 16 Jan 2021 17:38:53 +0300 Subject: [PATCH 01/27] Introduce Apple BNNS backend This is simple JSON based runtime which offload execution of some operation into Accelerate frameworks via BNNS api. Works only for: * macOS 11.0 and later * iOS 14.0 and later Supported primitives: * conv2d and fusing with bias and relu * dense and fusing with bias and relu/gelu * batch_matmul Signed-off-by: Alexander Peskov --- CMakeLists.txt | 2 + cmake/config.cmake | 3 + cmake/modules/contrib/BNNS.cmake | 30 + python/tvm/relay/op/contrib/__init__.py | 1 + python/tvm/relay/op/contrib/bnns.py | 247 ++++++ src/relay/backend/contrib/bnns/codegen.cc | 212 +++++ src/runtime/contrib/bnns/bnns_json_runtime.cc | 811 ++++++++++++++++++ src/runtime/contrib/json/json_runtime.h | 6 +- tests/python/contrib/test_bnns/__init__.py | 17 + .../contrib/test_bnns/infrastructure.py | 309 +++++++ tests/python/contrib/test_bnns/test_conv2d.py | 150 ++++ .../contrib/test_bnns/test_conv2d_patterns.py | 124 +++ tests/python/contrib/test_bnns/test_dense.py | 188 ++++ tests/python/contrib/test_bnns/test_matmul.py | 115 +++ 14 files changed, 2212 insertions(+), 3 deletions(-) create mode 100644 cmake/modules/contrib/BNNS.cmake create mode 100644 python/tvm/relay/op/contrib/bnns.py create mode 100644 src/relay/backend/contrib/bnns/codegen.cc create mode 100644 src/runtime/contrib/bnns/bnns_json_runtime.cc create mode 100644 tests/python/contrib/test_bnns/__init__.py create mode 100644 tests/python/contrib/test_bnns/infrastructure.py create mode 100644 tests/python/contrib/test_bnns/test_conv2d.py create mode 100644 tests/python/contrib/test_bnns/test_conv2d_patterns.py create mode 100644 tests/python/contrib/test_bnns/test_dense.py create mode 100644 tests/python/contrib/test_bnns/test_matmul.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 769a35318d9d..12c8e88044be 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -74,6 +74,7 @@ tvm_option(USE_CPP_RPC "Build CPP RPC" OFF) tvm_option(USE_TFLITE "Build with tflite support" OFF) tvm_option(USE_TENSORFLOW_PATH "TensorFlow root path when use TFLite" none) tvm_option(USE_COREML "Build with coreml support" OFF) +tvm_option(USE_BNNS "Build with BNNS support" OFF) tvm_option(USE_TARGET_ONNX "Build with ONNX Codegen support" OFF) tvm_option(USE_ARM_COMPUTE_LIB "Build with Arm Compute Library" OFF) tvm_option(USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME "Build with Arm Compute Library graph runtime" OFF) @@ -348,6 +349,7 @@ include(cmake/modules/contrib/HybridDump.cmake) include(cmake/modules/contrib/TFLite.cmake) include(cmake/modules/contrib/TF_TVMDSOOP.cmake) include(cmake/modules/contrib/CoreML.cmake) +include(cmake/modules/contrib/BNNS.cmake) include(cmake/modules/contrib/ONNX.cmake) include(cmake/modules/contrib/ArmComputeLib.cmake) include(cmake/modules/contrib/TensorRT.cmake) diff --git a/cmake/config.cmake b/cmake/config.cmake index 872feb918a4f..5faeb9325dba 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -269,3 +269,6 @@ set(USE_HEXAGON_SDK /path/to/sdk) # Whether to use ONNX codegen set(USE_TARGET_ONNX OFF) + +# Whether enable BNNS runtime +set(USE_BNNS OFF) diff --git a/cmake/modules/contrib/BNNS.cmake b/cmake/modules/contrib/BNNS.cmake new file mode 100644 index 000000000000..e14aa2857ebc --- /dev/null +++ b/cmake/modules/contrib/BNNS.cmake @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(USE_BNNS STREQUAL "ON") + add_definitions(-DUSE_JSON_RUNTIME=1) + file(GLOB BNNS_RELAY_CONTRIB_SRC src/relay/backend/contrib/bnns/*.cc) + list(APPEND COMPILER_SRCS ${BNNS_RELAY_CONTRIB_SRC}) + list(APPEND COMPILER_SRCS ${JSON_RELAY_CONTRIB_SRC}) + + list(APPEND TVM_RUNTIME_LINKER_LIBS "-framework Accelerate") + + file(GLOB BNNS_CONTRIB_SRC src/runtime/contrib/bnns/*.cc) + list(APPEND RUNTIME_SRCS ${BNNS_CONTRIB_SRC}) + message(STATUS "Build with BNNS JSON runtime: " ${EXTERN_LIBRARY_BNNS}) +endif() + diff --git a/python/tvm/relay/op/contrib/__init__.py b/python/tvm/relay/op/contrib/__init__.py index 49abf36134b4..30c2db0ddf0b 100644 --- a/python/tvm/relay/op/contrib/__init__.py +++ b/python/tvm/relay/op/contrib/__init__.py @@ -20,6 +20,7 @@ from .arm_compute_lib import * from .dnnl import * +from .bnns import * from .coreml import * from .ethosn import * from .tensorrt import * diff --git a/python/tvm/relay/op/contrib/bnns.py b/python/tvm/relay/op/contrib/bnns.py new file mode 100644 index 000000000000..17aad38d3063 --- /dev/null +++ b/python/tvm/relay/op/contrib/bnns.py @@ -0,0 +1,247 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""BNNS library supported operators. +Is a part of Accelerate framework on macOS/iOS platforms. Apple provide several APIs +to handle tensor processing. Particularly: + * BNNS (basic neural ) + * vDSP (1D and 2D tensor processing) + * BLAS (gemm provide) + +# There are two ways to registering a function for an op to indicate if it is +# supported by DNNL. + +# - The first and simplest way is to use the helper so that +# users only need to provide the operator name and a boolean value to indicate if +# it is supported. For example: +# +# .. code-block:: python +# +# add = _register_external_op_helper("add") +# add = _register_external_op_helper("add", True) +# add = _register_external_op_helper("add", False) +# +# - The other way is to implement the function by themselves to +# check the attributes of the op and decide if it should be offloaded to DNNL. +""" +import math +import tvm.ir +from ...dataflow_pattern import wildcard, is_op, is_expr, is_constant +from .register import register_pattern_table, get_pattern_table + +from tvm.relay import transform +from tvm.relay.expr import const +from tvm.relay.build_module import bind_params_by_name + +def partition_for_bnns(mod, params=None): + """Partition the graph greedily offloading supported + operators to BNNS. + + Parameters + ---------- + mod : Module + The module to run passes on. + params : Optional[Dict[str, NDArray]] + Constant input parameters. + + Returns + ------- + ret : annotated and partitioned module. + """ + if params: + mod["main"] = bind_params_by_name(mod["main"], params) + + seq = tvm.transform.Sequential( + [ + transform.InferType(), + transform.FoldConstant(), + transform.FoldScaleAxis(), + transform.DynamicToStatic(), + transform.AlterOpLayout(), + # TODO(apeskov): WA. AlterOpLayout call lead to constants shape transformation + # Some expand_dims op may appears after constants. It breaks BNNS fusing. + # So we have to call FoldConstant right before bnns composite passes. + transform.FoldConstant(), + transform.MergeComposite(get_pattern_table("bnns")), + transform.AnnotateTarget("bnns"), + # If you no need in per layer performance statistic you can + # uncomment next line + # transform.MergeCompilerRegions(), + transform.PartitionGraph(), + ] + ) + + return seq(mod) + + +def _register_external_op_helper(op_name, supported=True): + """The helper function to indicate that a given operator can be supported + by BNNS. + + Paramters + --------- + op_name : Str + The name of operator that will be registered. + + Returns + ------- + f : callable + A function that returns if the operator is supported by BNNS. + """ + + @tvm.ir.register_op_attr(op_name, "target.bnns") + def _func_wrapper(expr): + return supported + + return _func_wrapper + +_register_external_op_helper("nn.batch_matmul") + + +# TODO [apeskov]: +# 1. enlarge list of supported types on +# 2. clarify meaning of "" value +def dtype_is_supported(dtype): + return dtype == "float32" or dtype == "" + + +@tvm.ir.register_op_attr("nn.conv2d", "target.bnns") +def conv2d_check(expr): + """Check if the conv2d can be executed in BNNS""" + attrs, args = expr.attrs, expr.args + data_typ = args[0].checked_type + if len(data_typ.shape) != 4 or data_typ.dtype != "float32": + return False + if not isinstance(args[1], tvm.relay.expr.Constant): + return False + kernel_typ = args[1].checked_type + if len(kernel_typ.shape) != 4 or kernel_typ.dtype != "float32": + return False + if attrs.data_layout != "NCHW": + return False + if not dtype_is_supported(attrs.out_dtype): + return False + return True + + +def bias_check(expr): + """Check is bias added through the correct dimension""" + attrs, args = expr.attrs, expr.args + if not isinstance(args[1], tvm.relay.expr.Constant): + return False + if expr.op.name == "nn.bias_add": + return attrs.axis == 1 + elif expr.op.name == "add": + b_shape = args[1].checked_type.shape + if len(b_shape) == 4: + return bool(b_shape[0] == 1 and b_shape[2] == 1 and b_shape[3] == 1) + elif len(b_shape) == 3: + return bool(b_shape[1] == 1 and b_shape[2] == 1) + + return False + + +@tvm.ir.register_op_attr("nn.dense", "target.bnns") +def dense(expr): + """Check if the dense can be used in BNNS.""" + attrs, args = expr.attrs, expr.args + data_typ = args[0].checked_type + if data_typ.dtype != "float32": + return False + kernel_typ = args[1].checked_type + if len(kernel_typ.shape) != 2 or kernel_typ.dtype != "float32": + return False + if attrs.out_dtype != "float32" and attrs.out_dtype != "": + return False + return True + + +def make_conv_relu_pattern(with_bias=True, with_relu=True): + data = wildcard() + weight = wildcard() + bias = wildcard() + pat = is_op("nn.conv2d")(data, weight) + if with_bias: + pat = is_op("add")(pat, bias) | is_op("nn.bias_add")(pat, bias) + if with_relu: + pat = is_op("nn.relu")(pat) + return pat + + +def check_conv(extract): + """Check conv pattern is supported by BNNS.""" + is_ok = True + + def visit(op): + nonlocal is_ok + if isinstance(op, tvm.relay.Call): + if op.op.name == "nn.conv2d": + is_ok &= conv2d_check(op) + elif op.op.name in ("nn.bias_add", "add"): + is_ok &= bias_check(op) + + tvm.relay.analysis.post_order_visit(extract, visit) + return is_ok + + +def make_dense_bias_pattern(): + data = wildcard() + weight = wildcard() + bias = wildcard() + d = is_op("nn.dense")(data, weight) + return is_op("add")(d, bias) + + +def make_dense_bias_gelu_pattern(): + dense_bias = make_dense_bias_pattern() + const1 = is_expr(const(0.044715)) + const2 = is_expr(const(math.sqrt(2 / math.pi))) + + gelu = is_op("power")(dense_bias, is_expr(const(3, dtype="float32"))) + gelu = is_op("multiply")(gelu, const1) + gelu = is_op("add")(gelu, dense_bias) + gelu = is_op("multiply")(gelu, const2) + gelu = is_op("tanh")(gelu) + gelu = is_op("add")(gelu, is_expr(const(1, dtype="float32"))) + gelu = is_op("multiply")(gelu, is_expr(const(0.5))) + gelu = is_op("multiply")(gelu, dense_bias) + return gelu + + +def check_dense(extract): + """Check conv pattern is supported by ACL.""" + call = extract + while call.op.name != "nn.dense": + call = call.args[0] + return dense(call) + + +@register_pattern_table("bnns") +def pattern_table(): + conv2d_bias_pat = ("bnns.conv2d_bias", make_conv_relu_pattern(with_bias=True, with_relu=False), check_conv) + conv2d_bias_relu_pat = ("bnns.conv2d_bias_relu", make_conv_relu_pattern(with_bias=True, with_relu=True), check_conv) + conv2d_relu_pat = ("bnns.conv2d_relu", make_conv_relu_pattern(with_bias=False, with_relu=True), check_conv) + dense_bias_gelu = ("bnns.dense_bias_gelu", make_dense_bias_gelu_pattern(), check_dense) + dense_bias = ("bnns.dense_bias", make_dense_bias_pattern(), check_dense) + bnns_patterns = [ + conv2d_bias_relu_pat, + conv2d_relu_pat, + conv2d_bias_pat, + dense_bias_gelu, + dense_bias, + ] + return bnns_patterns diff --git a/src/relay/backend/contrib/bnns/codegen.cc b/src/relay/backend/contrib/bnns/codegen.cc new file mode 100644 index 000000000000..504425a225c5 --- /dev/null +++ b/src/relay/backend/contrib/bnns/codegen.cc @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file + * \brief Implementation of BNNS codegen APIs. + */ + +#include +#include +#include +#include + +#include +#include + +#include "../../utils.h" + +#include "../../../../runtime/contrib/json/json_node.h" +#include "../codegen_json/codegen_json.h" + +namespace tvm { +namespace relay { +namespace contrib { + +using namespace backend; + +/*! + * \brief Retrieve the expected "root" op nested inside a fused call, such as conv2d in relu(add(conv2d)) + * \param call A Relay call node. Typically nn.relu when called the first time. + * \param max_depth The maximum number of calls before the root op, counting from current_call. + * \param root_name The name of expected "root" op in this fused call. + * \return A CallNode corresponding to the root op + */ +inline const CallNode* FindCallWithName(const CallNode* current_call, int max_depth, + const std::string& root_name) { + ICHECK(current_call && max_depth >= 0); + + if (max_depth == 0) { + ICHECK(current_call && IsOp(current_call, root_name)); + return current_call; + } + if (IsOp(current_call, root_name)) { + return current_call; + } + + ICHECK_GT(current_call->args.size(), 0); + + const auto* next_call = current_call->args[0].as(); + return FindCallWithName(next_call, max_depth - 1, root_name); +} + +class BNNSJSONSerializer : public backend::contrib::JSONSerializer { + using JSONGraphNode = tvm::runtime::json::JSONGraphNode; + using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; + + public: + BNNSJSONSerializer(const std::string& symbol, const Expr& expr) : JSONSerializer(symbol, expr) {} + + std::vector VisitExpr_(const CallNode* cn) override { + Expr expr = GetRef(cn); + std::string name; + const CallNode* call = cn; + if (const auto* op_node = cn->op.as()) { + name = op_node->name; + } else if (const auto* fn = cn->op.as()) { + auto comp = fn->GetAttr(attr::kComposite); + ICHECK(comp.defined()) << "BNNS JSON runtime only supports composite functions."; + name = comp.value(); + + auto body = fn->body.as(); + if (name == "bnns.conv2d_bias_relu") { + auto add_op_type = IsOp(body->args[0].as(), "add") ? "add" : "nn.bias_add"; + call = GetRootCall(body, 2, {"nn.conv2d", add_op_type, "nn.relu"}); + } else if (name == "bnns.conv2d_bias") { + auto add_op_type = IsOp(body, "add") ? "add" : "nn.bias_add"; + call = GetRootCall(body, 1, {"nn.conv2d", add_op_type}); + } else if (name == "bnns.conv2d_relu") { + call = GetRootCall(body, 1, {"nn.conv2d", "nn.relu"}); + ICHECK(call->op.as()) << "Not op node"; + } else if (name == "bnns.dense_bias") { + call = GetRootCall(fn->body.as(), 1, {"nn.dense", "add"}); + } else if (name == "bnns.dense_bias_gelu") { + call = FindCallWithName(fn->body.as(), 10, "nn.dense"); + } else { + LOG(FATAL) << "Unrecognized BNNS pattern: " << name; + } + } else { + LOG(FATAL) << "BNNS JSON runtime does not support calls to " << cn->op->GetTypeKey(); + } + + std::vector inputs; + for (const auto& arg : cn->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + auto node = std::make_shared(name, /* name_ */ + "kernel", /* op_type_ */ + inputs, 1 /* num_outputs_ */); + SetCallNodeAttribute(node, call); + return AddNode(node, GetRef(cn)); + } +}; + +/*! + * \brief The external compiler/codegen tool. It takes a Relay expression/module and + * compile it into a runtime module. + */ +runtime::Module BNNSCompiler(const ObjectRef& ref) { + ICHECK(ref->IsInstance()); + auto func = Downcast(ref); + auto func_name = GetExtSymbol(func); + BNNSJSONSerializer serializer(func_name, func); + serializer.serialize(); + std::string graph_json = serializer.GetJSON(); + auto params = serializer.GetParams(); + + const auto* pf = runtime::Registry::Get("runtime.BNNSJSONRuntimeCreate"); + ICHECK(pf != nullptr) << "Cannot find JSON runtime module to create"; + auto mod = (*pf)(func_name, graph_json, params); + return mod; +} + +TVM_REGISTER_GLOBAL("relay.ext.bnns").set_body_typed(BNNSCompiler); + + +/** + * \brief A helper to expand the params by adding ones which used by BNNS runtime + * for a given expression. Same as default ConstantUpdater but skip constant from + * essential BNNS composed function ops. + */ +struct BNNSConstantUpdater : public ConstantUpdater { + public: + BNNSConstantUpdater(const std::string& symbol, + std::unordered_map* params, + std::vector &skip_mask) + : ConstantUpdater(symbol, params), skip_mask_(skip_mask) {} + using ConstantUpdater::VisitExpr_; + + /**! + * Like an original implementation but avoid visiting of body nodes + * for BNNS specific composite primitives. + */ + void VisitExpr_(const FunctionNode* op) final { + this->VisitSpan(op->span); + for (auto param : op->params) { + this->VisitExpr(param); + } + + if (!isBNNSSpecificCompositeFunc(op)) { + this->VisitExpr(op->body); + } + } + + private: + bool isBNNSSpecificCompositeFunc(const FunctionNode* op) { + auto comp = op->GetAttr(attr::kComposite); + if (!comp) + return false; + + auto comp_name = comp.value(); + + bool is_match = false; + for (const auto &mask : skip_mask_) { + if (std::string(comp_name).substr(0, mask.size()) == mask) { + is_match = true; + break; + } + } + return is_match; + } + + std::vector skip_mask_; +}; + +Map BNNSConstantUpdaterFunc(Expr expr, std::string symbol) { + std::vector bnns_composite_filter = {"bnns."}; + + // Visit all suitable constant nodes + std::unordered_map res; + BNNSConstantUpdater const_updater(symbol, &res, bnns_composite_filter); + const_updater(expr); + + // Convert to tvm::Map + Map ret; + for (const auto& kvp: res) + ret.Set(kvp.first, kvp.second); + return ret; +} + +TVM_REGISTER_GLOBAL("relay.ext.bnns.constant_updater") + .set_body_typed(BNNSConstantUpdaterFunc); + +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/runtime/contrib/bnns/bnns_json_runtime.cc b/src/runtime/contrib/bnns/bnns_json_runtime.cc new file mode 100644 index 000000000000..ffe1b037d347 --- /dev/null +++ b/src/runtime/contrib/bnns/bnns_json_runtime.cc @@ -0,0 +1,811 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * \file + * \brief Simple JSON runtime for Apple BNNS primitives + */ + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "../json/json_node.h" +#include "../json/json_runtime.h" + +#include "Accelerate/Accelerate.h" + +#define BNNS_TMP_CONCURRENCY 2 +#define BNNS_MAX_CONCURRENCY 8 + +template +bool one_of(T1 arg1, T2 arg2) { + return arg1 == arg2; +} + +template +bool one_of(T1 arg1, T2 arg2, T... args) { + return arg1 == arg2 || one_of(arg1, args...); +} + +namespace tvm { +namespace runtime { +namespace contrib { + +using namespace tvm::runtime; +using namespace tvm::runtime::json; + +/** C++ wrapper on top of original BNNS C api */ +namespace BNNS { + using Dim = size_t; + using Shape = std::vector; + using Dtype = BNNSDataType; + + void* default_alloc(size_t size) { + // TODO: Clarify, should it have some alignment for better performance + // with SIMD execution.. may be TVMBackendAllocWorkspace is more + // preferable here. + // Note: Apple uses posix_memalign by default. + return malloc(size); + } + + void default_free(void* ptr) { + free(ptr); + } + + class Tensor { + public: + Tensor(Shape shape, Dtype dtype, void* hdl) + : real_shape(shape) { + ICHECK(shape.size() < BNNS_MAX_TENSOR_DIMENSION); + + if (hdl) { + data_handler = hdl; + is_external_data = true; + } else { + const size_t buff_size = getNumOfElements(shape) * getElementSize(dtype); + data_handler = default_alloc(buff_size); + is_external_data = false; + } + + bnns_nd_desc = { + BNNSNDArrayFlags(0), + getPlainLayout(shape), + {}, // shape + {}, // strides, empty value means use default dense strides + hdl, // data handler + dtype, // data type + nullptr, // table_data (clustering case), is not used + dtype, + 1.f, + 0.f + }; + std::copy(shape.rbegin(), shape.rend(), std::begin(bnns_nd_desc.size)); + } + + ~Tensor() { + if (data_handler && !is_external_data) { + default_free(data_handler); + data_handler = nullptr; + } + } + + void* get_data_hdl() { return data_handler; } + + const void* get_data_hdl() const { return data_handler; }; + + void set_data_hdl(void *hdl) { + if (data_handler && !is_external_data) { + default_free(data_handler); + data_handler = nullptr; + } + + data_handler = hdl; + is_external_data = true; + } + + size_t get_mb() const { + return real_shape[0]; + } + + size_t get_mb_stride() const { + return std::accumulate(real_shape.begin() + 1, real_shape.end(), + 1, std::multiplies()); + } + + const BNNSNDArrayDescriptor get_nd_desc(size_t nd = 0) const { + auto original_nd = real_shape.size(); + // Ask of original descriptor + if (original_nd == nd || nd == 0) + return bnns_nd_desc; + + // As of desc with excluded batch + if (original_nd == nd + 1) { + auto res = bnns_nd_desc; + res.size[original_nd - 1] = 0; + res.layout = BNNSDataLayout3DLastMajor; // TODO [apeskov] : hardcoded value. FIXME + return res; + } + LOG(FATAL) << "Unknown case of BNNS tensor interpretation"; + return bnns_nd_desc; + }; + + private: + static BNNSDataLayout getPlainLayout(const Shape &shape) { + return getPlainLayout(shape.size()); + } + + static BNNSDataLayout getPlainLayout(size_t rank) { + switch (rank) { + case 1: return BNNSDataLayout1DFirstMajor; + case 2: return BNNSDataLayout2DFirstMajor; + case 3: return BNNSDataLayout3DFirstMajor; + case 4: return BNNSDataLayout4DFirstMajor; + case 5: return BNNSDataLayout5DFirstMajor; + case 6: return BNNSDataLayout6DFirstMajor; + case 7: return BNNSDataLayout7DFirstMajor; + case 8: return BNNSDataLayout8DFirstMajor; + default: + LOG(FATAL) << "Unsupported tensor rank : " << rank + << " Supported cases is only 1-8 "; + return static_cast(0); + } + } + + /** + * return size in byte of element of provided type + * @param dtype + * @return size of element in bytes + */ + static size_t getElementSize(Dtype dtype) { + return (dtype & 0xFFFF) / sizeof(uint8_t); + } + + static size_t getNumOfElements(const Shape &shape) { + return std::accumulate(shape.begin(), shape.end(), + 1, std::multiplies()); + } + + private: + Shape real_shape = {}; + void* data_handler = nullptr; + bool is_external_data = false; + + BNNSNDArrayDescriptor bnns_nd_desc; + }; + + class Primitive { + public: + Primitive(BNNSFilter f) : num_filters(1), filters{f} {} + + Primitive(BNNSFilter fs[BNNS_MAX_CONCURRENCY]) { + std::copy(fs, fs + BNNS_MAX_CONCURRENCY, filters); + for (int i = 0; i < BNNS_MAX_CONCURRENCY; i++) { + if (filters[i] == nullptr) { + num_filters = i; + break; + } + } + } + + ~Primitive() { + for (size_t i = 0; i < num_filters; i++) { + auto &filter = filters[i]; + if (filter) { + BNNSFilterDestroy(filter); + filter = nullptr; + } + } + } + + void execute(std::vector srcs, Tensor &dst, int forceBatchSize = -1) { + ICHECK_LE(srcs.size(), 2) << "Currently BNNS runtime supports primitives with only 1 or 2 " + "data inputs."; + + run_ctx ctx { this, srcs[0], nullptr, &dst, forceBatchSize }; + if (srcs.size() > 1) + ctx.src2 = srcs[1]; + + auto res = TVMBackendParallelLaunch(run_task, &ctx, num_filters); + ICHECK_EQ(res, 0) << "BNNS runtime. Primitive was not executed properly"; + } + + void set_input_stride(size_t stride1, size_t stride2 = 0) { + in1_hdl_stride = stride1; + in2_hdl_stride = stride2; + } + void set_output_stride(size_t stride) { out_hdl_stride = stride; } + + private: + struct run_ctx { + Primitive *prim; + const Tensor *src1; + const Tensor *src2; + Tensor *dst; + const int force_batch_size; + }; + + static int run_task(int task_id, TVMParallelGroupEnv* penv, void* cdata) { + auto ctx = reinterpret_cast(cdata); + const auto *prim = ctx->prim; + + const auto &filter = prim->filters[task_id]; + + auto src1_hdl = ctx->src1->get_data_hdl(); + auto dst_hdl = ctx->dst->get_data_hdl(); + + auto src1_mb = ctx->src1->get_mb(); + auto dst_mb = ctx->dst->get_mb(); + + auto src1_mb_stride = ctx->src1->get_mb_stride(); + auto dst_mb_stride = ctx->dst->get_mb_stride(); + + src1_hdl = static_cast(src1_hdl) + task_id*prim->in1_hdl_stride; + dst_hdl = static_cast(dst_hdl) + task_id*prim->out_hdl_stride; + + ICHECK(src1_mb == dst_mb) << "Mismatch of batch dimension of input/output tensors"; + + const void* src2_hdl = nullptr; + size_t src2_mb = 0; + size_t src2_mb_stride = 0; + + if (ctx->src2) { + src2_hdl = ctx->src2->get_data_hdl(); + src2_mb = ctx->src2->get_mb(); + src2_mb_stride = ctx->src2->get_mb_stride(); + src2_hdl = static_cast(src2_hdl) + task_id*prim->in2_hdl_stride; + ICHECK(src2_mb == dst_mb) << "Mismatch of batch dimension of input/output tensors"; + } + + const auto mb = (ctx->force_batch_size == -1) ? dst_mb : ctx->force_batch_size; + + // NB! Limitations + // * Do not use simple BNNSFilterApply. There is a bug inside BNNS, + // and BNNSFilterApply doesn't work for grouped convolution. + // * Group convolution doesn't support arbitrary stride for Batch dim. + // The tensor should be dense. + auto sts = (ctx->src2) + ? BNNSFilterApplyTwoInputBatch(filter, mb, + src1_hdl, src1_mb_stride, + src2_hdl, src2_mb_stride, + dst_hdl, dst_mb_stride) + : BNNSFilterApplyBatch(filter, mb, + src1_hdl, src1_mb_stride, + dst_hdl, dst_mb_stride); + + return sts; + } + + private: + size_t num_filters = 0; + BNNSFilter filters[BNNS_MAX_CONCURRENCY] = {}; + + // TODO: temporal solution with strides + size_t in1_hdl_stride = 0; + size_t in2_hdl_stride = 0; + size_t out_hdl_stride = 0; + }; +} + +struct BNNSConfig { + /** + * Internal parallelism level ov BNNS primitive specified via parameter. + * Has no real control from TVM level, so in fact it may be ignored by + * implementation. + */ + int internalConcurrency = 0; + + /** + * TVM level parallelism for BNNS primitive. In case if BNNS doesn't support + * internal parallelism we can add it by splitting primitive into independent + * parts and run it in parallel. May provide additional performance. + */ + int externalConcurrency = 0; +}; + +class BNNSJSONRuntime : public JSONRuntimeBase { + + public: + BNNSJSONRuntime(const std::string& symbol_name, const std::string& graph_json, + const Array const_names) + : JSONRuntimeBase(symbol_name, graph_json, const_names) {} + + const char* type_key() const override { return "bnns_json"; } + + void Init(const Array& consts) override { + SetupConstants(consts); + BuildEngine(); + + ICHECK_EQ(consts.size(), const_idx_.size()) + << "The number of input constants must match the number of required."; + } + + void Run() override { + // Wrap external handler into BNNS tensor representation + auto bind_ext_hdl_to_tensor = [this] (uint32_t eid) { + const auto &ext_dlt = *data_entry_[eid]; + auto &bnns_tensor = *entry_out_mem_[eid]; + bnns_tensor.set_data_hdl(ext_dlt.data); + }; + + // Bind all input/output external data object into internal abstractions + for (const auto &eid : input_var_eid_) { + bind_ext_hdl_to_tensor(eid); + } + for (const auto &out_entity : outputs_) { + bind_ext_hdl_to_tensor(EntryID(out_entity)); + } + + // Invoke primitives in topological order + for (int i = 0; i < primitives_.size(); ++i) { + auto res = entry_out_mem_.at(prim_results_[i]); + std::vector args; + for (auto arg_id : prim_args_[i]) + args.push_back(entry_out_mem_.at(arg_id).get()); + + int forceBatchSize = + (force_batch_size_.find(i) == force_batch_size_.end()) ? -1 : force_batch_size_.at(i); + primitives_.at(i)->execute(args, *res, forceBatchSize); + } + } + + private: + // Build up the engine based on the input graph. + void BuildEngine() { + // Build subgraph engine. + for (size_t nid = 0; nid < nodes_.size(); ++nid) { + const auto& node = nodes_[nid]; + if (node.GetOpType() == "kernel") { + ICHECK_EQ(node.GetOpType(), "kernel"); + auto op_name = node.GetOpName(); + if ("nn.conv2d" == op_name) { + Conv2d(nid); + } else if ("bnns.conv2d_relu" == op_name) { + Conv2d(nid, true, false); + } else if ("bnns.conv2d_bias_relu" == op_name) { + Conv2d(nid, true, true); + } else if ("bnns.conv2d_bias" == op_name) { + Conv2d(nid, false, true); + } else if ("nn.dense" == op_name) { + Dense(nid); + } else if ("bnns.dense_bias" == op_name) { + Dense(nid, true); + } else if ("bnns.dense_bias_gelu" == op_name) { + Dense(nid, true, true); + } else if ("nn.batch_matmul" == op_name) { + MatMul(nid); + } else { + LOG(FATAL) << "Unsupported op: " << op_name; + } + } + } + } + + // Bind a JSON graph node entry to a BNNS tensor. + std::shared_ptr BindBNNSTensor(const JSONGraphNodeEntry& entry, void *hdl = nullptr) { + auto eid = EntryID(entry); + if (entry_out_mem_.count(eid) == 0) { + auto data_node = nodes_[entry.id_]; + auto dlshape = data_node.GetOpShape()[entry.index_]; + auto dltype = data_node.GetOpDataType()[entry.index_]; + + entry_out_mem_[eid] = std::make_shared( + BNNS::Shape{dlshape.begin(), dlshape.end()}, + convertToBNNS(dltype), hdl); + } + return entry_out_mem_[eid]; + } + + /** + * Function which split primitive into sub primitives to parallel execution + * + * @param orig_conv_param descriptor of original convolution + * @param batch batch value + * @param num number of part to split into. + * @return collection of Convolution descriptors plus strides for input and output tensors + */ + static std::tuple, size_t, size_t> + split_into_n(const BNNSLayerParametersConvolution& orig_conv_param, size_t batch, size_t num) { + size_t i_shape[BNNS_MAX_TENSOR_DIMENSION] = {}; + size_t o_shape[BNNS_MAX_TENSOR_DIMENSION] = {}; + size_t w_shape[BNNS_MAX_TENSOR_DIMENSION] = {}; + size_t b_shape[BNNS_MAX_TENSOR_DIMENSION] = {}; + size_t w_stride = 0; + size_t b_stride = 0; + size_t i_stride = 0; + size_t o_stride = 0; + + // TODO: In case of batch we can split through bach dimension. + // Meanwhile we just disable it... + if (batch > 1) { + return {{orig_conv_param}, 0, 0}; + } + + auto groups = orig_conv_param.groups; + // if groups > 1 split only by groups + // otherwise split inside one convolution by output channels + if (groups > 1) { + // fallback into sequential execution + if (groups % num != 0) + return {{orig_conv_param}, 0, 0}; + + std::copy(orig_conv_param.i_desc.size, orig_conv_param.i_desc.size + 3, i_shape); + std::copy(orig_conv_param.o_desc.size, orig_conv_param.o_desc.size + 3, o_shape); + std::copy(orig_conv_param.w_desc.size, orig_conv_param.w_desc.size + 4, w_shape); + std::copy(orig_conv_param.bias.size, orig_conv_param.bias.size + 1, b_shape); + + auto orig_w_buff_size = std::accumulate(w_shape, w_shape + 4, 1, std::multiplies()) + * sizeof(float); + + auto orig_b_buff_size = std::accumulate(b_shape, b_shape + 1, 1, std::multiplies()) + * sizeof(float); + + auto orig_i_buff_size = std::accumulate(i_shape, i_shape + 3, 1, std::multiplies()) + * sizeof(float); + + auto orig_o_buff_size = std::accumulate(o_shape, o_shape + 3, 1, std::multiplies()) + * sizeof(float); + + i_shape[2] /= num; + o_shape[2] /= num; + w_shape[3] /= num; + b_shape[0] /= num; + + w_stride = orig_w_buff_size / num; + b_stride = orig_b_buff_size / num; + i_stride = orig_i_buff_size / num; + o_stride = orig_o_buff_size / num; + groups = groups / num; + } else { + std::copy(orig_conv_param.i_desc.size, orig_conv_param.i_desc.size + 3, i_shape); + std::copy(orig_conv_param.o_desc.size, orig_conv_param.o_desc.size + 3, o_shape); + std::copy(orig_conv_param.w_desc.size, orig_conv_param.w_desc.size + 4, w_shape); + std::copy(orig_conv_param.bias.size, orig_conv_param.bias.size + 1, b_shape); + + auto orig_w_buff_size = std::accumulate(w_shape, w_shape + 4, 1, std::multiplies()) + * sizeof(float); + + auto orig_b_buff_size = std::accumulate(b_shape, b_shape + 1, 1, std::multiplies()) + * sizeof(float); + +// auto orig_i_buff_size = std::accumulate(i_shape, i_shape + 3, 1, std::multiplies()) +// * sizeof(float); + + auto orig_o_buff_size = std::accumulate(o_shape, o_shape + 3, 1, std::multiplies()) + * sizeof(float); + + o_shape[2] /= num; + w_shape[3] /= num; + b_shape[0] /= num; + + w_stride = orig_w_buff_size / num; + b_stride = orig_b_buff_size / num; + i_stride = 0; + o_stride = orig_o_buff_size / num; + } + + std::vector res(num); + for (size_t i=0; i < num; i++) { + auto &cur = res[i]; + cur = orig_conv_param; + + std::copy(i_shape, i_shape + 3, cur.i_desc.size); + std::copy(o_shape, o_shape + 3, cur.o_desc.size); + std::copy(w_shape, w_shape + 4, cur.w_desc.size); + std::copy(b_shape, b_shape + 1, cur.bias.size); + + cur.w_desc.data = static_cast(cur.w_desc.data) + w_stride * i; + if (cur.bias.data) + cur.bias.data = static_cast(cur.bias.data) + b_stride * i; + + cur.groups = groups; + } + return {res, i_stride, o_stride}; + } + + + void Conv2d(const size_t& nid, const bool has_relu = false, const bool has_bias = false) { + auto node = nodes_[nid]; + + // Setup attributes. + auto src_entry = node.GetInputs()[0]; + auto weight_entry = node.GetInputs()[1]; + auto dst_entry = JSONGraphNodeEntry(nid, 0); + + auto dl_input_shape = nodes_[src_entry.id_].GetOpShape()[src_entry.index_]; + auto dl_weight_shape = nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_]; + BNNS::Shape input_shape {dl_input_shape.begin(), dl_input_shape.end()}; + BNNS::Shape weight_shape {dl_weight_shape.begin(), dl_weight_shape.end()}; + std::vector str_strides = node.GetAttr>("strides"); + std::vector str_dilation = node.GetAttr>("dilation"); + std::vector str_padding = node.GetAttr>("padding"); + BNNS::Dim groups = std::stoi(node.GetAttr>("groups")[0]); + + BNNS::Dim N = input_shape[0], // batch size + IC = input_shape[1], // input channels + IH = input_shape[2], // input height + IW = input_shape[2], // input width + OC = weight_shape[0], // output channels + KH = weight_shape[2], // weight height + KW = weight_shape[3], // weight width + PH_L = std::stoi(str_padding[0]), // height padding: left + PH_R = std::stoi(str_padding[2]), // height padding: right + PW_L = std::stoi(str_padding[1]), // width padding: left + PW_R = std::stoi(str_padding[3]), // width padding: right + SH = std::stoi(str_strides[0]), // height-wise stride + SW = std::stoi(str_strides[1]), // weight-wise stride + DH = std::stoi(str_dilation[0]), // height kernel dilation + DW = std::stoi(str_dilation[1]), // width kernel dilation + OH = (IH - KH + PH_L + PH_R) / SH + 1, // output height + OW = (IW - KW + PW_L + PW_R) / SW + 1; // output width + + // Memory shapes. + BNNS::Shape src_dims = {N, IC, IH, IW}; + BNNS::Shape weights_dims = {OC, IC, KH, KW}; + BNNS::Shape bias_dims = {OC}; + BNNS::Shape dst_dims = {N, OC, OH, OW}; + BNNS::Shape strides_dims = {SH, SW}; + BNNS::Shape padding_dims_l = {PH_L, PW_L}; + BNNS::Shape padding_dims_r = {PH_R, PW_R}; + + auto weight_data_entry = data_entry_[EntryID(weight_entry)]; + ICHECK(weight_data_entry) << "Convolution weights tensor should be constant and " + "available on initialization stage. Looks like weights " + "are not result of constant expression."; + + auto weight_ext_data_hdl = weight_data_entry->data; + + // Memory descriptions. + auto src_md = BindBNNSTensor(src_entry); + auto weights_md = BindBNNSTensor(weight_entry, weight_ext_data_hdl); + std::shared_ptr bias_md; + auto dst_md = BindBNNSTensor(dst_entry); + // TODO [apeskov]: check correctness of tensor shapes + + if (has_bias) { + auto bias_entry = node.GetInputs()[2]; + auto bias_data_entry = data_entry_[EntryID(bias_entry)]; + ICHECK(bias_data_entry) << "Convolution bias tensor should be constant and " + "available on initialization stage. Looks like bias " + "is not result of constant expression."; + + auto bias_data_hdl = bias_data_entry->data; + bias_md = BindBNNSTensor(bias_entry, bias_data_hdl); + } else { + bias_md = std::make_shared(BNNS::Shape {OC}, BNNSDataTypeFloat32, nullptr); + } + + BNNSActivation activation = { has_relu ? + BNNSActivationFunctionRectifiedLinear : + BNNSActivationFunctionIdentity }; + + auto src_candidate = src_md->get_nd_desc(3); + auto weights_candidate = weights_md->get_nd_desc(); + auto dst_candidate = dst_md->get_nd_desc(3); + auto bias_candidate = bias_md->get_nd_desc(); + src_candidate.layout = BNNSDataLayoutImageCHW; + dst_candidate.layout = BNNSDataLayoutImageCHW; + weights_candidate.layout = BNNSDataLayoutConvolutionWeightsOIHW; + bias_candidate.layout = BNNSDataLayoutVector; + + // TODO [apeskov]: Tmp WA, broadcast bias is here with tailing [1, 1] + if (bias_candidate.size[0] == 1 && bias_candidate.size[1] == 1 && + one_of(bias_candidate.size[3], 1, 0) && + std::all_of(bias_candidate.size + 4, bias_candidate.size + BNNS_MAX_TENSOR_DIMENSION, + [] ( size_t d) { return d == 0; })) { + auto element_count = bias_candidate.size[2]; + std::fill(bias_candidate.size, bias_candidate.size + BNNS_MAX_TENSOR_DIMENSION, 0); + bias_candidate.size[0] = element_count; + } + + BNNSLayerParametersConvolution conv_param = { + src_candidate, + weights_candidate, + dst_candidate, + bias_candidate, + activation, + SW, /* x_stride */ + SH, /* y_stride */ + DW, /* x_dilation_stride */ + DH, /* y_dilation_stride */ + 0, /* x_padding, explicit pads will be used */ + 0, /* y_padding, explicit pads will be used */ + groups, /* groups */ + {PW_L, PW_R, PH_L, PH_R} /* explicit pad values */ + }; + + BNNSFilter filters[BNNS_MAX_CONCURRENCY] = {}; + + std::vector params; + size_t i_stride, o_stride; + std::tie(params, i_stride, o_stride) = split_into_n(conv_param, N, BNNS_TMP_CONCURRENCY); + for (int i = 0; i < params.size(); i++) { + filters[i] = BNNSFilterCreateLayerConvolution(¶ms[i], &common_filter_param); + ICHECK(filters[i]) << "BNNS primitive was not created. Unsupported attributes configuration"; + } + + primitives_.emplace_back(std::make_shared(filters)); + primitives_.back()->set_input_stride(i_stride); + primitives_.back()->set_output_stride(o_stride); + + prim_args_.push_back({EntryID(src_entry)}); + prim_results_.push_back({EntryID(dst_entry)}); + } + + void Dense(const size_t& nid, const bool has_bias = false, const bool has_gelu = false) { + auto node = nodes_[nid]; + + // Setup attributes. + auto src_entry = node.GetInputs()[0]; + auto weight_entry = node.GetInputs()[1]; + auto dst_entry = JSONGraphNodeEntry(nid, 0); + + auto w_data = data_entry_[EntryID(weight_entry)]->data; + // Memory descriptions. + auto src_md = BindBNNSTensor(src_entry); + auto weights_md = BindBNNSTensor(weight_entry, w_data); + auto dst_md = BindBNNSTensor(dst_entry); + + BNNSNDArrayDescriptor in_desc = src_md->get_nd_desc(1); + BNNSNDArrayDescriptor w_desc = weights_md->get_nd_desc(2); + BNNSNDArrayDescriptor out_desc = dst_md->get_nd_desc(1); + w_desc.layout = BNNSDataLayoutRowMajorMatrix; + in_desc.layout = BNNSDataLayoutVector; + out_desc.layout = BNNSDataLayoutVector; + w_desc.data = w_data; + BNNSNDArrayDescriptor bias = {}; + if (has_bias) { + auto bias_entry = node.GetInputs()[2]; + auto bias_data = data_entry_[EntryID(bias_entry)]->data; + auto bias_md = BindBNNSTensor(bias_entry, bias_data); + bias = bias_md->get_nd_desc(); + bias.layout = BNNSDataLayoutVector; + bias.data = bias_data; + } + BNNSActivation activation = {BNNSActivationFunctionIdentity}; + if (has_gelu) { + activation = {BNNSActivationFunctionGELUApproximation}; + activation.alpha = std::sqrt(2.0 / M_PI); + activation.beta = 0.044715; + } + + BNNSLayerParametersFullyConnected layerParameters = { + in_desc, + w_desc, + out_desc, + bias, + activation, + }; + + auto filter = BNNSFilterCreateLayerFullyConnected(&layerParameters, &common_filter_param); + ICHECK(filter) << "BNNS primitive was not created. Unsupported attributes configuration"; + primitives_.emplace_back(std::make_shared(filter)); + prim_args_.push_back({EntryID(src_entry)}); + prim_results_.push_back({EntryID(dst_entry)}); + } + + void MatMul(const size_t& nid) { + auto node = nodes_[nid]; + + // Setup attributes. + auto a_entry = node.GetInputs()[0]; + auto b_entry = node.GetInputs()[1]; + auto dst_entry = JSONGraphNodeEntry(nid, 0); + bool a_is_weighted = data_entry_[EntryID(a_entry)] != nullptr; + bool b_is_weighted = data_entry_[EntryID(b_entry)] != nullptr; + + void* a_data = nullptr; + void* b_data = nullptr; + if (a_is_weighted) + a_data = data_entry_[EntryID(a_entry)]->data; + if (b_is_weighted) + b_data = data_entry_[EntryID(b_entry)]->data; + // Memory descriptions. + auto a_md = BindBNNSTensor(a_entry, a_data); + auto b_md = BindBNNSTensor(b_entry, b_data); + auto dst_md = BindBNNSTensor(dst_entry); + + BNNSNDArrayDescriptor a_desc = a_md->get_nd_desc(); + BNNSNDArrayDescriptor b_desc = b_md->get_nd_desc(); + BNNSNDArrayDescriptor out_desc = dst_md->get_nd_desc(); + std::reverse(a_desc.size, a_desc.size + 3); + std::reverse(b_desc.size, b_desc.size + 3); + std::reverse(out_desc.size, out_desc.size + 3); + a_desc.data = a_data; + b_desc.data = b_data; + + BNNSLayerParametersBroadcastMatMul layerParameters = { + 1, // alpha + 0, // beta + false, // transA + true, // transB + false, // quadratic + a_is_weighted, + b_is_weighted, + a_desc, + b_desc, + out_desc + }; + + auto filter = BNNSFilterCreateLayerBroadcastMatMul(&layerParameters, &common_filter_param); + ICHECK(filter) << "BNNS primitive was not created. Unsupported attributes configuration"; + primitives_.emplace_back(std::make_shared(filter)); + std::vector args; + if (!a_is_weighted) + args.push_back(EntryID(a_entry)); + if (!b_is_weighted) + args.push_back(EntryID(b_entry)); + prim_args_.push_back(std::move(args)); + prim_results_.push_back(EntryID(dst_entry)); + force_batch_size_.insert({prim_args_.size() - 1, 1}); + } + + BNNS::Dtype convertToBNNS(const DLDataType &dl_dtype) { + if (dl_dtype.code == DLDataTypeCode::kDLFloat) { + if (dl_dtype.bits == 32) return BNNSDataTypeFloat32; + if (dl_dtype.bits == 16) return BNNSDataTypeFloat16; + } + if (dl_dtype.code == DLDataTypeCode::kDLInt) { + if (dl_dtype.bits == 32) return BNNSDataTypeInt32; + if (dl_dtype.bits == 16) return BNNSDataTypeInt16; + if (dl_dtype.bits == 8) return BNNSDataTypeInt8; + } + if (dl_dtype.code == DLDataTypeCode::kDLUInt) { + if (dl_dtype.bits == 32) return BNNSDataTypeUInt32; + if (dl_dtype.bits == 16) return BNNSDataTypeUInt16; + if (dl_dtype.bits == 8) return BNNSDataTypeUInt8; + } + LOG(FATAL) << "Unsupported data type for BNNS runtime"; + return BNNS::Dtype(0); + } + + // TODO(apeskov): Allow to specify num of threads and keep customer buffers. + // Should investigate this attributes. + BNNSFilterParameters common_filter_param {}; + + std::vector> primitives_; + std::vector> prim_args_; + std::vector prim_results_; + std::unordered_map force_batch_size_; + + /* The entry ID to its corresponding output memory. */ + std::unordered_map> entry_out_mem_; +}; + +runtime::Module BNNSJSONRuntimeCreate(String symbol_name, String graph_json, + const Array& const_names) { + auto n = make_object(symbol_name, graph_json, const_names); + return runtime::Module(n); +} + +TVM_REGISTER_GLOBAL("runtime.BNNSJSONRuntimeCreate") + .set_body_typed(BNNSJSONRuntimeCreate); + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_bnns_json") + .set_body_typed(BNNSJSONRuntime::LoadFromBinary); + +} // namespace contrib +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/json/json_runtime.h b/src/runtime/contrib/json/json_runtime.h index 3ae652ccaf24..55f16635b9e6 100644 --- a/src/runtime/contrib/json/json_runtime.h +++ b/src/runtime/contrib/json/json_runtime.h @@ -55,7 +55,7 @@ class JSONRuntimeBase : public ModuleNode { LoadGraph(graph_json_); } - const char* type_key() const { return "json"; } + const char* type_key() const override { return "json"; } /*! \brief Initialize a specific json runtime. */ virtual void Init(const Array& consts) = 0; @@ -69,7 +69,7 @@ class JSONRuntimeBase : public ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The packed function. */ - virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) override { if (name == "get_symbol") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->symbol_name_; }); @@ -98,7 +98,7 @@ class JSONRuntimeBase : public ModuleNode { } } - virtual void SaveToBinary(dmlc::Stream* stream) { + void SaveToBinary(dmlc::Stream* stream) override { // Save the symbol stream->Write(symbol_name_); // Save the graph diff --git a/tests/python/contrib/test_bnns/__init__.py b/tests/python/contrib/test_bnns/__init__.py new file mode 100644 index 000000000000..724b23f1378b --- /dev/null +++ b/tests/python/contrib/test_bnns/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Infrastructure and tests for BNNS""" diff --git a/tests/python/contrib/test_bnns/infrastructure.py b/tests/python/contrib/test_bnns/infrastructure.py new file mode 100644 index 000000000000..b3e15959107e --- /dev/null +++ b/tests/python/contrib/test_bnns/infrastructure.py @@ -0,0 +1,309 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from itertools import zip_longest, combinations +import json +import os +import warnings + +import numpy as np + +import tvm +from tvm import relay +from tvm import rpc +from tvm.contrib import graph_runtime +from tvm.relay.op.contrib.bnns import partition_for_bnns +from tvm.contrib import utils +from tvm.autotvm.measure import request_remote + + +class Device: + """ + Common device configuration for python tests. + + Check tests/python/contrib/arm_compute_lib/ for the presence of an test_config.json file. + This file can be used to override the default configuration here which will attempt to run the Arm + Compute Library runtime tests locally if the runtime is available. Changing the configuration + will allow these runtime tests to be offloaded to a remote Arm device via a tracker for example. + + Notes + ----- + The test configuration will be loaded once when the the class is created. If the configuration + changes between tests, any changes will not be picked up. + + Parameters + ---------- + device : RPCSession + Allows tests to connect to and use remote device. + + Attributes + ---------- + connection_type : str + Details the type of RPC connection to use. Options: + local - Use the local device, + tracker - Connect to a tracker to request a remote device, + remote - Connect to a remote device directly. + host : str + Specify IP address or hostname of remote target. + port : int + Specify port number of remote target. + target : str + The compilation target. + device_key : str + The device key of the remote target. Use when connecting to a remote device via a tracker. + cross_compile : str + Specify path to cross compiler to use when connecting a remote device from a non-arm platform. + """ + + connection_type = "local" + host = "localhost" + port = 9090 + target = "llvm" + device_key = "" + cross_compile = "" + + def __init__(self): + """Keep remote device for lifetime of object.""" + self.device = self._get_remote() + + @classmethod + def _get_remote(cls): + """Get a remote (or local) device to use for testing.""" + if cls.connection_type == "tracker": + device = request_remote(cls.device_key, cls.host, cls.port, timeout=1000) + elif cls.connection_type == "remote": + device = rpc.connect(cls.host, cls.port) + elif cls.connection_type == "local": + device = rpc.LocalSession() + else: + raise ValueError( + "connection_type in test_config.json should be one of: " "local, tracker, remote." + ) + + return device + + @classmethod + def load(cls, file_name): + """Load test config + + Load the test configuration by looking for file_name relative + to the test_arm_compute_lib directory. + """ + location = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) + config_file = os.path.join(location, file_name) + if not os.path.exists(config_file): + warnings.warn( + "Config file doesn't exist, resuming tests with default config." + ) + return + with open(config_file, mode="r") as config: + test_config = json.load(config) + + cls.connection_type = test_config["connection_type"] + cls.host = test_config["host"] + cls.port = test_config["port"] + cls.target = test_config["target"] + cls.device_key = test_config.get("device_key") or "" + cls.cross_compile = test_config.get("cross_compile") or "" + + +Device.target = "llvm" + + +def skip_runtime_test(): + """Skip test if it requires the runtime and it's not present.""" + # BNNS codegen not present. + if not tvm.get_global_func("relay.ext.bnns", True): + print("Skip because BNNS codegen is not available.") + return True + return False + + +def skip_codegen_test(): + """Skip test if it requires the BNNS codegen and it's not present.""" + if not tvm.get_global_func("relay.ext.bnns", True): + print("Skip because BNNS codegen is not available.") + return True + + +def build_module(mod, target, params=None, enable_bnns=True, tvm_ops=0): + """Build module with option to build for BNNS.""" + if isinstance(mod, tvm.relay.expr.Call): + mod = tvm.IRModule.from_expr(mod) + with tvm.transform.PassContext(opt_level=3): + if enable_bnns: + mod = partition_for_bnns(mod) + relay.backend.compile_engine.get().clear() + return relay.build(mod, target=target, target_host=target, params=params) + + +def build_and_run( + mod, + inputs, + outputs, + params, + device, + enable_bnns=True, + no_runs=1, + tvm_ops=0, + config=None, +): + """Build and run the relay module.""" + if config is None: + config = {} + + try: + lib = build_module(mod, device.target, params, enable_bnns, tvm_ops) + except Exception as e: + err_msg = "The module could not be built.\n" + if config: + err_msg += f"The test failed with the following parameters: {config}\n" + err_msg += str(e) + raise Exception(err_msg) + + lib = update_lib(lib, device.device, device.cross_compile) + gen_module = graph_runtime.GraphModule(lib["default"](device.device.cpu(0))) + gen_module.set_input(**inputs) + out = [] + for _ in range(no_runs): + gen_module.run() + out.append([gen_module.get_output(i) for i in range(outputs)]) + return out + + +def update_lib(lib, device, cross_compile): + """Export the library to the remote/local device.""" + lib_name = "mod.so" + temp = utils.tempdir() + lib_path = temp.relpath(lib_name) + if cross_compile: + lib.export_library(lib_path, cc=cross_compile) + else: + lib.export_library(lib_path) + device.upload(lib_path) + lib = device.load_module(lib_name) + return lib + + +def extract_bnns_modules(module): + """Get the BNNS module(s) from llvm module.""" + return list( + filter(lambda mod: mod.type_key == "bnns_json", module.get_lib().imported_modules) + ) + + +def verify(answers, atol, rtol, verify_saturation=False, config=None): + """Compare the array of answers. Each entry is a list of outputs.""" + if config is None: + config = {} + + if len(answers) < 2: + raise RuntimeError(f"No results to compare: expected at least two, found {len(answers)}") + for answer in zip_longest(*answers): + for outs in combinations(answer, 2): + try: + if verify_saturation: + assert ( + np.count_nonzero(outs[0].asnumpy() == 255) < 0.25 * outs[0].asnumpy().size + ), "Output is saturated: {}".format(outs[0]) + assert ( + np.count_nonzero(outs[0].asnumpy() == 0) < 0.25 * outs[0].asnumpy().size + ), "Output is saturated: {}".format(outs[0]) + tvm.testing.assert_allclose( + outs[0].asnumpy(), outs[1].asnumpy(), rtol=rtol, atol=atol + ) + except AssertionError as e: + err_msg = "Results not within the acceptable tolerance.\n" + if config: + err_msg += f"The test failed with the following parameters: {config}\n" + err_msg += str(e) + raise AssertionError(err_msg) + + +def verify_codegen( + module, + known_good_codegen, + num_bnns_modules, + tvm_ops=0, + target=Device.target, +): + """Check BNNS codegen against a known good output.""" + module = build_module(module, target, tvm_ops=tvm_ops) + bnns_modules = extract_bnns_modules(module) + + assert len(bnns_modules) == num_bnns_modules, ( + f"The number of BNNS modules produced ({len(bnns_modules)}) does not " + f"match the expected value ({num_bnns_modules})." + ) + + for mod in bnns_modules: + source = mod.get_source("json") + codegen = json.loads(source)["nodes"] + # remove input and const names as these cannot be predetermined + for node in range(len(codegen)): + if codegen[node]["op"] == "input" or codegen[node]["op"] == "const": + codegen[node]["name"] = "" + codegen_str = json.dumps(codegen, sort_keys=True, indent=2) + known_good_codegen_str = json.dumps(known_good_codegen, sort_keys=True, indent=2) + + assert codegen_str == known_good_codegen_str, ( + f"The JSON produced by codegen does not match the expected result. \n" + f"Actual={codegen_str} \n" + f"Expected={known_good_codegen_str}" + ) + + +def generate_trials(space, r_factor=3): + """Generates a series of trials. + + This algorithm generates a series of non-deterministic trials given a + space of options to test. A trial is generated by pulling a value from + each option in the space. On some occasions the values are shuffled to + ensure a different trial on each r_factor iteration. The algorithm ensures + that each value from an option is used at least once. The total number of + trials is determined by the r_factor * the option with the largest number + of values. + + Parameters + ---------- + space: List[List[Any]] + A list of different options with varying values to test. + r_factor: (optional) int + The repeat factor. + + Returns + ------- + A list of trials specifying values for each option. + + """ + np.random.seed(0) + max_len = 1 + for option in space: + max_len = max(max_len, len(option)) + + num_trials = r_factor * max_len + trials = [] + for i in range(num_trials): + trial = [] + for option in space: + if i % len(option) == 0: + np.random.shuffle(option) + trial.append(option[i % len(option)]) + + trials.append(trial) + + return trials diff --git a/tests/python/contrib/test_bnns/test_conv2d.py b/tests/python/contrib/test_bnns/test_conv2d.py new file mode 100644 index 000000000000..7f641ab1fd60 --- /dev/null +++ b/tests/python/contrib/test_bnns/test_conv2d.py @@ -0,0 +1,150 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""BNNS integration conv2d tests.""" + +import numpy as np + +import tvm +from tvm import relay + +from .infrastructure import Device +from .infrastructure import ( + skip_runtime_test, + skip_codegen_test, + build_and_run, + verify, + generate_trials, +) + +# TODO: Missed cases +# 1. Bias as add with 3d const tensor. Lead to additional unsqueeze op between +# 2. Check unsupported cases of fusion. Like bias add with axis != 1, add with broadcast by spatial dims +# 3. Check if bias/weights is not constants. Should fallback into LLVM or decompose it +# 4. Check if bias/weights is constants expr. Should works somehow. + +def _get_model( + shape, + kernel_h, + kernel_w, + padding, + strides, + dilation, + groups, + dtype, + channels, + var_names, + bias_type='none', + activation_type='none', +): + """Return a model and any parameters it may have""" + a = relay.var(next(var_names), shape=shape, dtype=dtype) + weight_shape = (channels, shape[1] // groups, kernel_h, kernel_w) + w = tvm.nd.array(np.random.uniform(-128, 127, weight_shape).astype(dtype)) + weights = relay.const(w, dtype) + out = relay.nn.conv2d( + a, + weights, + kernel_size=(kernel_h, kernel_w), + dilation=dilation, + strides=strides, + padding=padding, + groups=groups, + channels=channels, + out_dtype=dtype, + ) + params = {"w": w} + if bias_type == "bias_add": + b = tvm.nd.array(np.random.uniform(-10, 10, weight_shape[0]).astype(dtype)) + biasc = relay.const(b, dtype) + out = relay.nn.bias_add(out, biasc, axis=1) + params["b"] = b + elif bias_type == "add_3d" or bias_type == "add_4d": + bias_shape = (weight_shape[0], 1, 1) if bias_type == "add_3d" else (1, weight_shape[0], 1, 1) + b = tvm.nd.array(np.random.uniform(-10, 10, bias_shape).astype(dtype)) + biasc = relay.const(b, dtype) + out = relay.add(out, biasc) + params["b"] = b + + if activation_type == "relu": + out = relay.nn.relu(out) + return out, params + + +def test_conv2d(): + if skip_runtime_test(): + return + + device = Device() + np.random.seed(0) + + kernel_hs = [1, 2, 3, 5] + kernel_ws = [1, 2, 3, 5] + pad = [(1, 1), (2, 2), (2, 1)] + strides = [(1, 1), (2, 2)] + dilation = [(1, 1)] + out_channels = [4, 8, 16] + input_shapes = [(10, 10, 14), (12, 15, 16), (20, 20, 20)] + batches = [1, 2] + groups = [1, 2] + bias_kind = ['none', 'add_3d', 'add_4d', 'bias.add'] + activation_kind = ['none', 'relu'] + dtype = "float32" + trials = generate_trials( + [kernel_hs, kernel_ws, pad, strides, dilation, out_channels, input_shapes, + groups, batches, bias_kind, activation_kind], 3 + ) + + for kernel_h, kernel_w, pad, stride, dilation, out_channels, input_shapes, group, batch, bias, activation in trials: + shape = (batch, *input_shapes) + outputs = [] + inputs = { + "a": tvm.nd.array(np.random.uniform(0, 127, shape).astype(dtype)), + } + + func, params = _get_model( + shape, + kernel_h, + kernel_w, + pad, + stride, + dilation, + group, + dtype, + out_channels, + iter(inputs), + bias_type=bias, + activation_type=activation, + ) + for bnns in [False, True]: + outputs.append(build_and_run(func, inputs, 1, params, device, enable_bnns=bnns)[0]) + + config = { + "shape": shape, + "group": group, + "kernel size": (kernel_h, kernel_w), + "padding": pad, + "stride": stride, + "dilation": dilation, + "out channels": out_channels, + "bias": bias, + "activation": activation + } + verify(outputs, atol=0.002, rtol=0.007, config=config) + + +if __name__ == "__main__": + test_conv2d() diff --git a/tests/python/contrib/test_bnns/test_conv2d_patterns.py b/tests/python/contrib/test_bnns/test_conv2d_patterns.py new file mode 100644 index 000000000000..6e19957f43fa --- /dev/null +++ b/tests/python/contrib/test_bnns/test_conv2d_patterns.py @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""BNNS pattern detection check""" + +import tvm +from tvm import relay +import numpy as np + +from tvm.relay.op.contrib.bnns import partition_for_bnns + +fp32 = "float32" + +def partition(exp): + """Apply BNNS specific partitioning transformation""" + mod = tvm.IRModule.from_expr(exp) + with tvm.transform.PassContext(opt_level=3): + mod = partition_for_bnns(mod) + return mod + + +def is_op_fused(func, op_name): + is_fused = False + + def visit(op): + if isinstance(op, tvm.relay.function.Function) \ + and op_name in op.attrs["PartitionedFromPattern"]: + nonlocal is_fused + is_fused = True + tvm.relay.analysis.post_order_visit(func.body, visit) + return is_fused + + +def test_pattern_conv2d_with_bias_add(): + for axis in (1, 2): + a = relay.var("a", shape=(2, 7, 8, 8), dtype=fp32) + w = relay.const(np.random.uniform(-10, 10, (8, 7, 3, 3)).astype(fp32)) + res = relay.nn.conv2d( + a, w, + kernel_size=(3, 3), padding=(1, 1), + channels=8, out_dtype=fp32 + ) + b = relay.const(np.random.uniform(-10, 10, 8).astype(fp32)) + res = relay.nn.bias_add(res, b, axis=axis) + + mod = partition(res) + bias_is_fused = is_op_fused(mod["bnns_0"], "nn.bias_add") + + assert bias_is_fused if axis == 1 else not bias_is_fused + + +def test_pattern_conv2d_with_add(): + workloads = { + 8: False, + (8, 1): False, + (8, 1, 1): True, + (1, 8, 1, 1): True + } + + for b_shape, should_be_fused in workloads.items(): + a = relay.var("a", shape=(2, 7, 8, 8), dtype=fp32) + w = relay.const(np.random.uniform(-10, 10, (8, 7, 3, 3)).astype(fp32)) + res = relay.nn.conv2d( + a, w, + kernel_size=(3, 3), padding=(1, 1), + channels=8, out_dtype=fp32 + ) + b = relay.const(np.random.uniform(-10, 10, b_shape).astype(fp32)) + res = relay.add(res, b) + + mod = partition(res) + bias_is_fused = is_op_fused(mod["bnns_0"], "add") + + assert bias_is_fused == should_be_fused + + +def test_pattern_conv2d_with_non_cons_weights(): + for const_weights in (True, False): + a = relay.var("a", shape=(2, 7, 8, 8), dtype=fp32) + if const_weights: + w = relay.const(np.random.uniform(-10, 10, (8, 7, 3, 3)).astype(fp32)) + else: + w = relay.var("w", shape=(8, 7, 3, 3), dtype=fp32) + + res = relay.nn.conv2d( + a, w, + kernel_size=(3, 3), padding=(1, 1), + channels=8, out_dtype=fp32 + ) + + mod = partition(res) + use_bnns = len(mod.get_global_vars()) == 2 # GlobalVar: "main" and "bnns_0" + + assert use_bnns == const_weights + + +def test_pattern_conv2d_with_non_cons_bias(): + a = relay.var("a", shape=[2, 7, 8, 8], dtype=fp32) + w = relay.const(np.random.uniform(-10, 10, (8, 7, 3, 3)).astype(fp32)) + res = relay.nn.conv2d( + a, w, + kernel_size=(3, 3), padding=(1, 1), + channels=8, out_dtype=fp32 + ) + b = relay.var("b", shape=[8], dtype=fp32) + res = relay.nn.bias_add(res, b, axis=1) + + mod = partition(res) + bias_is_fused = is_op_fused(mod["bnns_0"], "nn.bias_add") + + assert not bias_is_fused diff --git a/tests/python/contrib/test_bnns/test_dense.py b/tests/python/contrib/test_bnns/test_dense.py new file mode 100644 index 000000000000..4e4a03702688 --- /dev/null +++ b/tests/python/contrib/test_bnns/test_dense.py @@ -0,0 +1,188 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""BNNS integration dense tests.""" + +import numpy as np +import math + +import tvm +from tvm import relay +from tvm import testing +from .infrastructure import ( + Device, + skip_runtime_test, + skip_codegen_test, + build_and_run, + verify, + verify_codegen, + generate_trials, +) + + +def _get_model(shape, weight_shape, units, dtype, var_names, has_bias=False, has_gelu=False): + """Return a model and any parameters it may have""" + a = relay.var(next(var_names), shape=shape, dtype=dtype) + w = tvm.nd.array(np.random.uniform(-128, 127, weight_shape).astype(dtype)) + weights = relay.const(w, dtype) + out = relay.nn.dense(a, weights, units=units, out_dtype=dtype) + params = {"w": w} + if has_bias: + b = tvm.nd.array(np.random.randint(-128, 127, weight_shape[0]).astype(dtype)) + biasc = relay.const(b, dtype) + out = relay.op.add(out, biasc) + params["b"] = b + if has_gelu: + const1 = relay.const(0.044715) + const2 = relay.const(math.sqrt(2 / math.pi)) + bias = out + out = relay.op.power(bias, relay.const(3.0, "float32")) + out = relay.op.multiply(out, const1) + out = relay.op.add(out, bias) + out = relay.op.multiply(out, const2) + out = relay.op.tanh(out) + out = relay.op.add(out, relay.const(1, "float32")) + out = relay.op.multiply(out, relay.const(0.5)) + out = relay.op.multiply(out, bias) + return out, params + + +def _get_expected_codegen(shape, weight_shape, units, dtype, has_bias=False, has_gelu=False): + output_shape = (shape[0], units) + name = "nn.dense" + if has_bias is True: + name = "bnns.dense_bias" + if has_bias is True and has_gelu is True: + name = "bnns.dense_bias_gelu" + + node = { + "op": "kernel", + "name": name, + "inputs": [], + "attrs": { + "num_outputs": "1", + "out_dtype": [["float32"]], + "shape": [[list(output_shape)]], + "dtype": [[dtype]], + "units": [[str(units)]], + }, + } + + inputs = [ + {"op": "input", "name": "", "attrs": {"shape": [[list(shape)]], "dtype": [[str(dtype)]]}}, + { + "op": "const", + "name": "", + "attrs": {"shape": [[list(weight_shape)]], "dtype": [[str(dtype)]]}, + }, + ] + + if has_bias: + inputs.append( + { + "op": "const", + "name": "", + "attrs": {"shape": [[[weight_shape[0]]]], "dtype": [["float32"]]}, + } + ) + + input_idx = 0 + for _ in range(len(inputs)): + node["inputs"].append([input_idx, 0, 0]) + input_idx += 1 + node["attrs"]["num_inputs"] = str(len(inputs)) + inputs.append(node) + return inputs + + +def test_dense(): + if skip_runtime_test(): + return + + device = Device() + np.random.seed(0) + + dtype = ["float32"] + shape = [ + ((1, 128), (16, 128), 16), + ((32, 32), (32, 32), 32), + ((1, 64), (1, 64), 1), + ((11, 2), (2, 2), 2), + ((2, 2), (1, 2), 1), + ] + composite = [False, True] + trials = generate_trials([dtype, shape, composite, composite], 3) + + for dtype, (shape, weight_shape, units), with_bias, with_gelu in trials: + outputs = [] + inputs = {"a": tvm.nd.array(np.random.uniform(-128, 127, shape).astype(dtype))} + func, params = _get_model( + shape, weight_shape, units, dtype, var_names=iter(inputs), has_bias=with_bias, has_gelu=with_gelu + ) + for bnns in [False, True]: + outputs.append( + build_and_run( + func, + inputs, + 1, + params, + device, + enable_bnns=bnns, + )[0] + ) + + config = { + "shape": shape, + "weight_shape": weight_shape, + "units": units, + "dtype": dtype, + "with_bias": with_bias, + "with_gelu": with_gelu, + } + verify(outputs, atol=0.001, rtol=0.01, config=config) + + +def test_codegen_dense(): + if skip_codegen_test(): + return + + np.random.seed(0) + + dtype = ["float32"] + shape = [ + ((1, 128), (16, 128), 16), + ((32, 32), (32, 32), 32), + ((1, 64), (1, 64), 1), + ((11, 2), (2, 2), 2), + ((2, 2), (1, 2), 1), + ] + composite = [False, True] + trials = generate_trials([dtype, shape, composite, composite], 3) + + for dtype, (shape, weight_shape, units), with_bias, with_gelu in trials: + inputs = {"a"} + + args = (shape, weight_shape, units, dtype) + + func, params = _get_model(*args, var_names=iter(inputs), has_bias=with_bias, has_gelu=with_gelu) + exp_codegen = _get_expected_codegen(*args, has_bias=with_bias, has_gelu=with_gelu) + verify_codegen(func, exp_codegen, 1) + + +if __name__ == "__main__": + test_dense() + test_codegen_dense() + diff --git a/tests/python/contrib/test_bnns/test_matmul.py b/tests/python/contrib/test_bnns/test_matmul.py new file mode 100644 index 000000000000..38683e8748b6 --- /dev/null +++ b/tests/python/contrib/test_bnns/test_matmul.py @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""BNNS integration dense tests.""" + +import numpy as np +import math + +import tvm +from tvm import relay +from tvm import testing +from .infrastructure import ( + Device, + skip_runtime_test, + skip_codegen_test, + verify_codegen, + build_and_run, + verify, + generate_trials, +) + + +def _get_model(a_shape, b_shape, dtype, var_names, is_a_constant=False, is_b_constant=False): + """Return a model and any parameters it may have""" + a = relay.var(next(var_names), shape=a_shape, dtype=dtype) + b = relay.var(next(var_names), shape=b_shape, dtype=dtype) + params = {} + if is_b_constant is True: + b = tvm.nd.array(np.random.uniform(-128, 127, b_shape).astype(dtype)) + params['b'] = b + b = relay.const(b, dtype) + if is_a_constant is True: + a = tvm.nd.array(np.random.uniform(-128, 127, a_shape).astype(dtype)) + params['a'] = a + a = relay.const(a, dtype) + out = relay.nn.batch_matmul(a, b) + return out, params + + +def test_matmul(): + if skip_runtime_test(): + return + + device = Device() + np.random.seed(0) + dtype = "float32" + + # C[N, I, J] = A[N, I, K] * B[N, J, K] + shapes_config = [ + # B, I, J, K + [1, 4, 4, 3], + [1, 16, 32, 32], + [2, 1, 1, 3], + [2, 16, 32, 32], + [5, 1, 1, 3], + ] + data_config = [ + # A_is_constant, B_is_constant + [False, True], + [True, False], + [False, False], + ] + + for N, I, J, K in shapes_config: + a_shape = [N, I, K] + b_shape = [N, J, K] + for is_a_constant, is_b_constant in data_config: + outputs = [] + inputs = { + "a": tvm.nd.array(np.random.uniform(-128, 127, a_shape).astype(dtype)), + "b": tvm.nd.array(np.random.uniform(-128, 127, b_shape).astype(dtype)), + } + func, params = _get_model( + a_shape, b_shape, dtype, + var_names=iter(inputs), + is_a_constant=is_a_constant, + is_b_constant=is_b_constant + ) + for enable_bnns in [False, True]: + outputs.append( + build_and_run( + func, + inputs, + 1, + params, + device, + enable_bnns=enable_bnns, + )[0] + ) + + config = { + "a_shape": a_shape, + "b_shape": b_shape, + "dtype": dtype, + } + verify(outputs, atol=0.001, rtol=0.01, config=config) + + +if __name__ == "__main__": + test_matmul() + + From 3552b8af50bb779b21540efa59b4612ccd53e9c4 Mon Sep 17 00:00:00 2001 From: Alexander Peskov Date: Mon, 18 Jan 2021 15:49:02 +0300 Subject: [PATCH 02/27] [BNNS] Add conv2d DW test Also fix some pylint issues Signed-off-by: Alexander Peskov --- python/tvm/relay/op/contrib/bnns.py | 40 ++++------------ tests/python/contrib/test_bnns/test_conv2d.py | 48 ++++++++++++++++++- .../contrib/test_bnns/test_conv2d_patterns.py | 1 + tests/python/contrib/test_bnns/test_dense.py | 2 - 4 files changed, 58 insertions(+), 33 deletions(-) diff --git a/python/tvm/relay/op/contrib/bnns.py b/python/tvm/relay/op/contrib/bnns.py index 17aad38d3063..bee92a228869 100644 --- a/python/tvm/relay/op/contrib/bnns.py +++ b/python/tvm/relay/op/contrib/bnns.py @@ -20,33 +20,17 @@ to handle tensor processing. Particularly: * BNNS (basic neural ) * vDSP (1D and 2D tensor processing) - * BLAS (gemm provide) - -# There are two ways to registering a function for an op to indicate if it is -# supported by DNNL. - -# - The first and simplest way is to use the helper so that -# users only need to provide the operator name and a boolean value to indicate if -# it is supported. For example: -# -# .. code-block:: python -# -# add = _register_external_op_helper("add") -# add = _register_external_op_helper("add", True) -# add = _register_external_op_helper("add", False) -# -# - The other way is to implement the function by themselves to -# check the attributes of the op and decide if it should be offloaded to DNNL. """ import math import tvm.ir -from ...dataflow_pattern import wildcard, is_op, is_expr, is_constant +from ...dataflow_pattern import wildcard, is_op, is_expr from .register import register_pattern_table, get_pattern_table from tvm.relay import transform from tvm.relay.expr import const from tvm.relay.build_module import bind_params_by_name + def partition_for_bnns(mod, params=None): """Partition the graph greedily offloading supported operators to BNNS. @@ -109,6 +93,7 @@ def _func_wrapper(expr): return _func_wrapper + _register_external_op_helper("nn.batch_matmul") @@ -184,18 +169,13 @@ def make_conv_relu_pattern(with_bias=True, with_relu=True): def check_conv(extract): """Check conv pattern is supported by BNNS.""" - is_ok = True - - def visit(op): - nonlocal is_ok - if isinstance(op, tvm.relay.Call): - if op.op.name == "nn.conv2d": - is_ok &= conv2d_check(op) - elif op.op.name in ("nn.bias_add", "add"): - is_ok &= bias_check(op) - - tvm.relay.analysis.post_order_visit(extract, visit) - return is_ok + bias_is_ok = True + call = extract + while call.op.name != "nn.conv2d": + if call.op.name in ("nn.bias_add", "add"): + bias_is_ok &= bias_check(call) + call = call.args[0] + return conv2d_check(call) and bias_is_ok def make_dense_bias_pattern(): diff --git a/tests/python/contrib/test_bnns/test_conv2d.py b/tests/python/contrib/test_bnns/test_conv2d.py index 7f641ab1fd60..5dbeddcc5573 100644 --- a/tests/python/contrib/test_bnns/test_conv2d.py +++ b/tests/python/contrib/test_bnns/test_conv2d.py @@ -24,7 +24,6 @@ from .infrastructure import Device from .infrastructure import ( skip_runtime_test, - skip_codegen_test, build_and_run, verify, generate_trials, @@ -36,6 +35,7 @@ # 3. Check if bias/weights is not constants. Should fallback into LLVM or decompose it # 4. Check if bias/weights is constants expr. Should works somehow. + def _get_model( shape, kernel_h, @@ -146,5 +146,51 @@ def test_conv2d(): verify(outputs, atol=0.002, rtol=0.007, config=config) +def test_conv2d_dw(): + if skip_runtime_test(): + return + + device = Device() + np.random.seed(0) + dtype = "float32" + shape = [4, 5, 5] + kernel = [3, 3] + pad = [1, 1] + + for batch in [1, 2]: + i_shape = (batch, *shape) + channels = shape[0] + outputs = [] + inputs = { + "a": tvm.nd.array(np.random.uniform(0, 127, i_shape).astype(dtype)), + } + + a = relay.var("a", shape=i_shape, dtype=dtype) + weight_shape = [channels, 1, *kernel] + w = tvm.nd.array(np.random.uniform(-128, 127, weight_shape).astype(dtype)) + weights = relay.const(w, dtype) + func = relay.nn.conv2d( + a, + weights, + kernel_size=kernel, + padding=pad, + groups=channels, + channels=channels, + out_dtype=dtype, + ) + params = {"w": w} + + for bnns in [False, True]: + outputs.append(build_and_run(func, inputs, 1, params, device, enable_bnns=bnns)[0]) + + config = { + "shape": shape, + "kernel size": kernel, + "padding": pad, + "out channels": channels, + } + verify(outputs, atol=0.002, rtol=0.007, config=config) + + if __name__ == "__main__": test_conv2d() diff --git a/tests/python/contrib/test_bnns/test_conv2d_patterns.py b/tests/python/contrib/test_bnns/test_conv2d_patterns.py index 6e19957f43fa..9dc0695d57c4 100644 --- a/tests/python/contrib/test_bnns/test_conv2d_patterns.py +++ b/tests/python/contrib/test_bnns/test_conv2d_patterns.py @@ -24,6 +24,7 @@ fp32 = "float32" + def partition(exp): """Apply BNNS specific partitioning transformation""" mod = tvm.IRModule.from_expr(exp) diff --git a/tests/python/contrib/test_bnns/test_dense.py b/tests/python/contrib/test_bnns/test_dense.py index 4e4a03702688..f995194d1775 100644 --- a/tests/python/contrib/test_bnns/test_dense.py +++ b/tests/python/contrib/test_bnns/test_dense.py @@ -21,7 +21,6 @@ import tvm from tvm import relay -from tvm import testing from .infrastructure import ( Device, skip_runtime_test, @@ -185,4 +184,3 @@ def test_codegen_dense(): if __name__ == "__main__": test_dense() test_codegen_dense() - From 12e6f3e6c2d110ab97b665b0bd8eefddc64e2cdf Mon Sep 17 00:00:00 2001 From: Alexander Peskov Date: Mon, 18 Jan 2021 15:58:40 +0300 Subject: [PATCH 03/27] [BNNS] Fix clang-format issues Signed-off-by: Alexander Peskov --- src/relay/backend/contrib/bnns/codegen.cc | 4 +- src/runtime/contrib/bnns/bnns_json_runtime.cc | 441 +++++++++--------- 2 files changed, 226 insertions(+), 219 deletions(-) diff --git a/src/relay/backend/contrib/bnns/codegen.cc b/src/relay/backend/contrib/bnns/codegen.cc index 504425a225c5..c7c5646d971a 100644 --- a/src/relay/backend/contrib/bnns/codegen.cc +++ b/src/relay/backend/contrib/bnns/codegen.cc @@ -149,7 +149,7 @@ struct BNNSConstantUpdater : public ConstantUpdater { public: BNNSConstantUpdater(const std::string& symbol, std::unordered_map* params, - std::vector &skip_mask) + const std::vector &skip_mask) : ConstantUpdater(symbol, params), skip_mask_(skip_mask) {} using ConstantUpdater::VisitExpr_; @@ -199,7 +199,7 @@ Map BNNSConstantUpdaterFunc(Expr expr, std::string sym // Convert to tvm::Map Map ret; - for (const auto& kvp: res) + for (const auto& kvp : res) ret.Set(kvp.first, kvp.second); return ret; } diff --git a/src/runtime/contrib/bnns/bnns_json_runtime.cc b/src/runtime/contrib/bnns/bnns_json_runtime.cc index ffe1b037d347..563f212ae369 100644 --- a/src/runtime/contrib/bnns/bnns_json_runtime.cc +++ b/src/runtime/contrib/bnns/bnns_json_runtime.cc @@ -59,257 +59,263 @@ using namespace tvm::runtime::json; /** C++ wrapper on top of original BNNS C api */ namespace BNNS { - using Dim = size_t; - using Shape = std::vector; - using Dtype = BNNSDataType; - - void* default_alloc(size_t size) { - // TODO: Clarify, should it have some alignment for better performance - // with SIMD execution.. may be TVMBackendAllocWorkspace is more - // preferable here. - // Note: Apple uses posix_memalign by default. - return malloc(size); - } +using Dim = size_t; +using Shape = std::vector; +using Dtype = BNNSDataType; + +void* default_alloc(size_t size) { + // TODO(apeskov): Clarify, should it have some alignment for better performance + // with SIMD execution.. may be TVMBackendAllocWorkspace is more preferable here. + // Note: Apple uses posix_memalign by default. + return malloc(size); +} - void default_free(void* ptr) { - free(ptr); - } +void default_free(void* ptr) { + free(ptr); +} - class Tensor { - public: - Tensor(Shape shape, Dtype dtype, void* hdl) - : real_shape(shape) { - ICHECK(shape.size() < BNNS_MAX_TENSOR_DIMENSION); - - if (hdl) { - data_handler = hdl; - is_external_data = true; - } else { - const size_t buff_size = getNumOfElements(shape) * getElementSize(dtype); - data_handler = default_alloc(buff_size); - is_external_data = false; - } +class Tensor { + public: + Tensor(Shape shape, Dtype dtype, void* hdl) + : real_shape(shape) { + ICHECK(shape.size() < BNNS_MAX_TENSOR_DIMENSION); - bnns_nd_desc = { - BNNSNDArrayFlags(0), - getPlainLayout(shape), - {}, // shape - {}, // strides, empty value means use default dense strides - hdl, // data handler - dtype, // data type - nullptr, // table_data (clustering case), is not used - dtype, - 1.f, - 0.f - }; - std::copy(shape.rbegin(), shape.rend(), std::begin(bnns_nd_desc.size)); + if (hdl) { + data_handler = hdl; + is_external_data = true; + } else { + const size_t buff_size = getNumOfElements(shape) * getElementSize(dtype); + data_handler = default_alloc(buff_size); + is_external_data = false; } - ~Tensor() { - if (data_handler && !is_external_data) { - default_free(data_handler); - data_handler = nullptr; - } - } + bnns_nd_desc = { + BNNSNDArrayFlags(0), + getPlainLayout(shape), + {}, // shape + {}, // strides, empty value means use default dense strides + hdl, // data handler + dtype, // data type + nullptr, // table_data (clustering case), is not used + dtype, + 1.f, + 0.f + }; + std::copy(shape.rbegin(), shape.rend(), std::begin(bnns_nd_desc.size)); + } - void* get_data_hdl() { return data_handler; } + ~Tensor() { + if (data_handler && !is_external_data) { + default_free(data_handler); + data_handler = nullptr; + } + } - const void* get_data_hdl() const { return data_handler; }; + void* get_data_hdl() { return data_handler; } - void set_data_hdl(void *hdl) { - if (data_handler && !is_external_data) { - default_free(data_handler); - data_handler = nullptr; - } + const void* get_data_hdl() const { return data_handler; } - data_handler = hdl; - is_external_data = true; + void set_data_hdl(void *hdl) { + if (data_handler && !is_external_data) { + default_free(data_handler); + data_handler = nullptr; } - size_t get_mb() const { - return real_shape[0]; - } + data_handler = hdl; + is_external_data = true; + } - size_t get_mb_stride() const { - return std::accumulate(real_shape.begin() + 1, real_shape.end(), - 1, std::multiplies()); - } + size_t get_mb() const { + return real_shape[0]; + } - const BNNSNDArrayDescriptor get_nd_desc(size_t nd = 0) const { - auto original_nd = real_shape.size(); - // Ask of original descriptor - if (original_nd == nd || nd == 0) - return bnns_nd_desc; - - // As of desc with excluded batch - if (original_nd == nd + 1) { - auto res = bnns_nd_desc; - res.size[original_nd - 1] = 0; - res.layout = BNNSDataLayout3DLastMajor; // TODO [apeskov] : hardcoded value. FIXME - return res; - } - LOG(FATAL) << "Unknown case of BNNS tensor interpretation"; + size_t get_mb_stride() const { + return std::accumulate(real_shape.begin() + 1, real_shape.end(), + 1, std::multiplies()); + } + + const BNNSNDArrayDescriptor get_nd_desc(size_t nd = 0) const { + auto original_nd = real_shape.size(); + // Ask of original descriptor + if (original_nd == nd || nd == 0) return bnns_nd_desc; - }; - private: - static BNNSDataLayout getPlainLayout(const Shape &shape) { - return getPlainLayout(shape.size()); + // As of desc with excluded batch + if (original_nd == nd + 1) { + auto res = bnns_nd_desc; + res.size[original_nd - 1] = 0; + res.layout = BNNSDataLayout3DLastMajor; // TODO(apeskov): hardcoded value. FIXME + return res; } + LOG(FATAL) << "Unknown case of BNNS tensor interpretation"; + return bnns_nd_desc; + } - static BNNSDataLayout getPlainLayout(size_t rank) { - switch (rank) { - case 1: return BNNSDataLayout1DFirstMajor; - case 2: return BNNSDataLayout2DFirstMajor; - case 3: return BNNSDataLayout3DFirstMajor; - case 4: return BNNSDataLayout4DFirstMajor; - case 5: return BNNSDataLayout5DFirstMajor; - case 6: return BNNSDataLayout6DFirstMajor; - case 7: return BNNSDataLayout7DFirstMajor; - case 8: return BNNSDataLayout8DFirstMajor; - default: - LOG(FATAL) << "Unsupported tensor rank : " << rank - << " Supported cases is only 1-8 "; - return static_cast(0); - } - } + private: + static BNNSDataLayout getPlainLayout(const Shape &shape) { + return getPlainLayout(shape.size()); + } - /** - * return size in byte of element of provided type - * @param dtype - * @return size of element in bytes - */ - static size_t getElementSize(Dtype dtype) { - return (dtype & 0xFFFF) / sizeof(uint8_t); + static BNNSDataLayout getPlainLayout(size_t rank) { + switch (rank) { + case 1: return BNNSDataLayout1DFirstMajor; + case 2: return BNNSDataLayout2DFirstMajor; + case 3: return BNNSDataLayout3DFirstMajor; + case 4: return BNNSDataLayout4DFirstMajor; + case 5: return BNNSDataLayout5DFirstMajor; + case 6: return BNNSDataLayout6DFirstMajor; + case 7: return BNNSDataLayout7DFirstMajor; + case 8: return BNNSDataLayout8DFirstMajor; + default: + LOG(FATAL) << "Unsupported tensor rank : " << rank + << " Supported cases is only 1-8 "; + return static_cast(0); } + } - static size_t getNumOfElements(const Shape &shape) { - return std::accumulate(shape.begin(), shape.end(), - 1, std::multiplies()); - } + /** + * return size in byte of element of provided type + * @param dtype + * @return size of element in bytes + */ + static size_t getElementSize(Dtype dtype) { + return (dtype & 0xFFFF) / sizeof(uint8_t); + } - private: - Shape real_shape = {}; - void* data_handler = nullptr; - bool is_external_data = false; + static size_t getNumOfElements(const Shape &shape) { + return std::accumulate(shape.begin(), shape.end(), + 1, std::multiplies()); + } - BNNSNDArrayDescriptor bnns_nd_desc; - }; + private: + Shape real_shape = {}; + void* data_handler = nullptr; + bool is_external_data = false; - class Primitive { - public: - Primitive(BNNSFilter f) : num_filters(1), filters{f} {} + BNNSNDArrayDescriptor bnns_nd_desc; +}; - Primitive(BNNSFilter fs[BNNS_MAX_CONCURRENCY]) { - std::copy(fs, fs + BNNS_MAX_CONCURRENCY, filters); - for (int i = 0; i < BNNS_MAX_CONCURRENCY; i++) { - if (filters[i] == nullptr) { - num_filters = i; - break; - } +class Primitive { +public: + explicit Primitive(BNNSFilter f) : num_filters(1), filters{f} {} + + explicit Primitive(BNNSFilter fs[BNNS_MAX_CONCURRENCY]) { + std::copy(fs, fs + BNNS_MAX_CONCURRENCY, filters); + for (int i = 0; i < BNNS_MAX_CONCURRENCY; i++) { + if (filters[i] == nullptr) { + num_filters = i; + break; } } + } - ~Primitive() { - for (size_t i = 0; i < num_filters; i++) { - auto &filter = filters[i]; - if (filter) { - BNNSFilterDestroy(filter); - filter = nullptr; - } + ~Primitive() { + for (size_t i = 0; i < num_filters; i++) { + auto &filter = filters[i]; + if (filter) { + BNNSFilterDestroy(filter); + filter = nullptr; } } + } - void execute(std::vector srcs, Tensor &dst, int forceBatchSize = -1) { - ICHECK_LE(srcs.size(), 2) << "Currently BNNS runtime supports primitives with only 1 or 2 " - "data inputs."; + void execute(std::vector srcs, Tensor *dst, int forceBatchSize = -1) { + ICHECK_LE(srcs.size(), 2) << "Currently BNNS runtime supports primitives with only 1 or 2 " + "data inputs."; - run_ctx ctx { this, srcs[0], nullptr, &dst, forceBatchSize }; - if (srcs.size() > 1) - ctx.src2 = srcs[1]; + run_ctx ctx { this, srcs[0], nullptr, dst, forceBatchSize }; + if (srcs.size() > 1) + ctx.src2 = srcs[1]; - auto res = TVMBackendParallelLaunch(run_task, &ctx, num_filters); - ICHECK_EQ(res, 0) << "BNNS runtime. Primitive was not executed properly"; - } + auto res = TVMBackendParallelLaunch(run_task, &ctx, num_filters); + ICHECK_EQ(res, 0) << "BNNS runtime. Primitive was not executed properly"; + } - void set_input_stride(size_t stride1, size_t stride2 = 0) { - in1_hdl_stride = stride1; - in2_hdl_stride = stride2; - } - void set_output_stride(size_t stride) { out_hdl_stride = stride; } - - private: - struct run_ctx { - Primitive *prim; - const Tensor *src1; - const Tensor *src2; - Tensor *dst; - const int force_batch_size; - }; + void set_input_stride(size_t stride1, size_t stride2 = 0) { + in1_hdl_stride = stride1; + in2_hdl_stride = stride2; + } + void set_output_stride(size_t stride) { out_hdl_stride = stride; } - static int run_task(int task_id, TVMParallelGroupEnv* penv, void* cdata) { - auto ctx = reinterpret_cast(cdata); - const auto *prim = ctx->prim; + private: + struct run_ctx { + Primitive *prim; + const Tensor *src1; + const Tensor *src2; + Tensor *dst; + const int force_batch_size; + }; - const auto &filter = prim->filters[task_id]; + static int run_task(int task_id, TVMParallelGroupEnv* penv, void* cdata) { + auto ctx = reinterpret_cast(cdata); + const auto *prim = ctx->prim; - auto src1_hdl = ctx->src1->get_data_hdl(); - auto dst_hdl = ctx->dst->get_data_hdl(); + const auto &filter = prim->filters[task_id]; - auto src1_mb = ctx->src1->get_mb(); - auto dst_mb = ctx->dst->get_mb(); + auto src1_hdl = ctx->src1->get_data_hdl(); + auto dst_hdl = ctx->dst->get_data_hdl(); - auto src1_mb_stride = ctx->src1->get_mb_stride(); - auto dst_mb_stride = ctx->dst->get_mb_stride(); + auto src1_mb = ctx->src1->get_mb(); + auto dst_mb = ctx->dst->get_mb(); - src1_hdl = static_cast(src1_hdl) + task_id*prim->in1_hdl_stride; - dst_hdl = static_cast(dst_hdl) + task_id*prim->out_hdl_stride; + auto src1_mb_stride = ctx->src1->get_mb_stride(); + auto dst_mb_stride = ctx->dst->get_mb_stride(); - ICHECK(src1_mb == dst_mb) << "Mismatch of batch dimension of input/output tensors"; + src1_hdl = static_cast(src1_hdl) + task_id*prim->in1_hdl_stride; + dst_hdl = static_cast(dst_hdl) + task_id*prim->out_hdl_stride; - const void* src2_hdl = nullptr; - size_t src2_mb = 0; - size_t src2_mb_stride = 0; + ICHECK(src1_mb == dst_mb) << "Mismatch of batch dimension of input/output tensors"; - if (ctx->src2) { - src2_hdl = ctx->src2->get_data_hdl(); - src2_mb = ctx->src2->get_mb(); - src2_mb_stride = ctx->src2->get_mb_stride(); - src2_hdl = static_cast(src2_hdl) + task_id*prim->in2_hdl_stride; - ICHECK(src2_mb == dst_mb) << "Mismatch of batch dimension of input/output tensors"; - } + const void* src2_hdl = nullptr; + size_t src2_mb = 0; + size_t src2_mb_stride = 0; - const auto mb = (ctx->force_batch_size == -1) ? dst_mb : ctx->force_batch_size; - - // NB! Limitations - // * Do not use simple BNNSFilterApply. There is a bug inside BNNS, - // and BNNSFilterApply doesn't work for grouped convolution. - // * Group convolution doesn't support arbitrary stride for Batch dim. - // The tensor should be dense. - auto sts = (ctx->src2) - ? BNNSFilterApplyTwoInputBatch(filter, mb, - src1_hdl, src1_mb_stride, - src2_hdl, src2_mb_stride, - dst_hdl, dst_mb_stride) - : BNNSFilterApplyBatch(filter, mb, - src1_hdl, src1_mb_stride, - dst_hdl, dst_mb_stride); - - return sts; + if (ctx->src2) { + src2_hdl = ctx->src2->get_data_hdl(); + src2_mb = ctx->src2->get_mb(); + src2_mb_stride = ctx->src2->get_mb_stride(); + src2_hdl = static_cast(src2_hdl) + task_id*prim->in2_hdl_stride; + ICHECK(src2_mb == dst_mb) << "Mismatch of batch dimension of input/output tensors"; } - private: - size_t num_filters = 0; - BNNSFilter filters[BNNS_MAX_CONCURRENCY] = {}; + const auto mb = (ctx->force_batch_size == -1) ? dst_mb : ctx->force_batch_size; - // TODO: temporal solution with strides - size_t in1_hdl_stride = 0; - size_t in2_hdl_stride = 0; - size_t out_hdl_stride = 0; - }; -} + // WA + if (mb == 1) { + src1_mb_stride = prim->in1_hdl_stride / sizeof(float); + dst_mb_stride = prim->out_hdl_stride / sizeof(float); + } -struct BNNSConfig { + // NB! Limitations + // * Do not use simple BNNSFilterApply. There is a bug inside BNNS, + // and BNNSFilterApply doesn't work for grouped convolution. + // * Group convolution doesn't support arbitrary stride for Batch dim. + // The tensor should be dense. + auto sts = (ctx->src2) + ? BNNSFilterApplyTwoInputBatch(filter, mb, + src1_hdl, src1_mb_stride, + src2_hdl, src2_mb_stride, + dst_hdl, dst_mb_stride) + : BNNSFilterApplyBatch(filter, mb, + src1_hdl, src1_mb_stride, + dst_hdl, dst_mb_stride); + + return sts; + } + + private: + size_t num_filters = 0; + BNNSFilter filters[BNNS_MAX_CONCURRENCY] = {}; + + // TODO(apeskov): temporal solution with strides + size_t in1_hdl_stride = 0; + size_t in2_hdl_stride = 0; + size_t out_hdl_stride = 0; +}; + +} // namespace BNNS + +struct BNNSThreadingConfig { /** * Internal parallelism level ov BNNS primitive specified via parameter. * Has no real control from TVM level, so in fact it may be ignored by @@ -326,7 +332,6 @@ struct BNNSConfig { }; class BNNSJSONRuntime : public JSONRuntimeBase { - public: BNNSJSONRuntime(const std::string& symbol_name, const std::string& graph_json, const Array const_names) @@ -367,7 +372,7 @@ class BNNSJSONRuntime : public JSONRuntimeBase { int forceBatchSize = (force_batch_size_.find(i) == force_batch_size_.end()) ? -1 : force_batch_size_.at(i); - primitives_.at(i)->execute(args, *res, forceBatchSize); + primitives_.at(i)->execute(args, res.get(), forceBatchSize); } } @@ -404,7 +409,8 @@ class BNNSJSONRuntime : public JSONRuntimeBase { } // Bind a JSON graph node entry to a BNNS tensor. - std::shared_ptr BindBNNSTensor(const JSONGraphNodeEntry& entry, void *hdl = nullptr) { + std::shared_ptr BindBNNSTensor(const JSONGraphNodeEntry& entry, + void *hdl = nullptr) { auto eid = EntryID(entry); if (entry_out_mem_.count(eid) == 0) { auto data_node = nodes_[entry.id_]; @@ -427,7 +433,8 @@ class BNNSJSONRuntime : public JSONRuntimeBase { * @return collection of Convolution descriptors plus strides for input and output tensors */ static std::tuple, size_t, size_t> - split_into_n(const BNNSLayerParametersConvolution& orig_conv_param, size_t batch, size_t num) { + split_into_n(const BNNSLayerParametersConvolution& orig_conv_param, + size_t batch, size_t num) { size_t i_shape[BNNS_MAX_TENSOR_DIMENSION] = {}; size_t o_shape[BNNS_MAX_TENSOR_DIMENSION] = {}; size_t w_shape[BNNS_MAX_TENSOR_DIMENSION] = {}; @@ -437,8 +444,8 @@ class BNNSJSONRuntime : public JSONRuntimeBase { size_t i_stride = 0; size_t o_stride = 0; - // TODO: In case of batch we can split through bach dimension. - // Meanwhile we just disable it... + // TODO(apeskov): In case of batch we can split through bach dimension. + // Meanwhile we just disable it... if (batch > 1) { return {{orig_conv_param}, 0, 0}; } @@ -582,7 +589,7 @@ class BNNSJSONRuntime : public JSONRuntimeBase { auto weights_md = BindBNNSTensor(weight_entry, weight_ext_data_hdl); std::shared_ptr bias_md; auto dst_md = BindBNNSTensor(dst_entry); - // TODO [apeskov]: check correctness of tensor shapes + // TODO(apeskov): check correctness of tensor shapes if (has_bias) { auto bias_entry = node.GetInputs()[2]; @@ -610,7 +617,7 @@ class BNNSJSONRuntime : public JSONRuntimeBase { weights_candidate.layout = BNNSDataLayoutConvolutionWeightsOIHW; bias_candidate.layout = BNNSDataLayoutVector; - // TODO [apeskov]: Tmp WA, broadcast bias is here with tailing [1, 1] + // TODO(apeskov): Tmp WA, broadcast bias is here with tailing [1, 1] if (bias_candidate.size[0] == 1 && bias_candidate.size[1] == 1 && one_of(bias_candidate.size[3], 1, 0) && std::all_of(bias_candidate.size + 4, bias_candidate.size + BNNS_MAX_TENSOR_DIMENSION, @@ -737,11 +744,11 @@ class BNNSJSONRuntime : public JSONRuntimeBase { b_desc.data = b_data; BNNSLayerParametersBroadcastMatMul layerParameters = { - 1, // alpha - 0, // beta - false, // transA - true, // transB - false, // quadratic + 1, // alpha + 0, // beta + false, // transA + true, // transB + false, // quadratic a_is_weighted, b_is_weighted, a_desc, From a9a9f108d73e0a1f51f63659264a381752f1e7af Mon Sep 17 00:00:00 2001 From: Alexander Peskov Date: Tue, 19 Jan 2021 13:59:07 +0300 Subject: [PATCH 04/27] [BNNS] Refactoring. Add TView abstraction Signed-off-by: Alexander Peskov --- src/runtime/contrib/bnns/bnns.hpp | 407 +++++++++++ src/runtime/contrib/bnns/bnns_json_runtime.cc | 660 ++++-------------- 2 files changed, 536 insertions(+), 531 deletions(-) create mode 100644 src/runtime/contrib/bnns/bnns.hpp diff --git a/src/runtime/contrib/bnns/bnns.hpp b/src/runtime/contrib/bnns/bnns.hpp new file mode 100644 index 000000000000..dc071032250f --- /dev/null +++ b/src/runtime/contrib/bnns/bnns.hpp @@ -0,0 +1,407 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * \file + * \brief C++ wrappers and helpers to handle BNNS objects + */ + +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace runtime { +namespace contrib { +namespace BNNS { + +using Dim = size_t; +using Shape = std::vector; +using Dtype = BNNSDataType; +using HDL = void*; + +void* default_alloc(size_t size) { + return malloc(size); +} + +void default_free(void* ptr) { + free(ptr); +} + +/** + * Main abstraction for tensor representation + * + * Contains buffer handler and common attributes like shape and dtype. + */ +class Tensor { + public: + Tensor() = delete; + Tensor(Tensor&) = delete; + + Tensor(Shape shape, Dtype dtype, void* hdl) { + auto rank = shape.size(); + ICHECK(rank < BNNS_MAX_TENSOR_DIMENSION); + + desc_ = { + BNNSNDArrayFlags(0), + getPlainLayout(rank), + {}, {}, // shape and strides + hdl, // data handler + dtype, // data type + nullptr, dtype, 1.f, 0.f // table_data (clustering case), is not used + }; + std::copy(shape.rbegin(), shape.rend(), std::begin(desc_.size)); + + if (hdl) { + desc_.data = hdl; + is_external_data = true; + } else { + const size_t buff_size = getSize(desc_) * getElementSize(desc_); + desc_.data = default_alloc(buff_size); + is_external_data = false; + } + } + + ~Tensor() { + if (desc_.data && !is_external_data) { + default_free(desc_.data); + desc_.data = nullptr; + } + } + + void* get_data_hdl() const { return desc_.data; } + + void set_data_hdl(void *hdl) { + if (desc_.data && !is_external_data) { + default_free(desc_.data); + desc_.data = nullptr; + } + + desc_.data = hdl; + is_external_data = true; + } + + const BNNSNDArrayDescriptor& get_desc() const { return desc_; } + + static BNNSDataLayout getPlainLayout(size_t rank) { + ICHECK(rank <= BNNS_MAX_TENSOR_DIMENSION); + return static_cast((rank << 16) | 0x8001); + } + + static size_t getRank(BNNSDataLayout layout) { + return (layout & 0xF0000) >> 16; + } + + static size_t getRank(BNNSNDArrayDescriptor desc) { + return getRank(desc.layout); + } + + static size_t getSize(BNNSNDArrayDescriptor desc) { + auto rank = getRank(desc); + return std::accumulate(desc.size, desc.size + rank, 1, std::multiplies()); + } + + /** return size of element in bytes */ + static size_t getElementSize(Dtype dtype) { + return (dtype & 0xFFFF) / 8; + } + + /** return size of element in bytes */ + static size_t getElementSize(const BNNSNDArrayDescriptor &desc) { + return getElementSize(desc.data_type); + } + + private: + bool is_external_data = false; + BNNSNDArrayDescriptor desc_; +}; + +using TensorPtr = std::shared_ptr; + +/** + * Tensor View object which represent how provided BNNS::Tensor will be considered + * + * The single BNNS::Tensor can be treated in different form depend on particular primitive + * expectation. More other some primitive supports only external form of batching. So we have + * some abstraction to describe how primitive will handle provided tensor. + * + * Batched View + * View with extracted dimension as external batch value + * example: Tensor [2, 3, 224, 224] -> View [3, 224, 224] with ext batch 2 + * + * Party View + * The collection of view on the same tensor, can be the same view or with some stride + * example: Tensor [6, 5, 3, 3] -> 3 x View [2, 5, 3, 3] with stride 45 + */ +class TView { + public: + /** Make view on provided tensor as is */ + static TView as_is(const TensorPtr &origin) { + TView res; + res.origin_ = origin; + res.view_desc_ = origin->get_desc(); + return res; + } + + /** Extract outer dimension to separate batch field. TView will became batched view */ + TView extract_outer_dim() const { + auto rank = Tensor::getRank(view_desc_); + TView res = *this; + res.batch_size_ = view_desc_.size[rank - 1]; + res.batch_stride_ = std::accumulate(view_desc_.size, view_desc_.size + rank - 1, + 1, std::multiplies<>()); + res.view_desc_.size[rank - 1] = 0; + res.view_desc_.layout = Tensor::getPlainLayout(rank - 1); + return res; + }; + + /** Squeeze all dims equal 1 */ + TView squeeze() const { + auto rank = Tensor::getRank(view_desc_); + size_t squeezed_shape[BNNS_MAX_TENSOR_DIMENSION] = {}; + size_t squeezed_rank = 0; + for (int i = 0; i < rank; i++) + if (view_desc_.size[i] != 1) + squeezed_shape[squeezed_rank++] = view_desc_.size[i]; + + TView res = *this; + std::copy(squeezed_shape, squeezed_shape + squeezed_rank, res.view_desc_.size); + std::fill(res.view_desc_.size + squeezed_rank, res.view_desc_.size + rank, 0); + res.view_desc_.layout = Tensor::getPlainLayout(squeezed_rank); + return res; + }; + + /** Construct new TView with specified layout if it applicable */ + TView with_layout(BNNSDataLayout layout) const { + ICHECK_EQ(Tensor::getRank(view_desc_), Tensor::getRank(layout)); + + TView res = *this; + res.view_desc_.layout = layout; + return res; + } + + /** Construct party TView by splitting original TView into num parts */ + TView party_split_n(size_t num) const { + ICHECK_EQ(party_size_, 1); + + TView res = *this; + size_t rank = Tensor::getRank(view_desc_); + size_t size = Tensor::getSize(view_desc_); + res.party_size_ = num; + res.party_stride_ = size / num; + + if (res.batch_size_ != 1) { + res.batch_size_ /= num; + } else { + res.view_desc_.size[rank - 1] /= num; + res.batch_stride_ /= num; + } + return res; + } + + /** Construct party TView by duplicating original TView num times */ + TView party_duplicate_n(size_t num) const { + ICHECK_EQ(party_size_, 1); + + TView res = *this; + res.party_size_ = num; + res.party_stride_ = 0; + + return res; + } + + /** Return data buffer handler */ + HDL get_data_hdl() const { return view_desc_.data; } + + /** Return external batch dimension value */ + size_t get_batch_size() const { return batch_size_; } + + /** Return external batch dimension stride */ + size_t get_stride() const { return batch_stride_; } + + /** Return party element by index */ + TView operator [](size_t i) const { + ICHECK_LT(i, party_size_); + + TView res = *this; + res.party_size_ = 1; + if (origin_) { + auto hdl = reinterpret_cast(origin_->get_data_hdl()); + hdl += i * party_stride_ * Tensor::getElementSize(view_desc_.data_type); + res.view_desc_.data = hdl; + } + return res; + } + + /** Check if view is empty and doesn't relay to any tensor */ + operator bool() const { return view_desc_.data != nullptr; } + + /** Get BNNS descriptor for particular View. Batch and Party attributed are ignored. */ + const BNNSNDArrayDescriptor& get_bnns_view() const { + return view_desc_; + } + + private: + /** Original tensor object to view on */ + TensorPtr origin_; + + /** Batched view parameters */ + BNNSNDArrayDescriptor view_desc_ = {}; + size_t batch_size_ = 1; + size_t batch_stride_ = 0; + + /** Party representation parameters */ + size_t party_size_ = 1; + size_t party_stride_ = 0; +}; + +/** + * Wrapper on top of BNNSFilter and src/dst TensorView. + * + * Support decomposed representation of filter and can execute sub primitives in parallel. + */ +class Primitive { + public: + Primitive(const std::vector fs, + const TView& src, + const TView& dst) + : filters(fs), src_view(src), src2_view(), dst_view(dst) {} + + Primitive(const std::vector fs, + const TView& src, + const TView& src2, + const TView& dst) + : filters(fs), src_view(src), src2_view(src2), dst_view(dst) {} + + ~Primitive() { + for (auto &filter : filters) + if (filter) { + BNNSFilterDestroy(filter); + filter = nullptr; + } + } + + /** Execute primitive with using specified src/dst */ + void execute() { + auto res = TVMBackendParallelLaunch(run_task, this, filters.size()); + ICHECK_EQ(res, 0) << "BNNS runtime. Primitive was not executed properly"; + } + + private: + static int run_task(int task_id, TVMParallelGroupEnv* penv, void* cdata) { + auto prim = reinterpret_cast(cdata); + const auto filter = prim->filters[task_id]; + const auto src_view = prim->src_view[task_id]; + const auto dst_view = prim->dst_view[task_id]; + TView src2_view; + if (prim->src2_view) + src2_view = prim->src2_view[task_id]; + + size_t mb = src_view.get_batch_size(); + + // NB! BNNS limitations + // * Do not use simple BNNSFilterApply. There is a bug inside BNNS, + // BNNSFilterApply doesn't work for grouped convolution. + // * Group convolution doesn't support arbitrary stride for Batch dim. + // The tensor should be dense. + auto sts = (prim->src2_view) + ? BNNSFilterApplyTwoInputBatch(filter, mb, + src_view.get_data_hdl(), src_view.get_stride(), + src2_view.get_data_hdl(), src2_view.get_stride(), + dst_view.get_data_hdl(), dst_view.get_stride()) + : BNNSFilterApplyBatch(filter, mb, + src_view.get_data_hdl(), src_view.get_stride(), + dst_view.get_data_hdl(), dst_view.get_stride()); + return sts; + } + + private: + /** BNNS kernels/filters collect which will execute primitive */ + std::vector filters = {}; + const TView src_view, src2_view; + const TView dst_view; +}; + +/** + * Function which split primitive into sub primitives to parallel execution + * + * @param num requested num of sub primitives + * @param orig_conv_param original convolution descriptor + * @param src_view source tensor view + * @param wgh_view weight tensor view + * @param b_view bias tensor view + * @param dst_view destination tensor view + * @param num number of part to split into + * @return collection of Convolution descriptors plus corresponding src/dst tensors view + */ +static std::tuple, TView, TView> split_to_n( + size_t num, + const BNNSLayerParametersConvolution& orig_conv_param, + const TView &src_view, + const TView &wgh_view, + const TView &b_view, + const TView &dst_view) { + size_t batch = src_view.get_batch_size(); + size_t groups = orig_conv_param.groups; + + BNNS::TView src_view_new; + BNNS::TView wgh_view_new; + BNNS::TView b_view_new; + BNNS::TView dst_view_new; + + // TODO(apeskov): Add split by batch dim. Meanwhile we just disable it... + if (batch > 1 || (groups > 1 && groups % num != 0)) { + return {{orig_conv_param}, src_view, dst_view}; + } + + // if groups > 1 split only by groups + // otherwise split inside one convolution by output channels + if (groups > 1) { + src_view_new = src_view.party_split_n(num); + groups = groups / num; + } else { + src_view_new = src_view.party_duplicate_n(num); + } + + wgh_view_new = wgh_view.party_split_n(num); + b_view_new = b_view.party_split_n(num); + dst_view_new = dst_view.party_split_n(num); + + std::vector res(num); + for (size_t i=0; i < num; i++) { + auto &cur = res[i]; + cur = orig_conv_param; + + cur.i_desc = src_view_new[i].get_bnns_view(); + cur.o_desc = dst_view_new[i].get_bnns_view(); + cur.w_desc = wgh_view_new[i].get_bnns_view(); + cur.bias = b_view_new[i].get_bnns_view(); + cur.groups = groups; + } + return {res, src_view_new, dst_view_new}; +} + +} // namespace BNNS +} // namespace contrib +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/bnns/bnns_json_runtime.cc b/src/runtime/contrib/bnns/bnns_json_runtime.cc index 563f212ae369..cc13d10bca88 100644 --- a/src/runtime/contrib/bnns/bnns_json_runtime.cc +++ b/src/runtime/contrib/bnns/bnns_json_runtime.cc @@ -29,308 +29,62 @@ #include #include #include -#include #include #include "../json/json_node.h" #include "../json/json_runtime.h" -#include "Accelerate/Accelerate.h" - -#define BNNS_TMP_CONCURRENCY 2 -#define BNNS_MAX_CONCURRENCY 8 - -template -bool one_of(T1 arg1, T2 arg2) { - return arg1 == arg2; -} - -template -bool one_of(T1 arg1, T2 arg2, T... args) { - return arg1 == arg2 || one_of(arg1, args...); -} +#include "bnns.hpp" namespace tvm { namespace runtime { namespace contrib { -using namespace tvm::runtime; -using namespace tvm::runtime::json; - -/** C++ wrapper on top of original BNNS C api */ -namespace BNNS { -using Dim = size_t; -using Shape = std::vector; -using Dtype = BNNSDataType; - -void* default_alloc(size_t size) { - // TODO(apeskov): Clarify, should it have some alignment for better performance - // with SIMD execution.. may be TVMBackendAllocWorkspace is more preferable here. - // Note: Apple uses posix_memalign by default. - return malloc(size); -} - -void default_free(void* ptr) { - free(ptr); -} - -class Tensor { - public: - Tensor(Shape shape, Dtype dtype, void* hdl) - : real_shape(shape) { - ICHECK(shape.size() < BNNS_MAX_TENSOR_DIMENSION); - - if (hdl) { - data_handler = hdl; - is_external_data = true; - } else { - const size_t buff_size = getNumOfElements(shape) * getElementSize(dtype); - data_handler = default_alloc(buff_size); - is_external_data = false; - } - - bnns_nd_desc = { - BNNSNDArrayFlags(0), - getPlainLayout(shape), - {}, // shape - {}, // strides, empty value means use default dense strides - hdl, // data handler - dtype, // data type - nullptr, // table_data (clustering case), is not used - dtype, - 1.f, - 0.f - }; - std::copy(shape.rbegin(), shape.rend(), std::begin(bnns_nd_desc.size)); - } - - ~Tensor() { - if (data_handler && !is_external_data) { - default_free(data_handler); - data_handler = nullptr; - } - } - - void* get_data_hdl() { return data_handler; } - - const void* get_data_hdl() const { return data_handler; } - - void set_data_hdl(void *hdl) { - if (data_handler && !is_external_data) { - default_free(data_handler); - data_handler = nullptr; - } - - data_handler = hdl; - is_external_data = true; - } - - size_t get_mb() const { - return real_shape[0]; - } - - size_t get_mb_stride() const { - return std::accumulate(real_shape.begin() + 1, real_shape.end(), - 1, std::multiplies()); - } - - const BNNSNDArrayDescriptor get_nd_desc(size_t nd = 0) const { - auto original_nd = real_shape.size(); - // Ask of original descriptor - if (original_nd == nd || nd == 0) - return bnns_nd_desc; - - // As of desc with excluded batch - if (original_nd == nd + 1) { - auto res = bnns_nd_desc; - res.size[original_nd - 1] = 0; - res.layout = BNNSDataLayout3DLastMajor; // TODO(apeskov): hardcoded value. FIXME - return res; - } - LOG(FATAL) << "Unknown case of BNNS tensor interpretation"; - return bnns_nd_desc; - } - - private: - static BNNSDataLayout getPlainLayout(const Shape &shape) { - return getPlainLayout(shape.size()); - } - - static BNNSDataLayout getPlainLayout(size_t rank) { - switch (rank) { - case 1: return BNNSDataLayout1DFirstMajor; - case 2: return BNNSDataLayout2DFirstMajor; - case 3: return BNNSDataLayout3DFirstMajor; - case 4: return BNNSDataLayout4DFirstMajor; - case 5: return BNNSDataLayout5DFirstMajor; - case 6: return BNNSDataLayout6DFirstMajor; - case 7: return BNNSDataLayout7DFirstMajor; - case 8: return BNNSDataLayout8DFirstMajor; - default: - LOG(FATAL) << "Unsupported tensor rank : " << rank - << " Supported cases is only 1-8 "; - return static_cast(0); - } - } +using namespace ::tvm::runtime; +using namespace ::tvm::runtime::json; +using namespace ::tvm::runtime::contrib::BNNS; +struct ThreadingConfig { /** - * return size in byte of element of provided type - * @param dtype - * @return size of element in bytes - */ - static size_t getElementSize(Dtype dtype) { - return (dtype & 0xFFFF) / sizeof(uint8_t); - } - - static size_t getNumOfElements(const Shape &shape) { - return std::accumulate(shape.begin(), shape.end(), - 1, std::multiplies()); - } - - private: - Shape real_shape = {}; - void* data_handler = nullptr; - bool is_external_data = false; - - BNNSNDArrayDescriptor bnns_nd_desc; -}; - -class Primitive { -public: - explicit Primitive(BNNSFilter f) : num_filters(1), filters{f} {} - - explicit Primitive(BNNSFilter fs[BNNS_MAX_CONCURRENCY]) { - std::copy(fs, fs + BNNS_MAX_CONCURRENCY, filters); - for (int i = 0; i < BNNS_MAX_CONCURRENCY; i++) { - if (filters[i] == nullptr) { - num_filters = i; - break; - } - } - } - - ~Primitive() { - for (size_t i = 0; i < num_filters; i++) { - auto &filter = filters[i]; - if (filter) { - BNNSFilterDestroy(filter); - filter = nullptr; - } - } - } - - void execute(std::vector srcs, Tensor *dst, int forceBatchSize = -1) { - ICHECK_LE(srcs.size(), 2) << "Currently BNNS runtime supports primitives with only 1 or 2 " - "data inputs."; - - run_ctx ctx { this, srcs[0], nullptr, dst, forceBatchSize }; - if (srcs.size() > 1) - ctx.src2 = srcs[1]; - - auto res = TVMBackendParallelLaunch(run_task, &ctx, num_filters); - ICHECK_EQ(res, 0) << "BNNS runtime. Primitive was not executed properly"; - } - - void set_input_stride(size_t stride1, size_t stride2 = 0) { - in1_hdl_stride = stride1; - in2_hdl_stride = stride2; - } - void set_output_stride(size_t stride) { out_hdl_stride = stride; } - - private: - struct run_ctx { - Primitive *prim; - const Tensor *src1; - const Tensor *src2; - Tensor *dst; - const int force_batch_size; - }; - - static int run_task(int task_id, TVMParallelGroupEnv* penv, void* cdata) { - auto ctx = reinterpret_cast(cdata); - const auto *prim = ctx->prim; - - const auto &filter = prim->filters[task_id]; - - auto src1_hdl = ctx->src1->get_data_hdl(); - auto dst_hdl = ctx->dst->get_data_hdl(); - - auto src1_mb = ctx->src1->get_mb(); - auto dst_mb = ctx->dst->get_mb(); - - auto src1_mb_stride = ctx->src1->get_mb_stride(); - auto dst_mb_stride = ctx->dst->get_mb_stride(); - - src1_hdl = static_cast(src1_hdl) + task_id*prim->in1_hdl_stride; - dst_hdl = static_cast(dst_hdl) + task_id*prim->out_hdl_stride; - - ICHECK(src1_mb == dst_mb) << "Mismatch of batch dimension of input/output tensors"; - - const void* src2_hdl = nullptr; - size_t src2_mb = 0; - size_t src2_mb_stride = 0; - - if (ctx->src2) { - src2_hdl = ctx->src2->get_data_hdl(); - src2_mb = ctx->src2->get_mb(); - src2_mb_stride = ctx->src2->get_mb_stride(); - src2_hdl = static_cast(src2_hdl) + task_id*prim->in2_hdl_stride; - ICHECK(src2_mb == dst_mb) << "Mismatch of batch dimension of input/output tensors"; - } - - const auto mb = (ctx->force_batch_size == -1) ? dst_mb : ctx->force_batch_size; - - // WA - if (mb == 1) { - src1_mb_stride = prim->in1_hdl_stride / sizeof(float); - dst_mb_stride = prim->out_hdl_stride / sizeof(float); - } - - // NB! Limitations - // * Do not use simple BNNSFilterApply. There is a bug inside BNNS, - // and BNNSFilterApply doesn't work for grouped convolution. - // * Group convolution doesn't support arbitrary stride for Batch dim. - // The tensor should be dense. - auto sts = (ctx->src2) - ? BNNSFilterApplyTwoInputBatch(filter, mb, - src1_hdl, src1_mb_stride, - src2_hdl, src2_mb_stride, - dst_hdl, dst_mb_stride) - : BNNSFilterApplyBatch(filter, mb, - src1_hdl, src1_mb_stride, - dst_hdl, dst_mb_stride); - - return sts; - } - - private: - size_t num_filters = 0; - BNNSFilter filters[BNNS_MAX_CONCURRENCY] = {}; - - // TODO(apeskov): temporal solution with strides - size_t in1_hdl_stride = 0; - size_t in2_hdl_stride = 0; - size_t out_hdl_stride = 0; -}; - -} // namespace BNNS - -struct BNNSThreadingConfig { - /** - * Internal parallelism level ov BNNS primitive specified via parameter. - * Has no real control from TVM level, so in fact it may be ignored by - * implementation. + * Internal parallelism level ov BNNS primitive specified via BNNSFilterParameters + * struct. BNNS doesn't provide real control of internal threading, so it may be + * ignored by BNNS implementation. + * + * Valid values: + * 0 use default num of threads suggested by BNNS implementation + * >0 suggests to use this num of internal BNNS threads */ - int internalConcurrency = 0; + size_t internalConcurrency = 0; /** - * TVM level parallelism for BNNS primitive. In case if BNNS doesn't support - * internal parallelism we can add it by splitting primitive into independent - * parts and run it in parallel. May provide additional performance. + * TVM level parallelism for BNNS runtime. + * BNNS runtime will split primitive into set of independent sub primitives which + * can be executed in parallel. As a rule the splitting are performed through output + * channels, so the effective shape of executed primitive is changed. + * + * Valid values: + * 0 do not use graph level treading + * >0 split into this num of primitives */ - int externalConcurrency = 0; + size_t externalConcurrency = 0; }; +/** + * Depends on platform hardware the optimal ThreadingConfig may differ. + * This function contains a priori knowledge about some Apple platforms + * and their specific. + * + * @return default ThreadingConfig suggested for this platform + */ +ThreadingConfig getDefaultThreadingConfig() { + // TODO(apeskov): have to implement CPU/iOS version check. + // meanwhile will use {0, 2} stub to utilize big cores of A13/A14 CPU. + return {0, 2}; +} + +/** + * Main entry point to BNNS runtime + */ class BNNSJSONRuntime : public JSONRuntimeBase { public: BNNSJSONRuntime(const std::string& symbol_name, const std::string& graph_json, @@ -340,40 +94,30 @@ class BNNSJSONRuntime : public JSONRuntimeBase { const char* type_key() const override { return "bnns_json"; } void Init(const Array& consts) override { + ICHECK_EQ(consts.size(), const_idx_.size()) + << "The number of input constants must match the number of required."; + SetupConstants(consts); BuildEngine(); - - ICHECK_EQ(consts.size(), const_idx_.size()) - << "The number of input constants must match the number of required."; } void Run() override { // Wrap external handler into BNNS tensor representation auto bind_ext_hdl_to_tensor = [this] (uint32_t eid) { const auto &ext_dlt = *data_entry_[eid]; - auto &bnns_tensor = *entry_out_mem_[eid]; - bnns_tensor.set_data_hdl(ext_dlt.data); + auto &bnns_tensor = entry_out_mem_.at(eid); + bnns_tensor->set_data_hdl(ext_dlt.data); }; // Bind all input/output external data object into internal abstractions - for (const auto &eid : input_var_eid_) { + for (const auto &eid : input_var_eid_) bind_ext_hdl_to_tensor(eid); - } - for (const auto &out_entity : outputs_) { + for (const auto &out_entity : outputs_) bind_ext_hdl_to_tensor(EntryID(out_entity)); - } // Invoke primitives in topological order - for (int i = 0; i < primitives_.size(); ++i) { - auto res = entry_out_mem_.at(prim_results_[i]); - std::vector args; - for (auto arg_id : prim_args_[i]) - args.push_back(entry_out_mem_.at(arg_id).get()); - - int forceBatchSize = - (force_batch_size_.find(i) == force_batch_size_.end()) ? -1 : force_batch_size_.at(i); - primitives_.at(i)->execute(args, res.get(), forceBatchSize); - } + for (const auto &prim : primitives_) + prim->execute(); } private: @@ -424,125 +168,16 @@ class BNNSJSONRuntime : public JSONRuntimeBase { return entry_out_mem_[eid]; } - /** - * Function which split primitive into sub primitives to parallel execution - * - * @param orig_conv_param descriptor of original convolution - * @param batch batch value - * @param num number of part to split into. - * @return collection of Convolution descriptors plus strides for input and output tensors - */ - static std::tuple, size_t, size_t> - split_into_n(const BNNSLayerParametersConvolution& orig_conv_param, - size_t batch, size_t num) { - size_t i_shape[BNNS_MAX_TENSOR_DIMENSION] = {}; - size_t o_shape[BNNS_MAX_TENSOR_DIMENSION] = {}; - size_t w_shape[BNNS_MAX_TENSOR_DIMENSION] = {}; - size_t b_shape[BNNS_MAX_TENSOR_DIMENSION] = {}; - size_t w_stride = 0; - size_t b_stride = 0; - size_t i_stride = 0; - size_t o_stride = 0; - - // TODO(apeskov): In case of batch we can split through bach dimension. - // Meanwhile we just disable it... - if (batch > 1) { - return {{orig_conv_param}, 0, 0}; - } - - auto groups = orig_conv_param.groups; - // if groups > 1 split only by groups - // otherwise split inside one convolution by output channels - if (groups > 1) { - // fallback into sequential execution - if (groups % num != 0) - return {{orig_conv_param}, 0, 0}; - - std::copy(orig_conv_param.i_desc.size, orig_conv_param.i_desc.size + 3, i_shape); - std::copy(orig_conv_param.o_desc.size, orig_conv_param.o_desc.size + 3, o_shape); - std::copy(orig_conv_param.w_desc.size, orig_conv_param.w_desc.size + 4, w_shape); - std::copy(orig_conv_param.bias.size, orig_conv_param.bias.size + 1, b_shape); - - auto orig_w_buff_size = std::accumulate(w_shape, w_shape + 4, 1, std::multiplies()) - * sizeof(float); - - auto orig_b_buff_size = std::accumulate(b_shape, b_shape + 1, 1, std::multiplies()) - * sizeof(float); - - auto orig_i_buff_size = std::accumulate(i_shape, i_shape + 3, 1, std::multiplies()) - * sizeof(float); - - auto orig_o_buff_size = std::accumulate(o_shape, o_shape + 3, 1, std::multiplies()) - * sizeof(float); - - i_shape[2] /= num; - o_shape[2] /= num; - w_shape[3] /= num; - b_shape[0] /= num; - - w_stride = orig_w_buff_size / num; - b_stride = orig_b_buff_size / num; - i_stride = orig_i_buff_size / num; - o_stride = orig_o_buff_size / num; - groups = groups / num; - } else { - std::copy(orig_conv_param.i_desc.size, orig_conv_param.i_desc.size + 3, i_shape); - std::copy(orig_conv_param.o_desc.size, orig_conv_param.o_desc.size + 3, o_shape); - std::copy(orig_conv_param.w_desc.size, orig_conv_param.w_desc.size + 4, w_shape); - std::copy(orig_conv_param.bias.size, orig_conv_param.bias.size + 1, b_shape); - - auto orig_w_buff_size = std::accumulate(w_shape, w_shape + 4, 1, std::multiplies()) - * sizeof(float); - - auto orig_b_buff_size = std::accumulate(b_shape, b_shape + 1, 1, std::multiplies()) - * sizeof(float); - -// auto orig_i_buff_size = std::accumulate(i_shape, i_shape + 3, 1, std::multiplies()) -// * sizeof(float); - - auto orig_o_buff_size = std::accumulate(o_shape, o_shape + 3, 1, std::multiplies()) - * sizeof(float); - - o_shape[2] /= num; - w_shape[3] /= num; - b_shape[0] /= num; - - w_stride = orig_w_buff_size / num; - b_stride = orig_b_buff_size / num; - i_stride = 0; - o_stride = orig_o_buff_size / num; - } - - std::vector res(num); - for (size_t i=0; i < num; i++) { - auto &cur = res[i]; - cur = orig_conv_param; - - std::copy(i_shape, i_shape + 3, cur.i_desc.size); - std::copy(o_shape, o_shape + 3, cur.o_desc.size); - std::copy(w_shape, w_shape + 4, cur.w_desc.size); - std::copy(b_shape, b_shape + 1, cur.bias.size); - - cur.w_desc.data = static_cast(cur.w_desc.data) + w_stride * i; - if (cur.bias.data) - cur.bias.data = static_cast(cur.bias.data) + b_stride * i; - - cur.groups = groups; - } - return {res, i_stride, o_stride}; - } - - void Conv2d(const size_t& nid, const bool has_relu = false, const bool has_bias = false) { auto node = nodes_[nid]; // Setup attributes. auto src_entry = node.GetInputs()[0]; - auto weight_entry = node.GetInputs()[1]; + auto wgh_entry = node.GetInputs()[1]; auto dst_entry = JSONGraphNodeEntry(nid, 0); auto dl_input_shape = nodes_[src_entry.id_].GetOpShape()[src_entry.index_]; - auto dl_weight_shape = nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_]; + auto dl_weight_shape = nodes_[wgh_entry.id_].GetOpShape()[wgh_entry.index_]; BNNS::Shape input_shape {dl_input_shape.begin(), dl_input_shape.end()}; BNNS::Shape weight_shape {dl_weight_shape.begin(), dl_weight_shape.end()}; std::vector str_strides = node.GetAttr>("strides"); @@ -550,13 +185,7 @@ class BNNSJSONRuntime : public JSONRuntimeBase { std::vector str_padding = node.GetAttr>("padding"); BNNS::Dim groups = std::stoi(node.GetAttr>("groups")[0]); - BNNS::Dim N = input_shape[0], // batch size - IC = input_shape[1], // input channels - IH = input_shape[2], // input height - IW = input_shape[2], // input width - OC = weight_shape[0], // output channels - KH = weight_shape[2], // weight height - KW = weight_shape[3], // weight width + BNNS::Dim PH_L = std::stoi(str_padding[0]), // height padding: left PH_R = std::stoi(str_padding[2]), // height padding: right PW_L = std::stoi(str_padding[1]), // width padding: left @@ -564,20 +193,9 @@ class BNNSJSONRuntime : public JSONRuntimeBase { SH = std::stoi(str_strides[0]), // height-wise stride SW = std::stoi(str_strides[1]), // weight-wise stride DH = std::stoi(str_dilation[0]), // height kernel dilation - DW = std::stoi(str_dilation[1]), // width kernel dilation - OH = (IH - KH + PH_L + PH_R) / SH + 1, // output height - OW = (IW - KW + PW_L + PW_R) / SW + 1; // output width - - // Memory shapes. - BNNS::Shape src_dims = {N, IC, IH, IW}; - BNNS::Shape weights_dims = {OC, IC, KH, KW}; - BNNS::Shape bias_dims = {OC}; - BNNS::Shape dst_dims = {N, OC, OH, OW}; - BNNS::Shape strides_dims = {SH, SW}; - BNNS::Shape padding_dims_l = {PH_L, PW_L}; - BNNS::Shape padding_dims_r = {PH_R, PW_R}; - - auto weight_data_entry = data_entry_[EntryID(weight_entry)]; + DW = std::stoi(str_dilation[1]); // width kernel dilation + + auto weight_data_entry = data_entry_[EntryID(wgh_entry)]; ICHECK(weight_data_entry) << "Convolution weights tensor should be constant and " "available on initialization stage. Looks like weights " "are not result of constant expression."; @@ -585,11 +203,14 @@ class BNNSJSONRuntime : public JSONRuntimeBase { auto weight_ext_data_hdl = weight_data_entry->data; // Memory descriptions. - auto src_md = BindBNNSTensor(src_entry); - auto weights_md = BindBNNSTensor(weight_entry, weight_ext_data_hdl); - std::shared_ptr bias_md; - auto dst_md = BindBNNSTensor(dst_entry); - // TODO(apeskov): check correctness of tensor shapes + const auto &src_t = BindBNNSTensor(src_entry); + const auto &wgh_t = BindBNNSTensor(wgh_entry, weight_ext_data_hdl); + const auto &dst_t = BindBNNSTensor(dst_entry); + + auto src_view = TView::as_is(src_t).extract_outer_dim().with_layout(BNNSDataLayoutImageCHW); + auto wgh_view = TView::as_is(wgh_t).with_layout(BNNSDataLayoutConvolutionWeightsOIHW); + auto dst_view = TView::as_is(dst_t).extract_outer_dim().with_layout(BNNSDataLayoutImageCHW); + TView bias_view; if (has_bias) { auto bias_entry = node.GetInputs()[2]; @@ -599,39 +220,19 @@ class BNNSJSONRuntime : public JSONRuntimeBase { "is not result of constant expression."; auto bias_data_hdl = bias_data_entry->data; - bias_md = BindBNNSTensor(bias_entry, bias_data_hdl); - } else { - bias_md = std::make_shared(BNNS::Shape {OC}, BNNSDataTypeFloat32, nullptr); + auto bias_t = BindBNNSTensor(bias_entry, bias_data_hdl); + bias_view = TView::as_is(bias_t).squeeze().with_layout(BNNSDataLayoutVector); } BNNSActivation activation = { has_relu ? - BNNSActivationFunctionRectifiedLinear : - BNNSActivationFunctionIdentity }; - - auto src_candidate = src_md->get_nd_desc(3); - auto weights_candidate = weights_md->get_nd_desc(); - auto dst_candidate = dst_md->get_nd_desc(3); - auto bias_candidate = bias_md->get_nd_desc(); - src_candidate.layout = BNNSDataLayoutImageCHW; - dst_candidate.layout = BNNSDataLayoutImageCHW; - weights_candidate.layout = BNNSDataLayoutConvolutionWeightsOIHW; - bias_candidate.layout = BNNSDataLayoutVector; - - // TODO(apeskov): Tmp WA, broadcast bias is here with tailing [1, 1] - if (bias_candidate.size[0] == 1 && bias_candidate.size[1] == 1 && - one_of(bias_candidate.size[3], 1, 0) && - std::all_of(bias_candidate.size + 4, bias_candidate.size + BNNS_MAX_TENSOR_DIMENSION, - [] ( size_t d) { return d == 0; })) { - auto element_count = bias_candidate.size[2]; - std::fill(bias_candidate.size, bias_candidate.size + BNNS_MAX_TENSOR_DIMENSION, 0); - bias_candidate.size[0] = element_count; - } + BNNSActivationFunctionRectifiedLinear : + BNNSActivationFunctionIdentity }; BNNSLayerParametersConvolution conv_param = { - src_candidate, - weights_candidate, - dst_candidate, - bias_candidate, + src_view.get_bnns_view(), + wgh_view.get_bnns_view(), + dst_view.get_bnns_view(), + bias_view.get_bnns_view(), activation, SW, /* x_stride */ SH, /* y_stride */ @@ -643,22 +244,20 @@ class BNNSJSONRuntime : public JSONRuntimeBase { {PW_L, PW_R, PH_L, PH_R} /* explicit pad values */ }; - BNNSFilter filters[BNNS_MAX_CONCURRENCY] = {}; - + size_t num_sub_prim = default_thread_config.externalConcurrency; std::vector params; - size_t i_stride, o_stride; - std::tie(params, i_stride, o_stride) = split_into_n(conv_param, N, BNNS_TMP_CONCURRENCY); + std::tie(params, src_view, dst_view) = split_to_n(num_sub_prim, conv_param, + src_view, wgh_view, bias_view, dst_view); + + std::vector filters (params.size(), nullptr); for (int i = 0; i < params.size(); i++) { + auto common_filter_param = getCommonFilterParams(); filters[i] = BNNSFilterCreateLayerConvolution(¶ms[i], &common_filter_param); ICHECK(filters[i]) << "BNNS primitive was not created. Unsupported attributes configuration"; } - primitives_.emplace_back(std::make_shared(filters)); - primitives_.back()->set_input_stride(i_stride); - primitives_.back()->set_output_stride(o_stride); - - prim_args_.push_back({EntryID(src_entry)}); - prim_results_.push_back({EntryID(dst_entry)}); + primitives_.emplace_back( + std::make_shared(filters, src_view, dst_view)); } void Dense(const size_t& nid, const bool has_bias = false, const bool has_gelu = false) { @@ -671,26 +270,22 @@ class BNNSJSONRuntime : public JSONRuntimeBase { auto w_data = data_entry_[EntryID(weight_entry)]->data; // Memory descriptions. - auto src_md = BindBNNSTensor(src_entry); - auto weights_md = BindBNNSTensor(weight_entry, w_data); - auto dst_md = BindBNNSTensor(dst_entry); - - BNNSNDArrayDescriptor in_desc = src_md->get_nd_desc(1); - BNNSNDArrayDescriptor w_desc = weights_md->get_nd_desc(2); - BNNSNDArrayDescriptor out_desc = dst_md->get_nd_desc(1); - w_desc.layout = BNNSDataLayoutRowMajorMatrix; - in_desc.layout = BNNSDataLayoutVector; - out_desc.layout = BNNSDataLayoutVector; - w_desc.data = w_data; - BNNSNDArrayDescriptor bias = {}; + auto src_t = BindBNNSTensor(src_entry); + auto wgh_t = BindBNNSTensor(weight_entry, w_data); + auto dst_t = BindBNNSTensor(dst_entry); + + auto src_view = TView::as_is(src_t).extract_outer_dim().with_layout(BNNSDataLayoutVector); + auto wgh_view = TView::as_is(wgh_t).with_layout(BNNSDataLayoutRowMajorMatrix); + auto dst_view = TView::as_is(dst_t).extract_outer_dim().with_layout(BNNSDataLayoutVector); + + TView bias_view; if (has_bias) { auto bias_entry = node.GetInputs()[2]; auto bias_data = data_entry_[EntryID(bias_entry)]->data; auto bias_md = BindBNNSTensor(bias_entry, bias_data); - bias = bias_md->get_nd_desc(); - bias.layout = BNNSDataLayoutVector; - bias.data = bias_data; + bias_view = TView::as_is(bias_md).with_layout(BNNSDataLayoutVector); } + BNNSActivation activation = {BNNSActivationFunctionIdentity}; if (has_gelu) { activation = {BNNSActivationFunctionGELUApproximation}; @@ -699,18 +294,19 @@ class BNNSJSONRuntime : public JSONRuntimeBase { } BNNSLayerParametersFullyConnected layerParameters = { - in_desc, - w_desc, - out_desc, - bias, + src_view.get_bnns_view(), + wgh_view.get_bnns_view(), + dst_view.get_bnns_view(), + bias_view.get_bnns_view(), activation, }; + auto common_filter_param = getCommonFilterParams(); auto filter = BNNSFilterCreateLayerFullyConnected(&layerParameters, &common_filter_param); ICHECK(filter) << "BNNS primitive was not created. Unsupported attributes configuration"; - primitives_.emplace_back(std::make_shared(filter)); - prim_args_.push_back({EntryID(src_entry)}); - prim_results_.push_back({EntryID(dst_entry)}); + std::vector filters = {filter}; + primitives_.emplace_back( + std::make_shared(filters, src_view, dst_view)); } void MatMul(const size_t& nid) { @@ -730,18 +326,13 @@ class BNNSJSONRuntime : public JSONRuntimeBase { if (b_is_weighted) b_data = data_entry_[EntryID(b_entry)]->data; // Memory descriptions. - auto a_md = BindBNNSTensor(a_entry, a_data); - auto b_md = BindBNNSTensor(b_entry, b_data); - auto dst_md = BindBNNSTensor(dst_entry); - - BNNSNDArrayDescriptor a_desc = a_md->get_nd_desc(); - BNNSNDArrayDescriptor b_desc = b_md->get_nd_desc(); - BNNSNDArrayDescriptor out_desc = dst_md->get_nd_desc(); - std::reverse(a_desc.size, a_desc.size + 3); - std::reverse(b_desc.size, b_desc.size + 3); - std::reverse(out_desc.size, out_desc.size + 3); - a_desc.data = a_data; - b_desc.data = b_data; + auto a_t = BindBNNSTensor(a_entry, a_data); + auto b_t = BindBNNSTensor(b_entry, b_data); + auto dst_t = BindBNNSTensor(dst_entry); + + auto a_view = TView::as_is(a_t); + auto b_view = TView::as_is(b_t); + auto dst_view = TView::as_is(dst_t); BNNSLayerParametersBroadcastMatMul layerParameters = { 1, // alpha @@ -751,22 +342,28 @@ class BNNSJSONRuntime : public JSONRuntimeBase { false, // quadratic a_is_weighted, b_is_weighted, - a_desc, - b_desc, - out_desc + a_view.get_bnns_view(), + b_view.get_bnns_view(), + dst_view.get_bnns_view() }; + // BNNS limitation: MatMul use reverse dims values. However strides are calculated correctly + // based on BNNSNDArrayDescriptor::layout value. + std::reverse(layerParameters.iA_desc.size, layerParameters.iA_desc.size + 3); + std::reverse(layerParameters.iB_desc.size, layerParameters.iB_desc.size + 3); + std::reverse(layerParameters.o_desc.size, layerParameters.o_desc.size + 3); + + auto common_filter_param = getCommonFilterParams(); auto filter = BNNSFilterCreateLayerBroadcastMatMul(&layerParameters, &common_filter_param); ICHECK(filter) << "BNNS primitive was not created. Unsupported attributes configuration"; - primitives_.emplace_back(std::make_shared(filter)); - std::vector args; - if (!a_is_weighted) - args.push_back(EntryID(a_entry)); - if (!b_is_weighted) - args.push_back(EntryID(b_entry)); - prim_args_.push_back(std::move(args)); - prim_results_.push_back(EntryID(dst_entry)); - force_batch_size_.insert({prim_args_.size() - 1, 1}); + + std::vector filters {filter}; + if (a_is_weighted || b_is_weighted) { + auto src_view = a_is_weighted ? b_view : a_view; + primitives_.emplace_back(std::make_shared(filters, src_view, dst_view)); + } else { + primitives_.emplace_back(std::make_shared(filters, a_view, b_view, dst_view)); + } } BNNS::Dtype convertToBNNS(const DLDataType &dl_dtype) { @@ -788,14 +385,15 @@ class BNNSJSONRuntime : public JSONRuntimeBase { return BNNS::Dtype(0); } - // TODO(apeskov): Allow to specify num of threads and keep customer buffers. - // Should investigate this attributes. - BNNSFilterParameters common_filter_param {}; + BNNSFilterParameters getCommonFilterParams() { + // NOTE: To force weights tensor copy on stage of filter create + // just change : BNNSFlagsUseClientPtr -> 0 + return { BNNSFlagsUseClientPtr, default_thread_config.internalConcurrency }; + } + + const ThreadingConfig default_thread_config = getDefaultThreadingConfig(); std::vector> primitives_; - std::vector> prim_args_; - std::vector prim_results_; - std::unordered_map force_batch_size_; /* The entry ID to its corresponding output memory. */ std::unordered_map> entry_out_mem_; @@ -813,6 +411,6 @@ TVM_REGISTER_GLOBAL("runtime.BNNSJSONRuntimeCreate") TVM_REGISTER_GLOBAL("runtime.module.loadbinary_bnns_json") .set_body_typed(BNNSJSONRuntime::LoadFromBinary); -} // namespace contrib +} } // namespace runtime } // namespace tvm From ca72944879853249523edd857c2029eeb1eeb6ac Mon Sep 17 00:00:00 2001 From: Alexander Peskov Date: Tue, 19 Jan 2021 14:00:04 +0300 Subject: [PATCH 05/27] [BNNS] Add several more onnx topologies into tests Signed-off-by: Alexander Peskov --- tests/python/contrib/test_bnns/test_conv2d.py | 1 + .../contrib/test_bnns/test_onnx_topologies.py | 135 ++++++++++++++++++ 2 files changed, 136 insertions(+) create mode 100644 tests/python/contrib/test_bnns/test_onnx_topologies.py diff --git a/tests/python/contrib/test_bnns/test_conv2d.py b/tests/python/contrib/test_bnns/test_conv2d.py index 5dbeddcc5573..9a00269d2869 100644 --- a/tests/python/contrib/test_bnns/test_conv2d.py +++ b/tests/python/contrib/test_bnns/test_conv2d.py @@ -184,6 +184,7 @@ def test_conv2d_dw(): outputs.append(build_and_run(func, inputs, 1, params, device, enable_bnns=bnns)[0]) config = { + "batch": batch, "shape": shape, "kernel size": kernel, "padding": pad, diff --git a/tests/python/contrib/test_bnns/test_onnx_topologies.py b/tests/python/contrib/test_bnns/test_onnx_topologies.py new file mode 100644 index 000000000000..ebc255830811 --- /dev/null +++ b/tests/python/contrib/test_bnns/test_onnx_topologies.py @@ -0,0 +1,135 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""BNNS pattern detection check""" + +import tvm +from tvm import relay +from tvm.relay import transform +from tvm.contrib import utils, graph_runtime +from tvm.contrib.download import download_testdata +from tvm.relay.op.contrib.bnns import partition_for_bnns + +import onnx +import numpy as np +import pytest + +bnns_is_absent = tvm.get_global_func("relay.ext.bnns", True) is None + +TARGET = "llvm" +INPUT_SHAPE = [1, 3, 224, 224] + +BASE_MODEL_URL = "https://github.com/onnx/models/raw/master/" +MODEL_URL_COLLECTION = { + "BERT": "text/machine_comprehension/bert-squad/model/bertsquad-10.onnx", + "MobileNet-v2": "vision/classification/mobilenet/model/mobilenetv2-7.onnx", + "ResNet50-v1": "vision/classification/resnet/model/resnet50-v1-7.onnx", + "ResNet50-v2": "vision/classification/resnet/model/resnet50-v2-7.onnx", + "SqueezeNet-v1.1": "vision/classification/squeezenet/model/squeezenet1.1-7.onnx", + "SqueezeNet-v1.0": "vision/classification/squeezenet/model/squeezenet1.0-7.onnx", + "Inception-v1": "vision/classification/inception_and_googlenet/inception_v1/model/inception-v1-7.onnx", + "Inception-v2": "vision/classification/inception_and_googlenet/inception_v2/model/inception-v2-7.onnx", +} + + +def get_onnx_input_name(model): + inputs = [node.name for node in model.graph.input] + initializer = [node.name for node in model.graph.initializer] + + inputs = list(set(inputs) - set(initializer)) + return inputs + + +def get_model_url(model_name): + return BASE_MODEL_URL + MODEL_URL_COLLECTION[model_name] + + +def get_name_from_url(url): + return url[url.rfind('/') + 1:].strip() + + +def find_of_download(model_name): + model_url = get_model_url(model_name) + model_file_name = get_name_from_url(model_url) + return download_testdata(model_url, model_file_name, module="models") + + +def get_model(model_name): + model_path = find_of_download(model_name) + onnx_model = onnx.load(model_path) + input_names = get_onnx_input_name(onnx_model) + input_dict = {} + for name in input_names: + input_dict[name] = INPUT_SHAPE # TODO: hardcode + mod, params = relay.frontend.from_onnx(onnx_model, input_dict, freeze_params=True) + return mod, params, input_dict + + +def simplify_model(mod): + """ + Simplify execution graph + + At least merge BatchNorm into convolution. For this purpose decompose BN primitive + into simple operation which can be calculated as const expr and after that merged + into nearest conv/dense primitive. + """ + seq = tvm.transform.Sequential( + [ + transform.InferType(), + transform.FoldConstant(), + transform.SimplifyInference(), + transform.FoldScaleAxis(), + ] + ) + return seq(mod) + + +def process(model_name): + temp = utils.tempdir() + model, params, input_dict = get_model(model_name) + + def run(mod, target, simplify=True, with_bnns=False): + with tvm.transform.PassContext(opt_level=3): + if simplify: + mod = simplify_model(mod) + if with_bnns: + mod = partition_for_bnns(mod) + graph_module = relay.build(mod, target=target, target_host=target, params=params) + + lib_name = "deploy.tar" + path_dso = temp.relpath(lib_name) + graph_module.export_library(path_dso) + + ctx = tvm.cpu(0) + loaded_lib = tvm.runtime.load_module(path_dso) + + module = graph_runtime.GraphModule(loaded_lib["default"](ctx)) + module.run() + return module.get_output(0).asnumpy() + + res_llvm = run(model, TARGET, simplify=True, with_bnns=False) + res_bnns = run(model, TARGET, simplify=True, with_bnns=True) + + tvm.testing.assert_allclose( + res_llvm, res_bnns, atol=0.002, rtol=0.007, + ) + + +@pytest.mark.skip(reason="Manually disabled because of huge complexity") +@pytest.mark.skipif(bnns_is_absent, reason="BNNS runtime is absent") +@pytest.mark.parametrize("model_name", MODEL_URL_COLLECTION.keys()) +def test_topology(model_name): + process(model_name) From c3bf919f8e232db1b56867b908fc05f6e410e8da Mon Sep 17 00:00:00 2001 From: Alexander Peskov Date: Mon, 25 Jan 2021 18:43:38 +0300 Subject: [PATCH 06/27] [BNNS] Avoid redundant tensor allocation Signed-off-by: Alexander Peskov --- src/runtime/contrib/bnns/bnns.hpp | 22 ++-- src/runtime/contrib/bnns/bnns_json_runtime.cc | 112 ++++++++++-------- 2 files changed, 78 insertions(+), 56 deletions(-) diff --git a/src/runtime/contrib/bnns/bnns.hpp b/src/runtime/contrib/bnns/bnns.hpp index dc071032250f..62ee9289597b 100644 --- a/src/runtime/contrib/bnns/bnns.hpp +++ b/src/runtime/contrib/bnns/bnns.hpp @@ -72,14 +72,8 @@ class Tensor { }; std::copy(shape.rbegin(), shape.rend(), std::begin(desc_.size)); - if (hdl) { - desc_.data = hdl; - is_external_data = true; - } else { - const size_t buff_size = getSize(desc_) * getElementSize(desc_); - desc_.data = default_alloc(buff_size); - is_external_data = false; - } + desc_.data = hdl; + is_external_data = true; } ~Tensor() { @@ -89,6 +83,16 @@ class Tensor { } } + void allocate_memory() { + if (desc_.data && !is_external_data) { + default_free(desc_.data); + } + const size_t buff_size = getSize(desc_) * getElementSize(desc_); + desc_.data = default_alloc(buff_size); + ICHECK(desc_.data); + is_external_data = false; + } + void* get_data_hdl() const { return desc_.data; } void set_data_hdl(void *hdl) { @@ -254,7 +258,7 @@ class TView { } /** Check if view is empty and doesn't relay to any tensor */ - operator bool() const { return view_desc_.data != nullptr; } + operator bool() const { return origin_ != nullptr; } /** Get BNNS descriptor for particular View. Batch and Party attributed are ignored. */ const BNNSNDArrayDescriptor& get_bnns_view() const { diff --git a/src/runtime/contrib/bnns/bnns_json_runtime.cc b/src/runtime/contrib/bnns/bnns_json_runtime.cc index cc13d10bca88..dad871b2f42b 100644 --- a/src/runtime/contrib/bnns/bnns_json_runtime.cc +++ b/src/runtime/contrib/bnns/bnns_json_runtime.cc @@ -98,6 +98,8 @@ class BNNSJSONRuntime : public JSONRuntimeBase { << "The number of input constants must match the number of required."; SetupConstants(consts); + BindInputsAndOutputs(); + AllocateIntermediateTensors(); BuildEngine(); } @@ -105,7 +107,7 @@ class BNNSJSONRuntime : public JSONRuntimeBase { // Wrap external handler into BNNS tensor representation auto bind_ext_hdl_to_tensor = [this] (uint32_t eid) { const auto &ext_dlt = *data_entry_[eid]; - auto &bnns_tensor = entry_out_mem_.at(eid); + auto &bnns_tensor = tensors_eid_[eid]; bnns_tensor->set_data_hdl(ext_dlt.data); }; @@ -121,6 +123,47 @@ class BNNSJSONRuntime : public JSONRuntimeBase { } private: + /** Make corresponding input/output tensor stubs */ + void BindInputsAndOutputs() { + tensors_eid_.resize(data_entry_.size()); + auto createTensor = [&](JSONGraphNodeEntry entry) { + auto node = nodes_[entry.id_]; + auto dlshape = node.GetOpShape()[entry.index_]; + auto dltype = node.GetOpDataType()[entry.index_]; + void *data = nullptr; + if (data_entry_[entry.id_] != nullptr) + data = data_entry_[entry.id_]->data; + tensors_eid_[entry.id_] = + std::make_shared(BNNS::Shape{dlshape.begin(), dlshape.end()}, + convertToBNNS(dltype), data); + }; + + for (auto& id : input_nodes_) { + auto eid = JSONGraphNodeEntry(id, 0); + createTensor(eid); + } + + for (auto entry : outputs_) { + createTensor(entry); + } + } + + /** Allocate intermediate tensors */ + void AllocateIntermediateTensors() { + for (int i = 0; i < nodes_.size(); ++i) { + auto eid = JSONGraphNodeEntry(i, 0); + if (tensors_eid_[eid.id_] != nullptr) + continue; + auto node = nodes_[i]; + auto dlshape = node.GetOpShape()[0]; + auto dltype = node.GetOpDataType()[0]; + tensors_eid_[eid.id_] = + std::make_shared(BNNS::Shape{dlshape.begin(), dlshape.end()}, + convertToBNNS(dltype), nullptr); + tensors_eid_[eid.id_]->allocate_memory(); + } + } + // Build up the engine based on the input graph. void BuildEngine() { // Build subgraph engine. @@ -152,20 +195,11 @@ class BNNSJSONRuntime : public JSONRuntimeBase { } } - // Bind a JSON graph node entry to a BNNS tensor. - std::shared_ptr BindBNNSTensor(const JSONGraphNodeEntry& entry, - void *hdl = nullptr) { + // Get BNNS tensor. + std::shared_ptr GetBNNSTensor(const JSONGraphNodeEntry& entry) { auto eid = EntryID(entry); - if (entry_out_mem_.count(eid) == 0) { - auto data_node = nodes_[entry.id_]; - auto dlshape = data_node.GetOpShape()[entry.index_]; - auto dltype = data_node.GetOpDataType()[entry.index_]; - - entry_out_mem_[eid] = std::make_shared( - BNNS::Shape{dlshape.begin(), dlshape.end()}, - convertToBNNS(dltype), hdl); - } - return entry_out_mem_[eid]; + ICHECK(eid < tensors_eid_.size()); + return tensors_eid_[eid]; } void Conv2d(const size_t& nid, const bool has_relu = false, const bool has_bias = false) { @@ -195,17 +229,10 @@ class BNNSJSONRuntime : public JSONRuntimeBase { DH = std::stoi(str_dilation[0]), // height kernel dilation DW = std::stoi(str_dilation[1]); // width kernel dilation - auto weight_data_entry = data_entry_[EntryID(wgh_entry)]; - ICHECK(weight_data_entry) << "Convolution weights tensor should be constant and " - "available on initialization stage. Looks like weights " - "are not result of constant expression."; - - auto weight_ext_data_hdl = weight_data_entry->data; - // Memory descriptions. - const auto &src_t = BindBNNSTensor(src_entry); - const auto &wgh_t = BindBNNSTensor(wgh_entry, weight_ext_data_hdl); - const auto &dst_t = BindBNNSTensor(dst_entry); + const auto &src_t = GetBNNSTensor(src_entry); + const auto &wgh_t = GetBNNSTensor(wgh_entry); + const auto &dst_t = GetBNNSTensor(dst_entry); auto src_view = TView::as_is(src_t).extract_outer_dim().with_layout(BNNSDataLayoutImageCHW); auto wgh_view = TView::as_is(wgh_t).with_layout(BNNSDataLayoutConvolutionWeightsOIHW); @@ -214,13 +241,8 @@ class BNNSJSONRuntime : public JSONRuntimeBase { if (has_bias) { auto bias_entry = node.GetInputs()[2]; - auto bias_data_entry = data_entry_[EntryID(bias_entry)]; - ICHECK(bias_data_entry) << "Convolution bias tensor should be constant and " - "available on initialization stage. Looks like bias " - "is not result of constant expression."; - auto bias_data_hdl = bias_data_entry->data; - auto bias_t = BindBNNSTensor(bias_entry, bias_data_hdl); + auto bias_t = GetBNNSTensor(bias_entry); bias_view = TView::as_is(bias_t).squeeze().with_layout(BNNSDataLayoutVector); } @@ -268,11 +290,10 @@ class BNNSJSONRuntime : public JSONRuntimeBase { auto weight_entry = node.GetInputs()[1]; auto dst_entry = JSONGraphNodeEntry(nid, 0); - auto w_data = data_entry_[EntryID(weight_entry)]->data; // Memory descriptions. - auto src_t = BindBNNSTensor(src_entry); - auto wgh_t = BindBNNSTensor(weight_entry, w_data); - auto dst_t = BindBNNSTensor(dst_entry); + auto src_t = GetBNNSTensor(src_entry); + auto wgh_t = GetBNNSTensor(weight_entry); + auto dst_t = GetBNNSTensor(dst_entry); auto src_view = TView::as_is(src_t).extract_outer_dim().with_layout(BNNSDataLayoutVector); auto wgh_view = TView::as_is(wgh_t).with_layout(BNNSDataLayoutRowMajorMatrix); @@ -281,8 +302,7 @@ class BNNSJSONRuntime : public JSONRuntimeBase { TView bias_view; if (has_bias) { auto bias_entry = node.GetInputs()[2]; - auto bias_data = data_entry_[EntryID(bias_entry)]->data; - auto bias_md = BindBNNSTensor(bias_entry, bias_data); + auto bias_md = GetBNNSTensor(bias_entry); bias_view = TView::as_is(bias_md).with_layout(BNNSDataLayoutVector); } @@ -319,16 +339,10 @@ class BNNSJSONRuntime : public JSONRuntimeBase { bool a_is_weighted = data_entry_[EntryID(a_entry)] != nullptr; bool b_is_weighted = data_entry_[EntryID(b_entry)] != nullptr; - void* a_data = nullptr; - void* b_data = nullptr; - if (a_is_weighted) - a_data = data_entry_[EntryID(a_entry)]->data; - if (b_is_weighted) - b_data = data_entry_[EntryID(b_entry)]->data; // Memory descriptions. - auto a_t = BindBNNSTensor(a_entry, a_data); - auto b_t = BindBNNSTensor(b_entry, b_data); - auto dst_t = BindBNNSTensor(dst_entry); + auto a_t = GetBNNSTensor(a_entry); + auto b_t = GetBNNSTensor(b_entry); + auto dst_t = GetBNNSTensor(dst_entry); auto a_view = TView::as_is(a_t); auto b_view = TView::as_is(b_t); @@ -391,12 +405,16 @@ class BNNSJSONRuntime : public JSONRuntimeBase { return { BNNSFlagsUseClientPtr, default_thread_config.internalConcurrency }; } + /** Default threading config. Should be used if there are + * no other threading specificator. */ const ThreadingConfig default_thread_config = getDefaultThreadingConfig(); + /** Collection of all primitives in topological order */ std::vector> primitives_; - /* The entry ID to its corresponding output memory. */ - std::unordered_map> entry_out_mem_; + /** Vector with BNNS tensors. Index of tensor matched with + * corresponding EntryID from base JSONRuntimeBase. */ + std::vector tensors_eid_; }; runtime::Module BNNSJSONRuntimeCreate(String symbol_name, String graph_json, From d54576d392075c1c3a9133487c5f7d8373e85523 Mon Sep 17 00:00:00 2001 From: Alexander Peskov Date: Wed, 27 Jan 2021 14:01:11 +0300 Subject: [PATCH 07/27] [BNNS] Fix conv_splitter issue Signed-off-by: Alexander Peskov --- src/runtime/contrib/bnns/bnns.hpp | 5 +- tests/python/contrib/test_bnns/test_conv2d.py | 50 ++++++++++++++++++- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/src/runtime/contrib/bnns/bnns.hpp b/src/runtime/contrib/bnns/bnns.hpp index 62ee9289597b..3ecf936b44a4 100644 --- a/src/runtime/contrib/bnns/bnns.hpp +++ b/src/runtime/contrib/bnns/bnns.hpp @@ -366,6 +366,7 @@ static std::tuple, TView, TView> spl const TView &b_view, const TView &dst_view) { size_t batch = src_view.get_batch_size(); + size_t oc = dst_view.get_bnns_view().size[2]; size_t groups = orig_conv_param.groups; BNNS::TView src_view_new; @@ -374,7 +375,9 @@ static std::tuple, TView, TView> spl BNNS::TView dst_view_new; // TODO(apeskov): Add split by batch dim. Meanwhile we just disable it... - if (batch > 1 || (groups > 1 && groups % num != 0)) { + if (batch > 1 || + oc % num != 0 || + (groups > 1 && groups % num != 0)) { return {{orig_conv_param}, src_view, dst_view}; } diff --git a/tests/python/contrib/test_bnns/test_conv2d.py b/tests/python/contrib/test_bnns/test_conv2d.py index 9a00269d2869..1d3ee20559d2 100644 --- a/tests/python/contrib/test_bnns/test_conv2d.py +++ b/tests/python/contrib/test_bnns/test_conv2d.py @@ -96,7 +96,7 @@ def test_conv2d(): pad = [(1, 1), (2, 2), (2, 1)] strides = [(1, 1), (2, 2)] dilation = [(1, 1)] - out_channels = [4, 8, 16] + out_channels = [1, 4, 8, 16] input_shapes = [(10, 10, 14), (12, 15, 16), (20, 20, 20)] batches = [1, 2] groups = [1, 2] @@ -193,5 +193,53 @@ def test_conv2d_dw(): verify(outputs, atol=0.002, rtol=0.007, config=config) +def test_conv2d_with_oc1(): + if skip_runtime_test(): + return + + device = Device() + np.random.seed(0) + dtype = "float32" + shape = [3, 5, 5] + kernel = [3, 3] + pad = [1, 1] + oc = 1 # <= test on conv with one output channel + + for batch in [1, 2]: + i_shape = (batch, *shape) + ic = shape[0] + inputs = { + "a": tvm.nd.array(np.random.uniform(0, 127, i_shape).astype(dtype)), + } + + a = relay.var("a", shape=i_shape, dtype=dtype) + weight_shape = [oc, ic, *kernel] + w = tvm.nd.array(np.random.uniform(-128, 127, weight_shape).astype(dtype)) + weights = relay.const(w, dtype) + func = relay.nn.conv2d( + a, + weights, + kernel_size=kernel, + padding=pad, + groups=1, + channels=oc, + out_dtype=dtype, + ) + params = {"w": w} + + outputs = [] + for bnns in [False, True]: + outputs.append(build_and_run(func, inputs, 1, params, device, enable_bnns=bnns)[0]) + + config = { + "batch": batch, + "shape": shape, + "kernel size": kernel, + "padding": pad, + "out channels": oc, + } + verify(outputs, atol=0.002, rtol=0.007, config=config) + + if __name__ == "__main__": test_conv2d() From 331d8e0bb0b94c7b08831cf2f9fd70ec62522da1 Mon Sep 17 00:00:00 2001 From: Alexander Peskov Date: Wed, 27 Jan 2021 21:18:34 +0300 Subject: [PATCH 08/27] [BNNS] Fix isse with bias {1,1,1,1} Signed-off-by: Alexander Peskov --- src/runtime/contrib/bnns/bnns.hpp | 7 ++- tests/python/contrib/test_bnns/test_conv2d.py | 55 +++++++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/src/runtime/contrib/bnns/bnns.hpp b/src/runtime/contrib/bnns/bnns.hpp index 3ecf936b44a4..4532fbcb4930 100644 --- a/src/runtime/contrib/bnns/bnns.hpp +++ b/src/runtime/contrib/bnns/bnns.hpp @@ -180,7 +180,7 @@ class TView { }; /** Squeeze all dims equal 1 */ - TView squeeze() const { + TView squeeze(size_t min_rank = 1) const { auto rank = Tensor::getRank(view_desc_); size_t squeezed_shape[BNNS_MAX_TENSOR_DIMENSION] = {}; size_t squeezed_rank = 0; @@ -188,6 +188,11 @@ class TView { if (view_desc_.size[i] != 1) squeezed_shape[squeezed_rank++] = view_desc_.size[i]; + if (min_rank > squeezed_rank) { + std::fill(squeezed_shape + squeezed_rank, squeezed_shape + min_rank, 1); + squeezed_rank = min_rank; + } + TView res = *this; std::copy(squeezed_shape, squeezed_shape + squeezed_rank, res.view_desc_.size); std::fill(res.view_desc_.size + squeezed_rank, res.view_desc_.size + rank, 0); diff --git a/tests/python/contrib/test_bnns/test_conv2d.py b/tests/python/contrib/test_bnns/test_conv2d.py index 1d3ee20559d2..acede04f153e 100644 --- a/tests/python/contrib/test_bnns/test_conv2d.py +++ b/tests/python/contrib/test_bnns/test_conv2d.py @@ -241,5 +241,60 @@ def test_conv2d_with_oc1(): verify(outputs, atol=0.002, rtol=0.007, config=config) +def test_conv2d_with_scalar_bias(): + if skip_runtime_test(): + return + + device = Device() + np.random.seed(0) + dtype = "float32" + shape = [3, 5, 5] + kernel = [3, 3] + pad = [1, 1] + oc = 1 + + for batch in [1, 2]: + i_shape = (batch, *shape) + ic = shape[0] + inputs = { + "a": tvm.nd.array(np.random.uniform(0, 127, i_shape).astype(dtype)), + } + params = {} + + a = relay.var("a", shape=i_shape, dtype=dtype) + weight_shape = [oc, ic, *kernel] + w = tvm.nd.array(np.random.uniform(-128, 127, weight_shape).astype(dtype)) + weights = relay.const(w, dtype) + func = relay.nn.conv2d( + a, + weights, + kernel_size=kernel, + padding=pad, + groups=1, + channels=oc, + out_dtype=dtype, + ) + params["w"] = w + + b = tvm.nd.array(np.random.uniform(-10, 10, [1, oc, 1, 1]).astype(dtype)) # <= Check with 1, 1, 1, 1 version of bias + + bias = relay.const(b, dtype) + func = relay.add(func, bias) + params["b"] = b + + outputs = [] + for bnns in [False, True]: + outputs.append(build_and_run(func, inputs, 1, params, device, enable_bnns=bnns)[0]) + + config = { + "batch": batch, + "shape": shape, + "kernel size": kernel, + "padding": pad, + "out channels": oc, + } + verify(outputs, atol=0.002, rtol=0.007, config=config) + + if __name__ == "__main__": test_conv2d() From d31063ee77125d3505c2383126384f1d72c2760e Mon Sep 17 00:00:00 2001 From: Alexander Peskov Date: Wed, 27 Jan 2021 21:22:58 +0300 Subject: [PATCH 09/27] [BNNS] Min. Rename file Signed-off-by: Alexander Peskov --- src/runtime/contrib/bnns/bnns_json_runtime.cc | 2 +- src/runtime/contrib/bnns/{bnns.hpp => bnns_wrp.h} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename src/runtime/contrib/bnns/{bnns.hpp => bnns_wrp.h} (100%) diff --git a/src/runtime/contrib/bnns/bnns_json_runtime.cc b/src/runtime/contrib/bnns/bnns_json_runtime.cc index dad871b2f42b..cdb2307c3980 100644 --- a/src/runtime/contrib/bnns/bnns_json_runtime.cc +++ b/src/runtime/contrib/bnns/bnns_json_runtime.cc @@ -34,7 +34,7 @@ #include "../json/json_node.h" #include "../json/json_runtime.h" -#include "bnns.hpp" +#include "bnns_wrp.h" namespace tvm { namespace runtime { diff --git a/src/runtime/contrib/bnns/bnns.hpp b/src/runtime/contrib/bnns/bnns_wrp.h similarity index 100% rename from src/runtime/contrib/bnns/bnns.hpp rename to src/runtime/contrib/bnns/bnns_wrp.h From d67f6b7b3c3c8054a51e89e2e34c7a2b162f8be1 Mon Sep 17 00:00:00 2001 From: Alexander Peskov Date: Wed, 27 Jan 2021 14:00:05 +0300 Subject: [PATCH 10/27] Fix review comments. Initial Signed-off-by: Alexander Peskov --- python/tvm/relay/op/contrib/bnns.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/op/contrib/bnns.py b/python/tvm/relay/op/contrib/bnns.py index bee92a228869..c00ddfe842b6 100644 --- a/python/tvm/relay/op/contrib/bnns.py +++ b/python/tvm/relay/op/contrib/bnns.py @@ -76,10 +76,10 @@ def _register_external_op_helper(op_name, supported=True): """The helper function to indicate that a given operator can be supported by BNNS. - Paramters - --------- + Parameters + ---------- op_name : Str - The name of operator that will be registered. + The name of supported operator that will be registered. Returns ------- @@ -97,11 +97,8 @@ def _func_wrapper(expr): _register_external_op_helper("nn.batch_matmul") -# TODO [apeskov]: -# 1. enlarge list of supported types on -# 2. clarify meaning of "" value def dtype_is_supported(dtype): - return dtype == "float32" or dtype == "" + return dtype in ("", "float32") @tvm.ir.register_op_attr("nn.conv2d", "target.bnns") @@ -203,7 +200,7 @@ def make_dense_bias_gelu_pattern(): def check_dense(extract): - """Check conv pattern is supported by ACL.""" + """Check conv pattern is supported by BNNS.""" call = extract while call.op.name != "nn.dense": call = call.args[0] From 7da6a26b1c63f2a95e9c2ff38decbecb78591f9f Mon Sep 17 00:00:00 2001 From: Alexander Peskov Date: Fri, 29 Jan 2021 16:15:39 +0300 Subject: [PATCH 11/27] [BNNS] test refactoring Signed-off-by: Alexander Peskov --- .../contrib/test_bnns/infrastructure.py | 28 +++ tests/python/contrib/test_bnns/test_conv2d.py | 224 +++--------------- 2 files changed, 67 insertions(+), 185 deletions(-) diff --git a/tests/python/contrib/test_bnns/infrastructure.py b/tests/python/contrib/test_bnns/infrastructure.py index b3e15959107e..74b489e2207f 100644 --- a/tests/python/contrib/test_bnns/infrastructure.py +++ b/tests/python/contrib/test_bnns/infrastructure.py @@ -29,6 +29,7 @@ from tvm.relay.op.contrib.bnns import partition_for_bnns from tvm.contrib import utils from tvm.autotvm.measure import request_remote +from tvm.relay.analysis import analysis class Device: @@ -267,6 +268,33 @@ def verify_codegen( ) +def compare_inference_with_ref(func, params, atol=0.002, rtol=0.007): + """Compare scoring results for compilation with and without BNNS. + + Provided function will be compiled two times with and without BNNS. + The scoring results for both type of compilation will be compared + with provided atol and rtol. The input data will be automatically + generated based of shape and dtype info provided for var nodes. + + """ + # Generate input tensor values + inputs = {} + for free_param in analysis.free_vars(func): + name = free_param.name_hint + dtype = free_param.type_annotation.dtype + shape = [s.value for s in free_param.type_annotation.shape] + inputs[name] = tvm.nd.array(np.random.uniform(0, 127, shape).astype(dtype)) + + # Run for both type of compilation + device = Device() + outputs = [] + for bnns in [False, True]: + outputs.append(build_and_run(func, inputs, 1, params, device, enable_bnns=bnns)[0]) + + # Compare result tensors + verify(outputs, atol=atol, rtol=rtol) + + def generate_trials(space, r_factor=3): """Generates a series of trials. diff --git a/tests/python/contrib/test_bnns/test_conv2d.py b/tests/python/contrib/test_bnns/test_conv2d.py index acede04f153e..d02efb33362d 100644 --- a/tests/python/contrib/test_bnns/test_conv2d.py +++ b/tests/python/contrib/test_bnns/test_conv2d.py @@ -17,17 +17,11 @@ """BNNS integration conv2d tests.""" import numpy as np - +import pytest import tvm from tvm import relay -from .infrastructure import Device -from .infrastructure import ( - skip_runtime_test, - build_and_run, - verify, - generate_trials, -) +from .infrastructure import skip_runtime_test, compare_inference_with_ref, generate_trials # TODO: Missed cases # 1. Bias as add with 3d const tensor. Lead to additional unsqueeze op between @@ -38,27 +32,28 @@ def _get_model( shape, - kernel_h, - kernel_w, - padding, - strides, - dilation, - groups, - dtype, - channels, - var_names, + kernel=(3, 3), + padding=(1, 1), + strides=(1, 1), + dilation=(1, 1), + groups=1, + dtype="float32", + channels=-1, # -1 means same as input channels bias_type='none', activation_type='none', ): """Return a model and any parameters it may have""" - a = relay.var(next(var_names), shape=shape, dtype=dtype) - weight_shape = (channels, shape[1] // groups, kernel_h, kernel_w) + if channels == -1: + channels = shape[1] + + a = relay.var("a", shape=shape, dtype=dtype) + weight_shape = (channels, shape[1] // groups, *kernel) w = tvm.nd.array(np.random.uniform(-128, 127, weight_shape).astype(dtype)) weights = relay.const(w, dtype) out = relay.nn.conv2d( a, weights, - kernel_size=(kernel_h, kernel_w), + kernel_size=kernel, dilation=dilation, strides=strides, padding=padding, @@ -84,11 +79,8 @@ def _get_model( return out, params +@pytest.mark.skipif(skip_runtime_test(), reason="Skip because BNNS codegen is not available") def test_conv2d(): - if skip_runtime_test(): - return - - device = Device() np.random.seed(0) kernel_hs = [1, 2, 3, 5] @@ -102,199 +94,61 @@ def test_conv2d(): groups = [1, 2] bias_kind = ['none', 'add_3d', 'add_4d', 'bias.add'] activation_kind = ['none', 'relu'] - dtype = "float32" trials = generate_trials( [kernel_hs, kernel_ws, pad, strides, dilation, out_channels, input_shapes, groups, batches, bias_kind, activation_kind], 3 ) for kernel_h, kernel_w, pad, stride, dilation, out_channels, input_shapes, group, batch, bias, activation in trials: - shape = (batch, *input_shapes) - outputs = [] - inputs = { - "a": tvm.nd.array(np.random.uniform(0, 127, shape).astype(dtype)), - } - func, params = _get_model( - shape, - kernel_h, - kernel_w, - pad, - stride, - dilation, - group, - dtype, - out_channels, - iter(inputs), + shape=(batch, *input_shapes), + kernel=(kernel_h, kernel_w), + padding=pad, + strides=stride, + dilation=dilation, + groups=group, + channels=out_channels, bias_type=bias, activation_type=activation, ) - for bnns in [False, True]: - outputs.append(build_and_run(func, inputs, 1, params, device, enable_bnns=bnns)[0]) - - config = { - "shape": shape, - "group": group, - "kernel size": (kernel_h, kernel_w), - "padding": pad, - "stride": stride, - "dilation": dilation, - "out channels": out_channels, - "bias": bias, - "activation": activation - } - verify(outputs, atol=0.002, rtol=0.007, config=config) + compare_inference_with_ref(func, params) +@pytest.mark.skipif(skip_runtime_test(), reason="Skip because BNNS codegen is not available") def test_conv2d_dw(): if skip_runtime_test(): return - device = Device() np.random.seed(0) - dtype = "float32" shape = [4, 5, 5] - kernel = [3, 3] - pad = [1, 1] for batch in [1, 2]: - i_shape = (batch, *shape) - channels = shape[0] - outputs = [] - inputs = { - "a": tvm.nd.array(np.random.uniform(0, 127, i_shape).astype(dtype)), - } - - a = relay.var("a", shape=i_shape, dtype=dtype) - weight_shape = [channels, 1, *kernel] - w = tvm.nd.array(np.random.uniform(-128, 127, weight_shape).astype(dtype)) - weights = relay.const(w, dtype) - func = relay.nn.conv2d( - a, - weights, - kernel_size=kernel, - padding=pad, - groups=channels, - channels=channels, - out_dtype=dtype, + mod, params = _get_model( + shape=(batch, *shape), + groups=shape[0] ) - params = {"w": w} - - for bnns in [False, True]: - outputs.append(build_and_run(func, inputs, 1, params, device, enable_bnns=bnns)[0]) - - config = { - "batch": batch, - "shape": shape, - "kernel size": kernel, - "padding": pad, - "out channels": channels, - } - verify(outputs, atol=0.002, rtol=0.007, config=config) + compare_inference_with_ref(mod, params) +@pytest.mark.skipif(skip_runtime_test(), reason="Skip because BNNS codegen is not available") def test_conv2d_with_oc1(): if skip_runtime_test(): return - device = Device() - np.random.seed(0) - dtype = "float32" - shape = [3, 5, 5] - kernel = [3, 3] - pad = [1, 1] - oc = 1 # <= test on conv with one output channel - - for batch in [1, 2]: - i_shape = (batch, *shape) - ic = shape[0] - inputs = { - "a": tvm.nd.array(np.random.uniform(0, 127, i_shape).astype(dtype)), - } - - a = relay.var("a", shape=i_shape, dtype=dtype) - weight_shape = [oc, ic, *kernel] - w = tvm.nd.array(np.random.uniform(-128, 127, weight_shape).astype(dtype)) - weights = relay.const(w, dtype) - func = relay.nn.conv2d( - a, - weights, - kernel_size=kernel, - padding=pad, - groups=1, - channels=oc, - out_dtype=dtype, - ) - params = {"w": w} - - outputs = [] - for bnns in [False, True]: - outputs.append(build_and_run(func, inputs, 1, params, device, enable_bnns=bnns)[0]) - - config = { - "batch": batch, - "shape": shape, - "kernel size": kernel, - "padding": pad, - "out channels": oc, - } - verify(outputs, atol=0.002, rtol=0.007, config=config) - - -def test_conv2d_with_scalar_bias(): - if skip_runtime_test(): - return - - device = Device() np.random.seed(0) - dtype = "float32" shape = [3, 5, 5] - kernel = [3, 3] - pad = [1, 1] - oc = 1 for batch in [1, 2]: - i_shape = (batch, *shape) - ic = shape[0] - inputs = { - "a": tvm.nd.array(np.random.uniform(0, 127, i_shape).astype(dtype)), - } - params = {} - - a = relay.var("a", shape=i_shape, dtype=dtype) - weight_shape = [oc, ic, *kernel] - w = tvm.nd.array(np.random.uniform(-128, 127, weight_shape).astype(dtype)) - weights = relay.const(w, dtype) - func = relay.nn.conv2d( - a, - weights, - kernel_size=kernel, - padding=pad, - groups=1, - channels=oc, - out_dtype=dtype, - ) - params["w"] = w - - b = tvm.nd.array(np.random.uniform(-10, 10, [1, oc, 1, 1]).astype(dtype)) # <= Check with 1, 1, 1, 1 version of bias - - bias = relay.const(b, dtype) - func = relay.add(func, bias) - params["b"] = b - - outputs = [] - for bnns in [False, True]: - outputs.append(build_and_run(func, inputs, 1, params, device, enable_bnns=bnns)[0]) - - config = { - "batch": batch, - "shape": shape, - "kernel size": kernel, - "padding": pad, - "out channels": oc, - } - verify(outputs, atol=0.002, rtol=0.007, config=config) + for bias in ["none", "add_4d"]: + mod, params = _get_model( + shape=(batch, *shape), + channels=1, + bias_type=bias + ) + compare_inference_with_ref(mod, params) if __name__ == "__main__": test_conv2d() + test_conv2d_dw() + test_conv2d_with_oc1() From 6aecc42fda220e65dcfb325c12a3e90c367aacf3 Mon Sep 17 00:00:00 2001 From: Alexander Peskov Date: Fri, 29 Jan 2021 16:27:27 +0300 Subject: [PATCH 12/27] [BNNS] Fix cpplint issues Signed-off-by: Alexander Peskov --- src/runtime/contrib/bnns/bnns_json_runtime.cc | 10 ++++++---- src/runtime/contrib/bnns/bnns_wrp.h | 17 +++++++++++------ 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/src/runtime/contrib/bnns/bnns_json_runtime.cc b/src/runtime/contrib/bnns/bnns_json_runtime.cc index cdb2307c3980..1f25ca5e1165 100644 --- a/src/runtime/contrib/bnns/bnns_json_runtime.cc +++ b/src/runtime/contrib/bnns/bnns_json_runtime.cc @@ -271,7 +271,7 @@ class BNNSJSONRuntime : public JSONRuntimeBase { std::tie(params, src_view, dst_view) = split_to_n(num_sub_prim, conv_param, src_view, wgh_view, bias_view, dst_view); - std::vector filters (params.size(), nullptr); + std::vector filters(params.size(), nullptr); for (int i = 0; i < params.size(); i++) { auto common_filter_param = getCommonFilterParams(); filters[i] = BNNSFilterCreateLayerConvolution(¶ms[i], &common_filter_param); @@ -374,9 +374,11 @@ class BNNSJSONRuntime : public JSONRuntimeBase { std::vector filters {filter}; if (a_is_weighted || b_is_weighted) { auto src_view = a_is_weighted ? b_view : a_view; - primitives_.emplace_back(std::make_shared(filters, src_view, dst_view)); + primitives_.emplace_back( + std::make_shared(filters, src_view, dst_view)); } else { - primitives_.emplace_back(std::make_shared(filters, a_view, b_view, dst_view)); + primitives_.emplace_back( + std::make_shared(filters, a_view, b_view, dst_view)); } } @@ -429,6 +431,6 @@ TVM_REGISTER_GLOBAL("runtime.BNNSJSONRuntimeCreate") TVM_REGISTER_GLOBAL("runtime.module.loadbinary_bnns_json") .set_body_typed(BNNSJSONRuntime::LoadFromBinary); -} +} // namespace contrib } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/bnns/bnns_wrp.h b/src/runtime/contrib/bnns/bnns_wrp.h index 4532fbcb4930..0cb293acc9fe 100644 --- a/src/runtime/contrib/bnns/bnns_wrp.h +++ b/src/runtime/contrib/bnns/bnns_wrp.h @@ -22,14 +22,18 @@ * \brief C++ wrappers and helpers to handle BNNS objects */ +#ifndef TVM_RUNTIME_CONTRIB_BNNS_BNNS_WRP_H_ +#define TVM_RUNTIME_CONTRIB_BNNS_BNNS_WRP_H_ + +#include + #include #include #include #include +#include #include -#include - namespace tvm { namespace runtime { namespace contrib { @@ -68,7 +72,7 @@ class Tensor { {}, {}, // shape and strides hdl, // data handler dtype, // data type - nullptr, dtype, 1.f, 0.f // table_data (clustering case), is not used + nullptr, dtype, 1.f, 0.f // table_data (clustering case), is not used }; std::copy(shape.rbegin(), shape.rend(), std::begin(desc_.size)); @@ -177,7 +181,7 @@ class TView { res.view_desc_.size[rank - 1] = 0; res.view_desc_.layout = Tensor::getPlainLayout(rank - 1); return res; - }; + } /** Squeeze all dims equal 1 */ TView squeeze(size_t min_rank = 1) const { @@ -198,7 +202,7 @@ class TView { std::fill(res.view_desc_.size + squeezed_rank, res.view_desc_.size + rank, 0); res.view_desc_.layout = Tensor::getPlainLayout(squeezed_rank); return res; - }; + } /** Construct new TView with specified layout if it applicable */ TView with_layout(BNNSDataLayout layout) const { @@ -249,7 +253,7 @@ class TView { size_t get_stride() const { return batch_stride_; } /** Return party element by index */ - TView operator [](size_t i) const { + TView operator[](size_t i) const { ICHECK_LT(i, party_size_); TView res = *this; @@ -417,3 +421,4 @@ static std::tuple, TView, TView> spl } // namespace contrib } // namespace runtime } // namespace tvm +#endif // TVM_RUNTIME_CONTRIB_BNNS_BNNS_WRP_H_ From 4cdacf18ec502c766537e94b55496f529a44681f Mon Sep 17 00:00:00 2001 From: Alexander Peskov Date: Fri, 29 Jan 2021 17:54:23 +0300 Subject: [PATCH 13/27] [BNNS] Fix clang-format issues Signed-off-by: Alexander Peskov --- src/relay/backend/contrib/bnns/codegen.cc | 20 ++- src/runtime/contrib/bnns/bnns_json_runtime.cc | 136 ++++++++---------- src/runtime/contrib/bnns/bnns_wrp.h | 111 ++++++-------- .../contrib/test_bnns/infrastructure.py | 9 +- 4 files changed, 115 insertions(+), 161 deletions(-) diff --git a/src/relay/backend/contrib/bnns/codegen.cc b/src/relay/backend/contrib/bnns/codegen.cc index c7c5646d971a..40cba3073e9a 100644 --- a/src/relay/backend/contrib/bnns/codegen.cc +++ b/src/relay/backend/contrib/bnns/codegen.cc @@ -30,9 +30,8 @@ #include #include -#include "../../utils.h" - #include "../../../../runtime/contrib/json/json_node.h" +#include "../../utils.h" #include "../codegen_json/codegen_json.h" namespace tvm { @@ -42,7 +41,8 @@ namespace contrib { using namespace backend; /*! - * \brief Retrieve the expected "root" op nested inside a fused call, such as conv2d in relu(add(conv2d)) + * \brief Retrieve the expected "root" op nested inside a fused call, such as conv2d in + * relu(add(conv2d)) * \param call A Relay call node. Typically nn.relu when called the first time. * \param max_depth The maximum number of calls before the root op, counting from current_call. * \param root_name The name of expected "root" op in this fused call. @@ -139,7 +139,6 @@ runtime::Module BNNSCompiler(const ObjectRef& ref) { TVM_REGISTER_GLOBAL("relay.ext.bnns").set_body_typed(BNNSCompiler); - /** * \brief A helper to expand the params by adding ones which used by BNNS runtime * for a given expression. Same as default ConstantUpdater but skip constant from @@ -149,7 +148,7 @@ struct BNNSConstantUpdater : public ConstantUpdater { public: BNNSConstantUpdater(const std::string& symbol, std::unordered_map* params, - const std::vector &skip_mask) + const std::vector& skip_mask) : ConstantUpdater(symbol, params), skip_mask_(skip_mask) {} using ConstantUpdater::VisitExpr_; @@ -171,13 +170,12 @@ struct BNNSConstantUpdater : public ConstantUpdater { private: bool isBNNSSpecificCompositeFunc(const FunctionNode* op) { auto comp = op->GetAttr(attr::kComposite); - if (!comp) - return false; + if (!comp) return false; auto comp_name = comp.value(); bool is_match = false; - for (const auto &mask : skip_mask_) { + for (const auto& mask : skip_mask_) { if (std::string(comp_name).substr(0, mask.size()) == mask) { is_match = true; break; @@ -199,13 +197,11 @@ Map BNNSConstantUpdaterFunc(Expr expr, std::string sym // Convert to tvm::Map Map ret; - for (const auto& kvp : res) - ret.Set(kvp.first, kvp.second); + for (const auto& kvp : res) ret.Set(kvp.first, kvp.second); return ret; } -TVM_REGISTER_GLOBAL("relay.ext.bnns.constant_updater") - .set_body_typed(BNNSConstantUpdaterFunc); +TVM_REGISTER_GLOBAL("relay.ext.bnns.constant_updater").set_body_typed(BNNSConstantUpdaterFunc); } // namespace contrib } // namespace relay diff --git a/src/runtime/contrib/bnns/bnns_json_runtime.cc b/src/runtime/contrib/bnns/bnns_json_runtime.cc index 1f25ca5e1165..8c07f3cc02ff 100644 --- a/src/runtime/contrib/bnns/bnns_json_runtime.cc +++ b/src/runtime/contrib/bnns/bnns_json_runtime.cc @@ -22,18 +22,17 @@ * \brief Simple JSON runtime for Apple BNNS primitives */ +#include #include #include -#include #include #include -#include #include +#include #include "../json/json_node.h" #include "../json/json_runtime.h" - #include "bnns_wrp.h" namespace tvm { @@ -95,7 +94,7 @@ class BNNSJSONRuntime : public JSONRuntimeBase { void Init(const Array& consts) override { ICHECK_EQ(consts.size(), const_idx_.size()) - << "The number of input constants must match the number of required."; + << "The number of input constants must match the number of required."; SetupConstants(consts); BindInputsAndOutputs(); @@ -105,21 +104,18 @@ class BNNSJSONRuntime : public JSONRuntimeBase { void Run() override { // Wrap external handler into BNNS tensor representation - auto bind_ext_hdl_to_tensor = [this] (uint32_t eid) { - const auto &ext_dlt = *data_entry_[eid]; - auto &bnns_tensor = tensors_eid_[eid]; + auto bind_ext_hdl_to_tensor = [this](uint32_t eid) { + const auto& ext_dlt = *data_entry_[eid]; + auto& bnns_tensor = tensors_eid_[eid]; bnns_tensor->set_data_hdl(ext_dlt.data); }; // Bind all input/output external data object into internal abstractions - for (const auto &eid : input_var_eid_) - bind_ext_hdl_to_tensor(eid); - for (const auto &out_entity : outputs_) - bind_ext_hdl_to_tensor(EntryID(out_entity)); + for (const auto& eid : input_var_eid_) bind_ext_hdl_to_tensor(eid); + for (const auto& out_entity : outputs_) bind_ext_hdl_to_tensor(EntryID(out_entity)); // Invoke primitives in topological order - for (const auto &prim : primitives_) - prim->execute(); + for (const auto& prim : primitives_) prim->execute(); } private: @@ -130,12 +126,10 @@ class BNNSJSONRuntime : public JSONRuntimeBase { auto node = nodes_[entry.id_]; auto dlshape = node.GetOpShape()[entry.index_]; auto dltype = node.GetOpDataType()[entry.index_]; - void *data = nullptr; - if (data_entry_[entry.id_] != nullptr) - data = data_entry_[entry.id_]->data; - tensors_eid_[entry.id_] = - std::make_shared(BNNS::Shape{dlshape.begin(), dlshape.end()}, - convertToBNNS(dltype), data); + void* data = nullptr; + if (data_entry_[entry.id_] != nullptr) data = data_entry_[entry.id_]->data; + tensors_eid_[entry.id_] = std::make_shared( + BNNS::Shape{dlshape.begin(), dlshape.end()}, convertToBNNS(dltype), data); }; for (auto& id : input_nodes_) { @@ -152,14 +146,12 @@ class BNNSJSONRuntime : public JSONRuntimeBase { void AllocateIntermediateTensors() { for (int i = 0; i < nodes_.size(); ++i) { auto eid = JSONGraphNodeEntry(i, 0); - if (tensors_eid_[eid.id_] != nullptr) - continue; + if (tensors_eid_[eid.id_] != nullptr) continue; auto node = nodes_[i]; auto dlshape = node.GetOpShape()[0]; auto dltype = node.GetOpDataType()[0]; - tensors_eid_[eid.id_] = - std::make_shared(BNNS::Shape{dlshape.begin(), dlshape.end()}, - convertToBNNS(dltype), nullptr); + tensors_eid_[eid.id_] = std::make_shared( + BNNS::Shape{dlshape.begin(), dlshape.end()}, convertToBNNS(dltype), nullptr); tensors_eid_[eid.id_]->allocate_memory(); } } @@ -212,27 +204,26 @@ class BNNSJSONRuntime : public JSONRuntimeBase { auto dl_input_shape = nodes_[src_entry.id_].GetOpShape()[src_entry.index_]; auto dl_weight_shape = nodes_[wgh_entry.id_].GetOpShape()[wgh_entry.index_]; - BNNS::Shape input_shape {dl_input_shape.begin(), dl_input_shape.end()}; - BNNS::Shape weight_shape {dl_weight_shape.begin(), dl_weight_shape.end()}; + BNNS::Shape input_shape{dl_input_shape.begin(), dl_input_shape.end()}; + BNNS::Shape weight_shape{dl_weight_shape.begin(), dl_weight_shape.end()}; std::vector str_strides = node.GetAttr>("strides"); std::vector str_dilation = node.GetAttr>("dilation"); std::vector str_padding = node.GetAttr>("padding"); BNNS::Dim groups = std::stoi(node.GetAttr>("groups")[0]); - BNNS::Dim - PH_L = std::stoi(str_padding[0]), // height padding: left - PH_R = std::stoi(str_padding[2]), // height padding: right - PW_L = std::stoi(str_padding[1]), // width padding: left - PW_R = std::stoi(str_padding[3]), // width padding: right - SH = std::stoi(str_strides[0]), // height-wise stride - SW = std::stoi(str_strides[1]), // weight-wise stride - DH = std::stoi(str_dilation[0]), // height kernel dilation - DW = std::stoi(str_dilation[1]); // width kernel dilation + BNNS::Dim PH_L = std::stoi(str_padding[0]), // height padding: left + PH_R = std::stoi(str_padding[2]), // height padding: right + PW_L = std::stoi(str_padding[1]), // width padding: left + PW_R = std::stoi(str_padding[3]), // width padding: right + SH = std::stoi(str_strides[0]), // height-wise stride + SW = std::stoi(str_strides[1]), // weight-wise stride + DH = std::stoi(str_dilation[0]), // height kernel dilation + DW = std::stoi(str_dilation[1]); // width kernel dilation // Memory descriptions. - const auto &src_t = GetBNNSTensor(src_entry); - const auto &wgh_t = GetBNNSTensor(wgh_entry); - const auto &dst_t = GetBNNSTensor(dst_entry); + const auto& src_t = GetBNNSTensor(src_entry); + const auto& wgh_t = GetBNNSTensor(wgh_entry); + const auto& dst_t = GetBNNSTensor(dst_entry); auto src_view = TView::as_is(src_t).extract_outer_dim().with_layout(BNNSDataLayoutImageCHW); auto wgh_view = TView::as_is(wgh_t).with_layout(BNNSDataLayoutConvolutionWeightsOIHW); @@ -246,9 +237,8 @@ class BNNSJSONRuntime : public JSONRuntimeBase { bias_view = TView::as_is(bias_t).squeeze().with_layout(BNNSDataLayoutVector); } - BNNSActivation activation = { has_relu ? - BNNSActivationFunctionRectifiedLinear : - BNNSActivationFunctionIdentity }; + BNNSActivation activation = {has_relu ? BNNSActivationFunctionRectifiedLinear + : BNNSActivationFunctionIdentity}; BNNSLayerParametersConvolution conv_param = { src_view.get_bnns_view(), @@ -256,20 +246,20 @@ class BNNSJSONRuntime : public JSONRuntimeBase { dst_view.get_bnns_view(), bias_view.get_bnns_view(), activation, - SW, /* x_stride */ - SH, /* y_stride */ - DW, /* x_dilation_stride */ - DH, /* y_dilation_stride */ - 0, /* x_padding, explicit pads will be used */ - 0, /* y_padding, explicit pads will be used */ - groups, /* groups */ + SW, /* x_stride */ + SH, /* y_stride */ + DW, /* x_dilation_stride */ + DH, /* y_dilation_stride */ + 0, /* x_padding, explicit pads will be used */ + 0, /* y_padding, explicit pads will be used */ + groups, /* groups */ {PW_L, PW_R, PH_L, PH_R} /* explicit pad values */ }; size_t num_sub_prim = default_thread_config.externalConcurrency; std::vector params; - std::tie(params, src_view, dst_view) = split_to_n(num_sub_prim, conv_param, - src_view, wgh_view, bias_view, dst_view); + std::tie(params, src_view, dst_view) = + split_to_n(num_sub_prim, conv_param, src_view, wgh_view, bias_view, dst_view); std::vector filters(params.size(), nullptr); for (int i = 0; i < params.size(); i++) { @@ -278,8 +268,7 @@ class BNNSJSONRuntime : public JSONRuntimeBase { ICHECK(filters[i]) << "BNNS primitive was not created. Unsupported attributes configuration"; } - primitives_.emplace_back( - std::make_shared(filters, src_view, dst_view)); + primitives_.emplace_back(std::make_shared(filters, src_view, dst_view)); } void Dense(const size_t& nid, const bool has_bias = false, const bool has_gelu = false) { @@ -308,9 +297,9 @@ class BNNSJSONRuntime : public JSONRuntimeBase { BNNSActivation activation = {BNNSActivationFunctionIdentity}; if (has_gelu) { - activation = {BNNSActivationFunctionGELUApproximation}; - activation.alpha = std::sqrt(2.0 / M_PI); - activation.beta = 0.044715; + activation = {BNNSActivationFunctionGELUApproximation}; + activation.alpha = std::sqrt(2.0 / M_PI); + activation.beta = 0.044715; } BNNSLayerParametersFullyConnected layerParameters = { @@ -325,8 +314,7 @@ class BNNSJSONRuntime : public JSONRuntimeBase { auto filter = BNNSFilterCreateLayerFullyConnected(&layerParameters, &common_filter_param); ICHECK(filter) << "BNNS primitive was not created. Unsupported attributes configuration"; std::vector filters = {filter}; - primitives_.emplace_back( - std::make_shared(filters, src_view, dst_view)); + primitives_.emplace_back(std::make_shared(filters, src_view, dst_view)); } void MatMul(const size_t& nid) { @@ -348,18 +336,16 @@ class BNNSJSONRuntime : public JSONRuntimeBase { auto b_view = TView::as_is(b_t); auto dst_view = TView::as_is(dst_t); - BNNSLayerParametersBroadcastMatMul layerParameters = { - 1, // alpha - 0, // beta - false, // transA - true, // transB - false, // quadratic - a_is_weighted, - b_is_weighted, - a_view.get_bnns_view(), - b_view.get_bnns_view(), - dst_view.get_bnns_view() - }; + BNNSLayerParametersBroadcastMatMul layerParameters = {1, // alpha + 0, // beta + false, // transA + true, // transB + false, // quadratic + a_is_weighted, + b_is_weighted, + a_view.get_bnns_view(), + b_view.get_bnns_view(), + dst_view.get_bnns_view()}; // BNNS limitation: MatMul use reverse dims values. However strides are calculated correctly // based on BNNSNDArrayDescriptor::layout value. @@ -371,18 +357,17 @@ class BNNSJSONRuntime : public JSONRuntimeBase { auto filter = BNNSFilterCreateLayerBroadcastMatMul(&layerParameters, &common_filter_param); ICHECK(filter) << "BNNS primitive was not created. Unsupported attributes configuration"; - std::vector filters {filter}; + std::vector filters{filter}; if (a_is_weighted || b_is_weighted) { auto src_view = a_is_weighted ? b_view : a_view; - primitives_.emplace_back( - std::make_shared(filters, src_view, dst_view)); + primitives_.emplace_back(std::make_shared(filters, src_view, dst_view)); } else { primitives_.emplace_back( std::make_shared(filters, a_view, b_view, dst_view)); } } - BNNS::Dtype convertToBNNS(const DLDataType &dl_dtype) { + BNNS::Dtype convertToBNNS(const DLDataType& dl_dtype) { if (dl_dtype.code == DLDataTypeCode::kDLFloat) { if (dl_dtype.bits == 32) return BNNSDataTypeFloat32; if (dl_dtype.bits == 16) return BNNSDataTypeFloat16; @@ -404,7 +389,7 @@ class BNNSJSONRuntime : public JSONRuntimeBase { BNNSFilterParameters getCommonFilterParams() { // NOTE: To force weights tensor copy on stage of filter create // just change : BNNSFlagsUseClientPtr -> 0 - return { BNNSFlagsUseClientPtr, default_thread_config.internalConcurrency }; + return {BNNSFlagsUseClientPtr, default_thread_config.internalConcurrency}; } /** Default threading config. Should be used if there are @@ -425,8 +410,7 @@ runtime::Module BNNSJSONRuntimeCreate(String symbol_name, String graph_json, return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.BNNSJSONRuntimeCreate") - .set_body_typed(BNNSJSONRuntimeCreate); +TVM_REGISTER_GLOBAL("runtime.BNNSJSONRuntimeCreate").set_body_typed(BNNSJSONRuntimeCreate); TVM_REGISTER_GLOBAL("runtime.module.loadbinary_bnns_json") .set_body_typed(BNNSJSONRuntime::LoadFromBinary); diff --git a/src/runtime/contrib/bnns/bnns_wrp.h b/src/runtime/contrib/bnns/bnns_wrp.h index 0cb293acc9fe..722b74cf4699 100644 --- a/src/runtime/contrib/bnns/bnns_wrp.h +++ b/src/runtime/contrib/bnns/bnns_wrp.h @@ -27,12 +27,12 @@ #include -#include -#include -#include #include #include +#include #include +#include +#include namespace tvm { namespace runtime { @@ -44,13 +44,9 @@ using Shape = std::vector; using Dtype = BNNSDataType; using HDL = void*; -void* default_alloc(size_t size) { - return malloc(size); -} +void* default_alloc(size_t size) { return malloc(size); } -void default_free(void* ptr) { - free(ptr); -} +void default_free(void* ptr) { free(ptr); } /** * Main abstraction for tensor representation @@ -66,14 +62,16 @@ class Tensor { auto rank = shape.size(); ICHECK(rank < BNNS_MAX_TENSOR_DIMENSION); - desc_ = { - BNNSNDArrayFlags(0), - getPlainLayout(rank), - {}, {}, // shape and strides - hdl, // data handler - dtype, // data type - nullptr, dtype, 1.f, 0.f // table_data (clustering case), is not used - }; + desc_ = {BNNSNDArrayFlags(0), + getPlainLayout(rank), + {}, // shape + {}, // strides + hdl, // data handler + dtype, // data type + nullptr, // table_data (clustering case), is not used + dtype, + 1.f, + 0.f}; std::copy(shape.rbegin(), shape.rend(), std::begin(desc_.size)); desc_.data = hdl; @@ -99,7 +97,7 @@ class Tensor { void* get_data_hdl() const { return desc_.data; } - void set_data_hdl(void *hdl) { + void set_data_hdl(void* hdl) { if (desc_.data && !is_external_data) { default_free(desc_.data); desc_.data = nullptr; @@ -116,13 +114,9 @@ class Tensor { return static_cast((rank << 16) | 0x8001); } - static size_t getRank(BNNSDataLayout layout) { - return (layout & 0xF0000) >> 16; - } + static size_t getRank(BNNSDataLayout layout) { return (layout & 0xF0000) >> 16; } - static size_t getRank(BNNSNDArrayDescriptor desc) { - return getRank(desc.layout); - } + static size_t getRank(BNNSNDArrayDescriptor desc) { return getRank(desc.layout); } static size_t getSize(BNNSNDArrayDescriptor desc) { auto rank = getRank(desc); @@ -130,12 +124,10 @@ class Tensor { } /** return size of element in bytes */ - static size_t getElementSize(Dtype dtype) { - return (dtype & 0xFFFF) / 8; - } + static size_t getElementSize(Dtype dtype) { return (dtype & 0xFFFF) / 8; } /** return size of element in bytes */ - static size_t getElementSize(const BNNSNDArrayDescriptor &desc) { + static size_t getElementSize(const BNNSNDArrayDescriptor& desc) { return getElementSize(desc.data_type); } @@ -164,7 +156,7 @@ using TensorPtr = std::shared_ptr; class TView { public: /** Make view on provided tensor as is */ - static TView as_is(const TensorPtr &origin) { + static TView as_is(const TensorPtr& origin) { TView res; res.origin_ = origin; res.view_desc_ = origin->get_desc(); @@ -176,8 +168,8 @@ class TView { auto rank = Tensor::getRank(view_desc_); TView res = *this; res.batch_size_ = view_desc_.size[rank - 1]; - res.batch_stride_ = std::accumulate(view_desc_.size, view_desc_.size + rank - 1, - 1, std::multiplies<>()); + res.batch_stride_ = + std::accumulate(view_desc_.size, view_desc_.size + rank - 1, 1, std::multiplies<>()); res.view_desc_.size[rank - 1] = 0; res.view_desc_.layout = Tensor::getPlainLayout(rank - 1); return res; @@ -189,8 +181,7 @@ class TView { size_t squeezed_shape[BNNS_MAX_TENSOR_DIMENSION] = {}; size_t squeezed_rank = 0; for (int i = 0; i < rank; i++) - if (view_desc_.size[i] != 1) - squeezed_shape[squeezed_rank++] = view_desc_.size[i]; + if (view_desc_.size[i] != 1) squeezed_shape[squeezed_rank++] = view_desc_.size[i]; if (min_rank > squeezed_rank) { std::fill(squeezed_shape + squeezed_rank, squeezed_shape + min_rank, 1); @@ -270,9 +261,7 @@ class TView { operator bool() const { return origin_ != nullptr; } /** Get BNNS descriptor for particular View. Batch and Party attributed are ignored. */ - const BNNSNDArrayDescriptor& get_bnns_view() const { - return view_desc_; - } + const BNNSNDArrayDescriptor& get_bnns_view() const { return view_desc_; } private: /** Original tensor object to view on */ @@ -295,19 +284,14 @@ class TView { */ class Primitive { public: - Primitive(const std::vector fs, - const TView& src, - const TView& dst) + Primitive(const std::vector fs, const TView& src, const TView& dst) : filters(fs), src_view(src), src2_view(), dst_view(dst) {} - Primitive(const std::vector fs, - const TView& src, - const TView& src2, - const TView& dst) + Primitive(const std::vector fs, const TView& src, const TView& src2, const TView& dst) : filters(fs), src_view(src), src2_view(src2), dst_view(dst) {} ~Primitive() { - for (auto &filter : filters) + for (auto& filter : filters) if (filter) { BNNSFilterDestroy(filter); filter = nullptr; @@ -323,12 +307,11 @@ class Primitive { private: static int run_task(int task_id, TVMParallelGroupEnv* penv, void* cdata) { auto prim = reinterpret_cast(cdata); - const auto filter = prim->filters[task_id]; + const auto filter = prim->filters[task_id]; const auto src_view = prim->src_view[task_id]; const auto dst_view = prim->dst_view[task_id]; TView src2_view; - if (prim->src2_view) - src2_view = prim->src2_view[task_id]; + if (prim->src2_view) src2_view = prim->src2_view[task_id]; size_t mb = src_view.get_batch_size(); @@ -337,14 +320,14 @@ class Primitive { // BNNSFilterApply doesn't work for grouped convolution. // * Group convolution doesn't support arbitrary stride for Batch dim. // The tensor should be dense. - auto sts = (prim->src2_view) - ? BNNSFilterApplyTwoInputBatch(filter, mb, - src_view.get_data_hdl(), src_view.get_stride(), - src2_view.get_data_hdl(), src2_view.get_stride(), - dst_view.get_data_hdl(), dst_view.get_stride()) - : BNNSFilterApplyBatch(filter, mb, - src_view.get_data_hdl(), src_view.get_stride(), - dst_view.get_data_hdl(), dst_view.get_stride()); + auto sts = + (prim->src2_view) + ? BNNSFilterApplyTwoInputBatch(filter, mb, src_view.get_data_hdl(), + src_view.get_stride(), src2_view.get_data_hdl(), + src2_view.get_stride(), dst_view.get_data_hdl(), + dst_view.get_stride()) + : BNNSFilterApplyBatch(filter, mb, src_view.get_data_hdl(), src_view.get_stride(), + dst_view.get_data_hdl(), dst_view.get_stride()); return sts; } @@ -368,12 +351,8 @@ class Primitive { * @return collection of Convolution descriptors plus corresponding src/dst tensors view */ static std::tuple, TView, TView> split_to_n( - size_t num, - const BNNSLayerParametersConvolution& orig_conv_param, - const TView &src_view, - const TView &wgh_view, - const TView &b_view, - const TView &dst_view) { + size_t num, const BNNSLayerParametersConvolution& orig_conv_param, const TView& src_view, + const TView& wgh_view, const TView& b_view, const TView& dst_view) { size_t batch = src_view.get_batch_size(); size_t oc = dst_view.get_bnns_view().size[2]; size_t groups = orig_conv_param.groups; @@ -384,9 +363,7 @@ static std::tuple, TView, TView> spl BNNS::TView dst_view_new; // TODO(apeskov): Add split by batch dim. Meanwhile we just disable it... - if (batch > 1 || - oc % num != 0 || - (groups > 1 && groups % num != 0)) { + if (batch > 1 || oc % num != 0 || (groups > 1 && groups % num != 0)) { return {{orig_conv_param}, src_view, dst_view}; } @@ -404,14 +381,14 @@ static std::tuple, TView, TView> spl dst_view_new = dst_view.party_split_n(num); std::vector res(num); - for (size_t i=0; i < num; i++) { - auto &cur = res[i]; + for (size_t i = 0; i < num; i++) { + auto& cur = res[i]; cur = orig_conv_param; cur.i_desc = src_view_new[i].get_bnns_view(); cur.o_desc = dst_view_new[i].get_bnns_view(); cur.w_desc = wgh_view_new[i].get_bnns_view(); - cur.bias = b_view_new[i].get_bnns_view(); + cur.bias = b_view_new[i].get_bnns_view(); cur.groups = groups; } return {res, src_view_new, dst_view_new}; diff --git a/tests/python/contrib/test_bnns/infrastructure.py b/tests/python/contrib/test_bnns/infrastructure.py index 74b489e2207f..2f4833200763 100644 --- a/tests/python/contrib/test_bnns/infrastructure.py +++ b/tests/python/contrib/test_bnns/infrastructure.py @@ -46,10 +46,6 @@ class Device: The test configuration will be loaded once when the the class is created. If the configuration changes between tests, any changes will not be picked up. - Parameters - ---------- - device : RPCSession - Allows tests to connect to and use remote device. Attributes ---------- @@ -310,12 +306,13 @@ def generate_trials(space, r_factor=3): ---------- space: List[List[Any]] A list of different options with varying values to test. - r_factor: (optional) int + r_factor: Optional[int] The repeat factor. Returns ------- - A list of trials specifying values for each option. + result: List[Tuple] + A list of trials specifying values for each option. """ np.random.seed(0) From 1060b98d2804fe36e3acd9ef69aa6554e2447d19 Mon Sep 17 00:00:00 2001 From: Alexander Peskov Date: Mon, 1 Feb 2021 15:33:52 +0300 Subject: [PATCH 14/27] Fix python format Signed-off-by: Alexander Peskov --- python/tvm/relay/op/contrib/bnns.py | 18 +++++-- .../contrib/test_bnns/infrastructure.py | 8 +-- tests/python/contrib/test_bnns/test_conv2d.py | 53 +++++++++++++------ .../contrib/test_bnns/test_conv2d_patterns.py | 38 ++++--------- tests/python/contrib/test_bnns/test_dense.py | 12 ++++- tests/python/contrib/test_bnns/test_matmul.py | 12 ++--- .../contrib/test_bnns/test_onnx_topologies.py | 19 ++++--- 7 files changed, 90 insertions(+), 70 deletions(-) diff --git a/python/tvm/relay/op/contrib/bnns.py b/python/tvm/relay/op/contrib/bnns.py index c00ddfe842b6..a8d8eecf8483 100644 --- a/python/tvm/relay/op/contrib/bnns.py +++ b/python/tvm/relay/op/contrib/bnns.py @@ -209,9 +209,21 @@ def check_dense(extract): @register_pattern_table("bnns") def pattern_table(): - conv2d_bias_pat = ("bnns.conv2d_bias", make_conv_relu_pattern(with_bias=True, with_relu=False), check_conv) - conv2d_bias_relu_pat = ("bnns.conv2d_bias_relu", make_conv_relu_pattern(with_bias=True, with_relu=True), check_conv) - conv2d_relu_pat = ("bnns.conv2d_relu", make_conv_relu_pattern(with_bias=False, with_relu=True), check_conv) + conv2d_bias_pat = ( + "bnns.conv2d_bias", + make_conv_relu_pattern(with_bias=True, with_relu=False), + check_conv, + ) + conv2d_bias_relu_pat = ( + "bnns.conv2d_bias_relu", + make_conv_relu_pattern(with_bias=True, with_relu=True), + check_conv, + ) + conv2d_relu_pat = ( + "bnns.conv2d_relu", + make_conv_relu_pattern(with_bias=False, with_relu=True), + check_conv, + ) dense_bias_gelu = ("bnns.dense_bias_gelu", make_dense_bias_gelu_pattern(), check_dense) dense_bias = ("bnns.dense_bias", make_dense_bias_pattern(), check_dense) bnns_patterns = [ diff --git a/tests/python/contrib/test_bnns/infrastructure.py b/tests/python/contrib/test_bnns/infrastructure.py index 2f4833200763..e407e21fd868 100644 --- a/tests/python/contrib/test_bnns/infrastructure.py +++ b/tests/python/contrib/test_bnns/infrastructure.py @@ -103,9 +103,7 @@ def load(cls, file_name): location = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) config_file = os.path.join(location, file_name) if not os.path.exists(config_file): - warnings.warn( - "Config file doesn't exist, resuming tests with default config." - ) + warnings.warn("Config file doesn't exist, resuming tests with default config.") return with open(config_file, mode="r") as config: test_config = json.load(config) @@ -198,9 +196,7 @@ def update_lib(lib, device, cross_compile): def extract_bnns_modules(module): """Get the BNNS module(s) from llvm module.""" - return list( - filter(lambda mod: mod.type_key == "bnns_json", module.get_lib().imported_modules) - ) + return list(filter(lambda mod: mod.type_key == "bnns_json", module.get_lib().imported_modules)) def verify(answers, atol, rtol, verify_saturation=False, config=None): diff --git a/tests/python/contrib/test_bnns/test_conv2d.py b/tests/python/contrib/test_bnns/test_conv2d.py index d02efb33362d..a0507b6ad798 100644 --- a/tests/python/contrib/test_bnns/test_conv2d.py +++ b/tests/python/contrib/test_bnns/test_conv2d.py @@ -39,8 +39,8 @@ def _get_model( groups=1, dtype="float32", channels=-1, # -1 means same as input channels - bias_type='none', - activation_type='none', + bias_type="none", + activation_type="none", ): """Return a model and any parameters it may have""" if channels == -1: @@ -68,7 +68,9 @@ def _get_model( out = relay.nn.bias_add(out, biasc, axis=1) params["b"] = b elif bias_type == "add_3d" or bias_type == "add_4d": - bias_shape = (weight_shape[0], 1, 1) if bias_type == "add_3d" else (1, weight_shape[0], 1, 1) + bias_shape = ( + (weight_shape[0], 1, 1) if bias_type == "add_3d" else (1, weight_shape[0], 1, 1) + ) b = tvm.nd.array(np.random.uniform(-10, 10, bias_shape).astype(dtype)) biasc = relay.const(b, dtype) out = relay.add(out, biasc) @@ -92,14 +94,38 @@ def test_conv2d(): input_shapes = [(10, 10, 14), (12, 15, 16), (20, 20, 20)] batches = [1, 2] groups = [1, 2] - bias_kind = ['none', 'add_3d', 'add_4d', 'bias.add'] - activation_kind = ['none', 'relu'] + bias_kind = ["none", "add_3d", "add_4d", "bias.add"] + activation_kind = ["none", "relu"] trials = generate_trials( - [kernel_hs, kernel_ws, pad, strides, dilation, out_channels, input_shapes, - groups, batches, bias_kind, activation_kind], 3 + [ + kernel_hs, + kernel_ws, + pad, + strides, + dilation, + out_channels, + input_shapes, + groups, + batches, + bias_kind, + activation_kind, + ], + 3, ) - for kernel_h, kernel_w, pad, stride, dilation, out_channels, input_shapes, group, batch, bias, activation in trials: + for ( + kernel_h, + kernel_w, + pad, + stride, + dilation, + out_channels, + input_shapes, + group, + batch, + bias, + activation, + ) in trials: func, params = _get_model( shape=(batch, *input_shapes), kernel=(kernel_h, kernel_w), @@ -123,10 +149,7 @@ def test_conv2d_dw(): shape = [4, 5, 5] for batch in [1, 2]: - mod, params = _get_model( - shape=(batch, *shape), - groups=shape[0] - ) + mod, params = _get_model(shape=(batch, *shape), groups=shape[0]) compare_inference_with_ref(mod, params) @@ -140,11 +163,7 @@ def test_conv2d_with_oc1(): for batch in [1, 2]: for bias in ["none", "add_4d"]: - mod, params = _get_model( - shape=(batch, *shape), - channels=1, - bias_type=bias - ) + mod, params = _get_model(shape=(batch, *shape), channels=1, bias_type=bias) compare_inference_with_ref(mod, params) diff --git a/tests/python/contrib/test_bnns/test_conv2d_patterns.py b/tests/python/contrib/test_bnns/test_conv2d_patterns.py index 9dc0695d57c4..b10504bbc961 100644 --- a/tests/python/contrib/test_bnns/test_conv2d_patterns.py +++ b/tests/python/contrib/test_bnns/test_conv2d_patterns.py @@ -37,10 +37,13 @@ def is_op_fused(func, op_name): is_fused = False def visit(op): - if isinstance(op, tvm.relay.function.Function) \ - and op_name in op.attrs["PartitionedFromPattern"]: + if ( + isinstance(op, tvm.relay.function.Function) + and op_name in op.attrs["PartitionedFromPattern"] + ): nonlocal is_fused is_fused = True + tvm.relay.analysis.post_order_visit(func.body, visit) return is_fused @@ -49,11 +52,7 @@ def test_pattern_conv2d_with_bias_add(): for axis in (1, 2): a = relay.var("a", shape=(2, 7, 8, 8), dtype=fp32) w = relay.const(np.random.uniform(-10, 10, (8, 7, 3, 3)).astype(fp32)) - res = relay.nn.conv2d( - a, w, - kernel_size=(3, 3), padding=(1, 1), - channels=8, out_dtype=fp32 - ) + res = relay.nn.conv2d(a, w, kernel_size=(3, 3), padding=(1, 1), channels=8, out_dtype=fp32) b = relay.const(np.random.uniform(-10, 10, 8).astype(fp32)) res = relay.nn.bias_add(res, b, axis=axis) @@ -64,21 +63,12 @@ def test_pattern_conv2d_with_bias_add(): def test_pattern_conv2d_with_add(): - workloads = { - 8: False, - (8, 1): False, - (8, 1, 1): True, - (1, 8, 1, 1): True - } + workloads = {8: False, (8, 1): False, (8, 1, 1): True, (1, 8, 1, 1): True} for b_shape, should_be_fused in workloads.items(): a = relay.var("a", shape=(2, 7, 8, 8), dtype=fp32) w = relay.const(np.random.uniform(-10, 10, (8, 7, 3, 3)).astype(fp32)) - res = relay.nn.conv2d( - a, w, - kernel_size=(3, 3), padding=(1, 1), - channels=8, out_dtype=fp32 - ) + res = relay.nn.conv2d(a, w, kernel_size=(3, 3), padding=(1, 1), channels=8, out_dtype=fp32) b = relay.const(np.random.uniform(-10, 10, b_shape).astype(fp32)) res = relay.add(res, b) @@ -96,11 +86,7 @@ def test_pattern_conv2d_with_non_cons_weights(): else: w = relay.var("w", shape=(8, 7, 3, 3), dtype=fp32) - res = relay.nn.conv2d( - a, w, - kernel_size=(3, 3), padding=(1, 1), - channels=8, out_dtype=fp32 - ) + res = relay.nn.conv2d(a, w, kernel_size=(3, 3), padding=(1, 1), channels=8, out_dtype=fp32) mod = partition(res) use_bnns = len(mod.get_global_vars()) == 2 # GlobalVar: "main" and "bnns_0" @@ -111,11 +97,7 @@ def test_pattern_conv2d_with_non_cons_weights(): def test_pattern_conv2d_with_non_cons_bias(): a = relay.var("a", shape=[2, 7, 8, 8], dtype=fp32) w = relay.const(np.random.uniform(-10, 10, (8, 7, 3, 3)).astype(fp32)) - res = relay.nn.conv2d( - a, w, - kernel_size=(3, 3), padding=(1, 1), - channels=8, out_dtype=fp32 - ) + res = relay.nn.conv2d(a, w, kernel_size=(3, 3), padding=(1, 1), channels=8, out_dtype=fp32) b = relay.var("b", shape=[8], dtype=fp32) res = relay.nn.bias_add(res, b, axis=1) diff --git a/tests/python/contrib/test_bnns/test_dense.py b/tests/python/contrib/test_bnns/test_dense.py index f995194d1775..60280b1daea6 100644 --- a/tests/python/contrib/test_bnns/test_dense.py +++ b/tests/python/contrib/test_bnns/test_dense.py @@ -129,7 +129,13 @@ def test_dense(): outputs = [] inputs = {"a": tvm.nd.array(np.random.uniform(-128, 127, shape).astype(dtype))} func, params = _get_model( - shape, weight_shape, units, dtype, var_names=iter(inputs), has_bias=with_bias, has_gelu=with_gelu + shape, + weight_shape, + units, + dtype, + var_names=iter(inputs), + has_bias=with_bias, + has_gelu=with_gelu, ) for bnns in [False, True]: outputs.append( @@ -176,7 +182,9 @@ def test_codegen_dense(): args = (shape, weight_shape, units, dtype) - func, params = _get_model(*args, var_names=iter(inputs), has_bias=with_bias, has_gelu=with_gelu) + func, params = _get_model( + *args, var_names=iter(inputs), has_bias=with_bias, has_gelu=with_gelu + ) exp_codegen = _get_expected_codegen(*args, has_bias=with_bias, has_gelu=with_gelu) verify_codegen(func, exp_codegen, 1) diff --git a/tests/python/contrib/test_bnns/test_matmul.py b/tests/python/contrib/test_bnns/test_matmul.py index 38683e8748b6..408cd3b1dc6c 100644 --- a/tests/python/contrib/test_bnns/test_matmul.py +++ b/tests/python/contrib/test_bnns/test_matmul.py @@ -40,11 +40,11 @@ def _get_model(a_shape, b_shape, dtype, var_names, is_a_constant=False, is_b_con params = {} if is_b_constant is True: b = tvm.nd.array(np.random.uniform(-128, 127, b_shape).astype(dtype)) - params['b'] = b + params["b"] = b b = relay.const(b, dtype) if is_a_constant is True: a = tvm.nd.array(np.random.uniform(-128, 127, a_shape).astype(dtype)) - params['a'] = a + params["a"] = a a = relay.const(a, dtype) out = relay.nn.batch_matmul(a, b) return out, params @@ -84,10 +84,12 @@ def test_matmul(): "b": tvm.nd.array(np.random.uniform(-128, 127, b_shape).astype(dtype)), } func, params = _get_model( - a_shape, b_shape, dtype, + a_shape, + b_shape, + dtype, var_names=iter(inputs), is_a_constant=is_a_constant, - is_b_constant=is_b_constant + is_b_constant=is_b_constant, ) for enable_bnns in [False, True]: outputs.append( @@ -111,5 +113,3 @@ def test_matmul(): if __name__ == "__main__": test_matmul() - - diff --git a/tests/python/contrib/test_bnns/test_onnx_topologies.py b/tests/python/contrib/test_bnns/test_onnx_topologies.py index ebc255830811..241dac70eb9a 100644 --- a/tests/python/contrib/test_bnns/test_onnx_topologies.py +++ b/tests/python/contrib/test_bnns/test_onnx_topologies.py @@ -34,14 +34,14 @@ BASE_MODEL_URL = "https://github.com/onnx/models/raw/master/" MODEL_URL_COLLECTION = { - "BERT": "text/machine_comprehension/bert-squad/model/bertsquad-10.onnx", - "MobileNet-v2": "vision/classification/mobilenet/model/mobilenetv2-7.onnx", - "ResNet50-v1": "vision/classification/resnet/model/resnet50-v1-7.onnx", - "ResNet50-v2": "vision/classification/resnet/model/resnet50-v2-7.onnx", + "BERT": "text/machine_comprehension/bert-squad/model/bertsquad-10.onnx", + "MobileNet-v2": "vision/classification/mobilenet/model/mobilenetv2-7.onnx", + "ResNet50-v1": "vision/classification/resnet/model/resnet50-v1-7.onnx", + "ResNet50-v2": "vision/classification/resnet/model/resnet50-v2-7.onnx", "SqueezeNet-v1.1": "vision/classification/squeezenet/model/squeezenet1.1-7.onnx", "SqueezeNet-v1.0": "vision/classification/squeezenet/model/squeezenet1.0-7.onnx", - "Inception-v1": "vision/classification/inception_and_googlenet/inception_v1/model/inception-v1-7.onnx", - "Inception-v2": "vision/classification/inception_and_googlenet/inception_v2/model/inception-v2-7.onnx", + "Inception-v1": "vision/classification/inception_and_googlenet/inception_v1/model/inception-v1-7.onnx", + "Inception-v2": "vision/classification/inception_and_googlenet/inception_v2/model/inception-v2-7.onnx", } @@ -58,7 +58,7 @@ def get_model_url(model_name): def get_name_from_url(url): - return url[url.rfind('/') + 1:].strip() + return url[url.rfind("/") + 1 :].strip() def find_of_download(model_name): @@ -124,7 +124,10 @@ def run(mod, target, simplify=True, with_bnns=False): res_bnns = run(model, TARGET, simplify=True, with_bnns=True) tvm.testing.assert_allclose( - res_llvm, res_bnns, atol=0.002, rtol=0.007, + res_llvm, + res_bnns, + atol=0.002, + rtol=0.007, ) From 7203493a96dba1b606ad31e6b4c29b613359423c Mon Sep 17 00:00:00 2001 From: Alexander Peskov Date: Tue, 2 Feb 2021 20:17:45 +0300 Subject: [PATCH 15/27] Fix pylint issues Signed-off-by: Alexander Peskov --- python/tvm/relay/op/contrib/bnns.py | 13 +- tests/cpp/contrib/bnns.cc | 307 ++++++++++++++++++++++++++++ 2 files changed, 317 insertions(+), 3 deletions(-) create mode 100644 tests/cpp/contrib/bnns.cc diff --git a/python/tvm/relay/op/contrib/bnns.py b/python/tvm/relay/op/contrib/bnns.py index a8d8eecf8483..68efceb2052d 100644 --- a/python/tvm/relay/op/contrib/bnns.py +++ b/python/tvm/relay/op/contrib/bnns.py @@ -23,13 +23,15 @@ """ import math import tvm.ir -from ...dataflow_pattern import wildcard, is_op, is_expr + from .register import register_pattern_table, get_pattern_table from tvm.relay import transform from tvm.relay.expr import const from tvm.relay.build_module import bind_params_by_name +from ...dataflow_pattern import wildcard, is_op, is_expr + def partition_for_bnns(mod, params=None): """Partition the graph greedily offloading supported @@ -98,6 +100,7 @@ def _func_wrapper(expr): def dtype_is_supported(dtype): + """Check if data type is supported by BNNS backend""" return dtype in ("", "float32") @@ -127,11 +130,11 @@ def bias_check(expr): return False if expr.op.name == "nn.bias_add": return attrs.axis == 1 - elif expr.op.name == "add": + if expr.op.name == "add": b_shape = args[1].checked_type.shape if len(b_shape) == 4: return bool(b_shape[0] == 1 and b_shape[2] == 1 and b_shape[3] == 1) - elif len(b_shape) == 3: + if len(b_shape) == 3: return bool(b_shape[1] == 1 and b_shape[2] == 1) return False @@ -153,6 +156,7 @@ def dense(expr): def make_conv_relu_pattern(with_bias=True, with_relu=True): + """Make pattern for bnns.conv2d primitive""" data = wildcard() weight = wildcard() bias = wildcard() @@ -176,6 +180,7 @@ def check_conv(extract): def make_dense_bias_pattern(): + """Make pattern for bnns.dense primitive""" data = wildcard() weight = wildcard() bias = wildcard() @@ -184,6 +189,7 @@ def make_dense_bias_pattern(): def make_dense_bias_gelu_pattern(): + """Make pattern for bnns.dense primitive with fused bias and gelu activation""" dense_bias = make_dense_bias_pattern() const1 = is_expr(const(0.044715)) const2 = is_expr(const(math.sqrt(2 / math.pi))) @@ -209,6 +215,7 @@ def check_dense(extract): @register_pattern_table("bnns") def pattern_table(): + """Get BNNS specific fusing patterns collection""" conv2d_bias_pat = ( "bnns.conv2d_bias", make_conv_relu_pattern(with_bias=True, with_relu=False), diff --git a/tests/cpp/contrib/bnns.cc b/tests/cpp/contrib/bnns.cc new file mode 100644 index 000000000000..1efd487caff9 --- /dev/null +++ b/tests/cpp/contrib/bnns.cc @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +TEST(PackedFunc, Basic) { + using namespace tvm; + using namespace tvm::tir; + using namespace tvm::runtime; + int x = 0; + void* handle = &x; + DLTensor a; + + Var v = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { + ICHECK(args.num_args == 3); + ICHECK(args.values[0].v_float64 == 1.0); + ICHECK(args.type_codes[0] == kDLFloat); + ICHECK(args.values[1].v_handle == &a); + ICHECK(args.type_codes[1] == kTVMDLTensorHandle); + ICHECK(args.values[2].v_handle == &x); + ICHECK(args.type_codes[2] == kTVMOpaqueHandle); + *rv = Var("a"); + })(1.0, &a, handle); + ICHECK(v->name_hint == "a"); +} + +TEST(PackedFunc, Node) { + using namespace tvm; + using namespace tvm::tir; + using namespace tvm::runtime; + Var x; + Var t = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { + ICHECK(args.num_args == 1); + ICHECK(args[0].IsObjectRef()); + Var b = args[0]; + ICHECK(x.same_as(b)); + *rv = b; + })(x); + ICHECK(t.same_as(x)); +} + +TEST(PackedFunc, NDArray) { + using namespace tvm; + using namespace tvm::runtime; + auto x = NDArray::Empty({}, String2DLDataType("float32"), TVMContext{kDLCPU, 0}); + reinterpret_cast(x->data)[0] = 10.0f; + ICHECK(x.use_count() == 1); + + PackedFunc forward([&](TVMArgs args, TVMRetValue* rv) { *rv = args[0]; }); + + NDArray ret = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { + NDArray y = args[0]; + DLTensor* ptr = args[0]; + ICHECK(ptr == x.operator->()); + ICHECK(x.same_as(y)); + ICHECK(x.use_count() == 2); + *rv = forward(y); + })(x); + ICHECK(ret.use_count() == 2); + ICHECK(ret.same_as(x)); +} + +TEST(PackedFunc, str) { + using namespace tvm; + using namespace tvm::runtime; + PackedFunc([&](TVMArgs args, TVMRetValue* rv) { + ICHECK(args.num_args == 1); + std::string x = args[0]; + ICHECK(x == "hello"); + String y = args[0]; + ICHECK(y == "hello"); + *rv = x; + })("hello"); + + PackedFunc([&](TVMArgs args, TVMRetValue* rv) { + ICHECK(args.num_args == 1); + runtime::String s = args[0]; + ICHECK(s == "hello"); + })(runtime::String("hello")); +} + +TEST(PackedFunc, func) { + using namespace tvm; + using namespace tvm::runtime; + PackedFunc addone([&](TVMArgs args, TVMRetValue* rv) { *rv = args[0].operator int() + 1; }); + // function as arguments + int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + PackedFunc f = args[0]; + // TVMArgValue -> Arguments as function + *rv = f(args[1]).operator int(); + })(addone, 1); + ICHECK_EQ(r0, 2); + + int r1 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + // TVMArgValue -> TVMRetValue + *rv = args[1]; + })(2, 100); + ICHECK_EQ(r1, 100); + + int r2 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { + // re-assignment + *rv = args[0]; + // TVMRetValue -> Function argument + *rv = addone(args[0].operator PackedFunc()(args[1], 1)); + })(addone, 100); + ICHECK_EQ(r2, 102); +} + +TEST(PackedFunc, Expr) { + using namespace tvm; + using namespace tvm::runtime; + // automatic conversion of int to expr + PackedFunc addone([](TVMArgs args, TVMRetValue* rv) { + PrimExpr x = args[0]; + *rv = x.as()->value + 1; + }); + int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + PackedFunc f = args[0]; + // TVMArgValue -> Arguments as function + *rv = f(args[1]).operator int(); + })(addone, 1); + ICHECK_EQ(r0, 2); +} + +TEST(PackedFunc, Type) { + using namespace tvm; + using namespace tvm::runtime; + auto get_type = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + DataType x = args[0]; + *rv = x; + }); + auto get_type2 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { *rv = args[0]; }); + ICHECK(get_type("int32").operator DataType() == DataType::Int(32)); + ICHECK(get_type("float").operator DataType() == DataType::Float(32)); + ICHECK(get_type2("float32x2").operator DataType() == DataType::Float(32, 2)); +} + +TEST(TypedPackedFunc, HighOrder) { + using namespace tvm; + using namespace tvm::runtime; + using Int1Func = TypedPackedFunc; + using Int2Func = TypedPackedFunc; + using BindFunc = TypedPackedFunc; + BindFunc ftyped; + ftyped = [](Int2Func f1, int value) -> Int1Func { + auto binded = [f1, value](int x) { return f1(value, x); }; + Int1Func x(binded); + return x; + }; + auto add = [](int x, int y) { return x + y; }; + ICHECK_EQ(ftyped(Int2Func(add), 1)(2), 3); + PackedFunc f = ftyped(Int2Func(add), 1); + ICHECK_EQ(f(3).operator int(), 4); + // call the type erased version. + Int1Func f1 = ftyped.packed()(Int2Func(add), 1); + ICHECK_EQ(f1(3), 4); +} + +TEST(TypedPackedFunc, Deduce) { + using namespace tvm::runtime; + using tvm::runtime::detail::function_signature; + + TypedPackedFunc x; + auto f = [](int x) -> int { return x + 1; }; + std::function y; + + static_assert(std::is_same::FType, int(float)>::value, + "invariant1"); + static_assert(std::is_same::FType, int(int)>::value, + "invariant2"); + static_assert(std::is_same::FType, void(float)>::value, + "invariant3"); +} + +TEST(PackedFunc, ObjectConversion) { + using namespace tvm; + using namespace tvm::tir; + using namespace tvm::runtime; + TVMRetValue rv; + auto x = NDArray::Empty({}, String2DLDataType("float32"), TVMContext{kDLCPU, 0}); + // assign null + rv = ObjectRef(); + ICHECK_EQ(rv.type_code(), kTVMNullptr); + + // Can assign NDArray to ret type + rv = x; + ICHECK_EQ(rv.type_code(), kTVMNDArrayHandle); + // Even if we assign base type it still shows as NDArray + rv = ObjectRef(x); + ICHECK_EQ(rv.type_code(), kTVMNDArrayHandle); + // Check convert back + ICHECK(rv.operator NDArray().same_as(x)); + ICHECK(rv.operator ObjectRef().same_as(x)); + ICHECK(!rv.IsObjectRef()); + + auto pf1 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { + ICHECK_EQ(args[0].type_code(), kTVMNDArrayHandle); + ICHECK(args[0].operator NDArray().same_as(x)); + ICHECK(args[0].operator ObjectRef().same_as(x)); + ICHECK(args[1].operator ObjectRef().get() == nullptr); + ICHECK(args[1].operator NDArray().get() == nullptr); + ICHECK(args[1].operator Module().get() == nullptr); + ICHECK(args[1].operator Array().get() == nullptr); + ICHECK(!args[0].IsObjectRef()); + }); + pf1(x, ObjectRef()); + pf1(ObjectRef(x), NDArray()); + + // testcases for modules + auto* pf = tvm::runtime::Registry::Get("runtime.SourceModuleCreate"); + ICHECK(pf != nullptr); + Module m = (*pf)("", "xyz"); + rv = m; + ICHECK_EQ(rv.type_code(), kTVMModuleHandle); + // Even if we assign base type it still shows as NDArray + rv = ObjectRef(m); + ICHECK_EQ(rv.type_code(), kTVMModuleHandle); + // Check convert back + ICHECK(rv.operator Module().same_as(m)); + ICHECK(rv.operator ObjectRef().same_as(m)); + ICHECK(!rv.IsObjectRef()); + + auto pf2 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { + ICHECK_EQ(args[0].type_code(), kTVMModuleHandle); + ICHECK(args[0].operator Module().same_as(m)); + ICHECK(args[0].operator ObjectRef().same_as(m)); + ICHECK(args[1].operator ObjectRef().get() == nullptr); + ICHECK(args[1].operator NDArray().get() == nullptr); + ICHECK(args[1].operator Module().get() == nullptr); + ICHECK(!args[0].IsObjectRef()); + }); + pf2(m, ObjectRef()); + pf2(ObjectRef(m), Module()); +} + +TEST(TypedPackedFunc, RValue) { + using namespace tvm; + using namespace tvm::runtime; + { + auto inspect = [](TVMArgs args, TVMRetValue* rv) { + for (int i = 0; i < args.size(); ++i) { + ICHECK_EQ(args[0].type_code(), kTVMObjectRValueRefArg); + } + }; + PackedFunc finspect(inspect); + finspect(tir::Var("x")); + } + { + auto f = [](tir::Var x, bool move) { + if (move) { + ICHECK(x.unique()); + } else { + ICHECK(!x.unique()); + } + ICHECK(x->name_hint == "x"); + return x; + }; + TypedPackedFunc tf(f); + + tir::Var var("x"); + ICHECK(var.unique()); + tf(var, false); + // move the result to the function. + tir::Var ret = tf(std::move(var), true); + ICHECK(!var.defined()); + } + + { + // pass child class. + auto f = [](PrimExpr x, bool move) { + if (move) { + ICHECK(x.unique()); + } else { + ICHECK(!x.unique()); + } + return x; + }; + TypedPackedFunc tf(f); + + tir::Var var("x"); + ICHECK(var.unique()); + tf(var, false); + tf(std::move(var), true); + // auto conversion. + tf(1, true); + } +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} From 83b7be3688cd9d792a0dad1ad05bbfe9f19781e0 Mon Sep 17 00:00:00 2001 From: Alexander Peskov Date: Mon, 8 Feb 2021 15:58:22 +0300 Subject: [PATCH 16/27] [BNNS] Fix pylint. Second attempt Signed-off-by: Alexander Peskov --- python/tvm/relay/op/contrib/bnns.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/tvm/relay/op/contrib/bnns.py b/python/tvm/relay/op/contrib/bnns.py index 68efceb2052d..6929728bcee5 100644 --- a/python/tvm/relay/op/contrib/bnns.py +++ b/python/tvm/relay/op/contrib/bnns.py @@ -24,12 +24,11 @@ import math import tvm.ir -from .register import register_pattern_table, get_pattern_table - from tvm.relay import transform from tvm.relay.expr import const from tvm.relay.build_module import bind_params_by_name +from .register import register_pattern_table, get_pattern_table from ...dataflow_pattern import wildcard, is_op, is_expr From 7ede1135e94085f00a093602f9a0e2a41d120e83 Mon Sep 17 00:00:00 2001 From: dlexplorer Date: Mon, 25 Jan 2021 23:02:46 +0300 Subject: [PATCH 17/27] [BNNS] Add integration documentation --- docs/deploy/bnns.rst | 184 ++++++++++++++++++++++++++++++++++++++++++ docs/deploy/index.rst | 1 + 2 files changed, 185 insertions(+) create mode 100644 docs/deploy/bnns.rst diff --git a/docs/deploy/bnns.rst b/docs/deploy/bnns.rst new file mode 100644 index 000000000000..8aef44e25075 --- /dev/null +++ b/docs/deploy/bnns.rst @@ -0,0 +1,184 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Relay BNNS Integration +========================== +**Author**: `Egor Churaev `_ + +Introduction +------------ + +Apple BNNS library is a collection of functions that can be used to construct neural networks +for inference (and train). It’s supported in macOS, iOS, tvOS, and watchOS. BNNS provides +primitives executed on all CPU supported on those platforms and optimized for high performance +and low-energy consumption. This integration will offload as many operators as possible from Relay to BNNS. + +BNNS runtime is a part of platform API and available on all modern Apple operating systems. +Application using BNNS will not depends on any additional external dependencies. + +BNNS functions uses Apple private hardware capabilities which are not exposed yet by Apple. Example +of such capabilities can be AMX Apple cpu extension. + +This guide will demonstrate how to build TVM with BNNS codegen and runtime enabled. It will also provide example +code to compile and run models using BNNS runtime. Finally, we document the supported operators. + +Building TVM with BNNS support +---------------------------------- + +To turn on TVM BNNS codegen and TVM BNNS runtime you need to turn on the only USE_BNNS flag + +* USE_BNNS=ON/OFF - This flag will enable compiling a network with offloading subgraphs to BNNS primitives + and will link tvm library to the BNNS runtime module. + +Enabling of this flag will cause to search the default Accelerate Frameworks on current target SDK. +The minimal versions of required SDK is macOS 11.0, iOS 14.0, tvOS 14.0 and watchOS 7.0. + +Example setting in config.cmake file: + +.. code:: cmake + + set(USE_BNNS ON) + +BNNS partitioning of Relay graph +---------------------------------------- + +Operations to be offloaded on BNNS execution must be annotated before passing of module for compilation. +All opps annotated by `partition_for_bnns` will be offloaded for BNNS execution. The rest of the ops +will go through the LLVM compilation and code generation. + +Important note: BNNS support primitives only with constant weights. To satisfy this requirements we have +to map constants to related tensor abstraction in relay representation. To freeze tensors and operate +with them as with constants you may need to call ONNX importer with special flag "freeze_params=True" +or performer binding manually. In general cases all relay importers don't do that by default. +For your convenience "partition_for_bnns" can do this for you if params dictionary is passed as the argument. + +.. code:: python + + from tvm.relay.op.contrib.bnns import partition_for_bnns + with tvm.transform.PassContext(opt_level=3): + model = partition_for_bnns(model, params=params) + + +Input data layout for operations to be offloaded to BNNS execution +---------------------------------------- + +BNNS kernels support only planar format of input data. The partitioner will require to have NCHW input +layout for conv2d input. + +To use BNNS integration for models with interleave input layout, they should be converted before +passing of module to `partition_for_bnns`. The layout conversion will happen only for explicitly +enumerated types of ops. It might happen that depending on topology there might be regular data reorder +around conv2d to interleave and planar layout. This will be reflected in performance penalties and affect +execution time. It is recommended to analyze the whole topology and extend below list to convert all +intermediate tensors to NCHW data layout. + +Example of input layouts change: + +.. code:: python + + # For models with NHWC input layout + with tvm.transform.PassContext(opt_level=3): + mod = relay.transform.InferType()(mod) + mod = relay.transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"], + "nn.bias_add": ["NCHW", "default"], + "nn.relu": ["NCHW"]})(mod) + + +Example: Build and Deploy Mobilenet v2 1.0 with BNNS +---------------------------------------- + +Create a Relay graph from a MXNet Mobilenet v2 1.0 model. + +.. code:: python + + import tvm + from tvm import relay + import mxnet + from mxnet.gluon.model_zoo.vision import get_model + + dtype = "float32" + input_shape = (1, 3, 224, 224) + block = get_model('mobilenetv2_1.0', pretrained=True) + module, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) + + +Markup the parts of graphs to be offloaded to BNNS primitives. All ops which are supported by the BNNS +integration will be handled by BNNS invocations, the rest of the ops will go through the +regular TVM llvm compilation and code generation. + +After that you need to compile new module with target corresponding to required Apple platform + +.. code:: python + + from tvm.relay.op.contrib.bnns import partition_for_bnns + + # target for macOS Big Sur 11.1: + target = "llvm -mtriple=x86_64-apple-darwin20.2.0" + + with tvm.transform.PassContext(opt_level=3): + model = partition_for_bnns(model, params=params) # to markup operations to be offloaded to BNNS + lib = relay.build(model, target=target, target_host=target, params=params) + +Export the module. + +.. code:: python + + lib.export_library('compiled.dylib') + + +Load module and run inference on the target machine with TVM built with ``USE_BNNS`` enabled + +.. code:: python + + import tvm + import numpy as np + from tvm.contrib import graph_runtime + + ctx = tvm.cpu(0) + loaded_lib = tvm.runtime.load_module('compiled.dylib') + gen_module = tvm.contrib.graph_runtime.GraphModule(loaded_lib['default'](ctx)) + + dtype = "float32" + input_shape = (1, 3, 224, 224) + input_data = np.random.uniform(0, 1, input_shape).astype(dtype) + gen_module.run(data=input_data) + + + +Operator support +---------------- + ++------------------------+------------------------------------------------------------------------------+ +| Relay Node | Remarks | ++========================+==============================================================================+ +| nn.conv2d | | ++------------------------+------------------------------------------------------------------------------+ +| nn.batch_norm | Supported by BNNS integration only in nn.conv2d-batch_norm pattern | ++------------------------+------------------------------------------------------------------------------+ +| nn.dense | | ++------------------------+------------------------------------------------------------------------------+ +| nn.batch_matmul | | ++------------------------+------------------------------------------------------------------------------+ +| nn.bias_add | Supported by BNNS integration only as a bias part of nn.conv2d or nn.dense | +| | fusion | ++------------------------+------------------------------------------------------------------------------+ +| add | Supported by BNNS integration only as a bias part of nn.conv2d or nn.dense fusion | ++------------------------+------------------------------------------------------------------------------+ +| nn.relu | Supported by BNNS integration only as a part of nn.conv2d or nn.dense fusion | ++------------------------+------------------------------------------------------------------------------+ +| nn.gelu | Supported by BNNS integration only as a part of nn.conv2d or nn.dense fusion | ++------------------------+------------------------------------------------------------------------------+ diff --git a/docs/deploy/index.rst b/docs/deploy/index.rst index 2b37f734c3c3..3cbbb10bd74b 100644 --- a/docs/deploy/index.rst +++ b/docs/deploy/index.rst @@ -71,3 +71,4 @@ target device without relying on RPC. see the following resources on how to do s arm_compute_lib tensorrt vitis_ai + bnns From 17478963e6f3604768680517cfe5d04302a08bac Mon Sep 17 00:00:00 2001 From: Alexander Peskov Date: Thu, 25 Feb 2021 00:41:22 +0300 Subject: [PATCH 18/27] Check onnx import before use Signed-off-by: Alexander Peskov --- tests/python/contrib/test_bnns/test_onnx_topologies.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/python/contrib/test_bnns/test_onnx_topologies.py b/tests/python/contrib/test_bnns/test_onnx_topologies.py index 241dac70eb9a..87885cb8f5da 100644 --- a/tests/python/contrib/test_bnns/test_onnx_topologies.py +++ b/tests/python/contrib/test_bnns/test_onnx_topologies.py @@ -16,6 +16,8 @@ # under the License. """BNNS pattern detection check""" +import pytest + import tvm from tvm import relay from tvm.relay import transform @@ -23,9 +25,8 @@ from tvm.contrib.download import download_testdata from tvm.relay.op.contrib.bnns import partition_for_bnns -import onnx import numpy as np -import pytest +pytest.importorskip("onnx") bnns_is_absent = tvm.get_global_func("relay.ext.bnns", True) is None From 1a39265d784d0f46c1196a4e21e133204189ae42 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Mon, 15 Feb 2021 16:11:27 +0300 Subject: [PATCH 19/27] [BNNS] Add instance normalization operator --- python/tvm/relay/op/contrib/bnns.py | 17 +- src/runtime/contrib/bnns/bnns_json_runtime.cc | 68 +++++- src/runtime/contrib/bnns/bnns_wrp.h | 125 +++++++++-- .../contrib/test_bnns/test_normalization.py | 208 ++++++++++++++++++ 4 files changed, 398 insertions(+), 20 deletions(-) create mode 100644 tests/python/contrib/test_bnns/test_normalization.py diff --git a/python/tvm/relay/op/contrib/bnns.py b/python/tvm/relay/op/contrib/bnns.py index 6929728bcee5..ddffcb74638c 100644 --- a/python/tvm/relay/op/contrib/bnns.py +++ b/python/tvm/relay/op/contrib/bnns.py @@ -205,13 +205,28 @@ def make_dense_bias_gelu_pattern(): def check_dense(extract): - """Check conv pattern is supported by BNNS.""" + """Check dense pattern is supported by BNNS.""" call = extract while call.op.name != "nn.dense": call = call.args[0] return dense(call) +@tvm.ir.register_op_attr("nn.instance_norm", "target.bnns") +def instance_norm_check(expr): + """Check if the nn.instance_norm can be executed in BNNS""" + attrs, args = expr.attrs, expr.args + data_typ = args[0].checked_type + rank = len(data_typ.shape) + if rank < 3 or rank > 4 or data_typ.dtype != "float32": + return False + if not isinstance(args[1], tvm.relay.expr.Constant) or not isinstance(args[2], tvm.relay.expr.Constant): + return False + if attrs.axis == 0 and rank == 3 or attrs.axis == 1 and rank == 4: + return True + return False + + @register_pattern_table("bnns") def pattern_table(): """Get BNNS specific fusing patterns collection""" diff --git a/src/runtime/contrib/bnns/bnns_json_runtime.cc b/src/runtime/contrib/bnns/bnns_json_runtime.cc index 8c07f3cc02ff..6c3e0ccb9ec2 100644 --- a/src/runtime/contrib/bnns/bnns_json_runtime.cc +++ b/src/runtime/contrib/bnns/bnns_json_runtime.cc @@ -180,6 +180,8 @@ class BNNSJSONRuntime : public JSONRuntimeBase { Dense(nid, true, true); } else if ("nn.batch_matmul" == op_name) { MatMul(nid); + } else if ("nn.instance_norm" == op_name) { + InstanceNormalization(nid); } else { LOG(FATAL) << "Unsupported op: " << op_name; } @@ -363,10 +365,74 @@ class BNNSJSONRuntime : public JSONRuntimeBase { primitives_.emplace_back(std::make_shared(filters, src_view, dst_view)); } else { primitives_.emplace_back( - std::make_shared(filters, a_view, b_view, dst_view)); + std::make_shared(filters, a_view, b_view, dst_view)); } } + void InstanceNormalization(const size_t& nid) { + auto node = nodes_[nid]; + size_t axis = std::stoi(node.GetAttr>("axis")[0]); + float epsilon = std::stof(node.GetAttr>("epsilon")[0]); + bool center = std::stoi(node.GetAttr>("center")[0]); + bool scale = std::stoi(node.GetAttr>("scale")[0]); + + // Setup attributes. + auto src_entry = node.GetInputs()[0]; + auto scale_entry = node.GetInputs()[1]; + auto bias_entry = node.GetInputs()[2]; + auto dst_entry = JSONGraphNodeEntry(nid, 0); + + // Memory descriptions. + auto src_t = GetBNNSTensor(src_entry); + auto scale_t = GetBNNSTensor(scale_entry); + auto bias_t = GetBNNSTensor(bias_entry); + auto dst_t = GetBNNSTensor(dst_entry); + + auto src_view = TView::as_is(src_t); + auto dst_view = TView::as_is(dst_t); + size_t src_rank = Tensor::getRank(src_view.get_bnns_view()); + size_t dst_rank = Tensor::getRank(dst_view.get_bnns_view()); + ICHECK_EQ(src_rank, dst_rank); + ICHECK_LE(src_rank, 4); + if (src_rank < 4) { + src_view = src_view.unsqueeze(4); + dst_view = dst_view.unsqueeze(4); + } + src_view = src_view.extract_outer_dim().with_layout(BNNSDataLayoutImageCHW); + dst_view = dst_view.extract_outer_dim().with_layout(BNNSDataLayoutImageCHW); + auto scale_view = TView::as_is(scale_t).with_layout(BNNSDataLayoutVector); + auto bias_view = TView::as_is(bias_t).with_layout(BNNSDataLayoutVector); + BNNSActivation activation = {BNNSActivationFunctionIdentity}; + + auto b_desc = bias_view.get_bnns_view(); + if (!center) + b_desc = {}; + auto s_desc = scale_view.get_bnns_view(); + if (!scale) + s_desc = {}; + + // NOTE: Axis option is ignored in BNNS. The result doesn't depends on value of axis. + BNNSLayerParametersNormalization layerParameters = {src_view.get_bnns_view(), // i_desc + dst_view.get_bnns_view(), // o_desc + b_desc, // beta_desc + s_desc, // gamma_desc + {}, // moving_mean_desc + {}, // moving_variance_desc + 1.f, // momentum + epsilon, // epsilon + activation, // activation + 1, // num_groups + axis}; // normalization_axis + + BNNSFilterType filter_type = BNNSInstanceNorm; + auto common_filter_param = getCommonFilterParams(); + auto filter = BNNSFilterCreateLayerNormalization(filter_type, &layerParameters, &common_filter_param); + ICHECK(filter) << "BNNS primitive was not created. Unsupported attributes configuration"; + + std::vector filters{filter}; + primitives_.emplace_back(std::make_shared(filters, src_view, dst_view)); + } + BNNS::Dtype convertToBNNS(const DLDataType& dl_dtype) { if (dl_dtype.code == DLDataTypeCode::kDLFloat) { if (dl_dtype.bits == 32) return BNNSDataTypeFloat32; diff --git a/src/runtime/contrib/bnns/bnns_wrp.h b/src/runtime/contrib/bnns/bnns_wrp.h index 722b74cf4699..c2b147b859e2 100644 --- a/src/runtime/contrib/bnns/bnns_wrp.h +++ b/src/runtime/contrib/bnns/bnns_wrp.h @@ -195,6 +195,37 @@ class TView { return res; } + /** Expand the shape of an array */ + TView expand_dims(std::vector axes) const { + auto rank = Tensor::getRank(view_desc_); + TView res = *this; + size_t unsqueezed_shape[BNNS_MAX_TENSOR_DIMENSION] = {}; + size_t unsqueezed_rank = axes.size() + rank; + ICHECK_LE(unsqueezed_rank, BNNS_MAX_TENSOR_DIMENSION); + for (const auto& axis : axes) { + ICHECK_LT(axis, unsqueezed_rank); + unsqueezed_shape[axis] = 1; + } + for (int i = 0, orig_idx = 0; i < unsqueezed_rank; ++i) { + if (unsqueezed_shape[i] == 1) + continue; + unsqueezed_shape[i] = view_desc_.size[orig_idx++]; + } + std::copy(unsqueezed_shape, unsqueezed_shape + unsqueezed_rank, res.view_desc_.size); + res.view_desc_.layout = Tensor::getPlainLayout(unsqueezed_rank); + return res; + } + + /** Unsqueeze tensor to a new rank */ + TView unsqueeze(size_t new_rank) const { + ICHECK_LE(new_rank, BNNS_MAX_TENSOR_DIMENSION); + auto rank = Tensor::getRank(view_desc_); + ICHECK_GT(new_rank, rank); + std::vector axes(new_rank - rank); + std::iota(axes.begin(), axes.end(), rank); + return expand_dims(axes); + } + /** Construct new TView with specified layout if it applicable */ TView with_layout(BNNSDataLayout layout) const { ICHECK_EQ(Tensor::getRank(view_desc_), Tensor::getRank(layout)); @@ -285,12 +316,9 @@ class TView { class Primitive { public: Primitive(const std::vector fs, const TView& src, const TView& dst) - : filters(fs), src_view(src), src2_view(), dst_view(dst) {} + : filters(fs), src_view(src), dst_view(dst) {} - Primitive(const std::vector fs, const TView& src, const TView& src2, const TView& dst) - : filters(fs), src_view(src), src2_view(src2), dst_view(dst) {} - - ~Primitive() { + virtual ~Primitive() { for (auto& filter : filters) if (filter) { BNNSFilterDestroy(filter); @@ -299,7 +327,7 @@ class Primitive { } /** Execute primitive with using specified src/dst */ - void execute() { + virtual void execute() { auto res = TVMBackendParallelLaunch(run_task, this, filters.size()); ICHECK_EQ(res, 0) << "BNNS runtime. Primitive was not executed properly"; } @@ -310,8 +338,6 @@ class Primitive { const auto filter = prim->filters[task_id]; const auto src_view = prim->src_view[task_id]; const auto dst_view = prim->dst_view[task_id]; - TView src2_view; - if (prim->src2_view) src2_view = prim->src2_view[task_id]; size_t mb = src_view.get_batch_size(); @@ -320,22 +346,85 @@ class Primitive { // BNNSFilterApply doesn't work for grouped convolution. // * Group convolution doesn't support arbitrary stride for Batch dim. // The tensor should be dense. - auto sts = - (prim->src2_view) - ? BNNSFilterApplyTwoInputBatch(filter, mb, src_view.get_data_hdl(), - src_view.get_stride(), src2_view.get_data_hdl(), - src2_view.get_stride(), dst_view.get_data_hdl(), - dst_view.get_stride()) - : BNNSFilterApplyBatch(filter, mb, src_view.get_data_hdl(), src_view.get_stride(), - dst_view.get_data_hdl(), dst_view.get_stride()); + auto sts = BNNSFilterApplyBatch(filter, mb, src_view.get_data_hdl(), src_view.get_stride(), + dst_view.get_data_hdl(), dst_view.get_stride()); return sts; } - private: + protected: /** BNNS kernels/filters collect which will execute primitive */ std::vector filters = {}; - const TView src_view, src2_view; + const TView src_view; const TView dst_view; + bool isNormalization; +}; + +/** + * Wrapper on top of BNNS::Primitive + * + * This primitive should be used for executing primitive with two inputs. + */ +class TwoInputPrimitive : public Primitive { + public: + TwoInputPrimitive(const std::vector fs, const TView& src, const TView& src2, + const TView& dst) + : Primitive(fs, src, dst), src2_view(src2) {} + + /** Execute primitive with using specified src/dst */ + void execute() override { + auto res = TVMBackendParallelLaunch(run_task, this, filters.size()); + ICHECK_EQ(res, 0) << "BNNS runtime. Primitive was not executed properly"; + } + + private: + static int run_task(int task_id, TVMParallelGroupEnv* penv, void* cdata) { + auto prim = reinterpret_cast(cdata); + const auto filter = prim->filters[task_id]; + const auto src_view = prim->src_view[task_id]; + TView src2_view = prim->src2_view[task_id]; + const auto dst_view = prim->dst_view[task_id]; + + size_t mb = src_view.get_batch_size(); + + auto sts = BNNSFilterApplyTwoInputBatch(filter, mb, src_view.get_data_hdl(), + src_view.get_stride(), src2_view.get_data_hdl(), + src2_view.get_stride(), dst_view.get_data_hdl(), + dst_view.get_stride()); + return sts; + } + protected: + const TView src2_view; +}; + +/** + * Wrapper on top of BNNS::Primitive + * + * This primitive should be used for executing normalization filter + */ +class NormPrimitive : public Primitive { + public: + NormPrimitive(const std::vector fs, const TView& src, const TView& dst) + : Primitive(fs, src, dst) {} + + /** Execute primitive with using specified src/dst */ + void execute() override { + auto res = TVMBackendParallelLaunch(run_task, this, filters.size()); + ICHECK_EQ(res, 0) << "BNNS runtime. Primitive was not executed properly"; + } + + private: + static int run_task(int task_id, TVMParallelGroupEnv* penv, void* cdata) { + auto prim = reinterpret_cast(cdata); + const auto filter = prim->filters[task_id]; + const auto src_view = prim->src_view[task_id]; + const auto dst_view = prim->dst_view[task_id]; + + size_t mb = src_view.get_batch_size(); + auto sts = BNNSNormalizationFilterApplyBatch(filter, mb, src_view.get_data_hdl(), + src_view.get_stride(), dst_view.get_data_hdl(), + dst_view.get_stride(), false); + return sts; + } }; /** diff --git a/tests/python/contrib/test_bnns/test_normalization.py b/tests/python/contrib/test_bnns/test_normalization.py new file mode 100644 index 000000000000..263996e42ba1 --- /dev/null +++ b/tests/python/contrib/test_bnns/test_normalization.py @@ -0,0 +1,208 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""BNNS integration normalization tests.""" + +import numpy as np +import math + +import tvm +from tvm import relay +from tvm import testing +from .infrastructure import ( + Device, + skip_runtime_test, + skip_codegen_test, + verify_codegen, + build_and_run, + verify, + generate_trials, +) + + +def _get_model(shape, b_shape, s_shape, dtype, var_names, axis=1, epsilon=1e-5, center=True, scale=True): + """Return a model and any parameters it may have""" + src = relay.var(next(var_names), shape=shape, dtype=dtype) + params = {} + b = tvm.nd.array(np.random.uniform(-128, 127, b_shape).astype(dtype)) + params["b"] = b + b = relay.const(b, dtype) + s = tvm.nd.array(np.random.uniform(-128, 127, b_shape).astype(dtype)) + params["b"] = s + s = relay.const(s, dtype) + out = relay.nn.instance_norm(src, s, b, axis, epsilon, center, scale) + + return out, params + + +def _get_expected_codegen(shape, axis, center, scale, dtype, offload_on_bnns): + output_shape = shape + name = "nn.instance_norm" + + node = { + "op": "kernel", + "name": name, + "inputs": [], + "attrs": { + "num_outputs": "1", + "axis": [[str(axis)]], + "center": [[str(int(center))]], + "scale": [[str(int(scale))]], + "shape": [[list(output_shape)]], + "dtype": [[dtype]], + "epsilon": [["1.0000000000000001e-05"]], + }, + } + + inputs = [ + {"op": "input", "name": "", "attrs": {"shape": [[list(shape)]], "dtype": [[str(dtype)]]}}, + { + "op": "const", + "name": "", + "attrs": {"shape": [[[shape[axis]]]], "dtype": [[str(dtype)]]}, + }, + { + "op": "const", + "name": "", + "attrs": {"shape": [[[shape[axis]]]], "dtype": [[str(dtype)]]}, + }, + ] + + input_idx = 0 + for _ in range(len(inputs)): + node["inputs"].append([input_idx, 0, 0]) + input_idx += 1 + node["attrs"]["num_inputs"] = str(len(inputs)) + inputs.append(node) + return inputs + + +def test_normalization(): + if skip_runtime_test(): + return + + device = Device() + np.random.seed(0) + dtype = "float32" + + shapes_config = [ + [1, 2, 3, 4], + [3, 2, 3, 4], + [2, 2, 3], + [16, 32, 32], + [5, 3], + ] + axes = [-1, 0, 1, 2] + + for shape in shapes_config: + for axis in axes: + if len(shape) == 2 and axis != 0: + continue + for center in [False, True]: + for scale in [False, True]: + outputs = [] + inputs = { + "src": tvm.nd.array(np.random.uniform(-128, 127, shape).astype(dtype)), + } + func, params = _get_model( + shape, + [shape[axis]], + [shape[axis]], + dtype, + var_names=iter(inputs), + axis=axis, + center=center, + scale=scale + ) + for enable_bnns in [False, True]: + print('shape: ', shape) + print('axis: ', axis) + print('bnns_!: ', enable_bnns) + outputs.append( + build_and_run( + func, + inputs, + 1, + params, + device, + enable_bnns=enable_bnns, + )[0] + ) + + config = { + "dtype": dtype, + } + verify(outputs, atol=0.001, rtol=0.01, config=config) + + +def test_codegen_normalization(): + if skip_codegen_test(): + return + + np.random.seed(0) + + dtype = "float32" + shapes_config = [ + [1, 2, 3, 4], + [3, 2, 3, 4], + [2, 2, 3], + [16, 32, 32], + [5, 3], + ] + axes = [-1, 0, 1, 2] + + def check_normalization(rank, axis): + if rank < 3 or rank > 4: + return False + if axis == 0 and rank == 3 or axis == 1 and rank == 4: + return True + return False + + for shape in shapes_config: + for axis in axes: + if len(shape) == 2 and axis != 0: + continue + for center in [False, True]: + for scale in [False, True]: + inputs = {"src"} + + args = (shape, axis, center, scale, dtype) + + func, params = _get_model( + shape, + [shape[axis]], + [shape[axis]], + dtype, + var_names=iter(inputs), + axis=axis, + center=center, + scale=scale + ) + + offload_on_bnns = check_normalization(len(shape), axis) + if offload_on_bnns is True: + bnns_blocks = 1 + else: + bnns_blocks = 0 + exp_codegen = _get_expected_codegen(*args, offload_on_bnns) + verify_codegen(func, exp_codegen, bnns_blocks) + + +if __name__ == "__main__": + test_normalization() + test_codegen_normalization() + + From b70a89ba7db021679191b933fd3181fc66fe07d7 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Sat, 20 Feb 2021 13:12:35 +0300 Subject: [PATCH 20/27] Add fusing sigmoid activation after conv2d --- python/tvm/relay/op/contrib/bnns.py | 26 +++++++++++++++---- src/relay/backend/contrib/bnns/codegen.cc | 7 +++++ src/runtime/contrib/bnns/bnns_json_runtime.cc | 19 +++++++++----- tests/python/contrib/test_bnns/test_conv2d.py | 6 ++++- 4 files changed, 46 insertions(+), 12 deletions(-) diff --git a/python/tvm/relay/op/contrib/bnns.py b/python/tvm/relay/op/contrib/bnns.py index ddffcb74638c..049544a41bdc 100644 --- a/python/tvm/relay/op/contrib/bnns.py +++ b/python/tvm/relay/op/contrib/bnns.py @@ -146,6 +146,8 @@ def dense(expr): data_typ = args[0].checked_type if data_typ.dtype != "float32": return False + if not isinstance(args[1], tvm.relay.expr.Constant): + return False kernel_typ = args[1].checked_type if len(kernel_typ.shape) != 2 or kernel_typ.dtype != "float32": return False @@ -154,7 +156,7 @@ def dense(expr): return True -def make_conv_relu_pattern(with_bias=True, with_relu=True): +def make_conv_pattern(with_bias=True, activation="none"): """Make pattern for bnns.conv2d primitive""" data = wildcard() weight = wildcard() @@ -162,8 +164,10 @@ def make_conv_relu_pattern(with_bias=True, with_relu=True): pat = is_op("nn.conv2d")(data, weight) if with_bias: pat = is_op("add")(pat, bias) | is_op("nn.bias_add")(pat, bias) - if with_relu: + if activation == "relu": pat = is_op("nn.relu")(pat) + elif activation == "sigmoid": + pat = is_op("sigmoid")(pat) return pat @@ -232,17 +236,27 @@ def pattern_table(): """Get BNNS specific fusing patterns collection""" conv2d_bias_pat = ( "bnns.conv2d_bias", - make_conv_relu_pattern(with_bias=True, with_relu=False), + make_conv_pattern(with_bias=True), check_conv, ) conv2d_bias_relu_pat = ( "bnns.conv2d_bias_relu", - make_conv_relu_pattern(with_bias=True, with_relu=True), + make_conv_pattern(with_bias=True, activation="relu"), check_conv, ) conv2d_relu_pat = ( "bnns.conv2d_relu", - make_conv_relu_pattern(with_bias=False, with_relu=True), + make_conv_pattern(with_bias=False, activation="relu"), + check_conv, + ) + conv2d_bias_sigmoid_pat = ( + "bnns.conv2d_bias_sigmoid", + make_conv_pattern(with_bias=True, activation="sigmoid"), + check_conv, + ) + conv2d_sigmoid_pat = ( + "bnns.conv2d_sigmoid", + make_conv_pattern(with_bias=False, activation="sigmoid"), check_conv, ) dense_bias_gelu = ("bnns.dense_bias_gelu", make_dense_bias_gelu_pattern(), check_dense) @@ -250,6 +264,8 @@ def pattern_table(): bnns_patterns = [ conv2d_bias_relu_pat, conv2d_relu_pat, + conv2d_bias_sigmoid_pat, + conv2d_sigmoid_pat, conv2d_bias_pat, dense_bias_gelu, dense_bias, diff --git a/src/relay/backend/contrib/bnns/codegen.cc b/src/relay/backend/contrib/bnns/codegen.cc index 40cba3073e9a..72c32fb5b19e 100644 --- a/src/relay/backend/contrib/bnns/codegen.cc +++ b/src/relay/backend/contrib/bnns/codegen.cc @@ -94,6 +94,13 @@ class BNNSJSONSerializer : public backend::contrib::JSONSerializer { } else if (name == "bnns.conv2d_relu") { call = GetRootCall(body, 1, {"nn.conv2d", "nn.relu"}); ICHECK(call->op.as()) << "Not op node"; + } else if (name == "bnns.conv2d_bias_sigmoid") { + auto add_op_type = IsOp(body->args[0].as(), "add") ? "add" : "nn.bias_add"; + call = GetRootCall(body, 2, {"nn.conv2d", add_op_type, "sigmoid"}); + ICHECK(call->op.as()) << "Not op node"; + } else if (name == "bnns.conv2d_sigmoid") { + call = GetRootCall(body, 1, {"nn.conv2d", "sigmoid"}); + ICHECK(call->op.as()) << "Not op node"; } else if (name == "bnns.dense_bias") { call = GetRootCall(fn->body.as(), 1, {"nn.dense", "add"}); } else if (name == "bnns.dense_bias_gelu") { diff --git a/src/runtime/contrib/bnns/bnns_json_runtime.cc b/src/runtime/contrib/bnns/bnns_json_runtime.cc index 6c3e0ccb9ec2..6b260acbee0d 100644 --- a/src/runtime/contrib/bnns/bnns_json_runtime.cc +++ b/src/runtime/contrib/bnns/bnns_json_runtime.cc @@ -167,11 +167,15 @@ class BNNSJSONRuntime : public JSONRuntimeBase { if ("nn.conv2d" == op_name) { Conv2d(nid); } else if ("bnns.conv2d_relu" == op_name) { - Conv2d(nid, true, false); + Conv2d(nid, false, "relu"); } else if ("bnns.conv2d_bias_relu" == op_name) { - Conv2d(nid, true, true); + Conv2d(nid, true, "relu"); + } else if ("bnns.conv2d_sigmoid" == op_name) { + Conv2d(nid, false, "sigmoid"); + } else if ("bnns.conv2d_bias_sigmoid" == op_name) { + Conv2d(nid, true, "sigmoid"); } else if ("bnns.conv2d_bias" == op_name) { - Conv2d(nid, false, true); + Conv2d(nid, true); } else if ("nn.dense" == op_name) { Dense(nid); } else if ("bnns.dense_bias" == op_name) { @@ -196,7 +200,7 @@ class BNNSJSONRuntime : public JSONRuntimeBase { return tensors_eid_[eid]; } - void Conv2d(const size_t& nid, const bool has_relu = false, const bool has_bias = false) { + void Conv2d(const size_t& nid, const bool has_bias = false, const std::string activation_type = "none") { auto node = nodes_[nid]; // Setup attributes. @@ -239,8 +243,11 @@ class BNNSJSONRuntime : public JSONRuntimeBase { bias_view = TView::as_is(bias_t).squeeze().with_layout(BNNSDataLayoutVector); } - BNNSActivation activation = {has_relu ? BNNSActivationFunctionRectifiedLinear - : BNNSActivationFunctionIdentity}; + BNNSActivation activation = {BNNSActivationFunctionIdentity}; + if (activation_type == "relu") + activation = {BNNSActivationFunctionRectifiedLinear}; + else if (activation_type == "sigmoid") + activation = {BNNSActivationFunctionSigmoid}; BNNSLayerParametersConvolution conv_param = { src_view.get_bnns_view(), diff --git a/tests/python/contrib/test_bnns/test_conv2d.py b/tests/python/contrib/test_bnns/test_conv2d.py index a0507b6ad798..886958cf3076 100644 --- a/tests/python/contrib/test_bnns/test_conv2d.py +++ b/tests/python/contrib/test_bnns/test_conv2d.py @@ -78,6 +78,8 @@ def _get_model( if activation_type == "relu": out = relay.nn.relu(out) + elif activation_type == "sigmoid": + out = relay.op.sigmoid(out) return out, params @@ -95,7 +97,7 @@ def test_conv2d(): batches = [1, 2] groups = [1, 2] bias_kind = ["none", "add_3d", "add_4d", "bias.add"] - activation_kind = ["none", "relu"] + activation_kind = ["none", "relu", "sigmoid"] trials = generate_trials( [ kernel_hs, @@ -126,6 +128,8 @@ def test_conv2d(): bias, activation, ) in trials: + if out_channels % group != 0: + continue func, params = _get_model( shape=(batch, *input_shapes), kernel=(kernel_h, kernel_w), From 9d5945a045c7ea0ed736d7070f6ede680eae159b Mon Sep 17 00:00:00 2001 From: Alexander Peskov Date: Thu, 25 Feb 2021 00:42:07 +0300 Subject: [PATCH 21/27] min changes Signed-off-by: Alexander Peskov --- src/runtime/contrib/bnns/bnns_wrp.h | 67 +++++++++++------------------ 1 file changed, 26 insertions(+), 41 deletions(-) diff --git a/src/runtime/contrib/bnns/bnns_wrp.h b/src/runtime/contrib/bnns/bnns_wrp.h index c2b147b859e2..8ff386343f44 100644 --- a/src/runtime/contrib/bnns/bnns_wrp.h +++ b/src/runtime/contrib/bnns/bnns_wrp.h @@ -327,17 +327,16 @@ class Primitive { } /** Execute primitive with using specified src/dst */ - virtual void execute() { + void execute() { auto res = TVMBackendParallelLaunch(run_task, this, filters.size()); ICHECK_EQ(res, 0) << "BNNS runtime. Primitive was not executed properly"; } private: - static int run_task(int task_id, TVMParallelGroupEnv* penv, void* cdata) { - auto prim = reinterpret_cast(cdata); - const auto filter = prim->filters[task_id]; - const auto src_view = prim->src_view[task_id]; - const auto dst_view = prim->dst_view[task_id]; + virtual int execute_impl(int part_idx) { + const auto filter = this->filters[part_idx]; + const auto src_view = this->src_view[part_idx]; + const auto dst_view = this->dst_view[part_idx]; size_t mb = src_view.get_batch_size(); @@ -346,9 +345,13 @@ class Primitive { // BNNSFilterApply doesn't work for grouped convolution. // * Group convolution doesn't support arbitrary stride for Batch dim. // The tensor should be dense. - auto sts = BNNSFilterApplyBatch(filter, mb, src_view.get_data_hdl(), src_view.get_stride(), - dst_view.get_data_hdl(), dst_view.get_stride()); - return sts; + return BNNSFilterApplyBatch(filter, mb, src_view.get_data_hdl(), src_view.get_stride(), + dst_view.get_data_hdl(), dst_view.get_stride()); + } + + static int run_task(int task_id, TVMParallelGroupEnv* penv, void* cdata) { + auto prim = reinterpret_cast(cdata); + return prim->execute_impl(task_id); } protected: @@ -356,7 +359,6 @@ class Primitive { std::vector filters = {}; const TView src_view; const TView dst_view; - bool isNormalization; }; /** @@ -370,27 +372,19 @@ class TwoInputPrimitive : public Primitive { const TView& dst) : Primitive(fs, src, dst), src2_view(src2) {} - /** Execute primitive with using specified src/dst */ - void execute() override { - auto res = TVMBackendParallelLaunch(run_task, this, filters.size()); - ICHECK_EQ(res, 0) << "BNNS runtime. Primitive was not executed properly"; - } - private: - static int run_task(int task_id, TVMParallelGroupEnv* penv, void* cdata) { - auto prim = reinterpret_cast(cdata); - const auto filter = prim->filters[task_id]; - const auto src_view = prim->src_view[task_id]; - TView src2_view = prim->src2_view[task_id]; - const auto dst_view = prim->dst_view[task_id]; + int execute_impl(int task_id) override { + const auto filter = this->filters[task_id]; + const auto src_view = this->src_view[task_id]; + const auto src2_view = this->src2_view[task_id]; + const auto dst_view = this->dst_view[task_id]; size_t mb = src_view.get_batch_size(); - auto sts = BNNSFilterApplyTwoInputBatch(filter, mb, src_view.get_data_hdl(), + return BNNSFilterApplyTwoInputBatch(filter, mb, src_view.get_data_hdl(), src_view.get_stride(), src2_view.get_data_hdl(), src2_view.get_stride(), dst_view.get_data_hdl(), dst_view.get_stride()); - return sts; } protected: const TView src2_view; @@ -403,27 +397,18 @@ class TwoInputPrimitive : public Primitive { */ class NormPrimitive : public Primitive { public: - NormPrimitive(const std::vector fs, const TView& src, const TView& dst) - : Primitive(fs, src, dst) {} - - /** Execute primitive with using specified src/dst */ - void execute() override { - auto res = TVMBackendParallelLaunch(run_task, this, filters.size()); - ICHECK_EQ(res, 0) << "BNNS runtime. Primitive was not executed properly"; - } + using Primitive::Primitive; private: - static int run_task(int task_id, TVMParallelGroupEnv* penv, void* cdata) { - auto prim = reinterpret_cast(cdata); - const auto filter = prim->filters[task_id]; - const auto src_view = prim->src_view[task_id]; - const auto dst_view = prim->dst_view[task_id]; + int execute_impl(int task_id) { + const auto filter = this->filters[task_id]; + const auto src_view = this->src_view[task_id]; + const auto dst_view = this->dst_view[task_id]; size_t mb = src_view.get_batch_size(); - auto sts = BNNSNormalizationFilterApplyBatch(filter, mb, src_view.get_data_hdl(), - src_view.get_stride(), dst_view.get_data_hdl(), - dst_view.get_stride(), false); - return sts; + return BNNSNormalizationFilterApplyBatch(filter, mb, src_view.get_data_hdl(), + src_view.get_stride(), dst_view.get_data_hdl(), + dst_view.get_stride(), false); } }; From 30e2c0f1dd64c31a76a417da36a9dcc53bd2d484 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Wed, 24 Feb 2021 17:54:05 +0300 Subject: [PATCH 22/27] Add pooling operations to BNNS runtime Supports `nn.max_pool2d`, `nn.avg_pool2d`, `nn.global_max_pool2d` and `nn.global_avg_pool2d` operations --- python/tvm/relay/op/contrib/bnns.py | 52 ++++ src/runtime/contrib/bnns/bnns_json_runtime.cc | 81 +++++ src/runtime/contrib/bnns/bnns_wrp.h | 23 +- .../python/contrib/test_bnns/test_pooling.py | 293 ++++++++++++++++++ 4 files changed, 448 insertions(+), 1 deletion(-) create mode 100644 tests/python/contrib/test_bnns/test_pooling.py diff --git a/python/tvm/relay/op/contrib/bnns.py b/python/tvm/relay/op/contrib/bnns.py index 049544a41bdc..187fdda95e21 100644 --- a/python/tvm/relay/op/contrib/bnns.py +++ b/python/tvm/relay/op/contrib/bnns.py @@ -98,6 +98,58 @@ def _func_wrapper(expr): _register_external_op_helper("nn.batch_matmul") +@tvm.ir.register_op_attr("nn.max_pool2d", "target.bnns") +def max_pool2d_check(expr): + """Check if the nn.max_pool2d can be executed in BNNS""" + attrs, args = expr.attrs, expr.args + data_typ = args[0].checked_type + rank = len(data_typ.shape) + if rank < 3 or rank > 4 or data_typ.dtype != "float32": + return False + if attrs.layout != "NCHW": + return False + return True + + +@tvm.ir.register_op_attr("nn.avg_pool2d", "target.bnns") +def avg_pool2d_check(expr): + """Check if the nn.avg_pool2d can be executed in BNNS""" + attrs, args = expr.attrs, expr.args + data_typ = args[0].checked_type + rank = len(data_typ.shape) + if rank < 3 or rank > 4 or data_typ.dtype != "float32": + return False + if attrs.layout != "NCHW": + return False + return True + + +@tvm.ir.register_op_attr("nn.global_max_pool2d", "target.bnns") +def global_max_pool2d_check(expr): + """Check if the nn.global_max_pool2d can be executed in BNNS""" + attrs, args = expr.attrs, expr.args + data_typ = args[0].checked_type + rank = len(data_typ.shape) + if rank < 3 or rank > 4 or data_typ.dtype != "float32": + return False + if attrs.layout != "NCHW": + return False + return True + + +@tvm.ir.register_op_attr("nn.global_avg_pool2d", "target.bnns") +def global_avg_pool2d_check(expr): + """Check if the nn.global_avg_pool2d can be executed in BNNS""" + attrs, args = expr.attrs, expr.args + data_typ = args[0].checked_type + rank = len(data_typ.shape) + if rank < 3 or rank > 4 or data_typ.dtype != "float32": + return False + if attrs.layout != "NCHW": + return False + return True + + def dtype_is_supported(dtype): """Check if data type is supported by BNNS backend""" return dtype in ("", "float32") diff --git a/src/runtime/contrib/bnns/bnns_json_runtime.cc b/src/runtime/contrib/bnns/bnns_json_runtime.cc index 6b260acbee0d..271f3fa73d48 100644 --- a/src/runtime/contrib/bnns/bnns_json_runtime.cc +++ b/src/runtime/contrib/bnns/bnns_json_runtime.cc @@ -186,6 +186,14 @@ class BNNSJSONRuntime : public JSONRuntimeBase { MatMul(nid); } else if ("nn.instance_norm" == op_name) { InstanceNormalization(nid); + } else if ("nn.max_pool2d" == op_name) { + Pooling(nid, false); + } else if ("nn.avg_pool2d" == op_name) { + Pooling(nid, true); + } else if ("nn.global_max_pool2d" == op_name) { + Pooling(nid, false, true); + } else if ("nn.global_avg_pool2d" == op_name) { + Pooling(nid, true, true); } else { LOG(FATAL) << "Unsupported op: " << op_name; } @@ -440,6 +448,79 @@ class BNNSJSONRuntime : public JSONRuntimeBase { primitives_.emplace_back(std::make_shared(filters, src_view, dst_view)); } + void Pooling(const size_t& nid, bool avg_pooling, bool global = false) { + auto node = nodes_[nid]; + + auto src_entry = node.GetInputs()[0]; + auto dst_entry = JSONGraphNodeEntry(nid, 0); + + // Memory descriptions. + auto src_t = GetBNNSTensor(src_entry); + auto dst_t = GetBNNSTensor(dst_entry); + + auto src_view = TView::as_is(src_t); + auto dst_view = TView::as_is(dst_t); + size_t src_rank = Tensor::getRank(src_view.get_bnns_view()); + size_t dst_rank = Tensor::getRank(dst_view.get_bnns_view()); + ICHECK_EQ(src_rank, dst_rank); + ICHECK_LE(src_rank, 4); + if (src_rank < 4) { + src_view = src_view.unsqueeze(4); + dst_view = dst_view.unsqueeze(4); + } + src_view = src_view.extract_outer_dim().with_layout(BNNSDataLayoutImageCHW); + dst_view = dst_view.extract_outer_dim().with_layout(BNNSDataLayoutImageCHW); + BNNSActivation activation = {BNNSActivationFunctionIdentity}; + BNNSPoolingFunction pf = {BNNSPoolingFunctionMax}; + if (avg_pooling) + pf = {BNNSPoolingFunctionAverageCountExcludePadding}; + + // Setup attributes. + size_t k_height = 0; + size_t k_width = 0; + size_t y_padding = 0; + size_t x_padding = 0; + size_t y_stride = 1; + size_t x_stride = 1; + if (!global) { + std::vector pool_size = node.GetAttr>("pool_size"); + std::vector padding = node.GetAttr>("padding"); + std::vector strides = node.GetAttr>("strides"); + k_height = std::stoi(pool_size[0]); + k_width = std::stoi(pool_size[1]); + y_padding = std::stoi(padding[0]); + x_padding = std::stoi(padding[1]); + y_stride = std::stoi(strides[0]); + x_stride = std::stoi(strides[1]); + } else { + auto sv = src_view.get_bnns_view(); + k_height = sv.size[1]; + k_width = sv.size[0]; + } + + BNNSLayerParametersPooling layerParameters = {src_view.get_bnns_view(), // i_desc + dst_view.get_bnns_view(), // o_desc + {}, // bias + activation, // activation + pf, // pooling_function + k_width, // k_width + k_height, // k_height + x_stride, // x_stride + y_stride, // y_stride + 0, // x_dilation_stride + 0, // y_dilation_stride + x_padding, // x_padding + y_padding, // y_padding + {}}; // pad left, right, up, down padding + + auto common_filter_param = getCommonFilterParams(); + auto filter = BNNSFilterCreateLayerPooling(&layerParameters, &common_filter_param); + ICHECK(filter) << "BNNS primitive was not created. Unsupported attributes configuration"; + + std::vector filters{filter}; + primitives_.emplace_back(std::make_shared(filters, src_view, dst_view)); + } + BNNS::Dtype convertToBNNS(const DLDataType& dl_dtype) { if (dl_dtype.code == DLDataTypeCode::kDLFloat) { if (dl_dtype.bits == 32) return BNNSDataTypeFloat32; diff --git a/src/runtime/contrib/bnns/bnns_wrp.h b/src/runtime/contrib/bnns/bnns_wrp.h index 8ff386343f44..59276edd17dd 100644 --- a/src/runtime/contrib/bnns/bnns_wrp.h +++ b/src/runtime/contrib/bnns/bnns_wrp.h @@ -400,7 +400,7 @@ class NormPrimitive : public Primitive { using Primitive::Primitive; private: - int execute_impl(int task_id) { + int execute_impl(int task_id) override { const auto filter = this->filters[task_id]; const auto src_view = this->src_view[task_id]; const auto dst_view = this->dst_view[task_id]; @@ -412,6 +412,27 @@ class NormPrimitive : public Primitive { } }; +/** + * Wrapper on top of BNNS::Primitive + * + * This primitive should be used for executing pooling filter + */ +class PoolingPrimitive : public Primitive { + public: + using Primitive::Primitive; + + private: + int execute_impl(int task_id) override { + const auto filter = this->filters[task_id]; + const auto src_view = this->src_view[task_id]; + const auto dst_view = this->dst_view[task_id]; + + size_t mb = src_view.get_batch_size(); + return BNNSPoolingFilterApplyBatch(filter, mb, src_view.get_data_hdl(), src_view.get_stride(), + dst_view.get_data_hdl(), dst_view.get_stride(), nullptr, 0); + } +}; + /** * Function which split primitive into sub primitives to parallel execution * diff --git a/tests/python/contrib/test_bnns/test_pooling.py b/tests/python/contrib/test_bnns/test_pooling.py new file mode 100644 index 000000000000..5eff82ae440d --- /dev/null +++ b/tests/python/contrib/test_bnns/test_pooling.py @@ -0,0 +1,293 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""BNNS integration pooling tests.""" + +import numpy as np + +import tvm +from tvm import relay +from tvm import testing +from .infrastructure import ( + skip_runtime_test, + skip_codegen_test, + build_and_run, + verify, + verify_codegen, +) +from .infrastructure import Device + + +def _calculate_output_shape(shape, sizes, padding, strides): + """Calculate pooling output shape.""" + output_height = ((shape[2] - sizes[0] + padding[0] + padding[2]) / strides[0]) + 1 + output_width = ((shape[3] - sizes[1] + padding[1] + padding[3]) / strides[1]) + 1 + return 1, shape[1], int(output_height), int(output_width) + + +def _get_pooling_model( + shape, dtype, typef, sizes, strides, padding, ceil_mode, count_include_pad, var_names +): + """Return a model and any parameters it may have.""" + if len(padding) == 2: + padding = (padding[0], padding[1], padding[0], padding[1]) + out = relay.var(next(var_names), shape=shape, dtype=dtype) + + if typef == "nn.max_pool2d": + out = relay.nn.max_pool2d( + out, + pool_size=sizes, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + ) + elif typef == "nn.avg_pool2d": + out = relay.nn.avg_pool2d( + out, + pool_size=sizes, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + ) + else: + raise ValueError("Function not supported") + + return out + + +def _get_global_pooling_model(shape, dtype, typef, var_names): + """Return a model and any parameters it may have.""" + out = relay.var(next(var_names), shape=shape, dtype=dtype) + + if typef == "nn.global_max_pool2d": + out = relay.nn.global_max_pool2d(out) + elif typef == "nn.global_avg_pool2d": + out = relay.nn.global_avg_pool2d(out) + else: + raise ValueError("Function not supported") + + return out + + +def _get_expected_pooling_codegen( + shape, dtype, typef, sizes, strides, padding, ceil_mode, count_include_pad +): + if len(padding) == 2: + padding = (padding[0], padding[1], padding[0], padding[1]) + output_shape = _calculate_output_shape(shape, sizes, padding, strides) + + node = { + "op": "kernel", + "name": typef, + "inputs": [[0, 0, 0]], + "attrs": { + "num_inputs": "1", + "num_outputs": "1", + "layout": [["NCHW"]], + "shape": [[list(output_shape)]], + "dtype": [[dtype]], + "padding": [[str(p) for p in padding]], + "strides": [[str(s) for s in strides]], + "pool_size": [[str(s) for s in sizes]], + "ceil_mode": [[str(1 if ceil_mode else 0)]], + }, + } + + if typef == "nn.avg_pool2d" or typef == "nn.l2_pool2d": + node["attrs"]["count_include_pad"] = [["1" if count_include_pad else "0"]] + + input = {"op": "input", "name": "", "attrs": {"shape": [[list(shape)]], "dtype": [[dtype]]}} + return [input, node] + + +def _get_expected_global_pooling_codegen(shape, dtype, typef): + node = { + "op": "kernel", + "name": typef, + "inputs": [[0, 0, 0]], + "attrs": { + "num_inputs": "1", + "num_outputs": "1", + "layout": [["NCHW"]], + "shape": [[[1, shape[1], 1, 1]]], + "dtype": [[dtype]], + }, + } + + input = {"op": "input", "name": "", "attrs": {"shape": [[list(shape)]], "dtype": [[dtype]]}} + return [input, node] + + +def test_pooling(): + if skip_runtime_test(): + return + + device = Device() + np.random.seed(0) + + dtype = "float32" + trials = [ + ["nn.max_pool2d", (3, 3), (2, 2), (0, 0), False, False, (27, 27, 512)], + ["nn.max_pool2d", (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)], + ["nn.max_pool2d", (3, 3), (2, 2), (1, 1), True, True, (15, 15, 16)], + ["nn.max_pool2d", (2, 2), (2, 2), (0, 1), False, False, (16, 16, 16)], + ["nn.avg_pool2d", (2, 2), (2, 2), (1, 1), False, False, (16, 16, 16)], + ["nn.avg_pool2d", (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)], + ["nn.avg_pool2d", (3, 3), (2, 2), (0, 1), True, False, (15, 15, 16)], + ] + + for ( + typef, + size, + stride, + pad, + ceil_mode, + count_include_pad, + input_shape, + ) in trials: + shape = (1, *input_shape) + outputs = [] + inputs = { + "a": tvm.nd.array(np.random.uniform(-127, 128, shape).astype(dtype)), + } + + func = _get_pooling_model( + shape, dtype, typef, size, stride, pad, ceil_mode, count_include_pad, iter(inputs) + ) + + config = { + "size": size, + "stride": stride, + "shape": shape, + "pooling type": typef, + "dtype": dtype, + "padding": pad, + "ceil_mode": ceil_mode, + "count_include_pad": count_include_pad, + "inputs": inputs, + } + + params = None + for enable_bnns in [False, True]: + outputs.append( + build_and_run(func, inputs, 1, params, device, enable_bnns=enable_bnns, config=config)[0] + ) + + verify(outputs, atol=0.001, rtol=0.001, config=config) + + +def test_global_pooling(): + if skip_runtime_test(): + return + + device = Device() + np.random.seed(0) + + dtype = "float32" + + trials = [ + ["nn.global_max_pool2d", (8, 8, 16)], + ["nn.global_max_pool2d", (9, 9, 16)], + ["nn.global_max_pool2d", (8, 8, 16)], + ["nn.global_avg_pool2d", (8, 8, 16)], + ["nn.global_avg_pool2d", (8, 8, 16)], + ["nn.global_avg_pool2d", (9, 9, 16)], + ] + + for typef, input_shape in trials: + shape = (1, *input_shape) + outputs = [] + inputs = { + "a": tvm.nd.array(np.random.uniform(-127, 128, shape).astype(dtype)), + } + + func = _get_global_pooling_model(shape, dtype, typef, iter(inputs)) + config = { + "shape": shape, + "pooling type": typef, + "dtype": dtype, + } + + for enable_bnns in [False, True]: + outputs.append( + build_and_run(func, inputs, 1, None, device, enable_bnns=enable_bnns, config=config)[0] + ) + + verify(outputs, atol=0.001, rtol=0.001, config=config) + + +def test_codegen_pooling(): + if skip_codegen_test(): + return + + dtype = "float32" + + trials = [ + ["nn.max_pool2d", (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)], + ["nn.max_pool2d", (3, 3), (2, 2), (1, 1), True, True, (15, 15, 16)], + ["nn.max_pool2d", (2, 2), (2, 2), (0, 1), False, False, (16, 16, 16)], + ["nn.avg_pool2d", (2, 2), (2, 2), (1, 1), False, False, (16, 16, 16)], + ["nn.avg_pool2d", (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)], + ["nn.avg_pool2d", (3, 3), (2, 2), (0, 1), True, False, (15, 15, 16)], + ] + + for ( + typef, + size, + stride, + pad, + ceil_mode, + count_include_pad, + input_shape, + ) in trials: + shape = (1, *input_shape) + inputs = {"a"} + args = (shape, dtype, typef, size, stride, pad, False, False) + func = _get_pooling_model(*args, iter(inputs)) + exp_codegen = _get_expected_pooling_codegen(*args) + verify_codegen(func, exp_codegen, 1) + + +def test_codegen_global_pooling(): + if skip_codegen_test(): + return + + dtype = "float32" + + trials = [ + ["nn.global_max_pool2d", (8, 8, 16)], + ["nn.global_max_pool2d", (9, 9, 16)], + ["nn.global_max_pool2d", (8, 8, 16)], + ["nn.global_avg_pool2d", (8, 8, 16)], + ["nn.global_avg_pool2d", (8, 8, 16)], + ["nn.global_avg_pool2d", (9, 9, 16)], + ] + + for typef, input_shape in trials: + shape = (1, *input_shape) + inputs = {"a"} + args = (shape, dtype, typef) + func = _get_global_pooling_model(*args, iter(inputs)) + exp_codegen = _get_expected_global_pooling_codegen(*args) + verify_codegen(func, exp_codegen, 1) + + +if __name__ == "__main__": + test_pooling() + test_global_pooling() + test_codegen_pooling() + test_codegen_global_pooling() From c7f370586fa7b9815860eb6f1931cc11822a0399 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Thu, 25 Feb 2021 09:02:24 +0300 Subject: [PATCH 23/27] Fix lint --- tests/python/contrib/test_bnns/test_onnx_topologies.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/contrib/test_bnns/test_onnx_topologies.py b/tests/python/contrib/test_bnns/test_onnx_topologies.py index 87885cb8f5da..86f98eb6e8de 100644 --- a/tests/python/contrib/test_bnns/test_onnx_topologies.py +++ b/tests/python/contrib/test_bnns/test_onnx_topologies.py @@ -26,6 +26,7 @@ from tvm.relay.op.contrib.bnns import partition_for_bnns import numpy as np + pytest.importorskip("onnx") bnns_is_absent = tvm.get_global_func("relay.ext.bnns", True) is None From d444846ea7ad9916b8ae21096d6ac85348c78786 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Thu, 25 Feb 2021 09:43:44 +0300 Subject: [PATCH 24/27] Fix lint --- python/tvm/relay/op/contrib/bnns.py | 4 ++- src/runtime/contrib/bnns/bnns_json_runtime.cc | 25 +++++++++---------- src/runtime/contrib/bnns/bnns_wrp.h | 13 +++++----- .../contrib/test_bnns/test_normalization.py | 13 ++++------ .../python/contrib/test_bnns/test_pooling.py | 8 ++++-- 5 files changed, 32 insertions(+), 31 deletions(-) diff --git a/python/tvm/relay/op/contrib/bnns.py b/python/tvm/relay/op/contrib/bnns.py index 187fdda95e21..2ace502e6528 100644 --- a/python/tvm/relay/op/contrib/bnns.py +++ b/python/tvm/relay/op/contrib/bnns.py @@ -276,7 +276,9 @@ def instance_norm_check(expr): rank = len(data_typ.shape) if rank < 3 or rank > 4 or data_typ.dtype != "float32": return False - if not isinstance(args[1], tvm.relay.expr.Constant) or not isinstance(args[2], tvm.relay.expr.Constant): + if not isinstance(args[1], tvm.relay.expr.Constant) or not isinstance( + args[2], tvm.relay.expr.Constant + ): return False if attrs.axis == 0 and rank == 3 or attrs.axis == 1 and rank == 4: return True diff --git a/src/runtime/contrib/bnns/bnns_json_runtime.cc b/src/runtime/contrib/bnns/bnns_json_runtime.cc index 271f3fa73d48..87b01567cd30 100644 --- a/src/runtime/contrib/bnns/bnns_json_runtime.cc +++ b/src/runtime/contrib/bnns/bnns_json_runtime.cc @@ -208,7 +208,8 @@ class BNNSJSONRuntime : public JSONRuntimeBase { return tensors_eid_[eid]; } - void Conv2d(const size_t& nid, const bool has_bias = false, const std::string activation_type = "none") { + void Conv2d(const size_t& nid, const bool has_bias = false, + const std::string activation_type = "none") { auto node = nodes_[nid]; // Setup attributes. @@ -253,9 +254,9 @@ class BNNSJSONRuntime : public JSONRuntimeBase { BNNSActivation activation = {BNNSActivationFunctionIdentity}; if (activation_type == "relu") - activation = {BNNSActivationFunctionRectifiedLinear}; + activation = {BNNSActivationFunctionRectifiedLinear}; else if (activation_type == "sigmoid") - activation = {BNNSActivationFunctionSigmoid}; + activation = {BNNSActivationFunctionSigmoid}; BNNSLayerParametersConvolution conv_param = { src_view.get_bnns_view(), @@ -410,8 +411,8 @@ class BNNSJSONRuntime : public JSONRuntimeBase { ICHECK_EQ(src_rank, dst_rank); ICHECK_LE(src_rank, 4); if (src_rank < 4) { - src_view = src_view.unsqueeze(4); - dst_view = dst_view.unsqueeze(4); + src_view = src_view.unsqueeze(4); + dst_view = dst_view.unsqueeze(4); } src_view = src_view.extract_outer_dim().with_layout(BNNSDataLayoutImageCHW); dst_view = dst_view.extract_outer_dim().with_layout(BNNSDataLayoutImageCHW); @@ -420,11 +421,9 @@ class BNNSJSONRuntime : public JSONRuntimeBase { BNNSActivation activation = {BNNSActivationFunctionIdentity}; auto b_desc = bias_view.get_bnns_view(); - if (!center) - b_desc = {}; + if (!center) b_desc = {}; auto s_desc = scale_view.get_bnns_view(); - if (!scale) - s_desc = {}; + if (!scale) s_desc = {}; // NOTE: Axis option is ignored in BNNS. The result doesn't depends on value of axis. BNNSLayerParametersNormalization layerParameters = {src_view.get_bnns_view(), // i_desc @@ -441,7 +440,8 @@ class BNNSJSONRuntime : public JSONRuntimeBase { BNNSFilterType filter_type = BNNSInstanceNorm; auto common_filter_param = getCommonFilterParams(); - auto filter = BNNSFilterCreateLayerNormalization(filter_type, &layerParameters, &common_filter_param); + auto filter = + BNNSFilterCreateLayerNormalization(filter_type, &layerParameters, &common_filter_param); ICHECK(filter) << "BNNS primitive was not created. Unsupported attributes configuration"; std::vector filters{filter}; @@ -472,8 +472,7 @@ class BNNSJSONRuntime : public JSONRuntimeBase { dst_view = dst_view.extract_outer_dim().with_layout(BNNSDataLayoutImageCHW); BNNSActivation activation = {BNNSActivationFunctionIdentity}; BNNSPoolingFunction pf = {BNNSPoolingFunctionMax}; - if (avg_pooling) - pf = {BNNSPoolingFunctionAverageCountExcludePadding}; + if (avg_pooling) pf = {BNNSPoolingFunctionAverageCountExcludePadding}; // Setup attributes. size_t k_height = 0; @@ -511,7 +510,7 @@ class BNNSJSONRuntime : public JSONRuntimeBase { 0, // y_dilation_stride x_padding, // x_padding y_padding, // y_padding - {}}; // pad left, right, up, down padding + {}}; // pad left, right, up, down padding auto common_filter_param = getCommonFilterParams(); auto filter = BNNSFilterCreateLayerPooling(&layerParameters, &common_filter_param); diff --git a/src/runtime/contrib/bnns/bnns_wrp.h b/src/runtime/contrib/bnns/bnns_wrp.h index 59276edd17dd..b31e97e554da 100644 --- a/src/runtime/contrib/bnns/bnns_wrp.h +++ b/src/runtime/contrib/bnns/bnns_wrp.h @@ -207,9 +207,8 @@ class TView { unsqueezed_shape[axis] = 1; } for (int i = 0, orig_idx = 0; i < unsqueezed_rank; ++i) { - if (unsqueezed_shape[i] == 1) - continue; - unsqueezed_shape[i] = view_desc_.size[orig_idx++]; + if (unsqueezed_shape[i] == 1) continue; + unsqueezed_shape[i] = view_desc_.size[orig_idx++]; } std::copy(unsqueezed_shape, unsqueezed_shape + unsqueezed_rank, res.view_desc_.size); res.view_desc_.layout = Tensor::getPlainLayout(unsqueezed_rank); @@ -381,11 +380,11 @@ class TwoInputPrimitive : public Primitive { size_t mb = src_view.get_batch_size(); - return BNNSFilterApplyTwoInputBatch(filter, mb, src_view.get_data_hdl(), - src_view.get_stride(), src2_view.get_data_hdl(), - src2_view.get_stride(), dst_view.get_data_hdl(), - dst_view.get_stride()); + return BNNSFilterApplyTwoInputBatch(filter, mb, src_view.get_data_hdl(), src_view.get_stride(), + src2_view.get_data_hdl(), src2_view.get_stride(), + dst_view.get_data_hdl(), dst_view.get_stride()); } + protected: const TView src2_view; }; diff --git a/tests/python/contrib/test_bnns/test_normalization.py b/tests/python/contrib/test_bnns/test_normalization.py index 263996e42ba1..9fffd077a0dd 100644 --- a/tests/python/contrib/test_bnns/test_normalization.py +++ b/tests/python/contrib/test_bnns/test_normalization.py @@ -33,7 +33,9 @@ ) -def _get_model(shape, b_shape, s_shape, dtype, var_names, axis=1, epsilon=1e-5, center=True, scale=True): +def _get_model( + shape, b_shape, s_shape, dtype, var_names, axis=1, epsilon=1e-5, center=True, scale=True +): """Return a model and any parameters it may have""" src = relay.var(next(var_names), shape=shape, dtype=dtype) params = {} @@ -125,12 +127,9 @@ def test_normalization(): var_names=iter(inputs), axis=axis, center=center, - scale=scale + scale=scale, ) for enable_bnns in [False, True]: - print('shape: ', shape) - print('axis: ', axis) - print('bnns_!: ', enable_bnns) outputs.append( build_and_run( func, @@ -189,7 +188,7 @@ def check_normalization(rank, axis): var_names=iter(inputs), axis=axis, center=center, - scale=scale + scale=scale, ) offload_on_bnns = check_normalization(len(shape), axis) @@ -204,5 +203,3 @@ def check_normalization(rank, axis): if __name__ == "__main__": test_normalization() test_codegen_normalization() - - diff --git a/tests/python/contrib/test_bnns/test_pooling.py b/tests/python/contrib/test_bnns/test_pooling.py index 5eff82ae440d..5c3ff7f14f5f 100644 --- a/tests/python/contrib/test_bnns/test_pooling.py +++ b/tests/python/contrib/test_bnns/test_pooling.py @@ -184,7 +184,9 @@ def test_pooling(): params = None for enable_bnns in [False, True]: outputs.append( - build_and_run(func, inputs, 1, params, device, enable_bnns=enable_bnns, config=config)[0] + build_and_run( + func, inputs, 1, params, device, enable_bnns=enable_bnns, config=config + )[0] ) verify(outputs, atol=0.001, rtol=0.001, config=config) @@ -224,7 +226,9 @@ def test_global_pooling(): for enable_bnns in [False, True]: outputs.append( - build_and_run(func, inputs, 1, None, device, enable_bnns=enable_bnns, config=config)[0] + build_and_run( + func, inputs, 1, None, device, enable_bnns=enable_bnns, config=config + )[0] ) verify(outputs, atol=0.001, rtol=0.001, config=config) From 81479db404c8cb9bdc573b2e3952b09fd13022a3 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Thu, 4 Mar 2021 15:41:42 +0300 Subject: [PATCH 25/27] Apply comments --- docs/deploy/bnns.rst | 5 ++--- python/tvm/driver/tvmc/composite_target.py | 5 +++++ tests/python/contrib/test_bnns/test_dense.py | 10 +++------- tests/python/contrib/test_bnns/test_matmul.py | 6 ++---- .../contrib/test_bnns/test_normalization.py | 10 +++------- tests/python/contrib/test_bnns/test_pooling.py | 18 +++++------------- 6 files changed, 20 insertions(+), 34 deletions(-) diff --git a/docs/deploy/bnns.rst b/docs/deploy/bnns.rst index 8aef44e25075..a966da244dfd 100644 --- a/docs/deploy/bnns.rst +++ b/docs/deploy/bnns.rst @@ -69,8 +69,7 @@ For your convenience "partition_for_bnns" can do this for you if params dictiona .. code:: python from tvm.relay.op.contrib.bnns import partition_for_bnns - with tvm.transform.PassContext(opt_level=3): - model = partition_for_bnns(model, params=params) + model = partition_for_bnns(model, params=params) Input data layout for operations to be offloaded to BNNS execution @@ -129,8 +128,8 @@ After that you need to compile new module with target corresponding to required # target for macOS Big Sur 11.1: target = "llvm -mtriple=x86_64-apple-darwin20.2.0" + model = partition_for_bnns(model, params=params) # to markup operations to be offloaded to BNNS with tvm.transform.PassContext(opt_level=3): - model = partition_for_bnns(model, params=params) # to markup operations to be offloaded to BNNS lib = relay.build(model, target=target, target_host=target, params=params) Export the module. diff --git a/python/tvm/driver/tvmc/composite_target.py b/python/tvm/driver/tvmc/composite_target.py index 0a2592685646..886160ad000c 100644 --- a/python/tvm/driver/tvmc/composite_target.py +++ b/python/tvm/driver/tvmc/composite_target.py @@ -21,6 +21,7 @@ from tvm.relay.op.contrib.arm_compute_lib import partition_for_arm_compute_lib from tvm.relay.op.contrib.ethosn import partition_for_ethosn +from tvm.relay.op.contrib.bnns import partition_for_bnns from .common import TVMCException @@ -40,6 +41,10 @@ "config_key": "relay.ext.ethos-n.options", "pass_pipeline": partition_for_ethosn, }, + "bnns": { + "config_key": None, + "pass_pipeline": partition_for_bnns, + }, } diff --git a/tests/python/contrib/test_bnns/test_dense.py b/tests/python/contrib/test_bnns/test_dense.py index 60280b1daea6..c2cf9bf71373 100644 --- a/tests/python/contrib/test_bnns/test_dense.py +++ b/tests/python/contrib/test_bnns/test_dense.py @@ -18,7 +18,7 @@ import numpy as np import math - +import pytest import tvm from tvm import relay from .infrastructure import ( @@ -107,10 +107,8 @@ def _get_expected_codegen(shape, weight_shape, units, dtype, has_bias=False, has return inputs +@pytest.mark.skipif(skip_runtime_test(), reason="Skip because BNNS codegen is not available") def test_dense(): - if skip_runtime_test(): - return - device = Device() np.random.seed(0) @@ -160,10 +158,8 @@ def test_dense(): verify(outputs, atol=0.001, rtol=0.01, config=config) +@pytest.mark.skipif(skip_codegen_test(), reason="Skip because BNNS codegen is not available") def test_codegen_dense(): - if skip_codegen_test(): - return - np.random.seed(0) dtype = ["float32"] diff --git a/tests/python/contrib/test_bnns/test_matmul.py b/tests/python/contrib/test_bnns/test_matmul.py index 408cd3b1dc6c..7bf4d48f8e88 100644 --- a/tests/python/contrib/test_bnns/test_matmul.py +++ b/tests/python/contrib/test_bnns/test_matmul.py @@ -18,7 +18,7 @@ import numpy as np import math - +import pytest import tvm from tvm import relay from tvm import testing @@ -50,10 +50,8 @@ def _get_model(a_shape, b_shape, dtype, var_names, is_a_constant=False, is_b_con return out, params +@pytest.mark.skipif(skip_runtime_test(), reason="Skip because BNNS codegen is not available") def test_matmul(): - if skip_runtime_test(): - return - device = Device() np.random.seed(0) dtype = "float32" diff --git a/tests/python/contrib/test_bnns/test_normalization.py b/tests/python/contrib/test_bnns/test_normalization.py index 9fffd077a0dd..094cfb041c3c 100644 --- a/tests/python/contrib/test_bnns/test_normalization.py +++ b/tests/python/contrib/test_bnns/test_normalization.py @@ -18,7 +18,7 @@ import numpy as np import math - +import pytest import tvm from tvm import relay from tvm import testing @@ -92,10 +92,8 @@ def _get_expected_codegen(shape, axis, center, scale, dtype, offload_on_bnns): return inputs +@pytest.mark.skipif(skip_runtime_test(), reason="Skip because BNNS codegen is not available") def test_normalization(): - if skip_runtime_test(): - return - device = Device() np.random.seed(0) dtype = "float32" @@ -147,10 +145,8 @@ def test_normalization(): verify(outputs, atol=0.001, rtol=0.01, config=config) +@pytest.mark.skipif(skip_codegen_test(), reason="Skip because BNNS codegen is not available") def test_codegen_normalization(): - if skip_codegen_test(): - return - np.random.seed(0) dtype = "float32" diff --git a/tests/python/contrib/test_bnns/test_pooling.py b/tests/python/contrib/test_bnns/test_pooling.py index 5c3ff7f14f5f..77a78d4bf7e1 100644 --- a/tests/python/contrib/test_bnns/test_pooling.py +++ b/tests/python/contrib/test_bnns/test_pooling.py @@ -17,7 +17,7 @@ """BNNS integration pooling tests.""" import numpy as np - +import pytest import tvm from tvm import relay from tvm import testing @@ -132,10 +132,8 @@ def _get_expected_global_pooling_codegen(shape, dtype, typef): return [input, node] +@pytest.mark.skipif(skip_runtime_test(), reason="Skip because BNNS codegen is not available") def test_pooling(): - if skip_runtime_test(): - return - device = Device() np.random.seed(0) @@ -192,10 +190,8 @@ def test_pooling(): verify(outputs, atol=0.001, rtol=0.001, config=config) +@pytest.mark.skipif(skip_runtime_test(), reason="Skip because BNNS codegen is not available") def test_global_pooling(): - if skip_runtime_test(): - return - device = Device() np.random.seed(0) @@ -234,10 +230,8 @@ def test_global_pooling(): verify(outputs, atol=0.001, rtol=0.001, config=config) +@pytest.mark.skipif(skip_codegen_test(), reason="Skip because BNNS codegen is not available") def test_codegen_pooling(): - if skip_codegen_test(): - return - dtype = "float32" trials = [ @@ -266,10 +260,8 @@ def test_codegen_pooling(): verify_codegen(func, exp_codegen, 1) +@pytest.mark.skipif(skip_codegen_test(), reason="Skip because BNNS codegen is not available") def test_codegen_global_pooling(): - if skip_codegen_test(): - return - dtype = "float32" trials = [ From eaffeab7560ecf9e9723d1ecbc21362f3d139172 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Sat, 6 Mar 2021 09:33:23 +0300 Subject: [PATCH 26/27] Fix documentation --- docs/deploy/bnns.rst | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/deploy/bnns.rst b/docs/deploy/bnns.rst index a966da244dfd..cb15a4f3bd54 100644 --- a/docs/deploy/bnns.rst +++ b/docs/deploy/bnns.rst @@ -16,7 +16,7 @@ under the License. Relay BNNS Integration -========================== +====================== **Author**: `Egor Churaev `_ Introduction @@ -37,7 +37,7 @@ This guide will demonstrate how to build TVM with BNNS codegen and runtime enabl code to compile and run models using BNNS runtime. Finally, we document the supported operators. Building TVM with BNNS support ----------------------------------- +------------------------------ To turn on TVM BNNS codegen and TVM BNNS runtime you need to turn on the only USE_BNNS flag @@ -54,15 +54,15 @@ Example setting in config.cmake file: set(USE_BNNS ON) BNNS partitioning of Relay graph ----------------------------------------- +-------------------------------- Operations to be offloaded on BNNS execution must be annotated before passing of module for compilation. -All opps annotated by `partition_for_bnns` will be offloaded for BNNS execution. The rest of the ops +All ops annotated by `partition_for_bnns` will be offloaded for BNNS execution. The rest of the ops will go through the LLVM compilation and code generation. Important note: BNNS support primitives only with constant weights. To satisfy this requirements we have to map constants to related tensor abstraction in relay representation. To freeze tensors and operate -with them as with constants you may need to call ONNX importer with special flag "freeze_params=True" +with them as constants you may need to call ONNX importer with special flag "freeze_params=True" or performer binding manually. In general cases all relay importers don't do that by default. For your convenience "partition_for_bnns" can do this for you if params dictionary is passed as the argument. @@ -73,7 +73,7 @@ For your convenience "partition_for_bnns" can do this for you if params dictiona Input data layout for operations to be offloaded to BNNS execution ----------------------------------------- +------------------------------------------------------------------ BNNS kernels support only planar format of input data. The partitioner will require to have NCHW input layout for conv2d input. @@ -98,7 +98,7 @@ Example of input layouts change: Example: Build and Deploy Mobilenet v2 1.0 with BNNS ----------------------------------------- +---------------------------------------------------- Create a Relay graph from a MXNet Mobilenet v2 1.0 model. From 0103e70f02954b18dc0ab95e51dfa3e83668d122 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Wed, 10 Mar 2021 07:49:45 +0300 Subject: [PATCH 27/27] Fix comment to refer to BNNS --- tests/python/contrib/test_bnns/infrastructure.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python/contrib/test_bnns/infrastructure.py b/tests/python/contrib/test_bnns/infrastructure.py index e407e21fd868..0107de54a04f 100644 --- a/tests/python/contrib/test_bnns/infrastructure.py +++ b/tests/python/contrib/test_bnns/infrastructure.py @@ -37,9 +37,9 @@ class Device: Common device configuration for python tests. Check tests/python/contrib/arm_compute_lib/ for the presence of an test_config.json file. - This file can be used to override the default configuration here which will attempt to run the Arm - Compute Library runtime tests locally if the runtime is available. Changing the configuration - will allow these runtime tests to be offloaded to a remote Arm device via a tracker for example. + This file can be used to override the default configuration here which will attempt to run the BNNS + runtime tests locally if the runtime is available. Changing the configuration will allow these + runtime tests to be offloaded to a remote device with BNNS via a tracker for example. Notes ----- @@ -98,7 +98,7 @@ def load(cls, file_name): """Load test config Load the test configuration by looking for file_name relative - to the test_arm_compute_lib directory. + to the test_bnns directory. """ location = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) config_file = os.path.join(location, file_name)