From 0c301dee6c7cc5bd9ccf8a3e1a43ab9d20fbb3af Mon Sep 17 00:00:00 2001 From: Abhikrant Sharma Date: Wed, 1 Jun 2022 05:46:22 -0500 Subject: [PATCH 1/8] [TOPI] [Hexagon] Batch flatten slice op initial version --- python/tvm/topi/hexagon/slice_ops/__init__.py | 22 +++ .../topi/hexagon/slice_ops/batch_flatten.py | 78 +++++++++++ .../contrib/test_hexagon/infrastructure.py | 2 +- .../test_hexagon/test_batch_flatten.py | 130 ++++++++++++++++++ 4 files changed, 231 insertions(+), 1 deletion(-) create mode 100644 python/tvm/topi/hexagon/slice_ops/__init__.py create mode 100644 python/tvm/topi/hexagon/slice_ops/batch_flatten.py create mode 100644 tests/python/contrib/test_hexagon/test_batch_flatten.py diff --git a/python/tvm/topi/hexagon/slice_ops/__init__.py b/python/tvm/topi/hexagon/slice_ops/__init__.py new file mode 100644 index 000000000000..20652f7d616e --- /dev/null +++ b/python/tvm/topi/hexagon/slice_ops/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" Computes and Schedules for Hexagon slice ops. """ + +# pylint: disable=wildcard-import + +from .batch_flatten import * diff --git a/python/tvm/topi/hexagon/slice_ops/batch_flatten.py b/python/tvm/topi/hexagon/slice_ops/batch_flatten.py new file mode 100644 index 000000000000..cf342e2e2b0b --- /dev/null +++ b/python/tvm/topi/hexagon/slice_ops/batch_flatten.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Hexagon slice batch flatten compute and schedule""" +import typing + +from tvm import te, tir, topi +from tvm.script import tir as T + + +def batch_flatten_compute(A: te.Tensor) -> te.Tensor: + """Compute for slice batch flatten op for hexagon. + This op makes the following assumptions: + 1. This op is written for a sliced batch flatten operation. + 2. The input is assumed to be in NHWC layout. + + Parameters + ---------- + Input : te.Tensor + Input activations padded for inner dimension size + Returns + ------- + Output : te.Tensor + Output of applying batch flatten operation on input + """ + return topi.nn.flatten(A) + + +def batch_flatten_STIR_schedule( + outputs: te.Tensor, + input: te.Tensor, + out_layout: typing.Callable, + in_layout: typing.Callable, +) -> tir.Schedule: + """STIR schedule definition for the compute of batch flatten compute. + Parameters + ---------- + outputs : te.Tensor + The output tensor as returned by a call to batch_flatten_compute + input : te.Tensor + Input tensor to batch_flatten + out_layout: typing.Callable + The transformation function definition for the expected output layout + in_layout: typing.Callable + The transformation function definition for the input layout + Returns + ------- + sch : tvm.tir.Schedule + The STIR schedule for slice batch flatten compute + """ + + batch_flatten_func = te.create_prim_func([input, outputs]) + sch = tir.Schedule(batch_flatten_func, debug_mask="all") + compute = sch.get_block("compute") + + sch.transform_layout(compute, input.name, in_layout) + sch.transform_layout(compute, outputs.name, out_layout) + i, j = sch.get_loops(compute) + jo, c = sch.split(j, [None, input.shape[3]]) + h, w = sch.split(jo, [input.shape[1], input.shape[2]]) + co, ci = sch.split(c, [None, 1024]) + ci_1, ci_2 = sch.split(ci, [None, 64]) + sch.vectorize(ci_2) + return sch diff --git a/tests/python/contrib/test_hexagon/infrastructure.py b/tests/python/contrib/test_hexagon/infrastructure.py index 0c9a9478c870..01eef86e6b5b 100644 --- a/tests/python/contrib/test_hexagon/infrastructure.py +++ b/tests/python/contrib/test_hexagon/infrastructure.py @@ -48,7 +48,7 @@ def allocate_hexagon_array( for dim_i, dim_f in zip(boundaries[:-1], boundaries[1:]) ] - arr = tvm.nd.empty(physical_shape, dtype=dtype, device=dev) + arr = tvm.nd.empty(physical_shape, dtype=dtype, device=dev, mem_scope=mem_scope) if data is not None: arr.copyfrom(data.reshape(physical_shape)) diff --git a/tests/python/contrib/test_hexagon/test_batch_flatten.py b/tests/python/contrib/test_hexagon/test_batch_flatten.py new file mode 100644 index 000000000000..918362101512 --- /dev/null +++ b/tests/python/contrib/test_hexagon/test_batch_flatten.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import pytest + +import tvm +import tvm.testing +import tvm.topi.hexagon.slice_ops as sl +from tvm import te, topi +from tvm.contrib.hexagon.build import HexagonLauncher +from tvm.topi import testing + +from .infrastructure import allocate_hexagon_array + + +def n11c_1024c_1d(n, h, w, c): + return [n, h, w, c // 1024, tvm.te.AXIS_SEPARATOR, c % 1024] + + +def nc_1024_1d(n, c): + return [n, c // 1024, tvm.te.AXIS_SEPARATOR, c % 1024] + + +def transform_numpy(arr_np, layout): + if layout == "nhwc": + return arr_np + elif layout == "n11c-1024c-1d": + N, H, W, C = arr_np.shape + return arr_np.reshape([N, H, W, C // 1024, 1024]) + elif layout == "nc-1d": + N, C = arr_np.shape + return arr_np.reshape([N, C // 1024, 1024]) + + +@tvm.testing.fixture +def transformed_expected_output_np(expected_output_np, output_layout): + return transform_numpy(expected_output_np, output_layout) + + +class BaseTestBatchFlatten: + ( + input_shape, + input_layout, + output_layout, + input_axis_sep, + output_axis_sep, + ) = tvm.testing.parameters( + ((1, 1, 1, 2048), "n11c-1024c-1d", "nc-1d", [4], [2]), + ((1, 2, 4, 2048), "n11c-1024c-1d", "nc-1d", [4], [2]), + ((1, 8, 8, 1024), "n11c-1024c-1d", "nc-1d", [4], [2]), + ((2, 4, 8, 1024), "n11c-1024c-1d", "nc-1d", [4], [2]), + ((2, 3, 5, 2048), "n11c-1024c-1d", "nc-1d", [4], [2]), + ) + data_type = tvm.testing.parameter("float16") + + +class TestBatchFlatten(BaseTestBatchFlatten): + @tvm.testing.fixture + def output_shape(self, input_shape): + return input_shape[0], input_shape[1] * input_shape[2] * input_shape[3] + + @tvm.testing.requires_hexagon + def test_batch_flatten( + self, + data_type, + input_shape, + input_layout, + input_axis_sep, + output_shape, + output_layout, + output_axis_sep, + hexagon_session, + ): + target_hexagon = tvm.target.hexagon("v69") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + A = te.placeholder(input_shape, name="A", dtype=data_type) + D = sl.batch_flatten_compute(A) + tir_s = sl.batch_flatten_STIR_schedule( + D, + A, + nc_1024_1d, + n11c_1024c_1d, + ) + func_name = "batch_flatten" + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_assert": True}): + tir_irm = tvm.lower(tir_s.mod, [A, D], name=func_name) + runtime_module = tvm.build(tir_irm, [A, D], target=target, name=func_name) + + mod = hexagon_session.load_module(runtime_module) + + a_numpy = (np.random.uniform(-1, 1, input_shape)).astype(data_type) + ref = np.reshape(a_numpy, output_shape) + + input_np_transformed = transform_numpy(a_numpy, input_layout) + ref_np_transformed = transform_numpy(ref, output_layout) + + a_tvm = allocate_hexagon_array( + hexagon_session.device, + data=input_np_transformed, + axis_separators=input_axis_sep, + mem_scope="global.vtcm", + ) + output = allocate_hexagon_array( + hexagon_session.device, + ref_np_transformed.shape, + data_type, + axis_separators=output_axis_sep, + mem_scope="global.vtcm", + ) + mod(a_tvm, output) + np.testing.assert_allclose(output.numpy(), ref_np_transformed, atol=1e-07, rtol=0) + + +if __name__ == "__main__": + sys.exit(pytest.main(sys.argv)) From 7108e529feb9031debc82fbc8b8b11e5de503625 Mon Sep 17 00:00:00 2001 From: Abhikrant Sharma Date: Wed, 1 Jun 2022 09:47:29 -0500 Subject: [PATCH 2/8] Fix lint errors --- .../topi/hexagon/slice_ops/batch_flatten.py | 24 +++++++++---------- .../test_hexagon/test_batch_flatten.py | 2 +- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/python/tvm/topi/hexagon/slice_ops/batch_flatten.py b/python/tvm/topi/hexagon/slice_ops/batch_flatten.py index cf342e2e2b0b..3de07f4d3152 100644 --- a/python/tvm/topi/hexagon/slice_ops/batch_flatten.py +++ b/python/tvm/topi/hexagon/slice_ops/batch_flatten.py @@ -22,7 +22,7 @@ from tvm.script import tir as T -def batch_flatten_compute(A: te.Tensor) -> te.Tensor: +def batch_flatten_compute(inp: te.Tensor) -> te.Tensor: """Compute for slice batch flatten op for hexagon. This op makes the following assumptions: 1. This op is written for a sliced batch flatten operation. @@ -37,12 +37,12 @@ def batch_flatten_compute(A: te.Tensor) -> te.Tensor: Output : te.Tensor Output of applying batch flatten operation on input """ - return topi.nn.flatten(A) + return topi.nn.flatten(inp) -def batch_flatten_STIR_schedule( - outputs: te.Tensor, - input: te.Tensor, +def batch_flatten_stir_schedule( + out: te.Tensor, + inp: te.Tensor, out_layout: typing.Callable, in_layout: typing.Callable, ) -> tir.Schedule: @@ -63,16 +63,16 @@ def batch_flatten_STIR_schedule( The STIR schedule for slice batch flatten compute """ - batch_flatten_func = te.create_prim_func([input, outputs]) + batch_flatten_func = te.create_prim_func([inp, out]) sch = tir.Schedule(batch_flatten_func, debug_mask="all") compute = sch.get_block("compute") - sch.transform_layout(compute, input.name, in_layout) - sch.transform_layout(compute, outputs.name, out_layout) + sch.transform_layout(compute, inp.name, in_layout) + sch.transform_layout(compute, out.name, out_layout) i, j = sch.get_loops(compute) - jo, c = sch.split(j, [None, input.shape[3]]) - h, w = sch.split(jo, [input.shape[1], input.shape[2]]) + jo, c = sch.split(j, [None, inp.shape[3]]) + h, w = sch.split(jo, [inp.shape[1], inp.shape[2]]) co, ci = sch.split(c, [None, 1024]) - ci_1, ci_2 = sch.split(ci, [None, 64]) - sch.vectorize(ci_2) + cio, cii = sch.split(ci, [None, 64]) + sch.vectorize(cii) return sch diff --git a/tests/python/contrib/test_hexagon/test_batch_flatten.py b/tests/python/contrib/test_hexagon/test_batch_flatten.py index 918362101512..6295826531ac 100644 --- a/tests/python/contrib/test_hexagon/test_batch_flatten.py +++ b/tests/python/contrib/test_hexagon/test_batch_flatten.py @@ -90,7 +90,7 @@ def test_batch_flatten( target = tvm.target.Target(target_hexagon, host=target_hexagon) A = te.placeholder(input_shape, name="A", dtype=data_type) D = sl.batch_flatten_compute(A) - tir_s = sl.batch_flatten_STIR_schedule( + tir_s = sl.batch_flatten_stir_schedule( D, A, nc_1024_1d, From fdb9dbceb3f4d1e4933ac9b09846d7c6711adf8e Mon Sep 17 00:00:00 2001 From: Abhikrant Sharma Date: Thu, 2 Jun 2022 00:33:53 -0500 Subject: [PATCH 3/8] Fix more lint errors --- python/tvm/topi/hexagon/slice_ops/batch_flatten.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tvm/topi/hexagon/slice_ops/batch_flatten.py b/python/tvm/topi/hexagon/slice_ops/batch_flatten.py index 3de07f4d3152..35b542868a52 100644 --- a/python/tvm/topi/hexagon/slice_ops/batch_flatten.py +++ b/python/tvm/topi/hexagon/slice_ops/batch_flatten.py @@ -70,9 +70,9 @@ def batch_flatten_stir_schedule( sch.transform_layout(compute, inp.name, in_layout) sch.transform_layout(compute, out.name, out_layout) i, j = sch.get_loops(compute) - jo, c = sch.split(j, [None, inp.shape[3]]) - h, w = sch.split(jo, [inp.shape[1], inp.shape[2]]) - co, ci = sch.split(c, [None, 1024]) - cio, cii = sch.split(ci, [None, 64]) - sch.vectorize(cii) + jout, channel = sch.split(j, [None, inp.shape[3]]) + height, width = sch.split(jout, [inp.shape[1], inp.shape[2]]) + channelo, channeli = sch.split(channel, [None, 1024]) + channelio, channelii = sch.split(channeli, [None, 64]) + sch.vectorize(channelii) return sch From 6839726a33d7e83228bbb16c860e2aa7f0966f7d Mon Sep 17 00:00:00 2001 From: Abhikrant Sharma Date: Thu, 2 Jun 2022 02:42:50 -0500 Subject: [PATCH 4/8] Fix lint warnings --- python/tvm/topi/hexagon/slice_ops/batch_flatten.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/hexagon/slice_ops/batch_flatten.py b/python/tvm/topi/hexagon/slice_ops/batch_flatten.py index 35b542868a52..58022290f7ba 100644 --- a/python/tvm/topi/hexagon/slice_ops/batch_flatten.py +++ b/python/tvm/topi/hexagon/slice_ops/batch_flatten.py @@ -19,7 +19,6 @@ import typing from tvm import te, tir, topi -from tvm.script import tir as T def batch_flatten_compute(inp: te.Tensor) -> te.Tensor: @@ -74,5 +73,6 @@ def batch_flatten_stir_schedule( height, width = sch.split(jout, [inp.shape[1], inp.shape[2]]) channelo, channeli = sch.split(channel, [None, 1024]) channelio, channelii = sch.split(channeli, [None, 64]) + sch.reorder(i, height, width, channelo, channelio, channelii) sch.vectorize(channelii) return sch From f4495bf355eb910436adc49429383090d3153da9 Mon Sep 17 00:00:00 2001 From: Abhikrant Sharma Date: Fri, 3 Jun 2022 02:35:47 -0500 Subject: [PATCH 5/8] Fix review comments --- .../test_hexagon/test_batch_flatten.py | 23 ++++++++----------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/tests/python/contrib/test_hexagon/test_batch_flatten.py b/tests/python/contrib/test_hexagon/test_batch_flatten.py index 6295826531ac..d1e7c8143caa 100644 --- a/tests/python/contrib/test_hexagon/test_batch_flatten.py +++ b/tests/python/contrib/test_hexagon/test_batch_flatten.py @@ -53,19 +53,15 @@ def transformed_expected_output_np(expected_output_np, output_layout): class BaseTestBatchFlatten: - ( - input_shape, - input_layout, - output_layout, - input_axis_sep, - output_axis_sep, - ) = tvm.testing.parameters( - ((1, 1, 1, 2048), "n11c-1024c-1d", "nc-1d", [4], [2]), - ((1, 2, 4, 2048), "n11c-1024c-1d", "nc-1d", [4], [2]), - ((1, 8, 8, 1024), "n11c-1024c-1d", "nc-1d", [4], [2]), - ((2, 4, 8, 1024), "n11c-1024c-1d", "nc-1d", [4], [2]), - ((2, 3, 5, 2048), "n11c-1024c-1d", "nc-1d", [4], [2]), + input_shape = tvm.testing.parameter( + (1, 1, 1, 2048), + (1, 2, 4, 2048), + (1, 8, 8, 1024), + (2, 4, 8, 1024), + (2, 3, 5, 2048), ) + input_layout, input_axis_sep = tvm.testing.parameters(("n11c-1024c-1d", [4])) + output_layout, output_axis_sep = tvm.testing.parameters(("nc-1d", [2])) data_type = tvm.testing.parameter("float16") @@ -98,8 +94,7 @@ def test_batch_flatten( ) func_name = "batch_flatten" with tvm.transform.PassContext(opt_level=3, config={"tir.disable_assert": True}): - tir_irm = tvm.lower(tir_s.mod, [A, D], name=func_name) - runtime_module = tvm.build(tir_irm, [A, D], target=target, name=func_name) + runtime_module = tvm.build(tir_s.mod, target=target, name=func_name) mod = hexagon_session.load_module(runtime_module) From 89479ba3b134f59f96c47b6052c347dbbd14bc5c Mon Sep 17 00:00:00 2001 From: Abhikrant Sharma Date: Thu, 16 Jun 2022 05:44:18 -0500 Subject: [PATCH 6/8] Update tests to use util functions --- .../topi/hexagon/slice_ops/batch_flatten.py | 5 +-- python/tvm/topi/hexagon/utils.py | 14 ++++++++ .../contrib/test_hexagon/infrastructure.py | 6 ++++ .../{ => topi}/test_batch_flatten.py | 36 ++++--------------- 4 files changed, 29 insertions(+), 32 deletions(-) rename tests/python/contrib/test_hexagon/{ => topi}/test_batch_flatten.py (76%) diff --git a/python/tvm/topi/hexagon/slice_ops/batch_flatten.py b/python/tvm/topi/hexagon/slice_ops/batch_flatten.py index 58022290f7ba..07230296412e 100644 --- a/python/tvm/topi/hexagon/slice_ops/batch_flatten.py +++ b/python/tvm/topi/hexagon/slice_ops/batch_flatten.py @@ -19,6 +19,7 @@ import typing from tvm import te, tir, topi +from ..utils import get_layout_transform_fn def batch_flatten_compute(inp: te.Tensor) -> te.Tensor: @@ -66,8 +67,8 @@ def batch_flatten_stir_schedule( sch = tir.Schedule(batch_flatten_func, debug_mask="all") compute = sch.get_block("compute") - sch.transform_layout(compute, inp.name, in_layout) - sch.transform_layout(compute, out.name, out_layout) + sch.transform_layout(compute, inp.name, get_layout_transform_fn(in_layout)) + sch.transform_layout(compute, out.name, get_layout_transform_fn(out_layout)) i, j = sch.get_loops(compute) jout, channel = sch.split(j, [None, inp.shape[3]]) height, width = sch.split(jout, [inp.shape[1], inp.shape[2]]) diff --git a/python/tvm/topi/hexagon/utils.py b/python/tvm/topi/hexagon/utils.py index af6e3de9c350..1ceeb186ab87 100644 --- a/python/tvm/topi/hexagon/utils.py +++ b/python/tvm/topi/hexagon/utils.py @@ -39,6 +39,16 @@ def nhwc_8h2w32c2w_1d(n, h, w, c): return [n, h // 8, w // 4, c // 32, h % 8, (w % 4) // 2, c % 32, w % 2] +def nhwc_1024c_1d(n, h, w, c): + """Return index map for nhwc_1024 1d layout""" + return [n, h, w, c // 1024, te.AXIS_SEPARATOR, c % 1024] + + +def nc_1024_1d(n, c): + """Return index map for nc_1024 1d layout""" + return [n, c // 1024, te.AXIS_SEPARATOR, c % 1024] + + def get_layout_transform_fn(layout): """Return index map function as per the layout string""" if layout == "nhwc-8h2w32c2w-2d": @@ -49,4 +59,8 @@ def get_layout_transform_fn(layout): return n11c_1024c_2d if layout == "n11c-1024c-1d": return n11c_1024c_1d + if layout == "nhwc-1024c-1d": + return nhwc_1024c_1d + if layout == "nc-1d": + return nc_1024_1d raise RuntimeError(f"Unexpected layout '{layout}'") diff --git a/tests/python/contrib/test_hexagon/infrastructure.py b/tests/python/contrib/test_hexagon/infrastructure.py index 57a9dff8b424..5d031871509b 100644 --- a/tests/python/contrib/test_hexagon/infrastructure.py +++ b/tests/python/contrib/test_hexagon/infrastructure.py @@ -245,6 +245,12 @@ def transform_numpy(arr_np, current_layout: str, new_layout: str): n, h, w, c = arr_np.shape assert h == 1 and w == 1, "The size of h and w must be 1" return arr_np.reshape([n, 1, 1, c // 1024, 1024]) + if new_layout == "nc-1d": + N, C = arr_np.shape + return arr_np.reshape([N, C // 1024, 1024]) + if new_layout == "nhwc-1024c-1d": + N, H, W, C = arr_np.shape + return arr_np.reshape([N, H, W, C // 1024, 1024]) raise RuntimeError(f"Unexpected new_layout '{new_layout}'") raise RuntimeError(f"Unexpected current_layout '{current_layout}'") diff --git a/tests/python/contrib/test_hexagon/test_batch_flatten.py b/tests/python/contrib/test_hexagon/topi/test_batch_flatten.py similarity index 76% rename from tests/python/contrib/test_hexagon/test_batch_flatten.py rename to tests/python/contrib/test_hexagon/topi/test_batch_flatten.py index d1e7c8143caa..cd7a9ec51591 100644 --- a/tests/python/contrib/test_hexagon/test_batch_flatten.py +++ b/tests/python/contrib/test_hexagon/topi/test_batch_flatten.py @@ -25,31 +25,7 @@ from tvm.contrib.hexagon.build import HexagonLauncher from tvm.topi import testing -from .infrastructure import allocate_hexagon_array - - -def n11c_1024c_1d(n, h, w, c): - return [n, h, w, c // 1024, tvm.te.AXIS_SEPARATOR, c % 1024] - - -def nc_1024_1d(n, c): - return [n, c // 1024, tvm.te.AXIS_SEPARATOR, c % 1024] - - -def transform_numpy(arr_np, layout): - if layout == "nhwc": - return arr_np - elif layout == "n11c-1024c-1d": - N, H, W, C = arr_np.shape - return arr_np.reshape([N, H, W, C // 1024, 1024]) - elif layout == "nc-1d": - N, C = arr_np.shape - return arr_np.reshape([N, C // 1024, 1024]) - - -@tvm.testing.fixture -def transformed_expected_output_np(expected_output_np, output_layout): - return transform_numpy(expected_output_np, output_layout) +from ..infrastructure import allocate_hexagon_array, transform_numpy class BaseTestBatchFlatten: @@ -60,7 +36,7 @@ class BaseTestBatchFlatten: (2, 4, 8, 1024), (2, 3, 5, 2048), ) - input_layout, input_axis_sep = tvm.testing.parameters(("n11c-1024c-1d", [4])) + input_layout, input_axis_sep = tvm.testing.parameters(("nhwc-1024c-1d", [4])) output_layout, output_axis_sep = tvm.testing.parameters(("nc-1d", [2])) data_type = tvm.testing.parameter("float16") @@ -89,8 +65,8 @@ def test_batch_flatten( tir_s = sl.batch_flatten_stir_schedule( D, A, - nc_1024_1d, - n11c_1024c_1d, + output_layout, + input_layout, ) func_name = "batch_flatten" with tvm.transform.PassContext(opt_level=3, config={"tir.disable_assert": True}): @@ -101,8 +77,8 @@ def test_batch_flatten( a_numpy = (np.random.uniform(-1, 1, input_shape)).astype(data_type) ref = np.reshape(a_numpy, output_shape) - input_np_transformed = transform_numpy(a_numpy, input_layout) - ref_np_transformed = transform_numpy(ref, output_layout) + input_np_transformed = transform_numpy(a_numpy, "nhwc", input_layout) + ref_np_transformed = transform_numpy(ref, "nhwc", output_layout) a_tvm = allocate_hexagon_array( hexagon_session.device, From c4db7e3f6bb6d83617b268441fd90762098de8b5 Mon Sep 17 00:00:00 2001 From: abhikran-quic <63697863+abhikran-quic@users.noreply.github.com> Date: Wed, 22 Jun 2022 10:51:08 +0530 Subject: [PATCH 7/8] Update __init__.py --- python/tvm/topi/hexagon/slice_ops/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/hexagon/slice_ops/__init__.py b/python/tvm/topi/hexagon/slice_ops/__init__.py index ee0adf757d5c..df59e6301f07 100644 --- a/python/tvm/topi/hexagon/slice_ops/__init__.py +++ b/python/tvm/topi/hexagon/slice_ops/__init__.py @@ -19,4 +19,4 @@ from .avg_pool2d import avg_pool2d_compute, avg_pool2d_STIR_schedule from .add_subtract_multiply import * -from .batch_flatten import batch_flatten_compute, batch_flatten_stir_schedule \ No newline at end of file +from .batch_flatten import batch_flatten_compute, batch_flatten_stir_schedule From 05140979efe78e1cda9a932c7fdc7c5641b8ca3b Mon Sep 17 00:00:00 2001 From: abhikran Date: Fri, 24 Jun 2022 12:33:58 +0530 Subject: [PATCH 8/8] Fix review comments --- .../tvm/topi/hexagon/slice_ops/batch_flatten.py | 6 ++---- python/tvm/topi/hexagon/utils.py | 16 ++++++++-------- .../contrib/test_hexagon/infrastructure.py | 4 ++-- .../test_hexagon/topi/test_batch_flatten.py | 8 ++++---- 4 files changed, 16 insertions(+), 18 deletions(-) diff --git a/python/tvm/topi/hexagon/slice_ops/batch_flatten.py b/python/tvm/topi/hexagon/slice_ops/batch_flatten.py index 07230296412e..6dc0914e91b4 100644 --- a/python/tvm/topi/hexagon/slice_ops/batch_flatten.py +++ b/python/tvm/topi/hexagon/slice_ops/batch_flatten.py @@ -16,8 +16,6 @@ # under the License. """Hexagon slice batch flatten compute and schedule""" -import typing - from tvm import te, tir, topi from ..utils import get_layout_transform_fn @@ -43,8 +41,8 @@ def batch_flatten_compute(inp: te.Tensor) -> te.Tensor: def batch_flatten_stir_schedule( out: te.Tensor, inp: te.Tensor, - out_layout: typing.Callable, - in_layout: typing.Callable, + out_layout: str, + in_layout: str, ) -> tir.Schedule: """STIR schedule definition for the compute of batch flatten compute. Parameters diff --git a/python/tvm/topi/hexagon/utils.py b/python/tvm/topi/hexagon/utils.py index 1ceeb186ab87..f4fbeebc522a 100644 --- a/python/tvm/topi/hexagon/utils.py +++ b/python/tvm/topi/hexagon/utils.py @@ -39,13 +39,13 @@ def nhwc_8h2w32c2w_1d(n, h, w, c): return [n, h // 8, w // 4, c // 32, h % 8, (w % 4) // 2, c % 32, w % 2] -def nhwc_1024c_1d(n, h, w, c): - """Return index map for nhwc_1024 1d layout""" +def nhwc_1024c_2d(n, h, w, c): + """Return index map for nhwc_1024 2d layout""" return [n, h, w, c // 1024, te.AXIS_SEPARATOR, c % 1024] -def nc_1024_1d(n, c): - """Return index map for nc_1024 1d layout""" +def nc_1024_2d(n, c): + """Return index map for nc_1024 2d layout""" return [n, c // 1024, te.AXIS_SEPARATOR, c % 1024] @@ -59,8 +59,8 @@ def get_layout_transform_fn(layout): return n11c_1024c_2d if layout == "n11c-1024c-1d": return n11c_1024c_1d - if layout == "nhwc-1024c-1d": - return nhwc_1024c_1d - if layout == "nc-1d": - return nc_1024_1d + if layout == "nhwc-1024c-2d": + return nhwc_1024c_2d + if layout == "nc-1024-2d": + return nc_1024_2d raise RuntimeError(f"Unexpected layout '{layout}'") diff --git a/tests/python/contrib/test_hexagon/infrastructure.py b/tests/python/contrib/test_hexagon/infrastructure.py index 5d031871509b..34ba6243d7f9 100644 --- a/tests/python/contrib/test_hexagon/infrastructure.py +++ b/tests/python/contrib/test_hexagon/infrastructure.py @@ -245,10 +245,10 @@ def transform_numpy(arr_np, current_layout: str, new_layout: str): n, h, w, c = arr_np.shape assert h == 1 and w == 1, "The size of h and w must be 1" return arr_np.reshape([n, 1, 1, c // 1024, 1024]) - if new_layout == "nc-1d": + if new_layout == "nc-1024-2d": N, C = arr_np.shape return arr_np.reshape([N, C // 1024, 1024]) - if new_layout == "nhwc-1024c-1d": + if new_layout == "nhwc-1024c-2d": N, H, W, C = arr_np.shape return arr_np.reshape([N, H, W, C // 1024, 1024]) diff --git a/tests/python/contrib/test_hexagon/topi/test_batch_flatten.py b/tests/python/contrib/test_hexagon/topi/test_batch_flatten.py index cd7a9ec51591..3a056116d45c 100644 --- a/tests/python/contrib/test_hexagon/topi/test_batch_flatten.py +++ b/tests/python/contrib/test_hexagon/topi/test_batch_flatten.py @@ -36,8 +36,8 @@ class BaseTestBatchFlatten: (2, 4, 8, 1024), (2, 3, 5, 2048), ) - input_layout, input_axis_sep = tvm.testing.parameters(("nhwc-1024c-1d", [4])) - output_layout, output_axis_sep = tvm.testing.parameters(("nc-1d", [2])) + input_layout, input_axis_sep = tvm.testing.parameters(("nhwc-1024c-2d", [4])) + output_layout, output_axis_sep = tvm.testing.parameters(("nc-1024-2d", [2])) data_type = tvm.testing.parameter("float16") @@ -69,7 +69,7 @@ def test_batch_flatten( input_layout, ) func_name = "batch_flatten" - with tvm.transform.PassContext(opt_level=3, config={"tir.disable_assert": True}): + with tvm.transform.PassContext(opt_level=3): runtime_module = tvm.build(tir_s.mod, target=target, name=func_name) mod = hexagon_session.load_module(runtime_module) @@ -98,4 +98,4 @@ def test_batch_flatten( if __name__ == "__main__": - sys.exit(pytest.main(sys.argv)) + tvm.testing.main(pytest.main(sys.argv))