Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions colossalai/auto_parallel/checkpoint/ckpt_solver_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
import torch
from torch.fx import Graph, Node

from colossalai.auto_parallel.passes.runtime_apply_pass import (
runtime_apply,
runtime_apply_for_iterable_object,
runtime_comm_spec_apply,
)
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
from colossalai.fx.profiler.memory_utils import is_inplace

__all___ = ['CheckpointSolverBase']

Expand Down Expand Up @@ -131,7 +135,23 @@ def _is_sink() -> bool:
bool
"""

return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users))
def _is_inplace(n: Node):
"""Get the inplace argument from torch.fx.Node
"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
elif n.op == "call_module":
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
return inplace

def _is_shape_consistency(n: Node):
"""Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)
"""
return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply]

return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any(
map(_is_shape_consistency, n.users))

# make sure that item in cnode is valid
if self.cnode:
Expand Down