diff --git a/python/tvm/relay/op/strategy/hexagon.py b/python/tvm/relay/op/strategy/hexagon.py index 13c808f96b95..693352d650ba 100644 --- a/python/tvm/relay/op/strategy/hexagon.py +++ b/python/tvm/relay/op/strategy/hexagon.py @@ -30,7 +30,7 @@ def batch_matmul_strategy_hexagon(attrs, inputs, out_type, target): """batch_matmul strategy for Hexagon""" strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_batch_matmul(topi.nn.batch_matmul), + wrap_compute_batch_matmul(topi.nn.batch_matmul, need_out_dtype=True), wrap_topi_schedule(topi.hexagon.schedule_batch_matmul), name="batch_matmul.hexagon", ) @@ -187,3 +187,38 @@ def schedule_reduce_hexagon(attrs, outs, target): """Schedule reduction ops for Hexagon""" with target: return topi.hexagon.schedule_reduce(outs) + + +@conv2d_NCHWc_strategy.register("hexagon") +def conv2d_NCHWc_strategy_hexagon(attrs, inputs, out_type, target): + """conv2d_NCHWc_ hexagon strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_conv2d( + topi.hexagon.conv2d_NCHWc_int8, need_data_layout=True, need_out_layout=True + ), + wrap_topi_schedule(topi.hexagon.schedule_conv2d_NCHWc_int8), + name="conv2d_NCHWc_int8.hexagon", + ) + return strategy + + +@dense_pack_strategy.register("hexagon") +def dense_pack_strategy_hexagon(attrs, inputs, out_type, target): + """dense_pack hexagon strategy""" + strategy = _op.OpStrategy() + + if ( + inputs[0].dtype == "uint8" + and inputs[1].dtype == "uint8" + and out_type.dtype == "int32" + and attrs["weight_layout"] == "NC32n4c" + ): + strategy.add_implementation( + wrap_compute_dense(topi.hexagon.dense.dense_u8u8i32_vrmpy_compute), + wrap_topi_schedule(topi.hexagon.dense.dense_u8u8i32_vrmpy_schedule), + name="dense_uint8.hexagon", + plevel=12, + ) + + return strategy diff --git a/python/tvm/topi/generic/conv2d.py b/python/tvm/topi/generic/conv2d.py index 48b2a2f97146..76cd9a7d69d1 100644 --- a/python/tvm/topi/generic/conv2d.py +++ b/python/tvm/topi/generic/conv2d.py @@ -139,7 +139,16 @@ def schedule_conv_NCHWc_cpu_common_int8( More details - https://software.intel.com/en-us/articles/ lower-numerical-precision-deep-learning-inference-and-training """ - reg_n, unroll_kw = cfg["tile_ow"].size[-1], cfg["unroll_kw"].val + if isinstance(cfg["tile_ow"], int): + reg_n = cfg["tile_ow"] + else: + reg_n = cfg["tile_ow"].size[-1] + + if isinstance(cfg["unroll_kw"], (int, bool)): + unroll_kw = cfg["unroll_kw"] + else: + unroll_kw = cfg["unroll_kw"].val + _, _, _, _, ic_bn = get_const_tuple(data_vec.shape) _, _, _, _, oc_bn = get_const_tuple(conv_out.shape) diff --git a/python/tvm/topi/hexagon/__init__.py b/python/tvm/topi/hexagon/__init__.py index 295152d11631..b94526e5b919 100644 --- a/python/tvm/topi/hexagon/__init__.py +++ b/python/tvm/topi/hexagon/__init__.py @@ -29,3 +29,5 @@ from .resize2d import * from .tensor_intrin import * from .qnn import * +from .dense_alter_op import * +from .conv2d_alter_op import * diff --git a/python/tvm/topi/hexagon/conv2d.py b/python/tvm/topi/hexagon/conv2d.py index d8f44d663843..aa1b7e57e464 100644 --- a/python/tvm/topi/hexagon/conv2d.py +++ b/python/tvm/topi/hexagon/conv2d.py @@ -14,11 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +# pylint: disable=invalid-name """Schedule for conv2d""" import tvm +from tvm import te +from .. import nn from ..utils import traverse_inline +from .tensor_intrin import dot_vrmpy +from ..generic import conv2d as conv2d_generic def schedule_conv2d_nhwc(outs): @@ -86,3 +90,46 @@ def _callback(op): traverse_inline(s, outs[0].op, _callback) return s + + +def conv2d_NCHWc_int8( + data, kernel, stride, padding, dilation, layout, out_layout, out_dtype="int32" +): + """Compute definition for int8 conv2d in NCHWc layout""" + n_elems = int(kernel.shape[-1]) + return nn.conv2d_NCHWc_int8( + data, kernel, stride, padding, dilation, layout, out_layout, out_dtype, n_elems=n_elems + ) + + +def schedule_conv2d_NCHWc_int8(outs): + """Schedule for int8 conv2d in NCHWc layout using vrmpy tensorization""" + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if "conv2d_NCHWc_int8" in op.tag: + conv_out = op.output(0) + kernel_vec = conv_out.op.input_tensors[1] + data_vec = conv_out.op.input_tensors[0] + out_width = conv_out.shape[3] + + reg_n = 1 + for n in range(31, 0, -1): + if out_width % n == 0: + reg_n = n + break + + cfg = {"tile_ow": reg_n, "unroll_kw": False} + args = [s, cfg, data_vec, kernel_vec, conv_out, outs[0]] + intrin = dot_vrmpy(data_vec.dtype, kernel_vec.dtype) + + conv2d_generic.schedule_conv_NCHWc_cpu_common_int8( + *args, + int32_lanes=32, + int8_elems=4, + intrin=intrin, + inline_fused=True, + ) + + traverse_inline(s, outs[0].op, _callback) + return s diff --git a/python/tvm/topi/hexagon/conv2d_alter_op.py b/python/tvm/topi/hexagon/conv2d_alter_op.py new file mode 100644 index 000000000000..201b6f804352 --- /dev/null +++ b/python/tvm/topi/hexagon/conv2d_alter_op.py @@ -0,0 +1,111 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument,no-member +"""Conv2d alter op functions for Hexagon""" + +from tvm import relay +from ..utils import get_const_tuple +from .. import nn +from ..nn import conv2d_alter_layout +from ..generic.conv2d import conv2d_alter_int8_common + + +@conv2d_alter_layout.register("hexagon") +def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): + """Convert nn.conv2d into nn.contrib_conv2d_nchwc if vrmpy is applicable.""" + new_attrs = {k: attrs[k] for k in attrs.keys()} + + data_layout = attrs["data_layout"] + kernel_layout = attrs["kernel_layout"] + data_tensor, kernel_tensor = tinfos + out_channel, in_channel, _, _ = get_const_tuple(kernel_tensor.shape) + + if ( + "int8" in data_tensor.dtype + and "int8" in kernel_tensor.dtype + and out_channel % 32 == 0 + and in_channel % 4 == 0 + and data_layout == "NCHW" + and kernel_layout == "OIHW" + ): + out_channel, in_channel, _, _ = get_const_tuple(kernel_tensor.shape) + + n_elems = 4 + oc_bn = 32 + ic_bn = min(in_channel, 32) + + new_attrs = {k: attrs[k] for k in attrs.keys()} + + new_attrs["channels"] = out_channel + new_attrs["data_layout"] = "NCHW%dc" % ic_bn + new_attrs["kernel_layout"] = "OIHW{:n}i{:n}o{:n}i".format(ic_bn // n_elems, oc_bn, n_elems) + new_attrs["out_layout"] = "NCHW%dc" % oc_bn + + return relay.nn.contrib_conv2d_nchwc(*inputs, **new_attrs) + + return None + + +@nn.conv2d_legalize.register("hexagon") +def _conv2d_legalize(attrs, inputs, arg_types): + """Legalize conv2d op for vrmpy tensorization. + + If the inputs are signed or unsigned int8, the input and output channels are padded to be + a multiple of 4 and 32 respectively. + + If the input data types are (int8, int8), they are converted to (uint8, int8) and + the vector-by-vector variant of vrmpy is applied. + If the input data types are (uint8, uint8), the more efficient vector-by-scalar variant of vrmpy + is applied. + + Unlike the nn.dense case (see dense_alter_op.py), we do not convert (uint8, int8) to + (uint8, uint8). That would introduce another convolution by a constant (128 or 1) filter, + to compensate for the dtype legalization. In the nn.dense case, such compensation factor is + just a sum over the K axis. + """ + data_layout = attrs["data_layout"] + kernel_layout = attrs["kernel_layout"] + + output_tensor = arg_types[2] + + data, kernel = inputs + + if data_layout != "NCHW" or kernel_layout != "OIHW": + return None + + data_tensor, kernel_tensor = arg_types[0], arg_types[1] + + if "int8" in data_tensor.dtype and "int8" in data_tensor.dtype: + output_tensor = arg_types[2] + data, kernel = inputs + desired_data_dtype = "uint8" + in_channel_vector_length = 4 + out_channel_vector_length = 32 + + return conv2d_alter_int8_common( + data, + data_tensor, + kernel, + kernel_tensor, + output_tensor, + attrs, + desired_data_dtype, + in_channel_vector_length, + out_channel_vector_length, + ) + + return None diff --git a/python/tvm/topi/hexagon/dense.py b/python/tvm/topi/hexagon/dense.py index afe53f515fa9..02ad141ecb5a 100644 --- a/python/tvm/topi/hexagon/dense.py +++ b/python/tvm/topi/hexagon/dense.py @@ -14,10 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +# pylint: disable=invalid-name """Schedule for dense operator""" import tvm +from tvm.topi.utils import traverse_inline +from tvm import te +from .. import tag +from .tensor_intrin import dot_vrmpy def schedule_dense(outs): @@ -38,3 +42,70 @@ def schedule_dense(outs): s = tvm.te.create_schedule([x.op for x in outs]) tvm.te.schedule.AutoInlineInjective(s) return s + + +def dense_u8u8i32_vrmpy_compute(X, packed_w, bias, out_dtype): + """Compute for uint8 x uint8 -> int32 dense using vrmpy""" + assert X.dtype == "uint8" and packed_w.dtype == "uint8" and out_dtype == "int32" + m, k = X.shape + n_o, _, n_i, _ = packed_w.shape + assert n_i == 32 + ak = te.reduce_axis((0, k), name="k") + + C = te.compute( + (m, n_o * n_i), + lambda i, j: te.sum( + X[i, ak].astype("int32") + * packed_w[tvm.tir.indexdiv(j, 32), tvm.tir.indexdiv(ak, 4), j % 32, ak % 4].astype( + "int32" + ), + axis=ak, + ), + tag="dense_u8u8i32_vrmpy", + name="compute", + ) + + if bias is not None: + C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j], tag=tag.BROADCAST) + + return C + + +def dense_u8u8i32_vrmpy_schedule(outs): + """Schedule for vrmpy dense""" + s = te.create_schedule([x.op for x in outs]) + # O: The output of the fused op + O = outs[0] + + def _schedule_dense(s, C, O): + (a_k,) = C.op.reduce_axis + a_y = C.op.axis[-2] + a_yo, a_yi = s[C].split(a_y, factor=32) + a_xo, a_xi = s[C].split(C.op.axis[-1], factor=32) + a_ko, a_ki = s[C].split(a_k, factor=4) + + s[C].reorder(a_yo, a_xo, a_yi, a_ko, a_xi, a_ki) + + pc = dot_vrmpy("uint8", "uint8") + s[C].tensorize(a_xi, pc) + s[C].parallel(s[C].fuse(a_yo, a_xo)) + + if C != O: + a_y = O.op.axis[-2] + a_yo, a_yi = s[O].split(a_y, factor=32) + a_xo, a_xi = s[O].split(O.op.axis[-1], factor=32) + + s[O].reorder(a_yo, a_xo, a_yi, a_xi) + s[O].vectorize(a_xi) + s[C].compute_at(s[O], a_yi) + s[O].parallel(s[O].fuse(a_yo, a_xo)) + + def _callback(op): + if "u8u8i32_vrmpy" in op.tag: + # C: The output of GEMM + C = op.output(0) + _schedule_dense(s, C, O) + + traverse_inline(s, outs[0].op, _callback) + + return s diff --git a/python/tvm/topi/hexagon/dense_alter_op.py b/python/tvm/topi/hexagon/dense_alter_op.py new file mode 100644 index 000000000000..cb5feb56d68e --- /dev/null +++ b/python/tvm/topi/hexagon/dense_alter_op.py @@ -0,0 +1,147 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument,no-member +"""Dense alter op functions for ARM""" + +import tvm +from tvm import relay +from .. import nn +from ..nn import dense_alter_layout + + +def check_vrmpy_applicable(x, y): + return ( + "int8" in x.dtype and "int8" in y.dtype and y.shape[-2] % 32 == 0 and y.shape[-1] % 4 == 0 + ) + + +@dense_alter_layout.register(["hexagon"]) +def _alter_dense_layout(attrs, inputs, tinfos, out_type): + data_tensor, weight_tensor = tinfos + out_dtype = out_type.dtype + + if check_vrmpy_applicable(data_tensor, weight_tensor): + weight_layout = "NC32n4c" + return relay.nn.contrib_dense_pack(inputs[0], inputs[1], weight_layout, None, out_dtype) + else: + return None + + +def vrmpy_legalize(x, w, arg_types, op, attrs): + """ + Legalizes int8 inputs to dense for vrmpy. + X'_u8 = X_s8 + 128 + X_s8 * W_s8 = (X'_u8 - 128) * (W'_u8 - 128) + = X'_u8 * W'_u8 - X'_u8 * 128 - 128 * W'_u8 + 128 * 128 + X_u8 * W_s8 = X_u8 * (W'_u8 - 128) + = X'_u8 * W'_u8 - X_u8 * 128 + """ + if not check_vrmpy_applicable(arg_types[0], arg_types[1]): + return None + + def cast_to_uint8(x): + x = relay.cast(x, "int32") + x = relay.add(x, relay.const(128, "int32")) + return relay.cast(x, "uint8") + + if arg_types[0].dtype == "int8" and arg_types[1].dtype == "int8": + x = cast_to_uint8(x) + w = cast_to_uint8(w) + + W_u8x128 = relay.const(-128, "int32") * relay.sum(relay.cast(w, "int32"), axis=[-1]) + X_u8x128 = relay.const(-128, "int32") * relay.sum(relay.cast(x, "int32"), axis=[-1]) + X_u8x128 = relay.expand_dims(X_u8x128, axis=1) + + out = op(x, w, **attrs) + + out += W_u8x128 + out += X_u8x128 + + k_dim = int(arg_types[0].shape[-1]) + return out + relay.const(128 * 128 * k_dim, "int32") + + if arg_types[0].dtype == "uint8" and arg_types[1].dtype == "int8": + w = cast_to_uint8(w) + + X_u8x128 = relay.expand_dims( + relay.const(-128, "int32") * relay.sum(relay.cast(x, "int32"), axis=[-1]), axis=1 + ) + + out = op(x, w, **attrs) + + return out + X_u8x128 + + return None + + +@nn.dense_legalize.register("hexagon") +def _dense_legalize(attrs, inputs, arg_types): + """Legalize dense op for HVX vectorization and vrmpy tensorization. + + Given a workload with a matrix X of shape (M, K) and a matrix Y of (N, K), + we first pad the N dimension to be a multiple of the output vector length. + + And if the inputs are signed or unsigned int8 and the Y matrix can be packed into the + NK32n4k layout, we convert both inputs to uint8 to apply the most efficient variant of vrmpy. + """ + new_attrs = {k: attrs[k] for k in attrs.keys()} + # Collect the input tensors. + x_tensor, y_tensor = arg_types[0], arg_types[1] + dtype = x_tensor.dtype + + # Collect the output tensor. + output_tensor = arg_types[2] + + # Collect the input exprs. + x, y = inputs + + N, _ = y_tensor.shape + + if dtype == "float16": + vec_len = 64 + elif "int8" in dtype: + vec_len = 32 + else: + return None + + if N % vec_len != 0: + N_padded = ((N + vec_len) // vec_len) * vec_len + dn = N_padded - N + + y_ = relay.nn.pad(y, pad_width=((0, dn), (0, 0))) + + # If units is explicitly specified, it is used to compute the output shape. + # We need to update units after padding to prevent a type error. + if attrs["units"] is not None: + new_attrs["units"] = N + dn + + arg_types = [ + arg_types[0], + tvm.ir.tensor_type.TensorType([N + dn, arg_types[1].shape[1]], arg_types[1].dtype), + ] + + vrmpy_out = vrmpy_legalize(x, y_, arg_types, relay.nn.dense, new_attrs) + + if vrmpy_out is None: + out_ = relay.nn.dense(x, y_, **new_attrs) + else: + out_ = vrmpy_out + + out = relay.strided_slice(out_, begin=[0, 0], end=[x.value for x in output_tensor.shape]) + return out + + return vrmpy_legalize(inputs[0], inputs[1], arg_types, relay.nn.dense, attrs) diff --git a/python/tvm/topi/hexagon/injective.py b/python/tvm/topi/hexagon/injective.py index b1d1e1541961..bd06cb8ecd16 100644 --- a/python/tvm/topi/hexagon/injective.py +++ b/python/tvm/topi/hexagon/injective.py @@ -42,8 +42,9 @@ def schedule_injective(outs): # Fuse axes and vectorize inner elements for x in outs: fused = s[x].fuse(*x.op.axis) - _, inner = s[x].split(fused, factor=128 // np.dtype(x.dtype).itemsize) + outer, inner = s[x].split(fused, factor=128 // np.dtype(x.dtype).itemsize) s[x].vectorize(inner) + s[x].parallel(outer) return s diff --git a/python/tvm/topi/hexagon/tensor_intrin.py b/python/tvm/topi/hexagon/tensor_intrin.py index bdc63854328b..adea4690d4a7 100644 --- a/python/tvm/topi/hexagon/tensor_intrin.py +++ b/python/tvm/topi/hexagon/tensor_intrin.py @@ -14,10 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name """Optimized implementation of q_multiply_shift based on LLVM intrinsics""" import tvm from tvm.ir import register_intrin_lowering +from tvm import te def _q_multiply_shift_hexagon(op): @@ -69,3 +71,87 @@ def _q_multiply_shift_hexagon(op): register_intrin_lowering( "tir.q_multiply_shift", target="hexagon", f=_q_multiply_shift_hexagon, level=99 ) + + +def dot_vrmpy(x_ty, y_ty): + """Generates vrmpy instruciton for tensorization.""" + int32_lanes = 32 + num_int8_elements = 4 # 4 int8 elements in int32 + data = te.placeholder((num_int8_elements,), dtype=x_ty, name="data") + kernel = te.placeholder((int32_lanes, num_int8_elements), dtype=y_ty, name="kernel") + k = te.reduce_axis((0, num_int8_elements), name="k") + C = te.compute( + (int32_lanes,), + lambda i: te.sum(data[k].astype("int32") * kernel[i, k].astype("int32"), axis=k), + name="C", + ) + + a_buffer = tvm.tir.decl_buffer( + data.shape, dtype=x_ty, name="a_buffer", offset_factor=1, strides=[1] + ) + b_buffer = tvm.tir.decl_buffer( + kernel.shape, dtype=y_ty, name="b_buffer", offset_factor=1, strides=[te.var("ldw"), 1] + ) + + def _intrin_func(ins, outs): + def _instr(index): + ib = tvm.tir.ir_builder.create() + if index == 1: + ib.emit(outs[0].vstore(0, tvm.tir.const(0, "int32x32"))) + return ib.get() + + vec_zero = tvm.tir.const(0, "int32x32") + + if x_ty == "uint8" and y_ty == "uint8": + a_uint8 = ins[0].vload([0], "uint8x4") + re_int32 = tvm.tir.call_intrin("int32", "tir.reinterpret", a_uint8) + vec_b = ins[1].vload([0, 0], "uint8x128") + + vrmpy_inst_name = "llvm.hexagon.V6.vrmpyub.acc.128B" + + vec_bi32 = tvm.tir.call_intrin("int32x32", "tir.reinterpret", vec_b) + + quad_reduction = tvm.tir.call_llvm_pure_intrin( + "int32x32", + vrmpy_inst_name, + tvm.tir.const(3, "uint32"), + vec_zero, + vec_bi32, + re_int32, + ) + elif x_ty == "uint8" and y_ty == "int8": + a_uint8 = ins[0].vload([0], "uint8x4") + re_int32 = tvm.tir.call_intrin("int32", "tir.reinterpret", a_uint8) + vec_b = ins[1].vload([0, 0], "int8x128") + + vrmpy_inst_name = "llvm.hexagon.V6.vrmpybusv.acc.128B" + + vec_bi32 = tvm.tir.call_intrin("int32x32", "tir.reinterpret", vec_b) + + quad_reduction = tvm.tir.call_llvm_pure_intrin( + "int32x32", + vrmpy_inst_name, + tvm.tir.const(3, "uint32"), + vec_zero, + re_int32.astype("int32x32"), + vec_bi32, + ) + else: + raise ValueError(f"Only (u8, u8) or (u8, i8) dtype pairs are supported by vrmpy.") + + if index == 0: + ib.emit(outs[0].vstore(0, quad_reduction)) + else: + ib.emit(outs[0].vstore(0, quad_reduction + outs[0].vload([0], "int32x32"))) + return ib.get() + + # body, reset, update + return _instr(0), _instr(1), _instr(2) + + buffer_params = {"offset_factor": 1} + return te.decl_tensor_intrin( + C.op, + _intrin_func, + binds={data: a_buffer, kernel: b_buffer}, + default_buffer_params=buffer_params, + ) diff --git a/tests/python/contrib/test_hexagon/test_launcher.py b/tests/python/contrib/test_hexagon/test_launcher.py index 9321ddf71d3b..7431871524aa 100644 --- a/tests/python/contrib/test_hexagon/test_launcher.py +++ b/tests/python/contrib/test_hexagon/test_launcher.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +# pylint: disable=invalid-name,missing-function-docstring,redefined-outer-name """ Test rpc based launcher for hexagon """ +import pytest import numpy as np @@ -424,5 +425,151 @@ def test_aot_executor_multiple_conv2d(hexagon_session: Session, aot_host_target, tvm.testing.assert_allclose(hexagon_output, expected_output, rtol=1e-4, atol=1e-5) +data_dtype = tvm.testing.parameter("int8", "uint8") +weight_dtype = tvm.testing.parameter("int8", "uint8") + + +@tvm.testing.requires_hexagon +def test_conv2d_relay_vrmpy(hexagon_session, data_dtype, weight_dtype): + if data_dtype == "int8" and weight_dtype == "uint8": + pytest.skip("(i8, u8) input pair is not supported") + + def get_conv2d_nchw(d_shape, w_shape, padding, strides=(1, 1)): + out_dtype = "int32" + + data = relay.var("data", shape=d_shape, dtype=data_dtype) + weight = relay.var("weight", shape=w_shape, dtype=weight_dtype) + out_channel = w_shape[0] + return relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=w_shape[2:], + channels=out_channel, + padding=padding, + strides=strides, + out_dtype=out_dtype, + ) + + target_hexagon = tvm.target.hexagon("v68") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + I, O, H, W = 64, 256, 56, 56 + kH = kW = 3 + padding = (1, 1) + strides = (1, 1) + + data_shape = (1, I, H, W) + weight_shape = (O, I, kH, kW) + bias_shape = (weight_shape[0],) + + bias = relay.var("bias", shape=bias_shape, dtype="int32") + + conv2d = get_conv2d_nchw( + data_shape, + weight_shape, + padding, + strides=strides, + ) + bias_add = relay.nn.bias_add(conv2d, bias) + mod = tvm.IRModule.from_expr(bias_add) + + if data_dtype == "uint8": + data_np = np.random.uniform(0, 255, size=data_shape).astype("uint8") + else: + data_np = np.random.uniform(-128, 127, size=data_shape).astype("int8") + + if weight_dtype == "uint8": + weight_np = np.random.uniform(0, 255, size=weight_shape).astype("uint8") + else: + weight_np = np.random.uniform(-128, 127, size=weight_shape).astype("int8") + + bias_np = np.random.randint(low=-127, high=128, size=bias_shape).astype("int32") + params = {"weight": weight_np, "bias": bias_np} + + ref = ( + relay.create_executor("graph", mod=mod, device=tvm.cpu(0), target="llvm") + .evaluate()(*[data_np, weight_np, bias_np]) + .numpy() + ) + + with tvm.transform.PassContext( + opt_level=3, + ): + executor = relay.backend.Executor("graph", {"link-params": True}) + lib = relay.build(mod, target=target, params=params, executor=executor) + + asm = lib.lib.get_source("asm") + assert "vrmpy" in asm + + rt_mod = hexagon_session.get_executor_from_factory(lib) + + rt_mod.set_input("data", data_np) + + rt_mod.run() + + out = rt_mod.get_output(0).numpy() + + np.testing.assert_equal(out, ref) + + +@tvm.testing.requires_hexagon +def test_dense_relay_vrmpy(hexagon_session, data_dtype, weight_dtype): + if data_dtype == "int8" and weight_dtype == "uint8": + pytest.skip("(i8, u8) input pair is not supported") + + target_hexagon = tvm.target.hexagon("v68") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + + M = 128 + N = 1000 + K = 2048 + data_shape = (M, K) + weight_shape = (N, K) + + data = relay.var("data", shape=data_shape, dtype=data_dtype) + weight = relay.var("weight", shape=weight_shape, dtype=weight_dtype) + + dense = relay.nn.dense(data, weight, out_dtype="int32") + + if data_dtype == "uint8": + data_np = np.random.uniform(0, 255, size=data_shape).astype("uint8") + else: + data_np = np.random.uniform(-128, 127, size=data_shape).astype("int8") + + if weight_dtype == "uint8": + weight_np = np.random.uniform(0, 255, size=weight_shape).astype("uint8") + else: + weight_np = np.random.uniform(-128, 127, size=weight_shape).astype("int8") + + bias_np = np.random.uniform(1, 10, size=(weight_shape[0],)).astype("int32") + + params = {"weight": weight_np, "bias": bias_np} + + bias = relay.var("bias", shape=(weight_shape[0],), dtype="int32") + bias_add = relay.nn.bias_add(dense, bias) + mod = tvm.IRModule.from_expr(bias_add) + + with tvm.transform.PassContext( + opt_level=3, + ): + executor = relay.backend.Executor("graph", {"link-params": True}) + lib = relay.build(mod, target=target, params=params, executor=executor) + + asm = lib.lib.get_source("asm") + assert "vrmpy" in asm + + rt_mod = hexagon_session.get_executor_from_factory(lib) + + rt_mod.set_input("data", data_np) + + rt_mod.run() + + out = rt_mod.get_output(0).numpy() + + ref = np.dot(data_np.astype("int32"), weight_np.transpose().astype("int32")) + ref += bias_np + + np.testing.assert_equal(out, ref) + + if __name__ == "__main__": tvm.testing.main()