Register ModelOutput as supported torch pytree nodes#26618
Register ModelOutput as supported torch pytree nodes#26618LysandreJik merged 3 commits intohuggingface:mainfrom
Conversation
85ef81e to
551a9bc
Compare
551a9bc to
44a73ce
Compare
|
This looks important but I don't have the bandwidth to dive into it, could you take a look if you have bandwidth @ydshieh ? |
|
Hi @XuehaiPan, Could you elaborate a bit more on why we need to register |
|
|
Hi @XuehaiPan : I am not familiar with this, so would love to (and need) a bit more information 😅
|
@ydshieh This is correct. Let me elaborate this. In the PyTorch upstream: # All 'ModelOutput' subclasses are defined under module 'modeling_outputs'.
named_model_output_classes = inspect.getmembers(
modeling_outputs,
lambda x: inspect.isclass(x)
and issubclass(x, modeling_outputs.ModelOutput),
)The model output classes are determined by >>> issubclass(ModelOutput, ModelOutput)
TrueIn [1]: from transformers import modeling_outputs
In [2]: import inspect
In [3]: named_model_output_classes = inspect.getmembers(
...: modeling_outputs,
...: lambda x: inspect.isclass(x)
...: and issubclass(x, modeling_outputs.ModelOutput),
...: )
In [4]: dict(named_model_output_classes)
Out[4]:
{'BackboneOutput': transformers.modeling_outputs.BackboneOutput,
...,
'MoEModelOutputWithPastAndCrossAttentions': transformers.modeling_outputs.MoEModelOutputWithPastAndCrossAttentions,
'ModelOutput': transformers.utils.generic.ModelOutput,
'MultipleChoiceModelOutput': transformers.modeling_outputs.MultipleChoiceModelOutput,
...,
'XVectorOutput': transformers.modeling_outputs.XVectorOutput}In the PyTreeExtensionContext manager, it will ignore types that are already registered in the if class_type in pytree.SUPPORTED_NODES or class_type in self._extensions:
# PyTree node already registered.
# E.g., `huggingface/transformer` registers `ModelOutput` as PyTree node after
# https://github.com/huggingface/transformers/pull/25358.
return
self._extensions[class_type] = (flatten_func, unflatten_func)Without this PR, the upstream code behaves like: # This command will register all subclasses of `ModelOutput` except `ModelOutput`
# itself as pytree node.
import transformers
# <<< at here
# - all subclasses of `ModelOutput` are pytree node
# - `ModelOutput` is not pytree node
# This command will add `ModelOutput` into `context._extension` because it is
# not registered as pytree node.
# All subclasses of `ModelOutput` will be ignored because they are already
# registered as pytree node by transformers.
context = PyTreeExtensionContext()
# <<< at here
# - all subclasses of `ModelOutput` are pytree node
# - `ModelOutput` is not pytree node
# - context._extension = {ModelOutput: (..., ...)}
# This command will register elements in `context._extension` (which is only
# `ModelOutput`) as pytree node.
context.__enter__()
# <<< at here
# - all subclasses of `ModelOutput` are pytree node
# - `ModelOutput` is pytree node
# - context._extension = {ModelOutput: (..., ...)}
# do something in the context
...
# This command will unregister elements in `context._extension` (which is only
# `ModelOutput`) as pytree node.
context.__exit__()
# <<< at here
# - all subclasses of `ModelOutput` are pytree node
# - `ModelOutput` is not pytree nodewith this PR: # This command will register all subclasses of `ModelOutput` and `ModelOutput`
# itself as pytree node.
import transformers
# <<< at here
# - all subclasses of `ModelOutput` are pytree node
# - `ModelOutput` itself is also pytree node
# All subclasses of `ModelOutput` and `ModelOutput` itself will be ignored
# because they are already registered as pytree node by transformers.
# This is a no-op and `context._extension` is empty.
context = PyTreeExtensionContext()
# <<< at here
# - all subclasses of `ModelOutput` are pytree node
# - `ModelOutput` itself is also pytree node
# - context._extension = {} # empty
# This command will register elements in `context._extension` (which is empty)
# as pytree node. This is a no-op.
context.__enter__()
# <<< at here
# - all subclasses of `ModelOutput` are pytree node
# - `ModelOutput` is pytree node
# - context._extension = {} # empty
# do something in the context
...
# This command will unregister elements in `context._extension` (which is empty)
# as pytree node. This is a no-op.
context.__exit__()
# <<< at here
# - all subclasses of `ModelOutput` are pytree node
# - `ModelOutput` itself is also pytree node |
|
@XuehaiPan Thank you a lot. I will take a look next week! This is really new to me 💪 ! |
|
@ydshieh Hi, any update for this? |
|
Hi @XuehaiPan, thank you a lot of your super clean and detailed comment, it helps me a lot to get started. So yes,
I have 2 (general) questions:
Overall, I am fine with the idea of registering Thank you again @XuehaiPan ! |
Currently, the PyTorch upstream registers:
as pytree nodes. We are going to add more container-like classes in the stdlib ( For classes from third-party packages (non-Python stdlib, and non-PyTorch internals), users or developers need to manually register the classes as pytree types themselves. For example, we manually register the subclasses of ModelOutput in the transformers package.
I agree that is a rare case. So we can resolve this issue in two ways:
Both solutions are fine for me. |
|
Let's go for |
|
Could you provide a complete, self-contained (i.e. with all imports and variables defined) code snippet of the following (the one you provided before): with PyTreeExtensionContext():
# in the context
# ModelOutput is a container-like node type, it will live in the treespec object
leaves, treespec = pytree.flatten(tree)
# out of the context
# ModelOutput is not in the pytree node type registry and it will be considered as leaf
# but in the treespec object, ModelOutput a node type
pytree.tree_unflatten(leaves, treespec) # -> KeyError: cannot found ModelOutput in pytree.SUPPORTED_NODES |
|
Here is a self-contained snippet with import torch
import transformers
import torch.utils._pytree as pytree
from torch.onnx._internal.fx.dynamo_graph_extractor import _PyTreeExtensionContext as PyTreeExtensionContext
from transformers.modeling_utils import ModelOutput
print('===== before entering context =====')
leaves, treespec = pytree.tree_flatten(ModelOutput(a=1, b=2)) # ModelOutput is a leaf
print('leaves:', leaves)
print('treespec:', treespec)
print('reconstructed:', pytree.tree_unflatten(leaves, treespec))
with PyTreeExtensionContext():
print('===== entered context =====')
leaves, treespec = pytree.tree_flatten(ModelOutput(a=1, b=2)) # ModelOutput is a node
print('leaves:', leaves)
print('treespec:', treespec)
print('reconstructed:', pytree.tree_unflatten(leaves, treespec))
print('===== exited context =====')
print('leaves:', leaves)
print('treespec:', treespec)
print('reconstructed:', pytree.tree_unflatten(leaves, treespec))Output: $ python3 test.py
===== before entering context =====
leaves: [ModelOutput([('a', 1), ('b', 2)])]
treespec: *
reconstructed: ModelOutput([('a', 1), ('b', 2)])
===== entered context =====
leaves: [1, 2]
treespec: TreeSpec(ModelOutput, (<class 'transformers.utils.generic.ModelOutput'>, ['a', 'b']), [*,
*])
reconstructed: ModelOutput([('a', 1), ('b', 2)])
===== exited context =====
leaves: [1, 2]
treespec: TreeSpec(ModelOutput, (<class 'transformers.utils.generic.ModelOutput'>, ['a', 'b']), [*,
*])
Traceback (most recent call last):
File ".../test.py", line 23, in <module>
print('reconstructed:', pytree.tree_unflatten(leaves, treespec))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.11/site-packages/torch/utils/_pytree.py", line 268, in tree_unflatten
unflatten_fn = SUPPORTED_NODES[spec.type].unflatten_fn
~~~~~~~~~~~~~~~^^^^^^^^^^^
KeyError: <class 'transformers.utils.generic.ModelOutput'> |
| _torch_pytree._register_pytree_node( | ||
| cls, | ||
| torch.utils._pytree._dict_flatten, | ||
| lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)), | ||
| _model_output_flatten, | ||
| _model_output_unflatten, |
There was a problem hiding this comment.
I am wondering if it is possible to move this to __init__, so we don't need to do it twice: one for subclasses and one for ModelOutput itself.
There was a problem hiding this comment.
We need to register the class object rather than the instance object to the pytree registry. The __init__ method takes self, the instance, as the argument. We need to do this with cls (via a classmethod). There are two solutions:
- Register the subclasses in
ModelOutput.__init_subclass__()and register theModelOutputclass in the global scope. (This PR)
class Parent:
def __init_subclass__(cls):
pytree.register_node(cls, ...)
class Child1(Parent):
pass
class Child2(Parent):
pass
class GrandChild(Child1):
pass
pytree.register_node(Parent, ...)- Implement a metaclass and register the classes (both
ModelOutputand its subclasses) inmetaclass.__new__().
class MetaClass(type):
def __new__(mcls, name, bases, dict, **kwds):
cls = super().__new__(name, bases, dict, **kwds)
pytree.register_node(cls, ...)
return cls
class Parent(metaclass=MetaClass):
pass
class Child1(Parent):
pass
class Child2(Parent):
pass
class GrandChild(Child1):
passThere was a problem hiding this comment.
Oh, your are right! Thanks!
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
ydshieh
left a comment
There was a problem hiding this comment.
Thank you for the work and all the detailed explanations, @XuehaiPan
LysandreJik
left a comment
There was a problem hiding this comment.
Thanks a lot for the deep dive @XuehaiPan @ydshieh!
* Register ModelOutput as supported torch pytree nodes * Test ModelOutput as supported torch pytree nodes * Update type hints for pytree unflatten functions
* Register ModelOutput as supported torch pytree nodes * Test ModelOutput as supported torch pytree nodes * Update type hints for pytree unflatten functions
What does this PR do?
This registers
ModelOutputas supportedtorchpytree nodes. Currently all subclasses ofModelOutputare already registered byModelOutput.__init_subclass__(). This PR additionally registersModelOutputitself as supportedtorchpytree nodes.See also:
optreeoptional and populate members from_pytreewhen not available pytorch/pytorch#109684 (comment)Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@sgugger