-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[fx] Add activation checkpoint solver rotor #1496
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Cypher30
merged 37 commits into
hpcaitech:main
from
Cypher30:feature/add_ckpt_solver_rotor
Aug 26, 2022
Merged
Changes from all commits
Commits
Show all changes
37 commits
Select commit
Hold shift + click to select a range
04e5272
Merge pull request #1 from hpcaitech/main
Cypher30 75618b3
Merge pull request #2 from hpcaitech/main
Cypher30 3e4620c
Merge pull request #3 from hpcaitech/main
Cypher30 cf24049
Merge remote-tracking branch 'upstream/main' into main
3d223b6
Merge remote-tracking branch 'upstream/main' into main
644115c
Merge branch 'hpcaitech:main' into main
Cypher30 d995ade
Merge branch 'hpcaitech:main' into main
Cypher30 bba2dbe
Merge branch 'hpcaitech:main' into main
Cypher30 05ca628
Merge branch 'hpcaitech:main' into main
Cypher30 0a967da
Merge branch 'hpcaitech:main' into main
Cypher30 0637c0d
Merge branch 'hpcaitech:main' into main
Cypher30 74a6227
Merge branch 'hpcaitech:main' into main
Cypher30 e550490
Merge branch 'hpcaitech:main' into main
Cypher30 2d7f5d9
Merge branch 'hpcaitech:main' into main
Cypher30 b62e870
Merge branch 'hpcaitech:main' into main
Cypher30 b4b0974
Merge branch 'hpcaitech:main' into main
Cypher30 65c20de
Merge branch 'hpcaitech:main' into main
Cypher30 1660bfc
Merge branch 'hpcaitech:main' into main
Cypher30 6eb0ad0
Merge branch 'hpcaitech:main' into main
Cypher30 84734b1
[fx] fix defining ckpt functions inside forward
e96b8d4
[fx] Modify activation checkpoint codegen and add ColoGraphModule
ae562dd
Merge branch 'hpcaitech:main' into hotfix/fix_the_ckpt_func_def
Cypher30 08b4690
[fx] some modification
bf0ee86
Merge branch 'hotfix/fix_the_ckpt_func_def' of github.com:Cypher30/Co…
b854597
some modifications
0d329c3
some modifications
a6205df
some modifications
3885361
some modifications
3fa5b56
some code modifications
09f7985
[automatic_parallel] ckpt solver rotor
bd88d01
Merge branch 'hpcaitech:main' into feature/add_ckpt_solver_rotor
Cypher30 6246e0b
[fx] add ckpt_solver_rotor
9e9d4f9
Merge branch 'hpcaitech:main' into feature/add_ckpt_solver_rotor
Cypher30 80dd403
[fx] modification
e776d34
Merge branch 'feature/add_ckpt_solver_rotor' of github.com:Cypher30/C…
fd6472c
code refactor
e1d591e
code refactor
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1,2 @@ | ||
| from .tracer import ColoTracer | ||
| from .graph_module import ColoGraphModule |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1,3 @@ | ||
| from .ckpt_solver_chen import chen_greedy | ||
| from .linearize import linearize | ||
| from .ckpt_solver_rotor import solver_rotor |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,198 @@ | ||
| from typing import List, Set, Tuple, Dict | ||
| import torch | ||
| from torch.fx import GraphModule, Node | ||
| import math | ||
| from .linearize import linearize | ||
| from .utils import * | ||
| from colossalai.fx.profiler import profile_function, profile_module | ||
| from colossalai.fx.passes.meta_info_prop import MetaInfoProp | ||
|
|
||
|
|
||
| # this is the python compute table code from rotor | ||
| # https://gitlab.inria.fr/hiepacs/rotor | ||
| # paper link: https://hal.inria.fr/hal-02352969 | ||
| def _compute_table(chain: Chain, mmax) -> Tuple: | ||
| """Returns the optimal table: a tuple containing: | ||
| Opt[m][lmin][lmax] with lmin = 0...chain.length | ||
| and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax | ||
| what[m][lmin][lmax] is (True,) if the optimal choice is a chain checkpoint | ||
| (False, j) if the optimal choice is a leaf checkpoint of length j | ||
| The computation uses dynamic programming""" | ||
|
|
||
| fw = chain.fweight + [0] ## forward time | ||
| bw = chain.bweight ## backward time, not used | ||
| cw = chain.cweight + [0] ## size of x (and of y) | ||
| cbw = chain.cbweight + [0] ## size of xbar | ||
| fwd_tmp = chain.fwd_tmp + [0] | ||
| bwd_tmp = chain.bwd_tmp + [0] | ||
|
|
||
| # Build table | ||
| opt = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)] | ||
| what = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)] | ||
| ## Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation | ||
|
|
||
| # Initialize borders of the tables for lmax-lmin = 0 | ||
| for m in range(mmax + 1): | ||
| for i in range(chain.length + 1): | ||
| #lmax-lmin = 0 | ||
| limit = max(cw[i + 1] + cbw[i + 1] + fwd_tmp[i], cw[i] + cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) | ||
| if m >= limit: ## Equation (1) | ||
| opt[m][i][i] = fw[i] + bw[i] | ||
| else: | ||
| opt[m][i][i] = float("inf") | ||
|
|
||
| # Compute everything | ||
| for m in range(mmax + 1): | ||
| for d in range(1, chain.length + 1): | ||
| for i in range(chain.length + 1 - d): | ||
| # for idx in range(i+1, chain.length + 1): | ||
| idx = i + d | ||
| mmin = cw[idx + 1] + cw[i + 1] + fwd_tmp[i] | ||
| if idx > i + 1: | ||
| mmin = max(mmin, cw[idx + 1] + max(cw[j] + cw[j + 1] + fwd_tmp[j] for j in range(i + 1, idx))) | ||
| if m < mmin: | ||
| opt[m][i][idx] = float("inf") | ||
| else: | ||
| leaf_checkpoints = [(j, sum(fw[i:j]) + opt[m - cw[j]][j][idx] + opt[m][i][j - 1]) | ||
| for j in range(i + 1, idx + 1) | ||
| if m >= cw[j]] | ||
| if leaf_checkpoints: | ||
| best_leaf = min(leaf_checkpoints, key=lambda t: t[1]) | ||
| else: | ||
| best_leaf = None | ||
| if m >= cbw[i + 1]: | ||
| chain_checkpoint = opt[m][i][i] + opt[m - cbw[i + 1]][i + 1][idx] | ||
| else: | ||
| chain_checkpoint = float("inf") | ||
| if best_leaf and best_leaf[1] <= chain_checkpoint: | ||
| opt[m][i][idx] = best_leaf[1] | ||
| what[m][i][idx] = (False, best_leaf[0]) | ||
| else: | ||
| opt[m][i][idx] = chain_checkpoint | ||
| what[m][i][idx] = (True,) | ||
| return (opt, what) | ||
|
|
||
|
|
||
| def _rec(chain, lmin, lmax, cmem, opt_table): | ||
| """ chain : the class describing the AC graph | ||
| lmin : index of the first forward to execute | ||
| lmax : upper bound index of the last forward to execute (not included) | ||
| cmem : number of available memory slots | ||
| Return the optimal sequence of makespan Opt_hete[cmem][lmin][lmax-lmin]""" | ||
| if cmem <= 0: | ||
| raise ValueError("Can not process a chain with negative memory {cmem}".format(cmem=cmem)) | ||
| opt, what = opt_table | ||
| sequence = Sequence(Function("Persistent", lmax - lmin, cmem)) | ||
| if opt[cmem][lmin][lmax] == float("inf"): | ||
| raise ValueError("Can not process this chain from index {lmin} to {lmax} with memory {cmem}".format(lmin=lmin, | ||
| lmax=lmax, | ||
| cmem=cmem)) | ||
| if lmin == lmax: | ||
| if lmin == chain.length: | ||
| sequence.insert(Loss()) | ||
| else: | ||
| sequence.insert(ForwardEnable(lmin)) | ||
| sequence.insert(Backward(lmin)) | ||
| return sequence | ||
|
|
||
| if what[cmem][lmin][lmax][0]: | ||
| sequence.insert(ForwardEnable(lmin)) | ||
| sequence.insert_sequence(_rec(chain, lmin + 1, lmax, cmem - chain.cbweigth[lmin + 1], opt_table)) | ||
| sequence.insert(Backward(lmin)) | ||
| else: | ||
| j = what[cmem][lmin][lmax][1] | ||
| sequence.insert(ForwardCheck(lmin)) | ||
| for k in range(lmin + 1, j): | ||
| sequence.insert(ForwardNograd(k)) | ||
| sequence.insert_sequence(_rec(chain, j, lmax, cmem - chain.cweigth[j], opt_table)) | ||
| sequence.insert_sequence(_rec(chain, lmin, j - 1, cmem, opt_table)) | ||
| return sequence | ||
|
|
||
|
|
||
| def _discretize(mem_unit, values): | ||
| return [math.ceil(value / mem_unit) for value in values] | ||
|
|
||
|
|
||
| def _construct_chain(node_dict: Dict[int, Node], data: torch.Tensor, mem_unit: int) -> Chain: | ||
|
|
||
| fwd_time = [] | ||
| bwd_time = [] | ||
| xbar_sizes = [data.numel() * data.element_size()] | ||
| x_sizes = [data.numel() * data.element_size()] | ||
|
|
||
| # currently we can't get the temp memory needed in fwd and bwd | ||
| tmp_fwd = [0] * len(node_dict) | ||
| tmp_bwd = [0] * (len(node_dict) + 1) | ||
|
|
||
| for key in node_dict.keys(): | ||
| fwd_time.append(0) | ||
| bwd_time.append(0) | ||
| xbar_sizes.append(0) | ||
| x_sizes.append(node_dict[key][-1].meta['tensor_meta'].numel * | ||
| torch.tensor([], dtype=node_dict[key][-1].meta['tensor_meta'].dtype).element_size()) | ||
| for node in node_dict[key]: | ||
| fwd_time[-1] += node.__flops__ | ||
|
|
||
| # currently we haven't patched the backward flops count | ||
| bwd_time[-1] += node.__flops__ * 2 | ||
|
|
||
| xbar_sizes[-1] += node.__activation__ | ||
|
|
||
| xbar_sizes[-1] = max(xbar_sizes[-1], x_sizes[-1]) | ||
|
|
||
| bwd_time.append(0) | ||
|
|
||
| fwd_time = _discretize(mem_unit, fwd_time) | ||
| bwd_time = _discretize(mem_unit, bwd_time) | ||
| xbar_sizes = _discretize(mem_unit, xbar_sizes) | ||
| x_sizes = _discretize(mem_unit, x_sizes) | ||
| tmp_fwd = _discretize(mem_unit, tmp_fwd) | ||
| tmp_bwd = _discretize(mem_unit, tmp_bwd) | ||
|
|
||
| return Chain(fwd_time, bwd_time, x_sizes, xbar_sizes, tmp_fwd, tmp_bwd) | ||
|
|
||
|
|
||
| def _annotate_from_sequence(sequence: Sequence, node_dict: Dict[int, Node]) -> GraphModule: | ||
| op_list = sequence.list_operations() | ||
| loss_op = [op for op in op_list if isinstance(op, Loss)][0] | ||
| op_list = op_list[:op_list.index(loss_op)] | ||
| ckpt_idx = 0 | ||
| in_ckpt = False | ||
| ckpt_region = [] | ||
| for idx, op in enumerate(op_list, 1): | ||
| if in_ckpt: | ||
| if isinstance(op, ForwardNograd): | ||
| ckpt_region.append(idx) | ||
|
|
||
| elif isinstance(op, ForwardEnable): | ||
| in_ckpt = False | ||
| for idx in ckpt_region: | ||
| for node in node_dict[idx]: | ||
| setattr(node, "activation_checkpoint", ckpt_idx) | ||
|
|
||
| ckpt_idx += 1 | ||
| ckpt_region = [] | ||
|
|
||
| elif isinstance(op, ForwardCheck): | ||
| for idx in ckpt_region: | ||
| for node in node_dict[idx]: | ||
| setattr(node, "activation_checkpoint", ckpt_idx) | ||
|
|
||
| ckpt_idx += 1 | ||
| ckpt_region = [idx] | ||
|
|
||
| else: | ||
| if isinstance(op, ForwardCheck): | ||
| in_ckpt = True | ||
| ckpt_region.append(idx) | ||
|
|
||
|
|
||
| def solver_rotor(gm: GraphModule, data: torch.Tensor, mem_limit: int, mem_slots: int = 500) -> GraphModule: | ||
| node_dict = linearize(gm) | ||
| mem_unit = mem_limit // mem_slots | ||
| MetaInfoProp(gm).run(data) | ||
| chain: Chain = _construct_chain(node_dict, data, mem_unit) | ||
| opt_table = _compute_table(chain, mem_slots) | ||
| sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table) | ||
| _annotate_from_sequence(sequence, node_dict) | ||
| return gm |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,89 @@ | ||
| from typing import OrderedDict | ||
| from torch.fx import GraphModule | ||
| from collections import OrderedDict | ||
| import pdb | ||
|
|
||
|
|
||
| def linearize(gm: GraphModule) -> dict: | ||
| status_dict = {} | ||
| node_dict = OrderedDict() | ||
| node_idx = 0 | ||
| for node in gm.graph.nodes: | ||
| last_dict_len = len(status_dict) | ||
| # remove node from users list in status_dict | ||
| for item in status_dict.values(): | ||
| if node in item: | ||
| item.remove(node) | ||
|
|
||
| # pop node from status_dict if it is fully used | ||
| for key in list(status_dict): | ||
| if len(status_dict[key]) == 0: | ||
| status_dict.pop(key) | ||
|
|
||
| # first node in graph, it should be in n0-n1 type, | ||
| # where n0 contains only input op, i.e. placeholder | ||
| if last_dict_len == 0: | ||
| node_dict[node_idx] = [node] | ||
| status_dict[node.name] = list(node.users) | ||
| node_idx += 1 | ||
| node_dict[node_idx] = [] | ||
|
|
||
| continue | ||
|
|
||
| # boundary case | ||
| if len(status_dict) == 0: | ||
| # current node region end point = next node region start point | ||
| # i.e. n1-n2-n3-... type node, each node contains only one op | ||
| if last_dict_len == 1: | ||
| if len(node_dict[node_idx]) > 0: | ||
| node_idx += 1 | ||
| node_dict[node_idx] = [] | ||
| node_dict[node_idx].append(node) | ||
| status_dict[node.name] = list(node.users) | ||
|
|
||
| continue | ||
|
|
||
| # n1-n2-n3, if n1 has multiple ops, the last op in n1 will be | ||
| # the one who is able to clean all others in status_dict | ||
| # and as the last_dict_len > 1, there are multiple ops are used | ||
| # by this node, we view it as the end of one node and start a new node | ||
| else: | ||
|
|
||
| node_dict[node_idx].append(node) | ||
| status_dict[node.name] = list(node.users) | ||
| node_idx += 1 | ||
| node_dict[node_idx] = [] | ||
|
|
||
| continue | ||
|
|
||
| else: | ||
| # currently I will use bigger node structure | ||
| # if the following region is activated, the node will be smaller | ||
| ################################################# | ||
| # if last_dict_len == 1: | ||
| # if len(node_dict[node_idx]) > 0: | ||
| # node_idx += 1 | ||
| # node_dict[node_idx] = [node] | ||
| # status_dict[node.name] = list(node.users) | ||
| # | ||
| # continue | ||
| ################################################# | ||
|
|
||
| # in-node case, as the current node can not clean status_dict | ||
| # we view it as in-node status, the node will be appended to the | ||
| # current node_idx | ||
| node_dict[node_idx].append(node) | ||
| status_dict[node.name] = list(node.users) | ||
|
|
||
| continue | ||
|
|
||
| # If the output node use multiple nodes, there might be an | ||
| # empty node after the output node | ||
| if len(node_dict[node_idx]) == 0: | ||
| node_dict.pop[node_idx] | ||
| node_idx -= 1 | ||
|
|
||
| # pop the last two nodes | ||
| node_dict.pop(0) | ||
| node_dict.pop(node_idx) | ||
| return node_dict | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.