Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e0edb21
[fx] compute memory stat and flop count for MetaInfoProp.
super-dainiu Sep 1, 2022
579b70b
Merge branch 'hpcaitech:main' into feature/flop_tensor
super-dainiu Sep 1, 2022
7f3a532
Merge branch 'hpcaitech:main' into feature/flop_tensor
super-dainiu Sep 2, 2022
17be5a5
[fx] modify node attribute.
super-dainiu Sep 2, 2022
36c93ca
[fx] modify ckpt_chen.
super-dainiu Sep 2, 2022
d6dcd80
[fx] fix compatibility.
super-dainiu Sep 2, 2022
01de3f1
[fx] fix import error.
super-dainiu Sep 2, 2022
e51f32a
[fx] skip test for MetaInfoProp.
super-dainiu Sep 2, 2022
a9193b5
[fx] skip test for MetaInfoProp.
super-dainiu Sep 2, 2022
97cfba9
[fx] skip test for MetaInfoProp.
super-dainiu Sep 2, 2022
c992b25
[fx] skip test for MetaInfoProp.
super-dainiu Sep 2, 2022
62d7096
[fx] skip if torch 1.11.0.
super-dainiu Sep 2, 2022
b2e8f6a
[fx] seek to solve incompatibilities.
super-dainiu Sep 5, 2022
f80735f
[fx] recover MetaInfoProp support for PyTorch 1.11.
super-dainiu Sep 5, 2022
0d4a030
[fx] provide a stable but not accurate enough version of profiler.
super-dainiu Sep 6, 2022
e7a33a6
Merge branch 'hpcaitech:main' into feature/flop_tensor
super-dainiu Sep 6, 2022
9223090
[fx] provide a stable but not accurate enough version of profiler.
super-dainiu Sep 6, 2022
1e49efe
[fx] fix compatibility in tests.
super-dainiu Sep 6, 2022
77045db
[fx] fix compatibility in tests.
super-dainiu Sep 6, 2022
1c127c6
[fx] fix compatibility in tests.
super-dainiu Sep 6, 2022
6f486d1
[fx] fix compatibility in tests.
super-dainiu Sep 6, 2022
c1eb532
[fx] fix compatibility in tests.
super-dainiu Sep 6, 2022
b3d144e
[fx] fix compatibility in tests.
super-dainiu Sep 6, 2022
15a1bb1
[fx] fix compatibility in tests.
super-dainiu Sep 6, 2022
d011148
[fx] fix compatibility in tests.
super-dainiu Sep 6, 2022
1a45a12
[fx] fix compatibility in tests.
super-dainiu Sep 6, 2022
cf06d84
[fx] fix compatibility in tests.
super-dainiu Sep 6, 2022
359f750
[fx] fix import error.
super-dainiu Sep 6, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion colossalai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
try:
from ._meta_registrations import *
from . import _meta_registrations
META_COMPATIBILITY = True
except:
import torch
META_COMPATIBILITY = False
Comment on lines +2 to +6
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

META_COMPATIBILITY is checked when Colossal-AI initializes.

print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.')
from .initialize import (initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch,
get_default_parser)
Expand Down
20 changes: 20 additions & 0 deletions colossalai/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,12 @@ def meta_hardswish_backward(grad_out: torch.Tensor, input: torch.Tensor):
return grad_in


@register_meta(aten.hardtanh_backward.default)
def meta_hardtanh_backward(grad_out: torch.Tensor, input: torch.Tensor, min_val: int, max_val: int):
grad_in = torch.empty_like(input)
return grad_in
Comment on lines +184 to +187
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only some extra registrations in this file.



@register_meta(aten.roll.default)
def meta_roll(input: torch.Tensor, shifts, dims):
return torch.empty_like(input)
Expand Down Expand Up @@ -321,3 +327,17 @@ def meta_index_Tensor(self, indices):
else:
replacement_shape = list(index.shape)
return self.new_empty(before_shape + replacement_shape + after_shape)


@register_meta(aten.embedding_dense_backward.default)
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
scale_grad_by_freq):
return torch.empty((num_weights, grad_output.size(-1)),
dtype=grad_output.dtype,
device=grad_output.device,
layout=grad_output.layout)


