Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
04e5272
Merge pull request #1 from hpcaitech/main
Cypher30 Jul 14, 2022
75618b3
Merge pull request #2 from hpcaitech/main
Cypher30 Jul 15, 2022
3e4620c
Merge pull request #3 from hpcaitech/main
Cypher30 Jul 20, 2022
cf24049
Merge remote-tracking branch 'upstream/main' into main
Jul 20, 2022
3d223b6
Merge remote-tracking branch 'upstream/main' into main
Jul 21, 2022
644115c
Merge branch 'hpcaitech:main' into main
Cypher30 Jul 22, 2022
d995ade
Merge branch 'hpcaitech:main' into main
Cypher30 Jul 25, 2022
bba2dbe
Merge branch 'hpcaitech:main' into main
Cypher30 Jul 26, 2022
05ca628
Merge branch 'hpcaitech:main' into main
Cypher30 Jul 26, 2022
0a967da
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 6, 2022
0637c0d
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 8, 2022
74a6227
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 10, 2022
e550490
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 10, 2022
2d7f5d9
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 11, 2022
b62e870
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 12, 2022
b4b0974
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 15, 2022
65c20de
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 16, 2022
1660bfc
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 17, 2022
6eb0ad0
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 20, 2022
56df059
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 26, 2022
480e932
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 30, 2022
0fa66ee
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 30, 2022
1d013b0
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 31, 2022
5774db2
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 5, 2022
e8ff699
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 6, 2022
855c728
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 7, 2022
2c113ea
Merge branch 'main' of github.com:Cypher30/ColossalAI into main
Sep 8, 2022
838ba70
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 13, 2022
cacec2b
Merge branch 'main' of github.com:Cypher30/ColossalAI into main
Sep 13, 2022
5ed6ef0
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 14, 2022
668af30
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 14, 2022
df79772
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 15, 2022
7b6a0fc
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 20, 2022
c30022e
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 23, 2022
df20f4d
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 26, 2022
2d5a6a0
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 29, 2022
07d27a6
Merge branch 'hpcaitech:main' into main
Cypher30 Oct 3, 2022
dc68ba9
Merge branch 'hpcaitech:main' into main
Cypher30 Oct 5, 2022
929e7d3
Merge branch 'hpcaitech:main' into main
Cypher30 Oct 6, 2022
90aa46a
Merge branch 'hpcaitech:main' into main
Cypher30 Oct 11, 2022
40363da
Merge branch 'hpcaitech:main' into main
Cypher30 Oct 13, 2022
fe3fca5
Merge branch 'hpcaitech:main' into main
Cypher30 Oct 25, 2022
956156e
Merge branch 'hpcaitech:main' into main
Cypher30 Oct 26, 2022
cb20212
Merge branch 'hpcaitech:main' into main
Cypher30 Oct 26, 2022
744a775
Merge branch 'hpcaitech:main' into main
Cypher30 Nov 10, 2022
1629a90
Merge branch 'hpcaitech:main' into main
Cypher30 Nov 10, 2022
f0558e3
Merge branch 'hpcaitech:main' into main
Cypher30 Dec 1, 2022
bb7bd4a
Merge branch 'hpcaitech:main' into main
Cypher30 Dec 21, 2022
26de8e5
Merge branch 'hpcaitech:main' into main
Cypher30 Dec 23, 2022
83a1418
Merge branch 'hpcaitech:main' into main
Cypher30 Dec 30, 2022
0802b94
Merge branch 'hpcaitech:main' into main
Cypher30 Dec 30, 2022
24709eb
[autoparallel] hook node meta on graph nodes for checkpoint solver
Dec 31, 2022
b1caa03
[autoparallel] polish code
Dec 31, 2022
b161a87
[autoparallel] restore some node handlers
Dec 31, 2022
ec8150f
colossalai/auto_parallel/passes/meta_info_prop.py
Jan 2, 2023
d5e3d6e
[autoparallel] remove some unused import
Jan 2, 2023
5a44d24
[autoparallel] hook bwd_mem_out
Jan 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)

# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_op_data.data, device='meta')]
fwd_in = [torch.zeros_like(input_op_data.data, device='meta'), torch.zeros_like(other_op_data.data, device='meta')]
fwd_buffer = []
fwd_out = [torch.zeros_like(output_op_data.data, device='meta')]

Expand Down
22 changes: 4 additions & 18 deletions colossalai/auto_parallel/meta_profiler/metainfo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Callable, List

import numpy as np
import torch

