Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions colossalai/fx/passes/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .ckpt_solver_chen import chen_greedy, chen_sqrtn
62 changes: 62 additions & 0 deletions colossalai/fx/passes/algorithms/ckpt_solver_chen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import torch
from torch.fx import GraphModule

__all__ = ['chen_greedy', 'chen_sqrtn']


Comment thread
super-dainiu marked this conversation as resolved.
def chen_greedy(gm: GraphModule, B: int):
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.

Usage:
B = 5 * 1024 * 1024 * 1024 # An approximate memory budget of 5GB
model = resnet18()
input_sample = torch.rand(4, 3, 224, 224)
gm = symbolic_trace(model)
MetaInfoProp(gm).run(input_sample)
gm = chen_greedy(gm, B)

Args:
gm (GraphModule): The module to add checkpoints
B (int): The approximate memory budget for this module.
"""
gm.graph.lint() # make sure nodes are in topological order
temp = 0
x = 0
idx = 0
budget = B
for n in gm.graph.nodes:
B -= getattr(n, 'param_size')
assert B > 0, f'The memory budget {budget / 1024 ** 3:.2f} GB is not enough for model parameters of {gm}'
for n in gm.graph.nodes:
temp += getattr(n, 'activation_size')
if temp > B:
x += getattr(n, 'activation_size')
temp = x
setattr(n, 'activation_checkpoint', str(idx))
idx += 1
gm.recompile()
return gm


def chen_sqrtn(gm: GraphModule):
"""
This is the theoretical optimal strategy in https://arxiv.org/abs/1604.06174.

Usage:
model = resnet18()
input_sample = torch.rand(4, 3, 224, 224)
gm = symbolic_trace(model)
MetaInfoProp(gm).run(input_sample)
gm = chen_sqrtn(gm)

Args:
gm (GraphModule): The module to add checkpoints
"""
gm.graph.lint() # make sure nodes are in topological order
k = int(len(gm.graph.nodes)**0.5) # take approximately sqrt(n) checkpoints
for idx, n in enumerate(gm.graph.nodes):
if (idx + 1) % k == 0:
setattr(n, 'activation_checkpoint', str((idx + 1) // k))
gm.recompile()
return gm
40 changes: 40 additions & 0 deletions tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn
import torch
import torchvision.models as tm
from colossalai.fx import ColoTracer
from torch.fx import GraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from functools import partial
import pytest

SOLVERS = [partial(chen_greedy, B=1024 * 1024 * 64), chen_sqrtn]


def _is_activation_checkpoint_available(gm: GraphModule):
for n in gm.graph.nodes:
if hasattr(n, 'activation_checkpoint') and getattr(n, 'activation_checkpoint') is not None:
return True


def test_ckpt_solver():
MODEL_LIST = [tm.resnet18, tm.densenet121]

torch.backends.cudnn.deterministic = True

tracer = ColoTracer()
data = torch.rand(1, 3, 224, 224)

for solver in SOLVERS:
for model_cls in MODEL_LIST:
model = model_cls()
graph = tracer.trace(root=model)
gm = GraphModule(model, graph, model.__class__.__name__)
MetaInfoProp(gm).run(data)
gm = solver(gm)
assert _is_activation_checkpoint_available(
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
assert torch.allclose(gm(data), model(data))


if __name__ == '__main__':
test_ckpt_solver()