diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 99594ee50f93..ac12472a903e 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -208,6 +208,10 @@ NDArray NDArray::FromDLPack(DLManagedTensor* tensor) { // fill up content. data->manager_ctx = tensor; data->dl_tensor = tensor->dl_tensor; + // update shape_ + data->shape_.resize(data->dl_tensor.ndim); + data->shape_.assign(data->dl_tensor.shape, data->dl_tensor.shape + data->dl_tensor.ndim); + data->dl_tensor.shape = data->shape_.data(); return NDArray(GetObjectPtr(data)); }