Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions colossalai/fx/passes/algorithms/ckpt_solver_chen.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from typing import Set, Tuple
from typing import List, Set, Tuple
import torch
from torch.fx import GraphModule
import math

__all__ = ['chen_greedy', 'chen_sqrtn']


def _all_potential_ckpt_nodes(gm: GraphModule) -> List:
ckpt_nodes = []
for n in gm.graph.nodes:
if n.op == 'call_module':
ckpt_nodes.append(n)
return ckpt_nodes


def chen_greedy(gm: GraphModule) -> GraphModule:
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
Expand All @@ -31,36 +39,40 @@ def grid_search(num_grids: int = 6) -> Set:
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))
b_opt = math.inf
for b in range(b_min, b_max, (b_max - b_min) // num_grids):
ckpt, b_approx = run_chen_greedy(b)
ckpt_intv, b_approx = run_chen_greedy(b)
if b_approx < b_opt:
b_opt = b_approx
ckpt_opt = ckpt
ckpt_opt = ckpt_intv
return ckpt_opt

def run_chen_greedy(b: int = 0) -> Tuple[Set, int]:
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
"""
ckpt = set()
ckpt_nodes = _all_potential_ckpt_nodes(gm)
ckpt_intv = []
temp = 0
x = 0
y = 0
prev_idx = 2
for (idx, n) in enumerate(gm.graph.nodes):
temp += getattr(n, 'activation_size')
y = max(y, temp)
if temp > b:
if temp > b and n in ckpt_nodes:
x += getattr(n, 'activation_size')
temp = 0
ckpt.add(idx)
return ckpt, math.floor(math.sqrt(x * y))
ckpt_intv.append((prev_idx, idx + 1))
prev_idx = idx + 1
return ckpt_intv, math.floor(math.sqrt(x * y))

gm.graph.lint() # make sure nodes are in topological order
ckpt = grid_search(num_grids=6)
i = 0
for idx, n in enumerate(gm.graph.nodes):
if idx in ckpt:
setattr(n, 'activation_checkpoint', str(i))
i += 1
node_list = list(gm.graph.nodes)
for i, seg in enumerate(ckpt):
for idx in range(*seg):
n = node_list[idx]
if n.op in ['call_module', 'call_method', 'call_function']:
setattr(n, 'activation_checkpoint', str(i))
gm.recompile()
return gm

Expand All @@ -82,7 +94,9 @@ def chen_sqrtn(gm: GraphModule) -> GraphModule:
gm.graph.lint() # make sure nodes are in topological order
k = int(len(gm.graph.nodes)**0.5) # take approximately sqrt(n) checkpoints
for idx, n in enumerate(gm.graph.nodes):
if (idx + 1) % k == 0:
# We should not add act_ckpt to the placeholder
# The last segment should not be checkpointed
if n.op != 'placeholder' and (idx + 1) // k < k:
setattr(n, 'activation_checkpoint', str((idx + 1) // k))
gm.recompile()
return gm
90 changes: 73 additions & 17 deletions tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
from ctypes import Union
from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn
from typing import Callable
import copy
import torch
import torch.multiprocessing as mp
import torchvision.models as tm
from colossalai.fx import ColoTracer
from torch.fx import GraphModule
import colossalai
from colossalai.fx import ColoTracer
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
import pytest

try:
from colossalai.fx.codegen import ActivationCheckpointCodeGen
with_codegen = True
except:
# fall back to older pytorch version
from colossalai.fx.codegen import python_code_with_activation_checkpoint
with_codegen = False

SOLVERS = [chen_greedy, chen_sqrtn]


Expand All @@ -18,37 +31,80 @@ def _is_activation_checkpoint_available(gm: GraphModule):

def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule):
for m_p, gm_p in zip(m.parameters(), gm.parameters()):
if not torch.allclose(m_p, gm_p):
if not torch.allclose(m_p.grad, gm_p.grad):
return False
return True


def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule],
model_cls: Callable[[], torch.nn.Module]):
criterion = torch.nn.MSELoss()
data = torch.rand(2, 3, 32, 32)
label = torch.rand(2, 5)
loss = criterion(m(data), label)
loss.backward()
loss = criterion(gm(data), label)
loss.backward()
assert _is_all_gradient_close(m, gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}'


def _run_ckpt_solver(rank):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
MODEL_LIST = [tm.resnet18, tm.densenet121]

torch.backends.cudnn.deterministic = True

tracer = ColoTracer(trace_act_ckpt=False)

data = torch.rand(2, 3, 32, 32)
for solver in SOLVERS:
for model_cls in MODEL_LIST:
m = model_cls(num_classes=5)
graph = tracer.trace(root=m)
gm = GraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
MetaInfoProp(gm).run(data)
codegen = ActivationCheckpointCodeGen()
gm.graph.set_codegen(codegen)
gm = solver(gm)
assert _is_activation_checkpoint_available(
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
check_backward_consistency(m, gm, solver, model_cls)


@pytest.mark.skip
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
def test_ckpt_solver():
mp.spawn(_run_ckpt_solver, nprocs=1)


def _run_ckpt_solver_torch11(rank):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
MODEL_LIST = [tm.resnet18, tm.densenet121]

torch.backends.cudnn.deterministic = True

tracer = ColoTracer()
data = torch.rand(1, 3, 224, 224)
label = torch.rand(1, 1000)
tracer = ColoTracer(trace_act_ckpt=False)

data = torch.rand(2, 3, 32, 32)
for solver in SOLVERS:
for model_cls in MODEL_LIST:
model = model_cls()
criterion = torch.nn.MSELoss()
graph = tracer.trace(root=model)
gm = GraphModule(model, graph, model.__class__.__name__)
m = model_cls(num_classes=5)
graph = tracer.trace(root=m)
gm = GraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
MetaInfoProp(gm).run(data)
gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
gm = solver(gm)
assert _is_activation_checkpoint_available(
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
loss = criterion(model(data), label)
loss.backward()
loss = criterion(gm(data), label)
loss.backward()
assert _is_all_gradient_close(model,
gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}'
check_backward_consistency(m, gm, solver, model_cls)


@pytest.mark.skip
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
def test_ckpt_solver_torch11():
mp.spawn(_run_ckpt_solver_torch11, nprocs=1)


if __name__ == '__main__':
test_ckpt_solver()
test_ckpt_solver_torch11()