-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[hotfix/rotor] fix variable names #1597
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
044adf8
a42ab22
f8e1c1c
0d55f26
42c6e8c
dafbfcf
5f25d6e
9739876
3745c5f
504c607
5d72a52
6bdeb29
fafb7d0
d3c3690
de2be8f
bc727c2
3d284af
13b1d58
7df14a2
8025fb8
ce7b1a9
4dd1b27
5277b72
594ec0b
3144f8a
31d6bbf
aead7e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -94,12 +94,11 @@ def extract_tensor_meta(obj): | |
|
|
||
| tensor_meta = tree_map(extract_tensor_meta, result) | ||
| n.meta['tensor_meta'] = tensor_meta | ||
| n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` | ||
|
|
||
| n.meta = {**n.meta, **asdict(meta_info), 'fwd_mem_out': 0} # extend MetaInfo to `n.meta` | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid doubled |
||
| # TODO: the attribute node_size should be removed in the future | ||
| setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0)) | ||
| for par in n.all_input_nodes: | ||
| par.meta['fwd_mem_out'] = par.meta.get('fwd_mem_out', 0) + n.meta.get('fwd_mem_in', 0) | ||
| par.meta['fwd_mem_out'] = max(par.meta.get('fwd_mem_out', 0), n.meta.get('fwd_mem_in', 0)) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| n.meta['type'] = type(result) | ||
|
|
||
| # retain the autograd graph | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,12 @@ | ||
| from dataclasses import dataclass | ||
| from enum import Enum | ||
| from functools import partial | ||
| from typing import Dict | ||
| from torch.fx import Graph, Node | ||
| from .memory import activation_size | ||
|
|
||
|
|
||
| class Stage(Enum): | ||
| class Phase(Enum): | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| FORWARD = 0 | ||
| LOSS = 1 | ||
| BACKWARD = 2 | ||
|
|
@@ -48,24 +49,9 @@ class GraphInfo: | |
| bwd_mem_out: int = 0 | ||
|
|
||
|
|
||
| def is_forward(n: Node): | ||
| assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' | ||
| return n.meta['stage'] == Stage.FORWARD | ||
|
|
||
|
|
||
| def is_loss(n: Node): | ||
| assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' | ||
| return n.meta['stage'] == Stage.LOSS | ||
|
|
||
|
|
||
| def is_placeholder(n: Node): | ||
| assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' | ||
| return n.meta['stage'] == Stage.PLACEHOLDER | ||
|
|
||
|
|
||
| def is_backward(n: Node): | ||
| assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' | ||
| return n.meta['stage'] == Stage.BACKWARD | ||
| def is_phase(n: Node, phase: Phase) -> bool: | ||
| assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!' | ||
| return n.meta['phase'] == phase | ||
|
|
||
|
|
||
| def is_saved(n: Node): | ||
|
|
@@ -74,7 +60,7 @@ def is_saved(n: Node): | |
|
|
||
| def autograd_graph_analysis(graph: Graph) -> GraphInfo: | ||
| """Analyze the autograd node dependencies and find out the memory usage. | ||
| Basically the input graph should have all nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. | ||
| Basically the input graph should have all nodes marked for keyword `phase`. | ||
| Nodes should have attribute `out` indicating the output of each node. | ||
| ============================================================================ | ||
| Placeholder ----> p o <---- We need to keep track of grad out | ||
|
|
@@ -91,38 +77,38 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo: | |
| l | ||
| ============================================================================= | ||
| Args: | ||
| graph (Graph): The autograd graph with nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. | ||
| graph (Graph): The autograd graph with nodes marked for keyword `phase`. | ||
|
|
||
| Returns: | ||
| graph_info (GraphInfo): Meta information for the dataflow. | ||
| """ | ||
|
|
||
| def _peak_memory(deps: Dict[Node, int]): | ||
| bwd_tmp = 0 | ||
| peak_mem = 0 | ||
| for k, v in deps.items(): | ||
| if v > 0: | ||
| bwd_tmp += activation_size(k.meta['out']) | ||
| return bwd_tmp | ||
| peak_mem += activation_size(k.meta['out']) | ||
| return peak_mem | ||
|
|
||
| # deps is used to track all the memory dependencies of the graph. | ||
| deps = {} | ||
| graph_info = GraphInfo() | ||
|
|
||
| for n in graph.nodes: | ||
| n: Node | ||
| if is_saved(n) and not any(map(is_loss, n.users)): | ||
| if is_saved(n) and not any(map(partial(is_phase, phase=Phase.LOSS), n.users)): | ||
| # A forward tensor who is marked `save` but is not | ||
| # an input to `loss` should be saved during forward. | ||
| # If the tensor is a placeholder, then it belongs to `fwd_in`. | ||
| # Any `fwd_in` should be kept in memory even this function | ||
| # If the tensor is a placeholder, then it belongs to `fwd_mem_in`. | ||
| # Any `fwd_mem_in` should be kept in memory even this function | ||
| # is checkpointed. | ||
| # Otherwise, the tensor belongs to `fwd_tmp`. If we checkpoint | ||
| # the node, `fwd_tmp` can be freed. | ||
| if is_placeholder(n): | ||
| # Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint | ||
| # the node, `fwd_mem_tmp` can be freed. | ||
| if is_phase(n, Phase.PLACEHOLDER): | ||
| graph_info.fwd_mem_in += activation_size(n.meta['out']) | ||
| if is_forward(n): | ||
| if is_phase(n, Phase.FORWARD): | ||
| graph_info.fwd_mem_tmp += activation_size(n.meta['out']) | ||
| elif is_backward(n): | ||
| elif is_phase(n, Phase.BACKWARD): | ||
| if len(n.users): | ||
| # liveness analysis is only used in backward | ||
| deps[n] = len(n.users) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.