From 94d769e2fd353da9ca9d0217cfa61a8077df025f Mon Sep 17 00:00:00 2001 From: zhengzangw Date: Tue, 12 Jul 2022 17:32:57 +0800 Subject: [PATCH] [NFC] polish colossalai/communication/collective.py --- colossalai/communication/collective.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/colossalai/communication/collective.py b/colossalai/communication/collective.py index 50fd7dcc2d55..2c9e9927c7d9 100644 --- a/colossalai/communication/collective.py +++ b/colossalai/communication/collective.py @@ -10,10 +10,7 @@ from colossalai.core import global_context as gpc -def all_gather(tensor: Tensor, - dim: int, - parallel_mode: ParallelMode, - async_op: bool = False) -> Tensor: +def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor: r"""Gathers all tensors from the parallel group and concatenates them in a specific dimension. @@ -163,11 +160,7 @@ def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: b return out -def reduce(tensor: Tensor, - dst: int, - parallel_mode: ParallelMode, - op: ReduceOp = ReduceOp.SUM, - async_op: bool = False): +def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False): r"""Reduce tensors across whole parallel group. Only the process with rank ``dst`` is going to receive the final result.