From cd7f08c65fc2de93ba732cf1576b40615f999354 Mon Sep 17 00:00:00 2001 From: Josh Herrera Date: Thu, 22 Sep 2022 08:23:18 -0700 Subject: [PATCH 01/19] Float and quantized dense operators with schedules --- python/tvm/topi/hexagon/qnn/__init__.py | 18 +- python/tvm/topi/hexagon/qnn/qdense.py | 169 +++++++++ python/tvm/topi/hexagon/slice_ops/__init__.py | 1 + python/tvm/topi/hexagon/slice_ops/dense.py | 143 +++++++ python/tvm/topi/hexagon/utils.py | 21 ++ .../contrib/test_hexagon/infrastructure.py | 6 + .../test_hexagon/topi/test_dense_slice.py | 354 ++++++++++++++++++ 7 files changed, 702 insertions(+), 10 deletions(-) create mode 100644 python/tvm/topi/hexagon/qnn/qdense.py create mode 100644 python/tvm/topi/hexagon/slice_ops/dense.py create mode 100644 tests/python/contrib/test_hexagon/topi/test_dense_slice.py diff --git a/python/tvm/topi/hexagon/qnn/__init__.py b/python/tvm/topi/hexagon/qnn/__init__.py index b8cdc7a26d96..d6f140fcd63c 100644 --- a/python/tvm/topi/hexagon/qnn/__init__.py +++ b/python/tvm/topi/hexagon/qnn/__init__.py @@ -17,16 +17,14 @@ """ Computes and schedules for Hexagon quantized ops """ +from .adaptive_avg_pool1d import * from .avg_pool2d import qnn_avg_pool2d_compute, qnn_avg_pool2d_schedule -from .qadd_qsub_qmul import * -from .dequantize import ( - dequantize_compute, - dequantize_schedule, -) - -from .quantize import quantize_compute, tir_quantize_schedule +from .conv2d_alter_op import * +from .dequantize import dequantize_compute, dequantize_schedule +from .global_avg_pool2d import * from .nn import * +from .qadd_qsub_qmul import * +from .qdense import * from .qdepthwise_conv2d_slice import qdepthwise_conv2d_compute, qdepthwise_conv2d_schedule -from .adaptive_avg_pool1d import * -from .global_avg_pool2d import * -from .conv2d_alter_op import * +from .quantize import quantize_compute, tir_quantize_schedule + diff --git a/python/tvm/topi/hexagon/qnn/qdense.py b/python/tvm/topi/hexagon/qnn/qdense.py new file mode 100644 index 000000000000..7b427b465f65 --- /dev/null +++ b/python/tvm/topi/hexagon/qnn/qdense.py @@ -0,0 +1,169 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Schedule for dense operator""" + +from tvm import te, tir +from tvm.topi import tag +from ..utils import get_layout_transform_fn + + +def qdense_compute(tensor_a, tensor_b, + zero_A, scale_A, + zero_B, scale_B, + zero_out=None, scale_out=None, + bias=None, + q_dtype=None): + """Hexagon's implementation of a sliced dense operator in Topi. + Uses matmul. + + Parameters + ---------- + tensor_a : tvm.te.Tensor + data 2-D with shape [batch, in_dim] + + tensor_b : tvm.te.Tensor + weight 2-D with shape [out_dim, in_dim] + + bias : Optional[tvm.te.Tensor] + 1-D with shape [out_dim] + + out_dtype : Optional[str] + The output type. This is used for mixed precision. + + Returns + ------- + mat : tvm.te.Tensor + 2-D with shape [batch, out_dim] + + """ + if bias is not None: + assert len(bias.shape) == 1 + if q_dtype is None: + q_dtype = tensor_a.dtype + + batch, in_dim = tensor_a.shape + out_dim, red_dim = tensor_b.shape + + # cmp should be done by values + assert int(in_dim) == int(red_dim) + + k = te.reduce_axis((0, in_dim), name="k") + compute_lambda = lambda n, m: te.sum( + scale_A + * (tensor_a[n, k].astype("float32") - zero_A) + * scale_B + * (tensor_b[m, k].astype("float32") - zero_B), + axis=k, + ) + compute_name = "qmatmul_sliced" + + out = te.compute( + (batch, out_dim), + compute_lambda, + name=compute_name, + attrs={"layout_free_placeholders": [tensor_b]}, + ) + + if bias is not None: + out = te.compute( + (batch, out_dim), + lambda i, j: out[i, j] + bias[j], + tag=tag.BROADCAST, + name="bias", + ) + + # Requantization of dense + if scale_out is not None: + out = te.compute( + (batch, out_dim), + lambda *i: (out[i] / scale_out + zero_out).astype(q_dtype), + name="requantize", + ) + + return out + + +def qdense_schedule(outs, ins, output_layout: str, input_layout: str): + """Schedule for dense op. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of dense in the format + of an array of tensors. + + ins: Array of Tensor + Input tensors into graph. + + output_layout: str + Descriptor string for physical layout + + input_layout: str + Descriptor string for physical layout + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + if not isinstance(ins, list): + ins = [ins] + if not isinstance(outs, list): + outs = [outs] + + func = te.create_prim_func([*ins, *outs]) + s = tir.Schedule(func) + + matmul = s.get_block("qmatmul_sliced") + try: + requantize = s.get_block("requantize") + except tir.schedule.schedule.ScheduleError: + requantize = None + try: + bias = s.get_block("bias") + except tir.schedule.schedule.ScheduleError: + bias = None + + input_transform_fn = get_layout_transform_fn(input_layout) + output_transform_fn = get_layout_transform_fn(output_layout) + + # Transform input and output buffer + s.transform_layout(matmul, ("read", 0), input_transform_fn) + if requantize is not None: + s.transform_layout(requantize, ("write", 0), output_transform_fn) + elif bias is not None: + s.transform_layout(bias, ("write", 0), output_transform_fn) + else: + s.transform_layout(matmul, ("write", 0), output_transform_fn) + + # Vectorize + _, matmul_c, _ = s.get_loops(matmul) + _, matmul_c_inner = s.split(matmul_c, [None, 1024]) + s.vectorize(matmul_c_inner) + + # Compute everything inline + if bias is not None and requantize is not None: + _, bias_c = s.get_loops(bias) + s.compute_at(matmul, bias_c) + _, out_c = s.get_loops(requantize) + s.compute_at(bias, out_c) + elif bias is not None and requantize is None: + _, out_c = s.get_loops(bias) + s.compute_at(matmul, out_c) + + return s diff --git a/python/tvm/topi/hexagon/slice_ops/__init__.py b/python/tvm/topi/hexagon/slice_ops/__init__.py index 6b17b64489a9..46ae0c53200f 100644 --- a/python/tvm/topi/hexagon/slice_ops/__init__.py +++ b/python/tvm/topi/hexagon/slice_ops/__init__.py @@ -37,3 +37,4 @@ from .dwconv2d import * from .depth_to_space import d2s_compute, d2s_schedule from .global_avg_pool2d import * +from .dense import * diff --git a/python/tvm/topi/hexagon/slice_ops/dense.py b/python/tvm/topi/hexagon/slice_ops/dense.py new file mode 100644 index 000000000000..599b408cb849 --- /dev/null +++ b/python/tvm/topi/hexagon/slice_ops/dense.py @@ -0,0 +1,143 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Schedule for dense operator""" + +from tvm import te, tir +from tvm.topi import tag +from ..utils import get_layout_transform_fn + + +def dense_compute(tensor_a, tensor_b, bias=None, out_dtype=None): + """Hexagon's implementation of a sliced dense operator in Topi. + Uses matmul. + + Parameters + ---------- + tensor_a : tvm.te.Tensor + data 2-D with shape [batch, in_dim] + + tensor_b : tvm.te.Tensor + weight 2-D with shape [out_dim, in_dim] + + bias : Optional[tvm.te.Tensor] + 1-D with shape [out_dim] + + out_dtype : Optional[str] + The output type. This is used for mixed precision. + + Returns + ------- + output : tvm.te.Tensor + 2-D with shape [batch, out_dim] + + """ + if bias is not None: + assert len(bias.shape) == 1 + if out_dtype is None: + out_dtype = tensor_a.dtype + + batch, in_dim = tensor_a.shape + out_dim, red_dim = tensor_b.shape + + # cmp should be done by values + assert int(in_dim) == int(red_dim) + + k = te.reduce_axis((0, in_dim), name="k") + compute_lambda = lambda n, m: te.sum( + tensor_a[n, k].astype(out_dtype) * tensor_b[m, k].astype(out_dtype), axis=k + ) + compute_name = "matmul_sliced" + compute_tag = "matmul" + + mat = te.compute( + (batch, out_dim), + compute_lambda, + name=compute_name, + tag=compute_tag, + attrs={"layout_free_placeholders": [tensor_b]}, + ) + + if bias is not None: + mat = te.compute( + (batch, out_dim), + lambda i, j: mat[i, j] + bias[j].astype(out_dtype), + tag=tag.BROADCAST, + name="bias", + ) + + return mat + + +def dense_schedule(outs, ins, output_layout: str, input_layout: str): + """Schedule for dense op. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of dense in the format + of an array of tensors. + + ins: Array of Tensor + Input tensors into graph. + + output_layout: str + Descriptor string for physical layout + + input_layout: str + Descriptor string for physical layout + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + if not isinstance(ins, list): + ins = [ins] + if not isinstance(outs, list): + outs = [outs] + + func = te.create_prim_func([*ins, *outs]) + s = tir.Schedule(func) + + matmul = s.get_block("matmul_sliced") + try: + bias = s.get_block("bias") + except tir.schedule.schedule.ScheduleError: + bias = None + + input_transform_fn = get_layout_transform_fn(input_layout) + output_transform_fn = get_layout_transform_fn(output_layout) + + # No bias + if bias is None: + s.transform_layout(matmul, ("read", 0), input_transform_fn) + # s.transform_layout(matmul, ("read", 1), input_transform_fn) + s.transform_layout(matmul, ("write", 0), output_transform_fn) + else: + s.transform_layout(matmul, ("read", 0), input_transform_fn) + s.transform_layout(bias, ("write", 0), output_transform_fn) + + _, matmul_c, _ = s.get_loops(matmul) + _, matmul_c_inner = s.split(matmul_c, [None, 1024]) + s.vectorize(matmul_c_inner) + + if bias is not None: + _, bias_c = s.get_loops(bias) + s.compute_at(matmul, bias_c) + + return s diff --git a/python/tvm/topi/hexagon/utils.py b/python/tvm/topi/hexagon/utils.py index 78ed21e8a13b..ed169a27f6d5 100644 --- a/python/tvm/topi/hexagon/utils.py +++ b/python/tvm/topi/hexagon/utils.py @@ -75,6 +75,21 @@ def nc_1024c_2d(n, c): return [n, c // 1024, te.AXIS_SEPARATOR, c % 1024] +def nc_2048c_1d(n, c): + """Return index map for nc_2024c 1d layout""" + return [n, c // 2048, c % 2048] + + +def nc_2048c_2d(n, c): + """Return index map for nc_2024c 2d layout""" + return [n, c // 2048, te.AXIS_SEPARATOR, c % 2048] + + +def nc_1024c_1d(n, c): + """Return index map for nc_1024c 1d layout""" + return [n, c // 1024, c % 1024] + + def nhwc_4h2w32c2w_2d(n, h, w, c): """Return index map for nhwc_4h2w32c2w 2d layout""" return [n, h // 4, w // 4, c // 32, te.AXIS_SEPARATOR, h % 4, (w % 4) // 2, c % 32, w % 2] @@ -170,8 +185,14 @@ def get_layout_transform_fn(layout): return nc_512c_1d if layout == "nhwc-4h2w32c2w-2d": return nhwc_4h2w32c2w_2d + if layout == "nc-2048c-1d": + return nc_2048c_1d + if layout == "nc-2048c-2d": + return nc_2048c_2d if layout == "nc-1024c-2d": return nc_1024c_2d + if layout == "nc-1024c-1d": + return nc_1024c_1d if layout == "iohw-16i32o2i-1d": return iohw_16i32o2i_1d if layout == "nhwc-2048c-2d": diff --git a/tests/python/contrib/test_hexagon/infrastructure.py b/tests/python/contrib/test_hexagon/infrastructure.py index e81c24694ef9..735b3f2b94b5 100644 --- a/tests/python/contrib/test_hexagon/infrastructure.py +++ b/tests/python/contrib/test_hexagon/infrastructure.py @@ -253,8 +253,14 @@ def transform_numpy(arr_np, current_layout: str, new_layout: str): if current_layout == "nc": n, c = arr_np.shape + if new_layout in ["nc-2048c-1d"]: + return arr_np.reshape([n, c // 2048, 2048]) + if new_layout in ["nc-2048c-2d"]: + return arr_np.reshape([n, c // 2048, 2048]) if new_layout in ["nc-1024c-2d"]: return arr_np.reshape([n, c // 1024, 1024]) + if new_layout in ["nc-1024c-1d"]: + return arr_np.reshape([n, c // 1024, 1024]) if new_layout in ["nc-512c-2d"]: return arr_np.reshape([n, c // 512, 512]) if new_layout in ["nc-2048c-2d"]: diff --git a/tests/python/contrib/test_hexagon/topi/test_dense_slice.py b/tests/python/contrib/test_hexagon/topi/test_dense_slice.py new file mode 100644 index 000000000000..51a0f6a61b4e --- /dev/null +++ b/tests/python/contrib/test_hexagon/topi/test_dense_slice.py @@ -0,0 +1,354 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import numpy as np + +from tvm import te, topi + +import tvm.testing +from tvm.topi import testing +from tvm.contrib.hexagon.build import HexagonLauncher +from tvm.contrib.hexagon.session import Session +import tvm.topi.hexagon.qnn as qnn +import tvm.topi.hexagon.slice_ops as sl +from ..infrastructure import allocate_hexagon_array, transform_numpy, quantize_np + +@tvm.testing.fixture +def input_np(input_shape, dtype): + if "int" in dtype: + data = np.random.random(input_shape).astype("float32") + elif "float" in dtype: + data = np.random.random(input_shape).astype(dtype) + return data + + +@tvm.testing.fixture +def weight_np(weight_shape, dtype): + if "int" in dtype: + weight = np.random.random(weight_shape).astype("float32") + elif "float" in dtype: + weight = np.random.random(weight_shape).astype(dtype) + return weight + + +@tvm.testing.fixture +def input_quant(input_np, dtype): + if "float" in dtype: + return None + quant, scale, zp = quantize_np(input_np, dtype) + return {"zero": zp, + "scale": scale, + "data": quant} + + +@tvm.testing.fixture +def weight_quant(weight_np, dtype): + if "float" in dtype: + return None + quant, scale, zp = quantize_np(weight_np, "int8") + return {"zero": zp, + "scale": scale, + "data": quant} + +@tvm.testing.fixture +def bias_np(bias_shape, bias, dtype): + if bias: + return np.random.randint(-128, 127, size=bias_shape).astype("int32") + else: + return None + + +@tvm.testing.fixture +def quant_arr(input_quant, weight_quant): + if input_quant is None: + return None + arr = np.empty((6,), dtype="float32") + arr[0] = input_quant["zero"] + arr[1] = input_quant["scale"] + arr[2] = weight_quant["zero"] + arr[3] = weight_quant["scale"] + return arr + + +@tvm.testing.fixture +def transformed_expected_output_np(expected_output_np, layout): + return transform_numpy(expected_output_np, "nc", layout) + +@tvm.testing.fixture +def transformed_input_np(input_np, layout): + return transform_numpy(input_np, "nc", layout) + + +# TODO(joshherr-quic): transforming weight forces us to put it in vtcm. Crashes at runtime in vtcm +# @tvm.testing.fixture +# def transformed_weight_np(weight_np, layout): +# return transform_numpy(weight_np, "nc", layout) + + +@tvm.testing.fixture +def transformed_input_quant(input_quant, layout): + if input_quant is None: + return None + input_quant["data"] = transform_numpy(input_quant["data"], "nc", layout) + return input_quant + + +# @tvm.testing.fixture +# def transformed_weight_quant(weight_quant, layout): +# weight_quant["data"] = transform_numpy(weight_quant["data"], "nc", layout) +# return weight_quant + +# Test combinations of the following: +# dtype in (float16, uint8, int8) +# num_croutons in (1, 2) +# bias_enabled in (true, false) +class TestDenseSlice: + (input_shape, output_shape, layout, bias, dtype,) = tvm.testing.parameters( + ( # Float 16 + [1, 1024], + [1, 1024], + "nc-1024c-2d", + False, + "float16", + ), + ( + [1, 1024], + [1, 1024], + "nc-1024c-2d", + True, + "float16", + ), + ( + [1, 2048], + [1, 2048], + "nc-1024c-2d", + False, + "float16", + ), + ( + [1, 2048], + [1, 2048], + "nc-1024c-2d", + True, + "float16", + ), + ( # Uint 8 + [1, 2048], + [1, 2048], + "nc-2048c-2d", + False, + "uint8", + ), + ( + [1, 2048], + [1, 2048], + "nc-2048c-2d", + True, + "uint8", + ), + ( + [1, 4096], + [1, 4096], + "nc-2048c-2d", + False, + "uint8", + ), + ( + [1, 4096], + [1, 4096], + "nc-2048c-2d", + True, + "uint8", + ), + ( # Int 8 + [1, 2048], + [1, 2048], + "nc-2048c-2d", + False, + "int8", + ), + ( + [1, 2048], + [1, 2048], + "nc-2048c-2d", + True, + "int8", + ), + ( + [1, 4096], + [1, 4096], + "nc-2048c-2d", + False, + "int8", + ), + ( + [1, 4096], + [1, 4096], + "nc-2048c-2d", + True, + "int8", + ), + ) + + @tvm.testing.fixture + def expected_output_np(self, input_np, weight_np, bias_np, bias): + ref_np = tvm.topi.testing.dense( + np.reshape(input_np, (input_np.shape[0], input_np.shape[-1])), + weight_np, + bias_np, + use_bias=bias, + out_dtype="float32" if "int" in str(input_np.dtype) else input_np.dtype + ) + return ref_np + + @tvm.testing.fixture + def weight_shape(self, input_shape, output_shape): + return (output_shape[-1], input_shape[-1]) + + @tvm.testing.fixture + def bias_shape(self, output_shape): + return (output_shape[-1],) + + @tvm.testing.requires_hexagon + def test_dense_slice( + self, + dtype, + bias_np, + layout, + output_shape, + input_shape, + input_np, + input_quant, + transformed_input_np, + transformed_input_quant, + weight_np, + # transformed_weight_np, + weight_quant, + # transformed_weight_quant, + transformed_expected_output_np, + expected_output_np, + quant_arr, + hexagon_session: Session, + ): + if hexagon_session._launcher._serial_number != "simulator": + pytest.skip(msg="Due to https://github.com/apache/tvm/issues/11928") + + target_hexagon = tvm.target.hexagon("v69") + A = te.placeholder(input_shape, name="A", dtype=dtype) + W = te.placeholder( + (output_shape[-1], input_shape[-1]), + name="W", + dtype="int8" if dtype=="uint8" else dtype) + args = [A, W] + tensors = [A, W] + + # If quantized, append the quantization params + if "int" in dtype: + args.append(quant_arr[0].astype("int32")) + args.append(quant_arr[1]) + args.append(quant_arr[2].astype("int32")) + args.append(quant_arr[3]) + + if bias_np is not None: + B = te.placeholder((output_shape[-1],), name="B", dtype="int32") + args.append(B) + tensors.append(B) + else: + B = None + + # Different compute and schedule for quant and float + if "float" in dtype: + M = sl.dense_compute(*args) + tir_schedule = sl.dense_schedule([M], tensors, layout, layout) + elif "int" in dtype: + M = qnn.qdense_compute(*args, bias=B) + tir_schedule = qnn.qdense_schedule([M], tensors, layout, layout) + else: + print("Unsupported dtype {}".format(dtype)) + exit(-1) + + sch = tir_schedule.mod + + input_axis_separator = [2] + output_axis_separator = [2] + + with tvm.transform.PassContext(opt_level=3): + func = tvm.build( + sch, + args, + target=tvm.target.Target(target_hexagon, host=target_hexagon), + name="dense", + ) + + input_arr = allocate_hexagon_array( + hexagon_session.device, + data=transformed_input_np if "float" in dtype else transformed_input_quant["data"], + axis_separators=input_axis_separator, + mem_scope="global.vtcm", + ) + weight_arr = allocate_hexagon_array( + hexagon_session.device, + data=weight_np if "float" in dtype else weight_quant["data"], + axis_separators=None, + mem_scope="global", + ) + output_arr = allocate_hexagon_array( + hexagon_session.device, + transformed_expected_output_np.shape, + "float32" if "int" in dtype else dtype, + axis_separators=output_axis_separator, + mem_scope="global.vtcm", + ) + arrs = [input_arr, weight_arr] + + if bias_np is not None: + bias_arr = allocate_hexagon_array( + hexagon_session.device, + data=bias_np, + axis_separators=None, + mem_scope="global.vtcm", + ) + arrs.append(bias_arr) + + arrs.append(output_arr) + + mod = hexagon_session.load_module(func) + mod(*arrs) + + # Reshape for comparison + b, c = output_shape + if layout == "nc-1024c-2d": + output_np = output_arr.numpy().reshape([b, c // 1024, 1024]) + elif layout == "nc-2048c-2d": + output_np = output_arr.numpy().reshape([b, c // 2048, 2048]) + else: + raise RuntimeError(f"Unexpected layout '{layout}'") + + # TODO(joshherr-quic): Investigate ways to improve accuracy + if "int" in dtype: + np.testing.assert_allclose( + output_np, transformed_expected_output_np, rtol=1e-1, atol=0 + ) + elif "float" in dtype: + np.testing.assert_allclose( + output_np, transformed_expected_output_np, rtol=1e-1, atol=0 + ) + + +if __name__ == "__main__": + tvm.testing.main() From afd8193dbd2257f9e11fb7d091abd2e600394684 Mon Sep 17 00:00:00 2001 From: Josh Herrera Date: Tue, 27 Sep 2022 23:46:23 -0700 Subject: [PATCH 02/19] Formatting --- python/tvm/topi/hexagon/qnn/qdense.py | 18 +++++++---- .../test_hexagon/topi/test_dense_slice.py | 30 ++++++++----------- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/python/tvm/topi/hexagon/qnn/qdense.py b/python/tvm/topi/hexagon/qnn/qdense.py index 7b427b465f65..4aacab23ce14 100644 --- a/python/tvm/topi/hexagon/qnn/qdense.py +++ b/python/tvm/topi/hexagon/qnn/qdense.py @@ -22,12 +22,18 @@ from ..utils import get_layout_transform_fn -def qdense_compute(tensor_a, tensor_b, - zero_A, scale_A, - zero_B, scale_B, - zero_out=None, scale_out=None, - bias=None, - q_dtype=None): +def qdense_compute( + tensor_a, + tensor_b, + zero_A, + scale_A, + zero_B, + scale_B, + zero_out=None, + scale_out=None, + bias=None, + q_dtype=None, +): """Hexagon's implementation of a sliced dense operator in Topi. Uses matmul. diff --git a/tests/python/contrib/test_hexagon/topi/test_dense_slice.py b/tests/python/contrib/test_hexagon/topi/test_dense_slice.py index 51a0f6a61b4e..13e57ebfaf76 100644 --- a/tests/python/contrib/test_hexagon/topi/test_dense_slice.py +++ b/tests/python/contrib/test_hexagon/topi/test_dense_slice.py @@ -28,6 +28,7 @@ import tvm.topi.hexagon.slice_ops as sl from ..infrastructure import allocate_hexagon_array, transform_numpy, quantize_np + @tvm.testing.fixture def input_np(input_shape, dtype): if "int" in dtype: @@ -51,9 +52,7 @@ def input_quant(input_np, dtype): if "float" in dtype: return None quant, scale, zp = quantize_np(input_np, dtype) - return {"zero": zp, - "scale": scale, - "data": quant} + return {"zero": zp, "scale": scale, "data": quant} @tvm.testing.fixture @@ -61,9 +60,8 @@ def weight_quant(weight_np, dtype): if "float" in dtype: return None quant, scale, zp = quantize_np(weight_np, "int8") - return {"zero": zp, - "scale": scale, - "data": quant} + return {"zero": zp, "scale": scale, "data": quant} + @tvm.testing.fixture def bias_np(bias_shape, bias, dtype): @@ -89,6 +87,7 @@ def quant_arr(input_quant, weight_quant): def transformed_expected_output_np(expected_output_np, layout): return transform_numpy(expected_output_np, "nc", layout) + @tvm.testing.fixture def transformed_input_np(input_np, layout): return transform_numpy(input_np, "nc", layout) @@ -119,7 +118,7 @@ def transformed_input_quant(input_quant, layout): # bias_enabled in (true, false) class TestDenseSlice: (input_shape, output_shape, layout, bias, dtype,) = tvm.testing.parameters( - ( # Float 16 + ( # Float 16 [1, 1024], [1, 1024], "nc-1024c-2d", @@ -147,7 +146,7 @@ class TestDenseSlice: True, "float16", ), - ( # Uint 8 + ( # Uint 8 [1, 2048], [1, 2048], "nc-2048c-2d", @@ -175,7 +174,7 @@ class TestDenseSlice: True, "uint8", ), - ( # Int 8 + ( # Int 8 [1, 2048], [1, 2048], "nc-2048c-2d", @@ -212,7 +211,7 @@ def expected_output_np(self, input_np, weight_np, bias_np, bias): weight_np, bias_np, use_bias=bias, - out_dtype="float32" if "int" in str(input_np.dtype) else input_np.dtype + out_dtype="float32" if "int" in str(input_np.dtype) else input_np.dtype, ) return ref_np @@ -253,7 +252,8 @@ def test_dense_slice( W = te.placeholder( (output_shape[-1], input_shape[-1]), name="W", - dtype="int8" if dtype=="uint8" else dtype) + dtype="int8" if dtype == "uint8" else dtype, + ) args = [A, W] tensors = [A, W] @@ -341,13 +341,9 @@ def test_dense_slice( # TODO(joshherr-quic): Investigate ways to improve accuracy if "int" in dtype: - np.testing.assert_allclose( - output_np, transformed_expected_output_np, rtol=1e-1, atol=0 - ) + np.testing.assert_allclose(output_np, transformed_expected_output_np, rtol=1e-1, atol=0) elif "float" in dtype: - np.testing.assert_allclose( - output_np, transformed_expected_output_np, rtol=1e-1, atol=0 - ) + np.testing.assert_allclose(output_np, transformed_expected_output_np, rtol=1e-1, atol=0) if __name__ == "__main__": From 3c1613af990be0575550600b9e3b85b01c6ed4a6 Mon Sep 17 00:00:00 2001 From: Josh Herrera Date: Wed, 28 Sep 2022 02:12:06 -0700 Subject: [PATCH 03/19] Change var name to conform --- python/tvm/topi/hexagon/qnn/qdense.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/tvm/topi/hexagon/qnn/qdense.py b/python/tvm/topi/hexagon/qnn/qdense.py index 4aacab23ce14..dd6c7cef31e4 100644 --- a/python/tvm/topi/hexagon/qnn/qdense.py +++ b/python/tvm/topi/hexagon/qnn/qdense.py @@ -25,10 +25,10 @@ def qdense_compute( tensor_a, tensor_b, - zero_A, - scale_A, - zero_B, - scale_B, + zero_a, + scale_a, + zero_b, + scale_b, zero_out=None, scale_out=None, bias=None, @@ -70,10 +70,10 @@ def qdense_compute( k = te.reduce_axis((0, in_dim), name="k") compute_lambda = lambda n, m: te.sum( - scale_A - * (tensor_a[n, k].astype("float32") - zero_A) - * scale_B - * (tensor_b[m, k].astype("float32") - zero_B), + scale_a + * (tensor_a[n, k].astype("float32") - zero_a) + * scale_b + * (tensor_b[m, k].astype("float32") - zero_b), axis=k, ) compute_name = "qmatmul_sliced" From 82c28c32b3b58c5cf1fea3d8a8bf305f50bf447c Mon Sep 17 00:00:00 2001 From: Josh Herrera Date: Wed, 28 Sep 2022 02:20:41 -0700 Subject: [PATCH 04/19] Remove redundant function dec --- python/tvm/topi/hexagon/utils.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/python/tvm/topi/hexagon/utils.py b/python/tvm/topi/hexagon/utils.py index ed169a27f6d5..86aa87adf319 100644 --- a/python/tvm/topi/hexagon/utils.py +++ b/python/tvm/topi/hexagon/utils.py @@ -115,11 +115,6 @@ def nc_2048_2d(n, c): return [n, c // 2048, te.AXIS_SEPARATOR, c % 2048] -def nc_2048c_2d(n, c): - """Return index map for nc_2048 2d layout""" - return [n, c // 2048, te.AXIS_SEPARATOR, c % 2048] - - def nhwc_8h8w32c_2d(n, h, w, c): """Return index map for nhwc_8h8w32c 2d layout""" return [n, h // 8, w // 8, c // 32, te.AXIS_SEPARATOR, h % 8, w % 8, c % 32] From 6b0fa3f951a0dd3b9e9c317ac7c0617014ecc3db Mon Sep 17 00:00:00 2001 From: Josh Herrera Date: Wed, 5 Oct 2022 17:23:01 -0700 Subject: [PATCH 05/19] Use transposed weight matrix --- python/tvm/topi/hexagon/qnn/qdense.py | 6 ++-- python/tvm/topi/hexagon/slice_ops/dense.py | 11 +++--- .../test_hexagon/topi/test_dense_slice.py | 36 +++++++++---------- 3 files changed, 27 insertions(+), 26 deletions(-) diff --git a/python/tvm/topi/hexagon/qnn/qdense.py b/python/tvm/topi/hexagon/qnn/qdense.py index dd6c7cef31e4..ed4ea56b1a58 100644 --- a/python/tvm/topi/hexagon/qnn/qdense.py +++ b/python/tvm/topi/hexagon/qnn/qdense.py @@ -43,7 +43,7 @@ def qdense_compute( data 2-D with shape [batch, in_dim] tensor_b : tvm.te.Tensor - weight 2-D with shape [out_dim, in_dim] + weight 2-D with shape [in_dim, out_dim] bias : Optional[tvm.te.Tensor] 1-D with shape [out_dim] @@ -73,7 +73,7 @@ def qdense_compute( scale_a * (tensor_a[n, k].astype("float32") - zero_a) * scale_b - * (tensor_b[m, k].astype("float32") - zero_b), + * (tensor_b[k, m].astype("float32") - zero_b), axis=k, ) compute_name = "qmatmul_sliced" @@ -159,7 +159,7 @@ def qdense_schedule(outs, ins, output_layout: str, input_layout: str): # Vectorize _, matmul_c, _ = s.get_loops(matmul) - _, matmul_c_inner = s.split(matmul_c, [None, 1024]) + _, matmul_c_inner = s.split(matmul_c, [None, 128]) s.vectorize(matmul_c_inner) # Compute everything inline diff --git a/python/tvm/topi/hexagon/slice_ops/dense.py b/python/tvm/topi/hexagon/slice_ops/dense.py index 599b408cb849..a298ff4bc98e 100644 --- a/python/tvm/topi/hexagon/slice_ops/dense.py +++ b/python/tvm/topi/hexagon/slice_ops/dense.py @@ -32,7 +32,7 @@ def dense_compute(tensor_a, tensor_b, bias=None, out_dtype=None): data 2-D with shape [batch, in_dim] tensor_b : tvm.te.Tensor - weight 2-D with shape [out_dim, in_dim] + weight 2-D with shape [in_dim, out_dim] bias : Optional[tvm.te.Tensor] 1-D with shape [out_dim] @@ -59,7 +59,7 @@ def dense_compute(tensor_a, tensor_b, bias=None, out_dtype=None): k = te.reduce_axis((0, in_dim), name="k") compute_lambda = lambda n, m: te.sum( - tensor_a[n, k].astype(out_dtype) * tensor_b[m, k].astype(out_dtype), axis=k + tensor_a[n, k].astype(out_dtype) * tensor_b[k, m].astype(out_dtype), axis=k ) compute_name = "matmul_sliced" compute_tag = "matmul" @@ -75,7 +75,7 @@ def dense_compute(tensor_a, tensor_b, bias=None, out_dtype=None): if bias is not None: mat = te.compute( (batch, out_dim), - lambda i, j: mat[i, j] + bias[j].astype(out_dtype), + lambda i, j: mat[i, j] + bias[j], tag=tag.BROADCAST, name="bias", ) @@ -133,11 +133,12 @@ def dense_schedule(outs, ins, output_layout: str, input_layout: str): s.transform_layout(bias, ("write", 0), output_transform_fn) _, matmul_c, _ = s.get_loops(matmul) - _, matmul_c_inner = s.split(matmul_c, [None, 1024]) + _, matmul_c_inner = s.split(matmul_c, [None, 64]) s.vectorize(matmul_c_inner) if bias is not None: _, bias_c = s.get_loops(bias) - s.compute_at(matmul, bias_c) + _, bias_c_inner = s.split(bias_c, [None, 64]) + s.vectorize(bias_c_inner) return s diff --git a/tests/python/contrib/test_hexagon/topi/test_dense_slice.py b/tests/python/contrib/test_hexagon/topi/test_dense_slice.py index 13e57ebfaf76..06843aeb02c6 100644 --- a/tests/python/contrib/test_hexagon/topi/test_dense_slice.py +++ b/tests/python/contrib/test_hexagon/topi/test_dense_slice.py @@ -125,13 +125,13 @@ class TestDenseSlice: False, "float16", ), - ( - [1, 1024], - [1, 1024], - "nc-1024c-2d", - True, - "float16", - ), + # ( + # [1, 1024], + # [1, 1024], + # "nc-1024c-2d", + # True, + # "float16", + # ), ( [1, 2048], [1, 2048], @@ -139,13 +139,13 @@ class TestDenseSlice: False, "float16", ), - ( - [1, 2048], - [1, 2048], - "nc-1024c-2d", - True, - "float16", - ), + # ( + # [1, 2048], + # [1, 2048], + # "nc-1024c-2d", + # True, + # "float16", + # ), ( # Uint 8 [1, 2048], [1, 2048], @@ -208,7 +208,7 @@ class TestDenseSlice: def expected_output_np(self, input_np, weight_np, bias_np, bias): ref_np = tvm.topi.testing.dense( np.reshape(input_np, (input_np.shape[0], input_np.shape[-1])), - weight_np, + weight_np.T, # Function expects [in_dim, out_dim] bias_np, use_bias=bias, out_dtype="float32" if "int" in str(input_np.dtype) else input_np.dtype, @@ -294,6 +294,7 @@ def test_dense_slice( target=tvm.target.Target(target_hexagon, host=target_hexagon), name="dense", ) + func.save("dense.s" if bias_np is None else "dense_bias.s") input_arr = allocate_hexagon_array( hexagon_session.device, @@ -339,11 +340,10 @@ def test_dense_slice( else: raise RuntimeError(f"Unexpected layout '{layout}'") - # TODO(joshherr-quic): Investigate ways to improve accuracy if "int" in dtype: - np.testing.assert_allclose(output_np, transformed_expected_output_np, rtol=1e-1, atol=0) + np.testing.assert_allclose(output_np, transformed_expected_output_np, rtol=1e-2, atol=0) elif "float" in dtype: - np.testing.assert_allclose(output_np, transformed_expected_output_np, rtol=1e-1, atol=0) + np.testing.assert_allclose(output_np, transformed_expected_output_np, rtol=1e-2, atol=0) if __name__ == "__main__": From 926044b71f9750638917c39ab8057bd3d88b22ea Mon Sep 17 00:00:00 2001 From: Josh Herrera Date: Wed, 5 Oct 2022 17:57:59 -0700 Subject: [PATCH 06/19] Use fp32 for dense intermediate computation for accuracy purposes --- python/tvm/topi/hexagon/slice_ops/dense.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/python/tvm/topi/hexagon/slice_ops/dense.py b/python/tvm/topi/hexagon/slice_ops/dense.py index a298ff4bc98e..ce67d86e9728 100644 --- a/python/tvm/topi/hexagon/slice_ops/dense.py +++ b/python/tvm/topi/hexagon/slice_ops/dense.py @@ -59,7 +59,7 @@ def dense_compute(tensor_a, tensor_b, bias=None, out_dtype=None): k = te.reduce_axis((0, in_dim), name="k") compute_lambda = lambda n, m: te.sum( - tensor_a[n, k].astype(out_dtype) * tensor_b[k, m].astype(out_dtype), axis=k + tensor_a[n, k].astype("float32") * tensor_b[k, m].astype("float32"), axis=k ) compute_name = "matmul_sliced" compute_tag = "matmul" @@ -75,10 +75,16 @@ def dense_compute(tensor_a, tensor_b, bias=None, out_dtype=None): if bias is not None: mat = te.compute( (batch, out_dim), - lambda i, j: mat[i, j] + bias[j], + lambda i, j: mat[i, j].astype(out_dtype) + bias[j].astype(out_dtype), tag=tag.BROADCAST, name="bias", ) + else: + mat = te.compute( + (batch, out_dim), + lambda i, j: mat[i, j].astype(out_dtype), + tag=tag.BROADCAST, + name="cast",) return mat @@ -119,6 +125,7 @@ def dense_schedule(outs, ins, output_layout: str, input_layout: str): bias = s.get_block("bias") except tir.schedule.schedule.ScheduleError: bias = None + cast = s.get_block("cast") input_transform_fn = get_layout_transform_fn(input_layout) output_transform_fn = get_layout_transform_fn(output_layout) @@ -127,7 +134,7 @@ def dense_schedule(outs, ins, output_layout: str, input_layout: str): if bias is None: s.transform_layout(matmul, ("read", 0), input_transform_fn) # s.transform_layout(matmul, ("read", 1), input_transform_fn) - s.transform_layout(matmul, ("write", 0), output_transform_fn) + s.transform_layout(cast, ("write", 0), output_transform_fn) else: s.transform_layout(matmul, ("read", 0), input_transform_fn) s.transform_layout(bias, ("write", 0), output_transform_fn) From 9ca4cb82ba226be266a818c221a137ab4da47021 Mon Sep 17 00:00:00 2001 From: Josh Herrera Date: Wed, 5 Oct 2022 17:58:25 -0700 Subject: [PATCH 07/19] TODO comment --- tests/python/contrib/test_hexagon/topi/test_dense_slice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/contrib/test_hexagon/topi/test_dense_slice.py b/tests/python/contrib/test_hexagon/topi/test_dense_slice.py index 06843aeb02c6..e6054d808faa 100644 --- a/tests/python/contrib/test_hexagon/topi/test_dense_slice.py +++ b/tests/python/contrib/test_hexagon/topi/test_dense_slice.py @@ -126,7 +126,7 @@ class TestDenseSlice: "float16", ), # ( - # [1, 1024], + # [1, 1024], # TODO(joshherr-quic): Fix assertion in LLVM when bias is enabled. # [1, 1024], # "nc-1024c-2d", # True, From 016c6a57980962ad42d4d4f358680400215fdac8 Mon Sep 17 00:00:00 2001 From: Josh Herrera Date: Thu, 6 Oct 2022 12:36:46 -0700 Subject: [PATCH 08/19] Revert "Use fp32 for dense intermediate computation for accuracy purposes" This reverts commit 8fa3ff3e58d5e3d606e3e761d339a0f7ea17e9ca. --- python/tvm/topi/hexagon/slice_ops/dense.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/python/tvm/topi/hexagon/slice_ops/dense.py b/python/tvm/topi/hexagon/slice_ops/dense.py index ce67d86e9728..a298ff4bc98e 100644 --- a/python/tvm/topi/hexagon/slice_ops/dense.py +++ b/python/tvm/topi/hexagon/slice_ops/dense.py @@ -59,7 +59,7 @@ def dense_compute(tensor_a, tensor_b, bias=None, out_dtype=None): k = te.reduce_axis((0, in_dim), name="k") compute_lambda = lambda n, m: te.sum( - tensor_a[n, k].astype("float32") * tensor_b[k, m].astype("float32"), axis=k + tensor_a[n, k].astype(out_dtype) * tensor_b[k, m].astype(out_dtype), axis=k ) compute_name = "matmul_sliced" compute_tag = "matmul" @@ -75,16 +75,10 @@ def dense_compute(tensor_a, tensor_b, bias=None, out_dtype=None): if bias is not None: mat = te.compute( (batch, out_dim), - lambda i, j: mat[i, j].astype(out_dtype) + bias[j].astype(out_dtype), + lambda i, j: mat[i, j] + bias[j], tag=tag.BROADCAST, name="bias", ) - else: - mat = te.compute( - (batch, out_dim), - lambda i, j: mat[i, j].astype(out_dtype), - tag=tag.BROADCAST, - name="cast",) return mat @@ -125,7 +119,6 @@ def dense_schedule(outs, ins, output_layout: str, input_layout: str): bias = s.get_block("bias") except tir.schedule.schedule.ScheduleError: bias = None - cast = s.get_block("cast") input_transform_fn = get_layout_transform_fn(input_layout) output_transform_fn = get_layout_transform_fn(output_layout) @@ -134,7 +127,7 @@ def dense_schedule(outs, ins, output_layout: str, input_layout: str): if bias is None: s.transform_layout(matmul, ("read", 0), input_transform_fn) # s.transform_layout(matmul, ("read", 1), input_transform_fn) - s.transform_layout(cast, ("write", 0), output_transform_fn) + s.transform_layout(matmul, ("write", 0), output_transform_fn) else: s.transform_layout(matmul, ("read", 0), input_transform_fn) s.transform_layout(bias, ("write", 0), output_transform_fn) From 1b0433c84198e7ece37c0378cd5b624eb2a6a2cd Mon Sep 17 00:00:00 2001 From: Josh Herrera Date: Thu, 6 Oct 2022 12:48:53 -0700 Subject: [PATCH 09/19] Use correct type for float bias --- .../test_hexagon/topi/test_dense_slice.py | 36 ++++++++++--------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/tests/python/contrib/test_hexagon/topi/test_dense_slice.py b/tests/python/contrib/test_hexagon/topi/test_dense_slice.py index e6054d808faa..5c978a2af960 100644 --- a/tests/python/contrib/test_hexagon/topi/test_dense_slice.py +++ b/tests/python/contrib/test_hexagon/topi/test_dense_slice.py @@ -66,7 +66,11 @@ def weight_quant(weight_np, dtype): @tvm.testing.fixture def bias_np(bias_shape, bias, dtype): if bias: - return np.random.randint(-128, 127, size=bias_shape).astype("int32") + if "int" in dtype: + data = np.random.randint(-128, 127, size=bias_shape).astype("int32") + elif "float" in dtype: + data = np.random.random(bias_shape).astype(dtype) + return data else: return None @@ -125,13 +129,13 @@ class TestDenseSlice: False, "float16", ), - # ( - # [1, 1024], # TODO(joshherr-quic): Fix assertion in LLVM when bias is enabled. - # [1, 1024], - # "nc-1024c-2d", - # True, - # "float16", - # ), + ( + [1, 1024], + [1, 1024], + "nc-1024c-2d", + True, + "float16", + ), ( [1, 2048], [1, 2048], @@ -139,13 +143,13 @@ class TestDenseSlice: False, "float16", ), - # ( - # [1, 2048], - # [1, 2048], - # "nc-1024c-2d", - # True, - # "float16", - # ), + ( + [1, 2048], + [1, 2048], + "nc-1024c-2d", + True, + "float16", + ), ( # Uint 8 [1, 2048], [1, 2048], @@ -265,7 +269,7 @@ def test_dense_slice( args.append(quant_arr[3]) if bias_np is not None: - B = te.placeholder((output_shape[-1],), name="B", dtype="int32") + B = te.placeholder((output_shape[-1],), name="B", dtype=str(bias_np.dtype)) args.append(B) tensors.append(B) else: From c4a81225c91cd775893d3b5461967ea80e9ae928 Mon Sep 17 00:00:00 2001 From: Josh Herrera Date: Tue, 18 Oct 2022 07:48:06 -0700 Subject: [PATCH 10/19] Add synr to conda dependencies --- conda/build-environment.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/conda/build-environment.yaml b/conda/build-environment.yaml index a1b43eb6ef0c..400e7b287d7c 100644 --- a/conda/build-environment.yaml +++ b/conda/build-environment.yaml @@ -37,3 +37,5 @@ dependencies: - make - scipy - pillow + - synr + \ No newline at end of file From 03d8a21576c44187647661e1da7494864f77f6c1 Mon Sep 17 00:00:00 2001 From: Josh Herrera Date: Tue, 18 Oct 2022 07:48:45 -0700 Subject: [PATCH 11/19] Revert "Add synr to conda dependencies" This reverts commit 9c21e6a839f16ed8af08267a0c07155680bd47a4. --- conda/build-environment.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/conda/build-environment.yaml b/conda/build-environment.yaml index 400e7b287d7c..a1b43eb6ef0c 100644 --- a/conda/build-environment.yaml +++ b/conda/build-environment.yaml @@ -37,5 +37,3 @@ dependencies: - make - scipy - pillow - - synr - \ No newline at end of file From 45fdc73c65d9957100097e275c62610181d4b8a6 Mon Sep 17 00:00:00 2001 From: Josh Herrera Date: Mon, 24 Oct 2022 12:27:18 -0700 Subject: [PATCH 12/19] Formatting --- tests/python/contrib/test_hexagon/topi/test_dense_slice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/contrib/test_hexagon/topi/test_dense_slice.py b/tests/python/contrib/test_hexagon/topi/test_dense_slice.py index 5c978a2af960..348e8597f071 100644 --- a/tests/python/contrib/test_hexagon/topi/test_dense_slice.py +++ b/tests/python/contrib/test_hexagon/topi/test_dense_slice.py @@ -212,7 +212,7 @@ class TestDenseSlice: def expected_output_np(self, input_np, weight_np, bias_np, bias): ref_np = tvm.topi.testing.dense( np.reshape(input_np, (input_np.shape[0], input_np.shape[-1])), - weight_np.T, # Function expects [in_dim, out_dim] + weight_np.T, # Function expects [in_dim, out_dim] bias_np, use_bias=bias, out_dtype="float32" if "int" in str(input_np.dtype) else input_np.dtype, From de853b2cf7fc0b1551372456e114a09d7f17994a Mon Sep 17 00:00:00 2001 From: Josh Herrera Date: Sat, 29 Oct 2022 12:23:55 -0700 Subject: [PATCH 13/19] Relax float16 test constraints --- tests/python/contrib/test_hexagon/topi/test_dense_slice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/contrib/test_hexagon/topi/test_dense_slice.py b/tests/python/contrib/test_hexagon/topi/test_dense_slice.py index 348e8597f071..33a4ab760b68 100644 --- a/tests/python/contrib/test_hexagon/topi/test_dense_slice.py +++ b/tests/python/contrib/test_hexagon/topi/test_dense_slice.py @@ -347,7 +347,7 @@ def test_dense_slice( if "int" in dtype: np.testing.assert_allclose(output_np, transformed_expected_output_np, rtol=1e-2, atol=0) elif "float" in dtype: - np.testing.assert_allclose(output_np, transformed_expected_output_np, rtol=1e-2, atol=0) + np.testing.assert_allclose(output_np, transformed_expected_output_np, rtol=1e-1, atol=0) if __name__ == "__main__": From 01bf3980ff09c5daf51ba180b27c6c6f4b9a0247 Mon Sep 17 00:00:00 2001 From: Josh Herrera Date: Sun, 30 Oct 2022 14:13:08 -0700 Subject: [PATCH 14/19] Remove old assert and reduce number of tests --- .../test_hexagon/topi/test_dense_slice.py | 60 +------------------ 1 file changed, 1 insertion(+), 59 deletions(-) diff --git a/tests/python/contrib/test_hexagon/topi/test_dense_slice.py b/tests/python/contrib/test_hexagon/topi/test_dense_slice.py index 33a4ab760b68..6680e2c65b2b 100644 --- a/tests/python/contrib/test_hexagon/topi/test_dense_slice.py +++ b/tests/python/contrib/test_hexagon/topi/test_dense_slice.py @@ -117,7 +117,7 @@ def transformed_input_quant(input_quant, layout): # return weight_quant # Test combinations of the following: -# dtype in (float16, uint8, int8) +# dtype in (float16, uint8) # num_croutons in (1, 2) # bias_enabled in (true, false) class TestDenseSlice: @@ -129,20 +129,6 @@ class TestDenseSlice: False, "float16", ), - ( - [1, 1024], - [1, 1024], - "nc-1024c-2d", - True, - "float16", - ), - ( - [1, 2048], - [1, 2048], - "nc-1024c-2d", - False, - "float16", - ), ( [1, 2048], [1, 2048], @@ -157,20 +143,6 @@ class TestDenseSlice: False, "uint8", ), - ( - [1, 2048], - [1, 2048], - "nc-2048c-2d", - True, - "uint8", - ), - ( - [1, 4096], - [1, 4096], - "nc-2048c-2d", - False, - "uint8", - ), ( [1, 4096], [1, 4096], @@ -178,34 +150,6 @@ class TestDenseSlice: True, "uint8", ), - ( # Int 8 - [1, 2048], - [1, 2048], - "nc-2048c-2d", - False, - "int8", - ), - ( - [1, 2048], - [1, 2048], - "nc-2048c-2d", - True, - "int8", - ), - ( - [1, 4096], - [1, 4096], - "nc-2048c-2d", - False, - "int8", - ), - ( - [1, 4096], - [1, 4096], - "nc-2048c-2d", - True, - "int8", - ), ) @tvm.testing.fixture @@ -248,8 +192,6 @@ def test_dense_slice( quant_arr, hexagon_session: Session, ): - if hexagon_session._launcher._serial_number != "simulator": - pytest.skip(msg="Due to https://github.com/apache/tvm/issues/11928") target_hexagon = tvm.target.hexagon("v69") A = te.placeholder(input_shape, name="A", dtype=dtype) From e29088e1e30ae2f7135e57d5c11cb766965f8f11 Mon Sep 17 00:00:00 2001 From: joshherr-quic <95375797+joshherr-quic@users.noreply.github.com> Date: Mon, 31 Oct 2022 16:18:16 -0500 Subject: [PATCH 15/19] Update comments --- python/tvm/topi/hexagon/qnn/qdense.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/hexagon/qnn/qdense.py b/python/tvm/topi/hexagon/qnn/qdense.py index ed4ea56b1a58..53f9077e56ba 100644 --- a/python/tvm/topi/hexagon/qnn/qdense.py +++ b/python/tvm/topi/hexagon/qnn/qdense.py @@ -45,11 +45,29 @@ def qdense_compute( tensor_b : tvm.te.Tensor weight 2-D with shape [in_dim, out_dim] + zero_a : integer + quantization zero point for tensor a. + + scale_a : float + quantization scale for tensor a. + + zero_b : integer + quantization zero point for tensor b. + + scale_b : float + quantization scale for tensor b. + + zero_out : Optional[integer] + quantization zero point for output. + + scale_out : Optional[float] + quantization scale for output. + bias : Optional[tvm.te.Tensor] 1-D with shape [out_dim] - out_dtype : Optional[str] - The output type. This is used for mixed precision. + q_dtype : Optional[str] + The output type. Returns ------- From bc0446cf7786b71f61847644de3e752d165d1d08 Mon Sep 17 00:00:00 2001 From: Josh Herrera Date: Mon, 7 Nov 2022 11:56:28 -0800 Subject: [PATCH 16/19] Move test file --- .../contrib/test_hexagon/topi/{ => slice_op}/test_dense_slice.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/python/contrib/test_hexagon/topi/{ => slice_op}/test_dense_slice.py (100%) diff --git a/tests/python/contrib/test_hexagon/topi/test_dense_slice.py b/tests/python/contrib/test_hexagon/topi/slice_op/test_dense_slice.py similarity index 100% rename from tests/python/contrib/test_hexagon/topi/test_dense_slice.py rename to tests/python/contrib/test_hexagon/topi/slice_op/test_dense_slice.py From 4749107cd602687844f955248188b23da7133b87 Mon Sep 17 00:00:00 2001 From: Josh Herrera Date: Mon, 14 Nov 2022 09:28:12 -0800 Subject: [PATCH 17/19] Update imports --- .../contrib/test_hexagon/topi/slice_op/test_dense_slice.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/contrib/test_hexagon/topi/slice_op/test_dense_slice.py b/tests/python/contrib/test_hexagon/topi/slice_op/test_dense_slice.py index 6680e2c65b2b..33518d8eb8d4 100644 --- a/tests/python/contrib/test_hexagon/topi/slice_op/test_dense_slice.py +++ b/tests/python/contrib/test_hexagon/topi/slice_op/test_dense_slice.py @@ -26,7 +26,8 @@ from tvm.contrib.hexagon.session import Session import tvm.topi.hexagon.qnn as qnn import tvm.topi.hexagon.slice_ops as sl -from ..infrastructure import allocate_hexagon_array, transform_numpy, quantize_np +from ...infrastructure import transform_numpy, quantize_np +from tvm.contrib.hexagon import allocate_hexagon_array @tvm.testing.fixture From 76e8c01d4f5d04327716130967b05fb7e770fbb1 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Mon, 23 Jan 2023 10:29:48 -0800 Subject: [PATCH 18/19] Formatting --- python/tvm/topi/hexagon/qnn/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/topi/hexagon/qnn/__init__.py b/python/tvm/topi/hexagon/qnn/__init__.py index d6f140fcd63c..022a552c9d54 100644 --- a/python/tvm/topi/hexagon/qnn/__init__.py +++ b/python/tvm/topi/hexagon/qnn/__init__.py @@ -27,4 +27,3 @@ from .qdense import * from .qdepthwise_conv2d_slice import qdepthwise_conv2d_compute, qdepthwise_conv2d_schedule from .quantize import quantize_compute, tir_quantize_schedule - From 849b84286fa56cdab4cd1d8617f4a7944af8369d Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Tue, 24 Jan 2023 05:38:35 -0800 Subject: [PATCH 19/19] Remove extra comments --- .../topi/slice_op/test_dense_slice.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/tests/python/contrib/test_hexagon/topi/slice_op/test_dense_slice.py b/tests/python/contrib/test_hexagon/topi/slice_op/test_dense_slice.py index 33518d8eb8d4..e616c384fb40 100644 --- a/tests/python/contrib/test_hexagon/topi/slice_op/test_dense_slice.py +++ b/tests/python/contrib/test_hexagon/topi/slice_op/test_dense_slice.py @@ -98,12 +98,6 @@ def transformed_input_np(input_np, layout): return transform_numpy(input_np, "nc", layout) -# TODO(joshherr-quic): transforming weight forces us to put it in vtcm. Crashes at runtime in vtcm -# @tvm.testing.fixture -# def transformed_weight_np(weight_np, layout): -# return transform_numpy(weight_np, "nc", layout) - - @tvm.testing.fixture def transformed_input_quant(input_quant, layout): if input_quant is None: @@ -112,15 +106,6 @@ def transformed_input_quant(input_quant, layout): return input_quant -# @tvm.testing.fixture -# def transformed_weight_quant(weight_quant, layout): -# weight_quant["data"] = transform_numpy(weight_quant["data"], "nc", layout) -# return weight_quant - -# Test combinations of the following: -# dtype in (float16, uint8) -# num_croutons in (1, 2) -# bias_enabled in (true, false) class TestDenseSlice: (input_shape, output_shape, layout, bias, dtype,) = tvm.testing.parameters( ( # Float 16