Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion colossalai/_analyzer/_subclasses/flop_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,28 @@ def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
# Inputs contains the shapes of two matrices.
input_shapes = [v.shape for v in inputs]
assert len(input_shapes) == 2, input_shapes
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes

# There are three cases: 1) gemm, 2) gemv, 3) dot
if all(len(shape) == 2 for shape in input_shapes):
# gemm
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
elif all(len(shape) == 1 for shape in input_shapes):
# dot
assert input_shapes[0][0] == input_shapes[1][0], input_shapes

# expand shape
input_shapes[0] = torch.Size([1, input_shapes[0][0]])
input_shapes[1] = torch.Size([input_shapes[1][0], 1])
else:
# gemv
if len(input_shapes[0]) == 1:
assert input_shapes[0][0] == input_shapes[1][-2], input_shapes
input_shapes.reverse()
else:
assert input_shapes[1][0] == input_shapes[0][-1], input_shapes

# expand the shape of the vector to [batch size, 1]
input_shapes[-1] = torch.Size([input_shapes[-1][-1], 1])
flops = reduce(operator.mul, input_shapes[0]) * input_shapes[-1][-1]
return flops

Expand Down
19 changes: 11 additions & 8 deletions colossalai/_analyzer/fx/codegen.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import Any, Callable, Dict, Iterable, List, Tuple

import torch

try:
from torch.fx.graph import CodeGen
except:
pass
from torch.fx.graph import (
CodeGen,
PythonCode,
_custom_builtins,
_format_target,
Expand Down Expand Up @@ -48,8 +52,8 @@ def _end_of_ckpt(node: Node, ckpt_level: int) -> bool:
"""
Check if the node could end the ckpt region at `ckpt_level`
"""
if len(node.meta['info'].to_recompute) > ckpt_level:
return node.meta['info'].to_recompute[ckpt_level] is not None
if len(node.meta['info'].activation_checkpoint) > ckpt_level:
return node.meta['info'].activation_checkpoint[ckpt_level] is not None
return True


Expand Down Expand Up @@ -90,8 +94,8 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
current_region = None

for idx, node in enumerate(node_list):
if len(node.meta['info'].to_recompute) > ckpt_level:
act_ckpt_label = node.meta['info'].to_recompute[ckpt_level]
if len(node.meta['info'].activation_checkpoint) > ckpt_level:
act_ckpt_label = node.meta['info'].activation_checkpoint[ckpt_level]

# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
Expand Down Expand Up @@ -152,12 +156,12 @@ def emit_ckpt_func(body,

# label given by each layer, e.g. if you are currently at level (0, 1, 1)
# the label will be '0_1_1'
label = "_".join([str(idx) for idx in node_list[0].meta['info'].to_recompute[:ckpt_level + 1]])
label = "_".join([str(idx) for idx in node_list[0].meta['info'].activation_checkpoint[:ckpt_level + 1]])
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
ckpt_func.append(f'{ckpt_fn_def}\n')

# if there is more level to fetch
if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].to_recompute), node_list)):
if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].activation_checkpoint), node_list)):
ckpt_regions = _find_nested_ckpt_regions(node_list, ckpt_level + 1)
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]
Expand Down Expand Up @@ -215,7 +219,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
ckpt_regions = _find_nested_ckpt_regions(nodes, 0)
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]

node_list = list(nodes)

node_idx = 0
Expand Down
2 changes: 1 addition & 1 deletion colossalai/_analyzer/fx/node_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class MetaInfo:

# should keep the same whenever manipulated
# ============================= Invariant ==================================
to_recompute: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
to_offload: Optional[bool] = False
sharding_spec: str = 'RR'

Expand Down
9 changes: 8 additions & 1 deletion colossalai/_analyzer/fx/passes/shape_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,14 @@ def propagate(self, *args, device=None):
Returns:
Any: The value returned from executing the Module
"""
wrap_fn = lambda elem: MetaTensor(elem, device=device)

# wrap_fn = lambda elem: MetaTensor(elem, device=device)
def wrap_fn(elem, device=device):
if isinstance(elem, torch.Tensor):
return MetaTensor(elem, device=device)
else:
return elem

with self._mode:
return super().run(*tree_map(wrap_fn, args))

Expand Down
114 changes: 78 additions & 36 deletions colossalai/_analyzer/fx/tracer/bias_addition.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,69 +21,111 @@ def linear_impl(input, weight, bias=None):


@register_tracer_impl(F.conv1d, name='_bias_addition_impl')
def conv1d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1):
if bias is None:
return F.conv1d(input, weight, **kwargs)
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv1d(input, weight, **kwargs) + bias.reshape((-1, 1))
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1))


@register_tracer_impl(F.conv2d, name='_bias_addition_impl')
def conv2d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1):
if bias is None:
return F.conv2d(input, weight, **kwargs)
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv2d(input, weight, **kwargs) + bias.reshape((-1, 1, 1))
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1, 1))


@register_tracer_impl(F.conv3d, name='_bias_addition_impl')
def conv3d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1):
if bias is None:
return F.conv3d(input, weight, **kwargs)
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1))
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1, 1, 1))


@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl')
def conv_transpose1d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
def conv_transpose1d_impl(input,
weight,
bias=None,
stride=_single(1),
padding=_single(0),
output_padding=_single(0),
groups=1,
dilation=_single(1)):
if bias is None:
return F.conv_transpose1d(input, weight, **kwargs)
return F.conv_transpose1d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
else:
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv_transpose1d(input, weight, **new_kwargs) + bias.reshape((-1, 1))
return F.conv_transpose1d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1))


@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl')
def conv_transpose2d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
def conv_transpose2d_impl(input,
weight,
bias=None,
stride=_pair(1),
padding=_pair(0),
output_padding=_pair(0),
groups=1,
dilation=_pair(1)):
if bias is None:
return F.conv_transpose2d(input, weight, **kwargs)
return F.conv_transpose2d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
else:
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv_transpose2d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1))
return F.conv_transpose2d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1, 1))


@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl')
def conv_transpose3d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
def conv_transpose3d_impl(input,
weight,
bias=None,
stride=_triple(1),
padding=_triple(0),
output_padding=_triple(0),
groups=1,
dilation=_triple(1)):
if bias is None:
return F.conv_transpose3d(input, weight, **kwargs)
return F.conv_transpose3d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
else:
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv_transpose3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1))
return F.conv_transpose3d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1, 1, 1))


@register_tracer_impl(torch.addmm, name='_bias_addition_impl')
Expand Down
2 changes: 1 addition & 1 deletion colossalai/_analyzer/fx/tracer/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def create_proxy(self,

def create_node(self, *args, **kwargs) -> Node:
node = super().create_node(*args, **kwargs)
n_info = MetaInfo(node, mod_dir=self.mod_dir, to_recompute=tuple(self.ckpt_regions))
n_info = MetaInfo(node, mod_dir=self.mod_dir, activation_checkpoint=tuple(self.ckpt_regions))
return node

def trace(self,
Expand Down
2 changes: 1 addition & 1 deletion colossalai/auto_parallel/meta_profiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .meta_registry import *
from .metainfo import *
from .registry import meta_register
from .shard_metainfo import *
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import torch

from colossalai._analyzer._subclasses.flop_tensor import ewise_flop_counter as elementwise_flop_counter
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import elementwise_flop_counter

from ..registry import meta_register

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import torch

from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping

from ..constants import BCAST_FUNC_OP, NO_SAVE_ACTIVATION
from ..registry import meta_register
Expand All @@ -17,7 +17,7 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
"""Meta information generator for binary elementwise operations
NOTE: Some of the binary elementwise operations will discard the input activation after computation, as they
don't need those tensors for back propagation, for example, if there are two tensors being sent for `torch.add`,
they will be discarded right after add operation is done. We create a simple API in `MetaInfo` class to identify
they will be discarded right after add operation is done. We create a simple API in `ShardMetaInfo` class to identify
this behavior, it is critical for better memory estimation.

Returns:
Expand Down
28 changes: 14 additions & 14 deletions colossalai/auto_parallel/meta_profiler/meta_registry/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import torch

from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
Expand All @@ -10,8 +12,6 @@
StrategiesVector,
TrainCycleItem,
)
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from colossalai.tensor.sharding_spec import ShardingSpec

from ..registry import meta_register
Expand Down Expand Up @@ -110,18 +110,18 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# calculate memory cost
# TODO: use profiler to check conv temp memory
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(
activation=activation_size([input_tensor, output_tensor]),
parameter=activation_size([weight_tensor, bias_tensor]) if has_bias else activation_size(weight_tensor),
temp=0,
buffer=0)

bwd_memory_cost = MemoryCost(
activation=activation_size([input_tensor, weight_tensor, bias_tensor])
if has_bias else activation_size([input_tensor, weight_tensor]),
parameter=activation_size([weight_tensor, bias_tensor]) if has_bias else activation_size(weight_tensor),
temp=0,
buffer=0)
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
if has_bias else compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0)

bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
if has_bias else compute_size_in_bytes([input_tensor, weight_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
if has_bias else compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0)

# total cost is the sum of forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import torch

from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping

from ..registry import meta_register

Expand Down Expand Up @@ -34,11 +34,11 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
# NOTE: during the backward phase of torch.nn.Embedding, it seems when the input is large enough, it will
# have a temp memory which is kind of weird and we don't know the reason yet, so currently we just assume
# that there will be no temp memory, as the temp memory is significantly smaller than the gradient memory
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]),
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=0,
temp=0,
buffer=0)
bwd_memory_cost = MemoryCost(activation=activation_size([weight_tensor]), parameter=0, temp=0, buffer=0)
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([weight_tensor]), parameter=0, temp=0, buffer=0)

total_memory_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation)

Expand Down
Loading