Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 45 additions & 19 deletions colossalai/fx/passes/algorithms/ckpt_solver_chen.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,71 @@
from typing import Set, Tuple
import torch
from torch.fx import GraphModule
import math

__all__ = ['chen_greedy', 'chen_sqrtn']


def chen_greedy(gm: GraphModule, B: int):
def chen_greedy(gm: GraphModule) -> GraphModule:
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
Note that this algorithm targets at memory optimization only, using techniques in appendix A.

Usage:
B = 5 * 1024 * 1024 * 1024 # An approximate memory budget of 5GB
model = resnet18()
input_sample = torch.rand(4, 3, 224, 224)
gm = symbolic_trace(model)
MetaInfoProp(gm).run(input_sample)
gm = chen_greedy(gm, B)
gm = chen_greedy(gm)

Args:
gm (GraphModule): The module to add checkpoints
B (int): The approximate memory budget for this module.
"""

def grid_search(num_grids: int = 6) -> Set:
"""
Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy.
Grid search over [√2/2 b, √2 b] for ckpt_opt over num_grids as in appendix A.
"""
_, b_approx = run_chen_greedy(0)
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))
b_opt = math.inf
for b in range(b_min, b_max, (b_max - b_min) // num_grids):
ckpt, b_approx = run_chen_greedy(b)
if b_approx < b_opt:
b_opt = b_approx
ckpt_opt = ckpt
return ckpt_opt

def run_chen_greedy(b: int = 0) -> Tuple[Set, int]:
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
"""
ckpt = set()
temp = 0
x = 0
y = 0
for (idx, n) in enumerate(gm.graph.nodes):
temp += getattr(n, 'activation_size')
y = max(y, temp)
if temp > b:
x += getattr(n, 'activation_size')
temp = 0
ckpt.add(idx)
return ckpt, math.floor(math.sqrt(x * y))

gm.graph.lint() # make sure nodes are in topological order
temp = 0
x = 0
idx = 0
budget = B
for n in gm.graph.nodes:
B -= getattr(n, 'param_size')
assert B > 0, f'The memory budget {budget / 1024 ** 3:.2f} GB is not enough for model parameters of {gm}'
for n in gm.graph.nodes:
temp += getattr(n, 'activation_size')
if temp > B:
x += getattr(n, 'activation_size')
temp = x
setattr(n, 'activation_checkpoint', str(idx))
idx += 1
ckpt = grid_search(num_grids=6)
i = 0
for idx, n in enumerate(gm.graph.nodes):
if idx in ckpt:
setattr(n, 'activation_checkpoint', str(i))
i += 1
gm.recompile()
return gm


def chen_sqrtn(gm: GraphModule):
def chen_sqrtn(gm: GraphModule) -> GraphModule:
"""
This is the theoretical optimal strategy in https://arxiv.org/abs/1604.06174.

Expand Down
20 changes: 17 additions & 3 deletions tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from ctypes import Union
from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn
import torch
import torchvision.models as tm
from colossalai.fx import ColoTracer
from torch.fx import GraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from functools import partial
import pytest

SOLVERS = [partial(chen_greedy, B=1024 * 1024 * 64), chen_sqrtn]
SOLVERS = [chen_greedy, chen_sqrtn]


def _is_activation_checkpoint_available(gm: GraphModule):
Expand All @@ -16,24 +16,38 @@ def _is_activation_checkpoint_available(gm: GraphModule):
return True


def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule):
for m_p, gm_p in zip(m.parameters(), gm.parameters()):
if not torch.allclose(m_p, gm_p):
return False
return True


def test_ckpt_solver():
MODEL_LIST = [tm.resnet18, tm.densenet121]

torch.backends.cudnn.deterministic = True

tracer = ColoTracer()
data = torch.rand(1, 3, 224, 224)
label = torch.rand(1, 1000)

for solver in SOLVERS:
for model_cls in MODEL_LIST:
model = model_cls()
criterion = torch.nn.MSELoss()
graph = tracer.trace(root=model)
gm = GraphModule(model, graph, model.__class__.__name__)
MetaInfoProp(gm).run(data)
gm = solver(gm)
assert _is_activation_checkpoint_available(
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
assert torch.allclose(gm(data), model(data))
loss = criterion(model(data), label)
loss.backward()
loss = criterion(gm(data), label)
loss.backward()
assert _is_all_gradient_close(model,
gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}'


if __name__ == '__main__':
Expand Down