from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
Expand Down Expand Up @@ -71,25 +70,12 @@ def target(self, target: Callable) -> None:
if self._strategy is not None and self._target is not None:
self.compute_metainfo()

def compute_sharded_tensor(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor:
def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor:
"""
Compute sharded meta tensor based on the given data and sharding spec.
Compute sharded opdata based on the given data and sharding spec.
"""
shard_sequnce = sharding_spec.sharding_sequence
device_mesh = sharding_spec.device_mesh
shape = operation_data.data.shape

new_shape = []
for dim, shard in zip(shape, shard_sequnce):
if shard.is_replica:
# replica
new_shape.append(dim)
else:
# sharded according to device_mesh shape
new_shape.append(dim // np.prod(np.array([device_mesh.mesh_shape[i] for i in shard.shard_list])))

return OperationData(name=operation_data.name,
data=torch.zeros(new_shape, device="meta"),
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
type=operation_data.type,
logical_shape=operation_data.logical_shape)

Expand All @@ -113,7 +99,7 @@ def compute_metainfo(self):
save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION

# construct args for meta_func
args = [self.compute_sharded_tensor(k, v) for k, v in self._strategy.sharding_specs.items()]
args = [self.compute_sharded_opdata(k, v) for k, v in self._strategy.sharding_specs.items()]

# construct kwargs
if self.target in INPLACE_MODULE:
Expand Down
113 changes: 113 additions & 0 deletions colossalai/auto_parallel/passes/comm_metainfo_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from typing import Dict

import torch
from torch.fx import GraphModule
from torch.fx.node import Node

from colossalai.auto_parallel.meta_profiler import MetaInfo
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from colossalai.tensor.comm_spec import CommSpec
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec

shape_consistency_manager = ShapeConsistencyManager()


def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
target_sharding_spec: ShardingSpec) -> MetaInfo:
# get comm_action_sequence and total_cost from shape_consistency_manager
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
origin_sharding_spec, target_sharding_spec)

meta_info = MetaInfo()
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
# get mem cost for MetaInfo
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
# extract user that has _meta_data and extract element length
input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data'))
element_length = input_node._meta_data.element_size()

mem_cost.fwd.activation *= element_length
mem_cost.fwd.temp *= element_length
mem_cost.bwd.activation *= element_length
mem_cost.bwd.temp *= element_length
mem_cost.total.activation *= element_length

meta_info.memory_cost = mem_cost

# get computation cost for MetaInfo
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
total_cost['backward'] * element_length,
total_cost['total'] * element_length)

# get tensor shape for MetaInfo
origin_sharding_spec: ShardingSpec
target_sharding_spec: ShardingSpec
input_shape = origin_sharding_spec.get_sharded_shape_per_device()
output_shape = target_sharding_spec.get_sharded_shape_per_device()

meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
meta_info.fwd_buffer = []
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]

return meta_info


def _runtime_apply_meta_info(node: Node, original_sharding_spec_dict, sharding_spec_dict) -> MetaInfo:
"""
This method is used to construct `MetaInto` for shape consistency node
"""

# extract node index and user node index
args = node.args
node_index, user_node_index = args[3], args[4]
origin_sharding_spec, target_sharding_spec = original_sharding_spec_dict[node_index], sharding_spec_dict[
node_index][user_node_index]

return _construct_meta_info(node, origin_sharding_spec, target_sharding_spec)


def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> MetaInfo:
# extract node_index and op_data_name
node_index, op_data_name = node.args[2], node.args[3]

comm_action = comm_actions_dict[node_index][op_data_name]
if isinstance(comm_action.comm_spec, CommSpec):
# this case is for all_reduce, there will be no memory cost
meta_info = MetaInfo()
meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost)
output_node = next(n for n in node.users if hasattr(n, '_meta_data'))
element_length = output_node._meta_data.element_size()

total_cost = comm_action.comm_spec.get_comm_cost()
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
total_cost['backward'] * element_length,
total_cost['total'] * element_length)

input_shape = output_shape = comm_action.comm_spec.sharding_spec.get_sharded_shape_per_device()
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
meta_info.fwd_buffer = []
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
else:
# this case will be handled by shape consistency manager
origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[
'tgt_spec']
meta_info = _construct_meta_info(node, origin_sharding_spec, target_sharding_spec)

return meta_info


