Skip to content

[fx] hack __torch_dispatch__ for meta tensor and autograd.#1515

Merged
FrankLeeeee merged 19 commits intohpcaitech:mainfrom
super-dainiu:feature/meta_profiler
Aug 31, 2022
Merged

[fx] hack __torch_dispatch__ for meta tensor and autograd.#1515
FrankLeeeee merged 19 commits intohpcaitech:mainfrom
super-dainiu:feature/meta_profiler

Conversation

@super-dainiu
Copy link
Copy Markdown
Contributor

@super-dainiu super-dainiu commented Aug 29, 2022

What's new?

When I tried to run autograd with meta tensor input on vit_b_16, I discovered that some aten ops are not registered for meta backend. So following the suggestions in Function to automatically calculate Conv shape · Issue #79512 · pytorch/pytorch · GitHub, I tried to patch native_layer_norm.default for meta backend.

@register_meta(aten.native_layer_norm.default)
def meta_ln(
    input: torch.Tensor, 
    normalized_shape, weight, bias, eps
):
    n_input = input.size(1)

    output = torch.empty_like(input)
    running_mean = torch.empty((n_input), device='meta')
    running_var = torch.empty((n_input), device='meta')
    return output, running_mean, running_var

@register_meta(aten.native_layer_norm_backward.default)
def meta_ln_backward(
    dY: torch.Tensor,
    input: torch.Tensor, 
    normalized_shape, mean, rstd, weight, bias, grad_input_mask
):
    dX = torch.empty_like(input)
    dgamma = torch.empty_like(weight)
    dbeta = torch.empty_like(bias)
    return dX, dgamma, dbeta

However, even if patching is successful, the autograd dispatcher refuses to use my patched op for meta backend.

RuntimeError: 0 INTERNAL ASSERT FAILED at "../aten/src/ATen/core/boxing/KernelFunction.cpp":23, please report a bug to PyTorch. aten::native_layer_norm has kernels registered to both CompositeImplicitAutograd and a backend mapped to AutogradOther. This makes the backend kernel unreachable; the dispatcher will always prefer the CompositeImplicitAutograd lowering (see Note [Ambiguity in AutogradOther kernel]). If you want to override CompositeImplicitAutograd, please open an issue to request a dedicated Autograd dispatch key for the backend.   
If you only want to run inference instead of training, add `c10::InferenceMode mode;` before model.forward(). Note this guard is only available in C++ but not Python at present.  

So as discussed in CompositeImplicitAutograd operators should not perform operations that do not dispatch · Issue #61669 · pytorch/pytorch · GitHub, failing due to CompositeImplicitAutograd is inevitable for PyTorch version 1.12.0 and below. I somehow managed to develop another version of autograd with meta tensor.

class MetaTensor(torch.Tensor):

    elem: torch.Tensor
 
    __slots__ = ['elem']
 
    @staticmethod
    def __new__(cls, elem):
        r = torch.Tensor._make_wrapper_subclass(
            cls, elem.size(),
            strides=elem.stride(), storage_offset=elem.storage_offset(),
            dtype=elem.dtype, layout=elem.layout,
            device='cpu', requires_grad=elem.requires_grad
        )    # deceive the frontend for aten selections
        r.elem = elem
        return r

    @ classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        def unwrap(x):
            return x.elem.to('meta') if isinstance(x, MetaTensor) else x
                
        args = tree_map(unwrap, args)
        kwargs = tree_map(unwrap, kwargs)
        out = func(*args, **kwargs)
        
        def wrap(x):
            return MetaTensor(x) if isinstance(x, torch.Tensor) else x
           
        return tree_map(wrap, out)

Since previous works of the PyTorch team have supported aten ops on meta backend, we can simply hack the autograd dispatcher, deceiving it that we are running on CPU. In this way, the dispatcher will not use CompositeImplicitAutograd anymore, and our patched ops can be used for meta backend. So now we can do forward and backward with meta tensor only, and trace a large model with batch_size=1e10 in milliseconds.

model = vit_b_16()
data = MetaTensor(torch.rand(int(1e10), 3, 224, 224, device='meta'))
model.to('meta')(data).sum().backward()

With this amazing __torch_dispatch__, I replaced the previously patched version of tracing into REALLL meta tracing.

Concerns

Indeed, __torch_dispatch__ is not compatible with PyTorch 1.11.0 and below.

Comment thread colossalai/fx/profiler/__init__.py
Comment thread colossalai/fx/passes/meta_info_prop.py
Copy link
Copy Markdown
Contributor

@Cypher30 Cypher30 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! As we have debated, you could add the MetaTensor to proxy and try to record the atens inside MetaTensor data structure~

Copy link
Copy Markdown
Contributor

@Cypher30 Cypher30 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could pass this PR first~

Copy link
Copy Markdown
Contributor

@FrankLeeeee FrankLeeeee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should fix the broken unit tests before merging this PR.

Comment thread colossalai/fx/profiler/meta_tensor.py Outdated
Comment thread colossalai/fx/profiler/__init__.py Outdated
Comment thread colossalai/fx/profiler/_meta_registrations.py
Comment thread colossalai/fx/profiler/meta_tensor.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants