From ac39fddd2c84b038e03697d33700d53e8a3a8679 Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 12 Sep 2023 15:38:07 +0800 Subject: [PATCH 1/2] [legacy] remove outdated colo tensor --- colossalai/context/moe_context.py | 2 +- colossalai/fx/passes/shard_1d_pass.py | 16 +- colossalai/legacy/nn/__init__.py | 1 - colossalai/legacy/nn/_ops/__init__.py | 10 +- colossalai/legacy/nn/_ops/_utils.py | 3 +- colossalai/legacy/nn/_ops/addmm.py | 90 ------- colossalai/legacy/nn/_ops/batch_norm.py | 33 --- colossalai/legacy/nn/_ops/element_wise.py | 250 ------------------ colossalai/legacy/nn/_ops/embedding.py | 142 ---------- colossalai/legacy/nn/_ops/embedding_bag.py | 127 --------- colossalai/legacy/nn/_ops/layernorm.py | 28 -- colossalai/legacy/nn/_ops/linear.py | 171 ------------ colossalai/legacy/nn/_ops/loss.py | 51 ---- colossalai/legacy/nn/_ops/view.py | 96 ------- .../legacy/nn/parallel/data_parallel.py | 2 +- .../parallel_cached_embedding.py | 3 +- .../parallel_cached_embedding_tablewise.py | 2 +- ..._cached_embedding_tablewise_split_cache.py | 2 +- .../legacy/nn/parallel/layers/colo_module.py | 4 +- .../legacy/nn/parallel/layers/embedding.py | 2 +- .../legacy/nn/parallel/layers/linear.py | 2 +- .../legacy/nn/parallel/layers/module_utils.py | 3 +- .../legacy/pipeline/pipeline_process_group.py | 2 +- colossalai/legacy/tensor/__init__.py | 17 ++ .../{ => legacy}/tensor/compute_spec.py | 0 colossalai/{ => legacy}/tensor/const.py | 0 .../{ => legacy}/tensor/dist_spec_mgr.py | 6 +- colossalai/{ => legacy}/tensor/distspec.py | 0 colossalai/{ => legacy}/tensor/op_wrapper.py | 5 +- .../{ => legacy}/tensor/process_group.py | 0 colossalai/{ => legacy}/tensor/tensor_spec.py | 4 +- colossalai/tensor/__init__.py | 11 +- colossalai/utils/common.py | 3 +- colossalai/zero/gemini/colo_init_context.py | 3 +- .../test_compatibility_with_gemini.py | 5 +- .../test_layers/test_cache_embedding.py | 3 +- .../test_tensor/common_utils/__init__.py | 2 +- .../test_tensor/common_utils/_utils.py | 2 +- .../test_tensor/core/test_dist_spec_mgr.py | 2 +- .../test_tensor/test_parameter.py | 2 +- 40 files changed, 61 insertions(+), 1046 deletions(-) delete mode 100644 colossalai/legacy/nn/_ops/addmm.py delete mode 100644 colossalai/legacy/nn/_ops/batch_norm.py delete mode 100644 colossalai/legacy/nn/_ops/element_wise.py delete mode 100644 colossalai/legacy/nn/_ops/embedding.py delete mode 100644 colossalai/legacy/nn/_ops/embedding_bag.py delete mode 100644 colossalai/legacy/nn/_ops/layernorm.py delete mode 100644 colossalai/legacy/nn/_ops/linear.py delete mode 100644 colossalai/legacy/nn/_ops/loss.py delete mode 100644 colossalai/legacy/nn/_ops/view.py create mode 100644 colossalai/legacy/tensor/__init__.py rename colossalai/{ => legacy}/tensor/compute_spec.py (100%) rename colossalai/{ => legacy}/tensor/const.py (100%) rename colossalai/{ => legacy}/tensor/dist_spec_mgr.py (97%) rename colossalai/{ => legacy}/tensor/distspec.py (100%) rename colossalai/{ => legacy}/tensor/op_wrapper.py (97%) rename colossalai/{ => legacy}/tensor/process_group.py (100%) rename colossalai/{ => legacy}/tensor/tensor_spec.py (79%) rename tests/{ => test_legacy}/test_tensor/common_utils/__init__.py (95%) rename tests/{ => test_legacy}/test_tensor/common_utils/_utils.py (97%) rename tests/{ => test_legacy}/test_tensor/core/test_dist_spec_mgr.py (96%) rename tests/{ => test_legacy}/test_tensor/test_parameter.py (92%) diff --git a/colossalai/context/moe_context.py b/colossalai/context/moe_context.py index b41f4072a405..547a0c6646ee 100644 --- a/colossalai/context/moe_context.py +++ b/colossalai/context/moe_context.py @@ -5,7 +5,7 @@ from colossalai.context.parallel_mode import ParallelMode from colossalai.context.singleton_meta import SingletonMeta -from colossalai.tensor import ProcessGroup +from colossalai.legacy.tensor import ProcessGroup def _check_sanity(): diff --git a/colossalai/fx/passes/shard_1d_pass.py b/colossalai/fx/passes/shard_1d_pass.py index d2bad06bb45a..ccbab0c38a29 100644 --- a/colossalai/fx/passes/shard_1d_pass.py +++ b/colossalai/fx/passes/shard_1d_pass.py @@ -1,9 +1,11 @@ +import operator + import torch import torch.nn as nn -import operator -from colossalai.tensor import ProcessGroup -from colossalai.tensor.distspec import ShardSpec -from colossalai.tensor.compute_spec import ComputePattern, ComputeSpec + +from colossalai.legacy.tensor import ProcessGroup +from colossalai.legacy.tensor.compute_spec import ComputePattern, ComputeSpec +from colossalai.legacy.tensor.distspec import ShardSpec ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU] ELEMENTWISE_FUNC_OP = [ @@ -13,7 +15,7 @@ def weight_split(weight: torch.nn.parameter.Parameter, dim: int, col_normal: bool) -> torch.nn.parameter.Parameter: - """weight_split + """weight_split split a nn.Parameter Args: @@ -60,9 +62,9 @@ def row_shard_linear_pass(gm: torch.fx.GraphModule): def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: ProcessGroup): """ - This IR pass checks for transformer MLP like structure and annotate column and row sharding to the linear layers. + This IR pass checks for transformer MLP like structure and annotate column and row sharding to the linear layers. """ - #TODO: Needs to handle special cases, like x = linear(x) + linear(x) + # TODO: Needs to handle special cases, like x = linear(x) + linear(x) graph = graph_module.graph world_size = process_group.world_size() diff --git a/colossalai/legacy/nn/__init__.py b/colossalai/legacy/nn/__init__.py index 500162901905..d30ebf8d5406 100644 --- a/colossalai/legacy/nn/__init__.py +++ b/colossalai/legacy/nn/__init__.py @@ -1,4 +1,3 @@ -from ._ops import * from .layer import * from .loss import * from .metric import * diff --git a/colossalai/legacy/nn/_ops/__init__.py b/colossalai/legacy/nn/_ops/__init__.py index 4991ad9a2217..9a35d02ce5ed 100644 --- a/colossalai/legacy/nn/_ops/__init__.py +++ b/colossalai/legacy/nn/_ops/__init__.py @@ -1,9 +1 @@ -from .addmm import colo_addmm -from .batch_norm import colo_batch_norm -from .element_wise import * -from .embedding import colo_embedding -from .embedding_bag import colo_embedding_bag -from .layernorm import colo_layernorm -from .linear import colo_linear -from .loss import colo_cross_entropy -from .view import colo_view +from ._utils import * diff --git a/colossalai/legacy/nn/_ops/_utils.py b/colossalai/legacy/nn/_ops/_utils.py index 131c2154771b..dd4fe76fd54a 100644 --- a/colossalai/legacy/nn/_ops/_utils.py +++ b/colossalai/legacy/nn/_ops/_utils.py @@ -5,7 +5,8 @@ from colossalai.global_variables import tensor_parallel_env as env from colossalai.legacy.nn.layer.utils import divide -from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup +from colossalai.legacy.tensor import ColoTensorSpec, ProcessGroup +from colossalai.tensor import ColoTensor GeneralTensor = Union[ColoTensor, torch.Tensor] Number = Union[int, float] diff --git a/colossalai/legacy/nn/_ops/addmm.py b/colossalai/legacy/nn/_ops/addmm.py deleted file mode 100644 index 660b48a71d57..000000000000 --- a/colossalai/legacy/nn/_ops/addmm.py +++ /dev/null @@ -1,90 +0,0 @@ -import torch - -from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec, distspec -from colossalai.tensor.op_wrapper import colo_op_impl - -from ._utils import GeneralTensor, Number, convert_to_colo_tensor, reduce_grad, reduce_input - - -def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, - alpha: Number) -> ColoTensor: - # mat1:S[1] x mat2:S[0] = Output:P - # beta * input + alpha * All-Reduce(Output) = res - - mat1 = mat1.redistribute(ShardSpec([-1], [mat2.get_tp_world_size()]), mat2.get_process_group()) - - # Output:P - partial_output = torch.mm(mat1, mat2) - # Reduce(Output) - output = reduce_input(partial_output, mat2.get_process_group()) - # input - assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op' - output = beta * input_tensor + alpha * output - output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(input_tensor.get_process_group())) - return output - - -def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, - alpha: Number) -> ColoTensor: - # mat1:B x mat2:S[1] + input:S[1] = Output:S[1] - compute_spec = mat2.compute_spec - mat1 = mat1.redistribute(ReplicaSpec()) - mat1 = reduce_grad(mat1, mat1.get_process_group()) - - output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha) - output_spec = ColoTensorSpec(input_tensor.get_process_group(), ShardSpec([-1], [mat2.get_tp_world_size()]), - ComputeSpec(ComputePattern.TP1D)) - output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) - - if compute_spec.output_replicate: - return output.to_replicate() - else: - return output - - -def colo_addmm_1d(mode: str, input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, - alpha: Number) -> ColoTensor: - assert mode in ('row', 'col') - funcs = {'row': colo_addmm_1Drow, 'col': colo_addmm_1Dcol} - return funcs[mode](input_tensor, mat1, mat2, beta, alpha) - - -@colo_op_impl(torch.addmm) -def colo_addmm(input_tensor: GeneralTensor, - mat1: ColoTensor, - mat2: ColoTensor, - beta: Number = 1, - alpha: Number = 1, - **kargs) -> ColoTensor: - """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``. - This method computes a linear. - """ - # At least one of the tensor should be ColoTensor - assert isinstance(mat2, ColoTensor) - input_tensor = convert_to_colo_tensor(input_tensor, mat2.get_process_group()) - mat1 = convert_to_colo_tensor(mat1, mat2.get_process_group()) - - # Add communication logic before and after linear call. - ret_tensor = None - if not mat2.has_compute_spec(): # No Model Parallel Applied - assert mat2.is_replicate(), 'Invalid mat2 spec for native addmm op' - assert input_tensor.is_replicate(), 'Invalid input spec for native addmm op' - ret_tensor = ColoTensor.from_torch_tensor(tensor=torch.addmm(input_tensor, - mat1, - mat2, - beta=beta, - alpha=alpha, - **kargs), - spec=ColoTensorSpec(mat2.get_process_group())) - elif mat2.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied - if mat2.is_shard_1drow() and input_tensor.is_replicate(): - mode = 'row' - elif mat2.is_shard_1dcol() and (input_tensor.is_shard_1dcol() or input_tensor.is_shard_1drow()): - mode = 'col' - else: - raise NotImplementedError - ret_tensor = colo_addmm_1d(mode, input_tensor, mat1, mat2, beta, alpha) - else: - raise NotImplementedError - - return ret_tensor diff --git a/colossalai/legacy/nn/_ops/batch_norm.py b/colossalai/legacy/nn/_ops/batch_norm.py deleted file mode 100644 index 54ecc88f420a..000000000000 --- a/colossalai/legacy/nn/_ops/batch_norm.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Optional - -import torch.nn.functional as F - -from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec -from colossalai.tensor.op_wrapper import colo_op_impl - -from ._utils import GeneralTensor, convert_to_colo_tensor - - -@colo_op_impl(F.batch_norm) -def colo_batch_norm( - input: GeneralTensor, - running_mean: Optional[GeneralTensor], - running_var: Optional[GeneralTensor], - weight: Optional[GeneralTensor] = None, - bias: Optional[GeneralTensor] = None, - training: bool = False, - momentum: float = 0.1, - eps: float = 1e-5, -): - assert isinstance(weight, ColoTensor) - running_mean = running_mean.detach() - running_var = running_var.detach() - - input = convert_to_colo_tensor(input, weight.get_process_group()) - bias = convert_to_colo_tensor(bias, weight.get_process_group()) - input = input.redistribute(ReplicaSpec()) - bias = bias.redistribute(ReplicaSpec()) - - output = F.batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps) - output = ColoTensor.from_torch_tensor(tensor=output, spec=ColoTensorSpec(pg=weight.get_process_group())) - return output diff --git a/colossalai/legacy/nn/_ops/element_wise.py b/colossalai/legacy/nn/_ops/element_wise.py deleted file mode 100644 index 2de51e24a6dd..000000000000 --- a/colossalai/legacy/nn/_ops/element_wise.py +++ /dev/null @@ -1,250 +0,0 @@ -import torch -import torch.nn.functional as F -from torch import Tensor - -from colossalai.tensor import ColoTensor, ColoTensorSpec -from colossalai.tensor.op_wrapper import colo_op_impl - -from ._utils import GeneralTensor, convert_to_colo_tensor - - -def register_elementwise_op(op): - - @colo_op_impl(op) - def elementwise_op(input_tensor: GeneralTensor, *args, **kwargs): - """ - Handles ``__torch_function__`` dispatch for the elementwise op such - as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``. - This method computes on either a normal tensor or a sharded tensor. - """ - if 'inplace' in kwargs: - # TODO(jiaruifang) inplace will cause bugs - input_tensor = input_tensor.clone() - return op(input_tensor, *args, **kwargs) - else: - output = op(input_tensor, *args, **kwargs) - # return output - if isinstance(input_tensor, ColoTensor): - if isinstance(output, str): - return output - if not isinstance(output, torch.Tensor): - raise NotImplementedError - return ColoTensor.from_torch_tensor(output, - spec=ColoTensorSpec(input_tensor.get_process_group(), - dist_attr=input_tensor.dist_spec)) - - -# @colo_op_impl(torch.relu_) -# def elementwise_op(input_tensor): -# torch.relu_(input_tensor.data) -# return input_tensor - -# @colo_op_impl(Tensor.add_) -# def elementwise_op(input_tensor: ColoTensor, *args, **kwargs): -# input_tensor = input_tensor.data.add_(*args, **kwargs) -# return input_tensor - -# Tensor op -register_elementwise_op(Tensor.abs) -register_elementwise_op(Tensor.absolute) -register_elementwise_op(Tensor.acos) -register_elementwise_op(Tensor.arccos) -register_elementwise_op(Tensor.angle) -register_elementwise_op(Tensor.asin) -register_elementwise_op(Tensor.arcsin) -register_elementwise_op(Tensor.atan) -register_elementwise_op(Tensor.arctan) -register_elementwise_op(Tensor.all) -register_elementwise_op(Tensor.any) -register_elementwise_op(Tensor.bernoulli) -register_elementwise_op(Tensor.bfloat16) -register_elementwise_op(Tensor.bitwise_not) -register_elementwise_op(Tensor.bool) -register_elementwise_op(Tensor.byte) -register_elementwise_op(Tensor.ceil) -register_elementwise_op(Tensor.char) -register_elementwise_op(Tensor.clamp) -register_elementwise_op(Tensor.clamp_max) -register_elementwise_op(Tensor.clamp_min) -register_elementwise_op(Tensor.clip) -register_elementwise_op(Tensor.clone) -register_elementwise_op(Tensor.contiguous) -register_elementwise_op(Tensor.copysign) -register_elementwise_op(Tensor.cos) -register_elementwise_op(Tensor.cosh) -register_elementwise_op(Tensor.acosh) -register_elementwise_op(Tensor.arccosh) -register_elementwise_op(Tensor.cpu) -register_elementwise_op(Tensor.cuda) -register_elementwise_op(Tensor.deg2rad) -register_elementwise_op(Tensor.detach) -register_elementwise_op(Tensor.digamma) -register_elementwise_op(Tensor.double) -register_elementwise_op(Tensor.erf) -register_elementwise_op(Tensor.erfc) -register_elementwise_op(Tensor.erfinv) -register_elementwise_op(Tensor.exp) -register_elementwise_op(Tensor.expm1) -register_elementwise_op(Tensor.fix) -register_elementwise_op(Tensor.trunc) -register_elementwise_op(Tensor.float) -register_elementwise_op(Tensor.float_power) -register_elementwise_op(Tensor.floor) -register_elementwise_op(Tensor.frac) -register_elementwise_op(Tensor.half) -register_elementwise_op(Tensor.hardshrink) -register_elementwise_op(Tensor.heaviside) -register_elementwise_op(Tensor.i0) -register_elementwise_op(Tensor.int) -register_elementwise_op(Tensor.isfinite) -register_elementwise_op(Tensor.isinf) -register_elementwise_op(Tensor.isposinf) -register_elementwise_op(Tensor.isneginf) -register_elementwise_op(Tensor.isnan) -register_elementwise_op(Tensor.lgamma) -register_elementwise_op(Tensor.log) -register_elementwise_op(Tensor.log10) -register_elementwise_op(Tensor.log1p) -register_elementwise_op(Tensor.log2) -register_elementwise_op(Tensor.logical_not) -register_elementwise_op(Tensor.logit) -register_elementwise_op(Tensor.long) -register_elementwise_op(Tensor.nan_to_num) -register_elementwise_op(Tensor.neg) -register_elementwise_op(Tensor.negative) -register_elementwise_op(Tensor.positive) -register_elementwise_op(Tensor.pow) -register_elementwise_op(Tensor.rad2deg) -register_elementwise_op(Tensor.reciprocal) -register_elementwise_op(Tensor.round) -register_elementwise_op(Tensor.rsqrt) -register_elementwise_op(Tensor.short) -register_elementwise_op(Tensor.sigmoid) -register_elementwise_op(Tensor.sign) -register_elementwise_op(Tensor.signbit) -register_elementwise_op(Tensor.sgn) -register_elementwise_op(Tensor.sin) -register_elementwise_op(Tensor.sinc) -register_elementwise_op(Tensor.sinh) -register_elementwise_op(Tensor.asinh) -register_elementwise_op(Tensor.arcsinh) -register_elementwise_op(Tensor.sqrt) -register_elementwise_op(Tensor.square) -register_elementwise_op(Tensor.to) -register_elementwise_op(Tensor.tan) -register_elementwise_op(Tensor.tanh) -register_elementwise_op(Tensor.atanh) -register_elementwise_op(Tensor.arctanh) -register_elementwise_op(Tensor.type) -register_elementwise_op(Tensor.type_as) - -# torch OP -register_elementwise_op(torch.abs) -register_elementwise_op(torch.absolute) -register_elementwise_op(torch.acos) -register_elementwise_op(torch.arccos) -register_elementwise_op(torch.angle) -register_elementwise_op(torch.asin) -register_elementwise_op(torch.arcsin) -register_elementwise_op(torch.atan) -register_elementwise_op(torch.arctan) -register_elementwise_op(torch.all) -register_elementwise_op(torch.any) -register_elementwise_op(torch.bernoulli) -register_elementwise_op(torch.bitwise_not) -register_elementwise_op(torch.ceil) -register_elementwise_op(torch.clamp) -register_elementwise_op(torch.clamp_max) -register_elementwise_op(torch.clamp_min) -register_elementwise_op(torch.clip) -register_elementwise_op(torch.clone) -register_elementwise_op(torch.copysign) -register_elementwise_op(torch.cos) -register_elementwise_op(torch.cosh) -register_elementwise_op(torch.acosh) -register_elementwise_op(torch.arccosh) -register_elementwise_op(torch.deg2rad) -register_elementwise_op(torch.digamma) -register_elementwise_op(torch.erf) -register_elementwise_op(torch.erfc) -register_elementwise_op(torch.erfinv) -register_elementwise_op(torch.exp) -register_elementwise_op(torch.expm1) -register_elementwise_op(torch.fix) -register_elementwise_op(torch.trunc) -register_elementwise_op(torch.float_power) -register_elementwise_op(torch.floor) -register_elementwise_op(torch.frac) -register_elementwise_op(torch.hardshrink) -register_elementwise_op(torch.heaviside) -register_elementwise_op(torch.i0) -register_elementwise_op(torch.isfinite) -register_elementwise_op(torch.isinf) -register_elementwise_op(torch.isposinf) -register_elementwise_op(torch.isneginf) -register_elementwise_op(torch.isnan) -register_elementwise_op(torch.lgamma) -register_elementwise_op(torch.log) -register_elementwise_op(torch.log10) -register_elementwise_op(torch.log1p) -register_elementwise_op(torch.log2) -register_elementwise_op(torch.logical_not) -register_elementwise_op(torch.logit) -register_elementwise_op(torch.nan_to_num) -register_elementwise_op(torch.neg) -register_elementwise_op(torch.negative) -register_elementwise_op(torch.positive) -register_elementwise_op(torch.pow) -register_elementwise_op(torch.rad2deg) -register_elementwise_op(torch.reciprocal) -register_elementwise_op(torch.round) -register_elementwise_op(torch.rsqrt) -register_elementwise_op(torch.sigmoid) -register_elementwise_op(torch.sign) -register_elementwise_op(torch.signbit) -register_elementwise_op(torch.sgn) -register_elementwise_op(torch.sin) -register_elementwise_op(torch.sinc) -register_elementwise_op(torch.sinh) -register_elementwise_op(torch.asinh) -register_elementwise_op(torch.arcsinh) -register_elementwise_op(torch.sqrt) -register_elementwise_op(torch.square) -register_elementwise_op(torch.tan) -register_elementwise_op(torch.tanh) -register_elementwise_op(torch.atanh) -register_elementwise_op(torch.arctanh) -register_elementwise_op(torch.zeros_like) - -# nn.functional OP -register_elementwise_op(F.threshold) -register_elementwise_op(F.relu) -register_elementwise_op(F.hardtanh) -register_elementwise_op(F.hardswish) -register_elementwise_op(F.relu6) -register_elementwise_op(F.elu) -register_elementwise_op(F.selu) -register_elementwise_op(F.celu) -register_elementwise_op(F.leaky_relu) -register_elementwise_op(F.prelu) -register_elementwise_op(F.rrelu) -register_elementwise_op(F.gelu) -register_elementwise_op(F.logsigmoid) -register_elementwise_op(F.hardshrink) -register_elementwise_op(F.tanhshrink) -register_elementwise_op(F.softsign) -register_elementwise_op(F.softplus) -register_elementwise_op(F.softmin) -register_elementwise_op(F.softmax) -register_elementwise_op(F.softshrink) -register_elementwise_op(F.gumbel_softmax) -register_elementwise_op(F.log_softmax) -register_elementwise_op(F.tanh) -register_elementwise_op(F.sigmoid) -register_elementwise_op(F.hardsigmoid) -register_elementwise_op(F.silu) -register_elementwise_op(F.mish) -# TODO(ver217): dropout handles seed -register_elementwise_op(F.dropout) -register_elementwise_op(F.alpha_dropout) -register_elementwise_op(F.feature_alpha_dropout) diff --git a/colossalai/legacy/nn/_ops/embedding.py b/colossalai/legacy/nn/_ops/embedding.py deleted file mode 100644 index b145d1763380..000000000000 --- a/colossalai/legacy/nn/_ops/embedding.py +++ /dev/null @@ -1,142 +0,0 @@ -from typing import Optional - -import torch.nn.functional as F - -from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec -from colossalai.tensor.op_wrapper import colo_op_impl - -from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_input - - -def colo_embedding_1Dcol(input_tensor: ColoTensor, - weight: ColoTensor, - padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, - norm_type: float = 2.0, - scale_grad_by_freq: bool = False, - sparse: bool = False) -> ColoTensor: - # embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P) - # Gather splitted lookup table - input_tensor = input_tensor.redistribute(ReplicaSpec()) - - output_parallel = F.embedding(input_tensor, - weight, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse) - output_spec = ColoTensorSpec(weight.get_process_group(), ShardSpec([-1], [weight.get_tp_world_size()]), - ComputeSpec(ComputePattern.TP1D)) - output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) - - compute_spec = weight.compute_spec - - if compute_spec.output_replicate: - return output.to_replicate() - else: - return output - - -def colo_embedding_1Drow(input_tensor: ColoTensor, - weight: ColoTensor, - padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, - norm_type: float = 2.0, - scale_grad_by_freq: bool = False, - sparse: bool = False) -> ColoTensor: - # embedding_1Drow splits the weight(lookup table) to the shape, [num_embeddings/P, embedding_dim] - # get the index of current segment and mask other segments with 0 - - # get complete input tensor through all-gather - input_tensor = input_tensor.redistribute(ReplicaSpec()) - - # tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - tensor_parallel_rank = weight.get_process_group().tp_local_rank() - num_embeddings_per_partition = weight.size_local(0) - vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition - vocab_end_index = vocab_start_index + num_embeddings_per_partition - - # build the mask. - input_mask = (input_tensor < vocab_start_index) | (input_tensor >= vocab_end_index) - # mask the input. - # TODO(jzy) masked_input may be an activation managed by ColoTensor. - masked_input = input_tensor - vocab_start_index - masked_input[input_mask] = 0 - - partial_output = F.embedding(masked_input, - weight, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse) - - # Mask the output embedding. - partial_output[input_mask, :] = 0. - # Reduce across all the model parallel GPUs. - output = reduce_input(partial_output, weight.get_process_group()) - output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(weight.get_process_group(), ReplicaSpec())) - return output - - -def colo_embedding_1d(mode: str, - input_tensor: ColoTensor, - weight: ColoTensor, - padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, - norm_type: float = 2.0, - scale_grad_by_freq: bool = False, - sparse: bool = False) -> ColoTensor: - assert mode in ('row', 'col') - funcs = {'row': colo_embedding_1Drow, 'col': colo_embedding_1Dcol} - return funcs[mode](input_tensor, - weight, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse) - - -@colo_op_impl(F.embedding) -def colo_embedding(input_tensor: GeneralTensor, - weight: GeneralTensor, - padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, - norm_type: float = 2.0, - scale_grad_by_freq: bool = False, - sparse: bool = False): - """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``. - This method looks up an embedding table. - """ - assert isinstance(weight, ColoTensor) - input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group()) - - if not weight.has_compute_spec(): # No Model Parallel Applied - assert weight.is_replicate(), 'Invalid weight spec for native embedding op' - return ColoTensor.from_torch_tensor(tensor=F.embedding(input_tensor, - weight, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse), - spec=ColoTensorSpec(weight.get_process_group())) - elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied - if weight.is_shard_1drow(): - mode = 'row' - elif weight.is_shard_1dcol(): - mode = 'col' - else: - raise NotImplementedError - return colo_embedding_1d(mode, - input_tensor, - weight, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse) - else: - raise NotImplementedError diff --git a/colossalai/legacy/nn/_ops/embedding_bag.py b/colossalai/legacy/nn/_ops/embedding_bag.py deleted file mode 100644 index 9a656d5871a3..000000000000 --- a/colossalai/legacy/nn/_ops/embedding_bag.py +++ /dev/null @@ -1,127 +0,0 @@ -from typing import Optional - -import torch.nn.functional as F -from torch import Tensor - -from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec, distspec -from colossalai.tensor.op_wrapper import colo_op_impl - -from ._utils import GeneralTensor, convert_to_colo_tensor - - -def colo_embedding_bag_1Dcol(input_tensor: ColoTensor, - weight: ColoTensor, - offsets: Optional[Tensor] = None, - max_norm: Optional[float] = None, - norm_type: float = 2, - scale_grad_by_freq: bool = False, - mode: str = "mean", - sparse: bool = False, - per_sample_weights: Optional[Tensor] = None, - include_last_offset: bool = False, - padding_idx: Optional[int] = None) -> ColoTensor: - # embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P) - # Gather splitted lookup table - pg = weight.get_process_group() - input_tensor = input_tensor.redistribute(ReplicaSpec()) - - output_parallel = F.embedding_bag(input_tensor, - weight, - offsets=offsets, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - mode=mode, - sparse=sparse, - per_sample_weights=per_sample_weights, - include_last_offset=include_last_offset, - padding_idx=padding_idx) - output_spec = ColoTensorSpec(pg, ShardSpec([-1], [weight.get_tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) - - if weight.compute_spec.output_replicate: - return output.to_replicate() - else: - return output - - -def colo_embedding_bag_1d(tp_mode: str, - input_tensor: ColoTensor, - weight: ColoTensor, - offsets: Optional[Tensor] = None, - max_norm: Optional[float] = None, - norm_type: float = 2, - scale_grad_by_freq: bool = False, - mode: str = "mean", - sparse: bool = False, - per_sample_weights: Optional[Tensor] = None, - include_last_offset: bool = False, - padding_idx: Optional[int] = None) -> ColoTensor: - assert tp_mode in ('col',) - funcs = {'col': colo_embedding_bag_1Dcol} - return funcs[tp_mode](input_tensor, - weight, - offsets=offsets, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - mode=mode, - sparse=sparse, - per_sample_weights=per_sample_weights, - include_last_offset=include_last_offset, - padding_idx=padding_idx) - - -@colo_op_impl(F.embedding_bag) -def colo_embedding_bag(input_tensor: GeneralTensor, - weight: GeneralTensor, - offsets: Optional[Tensor] = None, - max_norm: Optional[float] = None, - norm_type: float = 2, - scale_grad_by_freq: bool = False, - mode: str = "mean", - sparse: bool = False, - per_sample_weights: Optional[Tensor] = None, - include_last_offset: bool = False, - padding_idx: Optional[int] = None): - """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding_bag``. - This method looks up an embedding table. - """ - assert isinstance(weight, ColoTensor) - input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group()) - - # Handle different parallel actions. - - if not weight.has_compute_spec(): # No Model Parallel Applied - assert weight.is_replicate(), 'Invalid weight spec for native embedding op' - return ColoTensor.from_torch_tensor(tensor=F.embedding_bag(input_tensor, - weight, - offsets=offsets, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - mode=mode, - sparse=sparse, - per_sample_weights=per_sample_weights, - include_last_offset=include_last_offset, - padding_idx=padding_idx), - spec=ColoTensorSpec(weight.get_process_group())) - elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied - if weight.is_shard_1dcol(): - tp_mode = 'col' - else: - raise NotImplementedError - return colo_embedding_bag_1d(tp_mode, - input_tensor, - weight, - offsets=offsets, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - mode=mode, - sparse=sparse, - per_sample_weights=per_sample_weights, - include_last_offset=include_last_offset, - padding_idx=padding_idx) - else: - raise NotImplementedError diff --git a/colossalai/legacy/nn/_ops/layernorm.py b/colossalai/legacy/nn/_ops/layernorm.py deleted file mode 100644 index 9960c5d48096..000000000000 --- a/colossalai/legacy/nn/_ops/layernorm.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import List, Optional - -import torch.nn.functional as F - -from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec, distspec -from colossalai.tensor.op_wrapper import colo_op_impl - -from ._utils import GeneralTensor, convert_to_colo_tensor - - -@colo_op_impl(F.layer_norm) -def colo_layernorm( - input_tensor: GeneralTensor, - normalized_shape: List[int], - weight: Optional[GeneralTensor] = None, - bias: Optional[GeneralTensor] = None, - eps: float = 1e-5, -): - assert isinstance(weight, ColoTensor) - input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group()) - bias = convert_to_colo_tensor(bias, weight.get_process_group()) - input_tensor = input_tensor.redistribute(ReplicaSpec()) - - output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps) - output = ColoTensor.from_torch_tensor(tensor=output, - spec=ColoTensorSpec(pg=input_tensor.get_process_group(), - dist_attr=input_tensor.dist_spec)) - return output diff --git a/colossalai/legacy/nn/_ops/linear.py b/colossalai/legacy/nn/_ops/linear.py deleted file mode 100644 index 2f2088c61fa8..000000000000 --- a/colossalai/legacy/nn/_ops/linear.py +++ /dev/null @@ -1,171 +0,0 @@ -from copy import deepcopy -from typing import Optional - -import torch.nn.functional as F - -from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec -from colossalai.tensor.op_wrapper import colo_op_impl -from colossalai.tensor.sharding_spec import ShardingSpec - -from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_grad, reduce_input - - -def colo_linear_1drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': - # Input:S[1] x Weight:S[0] = Output:P - # All-Reduce(Output) + bias = res - # Input:S[1] - pg = weight.get_process_group() - input_tensor = input_tensor.redistribute(ShardSpec([-1], [weight.get_tp_world_size()]), pg) - - # Output:P - partial_output = F.linear(input_tensor, weight) - # Reduce(Output) - - output = reduce_input(partial_output, pg) - # Bias - if bias is not None: - assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op' - output = output + bias - - output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(pg, ReplicaSpec())) - return output - - -def colo_linear_1dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': - # Input:B x Weight:S[1] + Bias:S[1] = Output:S[1] - # All-Gather(Output) - # Input:B - compute_spec = weight.compute_spec - input_tensor = input_tensor.redistribute(ReplicaSpec()) - input_parallel = reduce_grad(input_tensor, weight.get_process_group()) - - output_parallel = F.linear(input_parallel, weight, bias) - output = ColoTensor.from_torch_tensor(output_parallel, - spec=ColoTensorSpec(weight.get_process_group(), - ShardSpec([-1], [weight.get_tp_world_size()]), - ComputeSpec(ComputePattern.TP1D))) - if compute_spec.output_replicate: - return output.to_replicate() - else: - return output - - -def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': - assert mode in ('row', 'col') - funcs = {'row': colo_linear_1drow, 'col': colo_linear_1dcol} - return funcs[mode](input_tensor, weight, bias) - - -# @register_colo_graph(input_pos=[1], param_pos=[2, 3]) -def colo_linear_imp(input_tensor: GeneralTensor, - weight: GeneralTensor, - bias: Optional[GeneralTensor] = None) -> 'ColoTensor': - """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``. - This method computes a linear. - """ - assert isinstance(weight, ColoTensor) - pg = weight.get_process_group() - assert pg - input_tensor = convert_to_colo_tensor(input_tensor, pg) - bias = convert_to_colo_tensor(bias, pg) - # input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias))) - - # Add communication logic before and after linear call. - ret_tensor = None - if not weight.has_compute_spec(): # No Model Parallel Applied - assert weight.is_replicate(), 'Invalid weight spec for native Linear op' - assert bias is None or bias.is_replicate(), 'Invalid bias spec for native Linear op' - ret_tensor = ColoTensor.from_torch_tensor(F.linear(input_tensor, weight, bias), spec=ColoTensorSpec(pg)) - elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied - if weight.is_shard_1dcol() and (bias is None or bias.is_replicate()): - mode = 'row' - elif weight.is_shard_1drow() and (bias is None or bias.is_shard_1drow() or bias.is_shard_1dcol()): - mode = 'col' - else: - raise RuntimeError(f"the weight or bias tensor spec is not valid, weight {weight}, bias {bias}") - ret_tensor = colo_linear_1d(mode, input_tensor, weight, bias) - else: - raise NotImplementedError - - return ret_tensor - - -def _new_colo_linear_imp(input_tensor: GeneralTensor, - weight: GeneralTensor, - bias: Optional[GeneralTensor] = None) -> 'ColoTensor': - """ - A tentative function to compute the distributed linear layer with the latest sharding spec. - This function is subject to future change as the current sharding API is not stable. - """ - # get mesh info - input_sharding_seq = input_tensor.sharding_spec.sharding_sequence - weight_sharding_seq = weight.sharding_spec.sharding_sequence - if bias is not None: - bias_sharding_seq = bias.sharding_spec.sharding_sequence - device_mesh = weight.sharding_spec.device_mesh - pg_axis0 = weight.pg_axis0 - pg_axis1 = weight.pg_axis1 - - # the last dim of input should have the same spec as the first dim of weight - # the weight is transposed, so we look at the second dimension - assert input_sharding_seq[-1] == weight_sharding_seq[1] - - if bias is not None: - assert bias_sharding_seq[0] == weight_sharding_seq[0] - - # compute the output sharding sequence - # as weight is transposed, so we look at the first dimension - output_shard_seq = input_sharding_seq[:-1] + weight_sharding_seq[:1] - output_shard_seq = deepcopy(output_shard_seq) - - # TODO: add reduce grad logic - - # handle column and row parallel linear - # by reusing the implementation above - out = F.linear(input_tensor, weight) - - # run all reduce if necessary - last_dim_spec = input_sharding_seq[-1] - if last_dim_spec.is_replica: - pass - elif last_dim_spec.shard_list is not None: - for dim in last_dim_spec.shard_list: - if dim == 0: - reduce_input(out, pg_axis0) - elif dim == 1: - reduce_input(out, pg_axis1) - else: - raise RuntimeError("Found invalid sharding axis {dim}, only 0 or 1 is expected") - # add bias - if bias is not None: - out += bias - - # convert shard seq to partition dict - output_partition_dict = {} - for index, dim_spec in enumerate(output_shard_seq): - if not dim_spec.is_replica: - if index not in output_partition_dict: - output_partition_dict[index] = [] - output_partition_dict[index].extend(dim_spec.shard_list) - - entire_shape = out.shape - output_sharding_spec = ShardingSpec(device_mesh, entire_shape, output_partition_dict) - ret_tensor = ColoTensor.from_torch_tensor(out) - setattr(ret_tensor, 'sharding_spec', output_sharding_spec) - return ret_tensor - - -def _has_sharding_spec(tensor): - """ - A tentative function to check whether the tensor is using the new sharding spec API. We assume that the sharding spec object is - set as the attribute `sharding_spec` on a tensor. - """ - return hasattr(tensor, 'sharding_spec') - - -@colo_op_impl(F.linear) -def colo_linear(input: GeneralTensor, weight: GeneralTensor, bias: Optional[GeneralTensor] = None) -> 'ColoTensor': - if _has_sharding_spec(weight): - return _new_colo_linear_imp(input, weight, bias) - else: - return colo_linear_imp(input, weight, bias) diff --git a/colossalai/legacy/nn/_ops/loss.py b/colossalai/legacy/nn/_ops/loss.py deleted file mode 100644 index 90efbfa36f2a..000000000000 --- a/colossalai/legacy/nn/_ops/loss.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import Optional - -import torch -import torch.nn.functional as F - -from colossalai.legacy.nn.loss.loss_1d import VocabParallelCrossEntropyLoss1D -from colossalai.tensor import ColoTensor, ColoTensorSpec -from colossalai.tensor.op_wrapper import colo_op_impl - -from ._utils import GeneralTensor, convert_to_colo_tensor - - -@colo_op_impl(F.cross_entropy) -def colo_cross_entropy(input_tensor: GeneralTensor, - target: GeneralTensor, - weight: Optional[GeneralTensor] = None, - size_average: Optional[bool] = None, - ignore_index: int = -100, - reduce: Optional[bool] = None, - reduction: str = "mean", - label_smoothing: float = 0.0): - assert isinstance(weight, ColoTensor) or isinstance(target, ColoTensor) or isinstance(input_tensor, ColoTensor) - pg = input_tensor.get_process_group() if isinstance(input_tensor, ColoTensor) else isinstance(target, ColoTensor) - weight = convert_to_colo_tensor(weight, pg) - target = convert_to_colo_tensor(target, pg) - input_tensor = convert_to_colo_tensor(input_tensor, pg) - - if input_tensor.is_replicate(): # Input is gathered - assert target.is_replicate() and (weight is None or weight.is_replicate()), \ - "Target tensor and weight tensor both should be complete" - output = F.cross_entropy(input_tensor, - target, - weight=weight, - size_average=size_average, - ignore_index=ignore_index, - reduce=reduce, - reduction=reduction, - label_smoothing=label_smoothing) - return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg)) - elif input_tensor.has_compute_spec(): # Single Model Parallel Applied - if input_tensor.is_shard_1dcol(): - assert weight is None, "Current TP cross entropy loss function doesn't support passing weight tensor in" - assert target.is_replicate(), "Target tensor should be complete in TP cross entropy loss function" - output = VocabParallelCrossEntropyLoss1D()(input_tensor, - target, - process_group=input_tensor.process_group.tp_process_group()) - return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg)) - else: - raise NotImplementedError - else: - raise NotImplementedError diff --git a/colossalai/legacy/nn/_ops/view.py b/colossalai/legacy/nn/_ops/view.py deleted file mode 100644 index 3c0bc52337ce..000000000000 --- a/colossalai/legacy/nn/_ops/view.py +++ /dev/null @@ -1,96 +0,0 @@ -import operator -from functools import reduce -from typing import Optional, Union - -import torch - -from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec -from colossalai.tensor.op_wrapper import colo_op_impl - - -def _all_int(my_iter): - return all(isinstance(i, int) for i in my_iter) - - -def _get_valid_shape(shape): - if isinstance(shape, list): - if _all_int(shape): - return tuple(shape) - else: - raise RuntimeError("expects type(int) but finds an other type") - elif isinstance(shape, tuple): - if _all_int(shape): - return shape - else: - return _get_valid_shape(shape[0]) - else: - raise RuntimeError("expects an iterable array but finds '{}'".format(type(shape))) - - -def _shape_infer(org_sp, tgt_sp): - cnt = 0 - pos = 0 - for idx, dim in enumerate(tgt_sp): - if dim < -1: - raise RuntimeError("invalid shape dimension {}".format(dim)) - elif dim == -1: - cnt += 1 - pos = idx - - if cnt > 1: - raise RuntimeError("only one dimension can be inferred") - - org_prod = reduce(operator.mul, org_sp, 1) - tgt_prod = reduce(operator.mul, tgt_sp, 1) - - if cnt == 0: - if org_prod != tgt_prod: - raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod)) - else: - return tgt_sp - elif org_prod % tgt_prod != 0: - raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod)) - - infer_dim = -(org_prod // tgt_prod) - return tgt_sp[:pos] + (infer_dim,) + tgt_sp[pos + 1:] - - -@colo_op_impl(torch.Tensor.view) -def colo_view(self: ColoTensor, *shape) -> 'ColoTensor': - """Handles ``__torch_function__`` dispatch for ``torch.Tensor.view``. - Changes the shape of the current tensor. - """ - assert isinstance(self, ColoTensor) - # apply original `view` function for replicated colo tensors - if self.is_replicate(): - return self.view(*shape) - - cur_sp = self.size() - org_sp = self.size_global() - # parse the passed arguments - tgt_sp = _get_valid_shape(shape) - # get the correct shape from inference - inf_sp = _shape_infer(org_sp, tgt_sp) - - if self.is_shard_1drow() and org_sp[0] == inf_sp[0]: - new_shape = (cur_sp[0],) + tgt_sp[1:] - res = self.view(*new_shape) - elif self.is_shard_1dcol() and org_sp[-1] == inf_sp[-1]: - new_shape = tgt_sp[:-1] + (cur_sp[-1],) - res = self.view(*new_shape) - else: - replicated_t = self.redistribute(dist_spec=ReplicaSpec()) - return ColoTensor.from_torch_tensor(tensor=replicated_t.view(*shape), - spec=ColoTensorSpec(self.get_process_group())) - - return ColoTensor.from_torch_tensor(tensor=res, - spec=ColoTensorSpec(pg=self.get_process_group(), dist_attr=self.dist_spec)) - - -@colo_op_impl(torch.Tensor.size) -def colo_size(self: ColoTensor, dim: Optional[int] = None) -> Union[torch.Size, int]: - size = self.size_global() - if dim is None: - return size - else: - return size[dim] diff --git a/colossalai/legacy/nn/parallel/data_parallel.py b/colossalai/legacy/nn/parallel/data_parallel.py index f839d6b28444..328c6cc01de8 100644 --- a/colossalai/legacy/nn/parallel/data_parallel.py +++ b/colossalai/legacy/nn/parallel/data_parallel.py @@ -5,7 +5,7 @@ import torch import torch.distributed as dist -from colossalai.tensor import ProcessGroup as ColoProcessGroup +from colossalai.legacy.tensor import ProcessGroup as ColoProcessGroup from colossalai.utils import is_ddp_ignored from .reducer import Reducer diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py index 79d7672b26bc..522fb4f4497f 100644 --- a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py @@ -4,7 +4,8 @@ import torch.nn.functional as F from colossalai.legacy.nn._ops._utils import dual_all_to_all -from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ComputePattern, ProcessGroup, ShardSpec +from colossalai.legacy.tensor import ColoTensorSpec, ComputePattern, ProcessGroup, ShardSpec +from colossalai.tensor import ColoParameter, ColoTensor from .cache_mgr import CachedParamMgr, EvictionStrategy from .cached_embedding import CachedEmbeddingBag diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py index 116d836b7139..a1feda2bdb0e 100644 --- a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from colossalai.legacy.nn._ops._utils import dual_all_to_all_tablewise -from colossalai.tensor import ProcessGroup +from colossalai.legacy.tensor import ProcessGroup from .cache_mgr import EvictionStrategy from .cached_embedding import CachedEmbeddingBag diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py index 0014c784fba1..8017ee72b0b4 100644 --- a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py @@ -7,7 +7,7 @@ from torch.profiler import record_function from colossalai.legacy.nn._ops._utils import dual_all_to_all_tablewise -from colossalai.tensor import ProcessGroup +from colossalai.legacy.tensor import ProcessGroup from .cache_mgr import EvictionStrategy from .cached_embedding import CachedEmbeddingBag diff --git a/colossalai/legacy/nn/parallel/layers/colo_module.py b/colossalai/legacy/nn/parallel/layers/colo_module.py index a0a3eb40cf08..69d92afaaa94 100644 --- a/colossalai/legacy/nn/parallel/layers/colo_module.py +++ b/colossalai/legacy/nn/parallel/layers/colo_module.py @@ -1,7 +1,7 @@ from typing import Dict, List -from colossalai.tensor import ComputePattern -from colossalai.tensor.distspec import _DistSpec +from colossalai.legacy.tensor import ComputePattern +from colossalai.legacy.tensor.distspec import _DistSpec class ColoModule(object): diff --git a/colossalai/legacy/nn/parallel/layers/embedding.py b/colossalai/legacy/nn/parallel/layers/embedding.py index 3e4e7ffd8de7..4796699fc57f 100644 --- a/colossalai/legacy/nn/parallel/layers/embedding.py +++ b/colossalai/legacy/nn/parallel/layers/embedding.py @@ -1,4 +1,4 @@ -from colossalai.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec +from colossalai.legacy.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec from .colo_module import ColoModule diff --git a/colossalai/legacy/nn/parallel/layers/linear.py b/colossalai/legacy/nn/parallel/layers/linear.py index e391cf808933..51a8d4c976a6 100644 --- a/colossalai/legacy/nn/parallel/layers/linear.py +++ b/colossalai/legacy/nn/parallel/layers/linear.py @@ -1,4 +1,4 @@ -from colossalai.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec +from colossalai.legacy.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec from .colo_module import ColoModule diff --git a/colossalai/legacy/nn/parallel/layers/module_utils.py b/colossalai/legacy/nn/parallel/layers/module_utils.py index 191266fa70fd..09326d2d6f9a 100644 --- a/colossalai/legacy/nn/parallel/layers/module_utils.py +++ b/colossalai/legacy/nn/parallel/layers/module_utils.py @@ -2,7 +2,8 @@ import torch -from colossalai.tensor import ColoParameter, ComputeSpec, ProcessGroup, distspec +from colossalai.legacy.tensor import ComputeSpec, ProcessGroup, distspec +from colossalai.tensor import ColoParameter from . import ColoModule diff --git a/colossalai/legacy/pipeline/pipeline_process_group.py b/colossalai/legacy/pipeline/pipeline_process_group.py index c0ee0286787f..1168158defaf 100644 --- a/colossalai/legacy/pipeline/pipeline_process_group.py +++ b/colossalai/legacy/pipeline/pipeline_process_group.py @@ -5,7 +5,7 @@ import torch.distributed as dist from torch.distributed import rpc -from colossalai.tensor import ProcessGroup +from colossalai.legacy.tensor import ProcessGroup class PipelineProcessGroup: diff --git a/colossalai/legacy/tensor/__init__.py b/colossalai/legacy/tensor/__init__.py new file mode 100644 index 000000000000..d3278bf1e420 --- /dev/null +++ b/colossalai/legacy/tensor/__init__.py @@ -0,0 +1,17 @@ +from . import distspec +from .compute_spec import ComputePattern, ComputeSpec +from .dist_spec_mgr import DistSpecManager +from .distspec import ReplicaSpec, ShardSpec +from .process_group import ProcessGroup +from .tensor_spec import ColoTensorSpec + +__all__ = [ + 'ComputePattern', + 'ComputeSpec', + 'distspec', + 'DistSpecManager', + 'ProcessGroup', + 'ColoTensorSpec', + 'ShardSpec', + 'ReplicaSpec', +] diff --git a/colossalai/tensor/compute_spec.py b/colossalai/legacy/tensor/compute_spec.py similarity index 100% rename from colossalai/tensor/compute_spec.py rename to colossalai/legacy/tensor/compute_spec.py diff --git a/colossalai/tensor/const.py b/colossalai/legacy/tensor/const.py similarity index 100% rename from colossalai/tensor/const.py rename to colossalai/legacy/tensor/const.py diff --git a/colossalai/tensor/dist_spec_mgr.py b/colossalai/legacy/tensor/dist_spec_mgr.py similarity index 97% rename from colossalai/tensor/dist_spec_mgr.py rename to colossalai/legacy/tensor/dist_spec_mgr.py index 4740a316b7f5..d97308b04bef 100644 --- a/colossalai/tensor/dist_spec_mgr.py +++ b/colossalai/legacy/tensor/dist_spec_mgr.py @@ -4,12 +4,12 @@ import torch.distributed as dist from numpy import prod -from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec -from colossalai.tensor.process_group import ProcessGroup +from colossalai.legacy.tensor.distspec import DistPlacementPattern, _DistSpec +from colossalai.legacy.tensor.process_group import ProcessGroup # TODO(jiaruifang) circle import, move the divide to colossalai.commons. -# colossalai.tensor shall not import any submodule from colossal.nn +# colossalai.legacy.tensor shall not import any submodule from colossal.nn def divide(numerator, denominator): """Only allow exact division. diff --git a/colossalai/tensor/distspec.py b/colossalai/legacy/tensor/distspec.py similarity index 100% rename from colossalai/tensor/distspec.py rename to colossalai/legacy/tensor/distspec.py diff --git a/colossalai/tensor/op_wrapper.py b/colossalai/legacy/tensor/op_wrapper.py similarity index 97% rename from colossalai/tensor/op_wrapper.py rename to colossalai/legacy/tensor/op_wrapper.py index 1c00066f7465..63ebaa264279 100644 --- a/colossalai/tensor/op_wrapper.py +++ b/colossalai/legacy/tensor/op_wrapper.py @@ -1,8 +1,5 @@ -from typing import ( - Callable, - Dict, -) import functools +from typing import Callable, Dict # Custom sharded ops _COLOSSAL_OPS: Dict[str, Callable] = {} diff --git a/colossalai/tensor/process_group.py b/colossalai/legacy/tensor/process_group.py similarity index 100% rename from colossalai/tensor/process_group.py rename to colossalai/legacy/tensor/process_group.py diff --git a/colossalai/tensor/tensor_spec.py b/colossalai/legacy/tensor/tensor_spec.py similarity index 79% rename from colossalai/tensor/tensor_spec.py rename to colossalai/legacy/tensor/tensor_spec.py index 580df9f8f310..aa792e507639 100644 --- a/colossalai/tensor/tensor_spec.py +++ b/colossalai/legacy/tensor/tensor_spec.py @@ -1,8 +1,8 @@ from dataclasses import dataclass from typing import Optional -from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec -from colossalai.tensor.process_group import ProcessGroup +from colossalai.legacy.tensor.distspec import DistPlacementPattern, _DistSpec +from colossalai.legacy.tensor.process_group import ProcessGroup from .compute_spec import ComputeSpec diff --git a/colossalai/tensor/__init__.py b/colossalai/tensor/__init__.py index b2da64e6c33a..099376d931e8 100644 --- a/colossalai/tensor/__init__.py +++ b/colossalai/tensor/__init__.py @@ -1,18 +1,11 @@ -from . import distspec from .colo_parameter import ColoParameter from .colo_tensor import ColoTensor from .comm_spec import CollectiveCommPattern, CommSpec -from .compute_spec import ComputePattern, ComputeSpec -from .dist_spec_mgr import DistSpecManager -from .distspec import ReplicaSpec, ShardSpec from .param_op_hook import ColoParamOpHook, ColoParamOpHookManager -from .process_group import ProcessGroup -from .tensor_spec import ColoTensorSpec from .utils import convert_dim_partition_dict, convert_parameter, merge_same_dim_mesh_list, named_params_with_colotensor __all__ = [ - 'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter', - 'distspec', 'DistSpecManager', 'ColoParamOpHook', 'ColoParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', - 'ShardSpec', 'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict', + 'ColoTensor', 'convert_parameter', 'named_params_with_colotensor', 'ColoParameter', 'ColoParamOpHook', + 'ColoParamOpHookManager', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict', 'merge_same_dim_mesh_list' ] diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index 998901708239..5e5dd7224dfa 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -18,7 +18,8 @@ from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.global_variables import tensor_parallel_env as env -from colossalai.tensor import ColoParameter, ProcessGroup +from colossalai.legacy.tensor import ProcessGroup +from colossalai.tensor import ColoParameter from .multi_tensor_apply import multi_tensor_applier diff --git a/colossalai/zero/gemini/colo_init_context.py b/colossalai/zero/gemini/colo_init_context.py index dad852a34a71..549635af4332 100644 --- a/colossalai/zero/gemini/colo_init_context.py +++ b/colossalai/zero/gemini/colo_init_context.py @@ -3,7 +3,8 @@ import torch from torch import nn -from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup +from colossalai.legacy.tensor import ProcessGroup +from colossalai.tensor import ColoParameter, ColoTensor from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses # find named_params includes replica diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py index 4e3c26c1ba9c..715f62358e2d 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py @@ -13,10 +13,9 @@ from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.nn.optimizer import HybridAdam -from colossalai.tensor.process_group import ProcessGroup from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn from colossalai.utils import get_current_device -from colossalai.zero import post_process_colo_init_ctx, zero_model_wrapper, zero_optim_wrapper +from colossalai.zero import zero_model_wrapper, zero_optim_wrapper class MLP(torch.nn.Module): @@ -70,14 +69,12 @@ def check_auto_parallel_with_gemini(rank, world_size, port): print(strategy) print('=' * msg_length) - dp_process_group = ProcessGroup(rank=rank, ranks=[0, 1, 2, 3], tp_degree=2, dp_degree=2) gemini_config = dict(strict_ddp_mode=False, device=get_current_device(), placement_policy='cpu', pin_memory=True, search_range_m=128) - post_process_colo_init_ctx(gm, device=get_current_device(), default_pg=dp_process_group) gm = zero_model_wrapper(gm, zero_stage=3, gemini_config=gemini_config) optimizer = HybridAdam(gm.parameters(), betas=(0, 0)) optimizer = zero_optim_wrapper(gm, optimizer, initial_scale=1) diff --git a/tests/test_legacy/test_layers/test_cache_embedding.py b/tests/test_legacy/test_layers/test_cache_embedding.py index 0760a3f1ec38..3b1bb1f96eec 100644 --- a/tests/test_legacy/test_layers/test_cache_embedding.py +++ b/tests/test_legacy/test_layers/test_cache_embedding.py @@ -14,7 +14,8 @@ ParallelCachedEmbeddingBagTablewise, TablewiseEmbeddingBagConfig, ) -from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.legacy.tensor import ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.tensor import ColoTensor from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn NUM_EMBED, EMBED_DIM = 10, 8 diff --git a/tests/test_tensor/common_utils/__init__.py b/tests/test_legacy/test_tensor/common_utils/__init__.py similarity index 95% rename from tests/test_tensor/common_utils/__init__.py rename to tests/test_legacy/test_tensor/common_utils/__init__.py index 5387db70445f..9a35d02ce5ed 100644 --- a/tests/test_tensor/common_utils/__init__.py +++ b/tests/test_legacy/test_tensor/common_utils/__init__.py @@ -1 +1 @@ -from ._utils import * +from ._utils import * diff --git a/tests/test_tensor/common_utils/_utils.py b/tests/test_legacy/test_tensor/common_utils/_utils.py similarity index 97% rename from tests/test_tensor/common_utils/_utils.py rename to tests/test_legacy/test_tensor/common_utils/_utils.py index b405f8cd2108..b793851aef2b 100644 --- a/tests/test_tensor/common_utils/_utils.py +++ b/tests/test_legacy/test_tensor/common_utils/_utils.py @@ -8,7 +8,7 @@ from colossalai.context import ParallelMode from colossalai.core import global_context as gpc -from colossalai.tensor import ComputePattern, ComputeSpec, ShardSpec +from colossalai.legacy.tensor import ComputePattern, ComputeSpec, ShardSpec def set_seed(seed): diff --git a/tests/test_tensor/core/test_dist_spec_mgr.py b/tests/test_legacy/test_tensor/core/test_dist_spec_mgr.py similarity index 96% rename from tests/test_tensor/core/test_dist_spec_mgr.py rename to tests/test_legacy/test_tensor/core/test_dist_spec_mgr.py index 89476a35b63a..3102d6f0aece 100644 --- a/tests/test_tensor/core/test_dist_spec_mgr.py +++ b/tests/test_legacy/test_tensor/core/test_dist_spec_mgr.py @@ -5,7 +5,7 @@ import torch.distributed as dist import colossalai -from colossalai.tensor import DistSpecManager, ProcessGroup, ReplicaSpec, ShardSpec +from colossalai.legacy.tensor import DistSpecManager, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.testing import rerun_if_address_is_in_use, spawn diff --git a/tests/test_tensor/test_parameter.py b/tests/test_legacy/test_tensor/test_parameter.py similarity index 92% rename from tests/test_tensor/test_parameter.py rename to tests/test_legacy/test_tensor/test_parameter.py index 9c3f05da1ffa..68508df6df45 100644 --- a/tests/test_tensor/test_parameter.py +++ b/tests/test_legacy/test_tensor/test_parameter.py @@ -3,7 +3,7 @@ from common_utils import tensor_equal import colossalai -from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ProcessGroup +from colossalai.tensor import ColoParameter, ColoTensor from colossalai.testing import free_port From 86653f3fb16dd6ac0573904dddc755ec4f997f89 Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 12 Sep 2023 15:48:05 +0800 Subject: [PATCH 2/2] [test] fix test import --- colossalai/utils/__init__.py | 2 ++ colossalai/utils/common.py | 7 +++++++ tests/test_zero/test_gemini/test_chunk_mgrv2.py | 2 -- tests/test_zero/test_gemini/test_fwd_bwd.py | 2 +- tests/test_zero/test_gemini/test_grad_clip.py | 2 +- tests/test_zero/test_gemini/test_inference.py | 2 +- tests/test_zero/test_gemini/test_optim.py | 2 +- tests/test_zero/test_gemini/test_zeroddp_state_dict.py | 2 +- tests/test_zero/test_gemini/test_zerooptim_state_dict.py | 2 +- 9 files changed, 15 insertions(+), 8 deletions(-) diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index 6f9717d353e6..ba4c0423ee3b 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -20,6 +20,7 @@ multi_tensor_applier, param_is_not_tensor_parallel_duplicate, print_rank_0, + set_seed, switch_virtual_pipeline_parallel_rank, sync_model_param, ) @@ -76,4 +77,5 @@ 'colo_get_cpu_memory_capacity', '_cast_float', 'free_storage', + 'set_seed', ] diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index 5e5dd7224dfa..92159b57eb45 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -9,6 +9,7 @@ from pathlib import Path from typing import Callable, Dict, List, Optional, Union +import numpy as np import torch import torch.distributed as dist from torch import inf @@ -490,3 +491,9 @@ def _cast_float(args, dtype: torch.dtype): elif isinstance(args, dict): args = {k: _cast_float(v, dtype) for k, v in args.items()} return args + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) diff --git a/tests/test_zero/test_gemini/test_chunk_mgrv2.py b/tests/test_zero/test_gemini/test_chunk_mgrv2.py index d6c4f8bd8aac..f05ccfdbd41b 100644 --- a/tests/test_zero/test_gemini/test_chunk_mgrv2.py +++ b/tests/test_zero/test_gemini/test_chunk_mgrv2.py @@ -6,7 +6,6 @@ from colossalai.tensor import ColoTensor from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.zero.gemini.chunk import ChunkManager -from tests.test_tensor.common_utils import debug_print CUDA_MEM_0 = {False: 512, True: 1024} CUDA_MEM_1 = {False: 0, True: 1024} @@ -16,7 +15,6 @@ @parameterize('keep_gathered', [True, False]) @parameterize('pin_memory', [True, False]) def exam_chunk_memory(keep_gathered, pin_memory): - debug_print([0], "keep_gathered: {}, pin_memory: {}".format(keep_gathered, pin_memory)) params = [ColoTensor(torch.rand(8, 8)) for _ in range(3)] config = {2: dict(chunk_size=128, keep_gathered=keep_gathered)} diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py index 4cbf564ecfb9..f1d2656db791 100644 --- a/tests/test_zero/test_gemini/test_fwd_bwd.py +++ b/tests/test_zero/test_gemini/test_fwd_bwd.py @@ -8,12 +8,12 @@ from colossalai.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import set_seed from colossalai.utils.cuda import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import set_seed PLACEMENT_CONFIGS = [ { diff --git a/tests/test_zero/test_gemini/test_grad_clip.py b/tests/test_zero/test_gemini/test_grad_clip.py index 82b9133b89c1..90af1af93e87 100644 --- a/tests/test_zero/test_gemini/test_grad_clip.py +++ b/tests/test_zero/test_gemini/test_grad_clip.py @@ -8,11 +8,11 @@ from colossalai.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import set_seed from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import set_seed PLACEMENT_CONFIGS = [ { diff --git a/tests/test_zero/test_gemini/test_inference.py b/tests/test_zero/test_gemini/test_inference.py index 20d145f9661f..d316ec0db114 100644 --- a/tests/test_zero/test_gemini/test_inference.py +++ b/tests/test_zero/test_gemini/test_inference.py @@ -10,12 +10,12 @@ from colossalai.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import set_seed from colossalai.utils.cuda import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import set_seed PLACEMENT_CONFIGS = [ { diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index edcbada0acbb..998d84385f95 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -8,12 +8,12 @@ from colossalai.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import set_seed from colossalai.utils.cuda import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import set_seed PLACEMENT_CONFIGS = [ { diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py index 656bd709e2a1..602e3ad3519d 100644 --- a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py @@ -4,10 +4,10 @@ import colossalai from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import set_seed from colossalai.zero import GeminiDDP from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import set_seed PLACEMENT_CONFIGS = [ { diff --git a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py index 09725e11ec0c..5f7b51510d58 100644 --- a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py +++ b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py @@ -5,10 +5,10 @@ import colossalai from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import set_seed from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import set_seed PLACEMENT_CONFIGS = [ {