def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, original_sharding_spec_dict: Dict,
comm_actions_dict: Dict):
"""
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph.
"""
for node in gm.graph.nodes:
if node.target == runtime_apply:
setattr(node, 'best_metainfo',
_runtime_apply_meta_info(node, original_sharding_spec_dict, sharding_spec_dict))
elif node.target == runtime_comm_spec_apply:
setattr(node, 'best_metainfo', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
else:
pass
19 changes: 13 additions & 6 deletions colossalai/auto_parallel/passes/meta_info_prop.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import uuid
from dataclasses import asdict
from typing import Any, Dict, List, NamedTuple, Tuple
from typing import List

import torch
import torch.fx
from torch.fx import GraphModule
from torch.fx.node import Argument, Node, Target
from torch.utils._pytree import tree_map
from torch.fx.node import Node

from colossalai.auto_parallel.meta_profiler import MetaInfo
from colossalai.fx._compatibility import compatibility, is_compatible_with_meta
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import GraphInfo
from colossalai.fx.profiler.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS

Expand Down Expand Up @@ -68,7 +67,7 @@ def placeholder_handler(self, node: Node) -> None:
"""
graph_info = GraphInfo()
out = _normalize_tuple(getattr(node, '_meta_data', None))
graph_info.fwd_out = list(out)
graph_info.fwd_out = list(out) if out[0] is not None else []
node.meta = {**asdict(graph_info)}

@compatibility(is_backward_compatible=False)
Expand Down Expand Up @@ -97,7 +96,7 @@ def node_handler(self, node: Node) -> None:
"""
Handle other kind of nodes
"""
assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}"
assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}, {node.op}"
graph_info = GraphInfo()
meta_info = node.best_metainfo
meta_info: MetaInfo
Expand Down Expand Up @@ -158,5 +157,13 @@ def node_handler(self, node: Node) -> None:
memory_cost = meta_info.memory_cost
graph_info.fwd_mem_tmp = memory_cost.fwd.temp
graph_info.bwd_mem_tmp = memory_cost.bwd.temp
graph_info.bwd_mem_out = memory_cost.bwd.activation

# fetch flop information
# here we use fwd_time and bwd_time to deal with the case that
# communication cost is a float
compute_cost = meta_info.compute_cost
graph_info.fwd_time = compute_cost.fwd
graph_info.bwd_time = compute_cost.bwd

node.meta = {**asdict(graph_info)}
49 changes: 0 additions & 49 deletions colossalai/auto_parallel/passes/runtime_apply_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,53 +47,6 @@ def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict:
return rst


def construct_meta_info(node: Node, user_node: Node) -> MetaInfo:
"""
This method is used to construct `MetaInto` for shape consistency node
TODO: Actually we could attain the cost information from resharding cost in node
handler, we should modify this part in the future.
"""

def compute_shape(sharding_spec: ShardingSpec):
shape = sharding_spec.entire_shape
new_shape = []
for dim, shard in sharding_spec.dim_partition_dict.items():
new_shape.append(shape[dim] // len(shard))
return new_shape

meta_info = MetaInfo()
origin_sharding_spec, target_sharding_spec = node.sharding_spec, user_node.best_strategy.get_sharding_spec_by_name(
str(node.name))
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
origin_sharding_spec, target_sharding_spec)

# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
# get mem cost for MetaInfo
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
element_length = node._meta_data.element_size()
mem_cost.fwd.activation *= element_length
mem_cost.fwd.temp *= element_length
mem_cost.bwd.activation *= element_length
mem_cost.bwd.temp *= element_length
mem_cost.total.activation *= element_length

meta_info.memory_cost = mem_cost

# get computation cost for MetaInfo
compute_cost = TrainCycleItem(total_cost['forward'], total_cost['backward'], total_cost['total'])
meta_info.compute_cost = compute_cost

# get tensor shape for MetaInfo
input_shape = compute_shape(origin_sharding_spec)
output_shape = compute_shape(target_sharding_spec)

meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
meta_info.fwd_buffer = []
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]

return meta_info


def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str):
"""
This method will be invoked during runtime to apply the comm action following the instruction of comm spec.
Expand Down Expand Up @@ -175,8 +128,6 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
runtime_apply,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
meta_info = construct_meta_info(node, user_node)
setattr(shape_consistency_node, 'best_metainfo', meta_info)

new_args = list(user_node.args)
new_kwargs = dict(user_node.kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ def get_target_function(self) -> callable:
return None

if self.node.op == 'call_module':
submod = self.node.graph.owning_module.get_submodule(self.node.target)
target = type(submod)
target = self.node.graph.owning_module.get_submodule(self.node.target)
elif self.node.op == 'call_function':
target = self.node.target
elif self.node.op == 'call_method':
Expand Down