Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from .binary_elementwise_handler import BinaryElementwiseHandler
from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler
from .conv_handler import ConvFunctionHandler, ConvModuleHandler
from .default_reshape_handler import DefaultReshapeHandler
from .embedding_handler import EmbeddingFunctionHandler, EmbeddingModuleHandler
from .experimental import PermuteHandler, ViewHandler
from .getattr_handler import GetattrHandler
from .getitem_handler import GetItemHandler
from .layer_norm_handler import LayerNormModuleHandler
Expand All @@ -13,20 +13,24 @@
from .normal_pooling_handler import NormPoolingHandler
from .option import ShardOption
from .output_handler import OutputHandler
from .permute_handler import PermuteHandler
from .placeholder_handler import PlaceholderHandler
from .registry import operator_registry
from .reshape_handler import ReshapeHandler
from .softmax_handler import SoftmaxHandler
from .split_handler import SplitHandler
from .sum_handler import SumHandler
from .tensor_constructor_handler import TensorConstructorHandler
from .transpose_handler import TransposeHandler
from .unary_elementwise_handler import UnaryElementwiseHandler
from .view_handler import ViewHandler
from .where_handler import WhereHandler

__all__ = [
'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler',
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler',
'UnaryElementwiseHandler', 'DefaultReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler',
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler',
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'ShardOption'
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'ShardOption',
'TransposeHandler', 'SplitHandler'
]
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,23 @@
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import MetaInfoNodeHandler, NodeHandler
from .registry import operator_registry
from .strategy import ReshapeGenerator, StrategyGenerator
from .strategy import DefaultReshapeGenerator, StrategyGenerator

__all__ = ['ReshapeHandler']
__all__ = ['DefaultReshapeHandler']


@operator_registry.register(torch.flatten)
@operator_registry.register(torch.Tensor.unsqueeze)
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
class ReshapeHandler(MetaInfoNodeHandler):
class DefaultReshapeHandler(MetaInfoNodeHandler):
"""
A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
A DefaultReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
"""

def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(ReshapeGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
generators.append(DefaultReshapeGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators

def infer_logical_shape(self, data):
Expand Down

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

import torch

from ...sharding_strategy import OperationData, OperationDataType
from ..node_handler import NodeHandler
from ..registry import operator_registry
from ..strategy import StrategyGenerator
from .reshape_generator import PermuteGenerator
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import PermuteGenerator, StrategyGenerator

__all__ = ['PermuteHandler']

Expand Down
Loading