From 06f8991bd246d64126fad6b1c18e6623283b4484 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 9 Aug 2022 23:23:12 +0800 Subject: [PATCH 01/12] [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages --- colossalai/fx/passes/meta_info_prop.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index 9e370d733e78..e4de6d03f2de 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -22,6 +22,12 @@ class TensorMetadata(NamedTuple): # behaviour by appending sharding spec into list. +@compatibility(is_backward_compatible=False) +class node_size(NamedTuple): + output_size: int + param_size: int + + def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata: """ Extract a TensorMetadata NamedTuple describing `result`. @@ -114,18 +120,27 @@ def extract_tensor_meta(obj): return TensorMetadata(None, None, False, None, 0, False) meta = _map_aggregate(result, extract_tensor_meta) - n.meta['tensor_meta'] = meta - total_node_size = _compute_node_numel(n.meta['tensor_meta']) - # counting the total size of parameters + + # get byte size for each element + size_per_elem_bytes = torch.tensor([], dtype=meta.dtype).element_size() + + # compute the total size of output tensors + total_output_size = _compute_node_numel(n.meta['tensor_meta']) + + # compute the total size of model parameters total_param_size = 0 if n.op == 'call_module': target_module = n.graph.owning_module.get_submodule(n.target) for param in target_module.parameters(): total_param_size += param.numel() + + # compute the total memory cost of output tensors and model parameters + total_output_size *= size_per_elem_bytes + total_param_size *= size_per_elem_bytes - total_node_size += total_param_size - n.node_size = total_node_size + # TODO: node.node_size is not an original attribute + setattr(n, 'node_size', node_size(total_output_size, total_param_size)) n.meta['type'] = type(result) return result From 3cd7d2250759195f59b12da09914e980caacdd5a Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 9 Aug 2022 23:28:10 +0800 Subject: [PATCH 02/12] [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages --- tests/test_fx/test_meta_info_prop.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_fx/test_meta_info_prop.py b/tests/test_fx/test_meta_info_prop.py index 84cef23b038d..28c80ea39f97 100644 --- a/tests/test_fx/test_meta_info_prop.py +++ b/tests/test_fx/test_meta_info_prop.py @@ -23,12 +23,16 @@ def test_meta_info_prop(): input_sample = torch.rand(BATCH_SIZE, DIM_IN) orig_output = model(input_sample) gm = symbolic_trace(model) + for node in gm.graph.nodes: + assert not hasattr(node, + 'node_size'), 'The attribute Node.node_size should not exist before MetaInfoProp procedure' MetaInfoProp(gm).run(input_sample) for node in gm.graph.nodes: if node.op == 'placeholder': meta_check(node.meta['tensor_meta'], input_sample) if node.op == 'output': meta_check(node.meta['tensor_meta'], orig_output) + assert hasattr(node, 'node_size'), 'The attribute Node.node_size should exist after MetaInfoProp procedure' if __name__ == '__main__': From 0849b3b7feb7468df236553dd0bef10fc15b0ddd Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Wed, 10 Aug 2022 14:42:02 +0800 Subject: [PATCH 03/12] [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages --- colossalai/fx/passes/meta_info_prop.py | 20 ++++++++------------ tests/test_fx/test_meta_info_prop.py | 8 ++++++++ 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index e4de6d03f2de..98be1be48ca2 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -22,12 +22,6 @@ class TensorMetadata(NamedTuple): # behaviour by appending sharding spec into list. -@compatibility(is_backward_compatible=False) -class node_size(NamedTuple): - output_size: int - param_size: int - - def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata: """ Extract a TensorMetadata NamedTuple describing `result`. @@ -125,8 +119,8 @@ def extract_tensor_meta(obj): # get byte size for each element size_per_elem_bytes = torch.tensor([], dtype=meta.dtype).element_size() - # compute the total size of output tensors - total_output_size = _compute_node_numel(n.meta['tensor_meta']) + # compute the total size of activation tensors + total_activation_size = _compute_node_numel(n.meta['tensor_meta']) # compute the total size of model parameters total_param_size = 0 @@ -134,13 +128,15 @@ def extract_tensor_meta(obj): target_module = n.graph.owning_module.get_submodule(n.target) for param in target_module.parameters(): total_param_size += param.numel() - - # compute the total memory cost of output tensors and model parameters - total_output_size *= size_per_elem_bytes + + # compute the total memory cost of activation tensors and model parameters + total_activation_size *= size_per_elem_bytes total_param_size *= size_per_elem_bytes # TODO: node.node_size is not an original attribute - setattr(n, 'node_size', node_size(total_output_size, total_param_size)) + setattr(n, 'node_size', total_activation_size + total_param_size) + setattr(n, 'param_size', total_param_size) + setattr(n, 'activation_size', total_activation_size) n.meta['type'] = type(result) return result diff --git a/tests/test_fx/test_meta_info_prop.py b/tests/test_fx/test_meta_info_prop.py index 28c80ea39f97..1da4f6b3bffd 100644 --- a/tests/test_fx/test_meta_info_prop.py +++ b/tests/test_fx/test_meta_info_prop.py @@ -26,6 +26,11 @@ def test_meta_info_prop(): for node in gm.graph.nodes: assert not hasattr(node, 'node_size'), 'The attribute Node.node_size should not exist before MetaInfoProp procedure' + assert not hasattr(node, + 'param_size'), 'The attribute Node.param_size should not exist before MetaInfoProp procedure' + assert not hasattr( + node, + 'activation_size'), 'The attribute Node.activation_size should not exist before MetaInfoProp procedure' MetaInfoProp(gm).run(input_sample) for node in gm.graph.nodes: if node.op == 'placeholder': @@ -33,6 +38,9 @@ def test_meta_info_prop(): if node.op == 'output': meta_check(node.meta['tensor_meta'], orig_output) assert hasattr(node, 'node_size'), 'The attribute Node.node_size should exist after MetaInfoProp procedure' + assert hasattr(node, 'param_size'), 'The attribute Node.param_size should exist after MetaInfoProp procedure' + assert hasattr( + node, 'activation_size'), 'The attribute Node.activation_size should exist after MetaInfoProp procedure' if __name__ == '__main__': From 9b4f46049e70e9e9b800490ecddcdd913e4b4b54 Mon Sep 17 00:00:00 2001 From: Super Daniel <78588128+super-dainiu@users.noreply.github.com> Date: Fri, 12 Aug 2022 12:33:07 +0800 Subject: [PATCH 04/12] [fx] merge development into main (#1) * [fx] activation checkpointing using Chen strategies. * [fx] add test for ckpt_solver_chen * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add a namespace code for solver_chen. * [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174. * [fx] fix lowercase naming conventions. * [fx] simplify test for ckpt. --- .../test_ckpt_torchvision.py | 30 +++++++++++-------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index 169b4bcb6433..a5de03bec2e7 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -1,4 +1,5 @@ from ctypes import Union +from typing import Callable from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn import torch import torchvision.models as tm @@ -18,36 +19,41 @@ def _is_activation_checkpoint_available(gm: GraphModule): def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule): for m_p, gm_p in zip(m.parameters(), gm.parameters()): - if not torch.allclose(m_p, gm_p): + if not torch.allclose(m_p.grad, gm_p.grad): return False return True +def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule], + model_cls: Callable[[], torch.nn.Module]): + criterion = torch.nn.MSELoss() + data = torch.rand(2, 3, 24, 24) + label = torch.rand(2, 5) + loss = criterion(m(data), label) + loss.backward() + loss = criterion(gm(data), label) + loss.backward() + assert _is_all_gradient_close(m, gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}' + + def test_ckpt_solver(): - MODEL_LIST = [tm.resnet18, tm.densenet121] + MODEL_LIST = [tm.resnet18] torch.backends.cudnn.deterministic = True tracer = ColoTracer() - data = torch.rand(1, 3, 224, 224) - label = torch.rand(1, 1000) + data = torch.rand(2, 3, 24, 24) for solver in SOLVERS: for model_cls in MODEL_LIST: - model = model_cls() - criterion = torch.nn.MSELoss() + model = model_cls(num_classes=5) graph = tracer.trace(root=model) gm = GraphModule(model, graph, model.__class__.__name__) MetaInfoProp(gm).run(data) gm = solver(gm) assert _is_activation_checkpoint_available( gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" - loss = criterion(model(data), label) - loss.backward() - loss = criterion(gm(data), label) - loss.backward() - assert _is_all_gradient_close(model, - gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}' + check_backward_consistency(model, gm, solver, model_cls) if __name__ == '__main__': From ca968c0b0e7c570a410deb444bbb094327c21f2f Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Fri, 12 Aug 2022 19:40:59 +0800 Subject: [PATCH 05/12] [fx] fix test and algorithm bugs in activation checkpointing. --- .../fx/passes/algorithms/ckpt_solver_chen.py | 40 +++++++---- .../test_ckpt_torchvision.py | 66 +++++++++++++++---- 2 files changed, 82 insertions(+), 24 deletions(-) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py index 046b165a62de..8b404e3a65ee 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py @@ -1,4 +1,4 @@ -from typing import Set, Tuple +from typing import List, Set, Tuple import torch from torch.fx import GraphModule import math @@ -6,6 +6,14 @@ __all__ = ['chen_greedy', 'chen_sqrtn'] +def _all_potential_ckpt_nodes(gm: GraphModule) -> List: + ckpt_nodes = [] + for n in gm.graph.nodes: + if n.op == 'call_module': + ckpt_nodes.append(n) + return ckpt_nodes + + def chen_greedy(gm: GraphModule) -> GraphModule: """ This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. @@ -31,36 +39,40 @@ def grid_search(num_grids: int = 6) -> Set: b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2)) b_opt = math.inf for b in range(b_min, b_max, (b_max - b_min) // num_grids): - ckpt, b_approx = run_chen_greedy(b) + ckpt_intv, b_approx = run_chen_greedy(b) if b_approx < b_opt: b_opt = b_approx - ckpt_opt = ckpt + ckpt_opt = ckpt_intv return ckpt_opt def run_chen_greedy(b: int = 0) -> Tuple[Set, int]: """ This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. """ - ckpt = set() + ckpt_nodes = _all_potential_ckpt_nodes(gm) + ckpt_intv = [] temp = 0 x = 0 y = 0 + prev_idx = 2 for (idx, n) in enumerate(gm.graph.nodes): temp += getattr(n, 'activation_size') y = max(y, temp) - if temp > b: + if temp > b and n in ckpt_nodes: x += getattr(n, 'activation_size') temp = 0 - ckpt.add(idx) - return ckpt, math.floor(math.sqrt(x * y)) + ckpt_intv.append((prev_idx, idx + 1)) + prev_idx = idx + 1 + return ckpt_intv, math.floor(math.sqrt(x * y)) gm.graph.lint() # make sure nodes are in topological order ckpt = grid_search(num_grids=6) - i = 0 - for idx, n in enumerate(gm.graph.nodes): - if idx in ckpt: - setattr(n, 'activation_checkpoint', str(i)) - i += 1 + node_list = list(gm.graph.nodes) + for i, seg in enumerate(ckpt): + for idx in range(*seg): + n = node_list[idx] + if n.op in ['call_module', 'call_method', 'call_function']: + setattr(n, 'activation_checkpoint', str(i)) gm.recompile() return gm @@ -82,7 +94,9 @@ def chen_sqrtn(gm: GraphModule) -> GraphModule: gm.graph.lint() # make sure nodes are in topological order k = int(len(gm.graph.nodes)**0.5) # take approximately sqrt(n) checkpoints for idx, n in enumerate(gm.graph.nodes): - if (idx + 1) % k == 0: + # We should not add act_ckpt to the placeholder + # The last segment should not be checkpointed + if n.op != 'placeholder' and (idx + 1) // k < k: setattr(n, 'activation_checkpoint', str((idx + 1) // k)) gm.recompile() return gm diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index a5de03bec2e7..8a74035f21a1 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -1,13 +1,24 @@ -from ctypes import Union from typing import Callable -from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn +import copy import torch import torchvision.models as tm -from colossalai.fx import ColoTracer from torch.fx import GraphModule +import colossalai +from colossalai.fx import ColoTracer from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn +from colossalai.utils import free_port +from colossalai.core import global_context as gpc import pytest +try: + from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True +except: + # fall back to older pytorch version + from colossalai.fx.codegen import python_code_with_activation_checkpoint + with_codegen = False + SOLVERS = [chen_greedy, chen_sqrtn] @@ -27,7 +38,7 @@ def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule): def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule], model_cls: Callable[[], torch.nn.Module]): criterion = torch.nn.MSELoss() - data = torch.rand(2, 3, 24, 24) + data = torch.rand(2, 3, 32, 32) label = torch.rand(2, 5) loss = criterion(m(data), label) loss.backward() @@ -36,25 +47,58 @@ def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Call assert _is_all_gradient_close(m, gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}' +@pytest.mark.skip +@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') def test_ckpt_solver(): - MODEL_LIST = [tm.resnet18] + colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') + MODEL_LIST = [tm.resnet18, tm.densenet121] + + torch.backends.cudnn.deterministic = True + + tracer = ColoTracer(trace_act_ckpt=False) + + data = torch.rand(2, 3, 32, 32) + for solver in SOLVERS: + for model_cls in MODEL_LIST: + m = model_cls(num_classes=5) + graph = tracer.trace(root=m) + gm = GraphModule(copy.deepcopy(m), graph, m.__class__.__name__) + MetaInfoProp(gm).run(data) + codegen = ActivationCheckpointCodeGen() + gm.graph.set_codegen(codegen) + gm = solver(gm) + assert _is_activation_checkpoint_available( + gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" + gm.to_folder("foo", "Bar") + check_backward_consistency(m, gm, solver, model_cls) + gpc.destroy() + + +@pytest.mark.skip +@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') +def test_ckpt_solver_torch11(): + colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') + MODEL_LIST = [tm.resnet18, tm.densenet121] torch.backends.cudnn.deterministic = True - tracer = ColoTracer() + tracer = ColoTracer(trace_act_ckpt=False) - data = torch.rand(2, 3, 24, 24) + data = torch.rand(2, 3, 32, 32) for solver in SOLVERS: for model_cls in MODEL_LIST: - model = model_cls(num_classes=5) - graph = tracer.trace(root=model) - gm = GraphModule(model, graph, model.__class__.__name__) + m = model_cls(num_classes=5) + graph = tracer.trace(root=m) + gm = GraphModule(copy.deepcopy(m), graph, m.__class__.__name__) MetaInfoProp(gm).run(data) + gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph) gm = solver(gm) assert _is_activation_checkpoint_available( gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" - check_backward_consistency(model, gm, solver, model_cls) + check_backward_consistency(m, gm, solver, model_cls) + gpc.destroy() if __name__ == '__main__': test_ckpt_solver() + test_ckpt_solver_torch11() From d5c1e4eb77669b3e5d9abb1fddf72dc0982179f0 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Fri, 12 Aug 2022 20:05:21 +0800 Subject: [PATCH 06/12] mend [fx] fix test and algorithm bugs in activation checkpointing. --- tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index 8a74035f21a1..b64d606a1079 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -69,7 +69,6 @@ def test_ckpt_solver(): gm = solver(gm) assert _is_activation_checkpoint_available( gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" - gm.to_folder("foo", "Bar") check_backward_consistency(m, gm, solver, model_cls) gpc.destroy() From 195a8fadfa602f5382af9af94cfa834292de83c9 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Fri, 12 Aug 2022 20:05:45 +0800 Subject: [PATCH 07/12] mend [fx] fix test and algorithm bugs in activation checkpointing. --- colossalai/fx/passes/algorithms/ckpt_solver_chen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py index 8b404e3a65ee..3fbb1801d4e1 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py @@ -54,7 +54,7 @@ def run_chen_greedy(b: int = 0) -> Tuple[Set, int]: temp = 0 x = 0 y = 0 - prev_idx = 2 + prev_idx = 1 for (idx, n) in enumerate(gm.graph.nodes): temp += getattr(n, 'activation_size') y = max(y, temp) From 47f6502e5fa3f94af66158c6569503bd3f541fc6 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Fri, 12 Aug 2022 20:14:31 +0800 Subject: [PATCH 08/12] mend [fx] fix test and algorithm bugs in activation checkpointing. --- colossalai/fx/passes/algorithms/ckpt_solver_chen.py | 2 +- tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py index 3fbb1801d4e1..8b404e3a65ee 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py @@ -54,7 +54,7 @@ def run_chen_greedy(b: int = 0) -> Tuple[Set, int]: temp = 0 x = 0 y = 0 - prev_idx = 1 + prev_idx = 2 for (idx, n) in enumerate(gm.graph.nodes): temp += getattr(n, 'activation_size') y = max(y, temp) diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index b64d606a1079..8a74035f21a1 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -69,6 +69,7 @@ def test_ckpt_solver(): gm = solver(gm) assert _is_activation_checkpoint_available( gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" + gm.to_folder("foo", "Bar") check_backward_consistency(m, gm, solver, model_cls) gpc.destroy() From 75c22348322fd426b057c8d68dfdbd16dc3b71ae Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Fri, 12 Aug 2022 20:14:46 +0800 Subject: [PATCH 09/12] mend [fx] fix test and algorithm bugs in activation checkpointing. --- tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index 8a74035f21a1..b64d606a1079 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -69,7 +69,6 @@ def test_ckpt_solver(): gm = solver(gm) assert _is_activation_checkpoint_available( gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" - gm.to_folder("foo", "Bar") check_backward_consistency(m, gm, solver, model_cls) gpc.destroy() From abc52d199ada4dd41cd88f2ccce633813b20b8c7 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Mon, 15 Aug 2022 17:49:44 +0800 Subject: [PATCH 10/12] [fx] polish ckpt_test. --- .../test_ckpt_torchvision.py | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index b64d606a1079..a15a92db26c8 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -1,6 +1,7 @@ from typing import Callable import copy import torch +import torch.multiprocessing as mp import torchvision.models as tm from torch.fx import GraphModule import colossalai @@ -47,10 +48,8 @@ def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Call assert _is_all_gradient_close(m, gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}' -@pytest.mark.skip -@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') -def test_ckpt_solver(): - colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') +def _run_ckpt_solver(rank): + colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') MODEL_LIST = [tm.resnet18, tm.densenet121] torch.backends.cudnn.deterministic = True @@ -70,13 +69,15 @@ def test_ckpt_solver(): assert _is_activation_checkpoint_available( gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" check_backward_consistency(m, gm, solver, model_cls) - gpc.destroy() -@pytest.mark.skip -@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') -def test_ckpt_solver_torch11(): - colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') +@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +def test_ckpt_solver(): + mp.spawn(_run_ckpt_solver, nprocs=1) + + +def _run_ckpt_solver_torch11(rank): + colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') MODEL_LIST = [tm.resnet18, tm.densenet121] torch.backends.cudnn.deterministic = True @@ -95,7 +96,12 @@ def test_ckpt_solver_torch11(): assert _is_activation_checkpoint_available( gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" check_backward_consistency(m, gm, solver, model_cls) - gpc.destroy() + + +@pytest.mark.skip +@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') +def test_ckpt_solver_torch11(): + mp.spawn(_run_ckpt_solver_torch11, nprocs=1) if __name__ == '__main__': From aba817abf350e858b1f328e191fe28536270fe63 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Mon, 15 Aug 2022 17:51:26 +0800 Subject: [PATCH 11/12] [fx] polish ckpt_test. --- .../test_fx/test_codegen/test_activation_checkpoint_codegen.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py index fe5c638b2e24..2fe55b1c5351 100644 --- a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py +++ b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py @@ -90,6 +90,7 @@ def _run_act_ckpt_codegen(rank): gpc.destroy() +@pytest.mark.skip @pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') def test_act_ckpt_codegen(): mp.spawn(_run_act_ckpt_codegen, nprocs=1) @@ -142,6 +143,7 @@ def _run_act_ckpt_python_code_torch11(rank): gpc.destroy() +@pytest.mark.skip @pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') def test_act_ckpt_python_code_torch11(): mp.spawn(_run_act_ckpt_python_code_torch11, nprocs=1) From ded09b0ee772ef57dea60b12e4395bf433c4d2a6 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Mon, 15 Aug 2022 17:55:15 +0800 Subject: [PATCH 12/12] [fx] polish ckpt_test. --- tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py | 1 + .../test_fx/test_codegen/test_activation_checkpoint_codegen.py | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index a15a92db26c8..1772c2840535 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -71,6 +71,7 @@ def _run_ckpt_solver(rank): check_backward_consistency(m, gm, solver, model_cls) +@pytest.mark.skip @pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') def test_ckpt_solver(): mp.spawn(_run_ckpt_solver, nprocs=1) diff --git a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py index 2fe55b1c5351..fe5c638b2e24 100644 --- a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py +++ b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py @@ -90,7 +90,6 @@ def _run_act_ckpt_codegen(rank): gpc.destroy() -@pytest.mark.skip @pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') def test_act_ckpt_codegen(): mp.spawn(_run_act_ckpt_codegen, nprocs=1) @@ -143,7 +142,6 @@ def _run_act_ckpt_python_code_torch11(rank): gpc.destroy() -@pytest.mark.skip @pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') def test_act_ckpt_python_code_torch11(): mp.spawn(_run_act_ckpt_python_code_torch11, nprocs=1)