Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 102 additions & 95 deletions colossalai/_analyzer/_subclasses/_meta_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,15 @@ def meta_cudnn_rnn_backward(input: torch.Tensor,
aten.prelu.default,
aten.hardswish.default,
aten.hardtanh.default,
aten.prelu_backward.default,
aten.hardswish_backward.default,
aten.hardtanh_backward.default,
]

if version.parse(torch.__version__) < version.parse('2.0.0'):
_unregistered_ewise += [
aten.prelu_backward.default,
]

@register_meta(_unregistered_ewise)
def meta_unregistered_ewise(input: torch.Tensor, *args):
return new_like(input)
Expand Down Expand Up @@ -331,11 +335,6 @@ def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, me
def meta_im2col(input: torch.Tensor, kernel_size, dilation, padding, stride):
return new_like(input)

# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
@register_meta(aten.eye.m_out)
def meta_eye(n: int, m: int, out: torch.Tensor):
return out

# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
@register_meta(aten.roll.default)
def meta_roll(input: torch.Tensor, shifts, dims):
Expand All @@ -352,97 +351,9 @@ def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Te
result_type = torch.result_type(self, other)
return new_like(condition + self + other, dtype=result_type)

@register_meta(aten.index.Tensor)
def meta_index_Tensor(self, indices):
assert indices, "at least one index must be provided"
# aten::index is the internal advanced indexing implementation
# checkIndexTensorTypes and expandTensors
result: List[Optional[torch.Tensor]] = []
for i, index in enumerate(indices):
if index is not None:
assert index.dtype in [torch.long, torch.int8, torch.bool],\
"tensors used as indices must be long, byte or bool tensors"
if index.dtype in [torch.int8, torch.bool]:
nonzero = index.nonzero()
k = len(result)
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
for j in range(index.ndim):
assert index.shape[j] == self.shape[
k +
j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
result.append(nonzero.select(1, j))
else:
result.append(index)
else:
result.append(index)
indices = result
assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
# expand_outplace
import torch._refs as refs

indices = list(refs._maybe_broadcast(*indices))
# add missing null tensors
while len(indices) < self.ndim:
indices.append(None)

# hasContiguousSubspace
# true if all non-null tensors are adjacent
# See:
# https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
# https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
state = 0
has_contiguous_subspace = False
for index in indices:
if state == 0:
if index is not None:
state = 1
elif state == 1:
if index is None:
state = 2
else:
if index is not None:
break
else:
has_contiguous_subspace = True

# transposeToFront
# This is the logic that causes the newly inserted dimensions to show up
# at the beginning of the tensor, if they're not contiguous
if not has_contiguous_subspace:
dims = []
transposed_indices = []
for i, index in enumerate(indices):
if index is not None:
dims.append(i)
transposed_indices.append(index)
for i, index in enumerate(indices):
if index is None:
dims.append(i)
transposed_indices.append(index)
self = self.permute(dims)
indices = transposed_indices

# AdvancedIndex::AdvancedIndex
# Now we can assume the indices have contiguous subspace
# This is simplified from AdvancedIndex which goes to more effort
# to put the input and indices in a form so that TensorIterator can
# take them. If we write a ref for this, probably that logic should
# get implemented
before_shape: List[int] = []
after_shape: List[int] = []
replacement_shape: List[int] = []
for dim, index in enumerate(indices):
if index is None:
if replacement_shape:
after_shape.append(self.shape[dim])
else:
before_shape.append(self.shape[dim])
else:
replacement_shape = list(index.shape)
return self.new_empty(before_shape + replacement_shape + after_shape)

# ============================== Embedding =========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp

@register_meta(aten.embedding_dense_backward.default)
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
scale_grad_by_freq):
Expand All @@ -459,3 +370,99 @@ def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = Fal
@register_meta(aten.native_dropout_backward.default)
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
return new_like(grad) # (grad_in)

if version.parse(torch.__version__) < version.parse('1.13.0'):
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
@register_meta(aten.eye.m_out)
def meta_eye(n: int, m: int, out: torch.Tensor):
return out

@register_meta(aten.index.Tensor)
def meta_index_Tensor(self, indices):
assert indices, "at least one index must be provided"
# aten::index is the internal advanced indexing implementation
# checkIndexTensorTypes and expandTensors
result: List[Optional[torch.Tensor]] = []
for i, index in enumerate(indices):
if index is not None:
assert index.dtype in [torch.long, torch.int8, torch.bool],\
"tensors used as indices must be long, byte or bool tensors"
if index.dtype in [torch.int8, torch.bool]:
nonzero = index.nonzero()
k = len(result)
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
for j in range(index.ndim):
assert index.shape[j] == self.shape[
k +
j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
result.append(nonzero.select(1, j))
else:
result.append(index)
else:
result.append(index)
indices = result
assert len(
indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
# expand_outplace
import torch._refs as refs

indices = list(refs._maybe_broadcast(*indices))
# add missing null tensors
while len(indices) < self.ndim:
indices.append(None)

# hasContiguousSubspace
# true if all non-null tensors are adjacent
# See:
# https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
# https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
state = 0
has_contiguous_subspace = False
for index in indices:
if state == 0:
if index is not None:
state = 1
elif state == 1:
if index is None:
state = 2
else:
if index is not None:
break
else:
has_contiguous_subspace = True

# transposeToFront
# This is the logic that causes the newly inserted dimensions to show up
# at the beginning of the tensor, if they're not contiguous
if not has_contiguous_subspace:
dims = []
transposed_indices = []
for i, index in enumerate(indices):
if index is not None:
dims.append(i)
transposed_indices.append(index)
for i, index in enumerate(indices):
if index is None:
dims.append(i)
transposed_indices.append(index)
self = self.permute(dims)
indices = transposed_indices

# AdvancedIndex::AdvancedIndex
# Now we can assume the indices have contiguous subspace
# This is simplified from AdvancedIndex which goes to more effort
# to put the input and indices in a form so that TensorIterator can
# take them. If we write a ref for this, probably that logic should
# get implemented
before_shape: List[int] = []
after_shape: List[int] = []
replacement_shape: List[int] = []
for dim, index in enumerate(indices):
if index is None:
if replacement_shape:
after_shape.append(self.shape[dim])
else:
before_shape.append(self.shape[dim])
else:
replacement_shape = list(index.shape)
return self.new_empty(before_shape + replacement_shape + after_shape)