Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion colossalai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
try:
from ._meta_registrations import *
except:
import torch
print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.')
from .initialize import (initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch,
get_default_parser)

__version__ = '0.0.1'
__version__ = '0.1.9'
Comment thread
YuliangLiu0306 marked this conversation as resolved.
2 changes: 1 addition & 1 deletion colossalai/fx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .tracer import ColoTracer
from .tracer import ColoTracer, meta_trace
from .graph_module import ColoGraphModule
5 changes: 0 additions & 5 deletions colossalai/fx/profiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
try:
from ._meta_registrations import *
except:
import torch
print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.')
from .meta_tensor import MetaTensor
from .registry import meta_profiler_function, meta_profiler_module
from .profiler_function import *
Expand Down
1 change: 1 addition & 0 deletions colossalai/fx/tracer/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .tracer import ColoTracer
from ._meta_trace import meta_trace
99 changes: 99 additions & 0 deletions colossalai/fx/tracer/_meta_trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import torch
from torch.fx import Node, Graph
from torch.fx.graph import _Namespace
from torch.utils._pytree import tree_map


def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
"""Trace forward and backward graph with MetaTensor

Args:
module (torch.nn.Module): The target module for tracing.

Returns:
graph (torch.fx.Graph): The computation graph.

Usage:
>>> import torchvision.models as tm
>>> model = tm.alexnet()
>>> graph = meta_trace(model, torch.rand(1000, 3, 224, 224))
>>> graph.print_tabular()
"""
graph = Graph()
namespace = _Namespace()

class MetaProxy(torch.Tensor):
"""
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
"""

_tensor: torch.Tensor
_node: Node

__slots__ = ['_tensor', '_node']

@staticmethod
def __new__(cls, tensor, placeholder=False, name=None):
r = torch.Tensor._make_wrapper_subclass(
cls,
tensor.size(),
strides=tensor.stride(),
storage_offset=tensor.storage_offset(),
dtype=tensor.dtype,
layout=tensor.layout,
device='cpu',
requires_grad=tensor.requires_grad) # deceive the frontend for aten selections
r._tensor = tensor
if placeholder:
if name is None:
name = 'input'
r._node = graph.create_node('placeholder',
'placeholder', (graph._root,),
name=namespace.create_name(name, tensor))
# ...the real tensor is held as an element on the tensor.
return r

@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):

def unwrap(x):
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
x = MetaProxy(x)
return x._tensor.to('meta') if isinstance(x, MetaProxy) else x

def get_node(x):
if isinstance(x, torch.Tensor) and not hasattr(x, '_node'):
x = MetaProxy(x, placeholder=True, name='weight')
return x if not hasattr(x, '_node') else x._node

args_node = tree_map(get_node, args)
kwargs_node = tree_map(get_node, kwargs)
node = graph.create_node('call_function', func, args_node, kwargs_node)

args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)

# run aten for backend=CPU but actually on backend=Meta
out = func(*args, **kwargs)

# Now, we want to continue propagating this tensor, so we rewrap Tensors in
# our custom tensor subclass
def wrap(x):
return MetaProxy(x) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x

def set_node(x):
x._node = node

out = tree_map(wrap, out)
tree_map(set_node, out)

return out

def wrap(x):
return MetaProxy(x, True) if isinstance(x, torch.Tensor) else x

args = tree_map(wrap, args)
kwargs = tree_map(wrap, kwargs)

module(*args, **kwargs).sum().backward()
return graph