Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions colossalai/auto_parallel/passes/comm_metainfo_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,16 @@ def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
return meta_info


def _runtime_apply_meta_info(node: Node, original_sharding_spec_dict, sharding_spec_dict) -> MetaInfo:
def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -> MetaInfo:
"""
This method is used to construct `MetaInto` for shape consistency node
"""

# extract node index and user node index
args = node.args
node_index, user_node_index = args[3], args[4]
origin_sharding_spec, target_sharding_spec = original_sharding_spec_dict[node_index], sharding_spec_dict[
node_index][user_node_index]
origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][
user_node_index]

return _construct_meta_info(node, origin_sharding_spec, target_sharding_spec)

Expand Down Expand Up @@ -98,16 +98,16 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> M
return meta_info


def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, original_sharding_spec_dict: Dict,
comm_actions_dict: Dict):
def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict,
comm_actions_dict: Dict) -> GraphModule:
"""
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph.
"""
for node in gm.graph.nodes:
if node.target == runtime_apply:
setattr(node, 'best_metainfo',
_runtime_apply_meta_info(node, original_sharding_spec_dict, sharding_spec_dict))
setattr(node, 'best_metainfo', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
elif node.target == runtime_comm_spec_apply:
setattr(node, 'best_metainfo', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
else:
pass
return gm
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


@operator_registry.register(BCAST_FUNC_OP)
class BinaryElementwiseHandler(NodeHandler):
class BinaryElementwiseHandler(MetaInfoNodeHandler):
"""
An BinaryBcastOpHandler is a node handler which deals with operations which have two
operands and broadcasting occurs such as torch.add.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .node_handler import MetaInfoNodeHandler, NodeHandler
from .registry import operator_registry
from .strategy import ReshapeGenerator, StrategyGenerator

Expand All @@ -13,7 +13,7 @@
@operator_registry.register(torch.flatten)
@operator_registry.register(torch.Tensor.unsqueeze)
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
class ReshapeHandler(NodeHandler):
class ReshapeHandler(MetaInfoNodeHandler):
"""
A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .node_handler import MetaInfoNodeHandler, NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, UnaryElementwiseGenerator

Expand All @@ -19,7 +19,7 @@
@operator_registry.register(torch.nn.modules.dropout.Dropout)
@operator_registry.register(torch.Tensor.contiguous)
@operator_registry.register(torch.nn.functional.dropout)
class UnaryElementwiseHandler(NodeHandler):
class UnaryElementwiseHandler(MetaInfoNodeHandler):
"""
A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op.
"""
Expand Down