Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions colossalai/fx/passes/meta_info_prop.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from operator import add, getitem
import torch
import torch.fx
from torch.fx.node import Node, map_aggregate, Argument, Target
from torch.fx.node import Node, Argument, Target
from torch.utils._pytree import tree_map
from typing import Any, Tuple, NamedTuple, Optional, Dict
from functools import reduce
from torch.fx._compatibility import compatibility
from torch.fx.immutable_collections import immutable_dict, immutable_list
from colossalai.fx.profiler import MetaProfile, profile_function, profile_module, calculate_activation_size, profile_method
from colossalai.fx.profiler import MetaProfile, MetaTensor, profile_function, profile_module, calculate_activation_size, profile_method


@compatibility(is_backward_compatible=True)
Expand Down Expand Up @@ -75,9 +76,7 @@ def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_pr
"""
Add additional check for initial args to ensure all the tensor appears with `device='meta'`
"""
for elem in args:
if isinstance(elem, torch.Tensor):
assert elem.is_meta, "Input torch.Tensor are assumed to appear with device='meta'"
args = tree_map(lambda elem: MetaTensor(elem.to('meta')) if isinstance(elem, torch.Tensor) else elem, args)
return super().run(*args, initial_env, enable_io_processing)

@compatibility(is_backward_compatible=True)
Expand All @@ -103,7 +102,7 @@ def extract_tensor_meta(obj):
else:
return TensorMetadata(None, None, False, None, 0, False)

meta = map_aggregate(result, extract_tensor_meta)
meta = tree_map(extract_tensor_meta, result)
Comment thread
FrankLeeeee marked this conversation as resolved.
n.meta['tensor_meta'] = meta

# TODO: the attribute node_size should be removed in the future
Expand Down
8 changes: 7 additions & 1 deletion colossalai/fx/profiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from .registry import *
try:
from ._meta_registrations import *
except:
import torch
print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.')
from .meta_tensor import MetaTensor
from .registry import meta_profiler_function, meta_profiler_module
from .profiler_function import *
from .profiler_module import *
from .profiler import *
339 changes: 339 additions & 0 deletions colossalai/fx/profiler/_meta_registrations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,339 @@
# meta patch from https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py
# should be activated for PyTorch version 1.12.0 and below

from typing import List, Optional, Tuple, Union
import torch
from torch.utils._pytree import tree_map


aten = torch.ops.aten

meta_lib = torch.library.Library("aten", "IMPL", "Meta")

meta_table = {}


def register_meta(op, register_dispatcher=True):
def wrapper(f):
def add_func(op):
meta_table[op] = f
if register_dispatcher:
name = (
op.__name__
if op._overloadname != "default"
else op.overloadpacket.__name__
)
meta_lib.impl(name, f)

tree_map(add_func, op)
return f

return wrapper


# https://github.com/pytorch/pytorch/pull/79834
@register_meta(aten.convolution.default)
def meta_conv(
Comment thread
FrankLeeeee marked this conversation as resolved.
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: List[int],
padding: List[int],
dilation: List[int],
is_transposed: bool,
output_padding: List[int],
groups: int,
):
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
"""
Formula to apply to calculate the length of some dimension of the output
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
Args:
ln: length of the dimension
p: padding in that dim
d: dilation in that dim
k: kernel size in that dim
s: stride in that dim
Returns:
The output length
"""
return (ln + 2 * p - d * (k - 1) - 1) // s + 1

def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int:
"""
Formula to apply to calculate the length of some dimension of the output
if transposed convolution is used.
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
Args:
ln: length of the dimension
p: padding in that dim
d: dilation in that dim
k: kernel size in that dim
s: stride in that dim
op: output padding in that dim
Returns:
The output length
"""
return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1

def calc_conv_nd_return_shape(
dims: torch.Size,
kernel_size: torch.Size,
stride: Union[List[int], int],
padding: Union[List[int], int],
dilation: Union[List[int], int],
output_padding: Optional[Union[List[int], int]] = None,
):
ret_shape = []
if isinstance(stride, int):
stride = [stride] * len(dims)
elif len(stride) == 1:
stride = [stride[0]] * len(dims)

if isinstance(padding, int):
padding = [padding] * len(dims)
elif len(padding) == 1:
padding = [padding[0]] * len(dims)

if isinstance(dilation, int):
dilation = [dilation] * len(dims)
elif len(dilation) == 1:
dilation = [dilation[0]] * len(dims)

output_padding_list: Optional[List[int]] = None
if output_padding:
if isinstance(output_padding, int):
output_padding_list = [output_padding] * len(dims)
elif len(output_padding) == 1:
output_padding_list = [output_padding[0]] * len(dims)
else:
output_padding_list = output_padding

for i in range(len(dims)):
# If output_padding is present, we are dealing with a transposed convolution
if output_padding_list:
ret_shape.append(
_formula_transposed(
dims[i],
padding[i],
dilation[i],
kernel_size[i],
stride[i],
output_padding_list[i],
)
)
else:
ret_shape.append(
_formula(
dims[i], padding[i], dilation[i], kernel_size[i], stride[i]
)
)
return ret_shape

def pick_memory_format():
if input_tensor.is_contiguous(memory_format=torch.channels_last):
return torch.channels_last
elif input_tensor.is_contiguous(memory_format=torch.contiguous_format):
return torch.contiguous_format
elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
return torch.preserve_format

