Skip to content

Register ModelOutput as supported torch pytree nodes#26618

Merged
LysandreJik merged 3 commits intohuggingface:mainfrom
XuehaiPan:model-output-torch-pytree-node
Oct 24, 2023
Merged

Register ModelOutput as supported torch pytree nodes#26618
LysandreJik merged 3 commits intohuggingface:mainfrom
XuehaiPan:model-output-torch-pytree-node

Conversation

@XuehaiPan
Copy link
Copy Markdown
Contributor

@XuehaiPan XuehaiPan commented Oct 5, 2023

What does this PR do?

This registers ModelOutput as supported torch pytree nodes. Currently all subclasses of ModelOutput are already registered by ModelOutput.__init_subclass__(). This PR additionally registers ModelOutput itself as supported torch pytree nodes.

See also:

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sgugger

@XuehaiPan XuehaiPan force-pushed the model-output-torch-pytree-node branch from 551a9bc to 44a73ce Compare October 5, 2023 17:53
@LysandreJik
Copy link
Copy Markdown
Member

This looks important but I don't have the bandwidth to dive into it, could you take a look if you have bandwidth @ydshieh ?

@ydshieh ydshieh self-requested a review October 12, 2023 15:42
@ydshieh
Copy link
Copy Markdown
Collaborator

ydshieh commented Oct 12, 2023

Hi @XuehaiPan,

Could you elaborate a bit more on why we need to register ModelOutput itself: why #25358 isn't enough?

@XuehaiPan
Copy link
Copy Markdown
Contributor Author

Could you elaborate a bit more on why we need to register ModelOutput itself: why #25358 isn't enough?

@ydshieh

  1. ModelOutput is a container-like type, a subclass of OrderedDict. This fits the definition of a pytree node rather than a leaf.

  2. The PyTorch upstream provides a context manager to register ModelOutput itself and its subclasses as pytree nodes. When we exit the context manager in __exit__ method, it will unregister the types that registered in the __enter__ method.
    https://github.com/pytorch/pytorch/blob/53a9ac534c777047abc2b6b293f238865592291e/torch/onnx/_internal/fx/dynamo_graph_extractor.py#L31-L103

    We have already registered the ModelOutput's subclasses as pytree node type. So in the __enter__ method, it will additionally register ModelOutput as pytree node type, and unregister it in the __exit__ method. All subclasses of ModelOutput in the pytree node type registry will remain untouched. See my comment in [pytree] Make optree optional and populate members from _pytree when not available pytorch/pytorch#109684 (comment).

    I think deleting or updating the pytree node registry is dangerous. It could lead to potential bugs when the developer passes the treespec object out of the scope of the context manager.

    with PyTreeExtensionContext():
        # in the context
        # ModelOutput is a container-like node type, it will live in the treespec object
        leaves, treespec = pytree.flatten(tree)
    
    # out of the context
    # ModelOutput is not in the pytree node type registry and it will be considered as leaf
    # but in the treespec object, ModelOutput a node type
    pytree.tree_unflatten(leaves, treespec)  # -> KeyError: cannot found ModelOutput in pytree.SUPPORTED_NODES

    I could either do not register the ModelOutput type in the context manager in the PyTorch upstream, or register it as pytree node in the first place in transformers. I think it is more practical to register both ModelOutput and its subclasses as node type (e.g., for case isinstance(obj, ModelOutput)).

@ydshieh
Copy link
Copy Markdown
Collaborator

ydshieh commented Oct 13, 2023

