From b7f83c49cedd9676aeafba19bd6a2c2cfcaf97eb Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 12 Jul 2022 20:44:30 +0800 Subject: [PATCH] fix colo-tensor --- colossalai/tensor/colo_tensor.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 8cfb316b2f9d..7297a688cbff 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -40,7 +40,7 @@ def _scan_for_pg_from_args(args, kwargs) -> ProcessGroup: pg = _scan_for_pg_from_args(elem, {}) if pg is not None: return pg - for k, v in kwargs: + for k, v in kwargs.items(): if isinstance(v, ColoTensor): pg = v.get_process_group() return pg @@ -52,7 +52,7 @@ class ColoTensor(torch.Tensor): Args: data (torch.Tensor): a torch tensor used as the payload the colotensor. spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()). - + The signature of the function has to be consistent with the __new__ except for the 1st arg. The class should be initialized with a torch tensor in the following ways. 1. directly init. @@ -202,7 +202,6 @@ def to_replicate(self) -> 'ColoTensor': """ return self.redistribute(ReplicaSpec()) - @staticmethod def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor': tensor = tensor.as_subclass(ColoTensor)