Skip to content
Merged
829 changes: 406 additions & 423 deletions colossalai/_analyzer/_subclasses/_meta_registration.py

Large diffs are not rendered by default.

76 changes: 41 additions & 35 deletions colossalai/_analyzer/_subclasses/_monkey_patch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.distributed as dist
from packaging import version

aten = torch.ops.aten

Expand Down Expand Up @@ -49,40 +50,45 @@
"scatter",
]

# TODO: dive deep here
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
_AliasATen = [
aten.detach.default,
aten.detach_.default,
aten.t.default,
aten.transpose.int,
aten.view.default,
aten._unsafe_view.default,
aten._reshape_alias.default,
]
if version.parse(torch.__version__) >= version.parse('1.12.0'):
# TODO: dive deep here
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
_AliasATen = [
aten.detach.default,
aten.detach_.default,
aten.t.default,
aten.transpose.int,
aten.view.default,
aten._unsafe_view.default,
aten._reshape_alias.default,
]

_InplaceATen = [
aten.add_.Tensor,
aten.add_.Scalar,
aten.sub_.Tensor,
aten.sub_.Scalar,
aten.mul_.Tensor,
aten.mul_.Scalar,
aten.div_.Tensor,
aten.div_.Scalar,
aten.pow_.Tensor,
aten.pow_.Scalar,
]
_InplaceATen = [
aten.add_.Tensor,
aten.add_.Scalar,
aten.sub_.Tensor,
aten.sub_.Scalar,
aten.mul_.Tensor,
aten.mul_.Scalar,
aten.div_.Tensor,
aten.div_.Scalar,
aten.pow_.Tensor,
aten.pow_.Scalar,
]

# use `MaybeInplace` because they call ``as_strided()`` or ``slice()``
_MaybeInplaceATen = [
aten.diagonal.default,
aten.expand.default,
aten.select.int,
aten.slice.Tensor,
aten.split.Tensor,
aten.squeeze.default,
aten.permute.default,
aten.unsqueeze.default,
aten.as_strided.default,
]
# use `MaybeInplace` because they call ``as_strided()`` or ``slice()``
_MaybeInplaceATen = [
aten.diagonal.default,
aten.expand.default,
aten.select.int,
aten.slice.Tensor,
aten.split.Tensor,
aten.squeeze.default,
aten.permute.default,
aten.unsqueeze.default,
aten.as_strided.default,
]
else:
_AliasATen = []
_InplaceATen = []
_MaybeInplaceATen = []
236 changes: 121 additions & 115 deletions colossalai/_analyzer/_subclasses/flop_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Any, Callable, List, Optional, Union

import torch
from packaging import version
from torch.utils._pytree import tree_map

from .meta_tensor import MetaTensor
Expand Down Expand Up @@ -403,134 +404,139 @@ def zero_flop_jit(*args):
return 0