Hi @XuehaiPan : I am not familiar with this, so would love to (and need) a bit more information 😅

  • In the PyTorch upstream, ModelOutput is already registered as as pytree nodes`?

    • and if so, why we need do this at both in pytorch and transformers
    • if they are not doing exactly the same job, what are the difference between register in torch and register in transformers
  • ... in the __enter__ method, it will additionally register ModelOutput as pytree node type:

@XuehaiPan
Copy link
Copy Markdown
Contributor Author

XuehaiPan commented Oct 13, 2023

... in the __enter__ method, it will additionally register ModelOutput as pytree node type:

@ydshieh This is correct. Let me elaborate this.

In the PyTorch upstream:

https://github.com/pytorch/pytorch/blob/359336e3e9a0f67974e53805b5207fbbbc149490/torch/onnx/_internal/fx/dynamo_graph_extractor.py#L93-L98

        # All 'ModelOutput' subclasses are defined under module 'modeling_outputs'.
        named_model_output_classes = inspect.getmembers(
            modeling_outputs,
            lambda x: inspect.isclass(x)
            and issubclass(x, modeling_outputs.ModelOutput),
        )

The model output classes are determined by issubclass(cls, ModelOutput). It will treat the ModelOutput itself as model output class.

>>> issubclass(ModelOutput, ModelOutput)
True
In [1]: from transformers import modeling_outputs

In [2]: import inspect

In [3]: named_model_output_classes = inspect.getmembers(
   ...:     modeling_outputs,
   ...:     lambda x: inspect.isclass(x)
   ...:     and issubclass(x, modeling_outputs.ModelOutput),
   ...: )

In [4]: dict(named_model_output_classes)
Out[4]: 
{'BackboneOutput': transformers.modeling_outputs.BackboneOutput,
 ...,
 'MoEModelOutputWithPastAndCrossAttentions': transformers.modeling_outputs.MoEModelOutputWithPastAndCrossAttentions,
 'ModelOutput': transformers.utils.generic.ModelOutput,
 'MultipleChoiceModelOutput': transformers.modeling_outputs.MultipleChoiceModelOutput,
 ...,
 'XVectorOutput': transformers.modeling_outputs.XVectorOutput}

In the PyTreeExtensionContext manager, it will ignore types that are already registered in the pytree.SUPPORTED_NODES.

https://github.com/pytorch/pytorch/blob/359336e3e9a0f67974e53805b5207fbbbc149490/torch/onnx/_internal/fx/dynamo_graph_extractor.py#L67-L72

        if class_type in pytree.SUPPORTED_NODES or class_type in self._extensions:
            # PyTree node already registered.
            # E.g., `huggingface/transformer` registers `ModelOutput` as PyTree node after
            # https://github.com/huggingface/transformers/pull/25358.
            return
        self._extensions[class_type] = (flatten_func, unflatten_func)

Without this PR, the upstream code behaves like:

# This command will register all subclasses of `ModelOutput` except `ModelOutput`
# itself as pytree node.
import transformers

# <<< at here
#     - all subclasses of `ModelOutput` are pytree node
#     - `ModelOutput` is not pytree node

# This command will add `ModelOutput` into `context._extension` because it is
# not registered as pytree node.
# All subclasses of `ModelOutput` will be ignored because they are already
# registered as pytree node by transformers.
context = PyTreeExtensionContext()

# <<< at here
#     - all subclasses of `ModelOutput` are pytree node
#     - `ModelOutput` is not pytree node
#     - context._extension = {ModelOutput: (..., ...)}

# This command will register elements in `context._extension` (which is only
# `ModelOutput`) as pytree node.
context.__enter__()

# <<< at here
#     - all subclasses of `ModelOutput` are pytree node
#     - `ModelOutput` is pytree node
#     - context._extension = {ModelOutput: (..., ...)}

# do something in the context
...

# This command will unregister elements in `context._extension` (which is only
# `ModelOutput`) as pytree node.
context.__exit__()

# <<< at here
#     - all subclasses of `ModelOutput` are pytree node
#     - `ModelOutput` is not pytree node

with this PR:

# This command will register all subclasses of `ModelOutput` and `ModelOutput`
# itself as pytree node.
import transformers

# <<< at here
#     - all subclasses of `ModelOutput` are pytree node
#     - `ModelOutput` itself is also pytree node

# All subclasses of `ModelOutput` and `ModelOutput` itself will be ignored
# because they are already registered as pytree node by transformers.
# This is a no-op and `context._extension` is empty.
context = PyTreeExtensionContext()

# <<< at here
#     - all subclasses of `ModelOutput` are pytree node
#     - `ModelOutput` itself is also pytree node
#     - context._extension = {}  # empty

# This command will register elements in `context._extension` (which is empty)
# as pytree node. This is a no-op.
context.__enter__()

# <<< at here
#     - all subclasses of `ModelOutput` are pytree node
#     - `ModelOutput` is pytree node
#     - context._extension = {}  # empty

# do something in the context
...

# This command will unregister elements in `context._extension` (which is empty)
# as pytree node. This is a no-op.
context.__exit__()

# <<< at here
#     - all subclasses of `ModelOutput` are pytree node
#     - `ModelOutput` itself is also pytree node

@ydshieh
Copy link
Copy Markdown
Collaborator

ydshieh commented Oct 13, 2023

@XuehaiPan Thank you a lot. I will take a look next week! This is really new to me 💪 !

@XuehaiPan
Copy link
Copy Markdown
Contributor Author

@ydshieh Hi, any update for this?

@ydshieh
Copy link
Copy Markdown
Collaborator

ydshieh commented Oct 20, 2023

Hi @XuehaiPan, thank you a lot of your super clean and detailed comment, it helps me a lot to get started. So yes, ModelOutput itself will behave differently depending on if it is inside the context manager or outside it.

I think deleting or updating the pytree node registry is dangerous

I have 2 (general) questions:

  • I understand what you mentioned. But does this mean pytorch unstream has to register all standard container-like classes (list, dict, etc.) as pytree nodes?
    • (I guess so, but want to hear from you 😄 )
  • ModelOutput is not designed to be used directly (well, to be fair, within transformers code base).
    • All usage should inheritate ModelOutput and specify the attributes an output have), see, for example,
      class BaseModelOutput(ModelOutput):
      """
      Base class for model's outputs, with potential hidden states and attentions.
      Args:
      last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
      Sequence of hidden-states at the output of the last layer of the model.
      hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
      Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
      one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
      Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
      attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
      Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
      sequence_length)`.
      Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
      heads.
      """
      last_hidden_state: torch.FloatTensor = None
      hidden_states: Optional[Tuple[torch.FloatTensor]] = None
      attentions: Optional[Tuple[torch.FloatTensor]] = None
    • (it's kind abstract class, despite we can still use it directly, but that usage is outside transformers scope.)
    • Moreover, I would say it is super rare users will use ModelOutput directly (in a manually way) to store their model output (without creating a subclass first).
      • (and even more rare, this usage along with PyTreeExtensionContext context manager)

Overall, I am fine with the idea of registering ModelOutput as torch pytree nodes, but would like the above 2 questions addressed. Then I can take a review on the changes themselves and request a core maintainer's review.

Thank you again @XuehaiPan !

@XuehaiPan
Copy link
Copy Markdown
Contributor Author

XuehaiPan commented Oct 20, 2023

@ydshieh

  • I understand what you mentioned. But does this mean pytorch unstream has to register all standard container-like classes (list, dict, etc.) as pytree nodes?

    • (I guess so, but want to hear from you 😄 )

Currently, the PyTorch upstream registers:

  • tuple
  • list
  • dict
  • OrderedDict
  • classes created by collections.namedtuple
  • torch.return_types.* (PyStructSequence, e.g., torch.return_types.sort for function torch.sort())
  • torch.fx.immutable_collections.immutable_list
  • torch.fx.immutable_collections.immutable_dict

as pytree nodes.

We are going to add more container-like classes in the stdlib (collections.defaultdict and collections.deque) as type pytree nodes in the follow-up PRs in the PyTorch upstream.

For classes from third-party packages (non-Python stdlib, and non-PyTorch internals), users or developers need to manually register the classes as pytree types themselves. For example, we manually register the subclasses of ModelOutput in the transformers package.


  • ModelOutput is not designed to be used directly (well, to be fair, within transformers code base).

    • All usage should inheritate ModelOutput and specify the attributes an output have)

Moreover, I would say it is super rare users will use ModelOutput directly (in a manually way) to store their model output (without creating a subclass first).

  • (and even more rare, this usage along with PyTreeExtensionContext context manager)

I agree that is a rare case. So we can resolve this issue in two ways:

  • register ModelOutput as pytree node in transformers (this is what this PR does).
  • not registering the ModelOutput as pytree node in the PyTreeExtensionContext context manager in the PyTorch upstream.

Both solutions are fine for me.

@ydshieh
Copy link
Copy Markdown
Collaborator

ydshieh commented Oct 20, 2023

Let's go for register ModelOutput as pytree node in transformers (this is what this PR does). to align it with OrderedDict (it's parent). I will take a review on the changes 🤗 today or next Monday!

@ydshieh
Copy link
Copy Markdown
Collaborator

ydshieh commented Oct 20, 2023

@XuehaiPan

Could you provide a complete, self-contained (i.e. with all imports and variables defined) code snippet of the following (the one you provided before):

with PyTreeExtensionContext():
    # in the context
    # ModelOutput is a container-like node type, it will live in the treespec object
    leaves, treespec = pytree.flatten(tree)

# out of the context
# ModelOutput is not in the pytree node type registry and it will be considered as leaf
# but in the treespec object, ModelOutput a node type
pytree.tree_unflatten(leaves, treespec)  # -> KeyError: cannot found ModelOutput in pytree.SUPPORTED_NODES

@XuehaiPan
Copy link
Copy Markdown
Contributor Author

Here is a self-contained snippet with torch==2.1.0 and transformers==4.34.1.

import torch
import transformers
import torch.utils._pytree as pytree
from torch.onnx._internal.fx.dynamo_graph_extractor import _PyTreeExtensionContext as PyTreeExtensionContext
from transformers.modeling_utils import ModelOutput

print('===== before entering context =====')
leaves, treespec = pytree.tree_flatten(ModelOutput(a=1, b=2))  # ModelOutput is a leaf
print('leaves:', leaves)
print('treespec:', treespec)
print('reconstructed:', pytree.tree_unflatten(leaves, treespec))

with PyTreeExtensionContext():
    print('===== entered context =====')
    leaves, treespec = pytree.tree_flatten(ModelOutput(a=1, b=2))  # ModelOutput is a node
    print('leaves:', leaves)
    print('treespec:', treespec)
    print('reconstructed:', pytree.tree_unflatten(leaves, treespec))

print('===== exited context =====')
print('leaves:', leaves)
print('treespec:', treespec)
print('reconstructed:', pytree.tree_unflatten(leaves, treespec))

Output:

$ python3 test.py             
===== before entering context =====
leaves: [ModelOutput([('a', 1), ('b', 2)])]
treespec: *
reconstructed: ModelOutput([('a', 1), ('b', 2)])
===== entered context =====
leaves: [1, 2]
treespec: TreeSpec(ModelOutput, (<class 'transformers.utils.generic.ModelOutput'>, ['a', 'b']), [*,
  *])
reconstructed: ModelOutput([('a', 1), ('b', 2)])
===== exited context =====
leaves: [1, 2]
treespec: TreeSpec(ModelOutput, (<class 'transformers.utils.generic.ModelOutput'>, ['a', 'b']), [*,
  *])
Traceback (most recent call last):
  File ".../test.py", line 23, in <module>
    print('reconstructed:', pytree.tree_unflatten(leaves, treespec))
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.11/site-packages/torch/utils/_pytree.py", line 268, in tree_unflatten
    unflatten_fn = SUPPORTED_NODES[spec.type].unflatten_fn
                   ~~~~~~~~~~~~~~~^^^^^^^^^^^
KeyError: <class 'transformers.utils.generic.ModelOutput'>

Comment on lines +309 to +312
_torch_pytree._register_pytree_node(
cls,
torch.utils._pytree._dict_flatten,
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
_model_output_flatten,
_model_output_unflatten,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if it is possible to move this to __init__, so we don't need to do it twice: one for subclasses and one for ModelOutput itself.

Copy link
Copy Markdown
Contributor Author

@XuehaiPan XuehaiPan Oct 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to register the class object rather than the instance object to the pytree registry. The __init__ method takes self, the instance, as the argument. We need to do this with cls (via a classmethod). There are two solutions:

  1. Register the subclasses in ModelOutput.__init_subclass__() and register the ModelOutput class in the global scope. (This PR)
class Parent:
    def __init_subclass__(cls):
        pytree.register_node(cls, ...)

class Child1(Parent):
    pass

class Child2(Parent):
    pass

class GrandChild(Child1):
    pass

pytree.register_node(Parent, ...)
  1. Implement a metaclass and register the classes (both ModelOutput and its subclasses) in metaclass.__new__().
class MetaClass(type):
    def __new__(mcls, name, bases, dict, **kwds):
        cls = super().__new__(name, bases, dict, **kwds)
        pytree.register_node(cls, ...)
        return cls

class Parent(metaclass=MetaClass):
    pass

class Child1(Parent):
    pass

class Child2(Parent):
    pass

class GrandChild(Child1):
    pass

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, your are right! Thanks!

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Copy Markdown
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the work and all the detailed explanations, @XuehaiPan

Copy link
Copy Markdown
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the deep dive @XuehaiPan @ydshieh!

@LysandreJik LysandreJik merged commit cc7803c into huggingface:main Oct 24, 2023
@XuehaiPan XuehaiPan deleted the model-output-torch-pytree-node branch October 24, 2023 09:27
i4never pushed a commit to i4never/transformers that referenced this pull request Oct 25, 2023
* Register ModelOutput as supported torch pytree nodes

* Test ModelOutput as supported torch pytree nodes

* Update type hints for pytree unflatten functions
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
* Register ModelOutput as supported torch pytree nodes

* Test ModelOutput as supported torch pytree nodes

* Update type hints for pytree unflatten functions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants