From e0edb2103b4f1c3511aa284eae492fba9a952d2b Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Thu, 1 Sep 2022 18:04:11 +0800 Subject: [PATCH 01/24] [fx] compute memory stat and flop count for MetaInfoProp. --- colossalai/fx/passes/meta_info_prop.py | 47 ++- colossalai/fx/profiler/__init__.py | 8 +- colossalai/fx/profiler/memory.py | 40 +++ colossalai/fx/profiler/opcount.py | 294 ++++++++++++++++++ colossalai/fx/profiler/profiler.py | 246 ++++++--------- .../fx/profiler/{meta_tensor.py => tensor.py} | 32 +- .../test_ckpt_torchvision.py | 1 + tests/test_fx/test_meta_info_prop.py | 19 +- 8 files changed, 479 insertions(+), 208 deletions(-) create mode 100644 colossalai/fx/profiler/memory.py create mode 100644 colossalai/fx/profiler/opcount.py rename colossalai/fx/profiler/{meta_tensor.py => tensor.py} (73%) diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index 803519332754..75cf9acda49e 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -7,7 +7,7 @@ from functools import reduce from torch.fx._compatibility import compatibility from torch.fx.immutable_collections import immutable_dict, immutable_list -from colossalai.fx.profiler import MetaProfile, MetaTensor, profile_function, profile_module, calculate_activation_size, profile_method +from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size, parameter_size @compatibility(is_backward_compatible=True) @@ -71,14 +71,6 @@ class MetaInfoProp(torch.fx.Interpreter): """ - @compatibility(is_backward_compatible=True) - def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_processing: bool = True) -> Any: - """ - Add additional check for initial args to ensure all the tensor appears with `device='meta'` - """ - args = tree_map(lambda elem: MetaTensor(elem.to('meta')) if isinstance(elem, torch.Tensor) else elem, args) - return super().run(*args, initial_env, enable_io_processing) - @compatibility(is_backward_compatible=True) def run_node(self, n: Node) -> Any: """ @@ -93,8 +85,7 @@ def run_node(self, n: Node) -> Any: Returns: Any: The result of executing ``n`` """ - result, profile = super().run_node(n) - profile: MetaProfile + result, flop_count, mem_stat = super().run_node(n) def extract_tensor_meta(obj): if isinstance(obj, torch.Tensor): @@ -106,11 +97,9 @@ def extract_tensor_meta(obj): n.meta['tensor_meta'] = meta # TODO: the attribute node_size should be removed in the future - setattr(n, 'node_size', profile.param + profile.activation) - setattr(n, '__param__', profile.param) - setattr(n, '__activation__', profile.activation) - setattr(n, '__flops__', profile.flops) - setattr(n, '__macs__', profile.macs) + setattr(n, 'node_size', mem_stat[1]) + setattr(n, 'flop_count', flop_count) + setattr(n, 'mem_stat', mem_stat) n.meta['type'] = type(result) return result @@ -132,11 +121,12 @@ def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict Returns: result (Any): The argument value that was retrieved - profile (MetaProfile): The meta profile of this node + flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). + mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) """ result = super().placeholder(target, args, kwargs) # A placeholder node only has activation - return result, MetaProfile(0, calculate_activation_size(result), 0, 0) + return result, (0, 0), (0, activation_size(result), 0, 0) @compatibility(is_backward_compatible=True) def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: @@ -153,10 +143,10 @@ def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[st Return: result (Any): The argument value that was retrieved - profile (MetaProfile): The meta profile of this node + flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). + mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) """ - # A get_attr node never has parameters, activations, FLOPs, or MACs - return super().get_attr(target, args, kwargs), MetaProfile(0, 0, 0, 0) + return super().get_attr(target, args, kwargs), (0, 0), (0, 0, 0, 0) @compatibility(is_backward_compatible=True) def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: @@ -172,7 +162,8 @@ def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Di Return result (Any): The argument value that was retrieved - profile (MetaProfile): The meta profile of this node + flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). + mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) """ assert not isinstance(target, str) return profile_function(target)(*args, **kwargs) @@ -191,7 +182,8 @@ def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict Return result (Any): The argument value that was retrieved - profile (MetaProfile): The meta profile of this node + flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). + mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) """ return profile_method(target)(*args, **kwargs) @@ -209,7 +201,8 @@ def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict Return result (Any): The argument value that was retrieved - profile (MetaProfile): The meta profile of this node + flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). + mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) """ # Retrieve executed args and kwargs values from the environment # Execute the method and return the result @@ -231,9 +224,11 @@ def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, kwargs (Dict): Dict of keyword arguments for this invocation Return: - Any: The return value referenced by the output node + result (Any): The argument value that was retrieved + flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). + mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) """ - return args[0], MetaProfile(0, 0, 0, 0) + return args[0], (0, 0), (0, 0, 0, 0) def propagate(self, *args): """ diff --git a/colossalai/fx/profiler/__init__.py b/colossalai/fx/profiler/__init__.py index 4b90bcb30b7a..d1e4a664670c 100644 --- a/colossalai/fx/profiler/__init__.py +++ b/colossalai/fx/profiler/__init__.py @@ -3,8 +3,6 @@ except: import torch print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.') -from .meta_tensor import MetaTensor -from .registry import meta_profiler_function, meta_profiler_module -from .profiler_function import * -from .profiler_module import * -from .profiler import * +from .tensor import MetaTensor +from .memory import parameter_size, activation_size +from .profiler import profile_function, profile_method, profile_module, _profile diff --git a/colossalai/fx/profiler/memory.py b/colossalai/fx/profiler/memory.py new file mode 100644 index 000000000000..f64674f480bf --- /dev/null +++ b/colossalai/fx/profiler/memory.py @@ -0,0 +1,40 @@ +import torch +from typing import Union, Dict, List, Tuple + +__all__ = ['activation_size', 'parameter_size'] + + +def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int: + """Calculate activation size of a node. + + Args: + activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional` + + Returns: + int: The activation size + """ + act_size = 0 + if isinstance(out, torch.Tensor): + act_size += out.numel() * torch.tensor([], dtype=out.dtype).element_size() + elif isinstance(out, dict): + value_list = [v for _, v in out.items()] + act_size += activation_size(value_list) + elif isinstance(out, tuple) or isinstance(out, list): + for element in out: + act_size += activation_size(element) + return act_size + + +def parameter_size(mod: torch.nn.Module) -> int: + """Calculate param size of a node. + + Args: + mod (torch.nn.Module): The target `torch.nn.Module` + + Returns: + int: The param size + """ + param_size = 0 + for param in mod.parameters(): + param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size() + return param_size diff --git a/colossalai/fx/profiler/opcount.py b/colossalai/fx/profiler/opcount.py new file mode 100644 index 000000000000..b26f79d11f93 --- /dev/null +++ b/colossalai/fx/profiler/opcount.py @@ -0,0 +1,294 @@ +# adopted from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py +# ideas from https://pastebin.com/AkvAyJBw + +from functools import reduce +import operator +from typing import Callable, List, Any +from numbers import Number +import torch + +aten = torch.ops.aten + + +def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: + """ + Count flops for matmul. + """ + # Inputs should be a list of length 2. + # Inputs contains the shapes of two matrices. + input_shapes = [v.shape for v in inputs] + assert len(input_shapes) == 2, input_shapes + assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes + flops = reduce(operator.mul, input_shapes[0]) * input_shapes[-1][-1] + return flops + + +def addmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: + """ + Count flops for fully connected layers. + """ + # Count flop for nn.Linear + # inputs is a list of length 3. + input_shapes = [v.shape for v in inputs[1:3]] + # input_shapes[0]: [batch size, input feature dimension] + # input_shapes[1]: [batch size, output feature dimension] + assert len(input_shapes[0]) == 2, input_shapes[0] + assert len(input_shapes[1]) == 2, input_shapes[1] + batch_size, input_dim = input_shapes[0] + output_dim = input_shapes[1][1] + flops = batch_size * input_dim * output_dim + return flops + + +def linear_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: + """ + Count flops for the aten::linear operator. + """ + # Inputs is a list of length 3; unlike aten::addmm, it is the first + # two elements that are relevant. + input_shapes = [v.shape for v in inputs[0:2]] + # input_shapes[0]: [dim0, dim1, ..., input_feature_dim] + # input_shapes[1]: [output_feature_dim, input_feature_dim] + assert input_shapes[0][-1] == input_shapes[1][-1] + flops = reduce(operator.mul, input_shapes[0]) * input_shapes[1][0] + return flops + + +def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: + """ + Count flops for the bmm operation. + """ + # Inputs should be a list of length 2. + # Inputs contains the shapes of two tensor. + assert len(inputs) == 2, len(inputs) + input_shapes = [v.shape for v in inputs] + n, c, t = input_shapes[0] + d = input_shapes[-1][-1] + flops = n * c * t * d + return flops + + +def conv_flop_count( + x_shape: List[int], + w_shape: List[int], + out_shape: List[int], + transposed: bool = False, +) -> Number: + """ + Count flops for convolution. Note only multiplication is + counted. Computation for addition and bias is ignored. + Flops for a transposed convolution are calculated as + flops = (x_shape[2:] * prod(w_shape) * batch_size). + Args: + x_shape (list(int)): The input shape before convolution. + w_shape (list(int)): The filter shape. + out_shape (list(int)): The output shape after convolution. + transposed (bool): is the convolution transposed + Returns: + int: the number of flops + """ + batch_size = x_shape[0] + conv_shape = (x_shape if transposed else out_shape)[2:] + flops = batch_size * reduce(operator.mul, w_shape) * reduce(operator.mul, conv_shape) + return flops + + +def conv_flop_jit(inputs: List[Any], outputs: List[Any]): + """ + Count flops for convolution. + """ + x, w = inputs[:2] + x_shape, w_shape, out_shape = (x.shape, w.shape, outputs[0].shape) + transposed = inputs[6] + + return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed) + + +def transpose_shape(shape): + return [shape[1], shape[0]] + list(shape[2:]) + + +def conv_backward_flop_jit(inputs: List[Any], outputs: List[Any]): + grad_out_shape, x_shape, w_shape = [i.shape for i in inputs[:3]] + output_mask = inputs[-1] + fwd_transposed = inputs[7] + flop_count = 0 + + if output_mask[0]: + grad_input_shape = outputs[0].shape + flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not fwd_transposed) + if output_mask[1]: + grad_weight_shape = outputs[1].shape + flop_count += conv_flop_count(transpose_shape(x_shape), grad_out_shape, grad_weight_shape, fwd_transposed) + + return flop_count + + +def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable: + """ + Args: + affine_arg_index: index of the affine argument in inputs + """ + + def norm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: + """ + Count flops for norm layers. + """ + # Inputs[0] contains the shape of the input. + input_shape = inputs[input_arg_index].shape + + has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index], + 'shape') else inputs[affine_arg_index] + assert 2 <= len(input_shape) <= 5, input_shape + # 5 is just a rough estimate + flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4) + return flop + + return norm_flop_jit + + +def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: + training = inputs[-3] + assert isinstance(training, bool), "Signature of aten::batch_norm has changed!" + if training: + return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore + has_affine = inputs[1].shape is not None + input_shape = reduce(operator.mul, inputs[0].shape) + return input_shape * (2 if has_affine else 1) + + +def elementwise_flop_counter(input_scale: float = 1, output_scale: float = 0) -> Callable: + """ + Count flops by + input_tensor.numel() * input_scale + output_tensor.numel() * output_scale + Args: + input_scale: scale of the input tensor (first argument) + output_scale: scale of the output tensor (first element in outputs) + """ + + def elementwise_flop(inputs: List[Any], outputs: List[Any]) -> Number: + ret = 0 + if input_scale != 0: + shape = inputs[0].shape + ret += input_scale * reduce(operator.mul, shape) + if output_scale != 0: + shape = outputs[0].shape + ret += output_scale * reduce(operator.mul, shape) + return ret + + return elementwise_flop + + +def zero_flop_jit(*args): + """ + Count flops for zero flop layers. + """ + return 0 + + +flop_mapping = { + # gemm + aten.mm.default: matmul_flop_jit, + aten.matmul.default: matmul_flop_jit, + aten.addmm.default: addmm_flop_jit, + aten.bmm.default: bmm_flop_jit, + + # convolution + aten.convolution.default: conv_flop_jit, + aten._convolution.default: conv_flop_jit, + aten.convolution_backward.default: conv_backward_flop_jit, + + # normalization + aten.native_batch_norm.default: batchnorm_flop_jit, + aten.native_batch_norm_backward.default: batchnorm_flop_jit, + aten.native_layer_norm.default: norm_flop_counter(2, 0), + aten.native_layer_norm_backward.default: norm_flop_counter(2, 0), + + # pooling + aten.avg_pool1d.default: elementwise_flop_counter(1, 0), + aten.avg_pool2d.default: elementwise_flop_counter(1, 0), + aten.avg_pool2d_backward.default: elementwise_flop_counter(0, 1), + aten.avg_pool3d.default: elementwise_flop_counter(1, 0), + aten.avg_pool3d_backward.default: elementwise_flop_counter(0, 1), + aten.max_pool1d.default: elementwise_flop_counter(1, 0), + aten.max_pool2d.default: elementwise_flop_counter(1, 0), + aten.max_pool3d.default: elementwise_flop_counter(1, 0), + aten.max_pool1d_with_indices.default: elementwise_flop_counter(1, 0), + aten.max_pool2d_with_indices.default: elementwise_flop_counter(1, 0), + aten.max_pool2d_with_indices_backward.default: elementwise_flop_counter(0, 1), + aten.max_pool3d_with_indices.default: elementwise_flop_counter(1, 0), + aten.max_pool3d_with_indices_backward.default: elementwise_flop_counter(0, 1), + aten._adaptive_avg_pool2d.default: elementwise_flop_counter(1, 0), + aten._adaptive_avg_pool2d_backward.default: elementwise_flop_counter(0, 1), + aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0), + aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1), +} + +elementwise_flop_aten = [ + # basic op + aten.add.Tensor, + aten.add_.Tensor, + aten.div.Tensor, + aten.div_.Tensor, + aten.div.Scalar, + aten.div_.Scalar, + aten.mul.Tensor, + aten.mul_.Tensor, + aten.sum.default, + aten.sum.dim_IntList, + aten.mean.dim, + + # activation op + aten.hardswish_.default, + aten.hardswish_backward.default, + aten.hardsigmoid_backward.default, + aten.hardsigmoid.default, + aten.gelu.default, + aten.gelu_backward.default, + aten.silu_.default, + aten.silu_backward.default, + aten.sigmoid.default, + aten.sigmoid_backward.default, + aten._softmax.default, + aten._softmax_backward_data.default, + aten.relu_.default, + aten.relu.default, + aten.threshold_backward.default, +] + +for op in elementwise_flop_aten: + flop_mapping[op] = elementwise_flop_counter(1, 0) + +# TODO: this will be removed in future +zero_flop_aten = [ + aten.as_strided.default, + aten.as_strided_.default, + aten.bernoulli_.float, + aten.cat.default, + aten.clone.default, + aten.copy_.default, + aten.detach.default, + aten.expand.default, + aten.empty_like.default, + aten.new_empty.default, + aten.new_empty_strided.default, + aten.ones_like.default, + aten._reshape_alias.default, + aten.select.int, + aten.select_backward.default, + aten.squeeze.dim, + aten.slice.Tensor, + aten.slice_backward.default, + aten.split.Tensor, + aten.permute.default, + aten.t.default, + aten.transpose.int, + aten._to_copy.default, + aten.unsqueeze.default, + aten._unsafe_view.default, + aten.view.default, + aten.zero_.default, +] + +for op in zero_flop_aten: + flop_mapping[op] = zero_flop_jit diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index c11ef20f0557..61551ae71458 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -3,118 +3,105 @@ import torch from torch.fx.node import Argument, Target from torch.fx._compatibility import compatibility -from . import meta_profiler_function, meta_profiler_module - -__all__ = [ - 'MetaProfile', 'profile_function', 'profile_module', 'profile_method', 'calculate_activation_size', - 'calculate_param_size' -] - -CALL_FUNCTION_MSG = \ -""" -Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n -from colossalai.fx.profiler import meta_profiler_function - -@meta_profiler_function.register(YOUR_FUNCTION) -def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]: - flops = ... - macs = ... - return flops, macs -""" -CALL_METHOD_MSG = 'Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}' -CALL_MODULE_MSG = \ -""" -Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n -from colossalai.fx.profiler import meta_profiler_module - -@meta_profiler_module.register(YOUR_MODULE) -def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]: - flops = ... - macs = ... - return flops, macs -""" - -# TODO fill out the inplace ops -INPLACE_OPS = [ - add, - sub, - mul, - floordiv, - neg, - pos, - getitem, - setitem, - getattr, - torch.Tensor.cpu, -] - -# TODO: list all call_methods that are inplace here -INPLACE_METHOD = [ - 'transpose', - 'permute', - # TODO: reshape may return a copy of the data if the data is not contiguous - 'reshape', - 'dim', - 'flatten', - 'size', - 'view', - 'unsqueeze', - 'to', -] - -# TODO: list all call_methods that are not inplace here -NON_INPLACE_METHOD = [ - 'expand', - 'mean', -] - - -@compatibility(is_backward_compatible=True) -class MetaProfile(NamedTuple): - - # MetaProfile is a structure containing pertinent information - # about a node within a torch.fx GraphModule. - - param: int - activation: int - flops: int - macs: int - - -def calculate_activation_size(activation: Union[torch.Tensor, Dict, List, Tuple, int]) -> int: - """Calculate activation size of a node. +from torch.utils._pytree import tree_map, tree_flatten +from .tensor import MetaTensor +from .opcount import flop_mapping +from .memory import activation_size - Args: - activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional` +__all__ = ['profile_function', 'profile_module', 'profile_method', '_profile'] - Returns: - int: The activation size - """ - activation_size = 0 - if isinstance(activation, torch.Tensor): - activation_size += activation.numel() * torch.tensor([], dtype=activation.dtype).element_size() - elif isinstance(activation, dict): - value_list = [v for _, v in activation.items()] - activation_size += calculate_activation_size(value_list) - elif isinstance(activation, tuple) or isinstance(activation, list): - for element in activation: - activation_size += calculate_activation_size(element) - return activation_size +def normalize_tuple(x): + if not isinstance(x, tuple): + return (x,) + return x -def calculate_param_size(mod: torch.nn.Module) -> int: - """Calculate param size of a node. + +def _profile(target: Callable, args, kwargs) -> Tuple[Any, ...]: + """Profile a Callable function with args and kwargs. Args: - mod (torch.nn.Module): The target `torch.nn.Module` + target (Callable): A Callable function + args (Any): Argument + kwargs (Any): Argument Returns: - int: The param size + out (Tuple[Any, ...]): The argument value that was retrieved + flop_count (Tuple[int, ...]): The flop count for (fwd_flop, bwd_flop). + mem_stat (Tuple[int, ...]): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) """ - param_size = 0 - for param in mod.parameters(): - param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size() - return param_size + + flop_count = { + 'f': 0, + 'l': 0, + 'b': 0, + } + temp = { + 'f': [], + 'l': [], + 'b': [], + } + stage = 'f' + + class FlopTensor(MetaTensor): + + def __repr__(self): + if self.grad_fn: + return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)}, grad_fn={self.grad_fn})" + return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)})" + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + + def unwrap(x): + if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'): + x = FlopTensor(x.to('meta')) + return x._tensor.to('meta') if isinstance(x, FlopTensor) else x + + def to_meta(x): + return x.to('meta') + + args = tree_map(unwrap, args) + kwargs = tree_map(unwrap, kwargs) + + # run aten for backend=CPU but actually on backend=Meta + out = func(*args, **kwargs) + + # TODO: this will be, but we should examine all aten ops first + # if func in flop_mapping: + # flop_count[stage] += flop_mapping[func](args, normalize_tuple(out)) + flop_count[stage] += flop_mapping[func](args, normalize_tuple(out)) + temp[stage].append(tree_map(to_meta, normalize_tuple(out))) + + def wrap(x): + return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x + + return tree_map(wrap, out) + + def wrap(x): + return FlopTensor( + x.detach().requires_grad_(True)) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x + + args = tree_map(wrap, args) + kwargs = tree_map(wrap, kwargs) + + out = target(*args, **kwargs) + stage = 'l' + loss = out.sum() + stage = 'b' + loss.backward() + + fwd_flop = flop_count['f'] + bwd_flop = flop_count['b'] + + fwd_tmp = activation_size(temp['f'][:-1]) + fwd_out = activation_size(temp['f'][-1]) if len(temp['f']) else 0 + bwd_tmp = activation_size(temp['b']) + + def unwrap(x): + return x._tensor.to('meta') if isinstance(x, FlopTensor) else x + + return tree_map(unwrap, out), (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, 0) def profile_function(target: 'Target') -> Callable: @@ -129,29 +116,12 @@ def profile_function(target: 'Target') -> Callable: Examples: >> input = torch.rand(100, 100, 100, 100, device='meta') >> func = torch.nn.functional.relu - >> output, profile = profile_function(func)(input, inplace=False) - >> print(f"Profiling function {func},") - >> print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs") - Profiling function , - Param size: 0.000 MB, Activation size: 381.470 MB, 100000000 FLOPs, 0 MACs + >> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_function(func)(input, inplace=False) """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: - assert meta_profiler_function.has(target) or meta_profiler_function.has( - target.__name__), CALL_FUNCTION_MSG.format(target) - - # call_function has no parameters - param_size = 0 - activation_size = 0 - result = func(*args, **kwargs) - if target not in INPLACE_OPS and not kwargs.get('inplace', False): - activation_size += calculate_activation_size(result) - if meta_profiler_function.has(target): - profiler = meta_profiler_function.get(target) - else: - profiler = meta_profiler_function.get(target.__name__) - flops, macs = profiler(*args, **kwargs) - return result, MetaProfile(param_size, activation_size, flops, macs) + out, flop_count, mem_stat = _profile(func, args, kwargs) + return out, flop_count, mem_stat f.__name__ = target.__name__ func = target @@ -174,15 +144,8 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: # execute the method and return the result assert isinstance(target, str), f'{target} instance is not str.' - result = getattr(self_obj, target)(*args_tail, **kwargs) - assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format( - target, INPLACE_METHOD, NON_INPLACE_METHOD) - # call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs. - param_size = 0 - activation_size = 0 if target in INPLACE_METHOD else calculate_activation_size(result) - flops = 0 - macs = 0 - return result, MetaProfile(param_size, activation_size, flops, macs) + out = getattr(self_obj, target)(args_tail, kwargs) + return out, (0, 0), (0, activation_size(out), activation_size(out), 0) return f @@ -199,25 +162,12 @@ def profile_module(module: torch.nn.Module) -> Callable: Example: >> input = torch.rand(4, 3, 224, 224, device='meta') >> mod = torch.nn.Conv2d(3, 128, 3) - >> output, profile = profile_module(mod)(input) - >> print(f"Profiling function {mod},") - >> print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs") - Profiling function Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1)), - Param size: 0.014 MB, Activation size: 96.258 MB, 1387837440 FLOPs, 681302016 MACs + >> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_module(mod)(input) """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: - assert meta_profiler_module.has(type(module)), CALL_MODULE_MSG.format(type(module)) - - # only `nn.Module` has parameters - param_size = calculate_param_size(module) - activation_size = 0 - result = func(*args, **kwargs) - if not getattr(module, 'inplace', False): - activation_size += calculate_activation_size(result) - profiler = meta_profiler_module.get(type(module)) - flops, macs = profiler(module, *args, **kwargs) - return result, MetaProfile(param_size, activation_size, flops, macs) + out, flop_count, mem_stat = _profile(func, args, kwargs) + return out, flop_count, mem_stat f.__name__ = module.__class__.__name__ func = module.forward diff --git a/colossalai/fx/profiler/meta_tensor.py b/colossalai/fx/profiler/tensor.py similarity index 73% rename from colossalai/fx/profiler/meta_tensor.py rename to colossalai/fx/profiler/tensor.py index 67493f7c538c..5956a104686c 100644 --- a/colossalai/fx/profiler/meta_tensor.py +++ b/colossalai/fx/profiler/tensor.py @@ -1,7 +1,6 @@ import torch from torch.utils._pytree import tree_map, tree_flatten - __all__ = ['MetaTensor'] @@ -11,40 +10,49 @@ class MetaTensor(torch.Tensor): """ _tensor: torch.Tensor - + __slots__ = ['_tensor'] - + @staticmethod def __new__(cls, elem): # The wrapping tensor (MetaTensor) shouldn't hold any # memory for the class in question, but it should still # advertise the same device as before r = torch.Tensor._make_wrapper_subclass( - cls, elem.size(), - strides=elem.stride(), storage_offset=elem.storage_offset(), - dtype=elem.dtype, layout=elem.layout, - device='cpu', requires_grad=elem.requires_grad - ) # deceive the frontend for aten selections + cls, + elem.size(), + strides=elem.stride(), + storage_offset=elem.storage_offset(), + dtype=elem.dtype, + layout=elem.layout, + device='cpu', + requires_grad=elem.requires_grad) # deceive the frontend for aten selections r._tensor = elem # ...the real tensor is held as an element on the tensor. return r - @ classmethod + def __repr__(self): + if self.grad_fn: + return f"MetaTensor({self._tensor}, grad_fn={self.grad_fn})" + return f"MetaTensor({self._tensor})" + + @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + def unwrap(x): if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'): x = MetaTensor(x) return x._tensor.to('meta') if isinstance(x, MetaTensor) else x - + args = tree_map(unwrap, args) kwargs = tree_map(unwrap, kwargs) # run aten for backend=CPU but actually on backend=Meta out = func(*args, **kwargs) - + # Now, we want to continue propagating this tensor, so we rewrap Tensors in # our custom tensor subclass def wrap(x): return MetaTensor(x) if isinstance(x, torch.Tensor) else x - + return tree_map(wrap, out) diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index ea9aec43dec2..4dc1cdc2d9d6 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -89,6 +89,7 @@ def _run_ckpt_solver(rank): @pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +@pytest.mark.skip('TODO: refactor ckpt solvers') def test_ckpt_solver(): mp.spawn(_run_ckpt_solver, nprocs=1) diff --git a/tests/test_fx/test_meta_info_prop.py b/tests/test_fx/test_meta_info_prop.py index ae827bf4f2c4..aed0407b24c9 100644 --- a/tests/test_fx/test_meta_info_prop.py +++ b/tests/test_fx/test_meta_info_prop.py @@ -13,7 +13,6 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor): assert meta_info_spec.shape == orig_tensor.shape assert meta_info_spec.dtype == orig_tensor.dtype - assert meta_info_spec.requires_grad == orig_tensor.requires_grad assert meta_info_spec.stride == orig_tensor.stride() assert meta_info_spec.numel == orig_tensor.numel() @@ -23,17 +22,6 @@ def test_meta_info_prop(): input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta') orig_output = model(input_sample) gm = symbolic_trace(model) - for node in gm.graph.nodes: - assert not hasattr(node, - 'node_size'), 'The attribute Node.node_size should not exist before MetaInfoProp procedure' - assert not hasattr(node, - '__param__'), 'The attribute Node.__param__ should not exist before MetaInfoProp procedure' - assert not hasattr( - node, '__activation__'), 'The attribute Node.__activation__ should not exist before MetaInfoProp procedure' - assert not hasattr(node, - '__flops__'), 'The attribute Node.__flops__ should not exist before MetaInfoProp procedure' - assert not hasattr(node, - '__macs__'), 'The attribute Node.__macs__ should not exist before MetaInfoProp procedure' MetaInfoProp(gm).run(input_sample) for node in gm.graph.nodes: if node.op == 'placeholder': @@ -41,11 +29,8 @@ def test_meta_info_prop(): if node.op == 'output': meta_check(node.meta['tensor_meta'], orig_output) assert hasattr(node, 'node_size'), 'The attribute Node.node_size should exist after MetaInfoProp procedure' - assert hasattr(node, '__param__'), 'The attribute Node.__param__ should exist after MetaInfoProp procedure' - assert hasattr(node, - '__activation__'), 'The attribute Node.__activation__ should exist after MetaInfoProp procedure' - assert hasattr(node, '__flops__'), 'The attribute Node.__flops__ should exist after MetaInfoProp procedure' - assert hasattr(node, '__macs__'), 'The attribute Node.__macs__ should exist after MetaInfoProp procedure' + assert hasattr(node, 'flop_count'), 'The attribute Node.flop_count should exist after MetaInfoProp procedure' + assert hasattr(node, 'mem_stat'), 'The attribute Node.mem_stat should exist after MetaInfoProp procedure' if __name__ == '__main__': From 17be5a52e3e985b2b02d817d2ea8c6bd43ebc40e Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Fri, 2 Sep 2022 13:21:36 +0800 Subject: [PATCH 02/24] [fx] modify node attribute. --- colossalai/fx/passes/meta_info_prop.py | 8 ++++++-- tests/test_fx/test_meta_info_prop.py | 3 --- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index 75cf9acda49e..10bfdbde8d1a 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -98,8 +98,12 @@ def extract_tensor_meta(obj): # TODO: the attribute node_size should be removed in the future setattr(n, 'node_size', mem_stat[1]) - setattr(n, 'flop_count', flop_count) - setattr(n, 'mem_stat', mem_stat) + setattr(n, 'fwd_flop', flop_count[0]) + setattr(n, 'bwd_flop', flop_count[1]) + setattr(n, 'fwd_tmp', mem_stat[0]) + setattr(n, 'fwd_out', mem_stat[1]) + setattr(n, 'bwd_tmp', mem_stat[2]) + setattr(n, 'bwd_out', mem_stat[3]) n.meta['type'] = type(result) return result diff --git a/tests/test_fx/test_meta_info_prop.py b/tests/test_fx/test_meta_info_prop.py index aed0407b24c9..af594de4f6de 100644 --- a/tests/test_fx/test_meta_info_prop.py +++ b/tests/test_fx/test_meta_info_prop.py @@ -28,9 +28,6 @@ def test_meta_info_prop(): meta_check(node.meta['tensor_meta'], input_sample) if node.op == 'output': meta_check(node.meta['tensor_meta'], orig_output) - assert hasattr(node, 'node_size'), 'The attribute Node.node_size should exist after MetaInfoProp procedure' - assert hasattr(node, 'flop_count'), 'The attribute Node.flop_count should exist after MetaInfoProp procedure' - assert hasattr(node, 'mem_stat'), 'The attribute Node.mem_stat should exist after MetaInfoProp procedure' if __name__ == '__main__': From 36c93ca2cc098ee973d0df6732bbe89686cc2abb Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Fri, 2 Sep 2022 13:24:56 +0800 Subject: [PATCH 03/24] [fx] modify ckpt_chen. --- colossalai/fx/passes/algorithms/ckpt_solver_chen.py | 4 ++-- tests/test_fx/test_ckpt_solvers/test_linearize.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py index 54d22a538107..9830f822ff98 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py @@ -73,10 +73,10 @@ def run_chen_greedy(b: int = 0) -> Tuple[Set, int]: y = 0 prev_idx = 2 for (idx, n) in enumerate(gm.graph.nodes): - temp += getattr(n, '__activation__') + temp += getattr(n, 'fwd_out') + getattr(n, 'fwd_tmp') y = max(y, temp) if temp > b and n in ckpt_nodes: - x += getattr(n, '__activation__') + x += getattr(n, 'fwd_out') temp = 0 ckpt_intv.append((prev_idx, idx + 1)) prev_idx = idx + 1 diff --git a/tests/test_fx/test_ckpt_solvers/test_linearize.py b/tests/test_fx/test_ckpt_solvers/test_linearize.py index 36bd87b42d22..1f4d4a0bc1a5 100644 --- a/tests/test_fx/test_ckpt_solvers/test_linearize.py +++ b/tests/test_fx/test_ckpt_solvers/test_linearize.py @@ -15,6 +15,7 @@ with_codegen = False +@pytest.mark.skip(reason='TODO: modify calculations in rotor') @pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") def test_linearize(): MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]} From d6dcd80badf03a5a9e128671b88f1474b809f9a4 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Fri, 2 Sep 2022 14:22:31 +0800 Subject: [PATCH 04/24] [fx] fix compatibility. --- colossalai/fx/profiler/__init__.py | 1 + colossalai/fx/profiler/profiler.py | 4 +--- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/colossalai/fx/profiler/__init__.py b/colossalai/fx/profiler/__init__.py index d1e4a664670c..5469431b0fa2 100644 --- a/colossalai/fx/profiler/__init__.py +++ b/colossalai/fx/profiler/__init__.py @@ -1,5 +1,6 @@ try: from ._meta_registrations import * + from .opcount import flop_mapping except: import torch print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.') diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 61551ae71458..4d642669da35 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -4,9 +4,7 @@ from torch.fx.node import Argument, Target from torch.fx._compatibility import compatibility from torch.utils._pytree import tree_map, tree_flatten -from .tensor import MetaTensor -from .opcount import flop_mapping -from .memory import activation_size +from . import flop_mapping, MetaTensor, activation_size __all__ = ['profile_function', 'profile_module', 'profile_method', '_profile'] From 01de3f16264699d99aedc8862eba4b49cefdef02 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Fri, 2 Sep 2022 22:05:06 +0800 Subject: [PATCH 05/24] [fx] fix import error. --- colossalai/fx/profiler/profiler.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 4d642669da35..18a6fbf8ae5d 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -4,7 +4,11 @@ from torch.fx.node import Argument, Target from torch.fx._compatibility import compatibility from torch.utils._pytree import tree_map, tree_flatten -from . import flop_mapping, MetaTensor, activation_size +from . import MetaTensor, activation_size +try: + from . import flop_mapping +except: + pass __all__ = ['profile_function', 'profile_module', 'profile_method', '_profile'] From e51f32a8ac7c624c66f49356190ed101e3b132b1 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Fri, 2 Sep 2022 22:26:14 +0800 Subject: [PATCH 06/24] [fx] skip test for MetaInfoProp. --- tests/test_fx/test_meta_info_prop.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_fx/test_meta_info_prop.py b/tests/test_fx/test_meta_info_prop.py index af594de4f6de..672ca48c1135 100644 --- a/tests/test_fx/test_meta_info_prop.py +++ b/tests/test_fx/test_meta_info_prop.py @@ -5,6 +5,8 @@ from torch.fx import symbolic_trace from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata +import pytest + BATCH_SIZE = 2 DIM_IN = 4 DIM_OUT = 16 @@ -17,6 +19,7 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor): assert meta_info_spec.numel == orig_tensor.numel() +@pytest.skip(reason='Not compatible with torch < 1.12.0.') def test_meta_info_prop(): model = torch.nn.Linear(DIM_IN, DIM_OUT) input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta') From a9193b546bb55d2f44098db67b041abf85ce2dc6 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Fri, 2 Sep 2022 22:26:18 +0800 Subject: [PATCH 07/24] [fx] skip test for MetaInfoProp. --- tests/test_fx/test_meta_info_prop.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_fx/test_meta_info_prop.py b/tests/test_fx/test_meta_info_prop.py index 672ca48c1135..6263987b72a0 100644 --- a/tests/test_fx/test_meta_info_prop.py +++ b/tests/test_fx/test_meta_info_prop.py @@ -6,6 +6,11 @@ from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata import pytest +try: + meta_lib = torch.library.Library("aten", "IMPL", "Meta") + INCOMPATIBLE = False # version > 1.12.0 +except: + INCOMPATIBLE = True BATCH_SIZE = 2 DIM_IN = 4 @@ -19,7 +24,7 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor): assert meta_info_spec.numel == orig_tensor.numel() -@pytest.skip(reason='Not compatible with torch < 1.12.0.') +@pytest.skipif(INCOMPATIBLE, reason='torch version is lower than 1.12.0') def test_meta_info_prop(): model = torch.nn.Linear(DIM_IN, DIM_OUT) input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta') From 97cfba9a806238ad3de560324ebe024bfa5fcb2a Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Fri, 2 Sep 2022 22:48:21 +0800 Subject: [PATCH 08/24] [fx] skip test for MetaInfoProp. --- tests/test_fx/test_meta_info_prop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_fx/test_meta_info_prop.py b/tests/test_fx/test_meta_info_prop.py index 6263987b72a0..1bc1e9231c95 100644 --- a/tests/test_fx/test_meta_info_prop.py +++ b/tests/test_fx/test_meta_info_prop.py @@ -24,7 +24,7 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor): assert meta_info_spec.numel == orig_tensor.numel() -@pytest.skipif(INCOMPATIBLE, reason='torch version is lower than 1.12.0') +@pytest.skip.skipif(INCOMPATIBLE, reason='torch version is lower than 1.12.0') def test_meta_info_prop(): model = torch.nn.Linear(DIM_IN, DIM_OUT) input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta') From c992b25e18a23871519a51fc5016529bcef636ec Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Fri, 2 Sep 2022 23:01:18 +0800 Subject: [PATCH 09/24] [fx] skip test for MetaInfoProp. --- tests/test_fx/test_meta_info_prop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_fx/test_meta_info_prop.py b/tests/test_fx/test_meta_info_prop.py index 1bc1e9231c95..5783994ac9d6 100644 --- a/tests/test_fx/test_meta_info_prop.py +++ b/tests/test_fx/test_meta_info_prop.py @@ -24,7 +24,7 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor): assert meta_info_spec.numel == orig_tensor.numel() -@pytest.skip.skipif(INCOMPATIBLE, reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(INCOMPATIBLE, reason='torch version is lower than 1.12.0') def test_meta_info_prop(): model = torch.nn.Linear(DIM_IN, DIM_OUT) input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta') From 62d70966c514b868ff223a7ffd16e07d7b7493ca Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Fri, 2 Sep 2022 23:24:43 +0800 Subject: [PATCH 10/24] [fx] skip if torch 1.11.0. --- tests/test_fx/test_comm_size_compute.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_fx/test_comm_size_compute.py b/tests/test_fx/test_comm_size_compute.py index 69fb6ca9536c..a31878d66606 100644 --- a/tests/test_fx/test_comm_size_compute.py +++ b/tests/test_fx/test_comm_size_compute.py @@ -8,6 +8,12 @@ from colossalai.fx.passes.utils import get_comm_size import pytest +try: + meta_lib = torch.library.Library("aten", "IMPL", "Meta") + INCOMPATIBLE = False # version > 1.12.0 +except: + INCOMPATIBLE = True + MODEL_DIM = 16 BATCH_SIZE = 8 PIPELINE_SIZE = 2 @@ -30,6 +36,7 @@ def forward(self, x): return x +@pytest.mark.skipif(INCOMPATIBLE, reason='torch version is lower than 1.12.0') def test_comm_size_compute(): model = MLP(MODEL_DIM) input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device='meta') From f80735f00dc568a13a781b8dd3f4f1376c2f31f8 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Mon, 5 Sep 2022 16:55:25 +0800 Subject: [PATCH 11/24] [fx] recover MetaInfoProp support for PyTorch 1.11. --- colossalai/__init__.py | 4 +- colossalai/fx/passes/meta_info_prop.py | 9 +- colossalai/fx/profiler/__init__.py | 16 +-- .../fx/profiler/experimental/__init__.py | 4 + .../fx/profiler/experimental/profiler.py | 125 ++++++++++++++++++ .../profiler_function/__init__.py | 0 .../profiler_function/activation_function.py | 0 .../profiler_function/arithmetic.py | 0 .../profiler_function/embedding.py | 0 .../profiler_function/linear.py | 0 .../profiler_function/normalization.py | 0 .../profiler_function/pooling.py | 0 .../profiler_function/python_ops.py | 0 .../profiler_function/torch_ops.py | 0 .../profiler_module/__init__.py | 0 .../profiler_module/activation_function.py | 0 .../profiler_module/attention.py | 0 .../profiler_module/convolution.py | 0 .../profiler_module/dropout.py | 0 .../profiler_module/embedding.py | 0 .../profiler_module/linear.py | 0 .../profiler_module/normalization.py | 0 .../profiler_module/pooling.py | 0 .../{ => experimental}/profiler_module/rnn.py | 0 .../profiler_module/torch_op.py | 0 .../profiler/{ => experimental}/registry.py | 0 colossalai/fx/profiler/memory.py | 37 +++++- colossalai/fx/profiler/profiler.py | 48 ++++--- tests/test_fx/test_meta_info_prop.py | 6 - 29 files changed, 212 insertions(+), 37 deletions(-) create mode 100644 colossalai/fx/profiler/experimental/__init__.py create mode 100644 colossalai/fx/profiler/experimental/profiler.py rename colossalai/fx/profiler/{ => experimental}/profiler_function/__init__.py (100%) rename colossalai/fx/profiler/{ => experimental}/profiler_function/activation_function.py (100%) rename colossalai/fx/profiler/{ => experimental}/profiler_function/arithmetic.py (100%) rename colossalai/fx/profiler/{ => experimental}/profiler_function/embedding.py (100%) rename colossalai/fx/profiler/{ => experimental}/profiler_function/linear.py (100%) rename colossalai/fx/profiler/{ => experimental}/profiler_function/normalization.py (100%) rename colossalai/fx/profiler/{ => experimental}/profiler_function/pooling.py (100%) rename colossalai/fx/profiler/{ => experimental}/profiler_function/python_ops.py (100%) rename colossalai/fx/profiler/{ => experimental}/profiler_function/torch_ops.py (100%) rename colossalai/fx/profiler/{ => experimental}/profiler_module/__init__.py (100%) rename colossalai/fx/profiler/{ => experimental}/profiler_module/activation_function.py (100%) rename colossalai/fx/profiler/{ => experimental}/profiler_module/attention.py (100%) rename colossalai/fx/profiler/{ => experimental}/profiler_module/convolution.py (100%) rename colossalai/fx/profiler/{ => experimental}/profiler_module/dropout.py (100%) rename colossalai/fx/profiler/{ => experimental}/profiler_module/embedding.py (100%) rename colossalai/fx/profiler/{ => experimental}/profiler_module/linear.py (100%) rename colossalai/fx/profiler/{ => experimental}/profiler_module/normalization.py (100%) rename colossalai/fx/profiler/{ => experimental}/profiler_module/pooling.py (100%) rename colossalai/fx/profiler/{ => experimental}/profiler_module/rnn.py (100%) rename colossalai/fx/profiler/{ => experimental}/profiler_module/torch_op.py (100%) rename colossalai/fx/profiler/{ => experimental}/registry.py (100%) diff --git a/colossalai/__init__.py b/colossalai/__init__.py index b5fff7469a62..1cecbd43af4e 100644 --- a/colossalai/__init__.py +++ b/colossalai/__init__.py @@ -1,7 +1,9 @@ try: - from ._meta_registrations import * + from . import _meta_registrations + META_COMPATIBILITY = True except: import torch + META_COMPATIBILITY = False print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.') from .initialize import (initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch, get_default_parser) diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index 10bfdbde8d1a..813fec3a74f8 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -1,12 +1,10 @@ -from operator import add, getitem +from re import L import torch import torch.fx from torch.fx.node import Node, Argument, Target from torch.utils._pytree import tree_map -from typing import Any, Tuple, NamedTuple, Optional, Dict -from functools import reduce +from typing import Any, Tuple, NamedTuple, Dict from torch.fx._compatibility import compatibility -from torch.fx.immutable_collections import immutable_dict, immutable_list from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size, parameter_size @@ -105,6 +103,9 @@ def extract_tensor_meta(obj): setattr(n, 'bwd_tmp', mem_stat[2]) setattr(n, 'bwd_out', mem_stat[3]) n.meta['type'] = type(result) + + for param in self.module.parameters(): + param.grad = None return result # Main Node running APIs diff --git a/colossalai/fx/profiler/__init__.py b/colossalai/fx/profiler/__init__.py index 5469431b0fa2..c21fde5358dd 100644 --- a/colossalai/fx/profiler/__init__.py +++ b/colossalai/fx/profiler/__init__.py @@ -1,9 +1,9 @@ -try: - from ._meta_registrations import * +from ... import META_COMPATIBILITY +if META_COMPATIBILITY: from .opcount import flop_mapping -except: - import torch - print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.') -from .tensor import MetaTensor -from .memory import parameter_size, activation_size -from .profiler import profile_function, profile_method, profile_module, _profile + from .tensor import MetaTensor + from .profiler import profile_function, profile_method, profile_module, _profile +else: + from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module + +from .memory import parameter_size, activation_size, INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS diff --git a/colossalai/fx/profiler/experimental/__init__.py b/colossalai/fx/profiler/experimental/__init__.py new file mode 100644 index 000000000000..522d1324e7c7 --- /dev/null +++ b/colossalai/fx/profiler/experimental/__init__.py @@ -0,0 +1,4 @@ +from .registry import meta_profiler_function, meta_profiler_module +from .profiler_function import * +from .profiler_module import * +from .profiler import profile_function, profile_method, profile_module \ No newline at end of file diff --git a/colossalai/fx/profiler/experimental/profiler.py b/colossalai/fx/profiler/experimental/profiler.py new file mode 100644 index 000000000000..95f61adbb23e --- /dev/null +++ b/colossalai/fx/profiler/experimental/profiler.py @@ -0,0 +1,125 @@ +from typing import Callable, Any, Dict, Tuple +import torch +from torch.fx.node import Argument, Target +from . import meta_profiler_function, meta_profiler_module +from .. import activation_size, INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS + +__all__ = ['profile_function', 'profile_module', 'profile_method'] + +CALL_FUNCTION_MSG = \ +""" +Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n +from colossalai.fx.profiler import meta_profiler_function +@meta_profiler_function.register(YOUR_FUNCTION) +def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]: + flops = ... + macs = ... + return flops, macs +""" +CALL_METHOD_MSG = 'Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}' +CALL_MODULE_MSG = \ +""" +Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n +from colossalai.fx.profiler import meta_profiler_module +@meta_profiler_module.register(YOUR_MODULE) +def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]: + flops = ... + macs = ... + return flops, macs +""" + + +def profile_function(target: 'Target') -> Callable: + """ + Wrap a `call_function` node or `torch.nn.functional` in order to + record the memory cost and FLOPs of the execution. + Unfortunately, backward memory cost and FLOPs are estimated results. + + Warnings: + You may only use tensors with `device=meta` for this wrapped function. + Only original `torch.nn.functional` are available. + + Examples: + >>> input = torch.rand(100, 100, 100, 100, device='meta') + >>> func = torch.nn.functional.relu + >>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_function(func)(input, inplace=False) + """ + + def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: + assert meta_profiler_function.has(target) or meta_profiler_function.has( + target.__name__), CALL_FUNCTION_MSG.format(target) + + fwd_tmp = 0 + fwd_out = 0 + out = func(*args, **kwargs) + if target not in INPLACE_OPS and not kwargs.get('inplace', False): + fwd_out = activation_size(out) + if meta_profiler_function.has(target): + profiler = meta_profiler_function.get(target) + else: + profiler = meta_profiler_function.get(target.__name__) + fwd_flop, _ = profiler(*args, **kwargs) + return out, (fwd_flop, fwd_flop * 2), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) + + f.__name__ = target.__name__ + func = target + return f + + +def profile_method(target: 'Target') -> Callable: + """ + Wrap a `call_method` node + record the memory cost and FLOPs of the execution. + + Warnings: + This is not fully implemented and you may follow the error message to debug. + """ + + def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: + # args[0] is the `self` object for this method call + self_obj, *args_tail = args + + # execute the method and return the result + assert isinstance(target, str), f'{target} instance is not str.' + + out = getattr(self_obj, target)(*args_tail, **kwargs) + assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format( + target, INPLACE_METHOD, NON_INPLACE_METHOD) + # call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs. + fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out) + fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out) + return out, (0, 0), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) + + return f + + +def profile_module(module: torch.nn.Module) -> Callable: + """ + Wrap a `call_module` node or `torch.nn` in order to + record the memory cost and FLOPs of the execution. + + Warnings: + You may only use tensors with `device=meta` for this wrapped function. + Only original `torch.nn` are available. + + Example: + >>> input = torch.rand(4, 3, 224, 224, device='meta') + >>> mod = torch.nn.Conv2d(3, 128, 3) + >>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_module(mod)(input) + """ + + def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: + assert meta_profiler_module.has(type(module)), CALL_MODULE_MSG.format(type(module)) + + fwd_tmp = 0 + fwd_out = 0 + out = func(*args, **kwargs) + if getattr(module, 'inplace', False): + fwd_out = activation_size(out) + profiler = meta_profiler_module.get(type(module)) + fwd_flop, _ = profiler(module, *args, **kwargs) + return out, (fwd_flop, fwd_flop * 2), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) + + f.__name__ = module.__class__.__name__ + func = module.forward + return f diff --git a/colossalai/fx/profiler/profiler_function/__init__.py b/colossalai/fx/profiler/experimental/profiler_function/__init__.py similarity index 100% rename from colossalai/fx/profiler/profiler_function/__init__.py rename to colossalai/fx/profiler/experimental/profiler_function/__init__.py diff --git a/colossalai/fx/profiler/profiler_function/activation_function.py b/colossalai/fx/profiler/experimental/profiler_function/activation_function.py similarity index 100% rename from colossalai/fx/profiler/profiler_function/activation_function.py rename to colossalai/fx/profiler/experimental/profiler_function/activation_function.py diff --git a/colossalai/fx/profiler/profiler_function/arithmetic.py b/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py similarity index 100% rename from colossalai/fx/profiler/profiler_function/arithmetic.py rename to colossalai/fx/profiler/experimental/profiler_function/arithmetic.py diff --git a/colossalai/fx/profiler/profiler_function/embedding.py b/colossalai/fx/profiler/experimental/profiler_function/embedding.py similarity index 100% rename from colossalai/fx/profiler/profiler_function/embedding.py rename to colossalai/fx/profiler/experimental/profiler_function/embedding.py diff --git a/colossalai/fx/profiler/profiler_function/linear.py b/colossalai/fx/profiler/experimental/profiler_function/linear.py similarity index 100% rename from colossalai/fx/profiler/profiler_function/linear.py rename to colossalai/fx/profiler/experimental/profiler_function/linear.py diff --git a/colossalai/fx/profiler/profiler_function/normalization.py b/colossalai/fx/profiler/experimental/profiler_function/normalization.py similarity index 100% rename from colossalai/fx/profiler/profiler_function/normalization.py rename to colossalai/fx/profiler/experimental/profiler_function/normalization.py diff --git a/colossalai/fx/profiler/profiler_function/pooling.py b/colossalai/fx/profiler/experimental/profiler_function/pooling.py similarity index 100% rename from colossalai/fx/profiler/profiler_function/pooling.py rename to colossalai/fx/profiler/experimental/profiler_function/pooling.py diff --git a/colossalai/fx/profiler/profiler_function/python_ops.py b/colossalai/fx/profiler/experimental/profiler_function/python_ops.py similarity index 100% rename from colossalai/fx/profiler/profiler_function/python_ops.py rename to colossalai/fx/profiler/experimental/profiler_function/python_ops.py diff --git a/colossalai/fx/profiler/profiler_function/torch_ops.py b/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py similarity index 100% rename from colossalai/fx/profiler/profiler_function/torch_ops.py rename to colossalai/fx/profiler/experimental/profiler_function/torch_ops.py diff --git a/colossalai/fx/profiler/profiler_module/__init__.py b/colossalai/fx/profiler/experimental/profiler_module/__init__.py similarity index 100% rename from colossalai/fx/profiler/profiler_module/__init__.py rename to colossalai/fx/profiler/experimental/profiler_module/__init__.py diff --git a/colossalai/fx/profiler/profiler_module/activation_function.py b/colossalai/fx/profiler/experimental/profiler_module/activation_function.py similarity index 100% rename from colossalai/fx/profiler/profiler_module/activation_function.py rename to colossalai/fx/profiler/experimental/profiler_module/activation_function.py diff --git a/colossalai/fx/profiler/profiler_module/attention.py b/colossalai/fx/profiler/experimental/profiler_module/attention.py similarity index 100% rename from colossalai/fx/profiler/profiler_module/attention.py rename to colossalai/fx/profiler/experimental/profiler_module/attention.py diff --git a/colossalai/fx/profiler/profiler_module/convolution.py b/colossalai/fx/profiler/experimental/profiler_module/convolution.py similarity index 100% rename from colossalai/fx/profiler/profiler_module/convolution.py rename to colossalai/fx/profiler/experimental/profiler_module/convolution.py diff --git a/colossalai/fx/profiler/profiler_module/dropout.py b/colossalai/fx/profiler/experimental/profiler_module/dropout.py similarity index 100% rename from colossalai/fx/profiler/profiler_module/dropout.py rename to colossalai/fx/profiler/experimental/profiler_module/dropout.py diff --git a/colossalai/fx/profiler/profiler_module/embedding.py b/colossalai/fx/profiler/experimental/profiler_module/embedding.py similarity index 100% rename from colossalai/fx/profiler/profiler_module/embedding.py rename to colossalai/fx/profiler/experimental/profiler_module/embedding.py diff --git a/colossalai/fx/profiler/profiler_module/linear.py b/colossalai/fx/profiler/experimental/profiler_module/linear.py similarity index 100% rename from colossalai/fx/profiler/profiler_module/linear.py rename to colossalai/fx/profiler/experimental/profiler_module/linear.py diff --git a/colossalai/fx/profiler/profiler_module/normalization.py b/colossalai/fx/profiler/experimental/profiler_module/normalization.py similarity index 100% rename from colossalai/fx/profiler/profiler_module/normalization.py rename to colossalai/fx/profiler/experimental/profiler_module/normalization.py diff --git a/colossalai/fx/profiler/profiler_module/pooling.py b/colossalai/fx/profiler/experimental/profiler_module/pooling.py similarity index 100% rename from colossalai/fx/profiler/profiler_module/pooling.py rename to colossalai/fx/profiler/experimental/profiler_module/pooling.py diff --git a/colossalai/fx/profiler/profiler_module/rnn.py b/colossalai/fx/profiler/experimental/profiler_module/rnn.py similarity index 100% rename from colossalai/fx/profiler/profiler_module/rnn.py rename to colossalai/fx/profiler/experimental/profiler_module/rnn.py diff --git a/colossalai/fx/profiler/profiler_module/torch_op.py b/colossalai/fx/profiler/experimental/profiler_module/torch_op.py similarity index 100% rename from colossalai/fx/profiler/profiler_module/torch_op.py rename to colossalai/fx/profiler/experimental/profiler_module/torch_op.py diff --git a/colossalai/fx/profiler/registry.py b/colossalai/fx/profiler/experimental/registry.py similarity index 100% rename from colossalai/fx/profiler/registry.py rename to colossalai/fx/profiler/experimental/registry.py diff --git a/colossalai/fx/profiler/memory.py b/colossalai/fx/profiler/memory.py index f64674f480bf..eb02be3e5e22 100644 --- a/colossalai/fx/profiler/memory.py +++ b/colossalai/fx/profiler/memory.py @@ -1,7 +1,42 @@ import torch from typing import Union, Dict, List, Tuple +from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos -__all__ = ['activation_size', 'parameter_size'] +__all__ = ['activation_size', 'parameter_size', 'INPLACE_METHOD', 'NON_INPLACE_METHOD'] + +# TODO fill out the inplace ops +INPLACE_OPS = [ + add, + sub, + mul, + floordiv, + neg, + pos, + getitem, + setitem, + getattr, + torch.Tensor.cpu, +] + +# TODO: list all call_methods that are inplace here +INPLACE_METHOD = [ + 'transpose', + 'permute', + # TODO: reshape may return a copy of the data if the data is not contiguous + 'reshape', + 'dim', + 'flatten', + 'size', + 'view', + 'unsqueeze', + 'to', +] + +# TODO: list all call_methods that are not inplace here +NON_INPLACE_METHOD = [ + 'expand', + 'mean', +] def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int: diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 18a6fbf8ae5d..3b0601c77927 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -1,17 +1,15 @@ -from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos -from typing import Callable, List, NamedTuple, Any, Dict, Tuple, Union +from typing import Callable, Any, Dict, Tuple import torch from torch.fx.node import Argument, Target -from torch.fx._compatibility import compatibility -from torch.utils._pytree import tree_map, tree_flatten -from . import MetaTensor, activation_size -try: - from . import flop_mapping -except: - pass +from torch.utils._pytree import tree_map +from .memory import activation_size, NON_INPLACE_METHOD, INPLACE_METHOD, INPLACE_OPS +from .tensor import MetaTensor +from .opcount import flop_mapping __all__ = ['profile_function', 'profile_module', 'profile_method', '_profile'] +CALL_METHOD_MSG = 'Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}' + def normalize_tuple(x): if not isinstance(x, tuple): @@ -116,12 +114,17 @@ def profile_function(target: 'Target') -> Callable: Only original `torch.nn.functional` are available. Examples: - >> input = torch.rand(100, 100, 100, 100, device='meta') - >> func = torch.nn.functional.relu - >> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_function(func)(input, inplace=False) + >>> input = torch.rand(100, 100, 100, 100, device='meta') + >>> func = torch.nn.functional.relu + >>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_function(func)(input, inplace=False) """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: + if target in INPLACE_OPS or kwargs.get('inplace', False): + args = tree_map(lambda x: x.to('meta'), args) + kwargs = tree_map(lambda x: x.to('meta'), kwargs) + out = func(*args, **kwargs) + return out, (out.numel(), out.numel()), (0, 0, 0, 0) out, flop_count, mem_stat = _profile(func, args, kwargs) return out, flop_count, mem_stat @@ -136,7 +139,7 @@ def profile_method(target: 'Target') -> Callable: record the memory cost and FLOPs of the execution. Warnings: - This is not fully implemented and you may follow the error message to debug. + Not all `call_method` nodes are inplace. But for sake of simplicity, we mark all of them as inplace. """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: @@ -147,7 +150,13 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: assert isinstance(target, str), f'{target} instance is not str.' out = getattr(self_obj, target)(args_tail, kwargs) - return out, (0, 0), (0, activation_size(out), activation_size(out), 0) + + assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format( + target, INPLACE_METHOD, NON_INPLACE_METHOD) + # call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs. + fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out) + fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out) + return out, (0, 0), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) return f @@ -162,12 +171,17 @@ def profile_module(module: torch.nn.Module) -> Callable: Only original `torch.nn` are available. Example: - >> input = torch.rand(4, 3, 224, 224, device='meta') - >> mod = torch.nn.Conv2d(3, 128, 3) - >> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_module(mod)(input) + >>> input = torch.rand(4, 3, 224, 224, device='meta') + >>> mod = torch.nn.Conv2d(3, 128, 3) + >>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_module(mod)(input) """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: + if getattr(module, 'inplace', False): + args = tree_map(lambda x: x.to('meta'), args) + kwargs = tree_map(lambda x: x.to('meta'), kwargs) + out = func(*args, **kwargs) + return out, (out.numel(), out.numel()), (0, 0, 0, 0) out, flop_count, mem_stat = _profile(func, args, kwargs) return out, flop_count, mem_stat diff --git a/tests/test_fx/test_meta_info_prop.py b/tests/test_fx/test_meta_info_prop.py index 5783994ac9d6..fa9067ae3b38 100644 --- a/tests/test_fx/test_meta_info_prop.py +++ b/tests/test_fx/test_meta_info_prop.py @@ -6,11 +6,6 @@ from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata import pytest -try: - meta_lib = torch.library.Library("aten", "IMPL", "Meta") - INCOMPATIBLE = False # version > 1.12.0 -except: - INCOMPATIBLE = True BATCH_SIZE = 2 DIM_IN = 4 @@ -24,7 +19,6 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor): assert meta_info_spec.numel == orig_tensor.numel() -@pytest.mark.skipif(INCOMPATIBLE, reason='torch version is lower than 1.12.0') def test_meta_info_prop(): model = torch.nn.Linear(DIM_IN, DIM_OUT) input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta') From 0d4a030ad4f0387378797869e94c2eac6074fbab Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 6 Sep 2022 15:58:42 +0800 Subject: [PATCH 12/24] [fx] provide a stable but not accurate enough version of profiler. --- colossalai/_meta_registrations.py | 20 ++++++++ colossalai/fx/profiler/memory.py | 31 ++++++++++++ colossalai/fx/profiler/opcount.py | 14 +++++- colossalai/fx/profiler/profiler.py | 78 ++++++++++++++++-------------- 4 files changed, 105 insertions(+), 38 deletions(-) diff --git a/colossalai/_meta_registrations.py b/colossalai/_meta_registrations.py index 94f559f382d5..802150ded585 100644 --- a/colossalai/_meta_registrations.py +++ b/colossalai/_meta_registrations.py @@ -181,6 +181,12 @@ def meta_hardswish_backward(grad_out: torch.Tensor, input: torch.Tensor): return grad_in +@register_meta(aten.hardtanh_backward.default) +def meta_hardtanh_backward(grad_out: torch.Tensor, input: torch.Tensor, min_val: int, max_val: int): + grad_in = torch.empty_like(input) + return grad_in + + @register_meta(aten.roll.default) def meta_roll(input: torch.Tensor, shifts, dims): return torch.empty_like(input) @@ -321,3 +327,17 @@ def meta_index_Tensor(self, indices): else: replacement_shape = list(index.shape) return self.new_empty(before_shape + replacement_shape + after_shape) + + +@register_meta(aten.embedding_dense_backward.default) +def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, + scale_grad_by_freq): + return torch.empty((num_weights, grad_output.size(-1)), + dtype=grad_output.dtype, + device=grad_output.device, + layout=grad_output.layout) + + +@register_meta(aten.where.self) +def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor): + return torch.empty_like(condition) diff --git a/colossalai/fx/profiler/memory.py b/colossalai/fx/profiler/memory.py index eb02be3e5e22..8ab815f0eadb 100644 --- a/colossalai/fx/profiler/memory.py +++ b/colossalai/fx/profiler/memory.py @@ -1,6 +1,7 @@ import torch from typing import Union, Dict, List, Tuple from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos +from . import META_COMPATIBILITY __all__ = ['activation_size', 'parameter_size', 'INPLACE_METHOD', 'NON_INPLACE_METHOD'] @@ -18,6 +19,31 @@ torch.Tensor.cpu, ] +if META_COMPATIBILITY: + aten = torch.ops.aten + + WEIRD_OP = [ + torch.where, + ] + + INPLACE_ATEN = [ + aten.add_.Tensor, + aten.add.Tensor, + aten.sub_.Tensor, + aten.div_.Tensor, + aten.div_.Scalar, + aten.mul_.Tensor, + aten.mul.Tensor, + aten.bernoulli_.float, + + # inplace reshaping + aten.detach.default, + aten.t.default, + aten.transpose.int, + aten.view.default, + aten._unsafe_view.default, + ] + # TODO: list all call_methods that are inplace here INPLACE_METHOD = [ 'transpose', @@ -30,12 +56,17 @@ 'view', 'unsqueeze', 'to', + 'type', + 'flatten', ] # TODO: list all call_methods that are not inplace here NON_INPLACE_METHOD = [ + 'chunk', + 'contiguous', 'expand', 'mean', + 'split', ] diff --git a/colossalai/fx/profiler/opcount.py b/colossalai/fx/profiler/opcount.py index b26f79d11f93..3489f00be24c 100644 --- a/colossalai/fx/profiler/opcount.py +++ b/colossalai/fx/profiler/opcount.py @@ -170,10 +170,10 @@ def elementwise_flop(inputs: List[Any], outputs: List[Any]) -> Number: ret = 0 if input_scale != 0: shape = inputs[0].shape - ret += input_scale * reduce(operator.mul, shape) + ret += input_scale * reduce(operator.mul, shape) if shape else 0 if output_scale != 0: shape = outputs[0].shape - ret += output_scale * reduce(operator.mul, shape) + ret += output_scale * reduce(operator.mul, shape) if shape else 0 return ret return elementwise_flop @@ -233,14 +233,21 @@ def zero_flop_jit(*args): aten.div.Scalar, aten.div_.Scalar, aten.mul.Tensor, + aten.mul.Scalar, aten.mul_.Tensor, + aten.neg.default, + aten.pow.Tensor_Scalar, + aten.rsub.Scalar, aten.sum.default, aten.sum.dim_IntList, aten.mean.dim, # activation op + aten.hardswish.default, aten.hardswish_.default, aten.hardswish_backward.default, + aten.hardtanh_.default, + aten.hardtanh_backward.default, aten.hardsigmoid_backward.default, aten.hardsigmoid.default, aten.gelu.default, @@ -253,6 +260,8 @@ def zero_flop_jit(*args): aten._softmax_backward_data.default, aten.relu_.default, aten.relu.default, + aten.tanh.default, + aten.tanh_backward.default, aten.threshold_backward.default, ] @@ -287,6 +296,7 @@ def zero_flop_jit(*args): aten.unsqueeze.default, aten._unsafe_view.default, aten.view.default, + aten.where.self, aten.zero_.default, ] diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 3b0601c77927..40c686ca854f 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -1,8 +1,9 @@ from typing import Callable, Any, Dict, Tuple import torch +from torch.fx import Graph from torch.fx.node import Argument, Target from torch.utils._pytree import tree_map -from .memory import activation_size, NON_INPLACE_METHOD, INPLACE_METHOD, INPLACE_OPS +from .memory import activation_size, NON_INPLACE_METHOD, INPLACE_METHOD, INPLACE_OPS, INPLACE_ATEN, WEIRD_OP from .tensor import MetaTensor from .opcount import flop_mapping @@ -17,7 +18,11 @@ def normalize_tuple(x): return x -def _profile(target: Callable, args, kwargs) -> Tuple[Any, ...]: +def is_autogradable(x): + return isinstance(x, torch.Tensor) and x.is_floating_point() + + +def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]: """Profile a Callable function with args and kwargs. Args: @@ -59,44 +64,55 @@ def unwrap(x): return x._tensor.to('meta') if isinstance(x, FlopTensor) else x def to_meta(x): - return x.to('meta') + return x.to('meta') if isinstance(x, torch.Tensor) else x args = tree_map(unwrap, args) kwargs = tree_map(unwrap, kwargs) # run aten for backend=CPU but actually on backend=Meta out = func(*args, **kwargs) - - # TODO: this will be, but we should examine all aten ops first - # if func in flop_mapping: - # flop_count[stage] += flop_mapping[func](args, normalize_tuple(out)) flop_count[stage] += flop_mapping[func](args, normalize_tuple(out)) - temp[stage].append(tree_map(to_meta, normalize_tuple(out))) + if func not in INPLACE_ATEN: + temp[stage].append(tree_map(to_meta, normalize_tuple(out))) def wrap(x): return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x return tree_map(wrap, out) - def wrap(x): - return FlopTensor( - x.detach().requires_grad_(True)) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x + if target not in WEIRD_OP: + + def wrap(x): + return FlopTensor( + x.detach().requires_grad_(True)) if is_autogradable(x) and not hasattr(x, '_tensor') else x + else: + + def wrap(x): + return FlopTensor( + x.detach().requires_grad_(False)) if is_autogradable(x) and not hasattr(x, '_tensor') else x args = tree_map(wrap, args) kwargs = tree_map(wrap, kwargs) - out = target(*args, **kwargs) - stage = 'l' - loss = out.sum() - stage = 'b' - loss.backward() + if isinstance(target, str): + # args[0] is the `self` object for this method call + self_obj, *args_tail = args + out = getattr(self_obj, target)(*args_tail, **kwargs) + else: + out = target(*args, **kwargs) + + if is_autogradable(out) and out.requires_grad: + stage = 'l' + loss = out.sum() + stage = 'b' + loss.backward() fwd_flop = flop_count['f'] bwd_flop = flop_count['b'] - fwd_tmp = activation_size(temp['f'][:-1]) + fwd_tmp = max(map(activation_size, temp['f'][:-1])) if len(temp['f'][:-1]) else 0 fwd_out = activation_size(temp['f'][-1]) if len(temp['f']) else 0 - bwd_tmp = activation_size(temp['b']) + bwd_tmp = max(map(activation_size, temp['b'])) if len(temp['b']) else 0 def unwrap(x): return x._tensor.to('meta') if isinstance(x, FlopTensor) else x @@ -120,12 +136,12 @@ def profile_function(target: 'Target') -> Callable: """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: - if target in INPLACE_OPS or kwargs.get('inplace', False): - args = tree_map(lambda x: x.to('meta'), args) - kwargs = tree_map(lambda x: x.to('meta'), kwargs) + if kwargs.get('inplace', False): + args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args) + kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs) out = func(*args, **kwargs) - return out, (out.numel(), out.numel()), (0, 0, 0, 0) - out, flop_count, mem_stat = _profile(func, args, kwargs) + return out, (0, 0), (0, 0, 0, 0) + out, flop_count, mem_stat = _profile(func, *args, **kwargs) return out, flop_count, mem_stat f.__name__ = target.__name__ @@ -143,20 +159,10 @@ def profile_method(target: 'Target') -> Callable: """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: - # args[0] is the `self` object for this method call - self_obj, *args_tail = args - # execute the method and return the result assert isinstance(target, str), f'{target} instance is not str.' - - out = getattr(self_obj, target)(args_tail, kwargs) - - assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format( - target, INPLACE_METHOD, NON_INPLACE_METHOD) - # call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs. - fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out) - fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out) - return out, (0, 0), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) + out, flop_count, mem_stat = _profile(target, *args, **kwargs) + return out, flop_count, mem_stat return f @@ -182,7 +188,7 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: kwargs = tree_map(lambda x: x.to('meta'), kwargs) out = func(*args, **kwargs) return out, (out.numel(), out.numel()), (0, 0, 0, 0) - out, flop_count, mem_stat = _profile(func, args, kwargs) + out, flop_count, mem_stat = _profile(func, *args, **kwargs) return out, flop_count, mem_stat f.__name__ = module.__class__.__name__ From 9223090809f02f566f981d04f26884ad8a7dba24 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 6 Sep 2022 16:31:50 +0800 Subject: [PATCH 13/24] [fx] provide a stable but not accurate enough version of profiler. --- colossalai/fx/profiler/experimental/profiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/fx/profiler/experimental/profiler.py b/colossalai/fx/profiler/experimental/profiler.py index 95f61adbb23e..eb88c05e2fca 100644 --- a/colossalai/fx/profiler/experimental/profiler.py +++ b/colossalai/fx/profiler/experimental/profiler.py @@ -2,7 +2,7 @@ import torch from torch.fx.node import Argument, Target from . import meta_profiler_function, meta_profiler_module -from .. import activation_size, INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS +from ..memory import activation_size, INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS __all__ = ['profile_function', 'profile_module', 'profile_method'] From 1e49efe94aad0cec03dcab2ccf97bbc5fdf97873 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 6 Sep 2022 16:41:21 +0800 Subject: [PATCH 14/24] [fx] fix compatibility in tests. --- tests/test_fx/test_meta/test_aten.py | 11 +++--- tests/test_fx/test_meta/test_backward.py | 45 ++++++++---------------- 2 files changed, 19 insertions(+), 37 deletions(-) diff --git a/tests/test_fx/test_meta/test_aten.py b/tests/test_fx/test_meta/test_aten.py index 991130376498..4794ea8c61c5 100644 --- a/tests/test_fx/test_meta/test_aten.py +++ b/tests/test_fx/test_meta/test_aten.py @@ -2,15 +2,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -from colossalai.fx.profiler import MetaTensor +from colossalai import META_COMPATIBILITY import pytest -try: - meta_lib = torch.library.Library("aten", "IMPL", "Meta") - INCOMPATIBLE = False # version > 1.12.0 -except: - INCOMPATIBLE = True +if META_COMPATIBILITY: + from colossalai.fx.profiler import MetaTensor aten = torch.ops.aten @@ -77,7 +74,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac compare_all(x.grad, meta_x.grad) -@pytest.mark.skipif(INCOMPATIBLE, reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0') def test_meta_aten(): for (aten_op, requires_backward), v in registered_meta.items(): for f, x in v: diff --git a/tests/test_fx/test_meta/test_backward.py b/tests/test_fx/test_meta/test_backward.py index 98b3b464f8fc..e497792af78f 100644 --- a/tests/test_fx/test_meta/test_backward.py +++ b/tests/test_fx/test_meta/test_backward.py @@ -1,48 +1,33 @@ import torchvision.models as tm import timm.models as tmm import torch -from colossalai.fx.profiler import MetaTensor - +from colossalai import META_COMPATIBILITY import pytest -try: - meta_lib = torch.library.Library("aten", "IMPL", "Meta") - incompatible = False # version > 1.12.0 -except: - incompatible = True - +if META_COMPATIBILITY: + from colossalai.fx.profiler import MetaTensor tm_models = [ - tm.vgg11, - tm.resnet18, - tm.densenet121, - tm.mobilenet_v3_small, - tm.resnext50_32x4d, + tm.vgg11, + tm.resnet18, + tm.densenet121, + tm.mobilenet_v3_small, + tm.resnext50_32x4d, tm.wide_resnet50_2, - tm.regnet_x_16gf, - tm.mnasnet0_5, + tm.regnet_x_16gf, + tm.mnasnet0_5, tm.efficientnet_b0, ] - tmm_models = [ - tmm.resnest.resnest50d, - tmm.beit.beit_base_patch16_224, - tmm.cait.cait_s24_224, - tmm.efficientnet.efficientnetv2_m, - tmm.resmlp_12_224, - tmm.vision_transformer.vit_base_patch16_224, - tmm.deit_base_distilled_patch16_224, - tmm.convnext.convnext_base, - tmm.vgg.vgg11, - tmm.dpn.dpn68, - tmm.densenet.densenet121, - tmm.rexnet.rexnet_100, + tmm.resnest.resnest50d, tmm.beit.beit_base_patch16_224, tmm.cait.cait_s24_224, tmm.efficientnet.efficientnetv2_m, + tmm.resmlp_12_224, tmm.vision_transformer.vit_base_patch16_224, tmm.deit_base_distilled_patch16_224, + tmm.convnext.convnext_base, tmm.vgg.vgg11, tmm.dpn.dpn68, tmm.densenet.densenet121, tmm.rexnet.rexnet_100, tmm.swin_transformer.swin_base_patch4_window7_224 ] -@pytest.mark.skipif(incompatible, reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0') def test_torchvision_models(): for m in tm_models: model = m().to('meta') @@ -50,7 +35,7 @@ def test_torchvision_models(): model(MetaTensor(data)).sum().backward() -@pytest.mark.skipif(incompatible, reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0') def test_timm_models(): for m in tmm_models: model = m().to('meta') From 77045dbcaf4e8b96f142ef1748ee484c7969d0a9 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 6 Sep 2022 16:44:10 +0800 Subject: [PATCH 15/24] [fx] fix compatibility in tests. --- colossalai/fx/passes/algorithms/ckpt_solver_chen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py index 9830f822ff98..9ebbd48c75b1 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py @@ -73,7 +73,7 @@ def run_chen_greedy(b: int = 0) -> Tuple[Set, int]: y = 0 prev_idx = 2 for (idx, n) in enumerate(gm.graph.nodes): - temp += getattr(n, 'fwd_out') + getattr(n, 'fwd_tmp') + temp += getattr(n, 'fwd_out') y = max(y, temp) if temp > b and n in ckpt_nodes: x += getattr(n, 'fwd_out') From 1c127c698ddd7f59853f173d9963476a32a122d2 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 6 Sep 2022 16:44:52 +0800 Subject: [PATCH 16/24] [fx] fix compatibility in tests. --- colossalai/fx/passes/meta_info_prop.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index 813fec3a74f8..1a1e149577c4 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -1,4 +1,3 @@ -from re import L import torch import torch.fx from torch.fx.node import Node, Argument, Target From 6f486d1c0f5ac1a9d9ee480114124cc6a3b0e785 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 6 Sep 2022 16:58:42 +0800 Subject: [PATCH 17/24] [fx] fix compatibility in tests. --- colossalai/fx/profiler/experimental/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/fx/profiler/experimental/__init__.py b/colossalai/fx/profiler/experimental/__init__.py index 522d1324e7c7..b6beb76091b0 100644 --- a/colossalai/fx/profiler/experimental/__init__.py +++ b/colossalai/fx/profiler/experimental/__init__.py @@ -1,4 +1,4 @@ from .registry import meta_profiler_function, meta_profiler_module from .profiler_function import * from .profiler_module import * -from .profiler import profile_function, profile_method, profile_module \ No newline at end of file +from .profiler import profile_function, profile_method, profile_module From c1eb532477f05334c3f70ac8164b5ab993fc5ef3 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 6 Sep 2022 16:59:54 +0800 Subject: [PATCH 18/24] [fx] fix compatibility in tests. --- colossalai/fx/profiler/experimental/profiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/fx/profiler/experimental/profiler.py b/colossalai/fx/profiler/experimental/profiler.py index eb88c05e2fca..46d4add3c5e9 100644 --- a/colossalai/fx/profiler/experimental/profiler.py +++ b/colossalai/fx/profiler/experimental/profiler.py @@ -9,7 +9,7 @@ CALL_FUNCTION_MSG = \ """ Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n -from colossalai.fx.profiler import meta_profiler_function +from colossalai.fx.profiler.experimental import meta_profiler_function @meta_profiler_function.register(YOUR_FUNCTION) def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]: flops = ... @@ -20,7 +20,7 @@ def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]: CALL_MODULE_MSG = \ """ Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n -from colossalai.fx.profiler import meta_profiler_module +from colossalai.fx.profiler.experimental import meta_profiler_module @meta_profiler_module.register(YOUR_MODULE) def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]: flops = ... From b3d144e254e7ae1a1634479d7d77998c4302076e Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 6 Sep 2022 17:04:50 +0800 Subject: [PATCH 19/24] [fx] fix compatibility in tests. --- colossalai/fx/profiler/memory.py | 82 ++++++++++++++++-------------- colossalai/fx/profiler/profiler.py | 4 +- 2 files changed, 45 insertions(+), 41 deletions(-) diff --git a/colossalai/fx/profiler/memory.py b/colossalai/fx/profiler/memory.py index 8ab815f0eadb..be51064220e0 100644 --- a/colossalai/fx/profiler/memory.py +++ b/colossalai/fx/profiler/memory.py @@ -3,26 +3,12 @@ from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos from . import META_COMPATIBILITY -__all__ = ['activation_size', 'parameter_size', 'INPLACE_METHOD', 'NON_INPLACE_METHOD'] - -# TODO fill out the inplace ops -INPLACE_OPS = [ - add, - sub, - mul, - floordiv, - neg, - pos, - getitem, - setitem, - getattr, - torch.Tensor.cpu, -] +__all__ = ['activation_size', 'parameter_size'] if META_COMPATIBILITY: aten = torch.ops.aten - WEIRD_OP = [ + WEIRD_OPS = [ torch.where, ] @@ -44,30 +30,48 @@ aten._unsafe_view.default, ] -# TODO: list all call_methods that are inplace here -INPLACE_METHOD = [ - 'transpose', - 'permute', + __all__ += ['INPLACE_ATEN', 'WEIRD_OPS'] + +else: + # TODO fill out the inplace ops + INPLACE_OPS = [ + add, + sub, + mul, + floordiv, + neg, + pos, + getitem, + setitem, + getattr, + torch.Tensor.cpu, + ] + + # TODO: list all call_methods that are inplace here + INPLACE_METHOD = [ + 'transpose', + 'permute', # TODO: reshape may return a copy of the data if the data is not contiguous - 'reshape', - 'dim', - 'flatten', - 'size', - 'view', - 'unsqueeze', - 'to', - 'type', - 'flatten', -] - -# TODO: list all call_methods that are not inplace here -NON_INPLACE_METHOD = [ - 'chunk', - 'contiguous', - 'expand', - 'mean', - 'split', -] + 'reshape', + 'dim', + 'flatten', + 'size', + 'view', + 'unsqueeze', + 'to', + 'type', + 'flatten', + ] + + # TODO: list all call_methods that are not inplace here + NON_INPLACE_METHOD = [ + 'chunk', + 'contiguous', + 'expand', + 'mean', + 'split', + ] + __all__ += ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD'] def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int: diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 40c686ca854f..2e5eb7f02fac 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -3,7 +3,7 @@ from torch.fx import Graph from torch.fx.node import Argument, Target from torch.utils._pytree import tree_map -from .memory import activation_size, NON_INPLACE_METHOD, INPLACE_METHOD, INPLACE_OPS, INPLACE_ATEN, WEIRD_OP +from .memory import activation_size, NON_INPLACE_METHOD, INPLACE_METHOD, INPLACE_OPS, INPLACE_ATEN, WEIRD_OPS from .tensor import MetaTensor from .opcount import flop_mapping @@ -80,7 +80,7 @@ def wrap(x): return tree_map(wrap, out) - if target not in WEIRD_OP: + if target not in WEIRD_OPS: def wrap(x): return FlopTensor( From 15a1bb194c73d6d0f06f9974be5f4fed124c43b6 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 6 Sep 2022 17:07:58 +0800 Subject: [PATCH 20/24] [fx] fix compatibility in tests. --- colossalai/fx/profiler/profiler.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 2e5eb7f02fac..324550c2fe6d 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -153,9 +153,6 @@ def profile_method(target: 'Target') -> Callable: """ Wrap a `call_method` node record the memory cost and FLOPs of the execution. - - Warnings: - Not all `call_method` nodes are inplace. But for sake of simplicity, we mark all of them as inplace. """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: From d011148db007f99f4506656668a87c4bc15526cf Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 6 Sep 2022 17:08:01 +0800 Subject: [PATCH 21/24] [fx] fix compatibility in tests. --- colossalai/fx/profiler/profiler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 324550c2fe6d..8f94be81bc9a 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -9,8 +9,6 @@ __all__ = ['profile_function', 'profile_module', 'profile_method', '_profile'] -CALL_METHOD_MSG = 'Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}' - def normalize_tuple(x): if not isinstance(x, tuple): From 1a45a1203f95fd75b89d6b904b6645d7e137045b Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 6 Sep 2022 17:09:03 +0800 Subject: [PATCH 22/24] [fx] fix compatibility in tests. --- tests/test_fx/test_comm_size_compute.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/test_fx/test_comm_size_compute.py b/tests/test_fx/test_comm_size_compute.py index a31878d66606..e4d1ff32ba3a 100644 --- a/tests/test_fx/test_comm_size_compute.py +++ b/tests/test_fx/test_comm_size_compute.py @@ -6,14 +6,9 @@ from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, uniform_split_pass from colossalai.fx.passes.utils import get_comm_size +from colossalai import META_COMPATIBILITY import pytest -try: - meta_lib = torch.library.Library("aten", "IMPL", "Meta") - INCOMPATIBLE = False # version > 1.12.0 -except: - INCOMPATIBLE = True - MODEL_DIM = 16 BATCH_SIZE = 8 PIPELINE_SIZE = 2 @@ -36,7 +31,7 @@ def forward(self, x): return x -@pytest.mark.skipif(INCOMPATIBLE, reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0') def test_comm_size_compute(): model = MLP(MODEL_DIM) input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device='meta') From cf06d842201151be4afeb320982e041fbd25eb2e Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 6 Sep 2022 17:25:08 +0800 Subject: [PATCH 23/24] [fx] fix compatibility in tests. --- tests/test_fx/test_meta/test_aten.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_fx/test_meta/test_aten.py b/tests/test_fx/test_meta/test_aten.py index 4794ea8c61c5..49b97827042a 100644 --- a/tests/test_fx/test_meta/test_aten.py +++ b/tests/test_fx/test_meta/test_aten.py @@ -53,7 +53,7 @@ } -def compare_all(tensor: torch.Tensor, meta_tensor: MetaTensor) -> Any: +def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor) -> Any: assert tensor.shape == meta_tensor.shape, f'the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match.' assert tensor.dtype == meta_tensor.dtype, f'the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match.' assert tensor.stride() == meta_tensor.stride( From 359f750c2de5f0819db5426b7a5a8c4f165c74b9 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 6 Sep 2022 17:30:50 +0800 Subject: [PATCH 24/24] [fx] fix import error. --- colossalai/fx/profiler/__init__.py | 2 +- colossalai/fx/profiler/profiler.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/fx/profiler/__init__.py b/colossalai/fx/profiler/__init__.py index c21fde5358dd..1b46bd494a98 100644 --- a/colossalai/fx/profiler/__init__.py +++ b/colossalai/fx/profiler/__init__.py @@ -6,4 +6,4 @@ else: from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module -from .memory import parameter_size, activation_size, INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS +from .memory import parameter_size, activation_size diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 8f94be81bc9a..8f9fb92e0ae4 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -3,7 +3,7 @@ from torch.fx import Graph from torch.fx.node import Argument, Target from torch.utils._pytree import tree_map -from .memory import activation_size, NON_INPLACE_METHOD, INPLACE_METHOD, INPLACE_OPS, INPLACE_ATEN, WEIRD_OPS +from .memory import activation_size, INPLACE_ATEN, WEIRD_OPS from .tensor import MetaTensor from .opcount import flop_mapping