flop_mapping = {
if version.parse(torch.__version__) >= version.parse('1.12.0'):
flop_mapping = {
# gemm
aten.mm.default: matmul_flop_jit,
aten.matmul.default: matmul_flop_jit,
aten.addmm.default: addmm_flop_jit,
aten.bmm.default: bmm_flop_jit,
aten.mm.default: matmul_flop_jit,
aten.matmul.default: matmul_flop_jit,
aten.addmm.default: addmm_flop_jit,
aten.bmm.default: bmm_flop_jit,

# convolution
aten.convolution.default: conv_flop_jit,
aten._convolution.default: conv_flop_jit,
aten.convolution_backward.default: conv_backward_flop_jit,
aten.convolution.default: conv_flop_jit,
aten._convolution.default: conv_flop_jit,
aten.convolution_backward.default: conv_backward_flop_jit,

# normalization
aten.native_batch_norm.default: batchnorm_flop_jit,
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
aten.native_layer_norm.default: norm_flop_counter(2, 0),
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
aten.native_batch_norm.default: batchnorm_flop_jit,
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
aten.native_layer_norm.default: norm_flop_counter(2, 0),
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),

# pooling
aten.avg_pool1d.default: ewise_flop_counter(1, 0),
aten.avg_pool2d.default: ewise_flop_counter(1, 0),
aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1),
aten.avg_pool3d.default: ewise_flop_counter(1, 0),
aten.avg_pool3d_backward.default: ewise_flop_counter(0, 1),
aten.max_pool1d.default: ewise_flop_counter(1, 0),
aten.max_pool2d.default: ewise_flop_counter(1, 0),
aten.max_pool3d.default: ewise_flop_counter(1, 0),
aten.max_pool1d_with_indices.default: ewise_flop_counter(1, 0),
aten.max_pool2d_with_indices.default: ewise_flop_counter(1, 0),
aten.max_pool2d_with_indices_backward.default: ewise_flop_counter(0, 1),
aten.max_pool3d_with_indices.default: ewise_flop_counter(1, 0),
aten.max_pool3d_with_indices_backward.default: ewise_flop_counter(0, 1),
aten._adaptive_avg_pool2d.default: ewise_flop_counter(1, 0),
aten._adaptive_avg_pool2d_backward.default: ewise_flop_counter(0, 1),
aten._adaptive_avg_pool3d.default: ewise_flop_counter(1, 0),
aten._adaptive_avg_pool3d_backward.default: ewise_flop_counter(0, 1),
aten.embedding_dense_backward.default: ewise_flop_counter(0, 1),
aten.embedding.default: ewise_flop_counter(1, 0),
}

ewise_flop_aten = [
aten.avg_pool1d.default: ewise_flop_counter(1, 0),
aten.avg_pool2d.default: ewise_flop_counter(1, 0),
aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1),
aten.avg_pool3d.default: ewise_flop_counter(1, 0),
aten.avg_pool3d_backward.default: ewise_flop_counter(0, 1),
aten.max_pool1d.default: ewise_flop_counter(1, 0),
aten.max_pool2d.default: ewise_flop_counter(1, 0),
aten.max_pool3d.default: ewise_flop_counter(1, 0),
aten.max_pool1d_with_indices.default: ewise_flop_counter(1, 0),
aten.max_pool2d_with_indices.default: ewise_flop_counter(1, 0),
aten.max_pool2d_with_indices_backward.default: ewise_flop_counter(0, 1),
aten.max_pool3d_with_indices.default: ewise_flop_counter(1, 0),
aten.max_pool3d_with_indices_backward.default: ewise_flop_counter(0, 1),
aten._adaptive_avg_pool2d.default: ewise_flop_counter(1, 0),
aten._adaptive_avg_pool2d_backward.default: ewise_flop_counter(0, 1),
aten._adaptive_avg_pool3d.default: ewise_flop_counter(1, 0),
aten._adaptive_avg_pool3d_backward.default: ewise_flop_counter(0, 1),
aten.embedding_dense_backward.default: ewise_flop_counter(0, 1),
aten.embedding.default: ewise_flop_counter(1, 0),
}

