Skip to content

[utils] Add use_reetrant=False in utils.activation_checkpoint#1460

Merged
Cypher30 merged 21 commits intohpcaitech:mainfrom
Cypher30:feature/add_ckpt_reentrant_False
Aug 16, 2022
Merged

[utils] Add use_reetrant=False in utils.activation_checkpoint#1460
Cypher30 merged 21 commits intohpcaitech:mainfrom
Cypher30:feature/add_ckpt_reentrant_False

Conversation

@Cypher30
Copy link
Copy Markdown
Contributor

We encountered the situation that if the first operation in a checkpoint function is an in-place operation, it will raise error as in original checkpoint process, it calls run_function with the detached_input, which will be viewed as a leaf node with requires_grad=True, the autograd itself will not allow this thing to happen. We check the torch itself has the option to set use_reentrant=False that could address this problem, using the torch.autograd.graph.saved_tensors_hooks to avoid calling the re-computation with detached_input. So I add this feature inside our colossalai checkpoint, and add our activation offload process for use_reetrant=False case as the original torch checkpoint doesn't provide this. I also modify the activation_checkpoint test which fail in previous PR [test] recovered activation checkpointig test #1459

Comment on lines +68 to +69
def test_activation_checkpointing_reentrant_False(cpu_offload):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a reset_seed at the start of the function so that other tests will not affect this one.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This applies to the test function above as well.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oooo! I will modify it



@pytest.mark.gpu
@pytest.mark.parametrize("cpu_offload", [True, False])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If these two test functions only differ by the use_reentrant variable, you can just add parameterize here instead of creating a duplicated function.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay I will modify this

Comment thread colossalai/utils/activation_checkpoint.py
Copy link
Copy Markdown
Contributor

@super-dainiu super-dainiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approve

@Cypher30 Cypher30 deleted the feature/add_ckpt_reentrant_False branch August 26, 2022 06:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants