Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
06f8991
[fx] modify the calculation of node_size in MetaInfoProp for activati…
super-dainiu Aug 9, 2022
3cd7d22
[fx] modify the calculation of node_size in MetaInfoProp for activati…
super-dainiu Aug 9, 2022
0849b3b
[fx] modify the calculation of node_size in MetaInfoProp for activati…
super-dainiu Aug 10, 2022
701786c
Merge branch 'hpcaitech:main' into main
super-dainiu Aug 10, 2022
a75e5a2
Merge branch 'hpcaitech:main' into main
super-dainiu Aug 10, 2022
c20beb2
Merge branch 'hpcaitech:main' into main
super-dainiu Aug 11, 2022
7e87286
Merge branch 'hpcaitech:main' into main
super-dainiu Aug 11, 2022
f027931
Merge branch 'hpcaitech:main' into main
super-dainiu Aug 12, 2022
9b4f460
[fx] merge development into main (#1)
super-dainiu Aug 12, 2022
bea7060
[fx] add rules to linearize computation graphs for searching. (#2)
super-dainiu Aug 16, 2022
86c005d
[fx] merge
super-dainiu Aug 16, 2022
da259cc
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
296b405
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
bf7feea
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
e6c5f70
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
0cbafd8
Merge branch 'feature/linear_ckpt' of http://github.com/super-dainiu/…
super-dainiu Aug 16, 2022
8e14703
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
92e8223
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
3e9531c
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
02c5cae
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
a8616ef
Merge branch 'hpcaitech:main' into feature/linear_ckpt
super-dainiu Aug 17, 2022
083cf7f
[fx] fix inconsistencies.
super-dainiu Aug 17, 2022
9c7441e
[fx] fix MetaInfoProp.
super-dainiu Aug 17, 2022
76f55b7
Merge branch 'hpcaitech:main' into feature/linear_ckpt
super-dainiu Aug 17, 2022
2c8a827
[fx] fix MetaInfoProp.
super-dainiu Aug 17, 2022
b1afd09
Merge branch 'feature/linear_ckpt' of http://github.com/super-dainiu/…
super-dainiu Aug 17, 2022
ff71edc
[fx] consider MetaInfoProp for inplace operands.
super-dainiu Aug 18, 2022
c90d14a
[fx] consider MetaInfoProp for inplace operands.
super-dainiu Aug 18, 2022
ea7250b
[fx] consider MetaInfoProp for inplace operands.
super-dainiu Aug 18, 2022
77406fe
[fx] consider MetaInfoProp for inplace operands.
super-dainiu Aug 18, 2022
0da5d29
[fx] consider MetaInfoProp for inplace operands.
super-dainiu Aug 18, 2022
98ddce6
[fx] consider MetaInfoProp for inplace operands.
super-dainiu Aug 18, 2022
c4dbb99
[fx] add profiler for fx nodes.
super-dainiu Aug 23, 2022
7719e4f
[fx] add profiler for fx nodes.
super-dainiu Aug 23, 2022
3a82865
[fx] add profiler for fx nodes.
super-dainiu Aug 23, 2022
5075e45
[fx] add profiler for fx nodes.
super-dainiu Aug 23, 2022
08e6d73
[fx] add profiler for fx nodes.
super-dainiu Aug 23, 2022
9091c84
[fx] add profiler for fx nodes.
super-dainiu Aug 23, 2022
236c52e
[fx] add profiler for fx nodes.
super-dainiu Aug 23, 2022
7a03047
[fx] add profiler for fx nodes.
super-dainiu Aug 23, 2022
408b1d6
[fx] fix error in tests.
super-dainiu Aug 24, 2022
fe3e098
[fx] unfix bug.
super-dainiu Aug 24, 2022
37fad8c
[fx] unfix bug.
super-dainiu Aug 24, 2022
0d6e467
Merge branch 'hpcaitech:main' into feature/profiler
super-dainiu Aug 24, 2022
2249598
[fx] patch more modules and functions.
super-dainiu Aug 24, 2022
b5cc75a
[fx] change name of utils.py to profiler.py
super-dainiu Aug 24, 2022
bb8c272
[fx] add profiler for rnn.
super-dainiu Aug 25, 2022
c628b1f
[fx] add profiler for rnn.
super-dainiu Aug 25, 2022
80be7f7
Merge branch 'feature/profiler' of http://github.com/super-dainiu/Col…
super-dainiu Aug 25, 2022
45516cf
[fx] merge
super-dainiu Aug 25, 2022
e09ccb8
[fx] polish and add more patch for profiler.
super-dainiu Aug 25, 2022
f3b6464
[fx] polish and add more patch for profiler.
super-dainiu Aug 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/fx/profiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .registry import *
from .profiler_function import *
from .profiler_module import *
from .utils import *
from .profiler import *
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from functools import partial
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
from typing import Callable, NamedTuple, Any, Dict, Tuple
from typing import Callable, List, NamedTuple, Any, Dict, Tuple, Union
import torch
from torch.fx.node import Argument, Target
from torch.fx.node import Argument, Target, map_aggregate
from torch.fx._compatibility import compatibility
from colossalai.fx.tracer.meta_patch import meta_patched_function, meta_patched_module
from . import meta_profiler_function, meta_profiler_module
Expand All @@ -12,6 +12,30 @@
'calculate_param_size'
]

CALL_FUNCTION_MSG = \
"""
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
from colossalai.fx.profiler import meta_profiler_function

@meta_profiler_function.register(YOUR_FUNCTION)
def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]:
flops = ...
macs = ...
return flops, macs
"""
CALL_METHOD_MSG = 'Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}'
CALL_MODULE_MSG = \
"""
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
from colossalai.fx.profiler import meta_profiler_module

@meta_profiler_module.register(YOUR_MODULE)
def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]:
flops = ...
macs = ...
return flops, macs
"""

# TODO fill out the inplace ops
INPLACE_OPS = [
add,
Expand All @@ -22,18 +46,30 @@
pos,
getitem,
setitem,
getattr,
torch.Tensor.cpu,
]

# TODO check that call_methods are indeed inplace
# TODO: list all call_methods that are inplace here
INPLACE_METHOD = [
'transpose',
'permute',
# TODO: reshape may return a copy of the data if the data is not contiguous
'reshape',
'dim',
'flatten',
]

# TODO: list all call_methods that are not inplace here
NON_INPLACE_METHOD = [
'expand',
'mean',
]


@compatibility(is_backward_compatible=True)
class MetaProfile(NamedTuple):

# MetaProfile is a structure containing pertinent information
# about a node within a torch.fx GraphModule.

Expand All @@ -43,25 +79,35 @@ class MetaProfile(NamedTuple):
macs: int


def calculate_activation_size(activation: any) -> int:
"""
Calculate activation size of a node.
def calculate_activation_size(activation: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
"""Calculate activation size of a node.

Args:
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`

Returns:
int: The activation size
"""
activation_size = 0
if isinstance(activation, torch.Tensor):
activation_size += activation.numel() * torch.tensor([], dtype=activation.dtype).element_size()
elif isinstance(activation, dict):
value_list = [v for _, v in activation.items()]
activation_size += calculate_activation_size(value_list)
else:
elif isinstance(activation, tuple) or isinstance(activation, list):
for element in activation:
activation_size += calculate_activation_size(element)
return activation_size


def calculate_param_size(mod: torch.nn.Module) -> int:
"""
Calculate param size of a node.
"""Calculate param size of a node.

Args:
mod (torch.nn.Module): The target `torch.nn.Module`

Returns:
int: The param size
"""
param_size = 0
for param in mod.parameters():
Expand All @@ -78,17 +124,21 @@ def profile_function(target: 'Target') -> Callable:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn.functional` are available.

Usage:
input = torch.rand(100, 100, 100, 100, device='meta')
func = torch.nn.functional.relu
output, profile = profile_function(func)(input, inplace=False)
print(f"Profiling function {func},")
print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")
Examples:
>> input = torch.rand(100, 100, 100, 100, device='meta')
>> func = torch.nn.functional.relu
>> output, profile = profile_function(func)(input, inplace=False)
>> print(f"Profiling function {func},")
>> print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")
Profiling function <function relu at 0x7fcdd0258d30>,
Param size: 0.000 MB, Activation size: 381.470 MB, 100000000 FLOPs, 0 MACs
"""

def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
assert meta_profiler_function.has(target) or meta_profiler_function.has(
target.__name__), f"Colossal-AI hasn't supported profiling for {target}, you might manually patch it."
target.__name__), CALL_FUNCTION_MSG.format(target)
# ensure all arguments satisfy `device='meta'`
args, kwargs = map_aggregate([args, kwargs], lambda a: a.to('meta') if isinstance(a, torch.Tensor) else a)

# call_function has no parameters
param_size = 0
Expand Down Expand Up @@ -127,14 +177,17 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# args[0] is the `self` object for this method call
self_obj, *args_tail = args

# Execute the method and return the result
# execute the method and return the result
assert isinstance(target, str), f'{target} instance is not str.'
result = getattr(self_obj, target)(*args_tail, **kwargs)
assert target in INPLACE_METHOD, f'Please check {target} is an inplace method. If so, add target to INPLACE_METHOD={INPLACE_METHOD}.'

# ensure all arguments satisfy `device='meta'`
map_aggregate([args, kwargs], lambda a: a.to('meta') if isinstance(a, torch.Tensor) else a)
result = getattr(self_obj, target)(*args_tail, **kwargs)
assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format(
target, INPLACE_METHOD, NON_INPLACE_METHOD)
# call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs.
param_size = 0
activation_size = 0
activation_size = 0 if target in INPLACE_METHOD else calculate_activation_size(result)
flops = 0
macs = 0
return result, MetaProfile(param_size, activation_size, flops, macs)
Expand All @@ -151,17 +204,20 @@ def profile_module(module: torch.nn.Module) -> Callable:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn` are available.

Usage:
input = torch.rand(4, 3, 224, 224, device='meta')
mod = torch.nn.Conv2d(3, 128, 3)
output, profile = profile_module(mod)(input)
print(f"Profiling function {mod},")
print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")
Example:
>> input = torch.rand(4, 3, 224, 224, device='meta')
>> mod = torch.nn.Conv2d(3, 128, 3)
>> output, profile = profile_module(mod)(input)
>> print(f"Profiling function {mod},")
>> print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")
Profiling function Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1)),
Param size: 0.014 MB, Activation size: 96.258 MB, 1387837440 FLOPs, 681302016 MACs
"""

def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
assert meta_profiler_module.has(
type(module)), f"Colossal-AI hasn't supported profiling for {module}, you might manually patch it."
assert meta_profiler_module.has(type(module)), CALL_MODULE_MSG.format(type(module))
# ensure all arguments satisfy `device='meta'`
map_aggregate([args, kwargs], lambda a: a.to('meta') if isinstance(a, torch.Tensor) else a)
param_size = calculate_param_size(module)
activation_size = 0
result = func(*args, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
torch.nn.functional.elu: 4,
torch.nn.functional.relu6: 2,
torch.nn.functional.gelu: 9,
torch.nn.functional.hardswish: 5,
torch.nn.functional.hardsigmoid: 4,
}


Expand All @@ -23,6 +25,8 @@
@meta_profiler_function.register(torch.nn.functional.relu)
@meta_profiler_function.register(torch.nn.functional.sigmoid)
@meta_profiler_function.register(torch.nn.functional.tanh)
@meta_profiler_function.register(torch.nn.functional.hardswish)
@meta_profiler_function.register(torch.nn.functional.hardsigmoid)
def torch_nn_func_non_linear_act(input: torch.Tensor, inplace: bool = False) -> Tuple[int, int]:
flops = input.numel()
macs = 0
Expand Down
26 changes: 14 additions & 12 deletions colossalai/fx/profiler/profiler_function/arithmetic.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,19 @@
import operator
from functools import reduce
from typing import Any, Optional, Tuple, Union
import torch
from ..registry import meta_profiler_function


def _prod(dims):
p = 1
for v in dims:
p *= v
return p


def _elementwise_flops_compute(input, other):
# copied from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L763
if not torch.is_tensor(input):
if torch.is_tensor(other):
return _prod(other.shape), 0
return reduce(operator.mul, other.shape), 0
else:
return 1, 0
elif not torch.is_tensor(other):
return _prod(input.shape), 0
return reduce(operator.mul, input.shape), 0
else:
dim_input = len(input.shape)
dim_other = len(other.shape)
Expand All @@ -32,17 +27,24 @@ def _elementwise_flops_compute(input, other):
final_shape.append(in_i)
else:
final_shape.append(ot_i)
flops = _prod(final_shape)
flops = reduce(operator.mul, final_shape)
return flops, 0


@meta_profiler_function.register(torch.add)
@meta_profiler_function.register(torch.eq)
@meta_profiler_function.register(torch.sub)
@meta_profiler_function.register(torch.mul)
@meta_profiler_function.register(torch.floor_divide)
@meta_profiler_function.register('add') # for built-in op +
@meta_profiler_function.register('iadd') # for built-in op +=
@meta_profiler_function.register('eq') # for built-in op =
@meta_profiler_function.register('sub') # for built-in op -
@meta_profiler_function.register('isub') # for built-in op -=
@meta_profiler_function.register('mul') # for built-in op *
@meta_profiler_function.register('imul') # for built-in op *=
@meta_profiler_function.register('floordiv') # for built-in op //
@meta_profiler_function.register('ifloordiv') # for built-in op //=
def torch_add_like_ops(input: Any, other: Any, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
return _elementwise_flops_compute(input, other)

Expand All @@ -58,14 +60,14 @@ def torch_elementwise_op(input: torch.Tensor, *, out: Optional[torch.Tensor] = N
@meta_profiler_function.register('matmul') # for built-in op @
@meta_profiler_function.register(torch.Tensor.matmul)
def torch_matmul(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
macs = _prod(input.shape) * other.shape[-1]
macs = reduce(operator.mul, input.shape) * other.shape[-1]
flops = 2 * macs
return flops, macs


@meta_profiler_function.register(torch.bmm)
def torch_bmm(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
macs = _prod(input.shape) * other.shape[-1]
macs = reduce(operator.mul, input.shape) * other.shape[-1]
flops = 2 * macs
return flops, macs

Expand Down
7 changes: 7 additions & 0 deletions colossalai/fx/profiler/profiler_function/python_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,10 @@ def operator_getitem(a: Any, b: Any) -> Tuple[int, int]:
flops = 0
macs = 0
return flops, macs


@meta_profiler_function.register(getattr)
def python_getattr(a: Any, b: Any) -> Tuple[int, int]:
flops = 0
macs = 0
return flops, macs
12 changes: 4 additions & 8 deletions colossalai/fx/profiler/profiler_function/torch_ops.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
from functools import reduce
import operator
from typing import Any, Optional, Tuple
import torch
from ..registry import meta_profiler_function


def _prod(dims):
p = 1
for v in dims:
p *= v
return p


@meta_profiler_function.register(torch.arange)
@meta_profiler_function.register(torch.finfo)
@meta_profiler_function.register(torch.permute)
Expand All @@ -31,6 +26,7 @@ def _prod(dims):
@meta_profiler_function.register(torch.full)
@meta_profiler_function.register(torch.Tensor.cpu)
@meta_profiler_function.register(torch.Tensor.cuda)
@meta_profiler_function.register(torch._assert)
def torch_zero_flops_op(*args, **kwargs) -> Tuple[int, int]:
flops = 0
macs = 0
Expand All @@ -57,7 +53,7 @@ def torch_max(input: torch.Tensor,
if dim is not None:
shape = list(input.shape)
shape.pop(int(dim))
flops = _prod(shape), macs
flops = reduce(operator.mul, shape), macs
return flops, macs
else:
flops = input.numel()
Expand Down
3 changes: 3 additions & 0 deletions colossalai/fx/profiler/profiler_module/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from .activation_function import *
from .attention import *
from .convolution import *
from .dropout import *
from .embedding import *
from .linear import *
from .normalization import *
from .pooling import *
from .rnn import *
from .torch_op import *
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
torch.nn.ELU: 4,
torch.nn.ReLU6: 2,
torch.nn.GELU: 9,
torch.nn.Hardswish: 5,
torch.nn.Hardsigmoid: 4,
}


Expand All @@ -23,6 +25,8 @@
@meta_profiler_module.register(torch.nn.Tanh)
@meta_profiler_module.register(torch.nn.ReLU6)
@meta_profiler_module.register(torch.nn.PReLU)
@meta_profiler_module.register(torch.nn.Hardswish)
@meta_profiler_module.register(torch.nn.Hardsigmoid)
def torch_nn_non_linear_act(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]:
flops = input.numel()
macs = 0
Expand Down
Loading