Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
06f8991
[fx] modify the calculation of node_size in MetaInfoProp for activati…
super-dainiu Aug 9, 2022
3cd7d22
[fx] modify the calculation of node_size in MetaInfoProp for activati…
super-dainiu Aug 9, 2022
0849b3b
[fx] modify the calculation of node_size in MetaInfoProp for activati…
super-dainiu Aug 10, 2022
701786c
Merge branch 'hpcaitech:main' into main
super-dainiu Aug 10, 2022
a75e5a2
Merge branch 'hpcaitech:main' into main
super-dainiu Aug 10, 2022
c20beb2
Merge branch 'hpcaitech:main' into main
super-dainiu Aug 11, 2022
7e87286
Merge branch 'hpcaitech:main' into main
super-dainiu Aug 11, 2022
f027931
Merge branch 'hpcaitech:main' into main
super-dainiu Aug 12, 2022
9b4f460
[fx] merge development into main (#1)
super-dainiu Aug 12, 2022
bea7060
[fx] add rules to linearize computation graphs for searching. (#2)
super-dainiu Aug 16, 2022
86c005d
[fx] merge
super-dainiu Aug 16, 2022
da259cc
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
296b405
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
bf7feea
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
e6c5f70
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
0cbafd8
Merge branch 'feature/linear_ckpt' of http://github.com/super-dainiu/…
super-dainiu Aug 16, 2022
8e14703
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
92e8223
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
3e9531c
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
02c5cae
[fx] remove chen_sqrt for sake of simplicity
super-dainiu Aug 16, 2022
a8616ef
Merge branch 'hpcaitech:main' into feature/linear_ckpt
super-dainiu Aug 17, 2022
083cf7f
[fx] fix inconsistencies.
super-dainiu Aug 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/fx/passes/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .ckpt_solver_chen import chen_greedy, chen_sqrtn
from .ckpt_solver_chen import chen_greedy
52 changes: 22 additions & 30 deletions colossalai/fx/passes/algorithms/ckpt_solver_chen.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,33 @@
from typing import List, Set, Tuple
import torch
from torch.fx import GraphModule
from torch.fx import GraphModule, Node
import math

__all__ = ['chen_greedy', 'chen_sqrtn']
__all__ = ['chen_greedy']
CKPT_OP = ['call_module', 'call_method', 'call_function', 'get_attr']


def _all_potential_ckpt_nodes(gm: GraphModule) -> List:
"""
In most existing frameworks of activation checkpoint, the forward graph is assumed to be linearized.
"""

def is_sink():
"""
If we can free all memories when executing a certain node, it is a sink.
"""
return not sum((v for k, v in deps.items()))

deps = {}
ckpt_nodes = []
for n in gm.graph.nodes:
if n.op == 'call_module':
for n_par in n._input_nodes:
deps[n_par] -= 1 # free memory and dependencies

# We can only put act_ckpt on these nodes
if n.op in CKPT_OP and is_sink():
ckpt_nodes.append(n)
deps[n] = len(n.users) # add dependencies for future executions
return ckpt_nodes


Expand Down Expand Up @@ -71,32 +88,7 @@ def run_chen_greedy(b: int = 0) -> Tuple[Set, int]:
for i, seg in enumerate(ckpt):
for idx in range(*seg):
n = node_list[idx]
if n.op in ['call_module', 'call_method', 'call_function']:
setattr(n, 'activation_checkpoint', str(i))
gm.recompile()
return gm


def chen_sqrtn(gm: GraphModule) -> GraphModule:
"""
This is the theoretical optimal strategy in https://arxiv.org/abs/1604.06174.

Usage:
model = resnet18()
input_sample = torch.rand(4, 3, 224, 224)
gm = symbolic_trace(model)
MetaInfoProp(gm).run(input_sample)
gm = chen_sqrtn(gm)

Args:
gm (GraphModule): The module to add checkpoints
"""
gm.graph.lint() # make sure nodes are in topological order
k = int(len(gm.graph.nodes)**0.5) # take approximately sqrt(n) checkpoints
for idx, n in enumerate(gm.graph.nodes):
# We should not add act_ckpt to the placeholder
# The last segment should not be checkpointed
if n.op != 'placeholder' and (idx + 1) // k < k:
setattr(n, 'activation_checkpoint', str((idx + 1) // k))
if n.op in CKPT_OP:
setattr(n, 'activation_checkpoint', i)
gm.recompile()
return gm
21 changes: 17 additions & 4 deletions tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import Callable
import copy
import re
import torch
import torch.multiprocessing as mp
import torchvision.models as tm
from torch.fx import GraphModule
import colossalai
from colossalai.fx import ColoTracer
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn
from colossalai.fx.passes.algorithms import chen_greedy
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
import pytest
Expand All @@ -20,7 +21,7 @@
from colossalai.fx.codegen import python_code_with_activation_checkpoint
with_codegen = False

SOLVERS = [chen_greedy, chen_sqrtn]
SOLVERS = [chen_greedy]


def _is_activation_checkpoint_available(gm: GraphModule):
Expand All @@ -36,6 +37,16 @@ def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule):
return True


def _is_graph_linearized(gm: GraphModule):
code = gm.code
# find patterns like r' return output_1, output_2', which is not expected on a linearized graph
pattern = re.compile(r' return [a-zA-Z0-9_]+(, [a-zA-Z0-9_]+)+')
if pattern.findall(code):
return False
else:
return True


def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule],
model_cls: Callable[[], torch.nn.Module]):
criterion = torch.nn.MSELoss()
Expand Down Expand Up @@ -66,12 +77,13 @@ def _run_ckpt_solver(rank):
codegen = ActivationCheckpointCodeGen()
gm.graph.set_codegen(codegen)
gm = solver(gm)
assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner."
assert _is_activation_checkpoint_available(
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
check_backward_consistency(m, gm, solver, model_cls)
gpc.destroy()


@pytest.mark.skip
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
def test_ckpt_solver():
mp.spawn(_run_ckpt_solver, nprocs=1)
Expand All @@ -94,12 +106,13 @@ def _run_ckpt_solver_torch11(rank):
MetaInfoProp(gm).run(data)
gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
gm = solver(gm)
assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner."
assert _is_activation_checkpoint_available(
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
check_backward_consistency(m, gm, solver, model_cls)
gpc.destroy()


@pytest.mark.skip
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
def test_ckpt_solver_torch11():
mp.spawn(_run_ckpt_solver_torch11, nprocs=1)
Expand Down