Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 35 additions & 11 deletions colossalai/auto_parallel/checkpoint/ckpt_solver_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
import torch
from torch.fx import Graph, Node

from colossalai.auto_parallel.passes.runtime_apply_pass import (
runtime_apply,
runtime_apply_for_iterable_object,
runtime_comm_spec_apply,
)
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
from colossalai.fx.profiler.memory_utils import is_inplace

__all___ = ['CheckpointSolverBase']

Expand All @@ -31,10 +35,11 @@ def __init__(
free_memory: float = -1.0,
requires_linearize: bool = False,
cnode: List[str] = None,
optim_multiplier: float = 1.0,
):
"""CheckpointSolver class will integrate information provided by the components
and use an existing solver to find a possible optimal strategies combination for
target computing graph.
"""``CheckpointSolverBase`` class will integrate information provided by the components
and use an existing solver to find a possible optimal strategies combination for target
computing graph.

Existing Solvers:
Chen's Greedy solver: https://arxiv.org/abs/1604.06174 (CheckpointSolverChen)
Expand All @@ -45,9 +50,11 @@ def __init__(
free_memory (float): Memory constraint for the solution.
requires_linearize (bool): Whether the graph needs to be linearized.
cnode (List[str], optional): Common node List, should be the subset of input. Default to None.
optim_multiplier (float, optional): The multiplier of extra weight storage for the
``torch.optim.Optimizer``. Default to 1.0.

Warnings:
`MetaInfoProp` should be done before constructing the solver. Meta information of the graph is required.
Meta information of the graph is required for any ``CheckpointSolver``.
"""
# super-dainiu: this graph is a temporary graph which can refer to
# the owning module, but we will return another deepcopy of it after
Expand All @@ -57,13 +64,14 @@ def __init__(
_copy_output(graph, self.graph)
self.graph.set_codegen(ActivationCheckpointCodeGen())

# check if `MetaInfoProp` is done
# check if has meta information
if any(len(node.meta) == 0 for node in self.graph.nodes):
raise RuntimeError(
"Nodes meta information hasn't been prepared! Please run MetaInfoProp before constructing the solver!")
"Nodes meta information hasn't been prepared! Please extract from graph before constructing the solver!"
)

self.free_memory = free_memory
self.parameter_size = _get_param_size(self.graph.owning_module)
# parameter memory = parameter size + optimizer extra weight storage
self.free_memory = free_memory - _get_param_size(self.graph.owning_module) * (optim_multiplier + 1)
self.cnode = cnode
self.requires_linearize = requires_linearize
if self.requires_linearize:
Expand Down Expand Up @@ -93,7 +101,7 @@ def _linearize_graph(self) -> List[List[Node]]:
the actual 'node' in linearized manner.

Remarks:
Do merge the inplace ops into the previous node.
Do merge the inplace ops and shape-consistency ops into the previous node.
"""

# Common nodes are type of nodes that could be seen as attributes and remain
Expand Down Expand Up @@ -131,7 +139,23 @@ def _is_sink() -> bool:
bool
"""

return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users))
def _is_inplace(n: Node):
"""Get the inplace argument from ``torch.fx.Node``
"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
elif n.op == "call_module":
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
return inplace

def _is_shape_consistency(n: Node):
"""Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)
"""
return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply]

return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any(
map(_is_shape_consistency, n.users))

# make sure that item in cnode is valid
if self.cnode:
Expand Down
6 changes: 3 additions & 3 deletions colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def __init__(self, graph: Graph, cnode: List[str] = None, num_grids: int = 6):
Note that this algorithm targets at memory optimization only, using techniques in appendix A.

