From ed273b7d008044af36de7e8af9b2ee171af6444d Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Tue, 16 Aug 2022 13:01:47 +0800 Subject: [PATCH 1/6] [utils] Add use_reetrant=False into colossalai checkpoint --- colossalai/utils/activation_checkpoint.py | 91 ++++++++++++++++++- .../test_activation_checkpointing.py | 54 ++++++++++- 2 files changed, 141 insertions(+), 4 deletions(-) diff --git a/colossalai/utils/activation_checkpoint.py b/colossalai/utils/activation_checkpoint.py index 2edd6b1a5572..b0f253718cc5 100644 --- a/colossalai/utils/activation_checkpoint.py +++ b/colossalai/utils/activation_checkpoint.py @@ -7,6 +7,8 @@ from colossalai.context.random import get_states, get_current_mode, set_seed_states, set_mode, sync_states from .cuda import get_current_device +import weakref + def copy_to_device(obj, device): if torch.is_tensor(obj): @@ -136,7 +138,7 @@ def backward(ctx, *args): return (None, None) + grads -def checkpoint(function, activation_offload, *args): +def checkpoint(function, activation_offload, *args, use_reentrant=True): """Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint. Args: @@ -146,4 +148,89 @@ def checkpoint(function, activation_offload, *args): Returns: Output of running function with provided args. """ - return CheckpointFunction.apply(function, activation_offload, *args) + if use_reentrant: + return CheckpointFunction.apply(function, activation_offload, *args) + else: + return _checkpoint_without_reentrant( + function, + activation_offload, + *args, + ) + + +def _checkpoint_without_reentrant(function, activation_offload=False, *args): + + fwd_cpu_state = torch.get_rng_state() + sync_states() + fwd_seed_states = get_states(copy=True) + fwd_current_mode = get_current_mode() + + if hasattr(torch, 'is_autocast_enabled'): + has_autocast_in_fwd = torch.is_autocast_enabled() + else: + has_autocast_in_fwd = False + + storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + weak_holder_list = [] + + class Holder(): + pass + + def pack(x): + res = Holder() + weak_holder_list.append(weakref.ref(res)) + return res + + def unpack(x): + unpack_counter = 0 + if len(storage) == 0: + # print(weak_holder_list) + def inner_pack(inner): + nonlocal unpack_counter + unpack_counter += 1 + + if weak_holder_list[unpack_counter - 1]() is None: + return + + storage[weak_holder_list[unpack_counter - 1]()] = inner.detach() + return + + def inner_unpack(packed): + raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.") + + torch.set_rng_state(fwd_cpu_state) + for parallel_mode, state in fwd_seed_states.items(): + set_seed_states(parallel_mode, state) + set_mode(fwd_current_mode) + + if activation_offload: + for arg in args: + if torch.is_tensor(arg): + arg = arg.to(device=device) + + if has_autocast_in_fwd: + with torch.enable_grad(), \ + torch.cuda.amp.autocast(), \ + torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): + _unused = function(*args) + else: + with torch.enable_grad(), \ + torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): + _unused = function(*args) + + if x not in storage: + raise RuntimeError("Attempt to retrieve a tensor saved by autograd multiple times without checkpoint" + " recomputation being triggered in between, this is not currently supported. Please" + " open an issue with details on your use case so that we can prioritize adding this.") + + return storage[x] + + device = get_current_device() + with torch.autograd.graph.saved_tensors_hooks(pack, unpack): + output = function(*args) + if activation_offload: + for arg in args: + if torch.is_tensor(arg): + arg = arg.to(device="cpu") + + return output diff --git a/tests/test_utils/test_activation_checkpointing.py b/tests/test_utils/test_activation_checkpointing.py index a68644254cfa..f46fbf0b7698 100644 --- a/tests/test_utils/test_activation_checkpointing.py +++ b/tests/test_utils/test_activation_checkpointing.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from colossalai.context.parallel_mode import ParallelMode from colossalai.context.random import add_seed, seed, set_mode, reset_seeds -from colossalai.utils import checkpoint +from colossalai.utils.activation_checkpoint import checkpoint def forward(x, weight): @@ -17,7 +17,6 @@ def forward(x, weight): @pytest.mark.gpu -@pytest.mark.skip("set seed error") @pytest.mark.parametrize("cpu_offload", [True, False]) def test_activation_checkpointing(cpu_offload): @@ -62,3 +61,54 @@ def test_activation_checkpointing(cpu_offload): # other tests will fail if running together with this test # as other tests can't overwrite the seed set by this test reset_seeds() + + +@pytest.mark.gpu +@pytest.mark.parametrize("cpu_offload", [True, False]) +def test_activation_checkpointing_reentrant_False(cpu_offload): + + # We put initilization here to avoid change cuda rng state below + inputs = torch.rand(2, 2, requires_grad=True, device='cuda') + weight = torch.rand(2, 4, requires_grad=True, device='cuda') + + # Get a copy of input tensors + inputs_ = torch.empty(2, 2, requires_grad=True, device='cuda') + inputs_.data.copy_(inputs.data) + weight_ = torch.empty(2, 4, requires_grad=True, device='cuda') + weight_.data.copy_(weight.data) + + add_seed(ParallelMode.GLOBAL, 1024) + add_seed(ParallelMode.DATA, 1026) + set_mode(ParallelMode.GLOBAL) + global_cuda_rng_state = torch.cuda.get_rng_state() + set_mode(ParallelMode.DATA) + data_parallel_cuda_rng_state = torch.cuda.get_rng_state() + set_mode(ParallelMode.GLOBAL) + + out = forward(inputs, weight) + loss = out.sum() + loss.backward() + + # Recover cuda rng states + set_mode(ParallelMode.GLOBAL) + torch.cuda.set_rng_state(global_cuda_rng_state) + set_mode(ParallelMode.DATA) + torch.cuda.set_rng_state(data_parallel_cuda_rng_state) + set_mode(ParallelMode.GLOBAL) + + out = checkpoint(forward, cpu_offload, inputs_, weight_, use_reentrant=False) + loss = out.sum() + loss.backward() + + assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match' + torch.cuda.empty_cache() + + # as seed manager is singleton + # if we don't reset seeds here, + # other tests will fail if running together with this test + # as other tests can't overwrite the seed set by this test + reset_seeds() + + +if __name__ == "__main__": + test_activation_checkpointing_reentrant_False(False) From c3abee97ad5abc3f61d3c9d7e8e49282f0eeb059 Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Tue, 16 Aug 2022 13:27:42 +0800 Subject: [PATCH 2/6] [utils] add some annotation in utils.activaion_checkpoint --- colossalai/utils/activation_checkpoint.py | 31 ++++++++++++++++++++--- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/colossalai/utils/activation_checkpoint.py b/colossalai/utils/activation_checkpoint.py index b0f253718cc5..fa9ed827a8a7 100644 --- a/colossalai/utils/activation_checkpoint.py +++ b/colossalai/utils/activation_checkpoint.py @@ -138,12 +138,15 @@ def backward(ctx, *args): return (None, None) + grads -def checkpoint(function, activation_offload, *args, use_reentrant=True): +def checkpoint(function, activation_offload, *args, use_reentrant: bool = True): """Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint. Args: function: Describe the forward pass function. It should know how to handle the input tuples. + activation_offload: The variable to check whether we should offload activation to cpu args (list): Tuple containing the parameters of the function + use_reentrant: Bool type to check if we need to use_reentrant, if use_reentrant=False, there + might be more flexibility for user to define there checkpoint function Returns: Output of running function with provided args. @@ -159,55 +162,69 @@ def checkpoint(function, activation_offload, *args, use_reentrant=True): def _checkpoint_without_reentrant(function, activation_offload=False, *args): - + # store rng_state fwd_cpu_state = torch.get_rng_state() sync_states() fwd_seed_states = get_states(copy=True) fwd_current_mode = get_current_mode() + # check if use autocast if hasattr(torch, 'is_autocast_enabled'): has_autocast_in_fwd = torch.is_autocast_enabled() else: has_autocast_in_fwd = False + # using WeakKeyDictionary to store all the activation the first time we call unpack storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() weak_holder_list = [] + # class for weakref.ref class Holder(): pass + # return a Holder object for later unpack process def pack(x): res = Holder() weak_holder_list.append(weakref.ref(res)) return res + # unpack hook def unpack(x): unpack_counter = 0 + + # re-compute all the activation inside the function when we first call unpack if len(storage) == 0: - # print(weak_holder_list) + def inner_pack(inner): nonlocal unpack_counter unpack_counter += 1 + # If the holder went out of scope, the SavedVariable is dead and so + # the value will never be read from the storage. Skip filling it. if weak_holder_list[unpack_counter - 1]() is None: return + # Use detach here to ensure we don't keep the temporary autograd + # graph created during the second forward storage[weak_holder_list[unpack_counter - 1]()] = inner.detach() return def inner_unpack(packed): raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.") + # restore rng state torch.set_rng_state(fwd_cpu_state) for parallel_mode, state in fwd_seed_states.items(): set_seed_states(parallel_mode, state) set_mode(fwd_current_mode) + # reload arg into device if needed if activation_offload: for arg in args: if torch.is_tensor(arg): arg = arg.to(device=device) + # rerun forward, the inner_pack will store all the activations in storage if has_autocast_in_fwd: with torch.enable_grad(), \ torch.cuda.amp.autocast(), \ @@ -225,9 +242,15 @@ def inner_unpack(packed): return storage[x] - device = get_current_device() + # get device if we need to offload the activation + if activation_offload: + device = get_current_device() + + # run function with pack and unpack as saved_tensors_hooks with torch.autograd.graph.saved_tensors_hooks(pack, unpack): output = function(*args) + + # offload activation if needed if activation_offload: for arg in args: if torch.is_tensor(arg): From 2d3c672431626a2496779bd5f5593edeef060a1a Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Tue, 16 Aug 2022 13:46:59 +0800 Subject: [PATCH 3/6] [test] add reset_seed at the beginning of tests in test_actiavion_checkpointing.py --- tests/test_utils/test_activation_checkpointing.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_utils/test_activation_checkpointing.py b/tests/test_utils/test_activation_checkpointing.py index f46fbf0b7698..50eb9c5d236d 100644 --- a/tests/test_utils/test_activation_checkpointing.py +++ b/tests/test_utils/test_activation_checkpointing.py @@ -20,6 +20,11 @@ def forward(x, weight): @pytest.mark.parametrize("cpu_offload", [True, False]) def test_activation_checkpointing(cpu_offload): + # as seed manager is singleton + # if we don't reset seeds here, + # other tests might affect this test + reset_seeds() + # We put initilization here to avoid change cuda rng state below inputs = torch.rand(2, 2, requires_grad=True, device='cuda') weight = torch.rand(2, 4, requires_grad=True, device='cuda') @@ -67,6 +72,11 @@ def test_activation_checkpointing(cpu_offload): @pytest.mark.parametrize("cpu_offload", [True, False]) def test_activation_checkpointing_reentrant_False(cpu_offload): + # as seed manager is singleton + # if we don't reset seeds here, + # other tests might affect this test + reset_seeds() + # We put initilization here to avoid change cuda rng state below inputs = torch.rand(2, 2, requires_grad=True, device='cuda') weight = torch.rand(2, 4, requires_grad=True, device='cuda') From 64b9032e1800cff70e583f2316d3130cd52d76a5 Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Tue, 16 Aug 2022 13:51:52 +0800 Subject: [PATCH 4/6] [test] modify test_activation_checkpoint.py --- .../test_activation_checkpointing.py | 59 ++----------------- 1 file changed, 4 insertions(+), 55 deletions(-) diff --git a/tests/test_utils/test_activation_checkpointing.py b/tests/test_utils/test_activation_checkpointing.py index 50eb9c5d236d..788324aaaa36 100644 --- a/tests/test_utils/test_activation_checkpointing.py +++ b/tests/test_utils/test_activation_checkpointing.py @@ -17,8 +17,9 @@ def forward(x, weight): @pytest.mark.gpu +@pytest.mark.parametrize("use_reentrant", [True, False]) @pytest.mark.parametrize("cpu_offload", [True, False]) -def test_activation_checkpointing(cpu_offload): +def test_activation_checkpointing(cpu_offload, use_reentrant): # as seed manager is singleton # if we don't reset seeds here, @@ -54,59 +55,7 @@ def test_activation_checkpointing(cpu_offload): torch.cuda.set_rng_state(data_parallel_cuda_rng_state) set_mode(ParallelMode.GLOBAL) - out = checkpoint(forward, cpu_offload, inputs_, weight_) - loss = out.sum() - loss.backward() - - assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match' - torch.cuda.empty_cache() - - # as seed manager is singleton - # if we don't reset seeds here, - # other tests will fail if running together with this test - # as other tests can't overwrite the seed set by this test - reset_seeds() - - -@pytest.mark.gpu -@pytest.mark.parametrize("cpu_offload", [True, False]) -def test_activation_checkpointing_reentrant_False(cpu_offload): - - # as seed manager is singleton - # if we don't reset seeds here, - # other tests might affect this test - reset_seeds() - - # We put initilization here to avoid change cuda rng state below - inputs = torch.rand(2, 2, requires_grad=True, device='cuda') - weight = torch.rand(2, 4, requires_grad=True, device='cuda') - - # Get a copy of input tensors - inputs_ = torch.empty(2, 2, requires_grad=True, device='cuda') - inputs_.data.copy_(inputs.data) - weight_ = torch.empty(2, 4, requires_grad=True, device='cuda') - weight_.data.copy_(weight.data) - - add_seed(ParallelMode.GLOBAL, 1024) - add_seed(ParallelMode.DATA, 1026) - set_mode(ParallelMode.GLOBAL) - global_cuda_rng_state = torch.cuda.get_rng_state() - set_mode(ParallelMode.DATA) - data_parallel_cuda_rng_state = torch.cuda.get_rng_state() - set_mode(ParallelMode.GLOBAL) - - out = forward(inputs, weight) - loss = out.sum() - loss.backward() - - # Recover cuda rng states - set_mode(ParallelMode.GLOBAL) - torch.cuda.set_rng_state(global_cuda_rng_state) - set_mode(ParallelMode.DATA) - torch.cuda.set_rng_state(data_parallel_cuda_rng_state) - set_mode(ParallelMode.GLOBAL) - - out = checkpoint(forward, cpu_offload, inputs_, weight_, use_reentrant=False) + out = checkpoint(forward, cpu_offload, inputs_, weight_, use_reentrant=use_reentrant) loss = out.sum() loss.backward() @@ -121,4 +70,4 @@ def test_activation_checkpointing_reentrant_False(cpu_offload): if __name__ == "__main__": - test_activation_checkpointing_reentrant_False(False) + test_activation_checkpointing(False, False) From bcdba629e9f7e6ec073d7d43854105993d1ccaca Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Tue, 16 Aug 2022 14:56:15 +0800 Subject: [PATCH 5/6] [test] modify test for reentrant=False --- .../test_activation_checkpointing.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/tests/test_utils/test_activation_checkpointing.py b/tests/test_utils/test_activation_checkpointing.py index 788324aaaa36..3ac75fb00c86 100644 --- a/tests/test_utils/test_activation_checkpointing.py +++ b/tests/test_utils/test_activation_checkpointing.py @@ -16,6 +16,28 @@ def forward(x, weight): return out_ +def forward_inplace_ckpt(x, weight, cpu_offload=False): + out = torch.matmul(x, weight) + bn = torch.nn.BatchNorm1d(4, affine=False) + bn = bn.to(device="cuda") + out = bn(out) + + def ckpt0(x): + return F.relu(x, inplace=True) + + out = checkpoint(ckpt0, cpu_offload, out, use_reentrant=False) + return out + + +def forward_inplace(x, weight): + out = torch.matmul(x, weight) + bn = torch.nn.BatchNorm1d(4, affine=False) + bn = bn.to(device="cuda") + out = bn(out) + out = F.relu(out, inplace=True) + return out + + @pytest.mark.gpu @pytest.mark.parametrize("use_reentrant", [True, False]) @pytest.mark.parametrize("cpu_offload", [True, False]) @@ -62,6 +84,33 @@ def test_activation_checkpointing(cpu_offload, use_reentrant): assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match' torch.cuda.empty_cache() + # Extra test for use_reentrant=False + if use_reentrant == False: + # Recover cuda rng states + set_mode(ParallelMode.GLOBAL) + torch.cuda.set_rng_state(global_cuda_rng_state) + set_mode(ParallelMode.DATA) + torch.cuda.set_rng_state(data_parallel_cuda_rng_state) + set_mode(ParallelMode.GLOBAL) + + out = forward_inplace(inputs, weight) + loss = out.sum() + loss.backward() + + # Recover cuda rng states + set_mode(ParallelMode.GLOBAL) + torch.cuda.set_rng_state(global_cuda_rng_state) + set_mode(ParallelMode.DATA) + torch.cuda.set_rng_state(data_parallel_cuda_rng_state) + set_mode(ParallelMode.GLOBAL) + + out = forward_inplace_ckpt(inputs_, weight_, cpu_offload=cpu_offload) + loss = out.sum() + loss.backward() + + assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match' + torch.cuda.empty_cache() + # as seed manager is singleton # if we don't reset seeds here, # other tests will fail if running together with this test From f8ea3da62db4473edd4923f86e7b46b057bc815d Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Tue, 16 Aug 2022 21:34:53 +0800 Subject: [PATCH 6/6] [fx] Add use_reentrant=False of checkpoint into codegen --- .../codegen/activation_checkpoint_codegen.py | 22 +++++++-- .../test_activation_checkpoint_codegen.py | 49 +++++++++++++------ 2 files changed, 53 insertions(+), 18 deletions(-) diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py index 4a4bbef4cf74..53eb46529113 100644 --- a/colossalai/fx/codegen/activation_checkpoint_codegen.py +++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py @@ -99,13 +99,13 @@ def _gen_ckpt_output(output_vars: List[str]) -> str: return f"return {', '.join(output_vars)}" -def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars): +def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reentrant=True): """ Generate the checkpoint function call code text """ outputs = ', '.join(output_vars) inputs = ', '.join(input_vars) - return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_{label}, {activation_offload}, {inputs})' + return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})' def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unused_value_func): @@ -162,8 +162,24 @@ def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unu else: activation_offload = False + # we need to check if the checkpoint need use_reentrant=False + use_reentrant = True + for var in input_vars[label]: + input_node = [item for item in node_list if item.name == var] + input_node = input_node[0] + for user in input_node.users: + if hasattr(user, "activation_checkpoint"): + if user.activation_checkpoint == label: + if user.op == "call_module": + if hasattr(user.graph.owning_module.get_submodule(user.target), "inplace"): + use_reentrant = not user.graph.owning_module.get_submodule(user.target).inplace + + elif user.op == "call_function": + if "inplace" in user.kwargs: + use_reentrant = not user.kwargs["inplace"] + # generate checkpoint function call in a new line - usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label]) + usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label], use_reentrant) usage += '\n' body.append(usage) within_ckpt_region = False diff --git a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py index fe5c638b2e24..9c1bc57a3973 100644 --- a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py +++ b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py @@ -1,5 +1,6 @@ from operator import mod import torch +import torch.nn.functional as F import pytest import torch.multiprocessing as mp from torch.utils.checkpoint import checkpoint @@ -26,7 +27,17 @@ def __init__(self): self.linear2 = torch.nn.Linear(4, 4) def forward(self, x): - return self.linear1(x), self.linear1(x) + return self.linear1(x), self.linear2(x) + + +class relu(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + self.relu = torch.nn.ReLU(inplace=True) + + def forward(self, x): + return self.relu(x) class MyModule(torch.nn.Module): @@ -34,12 +45,17 @@ class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.mlp1 = MLP() - self.mlp2 = MLP() + self.relu = relu() self.linear3 = torch.nn.Linear(4, 4) def forward(self, x): y1, y2 = checkpoint(self.mlp1, x) - y3, y4 = checkpoint(self.mlp2, x) + y3 = checkpoint(self.relu, x) + + def ckpt2(x): + return F.relu(x, inplace=True) + + y4 = checkpoint(ckpt2, x) return y1 + y2 + y3 + y4 @@ -65,8 +81,8 @@ def _run_act_ckpt_codegen(rank): # check ops are annotated with ckpt # also annotate the selected node for offloading - ckpt_nodes = ['mlp1_linear1', 'mlp1_linear1_1', 'mlp2_linear1', 'mlp2_linear1_1'] - offload_starts = ['mlp2_linear1'] + ckpt_nodes = ['mlp1_linear1', 'mlp1_linear2', 'relu_relu', 'relu'] + offload_starts = ['mlp1_linear1'] for node in graph.nodes: if node.name in ckpt_nodes: assert hasattr(node, 'activation_checkpoint') @@ -75,15 +91,17 @@ def _run_act_ckpt_codegen(rank): if node.name in offload_starts: setattr(node, 'activation_offload', True) + gm = GraphModule(model, graph) + gm.recompile() + # assert checkpoint function will be generated and # the offload option is correct code = graph.python_code('self').src - assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, False, x)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, True, x)' in code + assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, True, x, use_reentrant=True)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, False, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_2, False, x, use_reentrant=False)' in code # recompile and verify the outputs are consistent - gm = GraphModule(model, graph) - gm.recompile() fx_out = gm(data) assert torch.equal(non_fx_out, fx_out) @@ -117,8 +135,8 @@ def _run_act_ckpt_python_code_torch11(rank): graph._python_code = python_code_with_activation_checkpoint.__get__(graph) # check ops are annotated with ckpt - ckpt_nodes = ['mlp1_linear1', 'mlp1_linear1_1', 'mlp2_linear1', 'mlp2_linear1_1'] - offload_starts = ['mlp2_linear1'] + ckpt_nodes = ['mlp1_linear1', 'mlp1_linear2', 'relu_relu', 'relu'] + offload_starts = ['mlp1_linear1'] for node in graph.nodes: if node.name in ckpt_nodes: assert hasattr(node, 'activation_checkpoint') @@ -127,15 +145,16 @@ def _run_act_ckpt_python_code_torch11(rank): if node.name in offload_starts: setattr(node, 'activation_offload', True) + gm = GraphModule(model, graph) + gm.recompile() # assert checkpoint function will be generated and # the offload option is correct code = graph.python_code('self').src - assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, False, x)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, True, x)' in code + assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, True, x, use_reentrant=True)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, False, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_2, False, x, use_reentrant=False)' in code # recompile and verify the outputs are consistent - gm = GraphModule(model, graph) - gm.recompile() fx_out = gm(data) assert torch.equal(non_fx_out, fx_out)