diff --git a/python/tvm/topi/hexagon/qnn/__init__.py b/python/tvm/topi/hexagon/qnn/__init__.py index b8cdc7a26d96..022a552c9d54 100644 --- a/python/tvm/topi/hexagon/qnn/__init__.py +++ b/python/tvm/topi/hexagon/qnn/__init__.py @@ -17,16 +17,13 @@ """ Computes and schedules for Hexagon quantized ops """ +from .adaptive_avg_pool1d import * from .avg_pool2d import qnn_avg_pool2d_compute, qnn_avg_pool2d_schedule -from .qadd_qsub_qmul import * -from .dequantize import ( - dequantize_compute, - dequantize_schedule, -) - -from .quantize import quantize_compute, tir_quantize_schedule +from .conv2d_alter_op import * +from .dequantize import dequantize_compute, dequantize_schedule +from .global_avg_pool2d import * from .nn import * +from .qadd_qsub_qmul import * +from .qdense import * from .qdepthwise_conv2d_slice import qdepthwise_conv2d_compute, qdepthwise_conv2d_schedule -from .adaptive_avg_pool1d import * -from .global_avg_pool2d import * -from .conv2d_alter_op import * +from .quantize import quantize_compute, tir_quantize_schedule diff --git a/python/tvm/topi/hexagon/qnn/qdense.py b/python/tvm/topi/hexagon/qnn/qdense.py new file mode 100644 index 000000000000..53f9077e56ba --- /dev/null +++ b/python/tvm/topi/hexagon/qnn/qdense.py @@ -0,0 +1,193 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Schedule for dense operator""" + +from tvm import te, tir +from tvm.topi import tag +from ..utils import get_layout_transform_fn + + +def qdense_compute( + tensor_a, + tensor_b, + zero_a, + scale_a, + zero_b, + scale_b, + zero_out=None, + scale_out=None, + bias=None, + q_dtype=None, +): + """Hexagon's implementation of a sliced dense operator in Topi. + Uses matmul. + + Parameters + ---------- + tensor_a : tvm.te.Tensor + data 2-D with shape [batch, in_dim] + + tensor_b : tvm.te.Tensor + weight 2-D with shape [in_dim, out_dim] + + zero_a : integer + quantization zero point for tensor a. + + scale_a : float + quantization scale for tensor a. + + zero_b : integer + quantization zero point for tensor b. + + scale_b : float + quantization scale for tensor b. + + zero_out : Optional[integer] + quantization zero point for output. + + scale_out : Optional[float] + quantization scale for output. + + bias : Optional[tvm.te.Tensor] + 1-D with shape [out_dim] + + q_dtype : Optional[str] + The output type. + + Returns + ------- + mat : tvm.te.Tensor + 2-D with shape [batch, out_dim] + + """ + if bias is not None: + assert len(bias.shape) == 1 + if q_dtype is None: + q_dtype = tensor_a.dtype + + batch, in_dim = tensor_a.shape + out_dim, red_dim = tensor_b.shape + + # cmp should be done by values + assert int(in_dim) == int(red_dim) + + k = te.reduce_axis((0, in_dim), name="k") + compute_lambda = lambda n, m: te.sum( + scale_a + * (tensor_a[n, k].astype("float32") - zero_a) + * scale_b + * (tensor_b[k, m].astype("float32") - zero_b), + axis=k, + ) + compute_name = "qmatmul_sliced" + + out = te.compute( + (batch, out_dim), + compute_lambda, + name=compute_name, + attrs={"layout_free_placeholders": [tensor_b]}, + ) + + if bias is not None: + out = te.compute( + (batch, out_dim), + lambda i, j: out[i, j] + bias[j], + tag=tag.BROADCAST, + name="bias", + ) + + # Requantization of dense + if scale_out is not None: + out = te.compute( + (batch, out_dim), + lambda *i: (out[i] / scale_out + zero_out).astype(q_dtype), + name="requantize", + ) + + return out + + +def qdense_schedule(outs, ins, output_layout: str, input_layout: str): + """Schedule for dense op. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of dense in the format + of an array of tensors. + + ins: Array of Tensor + Input tensors into graph. + + output_layout: str + Descriptor string for physical layout + + input_layout: str + Descriptor string for physical layout + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + if not isinstance(ins, list): + ins = [ins] + if not isinstance(outs, list): + outs = [outs] + + func = te.create_prim_func([*ins, *outs]) + s = tir.Schedule(func) + + matmul = s.get_block("qmatmul_sliced") + try: + requantize = s.get_block("requantize") + except tir.schedule.schedule.ScheduleError: + requantize = None + try: + bias = s.get_block("bias") + except tir.schedule.schedule.ScheduleError: + bias = None + + input_transform_fn = get_layout_transform_fn(input_layout) + output_transform_fn = get_layout_transform_fn(output_layout) + + # Transform input and output buffer + s.transform_layout(matmul, ("read", 0), input_transform_fn) + if requantize is not None: + s.transform_layout(requantize, ("write", 0), output_transform_fn) + elif bias is not None: + s.transform_layout(bias, ("write", 0), output_transform_fn) + else: + s.transform_layout(matmul, ("write", 0), output_transform_fn) + + # Vectorize + _, matmul_c, _ = s.get_loops(matmul) + _, matmul_c_inner = s.split(matmul_c, [None, 128]) + s.vectorize(matmul_c_inner) + + # Compute everything inline + if bias is not None and requantize is not None: + _, bias_c = s.get_loops(bias) + s.compute_at(matmul, bias_c) + _, out_c = s.get_loops(requantize) + s.compute_at(bias, out_c) + elif bias is not None and requantize is None: + _, out_c = s.get_loops(bias) + s.compute_at(matmul, out_c) + + return s diff --git a/python/tvm/topi/hexagon/slice_ops/__init__.py b/python/tvm/topi/hexagon/slice_ops/__init__.py index 6b17b64489a9..46ae0c53200f 100644 --- a/python/tvm/topi/hexagon/slice_ops/__init__.py +++ b/python/tvm/topi/hexagon/slice_ops/__init__.py @@ -37,3 +37,4 @@ from .dwconv2d import * from .depth_to_space import d2s_compute, d2s_schedule from .global_avg_pool2d import * +from .dense import * diff --git a/python/tvm/topi/hexagon/slice_ops/dense.py b/python/tvm/topi/hexagon/slice_ops/dense.py new file mode 100644 index 000000000000..a298ff4bc98e --- /dev/null +++ b/python/tvm/topi/hexagon/slice_ops/dense.py @@ -0,0 +1,144 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Schedule for dense operator""" + +from tvm import te, tir +from tvm.topi import tag +from ..utils import get_layout_transform_fn + + +def dense_compute(tensor_a, tensor_b, bias=None, out_dtype=None): + """Hexagon's implementation of a sliced dense operator in Topi. + Uses matmul. + + Parameters + ---------- + tensor_a : tvm.te.Tensor + data 2-D with shape [batch, in_dim] + + tensor_b : tvm.te.Tensor + weight 2-D with shape [in_dim, out_dim] + + bias : Optional[tvm.te.Tensor] + 1-D with shape [out_dim] + + out_dtype : Optional[str] + The output type. This is used for mixed precision. + + Returns + ------- + output : tvm.te.Tensor + 2-D with shape [batch, out_dim] + + """ + if bias is not None: + assert len(bias.shape) == 1 + if out_dtype is None: + out_dtype = tensor_a.dtype + + batch, in_dim = tensor_a.shape + out_dim, red_dim = tensor_b.shape + + # cmp should be done by values + assert int(in_dim) == int(red_dim) + + k = te.reduce_axis((0, in_dim), name="k") + compute_lambda = lambda n, m: te.sum( + tensor_a[n, k].astype(out_dtype) * tensor_b[k, m].astype(out_dtype), axis=k + ) + compute_name = "matmul_sliced" + compute_tag = "matmul" + + mat = te.compute( + (batch, out_dim), + compute_lambda, + name=compute_name, + tag=compute_tag, + attrs={"layout_free_placeholders": [tensor_b]}, + ) + + if bias is not None: + mat = te.compute( + (batch, out_dim), + lambda i, j: mat[i, j] + bias[j], + tag=tag.BROADCAST, + name="bias", + ) + + return mat + + +def dense_schedule(outs, ins, output_layout: str, input_layout: str): + """Schedule for dense op. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of dense in the format + of an array of tensors. + + ins: Array of Tensor + Input tensors into graph. + + output_layout: str + Descriptor string for physical layout + + input_layout: str + Descriptor string for physical layout + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + if not isinstance(ins, list): + ins = [ins] + if not isinstance(outs, list): + outs = [outs] + + func = te.create_prim_func([*ins, *outs]) + s = tir.Schedule(func) + + matmul = s.get_block("matmul_sliced") + try: + bias = s.get_block("bias") + except tir.schedule.schedule.ScheduleError: + bias = None + + input_transform_fn = get_layout_transform_fn(input_layout) + output_transform_fn = get_layout_transform_fn(output_layout) + + # No bias + if bias is None: + s.transform_layout(matmul, ("read", 0), input_transform_fn) + # s.transform_layout(matmul, ("read", 1), input_transform_fn) + s.transform_layout(matmul, ("write", 0), output_transform_fn) + else: + s.transform_layout(matmul, ("read", 0), input_transform_fn) + s.transform_layout(bias, ("write", 0), output_transform_fn) + + _, matmul_c, _ = s.get_loops(matmul) + _, matmul_c_inner = s.split(matmul_c, [None, 64]) + s.vectorize(matmul_c_inner) + + if bias is not None: + _, bias_c = s.get_loops(bias) + _, bias_c_inner = s.split(bias_c, [None, 64]) + s.vectorize(bias_c_inner) + + return s diff --git a/python/tvm/topi/hexagon/utils.py b/python/tvm/topi/hexagon/utils.py index 78ed21e8a13b..86aa87adf319 100644 --- a/python/tvm/topi/hexagon/utils.py +++ b/python/tvm/topi/hexagon/utils.py @@ -75,6 +75,21 @@ def nc_1024c_2d(n, c): return [n, c // 1024, te.AXIS_SEPARATOR, c % 1024] +def nc_2048c_1d(n, c): + """Return index map for nc_2024c 1d layout""" + return [n, c // 2048, c % 2048] + + +def nc_2048c_2d(n, c): + """Return index map for nc_2024c 2d layout""" + return [n, c // 2048, te.AXIS_SEPARATOR, c % 2048] + + +def nc_1024c_1d(n, c): + """Return index map for nc_1024c 1d layout""" + return [n, c // 1024, c % 1024] + + def nhwc_4h2w32c2w_2d(n, h, w, c): """Return index map for nhwc_4h2w32c2w 2d layout""" return [n, h // 4, w // 4, c // 32, te.AXIS_SEPARATOR, h % 4, (w % 4) // 2, c % 32, w % 2] @@ -100,11 +115,6 @@ def nc_2048_2d(n, c): return [n, c // 2048, te.AXIS_SEPARATOR, c % 2048] -def nc_2048c_2d(n, c): - """Return index map for nc_2048 2d layout""" - return [n, c // 2048, te.AXIS_SEPARATOR, c % 2048] - - def nhwc_8h8w32c_2d(n, h, w, c): """Return index map for nhwc_8h8w32c 2d layout""" return [n, h // 8, w // 8, c // 32, te.AXIS_SEPARATOR, h % 8, w % 8, c % 32] @@ -170,8 +180,14 @@ def get_layout_transform_fn(layout): return nc_512c_1d if layout == "nhwc-4h2w32c2w-2d": return nhwc_4h2w32c2w_2d + if layout == "nc-2048c-1d": + return nc_2048c_1d + if layout == "nc-2048c-2d": + return nc_2048c_2d if layout == "nc-1024c-2d": return nc_1024c_2d + if layout == "nc-1024c-1d": + return nc_1024c_1d if layout == "iohw-16i32o2i-1d": return iohw_16i32o2i_1d if layout == "nhwc-2048c-2d": diff --git a/tests/python/contrib/test_hexagon/infrastructure.py b/tests/python/contrib/test_hexagon/infrastructure.py index e81c24694ef9..735b3f2b94b5 100644 --- a/tests/python/contrib/test_hexagon/infrastructure.py +++ b/tests/python/contrib/test_hexagon/infrastructure.py @@ -253,8 +253,14 @@ def transform_numpy(arr_np, current_layout: str, new_layout: str): if current_layout == "nc": n, c = arr_np.shape + if new_layout in ["nc-2048c-1d"]: + return arr_np.reshape([n, c // 2048, 2048]) + if new_layout in ["nc-2048c-2d"]: + return arr_np.reshape([n, c // 2048, 2048]) if new_layout in ["nc-1024c-2d"]: return arr_np.reshape([n, c // 1024, 1024]) + if new_layout in ["nc-1024c-1d"]: + return arr_np.reshape([n, c // 1024, 1024]) if new_layout in ["nc-512c-2d"]: return arr_np.reshape([n, c // 512, 512]) if new_layout in ["nc-2048c-2d"]: diff --git a/tests/python/contrib/test_hexagon/topi/slice_op/test_dense_slice.py b/tests/python/contrib/test_hexagon/topi/slice_op/test_dense_slice.py new file mode 100644 index 000000000000..e616c384fb40 --- /dev/null +++ b/tests/python/contrib/test_hexagon/topi/slice_op/test_dense_slice.py @@ -0,0 +1,282 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import numpy as np + +from tvm import te, topi + +import tvm.testing +from tvm.topi import testing +from tvm.contrib.hexagon.build import HexagonLauncher +from tvm.contrib.hexagon.session import Session +import tvm.topi.hexagon.qnn as qnn +import tvm.topi.hexagon.slice_ops as sl +from ...infrastructure import transform_numpy, quantize_np +from tvm.contrib.hexagon import allocate_hexagon_array + + +@tvm.testing.fixture +def input_np(input_shape, dtype): + if "int" in dtype: + data = np.random.random(input_shape).astype("float32") + elif "float" in dtype: + data = np.random.random(input_shape).astype(dtype) + return data + + +@tvm.testing.fixture +def weight_np(weight_shape, dtype): + if "int" in dtype: + weight = np.random.random(weight_shape).astype("float32") + elif "float" in dtype: + weight = np.random.random(weight_shape).astype(dtype) + return weight + + +@tvm.testing.fixture +def input_quant(input_np, dtype): + if "float" in dtype: + return None + quant, scale, zp = quantize_np(input_np, dtype) + return {"zero": zp, "scale": scale, "data": quant} + + +@tvm.testing.fixture +def weight_quant(weight_np, dtype): + if "float" in dtype: + return None + quant, scale, zp = quantize_np(weight_np, "int8") + return {"zero": zp, "scale": scale, "data": quant} + + +@tvm.testing.fixture +def bias_np(bias_shape, bias, dtype): + if bias: + if "int" in dtype: + data = np.random.randint(-128, 127, size=bias_shape).astype("int32") + elif "float" in dtype: + data = np.random.random(bias_shape).astype(dtype) + return data + else: + return None + + +@tvm.testing.fixture +def quant_arr(input_quant, weight_quant): + if input_quant is None: + return None + arr = np.empty((6,), dtype="float32") + arr[0] = input_quant["zero"] + arr[1] = input_quant["scale"] + arr[2] = weight_quant["zero"] + arr[3] = weight_quant["scale"] + return arr + + +@tvm.testing.fixture +def transformed_expected_output_np(expected_output_np, layout): + return transform_numpy(expected_output_np, "nc", layout) + + +@tvm.testing.fixture +def transformed_input_np(input_np, layout): + return transform_numpy(input_np, "nc", layout) + + +@tvm.testing.fixture +def transformed_input_quant(input_quant, layout): + if input_quant is None: + return None + input_quant["data"] = transform_numpy(input_quant["data"], "nc", layout) + return input_quant + + +class TestDenseSlice: + (input_shape, output_shape, layout, bias, dtype,) = tvm.testing.parameters( + ( # Float 16 + [1, 1024], + [1, 1024], + "nc-1024c-2d", + False, + "float16", + ), + ( + [1, 2048], + [1, 2048], + "nc-1024c-2d", + True, + "float16", + ), + ( # Uint 8 + [1, 2048], + [1, 2048], + "nc-2048c-2d", + False, + "uint8", + ), + ( + [1, 4096], + [1, 4096], + "nc-2048c-2d", + True, + "uint8", + ), + ) + + @tvm.testing.fixture + def expected_output_np(self, input_np, weight_np, bias_np, bias): + ref_np = tvm.topi.testing.dense( + np.reshape(input_np, (input_np.shape[0], input_np.shape[-1])), + weight_np.T, # Function expects [in_dim, out_dim] + bias_np, + use_bias=bias, + out_dtype="float32" if "int" in str(input_np.dtype) else input_np.dtype, + ) + return ref_np + + @tvm.testing.fixture + def weight_shape(self, input_shape, output_shape): + return (output_shape[-1], input_shape[-1]) + + @tvm.testing.fixture + def bias_shape(self, output_shape): + return (output_shape[-1],) + + @tvm.testing.requires_hexagon + def test_dense_slice( + self, + dtype, + bias_np, + layout, + output_shape, + input_shape, + input_np, + input_quant, + transformed_input_np, + transformed_input_quant, + weight_np, + # transformed_weight_np, + weight_quant, + # transformed_weight_quant, + transformed_expected_output_np, + expected_output_np, + quant_arr, + hexagon_session: Session, + ): + + target_hexagon = tvm.target.hexagon("v69") + A = te.placeholder(input_shape, name="A", dtype=dtype) + W = te.placeholder( + (output_shape[-1], input_shape[-1]), + name="W", + dtype="int8" if dtype == "uint8" else dtype, + ) + args = [A, W] + tensors = [A, W] + + # If quantized, append the quantization params + if "int" in dtype: + args.append(quant_arr[0].astype("int32")) + args.append(quant_arr[1]) + args.append(quant_arr[2].astype("int32")) + args.append(quant_arr[3]) + + if bias_np is not None: + B = te.placeholder((output_shape[-1],), name="B", dtype=str(bias_np.dtype)) + args.append(B) + tensors.append(B) + else: + B = None + + # Different compute and schedule for quant and float + if "float" in dtype: + M = sl.dense_compute(*args) + tir_schedule = sl.dense_schedule([M], tensors, layout, layout) + elif "int" in dtype: + M = qnn.qdense_compute(*args, bias=B) + tir_schedule = qnn.qdense_schedule([M], tensors, layout, layout) + else: + print("Unsupported dtype {}".format(dtype)) + exit(-1) + + sch = tir_schedule.mod + + input_axis_separator = [2] + output_axis_separator = [2] + + with tvm.transform.PassContext(opt_level=3): + func = tvm.build( + sch, + args, + target=tvm.target.Target(target_hexagon, host=target_hexagon), + name="dense", + ) + func.save("dense.s" if bias_np is None else "dense_bias.s") + + input_arr = allocate_hexagon_array( + hexagon_session.device, + data=transformed_input_np if "float" in dtype else transformed_input_quant["data"], + axis_separators=input_axis_separator, + mem_scope="global.vtcm", + ) + weight_arr = allocate_hexagon_array( + hexagon_session.device, + data=weight_np if "float" in dtype else weight_quant["data"], + axis_separators=None, + mem_scope="global", + ) + output_arr = allocate_hexagon_array( + hexagon_session.device, + transformed_expected_output_np.shape, + "float32" if "int" in dtype else dtype, + axis_separators=output_axis_separator, + mem_scope="global.vtcm", + ) + arrs = [input_arr, weight_arr] + + if bias_np is not None: + bias_arr = allocate_hexagon_array( + hexagon_session.device, + data=bias_np, + axis_separators=None, + mem_scope="global.vtcm", + ) + arrs.append(bias_arr) + + arrs.append(output_arr) + + mod = hexagon_session.load_module(func) + mod(*arrs) + + # Reshape for comparison + b, c = output_shape + if layout == "nc-1024c-2d": + output_np = output_arr.numpy().reshape([b, c // 1024, 1024]) + elif layout == "nc-2048c-2d": + output_np = output_arr.numpy().reshape([b, c // 2048, 2048]) + else: + raise RuntimeError(f"Unexpected layout '{layout}'") + + if "int" in dtype: + np.testing.assert_allclose(output_np, transformed_expected_output_np, rtol=1e-2, atol=0) + elif "float" in dtype: + np.testing.assert_allclose(output_np, transformed_expected_output_np, rtol=1e-1, atol=0) + + +if __name__ == "__main__": + tvm.testing.main()