From 62c0a49f0213c0f5e6d414fdcbe56c6dbf4dcefd Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Fri, 2 Sep 2022 13:54:56 +0800 Subject: [PATCH 1/4] [fx] support meta tracing for aten level computation graphs like functorch. --- colossalai/__init__.py | 5 + .../{fx/profiler => }/_meta_registrations.py | 71 +++++--------- colossalai/fx/profiler/__init__.py | 5 - colossalai/fx/tracer/__init__.py | 1 + colossalai/fx/tracer/_meta_trace.py | 98 +++++++++++++++++++ 5 files changed, 131 insertions(+), 49 deletions(-) rename colossalai/{fx/profiler => }/_meta_registrations.py (84%) create mode 100644 colossalai/fx/tracer/_meta_trace.py diff --git a/colossalai/__init__.py b/colossalai/__init__.py index 697b73a74a9c..a28d8b607d0f 100644 --- a/colossalai/__init__.py +++ b/colossalai/__init__.py @@ -1,3 +1,8 @@ +try: + from ._meta_registrations import * +except: + import torch + print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.') from .initialize import (initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch, get_default_parser) diff --git a/colossalai/fx/profiler/_meta_registrations.py b/colossalai/_meta_registrations.py similarity index 84% rename from colossalai/fx/profiler/_meta_registrations.py rename to colossalai/_meta_registrations.py index 7dd3a21c9631..e645ef87a220 100644 --- a/colossalai/fx/profiler/_meta_registrations.py +++ b/colossalai/_meta_registrations.py @@ -5,7 +5,6 @@ import torch from torch.utils._pytree import tree_map - aten = torch.ops.aten meta_lib = torch.library.Library("aten", "IMPL", "Meta") @@ -14,15 +13,13 @@ def register_meta(op, register_dispatcher=True): + def wrapper(f): + def add_func(op): meta_table[op] = f if register_dispatcher: - name = ( - op.__name__ - if op._overloadname != "default" - else op.overloadpacket.__name__ - ) + name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__) meta_lib.impl(name, f) tree_map(add_func, op) @@ -44,6 +41,7 @@ def meta_conv( output_padding: List[int], groups: int, ): + def _formula(ln: int, p: int, d: int, k: int, s: int) -> int: """ Formula to apply to calculate the length of some dimension of the output @@ -120,14 +118,9 @@ def calc_conv_nd_return_shape( kernel_size[i], stride[i], output_padding_list[i], - ) - ) + )) else: - ret_shape.append( - _formula( - dims[i], padding[i], dilation[i], kernel_size[i], stride[i] - ) - ) + ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])) return ret_shape def pick_memory_format(): @@ -156,20 +149,16 @@ def pick_memory_format(): out_channels = weight.shape[0] if weight.shape[1] != input_tensor.shape[1] / groups: raise RuntimeError("Invalid channel dimensions") - shape_out = calc_conv_nd_return_shape( - dims, kernel_size, stride, padding, dilation - ) + shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation) out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out)) mem_fmt = pick_memory_format() - out = out.to(memory_format=mem_fmt) # type: ignore[call-overload] + out = out.to(memory_format=mem_fmt) # type: ignore[call-overload] return out @register_meta(aten.convolution_backward.default) -def meta_conv_backward( - grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, - bias_sizes, stride, padding, dilation, transposed, output_padding, groups, output_mask -): +def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride, + padding, dilation, transposed, output_padding, groups, output_mask): return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta') @@ -184,21 +173,20 @@ def meta_hardswish(input: torch.Tensor): @register_meta(aten.hardswish_backward.default) -def meta_hardswish_backward(grad_out:torch.Tensor, input: torch.Tensor): +def meta_hardswish_backward(grad_out: torch.Tensor, input: torch.Tensor): grad_in = torch.empty_like(input) return grad_in -@register_meta([aten.roll.default, ]) -def meta_roll(input:torch.Tensor, shifts, dims): +@register_meta([ + aten.roll.default, +]) +def meta_roll(input: torch.Tensor, shifts, dims): return torch.empty_like(input) @register_meta(aten.native_batch_norm.default) -def meta_bn( - input: torch.Tensor, - weight, bias, running_mean, running_var, training, momentum, eps -): +def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps): n_input = input.size(1) output = torch.empty_like(input) @@ -208,10 +196,8 @@ def meta_bn( @register_meta(aten.native_batch_norm_backward.default) -def meta_bn_backward( - dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, - running_mean, running_var, save_mean, save_invstd, train, eps, output_mask -): +def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean, + save_invstd, train, eps, output_mask): dX = torch.empty_like(input) dgamma = torch.empty_like(weight) dbeta = torch.empty_like(weight) @@ -219,10 +205,7 @@ def meta_bn_backward( @register_meta(aten.native_layer_norm.default) -def meta_ln( - input: torch.Tensor, - normalized_shape, weight, bias, eps -): +def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps): n_input = input.size(1) output = torch.empty_like(input) @@ -232,11 +215,8 @@ def meta_ln( @register_meta(aten.native_layer_norm_backward.default) -def meta_ln_backward( - dY: torch.Tensor, - input: torch.Tensor, - normalized_shape, mean, rstd, weight, bias, grad_input_mask -): +def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, + grad_input_mask): dX = torch.empty_like(input) dgamma = torch.empty_like(weight) dbeta = torch.empty_like(bias) @@ -245,7 +225,8 @@ def meta_ln_backward( @register_meta(aten._adaptive_avg_pool2d_backward.default) def meta_adaptive_avg_pool2d_backward( - grad_output: torch.Tensor, input: torch.Tensor, + grad_output: torch.Tensor, + input: torch.Tensor, ): grad_input = torch.empty_like(input) return torch.empty_like(input) @@ -266,7 +247,9 @@ def meta_index_Tensor(self, indices): k = len(result) assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}" for j in range(index.ndim): - assert index.shape[j] == self.shape[k + j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}" + assert index.shape[j] == self.shape[ + k + + j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}" result.append(nonzero.select(1, j)) else: result.append(index) @@ -275,7 +258,7 @@ def meta_index_Tensor(self, indices): indices = result assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})" # expand_outplace - import torch._refs as refs # avoid import cycle in mypy + import torch._refs as refs # avoid import cycle in mypy indices = list(refs._maybe_broadcast(*indices)) # add missing null tensors diff --git a/colossalai/fx/profiler/__init__.py b/colossalai/fx/profiler/__init__.py index 4b90bcb30b7a..9d657ad22030 100644 --- a/colossalai/fx/profiler/__init__.py +++ b/colossalai/fx/profiler/__init__.py @@ -1,8 +1,3 @@ -try: - from ._meta_registrations import * -except: - import torch - print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.') from .meta_tensor import MetaTensor from .registry import meta_profiler_function, meta_profiler_module from .profiler_function import * diff --git a/colossalai/fx/tracer/__init__.py b/colossalai/fx/tracer/__init__.py index ec6508a3040e..327e1510e0b5 100644 --- a/colossalai/fx/tracer/__init__.py +++ b/colossalai/fx/tracer/__init__.py @@ -1 +1,2 @@ from .tracer import ColoTracer +from ._meta_trace import meta_trace diff --git a/colossalai/fx/tracer/_meta_trace.py b/colossalai/fx/tracer/_meta_trace.py new file mode 100644 index 000000000000..6f4fdd6f1c70 --- /dev/null +++ b/colossalai/fx/tracer/_meta_trace.py @@ -0,0 +1,98 @@ +import torch +from torch.fx import Node, Graph +from torch.fx.graph import _Namespace +from torch.utils._pytree import tree_map +import colossalai.fx.profiler + +from typing import Tuple, Dict, Any, Optional + +import torchvision.models as tm + + +def meta_trace(module: torch.nn.Module, *args, **kwargs): + """Trace forward and backward graph with MetaTensor + + Args: + module (torch.nn.Module): The target module for tracing. + + Returns: + graph (torch.fx.Graph): The computation graph. + """ + graph = Graph() + namespace = _Namespace() + + class MetaProxy(torch.Tensor): + """ + A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops. + """ + + _tensor: torch.Tensor + _node: Node + + __slots__ = ['_tensor', '_node'] + + @staticmethod + def __new__(cls, tensor, placeholder=False, name=None): + r = torch.Tensor._make_wrapper_subclass( + cls, + tensor.size(), + strides=tensor.stride(), + storage_offset=tensor.storage_offset(), + dtype=tensor.dtype, + layout=tensor.layout, + device='cpu', + requires_grad=tensor.requires_grad) # deceive the frontend for aten selections + r._tensor = tensor + if placeholder: + if name is None: + name = 'input' + r._node = graph.create_node('placeholder', + 'placeholder', (graph._root,), + name=namespace.create_name(name, tensor)) + # ...the real tensor is held as an element on the tensor. + return r + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + + def unwrap(x): + if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'): + x = MetaProxy(x) + return x._tensor.to('meta') if isinstance(x, MetaProxy) else x + + def get_node(x): + if isinstance(x, torch.Tensor) and not hasattr(x, '_node'): + x = MetaProxy(x, placeholder=True, name='weight') + return x if not hasattr(x, '_node') else x._node + + args_node = tree_map(get_node, args) + kwargs_node = tree_map(get_node, kwargs) + node = graph.create_node('call_function', func, args_node, kwargs_node) + + args = tree_map(unwrap, args) + kwargs = tree_map(unwrap, kwargs) + + # run aten for backend=CPU but actually on backend=Meta + out = func(*args, **kwargs) + + # Now, we want to continue propagating this tensor, so we rewrap Tensors in + # our custom tensor subclass + def wrap(x): + return MetaProxy(x) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x + + def set_node(x): + x._node = node + + out = tree_map(wrap, out) + tree_map(set_node, out) + + return out + + def wrap(x): + return MetaProxy(x, True) if isinstance(x, torch.Tensor) else x + + args = tree_map(wrap, args) + kwargs = tree_map(wrap, kwargs) + + module(*args, **kwargs).sum().backward() + return graph From 11df5d86398997be38a18b969bdf8251c33d7322 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Fri, 2 Sep 2022 13:56:15 +0800 Subject: [PATCH 2/4] [fx] support meta tracing for aten level computation graphs like functorch. --- colossalai/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/__init__.py b/colossalai/__init__.py index a28d8b607d0f..b5fff7469a62 100644 --- a/colossalai/__init__.py +++ b/colossalai/__init__.py @@ -6,4 +6,4 @@ from .initialize import (initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch, get_default_parser) -__version__ = '0.0.1' +__version__ = '0.1.9' From cf6621ff008b10b3cf7ba1c3c269d955b67d61a8 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Fri, 2 Sep 2022 14:19:53 +0800 Subject: [PATCH 3/4] [fx] remove redundant import. --- colossalai/fx/__init__.py | 2 +- colossalai/fx/tracer/_meta_trace.py | 7 +------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/colossalai/fx/__init__.py b/colossalai/fx/__init__.py index 6513f6d03180..6d0475f70f60 100644 --- a/colossalai/fx/__init__.py +++ b/colossalai/fx/__init__.py @@ -1,2 +1,2 @@ -from .tracer import ColoTracer +from .tracer import ColoTracer, meta_trace from .graph_module import ColoGraphModule diff --git a/colossalai/fx/tracer/_meta_trace.py b/colossalai/fx/tracer/_meta_trace.py index 6f4fdd6f1c70..5bbe1aba5668 100644 --- a/colossalai/fx/tracer/_meta_trace.py +++ b/colossalai/fx/tracer/_meta_trace.py @@ -2,14 +2,9 @@ from torch.fx import Node, Graph from torch.fx.graph import _Namespace from torch.utils._pytree import tree_map -import colossalai.fx.profiler -from typing import Tuple, Dict, Any, Optional -import torchvision.models as tm - - -def meta_trace(module: torch.nn.Module, *args, **kwargs): +def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph: """Trace forward and backward graph with MetaTensor Args: From a151e393c234c4df96c27f0f45e046d23ea18062 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Sat, 3 Sep 2022 09:36:28 +0800 Subject: [PATCH 4/4] [fx] add docstring. --- colossalai/fx/tracer/_meta_trace.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/colossalai/fx/tracer/_meta_trace.py b/colossalai/fx/tracer/_meta_trace.py index 5bbe1aba5668..48b3e2debe24 100644 --- a/colossalai/fx/tracer/_meta_trace.py +++ b/colossalai/fx/tracer/_meta_trace.py @@ -12,6 +12,12 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph: Returns: graph (torch.fx.Graph): The computation graph. + + Usage: + >>> import torchvision.models as tm + >>> model = tm.alexnet() + >>> graph = meta_trace(model, torch.rand(1000, 3, 224, 224)) + >>> graph.print_tabular() """ graph = Graph() namespace = _Namespace()