Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/context/moe_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from colossalai.context.parallel_mode import ParallelMode
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor import ProcessGroup
from colossalai.legacy.tensor import ProcessGroup


def _check_sanity():
Expand Down
16 changes: 9 additions & 7 deletions colossalai/fx/passes/shard_1d_pass.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import operator

import torch
import torch.nn as nn
import operator
from colossalai.tensor import ProcessGroup
from colossalai.tensor.distspec import ShardSpec
from colossalai.tensor.compute_spec import ComputePattern, ComputeSpec

from colossalai.legacy.tensor import ProcessGroup
from colossalai.legacy.tensor.compute_spec import ComputePattern, ComputeSpec
from colossalai.legacy.tensor.distspec import ShardSpec

ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
ELEMENTWISE_FUNC_OP = [
Expand All @@ -13,7 +15,7 @@


def weight_split(weight: torch.nn.parameter.Parameter, dim: int, col_normal: bool) -> torch.nn.parameter.Parameter:
"""weight_split
"""weight_split
split a nn.Parameter

Args:
Expand Down Expand Up @@ -60,9 +62,9 @@ def row_shard_linear_pass(gm: torch.fx.GraphModule):

def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: ProcessGroup):
"""
This IR pass checks for transformer MLP like structure and annotate column and row sharding to the linear layers.
This IR pass checks for transformer MLP like structure and annotate column and row sharding to the linear layers.
"""
#TODO: Needs to handle special cases, like x = linear(x) + linear(x)
# TODO: Needs to handle special cases, like x = linear(x) + linear(x)
graph = graph_module.graph
world_size = process_group.world_size()

Expand Down
1 change: 0 additions & 1 deletion colossalai/legacy/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from ._ops import *
from .layer import *
from .loss import *
from .metric import *
10 changes: 1 addition & 9 deletions colossalai/legacy/nn/_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1 @@
from .addmm import colo_addmm
from .batch_norm import colo_batch_norm
from .element_wise import *
from .embedding import colo_embedding
from .embedding_bag import colo_embedding_bag
from .layernorm import colo_layernorm
from .linear import colo_linear
from .loss import colo_cross_entropy
from .view import colo_view
from ._utils import *
3 changes: 2 additions & 1 deletion colossalai/legacy/nn/_ops/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from colossalai.global_variables import tensor_parallel_env as env
from colossalai.legacy.nn.layer.utils import divide
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
from colossalai.legacy.tensor import ColoTensorSpec, ProcessGroup
from colossalai.tensor import ColoTensor

GeneralTensor = Union[ColoTensor, torch.Tensor]
Number = Union[int, float]
Expand Down
90 changes: 0 additions & 90 deletions colossalai/legacy/nn/_ops/addmm.py

This file was deleted.

33 changes: 0 additions & 33 deletions colossalai/legacy/nn/_ops/batch_norm.py

This file was deleted.

Loading