kernel_size = weight.shape[2:]
dims = input_tensor.shape[2:]
if is_transposed:
out_channels = groups * weight.shape[1]

shape_out = calc_conv_nd_return_shape(
dims,
kernel_size,
stride,
padding,
dilation,
output_padding,
)

else:
out_channels = weight.shape[0]
if weight.shape[1] != input_tensor.shape[1] / groups:
raise RuntimeError("Invalid channel dimensions")
shape_out = calc_conv_nd_return_shape(
dims, kernel_size, stride, padding, dilation
)
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
mem_fmt = pick_memory_format()
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
return out


@register_meta(aten.convolution_backward.default)
def meta_conv_backward(
grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
bias_sizes, stride, padding, dilation, transposed, output_padding, groups, output_mask
):
return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta')


@register_meta(aten.relu.default)
def meta_relu(input: torch.Tensor):
return torch.empty_like(input)


@register_meta(aten.hardswish.default)
def meta_hardswish(input: torch.Tensor):
return torch.empty_like(input)


@register_meta(aten.hardswish_backward.default)
def meta_hardswish_backward(grad_out:torch.Tensor, input: torch.Tensor):
grad_in = torch.empty_like(input)
return grad_in


@register_meta([aten.roll.default, ])
def meta_roll(input:torch.Tensor, shifts, dims):
return torch.empty_like(input)


@register_meta(aten.native_batch_norm.default)
def meta_bn(
input: torch.Tensor,
weight, bias, running_mean, running_var, training, momentum, eps
):
n_input = input.size(1)

output = torch.empty_like(input)
running_mean = torch.empty((n_input), device='meta')
running_var = torch.empty((n_input), device='meta')
return output, running_mean, running_var


@register_meta(aten.native_batch_norm_backward.default)
def meta_bn_backward(
dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
running_mean, running_var, save_mean, save_invstd, train, eps, output_mask
):
dX = torch.empty_like(input)
dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(weight)
return dX, dgamma, dbeta


@register_meta(aten.native_layer_norm.default)
def meta_ln(
input: torch.Tensor,
normalized_shape, weight, bias, eps
):
n_input = input.size(1)

output = torch.empty_like(input)
running_mean = torch.empty((n_input), device='meta')
running_var = torch.empty((n_input), device='meta')
return output, running_mean, running_var


@register_meta(aten.native_layer_norm_backward.default)
def meta_ln_backward(
dY: torch.Tensor,
input: torch.Tensor,
normalized_shape, mean, rstd, weight, bias, grad_input_mask
):
dX = torch.empty_like(input)
dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(bias)
return dX, dgamma, dbeta


@register_meta(aten._adaptive_avg_pool2d_backward.default)
def meta_adaptive_avg_pool2d_backward(
grad_output: torch.Tensor, input: torch.Tensor,
):
grad_input = torch.empty_like(input)
return torch.empty_like(input)


@register_meta(aten.index.Tensor)
def meta_index_Tensor(self, indices):
assert indices, "at least one index must be provided"
# aten::index is the internal advanced indexing implementation
# checkIndexTensorTypes and expandTensors
result: List[Optional[torch.Tensor]] = []
for i, index in enumerate(indices):
if index is not None:
assert index.dtype in [torch.long, torch.int8, torch.bool],\
"tensors used as indices must be long, byte or bool tensors"
if index.dtype in [torch.int8, torch.bool]:
nonzero = index.nonzero()
k = len(result)
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
for j in range(index.ndim):
assert index.shape[j] == self.shape[k + j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
result.append(nonzero.select(1, j))
else:
result.append(index)
else:
result.append(index)
indices = result
assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
# expand_outplace
import torch._refs as refs # avoid import cycle in mypy

indices = list(refs._maybe_broadcast(*indices))
# add missing null tensors
while len(indices) < self.ndim:
indices.append(None)

# hasContiguousSubspace
# true if all non-null tensors are adjacent
# See:
# https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
# https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
state = 0
has_contiguous_subspace = False
for index in indices:
if state == 0:
if index is not None:
state = 1
elif state == 1:
if index is None:
state = 2
else:
if index is not None:
break
else:
has_contiguous_subspace = True

# transposeToFront
# This is the logic that causes the newly inserted dimensions to show up
# at the beginning of the tensor, if they're not contiguous
if not has_contiguous_subspace:
dims = []
transposed_indices = []
for i, index in enumerate(indices):
if index is not None:
dims.append(i)
transposed_indices.append(index)
for i, index in enumerate(indices):
if index is None:
dims.append(i)
transposed_indices.append(index)
self = self.permute(dims)
indices = transposed_indices

# AdvancedIndex::AdvancedIndex
# Now we can assume the indices have contiguous subspace
# This is simplified from AdvancedIndex which goes to more effort
# to put the input and indices in a form so that TensorIterator can
# take them. If we write a ref for this, probably that logic should
# get implemented
before_shape: List[int] = []
after_shape: List[int] = []
replacement_shape: List[int] = []
for dim, index in enumerate(indices):
if index is None:
if replacement_shape:
after_shape.append(self.shape[dim])
else:
before_shape.append(self.shape[dim])
else:
replacement_shape = list(index.shape)
return self.new_empty(before_shape + replacement_shape + after_shape)
Loading