Usage:
Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp`
Assume that we have a ``GraphModule``, and we have already done the extractions
to the graph to retrieve all information needed, then we could use the following
code to find a solution using `CheckpointSolverChen`:
code to find a solution using ``CheckpointSolverChen``:
>>> solver = CheckpointSolverChen(gm.graph)
>>> chen_graph = solver.solve()
>>> gm.graph = chen_graph # set the graph to a new graph
Expand Down Expand Up @@ -74,7 +74,7 @@ def run_chen_greedy(self, b: int = 0) -> Tuple[Set, int]:
def grid_search(self) -> Set:
"""
Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy.
Grid search over [√2/2 b, √2 b] for ckpt_opt over num_grids as in appendix A.
Grid search over [√2/2 b, √2 b] for ``ckpt_opt`` over ``num_grids`` as in appendix A.
"""
_, b_approx = self.run_chen_greedy(0)
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))
Expand Down
41 changes: 28 additions & 13 deletions colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch import Tensor
from torch.fx import Graph, Node

from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai.fx.profiler import (
activation_size,
Expand All @@ -22,15 +23,20 @@

class CheckpointSolverRotor(CheckpointSolverBase):

def __init__(self, graph: Graph, free_memory: float = -1, cnode: List[str] = None, memory_slots: int = 500):
def __init__(self,
graph: Graph,
free_memory: float = -1,
cnode: List[str] = None,
memory_slots: int = 500,
optim_multiplier: float = 1.0):
"""This is the simple implementation of dynamic programming algorithm rotor
in https://hal.inria.fr/hal-02352969. Some code are adapted from
https://gitlab.inria.fr/hiepacs/rotor.

Usage:
Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp`
Assume that we have a ``GraphModule``, and we have already done the extractions
to the graph to retrieve all information needed, then we could use the following
code to find a solution using `CheckpointSolverRotor`:
code to find a solution using ``CheckpointSolverRotor``:
>>> solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info(device=0)[0])
>>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver
>>> gm.graph = rotor_graph # set the graph to a new graph
Expand All @@ -41,8 +47,10 @@ def __init__(self, graph: Graph, free_memory: float = -1, cnode: List[str] = Non
Use ``torch.cuda.mem_get_info(device=0)[0]`` to estimate the free_memory. Defaults to -1.
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None.
memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500.
optim_multiplier (float, optional): The multiplier of extra weight storage for the
``torch.optim.Optimizer``. Default to 1.0.
"""
super().__init__(graph, free_memory, True, cnode)
super().__init__(graph, free_memory, True, cnode, optim_multiplier)
self.memory_slots = memory_slots

# construct chain
Expand Down Expand Up @@ -128,16 +136,24 @@ def _extract_node_info(cls, node: List[Node]) -> Tuple[int, ...]:
xbar = 0
ftime = 0
btime = 0
fwd_mem_peak = 0
for n in node:
assert isinstance(n, Node), f'{n} is not a Node'
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
if n.target == runtime_apply or n.target == runtime_comm_spec_apply:
# in this case we need to calculate memory usage directly based on the statics that hooked in node.meta
xbar += n.meta['fwd_mem_out']
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'])
else:
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n))

# minimum flop count is required
ftime += max(calculate_fwd_time(n), 1.0)
btime += max(calculate_bwd_time(n), 1.0)

x = calculate_fwd_out(node[-1])
xbar = max(x, xbar)
ftmp = cls._extract_ftmp(node)
ftmp = fwd_mem_peak - xbar
btmp = cls._extract_btmp(node)
return ftime, btime, x, xbar, ftmp, btmp

Expand All @@ -151,10 +167,9 @@ def _extract_input(graph: Graph) -> Tuple[Tensor, ...]:
return input_tensors

@staticmethod
def _extract_ftmp(node: List[Node]) -> int:
"""Extract ftmp from a list of nodes"""
n = node[-1]
return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n)
def _extract_unused_output(node: Node) -> int:
"""Extract unused output from `torch.fx.Node`"""
return activation_size(node.meta['fwd_out']) - calculate_fwd_out(node)

@staticmethod
def _extract_btmp(node: List[Node]) -> int:
Expand Down Expand Up @@ -290,8 +305,8 @@ def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[A
lhs (int): The left index of the interval to backtrack.
rhs (int): The right index of the interval to backtrack.
budget (int): The memory budget for processing this interval.
cost_table (List[Any]): See `._compute_table()` for definitions
back_ptr (List[Any]): See `._compute_table()` for definitions
cost_table (List[Any]): See ``._compute_table()`` for definitions
back_ptr (List[Any]): See ``._compute_table()`` for definitions

Raises:
ValueError: Can not process the chain.
Expand Down Expand Up @@ -332,7 +347,7 @@ def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[A

@staticmethod
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
"""Annotate the nodes in the node_list with activation checkpoint from the sequence.
"""Annotate the nodes in the ``node_list`` with activation checkpoint from the sequence.

Args:
sequence (Sequence): The sequence of executing nodes with activation checkpoint annotations.
Expand Down
5 changes: 4 additions & 1 deletion colossalai/auto_parallel/meta_profiler/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

from ..tensor_shard.constants import *

# list of inplace operations
# list of inplace module
INPLACE_MODULE = [nn.ReLU]

# list of inplace operations
INPLACE_OPS = [torch.flatten]

# list of operations that do not save forward activations
NO_SAVE_ACTIVATION = [torch.add, torch.sub, operator.add, operator.sub]
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,25 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""

input_op_data, other_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT]
input_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT]
output_op_data = next(filter(lambda arg: arg.type == OperationDataType.OUTPUT, args))

# construct forward args for flop mapping
fwd_in_args = [input_op_data.data, other_op_data.data]
fwd_in_args = [opdata.data for opdata in input_op_data]
fwd_out_args = [output_op_data.data]

# calculate cost

# calculate compute cost
# NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case
fwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d.default](fwd_in_args, fwd_out_args)
fwd_compute_cost = flop_mapping[torch.ops.aten.add.Tensor](fwd_in_args, fwd_out_args)
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)

# calculate memory cost
param_mem_cost = activation_size(
[arg.data for arg in [input_op_data, other_op_data] if arg.type == OperationDataType.PARAM])
param_mem_cost = activation_size([arg.data for arg in input_op_data if arg.type == OperationDataType.PARAM])
fwd_mem_cost = MemoryCost(
activation=activation_size([input_op_data.data, output_op_data.data]),
activation=activation_size(output_op_data.data),
parameter=param_mem_cost,
)
bwd_mem_cost = MemoryCost(
Expand All @@ -60,7 +59,7 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)

# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_op_data.data, device='meta')]
fwd_in = []
fwd_buffer = []
fwd_out = [torch.zeros_like(output_op_data.data, device='meta')]

Expand Down
26 changes: 7 additions & 19 deletions colossalai/auto_parallel/meta_profiler/metainfo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Callable, List

import numpy as np
import torch

from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
Expand All @@ -13,7 +12,7 @@
)
from colossalai.tensor.sharding_spec import ShardingSpec

from .constants import INPLACE_MODULE, NO_SAVE_ACTIVATION
from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION
from .registry import meta_register

__all__ = ['MetaInfo']
Expand Down Expand Up @@ -71,25 +70,12 @@ def target(self, target: Callable) -> None:
if self._strategy is not None and self._target is not None:
self.compute_metainfo()

def compute_sharded_tensor(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor:
def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor:
"""
Compute sharded meta tensor based on the given data and sharding spec.
Compute sharded opdata based on the given data and sharding spec.
"""
shard_sequnce = sharding_spec.sharding_sequence
device_mesh = sharding_spec.device_mesh
shape = operation_data.data.shape

new_shape = []
for dim, shard in zip(shape, shard_sequnce):
if shard.is_replica:
# replica
new_shape.append(dim)
else:
# sharded according to device_mesh shape
new_shape.append(dim // np.prod(np.array([device_mesh.mesh_shape[i] for i in shard.shard_list])))

return OperationData(name=operation_data.name,
data=torch.zeros(new_shape, device="meta"),
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
type=operation_data.type,
logical_shape=operation_data.logical_shape)

Expand All @@ -113,11 +99,13 @@ def compute_metainfo(self):
save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION

# construct args for meta_func
args = [self.compute_sharded_tensor(k, v) for k, v in self._strategy.sharding_specs.items()]
args = [self.compute_sharded_opdata(k, v) for k, v in self._strategy.sharding_specs.items()]

# construct kwargs
if self.target in INPLACE_MODULE:
kwargs = {'inplace': self.target.inplace}
elif self.target in INPLACE_OPS:
kwargs = {'inplace': True}
else:
kwargs = {'inplace': False}

Expand Down
Loading