Skip to content
21 changes: 14 additions & 7 deletions colossalai/tensor/shape_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,8 @@ def gather_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int,
if discard_input:
alloc_numel -= input_numel

return alloc_numel, peak_numel

def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""analyze split memory footprint
split will allocate memory for the output tensor if we don't apply shard on the first dimension of
Expand Down Expand Up @@ -478,11 +480,13 @@ def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, p
# kind of weird, and I think we could ignore it for now.
pass

return alloc_numel, peak_numel

def reduce_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""
a dummy function for reduce memory footprint analysis, as the reduce action doesn't allocate extra memory
"""
pass
return alloc_numel, peak_numel

def all2all_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""analyze all_to_all memory footprint
Expand All @@ -508,11 +512,13 @@ def all2all_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int,
if discard_input:
alloc_numel -= input_numel

return alloc_numel, peak_numel

def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""
a dummy function for identity memory footprint analysis, as the identity action doesn't allocate extra memory
"""
pass
return alloc_numel, peak_numel

pattern_to_func_dict = {
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: [gather_analysis, split_analysis],
Expand All @@ -539,17 +545,18 @@ def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int
for idx, action_spec_pair in enumerate(zip(fwd_actions, comm_action_sequence)):
# the first forward comm action will not discard input
fwd_action, comm_spec = action_spec_pair
if idx == 0:
fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel)
else:
fwd_action(comm_spec, True, fwd_alloc_numel, fwd_peak_numel)
fwd_alloc_numel, fwd_peak_numel = fwd_action(comm_spec, False, fwd_alloc_numel,
fwd_peak_numel) if idx == 0 else fwd_action(
comm_spec, True, fwd_alloc_numel, fwd_peak_numel)

# analyze memory footprint for backward comm actions sequence
bwd_alloc_numel = 0
bwd_peak_numel = 0
for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))):
bwd_action, comm_spec = action_spec_pair
bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel)
bwd_alloc_numel, bwd_peak_numel = bwd_action(comm_spec, False, bwd_alloc_numel,
bwd_peak_numel) if idx == 0 else bwd_action(
comm_spec, True, bwd_alloc_numel, bwd_peak_numel)

fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel)
bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel)
Expand Down