From 2caa7ed301a529392c5443662f892b22e3a07e37 Mon Sep 17 00:00:00 2001 From: Yuanchen Xu Date: Tue, 28 Mar 2023 09:41:00 +0800 Subject: [PATCH] [NFC] polish code style --- colossalai/nn/_ops/_utils.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/colossalai/nn/_ops/_utils.py b/colossalai/nn/_ops/_utils.py index 56bb5f465184..24877bbb552f 100644 --- a/colossalai/nn/_ops/_utils.py +++ b/colossalai/nn/_ops/_utils.py @@ -1,12 +1,11 @@ -import torch -from typing import Union, Optional, List -from colossalai.tensor import ColoTensor +from typing import List, Optional, Union + import torch import torch.distributed as dist -from colossalai.global_variables import tensor_parallel_env as env +from colossalai.global_variables import tensor_parallel_env as env from colossalai.nn.layer.utils import divide -from colossalai.tensor import ProcessGroup, ColoTensorSpec +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup GeneralTensor = Union[ColoTensor, torch.Tensor] Number = Union[int, float] @@ -135,7 +134,7 @@ def backward(ctx, grad_output): class _SplitForwardGatherBackward(torch.autograd.Function): """ Split the input and keep only the corresponding chuck to the rank. - + Args: input_: input matrix. process_group: parallel mode.