Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch import Tensor
from torch.fx import Graph, Node

from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai.fx.profiler import (
activation_size,
Expand Down Expand Up @@ -131,8 +132,14 @@ def _extract_node_info(cls, node: List[Node]) -> Tuple[int, ...]:
fwd_mem_peak = 0
for n in node:
assert isinstance(n, Node), f'{n} is not a Node'
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n))
if n.target == runtime_apply or n.target == runtime_comm_spec_apply:
# in this case we need to calculate memory usage directly based on the statics that hooked in node.meta
xbar += n.meta['fwd_mem_out']
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'])
else:
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n))

# minimum flop count is required
ftime += max(calculate_fwd_time(n), 1.0)
btime += max(calculate_bwd_time(n), 1.0)
Expand Down
1 change: 1 addition & 0 deletions colossalai/auto_parallel/passes/meta_info_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def node_handler(self, node: Node) -> None:
# fetch other memory informations
memory_cost = meta_info.memory_cost
graph_info.fwd_mem_tmp = memory_cost.fwd.temp
graph_info.fwd_mem_out = memory_cost.fwd.activation
graph_info.bwd_mem_tmp = memory_cost.bwd.temp
graph_info.bwd_mem_out = memory_cost.bwd.activation

Expand Down
4 changes: 2 additions & 2 deletions colossalai/fx/profiler/shard_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def calculate_fwd_time(n: Node) -> float:
fwd_time (float): the result of `fwd_time`
"""
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
return n.meta["fwd_flop"]
return n.meta["fwd_time"]


def calculate_bwd_time(n: Node) -> float:
Expand All @@ -111,4 +111,4 @@ def calculate_bwd_time(n: Node) -> float:
bwd_time (float): the result of `bwd_time`
"""
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
return n.meta["bwd_flop"]
return n.meta["bwd_time"]