From 0d25111ac12559e8404c40b46de90a25b4c3b7c5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 21 Jun 2023 10:46:18 +0900 Subject: [PATCH 1/2] Allow creating a view from a stride array --- src/runtime/ndarray.cc | 42 +++++++++++++++++++- tests/python/unittest/test_runtime_dlpack.py | 13 ++++++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index c7bfefa9a8e7..3f7577b76f69 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -181,7 +181,47 @@ struct NDArray::Internal { NDArray NDArray::CreateView(ShapeTuple shape, DLDataType dtype) { ICHECK(data_ != nullptr); - ICHECK(get_mutable()->dl_tensor.strides == nullptr) << "Can only create view for compact tensor"; + + const DLTensor& orig = get_mutable()->dl_tensor; + bool is_compact = [&orig]() { + if (orig.strides == nullptr) { + return true; + } + + int compact_stride = 1; + for (int i = orig.ndim; i > 0; i--) { + int shape_i = orig.shape[i - 1]; + int stride_i = orig.strides[i - 1]; + if (compact_stride != stride_i && shape_i != 1) { + return false; + } + compact_stride *= shape_i; + } + return true; + }(); + + ICHECK(is_compact) << "Can only create view for compact tensor, but found strides " << + [&]() { + std::stringstream ss; + ss << "["; + for (int i = 0; i < orig.ndim; i++) { + if (i) ss << ", "; + ss << orig.strides[i]; + } + ss << "]"; + return ss.str(); + }() << ", for shape " + << [&]() { + std::stringstream ss; + ss << "["; + for (int i = 0; i < orig.ndim; i++) { + if (i) ss << ", "; + ss << orig.shape[i]; + } + ss << "]"; + return ss.str(); + }(); + NDArray ret = Internal::Create(shape, dtype, get_mutable()->dl_tensor.device); ret.get_mutable()->dl_tensor.byte_offset = this->get_mutable()->dl_tensor.byte_offset; size_t curr_size = GetDataSize(this->get_mutable()->dl_tensor); diff --git a/tests/python/unittest/test_runtime_dlpack.py b/tests/python/unittest/test_runtime_dlpack.py index 3f13e2e5fed5..cf12c89cdd51 100644 --- a/tests/python/unittest/test_runtime_dlpack.py +++ b/tests/python/unittest/test_runtime_dlpack.py @@ -48,5 +48,18 @@ def test_from_dlpack_shape_one(): tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy()) +@tvm.testing.requires_package("torch") +def test_from_dlpack_strided(): + import torch + from torch.utils.dlpack import to_dlpack + + rows = 1 + inp = torch.randn(rows, 16) + a = tvm.runtime.ndarray.from_dlpack(to_dlpack(inp)) + view = a._create_view((2, 8)) + + np.testing.assert_equal(inp.numpy().reshape(2, 8), view.numpy()) + + if __name__ == "__main__": tvm.testing.main() From e40c3443b1c24cf2c1b3ce10a63c03ce7a940c7f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 22 Jun 2023 04:28:13 +0900 Subject: [PATCH 2/2] use IsContiguous --- src/runtime/ndarray.cc | 41 ++++++++++++----------------------------- 1 file changed, 12 insertions(+), 29 deletions(-) diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 3f7577b76f69..b7153ab50f1f 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -183,25 +183,8 @@ NDArray NDArray::CreateView(ShapeTuple shape, DLDataType dtype) { ICHECK(data_ != nullptr); const DLTensor& orig = get_mutable()->dl_tensor; - bool is_compact = [&orig]() { - if (orig.strides == nullptr) { - return true; - } - - int compact_stride = 1; - for (int i = orig.ndim; i > 0; i--) { - int shape_i = orig.shape[i - 1]; - int stride_i = orig.strides[i - 1]; - if (compact_stride != stride_i && shape_i != 1) { - return false; - } - compact_stride *= shape_i; - } - return true; - }(); - - ICHECK(is_compact) << "Can only create view for compact tensor, but found strides " << - [&]() { + ICHECK(IsContiguous()) << "Can only create view for compact tensor, but found strides " << + [&orig]() { std::stringstream ss; ss << "["; for (int i = 0; i < orig.ndim; i++) { @@ -211,16 +194,16 @@ NDArray NDArray::CreateView(ShapeTuple shape, DLDataType dtype) { ss << "]"; return ss.str(); }() << ", for shape " - << [&]() { - std::stringstream ss; - ss << "["; - for (int i = 0; i < orig.ndim; i++) { - if (i) ss << ", "; - ss << orig.shape[i]; - } - ss << "]"; - return ss.str(); - }(); + << [&]() { + std::stringstream ss; + ss << "["; + for (int i = 0; i < orig.ndim; i++) { + if (i) ss << ", "; + ss << orig.shape[i]; + } + ss << "]"; + return ss.str(); + }(); NDArray ret = Internal::Create(shape, dtype, get_mutable()->dl_tensor.device); ret.get_mutable()->dl_tensor.byte_offset = this->get_mutable()->dl_tensor.byte_offset;