diff --git a/include/tvm/tir/index_map.h b/include/tvm/tir/index_map.h index e1b323462cda..35a74d294fbb 100644 --- a/include/tvm/tir/index_map.h +++ b/include/tvm/tir/index_map.h @@ -136,6 +136,14 @@ class IndexMapNode : public Object { */ Array MapShape(const Array& shape, arith::Analyzer* analyzer = nullptr) const; + /* \brief Map an NDArray according to this index map + * + * \param arr_src The NDArray whose layout is transformed by this index map. + * + * \returns The transformed NDArray. + */ + runtime::NDArray MapNDArray(runtime::NDArray arr_src) const; + /*! * \brief Convert to string representation in Python. * \return The stringified lambda expression in Python. diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index dd684bc4f1ae..4628ae36265f 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -28,6 +28,7 @@ from .buffer import Buffer from .expr import Var, PrimExpr from . import _ffi_api +from ..runtime.ndarray import NDArray @tvm._ffi.register_object("tir.PrimFunc") @@ -515,6 +516,21 @@ def map_shape(self, shape: List[PrimExpr]) -> List[PrimExpr]: """ return _ffi_api.IndexMapMapShape(self, shape) + def map_ndarray(self, arr_src: NDArray) -> NDArray: + """Apply thie index map to transform the layout of the input NDArray + + Parameters + ---------- + arr_src : runtime.NDArray + The NDArray to be transformed + + Returns + ------- + arr_dst : runtime.NDArray + The transformed NDArray + """ + return _ffi_api.IndexMapMapNDArray(self, arr_src) + def inverse(self, shape: List[Union[Range, PrimExpr]]) -> "IndexMap": """Return the inverse of the map diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index 2c5349ab9941..6d982b510a26 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -208,6 +208,60 @@ Array IndexMapNode::MapShape(const Array& shape, return output; } +runtime::NDArray IndexMapNode::MapNDArray(runtime::NDArray arr_src) const { + auto shape = arr_src.Shape(); + ICHECK(shape.size() == initial_indices.size()) + << "The rank of the input array should be " << initial_indices.size() << " but got " + << shape.size(); + size_t size_1d = 1; + Array orig_shape; + for (size_t i = 0; i < shape.size(); ++i) { + size_1d *= shape[i]; + orig_shape.push_back(PrimExpr(static_cast((shape[i])))); + } + auto dst_shape = MapShape(orig_shape); + + std::vector dst_shape_int; + for (size_t i = 0; i < dst_shape.size(); ++i) { + dst_shape_int.push_back(dst_shape[i].as()->value); + } + + auto elem_bytes = (arr_src->dtype.bits / 8) * arr_src->dtype.lanes; + std::vector bytes_src(size_1d * elem_bytes); + arr_src.CopyToBytes(bytes_src.data(), bytes_src.size()); + + std::vector bytes_dst(bytes_src.size()); + + for (size_t i = 0; i < size_1d; ++i) { + // Convert a linear coordinate to an N-d coordinate tuple + // z * height * width + y * width + x -> (z, y, x) + Array src_indices; + auto div_factor = size_1d; + auto src_linear_index = i; + for (auto s : shape) { + div_factor /= s; + src_indices.push_back(PrimExpr(static_cast((src_linear_index / div_factor)))); + src_linear_index %= div_factor; + } + auto dst_indices = MapIndices(src_indices); + + // Convert an N-d coordinate to a linear coordinate + // (z, y, x) -> z * height * width + y * width + x + size_t dst_linear_index = 0; + auto mul_factor = size_1d; + for (size_t j = 0; j < dst_indices.size(); ++j) { + mul_factor /= dst_shape_int[j]; + dst_linear_index += dst_indices[j].as()->value * mul_factor; + } + std::copy(bytes_src.begin() + i * elem_bytes, bytes_src.begin() + (i + 1) * elem_bytes, + bytes_dst.begin() + dst_linear_index * elem_bytes); + } + + auto arr_dst = runtime::NDArray::Empty(dst_shape_int, arr_src->dtype, arr_src->device); + arr_dst.CopyFromBytes(bytes_dst.data(), bytes_dst.size()); + return arr_dst; +} + /*! * \brief Auxilarry function to comvert an index map to lambda expression in Python. * \param initial_indices The initial indices in the index map. @@ -289,6 +343,9 @@ TVM_REGISTER_GLOBAL("tir.IndexMapMapShape").set_body_typed([](IndexMap map, Arra }); TVM_REGISTER_GLOBAL("tir.IndexMapInverse").set_body_method(&IndexMap::Inverse); +TVM_REGISTER_GLOBAL("tir.IndexMapMapNDArray") + .set_body_typed([](IndexMap map, runtime::NDArray arr) { return map->MapNDArray(arr); }); + TVM_REGISTER_GLOBAL("tir.IndexMapNonSurjectiveInverse") .set_body_typed([](IndexMap forward, Array initial_ranges) { auto result = forward.NonSurjectiveInverse(initial_ranges); diff --git a/tests/python/unittest/test_index_map.py b/tests/python/unittest/test_index_map.py index a86880b0f4a8..804d04d0b052 100644 --- a/tests/python/unittest/test_index_map.py +++ b/tests/python/unittest/test_index_map.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy as np import pytest import tvm @@ -202,5 +203,63 @@ def expected_inverse(i0, i1, i2, i3): assert expected_map.is_equivalent_to(inverse_map) +def test_map_ndarray(): + index_map = IndexMap.from_func(lambda i: [i // 4, i % 4]) + + inp = np.arange(16).astype("int8") + + out = index_map.map_ndarray(tvm.nd.array(inp)).numpy() + + ref = np.zeros(out.shape).astype("int8") + + for i in range(16): + ref[i // 4, i % 4] = inp[i] + + np.testing.assert_equal(ref, out) + + index_map = IndexMap.from_func(lambda i0, i1, i2, i3: (i3, i0, i1, i2)) + + inp = np.random.randn(10, 10, 10, 10).astype("float16") + + out = index_map.map_ndarray(tvm.nd.array(inp)).numpy() + + ref = np.transpose(inp, (3, 0, 1, 2)) + + np.testing.assert_equal(ref, out) + + index_map = IndexMap.from_func( + lambda i0, i1, i2, i3: ( + floordiv(i3, 32), + i0, + floordiv(i2, 8), + floordiv(floormod(i3, 32), 16), + i1, + floormod(i2, 8), + floormod(i3, 16), + ) + ) + + kH = kW = 3 + I = 64 + O = 64 + inp = np.random.randn(kH, kW, I, O).astype("float32") + arr = tvm.nd.array(inp) + out = index_map.map_ndarray(arr).numpy() + + ref = np.zeros(out.shape).astype("float32") + + for i0 in range(kH): + for i1 in range(kW): + for i2 in range(I): + for i3 in range(O): + v = inp[i0, i1, i2, i3] + ref[i3 // 32, i0, i2 // 8, (i3 % 32) // 16, i1, i2 % 8, i3 % 16] = v + + np.testing.assert_equal(ref, out) + + inverse_map = index_map.inverse(inp.shape) + np.testing.assert_equal(inverse_map.map_ndarray(index_map.map_ndarray(arr)).numpy(), inp) + + if __name__ == "__main__": tvm.testing.main()