Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
04e5272
Merge pull request #1 from hpcaitech/main
Cypher30 Jul 14, 2022
75618b3
Merge pull request #2 from hpcaitech/main
Cypher30 Jul 15, 2022
3e4620c
Merge pull request #3 from hpcaitech/main
Cypher30 Jul 20, 2022
cf24049
Merge remote-tracking branch 'upstream/main' into main
Jul 20, 2022
3d223b6
Merge remote-tracking branch 'upstream/main' into main
Jul 21, 2022
644115c
Merge branch 'hpcaitech:main' into main
Cypher30 Jul 22, 2022
d995ade
Merge branch 'hpcaitech:main' into main
Cypher30 Jul 25, 2022
bba2dbe
Merge branch 'hpcaitech:main' into main
Cypher30 Jul 26, 2022
05ca628
Merge branch 'hpcaitech:main' into main
Cypher30 Jul 26, 2022
0a967da
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 6, 2022
0637c0d
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 8, 2022
74a6227
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 10, 2022
e550490
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 10, 2022
2d7f5d9
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 11, 2022
b62e870
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 12, 2022
54e5815
Merge branch 'hpcaitech:main' into feature/add_ckpt_reentrant_False
Cypher30 Aug 16, 2022
ed273b7
[utils] Add use_reetrant=False into colossalai checkpoint
Aug 16, 2022
c3abee9
[utils] add some annotation in utils.activaion_checkpoint
Aug 16, 2022
2d3c672
[test] add reset_seed at the beginning of tests in test_actiavion_che…
Aug 16, 2022
64b9032
[test] modify test_activation_checkpoint.py
Aug 16, 2022
bcdba62
[test] modify test for reentrant=False
Aug 16, 2022
74bd65e
Merge branch 'hpcaitech:main' into feature/add_ckpt_reentrant_False_t…
Cypher30 Aug 16, 2022
f8ea3da
[fx] Add use_reentrant=False of checkpoint into codegen
Aug 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions colossalai/fx/codegen/activation_checkpoint_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,13 @@ def _gen_ckpt_output(output_vars: List[str]) -> str:
return f"return {', '.join(output_vars)}"


def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars):
def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reentrant=True):
"""
Generate the checkpoint function call code text
"""
outputs = ', '.join(output_vars)
inputs = ', '.join(input_vars)
return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_{label}, {activation_offload}, {inputs})'
return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})'


def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unused_value_func):
Expand Down Expand Up @@ -162,8 +162,24 @@ def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unu
else:
activation_offload = False

# we need to check if the checkpoint need use_reentrant=False
use_reentrant = True
for var in input_vars[label]:
input_node = [item for item in node_list if item.name == var]
input_node = input_node[0]
for user in input_node.users:
if hasattr(user, "activation_checkpoint"):
if user.activation_checkpoint == label:
if user.op == "call_module":
if hasattr(user.graph.owning_module.get_submodule(user.target), "inplace"):
use_reentrant = not user.graph.owning_module.get_submodule(user.target).inplace

elif user.op == "call_function":
if "inplace" in user.kwargs:
use_reentrant = not user.kwargs["inplace"]

# generate checkpoint function call in a new line
usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label])
usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label], use_reentrant)
usage += '\n'
body.append(usage)
within_ckpt_region = False
Expand Down
49 changes: 34 additions & 15 deletions tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from operator import mod
import torch
import torch.nn.functional as F
import pytest
import torch.multiprocessing as mp
from torch.utils.checkpoint import checkpoint
Expand All @@ -26,20 +27,35 @@ def __init__(self):
self.linear2 = torch.nn.Linear(4, 4)

def forward(self, x):
return self.linear1(x), self.linear1(x)
return self.linear1(x), self.linear2(x)


class relu(torch.nn.Module):

def __init__(self) -> None:
super().__init__()
self.relu = torch.nn.ReLU(inplace=True)

def forward(self, x):
return self.relu(x)


class MyModule(torch.nn.Module):

def __init__(self):
super().__init__()
self.mlp1 = MLP()
self.mlp2 = MLP()
self.relu = relu()
self.linear3 = torch.nn.Linear(4, 4)

def forward(self, x):
y1, y2 = checkpoint(self.mlp1, x)
y3, y4 = checkpoint(self.mlp2, x)
y3 = checkpoint(self.relu, x)

def ckpt2(x):
return F.relu(x, inplace=True)

y4 = checkpoint(ckpt2, x)
return y1 + y2 + y3 + y4


Expand All @@ -65,8 +81,8 @@ def _run_act_ckpt_codegen(rank):

# check ops are annotated with ckpt
# also annotate the selected node for offloading
ckpt_nodes = ['mlp1_linear1', 'mlp1_linear1_1', 'mlp2_linear1', 'mlp2_linear1_1']
offload_starts = ['mlp2_linear1']
ckpt_nodes = ['mlp1_linear1', 'mlp1_linear2', 'relu_relu', 'relu']
offload_starts = ['mlp1_linear1']
for node in graph.nodes:
if node.name in ckpt_nodes:
assert hasattr(node, 'activation_checkpoint')
Expand All @@ -75,15 +91,17 @@ def _run_act_ckpt_codegen(rank):
if node.name in offload_starts:
setattr(node, 'activation_offload', True)

gm = GraphModule(model, graph)
gm.recompile()

# assert checkpoint function will be generated and
# the offload option is correct
code = graph.python_code('self').src
assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, False, x)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, True, x)' in code
assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, True, x, use_reentrant=True)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, False, x, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_2, False, x, use_reentrant=False)' in code

# recompile and verify the outputs are consistent
gm = GraphModule(model, graph)
gm.recompile()
fx_out = gm(data)
assert torch.equal(non_fx_out, fx_out)

Expand Down Expand Up @@ -117,8 +135,8 @@ def _run_act_ckpt_python_code_torch11(rank):
graph._python_code = python_code_with_activation_checkpoint.__get__(graph)

# check ops are annotated with ckpt
ckpt_nodes = ['mlp1_linear1', 'mlp1_linear1_1', 'mlp2_linear1', 'mlp2_linear1_1']
offload_starts = ['mlp2_linear1']
ckpt_nodes = ['mlp1_linear1', 'mlp1_linear2', 'relu_relu', 'relu']
offload_starts = ['mlp1_linear1']
for node in graph.nodes:
if node.name in ckpt_nodes:
assert hasattr(node, 'activation_checkpoint')
Expand All @@ -127,15 +145,16 @@ def _run_act_ckpt_python_code_torch11(rank):
if node.name in offload_starts:
setattr(node, 'activation_offload', True)

gm = GraphModule(model, graph)
gm.recompile()
# assert checkpoint function will be generated and
# the offload option is correct
code = graph.python_code('self').src
assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, False, x)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, True, x)' in code
assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, True, x, use_reentrant=True)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, False, x, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_2, False, x, use_reentrant=False)' in code

# recompile and verify the outputs are consistent
gm = GraphModule(model, graph)
gm.recompile()
fx_out = gm(data)
assert torch.equal(non_fx_out, fx_out)

Expand Down