From 06f8991bd246d64126fad6b1c18e6623283b4484 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 9 Aug 2022 23:23:12 +0800 Subject: [PATCH 01/10] [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages --- colossalai/fx/passes/meta_info_prop.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index 9e370d733e78..e4de6d03f2de 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -22,6 +22,12 @@ class TensorMetadata(NamedTuple): # behaviour by appending sharding spec into list. +@compatibility(is_backward_compatible=False) +class node_size(NamedTuple): + output_size: int + param_size: int + + def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata: """ Extract a TensorMetadata NamedTuple describing `result`. @@ -114,18 +120,27 @@ def extract_tensor_meta(obj): return TensorMetadata(None, None, False, None, 0, False) meta = _map_aggregate(result, extract_tensor_meta) - n.meta['tensor_meta'] = meta - total_node_size = _compute_node_numel(n.meta['tensor_meta']) - # counting the total size of parameters + + # get byte size for each element + size_per_elem_bytes = torch.tensor([], dtype=meta.dtype).element_size() + + # compute the total size of output tensors + total_output_size = _compute_node_numel(n.meta['tensor_meta']) + + # compute the total size of model parameters total_param_size = 0 if n.op == 'call_module': target_module = n.graph.owning_module.get_submodule(n.target) for param in target_module.parameters(): total_param_size += param.numel() + + # compute the total memory cost of output tensors and model parameters + total_output_size *= size_per_elem_bytes + total_param_size *= size_per_elem_bytes - total_node_size += total_param_size - n.node_size = total_node_size + # TODO: node.node_size is not an original attribute + setattr(n, 'node_size', node_size(total_output_size, total_param_size)) n.meta['type'] = type(result) return result From 3cd7d2250759195f59b12da09914e980caacdd5a Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 9 Aug 2022 23:28:10 +0800 Subject: [PATCH 02/10] [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages --- tests/test_fx/test_meta_info_prop.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_fx/test_meta_info_prop.py b/tests/test_fx/test_meta_info_prop.py index 84cef23b038d..28c80ea39f97 100644 --- a/tests/test_fx/test_meta_info_prop.py +++ b/tests/test_fx/test_meta_info_prop.py @@ -23,12 +23,16 @@ def test_meta_info_prop(): input_sample = torch.rand(BATCH_SIZE, DIM_IN) orig_output = model(input_sample) gm = symbolic_trace(model) + for node in gm.graph.nodes: + assert not hasattr(node, + 'node_size'), 'The attribute Node.node_size should not exist before MetaInfoProp procedure' MetaInfoProp(gm).run(input_sample) for node in gm.graph.nodes: if node.op == 'placeholder': meta_check(node.meta['tensor_meta'], input_sample) if node.op == 'output': meta_check(node.meta['tensor_meta'], orig_output) + assert hasattr(node, 'node_size'), 'The attribute Node.node_size should exist after MetaInfoProp procedure' if __name__ == '__main__': From 0849b3b7feb7468df236553dd0bef10fc15b0ddd Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Wed, 10 Aug 2022 14:42:02 +0800 Subject: [PATCH 03/10] [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages --- colossalai/fx/passes/meta_info_prop.py | 20 ++++++++------------ tests/test_fx/test_meta_info_prop.py | 8 ++++++++ 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index e4de6d03f2de..98be1be48ca2 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -22,12 +22,6 @@ class TensorMetadata(NamedTuple): # behaviour by appending sharding spec into list. -@compatibility(is_backward_compatible=False) -class node_size(NamedTuple): - output_size: int - param_size: int - - def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata: """ Extract a TensorMetadata NamedTuple describing `result`. @@ -125,8 +119,8 @@ def extract_tensor_meta(obj): # get byte size for each element size_per_elem_bytes = torch.tensor([], dtype=meta.dtype).element_size() - # compute the total size of output tensors - total_output_size = _compute_node_numel(n.meta['tensor_meta']) + # compute the total size of activation tensors + total_activation_size = _compute_node_numel(n.meta['tensor_meta']) # compute the total size of model parameters total_param_size = 0 @@ -134,13 +128,15 @@ def extract_tensor_meta(obj): target_module = n.graph.owning_module.get_submodule(n.target) for param in target_module.parameters(): total_param_size += param.numel() - - # compute the total memory cost of output tensors and model parameters - total_output_size *= size_per_elem_bytes + + # compute the total memory cost of activation tensors and model parameters + total_activation_size *= size_per_elem_bytes total_param_size *= size_per_elem_bytes # TODO: node.node_size is not an original attribute - setattr(n, 'node_size', node_size(total_output_size, total_param_size)) + setattr(n, 'node_size', total_activation_size + total_param_size) + setattr(n, 'param_size', total_param_size) + setattr(n, 'activation_size', total_activation_size) n.meta['type'] = type(result) return result diff --git a/tests/test_fx/test_meta_info_prop.py b/tests/test_fx/test_meta_info_prop.py index 28c80ea39f97..1da4f6b3bffd 100644 --- a/tests/test_fx/test_meta_info_prop.py +++ b/tests/test_fx/test_meta_info_prop.py @@ -26,6 +26,11 @@ def test_meta_info_prop(): for node in gm.graph.nodes: assert not hasattr(node, 'node_size'), 'The attribute Node.node_size should not exist before MetaInfoProp procedure' + assert not hasattr(node, + 'param_size'), 'The attribute Node.param_size should not exist before MetaInfoProp procedure' + assert not hasattr( + node, + 'activation_size'), 'The attribute Node.activation_size should not exist before MetaInfoProp procedure' MetaInfoProp(gm).run(input_sample) for node in gm.graph.nodes: if node.op == 'placeholder': @@ -33,6 +38,9 @@ def test_meta_info_prop(): if node.op == 'output': meta_check(node.meta['tensor_meta'], orig_output) assert hasattr(node, 'node_size'), 'The attribute Node.node_size should exist after MetaInfoProp procedure' + assert hasattr(node, 'param_size'), 'The attribute Node.param_size should exist after MetaInfoProp procedure' + assert hasattr( + node, 'activation_size'), 'The attribute Node.activation_size should exist after MetaInfoProp procedure' if __name__ == '__main__': From 9b4f46049e70e9e9b800490ecddcdd913e4b4b54 Mon Sep 17 00:00:00 2001 From: Super Daniel <78588128+super-dainiu@users.noreply.github.com> Date: Fri, 12 Aug 2022 12:33:07 +0800 Subject: [PATCH 04/10] [fx] merge development into main (#1) * [fx] activation checkpointing using Chen strategies. * [fx] add test for ckpt_solver_chen * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add a namespace code for solver_chen. * [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174. * [fx] fix lowercase naming conventions. * [fx] simplify test for ckpt. --- .../test_ckpt_torchvision.py | 30 +++++++++++-------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index 169b4bcb6433..a5de03bec2e7 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -1,4 +1,5 @@ from ctypes import Union +from typing import Callable from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn import torch import torchvision.models as tm @@ -18,36 +19,41 @@ def _is_activation_checkpoint_available(gm: GraphModule): def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule): for m_p, gm_p in zip(m.parameters(), gm.parameters()): - if not torch.allclose(m_p, gm_p): + if not torch.allclose(m_p.grad, gm_p.grad): return False return True +def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule], + model_cls: Callable[[], torch.nn.Module]): + criterion = torch.nn.MSELoss() + data = torch.rand(2, 3, 24, 24) + label = torch.rand(2, 5) + loss = criterion(m(data), label) + loss.backward() + loss = criterion(gm(data), label) + loss.backward() + assert _is_all_gradient_close(m, gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}' + + def test_ckpt_solver(): - MODEL_LIST = [tm.resnet18, tm.densenet121] + MODEL_LIST = [tm.resnet18] torch.backends.cudnn.deterministic = True tracer = ColoTracer() - data = torch.rand(1, 3, 224, 224) - label = torch.rand(1, 1000) + data = torch.rand(2, 3, 24, 24) for solver in SOLVERS: for model_cls in MODEL_LIST: - model = model_cls() - criterion = torch.nn.MSELoss() + model = model_cls(num_classes=5) graph = tracer.trace(root=model) gm = GraphModule(model, graph, model.__class__.__name__) MetaInfoProp(gm).run(data) gm = solver(gm) assert _is_activation_checkpoint_available( gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" - loss = criterion(model(data), label) - loss.backward() - loss = criterion(gm(data), label) - loss.backward() - assert _is_all_gradient_close(model, - gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}' + check_backward_consistency(model, gm, solver, model_cls) if __name__ == '__main__': From bea7060ab53414af39d2f2b038b3852b07a070d1 Mon Sep 17 00:00:00 2001 From: Super Daniel <78588128+super-dainiu@users.noreply.github.com> Date: Tue, 16 Aug 2022 16:13:51 +0800 Subject: [PATCH 05/10] [fx] add rules to linearize computation graphs for searching. (#2) * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] merge development into main (#1) * [fx] activation checkpointing using Chen strategies. * [fx] add test for ckpt_solver_chen * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add a namespace code for solver_chen. * [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174. * [fx] fix lowercase naming conventions. * [fx] simplify test for ckpt. * [fx] fix test and algorithm bugs in activation checkpointing. * [fx] polish ckpt_test. * [fx] add rules to linearize computation graphs for searching. --- .../codegen/activation_checkpoint_codegen.py | 2 +- .../fx/passes/algorithms/ckpt_solver_chen.py | 23 ++++++++++++++++--- .../test_ckpt_torchvision.py | 17 +++++++++++++- 3 files changed, 37 insertions(+), 5 deletions(-) diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py index 4a4bbef4cf74..07ce43815de1 100644 --- a/colossalai/fx/codegen/activation_checkpoint_codegen.py +++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py @@ -105,7 +105,7 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars): """ outputs = ', '.join(output_vars) inputs = ', '.join(input_vars) - return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_{label}, {activation_offload}, {inputs})' + return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant=False)' def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unused_value_func): diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py index 8b404e3a65ee..9f854703308c 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py @@ -1,16 +1,33 @@ from typing import List, Set, Tuple import torch -from torch.fx import GraphModule +from torch.fx import GraphModule, Node import math __all__ = ['chen_greedy', 'chen_sqrtn'] +CKPT_OP = ['call_module', 'call_method', 'call_function', 'get_attr'] def _all_potential_ckpt_nodes(gm: GraphModule) -> List: + """ + In most existing frameworks of activation checkpoint, the forward graph is assumed to be linearized. + """ + + def is_sink(): + """ + If we can free all memories when executing a certain node, it is a sink. + """ + return not sum((v for k, v in deps.items())) + + deps = {} ckpt_nodes = [] for n in gm.graph.nodes: - if n.op == 'call_module': + for n_par in n._input_nodes: + deps[n_par] -= 1 # free memory and dependencies + + # We can only put act_ckpt on these nodes + if n.op in CKPT_OP and is_sink(): ckpt_nodes.append(n) + deps[n] = len(n.users) # add dependencies for future graph return ckpt_nodes @@ -71,7 +88,7 @@ def run_chen_greedy(b: int = 0) -> Tuple[Set, int]: for i, seg in enumerate(ckpt): for idx in range(*seg): n = node_list[idx] - if n.op in ['call_module', 'call_method', 'call_function']: + if n.op in CKPT_OP: setattr(n, 'activation_checkpoint', str(i)) gm.recompile() return gm diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index 1772c2840535..479727677721 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -1,5 +1,6 @@ from typing import Callable import copy +import re import torch import torch.multiprocessing as mp import torchvision.models as tm @@ -20,7 +21,7 @@ from colossalai.fx.codegen import python_code_with_activation_checkpoint with_codegen = False -SOLVERS = [chen_greedy, chen_sqrtn] +SOLVERS = [chen_greedy] def _is_activation_checkpoint_available(gm: GraphModule): @@ -36,6 +37,16 @@ def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule): return True +def _is_graph_linearized(gm: GraphModule): + code = gm.code + # find patterns like r' return output_1, output_2' + pattern = re.compile(r' return [a-zA-Z0-9_]+(, [a-zA-Z0-9_]+)+') + if pattern.findall(code): + return False + else: + return True + + def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule], model_cls: Callable[[], torch.nn.Module]): criterion = torch.nn.MSELoss() @@ -66,9 +77,11 @@ def _run_ckpt_solver(rank): codegen = ActivationCheckpointCodeGen() gm.graph.set_codegen(codegen) gm = solver(gm) + assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner." assert _is_activation_checkpoint_available( gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" check_backward_consistency(m, gm, solver, model_cls) + gpc.destroy() @pytest.mark.skip @@ -94,9 +107,11 @@ def _run_ckpt_solver_torch11(rank): MetaInfoProp(gm).run(data) gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph) gm = solver(gm) + assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner." assert _is_activation_checkpoint_available( gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" check_backward_consistency(m, gm, solver, model_cls) + gpc.destroy() @pytest.mark.skip From da259ccda388ca826993926157fbbfeb317db4dc Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 16 Aug 2022 16:32:04 +0800 Subject: [PATCH 06/10] [fx] remove chen_sqrt for sake of simplicity --- .../fx/passes/algorithms/ckpt_solver_chen.py | 25 ------------------- 1 file changed, 25 deletions(-) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py index 9f854703308c..f2313d92e732 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py @@ -92,28 +92,3 @@ def run_chen_greedy(b: int = 0) -> Tuple[Set, int]: setattr(n, 'activation_checkpoint', str(i)) gm.recompile() return gm - - -def chen_sqrtn(gm: GraphModule) -> GraphModule: - """ - This is the theoretical optimal strategy in https://arxiv.org/abs/1604.06174. - - Usage: - model = resnet18() - input_sample = torch.rand(4, 3, 224, 224) - gm = symbolic_trace(model) - MetaInfoProp(gm).run(input_sample) - gm = chen_sqrtn(gm) - - Args: - gm (GraphModule): The module to add checkpoints - """ - gm.graph.lint() # make sure nodes are in topological order - k = int(len(gm.graph.nodes)**0.5) # take approximately sqrt(n) checkpoints - for idx, n in enumerate(gm.graph.nodes): - # We should not add act_ckpt to the placeholder - # The last segment should not be checkpointed - if n.op != 'placeholder' and (idx + 1) // k < k: - setattr(n, 'activation_checkpoint', str((idx + 1) // k)) - gm.recompile() - return gm From 296b405e1bfd5e701009f527a90030518d8c9cbc Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 16 Aug 2022 16:32:04 +0800 Subject: [PATCH 07/10] [fx] remove chen_sqrt for sake of simplicity --- .../fx/passes/algorithms/ckpt_solver_chen.py | 25 ------------------- 1 file changed, 25 deletions(-) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py index 9f854703308c..f2313d92e732 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py @@ -92,28 +92,3 @@ def run_chen_greedy(b: int = 0) -> Tuple[Set, int]: setattr(n, 'activation_checkpoint', str(i)) gm.recompile() return gm - - -def chen_sqrtn(gm: GraphModule) -> GraphModule: - """ - This is the theoretical optimal strategy in https://arxiv.org/abs/1604.06174. - - Usage: - model = resnet18() - input_sample = torch.rand(4, 3, 224, 224) - gm = symbolic_trace(model) - MetaInfoProp(gm).run(input_sample) - gm = chen_sqrtn(gm) - - Args: - gm (GraphModule): The module to add checkpoints - """ - gm.graph.lint() # make sure nodes are in topological order - k = int(len(gm.graph.nodes)**0.5) # take approximately sqrt(n) checkpoints - for idx, n in enumerate(gm.graph.nodes): - # We should not add act_ckpt to the placeholder - # The last segment should not be checkpointed - if n.op != 'placeholder' and (idx + 1) // k < k: - setattr(n, 'activation_checkpoint', str((idx + 1) // k)) - gm.recompile() - return gm From 3e9531cb25a1ad7f68452d6d5959790c9039a962 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 16 Aug 2022 16:55:19 +0800 Subject: [PATCH 08/10] [fx] remove chen_sqrt for sake of simplicity --- colossalai/fx/codegen/activation_checkpoint_codegen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py index 07ce43815de1..4a4bbef4cf74 100644 --- a/colossalai/fx/codegen/activation_checkpoint_codegen.py +++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py @@ -105,7 +105,7 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars): """ outputs = ', '.join(output_vars) inputs = ', '.join(input_vars) - return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant=False)' + return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_{label}, {activation_offload}, {inputs})' def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unused_value_func): From 02c5cae0c699719d81a60ae0f2635e32decbaca8 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 16 Aug 2022 16:57:18 +0800 Subject: [PATCH 09/10] [fx] remove chen_sqrt for sake of simplicity --- colossalai/fx/passes/algorithms/ckpt_solver_chen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py index 4aa36f429aca..4a1601685956 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py @@ -3,7 +3,7 @@ from torch.fx import GraphModule, Node import math -__all__ = ['chen_greedy', 'chen_sqrtn'] +__all__ = ['chen_greedy'] CKPT_OP = ['call_module', 'call_method', 'call_function', 'get_attr'] From 083cf7f40369a78f9d1baec10b593047c6080acd Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Wed, 17 Aug 2022 14:10:53 +0800 Subject: [PATCH 10/10] [fx] fix inconsistencies. --- colossalai/fx/passes/algorithms/__init__.py | 2 +- colossalai/fx/passes/algorithms/ckpt_solver_chen.py | 2 +- tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py | 4 +--- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/colossalai/fx/passes/algorithms/__init__.py b/colossalai/fx/passes/algorithms/__init__.py index 943fbd8678bd..bf6f9eb28017 100644 --- a/colossalai/fx/passes/algorithms/__init__.py +++ b/colossalai/fx/passes/algorithms/__init__.py @@ -1 +1 @@ -from .ckpt_solver_chen import chen_greedy, chen_sqrtn +from .ckpt_solver_chen import chen_greedy diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py index 4a1601685956..5f665aae525d 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py @@ -89,6 +89,6 @@ def run_chen_greedy(b: int = 0) -> Tuple[Set, int]: for idx in range(*seg): n = node_list[idx] if n.op in CKPT_OP: - setattr(n, 'activation_checkpoint', str(i)) + setattr(n, 'activation_checkpoint', i) gm.recompile() return gm diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index 3bb5e48efe71..e57fa5f12921 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -8,7 +8,7 @@ import colossalai from colossalai.fx import ColoTracer from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn +from colossalai.fx.passes.algorithms import chen_greedy from colossalai.utils import free_port from colossalai.core import global_context as gpc import pytest @@ -84,7 +84,6 @@ def _run_ckpt_solver(rank): gpc.destroy() -@pytest.mark.skip @pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') def test_ckpt_solver(): mp.spawn(_run_ckpt_solver, nprocs=1) @@ -114,7 +113,6 @@ def _run_ckpt_solver_torch11(rank): gpc.destroy() -@pytest.mark.skip @pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') def test_ckpt_solver_torch11(): mp.spawn(_run_ckpt_solver_torch11, nprocs=1)