Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
df7e650
[CLI] add CLI launcher
YuliangLiu0306 Apr 13, 2022
73753aa
Merge branch 'feature/cli' into main
YuliangLiu0306 Apr 13, 2022
80da77a
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 15, 2022
551359c
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 18, 2022
a25697a
Revert "[CLI] add CLI launcher"
YuliangLiu0306 Apr 19, 2022
77b5704
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 19, 2022
e23d33e
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 20, 2022
997c625
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 23, 2022
961d950
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 24, 2022
2deaa40
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 24, 2022
9ff217f
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 28, 2022
501dc1a
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 May 12, 2022
21e43fd
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 May 21, 2022
cbd4579
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 May 23, 2022
1443291
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 May 30, 2022
e627cf5
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 10, 2022
289316e
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 15, 2022
689e047
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 15, 2022
0a83919
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 17, 2022
98c1ef9
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 20, 2022
9a3af67
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 21, 2022
7700793
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 28, 2022
3c77d1f
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 30, 2022
7c10323
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jul 4, 2022
11711d1
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jul 6, 2022
cee6276
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jul 8, 2022
8d00be0
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jul 12, 2022
af2a8f9
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jul 12, 2022
9745566
[fx] add balanced policy v2
YuliangLiu0306 Jul 12, 2022
8a7e281
add unittest
YuliangLiu0306 Jul 14, 2022
e444bc5
Merge branch 'main' into feature/balanced_policy_v2
YuliangLiu0306 Jul 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion colossalai/fx/passes/adding_split_node_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ def pipe_split():


def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
# TODO(lyl): balanced policy V2, split module by node size(weight+bias+output)
"""
In balanced_split_pass, we split module by the size of parameters(weights+bias).
"""
mod_graph = gm.graph
total_param_amount = 0
for param in mod_graph.owning_module.parameters():
Expand Down Expand Up @@ -39,6 +41,36 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
return gm


def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
Comment thread
YuliangLiu0306 marked this conversation as resolved.
"""
In balanced_split_pass_v12, we split module by the size of nodes(weights+bias+outputs).
"""
mod_graph = gm.graph
# To use balanced_split_pass_v2, we need run meta_info_prop interpreter first.
# If nodes don't have meta info, this pass will fall back to normal balanced split pass.
check_node = list(mod_graph.nodes)[0]
if 'tensor_meta' not in check_node.meta:
return balanced_split_pass(gm, pp_size)

total_element_size = 0
for node in mod_graph.nodes:
total_element_size += node.node_size

partition_size = total_element_size // pp_size
accumulate_node_size = 0
for node in mod_graph.nodes:
if pp_size <= 1:
break
accumulate_node_size += node.node_size
if accumulate_node_size >= partition_size:
accumulate_node_size = 0
pp_size -= 1
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
gm.recompile()
return gm


def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int):
mod_graph = gm.graph
valid_children_size = 0
Expand Down
21 changes: 19 additions & 2 deletions colossalai/fx/passes/meta_info_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ class MetaInfoProp(torch.fx.Interpreter):

def run_node(self, n: Node) -> Any:
result = super().run_node(n)

found_tensor = False

def extract_tensor_meta(obj):
Expand All @@ -83,7 +82,25 @@ def extract_tensor_meta(obj):
n.meta['tensor_meta'] = meta
else:
n.meta['tensor_meta'] = TensorMetadata(None, None, False, None, 0)

# counting the total size of node outputs
total_node_size = 0
if isinstance(n.meta['tensor_meta'], TensorMetadata):
total_node_size += n.meta['tensor_meta'].numel
else:
for element in n.meta['tensor_meta']:
assert isinstance(
element, TensorMetadata
), f"``n.meta['tensor_meta']`` should be either TensorMetadata or a tuple of TensorMetadata."
total_node_size += element.numel
# counting the total size of parameters
total_param_size = 0
if n.op == 'call_module':
target_module = n.graph.owning_module.get_submodule(n.target)
for param in target_module.parameters():
total_param_size += param.numel()

total_node_size += total_param_size
n.node_size = total_node_size
n.meta['type'] = type(result)
return result

Expand Down
4 changes: 3 additions & 1 deletion tests/test_fx/test_pipeline_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import colossalai.nn as col_nn
from torch.fx import symbolic_trace
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass, \
uniform_split_pass
uniform_split_pass, balanced_split_pass_v2

import pytest

MODEL_DIM = 16
Expand Down Expand Up @@ -43,6 +44,7 @@ def test_pipeline_passes():
model = MLP(MODEL_DIM)
data = torch.rand(BATCH_SIZE, MODEL_DIM)
pipeline_pass_test_helper(model, data, balanced_split_pass)
pipeline_pass_test_helper(model, data, balanced_split_pass_v2)
pipeline_pass_test_helper(model, data, uniform_split_pass)


Expand Down