Option to set 'non_blocking' for to(device) in BatchEncoding and BatchFeature#34883
Conversation
… improvements. Defaults to 'false', thus no behavioral changes.
qubvel
left a comment
There was a problem hiding this comment.
Hi @daniel-bogdoll, thanks for adding this! It looks great to me. Do you think it might be worth extending the same option to BatchFeature to ensure consistent capabilities?
|
Thanks @qubvel, sure thing! Which tests would I need to run to make sure modifications in the to() function of BatchFeature get tested? Just to make sure, I assume you refer to ? |
|
Yes, I refer to this one, but not sure it's properly tested anywhere, I was able to find only |
|
Maybe we can do it as simple as non_blocking = kwargs.get("non_blocking", False)
...
elif isinstance(v, torch.Tensor) and device is not None:
new_data[k] = v.to(device=device, non_blocking=non_blocking)
... |
|
That's how I would have tried it as well. But what about this block? Here device is derived from |
I don't think so, maybe at some moment, it is worth refactoring this method for more explicit args and kwargs. For now, we can add a note in docstring that |
|
@qubvel Done! Thanks for the super-fast replies, was a pleasure! Tests fail now, though: For the first one, as you stated here (#34826 (comment)), it does not seem to be related. As the second one is a timeout issue, it also seems unrelated: |
qubvel
left a comment
There was a problem hiding this comment.
Thanks for updates! Looks great, just a small suggestion
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
|
@ArthurZucker or @LysandreJik please review when you have bandwidth |
ArthurZucker
left a comment
There was a problem hiding this comment.
Yeah sound super good!
| return self | ||
|
|
||
| def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding": | ||
| def to(self, device: Union[str, "torch.device"], *, non_blocking: bool = False) -> "BatchEncoding": |
There was a problem hiding this comment.
@qubvel suggested this to enforce it as a keyword argument for future backwards compatability. All arguments after the * are forced to be passed as keyword arguments: #34883 (comment)
There was a problem hiding this comment.
Yes, only device can be passed as a positional argument with * introduced. This way, we will prevent anyone from using batch_feature.to("cuda", True) instead of batch_feature.to("cuda", non_blocking=True). This would be useful in case we introduce more positional arguments in the future or need to change order, for example, with adding dtype.
There was a problem hiding this comment.
Thanks for explaining, good decision @qubvel ! 🤗
Option to set 'non_blocking' for to(device) operation in BatchEncoding for performance improvements. Defaults to 'false', thus no behavioral changes.
What does this PR do?
This minor PR adds the non_blocking option to the to() function.
Previous: def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
New: def to(self, device: Union[str, "torch.device"], non_blocking: bool = False) -> "BatchEncoding":
Since non_blocking defaults to 'False', this PR does not introduce behavioral changes.
I realized, when utilizing Zero Shot Object Detection models, that it was not possible to set this option, leading to sub-optimal performance during inference.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?