From 21ae90473606bc699cc3373d8818fef457b2a4e0 Mon Sep 17 00:00:00 2001 From: Gayatri Panchapakesan Kumari Date: Mon, 17 Oct 2022 16:38:22 +0530 Subject: [PATCH 1/4] [TOPI][Hexagon] Implement quantized depthwise conv2d --- python/tvm/topi/hexagon/qnn/__init__.py | 1 + .../hexagon/qnn/qdepthwise_conv2d_slice.py | 218 +++++++++++ python/tvm/topi/hexagon/slice_ops/dwconv2d.py | 5 +- python/tvm/topi/hexagon/utils.py | 20 ++ .../topi/test_depthwise_conv2d_slice.py | 338 ++++++++++++++++++ .../test_hexagon/topi/test_dwconv2d_slice.py | 314 ---------------- 6 files changed, 580 insertions(+), 316 deletions(-) create mode 100644 python/tvm/topi/hexagon/qnn/qdepthwise_conv2d_slice.py create mode 100644 tests/python/contrib/test_hexagon/topi/test_depthwise_conv2d_slice.py delete mode 100644 tests/python/contrib/test_hexagon/topi/test_dwconv2d_slice.py diff --git a/python/tvm/topi/hexagon/qnn/__init__.py b/python/tvm/topi/hexagon/qnn/__init__.py index bafc6846b6fb..f7a018d2257a 100644 --- a/python/tvm/topi/hexagon/qnn/__init__.py +++ b/python/tvm/topi/hexagon/qnn/__init__.py @@ -26,3 +26,4 @@ from .quantize import quantize_compute, tir_quantize_schedule from .nn import * +from .qdepthwise_conv2d_slice import qdepthwise_conv2d_compute, qdepthwise_conv2d_schedule diff --git a/python/tvm/topi/hexagon/qnn/qdepthwise_conv2d_slice.py b/python/tvm/topi/hexagon/qnn/qdepthwise_conv2d_slice.py new file mode 100644 index 000000000000..99a4d3319e86 --- /dev/null +++ b/python/tvm/topi/hexagon/qnn/qdepthwise_conv2d_slice.py @@ -0,0 +1,218 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-variable, unused-argument, too-many-locals +""" +Please note the following assumptions made by the implementation: +1) The input must be padded in advance to account for 'padding'. In addition, + both input and output must be padded as per the physical buffer layout. +2) 'padding' is ignored. It must be handled outside of the sliced op. +3) The weights are expected to be as per physical layout + +The initial compute for quantized depthwise conv2d is as follows +where cm = channel_multiplier; assumed to be 1, +zp_a = Activation_zero_point, +zp_w = Weight_zero_point, +Qa = Quantized Activation, +Qw = Quantized Weights. + + a) Qc(n, oh, ow, oc) = (Sigma(r, s) (Qw(r, s, oc%cm, oc/cm) - zp_w) + * (Qa(n, oh + r, ow + s, oc/cm) - zp_a)) + * scale_value + where scale_value = (activation_scale * weight_scale) / output_scale + + This can be written as + + b) Qc(n, oh, ow, oc) = (t1 - t2 - t3 + t4) * scale_value + + where t1 = Sigma(r, s) Qw(r, s, oc%cm, oc/cm) * Qa(n, oh + r, ow + s, oc/cm) + t2 = Sigma(r, s) zp_w * Qa(n, oh + r, ow + s, oc/cm) + t3 = Sigma(r, s) zp_a * Qw(r, s, oc%cm, oc/cm) + t4 = Sigma(r, s) zp_a * zp_w + + c) Qc(n, oh, ow, oc) = saturate(((t1 - t2 - t3 + t4) * fixed_scale_value)) >> rsh) + + where fixed_scale_value, rsh are fixed point values for scale_value. + + +Compute and schedule for quantized depthwise conv2d slice op""" + +import typing +import tvm +from tvm import te +from ..utils import get_layout_transform_fn, get_fixed_point_value, saturate + + +def qdepthwise_conv2d_compute( + activations: te.Tensor, + weights: te.Tensor, + out_shape: typing.Tuple, + stride: typing.Tuple, + dilation: typing.Tuple, + dtype: str, + # quantization params: + activation_zero_point, + activation_scale, + weight_zero_point, + weight_scale, + output_zero_point, + output_scale, +): + """Compute for quantized depthwise conv2d""" + filt_shape = weights.shape + ob, oh, ow, oc = out_shape + + if dtype == "uint8": + temp_dtype = "int32" + big_dtype = "int64" + elif dtype == "int8": + temp_dtype = "int32" + big_dtype = "int64" + else: + raise RuntimeError(f"Unsupported output dtype, {odtype}'") + + reduce_height = tvm.te.reduce_axis((0, filt_shape[0]), name="reduce_height") + reduce_width = tvm.te.reduce_axis((0, filt_shape[1]), name="reduce_width") + stride_height, stride_width = stride + dilation_height, dilation_width = dilation + + scale_value = (activation_scale * weight_scale) / output_scale + fixed_scale_value, rsh = get_fixed_point_value(scale_value, "int16") + + t1 = tvm.te.compute( + out_shape, + lambda n, h, w, c: tvm.te.sum( + ( + ( + activations[ + n, + h * stride_height + reduce_height * dilation_height, + w * stride_width + reduce_width * dilation_width, + c, + ].astype(temp_dtype) + ) + * (weights[reduce_height, reduce_width, 0, c].astype(temp_dtype)) + ).astype(temp_dtype), + axis=[reduce_height, reduce_width], + ), + name="t1", + ) + + t2 = tvm.te.compute( + out_shape, + lambda n, h, w, c: tvm.te.sum( + ( + ( + activations[ + n, + h * stride_height + reduce_height * dilation_height, + w * stride_width + reduce_width * dilation_width, + c, + ].astype(temp_dtype) + ) + * weight_zero_point + ).astype(temp_dtype), + axis=[reduce_height, reduce_width], + ), + name="t2", + ) + + t3 = tvm.te.compute( + (oc,), + lambda c: tvm.te.sum( + ( + ((weights[reduce_height, reduce_width, 0, c].astype(temp_dtype))) + * activation_zero_point + ).astype(temp_dtype), + axis=[reduce_height, reduce_width], + ), + name="t3", + ) + + t4 = activation_zero_point * weight_zero_point * reduce_height * reduce_width + + output = tvm.te.compute( + out_shape, + lambda n, h, w, c: saturate( + ( + ( + ( + ((t1[n, h, w, c]).astype(big_dtype) - t2[n, h, w, c] - t3[c] + t4) + * fixed_scale_value + ) + >> rsh + ) + + (output_zero_point).astype(big_dtype) + ), + dtype, + ).astype(dtype), + name="output", + ) + + return output + + +def qdepthwise_conv2d_schedule( + outs: te.Tensor, + ins: typing.List[te.Tensor], + transform_activation_layout: str, + transform_weights: str, +): + """ + Schedule for quantized depthwise conv2d for input layout nhwc-8h8w32c + assert len(ins) == 2, "This schedule expects only 2 inputs - Activations and Weights + """ + source_expr = ins + [outs] + prim_func = tvm.te.create_prim_func(source_expr) + sch = tvm.tir.Schedule(prim_func) + + compute = sch.get_block("output") + compute1 = sch.get_block("t1") + + transform_layout_fn = get_layout_transform_fn(transform_activation_layout) + transform_layout_weights = get_layout_transform_fn(transform_weights) + + # Apply layout_transform for activation + sch.transform_layout(compute1, ins[0].name, transform_layout_fn) + + # Apply layout_transform for weights + sch.transform_layout(compute1, ins[1].name, transform_layout_weights) + + # Apply layout_transform for output + sch.transform_layout(compute, outs.name, transform_layout_fn) + + # This returns the original 6d loop + batch, height, width, channel, reduce_height, reduce_width = sch.get_loops(compute1) + h_outer, h_inner = sch.split(height, [None, 8]) + w_outer, w_inner = sch.split(width, [None, 8]) + c_outer, c_inner = sch.split(channel, [None, 32]) + sch.reorder( + batch, + h_outer, + w_outer, + c_outer, + h_inner, + reduce_height, + reduce_width, + w_inner, + c_inner, + ) + + sch.decompose_reduction(compute1, reduce_height) + # wi_ci = sch.fuse(w_inner,c_inner) + # sch.vectorize(wi_ci) + return sch + diff --git a/python/tvm/topi/hexagon/slice_ops/dwconv2d.py b/python/tvm/topi/hexagon/slice_ops/dwconv2d.py index 698495daf1b7..d22dc02a5c1b 100644 --- a/python/tvm/topi/hexagon/slice_ops/dwconv2d.py +++ b/python/tvm/topi/hexagon/slice_ops/dwconv2d.py @@ -85,7 +85,7 @@ def dwconv2d_schedule( outs: te.Tensor, ins: typing.List[te.Tensor], transform_activation_layout: str, - transform_weights: typing.Callable, + transform_weights: str, ) -> tvm.tir.Schedule: """STIR schedule definition for the compute defined above by dwconv2d_compute. - Auto-generated prim_func before applying schedule primitives for reference @@ -128,11 +128,12 @@ def main(InputTensor: T.Buffer[(1, 16, 8, 32), "float16"], Weights: T.Buffer[(3, sch = tvm.tir.Schedule(prim_func) compute = sch.get_block("Output") transform_layout_fn = get_layout_transform_fn(transform_activation_layout) + transform_layout_weights = get_layout_transform_fn(transform_weights) # Apply layout_transform for activation sch.transform_layout(compute, ins[0].name, transform_layout_fn) # Apply layout_transform for weights - sch.transform_layout(compute, ins[1].name, transform_weights) + sch.transform_layout(compute, ins[1].name, transform_layout_weights) # Apply layout_transform for output sch.transform_layout(compute, outs.name, transform_layout_fn) diff --git a/python/tvm/topi/hexagon/utils.py b/python/tvm/topi/hexagon/utils.py index dab9aa3f74ab..1bf4b5ef6af6 100644 --- a/python/tvm/topi/hexagon/utils.py +++ b/python/tvm/topi/hexagon/utils.py @@ -127,6 +127,10 @@ def iohw_16i32o2i_1d(height, width, in_channel, out_channel): ] +def ohwi32o_1d(height, width, in_channel, out_channel): + return [out_channel // 32, height, width, in_channel, out_channel % 32] + + def get_layout_transform_fn(layout): """Return index map function as per the layout string""" if layout == "nhwc-8h2w32c2w-2d": @@ -167,6 +171,8 @@ def get_layout_transform_fn(layout): return nhwc_8h8w32c_2d if layout == "n11c-2048c-2d": return n11c_2048c_2d + if layout == "ohwi32o-1d": + return ohwi32o_1d raise RuntimeError(f"Unexpected layout '{layout}'") @@ -235,6 +241,19 @@ def get_fixed_point_value(flp: float, dtype: str = "int16") -> Tuple[int, int]: best scaling factor for 'int16' type that can be used to convert the floating-point value to fixed-point with the least amount of precision loss. + + Here is a more rigorous explanation of the above, for non-negative scale values, which are of interest. + M < 2, so M * 2^(E-Bias+x) < 2 ^ (E-Bias+x+1) [Note: LHS is a fraction, RHS int] + => round(M * 2^(E-Bias+x)) <= 2 ^ (E-Bias+x+1) [Note the "<=", not "<"] + We want x s.t. round(M * 2^(E-Bias+x)) <= 2^15 - 1 + We know round(M * 2^(E-Bias+x)) <= 2^(E-Bias+x+1) + It will be sufficient to choose x s.t. 2^(E-Bias+x+1) <= 2^15 - 1 + That is, max x. s.t. 2^(E-Bias+x+1) < 2^15 + E-Bias+x+1 < 15 + E-Bias+x+1 <= 14 + Max x will make E-Bias+x+1 = 14 + x = 13 - E + Bias + Additonal notes on various floating-point values: ------------------------------------------------ 1) Denormalized values: causes assertion failure. The problem with the denormalized values @@ -299,3 +318,4 @@ def within_range(val, dtype): def saturate(x: te.Tensor, dtype: str): """Saturate value for the specified data type""" return te.max(te.min_value(dtype), te.min(x, te.max_value(dtype))) + diff --git a/tests/python/contrib/test_hexagon/topi/test_depthwise_conv2d_slice.py b/tests/python/contrib/test_hexagon/topi/test_depthwise_conv2d_slice.py new file mode 100644 index 000000000000..0072c4b51006 --- /dev/null +++ b/tests/python/contrib/test_hexagon/topi/test_depthwise_conv2d_slice.py @@ -0,0 +1,338 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-variable, unused-argument, disable=line-too-long, redefined-outer-name + +"""Test depthwise_conv2d slice op for hexagon.""" + +import numpy as np +import tvm +import tvm.testing +import tvm.topi.hexagon.qnn as qn +from tvm.topi.testing import depthwise_conv2d_python_nhwc +from tvm.topi.hexagon.slice_ops.dwconv2d import dwconv2d_compute, dwconv2d_schedule +from ..infrastructure import allocate_hexagon_array, transform_numpy, quantize_np + + +@tvm.testing.fixture +def input_np(in_shape, dtype, low, high): + if dtype in ("uint8"): + return np.random.uniform(low=low, high=high, size=in_shape).astype("float32") + if dtype in ("int8"): + return np.random.uniform(low=-low, high=high, size=in_shape).astype("float32") + return np.random.uniform(size=in_shape).astype(dtype) + + +@tvm.testing.fixture +def input_np_padded(input_np, in_shape, padded_in_shape): + pad_height = padded_in_shape[1] - in_shape[1] + pad_width = padded_in_shape[2] - in_shape[2] + pad_channel = padded_in_shape[3] - in_shape[3] + input_padded = np.pad( + input_np, ((0, 0), (0, pad_height), (0, pad_width), (0, pad_channel)), "constant" + ) + return input_padded + + +@tvm.testing.fixture +def in_out_layout(dtype): + if dtype == "float16": + return "nhwc-8h2w32c2w-2d" + elif dtype in ("uint8", "int8"): + return "nhwc-8h8w32c-2d" + else: + raise RuntimeError(f"Unsupported quantized data type '{dtype}'") + + +@tvm.testing.fixture +def expected_output_np(input_np, dilated_weights_np, stride, dtype): + dilated_weights_np_t = dilated_weights_np.transpose(0, 1, 3, 2) + ref_type = dtype + if dtype in ("uint8", "int8"): + # for quantized versions, return float32 output + ref_type = "float32" + ref_np = depthwise_conv2d_python_nhwc( + input_np.astype("float32"), dilated_weights_np_t.astype("float32"), stride, padding=0 + ).astype(ref_type) + return ref_np + + +@tvm.testing.fixture +def transformed_expected_output_np(expected_output_np, in_out_layout, dtype): + if dtype == "float16": + return transform_numpy(expected_output_np, "nhwc", in_out_layout) + elif dtype in ("uint8", "int8"): + quant_arr, scale, zero_point = quantize_np(expected_output_np, dtype) + return [transform_numpy(quant_arr, "nhwc", in_out_layout), scale, zero_point] + else: + raise RuntimeError(f"Unsupported data type '{dtype}'") + + +@tvm.testing.fixture +def transformed_input_np_padded(input_np_padded, in_out_layout, dtype): + if dtype == "float16": + return transform_numpy(input_np_padded, "nhwc", in_out_layout) + if dtype in ("uint8", "int8"): + quant_arr, scale, zero_point = quantize_np(input_np_padded, dtype) + return [transform_numpy(quant_arr, "nhwc", in_out_layout), scale, zero_point] + raise RuntimeError(f"Unsupported data type '{dtype}'") + + +@tvm.testing.fixture +def weights_np(filt_shape, dtype): + if dtype == "float16": + return np.random.uniform(size=filt_shape).astype(dtype) + elif dtype in ("uint8", "int8"): + weight_arr = np.random.uniform(low=-5, high=5, size=filt_shape).astype("float32") + return weight_arr + else: + raise RuntimeError(f"Unsupported data type '{dtype}'") + + +@tvm.testing.fixture +def dilated_filt_shape(filt_shape, dilation): + """Compute the dilated filter shape when dilation > 1""" + filt_height, filt_width, in_channel, out_channel = filt_shape + dilation_height, dilation_width = dilation + if dilation_height == 1 and dilation_width == 1: + return filt_shape + dilated_height = dilation_height * (filt_height - 1) + 1 + dilated_width = dilation_width * (filt_width - 1) + 1 + return dilated_height, dilated_width, in_channel, out_channel + + +@tvm.testing.fixture +def dilated_weights_np(weights_np, dilation, dilated_filt_shape, dtype): + """Get dilated weights from original weights for testing""" + if dtype in ["int8", "uint8"]: + dtype = "float32" + filt_height, filt_width, in_channels, out_channels = weights_np.shape + dilated_weights = np.zeros(dilated_filt_shape) + dilation_height, dilation_width = dilation + if dilation_height == 1 and dilation_width == 1: + return weights_np + dilated_height, dilated_width = dilated_filt_shape[0], dilated_filt_shape[1] + for in_channel in range(in_channels): + for out_channel in range(out_channels): + for dilation_i, height_i in zip( + range(0, dilated_height, dilation_height), range(filt_height) + ): + for dilation_j, width_j in zip( + range(0, dilated_width, dilation_width), range(filt_width) + ): + dilated_weights[dilation_i, dilation_j, in_channel, out_channel] = weights_np[ + height_i, width_j, in_channel, out_channel + ] + return dilated_weights + + +@tvm.testing.fixture +def transformed_weights_np(weights_np, dtype): + height, width, in_channel, out_channel = weights_np.shape + t = weights_np.reshape([height, width, in_channel, out_channel // 32, 32]).transpose( + 3, 0, 1, 2, 4 + ) + if dtype == "float16": + return t + if dtype in ("uint8", "int8"): + quant_arr, scale, zero_point = quantize_np(t, dtype) + return [quant_arr, scale, zero_point] + raise RuntimeError(f"Unsupported data type '{dtype}'") + + +def generate_test_config(test_params): + """Utility function to generate test config with meaningful ids""" + test_config = {} + + dims = lambda vals: "x".join(map(str, vals)) + + for param in test_params: + in_shape, filt_shape, stride, dilation = param[:4] + test_name = f"nhwc{dims(in_shape)}-hwio{dims(filt_shape)}-stride{dims(stride)}-dilation{dims(dilation)}" + test_config[test_name] = param + + return test_config + + +class Testdwconv2dSlice: + """Test class that defines the dwconv2d slice test""" + + test_params = [ + [(1, 10, 10, 32), (3, 3, 1, 32), (1, 1), (1, 1), 0.0, 10.0], + [(1, 10, 10, 64), (3, 3, 1, 64), (1, 1), (1, 1), 0.0, 10.0], + [(1, 12, 12, 32), (5, 5, 1, 32), (1, 1), (1, 1), 0.0, 20.0], + [(1, 16, 16, 32), (5, 5, 1, 32), (1, 1), (2, 2), 0.0, 1.0], + [(1, 18, 10, 32), (3, 3, 1, 32), (1, 1), (1, 1), 0.0, 10.0], + [(1, 18, 18, 32), (3, 3, 1, 32), (2, 2), (1, 1), 0.0, 10.0], + [(1, 18, 10, 96), (3, 3, 1, 96), (1, 1), (1, 1), 0.0, 10.0], + [(1, 21, 21, 32), (7, 7, 1, 32), (2, 2), (1, 1), 0.0, 10.0], + [(1, 28, 28, 32), (7, 7, 1, 32), (2, 2), (2, 2), 0.0, 10.0], + [(1, 28, 28, 96), (7, 7, 1, 96), (2, 2), (2, 2), 0.0, 10.0], + [(1, 10, 16, 32), (3, 1, 1, 32), (1, 1), (1, 1), 0.0, 10.0], + ] + + test_config = generate_test_config(test_params) + + in_shape, filt_shape, stride, dilation, low, high = tvm.testing.parameters( + *test_config.values(), ids=test_config.keys() + ) + dtype = tvm.testing.parameter("float16", "uint8") + working_scope = tvm.testing.parameter("global.vtcm") + weights_layout = tvm.testing.parameter("ohwi32o-1d") + + @tvm.testing.fixture + def padded_in_shape(self, in_shape, dtype): + """Padding the input shape according to layout""" + # NOTE: For float16, the input layout is always assumed to be nhwc-8h2w32c2w-2d and + # for int8/uint8, it's nhwc-8h8w32c-2d. + # For both nhwc-8h2w32c2w-2d and nhwc-8h8w32c-2d, the height should be a multiple + # of 8. However, the width should be a multiple of 4 for the first case and 8 for + # the second case. + in_batch, in_height, in_width, in_channel = in_shape + in_height = ((in_height + 7) // 8) * 8 + + if dtype == "float16": + in_width = ((in_width + 3) // 4) * 4 + elif dtype in ("uint8", "int8"): + in_width = ((in_width + 7) // 8) * 8 + + in_channel = ((in_channel + 31) // 32) * 32 + + return in_batch, in_height, in_width, in_channel + + @tvm.testing.fixture + def out_shape(self, in_shape, dilated_filt_shape, stride): + in_batch, in_height, in_width, _ = in_shape + filt_height, filt_width, _, num_filt = dilated_filt_shape + out_height = (in_height - filt_height) // stride[0] + 1 + out_width = (in_width - filt_width) // stride[1] + 1 + out_channel = num_filt + return in_batch, out_height, out_width, out_channel + + @tvm.testing.requires_hexagon + def test_dwconv2d( + self, + dtype, + in_out_layout, + weights_layout, + padded_in_shape, + weights_np, + filt_shape, + stride, + dilation, + out_shape, + input_np, + input_np_padded, + transformed_weights_np, + expected_output_np, + target, + working_scope, + transformed_input_np_padded, + transformed_expected_output_np, + hexagon_session, + ): + """Main test function that tests the dwconv2d slice op""" + input_tensor = tvm.te.placeholder(padded_in_shape, name="InputTensor", dtype=dtype) + weights = tvm.te.placeholder(filt_shape, name="Weights", dtype=dtype) + + target_hexagon = tvm.target.hexagon("v69") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + # Construct compute and schedule based on dtype + if dtype in ("uint8", "int8"): + in_data_np, activation_scale, activation_zero_point = transformed_input_np_padded + ( + weights_data_np, + weight_scale, + weight_zero_point, + ) = transformed_weights_np + out_data_np, output_scale, output_zero_point = transformed_expected_output_np + + output_tensor = qn.qdepthwise_conv2d_compute( + input_tensor, + weights, + out_shape, + stride, + dilation, + dtype, + activation_zero_point, + activation_scale, + weight_zero_point, + weight_scale, + output_zero_point, + output_scale, + ) + + tir_schedule = qn.qdepthwise_conv2d_schedule( + output_tensor, [input_tensor, weights], in_out_layout, weights_layout + ) + + elif dtype == "float16": + in_data_np = transformed_input_np_padded + out_data_np = transformed_expected_output_np + weights_data_np = transformed_weights_np + output_tensor = dwconv2d_compute( + input_tensor, weights, out_shape, stride, dilation, dtype + ) + + tir_schedule = dwconv2d_schedule( + output_tensor, [input_tensor, weights], in_out_layout, weights_layout + ) + else: + raise RuntimeError(f"Unsupport dtype '{dtype}'") + + func_name = "depthwise_conv2d_slice" + with tvm.transform.PassContext(opt_level=3): + runtime_module = tvm.build( + tir_schedule.mod, + [input_tensor, output_tensor], + target=target, + name=func_name, + ) + + input_arr = allocate_hexagon_array( + hexagon_session.device, + data=in_data_np, + axis_separators=[4], + mem_scope=working_scope, + ) + + weights_arr = allocate_hexagon_array( + hexagon_session.device, data=weights_data_np, mem_scope=working_scope + ) + + output_arr = allocate_hexagon_array( + hexagon_session.device, + out_data_np.shape, + dtype=dtype, + axis_separators=[4], + mem_scope=working_scope, + ) + + mod = hexagon_session.load_module(runtime_module) + mod(input_arr, weights_arr, output_arr) + n, h, w, c = out_shape + + if dtype in ("uint8", "int8"): + output_np = output_arr.numpy().reshape([n, h // 8, w // 8, c // 32, 8, 8, 32]) + np.testing.assert_allclose(output_np, out_data_np, atol=3, rtol=0.02) + elif dtype == "float16": + output_np = output_arr.numpy() + np.testing.assert_allclose(output_np, out_data_np, atol=0.01, rtol=0.01) + + +if __name__ == "__main__": + tvm.testing.main() + diff --git a/tests/python/contrib/test_hexagon/topi/test_dwconv2d_slice.py b/tests/python/contrib/test_hexagon/topi/test_dwconv2d_slice.py deleted file mode 100644 index 3e43718afd8d..000000000000 --- a/tests/python/contrib/test_hexagon/topi/test_dwconv2d_slice.py +++ /dev/null @@ -1,314 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=line-too-long, redefined-outer-name - -"""Test dwconv2d slice op for hexagon. Input layout is always nhwc""" - -import numpy as np - -import tvm -import tvm.testing - -from tvm.topi.testing import depthwise_conv2d_python_nhwc -from tvm.topi.hexagon.slice_ops.dwconv2d import dwconv2d_compute, dwconv2d_schedule - -from ..infrastructure import allocate_hexagon_array, transform_numpy, get_hexagon_target - - -@tvm.testing.fixture -def input_np(in_shape, dtype): - return np.random.uniform(size=in_shape).astype(dtype) - - -@tvm.testing.fixture -def weights_np(filt_shape, dtype): - return (np.random.uniform(size=filt_shape)).astype(dtype) - - -@tvm.testing.fixture -def dilated_filt_shape(filt_shape, dilation): - """Compute the dilated filter shape when dilation > 1""" - filt_height, filt_width, in_channel, out_channel = filt_shape - dilation_height, dilation_width = dilation - if dilation_height == 1 and dilation_width == 1: - return filt_shape - dilated_height, dilated_width = ( - dilation_height * (filt_height - 1) + 1, - dilation_width * (filt_width - 1) + 1, - ) - return dilated_height, dilated_width, in_channel, out_channel - - -@tvm.testing.fixture -def dilated_weights_np(weights_np, dilation, dilated_filt_shape): - """Get dilated weights from original weights for testing""" - filt_height, filt_width, in_channels, out_channels = weights_np.shape - dilation_height, dilation_width = dilation - if dilation_height == 1 and dilation_width == 1: - return weights_np - dilated_height, dilated_width = dilated_filt_shape[0], dilated_filt_shape[1] - dilated_weights = np.zeros(dilated_filt_shape, dtype="float16") - for in_channel in range(in_channels): - for out_channel in range(out_channels): - for dilation_i, height_i in zip( - range(0, dilated_height, dilation_height), range(filt_height) - ): - for dilation_j, width_j in zip( - range(0, dilated_width, dilation_width), range(filt_width) - ): - dilated_weights[dilation_i, dilation_j, in_channel, out_channel] = weights_np[ - height_i, width_j, in_channel, out_channel - ] - - return dilated_weights - - -@tvm.testing.fixture -def input_np_padded(input_np, in_shape, padded_in_shape): - pad_height = padded_in_shape[1] - in_shape[1] - pad_width = padded_in_shape[2] - in_shape[2] - pad_channel = padded_in_shape[3] - in_shape[3] - input_padded = np.pad( - input_np, ((0, 0), (0, pad_height), (0, pad_width), (0, pad_channel)), "constant" - ) - return input_padded - - -@tvm.testing.fixture -def weights_np_transformed(weights_np): - height, width, in_channel, out_channel = weights_np.shape - return weights_np.reshape([height, width, in_channel, out_channel // 32, 32]).transpose( - 3, 0, 1, 2, 4 - ) - - -def generate_test_config(test_params): - """Utility function to generate test config with meaningful ids""" - test_config = {} - - dims = lambda vals: "x".join(map(str, vals)) - - for param in test_params: - in_shape, filt_shape, stride, dilation = param - test_name = f"nhwc{dims(in_shape)}-hwio{dims(filt_shape)}-stride{dims(stride)}-dilation{dims(dilation)}" - test_config[test_name] = param - - return test_config - - -class Testdwconv2dSlice: - """Test class that defines the dwconv2d slice test""" - - test_params = [ - [ - (1, 10, 6, 32), - (3, 3, 1, 32), - (1, 1), - (1, 1), - ], - [ - (1, 18, 10, 32), - (3, 3, 1, 32), - (1, 1), - (1, 1), - ], - [ - (1, 10, 6, 64), - (3, 3, 1, 64), - (1, 1), - (1, 1), - ], - [ - (1, 12, 8, 32), - (3, 3, 1, 32), - (1, 1), - (2, 2), - ], - [ - (1, 12, 8, 32), - (5, 5, 1, 32), - (1, 1), - (1, 1), - ], - [ - (1, 16, 12, 32), - (5, 5, 1, 32), - (1, 1), - (2, 2), - ], - [ - (1, 13, 9, 32), - (6, 6, 1, 32), - (1, 1), - (1, 1), - ], - [ - (1, 18, 10, 32), - (3, 3, 1, 32), - (2, 2), - (1, 1), - ], - [ - (1, 18, 10, 96), - (3, 3, 1, 96), - (2, 2), - (1, 1), - ], - [ - (1, 20, 12, 32), - (5, 5, 1, 32), - (2, 2), - (1, 1), - ], - [ - (1, 22, 14, 32), - (7, 7, 1, 32), - (2, 2), - (1, 1), - ], - [ - (1, 28, 20, 32), - (7, 7, 1, 32), - (2, 2), - (2, 2), - ], - [ - (1, 28, 20, 96), - (7, 7, 1, 96), - (2, 2), - (2, 2), - ], - [ - (1, 10, 4, 32), - (3, 1, 1, 32), - (1, 1), - (1, 1), - ], - [ - (1, 18, 8, 32), - (3, 1, 1, 32), - (2, 2), - (1, 1), - ], - [ - (1, 20, 8, 32), - (3, 1, 1, 32), - (2, 2), - (2, 2), - ], - ] - test_config = generate_test_config(test_params) - - in_shape, filt_shape, stride, dilation = tvm.testing.parameters( - *test_config.values(), ids=test_config.keys() - ) - dtype = tvm.testing.parameter("float16") - working_scope = tvm.testing.parameter("global.vtcm") - in_out_layout = tvm.testing.parameter("nhwc-8h2w32c2w-2d") - - @tvm.testing.fixture - def padded_in_shape(self, in_shape): - in_batch, in_height, in_width, in_channel = in_shape - in_height = ((in_height + 7) // 8) * 8 - in_width = ((in_width + 3) // 4) * 4 - in_channel = ((in_channel + 31) // 32) * 32 - return in_batch, in_height, in_width, in_channel - - @tvm.testing.fixture - def out_shape(self, in_shape, dilated_filt_shape, stride): - in_batch, in_height, in_width, _ = in_shape - filt_height, filt_width, _, num_filt = dilated_filt_shape - out_height = (in_height - filt_height) // stride[0] + 1 - out_width = (in_width - filt_width) // stride[1] + 1 - out_channel = num_filt - return in_batch, out_height, out_width, out_channel - - @tvm.testing.fixture - def expected_output_np(self, input_np, dilated_weights_np, stride): - dilated_weights_np_t = dilated_weights_np.transpose(0, 1, 3, 2) - ref_np = depthwise_conv2d_python_nhwc( - input_np.astype("float32"), dilated_weights_np_t.astype("float32"), stride, padding=0 - ).astype("float16") - return ref_np - - @tvm.testing.requires_hexagon - def test_dwconv2d( - self, - padded_in_shape, - filt_shape, - stride, - dilation, - dtype, - out_shape, - in_out_layout, - input_np_padded, - weights_np_transformed, - expected_output_np, - working_scope, - hexagon_session, - ): - """Main test function that tests the dwconv2d slice op""" - input_tensor = tvm.te.placeholder(padded_in_shape, name="InputTensor", dtype=dtype) - weights = tvm.te.placeholder(filt_shape, name="Weights", dtype=dtype) - - output_tensor = dwconv2d_compute(input_tensor, weights, out_shape, stride, dilation, dtype) - - def transform_weights(height, width, in_channel, out_channel): - return [out_channel // 32, height, width, in_channel, out_channel % 32] - - tir_schedule = dwconv2d_schedule( - output_tensor, [input_tensor, weights], in_out_layout, transform_weights - ) - - func_name = f"fdwconv2d_{dtype}" - with tvm.transform.PassContext(opt_level=3, config={"tir.disable_assert": True}): - runtime_module = tvm.build( - tir_schedule.mod, - target=get_hexagon_target("v69"), - name=func_name, - ) - - input_np_transformed = transform_numpy(input_np_padded, "nhwc", in_out_layout) - output_np_transformed = transform_numpy(expected_output_np, "nhwc", in_out_layout) - - input_arr = allocate_hexagon_array( - hexagon_session.device, - data=input_np_transformed, - axis_separators=[4], - mem_scope=working_scope, - ) - - weights_arr = allocate_hexagon_array( - hexagon_session.device, data=weights_np_transformed, mem_scope=working_scope - ) - - output_arr = allocate_hexagon_array( - hexagon_session.device, - tensor_shape=output_np_transformed.shape, - dtype=output_np_transformed.dtype, - axis_separators=[4], - mem_scope=working_scope, - ) - - mod = hexagon_session.load_module(runtime_module) - mod(input_arr, weights_arr, output_arr) - output_np = output_arr.numpy() - np.testing.assert_allclose(output_np, output_np_transformed, atol=0.01, rtol=0.01) - - -if __name__ == "__main__": - tvm.testing.main() From 7fb5182c37c5322460b846108eb0b3360c421f9b Mon Sep 17 00:00:00 2001 From: Gayatri Panchapakesan Kumari Date: Mon, 17 Oct 2022 21:51:56 +0530 Subject: [PATCH 2/4] Fix lint errors --- python/tvm/topi/hexagon/qnn/qdepthwise_conv2d_slice.py | 1 - python/tvm/topi/hexagon/utils.py | 1 - .../contrib/test_hexagon/topi/test_depthwise_conv2d_slice.py | 1 - 3 files changed, 3 deletions(-) diff --git a/python/tvm/topi/hexagon/qnn/qdepthwise_conv2d_slice.py b/python/tvm/topi/hexagon/qnn/qdepthwise_conv2d_slice.py index 99a4d3319e86..d509e7e13433 100644 --- a/python/tvm/topi/hexagon/qnn/qdepthwise_conv2d_slice.py +++ b/python/tvm/topi/hexagon/qnn/qdepthwise_conv2d_slice.py @@ -215,4 +215,3 @@ def qdepthwise_conv2d_schedule( # wi_ci = sch.fuse(w_inner,c_inner) # sch.vectorize(wi_ci) return sch - diff --git a/python/tvm/topi/hexagon/utils.py b/python/tvm/topi/hexagon/utils.py index 1bf4b5ef6af6..76205f5e9de8 100644 --- a/python/tvm/topi/hexagon/utils.py +++ b/python/tvm/topi/hexagon/utils.py @@ -318,4 +318,3 @@ def within_range(val, dtype): def saturate(x: te.Tensor, dtype: str): """Saturate value for the specified data type""" return te.max(te.min_value(dtype), te.min(x, te.max_value(dtype))) - diff --git a/tests/python/contrib/test_hexagon/topi/test_depthwise_conv2d_slice.py b/tests/python/contrib/test_hexagon/topi/test_depthwise_conv2d_slice.py index 0072c4b51006..840a462917ae 100644 --- a/tests/python/contrib/test_hexagon/topi/test_depthwise_conv2d_slice.py +++ b/tests/python/contrib/test_hexagon/topi/test_depthwise_conv2d_slice.py @@ -335,4 +335,3 @@ def test_dwconv2d( if __name__ == "__main__": tvm.testing.main() - From fee2cfdc0e61440634c998434697f137a1cfcc5c Mon Sep 17 00:00:00 2001 From: Gayatri Panchapakesan Kumari Date: Tue, 18 Oct 2022 15:39:18 +0530 Subject: [PATCH 3/4] Fix lint error --- python/tvm/topi/hexagon/qnn/qdepthwise_conv2d_slice.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/hexagon/qnn/qdepthwise_conv2d_slice.py b/python/tvm/topi/hexagon/qnn/qdepthwise_conv2d_slice.py index d509e7e13433..9a275c1cc370 100644 --- a/python/tvm/topi/hexagon/qnn/qdepthwise_conv2d_slice.py +++ b/python/tvm/topi/hexagon/qnn/qdepthwise_conv2d_slice.py @@ -32,10 +32,10 @@ a) Qc(n, oh, ow, oc) = (Sigma(r, s) (Qw(r, s, oc%cm, oc/cm) - zp_w) * (Qa(n, oh + r, ow + s, oc/cm) - zp_a)) * scale_value - where scale_value = (activation_scale * weight_scale) / output_scale + where scale_value = (activation_scale * weight_scale) / output_scale This can be written as - + b) Qc(n, oh, ow, oc) = (t1 - t2 - t3 + t4) * scale_value where t1 = Sigma(r, s) Qw(r, s, oc%cm, oc/cm) * Qa(n, oh + r, ow + s, oc/cm) From 6690e8ac459b908bd9002c942968c6da0a2ad2b0 Mon Sep 17 00:00:00 2001 From: Gayatri Panchapakesan Kumari Date: Tue, 18 Oct 2022 22:48:47 +0530 Subject: [PATCH 4/4] Fix lint errors --- python/tvm/topi/hexagon/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/hexagon/utils.py b/python/tvm/topi/hexagon/utils.py index 76205f5e9de8..890ebeb9fd11 100644 --- a/python/tvm/topi/hexagon/utils.py +++ b/python/tvm/topi/hexagon/utils.py @@ -242,8 +242,8 @@ def get_fixed_point_value(flp: float, dtype: str = "int16") -> Tuple[int, int]: fixed-point with the least amount of precision loss. - Here is a more rigorous explanation of the above, for non-negative scale values, which are of interest. - M < 2, so M * 2^(E-Bias+x) < 2 ^ (E-Bias+x+1) [Note: LHS is a fraction, RHS int] + Here is a more rigorous explanation of the above, for non-negative scale values, which are of + interest. M < 2, so M * 2^(E-Bias+x) < 2 ^ (E-Bias+x+1) [Note: LHS is a fraction, RHS int] => round(M * 2^(E-Bias+x)) <= 2 ^ (E-Bias+x+1) [Note the "<=", not "<"] We want x s.t. round(M * 2^(E-Bias+x)) <= 2^15 - 1 We know round(M * 2^(E-Bias+x)) <= 2^(E-Bias+x+1)