From e9d1ceab355aead4f9cb42c95f00f5bd4bb53166 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 21 Sep 2022 18:48:34 +0900 Subject: [PATCH 01/14] [Hexagon] Support vrmpy tensorization for conv2d and dense schedules --- python/tvm/relay/op/strategy/hexagon.py | 97 ++++- python/tvm/tir/tensor_intrin/hexagon.py | 50 +++ python/tvm/topi/hexagon/__init__.py | 2 + python/tvm/topi/hexagon/conv2d.py | 155 ++++++++ python/tvm/topi/hexagon/conv2d_alter_op.py | 180 +++++++++ python/tvm/topi/hexagon/dense.py | 73 ++++ python/tvm/topi/hexagon/dense_alter_op.py | 150 ++++++++ python/tvm/topi/hexagon/injective.py | 3 +- python/tvm/topi/hexagon/tensor_intrin.py | 84 +++++ .../contrib/test_hexagon/test_conv2d_vrmpy.py | 140 +++++++ .../contrib/test_hexagon/test_dense_vrmpy.py | 201 ++++++++++ tests/python/contrib/test_hexagon/test_qnn.py | 347 ++++++++++++++++++ 12 files changed, 1478 insertions(+), 4 deletions(-) create mode 100644 python/tvm/topi/hexagon/conv2d_alter_op.py create mode 100644 python/tvm/topi/hexagon/dense_alter_op.py create mode 100644 tests/python/contrib/test_hexagon/test_conv2d_vrmpy.py create mode 100644 tests/python/contrib/test_hexagon/test_dense_vrmpy.py create mode 100644 tests/python/contrib/test_hexagon/test_qnn.py diff --git a/python/tvm/relay/op/strategy/hexagon.py b/python/tvm/relay/op/strategy/hexagon.py index 13c808f96b95..72e837b5d156 100644 --- a/python/tvm/relay/op/strategy/hexagon.py +++ b/python/tvm/relay/op/strategy/hexagon.py @@ -30,7 +30,7 @@ def batch_matmul_strategy_hexagon(attrs, inputs, out_type, target): """batch_matmul strategy for Hexagon""" strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_batch_matmul(topi.nn.batch_matmul), + wrap_compute_batch_matmul(topi.nn.batch_matmul, need_out_dtype=True), wrap_topi_schedule(topi.hexagon.schedule_batch_matmul), name="batch_matmul.hexagon", ) @@ -62,10 +62,38 @@ def conv2d_strategy_hexagon(attrs, inputs, out_type, target): if groups == 1: if data_layout == "NHWC" and kernel_layout == "HWIO": strategy.add_implementation( - wrap_compute_conv2d(topi.nn.conv2d_nhwc), + wrap_compute_conv2d(topi.nn.conv2d_nhwc, need_meta_schedule_layout=True), wrap_topi_schedule(topi.hexagon.schedule_conv2d_nhwc), name="conv2d_nhwc.hexagon", ) + + # kernel_h, kernel_w, _, co = get_const_tuple(kernel.shape) + # stride_h, stride_w = get_const_tuple(attrs.strides) + # dilation_h, dilation_w = get_const_tuple(attrs.dilation) + + # judge_winograd_auto_scheduler = ( + # "float" in data.dtype + # and "float" in kernel.dtype + # and kernel_h == 3 + # and kernel_w == 3 + # and stride_h == 1 + # and stride_w == 1 + # and dilation_h == 1 + # and dilation_w == 1 + # ) + + # # register auto-scheduler implementations + # if judge_winograd_auto_scheduler: + # strategy.add_implementation( + # wrap_compute_conv2d( + # topi.nn.conv2d_winograd_nhwc, + # need_meta_schedule_layout=True + # ), + # naive_schedule, # this implementation should never be picked by autotvm + # name="conv2d_nhwc.winograd", + # plevel=15, + # ) + elif data_layout == "NCHW" and kernel_layout == "OIHW": strategy.add_implementation( wrap_compute_conv2d(topi.nn.conv2d_nchw), @@ -100,12 +128,37 @@ def conv2d_strategy_hexagon(attrs, inputs, out_type, target): return strategy +@conv2d_winograd_without_weight_transfrom_strategy.register("hexagon") +def conv2d_winograd_without_weight_transfrom_strategy_cpu(attrs, inputs, out_type, target): + """conv2d_winograd_without_weight_transfrom cpu strategy""" + dilation = attrs.get_int_tuple("dilation") + groups = attrs.get_int("groups") + layout = attrs.data_layout + strides = attrs.get_int_tuple("strides") + assert dilation == (1, 1), "Do not support dilate now" + assert strides == (1, 1), "Do not support strides now" + assert groups == 1, "Do not supoort arbitrary group number" + strategy = _op.OpStrategy() + + if layout == "NHWC": + strategy.add_implementation( + wrap_compute_conv2d( + topi.nn.conv2d_winograd_nhwc_without_weight_transform, + need_auto_scheduler_layout=False, + need_meta_schedule_layout=True, + ), + naive_schedule, + name="ansor.winograd", + ) + return strategy + + @dense_strategy.register("hexagon") def dense_strategy_hexagon(attrs, inputs, out_type, target): """Dense strategy for Hexagon""" strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_dense(topi.nn.dense), + wrap_compute_dense(topi.nn.dense, need_meta_schedule_layout=True), wrap_topi_schedule(topi.hexagon.schedule_dense), name="dense.hexagon", ) @@ -187,3 +240,41 @@ def schedule_reduce_hexagon(attrs, outs, target): """Schedule reduction ops for Hexagon""" with target: return topi.hexagon.schedule_reduce(outs) + + +@conv2d_NCHWc_strategy.register("hexagon") +def conv2d_NCHWc_strategy_hexagon(attrs, inputs, out_type, target): + strategy = _op.OpStrategy() + data, kernel = inputs + strategy.add_implementation( + wrap_compute_conv2d( + topi.hexagon.conv2d_NCHWc_int8, need_data_layout=True, need_out_layout=True + ), + wrap_topi_schedule(topi.hexagon.schedule_conv2d_NCHWc_int8), + name="conv2d_NCHWc_int8.hexagon", + ) + return strategy + + +@dense_pack_strategy.register("hexagon") +def dense_pack_strategy_hexagon(attrs, inputs, out_type, target): + """dense_pack hexagon strategy""" + strategy = _op.OpStrategy() + + if ( + # inputs[0].dtype == "uint8" + # and inputs[1].dtype == "uint8" + "int8" in inputs[0].dtype + and "int8" in inputs[1].dtype + and out_type.dtype == "int32" + and attrs["weight_layout"] == "NC32n4c" + ): + strategy.add_implementation( + wrap_compute_dense(topi.hexagon.dense.dense_u8u8i32_vrmpy_compute), + wrap_topi_schedule(topi.hexagon.dense.dense_u8u8i32_vrmpy_schedule), + # wrap_topi_schedule(topi.hexagon.dense.schedule_dense), + name="dense_uint8.hexagon", + plevel=12, + ) + + return strategy diff --git a/python/tvm/tir/tensor_intrin/hexagon.py b/python/tvm/tir/tensor_intrin/hexagon.py index 0227312d6373..3cad94006dd8 100644 --- a/python/tvm/tir/tensor_intrin/hexagon.py +++ b/python/tvm/tir/tensor_intrin/hexagon.py @@ -64,8 +64,58 @@ def dot_product_32x4_u8u8i32_vrmpy( ) +@T.prim_func +def dot_product_32x4_u8i8i32_desc( + A: T.Buffer((4,), "uint8", offset_factor=1), + B: T.Buffer((32, 4), "int8", offset_factor=1), + C: T.Buffer((32,), "int32", offset_factor=1), +) -> None: + with T.block("root"): + T.reads(C[0:32], A[0:4], B[0:32, 0:4]) + T.writes(C[0:32]) + for i in T.serial(0, 32): + with T.init(): + C[i] = T.int32(0) + for k in T.serial(0, 4): + with T.block("update"): + vi, vk = T.axis.remap("SR", [i, k]) + C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32") + + +@T.prim_func +def dot_product_32x4_u8i8i32_vrmpy( + A: T.Buffer((4,), "uint8", offset_factor=1), + B: T.Buffer((32, 4), "int8", offset_factor=1), + C: T.Buffer((32,), "int32", offset_factor=1), +) -> None: + with T.block("root"): + T.reads(C[0:32], A[0:4], B[0:32, 0:4]) + T.writes(C[0:32]) + + A_u8x4 = A.vload([0], "uint8x4") + A_i32 = T.reinterpret(A_u8x4, dtype="int32") + + B_i8x128 = B.vload([0, 0], dtype="int8x128") + B_i32x32 = T.reinterpret(B_i8x128, dtype="int32x32") + + C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpybusv.acc.128B"), + T.uint32(3), + C[T.ramp(T.int32(0), 1, 32)], + T.broadcast(A_i32, 32), + B_i32x32, + dtype="int32x32", + ) + + VRMPY_u8u8i32_INTRIN = "dot_32x4_u8u8i32_vrmpy" TensorIntrin.register( VRMPY_u8u8i32_INTRIN, dot_product_32x4_u8u8i32_desc, dot_product_32x4_u8u8i32_vrmpy ) + +VRMPY_u8i8i32_INTRIN = "dot_32x4_u8i8i32_vrmpy" + +TensorIntrin.register( + VRMPY_u8i8i32_INTRIN, dot_product_32x4_u8i8i32_desc, dot_product_32x4_u8i8i32_vrmpy +) diff --git a/python/tvm/topi/hexagon/__init__.py b/python/tvm/topi/hexagon/__init__.py index 295152d11631..b94526e5b919 100644 --- a/python/tvm/topi/hexagon/__init__.py +++ b/python/tvm/topi/hexagon/__init__.py @@ -29,3 +29,5 @@ from .resize2d import * from .tensor_intrin import * from .qnn import * +from .dense_alter_op import * +from .conv2d_alter_op import * diff --git a/python/tvm/topi/hexagon/conv2d.py b/python/tvm/topi/hexagon/conv2d.py index d8f44d663843..0289ae110ae2 100644 --- a/python/tvm/topi/hexagon/conv2d.py +++ b/python/tvm/topi/hexagon/conv2d.py @@ -18,7 +18,13 @@ """Schedule for conv2d""" import tvm +from tvm import te +from tvm.topi.nn.pad import pad +from .. import nn from ..utils import traverse_inline +from tvm.topi.utils import get_const_tuple +from tvm.topi.nn.utils import get_pad_tuple +from .tensor_intrin import dot_vrmpy def schedule_conv2d_nhwc(outs): @@ -86,3 +92,152 @@ def _callback(op): traverse_inline(s, outs[0].op, _callback) return s + + +def conv2d_NCHWc_int8( + data, kernel, stride, padding, dilation, layout, out_layout, out_dtype="int32" +): + n_elems = int(kernel.shape[-1]) + return nn.conv2d_NCHWc_int8( + data, kernel, stride, padding, dilation, layout, out_layout, out_dtype, n_elems=n_elems + ) + + +def schedule_conv2d_NCHWc_int8(outs): + """Create schedule for tensors""" + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if "conv2d_NCHWc_int8" in op.tag: + conv_out = op.output(0) + kernel_vec = conv_out.op.input_tensors[1] + data_vec = conv_out.op.input_tensors[0] + data = ( + data_vec.op.input_tensors[0] + if isinstance(data_vec.op, te.tensor.ComputeOp) and "pad" not in data_vec.op.tag + else data_vec + ) + if isinstance(data.op, te.tensor.ComputeOp) and "pad" in data.op.tag: + data_pad = data + data = data_pad.op.input_tensors[0] + + out_width = conv_out.shape[3] + reg_n = 1 + for n in range(31, 0, -1): + if out_width % n == 0: + reg_n = n + break + + args = [s, data_vec, conv_out, outs[0]] + # int8 conv kernel is 7-dim + _, _, kh, kw, _, _, n_elems = get_const_tuple(kernel_vec.shape) + # assert n_elems == 4 + intrin = dot_vrmpy(data.dtype, kernel_vec.dtype) + + inline_fused = True + + schedule_conv_NCHWc_cpu_common_int8( + *args, reg_n=reg_n, int32_lanes=32, int8_elems=4, intrin=intrin, inline_fused=inline_fused + ) + + traverse_inline(s, outs[0].op, _callback) + return s + + +def schedule_conv_NCHWc_cpu_common_int8( + s, + data_vec, + conv_out, + last, + reg_n, + int32_lanes=32, + int8_elems=4, + intrin=None, + inline_fused=True, +): + unroll_kw = False + _, _, _, _, ic_bn = get_const_tuple(data_vec.shape) + _, _, _, _, oc_bn = get_const_tuple(conv_out.shape) + + # schedule pad + if isinstance(s[data_vec].op, te.tensor.ComputeOp) and "pad" in data_vec.op.tag: + batch, ic_chunk, ih, iw, ic_block = s[data_vec].op.axis + parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih) + # s[data_vec].parallel(parallel_axis) + data_vec = data_vec.op.input_tensors[0] + + # schedule 5-D NCHW[x]c conv + C, O = conv_out, last + CC = s.cache_write(C, "global") + + batch, oc_chunk, oh, ow, oc_block = s[C].op.axis + ow_chunk, ow_block = s[C].split(ow, factor=reg_n) + s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) + parallel_axis = s[C].fuse(batch, oc_chunk, oh) + s[C].vectorize(oc_block) + + if C == O: + s[C].parallel(parallel_axis) + + s[CC].compute_at(s[C], parallel_axis) + _, oc_chunk, oh, ow, oc_block = s[CC].op.axis + kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis + + ow_chunk, ow_block = s[CC].split(ow, factor=reg_n) + + assert oc_bn % int32_lanes == 0, f"oc_bn={oc_bn} % int32_lanes={int32_lanes} != 0" + assert ( + ic_bn % int8_elems == 0 + ), f"ic_bn={ic_bn} % int8_elems={int8_elems} != 0" # (u)int8 elements in (u)int32 + + oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes) + + if unroll_kw: + s[CC].reorder( + oc_chunk, + oh, + ow_chunk, + ic_outer, + kh, + ic_f_inner, + kw, + ow_block, + oc_f_inner, + oc_s_inner, + ic_s_inner, + ) + s[CC].unroll(kw) + else: + s[CC].reorder( + oc_chunk, + oh, + ow_chunk, + ic_outer, + kh, + kw, + ic_f_inner, + ow_block, + oc_f_inner, + oc_s_inner, + ic_s_inner, + ) + + s[CC].tensorize(oc_s_inner, intrin) + + s[CC].unroll(ow_block) + s[CC].unroll(oc_f_inner) + + if C != O: + batch, oc_chunk, oh, ow, oc_block = s[O].op.axis + ow_chunk, ow_block = s[O].split(ow, factor=reg_n) + s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) + parallel_axis = s[O].fuse(batch, oc_chunk, oh) + + if inline_fused: + s[C].compute_at(s[O], ow_block) + else: + s[C].compute_at(s[O], parallel_axis) + s[O].vectorize(oc_block) + s[O].parallel(parallel_axis) + + return s diff --git a/python/tvm/topi/hexagon/conv2d_alter_op.py b/python/tvm/topi/hexagon/conv2d_alter_op.py new file mode 100644 index 000000000000..cf46a602e9f6 --- /dev/null +++ b/python/tvm/topi/hexagon/conv2d_alter_op.py @@ -0,0 +1,180 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument,no-member +"""Dense alter op functions for ARM""" + +import tvm +from tvm import relay +from tvm import autotvm +from ..utils import get_const_tuple +from .. import nn +from ..nn.utils import get_pad_tuple +from ..nn import conv2d_legalize, conv2d_alter_layout +from ..generic.conv2d import conv2d_alter_int8_common + + +def check_vrmpy_applicable(x, y): + out_channel, in_channel, _, _ = get_const_tuple(y.shape) + return ( + "int8" in x.dtype and "int8" in y.dtype and out_channel % 32 == 0 and in_channel % 4 == 0 + ) + + +@conv2d_alter_layout.register("hexagon") +def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): + target = tvm.target.Target.current(allow_none=False) + dispatch_ctx = autotvm.task.DispatchContext.current + new_attrs = {k: attrs[k] for k in attrs.keys()} + + # Parse the attributes. + padding = attrs.get_int_tuple("padding") + strides = attrs.get_int_tuple("strides") + dilation = attrs.get_int_tuple("dilation") + data_layout = attrs["data_layout"] + kernel_layout = attrs["kernel_layout"] + data_tensor, kernel_tensor = tinfos + data_dtype = data_tensor.dtype + kernel_dtype = kernel_tensor.dtype + out_dtype = out_type.dtype + + impl, outs = relay.backend.te_compiler.select_implementation( + relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target + ) + if impl.name.find("winograd") != -1: + if dilation != (1, 1): + return None + + assert data_layout == "NHWC" and kernel_layout == "HWIO" + N, H, W, CI = get_const_tuple(data_tensor.shape) + KH, KW, _, CO = get_const_tuple(kernel_tensor.shape) + + # Pre-compute weight transformation in winograd + tile_size = 4 + # HWIO -> OIHW + kernel_transform = relay.transpose(inputs[1], axes=[3, 2, 0, 1]) + # alpha, alpha, CO, CI + weight = relay.nn.contrib_conv2d_winograd_weight_transform( + kernel_transform, tile_size=tile_size + ) + new_attrs["tile_size"] = tile_size + new_attrs["channels"] = CO + return relay.nn.contrib_conv2d_winograd_without_weight_transform( + inputs[0], weight, **new_attrs + ) + + if not check_vrmpy_applicable(data_tensor, kernel_tensor) or data_layout != "NCHW" or kernel_layout != "OIHW": + return None + + batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape) + out_channel, _, kh, kw = get_const_tuple(kernel_tensor.shape) + data_dtype = data_tensor.dtype + kernel_dtype = kernel_tensor.dtype + + n_elems = 4 + ic_bn, oc_bn = 32, 32 + + if ic_bn > in_channel: + assert in_channel == 4 + ic_bn = in_channel + + new_attrs = {k: attrs[k] for k in attrs.keys()} + + new_attrs["channels"] = out_channel + new_attrs["data_layout"] = "NCHW%dc" % ic_bn + new_attrs["kernel_layout"] = "OIHW{:n}i{:n}o{:n}i".format(ic_bn // n_elems, oc_bn, n_elems) + new_attrs["out_layout"] = "NCHW%dc" % oc_bn + + return relay.nn.contrib_conv2d_nchwc(*inputs, **new_attrs) + + +@nn.conv2d_legalize.register("hexagon") +def _conv2d_legalize(attrs, inputs, arg_types): + data_layout = attrs["data_layout"] + kernel_layout = attrs["kernel_layout"] + + output_tensor = arg_types[2] + + # Collect the input exprs. + data, kernel = inputs + + if data_layout == "NHWC" and kernel_layout == "HWIO": + # Collect the input tensors. + data_tensor, kernel_tensor = arg_types[0], arg_types[1] + out_channel = kernel_tensor.shape[0] + + # Dilation not supported yet. Return None if dilation is not (1, 1) + dilation = attrs.get_int_tuple("dilation") + if not (dilation[0] == 1 and dilation[1] == 1): + return None + + # No legalization for depthwise convolutions yet. + groups = attrs.get_int("groups") + if groups != 1: + return None + + # Get the conv attrs + new_attrs = {k: attrs[k] for k in attrs.keys()} + + padding = attrs.get_int_tuple("padding") + kh, kw = attrs.get_int_tuple("kernel_size") + pt, pl, pb, pr = get_pad_tuple(padding, (kh, kw)) + + # TODO: pad on input channel? + in_channel_vector_length = 1 + in_channel = data_tensor.shape[3].value + + out_channel_vector_length = 64 if output_tensor.dtype == "float16" else 128 + out_channel = kernel_tensor.shape[3].value + + if out_channel % out_channel_vector_length != 0: + new_out_channel = ( + (out_channel + out_channel_vector_length) // out_channel_vector_length + ) * out_channel_vector_length + diff = new_out_channel - out_channel + kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, 0), (0, diff))) + + new_attrs["channels"] = new_out_channel + out = relay.nn.conv2d(data, kernel, **new_attrs) + original_out_shape = [x.value for x in output_tensor.shape] + return relay.strided_slice(out, begin=[0, 0, 0, 0], end=original_out_shape) + else: + return relay.nn.conv2d(data, kernel, **new_attrs) + + if data_layout != "NCHW" or kernel_layout != "OIHW": + return None + + # Collect the input tensors. + data_tensor, kernel_tensor = arg_types[0], arg_types[1] + out_channel = kernel_tensor.shape[0] + + if "int8" in data_tensor.dtype and "int8" in data_tensor.dtype and out_channel % 32 == 0: + data_dtype = data_tensor.dtype + kernel_dtype = kernel_tensor.dtype + + # Collect the output tensor. + output_tensor = arg_types[2] + + # Collect the input exprs. + data, kernel = inputs + + data_dtype = "uint8" + + return conv2d_alter_int8_common( + data, data_tensor, kernel, kernel_tensor, output_tensor, attrs, data_dtype, 4, 32 + ) + + return None diff --git a/python/tvm/topi/hexagon/dense.py b/python/tvm/topi/hexagon/dense.py index afe53f515fa9..18e34d106d1c 100644 --- a/python/tvm/topi/hexagon/dense.py +++ b/python/tvm/topi/hexagon/dense.py @@ -18,6 +18,10 @@ """Schedule for dense operator""" import tvm +from tvm.topi.utils import get_const_tuple, traverse_inline +from tvm import te +from .. import tag +from .tensor_intrin import dot_vrmpy def schedule_dense(outs): @@ -38,3 +42,72 @@ def schedule_dense(outs): s = tvm.te.create_schedule([x.op for x in outs]) tvm.te.schedule.AutoInlineInjective(s) return s + + +def dense_u8u8i32_vrmpy_compute(X, packed_w, bias, out_dtype): + """Compute for uint8 x uint8 -> int32 dense""" + # assert X.dtype == "uint8" and packed_w.dtype == "uint8" and out_dtype == "int32" + m, k = X.shape + n_o, _, n_i, _ = packed_w.shape + assert n_i == 32 + ak = te.reduce_axis((0, k), name="k") + + C = te.compute( + (m, n_o * n_i), + lambda i, j: te.sum( + X[i, ak].astype("int32") + * packed_w[tvm.tir.indexdiv(j, 32), tvm.tir.indexdiv(ak, 4), j % 32, ak % 4].astype( + "int32" + ), + axis=ak, + ), + tag="dense_u8u8i32_vrmpy", + name="compute", + attrs={"schedule_rule": "meta_schedule.dense_u8u8i32_vrmpy"}, + ) + + if bias is not None: + C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j], tag=tag.BROADCAST) + + # a_y, _ = C.op.axis + # cfg.define_split("tile_y", a_y, num_outputs=2) + + return C + + +def dense_u8u8i32_vrmpy_common(s, C, O): + (a_k,) = C.op.reduce_axis + a_y = C.op.axis[-2] + a_yo, a_yi = s[C].split(a_y, factor=32) + a_xo, a_xi = s[C].split(C.op.axis[-1], factor=32) + a_ko, a_ki = s[C].split(a_k, factor=4) + + s[C].reorder(a_yo, a_xo, a_yi, a_ko, a_xi, a_ki) + + pc = dot_vrmpy("uint8", "uint8") + s[C].tensorize(a_xi, pc) + + if C != O: + a_y = O.op.axis[-2] + a_yo, a_yi = s[O].split(a_y, factor=32) + a_xo, a_xi = s[O].split(O.op.axis[-1], factor=32) + + s[O].reorder(a_yo, a_xo, a_yi, a_xi) + s[O].vectorize(a_xi) + s[C].compute_at(s[O], a_yi) + + +def dense_u8u8i32_vrmpy_schedule(outs): + s = te.create_schedule([x.op for x in outs]) + # O: The output of the fused op + O = outs[0] + + def _callback(op): + if "u8u8i32_vrmpy" in op.tag: + # C: The output of GEMM + C = op.output(0) + dense_u8u8i32_vrmpy_common(s, C, O) + + traverse_inline(s, outs[0].op, _callback) + + return s diff --git a/python/tvm/topi/hexagon/dense_alter_op.py b/python/tvm/topi/hexagon/dense_alter_op.py new file mode 100644 index 000000000000..dbf4c5d7b361 --- /dev/null +++ b/python/tvm/topi/hexagon/dense_alter_op.py @@ -0,0 +1,150 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument,no-member +"""Dense alter op functions for ARM""" + +import tvm +from tvm import te +from tvm import relay +from tvm import autotvm +from ..utils import get_const_tuple +from .. import nn +from ..nn import dense_alter_layout + + +def check_vrmpy_applicable(x, y): + return ( + "int8" in x.dtype and "int8" in y.dtype and y.shape[-2] % 32 == 0 and y.shape[-1] % 4 == 0 + ) + + + +@dense_alter_layout.register(["hexagon"]) +def _alter_dense_layout(attrs, inputs, tinfos, out_type): + data_tensor, weight_tensor = tinfos + out_dtype = out_type.dtype + M, K = get_const_tuple(data_tensor.shape) + N, _ = get_const_tuple(weight_tensor.shape) + + if check_vrmpy_applicable(data_tensor, weight_tensor): # and data_tensor.dtype == "uint8": + weight_layout = "NC32n4c" + return relay.nn.contrib_dense_pack(inputs[0], inputs[1], weight_layout, None, out_dtype) + else: + return None + + +def vrmpy_legalize(x, w, arg_types, op, attrs, is_batched_mm): + """ + Legalizes s8, s8 -> s32 GEMM op for VRMPY. + X'_u8 = X_s8 + 128 + X_s8 * W_s8 = (X'_u8 - 128) * (W'_u8 - 128) + = X'_u8 * W'_u8 - X'_u8 * 128 - 128 * W'_u8 + 128 * 128 + X_u8 * W_s8 = X_u8 * (W'_u8 - 128) + = X'_u8 * W'_u8 - X_u8 * 128 + """ + def cast_to_uint8(x): + x = relay.cast(x, "int32") + x = relay.add(x, relay.const(128, "int32")) + return relay.cast(x, "uint8") + + if check_vrmpy_applicable(arg_types[0], arg_types[1]) and arg_types[0].dtype == "int8" and arg_types[1].dtype == "int8": + x = cast_to_uint8(x) + w = cast_to_uint8(w) + + W_u8x128 = relay.const(-128, "int32") * relay.sum(relay.cast(w, "int32"), axis=[-1]) + X_u8x128 = relay.const(-128, "int32") * relay.sum(relay.cast(x, "int32"), axis=[-1]) + + if is_batched_mm: + X_u8x128 = relay.expand_dims(X_u8x128, axis=2) + W_u8x128 = relay.expand_dims(W_u8x128, axis=1) + else: + X_u8x128 = relay.expand_dims(X_u8x128, axis=1) + + out = op(x, w, **attrs) + + out += W_u8x128 + out += X_u8x128 + + k_dim = int(arg_types[0].shape[-1]) + return out + relay.const(128 * 128 * k_dim, "int32") + + if check_vrmpy_applicable(arg_types[0], arg_types[1]) and arg_types[0].dtype == "uint8" and arg_types[1].dtype == "int8": + w = cast_to_uint8(w) + + X_u8x128 = relay.expand_dims(relay.const(-128, "int32") * relay.sum(relay.cast(x, "int32"), axis=[-1]), axis=1) + + out = op(x, w, **attrs) + + return out + X_u8x128 + + return None + + +@nn.dense_legalize.register("hexagon") +def _dense_legalize(attrs, inputs, arg_types): + new_attrs = {k: attrs[k] for k in attrs.keys()} + # Collect the input tensors. + x_tensor, y_tensor = arg_types[0], arg_types[1] + dtype = x_tensor.dtype + + # Collect the output tensor. + output_tensor = arg_types[2] + + # Collect the input exprs. + x, y = inputs + + M, K = x_tensor.shape + N, K = y_tensor.shape + try: + M = M.value + K = K.value + N = N.value + except AttributeError: + # todo: deal with unfixed shape when compiling wdl model + return None + + # vec_len = 1024 // + if dtype == "float16": + vec_len = 64 + elif "int8" in dtype: + vec_len = 32 + + if N % vec_len != 0: + N_padded = ((N + vec_len) // vec_len) * vec_len + dn = N_padded - N + + y_ = relay.nn.pad(y, pad_width=((0, dn), (0, 0))) + + # If units is explicitly specified, it is used to compute the output shape. + # We need to update units after padding to prevent a type error. + if attrs["units"] is not None: + new_attrs["units"] = N + dn + + arg_types = [arg_types[0], + tvm.ir.tensor_type.TensorType([N + dn, arg_types[1].shape[1]], arg_types[1].dtype)] + + vrmpy_out = vrmpy_legalize(x, y_, arg_types, relay.nn.dense, new_attrs, False) + + if vrmpy_out is None: + out_ = relay.nn.dense(x, y_, **new_attrs) + else: + out_ = vrmpy_out + + out = relay.strided_slice(out_, begin=[0, 0], end=[x.value for x in output_tensor.shape]) + return out + + return vrmpy_legalize(inputs[0], inputs[1], arg_types, relay.nn.dense, attrs, False) diff --git a/python/tvm/topi/hexagon/injective.py b/python/tvm/topi/hexagon/injective.py index b1d1e1541961..bd06cb8ecd16 100644 --- a/python/tvm/topi/hexagon/injective.py +++ b/python/tvm/topi/hexagon/injective.py @@ -42,8 +42,9 @@ def schedule_injective(outs): # Fuse axes and vectorize inner elements for x in outs: fused = s[x].fuse(*x.op.axis) - _, inner = s[x].split(fused, factor=128 // np.dtype(x.dtype).itemsize) + outer, inner = s[x].split(fused, factor=128 // np.dtype(x.dtype).itemsize) s[x].vectorize(inner) + s[x].parallel(outer) return s diff --git a/python/tvm/topi/hexagon/tensor_intrin.py b/python/tvm/topi/hexagon/tensor_intrin.py index bdc63854328b..4a5371135778 100644 --- a/python/tvm/topi/hexagon/tensor_intrin.py +++ b/python/tvm/topi/hexagon/tensor_intrin.py @@ -18,6 +18,7 @@ import tvm from tvm.ir import register_intrin_lowering +from tvm import te def _q_multiply_shift_hexagon(op): @@ -69,3 +70,86 @@ def _q_multiply_shift_hexagon(op): register_intrin_lowering( "tir.q_multiply_shift", target="hexagon", f=_q_multiply_shift_hexagon, level=99 ) + + +def dot_vrmpy(x_ty, y_ty): + int32_lanes = 32 + num_int8_elements = 4 # 4 int8 elements in int32 + data = te.placeholder((num_int8_elements,), dtype=x_ty, name="data") + kernel = te.placeholder((int32_lanes, num_int8_elements), dtype=y_ty, name="kernel") + k = te.reduce_axis((0, num_int8_elements), name="k") + C = te.compute( + (int32_lanes,), + lambda i: te.sum(data[k].astype("int32") * kernel[i, k].astype("int32"), axis=k), + name="C", + ) + + a_buffer = tvm.tir.decl_buffer( + data.shape, dtype=x_ty, name="a_buffer", offset_factor=1, strides=[1] + ) + b_buffer = tvm.tir.decl_buffer( + kernel.shape, dtype=y_ty, name="b_buffer", offset_factor=1, strides=[te.var("ldw"), 1] + ) + + def _intrin_func(ins, outs): + def _instr(index): + ib = tvm.tir.ir_builder.create() + if index == 1: + ib.emit(outs[0].vstore(0, tvm.tir.const(0, "int32x32"))) + return ib.get() + + vec_zero = tvm.tir.const(0, "int32x32") + + if x_ty == "uint8" and y_ty == "uint8": + a_uint8 = ins[0].vload([0], "uint8x4") + re_int32 = tvm.tir.call_intrin("int32", "tir.reinterpret", a_uint8) + vec_b = ins[1].vload([0, 0], "uint8x128") + + vrmpy_inst_name = "llvm.hexagon.V6.vrmpyub.acc.128B" + + vec_bi32 = tvm.tir.call_intrin("int32x32", "tir.reinterpret", vec_b) + + quad_reduction = tvm.tir.call_llvm_pure_intrin( + "int32x32", + vrmpy_inst_name, + tvm.tir.const(3, "uint32"), + vec_zero, + vec_bi32, + re_int32, + ) + elif x_ty == "uint8" and y_ty == "int8": + a_uint8 = ins[0].vload([0], "uint8x4") + re_int32 = tvm.tir.call_intrin("int32", "tir.reinterpret", a_uint8) + vec_b = ins[1].vload([0, 0], "int8x128") + + vrmpy_inst_name = "llvm.hexagon.V6.vrmpybusv.acc.128B" + + vec_bi32 = tvm.tir.call_intrin("int32x32", "tir.reinterpret", vec_b) + + quad_reduction = tvm.tir.call_llvm_pure_intrin( + "int32x32", + vrmpy_inst_name, + tvm.tir.const(3, "uint32"), + vec_zero, + re_int32.astype("int32x32"), + vec_bi32, + ) + else: + assert False, "Not supported" + + if index == 0: + ib.emit(outs[0].vstore(0, quad_reduction)) + else: + ib.emit(outs[0].vstore(0, quad_reduction + outs[0].vload([0], "int32x32"))) + return ib.get() + + # body, reset, update + return _instr(0), _instr(1), _instr(2) + + buffer_params = {"offset_factor": 1} + return te.decl_tensor_intrin( + C.op, + _intrin_func, + binds={data: a_buffer, kernel: b_buffer}, + default_buffer_params=buffer_params, + ) diff --git a/tests/python/contrib/test_hexagon/test_conv2d_vrmpy.py b/tests/python/contrib/test_hexagon/test_conv2d_vrmpy.py new file mode 100644 index 000000000000..494313b82cc9 --- /dev/null +++ b/tests/python/contrib/test_hexagon/test_conv2d_vrmpy.py @@ -0,0 +1,140 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np + +import tvm.testing +from tvm import relay + + +def get_conv2d_nchw( + d_shape, + w_shape, + padding, + strides=(1, 1), + data_dtype = "int8", + weight_dtype = "int8" +): + out_dtype = "int32" + + data = relay.var("data", shape=d_shape, dtype=data_dtype) + weight = relay.var("weight", shape=w_shape, dtype=weight_dtype) + out_channel = w_shape[0] + return relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=w_shape[2:], + channels=out_channel, + padding=padding, + strides=strides, + out_dtype=out_dtype, + ) + + +@tvm.testing.requires_hexagon +def test_conv2d_u8u8i32_vrmpy(hexagon_session): + target_hexagon = tvm.target.hexagon("v68") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + I = 64 + O = 256 + H = 56 + W = 56 + kH = 3 + kW = 3 + padding = (1, 1) + strides = (1, 1) + + data_shape = (1, I, H, W) + weight_shape = (O, I, kH, kW) + bias_shape = (weight_shape[0],) + + bias = relay.var("bias", shape=bias_shape, dtype="int32") + + data_dtype = "uint8" + weight_dtype = "int8" + conv2d = get_conv2d_nchw(data_shape, weight_shape, padding, strides=strides, data_dtype=data_dtype, weight_dtype=weight_dtype) + bias_add = relay.nn.bias_add(conv2d, bias) + + use_bias = True + + if use_bias: + out = bias_add + else: + out = conv2d + + mod = tvm.IRModule.from_expr(out) + + if data_dtype == "uint8": + data_np = np.random.uniform(0, 255, size=data_shape).astype("uint8") + else: + data_np = np.random.uniform(-128, 127, size=data_shape).astype("int8") + + if weight_dtype == "uint8": + weight_np = np.random.uniform(0, 255, size=weight_shape).astype("uint8") + else: + weight_np = np.random.uniform(-128, 127, size=weight_shape).astype("int8") + + bias_np = np.random.randint(low=-127, high=128, size=bias_shape).astype("int32") + params = {"weight": weight_np, "bias": bias_np} + + out_ty = relay.transform.InferType()(mod) + + _, _, P, Q = out_ty["main"].body.checked_type.shape + + target_llvm = tvm.target.Target("llvm") + + with tvm.transform.PassContext( + opt_level=3, + ): + lib_ref = relay.build(mod, target=target_llvm, params=params) + + # return + + with tvm.transform.PassContext( + opt_level=3, + ): + # opt_mod, _ = relay.optimize(mod, target=target, params=params) + # print(opt_mod) + # return + executor = relay.backend.Executor("graph", {"link-params": True}) + lib = relay.build(mod, target=target, params=params, executor=executor) + + asm = lib.lib.get_source("asm") + assert "vrmpy" in asm + + rt_mod = hexagon_session.get_executor_from_factory(lib) + + rt_mod.set_input("data", data_np) + + rt_mod.run() + + out = rt_mod.get_output(0).numpy() + + rt_mod_ref = tvm.contrib.graph_executor.GraphModule(lib_ref["default"](tvm.cpu(0))) + + rt_mod_ref.set_input("data", data_np) + + rt_mod_ref.run() + + ref = rt_mod_ref.get_output(0).numpy() + + np.testing.assert_equal(out, ref) + + # gops = (O * P * Q * I * kH * kW) * 2 / 1e9 + # time_ms = rt_mod.benchmark(hexagon_session.device, number=1, repeat=50).mean * 1e3 + + # print("time elapsed: ", time_ms) diff --git a/tests/python/contrib/test_hexagon/test_dense_vrmpy.py b/tests/python/contrib/test_hexagon/test_dense_vrmpy.py new file mode 100644 index 000000000000..381600f9556f --- /dev/null +++ b/tests/python/contrib/test_hexagon/test_dense_vrmpy.py @@ -0,0 +1,201 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np + +import tvm.testing +from tvm import relay +from tvm.relay.backend import Executor, Runtime +from tvm.script import tir as T +from tvm.tir import TensorIntrin +from tvm import meta_schedule as ms +from tvm.meta_schedule.testing.utils import apply_fixed_schedules + + +@T.prim_func +def dot_product_32x4_u8u8i32_desc( + A: T.Buffer((4,), "uint8", offset_factor=1), + B: T.Buffer((32, 4), "uint8", offset_factor=1), + C: T.Buffer((32,), "int32", offset_factor=1), +) -> None: + with T.block("root"): + T.reads(C[0:32], A[0:4], B[0:32, 0:4]) + T.writes(C[0:32]) + for i in T.serial(0, 32): + with T.init(): + C[i] = T.int32(0) + for k in T.serial(0, 4): + with T.block("update"): + vi, vk = T.axis.remap("SR", [i, k]) + C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32") + + +@T.prim_func +def dot_product_32x4_u8u8i32_vrmpy( + A: T.Buffer((4,), "uint8", offset_factor=1), + B: T.Buffer((32, 4), "uint8", offset_factor=1), + C: T.Buffer((32,), "int32", offset_factor=1), +) -> None: + with T.block("root"): + T.reads(C[0:32], A[0:4], B[0:32, 0:4]) + T.writes(C[0:32]) + + A_u8x4 = A.vload([0], "uint8x4") + A_i32 = T.reinterpret(A_u8x4, dtype="int32") + + B_i8x128 = B.vload([0, 0], dtype="uint8x128") + B_i32x32 = T.reinterpret(B_i8x128, dtype="int32x32") + + C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyub.acc.128B"), + T.uint32(3), + C[T.ramp(T.int32(0), 1, 32)], + B_i32x32, + A_i32, + dtype="int32x32", + ) + + +VRMPY_INTRIN = "dot_32x4_vrmpy" + +TensorIntrin.register(VRMPY_INTRIN, dot_product_32x4_u8u8i32_desc, dot_product_32x4_u8u8i32_vrmpy) + + +def schedule_matmul_common(sch, block, batched, M): + a_y, a_x, _ = sch.get_loops(block)[-3:] + outer_block = block + + a_yo, a_yi = sch.split(a_y, factors=[None, min(M, 32)]) + + a_xo, a_xi = sch.split(a_x, factors=[None, 32]) + sch.reorder(a_yo, a_xo, a_yi, a_xi) + + a_xi, a_k = sch.get_loops(block)[-2:] + a_ko, a_ki = sch.split(a_k, factors=[None, 4]) + sch.reorder(a_ko, a_xi, a_ki) + + if batched: + a_b = sch.get_loops(outer_block)[0] + fused = sch.fuse(a_b, a_yo, a_xo) + else: + fused = sch.fuse(a_yo, a_xo) + + # sch.parallel(fused) + + dec = sch.decompose_reduction(block, a_ko) + + init_loop = sch.get_loops(dec)[-1] + sch.vectorize(init_loop) + + sch.tensorize(a_xi, VRMPY_INTRIN) + + return fused + + +def schedule_dense(dense_block, M, sch): + schedule_matmul_common(sch, dense_block, False, M) + + +@tvm.testing.requires_hexagon +def test_dense_u8u8i32_vrmpy(hexagon_session): + target_hexagon = tvm.target.hexagon("v68", link_params=True) + target = tvm.target.Target(target_hexagon, host=target_hexagon) + + M = 128 + N = 768 + K = 768 + data_shape = (M, K) + weight_shape = (N, K) + + dtype = "uint8" + data = relay.var("data", shape=data_shape, dtype=dtype) + weight = relay.var("weight", shape=weight_shape, dtype=dtype) + + dense = relay.nn.dense(data, weight, out_dtype="int32") + + use_bias = False + + if dtype == "uint8": + data_np = np.random.uniform(1, 255, size=data_shape).astype(dtype) + weight_np = np.random.uniform(1, 255, size=weight_shape).astype(dtype) + else: + data_np = np.random.uniform(-128, 127, size=data_shape).astype(dtype) + weight_np = np.random.uniform(-128, 127, size=weight_shape).astype(dtype) + + # data_np = np.ones(data_shape).astype(dtype) * 127 + # weight_np = np.ones(weight_shape).astype(dtype) * 127 + + bias_np = np.random.uniform(1, 10, size=(weight_shape[0],)).astype("int32") + + params = {"weight": weight_np, "bias": bias_np} + + if use_bias: + bias = relay.var("bias", shape=(weight_shape[0],), dtype="int32") + out = relay.nn.bias_add(dense, bias) + else: + out = dense + + mod = tvm.IRModule.from_expr(out) + + def schedule_fn(task, sch): + if "dense" not in task.task_name: + return False + + block = sch.get_block("compute") + + schedule_rule = sch.get(block).annotations["schedule_rule"] + + assert "dense_u8u8i32_vrmpy" in schedule_rule + + schedule_dense(block, M, sch) + +# print(sch.mod.script()) + + return True + + database = apply_fixed_schedules(mod, target, params, schedule_fn) + + with ms.ApplyHistoryBest(database): + with tvm.transform.PassContext( + opt_level=3, + config={"relay.backend.use_meta_schedule": True}, + ): + lib = relay.build(mod, target=target, params=params) + + asm = lib.lib.get_source("asm") +# assert "vrmpy" in asm + + rt_mod = hexagon_session.get_executor_from_factory(lib) + + rt_mod.set_input("data", data_np) + + rt_mod.run() + + out = rt_mod.get_output(0).numpy() + ref = np.dot(data_np.astype("int32"), weight_np.transpose().astype("int32")) + + if use_bias: + ref += bias_np + + np.testing.assert_equal(out, ref) + print(ref) + + gops = (N * M * K) * 2 / 1e9 + time_ms = rt_mod.benchmark(hexagon_session.device, number=1, repeat=50).mean * 1e3 + + print("time elapsed: ", time_ms) + print("GOPS:", gops / (time_ms / 1e3)) diff --git a/tests/python/contrib/test_hexagon/test_qnn.py b/tests/python/contrib/test_hexagon/test_qnn.py new file mode 100644 index 000000000000..ae17d0054a18 --- /dev/null +++ b/tests/python/contrib/test_hexagon/test_qnn.py @@ -0,0 +1,347 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np + +import tvm +import tvm.testing +from tvm import relay +from tvm.contrib.hexagon.session import Session +from tvm.meta_schedule import postproc, schedule_rule +from tvm.tir.tensor_intrin.hexagon import VRMPY_u8i8i32_INTRIN +from tvm.contrib.hexagon.meta_schedule import get_hexagon_local_builder, get_hexagon_rpc_runner +from tvm import meta_schedule as ms + + +executor = relay.backend.Executor("graph", {"link-params": True}) +target_hexagon = tvm.target.hexagon("v68") +target_llvm = tvm.target.Target("llvm") + + +@tvm.testing.requires_hexagon +def test_resnet50(hexagon_session: Session): + with open("qresnet50.json", "r") as fi: + mod = tvm.ir.load_json(fi.read()) + + with open("qresnet50.params", "rb") as fi: + params = relay.load_param_dict(fi.read()) + + # print(relay.transform.InferType()(mod)) + # return + + inp = np.random.randn(1, 3, 224, 224).astype("float32") + input_name = "image" + + with tvm.transform.PassContext( + opt_level=3, + ): + # opt_mod, _ = relay.optimize( + # mod, + # tvm.target.Target(target_hexagon, host=target_hexagon), + # params=params, + # ) + + # print(opt_mod) + + # return + + hexagon_lowered = relay.build( + mod, + tvm.target.Target(target_hexagon, host=target_hexagon), + params=params, + executor=executor, + ) + + with tvm.transform.PassContext(opt_level=3): + llvm_lowered = tvm.relay.build( + mod, + tvm.target.Target(target_llvm, host=target_llvm), + params=params, + ) + + # assert "vrmpy" in hexagon_lowered.lib.get_source("asm") + # print(hexagon_lowered.lib.get_source("asm")) + + # debug_ex = hexagon_session.get_graph_debug_executor(hexagon_lowered.get_graph_json(), hexagon_lowered.lib) + # print(debug_ex.profile(input_name=inp)) + + # return + + graph_mod = hexagon_session.get_executor_from_factory(hexagon_lowered) + graph_mod.set_input(input_name, inp.copy()) + + llvm_graph_mod = tvm.contrib.graph_executor.GraphModule(llvm_lowered["default"](tvm.cpu(0))) + llvm_graph_mod.set_input(input_name, inp.copy()) + + import time + + t0 = time.time() + graph_mod.run() + hexagon_output = graph_mod.get_output(0).numpy() + print("run finished in ", time.time() - t0) + + llvm_graph_mod.run() + ref_result = llvm_graph_mod.get_output(0).numpy() + print(np.max(np.abs(ref_result - hexagon_output)), np.mean(np.abs(ref_result - hexagon_output))) + + time_ms = graph_mod.benchmark(hexagon_session.device, number=1, repeat=20).mean * 1e3 + + print("time elapsed: ", time_ms) + + debug_ex = hexagon_session.get_graph_debug_executor(hexagon_lowered.get_graph_json(), hexagon_lowered.lib) + print(debug_ex.profile(input_name=inp.copy())) + + +def tune_ms(mod, params, hexagon_launcher): + sch_rules = [ + schedule_rule.AutoInline( + into_producer=False, + into_consumer=True, + inline_const_tensor=True, + disallow_if_then_else=True, + require_injective=True, + require_ordered=True, + disallow_op=["tir.exp"], + ), + schedule_rule.MultiLevelTilingWithIntrin( + VRMPY_u8i8i32_INTRIN, + structure="SRSRS", + tile_binds=None, + max_innermost_factor=64, + vector_load_lens=None, + reuse_read=None, + reuse_write=schedule_rule.ReuseType( + req="may", + levels=[1, 2], + scope="global", + ), + ), + schedule_rule.ParallelizeVectorizeUnroll( + max_jobs_per_core=16, + max_vectorize_extent=128, + unroll_max_steps=[0, 16, 64, 512], + unroll_explicit=True, + ), + ] + + postprocs = [ + postproc.DisallowDynamicLoop(), + postproc.RewriteParallelVectorizeUnroll(), + postproc.RewriteReductionBlock(), + postproc.RewriteTensorize(vectorize_init_loop=True), + ] + + work_dir = "work_auto_tensorize" + config = ms.TuneConfig( + strategy="replay_trace", + # strategy="evolutionary", + num_trials_per_iter=8, + max_trials_per_task=8, + max_trials_global=20000, + ) + + target = tvm.target.Target(target_hexagon, host=target_hexagon) + + if False: + return ms.tune_relay( + mod=mod, + params=params, + target=target, + config=config, + work_dir=work_dir, + builder=get_hexagon_local_builder(), + runner=get_hexagon_rpc_runner(hexagon_launcher, number=20), + executor=executor, + sch_rules=lambda: sch_rules, + postprocs=lambda: postprocs, + ) + else: + pass_config = {"relay.FuseOps.link_params": True, + "relay.backend.use_meta_schedule": True, + "relay.backend.tir_converter": "default" + } + + from tvm.meta_schedule.tune import tune_extracted_tasks + from tvm.meta_schedule.relay_integration import extract_task_from_relay + + extracted_tasks = extract_task_from_relay(mod, target, params, pass_config=pass_config) + + tune_tasks = [] + + for task in extracted_tasks: + # if "conv2d" in task.task_name: + if True: + tune_tasks.append(task) + + database = tune_extracted_tasks( + tune_tasks, + config, + work_dir, + builder=get_hexagon_local_builder(), + runner=get_hexagon_rpc_runner(hexagon_launcher, number=20), + num_threads=32, + sch_rules=lambda: sch_rules, + postprocs=lambda: postprocs, + ) + + with target, database: + with tvm.transform.PassContext( + opt_level=3, + config={ + "relay.backend.use_meta_schedule": True, + "relay.backend.use_meta_schedule_dispatch": target.kind.name != "cuda", + "relay.backend.tir_converter": "default", + }, + ): + return relay.build(mod, target=target, params=params, executor=executor) + + +@tvm.testing.requires_hexagon +def test_resnet50_auto_tensorize(hexagon_launcher): + with open("qresnet50.json", "r") as fi: + mod = tvm.ir.load_json(fi.read()) + + with open("qresnet50.params", "rb") as fi: + params = relay.load_param_dict(fi.read()) + + inp = np.random.randn(1, 3, 224, 224).astype("float32") + input_name = "image" + + hexagon_lowered = tune_ms(mod, params, hexagon_launcher) + + with tvm.transform.PassContext(opt_level=3): + llvm_lowered = tvm.relay.build( + mod, + tvm.target.Target(target_llvm, host=target_llvm), + params=params, + ) + + with hexagon_launcher.start_session() as session: + graph_mod = session.get_executor_from_factory(hexagon_lowered) + graph_mod.set_input(input_name, inp.copy()) + + llvm_graph_mod = tvm.contrib.graph_executor.GraphModule(llvm_lowered["default"](tvm.cpu(0))) + llvm_graph_mod.set_input(input_name, inp.copy()) + + import time + + t0 = time.time() + graph_mod.run() + hexagon_output = graph_mod.get_output(0).numpy() + print("run finished in ", time.time() - t0) + + llvm_graph_mod.run() + ref_result = llvm_graph_mod.get_output(0).numpy() + print(np.max(np.abs(ref_result - hexagon_output)), np.mean(np.abs(ref_result - hexagon_output))) + + time_ms = graph_mod.benchmark(session.device, number=1, repeat=20).mean * 1e3 + + print("time elapsed: ", time_ms) + + debug_ex = session.get_graph_debug_executor(hexagon_lowered.get_graph_json(), hexagon_lowered.lib) + print(debug_ex.profile(input_name=inp.copy())) + + +@tvm.testing.requires_hexagon +def test_qnn_conv2d(hexagon_launcher): + with open("qnn_conv2d.json", "r") as fi: + mod = tvm.ir.load_json(fi.read()) + + with open("qnn_conv2d.params", "rb") as fi: + params = relay.load_param_dict(fi.read()) + + if True: + hexagon_lowered = tune_ms(mod, params ,hexagon_launcher) + else: + with tvm.transform.PassContext( + opt_level=3, + ): + hexagon_lowered = relay.build( + mod, + tvm.target.Target(target_hexagon, host=target_hexagon), + params=params, + executor=executor, + ) + + inp = np.load("qconv2d_input.npy") + input_name = "input" + + with hexagon_launcher.start_session() as session: + graph_mod = session.get_executor_from_factory(hexagon_lowered) + graph_mod.set_input(input_name, inp.copy()) + # graph_mod.set_input(**params) + + import time + + t0 = time.time() + graph_mod.run() + hexagon_output = graph_mod.get_output(0).numpy() + print("run finished in ", time.time() - t0) + + pt_result = np.load("qconv2d_output.npy") + print(np.max(np.abs(pt_result - hexagon_output)), np.mean(np.abs(pt_result - hexagon_output))) + + # time_ms = graph_mod.benchmark(hexagon_session.device, number=1, repeat=20).mean * 1e3 + + # print("time elapsed: ", time_ms) + + +@tvm.testing.requires_hexagon +def test_qconv2d_subgraph(hexagon_session: Session): + mod = tvm.parser.fromtext( +""" +#[version = "0.0.5"] +def @main(%p070: Tensor[(1, 2, 56, 56, 32), uint8], %p150: Tensor[(8, 2, 1, 1, 8, 32, 4), int8], %p250: Tensor[(1, 8, 1, 1, 32), int32] , %p350: Tensor[(1, 8, 1, 1, 32), int64], %p450: Tensor[(1, 8, 1, 1, 32), int64], %p550: Tensor[(1, 8, 1, 1, 32), int64], %p617: Tensor[(1), int32] ) -> Tensor[(1, 8, 56, 56, 32), int32] { + %546 = nn.contrib_conv2d_NCHWc(%p070, %p150, padding=[0, 0, 0, 0], channels=256, kernel_size=[1, 1], data_layout="NCHW32c", kernel_layout="OIHW8i32o4i", out_layout="NCHW32c", out_dtype="int32") /* ty=Tensor[(1, 8, 56, 56, 32), int32] */; + %547 = add(%546, %p250) /* ty=Tensor[(1, 8, 56, 56, 32), int32] */; + %548 = cast(%547, dtype="int64") /* ty=Tensor[(1, 8, 56, 56, 32), int64] */; + %549 = multiply(%548, %p350) /* ty=Tensor[(1, 8, 56, 56, 32), int64] */; + %550 = add(%549, %p450) /* ty=Tensor[(1, 8, 56, 56, 32), int64] */; + %551 = right_shift(%550, %p550) /* ty=Tensor[(1, 8, 56, 56, 32), int64] */; + %552 = cast(%551, dtype="int32") /* ty=Tensor[(1, 8, 56, 56, 32), int32] */; + %553 = add(77 /* ty=int32 */, %552) /* ty=Tensor[(1, 8, 56, 56, 32), int32] */; + %554 = clip(%553, a_min=0f, a_max=255f) /* ty=Tensor[(1, 8, 56, 56, 32), int32] */; + %555 = subtract(%554, %p617) /* ty=Tensor[(1, 8, 56, 56, 32), int32] */; + fixed_point_multiply(%555, multiplier=1147032118, shift=2) /* ty=Tensor[(1, 8, 56, 56, 32), int32] */ + } +""") + + mod2 = tvm.parser.fromtext( +""" +#[version = "0.0.5"] +def @main(%p070: Tensor[(1, 2, 56, 56, 32), uint8] /* ty=Tensor[(1, 2, 56, 56, 32), uint8] */, %p150: Tensor[(8, 2, 1, 1, 8, 32, 4), int8] /* ty=Tensor[(8, 2, 1, 1, 8, 32, 4), int8] */, %p250: Tensor[(1, 8, 1, 1, 32), int32] /* ty=Tensor[(1, 8, 1, 1, 32), int32] */, %p350: Tensor[(1, 8, 1, 1, 32), int64] /* ty=Tensor[(1, 8, 1, 1, 32), int64] */, %p450: Tensor[(1, 8, 1, 1, 32), int64] /* ty=Tensor[(1, 8, 1, 1, 32), int64] */, %p550: Tensor[(1, 8, 1, 1, 32), int64] /* ty=Tensor[(1, 8, 1, 1, 32), int64] */, %p617: Tensor[(1), int32] /* ty=Tensor[(1), int32] */, kernel_layout="OIHW8i32o4i", data_layout="NCHW32c", out_layout="NCHW32c") { + nn.contrib_conv2d_NCHWc(%p070, %p150, padding=[0, 0, 0, 0], channels=256, kernel_size=[1, 1], data_layout="NCHW32c", kernel_layout="OIHW8i32o4i", out_layout="NCHW32c", out_dtype="int32") + } +""") + + params = {} + + with tvm.transform.PassContext(opt_level=3): + hexagon_lowered = relay.build( + mod2, + tvm.target.Target(target_hexagon, host=target_hexagon), + params=params, + ) + + # print(hexagon_lowered.lib.get_source("asm")) + graph_mod = hexagon_session.get_executor_from_factory(hexagon_lowered) + graph_mod.run() + time_ms = graph_mod.benchmark(hexagon_session.device, number=1, repeat=20).mean * 1e3 + + print("time elapsed: ", time_ms) + + # debug_ex = hexagon_session.get_graph_debug_executor(hexagon_lowered.get_graph_json(), hexagon_lowered.lib) + # print(debug_ex.profile()) From 2cd5ae84a05427a146baebca09ce78621f611864 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 21 Sep 2022 19:04:45 +0900 Subject: [PATCH 02/14] update --- python/tvm/relay/op/strategy/hexagon.py | 64 +--- python/tvm/tir/tensor_intrin/hexagon.py | 50 --- python/tvm/topi/hexagon/conv2d_alter_op.py | 31 +- python/tvm/topi/hexagon/dense.py | 6 +- python/tvm/topi/hexagon/dense_alter_op.py | 11 +- .../contrib/test_hexagon/test_conv2d_vrmpy.py | 5 - .../contrib/test_hexagon/test_dense_vrmpy.py | 125 +------ tests/python/contrib/test_hexagon/test_qnn.py | 347 ------------------ 8 files changed, 13 insertions(+), 626 deletions(-) delete mode 100644 tests/python/contrib/test_hexagon/test_qnn.py diff --git a/python/tvm/relay/op/strategy/hexagon.py b/python/tvm/relay/op/strategy/hexagon.py index 72e837b5d156..45a3d87e6af4 100644 --- a/python/tvm/relay/op/strategy/hexagon.py +++ b/python/tvm/relay/op/strategy/hexagon.py @@ -62,38 +62,10 @@ def conv2d_strategy_hexagon(attrs, inputs, out_type, target): if groups == 1: if data_layout == "NHWC" and kernel_layout == "HWIO": strategy.add_implementation( - wrap_compute_conv2d(topi.nn.conv2d_nhwc, need_meta_schedule_layout=True), + wrap_compute_conv2d(topi.nn.conv2d_nhwc), wrap_topi_schedule(topi.hexagon.schedule_conv2d_nhwc), name="conv2d_nhwc.hexagon", ) - - # kernel_h, kernel_w, _, co = get_const_tuple(kernel.shape) - # stride_h, stride_w = get_const_tuple(attrs.strides) - # dilation_h, dilation_w = get_const_tuple(attrs.dilation) - - # judge_winograd_auto_scheduler = ( - # "float" in data.dtype - # and "float" in kernel.dtype - # and kernel_h == 3 - # and kernel_w == 3 - # and stride_h == 1 - # and stride_w == 1 - # and dilation_h == 1 - # and dilation_w == 1 - # ) - - # # register auto-scheduler implementations - # if judge_winograd_auto_scheduler: - # strategy.add_implementation( - # wrap_compute_conv2d( - # topi.nn.conv2d_winograd_nhwc, - # need_meta_schedule_layout=True - # ), - # naive_schedule, # this implementation should never be picked by autotvm - # name="conv2d_nhwc.winograd", - # plevel=15, - # ) - elif data_layout == "NCHW" and kernel_layout == "OIHW": strategy.add_implementation( wrap_compute_conv2d(topi.nn.conv2d_nchw), @@ -128,37 +100,12 @@ def conv2d_strategy_hexagon(attrs, inputs, out_type, target): return strategy -@conv2d_winograd_without_weight_transfrom_strategy.register("hexagon") -def conv2d_winograd_without_weight_transfrom_strategy_cpu(attrs, inputs, out_type, target): - """conv2d_winograd_without_weight_transfrom cpu strategy""" - dilation = attrs.get_int_tuple("dilation") - groups = attrs.get_int("groups") - layout = attrs.data_layout - strides = attrs.get_int_tuple("strides") - assert dilation == (1, 1), "Do not support dilate now" - assert strides == (1, 1), "Do not support strides now" - assert groups == 1, "Do not supoort arbitrary group number" - strategy = _op.OpStrategy() - - if layout == "NHWC": - strategy.add_implementation( - wrap_compute_conv2d( - topi.nn.conv2d_winograd_nhwc_without_weight_transform, - need_auto_scheduler_layout=False, - need_meta_schedule_layout=True, - ), - naive_schedule, - name="ansor.winograd", - ) - return strategy - - @dense_strategy.register("hexagon") def dense_strategy_hexagon(attrs, inputs, out_type, target): """Dense strategy for Hexagon""" strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_dense(topi.nn.dense, need_meta_schedule_layout=True), + wrap_compute_dense(topi.nn.dense), wrap_topi_schedule(topi.hexagon.schedule_dense), name="dense.hexagon", ) @@ -262,17 +209,14 @@ def dense_pack_strategy_hexagon(attrs, inputs, out_type, target): strategy = _op.OpStrategy() if ( - # inputs[0].dtype == "uint8" - # and inputs[1].dtype == "uint8" - "int8" in inputs[0].dtype - and "int8" in inputs[1].dtype + inputs[0].dtype == "uint8" + and inputs[1].dtype == "uint8" and out_type.dtype == "int32" and attrs["weight_layout"] == "NC32n4c" ): strategy.add_implementation( wrap_compute_dense(topi.hexagon.dense.dense_u8u8i32_vrmpy_compute), wrap_topi_schedule(topi.hexagon.dense.dense_u8u8i32_vrmpy_schedule), - # wrap_topi_schedule(topi.hexagon.dense.schedule_dense), name="dense_uint8.hexagon", plevel=12, ) diff --git a/python/tvm/tir/tensor_intrin/hexagon.py b/python/tvm/tir/tensor_intrin/hexagon.py index 3cad94006dd8..0227312d6373 100644 --- a/python/tvm/tir/tensor_intrin/hexagon.py +++ b/python/tvm/tir/tensor_intrin/hexagon.py @@ -64,58 +64,8 @@ def dot_product_32x4_u8u8i32_vrmpy( ) -@T.prim_func -def dot_product_32x4_u8i8i32_desc( - A: T.Buffer((4,), "uint8", offset_factor=1), - B: T.Buffer((32, 4), "int8", offset_factor=1), - C: T.Buffer((32,), "int32", offset_factor=1), -) -> None: - with T.block("root"): - T.reads(C[0:32], A[0:4], B[0:32, 0:4]) - T.writes(C[0:32]) - for i in T.serial(0, 32): - with T.init(): - C[i] = T.int32(0) - for k in T.serial(0, 4): - with T.block("update"): - vi, vk = T.axis.remap("SR", [i, k]) - C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32") - - -@T.prim_func -def dot_product_32x4_u8i8i32_vrmpy( - A: T.Buffer((4,), "uint8", offset_factor=1), - B: T.Buffer((32, 4), "int8", offset_factor=1), - C: T.Buffer((32,), "int32", offset_factor=1), -) -> None: - with T.block("root"): - T.reads(C[0:32], A[0:4], B[0:32, 0:4]) - T.writes(C[0:32]) - - A_u8x4 = A.vload([0], "uint8x4") - A_i32 = T.reinterpret(A_u8x4, dtype="int32") - - B_i8x128 = B.vload([0, 0], dtype="int8x128") - B_i32x32 = T.reinterpret(B_i8x128, dtype="int32x32") - - C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin( - T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpybusv.acc.128B"), - T.uint32(3), - C[T.ramp(T.int32(0), 1, 32)], - T.broadcast(A_i32, 32), - B_i32x32, - dtype="int32x32", - ) - - VRMPY_u8u8i32_INTRIN = "dot_32x4_u8u8i32_vrmpy" TensorIntrin.register( VRMPY_u8u8i32_INTRIN, dot_product_32x4_u8u8i32_desc, dot_product_32x4_u8u8i32_vrmpy ) - -VRMPY_u8i8i32_INTRIN = "dot_32x4_u8i8i32_vrmpy" - -TensorIntrin.register( - VRMPY_u8i8i32_INTRIN, dot_product_32x4_u8i8i32_desc, dot_product_32x4_u8i8i32_vrmpy -) diff --git a/python/tvm/topi/hexagon/conv2d_alter_op.py b/python/tvm/topi/hexagon/conv2d_alter_op.py index cf46a602e9f6..a0ca5c386711 100644 --- a/python/tvm/topi/hexagon/conv2d_alter_op.py +++ b/python/tvm/topi/hexagon/conv2d_alter_op.py @@ -51,31 +51,6 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): kernel_dtype = kernel_tensor.dtype out_dtype = out_type.dtype - impl, outs = relay.backend.te_compiler.select_implementation( - relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target - ) - if impl.name.find("winograd") != -1: - if dilation != (1, 1): - return None - - assert data_layout == "NHWC" and kernel_layout == "HWIO" - N, H, W, CI = get_const_tuple(data_tensor.shape) - KH, KW, _, CO = get_const_tuple(kernel_tensor.shape) - - # Pre-compute weight transformation in winograd - tile_size = 4 - # HWIO -> OIHW - kernel_transform = relay.transpose(inputs[1], axes=[3, 2, 0, 1]) - # alpha, alpha, CO, CI - weight = relay.nn.contrib_conv2d_winograd_weight_transform( - kernel_transform, tile_size=tile_size - ) - new_attrs["tile_size"] = tile_size - new_attrs["channels"] = CO - return relay.nn.contrib_conv2d_winograd_without_weight_transform( - inputs[0], weight, **new_attrs - ) - if not check_vrmpy_applicable(data_tensor, kernel_tensor) or data_layout != "NCHW" or kernel_layout != "OIHW": return None @@ -133,11 +108,7 @@ def _conv2d_legalize(attrs, inputs, arg_types): kh, kw = attrs.get_int_tuple("kernel_size") pt, pl, pb, pr = get_pad_tuple(padding, (kh, kw)) - # TODO: pad on input channel? - in_channel_vector_length = 1 - in_channel = data_tensor.shape[3].value - - out_channel_vector_length = 64 if output_tensor.dtype == "float16" else 128 + out_channel_vector_length = 64 if output_tensor.dtype == "float16" else 32 out_channel = kernel_tensor.shape[3].value if out_channel % out_channel_vector_length != 0: diff --git a/python/tvm/topi/hexagon/dense.py b/python/tvm/topi/hexagon/dense.py index 18e34d106d1c..59190ba83efb 100644 --- a/python/tvm/topi/hexagon/dense.py +++ b/python/tvm/topi/hexagon/dense.py @@ -46,7 +46,7 @@ def schedule_dense(outs): def dense_u8u8i32_vrmpy_compute(X, packed_w, bias, out_dtype): """Compute for uint8 x uint8 -> int32 dense""" - # assert X.dtype == "uint8" and packed_w.dtype == "uint8" and out_dtype == "int32" + assert X.dtype == "uint8" and packed_w.dtype == "uint8" and out_dtype == "int32" m, k = X.shape n_o, _, n_i, _ = packed_w.shape assert n_i == 32 @@ -63,15 +63,11 @@ def dense_u8u8i32_vrmpy_compute(X, packed_w, bias, out_dtype): ), tag="dense_u8u8i32_vrmpy", name="compute", - attrs={"schedule_rule": "meta_schedule.dense_u8u8i32_vrmpy"}, ) if bias is not None: C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j], tag=tag.BROADCAST) - # a_y, _ = C.op.axis - # cfg.define_split("tile_y", a_y, num_outputs=2) - return C diff --git a/python/tvm/topi/hexagon/dense_alter_op.py b/python/tvm/topi/hexagon/dense_alter_op.py index dbf4c5d7b361..496b842a6a16 100644 --- a/python/tvm/topi/hexagon/dense_alter_op.py +++ b/python/tvm/topi/hexagon/dense_alter_op.py @@ -109,18 +109,11 @@ def _dense_legalize(attrs, inputs, arg_types): M, K = x_tensor.shape N, K = y_tensor.shape - try: - M = M.value - K = K.value - N = N.value - except AttributeError: - # todo: deal with unfixed shape when compiling wdl model - return None - # vec_len = 1024 // if dtype == "float16": vec_len = 64 - elif "int8" in dtype: + else: + assert "int8" in dtype vec_len = 32 if N % vec_len != 0: diff --git a/tests/python/contrib/test_hexagon/test_conv2d_vrmpy.py b/tests/python/contrib/test_hexagon/test_conv2d_vrmpy.py index 494313b82cc9..f3db343a53ed 100644 --- a/tests/python/contrib/test_hexagon/test_conv2d_vrmpy.py +++ b/tests/python/contrib/test_hexagon/test_conv2d_vrmpy.py @@ -133,8 +133,3 @@ def test_conv2d_u8u8i32_vrmpy(hexagon_session): ref = rt_mod_ref.get_output(0).numpy() np.testing.assert_equal(out, ref) - - # gops = (O * P * Q * I * kH * kW) * 2 / 1e9 - # time_ms = rt_mod.benchmark(hexagon_session.device, number=1, repeat=50).mean * 1e3 - - # print("time elapsed: ", time_ms) diff --git a/tests/python/contrib/test_hexagon/test_dense_vrmpy.py b/tests/python/contrib/test_hexagon/test_dense_vrmpy.py index 381600f9556f..a479cb930c3d 100644 --- a/tests/python/contrib/test_hexagon/test_dense_vrmpy.py +++ b/tests/python/contrib/test_hexagon/test_dense_vrmpy.py @@ -19,95 +19,7 @@ import tvm.testing from tvm import relay -from tvm.relay.backend import Executor, Runtime -from tvm.script import tir as T -from tvm.tir import TensorIntrin -from tvm import meta_schedule as ms -from tvm.meta_schedule.testing.utils import apply_fixed_schedules - - -@T.prim_func -def dot_product_32x4_u8u8i32_desc( - A: T.Buffer((4,), "uint8", offset_factor=1), - B: T.Buffer((32, 4), "uint8", offset_factor=1), - C: T.Buffer((32,), "int32", offset_factor=1), -) -> None: - with T.block("root"): - T.reads(C[0:32], A[0:4], B[0:32, 0:4]) - T.writes(C[0:32]) - for i in T.serial(0, 32): - with T.init(): - C[i] = T.int32(0) - for k in T.serial(0, 4): - with T.block("update"): - vi, vk = T.axis.remap("SR", [i, k]) - C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32") - - -@T.prim_func -def dot_product_32x4_u8u8i32_vrmpy( - A: T.Buffer((4,), "uint8", offset_factor=1), - B: T.Buffer((32, 4), "uint8", offset_factor=1), - C: T.Buffer((32,), "int32", offset_factor=1), -) -> None: - with T.block("root"): - T.reads(C[0:32], A[0:4], B[0:32, 0:4]) - T.writes(C[0:32]) - - A_u8x4 = A.vload([0], "uint8x4") - A_i32 = T.reinterpret(A_u8x4, dtype="int32") - - B_i8x128 = B.vload([0, 0], dtype="uint8x128") - B_i32x32 = T.reinterpret(B_i8x128, dtype="int32x32") - - C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin( - T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyub.acc.128B"), - T.uint32(3), - C[T.ramp(T.int32(0), 1, 32)], - B_i32x32, - A_i32, - dtype="int32x32", - ) - - -VRMPY_INTRIN = "dot_32x4_vrmpy" - -TensorIntrin.register(VRMPY_INTRIN, dot_product_32x4_u8u8i32_desc, dot_product_32x4_u8u8i32_vrmpy) - - -def schedule_matmul_common(sch, block, batched, M): - a_y, a_x, _ = sch.get_loops(block)[-3:] - outer_block = block - - a_yo, a_yi = sch.split(a_y, factors=[None, min(M, 32)]) - - a_xo, a_xi = sch.split(a_x, factors=[None, 32]) - sch.reorder(a_yo, a_xo, a_yi, a_xi) - - a_xi, a_k = sch.get_loops(block)[-2:] - a_ko, a_ki = sch.split(a_k, factors=[None, 4]) - sch.reorder(a_ko, a_xi, a_ki) - - if batched: - a_b = sch.get_loops(outer_block)[0] - fused = sch.fuse(a_b, a_yo, a_xo) - else: - fused = sch.fuse(a_yo, a_xo) - - # sch.parallel(fused) - - dec = sch.decompose_reduction(block, a_ko) - - init_loop = sch.get_loops(dec)[-1] - sch.vectorize(init_loop) - - sch.tensorize(a_xi, VRMPY_INTRIN) - - return fused - - -def schedule_dense(dense_block, M, sch): - schedule_matmul_common(sch, dense_block, False, M) +from tvm.relay.backend import Executor @tvm.testing.requires_hexagon @@ -151,30 +63,10 @@ def test_dense_u8u8i32_vrmpy(hexagon_session): mod = tvm.IRModule.from_expr(out) - def schedule_fn(task, sch): - if "dense" not in task.task_name: - return False - - block = sch.get_block("compute") - - schedule_rule = sch.get(block).annotations["schedule_rule"] - - assert "dense_u8u8i32_vrmpy" in schedule_rule - - schedule_dense(block, M, sch) - -# print(sch.mod.script()) - - return True - - database = apply_fixed_schedules(mod, target, params, schedule_fn) - - with ms.ApplyHistoryBest(database): - with tvm.transform.PassContext( - opt_level=3, - config={"relay.backend.use_meta_schedule": True}, - ): - lib = relay.build(mod, target=target, params=params) + with tvm.transform.PassContext( + opt_level=3, + ): + lib = relay.build(mod, target=target, params=params) asm = lib.lib.get_source("asm") # assert "vrmpy" in asm @@ -192,10 +84,3 @@ def schedule_fn(task, sch): ref += bias_np np.testing.assert_equal(out, ref) - print(ref) - - gops = (N * M * K) * 2 / 1e9 - time_ms = rt_mod.benchmark(hexagon_session.device, number=1, repeat=50).mean * 1e3 - - print("time elapsed: ", time_ms) - print("GOPS:", gops / (time_ms / 1e3)) diff --git a/tests/python/contrib/test_hexagon/test_qnn.py b/tests/python/contrib/test_hexagon/test_qnn.py deleted file mode 100644 index ae17d0054a18..000000000000 --- a/tests/python/contrib/test_hexagon/test_qnn.py +++ /dev/null @@ -1,347 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import numpy as np - -import tvm -import tvm.testing -from tvm import relay -from tvm.contrib.hexagon.session import Session -from tvm.meta_schedule import postproc, schedule_rule -from tvm.tir.tensor_intrin.hexagon import VRMPY_u8i8i32_INTRIN -from tvm.contrib.hexagon.meta_schedule import get_hexagon_local_builder, get_hexagon_rpc_runner -from tvm import meta_schedule as ms - - -executor = relay.backend.Executor("graph", {"link-params": True}) -target_hexagon = tvm.target.hexagon("v68") -target_llvm = tvm.target.Target("llvm") - - -@tvm.testing.requires_hexagon -def test_resnet50(hexagon_session: Session): - with open("qresnet50.json", "r") as fi: - mod = tvm.ir.load_json(fi.read()) - - with open("qresnet50.params", "rb") as fi: - params = relay.load_param_dict(fi.read()) - - # print(relay.transform.InferType()(mod)) - # return - - inp = np.random.randn(1, 3, 224, 224).astype("float32") - input_name = "image" - - with tvm.transform.PassContext( - opt_level=3, - ): - # opt_mod, _ = relay.optimize( - # mod, - # tvm.target.Target(target_hexagon, host=target_hexagon), - # params=params, - # ) - - # print(opt_mod) - - # return - - hexagon_lowered = relay.build( - mod, - tvm.target.Target(target_hexagon, host=target_hexagon), - params=params, - executor=executor, - ) - - with tvm.transform.PassContext(opt_level=3): - llvm_lowered = tvm.relay.build( - mod, - tvm.target.Target(target_llvm, host=target_llvm), - params=params, - ) - - # assert "vrmpy" in hexagon_lowered.lib.get_source("asm") - # print(hexagon_lowered.lib.get_source("asm")) - - # debug_ex = hexagon_session.get_graph_debug_executor(hexagon_lowered.get_graph_json(), hexagon_lowered.lib) - # print(debug_ex.profile(input_name=inp)) - - # return - - graph_mod = hexagon_session.get_executor_from_factory(hexagon_lowered) - graph_mod.set_input(input_name, inp.copy()) - - llvm_graph_mod = tvm.contrib.graph_executor.GraphModule(llvm_lowered["default"](tvm.cpu(0))) - llvm_graph_mod.set_input(input_name, inp.copy()) - - import time - - t0 = time.time() - graph_mod.run() - hexagon_output = graph_mod.get_output(0).numpy() - print("run finished in ", time.time() - t0) - - llvm_graph_mod.run() - ref_result = llvm_graph_mod.get_output(0).numpy() - print(np.max(np.abs(ref_result - hexagon_output)), np.mean(np.abs(ref_result - hexagon_output))) - - time_ms = graph_mod.benchmark(hexagon_session.device, number=1, repeat=20).mean * 1e3 - - print("time elapsed: ", time_ms) - - debug_ex = hexagon_session.get_graph_debug_executor(hexagon_lowered.get_graph_json(), hexagon_lowered.lib) - print(debug_ex.profile(input_name=inp.copy())) - - -def tune_ms(mod, params, hexagon_launcher): - sch_rules = [ - schedule_rule.AutoInline( - into_producer=False, - into_consumer=True, - inline_const_tensor=True, - disallow_if_then_else=True, - require_injective=True, - require_ordered=True, - disallow_op=["tir.exp"], - ), - schedule_rule.MultiLevelTilingWithIntrin( - VRMPY_u8i8i32_INTRIN, - structure="SRSRS", - tile_binds=None, - max_innermost_factor=64, - vector_load_lens=None, - reuse_read=None, - reuse_write=schedule_rule.ReuseType( - req="may", - levels=[1, 2], - scope="global", - ), - ), - schedule_rule.ParallelizeVectorizeUnroll( - max_jobs_per_core=16, - max_vectorize_extent=128, - unroll_max_steps=[0, 16, 64, 512], - unroll_explicit=True, - ), - ] - - postprocs = [ - postproc.DisallowDynamicLoop(), - postproc.RewriteParallelVectorizeUnroll(), - postproc.RewriteReductionBlock(), - postproc.RewriteTensorize(vectorize_init_loop=True), - ] - - work_dir = "work_auto_tensorize" - config = ms.TuneConfig( - strategy="replay_trace", - # strategy="evolutionary", - num_trials_per_iter=8, - max_trials_per_task=8, - max_trials_global=20000, - ) - - target = tvm.target.Target(target_hexagon, host=target_hexagon) - - if False: - return ms.tune_relay( - mod=mod, - params=params, - target=target, - config=config, - work_dir=work_dir, - builder=get_hexagon_local_builder(), - runner=get_hexagon_rpc_runner(hexagon_launcher, number=20), - executor=executor, - sch_rules=lambda: sch_rules, - postprocs=lambda: postprocs, - ) - else: - pass_config = {"relay.FuseOps.link_params": True, - "relay.backend.use_meta_schedule": True, - "relay.backend.tir_converter": "default" - } - - from tvm.meta_schedule.tune import tune_extracted_tasks - from tvm.meta_schedule.relay_integration import extract_task_from_relay - - extracted_tasks = extract_task_from_relay(mod, target, params, pass_config=pass_config) - - tune_tasks = [] - - for task in extracted_tasks: - # if "conv2d" in task.task_name: - if True: - tune_tasks.append(task) - - database = tune_extracted_tasks( - tune_tasks, - config, - work_dir, - builder=get_hexagon_local_builder(), - runner=get_hexagon_rpc_runner(hexagon_launcher, number=20), - num_threads=32, - sch_rules=lambda: sch_rules, - postprocs=lambda: postprocs, - ) - - with target, database: - with tvm.transform.PassContext( - opt_level=3, - config={ - "relay.backend.use_meta_schedule": True, - "relay.backend.use_meta_schedule_dispatch": target.kind.name != "cuda", - "relay.backend.tir_converter": "default", - }, - ): - return relay.build(mod, target=target, params=params, executor=executor) - - -@tvm.testing.requires_hexagon -def test_resnet50_auto_tensorize(hexagon_launcher): - with open("qresnet50.json", "r") as fi: - mod = tvm.ir.load_json(fi.read()) - - with open("qresnet50.params", "rb") as fi: - params = relay.load_param_dict(fi.read()) - - inp = np.random.randn(1, 3, 224, 224).astype("float32") - input_name = "image" - - hexagon_lowered = tune_ms(mod, params, hexagon_launcher) - - with tvm.transform.PassContext(opt_level=3): - llvm_lowered = tvm.relay.build( - mod, - tvm.target.Target(target_llvm, host=target_llvm), - params=params, - ) - - with hexagon_launcher.start_session() as session: - graph_mod = session.get_executor_from_factory(hexagon_lowered) - graph_mod.set_input(input_name, inp.copy()) - - llvm_graph_mod = tvm.contrib.graph_executor.GraphModule(llvm_lowered["default"](tvm.cpu(0))) - llvm_graph_mod.set_input(input_name, inp.copy()) - - import time - - t0 = time.time() - graph_mod.run() - hexagon_output = graph_mod.get_output(0).numpy() - print("run finished in ", time.time() - t0) - - llvm_graph_mod.run() - ref_result = llvm_graph_mod.get_output(0).numpy() - print(np.max(np.abs(ref_result - hexagon_output)), np.mean(np.abs(ref_result - hexagon_output))) - - time_ms = graph_mod.benchmark(session.device, number=1, repeat=20).mean * 1e3 - - print("time elapsed: ", time_ms) - - debug_ex = session.get_graph_debug_executor(hexagon_lowered.get_graph_json(), hexagon_lowered.lib) - print(debug_ex.profile(input_name=inp.copy())) - - -@tvm.testing.requires_hexagon -def test_qnn_conv2d(hexagon_launcher): - with open("qnn_conv2d.json", "r") as fi: - mod = tvm.ir.load_json(fi.read()) - - with open("qnn_conv2d.params", "rb") as fi: - params = relay.load_param_dict(fi.read()) - - if True: - hexagon_lowered = tune_ms(mod, params ,hexagon_launcher) - else: - with tvm.transform.PassContext( - opt_level=3, - ): - hexagon_lowered = relay.build( - mod, - tvm.target.Target(target_hexagon, host=target_hexagon), - params=params, - executor=executor, - ) - - inp = np.load("qconv2d_input.npy") - input_name = "input" - - with hexagon_launcher.start_session() as session: - graph_mod = session.get_executor_from_factory(hexagon_lowered) - graph_mod.set_input(input_name, inp.copy()) - # graph_mod.set_input(**params) - - import time - - t0 = time.time() - graph_mod.run() - hexagon_output = graph_mod.get_output(0).numpy() - print("run finished in ", time.time() - t0) - - pt_result = np.load("qconv2d_output.npy") - print(np.max(np.abs(pt_result - hexagon_output)), np.mean(np.abs(pt_result - hexagon_output))) - - # time_ms = graph_mod.benchmark(hexagon_session.device, number=1, repeat=20).mean * 1e3 - - # print("time elapsed: ", time_ms) - - -@tvm.testing.requires_hexagon -def test_qconv2d_subgraph(hexagon_session: Session): - mod = tvm.parser.fromtext( -""" -#[version = "0.0.5"] -def @main(%p070: Tensor[(1, 2, 56, 56, 32), uint8], %p150: Tensor[(8, 2, 1, 1, 8, 32, 4), int8], %p250: Tensor[(1, 8, 1, 1, 32), int32] , %p350: Tensor[(1, 8, 1, 1, 32), int64], %p450: Tensor[(1, 8, 1, 1, 32), int64], %p550: Tensor[(1, 8, 1, 1, 32), int64], %p617: Tensor[(1), int32] ) -> Tensor[(1, 8, 56, 56, 32), int32] { - %546 = nn.contrib_conv2d_NCHWc(%p070, %p150, padding=[0, 0, 0, 0], channels=256, kernel_size=[1, 1], data_layout="NCHW32c", kernel_layout="OIHW8i32o4i", out_layout="NCHW32c", out_dtype="int32") /* ty=Tensor[(1, 8, 56, 56, 32), int32] */; - %547 = add(%546, %p250) /* ty=Tensor[(1, 8, 56, 56, 32), int32] */; - %548 = cast(%547, dtype="int64") /* ty=Tensor[(1, 8, 56, 56, 32), int64] */; - %549 = multiply(%548, %p350) /* ty=Tensor[(1, 8, 56, 56, 32), int64] */; - %550 = add(%549, %p450) /* ty=Tensor[(1, 8, 56, 56, 32), int64] */; - %551 = right_shift(%550, %p550) /* ty=Tensor[(1, 8, 56, 56, 32), int64] */; - %552 = cast(%551, dtype="int32") /* ty=Tensor[(1, 8, 56, 56, 32), int32] */; - %553 = add(77 /* ty=int32 */, %552) /* ty=Tensor[(1, 8, 56, 56, 32), int32] */; - %554 = clip(%553, a_min=0f, a_max=255f) /* ty=Tensor[(1, 8, 56, 56, 32), int32] */; - %555 = subtract(%554, %p617) /* ty=Tensor[(1, 8, 56, 56, 32), int32] */; - fixed_point_multiply(%555, multiplier=1147032118, shift=2) /* ty=Tensor[(1, 8, 56, 56, 32), int32] */ - } -""") - - mod2 = tvm.parser.fromtext( -""" -#[version = "0.0.5"] -def @main(%p070: Tensor[(1, 2, 56, 56, 32), uint8] /* ty=Tensor[(1, 2, 56, 56, 32), uint8] */, %p150: Tensor[(8, 2, 1, 1, 8, 32, 4), int8] /* ty=Tensor[(8, 2, 1, 1, 8, 32, 4), int8] */, %p250: Tensor[(1, 8, 1, 1, 32), int32] /* ty=Tensor[(1, 8, 1, 1, 32), int32] */, %p350: Tensor[(1, 8, 1, 1, 32), int64] /* ty=Tensor[(1, 8, 1, 1, 32), int64] */, %p450: Tensor[(1, 8, 1, 1, 32), int64] /* ty=Tensor[(1, 8, 1, 1, 32), int64] */, %p550: Tensor[(1, 8, 1, 1, 32), int64] /* ty=Tensor[(1, 8, 1, 1, 32), int64] */, %p617: Tensor[(1), int32] /* ty=Tensor[(1), int32] */, kernel_layout="OIHW8i32o4i", data_layout="NCHW32c", out_layout="NCHW32c") { - nn.contrib_conv2d_NCHWc(%p070, %p150, padding=[0, 0, 0, 0], channels=256, kernel_size=[1, 1], data_layout="NCHW32c", kernel_layout="OIHW8i32o4i", out_layout="NCHW32c", out_dtype="int32") - } -""") - - params = {} - - with tvm.transform.PassContext(opt_level=3): - hexagon_lowered = relay.build( - mod2, - tvm.target.Target(target_hexagon, host=target_hexagon), - params=params, - ) - - # print(hexagon_lowered.lib.get_source("asm")) - graph_mod = hexagon_session.get_executor_from_factory(hexagon_lowered) - graph_mod.run() - time_ms = graph_mod.benchmark(hexagon_session.device, number=1, repeat=20).mean * 1e3 - - print("time elapsed: ", time_ms) - - # debug_ex = hexagon_session.get_graph_debug_executor(hexagon_lowered.get_graph_json(), hexagon_lowered.lib) - # print(debug_ex.profile()) From d8b6be7482e78608475c3cae4fd2fe2fe38690c2 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 21 Sep 2022 03:15:47 -0700 Subject: [PATCH 03/14] clean up --- .../contrib/test_hexagon/test_conv2d_vrmpy.py | 24 +++----------- .../contrib/test_hexagon/test_dense_vrmpy.py | 33 ++++++++++--------- 2 files changed, 22 insertions(+), 35 deletions(-) diff --git a/tests/python/contrib/test_hexagon/test_conv2d_vrmpy.py b/tests/python/contrib/test_hexagon/test_conv2d_vrmpy.py index f3db343a53ed..c380fff16f6d 100644 --- a/tests/python/contrib/test_hexagon/test_conv2d_vrmpy.py +++ b/tests/python/contrib/test_hexagon/test_conv2d_vrmpy.py @@ -95,21 +95,15 @@ def test_conv2d_u8u8i32_vrmpy(hexagon_session): _, _, P, Q = out_ty["main"].body.checked_type.shape - target_llvm = tvm.target.Target("llvm") - - with tvm.transform.PassContext( - opt_level=3, - ): - lib_ref = relay.build(mod, target=target_llvm, params=params) - - # return + ref = ( + relay.create_executor("graph", mod=mod, device=tvm.cpu(0), target="llvm") + .evaluate()(*[data_np, weight_np, bias_np]) + .numpy() + ) with tvm.transform.PassContext( opt_level=3, ): - # opt_mod, _ = relay.optimize(mod, target=target, params=params) - # print(opt_mod) - # return executor = relay.backend.Executor("graph", {"link-params": True}) lib = relay.build(mod, target=target, params=params, executor=executor) @@ -124,12 +118,4 @@ def test_conv2d_u8u8i32_vrmpy(hexagon_session): out = rt_mod.get_output(0).numpy() - rt_mod_ref = tvm.contrib.graph_executor.GraphModule(lib_ref["default"](tvm.cpu(0))) - - rt_mod_ref.set_input("data", data_np) - - rt_mod_ref.run() - - ref = rt_mod_ref.get_output(0).numpy() - np.testing.assert_equal(out, ref) diff --git a/tests/python/contrib/test_hexagon/test_dense_vrmpy.py b/tests/python/contrib/test_hexagon/test_dense_vrmpy.py index a479cb930c3d..2d9c6b3230e8 100644 --- a/tests/python/contrib/test_hexagon/test_dense_vrmpy.py +++ b/tests/python/contrib/test_hexagon/test_dense_vrmpy.py @@ -19,7 +19,6 @@ import tvm.testing from tvm import relay -from tvm.relay.backend import Executor @tvm.testing.requires_hexagon @@ -27,29 +26,30 @@ def test_dense_u8u8i32_vrmpy(hexagon_session): target_hexagon = tvm.target.hexagon("v68", link_params=True) target = tvm.target.Target(target_hexagon, host=target_hexagon) - M = 128 - N = 768 - K = 768 + M = 1 + N = 1000 + K = 2048 data_shape = (M, K) weight_shape = (N, K) - dtype = "uint8" - data = relay.var("data", shape=data_shape, dtype=dtype) - weight = relay.var("weight", shape=weight_shape, dtype=dtype) + data_dtype = "uint8" + weight_dtype = "int8" + data = relay.var("data", shape=data_shape, dtype=data_dtype) + weight = relay.var("weight", shape=weight_shape, dtype=weight_dtype) dense = relay.nn.dense(data, weight, out_dtype="int32") use_bias = False - if dtype == "uint8": - data_np = np.random.uniform(1, 255, size=data_shape).astype(dtype) - weight_np = np.random.uniform(1, 255, size=weight_shape).astype(dtype) + if data_dtype == "uint8": + data_np = np.random.uniform(0, 255, size=data_shape).astype("uint8") else: - data_np = np.random.uniform(-128, 127, size=data_shape).astype(dtype) - weight_np = np.random.uniform(-128, 127, size=weight_shape).astype(dtype) + data_np = np.random.uniform(-128, 127, size=data_shape).astype("int8") - # data_np = np.ones(data_shape).astype(dtype) * 127 - # weight_np = np.ones(weight_shape).astype(dtype) * 127 + if weight_dtype == "uint8": + weight_np = np.random.uniform(0, 255, size=weight_shape).astype("uint8") + else: + weight_np = np.random.uniform(-128, 127, size=weight_shape).astype("int8") bias_np = np.random.uniform(1, 10, size=(weight_shape[0],)).astype("int32") @@ -66,10 +66,11 @@ def test_dense_u8u8i32_vrmpy(hexagon_session): with tvm.transform.PassContext( opt_level=3, ): - lib = relay.build(mod, target=target, params=params) + executor = relay.backend.Executor("graph", {"link-params": True}) + lib = relay.build(mod, target=target, params=params, executor=executor) asm = lib.lib.get_source("asm") -# assert "vrmpy" in asm + assert "vrmpy" in asm rt_mod = hexagon_session.get_executor_from_factory(lib) From 21bcd55483dbf7fb1e90753b8b8b37af812de753 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 21 Sep 2022 03:20:38 -0700 Subject: [PATCH 04/14] migrate tests to test_launcher.py --- .../contrib/test_hexagon/test_conv2d_vrmpy.py | 100 ----------- .../contrib/test_hexagon/test_dense_vrmpy.py | 66 -------- .../contrib/test_hexagon/test_launcher.py | 157 ++++++++++++++++++ 3 files changed, 157 insertions(+), 166 deletions(-) diff --git a/tests/python/contrib/test_hexagon/test_conv2d_vrmpy.py b/tests/python/contrib/test_hexagon/test_conv2d_vrmpy.py index c380fff16f6d..1dc1055f9a96 100644 --- a/tests/python/contrib/test_hexagon/test_conv2d_vrmpy.py +++ b/tests/python/contrib/test_hexagon/test_conv2d_vrmpy.py @@ -19,103 +19,3 @@ import tvm.testing from tvm import relay - - -def get_conv2d_nchw( - d_shape, - w_shape, - padding, - strides=(1, 1), - data_dtype = "int8", - weight_dtype = "int8" -): - out_dtype = "int32" - - data = relay.var("data", shape=d_shape, dtype=data_dtype) - weight = relay.var("weight", shape=w_shape, dtype=weight_dtype) - out_channel = w_shape[0] - return relay.nn.conv2d( - data=data, - weight=weight, - kernel_size=w_shape[2:], - channels=out_channel, - padding=padding, - strides=strides, - out_dtype=out_dtype, - ) - - -@tvm.testing.requires_hexagon -def test_conv2d_u8u8i32_vrmpy(hexagon_session): - target_hexagon = tvm.target.hexagon("v68") - target = tvm.target.Target(target_hexagon, host=target_hexagon) - I = 64 - O = 256 - H = 56 - W = 56 - kH = 3 - kW = 3 - padding = (1, 1) - strides = (1, 1) - - data_shape = (1, I, H, W) - weight_shape = (O, I, kH, kW) - bias_shape = (weight_shape[0],) - - bias = relay.var("bias", shape=bias_shape, dtype="int32") - - data_dtype = "uint8" - weight_dtype = "int8" - conv2d = get_conv2d_nchw(data_shape, weight_shape, padding, strides=strides, data_dtype=data_dtype, weight_dtype=weight_dtype) - bias_add = relay.nn.bias_add(conv2d, bias) - - use_bias = True - - if use_bias: - out = bias_add - else: - out = conv2d - - mod = tvm.IRModule.from_expr(out) - - if data_dtype == "uint8": - data_np = np.random.uniform(0, 255, size=data_shape).astype("uint8") - else: - data_np = np.random.uniform(-128, 127, size=data_shape).astype("int8") - - if weight_dtype == "uint8": - weight_np = np.random.uniform(0, 255, size=weight_shape).astype("uint8") - else: - weight_np = np.random.uniform(-128, 127, size=weight_shape).astype("int8") - - bias_np = np.random.randint(low=-127, high=128, size=bias_shape).astype("int32") - params = {"weight": weight_np, "bias": bias_np} - - out_ty = relay.transform.InferType()(mod) - - _, _, P, Q = out_ty["main"].body.checked_type.shape - - ref = ( - relay.create_executor("graph", mod=mod, device=tvm.cpu(0), target="llvm") - .evaluate()(*[data_np, weight_np, bias_np]) - .numpy() - ) - - with tvm.transform.PassContext( - opt_level=3, - ): - executor = relay.backend.Executor("graph", {"link-params": True}) - lib = relay.build(mod, target=target, params=params, executor=executor) - - asm = lib.lib.get_source("asm") - assert "vrmpy" in asm - - rt_mod = hexagon_session.get_executor_from_factory(lib) - - rt_mod.set_input("data", data_np) - - rt_mod.run() - - out = rt_mod.get_output(0).numpy() - - np.testing.assert_equal(out, ref) diff --git a/tests/python/contrib/test_hexagon/test_dense_vrmpy.py b/tests/python/contrib/test_hexagon/test_dense_vrmpy.py index 2d9c6b3230e8..1dc1055f9a96 100644 --- a/tests/python/contrib/test_hexagon/test_dense_vrmpy.py +++ b/tests/python/contrib/test_hexagon/test_dense_vrmpy.py @@ -19,69 +19,3 @@ import tvm.testing from tvm import relay - - -@tvm.testing.requires_hexagon -def test_dense_u8u8i32_vrmpy(hexagon_session): - target_hexagon = tvm.target.hexagon("v68", link_params=True) - target = tvm.target.Target(target_hexagon, host=target_hexagon) - - M = 1 - N = 1000 - K = 2048 - data_shape = (M, K) - weight_shape = (N, K) - - data_dtype = "uint8" - weight_dtype = "int8" - data = relay.var("data", shape=data_shape, dtype=data_dtype) - weight = relay.var("weight", shape=weight_shape, dtype=weight_dtype) - - dense = relay.nn.dense(data, weight, out_dtype="int32") - - use_bias = False - - if data_dtype == "uint8": - data_np = np.random.uniform(0, 255, size=data_shape).astype("uint8") - else: - data_np = np.random.uniform(-128, 127, size=data_shape).astype("int8") - - if weight_dtype == "uint8": - weight_np = np.random.uniform(0, 255, size=weight_shape).astype("uint8") - else: - weight_np = np.random.uniform(-128, 127, size=weight_shape).astype("int8") - - bias_np = np.random.uniform(1, 10, size=(weight_shape[0],)).astype("int32") - - params = {"weight": weight_np, "bias": bias_np} - - if use_bias: - bias = relay.var("bias", shape=(weight_shape[0],), dtype="int32") - out = relay.nn.bias_add(dense, bias) - else: - out = dense - - mod = tvm.IRModule.from_expr(out) - - with tvm.transform.PassContext( - opt_level=3, - ): - executor = relay.backend.Executor("graph", {"link-params": True}) - lib = relay.build(mod, target=target, params=params, executor=executor) - - asm = lib.lib.get_source("asm") - assert "vrmpy" in asm - - rt_mod = hexagon_session.get_executor_from_factory(lib) - - rt_mod.set_input("data", data_np) - - rt_mod.run() - - out = rt_mod.get_output(0).numpy() - ref = np.dot(data_np.astype("int32"), weight_np.transpose().astype("int32")) - - if use_bias: - ref += bias_np - - np.testing.assert_equal(out, ref) diff --git a/tests/python/contrib/test_hexagon/test_launcher.py b/tests/python/contrib/test_hexagon/test_launcher.py index 9321ddf71d3b..f8bdb94b93a5 100644 --- a/tests/python/contrib/test_hexagon/test_launcher.py +++ b/tests/python/contrib/test_hexagon/test_launcher.py @@ -424,5 +424,162 @@ def test_aot_executor_multiple_conv2d(hexagon_session: Session, aot_host_target, tvm.testing.assert_allclose(hexagon_output, expected_output, rtol=1e-4, atol=1e-5) +@tvm.testing.requires_hexagon +def test_conv2d_u8u8i32_vrmpy(hexagon_session): + def get_conv2d_nchw( + d_shape, + w_shape, + padding, + strides=(1, 1), + data_dtype = "int8", + weight_dtype = "int8" + ): + out_dtype = "int32" + + data = relay.var("data", shape=d_shape, dtype=data_dtype) + weight = relay.var("weight", shape=w_shape, dtype=weight_dtype) + out_channel = w_shape[0] + return relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=w_shape[2:], + channels=out_channel, + padding=padding, + strides=strides, + out_dtype=out_dtype, + ) + + target_hexagon = tvm.target.hexagon("v68") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + I, O, H, W = 64, 256, 56, 56 + kH = kW = 3 + padding = (1, 1) + strides = (1, 1) + + data_shape = (1, I, H, W) + weight_shape = (O, I, kH, kW) + bias_shape = (weight_shape[0],) + + bias = relay.var("bias", shape=bias_shape, dtype="int32") + + data_dtype = "uint8" + weight_dtype = "int8" + conv2d = get_conv2d_nchw(data_shape, weight_shape, padding, strides=strides, data_dtype=data_dtype, weight_dtype=weight_dtype) + bias_add = relay.nn.bias_add(conv2d, bias) + + use_bias = True + + if use_bias: + out = bias_add + else: + out = conv2d + + mod = tvm.IRModule.from_expr(out) + + if data_dtype == "uint8": + data_np = np.random.uniform(0, 255, size=data_shape).astype("uint8") + else: + data_np = np.random.uniform(-128, 127, size=data_shape).astype("int8") + + if weight_dtype == "uint8": + weight_np = np.random.uniform(0, 255, size=weight_shape).astype("uint8") + else: + weight_np = np.random.uniform(-128, 127, size=weight_shape).astype("int8") + + bias_np = np.random.randint(low=-127, high=128, size=bias_shape).astype("int32") + params = {"weight": weight_np, "bias": bias_np} + + ref = ( + relay.create_executor("graph", mod=mod, device=tvm.cpu(0), target="llvm") + .evaluate()(*[data_np, weight_np, bias_np]) + .numpy() + ) + + with tvm.transform.PassContext( + opt_level=3, + ): + executor = relay.backend.Executor("graph", {"link-params": True}) + lib = relay.build(mod, target=target, params=params, executor=executor) + + asm = lib.lib.get_source("asm") + assert "vrmpy" in asm + + rt_mod = hexagon_session.get_executor_from_factory(lib) + + rt_mod.set_input("data", data_np) + + rt_mod.run() + + out = rt_mod.get_output(0).numpy() + + np.testing.assert_equal(out, ref) + + +@tvm.testing.requires_hexagon +def test_dense_u8u8i32_vrmpy(hexagon_session): + target_hexagon = tvm.target.hexagon("v68") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + + M = 128 + N = 1000 + K = 2048 + data_shape = (M, K) + weight_shape = (N, K) + + data_dtype = "uint8" + weight_dtype = "int8" + data = relay.var("data", shape=data_shape, dtype=data_dtype) + weight = relay.var("weight", shape=weight_shape, dtype=weight_dtype) + + dense = relay.nn.dense(data, weight, out_dtype="int32") + + use_bias = False + + if data_dtype == "uint8": + data_np = np.random.uniform(0, 255, size=data_shape).astype("uint8") + else: + data_np = np.random.uniform(-128, 127, size=data_shape).astype("int8") + + if weight_dtype == "uint8": + weight_np = np.random.uniform(0, 255, size=weight_shape).astype("uint8") + else: + weight_np = np.random.uniform(-128, 127, size=weight_shape).astype("int8") + + bias_np = np.random.uniform(1, 10, size=(weight_shape[0],)).astype("int32") + + params = {"weight": weight_np, "bias": bias_np} + + if use_bias: + bias = relay.var("bias", shape=(weight_shape[0],), dtype="int32") + out = relay.nn.bias_add(dense, bias) + else: + out = dense + + mod = tvm.IRModule.from_expr(out) + + with tvm.transform.PassContext( + opt_level=3, + ): + executor = relay.backend.Executor("graph", {"link-params": True}) + lib = relay.build(mod, target=target, params=params, executor=executor) + + asm = lib.lib.get_source("asm") + assert "vrmpy" in asm + + rt_mod = hexagon_session.get_executor_from_factory(lib) + + rt_mod.set_input("data", data_np) + + rt_mod.run() + + out = rt_mod.get_output(0).numpy() + ref = np.dot(data_np.astype("int32"), weight_np.transpose().astype("int32")) + + if use_bias: + ref += bias_np + + np.testing.assert_equal(out, ref) + + if __name__ == "__main__": tvm.testing.main() From 39b17f63525adeffa23db8d4710021f195384dc8 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 21 Sep 2022 03:21:03 -0700 Subject: [PATCH 05/14] remove vrmpy test files --- .../contrib/test_hexagon/test_conv2d_vrmpy.py | 21 ------------------- .../contrib/test_hexagon/test_dense_vrmpy.py | 21 ------------------- 2 files changed, 42 deletions(-) delete mode 100644 tests/python/contrib/test_hexagon/test_conv2d_vrmpy.py delete mode 100644 tests/python/contrib/test_hexagon/test_dense_vrmpy.py diff --git a/tests/python/contrib/test_hexagon/test_conv2d_vrmpy.py b/tests/python/contrib/test_hexagon/test_conv2d_vrmpy.py deleted file mode 100644 index 1dc1055f9a96..000000000000 --- a/tests/python/contrib/test_hexagon/test_conv2d_vrmpy.py +++ /dev/null @@ -1,21 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import numpy as np - -import tvm.testing -from tvm import relay diff --git a/tests/python/contrib/test_hexagon/test_dense_vrmpy.py b/tests/python/contrib/test_hexagon/test_dense_vrmpy.py deleted file mode 100644 index 1dc1055f9a96..000000000000 --- a/tests/python/contrib/test_hexagon/test_dense_vrmpy.py +++ /dev/null @@ -1,21 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import numpy as np - -import tvm.testing -from tvm import relay From e95daaec2ae423e1ee202c785153eaacb9063a33 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 21 Sep 2022 03:31:41 -0700 Subject: [PATCH 06/14] use generic int8 conv2d schedule --- python/tvm/topi/generic/conv2d.py | 11 ++- python/tvm/topi/hexagon/conv2d.py | 114 ++---------------------------- 2 files changed, 15 insertions(+), 110 deletions(-) diff --git a/python/tvm/topi/generic/conv2d.py b/python/tvm/topi/generic/conv2d.py index 48b2a2f97146..6467c44faa77 100644 --- a/python/tvm/topi/generic/conv2d.py +++ b/python/tvm/topi/generic/conv2d.py @@ -139,7 +139,16 @@ def schedule_conv_NCHWc_cpu_common_int8( More details - https://software.intel.com/en-us/articles/ lower-numerical-precision-deep-learning-inference-and-training """ - reg_n, unroll_kw = cfg["tile_ow"].size[-1], cfg["unroll_kw"].val + if isinstance(cfg["tile_ow"], int): + reg_n = cfg["tile_ow"] + else: + reg_n = cfg["tile_ow"].size[-1] + + if isinstance(cfg["unroll_kw"], (int, bool)): + unroll_kw = cfg["unroll_kw"] + else: + unroll_kw = cfg["unroll_kw"].val + _, _, _, _, ic_bn = get_const_tuple(data_vec.shape) _, _, _, _, oc_bn = get_const_tuple(conv_out.shape) diff --git a/python/tvm/topi/hexagon/conv2d.py b/python/tvm/topi/hexagon/conv2d.py index 0289ae110ae2..d56c8f3a53e0 100644 --- a/python/tvm/topi/hexagon/conv2d.py +++ b/python/tvm/topi/hexagon/conv2d.py @@ -19,12 +19,11 @@ import tvm from tvm import te -from tvm.topi.nn.pad import pad from .. import nn from ..utils import traverse_inline from tvm.topi.utils import get_const_tuple -from tvm.topi.nn.utils import get_pad_tuple from .tensor_intrin import dot_vrmpy +from ..generic import conv2d as conv2d_generic def schedule_conv2d_nhwc(outs): @@ -128,116 +127,13 @@ def _callback(op): reg_n = n break - args = [s, data_vec, conv_out, outs[0]] - # int8 conv kernel is 7-dim - _, _, kh, kw, _, _, n_elems = get_const_tuple(kernel_vec.shape) - # assert n_elems == 4 + cfg = {"tile_ow": reg_n, "unroll_kw": False} + args = [s, cfg, data_vec, kernel_vec, conv_out, outs[0]] intrin = dot_vrmpy(data.dtype, kernel_vec.dtype) - inline_fused = True - - schedule_conv_NCHWc_cpu_common_int8( - *args, reg_n=reg_n, int32_lanes=32, int8_elems=4, intrin=intrin, inline_fused=inline_fused + conv2d_generic.schedule_conv_NCHWc_cpu_common_int8( + *args, int32_lanes=32, int8_elems=4, intrin=intrin, inline_fused=True, ) traverse_inline(s, outs[0].op, _callback) return s - - -def schedule_conv_NCHWc_cpu_common_int8( - s, - data_vec, - conv_out, - last, - reg_n, - int32_lanes=32, - int8_elems=4, - intrin=None, - inline_fused=True, -): - unroll_kw = False - _, _, _, _, ic_bn = get_const_tuple(data_vec.shape) - _, _, _, _, oc_bn = get_const_tuple(conv_out.shape) - - # schedule pad - if isinstance(s[data_vec].op, te.tensor.ComputeOp) and "pad" in data_vec.op.tag: - batch, ic_chunk, ih, iw, ic_block = s[data_vec].op.axis - parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih) - # s[data_vec].parallel(parallel_axis) - data_vec = data_vec.op.input_tensors[0] - - # schedule 5-D NCHW[x]c conv - C, O = conv_out, last - CC = s.cache_write(C, "global") - - batch, oc_chunk, oh, ow, oc_block = s[C].op.axis - ow_chunk, ow_block = s[C].split(ow, factor=reg_n) - s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) - parallel_axis = s[C].fuse(batch, oc_chunk, oh) - s[C].vectorize(oc_block) - - if C == O: - s[C].parallel(parallel_axis) - - s[CC].compute_at(s[C], parallel_axis) - _, oc_chunk, oh, ow, oc_block = s[CC].op.axis - kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis - - ow_chunk, ow_block = s[CC].split(ow, factor=reg_n) - - assert oc_bn % int32_lanes == 0, f"oc_bn={oc_bn} % int32_lanes={int32_lanes} != 0" - assert ( - ic_bn % int8_elems == 0 - ), f"ic_bn={ic_bn} % int8_elems={int8_elems} != 0" # (u)int8 elements in (u)int32 - - oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes) - - if unroll_kw: - s[CC].reorder( - oc_chunk, - oh, - ow_chunk, - ic_outer, - kh, - ic_f_inner, - kw, - ow_block, - oc_f_inner, - oc_s_inner, - ic_s_inner, - ) - s[CC].unroll(kw) - else: - s[CC].reorder( - oc_chunk, - oh, - ow_chunk, - ic_outer, - kh, - kw, - ic_f_inner, - ow_block, - oc_f_inner, - oc_s_inner, - ic_s_inner, - ) - - s[CC].tensorize(oc_s_inner, intrin) - - s[CC].unroll(ow_block) - s[CC].unroll(oc_f_inner) - - if C != O: - batch, oc_chunk, oh, ow, oc_block = s[O].op.axis - ow_chunk, ow_block = s[O].split(ow, factor=reg_n) - s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) - parallel_axis = s[O].fuse(batch, oc_chunk, oh) - - if inline_fused: - s[C].compute_at(s[O], ow_block) - else: - s[C].compute_at(s[O], parallel_axis) - s[O].vectorize(oc_block) - s[O].parallel(parallel_axis) - - return s From cabe37c1be79d3c9ccaf35bd937de3cba02b6e43 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 21 Sep 2022 03:43:57 -0700 Subject: [PATCH 07/14] clean up --- python/tvm/topi/hexagon/conv2d.py | 12 ++----- python/tvm/topi/hexagon/conv2d_alter_op.py | 40 ---------------------- 2 files changed, 2 insertions(+), 50 deletions(-) diff --git a/python/tvm/topi/hexagon/conv2d.py b/python/tvm/topi/hexagon/conv2d.py index d56c8f3a53e0..3e1de8f04a45 100644 --- a/python/tvm/topi/hexagon/conv2d.py +++ b/python/tvm/topi/hexagon/conv2d.py @@ -111,16 +111,8 @@ def _callback(op): conv_out = op.output(0) kernel_vec = conv_out.op.input_tensors[1] data_vec = conv_out.op.input_tensors[0] - data = ( - data_vec.op.input_tensors[0] - if isinstance(data_vec.op, te.tensor.ComputeOp) and "pad" not in data_vec.op.tag - else data_vec - ) - if isinstance(data.op, te.tensor.ComputeOp) and "pad" in data.op.tag: - data_pad = data - data = data_pad.op.input_tensors[0] - out_width = conv_out.shape[3] + reg_n = 1 for n in range(31, 0, -1): if out_width % n == 0: @@ -129,7 +121,7 @@ def _callback(op): cfg = {"tile_ow": reg_n, "unroll_kw": False} args = [s, cfg, data_vec, kernel_vec, conv_out, outs[0]] - intrin = dot_vrmpy(data.dtype, kernel_vec.dtype) + intrin = dot_vrmpy(data_vec.dtype, kernel_vec.dtype) conv2d_generic.schedule_conv_NCHWc_cpu_common_int8( *args, int32_lanes=32, int8_elems=4, intrin=intrin, inline_fused=True, diff --git a/python/tvm/topi/hexagon/conv2d_alter_op.py b/python/tvm/topi/hexagon/conv2d_alter_op.py index a0ca5c386711..423680fd04ca 100644 --- a/python/tvm/topi/hexagon/conv2d_alter_op.py +++ b/python/tvm/topi/hexagon/conv2d_alter_op.py @@ -86,45 +86,6 @@ def _conv2d_legalize(attrs, inputs, arg_types): # Collect the input exprs. data, kernel = inputs - if data_layout == "NHWC" and kernel_layout == "HWIO": - # Collect the input tensors. - data_tensor, kernel_tensor = arg_types[0], arg_types[1] - out_channel = kernel_tensor.shape[0] - - # Dilation not supported yet. Return None if dilation is not (1, 1) - dilation = attrs.get_int_tuple("dilation") - if not (dilation[0] == 1 and dilation[1] == 1): - return None - - # No legalization for depthwise convolutions yet. - groups = attrs.get_int("groups") - if groups != 1: - return None - - # Get the conv attrs - new_attrs = {k: attrs[k] for k in attrs.keys()} - - padding = attrs.get_int_tuple("padding") - kh, kw = attrs.get_int_tuple("kernel_size") - pt, pl, pb, pr = get_pad_tuple(padding, (kh, kw)) - - out_channel_vector_length = 64 if output_tensor.dtype == "float16" else 32 - out_channel = kernel_tensor.shape[3].value - - if out_channel % out_channel_vector_length != 0: - new_out_channel = ( - (out_channel + out_channel_vector_length) // out_channel_vector_length - ) * out_channel_vector_length - diff = new_out_channel - out_channel - kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, 0), (0, diff))) - - new_attrs["channels"] = new_out_channel - out = relay.nn.conv2d(data, kernel, **new_attrs) - original_out_shape = [x.value for x in output_tensor.shape] - return relay.strided_slice(out, begin=[0, 0, 0, 0], end=original_out_shape) - else: - return relay.nn.conv2d(data, kernel, **new_attrs) - if data_layout != "NCHW" or kernel_layout != "OIHW": return None @@ -134,7 +95,6 @@ def _conv2d_legalize(attrs, inputs, arg_types): if "int8" in data_tensor.dtype and "int8" in data_tensor.dtype and out_channel % 32 == 0: data_dtype = data_tensor.dtype - kernel_dtype = kernel_tensor.dtype # Collect the output tensor. output_tensor = arg_types[2] From d25315366928871434f337c46ec2c748210bc130 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 27 Sep 2022 04:56:03 +0900 Subject: [PATCH 08/14] doc update --- python/tvm/relay/op/strategy/hexagon.py | 1 + python/tvm/topi/generic/conv2d.py | 4 +- python/tvm/topi/hexagon/conv2d.py | 9 +- python/tvm/topi/hexagon/conv2d_alter_op.py | 101 +++++++++--------- python/tvm/topi/hexagon/dense.py | 46 ++++---- python/tvm/topi/hexagon/dense_alter_op.py | 43 ++++---- python/tvm/topi/hexagon/tensor_intrin.py | 1 + .../contrib/test_hexagon/test_launcher.py | 16 +-- 8 files changed, 119 insertions(+), 102 deletions(-) diff --git a/python/tvm/relay/op/strategy/hexagon.py b/python/tvm/relay/op/strategy/hexagon.py index 45a3d87e6af4..48fe1485cd57 100644 --- a/python/tvm/relay/op/strategy/hexagon.py +++ b/python/tvm/relay/op/strategy/hexagon.py @@ -191,6 +191,7 @@ def schedule_reduce_hexagon(attrs, outs, target): @conv2d_NCHWc_strategy.register("hexagon") def conv2d_NCHWc_strategy_hexagon(attrs, inputs, out_type, target): + """conv2d_NCHWc_ hexagon strategy""" strategy = _op.OpStrategy() data, kernel = inputs strategy.add_implementation( diff --git a/python/tvm/topi/generic/conv2d.py b/python/tvm/topi/generic/conv2d.py index 6467c44faa77..76cd9a7d69d1 100644 --- a/python/tvm/topi/generic/conv2d.py +++ b/python/tvm/topi/generic/conv2d.py @@ -145,9 +145,9 @@ def schedule_conv_NCHWc_cpu_common_int8( reg_n = cfg["tile_ow"].size[-1] if isinstance(cfg["unroll_kw"], (int, bool)): - unroll_kw = cfg["unroll_kw"] + unroll_kw = cfg["unroll_kw"] else: - unroll_kw = cfg["unroll_kw"].val + unroll_kw = cfg["unroll_kw"].val _, _, _, _, ic_bn = get_const_tuple(data_vec.shape) _, _, _, _, oc_bn = get_const_tuple(conv_out.shape) diff --git a/python/tvm/topi/hexagon/conv2d.py b/python/tvm/topi/hexagon/conv2d.py index 3e1de8f04a45..b8e22adb25a9 100644 --- a/python/tvm/topi/hexagon/conv2d.py +++ b/python/tvm/topi/hexagon/conv2d.py @@ -96,6 +96,7 @@ def _callback(op): def conv2d_NCHWc_int8( data, kernel, stride, padding, dilation, layout, out_layout, out_dtype="int32" ): + """Compute definition for int8 conv2d in NCHWc layout""" n_elems = int(kernel.shape[-1]) return nn.conv2d_NCHWc_int8( data, kernel, stride, padding, dilation, layout, out_layout, out_dtype, n_elems=n_elems @@ -103,7 +104,7 @@ def conv2d_NCHWc_int8( def schedule_conv2d_NCHWc_int8(outs): - """Create schedule for tensors""" + """Schedule for int8 conv2d in NCHWc layout using vrmpy tensorization""" s = te.create_schedule([x.op for x in outs]) def _callback(op): @@ -124,7 +125,11 @@ def _callback(op): intrin = dot_vrmpy(data_vec.dtype, kernel_vec.dtype) conv2d_generic.schedule_conv_NCHWc_cpu_common_int8( - *args, int32_lanes=32, int8_elems=4, intrin=intrin, inline_fused=True, + *args, + int32_lanes=32, + int8_elems=4, + intrin=intrin, + inline_fused=True, ) traverse_inline(s, outs[0].op, _callback) diff --git a/python/tvm/topi/hexagon/conv2d_alter_op.py b/python/tvm/topi/hexagon/conv2d_alter_op.py index 423680fd04ca..02b0a6ae985e 100644 --- a/python/tvm/topi/hexagon/conv2d_alter_op.py +++ b/python/tvm/topi/hexagon/conv2d_alter_op.py @@ -15,69 +15,68 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name,unused-variable,unused-argument,no-member -"""Dense alter op functions for ARM""" +"""Conv2d alter op functions for Hexagon""" -import tvm from tvm import relay -from tvm import autotvm from ..utils import get_const_tuple from .. import nn -from ..nn.utils import get_pad_tuple -from ..nn import conv2d_legalize, conv2d_alter_layout +from ..nn import conv2d_alter_layout from ..generic.conv2d import conv2d_alter_int8_common -def check_vrmpy_applicable(x, y): - out_channel, in_channel, _, _ = get_const_tuple(y.shape) - return ( - "int8" in x.dtype and "int8" in y.dtype and out_channel % 32 == 0 and in_channel % 4 == 0 - ) - - @conv2d_alter_layout.register("hexagon") def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): - target = tvm.target.Target.current(allow_none=False) - dispatch_ctx = autotvm.task.DispatchContext.current + """Convert nn.conv2d into nn.contrib_conv2d_nchwc if vrmpy is applicable.""" new_attrs = {k: attrs[k] for k in attrs.keys()} - # Parse the attributes. - padding = attrs.get_int_tuple("padding") - strides = attrs.get_int_tuple("strides") - dilation = attrs.get_int_tuple("dilation") data_layout = attrs["data_layout"] kernel_layout = attrs["kernel_layout"] data_tensor, kernel_tensor = tinfos - data_dtype = data_tensor.dtype - kernel_dtype = kernel_tensor.dtype - out_dtype = out_type.dtype - - if not check_vrmpy_applicable(data_tensor, kernel_tensor) or data_layout != "NCHW" or kernel_layout != "OIHW": - return None + out_channel, in_channel, _, _ = get_const_tuple(kernel_tensor.shape) - batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape) - out_channel, _, kh, kw = get_const_tuple(kernel_tensor.shape) - data_dtype = data_tensor.dtype - kernel_dtype = kernel_tensor.dtype + if ( + "int8" in data_tensor.dtype + and "int8" in kernel_tensor.dtype + and out_channel % 32 == 0 + and in_channel % 4 == 0 + and data_layout == "NCHW" + and kernel_layout == "OIHW" + ): + out_channel, in_channel, _, _ = get_const_tuple(kernel_tensor.shape) - n_elems = 4 - ic_bn, oc_bn = 32, 32 + n_elems = 4 + oc_bn = 32 + ic_bn = min(in_channel, 32) - if ic_bn > in_channel: - assert in_channel == 4 - ic_bn = in_channel + new_attrs = {k: attrs[k] for k in attrs.keys()} - new_attrs = {k: attrs[k] for k in attrs.keys()} + new_attrs["channels"] = out_channel + new_attrs["data_layout"] = "NCHW%dc" % ic_bn + new_attrs["kernel_layout"] = "OIHW{:n}i{:n}o{:n}i".format(ic_bn // n_elems, oc_bn, n_elems) + new_attrs["out_layout"] = "NCHW%dc" % oc_bn - new_attrs["channels"] = out_channel - new_attrs["data_layout"] = "NCHW%dc" % ic_bn - new_attrs["kernel_layout"] = "OIHW{:n}i{:n}o{:n}i".format(ic_bn // n_elems, oc_bn, n_elems) - new_attrs["out_layout"] = "NCHW%dc" % oc_bn + return relay.nn.contrib_conv2d_nchwc(*inputs, **new_attrs) - return relay.nn.contrib_conv2d_nchwc(*inputs, **new_attrs) + return None @nn.conv2d_legalize.register("hexagon") def _conv2d_legalize(attrs, inputs, arg_types): + """Legalize conv2d op for vrmpy tensorization. + + If the inputs are signed or unsigned int8, the input and output channels are padded to be + a multiple of 4 and 32 respectively. + + If the input data types are (int8, int8), they are converted to (uint8, int8) and + the vector-by-vector variant of vrmpy is applied. + If the input data types are (uint8, uint8), the more efficient vector-by-scalar variant of vrmpy + is applied. + + Unlike the nn.dense case (see dense_alter_op.py), we do not convert (uint8, int8) to + (uint8, uint8). That would introduce another convolution by a constant (128 or 1) filter, + to compensate for the dtype legalization. In the nn.dense case, such compensation factor is + just a sum over the K axis. + """ data_layout = attrs["data_layout"] kernel_layout = attrs["kernel_layout"] @@ -89,23 +88,25 @@ def _conv2d_legalize(attrs, inputs, arg_types): if data_layout != "NCHW" or kernel_layout != "OIHW": return None - # Collect the input tensors. data_tensor, kernel_tensor = arg_types[0], arg_types[1] - out_channel = kernel_tensor.shape[0] - if "int8" in data_tensor.dtype and "int8" in data_tensor.dtype and out_channel % 32 == 0: - data_dtype = data_tensor.dtype - - # Collect the output tensor. + if "int8" in data_tensor.dtype and "int8" in data_tensor.dtype: output_tensor = arg_types[2] - - # Collect the input exprs. data, kernel = inputs - - data_dtype = "uint8" + desired_data_dtype = "uint8" + in_channel_vector_length = 4 + out_channel_vector_length = 32 return conv2d_alter_int8_common( - data, data_tensor, kernel, kernel_tensor, output_tensor, attrs, data_dtype, 4, 32 + data, + data_tensor, + kernel, + kernel_tensor, + output_tensor, + attrs, + desired_data_dtype, + in_channel_vector_length, + out_channel_vector_length, ) return None diff --git a/python/tvm/topi/hexagon/dense.py b/python/tvm/topi/hexagon/dense.py index 59190ba83efb..0fad0ca778c6 100644 --- a/python/tvm/topi/hexagon/dense.py +++ b/python/tvm/topi/hexagon/dense.py @@ -45,7 +45,7 @@ def schedule_dense(outs): def dense_u8u8i32_vrmpy_compute(X, packed_w, bias, out_dtype): - """Compute for uint8 x uint8 -> int32 dense""" + """Compute for uint8 x uint8 -> int32 dense using vrmpy""" assert X.dtype == "uint8" and packed_w.dtype == "uint8" and out_dtype == "int32" m, k = X.shape n_o, _, n_i, _ = packed_w.shape @@ -71,38 +71,38 @@ def dense_u8u8i32_vrmpy_compute(X, packed_w, bias, out_dtype): return C -def dense_u8u8i32_vrmpy_common(s, C, O): - (a_k,) = C.op.reduce_axis - a_y = C.op.axis[-2] - a_yo, a_yi = s[C].split(a_y, factor=32) - a_xo, a_xi = s[C].split(C.op.axis[-1], factor=32) - a_ko, a_ki = s[C].split(a_k, factor=4) - - s[C].reorder(a_yo, a_xo, a_yi, a_ko, a_xi, a_ki) +def dense_u8u8i32_vrmpy_schedule(outs): + """Schedule for vrmpy dense""" + s = te.create_schedule([x.op for x in outs]) + # O: The output of the fused op + O = outs[0] - pc = dot_vrmpy("uint8", "uint8") - s[C].tensorize(a_xi, pc) + def _schedule_dense(s, C, O): + (a_k,) = C.op.reduce_axis + a_y = C.op.axis[-2] + a_yo, a_yi = s[C].split(a_y, factor=32) + a_xo, a_xi = s[C].split(C.op.axis[-1], factor=32) + a_ko, a_ki = s[C].split(a_k, factor=4) - if C != O: - a_y = O.op.axis[-2] - a_yo, a_yi = s[O].split(a_y, factor=32) - a_xo, a_xi = s[O].split(O.op.axis[-1], factor=32) + s[C].reorder(a_yo, a_xo, a_yi, a_ko, a_xi, a_ki) - s[O].reorder(a_yo, a_xo, a_yi, a_xi) - s[O].vectorize(a_xi) - s[C].compute_at(s[O], a_yi) + pc = dot_vrmpy("uint8", "uint8") + s[C].tensorize(a_xi, pc) + if C != O: + a_y = O.op.axis[-2] + a_yo, a_yi = s[O].split(a_y, factor=32) + a_xo, a_xi = s[O].split(O.op.axis[-1], factor=32) -def dense_u8u8i32_vrmpy_schedule(outs): - s = te.create_schedule([x.op for x in outs]) - # O: The output of the fused op - O = outs[0] + s[O].reorder(a_yo, a_xo, a_yi, a_xi) + s[O].vectorize(a_xi) + s[C].compute_at(s[O], a_yi) def _callback(op): if "u8u8i32_vrmpy" in op.tag: # C: The output of GEMM C = op.output(0) - dense_u8u8i32_vrmpy_common(s, C, O) + _schedule_dense(s, C, O) traverse_inline(s, outs[0].op, _callback) diff --git a/python/tvm/topi/hexagon/dense_alter_op.py b/python/tvm/topi/hexagon/dense_alter_op.py index 496b842a6a16..35ddb213cc27 100644 --- a/python/tvm/topi/hexagon/dense_alter_op.py +++ b/python/tvm/topi/hexagon/dense_alter_op.py @@ -32,47 +32,42 @@ def check_vrmpy_applicable(x, y): ) - @dense_alter_layout.register(["hexagon"]) def _alter_dense_layout(attrs, inputs, tinfos, out_type): data_tensor, weight_tensor = tinfos out_dtype = out_type.dtype - M, K = get_const_tuple(data_tensor.shape) - N, _ = get_const_tuple(weight_tensor.shape) - if check_vrmpy_applicable(data_tensor, weight_tensor): # and data_tensor.dtype == "uint8": + if check_vrmpy_applicable(data_tensor, weight_tensor): weight_layout = "NC32n4c" return relay.nn.contrib_dense_pack(inputs[0], inputs[1], weight_layout, None, out_dtype) else: return None -def vrmpy_legalize(x, w, arg_types, op, attrs, is_batched_mm): +def vrmpy_legalize(x, w, arg_types, op, attrs): """ - Legalizes s8, s8 -> s32 GEMM op for VRMPY. + Legalizes int8 inputs to dense for vrmpy. X'_u8 = X_s8 + 128 X_s8 * W_s8 = (X'_u8 - 128) * (W'_u8 - 128) = X'_u8 * W'_u8 - X'_u8 * 128 - 128 * W'_u8 + 128 * 128 X_u8 * W_s8 = X_u8 * (W'_u8 - 128) = X'_u8 * W'_u8 - X_u8 * 128 """ + if not check_vrmpy_applicable(arg_types[0], arg_types[1]): + return None + def cast_to_uint8(x): x = relay.cast(x, "int32") x = relay.add(x, relay.const(128, "int32")) return relay.cast(x, "uint8") - if check_vrmpy_applicable(arg_types[0], arg_types[1]) and arg_types[0].dtype == "int8" and arg_types[1].dtype == "int8": + if arg_types[0].dtype == "int8" and arg_types[1].dtype == "int8": x = cast_to_uint8(x) w = cast_to_uint8(w) W_u8x128 = relay.const(-128, "int32") * relay.sum(relay.cast(w, "int32"), axis=[-1]) X_u8x128 = relay.const(-128, "int32") * relay.sum(relay.cast(x, "int32"), axis=[-1]) - - if is_batched_mm: - X_u8x128 = relay.expand_dims(X_u8x128, axis=2) - W_u8x128 = relay.expand_dims(W_u8x128, axis=1) - else: - X_u8x128 = relay.expand_dims(X_u8x128, axis=1) + X_u8x128 = relay.expand_dims(X_u8x128, axis=1) out = op(x, w, **attrs) @@ -82,10 +77,12 @@ def cast_to_uint8(x): k_dim = int(arg_types[0].shape[-1]) return out + relay.const(128 * 128 * k_dim, "int32") - if check_vrmpy_applicable(arg_types[0], arg_types[1]) and arg_types[0].dtype == "uint8" and arg_types[1].dtype == "int8": + if arg_types[0].dtype == "uint8" and arg_types[1].dtype == "int8": w = cast_to_uint8(w) - X_u8x128 = relay.expand_dims(relay.const(-128, "int32") * relay.sum(relay.cast(x, "int32"), axis=[-1]), axis=1) + X_u8x128 = relay.expand_dims( + relay.const(-128, "int32") * relay.sum(relay.cast(x, "int32"), axis=[-1]), axis=1 + ) out = op(x, w, **attrs) @@ -96,6 +93,14 @@ def cast_to_uint8(x): @nn.dense_legalize.register("hexagon") def _dense_legalize(attrs, inputs, arg_types): + """Legalize dense op for HVX vectorization and vrmpy tensorization. + + Given a workload with a matrix X of shape (M, K) and a matrix Y of (N, K), + we first pad the N dimension to be a multiple of the output vector length. + + And if the inputs are signed or unsigned int8 and the Y matrix can be packed into the + NK32n4k layout, we convert both inputs to uint8 to apply the most efficient variant of vrmpy. + """ new_attrs = {k: attrs[k] for k in attrs.keys()} # Collect the input tensors. x_tensor, y_tensor = arg_types[0], arg_types[1] @@ -127,8 +132,10 @@ def _dense_legalize(attrs, inputs, arg_types): if attrs["units"] is not None: new_attrs["units"] = N + dn - arg_types = [arg_types[0], - tvm.ir.tensor_type.TensorType([N + dn, arg_types[1].shape[1]], arg_types[1].dtype)] + arg_types = [ + arg_types[0], + tvm.ir.tensor_type.TensorType([N + dn, arg_types[1].shape[1]], arg_types[1].dtype), + ] vrmpy_out = vrmpy_legalize(x, y_, arg_types, relay.nn.dense, new_attrs, False) @@ -137,7 +144,7 @@ def _dense_legalize(attrs, inputs, arg_types): else: out_ = vrmpy_out - out = relay.strided_slice(out_, begin=[0, 0], end=[x.value for x in output_tensor.shape]) + out = relay.strided_slice(out_, begin=[0, 0], end=[x.value for x in output_tensor.shape]) return out return vrmpy_legalize(inputs[0], inputs[1], arg_types, relay.nn.dense, attrs, False) diff --git a/python/tvm/topi/hexagon/tensor_intrin.py b/python/tvm/topi/hexagon/tensor_intrin.py index 4a5371135778..daf3a5180833 100644 --- a/python/tvm/topi/hexagon/tensor_intrin.py +++ b/python/tvm/topi/hexagon/tensor_intrin.py @@ -73,6 +73,7 @@ def _q_multiply_shift_hexagon(op): def dot_vrmpy(x_ty, y_ty): + """Generates vrmpy instruciton for tensorization.""" int32_lanes = 32 num_int8_elements = 4 # 4 int8 elements in int32 data = te.placeholder((num_int8_elements,), dtype=x_ty, name="data") diff --git a/tests/python/contrib/test_hexagon/test_launcher.py b/tests/python/contrib/test_hexagon/test_launcher.py index f8bdb94b93a5..7ccceb040b35 100644 --- a/tests/python/contrib/test_hexagon/test_launcher.py +++ b/tests/python/contrib/test_hexagon/test_launcher.py @@ -427,12 +427,7 @@ def test_aot_executor_multiple_conv2d(hexagon_session: Session, aot_host_target, @tvm.testing.requires_hexagon def test_conv2d_u8u8i32_vrmpy(hexagon_session): def get_conv2d_nchw( - d_shape, - w_shape, - padding, - strides=(1, 1), - data_dtype = "int8", - weight_dtype = "int8" + d_shape, w_shape, padding, strides=(1, 1), data_dtype="int8", weight_dtype="int8" ): out_dtype = "int32" @@ -464,7 +459,14 @@ def get_conv2d_nchw( data_dtype = "uint8" weight_dtype = "int8" - conv2d = get_conv2d_nchw(data_shape, weight_shape, padding, strides=strides, data_dtype=data_dtype, weight_dtype=weight_dtype) + conv2d = get_conv2d_nchw( + data_shape, + weight_shape, + padding, + strides=strides, + data_dtype=data_dtype, + weight_dtype=weight_dtype, + ) bias_add = relay.nn.bias_add(conv2d, bias) use_bias = True From 17bde45dcace888e98422da0120b3a16824dcdf0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 27 Sep 2022 07:53:59 +0900 Subject: [PATCH 09/14] pylint fix --- python/tvm/relay/op/strategy/hexagon.py | 1 - python/tvm/topi/hexagon/conv2d.py | 3 +-- python/tvm/topi/hexagon/dense.py | 4 ++-- python/tvm/topi/hexagon/dense_alter_op.py | 7 ++----- python/tvm/topi/hexagon/tensor_intrin.py | 1 + tests/python/contrib/test_hexagon/test_launcher.py | 2 +- 6 files changed, 7 insertions(+), 11 deletions(-) diff --git a/python/tvm/relay/op/strategy/hexagon.py b/python/tvm/relay/op/strategy/hexagon.py index 48fe1485cd57..693352d650ba 100644 --- a/python/tvm/relay/op/strategy/hexagon.py +++ b/python/tvm/relay/op/strategy/hexagon.py @@ -193,7 +193,6 @@ def schedule_reduce_hexagon(attrs, outs, target): def conv2d_NCHWc_strategy_hexagon(attrs, inputs, out_type, target): """conv2d_NCHWc_ hexagon strategy""" strategy = _op.OpStrategy() - data, kernel = inputs strategy.add_implementation( wrap_compute_conv2d( topi.hexagon.conv2d_NCHWc_int8, need_data_layout=True, need_out_layout=True diff --git a/python/tvm/topi/hexagon/conv2d.py b/python/tvm/topi/hexagon/conv2d.py index b8e22adb25a9..aa1b7e57e464 100644 --- a/python/tvm/topi/hexagon/conv2d.py +++ b/python/tvm/topi/hexagon/conv2d.py @@ -14,14 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +# pylint: disable=invalid-name """Schedule for conv2d""" import tvm from tvm import te from .. import nn from ..utils import traverse_inline -from tvm.topi.utils import get_const_tuple from .tensor_intrin import dot_vrmpy from ..generic import conv2d as conv2d_generic diff --git a/python/tvm/topi/hexagon/dense.py b/python/tvm/topi/hexagon/dense.py index 0fad0ca778c6..abb8a1410a8d 100644 --- a/python/tvm/topi/hexagon/dense.py +++ b/python/tvm/topi/hexagon/dense.py @@ -14,11 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +# pylint: disable=invalid-name """Schedule for dense operator""" import tvm -from tvm.topi.utils import get_const_tuple, traverse_inline +from tvm.topi.utils import traverse_inline from tvm import te from .. import tag from .tensor_intrin import dot_vrmpy diff --git a/python/tvm/topi/hexagon/dense_alter_op.py b/python/tvm/topi/hexagon/dense_alter_op.py index 35ddb213cc27..147d62183559 100644 --- a/python/tvm/topi/hexagon/dense_alter_op.py +++ b/python/tvm/topi/hexagon/dense_alter_op.py @@ -18,10 +18,7 @@ """Dense alter op functions for ARM""" import tvm -from tvm import te from tvm import relay -from tvm import autotvm -from ..utils import get_const_tuple from .. import nn from ..nn import dense_alter_layout @@ -137,7 +134,7 @@ def _dense_legalize(attrs, inputs, arg_types): tvm.ir.tensor_type.TensorType([N + dn, arg_types[1].shape[1]], arg_types[1].dtype), ] - vrmpy_out = vrmpy_legalize(x, y_, arg_types, relay.nn.dense, new_attrs, False) + vrmpy_out = vrmpy_legalize(x, y_, arg_types, relay.nn.dense, new_attrs) if vrmpy_out is None: out_ = relay.nn.dense(x, y_, **new_attrs) @@ -147,4 +144,4 @@ def _dense_legalize(attrs, inputs, arg_types): out = relay.strided_slice(out_, begin=[0, 0], end=[x.value for x in output_tensor.shape]) return out - return vrmpy_legalize(inputs[0], inputs[1], arg_types, relay.nn.dense, attrs, False) + return vrmpy_legalize(inputs[0], inputs[1], arg_types, relay.nn.dense, attrs) diff --git a/python/tvm/topi/hexagon/tensor_intrin.py b/python/tvm/topi/hexagon/tensor_intrin.py index daf3a5180833..e8587f571fd6 100644 --- a/python/tvm/topi/hexagon/tensor_intrin.py +++ b/python/tvm/topi/hexagon/tensor_intrin.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name """Optimized implementation of q_multiply_shift based on LLVM intrinsics""" import tvm diff --git a/tests/python/contrib/test_hexagon/test_launcher.py b/tests/python/contrib/test_hexagon/test_launcher.py index 7ccceb040b35..13b83e5cab09 100644 --- a/tests/python/contrib/test_hexagon/test_launcher.py +++ b/tests/python/contrib/test_hexagon/test_launcher.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +# pylint: disable=invalid-name,missing-function-docstring """ Test rpc based launcher for hexagon """ import numpy as np From 25374111a2e312889dcb49958254da439e3e8432 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 26 Sep 2022 22:27:58 -0700 Subject: [PATCH 10/14] parametrize dtype in test --- .../contrib/test_hexagon/test_launcher.py | 52 +++++++------------ 1 file changed, 20 insertions(+), 32 deletions(-) diff --git a/tests/python/contrib/test_hexagon/test_launcher.py b/tests/python/contrib/test_hexagon/test_launcher.py index 13b83e5cab09..c22acf8b5ef5 100644 --- a/tests/python/contrib/test_hexagon/test_launcher.py +++ b/tests/python/contrib/test_hexagon/test_launcher.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=invalid-name,missing-function-docstring """ Test rpc based launcher for hexagon """ +import pytest import numpy as np @@ -424,11 +425,16 @@ def test_aot_executor_multiple_conv2d(hexagon_session: Session, aot_host_target, tvm.testing.assert_allclose(hexagon_output, expected_output, rtol=1e-4, atol=1e-5) +data_dtype = tvm.testing.parameter("int8", "uint8") +weight_dtype = tvm.testing.parameter("int8", "uint8") + + @tvm.testing.requires_hexagon -def test_conv2d_u8u8i32_vrmpy(hexagon_session): - def get_conv2d_nchw( - d_shape, w_shape, padding, strides=(1, 1), data_dtype="int8", weight_dtype="int8" - ): +def test_conv2d_relay_vrmpy(hexagon_session, data_dtype, weight_dtype): + if data_dtype == "int8" and weight_dtype == "uint8": + pytest.skip("(i8, u8) input pair is not supported") + + def get_conv2d_nchw(d_shape, w_shape, padding, strides=(1, 1)): out_dtype = "int32" data = relay.var("data", shape=d_shape, dtype=data_dtype) @@ -457,26 +463,14 @@ def get_conv2d_nchw( bias = relay.var("bias", shape=bias_shape, dtype="int32") - data_dtype = "uint8" - weight_dtype = "int8" conv2d = get_conv2d_nchw( data_shape, weight_shape, padding, strides=strides, - data_dtype=data_dtype, - weight_dtype=weight_dtype, ) bias_add = relay.nn.bias_add(conv2d, bias) - - use_bias = True - - if use_bias: - out = bias_add - else: - out = conv2d - - mod = tvm.IRModule.from_expr(out) + mod = tvm.IRModule.from_expr(bias_add) if data_dtype == "uint8": data_np = np.random.uniform(0, 255, size=data_shape).astype("uint8") @@ -518,7 +512,10 @@ def get_conv2d_nchw( @tvm.testing.requires_hexagon -def test_dense_u8u8i32_vrmpy(hexagon_session): +def test_dense_relay_vrmpy(hexagon_session, data_dtype, weight_dtype): + if data_dtype == "int8" and weight_dtype == "uint8": + pytest.skip("(i8, u8) input pair is not supported") + target_hexagon = tvm.target.hexagon("v68") target = tvm.target.Target(target_hexagon, host=target_hexagon) @@ -528,15 +525,11 @@ def test_dense_u8u8i32_vrmpy(hexagon_session): data_shape = (M, K) weight_shape = (N, K) - data_dtype = "uint8" - weight_dtype = "int8" data = relay.var("data", shape=data_shape, dtype=data_dtype) weight = relay.var("weight", shape=weight_shape, dtype=weight_dtype) dense = relay.nn.dense(data, weight, out_dtype="int32") - use_bias = False - if data_dtype == "uint8": data_np = np.random.uniform(0, 255, size=data_shape).astype("uint8") else: @@ -551,13 +544,9 @@ def test_dense_u8u8i32_vrmpy(hexagon_session): params = {"weight": weight_np, "bias": bias_np} - if use_bias: - bias = relay.var("bias", shape=(weight_shape[0],), dtype="int32") - out = relay.nn.bias_add(dense, bias) - else: - out = dense - - mod = tvm.IRModule.from_expr(out) + bias = relay.var("bias", shape=(weight_shape[0],), dtype="int32") + bias_add = relay.nn.bias_add(dense, bias) + mod = tvm.IRModule.from_expr(bias_add) with tvm.transform.PassContext( opt_level=3, @@ -575,10 +564,9 @@ def test_dense_u8u8i32_vrmpy(hexagon_session): rt_mod.run() out = rt_mod.get_output(0).numpy() - ref = np.dot(data_np.astype("int32"), weight_np.transpose().astype("int32")) - if use_bias: - ref += bias_np + ref = np.dot(data_np.astype("int32"), weight_np.transpose().astype("int32")) + ref += bias_np np.testing.assert_equal(out, ref) From 978158ddf5fb9fc4d611ccd7b92452dfee2a7e32 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 26 Sep 2022 22:40:48 -0700 Subject: [PATCH 11/14] doc update --- python/tvm/topi/hexagon/conv2d_alter_op.py | 1 - python/tvm/topi/hexagon/tensor_intrin.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/python/tvm/topi/hexagon/conv2d_alter_op.py b/python/tvm/topi/hexagon/conv2d_alter_op.py index 02b0a6ae985e..201b6f804352 100644 --- a/python/tvm/topi/hexagon/conv2d_alter_op.py +++ b/python/tvm/topi/hexagon/conv2d_alter_op.py @@ -82,7 +82,6 @@ def _conv2d_legalize(attrs, inputs, arg_types): output_tensor = arg_types[2] - # Collect the input exprs. data, kernel = inputs if data_layout != "NCHW" or kernel_layout != "OIHW": diff --git a/python/tvm/topi/hexagon/tensor_intrin.py b/python/tvm/topi/hexagon/tensor_intrin.py index e8587f571fd6..adea4690d4a7 100644 --- a/python/tvm/topi/hexagon/tensor_intrin.py +++ b/python/tvm/topi/hexagon/tensor_intrin.py @@ -137,7 +137,7 @@ def _instr(index): vec_bi32, ) else: - assert False, "Not supported" + raise ValueError(f"Only (u8, u8) or (u8, i8) dtype pairs are supported by vrmpy.") if index == 0: ib.emit(outs[0].vstore(0, quad_reduction)) From 4ad3e63fb0ea828764fbd4b21284bed00c5075f6 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 27 Sep 2022 14:57:46 +0900 Subject: [PATCH 12/14] add missing paralleization for dense --- python/tvm/topi/hexagon/dense.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tvm/topi/hexagon/dense.py b/python/tvm/topi/hexagon/dense.py index abb8a1410a8d..02ad141ecb5a 100644 --- a/python/tvm/topi/hexagon/dense.py +++ b/python/tvm/topi/hexagon/dense.py @@ -88,6 +88,7 @@ def _schedule_dense(s, C, O): pc = dot_vrmpy("uint8", "uint8") s[C].tensorize(a_xi, pc) + s[C].parallel(s[C].fuse(a_yo, a_xo)) if C != O: a_y = O.op.axis[-2] @@ -97,6 +98,7 @@ def _schedule_dense(s, C, O): s[O].reorder(a_yo, a_xo, a_yi, a_xi) s[O].vectorize(a_xi) s[C].compute_at(s[O], a_yi) + s[O].parallel(s[O].fuse(a_yo, a_xo)) def _callback(op): if "u8u8i32_vrmpy" in op.tag: From 849ed3c9bc9f9dc24a70c5b512b57e7ed9596429 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 27 Sep 2022 15:15:14 +0900 Subject: [PATCH 13/14] more pylint --- tests/python/contrib/test_hexagon/test_launcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/contrib/test_hexagon/test_launcher.py b/tests/python/contrib/test_hexagon/test_launcher.py index c22acf8b5ef5..7431871524aa 100644 --- a/tests/python/contrib/test_hexagon/test_launcher.py +++ b/tests/python/contrib/test_hexagon/test_launcher.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name,missing-function-docstring +# pylint: disable=invalid-name,missing-function-docstring,redefined-outer-name """ Test rpc based launcher for hexagon """ import pytest From 208896591ff194a9461a60e52d92f7665fed62db Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 27 Sep 2022 18:16:04 +0900 Subject: [PATCH 14/14] fixed for fp32 dense --- python/tvm/topi/hexagon/dense_alter_op.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/topi/hexagon/dense_alter_op.py b/python/tvm/topi/hexagon/dense_alter_op.py index 147d62183559..cb5feb56d68e 100644 --- a/python/tvm/topi/hexagon/dense_alter_op.py +++ b/python/tvm/topi/hexagon/dense_alter_op.py @@ -109,14 +109,14 @@ def _dense_legalize(attrs, inputs, arg_types): # Collect the input exprs. x, y = inputs - M, K = x_tensor.shape - N, K = y_tensor.shape + N, _ = y_tensor.shape if dtype == "float16": vec_len = 64 - else: - assert "int8" in dtype + elif "int8" in dtype: vec_len = 32 + else: + return None if N % vec_len != 0: N_padded = ((N + vec_len) // vec_len) * vec_len