Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .linear_handler import LinearFunctionHandler, LinearModuleHandler
from .matmul_handler import MatMulHandler
from .normal_pooling_handler import NormPoolingHandler
from .option import ShardOption
from .output_handler import OutputHandler
from .placeholder_handler import PlaceholderHandler
from .registry import operator_registry
Expand All @@ -27,5 +28,5 @@
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler',
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler',
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler'
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'ShardOption'
]
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch.fx.node import Node

from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register
from colossalai.auto_parallel.tensor_shard.node_handler.option import ShardOption
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
Expand Down Expand Up @@ -35,12 +36,14 @@ def __init__(
node: Node,
device_mesh: DeviceMesh,
strategies_vector: StrategiesVector,
shard_option: ShardOption = ShardOption.STANDARD,
) -> None:
self.node = node
self.predecessor_node = list(node._input_nodes.keys())
self.successor_node = list(node.users.keys())
self.device_mesh = device_mesh
self.strategies_vector = strategies_vector
self.shard_option = shard_option

def update_resharding_cost(self, strategy: ShardingStrategy) -> None:
"""
Expand Down Expand Up @@ -181,6 +184,21 @@ def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesV
if op_data.data is not None and isinstance(op_data.data, torch.Tensor):
check_sharding_spec_validity(sharding_spec, op_data.data)

remove_strategy_list = []
for strategy in self.strategies_vector:
shard_level = 0
for op_data, sharding_spec in strategy.sharding_specs.items():
if op_data.data is not None and isinstance(op_data.data, torch.Tensor):
for dim, shard_axis in sharding_spec.dim_partition_dict.items():
shard_level += len(shard_axis)
if self.shard_option == ShardOption.SHARD and shard_level == 0:
remove_strategy_list.append(strategy)
if self.shard_option == ShardOption.FULL_SHARD and shard_level <= 1:
remove_strategy_list.append(strategy)

for strategy in remove_strategy_list:
self.strategies_vector.remove(strategy)

return self.strategies_vector

def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
Expand Down
17 changes: 17 additions & 0 deletions colossalai/auto_parallel/tensor_shard/node_handler/option.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from enum import Enum

__all__ = ['ShardOption']


class ShardOption(Enum):
"""
This enum class is to define the shard level required in node strategies.

Notes:
STANDARD: We do not add any extra shard requirements.
SHARD: We require the node to be shard using at least one device mesh axis.
FULL_SHARD: We require the node to be shard using all device mesh axes.
"""
STANDARD = 0
SHARD = 1
FULL_SHARD = 2
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from functools import partial

import torch
import torch.multiprocessing as mp
import torch.nn as nn

from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.option import ShardOption
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing import parameterize
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize


class LinearModel(nn.Module):

def __init__(self):
super().__init__()

def forward(self, input, others, bias=None):
x = nn.functional.linear(input, others, bias=bias)
return x


def check_shard_option(shard_option):
model = LinearModel().cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)

tracer = ColoTracer()
graph = tracer.trace(model,
meta_args={
"input": torch.rand(4, 4, 4, 16).to('meta'),
'others': torch.rand(32, 16).to('meta')
})
gm = ColoGraphModule(model, graph)
linear_func_node = list(graph.nodes)[2]
strategies_vector = StrategiesVector(linear_func_node)

# build handler
handler = LinearFunctionHandler(node=linear_func_node,
device_mesh=device_mesh,
strategies_vector=strategies_vector,
shard_option=shard_option)

strategies_vector = handler.register_strategy(compute_resharding_cost=False)
strategy_name_list = [val.name for val in strategies_vector]

# SS = SR x RS
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
assert 'S0S1 = S0R x RS1_1' in strategy_name_list
assert 'S0S1 = S0R x RS1_2' in strategy_name_list
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
assert 'S1S0 = S1R x RS0_1' in strategy_name_list
assert 'S1S0 = S1R x RS0_2' in strategy_name_list

# SR = SS x SR
assert 'S0R = S0S1 x S1R_1' in strategy_name_list
assert 'S0R = S0S1 x S1R_2' in strategy_name_list
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
assert 'S1R = S1S0 x S0R_1' in strategy_name_list
assert 'S1R = S1S0 x S0R_2' in strategy_name_list

# RS = RS x SS
assert 'RS0 = RS1 x S1S0' in strategy_name_list
assert 'RS1 = RS0 x S0S1' in strategy_name_list

# S01R = S01R x RR
assert 'S01R = S01R x RR_0' in strategy_name_list
assert 'S01R = S01R x RR_1' in strategy_name_list
assert 'S01R = S01R x RR_2' in strategy_name_list

# RR = RS01 x S01R
assert 'RR = RS01 x S01R' in strategy_name_list

# RS01 = RR x RS01
assert 'RS01 = RR x RS01' in strategy_name_list

if shard_option == ShardOption.SHARD:
# RR = RS x SR
assert 'RR = RS0 x S0R' in strategy_name_list
assert 'RR = RS1 x S1R' in strategy_name_list

# RS= RR x RS
assert 'RS0 = RR x RS0' in strategy_name_list
assert 'RS1 = RR x RS1' in strategy_name_list

if shard_option == ShardOption.STANDARD:
# RR = RS x SR
assert 'RR = RS0 x S0R' in strategy_name_list
assert 'RR = RS1 x S1R' in strategy_name_list

# RS= RR x RS
assert 'RS0 = RR x RS0' in strategy_name_list
assert 'RS1 = RR x RS1' in strategy_name_list

# RR = RR x RR
assert 'RR = RR x RR' in strategy_name_list


@run_on_environment_flag(name='AUTO_PARALLEL')
def test_shard_option():
for shard_option in [ShardOption.STANDARD, ShardOption.SHARD, ShardOption.FULL_SHARD]:
check_shard_option(shard_option)


if __name__ == '__main__':
test_shard_option()