From ed273b7d008044af36de7e8af9b2ee171af6444d Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Tue, 16 Aug 2022 13:01:47 +0800 Subject: [PATCH 1/5] [utils] Add use_reetrant=False into colossalai checkpoint --- colossalai/utils/activation_checkpoint.py | 91 ++++++++++++++++++- .../test_activation_checkpointing.py | 54 ++++++++++- 2 files changed, 141 insertions(+), 4 deletions(-) diff --git a/colossalai/utils/activation_checkpoint.py b/colossalai/utils/activation_checkpoint.py index 2edd6b1a5572..b0f253718cc5 100644 --- a/colossalai/utils/activation_checkpoint.py +++ b/colossalai/utils/activation_checkpoint.py @@ -7,6 +7,8 @@ from colossalai.context.random import get_states, get_current_mode, set_seed_states, set_mode, sync_states from .cuda import get_current_device +import weakref + def copy_to_device(obj, device): if torch.is_tensor(obj): @@ -136,7 +138,7 @@ def backward(ctx, *args): return (None, None) + grads -def checkpoint(function, activation_offload, *args): +def checkpoint(function, activation_offload, *args, use_reentrant=True): """Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint. Args: @@ -146,4 +148,89 @@ def checkpoint(function, activation_offload, *args): Returns: Output of running function with provided args. """ - return CheckpointFunction.apply(function, activation_offload, *args) + if use_reentrant: + return CheckpointFunction.apply(function, activation_offload, *args) + else: + return _checkpoint_without_reentrant( + function, + activation_offload, + *args, + ) + + +def _checkpoint_without_reentrant(function, activation_offload=False, *args): + + fwd_cpu_state = torch.get_rng_state() + sync_states() + fwd_seed_states = get_states(copy=True) + fwd_current_mode = get_current_mode() + + if hasattr(torch, 'is_autocast_enabled'): + has_autocast_in_fwd = torch.is_autocast_enabled() + else: + has_autocast_in_fwd = False + + storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + weak_holder_list = [] + + class Holder(): + pass + + def pack(x): + res = Holder() + weak_holder_list.append(weakref.ref(res)) + return res + + def unpack(x): + unpack_counter = 0 + if len(storage) == 0: + # print(weak_holder_list) + def inner_pack(inner): + nonlocal unpack_counter + unpack_counter += 1 + + if weak_holder_list[unpack_counter - 1]() is None: + return + + storage[weak_holder_list[unpack_counter - 1]()] = inner.detach() + return + + def inner_unpack(packed): + raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.") + + torch.set_rng_state(fwd_cpu_state) + for parallel_mode, state in fwd_seed_states.items(): + set_seed_states(parallel_mode, state) + set_mode(fwd_current_mode) + + if activation_offload: + for arg in args: + if torch.is_tensor(arg): + arg = arg.to(device=device) + + if has_autocast_in_fwd: + with torch.enable_grad(), \ + torch.cuda.amp.autocast(), \ + torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): + _unused = function(*args) + else: + with torch.enable_grad(), \ + torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): + _unused = function(*args) + + if x not in storage: + raise RuntimeError("Attempt to retrieve a tensor saved by autograd multiple times without checkpoint" + " recomputation being triggered in between, this is not currently supported. Please" + " open an issue with details on your use case so that we can prioritize adding this.") + + return storage[x] + + device = get_current_device() + with torch.autograd.graph.saved_tensors_hooks(pack, unpack): + output = function(*args) + if activation_offload: + for arg in args: + if torch.is_tensor(arg): + arg = arg.to(device="cpu") + + return output diff --git a/tests/test_utils/test_activation_checkpointing.py b/tests/test_utils/test_activation_checkpointing.py index a68644254cfa..f46fbf0b7698 100644 --- a/tests/test_utils/test_activation_checkpointing.py +++ b/tests/test_utils/test_activation_checkpointing.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from colossalai.context.parallel_mode import ParallelMode from colossalai.context.random import add_seed, seed, set_mode, reset_seeds -from colossalai.utils import checkpoint +from colossalai.utils.activation_checkpoint import checkpoint def forward(x, weight): @@ -17,7 +17,6 @@ def forward(x, weight): @pytest.mark.gpu -@pytest.mark.skip("set seed error") @pytest.mark.parametrize("cpu_offload", [True, False]) def test_activation_checkpointing(cpu_offload): @@ -62,3 +61,54 @@ def test_activation_checkpointing(cpu_offload): # other tests will fail if running together with this test # as other tests can't overwrite the seed set by this test reset_seeds() + + +@pytest.mark.gpu +@pytest.mark.parametrize("cpu_offload", [True, False]) +def test_activation_checkpointing_reentrant_False(cpu_offload): + + # We put initilization here to avoid change cuda rng state below + inputs = torch.rand(2, 2, requires_grad=True, device='cuda') + weight = torch.rand(2, 4, requires_grad=True, device='cuda') + + # Get a copy of input tensors + inputs_ = torch.empty(2, 2, requires_grad=True, device='cuda') + inputs_.data.copy_(inputs.data) + weight_ = torch.empty(2, 4, requires_grad=True, device='cuda') + weight_.data.copy_(weight.data) + + add_seed(ParallelMode.GLOBAL, 1024) + add_seed(ParallelMode.DATA, 1026) + set_mode(ParallelMode.GLOBAL) + global_cuda_rng_state = torch.cuda.get_rng_state() + set_mode(ParallelMode.DATA) + data_parallel_cuda_rng_state = torch.cuda.get_rng_state() + set_mode(ParallelMode.GLOBAL) + + out = forward(inputs, weight) + loss = out.sum() + loss.backward() + + # Recover cuda rng states + set_mode(ParallelMode.GLOBAL) + torch.cuda.set_rng_state(global_cuda_rng_state) + set_mode(ParallelMode.DATA) + torch.cuda.set_rng_state(data_parallel_cuda_rng_state) + set_mode(ParallelMode.GLOBAL) + + out = checkpoint(forward, cpu_offload, inputs_, weight_, use_reentrant=False) + loss = out.sum() + loss.backward() + + assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match' + torch.cuda.empty_cache() + + # as seed manager is singleton + # if we don't reset seeds here, + # other tests will fail if running together with this test + # as other tests can't overwrite the seed set by this test + reset_seeds() + + +if __name__ == "__main__": + test_activation_checkpointing_reentrant_False(False) From c3abee97ad5abc3f61d3c9d7e8e49282f0eeb059 Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Tue, 16 Aug 2022 13:27:42 +0800 Subject: [PATCH 2/5] [utils] add some annotation in utils.activaion_checkpoint --- colossalai/utils/activation_checkpoint.py | 31 ++++++++++++++++++++--- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/colossalai/utils/activation_checkpoint.py b/colossalai/utils/activation_checkpoint.py index b0f253718cc5..fa9ed827a8a7 100644 --- a/colossalai/utils/activation_checkpoint.py +++ b/colossalai/utils/activation_checkpoint.py @@ -138,12 +138,15 @@ def backward(ctx, *args): return (None, None) + grads -def checkpoint(function, activation_offload, *args, use_reentrant=True): +def checkpoint(function, activation_offload, *args, use_reentrant: bool = True): """Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint. Args: function: Describe the forward pass function. It should know how to handle the input tuples. + activation_offload: The variable to check whether we should offload activation to cpu args (list): Tuple containing the parameters of the function + use_reentrant: Bool type to check if we need to use_reentrant, if use_reentrant=False, there + might be more flexibility for user to define there checkpoint function Returns: Output of running function with provided args. @@ -159,55 +162,69 @@ def checkpoint(function, activation_offload, *args, use_reentrant=True): def _checkpoint_without_reentrant(function, activation_offload=False, *args): - + # store rng_state fwd_cpu_state = torch.get_rng_state() sync_states() fwd_seed_states = get_states(copy=True) fwd_current_mode = get_current_mode() + # check if use autocast if hasattr(torch, 'is_autocast_enabled'): has_autocast_in_fwd = torch.is_autocast_enabled() else: has_autocast_in_fwd = False + # using WeakKeyDictionary to store all the activation the first time we call unpack storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() weak_holder_list = [] + # class for weakref.ref class Holder(): pass + # return a Holder object for later unpack process def pack(x): res = Holder() weak_holder_list.append(weakref.ref(res)) return res + # unpack hook def unpack(x): unpack_counter = 0 + + # re-compute all the activation inside the function when we first call unpack if len(storage) == 0: - # print(weak_holder_list) + def inner_pack(inner): nonlocal unpack_counter unpack_counter += 1 + # If the holder went out of scope, the SavedVariable is dead and so + # the value will never be read from the storage. Skip filling it. if weak_holder_list[unpack_counter - 1]() is None: return + # Use detach here to ensure we don't keep the temporary autograd + # graph created during the second forward storage[weak_holder_list[unpack_counter - 1]()] = inner.detach() return def inner_unpack(packed): raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.") + # restore rng state torch.set_rng_state(fwd_cpu_state) for parallel_mode, state in fwd_seed_states.items(): set_seed_states(parallel_mode, state) set_mode(fwd_current_mode) + # reload arg into device if needed if activation_offload: for arg in args: if torch.is_tensor(arg): arg = arg.to(device=device) + # rerun forward, the inner_pack will store all the activations in storage if has_autocast_in_fwd: with torch.enable_grad(), \ torch.cuda.amp.autocast(), \ @@ -225,9 +242,15 @@ def inner_unpack(packed): return storage[x] - device = get_current_device() + # get device if we need to offload the activation + if activation_offload: + device = get_current_device() + + # run function with pack and unpack as saved_tensors_hooks with torch.autograd.graph.saved_tensors_hooks(pack, unpack): output = function(*args) + + # offload activation if needed if activation_offload: for arg in args: if torch.is_tensor(arg): From 2d3c672431626a2496779bd5f5593edeef060a1a Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Tue, 16 Aug 2022 13:46:59 +0800 Subject: [PATCH 3/5] [test] add reset_seed at the beginning of tests in test_actiavion_checkpointing.py --- tests/test_utils/test_activation_checkpointing.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_utils/test_activation_checkpointing.py b/tests/test_utils/test_activation_checkpointing.py index f46fbf0b7698..50eb9c5d236d 100644 --- a/tests/test_utils/test_activation_checkpointing.py +++ b/tests/test_utils/test_activation_checkpointing.py @@ -20,6 +20,11 @@ def forward(x, weight): @pytest.mark.parametrize("cpu_offload", [True, False]) def test_activation_checkpointing(cpu_offload): + # as seed manager is singleton + # if we don't reset seeds here, + # other tests might affect this test + reset_seeds() + # We put initilization here to avoid change cuda rng state below inputs = torch.rand(2, 2, requires_grad=True, device='cuda') weight = torch.rand(2, 4, requires_grad=True, device='cuda') @@ -67,6 +72,11 @@ def test_activation_checkpointing(cpu_offload): @pytest.mark.parametrize("cpu_offload", [True, False]) def test_activation_checkpointing_reentrant_False(cpu_offload): + # as seed manager is singleton + # if we don't reset seeds here, + # other tests might affect this test + reset_seeds() + # We put initilization here to avoid change cuda rng state below inputs = torch.rand(2, 2, requires_grad=True, device='cuda') weight = torch.rand(2, 4, requires_grad=True, device='cuda') From 64b9032e1800cff70e583f2316d3130cd52d76a5 Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Tue, 16 Aug 2022 13:51:52 +0800 Subject: [PATCH 4/5] [test] modify test_activation_checkpoint.py --- .../test_activation_checkpointing.py | 59 ++----------------- 1 file changed, 4 insertions(+), 55 deletions(-) diff --git a/tests/test_utils/test_activation_checkpointing.py b/tests/test_utils/test_activation_checkpointing.py index 50eb9c5d236d..788324aaaa36 100644 --- a/tests/test_utils/test_activation_checkpointing.py +++ b/tests/test_utils/test_activation_checkpointing.py @@ -17,8 +17,9 @@ def forward(x, weight): @pytest.mark.gpu +@pytest.mark.parametrize("use_reentrant", [True, False]) @pytest.mark.parametrize("cpu_offload", [True, False]) -def test_activation_checkpointing(cpu_offload): +def test_activation_checkpointing(cpu_offload, use_reentrant): # as seed manager is singleton # if we don't reset seeds here, @@ -54,59 +55,7 @@ def test_activation_checkpointing(cpu_offload): torch.cuda.set_rng_state(data_parallel_cuda_rng_state) set_mode(ParallelMode.GLOBAL) - out = checkpoint(forward, cpu_offload, inputs_, weight_) - loss = out.sum() - loss.backward() - - assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match' - torch.cuda.empty_cache() - - # as seed manager is singleton - # if we don't reset seeds here, - # other tests will fail if running together with this test - # as other tests can't overwrite the seed set by this test - reset_seeds() - - -@pytest.mark.gpu -@pytest.mark.parametrize("cpu_offload", [True, False]) -def test_activation_checkpointing_reentrant_False(cpu_offload): - - # as seed manager is singleton - # if we don't reset seeds here, - # other tests might affect this test - reset_seeds() - - # We put initilization here to avoid change cuda rng state below - inputs = torch.rand(2, 2, requires_grad=True, device='cuda') - weight = torch.rand(2, 4, requires_grad=True, device='cuda') - - # Get a copy of input tensors - inputs_ = torch.empty(2, 2, requires_grad=True, device='cuda') - inputs_.data.copy_(inputs.data) - weight_ = torch.empty(2, 4, requires_grad=True, device='cuda') - weight_.data.copy_(weight.data) - - add_seed(ParallelMode.GLOBAL, 1024) - add_seed(ParallelMode.DATA, 1026) - set_mode(ParallelMode.GLOBAL) - global_cuda_rng_state = torch.cuda.get_rng_state() - set_mode(ParallelMode.DATA) - data_parallel_cuda_rng_state = torch.cuda.get_rng_state() - set_mode(ParallelMode.GLOBAL) - - out = forward(inputs, weight) - loss = out.sum() - loss.backward() - - # Recover cuda rng states - set_mode(ParallelMode.GLOBAL) - torch.cuda.set_rng_state(global_cuda_rng_state) - set_mode(ParallelMode.DATA) - torch.cuda.set_rng_state(data_parallel_cuda_rng_state) - set_mode(ParallelMode.GLOBAL) - - out = checkpoint(forward, cpu_offload, inputs_, weight_, use_reentrant=False) + out = checkpoint(forward, cpu_offload, inputs_, weight_, use_reentrant=use_reentrant) loss = out.sum() loss.backward() @@ -121,4 +70,4 @@ def test_activation_checkpointing_reentrant_False(cpu_offload): if __name__ == "__main__": - test_activation_checkpointing_reentrant_False(False) + test_activation_checkpointing(False, False) From bcdba629e9f7e6ec073d7d43854105993d1ccaca Mon Sep 17 00:00:00 2001 From: Cypher30 <1529318642@qq.com> Date: Tue, 16 Aug 2022 14:56:15 +0800 Subject: [PATCH 5/5] [test] modify test for reentrant=False --- .../test_activation_checkpointing.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/tests/test_utils/test_activation_checkpointing.py b/tests/test_utils/test_activation_checkpointing.py index 788324aaaa36..3ac75fb00c86 100644 --- a/tests/test_utils/test_activation_checkpointing.py +++ b/tests/test_utils/test_activation_checkpointing.py @@ -16,6 +16,28 @@ def forward(x, weight): return out_ +def forward_inplace_ckpt(x, weight, cpu_offload=False): + out = torch.matmul(x, weight) + bn = torch.nn.BatchNorm1d(4, affine=False) + bn = bn.to(device="cuda") + out = bn(out) + + def ckpt0(x): + return F.relu(x, inplace=True) + + out = checkpoint(ckpt0, cpu_offload, out, use_reentrant=False) + return out + + +def forward_inplace(x, weight): + out = torch.matmul(x, weight) + bn = torch.nn.BatchNorm1d(4, affine=False) + bn = bn.to(device="cuda") + out = bn(out) + out = F.relu(out, inplace=True) + return out + + @pytest.mark.gpu @pytest.mark.parametrize("use_reentrant", [True, False]) @pytest.mark.parametrize("cpu_offload", [True, False]) @@ -62,6 +84,33 @@ def test_activation_checkpointing(cpu_offload, use_reentrant): assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match' torch.cuda.empty_cache() + # Extra test for use_reentrant=False + if use_reentrant == False: + # Recover cuda rng states + set_mode(ParallelMode.GLOBAL) + torch.cuda.set_rng_state(global_cuda_rng_state) + set_mode(ParallelMode.DATA) + torch.cuda.set_rng_state(data_parallel_cuda_rng_state) + set_mode(ParallelMode.GLOBAL) + + out = forward_inplace(inputs, weight) + loss = out.sum() + loss.backward() + + # Recover cuda rng states + set_mode(ParallelMode.GLOBAL) + torch.cuda.set_rng_state(global_cuda_rng_state) + set_mode(ParallelMode.DATA) + torch.cuda.set_rng_state(data_parallel_cuda_rng_state) + set_mode(ParallelMode.GLOBAL) + + out = forward_inplace_ckpt(inputs_, weight_, cpu_offload=cpu_offload) + loss = out.sum() + loss.backward() + + assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match' + torch.cuda.empty_cache() + # as seed manager is singleton # if we don't reset seeds here, # other tests will fail if running together with this test