Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions colossalai/fx/passes/adding_split_node_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,40 @@ def pipe_split():
pass


def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
"""
In avgcompute_split_pass, we split module by the fwd flops.
"""
mod_graph = gm.graph
# To use avgcompute_split_pass, we need run meta_info_prop interpreter first.
# If nodes don't have meta info, this pass will fall back to normal balanced split pass.
check_node = list(mod_graph.nodes)[0]
if 'tensor_meta' not in check_node.meta:
return balanced_split_pass(gm, pp_size)

total_fwd_flop = 0
for node in mod_graph.nodes:
total_fwd_flop += node.fwd_flop

partition_flop = total_fwd_flop // pp_size
accumulate_fwd_flop = 0
for node in mod_graph.nodes:
if pp_size <= 1:
break
if 'pipe_split' in node.name:
continue
accumulate_fwd_flop += node.fwd_flop
if accumulate_fwd_flop >= partition_flop:
total_fwd_flop = total_fwd_flop - accumulate_fwd_flop
accumulate_fwd_flop = 0
pp_size -= 1
partition_flop = total_fwd_flop // pp_size
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
gm.recompile()
return gm


def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int):
"""
In avgnode_split_pass, simpliy split graph by node number.
Expand Down Expand Up @@ -104,8 +138,10 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
continue
accumulate_node_size += node.node_size
if accumulate_node_size >= partition_size:
total_element_size = total_element_size - accumulate_node_size
accumulate_node_size = 0
pp_size -= 1
partition_size = total_element_size // pp_size
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
gm.recompile()
Expand Down
3 changes: 2 additions & 1 deletion colossalai/fx/passes/meta_info_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def extract_tensor_meta(obj):
n.meta['tensor_meta'] = tensor_meta
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', activation_size(n.meta.get('fwd_in', 0)) + activation_size(n.meta.get('fwd_tmp', 0)))
setattr(n, 'node_size', activation_size(n.meta.get('fwd_out', 0)) + activation_size(n.meta.get('fwd_tmp', 0)))
setattr(n, 'fwd_flop', n.meta.get('fwd_flop', 0))
n.meta['type'] = type(result)

# retain the autograd graph
Expand Down