-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[fx] provide a stable but not accurate enough version of profiler. #1547
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e0edb21
579b70b
7f3a532
17be5a5
36c93ca
d6dcd80
01de3f1
e51f32a
a9193b5
97cfba9
c992b25
62d7096
b2e8f6a
f80735f
0d4a030
e7a33a6
9223090
1e49efe
77045db
1c127c6
6f486d1
c1eb532
b3d144e
15a1bb1
d011148
1a45a12
cf06d84
359f750
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -181,6 +181,12 @@ def meta_hardswish_backward(grad_out: torch.Tensor, input: torch.Tensor): | |
| return grad_in | ||
|
|
||
|
|
||
| @register_meta(aten.hardtanh_backward.default) | ||
| def meta_hardtanh_backward(grad_out: torch.Tensor, input: torch.Tensor, min_val: int, max_val: int): | ||
| grad_in = torch.empty_like(input) | ||
| return grad_in | ||
|
Comment on lines
+184
to
+187
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. only some extra registrations in this file. |
||
|
|
||
|
|
||
| @register_meta(aten.roll.default) | ||
| def meta_roll(input: torch.Tensor, shifts, dims): | ||
| return torch.empty_like(input) | ||
|
|
@@ -321,3 +327,17 @@ def meta_index_Tensor(self, indices): | |
| else: | ||
| replacement_shape = list(index.shape) | ||
| return self.new_empty(before_shape + replacement_shape + after_shape) | ||
|
|
||
|
|
||
| @register_meta(aten.embedding_dense_backward.default) | ||
| def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, | ||
| scale_grad_by_freq): | ||
| return torch.empty((num_weights, grad_output.size(-1)), | ||
| dtype=grad_output.dtype, | ||
| device=grad_output.device, | ||
| layout=grad_output.layout) | ||
|
|
||
|
|
||
| @register_meta(aten.where.self) | ||
| def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor): | ||
| return torch.empty_like(condition) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,13 +1,10 @@ | ||
| from operator import add, getitem | ||
| import torch | ||
| import torch.fx | ||
| from torch.fx.node import Node, Argument, Target | ||
| from torch.utils._pytree import tree_map | ||
| from typing import Any, Tuple, NamedTuple, Optional, Dict | ||
| from functools import reduce | ||
| from typing import Any, Tuple, NamedTuple, Dict | ||
| from torch.fx._compatibility import compatibility | ||
| from torch.fx.immutable_collections import immutable_dict, immutable_list | ||
| from colossalai.fx.profiler import MetaProfile, MetaTensor, profile_function, profile_module, calculate_activation_size, profile_method | ||
| from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size, parameter_size | ||
|
|
||
|
|
||
| @compatibility(is_backward_compatible=True) | ||
|
|
@@ -71,14 +68,6 @@ class MetaInfoProp(torch.fx.Interpreter): | |
|
|
||
| """ | ||
|
|
||
| @compatibility(is_backward_compatible=True) | ||
| def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_processing: bool = True) -> Any: | ||
| """ | ||
| Add additional check for initial args to ensure all the tensor appears with `device='meta'` | ||
| """ | ||
| args = tree_map(lambda elem: MetaTensor(elem.to('meta')) if isinstance(elem, torch.Tensor) else elem, args) | ||
| return super().run(*args, initial_env, enable_io_processing) | ||
|
|
||
| @compatibility(is_backward_compatible=True) | ||
| def run_node(self, n: Node) -> Any: | ||
| """ | ||
|
|
@@ -93,8 +82,7 @@ def run_node(self, n: Node) -> Any: | |
| Returns: | ||
| Any: The result of executing ``n`` | ||
| """ | ||
| result, profile = super().run_node(n) | ||
| profile: MetaProfile | ||
| result, flop_count, mem_stat = super().run_node(n) | ||
|
|
||
| def extract_tensor_meta(obj): | ||
| if isinstance(obj, torch.Tensor): | ||
|
|
@@ -106,12 +94,17 @@ def extract_tensor_meta(obj): | |
| n.meta['tensor_meta'] = meta | ||
|
|
||
| # TODO: the attribute node_size should be removed in the future | ||
| setattr(n, 'node_size', profile.param + profile.activation) | ||
| setattr(n, '__param__', profile.param) | ||
| setattr(n, '__activation__', profile.activation) | ||
| setattr(n, '__flops__', profile.flops) | ||
| setattr(n, '__macs__', profile.macs) | ||
| setattr(n, 'node_size', mem_stat[1]) | ||
| setattr(n, 'fwd_flop', flop_count[0]) | ||
| setattr(n, 'bwd_flop', flop_count[1]) | ||
| setattr(n, 'fwd_tmp', mem_stat[0]) | ||
| setattr(n, 'fwd_out', mem_stat[1]) | ||
| setattr(n, 'bwd_tmp', mem_stat[2]) | ||
| setattr(n, 'bwd_out', mem_stat[3]) | ||
| n.meta['type'] = type(result) | ||
|
|
||
| for param in self.module.parameters(): | ||
| param.grad = None | ||
|
Comment on lines
+106
to
+107
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Obviously, we need to clear grad of the parameter, because these grads are |
||
| return result | ||
|
|
||
| # Main Node running APIs | ||
|
|
@@ -132,11 +125,12 @@ def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict | |
|
|
||
| Returns: | ||
| result (Any): The argument value that was retrieved | ||
| profile (MetaProfile): The meta profile of this node | ||
| flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). | ||
| mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) | ||
| """ | ||
| result = super().placeholder(target, args, kwargs) | ||
| # A placeholder node only has activation | ||
| return result, MetaProfile(0, calculate_activation_size(result), 0, 0) | ||
| return result, (0, 0), (0, activation_size(result), 0, 0) | ||
|
|
||
| @compatibility(is_backward_compatible=True) | ||
| def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: | ||
|
|
@@ -153,10 +147,10 @@ def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[st | |
|
|
||
| Return: | ||
| result (Any): The argument value that was retrieved | ||
| profile (MetaProfile): The meta profile of this node | ||
| flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). | ||
| mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) | ||
| """ | ||
| # A get_attr node never has parameters, activations, FLOPs, or MACs | ||
| return super().get_attr(target, args, kwargs), MetaProfile(0, 0, 0, 0) | ||
| return super().get_attr(target, args, kwargs), (0, 0), (0, 0, 0, 0) | ||
|
|
||
| @compatibility(is_backward_compatible=True) | ||
| def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: | ||
|
|
@@ -172,7 +166,8 @@ def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Di | |
|
|
||
| Return | ||
| result (Any): The argument value that was retrieved | ||
| profile (MetaProfile): The meta profile of this node | ||
| flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). | ||
| mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) | ||
| """ | ||
| assert not isinstance(target, str) | ||
| return profile_function(target)(*args, **kwargs) | ||
|
|
@@ -191,7 +186,8 @@ def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict | |
|
|
||
| Return | ||
| result (Any): The argument value that was retrieved | ||
| profile (MetaProfile): The meta profile of this node | ||
| flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). | ||
| mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) | ||
| """ | ||
| return profile_method(target)(*args, **kwargs) | ||
|
|
||
|
|
@@ -209,7 +205,8 @@ def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict | |
|
|
||
| Return | ||
| result (Any): The argument value that was retrieved | ||
| profile (MetaProfile): The meta profile of this node | ||
| flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). | ||
| mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) | ||
| """ | ||
| # Retrieve executed args and kwargs values from the environment | ||
| # Execute the method and return the result | ||
|
|
@@ -231,9 +228,11 @@ def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, | |
| kwargs (Dict): Dict of keyword arguments for this invocation | ||
|
|
||
| Return: | ||
| Any: The return value referenced by the output node | ||
| result (Any): The argument value that was retrieved | ||
| flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). | ||
| mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) | ||
| """ | ||
| return args[0], MetaProfile(0, 0, 0, 0) | ||
| return args[0], (0, 0), (0, 0, 0, 0) | ||
|
|
||
| def propagate(self, *args): | ||
| """ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,9 @@ | ||
| from .meta_tensor import MetaTensor | ||
| from .registry import meta_profiler_function, meta_profiler_module | ||
| from .profiler_function import * | ||
| from .profiler_module import * | ||
| from .profiler import * | ||
| from ... import META_COMPATIBILITY | ||
| if META_COMPATIBILITY: | ||
| from .opcount import flop_mapping | ||
| from .tensor import MetaTensor | ||
| from .profiler import profile_function, profile_method, profile_module, _profile | ||
| else: | ||
| from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module | ||
|
|
||
| from .memory import parameter_size, activation_size |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| from .registry import meta_profiler_function, meta_profiler_module | ||
| from .profiler_function import * | ||
| from .profiler_module import * | ||
| from .profiler import profile_function, profile_method, profile_module |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,125 @@ | ||
| from typing import Callable, Any, Dict, Tuple | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the old one, so I did not modify anything except for the output format.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. old for PyTorch 1.11 |
||
| import torch | ||
| from torch.fx.node import Argument, Target | ||
| from . import meta_profiler_function, meta_profiler_module | ||
| from ..memory import activation_size, INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS | ||
|
|
||
| __all__ = ['profile_function', 'profile_module', 'profile_method'] | ||
|
|
||
| CALL_FUNCTION_MSG = \ | ||
| """ | ||
| Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n | ||
| from colossalai.fx.profiler.experimental import meta_profiler_function | ||
| @meta_profiler_function.register(YOUR_FUNCTION) | ||
| def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]: | ||
| flops = ... | ||
| macs = ... | ||
| return flops, macs | ||
| """ | ||
| CALL_METHOD_MSG = 'Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}' | ||
| CALL_MODULE_MSG = \ | ||
| """ | ||
| Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n | ||
| from colossalai.fx.profiler.experimental import meta_profiler_module | ||
| @meta_profiler_module.register(YOUR_MODULE) | ||
| def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]: | ||
| flops = ... | ||
| macs = ... | ||
| return flops, macs | ||
| """ | ||
|
|
||
|
|
||
| def profile_function(target: 'Target') -> Callable: | ||
| """ | ||
| Wrap a `call_function` node or `torch.nn.functional` in order to | ||
| record the memory cost and FLOPs of the execution. | ||
| Unfortunately, backward memory cost and FLOPs are estimated results. | ||
|
|
||
| Warnings: | ||
| You may only use tensors with `device=meta` for this wrapped function. | ||
| Only original `torch.nn.functional` are available. | ||
|
|
||
| Examples: | ||
| >>> input = torch.rand(100, 100, 100, 100, device='meta') | ||
| >>> func = torch.nn.functional.relu | ||
| >>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_function(func)(input, inplace=False) | ||
| """ | ||
|
|
||
| def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: | ||
| assert meta_profiler_function.has(target) or meta_profiler_function.has( | ||
| target.__name__), CALL_FUNCTION_MSG.format(target) | ||
|
|
||
| fwd_tmp = 0 | ||
| fwd_out = 0 | ||
| out = func(*args, **kwargs) | ||
| if target not in INPLACE_OPS and not kwargs.get('inplace', False): | ||
| fwd_out = activation_size(out) | ||
| if meta_profiler_function.has(target): | ||
| profiler = meta_profiler_function.get(target) | ||
| else: | ||
| profiler = meta_profiler_function.get(target.__name__) | ||
| fwd_flop, _ = profiler(*args, **kwargs) | ||
| return out, (fwd_flop, fwd_flop * 2), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) | ||
|
|
||
| f.__name__ = target.__name__ | ||
| func = target | ||
| return f | ||
|
|
||
|
|
||
| def profile_method(target: 'Target') -> Callable: | ||
| """ | ||
| Wrap a `call_method` node | ||
| record the memory cost and FLOPs of the execution. | ||
|
|
||
| Warnings: | ||
| This is not fully implemented and you may follow the error message to debug. | ||
| """ | ||
|
|
||
| def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: | ||
| # args[0] is the `self` object for this method call | ||
| self_obj, *args_tail = args | ||
|
|
||
| # execute the method and return the result | ||
| assert isinstance(target, str), f'{target} instance is not str.' | ||
|
|
||
| out = getattr(self_obj, target)(*args_tail, **kwargs) | ||
| assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format( | ||
| target, INPLACE_METHOD, NON_INPLACE_METHOD) | ||
| # call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs. | ||
| fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out) | ||
| fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out) | ||
| return out, (0, 0), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) | ||
|
|
||
| return f | ||
|
|
||
|
|
||
| def profile_module(module: torch.nn.Module) -> Callable: | ||
| """ | ||
| Wrap a `call_module` node or `torch.nn` in order to | ||
| record the memory cost and FLOPs of the execution. | ||
|
|
||
| Warnings: | ||
| You may only use tensors with `device=meta` for this wrapped function. | ||
| Only original `torch.nn` are available. | ||
|
|
||
| Example: | ||
| >>> input = torch.rand(4, 3, 224, 224, device='meta') | ||
| >>> mod = torch.nn.Conv2d(3, 128, 3) | ||
| >>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_module(mod)(input) | ||
| """ | ||
|
|
||
| def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: | ||
| assert meta_profiler_module.has(type(module)), CALL_MODULE_MSG.format(type(module)) | ||
|
|
||
| fwd_tmp = 0 | ||
| fwd_out = 0 | ||
| out = func(*args, **kwargs) | ||
| if getattr(module, 'inplace', False): | ||
| fwd_out = activation_size(out) | ||
| profiler = meta_profiler_module.get(type(module)) | ||
| fwd_flop, _ = profiler(module, *args, **kwargs) | ||
| return out, (fwd_flop, fwd_flop * 2), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) | ||
|
|
||
| f.__name__ = module.__class__.__name__ | ||
| func = module.forward | ||
| return f | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
META_COMPATIBILITY is checked when Colossal-AI initializes.