diff --git a/colossalai/tensor/dist_spec_mgr.py b/colossalai/tensor/dist_spec_mgr.py index 51b7bfb918ec..f1dc241a8968 100644 --- a/colossalai/tensor/dist_spec_mgr.py +++ b/colossalai/tensor/dist_spec_mgr.py @@ -88,11 +88,13 @@ def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor: a replicated tensor. """ assert old_dist_spec.placement.value == 's', f"The old_dist_spec of DistSpecManager._gather must be SHARD!" - if version.parse(torch.__version__) < version.parse("1.11.0"): + is_cpu_tensor = False + if tensor.device.type == 'cpu': # pytorch lower than 1.11 dose not support gather a cpu tensor. # Therefore, we transfer tensor to GPU before gather. saved_dev = tensor.device tensor.data = tensor.data.cuda() + is_cpu_tensor = True buffer = [torch.empty_like(tensor) for _ in range(pg.tp_world_size())] assert tensor.device.type == 'cuda' @@ -106,7 +108,7 @@ def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec, pg: ProcessGroup) -> buffer = new_buffer assert len(buffer) == 1 - if version.parse(torch.__version__) < version.parse("1.11.0"): + if is_cpu_tensor: buffer[0].data = buffer[0].data.to(saved_dev) return buffer[0]