-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[fx] provide an accurate estimation of memory. #1587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
FrankLeeeee
merged 18 commits into
hpcaitech:main
from
super-dainiu:feature/better_flop_tensor
Sep 14, 2022
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
044adf8
[fx] add some comment and docstrings.
super-dainiu a42ab22
[fx] add dataflow analysis for an autograd graph.
super-dainiu f8e1c1c
Merge branch 'main' of https://github.com/super-dainiu/ColossalAI int…
super-dainiu 0d55f26
add intepretation for graph analysis.
super-dainiu 42c6e8c
[fx] before doing save_tensor_hooks.
super-dainiu dafbfcf
Merge branch 'hpcaitech:main' into feature/better_flop_tensor
super-dainiu 5f25d6e
[fx] provide an accurate estimation of memory except for GPT-2.
super-dainiu 9739876
Merge branch 'hpcaitech:main' into feature/better_flop_tensor
super-dainiu 3745c5f
[fx] provide an accurate estimation of memory except for GPT-2.
super-dainiu 504c607
[fx] provide an accurate estimation of memory except for GPT-2.
super-dainiu 5d72a52
[fx] a very accurate version on GPT-2.
super-dainiu 6bdeb29
[fx] refactor code.
super-dainiu fafb7d0
[fx] remove redundant inplace=True.
super-dainiu d3c3690
[fx] refactor code.
super-dainiu de2be8f
[fx] refactor code.
super-dainiu bc727c2
[fx] refactor code.
super-dainiu db4483a
[fx] dive into backward memory.
super-dainiu 88d884f
Merge branch 'hpcaitech:main' into feature/better_flop_tensor
super-dainiu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,136 @@ | ||
| from dataclasses import dataclass | ||
| from enum import Enum | ||
| from typing import Dict | ||
| from torch.fx import Graph, Node | ||
| from .memory import activation_size | ||
|
|
||
|
|
||
| class Stage(Enum): | ||
| FORWARD = 0 | ||
| LOSS = 1 | ||
| BACKWARD = 2 | ||
| PLACEHOLDER = 3 | ||
|
|
||
|
|
||
| @dataclass | ||
| class GraphInfo: | ||
| """ | ||
| GraphInfo is a dataclass for MetaInfo, which measures | ||
| the execution memory cost and FLOPs with `MetaTensor`. | ||
| The dataflow analysis is conducted on a single node of the FX graph. | ||
| ============================================================================ | ||
| ------------------------------- | ||
| | Node | | ||
| [fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out` | ||
| placeholders saved for | | \__________ | | | ||
| backward. | | \ | | | ||
| | [fwd_tmp] ------> [bwd_tmp] | <----- | ||
| | | \_________ | | [bwd_tmp] marks the peak memory | ||
| | / \ \ | | in backward pass. | ||
| [x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <----- | ||
| in [fwd_tmp] because | | | \_____ | | | ||
| it is not saved for | | | \ | | | ||
| backward. ------------------------------- | ||
| ============================================================================ | ||
| Attributes: | ||
| fwd_flop (int): The forward FLOPs of a certain node | ||
| bwd_flop (int): The backward FLOPs of a certain node. | ||
| fwd_mem_in (int): See the above illustration. | ||
| fwd_mem_tmp (int): See the above illustration. | ||
| bwd_mem_tmp (int): See the above illustration. | ||
| bwd_mem_out (int): See the above illustration. | ||
| """ | ||
| fwd_flop: int = 0 | ||
| bwd_flop: int = 0 | ||
| fwd_mem_in: int = 0 | ||
| fwd_mem_tmp: int = 0 | ||
| bwd_mem_tmp: int = 0 | ||
| bwd_mem_out: int = 0 | ||
|
|
||
|
|
||
| def is_forward(n: Node): | ||
| assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' | ||
| return n.meta['stage'] == Stage.FORWARD | ||
|
|
||
|
|
||
| def is_loss(n: Node): | ||
| assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' | ||
| return n.meta['stage'] == Stage.LOSS | ||
|
|
||
|
|
||
| def is_placeholder(n: Node): | ||
| assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' | ||
| return n.meta['stage'] == Stage.PLACEHOLDER | ||
|
|
||
|
|
||
| def is_backward(n: Node): | ||
| assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' | ||
| return n.meta['stage'] == Stage.BACKWARD | ||
|
|
||
|
|
||
| def is_saved(n: Node): | ||
| return n.meta.get('saved', False) | ||
|
|
||
|
|
||
| def autograd_graph_analysis(graph: Graph) -> GraphInfo: | ||
| """Analyze the autograd node dependencies and find out the memory usage. | ||
| Basically the input graph should have all nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. | ||
| Nodes should have attribute `out` indicating the output of each node. | ||
| ============================================================================ | ||
| Placeholder ----> p o <---- We need to keep track of grad out | ||
| |\________ | | ||
| ↓ ↘| | ||
| f --------> b | ||
| |\ \_____ ↑ | ||
| | \ ↘ / | ||
| f f ----> b <---- Not every forward result needs to be saved for backward | ||
| | \____ ↑ | ||
| ↘ ↘| | ||
| f ----> b <---- Backward can be freed as soon as it is required no more. | ||
| ↘ ↗ | ||
| l | ||
| ============================================================================= | ||
| Args: | ||
| graph (Graph): The autograd graph with nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. | ||
|
|
||
| Returns: | ||
| graph_info (GraphInfo): Meta information for the dataflow. | ||
| """ | ||
|
|
||
| def _peak_memory(deps: Dict[Node, int]): | ||
| bwd_tmp = 0 | ||
| for k, v in deps.items(): | ||
| if v > 0: | ||
| bwd_tmp += activation_size(k.meta['out']) | ||
| return bwd_tmp | ||
|
|
||
| # deps is used to track all the memory dependencies of the graph. | ||
| deps = {} | ||
| graph_info = GraphInfo() | ||
|
|
||
| for n in graph.nodes: | ||
| n: Node | ||
| if is_saved(n) and not any(map(is_loss, n.users)): | ||
| # A forward tensor who is marked `save` but is not | ||
| # an input to `loss` should be saved during forward. | ||
| # If the tensor is a placeholder, then it belongs to `fwd_in`. | ||
| # Any `fwd_in` should be kept in memory even this function | ||
| # is checkpointed. | ||
| # Otherwise, the tensor belongs to `fwd_tmp`. If we checkpoint | ||
| # the node, `fwd_tmp` can be freed. | ||
| if is_placeholder(n): | ||
| graph_info.fwd_mem_in += activation_size(n.meta['out']) | ||
| if is_forward(n): | ||
| graph_info.fwd_mem_tmp += activation_size(n.meta['out']) | ||
| elif is_backward(n): | ||
| if len(n.users): | ||
| # liveness analysis is only used in backward | ||
| deps[n] = len(n.users) | ||
| graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps)) | ||
| for input_n in n.all_input_nodes: | ||
| if input_n in deps: | ||
| deps[input_n] -= 1 | ||
| else: | ||
| # basically a backward node without user is a `grad_out` node | ||
| graph_info.bwd_mem_out += activation_size(n.meta['out']) | ||
| return graph_info |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.