Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 22 additions & 29 deletions colossalai/fx/passes/meta_info_prop.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from dataclasses import asdict
from colossalai.fx.profiler import GraphInfo
import torch
import torch.fx
from torch.fx.node import Node, Argument, Target
from torch.utils._pytree import tree_map
from typing import Any, Tuple, NamedTuple, Dict
from torch.fx._compatibility import compatibility
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size, parameter_size
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size


@compatibility(is_backward_compatible=True)
Expand Down Expand Up @@ -40,7 +42,7 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata:
class MetaInfoProp(torch.fx.Interpreter):
"""
Execute an FX graph Node-by-Node with meta tensor and
record the shape, FLOPs, MACs and type of the result
record the memory usage, FLOPs, and type of the result
into the corresponding node.

Usage:
Expand Down Expand Up @@ -82,29 +84,28 @@ def run_node(self, n: Node) -> Any:
Returns:
Any: The result of executing ``n``
"""
result, flop_count, mem_stat = super().run_node(n)
result, meta_info = super().run_node(n)
Comment thread
super-dainiu marked this conversation as resolved.

def extract_tensor_meta(obj):
if isinstance(obj, torch.Tensor):
return _extract_tensor_metadata(obj)
else:
return TensorMetadata(None, None, False, None, 0, False)

meta = tree_map(extract_tensor_meta, result)
n.meta['tensor_meta'] = meta
tensor_meta = tree_map(extract_tensor_meta, result)
n.meta['tensor_meta'] = tensor_meta
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`

# TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', mem_stat[1])
setattr(n, 'fwd_flop', flop_count[0])
setattr(n, 'bwd_flop', flop_count[1])
setattr(n, 'fwd_tmp', mem_stat[0])
setattr(n, 'fwd_out', mem_stat[1])
setattr(n, 'bwd_tmp', mem_stat[2])
setattr(n, 'bwd_out', mem_stat[3])
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
for par in n.all_input_nodes:
par.meta['fwd_mem_out'] = par.meta.get('fwd_mem_out', 0) + n.meta.get('fwd_mem_in', 0)
n.meta['type'] = type(result)

# retain the autograd graph
for param in self.module.parameters():
param.grad = None

return result

# Main Node running APIs
Expand All @@ -125,12 +126,9 @@ def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict

Returns:
result (Any): The argument value that was retrieved
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
result = super().placeholder(target, args, kwargs)
# A placeholder node only has activation
return result, (0, 0), (0, activation_size(result), 0, 0)
return super().placeholder(target, args, kwargs), GraphInfo()

@compatibility(is_backward_compatible=True)
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
Expand All @@ -147,10 +145,9 @@ def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[st

Return:
result (Any): The argument value that was retrieved
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
return super().get_attr(target, args, kwargs), (0, 0), (0, 0, 0, 0)
return super().get_attr(target, args, kwargs), GraphInfo()

@compatibility(is_backward_compatible=True)
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
Comment thread
super-dainiu marked this conversation as resolved.
Expand All @@ -166,8 +163,7 @@ def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Di

Return
result (Any): The argument value that was retrieved
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
assert not isinstance(target, str)
return profile_function(target)(*args, **kwargs)
Expand All @@ -186,8 +182,7 @@ def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict

Return
result (Any): The argument value that was retrieved
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
return profile_method(target)(*args, **kwargs)

Expand All @@ -205,8 +200,7 @@ def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict

Return
result (Any): The argument value that was retrieved
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
# Retrieve executed args and kwargs values from the environment
# Execute the method and return the result
Expand All @@ -229,10 +223,9 @@ def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str,

Return:
result (Any): The argument value that was retrieved
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
return args[0], (0, 0), (0, 0, 0, 0)
return args[0], GraphInfo(fwd_mem_in=activation_size(args[0]))

def propagate(self, *args):
"""
Expand Down
3 changes: 2 additions & 1 deletion colossalai/fx/profiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
if META_COMPATIBILITY:
from .opcount import flop_mapping
from .tensor import MetaTensor
from .profiler import profile_function, profile_method, profile_module, _profile
from .profiler import profile_function, profile_method, profile_module
else:
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module

from .dataflow import GraphInfo
from .memory import parameter_size, activation_size
136 changes: 136 additions & 0 deletions colossalai/fx/profiler/dataflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from dataclasses import dataclass
from enum import Enum
from typing import Dict
from torch.fx import Graph, Node
from .memory import activation_size


class Stage(Enum):
FORWARD = 0
LOSS = 1
BACKWARD = 2
PLACEHOLDER = 3


@dataclass
class GraphInfo:
"""
GraphInfo is a dataclass for MetaInfo, which measures
the execution memory cost and FLOPs with `MetaTensor`.
The dataflow analysis is conducted on a single node of the FX graph.
============================================================================
-------------------------------
| Node |
[fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out`
placeholders saved for | | \__________ | |
backward. | | \ | |
| [fwd_tmp] ------> [bwd_tmp] | <-----
| | \_________ | | [bwd_tmp] marks the peak memory
| / \ \ | | in backward pass.
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
in [fwd_tmp] because | | | \_____ | |
it is not saved for | | | \ | |
backward. -------------------------------
============================================================================
Attributes:
fwd_flop (int): The forward FLOPs of a certain node
bwd_flop (int): The backward FLOPs of a certain node.
fwd_mem_in (int): See the above illustration.
fwd_mem_tmp (int): See the above illustration.
bwd_mem_tmp (int): See the above illustration.
bwd_mem_out (int): See the above illustration.
"""
fwd_flop: int = 0
bwd_flop: int = 0
fwd_mem_in: int = 0
fwd_mem_tmp: int = 0
bwd_mem_tmp: int = 0
bwd_mem_out: int = 0


def is_forward(n: Node):
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
return n.meta['stage'] == Stage.FORWARD


def is_loss(n: Node):
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
return n.meta['stage'] == Stage.LOSS


def is_placeholder(n: Node):
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
return n.meta['stage'] == Stage.PLACEHOLDER


def is_backward(n: Node):
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
return n.meta['stage'] == Stage.BACKWARD


def is_saved(n: Node):
return n.meta.get('saved', False)


def autograd_graph_analysis(graph: Graph) -> GraphInfo:
"""Analyze the autograd node dependencies and find out the memory usage.
Basically the input graph should have all nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`.
Nodes should have attribute `out` indicating the output of each node.
============================================================================
Placeholder ----> p o <---- We need to keep track of grad out
|\________ |
↓ ↘|
f --------> b
|\ \_____ ↑
| \ ↘ /
f f ----> b <---- Not every forward result needs to be saved for backward
| \____ ↑
↘ ↘|
f ----> b <---- Backward can be freed as soon as it is required no more.
↘ ↗
l
=============================================================================
Args:
graph (Graph): The autograd graph with nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`.

