feat: chunked logprob calculation with deferred fp32 cast to help with OOM#856
feat: chunked logprob calculation with deferred fp32 cast to help with OOM#856pjin-nvidia wants to merge 49 commits intoNVIDIA-NeMo:mainfrom
Conversation
Signed-off-by: Peter Jin <pjin@nvidia.com>
Based on NeMo commit: 8ddf4387344c6423763ec9ee0c9a755cbb5d8d35 Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
2a985fe to
6a445bc
Compare
Signed-off-by: Peter Jin <pjin@nvidia.com>
| from nemo.collections.llm.t5.model.t5 import T5Config | ||
|
|
||
|
|
||
| def get_model_from_config_no_float32( |
There was a problem hiding this comment.
was this function copied from somewhere? if so, what changes were made?
There was a problem hiding this comment.
it's a copy of nemo/tron/model.py:
https://github.com/NVIDIA/NeMo/blob/8ddf4387344c6423763ec9ee0c9a755cbb5d8d35/nemo/tron/model.py
the main change is removing the Float16Module wrapper (which is what originally casts the model logits output to float32):
https://github.com/NVIDIA-NeMo/RL/pull/856/files/a020289609cfa0d7a695a175eed009fdb4695088#diff-37539801eab6c58172c5cf85be33a1f9eac04c096a8e23170550ddf3bff8e3b3R125-R128
There was a problem hiding this comment.
If it's only a one-line change, I'd prefer the change be reflected in the submodule (you can branch where the submodule is at to update)
also, if you expect the model coming back to not be a FP16 but something else, could you add a test asserting the model type? We're currently migrating away from tron, so once that is done, this test would ensure we don't miss this typing fix you're adding
There was a problem hiding this comment.
updated NeMo submodule
branch: https://github.com/NVIDIA/NeMo/tree/pjin/nemorl-logprob
commit: NVIDIA-NeMo/NeMo@0bf0dbc
There was a problem hiding this comment.
also, if you expect the model coming back to not be a FP16 but something else, could you add a test asserting the model type? We're currently migrating away from tron, so once that is done, this test would ensure we don't miss this typing fix you're adding
what I did is add a float32 dtype check to the existing megatron logprobs test, and running that test on more cases of (logprob chunk size, deferred float32 logits)
https://github.com/NVIDIA-NeMo/RL/pull/856/files#diff-9556cb57e37308923c54e7a6df8982afafef5e36544f350af3324db43f74bdbeR703
one thing is that the policy model worker mainly exposes the model output through get_logprobs, and there is not another interface for getting at the underlying torch model logits. but I think just checking that the returned logprobs are float32 should be sufficient?
| return grad_input, None, None, None, None, None, None | ||
|
|
||
|
|
||
| class ChunkedDistributedLogprob(torch.autograd.Function): |
There was a problem hiding this comment.
can we add a unit test for this function so we make sure the non-chunked version equals the chunk?
There was a problem hiding this comment.
added a chunk_size parameter to DistributedLogprobTestActor:
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
58a202e to
df70715
Compare
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
| moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo | ||
| apply_rope_fusion: True | ||
| activation_checkpointing: True | ||
| defer_fp32_logits: True |
There was a problem hiding this comment.
what would be the reason to set this to False?
There was a problem hiding this comment.
mostly for strict backward compat, but we could instead enable it by default (i.e. make it an opt-out config like no_defer_fp32_logits or similar)
wdyt?
There was a problem hiding this comment.
I see. How about the following:
- this PR introduces it, default off
- follow up PR where we run all our nightly tests to see if defaulting to true is ok, if so, remove the arg
wdyt? If the feature is broadly applicable we should probably switch it to true so no one else runs into the same issue (assuming no accuracy penalty)
There was a problem hiding this comment.
yup, (1) and then (2) SGTM!
Signed-off-by: Peter Jin <pjin@nvidia.com>
1c8186a to
81fb8e1
Compare
Signed-off-by: Peter Jin <pjin@nvidia.com>
|
There's a permission issue with that |
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
|
closing in favor of #918 |
No description provided.