Generate: replace breaks by a loop condition#29662
Conversation
|
|
||
| if this_peer_finished and not synced_gpus: | ||
| break | ||
| this_peer_finished = unfinished_sequences.max() == 0 |
There was a problem hiding this comment.
The previous version is also a data-dependent control flow, so this change is for torch.compile readiness :)
|
FYI @zucchini-nlp (the stopping criteria solution did not preserve ZeRO stage 3 support) |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
amyeroberts
left a comment
There was a problem hiding this comment.
Thanks for working on this - so much cleaner 🤩
| else: | ||
| if this_peer_finished: | ||
| return False | ||
| return True |
There was a problem hiding this comment.
Or actually, we can just do
| else: | |
| if this_peer_finished: | |
| return False | |
| return True | |
| return not this_peer_finished |
There was a problem hiding this comment.
This solution can return False when synced_gpus is True and this_peer_finished is True, which is not intended -- this_peer_finished has to be True in all distributed devices when synced_gpus is True 🤗
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
What does this PR do?
Pulled from the
torch.compile(..., fullgraph=True)draft PR: #29374It replaces the
breaksthat exit the endless generation loop with an equivalent function that returnsFalsewhen it should stop generating, while preserving ZeRO stage 3 support. It is not only an improvement in terms of code reuse, but also a hard requirement to enabletorch.compile(..., fullgraph=True):breakand data-dependent control flow is not supported.