-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[fx] add vanilla activation checkpoint search with test on resnet and densenet #1433
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
super-dainiu
merged 12 commits into
hpcaitech:main
from
super-dainiu:feature/fx_ckpt_chen
Aug 11, 2022
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
06f8991
[fx] modify the calculation of node_size in MetaInfoProp for activati…
super-dainiu 3cd7d22
[fx] modify the calculation of node_size in MetaInfoProp for activati…
super-dainiu 0849b3b
[fx] modify the calculation of node_size in MetaInfoProp for activati…
super-dainiu 701786c
Merge branch 'hpcaitech:main' into main
super-dainiu e11db26
[fx] activation checkpointing using Chen strategies.
super-dainiu 7b77f56
Merge remote-tracking branch 'upstream/main' into feature/fx_ckpt_chen
super-dainiu a7d56bd
[fx] add test for ckpt_solver_chen
super-dainiu f8a28bc
mend
super-dainiu 004bcff
[fx] add vanilla activation checkpoint search with test on resnet and…
super-dainiu e83b8c3
Merge branch 'feature/fx_ckpt_chen' of https://github.com/super-daini…
super-dainiu afa6178
[fx] add vanilla activation checkpoint search with test on resnet and…
super-dainiu acc5184
[fx] add a namespace code for solver_chen.
super-dainiu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from .ckpt_solver_chen import chen_greedy, chen_sqrtn |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| import torch | ||
| from torch.fx import GraphModule | ||
|
|
||
| __all__ = ['chen_greedy', 'chen_sqrtn'] | ||
|
|
||
|
|
||
| def chen_greedy(gm: GraphModule, B: int): | ||
| """ | ||
| This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. | ||
|
|
||
| Usage: | ||
| B = 5 * 1024 * 1024 * 1024 # An approximate memory budget of 5GB | ||
| model = resnet18() | ||
| input_sample = torch.rand(4, 3, 224, 224) | ||
| gm = symbolic_trace(model) | ||
| MetaInfoProp(gm).run(input_sample) | ||
| gm = chen_greedy(gm, B) | ||
|
|
||
| Args: | ||
| gm (GraphModule): The module to add checkpoints | ||
| B (int): The approximate memory budget for this module. | ||
| """ | ||
| gm.graph.lint() # make sure nodes are in topological order | ||
| temp = 0 | ||
| x = 0 | ||
| idx = 0 | ||
| budget = B | ||
| for n in gm.graph.nodes: | ||
| B -= getattr(n, 'param_size') | ||
| assert B > 0, f'The memory budget {budget / 1024 ** 3:.2f} GB is not enough for model parameters of {gm}' | ||
| for n in gm.graph.nodes: | ||
| temp += getattr(n, 'activation_size') | ||
| if temp > B: | ||
| x += getattr(n, 'activation_size') | ||
| temp = x | ||
| setattr(n, 'activation_checkpoint', str(idx)) | ||
| idx += 1 | ||
| gm.recompile() | ||
| return gm | ||
|
|
||
|
|
||
| def chen_sqrtn(gm: GraphModule): | ||
| """ | ||
| This is the theoretical optimal strategy in https://arxiv.org/abs/1604.06174. | ||
|
|
||
| Usage: | ||
| model = resnet18() | ||
| input_sample = torch.rand(4, 3, 224, 224) | ||
| gm = symbolic_trace(model) | ||
| MetaInfoProp(gm).run(input_sample) | ||
| gm = chen_sqrtn(gm) | ||
|
|
||
| Args: | ||
| gm (GraphModule): The module to add checkpoints | ||
| """ | ||
| gm.graph.lint() # make sure nodes are in topological order | ||
| k = int(len(gm.graph.nodes)**0.5) # take approximately sqrt(n) checkpoints | ||
| for idx, n in enumerate(gm.graph.nodes): | ||
| if (idx + 1) % k == 0: | ||
| setattr(n, 'activation_checkpoint', str((idx + 1) // k)) | ||
| gm.recompile() | ||
| return gm | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn | ||
| import torch | ||
| import torchvision.models as tm | ||
| from colossalai.fx import ColoTracer | ||
| from torch.fx import GraphModule | ||
| from colossalai.fx.passes.meta_info_prop import MetaInfoProp | ||
| from functools import partial | ||
| import pytest | ||
|
|
||
| SOLVERS = [partial(chen_greedy, B=1024 * 1024 * 64), chen_sqrtn] | ||
|
|
||
|
|
||
| def _is_activation_checkpoint_available(gm: GraphModule): | ||
| for n in gm.graph.nodes: | ||
| if hasattr(n, 'activation_checkpoint') and getattr(n, 'activation_checkpoint') is not None: | ||
| return True | ||
|
|
||
|
|
||
| def test_ckpt_solver(): | ||
| MODEL_LIST = [tm.resnet18, tm.densenet121] | ||
|
|
||
| torch.backends.cudnn.deterministic = True | ||
|
|
||
| tracer = ColoTracer() | ||
| data = torch.rand(1, 3, 224, 224) | ||
|
|
||
| for solver in SOLVERS: | ||
| for model_cls in MODEL_LIST: | ||
| model = model_cls() | ||
| graph = tracer.trace(root=model) | ||
| gm = GraphModule(model, graph, model.__class__.__name__) | ||
| MetaInfoProp(gm).run(data) | ||
| gm = solver(gm) | ||
| assert _is_activation_checkpoint_available( | ||
| gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" | ||
| assert torch.allclose(gm(data), model(data)) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| test_ckpt_solver() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.