Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
06f8991
[fx] modify the calculation of node_size in MetaInfoProp for activati…
super-dainiu Aug 9, 2022
3cd7d22
[fx] modify the calculation of node_size in MetaInfoProp for activati…
super-dainiu Aug 9, 2022
0849b3b
[fx] modify the calculation of node_size in MetaInfoProp for activati…
super-dainiu Aug 10, 2022
701786c
Merge branch 'hpcaitech:main' into main
super-dainiu Aug 10, 2022
a75e5a2
Merge branch 'hpcaitech:main' into main
super-dainiu Aug 10, 2022
c20beb2
Merge branch 'hpcaitech:main' into main
super-dainiu Aug 11, 2022
7e87286
Merge branch 'hpcaitech:main' into main
super-dainiu Aug 11, 2022
f027931
Merge branch 'hpcaitech:main' into main
super-dainiu Aug 12, 2022
9b4f460
[fx] merge development into main (#1)
super-dainiu Aug 12, 2022
bea7060
[fx] add rules to linearize computation graphs for searching. (#2)
super-dainiu Aug 16, 2022
86c005d
[fx] merge
super-dainiu Aug 16, 2022
da259cc
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
296b405
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
bf7feea
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
e6c5f70
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
0cbafd8
Merge branch 'feature/linear_ckpt' of http://github.com/super-dainiu/…
super-dainiu Aug 16, 2022
8e14703
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
92e8223
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
3e9531c
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
02c5cae
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
a8616ef
Merge branch 'hpcaitech:main' into feature/linear_ckpt
super-dainiu Aug 17, 2022
083cf7f
[fx] fix inconsistencies.
super-dainiu Aug 17, 2022
9c7441e
[fx] fix MetaInfoProp.
super-dainiu Aug 17, 2022
76f55b7
Merge branch 'hpcaitech:main' into feature/linear_ckpt
super-dainiu Aug 17, 2022
2c8a827
[fx] fix MetaInfoProp.
super-dainiu Aug 17, 2022
b1afd09
Merge branch 'feature/linear_ckpt' of http://github.com/super-dainiu/…
super-dainiu Aug 17, 2022
ff71edc
[fx] consider MetaInfoProp for inplace operands.
super-dainiu Aug 18, 2022
c90d14a
[fx] consider MetaInfoProp for inplace operands.
super-dainiu Aug 18, 2022
ea7250b
[fx] consider MetaInfoProp for inplace operands.
super-dainiu Aug 18, 2022
77406fe
[fx] consider MetaInfoProp for inplace operands.
super-dainiu Aug 18, 2022
0da5d29
[fx] consider MetaInfoProp for inplace operands.
super-dainiu Aug 18, 2022
98ddce6
[fx] consider MetaInfoProp for inplace operands.
super-dainiu Aug 18, 2022
c4dbb99
[fx] add profiler for fx nodes.
super-dainiu Aug 23, 2022
7719e4f
[fx] add profiler for fx nodes.
super-dainiu Aug 23, 2022
3a82865
[fx] add profiler for fx nodes.
super-dainiu Aug 23, 2022
5075e45
[fx] add profiler for fx nodes.
super-dainiu Aug 23, 2022
08e6d73
[fx] add profiler for fx nodes.
super-dainiu Aug 23, 2022
9091c84
[fx] add profiler for fx nodes.
super-dainiu Aug 23, 2022
236c52e
[fx] add profiler for fx nodes.
super-dainiu Aug 23, 2022
7a03047
[fx] add profiler for fx nodes.
super-dainiu Aug 23, 2022
408b1d6
[fx] fix error in tests.
super-dainiu Aug 24, 2022
fe3e098
[fx] unfix bug.
super-dainiu Aug 24, 2022
37fad8c
[fx] unfix bug.
super-dainiu Aug 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions colossalai/fx/passes/algorithms/ckpt_solver_chen.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ def run_chen_greedy(b: int = 0) -> Tuple[Set, int]:
y = 0
prev_idx = 2
for (idx, n) in enumerate(gm.graph.nodes):
temp += getattr(n, 'activation_size')
temp += getattr(n, '__activation__')
y = max(y, temp)
if temp > b and n in ckpt_nodes:
x += getattr(n, 'activation_size')
x += getattr(n, '__activation__')
temp = 0
ckpt_intv.append((prev_idx, idx + 1))
prev_idx = idx + 1
Expand Down
217 changes: 158 additions & 59 deletions colossalai/fx/passes/meta_info_prop.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from operator import add, getitem
import torch
import torch.fx
from torch.fx.node import Node, map_aggregate
from torch.fx.node import Node, map_aggregate, Argument, Target
from typing import Any, Tuple, NamedTuple, Optional, Dict
from functools import reduce
from torch.fx._compatibility import compatibility
from torch.fx.immutable_collections import immutable_dict, immutable_list
from colossalai.fx.profiler import MetaProfile, profile_function, profile_module, calculate_activation_size, profile_method


@compatibility(is_backward_compatible=True)
Expand Down Expand Up @@ -36,47 +38,11 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata:
return TensorMetadata(shape, dtype, requires_grad, stride, numel, is_tensor)


def _compute_activation_size(node_metadata: any) -> int:
"""
Compute numel of a node with ``tensor_meta`` attribute.
"""
node_numel = 0

if isinstance(node_metadata, TensorMetadata):
node_numel += node_metadata.numel * torch.tensor([], dtype=node_metadata.dtype).element_size()
elif isinstance(node_metadata, dict):
value_list = [v for _, v in node_metadata.items()]
node_numel += _compute_activation_size(value_list)
else:
for element in node_metadata:
node_numel += _compute_activation_size(element)

return node_numel


def _map_aggregate(arg, fn):
"""
Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys.
"""
if isinstance(arg, torch.Size):
return fn(arg)
if isinstance(arg, tuple):
return tuple(map_aggregate(elem, fn) for elem in arg)
elif isinstance(arg, list):
return immutable_list(map_aggregate(elem, fn) for elem in arg)
elif isinstance(arg, dict):
return immutable_dict((k, map_aggregate(v, fn)) for k, v in arg.items())
elif isinstance(arg, slice):
return slice(map_aggregate(arg.start, fn), map_aggregate(arg.stop, fn), map_aggregate(arg.step, fn))
else:
return fn(arg)


@compatibility(is_backward_compatible=True)
class MetaInfoProp(torch.fx.Interpreter):
"""
Execute an FX graph Node-by-Node and
record the shape and type of the result
Execute an FX graph Node-by-Node with meta tensor and
record the shape, FLOPs, MACs and type of the result
into the corresponding node.

Usage:
Expand Down Expand Up @@ -104,39 +70,172 @@ class MetaInfoProp(torch.fx.Interpreter):

"""

@compatibility(is_backward_compatible=True)
def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_processing: bool = True) -> Any:
"""
Add additional check for initial args to ensure all the tensor appears with `device='meta'`
"""
for elem in args:
if isinstance(elem, torch.Tensor):
assert elem.is_meta, "Input torch.Tensor are assumed to appear with device='meta'"
return super().run(*args, initial_env, enable_io_processing)

@compatibility(is_backward_compatible=True)
def run_node(self, n: Node) -> Any:
# TODO: We might run_node(n) with meta data, and count FLOPS for each node
result = super().run_node(n)
"""
Run a specific node ``n`` and return the result.
Calls into placeholder, get_attr, call_function,
call_method, call_module, or output depending
on ``node.op``

Args:
n (Node): The Node to execute

Returns:
Any: The result of executing ``n``
"""
result, profile = super().run_node(n)
profile: MetaProfile

def extract_tensor_meta(obj):
if isinstance(obj, torch.Tensor):
return _extract_tensor_metadata(obj)
else:
return TensorMetadata(None, None, False, None, 0, False)

meta = _map_aggregate(result, extract_tensor_meta)
meta = map_aggregate(result, extract_tensor_meta)
n.meta['tensor_meta'] = meta

total_activation_size = 0
total_param_size = 0
if n.op == 'call_module':
target_module = n.graph.owning_module.get_submodule(n.target)
if not getattr(target_module, 'inplace', False):
total_activation_size = _compute_activation_size(n.meta['tensor_meta'])
for param in target_module.parameters():
total_param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
elif n.op == 'call_function':
if 'inplace' not in n.kwargs:
total_activation_size = _compute_activation_size(n.meta['tensor_meta'])
else:
total_activation_size = _compute_activation_size(n.meta['tensor_meta'])

setattr(n, 'node_size', total_activation_size + total_param_size)
setattr(n, 'param_size', total_param_size)
setattr(n, 'activation_size', total_activation_size)
# TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', profile.param + profile.activation)
setattr(n, '__param__', profile.param)
setattr(n, '__activation__', profile.activation)
setattr(n, '__flops__', profile.flops)
setattr(n, '__macs__', profile.macs)
n.meta['type'] = type(result)
return result

# Main Node running APIs
@compatibility(is_backward_compatible=True)
def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``placeholder`` node. Note that this is stateful:
``Interpreter`` maintains an internal iterator over
arguments passed to ``run`` and this method returns
next() on that iterator.

Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation

Returns:
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
"""
result = super().placeholder(target, args, kwargs)
# A placeholder node only has activation
return result, MetaProfile(0, calculate_activation_size(result), 0, 0)

@compatibility(is_backward_compatible=True)
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``get_attr`` node. Will retrieve an attribute
value from the ``Module`` hierarchy of ``self.module``.

Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation

Return:
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
"""
# A get_attr node never has parameters, activations, FLOPs, or MACs
return super().get_attr(target, args, kwargs), MetaProfile(0, 0, 0, 0)

@compatibility(is_backward_compatible=True)
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node with meta tensor and return the result and its meta profile.

Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation

Return
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
"""
assert not isinstance(target, str)
return profile_function(target)(*args, **kwargs)

@compatibility(is_backward_compatible=True)
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node with meta tensor and return the result and its meta profile.

Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation

Return
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
"""
return profile_method(target)(*args, **kwargs)

@compatibility(is_backward_compatible=True)
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_module`` node with meta tensor and return the result and its meta profile.

Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation

Return
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
"""
# Retrieve executed args and kwargs values from the environment
# Execute the method and return the result
assert isinstance(target, str)
submod = self.fetch_attr(target)
return profile_module(submod)(*args, **kwargs)

@compatibility(is_backward_compatible=True)
def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute an ``output`` node. This really just retrieves
the value referenced by the ``output`` node and returns it.

Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation

Return:
Any: The return value referenced by the output node
"""
return args[0], MetaProfile(0, 0, 0, 0)

def propagate(self, *args):
"""
Run `module` via interpretation and return the result and
Expand Down
4 changes: 4 additions & 0 deletions colossalai/fx/profiler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .registry import *
from .profiler_function import *
from .profiler_module import *
from .utils import *
8 changes: 8 additions & 0 deletions colossalai/fx/profiler/profiler_function/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .activation_function import *
from .arithmetic import *
from .embedding import *
from .linear import *
from .normalization import *
from .pooling import *
from .python_ops import *
from .torch_ops import *
29 changes: 29 additions & 0 deletions colossalai/fx/profiler/profiler_function/activation_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Tuple
import torch
from ..registry import meta_profiler_function

# TODO: different activation has different FLOPs count, currently unused.
_multiplier = {
torch.nn.functional.relu: 1,
torch.nn.functional.prelu: 4,
torch.nn.functional.sigmoid: 4,
torch.nn.functional.tanh: 5,
torch.nn.functional.leaky_relu: 3,
torch.nn.functional.elu: 4,
torch.nn.functional.relu6: 2,
torch.nn.functional.gelu: 9,
}


@meta_profiler_function.register(torch.nn.functional.leaky_relu)
@meta_profiler_function.register(torch.nn.functional.elu)
@meta_profiler_function.register(torch.nn.functional.gelu)
@meta_profiler_function.register(torch.nn.functional.relu6)
@meta_profiler_function.register(torch.nn.functional.prelu)
@meta_profiler_function.register(torch.nn.functional.relu)
@meta_profiler_function.register(torch.nn.functional.sigmoid)
@meta_profiler_function.register(torch.nn.functional.tanh)
def torch_nn_func_non_linear_act(input: torch.Tensor, inplace: bool = False) -> Tuple[int, int]:
flops = input.numel()
macs = 0
return flops, macs
83 changes: 83 additions & 0 deletions colossalai/fx/profiler/profiler_function/arithmetic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from typing import Any, Optional, Tuple, Union
import torch
from ..registry import meta_profiler_function


def _prod(dims):
p = 1
for v in dims:
p *= v
return p


def _elementwise_flops_compute(input, other):
# copied from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L763
if not torch.is_tensor(input):
if torch.is_tensor(other):
return _prod(other.shape), 0
else:
return 1, 0
elif not torch.is_tensor(other):
return _prod(input.shape), 0
else:
dim_input = len(input.shape)
dim_other = len(other.shape)
max_dim = max(dim_input, dim_other)

final_shape = []
for i in range(max_dim):
in_i = input.shape[i] if i < dim_input else 1
ot_i = other.shape[i] if i < dim_other else 1
if in_i > ot_i:
final_shape.append(in_i)
else:
final_shape.append(ot_i)
flops = _prod(final_shape)
return flops, 0


@meta_profiler_function.register(torch.add)
@meta_profiler_function.register('add') # for built-in op +
@meta_profiler_function.register('iadd') # for built-in op +=
@meta_profiler_function.register('sub') # for built-in op -
@meta_profiler_function.register('isub') # for built-in op -=
@meta_profiler_function.register('mul') # for built-in op *
@meta_profiler_function.register('imul') # for built-in op *=
def torch_add_like_ops(input: Any, other: Any, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
return _elementwise_flops_compute(input, other)


@meta_profiler_function.register(torch.abs)
def torch_elementwise_op(input: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
flops = input.numel()
macs = 0
return flops, macs


@meta_profiler_function.register(torch.matmul)
@meta_profiler_function.register('matmul') # for built-in op @
@meta_profiler_function.register(torch.Tensor.matmul)
def torch_matmul(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
macs = _prod(input.shape) * other.shape[-1]
flops = 2 * macs
return flops, macs


@meta_profiler_function.register(torch.bmm)
def torch_bmm(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
macs = _prod(input.shape) * other.shape[-1]
flops = 2 * macs
return flops, macs


@meta_profiler_function.register(torch.var_mean)
def torch_var_mean(input: torch.Tensor,
dim: Union[int, Tuple[int, ...]],
unbiased: Optional[bool] = True,
keepdim: Optional[bool] = False,
*,
out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
assert out is None, 'saving to out is not supported yet'
flops = input.numel() * 3
macs = 0
return flops, macs
Loading