Skip to content

[fx] fix test and algorithm bugs in activation checkpointing.#1451

Merged
Cypher30 merged 19 commits intohpcaitech:mainfrom
super-dainiu:feature/more_ckpt
Aug 15, 2022
Merged

[fx] fix test and algorithm bugs in activation checkpointing.#1451
Cypher30 merged 19 commits intohpcaitech:mainfrom
super-dainiu:feature/more_ckpt

Conversation

@super-dainiu
Copy link
Copy Markdown
Contributor

@super-dainiu super-dainiu commented Aug 12, 2022

What's new?

I regretted using torch.randn(2, 3, 224, 224) previously #1446 for testing because this consumes too much time on CI.
Also, I made some modifications to the search algorithm (mostly conditions for annotations) to avoid crashes in ActivationCheckpointCodeGen.

What's wrong?

However, I did not figure out why tracing on densenet121 got an error.

Traceback (most recent call last):
  File "/home/ColossalAI/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py", line 104, in <module>
    test_ckpt_solver()
  File "/home/ColossalAI/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py", line 73, in test_ckpt_solver
    check_backward_consistency(m, gm, solver, model_cls)
  File "/home/ColossalAI/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py", line 46, in check_backward_consistency
    loss.backward()
  File "/home/.local/lib/python3.9/site-packages/torch/_tensor.py", line 396, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/.local/lib/python3.9/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/.local/lib/python3.9/site-packages/torch/autograd/function.py", line 253, in apply
    return user_fn(self, *args)
  File "/home/ColossalAI/colossalai/utils/activation_checkpoint.py", line 114, in backward
    outputs = ctx.run_function(*detached_inputs)
  File "<eval_with_key>.3", line 73, in checkpoint_2
  File "/home/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/.local/lib/python3.9/site-packages/torch/nn/modules/activation.py", line 98, in forward
    return F.relu(input, inplace=self.inplace)
  File "/home/.local/lib/python3.9/site-packages/torch/nn/functional.py", line 1455, in relu
    result = torch.relu_(input)
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

The generated nn.Module is as follows. Problems occurred in checkpoint_2

