From 06235d21298778ca3f61ac08915171576f736092 Mon Sep 17 00:00:00 2001 From: FrankLeeeee Date: Tue, 16 Aug 2022 11:23:19 +0800 Subject: [PATCH 1/3] [test] recovered activation checkpointig test --- tests/test_utils/test_activation_checkpointing.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_utils/test_activation_checkpointing.py b/tests/test_utils/test_activation_checkpointing.py index a68644254cfa..5a5d5deab8b9 100644 --- a/tests/test_utils/test_activation_checkpointing.py +++ b/tests/test_utils/test_activation_checkpointing.py @@ -17,7 +17,6 @@ def forward(x, weight): @pytest.mark.gpu -@pytest.mark.skip("set seed error") @pytest.mark.parametrize("cpu_offload", [True, False]) def test_activation_checkpointing(cpu_offload): @@ -62,3 +61,8 @@ def test_activation_checkpointing(cpu_offload): # other tests will fail if running together with this test # as other tests can't overwrite the seed set by this test reset_seeds() + + +if __name__ == '__main__': + test_activation_checkpointing(cpu_offload=False) + test_activation_checkpointing(cpu_offload=True) From 0b82ac373cbd40fe5658aabba963e7938eb18a56 Mon Sep 17 00:00:00 2001 From: FrankLeeeee Date: Tue, 16 Aug 2022 11:48:26 +0800 Subject: [PATCH 2/3] polish code --- tests/test_utils/test_activation_checkpointing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_utils/test_activation_checkpointing.py b/tests/test_utils/test_activation_checkpointing.py index 5a5d5deab8b9..f2b5290508d6 100644 --- a/tests/test_utils/test_activation_checkpointing.py +++ b/tests/test_utils/test_activation_checkpointing.py @@ -19,6 +19,8 @@ def forward(x, weight): @pytest.mark.gpu @pytest.mark.parametrize("cpu_offload", [True, False]) def test_activation_checkpointing(cpu_offload): + # clear all previous seeds possibly set by other tests + reset_seeds() # We put initilization here to avoid change cuda rng state below inputs = torch.rand(2, 2, requires_grad=True, device='cuda') From 367ff8e6950685ed0542dec3c0e82290bea089ba Mon Sep 17 00:00:00 2001 From: FrankLeeeee Date: Tue, 16 Aug 2022 13:28:36 +0800 Subject: [PATCH 3/3] polish code --- tests/test_utils/test_activation_checkpointing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_utils/test_activation_checkpointing.py b/tests/test_utils/test_activation_checkpointing.py index f2b5290508d6..9ba600170d28 100644 --- a/tests/test_utils/test_activation_checkpointing.py +++ b/tests/test_utils/test_activation_checkpointing.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from colossalai.context.parallel_mode import ParallelMode from colossalai.context.random import add_seed, seed, set_mode, reset_seeds -from colossalai.utils import checkpoint +from colossalai.utils.activation_checkpoint import checkpoint def forward(x, weight):