Pass device in Logits Processor's init#29804
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
gante
left a comment
There was a problem hiding this comment.
Overall notes before going to details:
- In the processors that take
eos_token_idas input: see #29788. In this PR, the special tokens are treated as tensors by default, solving most of the needed changes. I would rebase this PR onmainafter that PR is merged, as some of the changes here will become redundant :) - On the processors that don't need to use
device, such asTemperatureLogitsWarper-- let's not add unused arguments. Clean interfaces are important 🧼 (unless there are significant benefits from standardizing them) - Let's not throw a warning when the device is not passed and tensors are initialized on CPU. A
.tooperation is not that expensive :)
|
|
Not stale |
|
This PR now can be reviewed. Rebased main and updated the changes. All the tests from |
gante
left a comment
There was a problem hiding this comment.
LGTM, thank you for improving generate :D
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
|
@gante Ah I forgot whisper is encoder-decoder. Oke, now it infers device from one of the inputs passed by the user. |
|
How could the bot come 🤣 anyways on it! |
ArthurZucker
left a comment
There was a problem hiding this comment.
Overall LGTM, not sure input_ids device is the always the best, and we need a small test to see which feature is enable by this potentially!
| if device is None: | ||
| device = "cpu" | ||
|
|
There was a problem hiding this comment.
I'd argue that we can just set it to "cpu" in the arg no?
There was a problem hiding this comment.
This is mostly for users who use/pass LogitsProcessor as a standalone kwarg, because 'generate()' takes care that device is not None.
I think we should raise warning for BC saying users to pass-in the device, but let's ask @gante if he's okay with it. If I am not misunderstanding, we shouldn't raise warnings 🤔
Let's not throw a warning when the device is not passed and tensors are initialized on CPU. A .to operation is not that expensive :)
There was a problem hiding this comment.
yeah, don't think it's a problem to silently do this
There was a problem hiding this comment.
Down to just default to CPU which was already the behaviour by default before this PR no?
There was a problem hiding this comment.
ahh my bad, didn't read carefully the first comment. Setting in the arg as default is better, right
My concern is that before this PR, we were placing these on scores.device during "_ [call]_ " , but anyway I still get lost at when to do BC deprecation and when to not do 😄
| prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, | ||
| logits_processor=logits_processor, | ||
| device=inputs_tensor.device, | ||
| device=input_ids.device, |
There was a problem hiding this comment.
why is this required ?
There was a problem hiding this comment.
Right! I thought that it was me who changed to inputs_tensor and was trying to revert 😆 I'll revert it back, no difference whichever tensor we use here
There was a problem hiding this comment.
should be use self.device? or lm_head.device? (which is not always there but still)
There was a problem hiding this comment.
I think we need to make sure dive placement on multi GPU works, might already be tested !
There was a problem hiding this comment.
got it. Any how LGTM
ArthurZucker
left a comment
There was a problem hiding this comment.
could you rebase your branch ? (format changes seems unrelated?)
|
Oke, rebased main and the unnecessary formatting is removed. Will merge as I guess we don't need to add warnings :) |
What does this PR do?
This PR adds the ability to pass in device when initializing
LogitsProcessorsand is one more step towardscompilecompatibility.Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@gante