import torch
from torch.nn import *
class DenseNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # omitted
        self.classifier = Linear(in_features=1024, out_features=5, bias=True)
        self.load_state_dict(torch.load(r'densenet/state_dict.pt'))

    def forward(self, x : torch.Tensor) -> torch.Tensor:
        import colossalai
        features_conv0 = self.features.conv0(x);  x = None
        def checkpoint_0(features_conv0):
            # omitted
            return features_pool0, features_denseblock1_denselayer1_conv2, features_denseblock1_denselayer2_conv2, features_denseblock1_denselayer3_conv2, features_denseblock1_denselayer4_conv1
        features_pool0, features_denseblock1_denselayer1_conv2, features_denseblock1_denselayer2_conv2, features_denseblock1_denselayer3_conv2, features_denseblock1_denselayer4_conv1 = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, False, features_conv0)
        def checkpoint_1(features_denseblock1_denselayer4_conv1, features_pool0, features_denseblock1_denselayer1_conv2, features_denseblock1_denselayer2_conv2, features_denseblock1_denselayer3_conv2):
            features_denseblock1_denselayer4_norm2 = self.features.denseblock1.denselayer4.norm2(features_denseblock1_denselayer4_conv1);  features_denseblock1_denselayer4_conv1 = None
            features_denseblock1_denselayer4_relu2 = self.features.denseblock1.denselayer4.relu2(features_denseblock1_denselayer4_norm2);  features_denseblock1_denselayer4_norm2 = None
            features_denseblock1_denselayer4_conv2 = self.features.denseblock1.denselayer4.conv2(features_denseblock1_denselayer4_relu2);  features_denseblock1_denselayer4_relu2 = None
            cat_4 = torch.cat([features_pool0, features_denseblock1_denselayer1_conv2, features_denseblock1_denselayer2_conv2, features_denseblock1_denselayer3_conv2, features_denseblock1_denselayer4_conv2], 1)
            features_denseblock1_denselayer5_norm1 = self.features.denseblock1.denselayer5.norm1(cat_4);  cat_4 = None
            features_denseblock1_denselayer5_relu1 = self.features.denseblock1.denselayer5.relu1(features_denseblock1_denselayer5_norm1);  features_denseblock1_denselayer5_norm1 = None
            features_denseblock1_denselayer5_conv1 = self.features.denseblock1.denselayer5.conv1(features_denseblock1_denselayer5_relu1);  features_denseblock1_denselayer5_relu1 = None
            features_denseblock1_denselayer5_norm2 = self.features.denseblock1.denselayer5.norm2(features_denseblock1_denselayer5_conv1);  features_denseblock1_denselayer5_conv1 = None
            features_denseblock1_denselayer5_relu2 = self.features.denseblock1.denselayer5.relu2(features_denseblock1_denselayer5_norm2);  features_denseblock1_denselayer5_norm2 = None
            features_denseblock1_denselayer5_conv2 = self.features.denseblock1.denselayer5.conv2(features_denseblock1_denselayer5_relu2);  features_denseblock1_denselayer5_relu2 = None
            cat_5 = torch.cat([features_pool0, features_denseblock1_denselayer1_conv2, features_denseblock1_denselayer2_conv2, features_denseblock1_denselayer3_conv2, features_denseblock1_denselayer4_conv2, features_denseblock1_denselayer5_conv2], 1)
            features_denseblock1_denselayer6_norm1 = self.features.denseblock1.denselayer6.norm1(cat_5);  cat_5 = None
            features_denseblock1_denselayer6_relu1 = self.features.denseblock1.denselayer6.relu1(features_denseblock1_denselayer6_norm1);  features_denseblock1_denselayer6_norm1 = None
            features_denseblock1_denselayer6_conv1 = self.features.denseblock1.denselayer6.conv1(features_denseblock1_denselayer6_relu1);  features_denseblock1_denselayer6_relu1 = None
            features_denseblock1_denselayer6_norm2 = self.features.denseblock1.denselayer6.norm2(features_denseblock1_denselayer6_conv1);  features_denseblock1_denselayer6_conv1 = None
            features_denseblock1_denselayer6_relu2 = self.features.denseblock1.denselayer6.relu2(features_denseblock1_denselayer6_norm2);  features_denseblock1_denselayer6_norm2 = None
            features_denseblock1_denselayer6_conv2 = self.features.denseblock1.denselayer6.conv2(features_denseblock1_denselayer6_relu2);  features_denseblock1_denselayer6_relu2 = None
            cat_6 = torch.cat([features_pool0, features_denseblock1_denselayer1_conv2, features_denseblock1_denselayer2_conv2, features_denseblock1_denselayer3_conv2, features_denseblock1_denselayer4_conv2, features_denseblock1_denselayer5_conv2, features_denseblock1_denselayer6_conv2], 1);  features_pool0 = features_denseblock1_denselayer1_conv2 = features_denseblock1_denselayer2_conv2 = features_denseblock1_denselayer3_conv2 = features_denseblock1_denselayer4_conv2 = features_denseblock1_denselayer5_conv2 = features_denseblock1_denselayer6_conv2 = None
            features_transition1_norm = self.features.transition1.norm(cat_6);  cat_6 = None
            features_transition1_relu = self.features.transition1.relu(features_transition1_norm);  features_transition1_norm = None
            features_transition1_conv = self.features.transition1.conv(features_transition1_relu);  features_transition1_relu = None
            features_transition1_pool = self.features.transition1.pool(features_transition1_conv);  features_transition1_conv = None
            cat_7 = torch.cat([features_transition1_pool], 1)
            features_denseblock2_denselayer1_norm1 = self.features.denseblock2.denselayer1.norm1(cat_7);  cat_7 = None
            features_denseblock2_denselayer1_relu1 = self.features.denseblock2.denselayer1.relu1(features_denseblock2_denselayer1_norm1);  features_denseblock2_denselayer1_norm1 = None
            features_denseblock2_denselayer1_conv1 = self.features.denseblock2.denselayer1.conv1(features_denseblock2_denselayer1_relu1);  features_denseblock2_denselayer1_relu1 = None
            features_denseblock2_denselayer1_norm2 = self.features.denseblock2.denselayer1.norm2(features_denseblock2_denselayer1_conv1);  features_denseblock2_denselayer1_conv1 = None
            features_denseblock2_denselayer1_relu2 = self.features.denseblock2.denselayer1.relu2(features_denseblock2_denselayer1_norm2);  features_denseblock2_denselayer1_norm2 = None
            features_denseblock2_denselayer1_conv2 = self.features.denseblock2.denselayer1.conv2(features_denseblock2_denselayer1_relu2);  features_denseblock2_denselayer1_relu2 = None
            cat_8 = torch.cat([features_transition1_pool, features_denseblock2_denselayer1_conv2], 1)
            features_denseblock2_denselayer2_norm1 = self.features.denseblock2.denselayer2.norm1(cat_8);  cat_8 = None
            return features_transition1_pool, features_denseblock2_denselayer1_conv2, features_denseblock2_denselayer2_norm1
        features_transition1_pool, features_denseblock2_denselayer1_conv2, features_denseblock2_denselayer2_norm1 = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, False, features_denseblock1_denselayer4_conv1, features_pool0, features_denseblock1_denselayer1_conv2, features_denseblock1_denselayer2_conv2, features_denseblock1_denselayer3_conv2)
        def checkpoint_2(features_denseblock2_denselayer2_norm1, features_transition1_pool, features_denseblock2_denselayer1_conv2):
            features_denseblock2_denselayer2_relu1 = self.features.denseblock2.denselayer2.relu1(features_denseblock2_denselayer2_norm1);  features_denseblock2_denselayer2_norm1 = None
            features_denseblock2_denselayer2_conv1 = self.features.denseblock2.denselayer2.conv1(features_denseblock2_denselayer2_relu1);  features_denseblock2_denselayer2_relu1 = None
            features_denseblock2_denselayer2_norm2 = self.features.denseblock2.denselayer2.norm2(features_denseblock2_denselayer2_conv1);  features_denseblock2_denselayer2_conv1 = None
            features_denseblock2_denselayer2_relu2 = self.features.denseblock2.denselayer2.relu2(features_denseblock2_denselayer2_norm2);  features_denseblock2_denselayer2_norm2 = None
            features_denseblock2_denselayer2_conv2 = self.features.denseblock2.denselayer2.conv2(features_denseblock2_denselayer2_relu2);  features_denseblock2_denselayer2_relu2 = None
            cat_9 = torch.cat([features_transition1_pool, features_denseblock2_denselayer1_conv2, features_denseblock2_denselayer2_conv2], 1)
            features_denseblock2_denselayer3_norm1 = self.features.denseblock2.denselayer3.norm1(cat_9);  cat_9 = None
            features_denseblock2_denselayer3_relu1 = self.features.denseblock2.denselayer3.relu1(features_denseblock2_denselayer3_norm1);  features_denseblock2_denselayer3_norm1 = None
            features_denseblock2_denselayer3_conv1 = self.features.denseblock2.denselayer3.conv1(features_denseblock2_denselayer3_relu1);  features_denseblock2_denselayer3_relu1 = None
            features_denseblock2_denselayer3_norm2 = self.features.denseblock2.denselayer3.norm2(features_denseblock2_denselayer3_conv1);  features_denseblock2_denselayer3_conv1 = None
            features_denseblock2_denselayer3_relu2 = self.features.denseblock2.denselayer3.relu2(features_denseblock2_denselayer3_norm2);  features_denseblock2_denselayer3_norm2 = None
            features_denseblock2_denselayer3_conv2 = self.features.denseblock2.denselayer3.conv2(features_denseblock2_denselayer3_relu2);  features_denseblock2_denselayer3_relu2 = None
            cat_10 = torch.cat([features_transition1_pool, features_denseblock2_denselayer1_conv2, features_denseblock2_denselayer2_conv2, features_denseblock2_denselayer3_conv2], 1)
            features_denseblock2_denselayer4_norm1 = self.features.denseblock2.denselayer4.norm1(cat_10);  cat_10 = None
            features_denseblock2_denselayer4_relu1 = self.features.denseblock2.denselayer4.relu1(features_denseblock2_denselayer4_norm1);  features_denseblock2_denselayer4_norm1 = None
            features_denseblock2_denselayer4_conv1 = self.features.denseblock2.denselayer4.conv1(features_denseblock2_denselayer4_relu1);  features_denseblock2_denselayer4_relu1 = None
            features_denseblock2_denselayer4_norm2 = self.features.denseblock2.denselayer4.norm2(features_denseblock2_denselayer4_conv1);  features_denseblock2_denselayer4_conv1 = None
            features_denseblock2_denselayer4_relu2 = self.features.denseblock2.denselayer4.relu2(features_denseblock2_denselayer4_norm2);  features_denseblock2_denselayer4_norm2 = None
            features_denseblock2_denselayer4_conv2 = self.features.denseblock2.denselayer4.conv2(features_denseblock2_denselayer4_relu2);  features_denseblock2_denselayer4_relu2 = None
            cat_11 = torch.cat([features_transition1_pool, features_denseblock2_denselayer1_conv2, features_denseblock2_denselayer2_conv2, features_denseblock2_denselayer3_conv2, features_denseblock2_denselayer4_conv2], 1)
            features_denseblock2_denselayer5_norm1 = self.features.denseblock2.denselayer5.norm1(cat_11);  cat_11 = None
            features_denseblock2_denselayer5_relu1 = self.features.denseblock2.denselayer5.relu1(features_denseblock2_denselayer5_norm1);  features_denseblock2_denselayer5_norm1 = None
            features_denseblock2_denselayer5_conv1 = self.features.denseblock2.denselayer5.conv1(features_denseblock2_denselayer5_relu1);  features_denseblock2_denselayer5_relu1 = None
            features_denseblock2_denselayer5_norm2 = self.features.denseblock2.denselayer5.norm2(features_denseblock2_denselayer5_conv1);  features_denseblock2_denselayer5_conv1 = None
            features_denseblock2_denselayer5_relu2 = self.features.denseblock2.denselayer5.relu2(features_denseblock2_denselayer5_norm2);  features_denseblock2_denselayer5_norm2 = None
            features_denseblock2_denselayer5_conv2 = self.features.denseblock2.denselayer5.conv2(features_denseblock2_denselayer5_relu2);  features_denseblock2_denselayer5_relu2 = None
            cat_12 = torch.cat([features_transition1_pool, features_denseblock2_denselayer1_conv2, features_denseblock2_denselayer2_conv2, features_denseblock2_denselayer3_conv2, features_denseblock2_denselayer4_conv2, features_denseblock2_denselayer5_conv2], 1)
            features_denseblock2_denselayer6_norm1 = self.features.denseblock2.denselayer6.norm1(cat_12);  cat_12 = None
            features_denseblock2_denselayer6_relu1 = self.features.denseblock2.denselayer6.relu1(features_denseblock2_denselayer6_norm1);  features_denseblock2_denselayer6_norm1 = None
            features_denseblock2_denselayer6_conv1 = self.features.denseblock2.denselayer6.conv1(features_denseblock2_denselayer6_relu1);  features_denseblock2_denselayer6_relu1 = None
            features_denseblock2_denselayer6_norm2 = self.features.denseblock2.denselayer6.norm2(features_denseblock2_denselayer6_conv1);  features_denseblock2_denselayer6_conv1 = None
            features_denseblock2_denselayer6_relu2 = self.features.denseblock2.denselayer6.relu2(features_denseblock2_denselayer6_norm2);  features_denseblock2_denselayer6_norm2 = None
            features_denseblock2_denselayer6_conv2 = self.features.denseblock2.denselayer6.conv2(features_denseblock2_denselayer6_relu2);  features_denseblock2_denselayer6_relu2 = None
            cat_13 = torch.cat([features_transition1_pool, features_denseblock2_denselayer1_conv2, features_denseblock2_denselayer2_conv2, features_denseblock2_denselayer3_conv2, features_denseblock2_denselayer4_conv2, features_denseblock2_denselayer5_conv2, features_denseblock2_denselayer6_conv2], 1)
            features_denseblock2_denselayer7_norm1 = self.features.denseblock2.denselayer7.norm1(cat_13);  cat_13 = None
            features_denseblock2_denselayer7_relu1 = self.features.denseblock2.denselayer7.relu1(features_denseblock2_denselayer7_norm1);  features_denseblock2_denselayer7_norm1 = None
            features_denseblock2_denselayer7_conv1 = self.features.denseblock2.denselayer7.conv1(features_denseblock2_denselayer7_relu1);  features_denseblock2_denselayer7_relu1 = None
            features_denseblock2_denselayer7_norm2 = self.features.denseblock2.denselayer7.norm2(features_denseblock2_denselayer7_conv1);  features_denseblock2_denselayer7_conv1 = None
            features_denseblock2_denselayer7_relu2 = self.features.denseblock2.denselayer7.relu2(features_denseblock2_denselayer7_norm2);  features_denseblock2_denselayer7_norm2 = None
            features_denseblock2_denselayer7_conv2 = self.features.denseblock2.denselayer7.conv2(features_denseblock2_denselayer7_relu2);  features_denseblock2_denselayer7_relu2 = None
            cat_14 = torch.cat([features_transition1_pool, features_denseblock2_denselayer1_conv2, features_denseblock2_denselayer2_conv2, features_denseblock2_denselayer3_conv2, features_denseblock2_denselayer4_conv2, features_denseblock2_denselayer5_conv2, features_denseblock2_denselayer6_conv2, features_denseblock2_denselayer7_conv2], 1)
            features_denseblock2_denselayer8_norm1 = self.features.denseblock2.denselayer8.norm1(cat_14);  cat_14 = None
            features_denseblock2_denselayer8_relu1 = self.features.denseblock2.denselayer8.relu1(features_denseblock2_denselayer8_norm1);  features_denseblock2_denselayer8_norm1 = None
            features_denseblock2_denselayer8_conv1 = self.features.denseblock2.denselayer8.conv1(features_denseblock2_denselayer8_relu1);  features_denseblock2_denselayer8_relu1 = None
            features_denseblock2_denselayer8_norm2 = self.features.denseblock2.denselayer8.norm2(features_denseblock2_denselayer8_conv1);  features_denseblock2_denselayer8_conv1 = None
            features_denseblock2_denselayer8_relu2 = self.features.denseblock2.denselayer8.relu2(features_denseblock2_denselayer8_norm2);  features_denseblock2_denselayer8_norm2 = None
            features_denseblock2_denselayer8_conv2 = self.features.denseblock2.denselayer8.conv2(features_denseblock2_denselayer8_relu2);  features_denseblock2_denselayer8_relu2 = None
            cat_15 = torch.cat([features_transition1_pool, features_denseblock2_denselayer1_conv2, features_denseblock2_denselayer2_conv2, features_denseblock2_denselayer3_conv2, features_denseblock2_denselayer4_conv2, features_denseblock2_denselayer5_conv2, features_denseblock2_denselayer6_conv2, features_denseblock2_denselayer7_conv2, features_denseblock2_denselayer8_conv2], 1)
            features_denseblock2_denselayer9_norm1 = self.features.denseblock2.denselayer9.norm1(cat_15);  cat_15 = None
            features_denseblock2_denselayer9_relu1 = self.features.denseblock2.denselayer9.relu1(features_denseblock2_denselayer9_norm1);  features_denseblock2_denselayer9_norm1 = None
            features_denseblock2_denselayer9_conv1 = self.features.denseblock2.denselayer9.conv1(features_denseblock2_denselayer9_relu1);  features_denseblock2_denselayer9_relu1 = None
            features_denseblock2_denselayer9_norm2 = self.features.denseblock2.denselayer9.norm2(features_denseblock2_denselayer9_conv1);  features_denseblock2_denselayer9_conv1 = None
            features_denseblock2_denselayer9_relu2 = self.features.denseblock2.denselayer9.relu2(features_denseblock2_denselayer9_norm2);  features_denseblock2_denselayer9_norm2 = None
            features_denseblock2_denselayer9_conv2 = self.features.denseblock2.denselayer9.conv2(features_denseblock2_denselayer9_relu2);  features_denseblock2_denselayer9_relu2 = None
            cat_16 = torch.cat([features_transition1_pool, features_denseblock2_denselayer1_conv2, features_denseblock2_denselayer2_conv2, features_denseblock2_denselayer3_conv2, features_denseblock2_denselayer4_conv2, features_denseblock2_denselayer5_conv2, features_denseblock2_denselayer6_conv2, features_denseblock2_denselayer7_conv2, features_denseblock2_denselayer8_conv2, features_denseblock2_denselayer9_conv2], 1)
            features_denseblock2_denselayer10_norm1 = self.features.denseblock2.denselayer10.norm1(cat_16);  cat_16 = None
            features_denseblock2_denselayer10_relu1 = self.features.denseblock2.denselayer10.relu1(features_denseblock2_denselayer10_norm1);  features_denseblock2_denselayer10_norm1 = None
            features_denseblock2_denselayer10_conv1 = self.features.denseblock2.denselayer10.conv1(features_denseblock2_denselayer10_relu1);  features_denseblock2_denselayer10_relu1 = None
            features_denseblock2_denselayer10_norm2 = self.features.denseblock2.denselayer10.norm2(features_denseblock2_denselayer10_conv1);  features_denseblock2_denselayer10_conv1 = None
            features_denseblock2_denselayer10_relu2 = self.features.denseblock2.denselayer10.relu2(features_denseblock2_denselayer10_norm2);  features_denseblock2_denselayer10_norm2 = None
            features_denseblock2_denselayer10_conv2 = self.features.denseblock2.denselayer10.conv2(features_denseblock2_denselayer10_relu2);  features_denseblock2_denselayer10_relu2 = None
            cat_17 = torch.cat([features_transition1_pool, features_denseblock2_denselayer1_conv2, features_denseblock2_denselayer2_conv2, features_denseblock2_denselayer3_conv2, features_denseblock2_denselayer4_conv2, features_denseblock2_denselayer5_conv2, features_denseblock2_denselayer6_conv2, features_denseblock2_denselayer7_conv2, features_denseblock2_denselayer8_conv2, features_denseblock2_denselayer9_conv2, features_denseblock2_denselayer10_conv2], 1)
            features_denseblock2_denselayer11_norm1 = self.features.denseblock2.denselayer11.norm1(cat_17);  cat_17 = None
            features_denseblock2_denselayer11_relu1 = self.features.denseblock2.denselayer11.relu1(features_denseblock2_denselayer11_norm1);  features_denseblock2_denselayer11_norm1 = None
            features_denseblock2_denselayer11_conv1 = self.features.denseblock2.denselayer11.conv1(features_denseblock2_denselayer11_relu1);  features_denseblock2_denselayer11_relu1 = None
            features_denseblock2_denselayer11_norm2 = self.features.denseblock2.denselayer11.norm2(features_denseblock2_denselayer11_conv1);  features_denseblock2_denselayer11_conv1 = None
            features_denseblock2_denselayer11_relu2 = self.features.denseblock2.denselayer11.relu2(features_denseblock2_denselayer11_norm2);  features_denseblock2_denselayer11_norm2 = None
            features_denseblock2_denselayer11_conv2 = self.features.denseblock2.denselayer11.conv2(features_denseblock2_denselayer11_relu2);  features_denseblock2_denselayer11_relu2 = None
            cat_18 = torch.cat([features_transition1_pool, features_denseblock2_denselayer1_conv2, features_denseblock2_denselayer2_conv2, features_denseblock2_denselayer3_conv2, features_denseblock2_denselayer4_conv2, features_denseblock2_denselayer5_conv2, features_denseblock2_denselayer6_conv2, features_denseblock2_denselayer7_conv2, features_denseblock2_denselayer8_conv2, features_denseblock2_denselayer9_conv2, features_denseblock2_denselayer10_conv2, features_denseblock2_denselayer11_conv2], 1)
            features_denseblock2_denselayer12_norm1 = self.features.denseblock2.denselayer12.norm1(cat_18);  cat_18 = None
            features_denseblock2_denselayer12_relu1 = self.features.denseblock2.denselayer12.relu1(features_denseblock2_denselayer12_norm1);  features_denseblock2_denselayer12_norm1 = None
            return features_denseblock2_denselayer2_conv2, features_denseblock2_denselayer3_conv2, features_denseblock2_denselayer4_conv2, features_denseblock2_denselayer5_conv2, features_denseblock2_denselayer6_conv2, features_denseblock2_denselayer7_conv2, features_denseblock2_denselayer8_conv2, features_denseblock2_denselayer9_conv2, features_denseblock2_denselayer10_conv2, features_denseblock2_denselayer11_conv2, features_denseblock2_denselayer12_relu1
        features_denseblock2_denselayer2_conv2, features_denseblock2_denselayer3_conv2, features_denseblock2_denselayer4_conv2, features_denseblock2_denselayer5_conv2, features_denseblock2_denselayer6_conv2, features_denseblock2_denselayer7_conv2, features_denseblock2_denselayer8_conv2, features_denseblock2_denselayer9_conv2, features_denseblock2_denselayer10_conv2, features_denseblock2_denselayer11_conv2, features_denseblock2_denselayer12_relu1 = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_2, False, features_denseblock2_denselayer2_norm1, features_transition1_pool, features_denseblock2_denselayer1_conv2)
        # Too many layers after checkpoint_2