@register_meta(aten.where.self)
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
return torch.empty_like(condition)
4 changes: 2 additions & 2 deletions colossalai/fx/passes/algorithms/ckpt_solver_chen.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ def run_chen_greedy(b: int = 0) -> Tuple[Set, int]:
y = 0
prev_idx = 2
for (idx, n) in enumerate(gm.graph.nodes):
temp += getattr(n, '__activation__')
temp += getattr(n, 'fwd_out')
y = max(y, temp)
if temp > b and n in ckpt_nodes:
x += getattr(n, '__activation__')
x += getattr(n, 'fwd_out')
temp = 0
ckpt_intv.append((prev_idx, idx + 1))
prev_idx = idx + 1
Expand Down
59 changes: 29 additions & 30 deletions colossalai/fx/passes/meta_info_prop.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from operator import add, getitem
import torch
import torch.fx
from torch.fx.node import Node, Argument, Target
from torch.utils._pytree import tree_map
from typing import Any, Tuple, NamedTuple, Optional, Dict
from functools import reduce
from typing import Any, Tuple, NamedTuple, Dict
from torch.fx._compatibility import compatibility
from torch.fx.immutable_collections import immutable_dict, immutable_list
from colossalai.fx.profiler import MetaProfile, MetaTensor, profile_function, profile_module, calculate_activation_size, profile_method
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size, parameter_size


@compatibility(is_backward_compatible=True)
Expand Down Expand Up @@ -71,14 +68,6 @@ class MetaInfoProp(torch.fx.Interpreter):

"""

@compatibility(is_backward_compatible=True)
def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_processing: bool = True) -> Any:
"""
Add additional check for initial args to ensure all the tensor appears with `device='meta'`
"""
args = tree_map(lambda elem: MetaTensor(elem.to('meta')) if isinstance(elem, torch.Tensor) else elem, args)
return super().run(*args, initial_env, enable_io_processing)

@compatibility(is_backward_compatible=True)
def run_node(self, n: Node) -> Any:
"""
Expand All @@ -93,8 +82,7 @@ def run_node(self, n: Node) -> Any:
Returns:
Any: The result of executing ``n``
"""
result, profile = super().run_node(n)
profile: MetaProfile
result, flop_count, mem_stat = super().run_node(n)

def extract_tensor_meta(obj):
if isinstance(obj, torch.Tensor):
Expand All @@ -106,12 +94,17 @@ def extract_tensor_meta(obj):
n.meta['tensor_meta'] = meta

# TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', profile.param + profile.activation)
setattr(n, '__param__', profile.param)
setattr(n, '__activation__', profile.activation)
setattr(n, '__flops__', profile.flops)
setattr(n, '__macs__', profile.macs)
setattr(n, 'node_size', mem_stat[1])
setattr(n, 'fwd_flop', flop_count[0])
setattr(n, 'bwd_flop', flop_count[1])
setattr(n, 'fwd_tmp', mem_stat[0])
setattr(n, 'fwd_out', mem_stat[1])
setattr(n, 'bwd_tmp', mem_stat[2])
setattr(n, 'bwd_out', mem_stat[3])
n.meta['type'] = type(result)

for param in self.module.parameters():
param.grad = None
Comment on lines +106 to +107
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Obviously, we need to clear grad of the parameter, because these grads are meta

return result

# Main Node running APIs
Expand All @@ -132,11 +125,12 @@ def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict

Returns:
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
"""
result = super().placeholder(target, args, kwargs)
# A placeholder node only has activation
return result, MetaProfile(0, calculate_activation_size(result), 0, 0)
return result, (0, 0), (0, activation_size(result), 0, 0)

