From 06f8991bd246d64126fad6b1c18e6623283b4484 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 9 Aug 2022 23:23:12 +0800 Subject: [PATCH 01/11] [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages --- colossalai/fx/passes/meta_info_prop.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index 9e370d733e78..e4de6d03f2de 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -22,6 +22,12 @@ class TensorMetadata(NamedTuple): # behaviour by appending sharding spec into list. +@compatibility(is_backward_compatible=False) +class node_size(NamedTuple): + output_size: int + param_size: int + + def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata: """ Extract a TensorMetadata NamedTuple describing `result`. @@ -114,18 +120,27 @@ def extract_tensor_meta(obj): return TensorMetadata(None, None, False, None, 0, False) meta = _map_aggregate(result, extract_tensor_meta) - n.meta['tensor_meta'] = meta - total_node_size = _compute_node_numel(n.meta['tensor_meta']) - # counting the total size of parameters + + # get byte size for each element + size_per_elem_bytes = torch.tensor([], dtype=meta.dtype).element_size() + + # compute the total size of output tensors + total_output_size = _compute_node_numel(n.meta['tensor_meta']) + + # compute the total size of model parameters total_param_size = 0 if n.op == 'call_module': target_module = n.graph.owning_module.get_submodule(n.target) for param in target_module.parameters(): total_param_size += param.numel() + + # compute the total memory cost of output tensors and model parameters + total_output_size *= size_per_elem_bytes + total_param_size *= size_per_elem_bytes - total_node_size += total_param_size - n.node_size = total_node_size + # TODO: node.node_size is not an original attribute + setattr(n, 'node_size', node_size(total_output_size, total_param_size)) n.meta['type'] = type(result) return result From 3cd7d2250759195f59b12da09914e980caacdd5a Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 9 Aug 2022 23:28:10 +0800 Subject: [PATCH 02/11] [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages --- tests/test_fx/test_meta_info_prop.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_fx/test_meta_info_prop.py b/tests/test_fx/test_meta_info_prop.py index 84cef23b038d..28c80ea39f97 100644 --- a/tests/test_fx/test_meta_info_prop.py +++ b/tests/test_fx/test_meta_info_prop.py @@ -23,12 +23,16 @@ def test_meta_info_prop(): input_sample = torch.rand(BATCH_SIZE, DIM_IN) orig_output = model(input_sample) gm = symbolic_trace(model) + for node in gm.graph.nodes: + assert not hasattr(node, + 'node_size'), 'The attribute Node.node_size should not exist before MetaInfoProp procedure' MetaInfoProp(gm).run(input_sample) for node in gm.graph.nodes: if node.op == 'placeholder': meta_check(node.meta['tensor_meta'], input_sample) if node.op == 'output': meta_check(node.meta['tensor_meta'], orig_output) + assert hasattr(node, 'node_size'), 'The attribute Node.node_size should exist after MetaInfoProp procedure' if __name__ == '__main__': From 0849b3b7feb7468df236553dd0bef10fc15b0ddd Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Wed, 10 Aug 2022 14:42:02 +0800 Subject: [PATCH 03/11] [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages --- colossalai/fx/passes/meta_info_prop.py | 20 ++++++++------------ tests/test_fx/test_meta_info_prop.py | 8 ++++++++ 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index e4de6d03f2de..98be1be48ca2 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -22,12 +22,6 @@ class TensorMetadata(NamedTuple): # behaviour by appending sharding spec into list. -@compatibility(is_backward_compatible=False) -class node_size(NamedTuple): - output_size: int - param_size: int - - def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata: """ Extract a TensorMetadata NamedTuple describing `result`. @@ -125,8 +119,8 @@ def extract_tensor_meta(obj): # get byte size for each element size_per_elem_bytes = torch.tensor([], dtype=meta.dtype).element_size() - # compute the total size of output tensors - total_output_size = _compute_node_numel(n.meta['tensor_meta']) + # compute the total size of activation tensors + total_activation_size = _compute_node_numel(n.meta['tensor_meta']) # compute the total size of model parameters total_param_size = 0 @@ -134,13 +128,15 @@ def extract_tensor_meta(obj): target_module = n.graph.owning_module.get_submodule(n.target) for param in target_module.parameters(): total_param_size += param.numel() - - # compute the total memory cost of output tensors and model parameters - total_output_size *= size_per_elem_bytes + + # compute the total memory cost of activation tensors and model parameters + total_activation_size *= size_per_elem_bytes total_param_size *= size_per_elem_bytes # TODO: node.node_size is not an original attribute - setattr(n, 'node_size', node_size(total_output_size, total_param_size)) + setattr(n, 'node_size', total_activation_size + total_param_size) + setattr(n, 'param_size', total_param_size) + setattr(n, 'activation_size', total_activation_size) n.meta['type'] = type(result) return result diff --git a/tests/test_fx/test_meta_info_prop.py b/tests/test_fx/test_meta_info_prop.py index 28c80ea39f97..1da4f6b3bffd 100644 --- a/tests/test_fx/test_meta_info_prop.py +++ b/tests/test_fx/test_meta_info_prop.py @@ -26,6 +26,11 @@ def test_meta_info_prop(): for node in gm.graph.nodes: assert not hasattr(node, 'node_size'), 'The attribute Node.node_size should not exist before MetaInfoProp procedure' + assert not hasattr(node, + 'param_size'), 'The attribute Node.param_size should not exist before MetaInfoProp procedure' + assert not hasattr( + node, + 'activation_size'), 'The attribute Node.activation_size should not exist before MetaInfoProp procedure' MetaInfoProp(gm).run(input_sample) for node in gm.graph.nodes: if node.op == 'placeholder': @@ -33,6 +38,9 @@ def test_meta_info_prop(): if node.op == 'output': meta_check(node.meta['tensor_meta'], orig_output) assert hasattr(node, 'node_size'), 'The attribute Node.node_size should exist after MetaInfoProp procedure' + assert hasattr(node, 'param_size'), 'The attribute Node.param_size should exist after MetaInfoProp procedure' + assert hasattr( + node, 'activation_size'), 'The attribute Node.activation_size should exist after MetaInfoProp procedure' if __name__ == '__main__': From e11db26cce0f525f8628a11f3afe11ba79354772 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Wed, 10 Aug 2022 16:44:13 +0800 Subject: [PATCH 04/11] [fx] activation checkpointing using Chen strategies. --- colossalai/fx/passes/algorithms/__init__ | 0 .../fx/passes/algorithms/ckpt_solver_chen.py | 55 +++++++++++++++++++ 2 files changed, 55 insertions(+) create mode 100644 colossalai/fx/passes/algorithms/__init__ create mode 100644 colossalai/fx/passes/algorithms/ckpt_solver_chen.py diff --git a/colossalai/fx/passes/algorithms/__init__ b/colossalai/fx/passes/algorithms/__init__ new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py new file mode 100644 index 000000000000..3b1f61933cd9 --- /dev/null +++ b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py @@ -0,0 +1,55 @@ +import torch +from torch.fx import GraphModule + + +def chen_greedy(gm: GraphModule, B: int): + """ + This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174 + + Usage: + B = 5 * 1024 * 1024 * 1024 # An approximate memory budget of 5GB + model = resnet18() + input_sample = torch.rand(4, 3, 224, 224) + gm = symbolic_trace(model) + MetaInfoProp(gm).run(input_sample) + gm = chen_greedy(gm, B) + + Args: + gm (GraphModule): The module to add checkpoints + B (int): The approximate memory budget for this module. + """ + gm.graph.lint() # make sure nodes are in topological order + temp = 0 + x = 0 + idx = 0 + for n in gm.graph.nodes: + temp += getattr(n, 'activation_size') + if temp > B: + x += getattr(n, 'activation_size') + temp = 0 + setattr(n, 'activation_checkpoint', str(idx)) + idx += 1 + return gm + + +def chen_sqrtn(gm: GraphModule): + """ + This is the simple reimplementation of Algorithm 3 in https://arxiv.org/abs/1604.06174 + + Usage: + B = 5 * 1024 * 1024 * 1024 # An approximate memory budget of 5GB + model = resnet18() + input_sample = torch.rand(4, 3, 224, 224) + gm = symbolic_trace(model) + MetaInfoProp(gm).run(input_sample) + gm = chen_greedy(gm, B) + + Args: + gm (GraphModule): The module to add checkpoints + B (int): The approximate memory budget for this module. + """ + gm.graph.lint() # make sure nodes are in topological order + k = int(len(gm.graph.nodes)**0.5) # take approximately sqrt(n) checkpoints + for idx, n in enumerate(gm.graph.nodes): + if (idx + 1) % k == 0: + setattr(n, 'activation checkpoint', str((idx + 1) // k)) From a7d56bde00b0762fb755a5b01e1dba7a046096da Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Wed, 10 Aug 2022 17:39:29 +0800 Subject: [PATCH 05/11] [fx] add test for ckpt_solver_chen --- colossalai/fx/passes/algorithms/__init__ | 0 colossalai/fx/passes/algorithms/__init__.py | 1 + .../fx/passes/algorithms/ckpt_solver_chen.py | 18 ++++++--- .../test_ckpt_torchvision.py | 40 +++++++++++++++++++ 4 files changed, 53 insertions(+), 6 deletions(-) delete mode 100644 colossalai/fx/passes/algorithms/__init__ create mode 100644 colossalai/fx/passes/algorithms/__init__.py create mode 100644 tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py diff --git a/colossalai/fx/passes/algorithms/__init__ b/colossalai/fx/passes/algorithms/__init__ deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/colossalai/fx/passes/algorithms/__init__.py b/colossalai/fx/passes/algorithms/__init__.py new file mode 100644 index 000000000000..943fbd8678bd --- /dev/null +++ b/colossalai/fx/passes/algorithms/__init__.py @@ -0,0 +1 @@ +from .ckpt_solver_chen import chen_greedy, chen_sqrtn diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py index 3b1f61933cd9..3f56fafa8c66 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py @@ -4,7 +4,7 @@ def chen_greedy(gm: GraphModule, B: int): """ - This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174 + This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. Usage: B = 5 * 1024 * 1024 * 1024 # An approximate memory budget of 5GB @@ -22,6 +22,11 @@ def chen_greedy(gm: GraphModule, B: int): temp = 0 x = 0 idx = 0 + budget = B + for n in gm.graph.nodes: + B -= getattr(n, 'param_size') + assert B > 0, f'The memory budget {budget / 1024 ** 3:.2f} GB is not enough for model parameters of {gm}' + print(B) for n in gm.graph.nodes: temp += getattr(n, 'activation_size') if temp > B: @@ -29,27 +34,28 @@ def chen_greedy(gm: GraphModule, B: int): temp = 0 setattr(n, 'activation_checkpoint', str(idx)) idx += 1 + gm.recompile() return gm def chen_sqrtn(gm: GraphModule): """ - This is the simple reimplementation of Algorithm 3 in https://arxiv.org/abs/1604.06174 + This is the theoretical optimal strategy in https://arxiv.org/abs/1604.06174. Usage: - B = 5 * 1024 * 1024 * 1024 # An approximate memory budget of 5GB model = resnet18() input_sample = torch.rand(4, 3, 224, 224) gm = symbolic_trace(model) MetaInfoProp(gm).run(input_sample) - gm = chen_greedy(gm, B) + gm = chen_sqrtn(gm) Args: gm (GraphModule): The module to add checkpoints - B (int): The approximate memory budget for this module. """ gm.graph.lint() # make sure nodes are in topological order k = int(len(gm.graph.nodes)**0.5) # take approximately sqrt(n) checkpoints for idx, n in enumerate(gm.graph.nodes): if (idx + 1) % k == 0: - setattr(n, 'activation checkpoint', str((idx + 1) // k)) + setattr(n, 'activation_checkpoint', str((idx + 1) // k)) + gm.recompile() + return gm diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py new file mode 100644 index 000000000000..6f07dd0a5d35 --- /dev/null +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -0,0 +1,40 @@ +from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn +import torch +import torchvision.models as tm +from colossalai.fx import ColoTracer +from torch.fx import GraphModule +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from functools import partial +import pytest + +SOLVERS = [partial(chen_greedy, B=1024 * 1024 * 64), chen_sqrtn] + + +def _is_activation_checkpoint_available(gm: GraphModule): + for n in gm.graph.nodes: + if hasattr(n, 'activation_checkpoint') and getattr(n, 'activation_checkpoint') is not None: + return True + + +def test_ckpt_solver(): + MODEL_LIST = [tm.resnet18, tm.densenet121] + + torch.backends.cudnn.deterministic = True + + tracer = ColoTracer() + data = torch.rand(1, 3, 224, 224) + + for solver in SOLVERS: + for model_cls in MODEL_LIST: + model = model_cls() + graph = tracer.trace(root=model) + gm = GraphModule(model, graph, model.__class__.__name__) + MetaInfoProp(gm).run(data) + gm = solver(gm) + assert _is_activation_checkpoint_available( + gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" + gm(data) + + +if __name__ == '__main__': + test_ckpt_solver() From f8a28bc252264ffec6d45bdcc16b405c6d5c07f3 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Wed, 10 Aug 2022 17:44:34 +0800 Subject: [PATCH 06/11] mend --- tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index 6f07dd0a5d35..4bf3128c6c0a 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -33,7 +33,7 @@ def test_ckpt_solver(): gm = solver(gm) assert _is_activation_checkpoint_available( gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" - gm(data) + assert torch.allclose(gm(data), model(data)) if __name__ == '__main__': From 004bcff973152b490c38c883911a0929c8da309a Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Wed, 10 Aug 2022 17:44:34 +0800 Subject: [PATCH 07/11] [fx] add vanilla activation checkpoint search with test on resnet and densenet --- tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index 6f07dd0a5d35..4bf3128c6c0a 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -33,7 +33,7 @@ def test_ckpt_solver(): gm = solver(gm) assert _is_activation_checkpoint_available( gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" - gm(data) + assert torch.allclose(gm(data), model(data)) if __name__ == '__main__': From afa6178f39e31c50606dfa47d5d7acffd5c4ddd0 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Wed, 10 Aug 2022 18:08:41 +0800 Subject: [PATCH 08/11] [fx] add vanilla activation checkpoint search with test on resnet and densenet --- colossalai/fx/passes/algorithms/ckpt_solver_chen.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py index 3f56fafa8c66..8076179e623c 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py @@ -26,12 +26,11 @@ def chen_greedy(gm: GraphModule, B: int): for n in gm.graph.nodes: B -= getattr(n, 'param_size') assert B > 0, f'The memory budget {budget / 1024 ** 3:.2f} GB is not enough for model parameters of {gm}' - print(B) for n in gm.graph.nodes: temp += getattr(n, 'activation_size') if temp > B: x += getattr(n, 'activation_size') - temp = 0 + temp = x setattr(n, 'activation_checkpoint', str(idx)) idx += 1 gm.recompile() From acc5184594425aeef456e8eba3fad268591a8c89 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Thu, 11 Aug 2022 15:26:18 +0800 Subject: [PATCH 09/11] [fx] add a namespace code for solver_chen. --- colossalai/fx/passes/algorithms/ckpt_solver_chen.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py index 8076179e623c..d28e6fa1af9b 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py @@ -1,6 +1,8 @@ import torch from torch.fx import GraphModule +__all__ = ['chen_greedy', 'chen_sqrtn'] + def chen_greedy(gm: GraphModule, B: int): """ From 9f167431d780c86a824f75dfcff6301f820a84e2 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Thu, 11 Aug 2022 23:13:18 +0800 Subject: [PATCH 10/11] [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174. --- .../fx/passes/algorithms/ckpt_solver_chen.py | 64 +++++++++++++------ .../test_ckpt_torchvision.py | 20 +++++- 2 files changed, 62 insertions(+), 22 deletions(-) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py index d28e6fa1af9b..105caf88db11 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py @@ -1,45 +1,71 @@ +from typing import Set, Tuple import torch from torch.fx import GraphModule +import math __all__ = ['chen_greedy', 'chen_sqrtn'] -def chen_greedy(gm: GraphModule, B: int): +def chen_greedy(gm: GraphModule) -> GraphModule: """ This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. + Note that this algorithm targets at memory optimization only, using techniques in appendix A. Usage: - B = 5 * 1024 * 1024 * 1024 # An approximate memory budget of 5GB model = resnet18() input_sample = torch.rand(4, 3, 224, 224) gm = symbolic_trace(model) MetaInfoProp(gm).run(input_sample) - gm = chen_greedy(gm, B) + gm = chen_greedy(gm) Args: gm (GraphModule): The module to add checkpoints - B (int): The approximate memory budget for this module. """ + + def grid_search(num_grids: int = 6) -> Set: + """ + Search ckpt strategy with B = 0, then run the allocation algorithm again with B = √xy. + Grid search over [√2/2 B, √2 B] for ckpt_opt over num_grids as in appendix A. + """ + _, B_approx = run_chen_greedy(0) + B_min, B_max = math.floor(B_approx / math.sqrt(2)), math.ceil(B_approx * math.sqrt(2)) + B_opt = math.inf + for B in range(B_min, B_max, (B_max - B_min) // num_grids): + ckpt, B_approx = run_chen_greedy(B) + if B_approx < B_opt: + B_opt = B_approx + ckpt_opt = ckpt + return ckpt_opt + + def run_chen_greedy(B: int = 0) -> Tuple[Set, int]: + """ + This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. + """ + ckpt = set() + temp = 0 + x = 0 + y = 0 + for (idx, n) in enumerate(gm.graph.nodes): + temp += getattr(n, 'activation_size') + y = max(y, temp) + if temp > B: + x += getattr(n, 'activation_size') + temp = 0 + ckpt.add(idx) + return ckpt, math.floor(math.sqrt(x * y)) + gm.graph.lint() # make sure nodes are in topological order - temp = 0 - x = 0 - idx = 0 - budget = B - for n in gm.graph.nodes: - B -= getattr(n, 'param_size') - assert B > 0, f'The memory budget {budget / 1024 ** 3:.2f} GB is not enough for model parameters of {gm}' - for n in gm.graph.nodes: - temp += getattr(n, 'activation_size') - if temp > B: - x += getattr(n, 'activation_size') - temp = x - setattr(n, 'activation_checkpoint', str(idx)) - idx += 1 + ckpt = grid_search(num_grids=6) + i = 0 + for idx, n in enumerate(gm.graph.nodes): + if idx in ckpt: + setattr(n, 'activation_checkpoint', str(i)) + i += 1 gm.recompile() return gm -def chen_sqrtn(gm: GraphModule): +def chen_sqrtn(gm: GraphModule) -> GraphModule: """ This is the theoretical optimal strategy in https://arxiv.org/abs/1604.06174. diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index 4bf3128c6c0a..169b4bcb6433 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -1,13 +1,13 @@ +from ctypes import Union from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn import torch import torchvision.models as tm from colossalai.fx import ColoTracer from torch.fx import GraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from functools import partial import pytest -SOLVERS = [partial(chen_greedy, B=1024 * 1024 * 64), chen_sqrtn] +SOLVERS = [chen_greedy, chen_sqrtn] def _is_activation_checkpoint_available(gm: GraphModule): @@ -16,6 +16,13 @@ def _is_activation_checkpoint_available(gm: GraphModule): return True +def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule): + for m_p, gm_p in zip(m.parameters(), gm.parameters()): + if not torch.allclose(m_p, gm_p): + return False + return True + + def test_ckpt_solver(): MODEL_LIST = [tm.resnet18, tm.densenet121] @@ -23,17 +30,24 @@ def test_ckpt_solver(): tracer = ColoTracer() data = torch.rand(1, 3, 224, 224) + label = torch.rand(1, 1000) for solver in SOLVERS: for model_cls in MODEL_LIST: model = model_cls() + criterion = torch.nn.MSELoss() graph = tracer.trace(root=model) gm = GraphModule(model, graph, model.__class__.__name__) MetaInfoProp(gm).run(data) gm = solver(gm) assert _is_activation_checkpoint_available( gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" - assert torch.allclose(gm(data), model(data)) + loss = criterion(model(data), label) + loss.backward() + loss = criterion(gm(data), label) + loss.backward() + assert _is_all_gradient_close(model, + gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}' if __name__ == '__main__': From 041ca60b5eeca5e5c5b3d48f8d2ba0741a23b73e Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Fri, 12 Aug 2022 11:01:49 +0800 Subject: [PATCH 11/11] [fx] fix lowercase naming conventions. --- .../fx/passes/algorithms/ckpt_solver_chen.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py index 105caf88db11..046b165a62de 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py @@ -24,20 +24,20 @@ def chen_greedy(gm: GraphModule) -> GraphModule: def grid_search(num_grids: int = 6) -> Set: """ - Search ckpt strategy with B = 0, then run the allocation algorithm again with B = √xy. - Grid search over [√2/2 B, √2 B] for ckpt_opt over num_grids as in appendix A. + Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy. + Grid search over [√2/2 b, √2 b] for ckpt_opt over num_grids as in appendix A. """ - _, B_approx = run_chen_greedy(0) - B_min, B_max = math.floor(B_approx / math.sqrt(2)), math.ceil(B_approx * math.sqrt(2)) - B_opt = math.inf - for B in range(B_min, B_max, (B_max - B_min) // num_grids): - ckpt, B_approx = run_chen_greedy(B) - if B_approx < B_opt: - B_opt = B_approx + _, b_approx = run_chen_greedy(0) + b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2)) + b_opt = math.inf + for b in range(b_min, b_max, (b_max - b_min) // num_grids): + ckpt, b_approx = run_chen_greedy(b) + if b_approx < b_opt: + b_opt = b_approx ckpt_opt = ckpt return ckpt_opt - def run_chen_greedy(B: int = 0) -> Tuple[Set, int]: + def run_chen_greedy(b: int = 0) -> Tuple[Set, int]: """ This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. """ @@ -48,7 +48,7 @@ def run_chen_greedy(B: int = 0) -> Tuple[Set, int]: for (idx, n) in enumerate(gm.graph.nodes): temp += getattr(n, 'activation_size') y = max(y, temp) - if temp > B: + if temp > b: x += getattr(n, 'activation_size') temp = 0 ckpt.add(idx)