Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 29 additions & 45 deletions colossalai/fx/profiler/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
from torch.utils._pytree import tree_map


aten = torch.ops.aten

meta_lib = torch.library.Library("aten", "IMPL", "Meta")
Expand All @@ -14,16 +13,17 @@


def register_meta(op, register_dispatcher=True):

def wrapper(f):

def add_func(op):
meta_table[op] = f
if register_dispatcher:
name = (
op.__name__
if op._overloadname != "default"
else op.overloadpacket.__name__
)
meta_lib.impl(name, f)
name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__)
try:
meta_lib.impl(name, f)
except:
pass
Comment on lines +23 to +26
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change happens only here. I don't know why but the pre-commit hook reformatted all of my code...

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worry, pre-commit hook just standardized your code format.


tree_map(add_func, op)
return f
Expand All @@ -44,6 +44,7 @@ def meta_conv(
output_padding: List[int],
groups: int,
):

def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
"""
Formula to apply to calculate the length of some dimension of the output
Expand Down Expand Up @@ -120,14 +121,9 @@ def calc_conv_nd_return_shape(
kernel_size[i],
stride[i],
output_padding_list[i],
)
)
))
else:
ret_shape.append(
_formula(
dims[i], padding[i], dilation[i], kernel_size[i], stride[i]
)
)
ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]))
return ret_shape

def pick_memory_format():
Expand Down Expand Up @@ -156,20 +152,16 @@ def pick_memory_format():
out_channels = weight.shape[0]
if weight.shape[1] != input_tensor.shape[1] / groups:
raise RuntimeError("Invalid channel dimensions")
shape_out = calc_conv_nd_return_shape(
dims, kernel_size, stride, padding, dilation
)
shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation)
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
mem_fmt = pick_memory_format()
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
return out


@register_meta(aten.convolution_backward.default)
def meta_conv_backward(
grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
bias_sizes, stride, padding, dilation, transposed, output_padding, groups, output_mask
):
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
padding, dilation, transposed, output_padding, groups, output_mask):
return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta')


Expand All @@ -184,21 +176,18 @@ def meta_hardswish(input: torch.Tensor):


@register_meta(aten.hardswish_backward.default)
def meta_hardswish_backward(grad_out:torch.Tensor, input: torch.Tensor):
def meta_hardswish_backward(grad_out: torch.Tensor, input: torch.Tensor):
grad_in = torch.empty_like(input)
return grad_in


@register_meta([aten.roll.default, ])
def meta_roll(input:torch.Tensor, shifts, dims):
@register_meta(aten.roll.default)
def meta_roll(input: torch.Tensor, shifts, dims):
return torch.empty_like(input)


@register_meta(aten.native_batch_norm.default)
def meta_bn(
input: torch.Tensor,
weight, bias, running_mean, running_var, training, momentum, eps
):
def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
n_input = input.size(1)

output = torch.empty_like(input)
Expand All @@ -208,21 +197,16 @@ def meta_bn(


@register_meta(aten.native_batch_norm_backward.default)
def meta_bn_backward(
dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
running_mean, running_var, save_mean, save_invstd, train, eps, output_mask
):
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean,
save_invstd, train, eps, output_mask):
dX = torch.empty_like(input)
dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(weight)
return dX, dgamma, dbeta


@register_meta(aten.native_layer_norm.default)
def meta_ln(
input: torch.Tensor,
normalized_shape, weight, bias, eps
):
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
n_input = input.size(1)

output = torch.empty_like(input)
Expand All @@ -232,11 +216,8 @@ def meta_ln(


@register_meta(aten.native_layer_norm_backward.default)
def meta_ln_backward(
dY: torch.Tensor,
input: torch.Tensor,
normalized_shape, mean, rstd, weight, bias, grad_input_mask
):
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
grad_input_mask):
dX = torch.empty_like(input)
dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(bias)
Expand All @@ -245,7 +226,8 @@ def meta_ln_backward(

@register_meta(aten._adaptive_avg_pool2d_backward.default)
def meta_adaptive_avg_pool2d_backward(
grad_output: torch.Tensor, input: torch.Tensor,
grad_output: torch.Tensor,
input: torch.Tensor,
):
grad_input = torch.empty_like(input)
return torch.empty_like(input)
Expand All @@ -266,7 +248,9 @@ def meta_index_Tensor(self, indices):
k = len(result)
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
for j in range(index.ndim):
assert index.shape[j] == self.shape[k + j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
assert index.shape[j] == self.shape[
k +
j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
result.append(nonzero.select(1, j))
else:
result.append(index)
Expand All @@ -275,7 +259,7 @@ def meta_index_Tensor(self, indices):
indices = result
assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
# expand_outplace
import torch._refs as refs # avoid import cycle in mypy
import torch._refs as refs # avoid import cycle in mypy

indices = list(refs._maybe_broadcast(*indices))
# add missing null tensors
Expand Down