ewise_flop_aten = [
# basic op
aten.add.Tensor,
aten.add_.Tensor,
aten.div.Tensor,
aten.div_.Tensor,
aten.div.Scalar,
aten.div_.Scalar,
aten.mul.Tensor,
aten.mul.Scalar,
aten.mul_.Tensor,
aten.neg.default,
aten.pow.Tensor_Scalar,
aten.rsub.Scalar,
aten.sum.default,
aten.sum.dim_IntList,
aten.mean.dim,
aten.add.Tensor,
aten.add_.Tensor,
aten.div.Tensor,
aten.div_.Tensor,
aten.div.Scalar,
aten.div_.Scalar,
aten.mul.Tensor,
aten.mul.Scalar,
aten.mul_.Tensor,
aten.neg.default,
aten.pow.Tensor_Scalar,
aten.rsub.Scalar,
aten.sum.default,
aten.sum.dim_IntList,
aten.mean.dim,

# activation op
aten.hardswish.default,
aten.hardswish_.default,
aten.hardswish_backward.default,
aten.hardtanh.default,
aten.hardtanh_.default,
aten.hardtanh_backward.default,
aten.hardsigmoid_backward.default,
aten.hardsigmoid.default,
aten.gelu.default,
aten.gelu_backward.default,
aten.silu.default,
aten.silu_.default,
aten.silu_backward.default,
aten.sigmoid.default,
aten.sigmoid_backward.default,
aten._softmax.default,
aten._softmax_backward_data.default,
aten.relu_.default,
aten.relu.default,
aten.tanh.default,
aten.tanh_backward.default,
aten.threshold_backward.default,
aten.hardswish.default,
aten.hardswish_.default,
aten.hardswish_backward.default,
aten.hardtanh.default,
aten.hardtanh_.default,
aten.hardtanh_backward.default,
aten.hardsigmoid_backward.default,
aten.hardsigmoid.default,
aten.gelu.default,
aten.gelu_backward.default,
aten.silu.default,
aten.silu_.default,
aten.silu_backward.default,
aten.sigmoid.default,
aten.sigmoid_backward.default,
aten._softmax.default,
aten._softmax_backward_data.default,
aten.relu_.default,
aten.relu.default,
aten.tanh.default,
aten.tanh_backward.default,
aten.threshold_backward.default,

# dropout
aten.native_dropout.default,
aten.native_dropout_backward.default,
aten.native_dropout.default,
aten.native_dropout_backward.default,

# distribution
aten.bernoulli_.float,
aten.bernoulli_.float,

# where
aten.where.self,
]
for op in ewise_flop_aten:
flop_mapping[op] = ewise_flop_counter(1, 0)

# fix-me: this will be removed in future
zero_flop_aten = [
aten.as_strided.default,
aten.as_strided_.default,
aten.cat.default,
aten.clone.default,
aten.copy_.default,
aten.detach.default,
aten.expand.default,
aten.empty_like.default,
aten.new_empty.default,
aten.new_empty_strided.default,
aten.ones_like.default,
aten._reshape_alias.default,
aten.select.int,
aten.select_backward.default,
aten.squeeze.dim,
aten.slice.Tensor,
aten.slice_backward.default,
aten.split.Tensor,
aten.permute.default,
aten.t.default,
aten.transpose.int,
aten._to_copy.default,
aten.unsqueeze.default,
aten.unbind.int,
aten._unsafe_view.default,
aten.view.default,
aten.zero_.default,
aten.zeros_like.default,
]

for op in zero_flop_aten:
flop_mapping[op] = zero_flop_jit
aten.where.self,
]
for op in ewise_flop_aten:
flop_mapping[op] = ewise_flop_counter(1, 0)

# fix-me: this will be removed in future
zero_flop_aten = [
aten.as_strided.default,
aten.as_strided_.default,
aten.cat.default,
aten.clone.default,
aten.copy_.default,
aten.detach.default,
aten.expand.default,
aten.empty_like.default,
aten.new_empty.default,
aten.new_empty_strided.default,
aten.ones_like.default,
aten._reshape_alias.default,
aten.select.int,
aten.select_backward.default,
aten.squeeze.dim,
aten.slice.Tensor,
aten.slice_backward.default,
aten.split.Tensor,
aten.permute.default,
aten.t.default,
aten.transpose.int,
aten._to_copy.default,
aten.unsqueeze.default,
aten.unbind.int,
aten._unsafe_view.default,
aten.view.default,
aten.zero_.default,
aten.zeros_like.default,
]

for op in zero_flop_aten:
flop_mapping[op] = zero_flop_jit
else:
flop_mapping = {}
elementwise_flop_aten = {}
zero_flop_aten = {}
3 changes: 1 addition & 2 deletions colossalai/_analyzer/fx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .bias_addition import *
from .node_util import MetaInfo
from .symbolic_profile import symbolic_profile
from .symbolic_trace import symbolic_trace
from .tracer.symbolic_trace import symbolic_trace
Loading