Is that because colossalai.utils.activation_checkpoint.checkpoint does not support in-place operation right after the input node? Should we hijack this potential problem during CodeGen or modify our checkpoint logit?

super-dainiu and others added 15 commits August 9, 2022 23:23
* [fx] activation checkpointing using Chen strategies.

* [fx] add test for ckpt_solver_chen

* [fx] add vanilla activation checkpoint search with test on resnet and densenet

* [fx] add a namespace code for solver_chen.

* [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174.

* [fx] fix lowercase naming conventions.

* [fx] simplify test for ckpt.
[fx] fix test and algorithm bugs in activation checkpointing.
[fx] fix test and algorithm bugs in activation checkpointing.
[fx] fix test and algorithm bugs in activation checkpointing.
[fx] fix test and algorithm bugs in activation checkpointing.
@super-dainiu super-dainiu changed the title Feature/more ckpt [fx] fix test and algorithm bugs in activation checkpointing. Aug 12, 2022
@Cypher30
Copy link
Copy Markdown
Contributor

What's new?

I regretted using torch.randn(2, 3, 224, 224) previously #1446 for testing because this consumes too much time on CI. Also, I made some modifications to the search algorithm (mostly conditions for annotations) to avoid crashes in ActivationCheckpointCodeGen.

What's wrong

However, I did not figure out why tracing on densenet121 got an error.

Traceback (most recent call last):
  File "/home/ColossalAI/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py", line 104, in <module>
    test_ckpt_solver()
  File "/home/ColossalAI/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py", line 73, in test_ckpt_solver
    check_backward_consistency(m, gm, solver, model_cls)
  File "/home/ColossalAI/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py", line 46, in check_backward_consistency
    loss.backward()
  File "/home/.local/lib/python3.9/site-packages/torch/_tensor.py", line 396, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/.local/lib/python3.9/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/.local/lib/python3.9/site-packages/torch/autograd/function.py", line 253, in apply
    return user_fn(self, *args)
  File "/home/ColossalAI/colossalai/utils/activation_checkpoint.py", line 114, in backward
    outputs = ctx.run_function(*detached_inputs)
  File "<eval_with_key>.3", line 73, in checkpoint_2
  File "/home/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/.local/lib/python3.9/site-packages/torch/nn/modules/activation.py", line 98, in forward
    return F.relu(input, inplace=self.inplace)
  File "/home/.local/lib/python3.9/site-packages/torch/nn/functional.py", line 1455, in relu
    result = torch.relu_(input)
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

The generated nn.Module is as follows. Problems occurred in checkpoint_2

import torch
from torch.nn import *
class DenseNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # omitted
        self.classifier = Linear(in_features=1024, out_features=5, bias=True)
        self.load_state_dict(torch.load(r'densenet/state_dict.pt'))

    def forward(self, x : torch.Tensor) -> torch.Tensor:
        import colossalai
        features_conv0 = self.features.conv0(x);  x = None
        def checkpoint_0(features_conv0):
            # omitted
            return features_pool0, features_denseblock1_denselayer1_conv2, features_denseblock1_denselayer2_conv2, features_denseblock1_denselayer3_conv2, features_denseblock1_denselayer4_conv1
        features_pool0, features_denseblock1_denselayer1_conv2, features_denseblock1_denselayer2_conv2, features_denseblock1_denselayer3_conv2, features_denseblock1_denselayer4_conv1 = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, False, features_conv0)
        def checkpoint_1(features_denseblock1_denselayer4_conv1, features_pool0, features_denseblock1_denselayer1_conv2, features_denseblock1_denselayer2_conv2, features_denseblock1_denselayer3_conv2):
            features_denseblock1_denselayer4_norm2 = self.features.denseblock1.denselayer4.norm2(features_denseblock1_denselayer4_conv1);  features_denseblock1_denselayer4_conv1 = None
            features_denseblock1_denselayer4_relu2 = self.features.denseblock1.denselayer4.relu2(features_denseblock1_denselayer4_norm2);  features_denseblock1_denselayer4_norm2 = None
            features_denseblock1_denselayer4_conv2 = self.features.denseblock1.denselayer4.conv2(features_denseblock1_denselayer4_relu2);  features_denseblock1_denselayer4_relu2 = None
            cat_4 = torch.cat([features_pool0, features_denseblock1_denselayer1_conv2, features_denseblock1_denselayer2_conv2, features_denseblock1_denselayer3_conv2, features_denseblock1_denselayer4_conv2], 1)
            features_denseblock1_denselayer5_norm1 = self.features.denseblock1.denselayer5.norm1(cat_4);  cat_4 = None
            features_denseblock1_denselayer5_relu1 = self.features.denseblock1.denselayer5.relu1(features_denseblock1_denselayer5_norm1);  features_denseblock1_denselayer5_norm1 = None
            features_denseblock1_denselayer5_conv1 = self.features.denseblock1.denselayer5.conv1(features_denseblock1_denselayer5_relu1);  features_denseblock1_denselayer5_relu1 = None
            features_denseblock1_denselayer5_norm2 = self.features.denseblock1.denselayer5.norm2(features_denseblock1_denselayer5_conv1);  features_denseblock1_denselayer5_conv1 = None
            features_denseblock1_denselayer5_relu2 = self.features.denseblock1.denselayer5.relu2(features_denseblock1_denselayer5_norm2);  features_denseblock1_denselayer5_norm2 = None
            features_denseblock1_denselayer5_conv2 = self.features.denseblock1.denselayer5.conv2(features_denseblock1_denselayer5_relu2);  features_denseblock1_denselayer5_relu2 = None
            cat_5 = torch.cat([features_pool0, features_denseblock1_denselayer1_conv2, features_denseblock1_denselayer2_conv2, features_denseblock1_denselayer3_conv2, features_denseblock1_denselayer4_conv2, features_denseblock1_denselayer5_conv2], 1)
            features_denseblock1_denselayer6_norm1 = self.features.denseblock1.denselayer6.norm1(cat_5);  cat_5 = None
            features_denseblock1_denselayer6_relu1 = self.features.denseblock1.denselayer6.relu1(features_denseblock1_denselayer6_norm1);  features_denseblock1_denselayer6_norm1 = None
            features_denseblock1_denselayer6_conv1 = self.features.denseblock1.denselayer6.conv1(features_denseblock1_denselayer6_relu1);  features_denseblock1_denselayer6_relu1 = None
            features_denseblock1_denselayer6_norm2 = self.features.denseblock1.denselayer6.norm2(features_denseblock1_denselayer6_conv1);  features_denseblock1_denselayer6_conv1 = None
            features_denseblock1_denselayer6_relu2 = self.features.denseblock1.denselayer6.relu2(features_denseblock1_denselayer6_norm2);  features_denseblock1_denselayer6_norm2 = None
            features_denseblock1_denselayer6_conv2 = self.features.denseblock1.denselayer6.conv2(features_denseblock1_denselayer6_relu2);  features_denseblock1_denselayer6_relu2 = None
            cat_6 = torch.cat([features_pool0, features_denseblock1_denselayer1_conv2, features_denseblock1_denselayer2_conv2, features_denseblock1_denselayer3_conv2, features_denseblock1_denselayer4_conv2, features_denseblock1_denselayer5_conv2, features_denseblock1_denselayer6_conv2], 1);  features_pool0 = features_denseblock1_denselayer1_conv2 = features_denseblock1_denselayer2_conv2 = features_denseblock1_denselayer3_conv2 = features_denseblock1_denselayer4_conv2 = features_denseblock1_denselayer5_conv2 = features_denseblock1_denselayer6_conv2 = None
            features_transition1_norm = self.features.transition1.norm(cat_6);  cat_6 = None
            features_transition1_relu = self.features.transition1.relu(features_transition1_norm);  features_transition1_norm = None
            features_transition1_conv = self.features.transition1.conv(features_transition1_relu);  features_transition1_relu = None
            features_transition1_pool = self.features.transition1.pool(features_transition1_conv);  features_transition1_conv = None
            cat_7 = torch.cat([features_transition1_pool], 1)
            features_denseblock2_denselayer1_norm1 = self.features.denseblock2.denselayer1.norm1(cat_7);  cat_7 = None
            features_denseblock2_denselayer1_relu1 = self.features.denseblock2.denselayer1.relu1(features_denseblock2_denselayer1_norm1);  features_denseblock2_denselayer1_norm1 = None
            features_denseblock2_denselayer1_conv1 = self.features.denseblock2.denselayer1.conv1(features_denseblock2_denselayer1_relu1);  features_denseblock2_denselayer1_relu1 = None
            features_denseblock2_denselayer1_norm2 = self.features.denseblock2.denselayer1.norm2(features_denseblock2_denselayer1_conv1);  features_denseblock2_denselayer1_conv1 = None
            features_denseblock2_denselayer1_relu2 = self.features.denseblock2.denselayer1.relu2(features_denseblock2_denselayer1_norm2);  features_denseblock2_denselayer1_norm2 = None
            features_denseblock2_denselayer1_conv2 = self.features.denseblock2.denselayer1.conv2(features_denseblock2_denselayer1_relu2);  features_denseblock2_denselayer1_relu2 = None
            cat_8 = torch.cat([features_transition1_pool, features_denseblock2_denselayer1_conv2], 1)
            features_denseblock2_denselayer2_norm1 = self.features.denseblock2.denselayer2.norm1(cat_8);  cat_8 = None
            return features_transition1_pool, features_denseblock2_denselayer1_conv2, features_denseblock2_denselayer2_norm1
        features_transition1_pool, features_denseblock2_denselayer1_conv2, features_denseblock2_denselayer2_norm1 = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, False, features_denseblock1_denselayer4_conv1, features_pool0, features_denseblock1_denselayer1_conv2, features_denseblock1_denselayer2_conv2, features_denseblock1_denselayer3_conv2)
        def checkpoint_2(features_denseblock2_denselayer2_norm1, features_transition1_pool, features_denseblock2_denselayer1_conv2):
            features_denseblock2_denselayer2_relu1 = self.features.denseblock2.denselayer2.relu1(features_denseblock2_denselayer2_norm1);  features_denseblock2_denselayer2_norm1 = None
            features_denseblock2_denselayer2_conv1 = self.features.denseblock2.denselayer2.conv1(features_denseblock2_denselayer2_relu1);  features_denseblock2_denselayer2_relu1 = None
            features_denseblock2_denselayer2_norm2 = self.features.denseblock2.denselayer2.norm2(features_denseblock2_denselayer2_conv1);  features_denseblock2_denselayer2_conv1 = None
            features_denseblock2_denselayer2_relu2 = self.features.denseblock2.denselayer2.relu2(features_denseblock2_denselayer2_norm2);  features_denseblock2_denselayer2_norm2 = None
            features_denseblock2_denselayer2_conv2 = self.features.denseblock2.denselayer2.conv2(features_denseblock2_denselayer2_relu2);  features_denseblock2_denselayer2_relu2 = None
            cat_9 = torch.cat([features_transition1_pool, features_denseblock2_denselayer1_conv2, features_denseblock2_denselayer2_conv2], 1)
            features_denseblock2_denselayer3_norm1 = self.features.denseblock2.denselayer3.norm1(cat_9);  cat_9 = None
            features_denseblock2_denselayer3_relu1 = self.features.denseblock2.denselayer3.relu1(features_denseblock2_denselayer3_norm1);  features_denseblock2_denselayer3_norm1 = None
            features_denseblock2_denselayer3_conv1 = self.features.denseblock2.denselayer3.conv1(features_denseblock2_denselayer3_relu1);  features_denseblock2_denselayer3_relu1 = None
            features_denseblock2_denselayer3_norm2 = self.features.denseblock2.denselayer3.norm2(features_denseblock2_denselayer3_conv1);  features_denseblock2_denselayer3_conv1 = None
            features_denseblock2_denselayer3_relu2 = self.features.denseblock2.denselayer3.relu2(features_denseblock2_denselayer3_norm2);  features_denseblock2_denselayer3_norm2 = None
            features_denseblock2_denselayer3_conv2 = self.features.denseblock2.denselayer3.conv2(features_denseblock2_denselayer3_relu2);  features_denseblock2_denselayer3_relu2 = None
            cat_10 = torch.cat([features_transition1_pool, features_denseblock2_denselayer1_conv2, features_denseblock2_denselayer2_conv2, features_denseblock2_denselayer3_conv2], 1)
            features_denseblock2_denselayer4_norm1 = self.features.denseblock2.denselayer4.norm1(cat_10);  cat_10 = None
            features_denseblock2_denselayer4_relu1 = self.features.denseblock2.denselayer4.relu1(features_denseblock2_denselayer4_norm1);  features_denseblock2_denselayer4_norm1 = None
            features_denseblock2_denselayer4_conv1 = self.features.denseblock2.denselayer4.conv1(features_denseblock2_denselayer4_relu1);  features_denseblock2_denselayer4_relu1 = None
            features_denseblock2_denselayer4_norm2 = self.features.denseblock2.denselayer4.norm2(features_denseblock2_denselayer4_conv1);  features_denseblock2_denselayer4_conv1 = None
            features_denseblock2_denselayer4_relu2 = self.features.denseblock2.denselayer4.relu2(features_denseblock2_denselayer4_norm2);  features_denseblock2_denselayer4_norm2 = None
            features_denseblock2_denselayer4_conv2 = self.features.denseblock2.denselayer4.conv2(features_denseblock2_denselayer4_relu2);  features_denseblock2_denselayer4_relu2 = None
            cat_11 = torch.cat([features_transition1_pool, features_denseblock2_denselayer1_conv2, features_denseblock2_denselayer2_conv2, features_denseblock2_denselayer3_conv2, features_denseblock2_denselayer4_conv2], 1)
            features_denseblock2_denselayer5_norm1 = self.features.denseblock2.denselayer5.norm1(cat_11);  cat_11 = None
            features_denseblock2_denselayer5_relu1 = self.features.denseblock2.denselayer5.relu1(features_denseblock2_denselayer5_norm1);  features_denseblock2_denselayer5_norm1 = None
            features_denseblock2_denselayer5_conv1 = self.features.denseblock2.denselayer5.conv1(features_denseblock2_denselayer5_relu1);  features_denseblock2_denselayer5_relu1 = None
            features_denseblock2_denselayer5_norm2 = self.features.denseblock2.denselayer5.norm2(features_denseblock2_denselayer5_conv1);  features_denseblock2_denselayer5_conv1 = None
            features_denseblock2_denselayer5_relu2 = self.features.denseblock2.denselayer5.relu2(features_denseblock2_denselayer5_norm2);  features_denseblock2_denselayer5_norm2 = None
            features_denseblock2_denselayer5_conv2 = self.features.denseblock2.denselayer5.conv2(features_denseblock2_denselayer5_relu2);  features_denseblock2_denselayer5_relu2 = None
            cat_12 = torch.cat([features_transition1_pool, features_denseblock2_denselayer1_conv2, features_denseblock2_denselayer2_conv2, features_denseblock2_denselayer3_conv2, features_denseblock2_denselayer4_conv2, features_denseblock2_denselayer5_conv2], 1)
            features_denseblock2_denselayer6_norm1 = self.features.denseblock2.denselayer6.norm1(cat_12);  cat_12 = None
            features_denseblock2_denselayer6_relu1 = self.features.denseblock2.denselayer6.relu1(features_denseblock2_denselayer6_norm1);  features_denseblock2_denselayer6_norm1 = None
            features_denseblock2_denselayer6_conv1 = self.features.denseblock2.denselayer6.conv1(features_denseblock2_denselayer6_relu1);  features_denseblock2_denselayer6_relu1 = None
            features_denseblock2_denselayer6_norm2 = self.features.denseblock2.denselayer6.norm2(features_denseblock2_denselayer6_conv1);  features_denseblock2_denselayer6_conv1 = None
            features_denseblock2_denselayer6_relu2 = self.features.denseblock2.denselayer6.relu2(features_denseblock2_denselayer6_norm2);  features_denseblock2_denselayer6_norm2 = None
            features_denseblock2_denselayer6_conv2 = self.features.denseblock2.denselayer6.conv2(features_denseblock2_denselayer6_relu2);  features_denseblock2_denselayer6_relu2 = None
            cat_13 = torch.cat([features_transition1_pool, features_denseblock2_denselayer1_conv2, features_denseblock2_denselayer2_conv2, features_denseblock2_denselayer3_conv2, features_denseblock2_denselayer4_conv2, features_denseblock2_denselayer5_conv2, features_denseblock2_denselayer6_conv2], 1)
            features_denseblock2_denselayer7_norm1 = self.features.denseblock2.denselayer7.norm1(cat_13);  cat_13 = None
            features_denseblock2_denselayer7_relu1 = self.features.denseblock2.denselayer7.relu1(features_denseblock2_denselayer7_norm1);  features_denseblock2_denselayer7_norm1 = None
            features_denseblock2_denselayer7_conv1 = self.features.denseblock2.denselayer7.conv1(features_denseblock2_denselayer7_relu1);  features_denseblock2_denselayer7_relu1 = None
            features_denseblock2_denselayer7_norm2 = self.features.denseblock2.denselayer7.norm2(features_denseblock2_denselayer7_conv1);  features_denseblock2_denselayer7_conv1 = None
            features_denseblock2_denselayer7_relu2 = self.features.denseblock2.denselayer7.relu2(features_denseblock2_denselayer7_norm2);  features_denseblock2_denselayer7_norm2 = None
            features_denseblock2_denselayer7_conv2 = self.features.denseblock2.denselayer7.conv2(features_denseblock2_denselayer7_relu2);  features_denseblock2_denselayer7_relu2 = None
            cat_14 = torch.cat([features_transition1_pool, features_denseblock2_denselayer1_conv2, features_denseblock2_denselayer2_conv2, features_denseblock2_denselayer3_conv2, features_denseblock2_denselayer4_conv2, features_denseblock2_denselayer5_conv2, features_denseblock2_denselayer6_conv2, features_denseblock2_denselayer7_conv2], 1)
            features_denseblock2_denselayer8_norm1 = self.features.denseblock2.denselayer8.norm1(cat_14);  cat_14 = None
            features_denseblock2_denselayer8_relu1 = self.features.denseblock2.denselayer8.relu1(features_denseblock2_denselayer8_norm1);  features_denseblock2_denselayer8_norm1 = None
            features_denseblock2_denselayer8_conv1 = self.features.denseblock2.denselayer8.conv1(features_denseblock2_denselayer8_relu1);  features_denseblock2_denselayer8_relu1 = None
            features_denseblock2_denselayer8_norm2 = self.features.denseblock2.denselayer8.norm2(features_denseblock2_denselayer8_conv1);  features_denseblock2_denselayer8_conv1 = None
            features_denseblock2_denselayer8_relu2 = self.features.denseblock2.denselayer8.relu2(features_denseblock2_denselayer8_norm2);  features_denseblock2_denselayer8_norm2 = None
            features_denseblock2_denselayer8_conv2 = self.features.denseblock2.denselayer8.conv2(features_denseblock2_denselayer8_relu2);  features_denseblock2_denselayer8_relu2 = None
            cat_15 = torch.cat([features_transition1_pool, features_denseblock2_denselayer1_conv2, features_denseblock2_denselayer2_conv2, features_denseblock2_denselayer3_conv2, features_denseblock2_denselayer4_conv2, features_denseblock2_denselayer5_conv2, features_denseblock2_denselayer6_conv2, features_denseblock2_denselayer7_conv2, features_denseblock2_denselayer8_conv2], 1)
            features_denseblock2_denselayer9_norm1 = self.features.denseblock2.denselayer9.norm1(cat_15);  cat_15 = None
            features_denseblock2_denselayer9_relu1 = self.features.denseblock2.denselayer9.relu1(features_denseblock2_denselayer9_norm1);  features_denseblock2_denselayer9_norm1 = None
            features_denseblock2_denselayer9_conv1 = self.features.denseblock2.denselayer9.conv1(features_denseblock2_denselayer9_relu1);  features_denseblock2_denselayer9_relu1 = None
            features_denseblock2_denselayer9_norm2 = self.features.denseblock2.denselayer9.norm2(features_denseblock2_denselayer9_conv1);  features_denseblock2_denselayer9_conv1 = None
            features_denseblock2_denselayer9_relu2 = self.features.denseblock2.denselayer9.relu2(features_denseblock2_denselayer9_norm2);  features_denseblock2_denselayer9_norm2 = None
            features_denseblock2_denselayer9_conv2 = self.features.denseblock2.denselayer9.conv2(features_denseblock2_denselayer9_relu2);  features_denseblock2_denselayer9_relu2 = None
            cat_16 = torch.cat([features_transition1_pool, features_denseblock2_denselayer1_conv2, features_denseblock2_denselayer2_conv2, features_denseblock2_denselayer3_conv2, features_denseblock2_denselayer4_conv2, features_denseblock2_denselayer5_conv2, features_denseblock2_denselayer6_conv2, features_denseblock2_denselayer7_conv2, features_denseblock2_denselayer8_conv2, features_denseblock2_denselayer9_conv2], 1)
            features_denseblock2_denselayer10_norm1 = self.features.denseblock2.denselayer10.norm1(cat_16);  cat_16 = None
            features_denseblock2_denselayer10_relu1 = self.features.denseblock2.denselayer10.relu1(features_denseblock2_denselayer10_norm1);  features_denseblock2_denselayer10_norm1 = None
            features_denseblock2_denselayer10_conv1 = self.features.denseblock2.denselayer10.conv1(features_denseblock2_denselayer10_relu1);  features_denseblock2_denselayer10_relu1 = None
            features_denseblock2_denselayer10_norm2 = self.features.denseblock2.denselayer10.norm2(features_denseblock2_denselayer10_conv1);  features_denseblock2_denselayer10_conv1 = None
            features_denseblock2_denselayer10_relu2 = self.features.denseblock2.denselayer10.relu2(features_denseblock2_denselayer10_norm2);  features_denseblock2_denselayer10_norm2 = None
            features_denseblock2_denselayer10_conv2 = self.features.denseblock2.denselayer10.conv2(features_denseblock2_denselayer10_relu2);  features_denseblock2_denselayer10_relu2 = None
            cat_17 = torch.cat([features_transition1_pool, features_denseblock2_denselayer1_conv2, features_denseblock2_denselayer2_conv2, features_denseblock2_denselayer3_conv2, features_denseblock2_denselayer4_conv2, features_denseblock2_denselayer5_conv2, features_denseblock2_denselayer6_conv2, features_denseblock2_denselayer7_conv2, features_denseblock2_denselayer8_conv2, features_denseblock2_denselayer9_conv2, features_denseblock2_denselayer10_conv2], 1)
            features_denseblock2_denselayer11_norm1 = self.features.denseblock2.denselayer11.norm1(cat_17);  cat_17 = None
            features_denseblock2_denselayer11_relu1 = self.features.denseblock2.denselayer11.relu1(features_denseblock2_denselayer11_norm1);  features_denseblock2_denselayer11_norm1 = None
            features_denseblock2_denselayer11_conv1 = self.features.denseblock2.denselayer11.conv1(features_denseblock2_denselayer11_relu1);  features_denseblock2_denselayer11_relu1 = None
            features_denseblock2_denselayer11_norm2 = self.features.denseblock2.denselayer11.norm2(features_denseblock2_denselayer11_conv1);  features_denseblock2_denselayer11_conv1 = None
            features_denseblock2_denselayer11_relu2 = self.features.denseblock2.denselayer11.relu2(features_denseblock2_denselayer11_norm2);  features_denseblock2_denselayer11_norm2 = None
            features_denseblock2_denselayer11_conv2 = self.features.denseblock2.denselayer11.conv2(features_denseblock2_denselayer11_relu2);  features_denseblock2_denselayer11_relu2 = None
            cat_18 = torch.cat([features_transition1_pool, features_denseblock2_denselayer1_conv2, features_denseblock2_denselayer2_conv2, features_denseblock2_denselayer3_conv2, features_denseblock2_denselayer4_conv2, features_denseblock2_denselayer5_conv2, features_denseblock2_denselayer6_conv2, features_denseblock2_denselayer7_conv2, features_denseblock2_denselayer8_conv2, features_denseblock2_denselayer9_conv2, features_denseblock2_denselayer10_conv2, features_denseblock2_denselayer11_conv2], 1)
            features_denseblock2_denselayer12_norm1 = self.features.denseblock2.denselayer12.norm1(cat_18);  cat_18 = None
            features_denseblock2_denselayer12_relu1 = self.features.denseblock2.denselayer12.relu1(features_denseblock2_denselayer12_norm1);  features_denseblock2_denselayer12_norm1 = None
            return features_denseblock2_denselayer2_conv2, features_denseblock2_denselayer3_conv2, features_denseblock2_denselayer4_conv2, features_denseblock2_denselayer5_conv2, features_denseblock2_denselayer6_conv2, features_denseblock2_denselayer7_conv2, features_denseblock2_denselayer8_conv2, features_denseblock2_denselayer9_conv2, features_denseblock2_denselayer10_conv2, features_denseblock2_denselayer11_conv2, features_denseblock2_denselayer12_relu1
        features_denseblock2_denselayer2_conv2, features_denseblock2_denselayer3_conv2, features_denseblock2_denselayer4_conv2, features_denseblock2_denselayer5_conv2, features_denseblock2_denselayer6_conv2, features_denseblock2_denselayer7_conv2, features_denseblock2_denselayer8_conv2, features_denseblock2_denselayer9_conv2, features_denseblock2_denselayer10_conv2, features_denseblock2_denselayer11_conv2, features_denseblock2_denselayer12_relu1 = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_2, False, features_denseblock2_denselayer2_norm1, features_transition1_pool, features_denseblock2_denselayer1_conv2)
        # Too many layers after checkpoint_2

Is that because colossalai.utils.activation_checkpoint.checkpoint does not support in-place operation right after the input node? Should we hijack this potential problem during CodeGen or modify our checkpoint logit?

I think that's it, take a look at activation_checkpoint.py line:11-17, I think in this case the require_grad is set to True for input tensor, and as the input tensors are the leaf node in run_function, PyTorch will not let this kind of operation happen, see this. Seems our tracer could not identify those in-place operation?

Copy link
Copy Markdown
Contributor

@Cypher30 Cypher30 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just hold the PR, I think we need further discussion of those in-place operations

Copy link
Copy Markdown
Contributor

@Cypher30 Cypher30 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We approve this change and merge, but skip the test, waiting for new version of colossalai checkpoint

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants