From 283b149a9cf62618046f3d5f2ae9782cc1fe0dec Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 30 Sep 2022 16:16:14 +0900 Subject: [PATCH 1/9] [TIR] Transform NDArray by IndexMap --- include/tvm/tir/index_map.h | 5 +++ python/tvm/tir/function.py | 6 +++ src/tir/ir/index_map.cc | 50 +++++++++++++++++++++++++ tests/python/unittest/test_index_map.py | 34 ++++++++++++++++- 4 files changed, 94 insertions(+), 1 deletion(-) diff --git a/include/tvm/tir/index_map.h b/include/tvm/tir/index_map.h index e1b323462cda..3b1328575d09 100644 --- a/include/tvm/tir/index_map.h +++ b/include/tvm/tir/index_map.h @@ -136,6 +136,11 @@ class IndexMapNode : public Object { */ Array MapShape(const Array& shape, arith::Analyzer* analyzer = nullptr) const; + /* + TODO + */ + runtime::NDArray MapNDArray(runtime::NDArray constant) const; + /*! * \brief Convert to string representation in Python. * \return The stringified lambda expression in Python. diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index dd684bc4f1ae..0c4772a378c4 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -515,6 +515,12 @@ def map_shape(self, shape: List[PrimExpr]) -> List[PrimExpr]: """ return _ffi_api.IndexMapMapShape(self, shape) + def map_ndarray(self, arr): + """ + TODO + """ + return _ffi_api.IndexMapMapNDArray(self, arr) + def inverse(self, shape: List[Union[Range, PrimExpr]]) -> "IndexMap": """Return the inverse of the map diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index 2c5349ab9941..f42fb48c6ff8 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -208,6 +208,53 @@ Array IndexMapNode::MapShape(const Array& shape, return output; } +runtime::NDArray IndexMapNode::MapNDArray(runtime::NDArray constant) const { + auto shape = constant.Shape(); + size_t extent = 1; + Array orig_shape; + for (int i = 0; i < shape.size(); ++i) { + extent *= shape[i]; + orig_shape.push_back(PrimExpr(int(shape[i]))); + } + auto dst_shape = MapShape(orig_shape); + + std::vector dst_shape_int; + for (int i = 0; i < dst_shape.size(); ++i) { + dst_shape_int.push_back(dst_shape[i].as()->value); + } + + auto elem_bytes = (constant->dtype.bits / 8) * constant->dtype.lanes; + std::vector bytes(extent * elem_bytes); + constant.CopyToBytes(bytes.data(), bytes.size()); + + std::vector bytes_rewritten(bytes.size()); + + for (size_t i = 0; i < extent; ++i) { + Array src_indices; + auto div_factor = extent; + auto index = i; + for (auto s : shape) { + div_factor /= s; + src_indices.push_back(PrimExpr(int(index / div_factor))); + index = index % div_factor; + } + auto dst_indices = MapIndices(src_indices); + size_t dst_linear_indices = 0; + auto mul_factor = extent; + for (int j = 0; j < dst_indices.size(); ++j) { + mul_factor /= dst_shape_int[j]; + dst_linear_indices += dst_indices[j].as()->value * mul_factor; + } + std::copy(bytes.begin() + i * elem_bytes, bytes.begin() + (i + 1) * elem_bytes, + bytes_rewritten.begin() + dst_linear_indices * elem_bytes); + } + + auto rewritten_constant = + runtime::NDArray::Empty(dst_shape_int, constant->dtype, constant->device); + rewritten_constant.CopyFromBytes(bytes_rewritten.data(), bytes.size()); + return rewritten_constant; +} + /*! * \brief Auxilarry function to comvert an index map to lambda expression in Python. * \param initial_indices The initial indices in the index map. @@ -289,6 +336,9 @@ TVM_REGISTER_GLOBAL("tir.IndexMapMapShape").set_body_typed([](IndexMap map, Arra }); TVM_REGISTER_GLOBAL("tir.IndexMapInverse").set_body_method(&IndexMap::Inverse); +TVM_REGISTER_GLOBAL("tir.IndexMapMapNDArray") + .set_body_typed([](IndexMap map, runtime::NDArray arr) { return map->MapNDArray(arr); }); + TVM_REGISTER_GLOBAL("tir.IndexMapNonSurjectiveInverse") .set_body_typed([](IndexMap forward, Array initial_ranges) { auto result = forward.NonSurjectiveInverse(initial_ranges); diff --git a/tests/python/unittest/test_index_map.py b/tests/python/unittest/test_index_map.py index a86880b0f4a8..4536ded964de 100644 --- a/tests/python/unittest/test_index_map.py +++ b/tests/python/unittest/test_index_map.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy as np import pytest import tvm @@ -202,5 +203,36 @@ def expected_inverse(i0, i1, i2, i3): assert expected_map.is_equivalent_to(inverse_map) +def test_map_ndarray(): + # index_map = IndexMap.from_func(lambda i: [i // 4, i % 4]) + + # arr = tvm.nd.array(np.arange(16).astype("int32")) + # print(index_map.map_ndarray(arr)) + + index_map = IndexMap.from_func(lambda i0, i1, i2, i3: (floordiv(i3, 32), i0, floordiv(i2, 8), floordiv(floormod(i3, 32), 16), i1, floormod(i2, 8), floormod(i3, 16))) + # index_map = IndexMap.from_func(lambda i0, i1, i2, i3: (i3, i0, i1, i2)) + # index_map = IndexMap.from_func(lambda i0, i1: (i1, i0)) + + kH = kW = 3 + I = 64 + O = 64 + inp = np.random.randn(kH, kW, I, O).astype("float32") + arr = tvm.nd.array(inp) + out = index_map.map_ndarray(arr).numpy() + + ref = np.zeros(out.shape).astype("float32") + + for i0 in range(kH): + for i1 in range(kW): + for i2 in range(I): + for i3 in range(O): + v = inp[i0, i1, i2, i3] + ref[i3 // 32, i0, i2 // 8, (i3 % 32) // 16, i1, i2 % 8, i3 % 16] = v + # ref[i3, i0, i1, i2] = v + + print(np.max(np.abs(ref - out))) + + if __name__ == "__main__": - tvm.testing.main() + # tvm.testing.main() + test_map_ndarray() From 2385ba661c6b33fcf6aea7011197548ea6b77b60 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 30 Sep 2022 16:35:10 +0900 Subject: [PATCH 2/9] add more tests --- tests/python/unittest/test_index_map.py | 44 +++++++++++++++++++------ 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/tests/python/unittest/test_index_map.py b/tests/python/unittest/test_index_map.py index 4536ded964de..2ee9a064cb0b 100644 --- a/tests/python/unittest/test_index_map.py +++ b/tests/python/unittest/test_index_map.py @@ -204,14 +204,40 @@ def expected_inverse(i0, i1, i2, i3): def test_map_ndarray(): - # index_map = IndexMap.from_func(lambda i: [i // 4, i % 4]) + index_map = IndexMap.from_func(lambda i: [i // 4, i % 4]) + + inp = np.arange(16).astype("int8") + + out = index_map.map_ndarray(tvm.nd.array(inp)).numpy() + + ref = np.zeros(out.shape).astype("int8") + + for i in range(16): + ref[i // 4, i % 4] = inp[i] + + np.testing.assert_equal(ref, out) + + index_map = IndexMap.from_func(lambda i0, i1, i2, i3: (i3, i0, i1, i2)) + + inp = np.random.randn(10, 10, 10, 10).astype("float16") + + out = index_map.map_ndarray(tvm.nd.array(inp)).numpy() + + ref = np.transpose(inp, (3, 0, 1, 2)) - # arr = tvm.nd.array(np.arange(16).astype("int32")) - # print(index_map.map_ndarray(arr)) + np.testing.assert_equal(ref, out) - index_map = IndexMap.from_func(lambda i0, i1, i2, i3: (floordiv(i3, 32), i0, floordiv(i2, 8), floordiv(floormod(i3, 32), 16), i1, floormod(i2, 8), floormod(i3, 16))) - # index_map = IndexMap.from_func(lambda i0, i1, i2, i3: (i3, i0, i1, i2)) - # index_map = IndexMap.from_func(lambda i0, i1: (i1, i0)) + index_map = IndexMap.from_func( + lambda i0, i1, i2, i3: ( + floordiv(i3, 32), + i0, + floordiv(i2, 8), + floordiv(floormod(i3, 32), 16), + i1, + floormod(i2, 8), + floormod(i3, 16), + ) + ) kH = kW = 3 I = 64 @@ -228,11 +254,9 @@ def test_map_ndarray(): for i3 in range(O): v = inp[i0, i1, i2, i3] ref[i3 // 32, i0, i2 // 8, (i3 % 32) // 16, i1, i2 % 8, i3 % 16] = v - # ref[i3, i0, i1, i2] = v - print(np.max(np.abs(ref - out))) + np.testing.assert_equal(ref, out) if __name__ == "__main__": - # tvm.testing.main() - test_map_ndarray() + tvm.testing.main() From 133a320f293be2787232ebfb1204110730b20fbc Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 30 Sep 2022 16:50:48 +0900 Subject: [PATCH 3/9] clean --- src/tir/ir/index_map.cc | 48 +++++++++++++------------ tests/python/unittest/test_index_map.py | 3 +- 2 files changed, 27 insertions(+), 24 deletions(-) diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index f42fb48c6ff8..ecc9bf8a132d 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -208,12 +208,12 @@ Array IndexMapNode::MapShape(const Array& shape, return output; } -runtime::NDArray IndexMapNode::MapNDArray(runtime::NDArray constant) const { - auto shape = constant.Shape(); - size_t extent = 1; +runtime::NDArray IndexMapNode::MapNDArray(runtime::NDArray arr_src) const { + auto shape = arr_src.Shape(); + size_t size_1d = 1; Array orig_shape; for (int i = 0; i < shape.size(); ++i) { - extent *= shape[i]; + size_1d *= shape[i]; orig_shape.push_back(PrimExpr(int(shape[i]))); } auto dst_shape = MapShape(orig_shape); @@ -223,36 +223,40 @@ runtime::NDArray IndexMapNode::MapNDArray(runtime::NDArray constant) const { dst_shape_int.push_back(dst_shape[i].as()->value); } - auto elem_bytes = (constant->dtype.bits / 8) * constant->dtype.lanes; - std::vector bytes(extent * elem_bytes); - constant.CopyToBytes(bytes.data(), bytes.size()); + auto elem_bytes = (arr_src->dtype.bits / 8) * arr_src->dtype.lanes; + std::vector bytes_src(size_1d * elem_bytes); + arr_src.CopyToBytes(bytes_src.data(), bytes_src.size()); - std::vector bytes_rewritten(bytes.size()); + std::vector bytes_dst(bytes_src.size()); - for (size_t i = 0; i < extent; ++i) { + for (size_t i = 0; i < size_1d; ++i) { + // Convert a linear coordinate to an N-d coordinate tuple + // z * height * width + y * width + x -> (z, y, x) Array src_indices; - auto div_factor = extent; - auto index = i; + auto div_factor = size_1d; + auto src_linear_index = i; for (auto s : shape) { div_factor /= s; - src_indices.push_back(PrimExpr(int(index / div_factor))); - index = index % div_factor; + src_indices.push_back(PrimExpr(int(src_linear_index / div_factor))); + src_linear_index %= div_factor; } auto dst_indices = MapIndices(src_indices); - size_t dst_linear_indices = 0; - auto mul_factor = extent; + + // Convert an N-d coordinate to a linear coordinate + // (z, y, x) -> z * height * width + y * width + x + size_t dst_linear_index = 0; + auto mul_factor = size_1d; for (int j = 0; j < dst_indices.size(); ++j) { mul_factor /= dst_shape_int[j]; - dst_linear_indices += dst_indices[j].as()->value * mul_factor; + dst_linear_index += dst_indices[j].as()->value * mul_factor; } - std::copy(bytes.begin() + i * elem_bytes, bytes.begin() + (i + 1) * elem_bytes, - bytes_rewritten.begin() + dst_linear_indices * elem_bytes); + std::copy(bytes_src.begin() + i * elem_bytes, bytes_src.begin() + (i + 1) * elem_bytes, + bytes_dst.begin() + dst_linear_index * elem_bytes); } - auto rewritten_constant = - runtime::NDArray::Empty(dst_shape_int, constant->dtype, constant->device); - rewritten_constant.CopyFromBytes(bytes_rewritten.data(), bytes.size()); - return rewritten_constant; + auto arr_dst = runtime::NDArray::Empty(dst_shape_int, arr_src->dtype, arr_src->device); + arr_dst.CopyFromBytes(bytes_dst.data(), bytes_dst.size()); + return arr_dst; } /*! diff --git a/tests/python/unittest/test_index_map.py b/tests/python/unittest/test_index_map.py index 2ee9a064cb0b..b5bd0fdac65f 100644 --- a/tests/python/unittest/test_index_map.py +++ b/tests/python/unittest/test_index_map.py @@ -243,8 +243,7 @@ def test_map_ndarray(): I = 64 O = 64 inp = np.random.randn(kH, kW, I, O).astype("float32") - arr = tvm.nd.array(inp) - out = index_map.map_ndarray(arr).numpy() + out = index_map.map_ndarray(tvm.nd.array(inp)).numpy() ref = np.zeros(out.shape).astype("float32") From 7cb0494923bc7a87d6e54e1bf97709648794c0ad Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 30 Sep 2022 16:56:23 +0900 Subject: [PATCH 4/9] add doc --- include/tvm/tir/index_map.h | 9 ++++++--- python/tvm/tir/function.py | 18 ++++++++++++++---- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/include/tvm/tir/index_map.h b/include/tvm/tir/index_map.h index 3b1328575d09..0c1fabbf266f 100644 --- a/include/tvm/tir/index_map.h +++ b/include/tvm/tir/index_map.h @@ -136,10 +136,13 @@ class IndexMapNode : public Object { */ Array MapShape(const Array& shape, arith::Analyzer* analyzer = nullptr) const; - /* - TODO + /* \brief Map an NDArray according to this index map + * + * \param arr_src The NDArray whose layout is transformed by this index map. + * + * \returns The transformed NDArray. */ - runtime::NDArray MapNDArray(runtime::NDArray constant) const; + runtime::NDArray MapNDArray(runtime::NDArray arr_src) const; /*! * \brief Convert to string representation in Python. diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 0c4772a378c4..4628ae36265f 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -28,6 +28,7 @@ from .buffer import Buffer from .expr import Var, PrimExpr from . import _ffi_api +from ..runtime.ndarray import NDArray @tvm._ffi.register_object("tir.PrimFunc") @@ -515,11 +516,20 @@ def map_shape(self, shape: List[PrimExpr]) -> List[PrimExpr]: """ return _ffi_api.IndexMapMapShape(self, shape) - def map_ndarray(self, arr): - """ - TODO + def map_ndarray(self, arr_src: NDArray) -> NDArray: + """Apply thie index map to transform the layout of the input NDArray + + Parameters + ---------- + arr_src : runtime.NDArray + The NDArray to be transformed + + Returns + ------- + arr_dst : runtime.NDArray + The transformed NDArray """ - return _ffi_api.IndexMapMapNDArray(self, arr) + return _ffi_api.IndexMapMapNDArray(self, arr_src) def inverse(self, shape: List[Union[Range, PrimExpr]]) -> "IndexMap": """Return the inverse of the map From 0bc49f4a1440f27fd77593a41c1b7c85467e6505 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 30 Sep 2022 16:56:44 +0900 Subject: [PATCH 5/9] clang format --- include/tvm/tir/index_map.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/tir/index_map.h b/include/tvm/tir/index_map.h index 0c1fabbf266f..35a74d294fbb 100644 --- a/include/tvm/tir/index_map.h +++ b/include/tvm/tir/index_map.h @@ -141,7 +141,7 @@ class IndexMapNode : public Object { * \param arr_src The NDArray whose layout is transformed by this index map. * * \returns The transformed NDArray. - */ + */ runtime::NDArray MapNDArray(runtime::NDArray arr_src) const; /*! From c7794484bb18bb5b1ff7d71be05a6138fd44cfbe Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 30 Sep 2022 17:07:56 +0900 Subject: [PATCH 6/9] add rank check --- src/tir/ir/index_map.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index ecc9bf8a132d..7af695f694cd 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -210,6 +210,9 @@ Array IndexMapNode::MapShape(const Array& shape, runtime::NDArray IndexMapNode::MapNDArray(runtime::NDArray arr_src) const { auto shape = arr_src.Shape(); + ICHECK(shape.size() == initial_indices.size()) + << "The rank of the input array should be " << initial_indices.size() << " but got " + << shape.size(); size_t size_1d = 1; Array orig_shape; for (int i = 0; i < shape.size(); ++i) { From 4e36b82c81dbc69c73ce59cc6822bbda253e9945 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 30 Sep 2022 17:16:08 +0900 Subject: [PATCH 7/9] cpplint --- src/tir/ir/index_map.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index 7af695f694cd..0569e41c580c 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -217,7 +217,7 @@ runtime::NDArray IndexMapNode::MapNDArray(runtime::NDArray arr_src) const { Array orig_shape; for (int i = 0; i < shape.size(); ++i) { size_1d *= shape[i]; - orig_shape.push_back(PrimExpr(int(shape[i]))); + orig_shape.push_back(PrimExpr(static_cast((shape[i])))); } auto dst_shape = MapShape(orig_shape); @@ -240,7 +240,7 @@ runtime::NDArray IndexMapNode::MapNDArray(runtime::NDArray arr_src) const { auto src_linear_index = i; for (auto s : shape) { div_factor /= s; - src_indices.push_back(PrimExpr(int(src_linear_index / div_factor))); + src_indices.push_back(PrimExpr(static_cast((src_linear_index / div_factor)))); src_linear_index %= div_factor; } auto dst_indices = MapIndices(src_indices); From e05ef0fb2f1a3fcb493597aed74bf119f1a79ffc Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 30 Sep 2022 19:23:18 +0900 Subject: [PATCH 8/9] fix compile warning --- src/tir/ir/index_map.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index 0569e41c580c..6d982b510a26 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -215,14 +215,14 @@ runtime::NDArray IndexMapNode::MapNDArray(runtime::NDArray arr_src) const { << shape.size(); size_t size_1d = 1; Array orig_shape; - for (int i = 0; i < shape.size(); ++i) { + for (size_t i = 0; i < shape.size(); ++i) { size_1d *= shape[i]; orig_shape.push_back(PrimExpr(static_cast((shape[i])))); } auto dst_shape = MapShape(orig_shape); std::vector dst_shape_int; - for (int i = 0; i < dst_shape.size(); ++i) { + for (size_t i = 0; i < dst_shape.size(); ++i) { dst_shape_int.push_back(dst_shape[i].as()->value); } @@ -249,7 +249,7 @@ runtime::NDArray IndexMapNode::MapNDArray(runtime::NDArray arr_src) const { // (z, y, x) -> z * height * width + y * width + x size_t dst_linear_index = 0; auto mul_factor = size_1d; - for (int j = 0; j < dst_indices.size(); ++j) { + for (size_t j = 0; j < dst_indices.size(); ++j) { mul_factor /= dst_shape_int[j]; dst_linear_index += dst_indices[j].as()->value * mul_factor; } From 745acc1f8b967fafe42bfbba73b5014b4e2b959d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 30 Sep 2022 20:46:26 +0900 Subject: [PATCH 9/9] add test for inverse --- tests/python/unittest/test_index_map.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_index_map.py b/tests/python/unittest/test_index_map.py index b5bd0fdac65f..804d04d0b052 100644 --- a/tests/python/unittest/test_index_map.py +++ b/tests/python/unittest/test_index_map.py @@ -243,7 +243,8 @@ def test_map_ndarray(): I = 64 O = 64 inp = np.random.randn(kH, kW, I, O).astype("float32") - out = index_map.map_ndarray(tvm.nd.array(inp)).numpy() + arr = tvm.nd.array(inp) + out = index_map.map_ndarray(arr).numpy() ref = np.zeros(out.shape).astype("float32") @@ -256,6 +257,9 @@ def test_map_ndarray(): np.testing.assert_equal(ref, out) + inverse_map = index_map.inverse(inp.shape) + np.testing.assert_equal(inverse_map.map_ndarray(index_map.map_ndarray(arr)).numpy(), inp) + if __name__ == "__main__": tvm.testing.main()