Returns:
graph_info (GraphInfo): Meta information for the dataflow.
"""

def _peak_memory(deps: Dict[Node, int]):
bwd_tmp = 0
for k, v in deps.items():
if v > 0:
bwd_tmp += activation_size(k.meta['out'])
return bwd_tmp

# deps is used to track all the memory dependencies of the graph.
deps = {}
graph_info = GraphInfo()

for n in graph.nodes:
n: Node
if is_saved(n) and not any(map(is_loss, n.users)):
# A forward tensor who is marked `save` but is not
# an input to `loss` should be saved during forward.
# If the tensor is a placeholder, then it belongs to `fwd_in`.
# Any `fwd_in` should be kept in memory even this function
# is checkpointed.
# Otherwise, the tensor belongs to `fwd_tmp`. If we checkpoint
# the node, `fwd_tmp` can be freed.
if is_placeholder(n):
graph_info.fwd_mem_in += activation_size(n.meta['out'])
if is_forward(n):
graph_info.fwd_mem_tmp += activation_size(n.meta['out'])
elif is_backward(n):
if len(n.users):
# liveness analysis is only used in backward
deps[n] = len(n.users)
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
for input_n in n.all_input_nodes:
if input_n in deps:
deps[input_n] -= 1
else:
# basically a backward node without user is a `grad_out` node
graph_info.bwd_mem_out += activation_size(n.meta['out'])
return graph_info
45 changes: 42 additions & 3 deletions colossalai/fx/profiler/experimental/profiler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
from typing import Callable, Any, Dict, Tuple
import torch
from torch.fx.node import Argument, Target
Expand All @@ -6,6 +7,44 @@

__all__ = ['profile_function', 'profile_module', 'profile_method']


# this is for compatibility use
@dataclass
class GraphInfo:
"""
GraphInfo is a dataclass for MetaInfo, which measures
the execution memory cost and FLOPs with `MetaTensor`.
The dataflow analysis is conducted on a single node of the FX graph.
============================================================================
-------------------------------
| Node |
[fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out`
placeholders saved for | | \__________ | |
backward. | | \ | |
| [fwd_tmp] ------> [bwd_tmp] | <-----
| | \_________ | | [bwd_tmp] marks the peak memory
| / \ \ | | in backward pass.
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
in [fwd_tmp] because | | | \_____ | |
it is not saved for | | | \ | |
backward. -------------------------------
============================================================================
Attributes:
fwd_flop (int): The forward FLOPs of a certain node
bwd_flop (int): The backward FLOPs of a certain node.
fwd_mem_in (int): See the above illustration.
fwd_mem_tmp (int): See the above illustration.
bwd_mem_tmp (int): See the above illustration.
bwd_mem_out (int): See the above illustration.
"""
fwd_flop: int = 0
bwd_flop: int = 0
fwd_mem_in: int = 0
fwd_mem_tmp: int = 0
bwd_mem_tmp: int = 0
bwd_mem_out: int = 0


CALL_FUNCTION_MSG = \
"""
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
Expand Down Expand Up @@ -59,7 +98,7 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
else:
profiler = meta_profiler_function.get(target.__name__)
fwd_flop, _ = profiler(*args, **kwargs)
return out, (fwd_flop, fwd_flop * 2), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
return out, GraphInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)

f.__name__ = target.__name__
func = target
Expand Down Expand Up @@ -88,7 +127,7 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs.
fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out)
fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out)
return out, (0, 0), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
return out, GraphInfo(0, 0, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)

return f

Expand Down Expand Up @@ -118,7 +157,7 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
fwd_out = activation_size(out)
profiler = meta_profiler_module.get(type(module))
fwd_flop, _ = profiler(module, *args, **kwargs)
return out, (fwd_flop, fwd_flop * 2), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
return out, GraphInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)

f.__name__ = module.__class__.__name__
func = module.forward
Expand Down
2 changes: 0 additions & 2 deletions colossalai/fx/profiler/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@

INPLACE_ATEN = [
aten.add_.Tensor,
aten.add.Tensor,
aten.sub_.Tensor,
aten.div_.Tensor,
aten.div_.Scalar,
aten.mul_.Tensor,
aten.mul.Tensor,
aten.bernoulli_.float,

# inplace reshaping
Expand Down
Loading