From a3aaaab7b6a1511316a851ebfbbcef9818e84c74 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 11 Apr 2023 08:00:52 +0100 Subject: [PATCH 1/2] fixes #6326 Signed-off-by: Wenqi Li --- monai/data/meta_tensor.py | 13 ++++++++++--- tests/test_meta_tensor.py | 1 + 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 48b9320f99..59de3d9369 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -510,9 +510,16 @@ def new_empty(self, size, dtype=None, device=None, requires_grad=False): self.as_tensor().new_empty(size=size, dtype=dtype, device=device, requires_grad=requires_grad) ) - def clone(self): - """returns a copy of the MetaTensor instance.""" - new_inst = MetaTensor(self.as_tensor().clone()) + def clone(self, *_args, **kwargs): + """ + returns a copy of the MetaTensor instance. + + Args: + kwargs: additional keyword arguments to `torch.clone`. + + See also: https://pytorch.org/docs/stable/generated/torch.clone.html + """ + new_inst = MetaTensor(self.as_tensor().clone(**kwargs)) new_inst.__dict__ = deepcopy(self.__dict__) return new_inst diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index a6607a3ccd..e547675a0e 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -177,6 +177,7 @@ def test_copy(self, device, dtype): a = deepcopy(m) self.check(a, m, ids=False) # clone + a = m.clone(memory_format=torch.preserve_format) a = m.clone() self.check(a, m, ids=False) a = MetaTensor([[]], device=device, dtype=dtype) From 83eae90b64f00530c1bdf5159e35a2d3a7040711 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 11 Apr 2023 08:30:12 +0100 Subject: [PATCH 2/2] update based on comments Signed-off-by: Wenqi Li --- monai/data/meta_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 59de3d9369..e3aacb95ee 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -510,9 +510,9 @@ def new_empty(self, size, dtype=None, device=None, requires_grad=False): self.as_tensor().new_empty(size=size, dtype=dtype, device=device, requires_grad=requires_grad) ) - def clone(self, *_args, **kwargs): + def clone(self, **kwargs): """ - returns a copy of the MetaTensor instance. + Returns a copy of the MetaTensor instance. Args: kwargs: additional keyword arguments to `torch.clone`.