@compatibility(is_backward_compatible=True)
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
Expand All @@ -153,10 +147,10 @@ def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[st

Return:
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
"""
# A get_attr node never has parameters, activations, FLOPs, or MACs
return super().get_attr(target, args, kwargs), MetaProfile(0, 0, 0, 0)
return super().get_attr(target, args, kwargs), (0, 0), (0, 0, 0, 0)

@compatibility(is_backward_compatible=True)
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
Expand All @@ -172,7 +166,8 @@ def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Di

Return
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
"""
assert not isinstance(target, str)
return profile_function(target)(*args, **kwargs)
Expand All @@ -191,7 +186,8 @@ def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict

Return
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
"""
return profile_method(target)(*args, **kwargs)

Expand All @@ -209,7 +205,8 @@ def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict

Return
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
"""
# Retrieve executed args and kwargs values from the environment
# Execute the method and return the result
Expand All @@ -231,9 +228,11 @@ def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str,
kwargs (Dict): Dict of keyword arguments for this invocation

Return:
Any: The return value referenced by the output node
result (Any): The argument value that was retrieved
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
"""
return args[0], MetaProfile(0, 0, 0, 0)
return args[0], (0, 0), (0, 0, 0, 0)

def propagate(self, *args):
"""
Expand Down
14 changes: 9 additions & 5 deletions colossalai/fx/profiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from .meta_tensor import MetaTensor
from .registry import meta_profiler_function, meta_profiler_module
from .profiler_function import *
from .profiler_module import *
from .profiler import *
from ... import META_COMPATIBILITY
if META_COMPATIBILITY:
from .opcount import flop_mapping
from .tensor import MetaTensor
from .profiler import profile_function, profile_method, profile_module, _profile
else:
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module

from .memory import parameter_size, activation_size
4 changes: 4 additions & 0 deletions colossalai/fx/profiler/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .registry import meta_profiler_function, meta_profiler_module
from .profiler_function import *
from .profiler_module import *
from .profiler import profile_function, profile_method, profile_module
125 changes: 125 additions & 0 deletions colossalai/fx/profiler/experimental/profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from typing import Callable, Any, Dict, Tuple
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the old one, so I did not modify anything except for the output format.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

old for PyTorch 1.11

import torch
from torch.fx.node import Argument, Target
from . import meta_profiler_function, meta_profiler_module
from ..memory import activation_size, INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS

__all__ = ['profile_function', 'profile_module', 'profile_method']

CALL_FUNCTION_MSG = \
"""
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
from colossalai.fx.profiler.experimental import meta_profiler_function
@meta_profiler_function.register(YOUR_FUNCTION)
def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]:
flops = ...
macs = ...
return flops, macs
"""
CALL_METHOD_MSG = 'Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}'
CALL_MODULE_MSG = \
"""
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
from colossalai.fx.profiler.experimental import meta_profiler_module
@meta_profiler_module.register(YOUR_MODULE)
def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]:
flops = ...
macs = ...
return flops, macs
"""


def profile_function(target: 'Target') -> Callable:
"""
Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution.
Unfortunately, backward memory cost and FLOPs are estimated results.

Warnings:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn.functional` are available.

Examples:
>>> input = torch.rand(100, 100, 100, 100, device='meta')
>>> func = torch.nn.functional.relu
>>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_function(func)(input, inplace=False)
"""

def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
assert meta_profiler_function.has(target) or meta_profiler_function.has(
target.__name__), CALL_FUNCTION_MSG.format(target)

fwd_tmp = 0
fwd_out = 0
out = func(*args, **kwargs)
if target not in INPLACE_OPS and not kwargs.get('inplace', False):
fwd_out = activation_size(out)
if meta_profiler_function.has(target):
profiler = meta_profiler_function.get(target)
else:
profiler = meta_profiler_function.get(target.__name__)
fwd_flop, _ = profiler(*args, **kwargs)
return out, (fwd_flop, fwd_flop * 2), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)

f.__name__ = target.__name__
func = target
return f


def profile_method(target: 'Target') -> Callable:
"""
Wrap a `call_method` node
record the memory cost and FLOPs of the execution.

Warnings:
This is not fully implemented and you may follow the error message to debug.
"""

def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# args[0] is the `self` object for this method call
self_obj, *args_tail = args

# execute the method and return the result
assert isinstance(target, str), f'{target} instance is not str.'

out = getattr(self_obj, target)(*args_tail, **kwargs)
assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format(
target, INPLACE_METHOD, NON_INPLACE_METHOD)
# call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs.
fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out)
fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out)
return out, (0, 0), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)

return f


def profile_module(module: torch.nn.Module) -> Callable:
"""
Wrap a `call_module` node or `torch.nn` in order to
record the memory cost and FLOPs of the execution.

Warnings:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn` are available.

Example:
>>> input = torch.rand(4, 3, 224, 224, device='meta')
>>> mod = torch.nn.Conv2d(3, 128, 3)
>>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_module(mod)(input)
"""

def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
assert meta_profiler_module.has(type(module)), CALL_MODULE_MSG.format(type(module))

fwd_tmp = 0
fwd_out = 0
out = func(*args, **kwargs)
if getattr(module, 'inplace', False):
fwd_out = activation_size(out)
profiler = meta_profiler_module.get(type(module))
fwd_flop, _ = profiler(module, *args, **kwargs)
return out, (fwd_flop, fwd_flop * 2), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)

f.__name__ = module.__class__.__name__
func = module.forward
return f
Loading