diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index c7bfefa9a8e7..b7153ab50f1f 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -181,7 +181,30 @@ struct NDArray::Internal { NDArray NDArray::CreateView(ShapeTuple shape, DLDataType dtype) { ICHECK(data_ != nullptr); - ICHECK(get_mutable()->dl_tensor.strides == nullptr) << "Can only create view for compact tensor"; + + const DLTensor& orig = get_mutable()->dl_tensor; + ICHECK(IsContiguous()) << "Can only create view for compact tensor, but found strides " << + [&orig]() { + std::stringstream ss; + ss << "["; + for (int i = 0; i < orig.ndim; i++) { + if (i) ss << ", "; + ss << orig.strides[i]; + } + ss << "]"; + return ss.str(); + }() << ", for shape " + << [&]() { + std::stringstream ss; + ss << "["; + for (int i = 0; i < orig.ndim; i++) { + if (i) ss << ", "; + ss << orig.shape[i]; + } + ss << "]"; + return ss.str(); + }(); + NDArray ret = Internal::Create(shape, dtype, get_mutable()->dl_tensor.device); ret.get_mutable()->dl_tensor.byte_offset = this->get_mutable()->dl_tensor.byte_offset; size_t curr_size = GetDataSize(this->get_mutable()->dl_tensor); diff --git a/tests/python/unittest/test_runtime_dlpack.py b/tests/python/unittest/test_runtime_dlpack.py index 3f13e2e5fed5..cf12c89cdd51 100644 --- a/tests/python/unittest/test_runtime_dlpack.py +++ b/tests/python/unittest/test_runtime_dlpack.py @@ -48,5 +48,18 @@ def test_from_dlpack_shape_one(): tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy()) +@tvm.testing.requires_package("torch") +def test_from_dlpack_strided(): + import torch + from torch.utils.dlpack import to_dlpack + + rows = 1 + inp = torch.randn(rows, 16) + a = tvm.runtime.ndarray.from_dlpack(to_dlpack(inp)) + view = a._create_view((2, 8)) + + np.testing.assert_equal(inp.numpy().reshape(2, 8), view.numpy()) + + if __name__ == "__main__": tvm.testing.main()