Skip to content

fix: remove tie weight check#700

Merged
terrykong merged 5 commits intomainfrom
ruit/remove_tie_weight_check
Aug 9, 2025
Merged

fix: remove tie weight check#700
terrykong merged 5 commits intomainfrom
ruit/remove_tie_weight_check

Conversation

@RayenTian
Copy link
Copy Markdown
Contributor

@RayenTian RayenTian commented Jul 21, 2025

What does this PR do ?

This pull request removes the tie weight check in the DTensor Worker, addressing a previous limitation with training models that utilize tie_word_embeddings when tensor parallel size (tp_size) > 1.

Issues

closes #684

Test Result

For remove NRL_SKIP_TIED_WEIGHT_CHECK env variable

Experiments were run with meta-llama/Llama-3.2-1B, Qwen/Qwen2.5-1.5B-Instruct and google/gemma-2-2b-it. With the truncate rate set within a reasonable range, both tp_size=1 and tp_size=2 produced comparable reward results, indicating that training proceeds as expected. This demonstrates consistency in model behavior across different tensor parallel configurations, confirming correct handling of tied word embeddings during training.

Main setup:

  • policy.generation.temperature=1.0
  • policy.max_total_sequence_length=2048
  • cluster.gpus_per_node=8

Llama-3.2-1B

Qwen2.5-1.5B-Instruct

google/gemma-2-2b-it

Test with:

  • policy.dynamic_batching.enabled=true
  • policy.sequence_packing.enabled=false
  • policy.train_global_batch_size=128

gemma_reward_40 gemma_trancate_rate_40 gemma_token_mult_prob_error_40

For remove model.config.tie_word_embeddings in nemo_rl/models/dtensor/parallelize.py

Experiments were conducted using both the Qwen/Qwen2.5-1.5B-Instruct and Qwen/Qwen2.5-7B-Instruct models, representing the original tied and untied weight configurations, respectively. I disabled Qwen's optimized parallel plan in PARALLIZE_FUNCTIONS, forcing both models to utilize the HP plan for testing.

  • The results indicate that, after removing the model.config.tie_word_embeddings check, the HP plan and the optimized plan produce identical outcomes.

Qwen/Qwen2.5-1.5B-Instruct

  • This model uses tied word embeddings by default.
  • Current results are from a short test of 10 steps.

image

Qwen/Qwen2.5-7B-Instruct

  • This model uses untied word embeddings by default.
  • Current results are from a short test of 10 steps.

image

More Custom Test

Some additional tests were also conducted here.
For models like Qwen/Qwen2.5-1.5B-Instruct, which by default use tied word embeddings, we forcibly disabled the tie_word_embedding setting. The results showed significant anomalies in both the token_mult_prob_error and the reward. We suspect this may be because vLLM still keeps the embeddings tied on its side, so when you call update_weights, there might be some issues (in practice, vLLM may be using either the embed or the lm_head weight, rather than truly untying them as in training). This causes the token_mult_prob_error to behave abnormally.

qwen_1 5b_reward qwen_1 5b_token_mult_prob_error

For models like Qwen/Qwen2.5-7B-Instruct, which by default do not use tied embeddings, we forcibly enabled tie_word_embedding. The results showed large differences in reward, but the token_mult_prob_error remained stable. This might be because, originally, the embed and lm_head weights were different due to being untied; forcing a tie essentially removes the original lm_head weights. When we call update_weights to vLLM, its lm_head is also overwritten with the embed weights, so the token_mult_prob_error remains normal. However, since we forcibly replaced the original weights, the reward becomes abnormal.

qwen_7b_reward qwen_7b_token_mult_prob_error

Additional Notes on Gemma Model Testing

  • gemma-3-1b-it: This model sets kv_head=1, which is currently not compatible with tensor parallelism (TP=2). As a result, it could not be tested in a TP=2 configuration.

@RayenTian RayenTian added the CI:L1 Run doctests, unit tests, and functional tests label Jul 21, 2025
@RayenTian RayenTian added CI:L0 Run doctests and unit tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Jul 21, 2025
@RayenTian RayenTian requested review from joyang-nv and yuki-97 July 21, 2025 09:17
@RayenTian RayenTian added CI:L0 Run doctests and unit tests and removed CI:L0 Run doctests and unit tests labels Jul 22, 2025
@RayenTian RayenTian force-pushed the ruit/remove_tie_weight_check branch from e610b0e to a4dbcec Compare July 22, 2025 04:32
@RayenTian RayenTian added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L0 Run doctests and unit tests labels Jul 22, 2025
Comment thread nemo_rl/models/huggingface/common.py Outdated
Comment thread nemo_rl/models/policy/dtensor_policy_worker.py Outdated
@yuki-97
Copy link
Copy Markdown
Contributor

yuki-97 commented Jul 22, 2025

Thanks @RayenTian for verifying and removing these things!

Can you also check if we can also remove model.config.tie_word_embeddings in HF TP plan in nemo_rl/models/dtensor/parallelize.py?

@yuki-97
Copy link
Copy Markdown
Contributor

yuki-97 commented Jul 22, 2025

One more thing, can you also search issues/227 in the repo and remove (or update) them?

@RayenTian RayenTian force-pushed the ruit/remove_tie_weight_check branch from a4dbcec to b1cd9b6 Compare July 23, 2025 01:50
@SahilJain314
Copy link
Copy Markdown
Contributor

can you please attach the token_mult_prob_error plots to the description? (we can't always see the errors when we just look at convergence plots)

@RayenTian RayenTian force-pushed the ruit/remove_tie_weight_check branch from b1cd9b6 to fea99a2 Compare July 24, 2025 03:21
@github-actions github-actions Bot added the Documentation Improvements or additions to documentation label Jul 24, 2025
Comment thread docs/model-quirks.md
@RayenTian
Copy link
Copy Markdown
Contributor Author

can you please attach the token_mult_prob_error plots to the description? (we can't always see the errors when we just look at convergence plots)

Thank you for the suggestion. I have added the plots as requested.

@RayenTian RayenTian force-pushed the ruit/remove_tie_weight_check branch from c6386b5 to d741557 Compare July 29, 2025 03:26
@RayenTian RayenTian added the CI:L0 Run doctests and unit tests label Aug 5, 2025
@RayenTian RayenTian added CI:L0 Run doctests and unit tests and removed CI:L0 Run doctests and unit tests labels Aug 6, 2025
@RayenTian RayenTian force-pushed the ruit/remove_tie_weight_check branch from c9f41b3 to 292cc83 Compare August 6, 2025 07:15
@RayenTian RayenTian added CI:L0 Run doctests and unit tests and removed CI:L0 Run doctests and unit tests labels Aug 6, 2025
terrykong
terrykong previously approved these changes Aug 8, 2025
SahilJain314
SahilJain314 previously approved these changes Aug 8, 2025
Copy link
Copy Markdown
Contributor

@SahilJain314 SahilJain314 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for rebasing, lgtm now.

Signed-off-by: ruit <ruit@nvidia.com>
…h are not used anymore

Signed-off-by: ruit <ruit@nvidia.com>
Signed-off-by: ruit <ruit@nvidia.com>
Signed-off-by: ruit <ruit@nvidia.com>
Signed-off-by: ruit <ruit@nvidia.com>
@terrykong terrykong dismissed stale reviews from SahilJain314, yuki-97, and themself via a5d70d8 August 8, 2025 22:40
@terrykong terrykong force-pushed the ruit/remove_tie_weight_check branch from 292cc83 to a5d70d8 Compare August 8, 2025 22:40
@terrykong terrykong enabled auto-merge August 8, 2025 22:40
@terrykong terrykong added this pull request to the merge queue Aug 8, 2025
Merged via the queue into main with commit fecf71e Aug 9, 2025
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L0 Run doctests and unit tests Documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Remove tie weights check in DTensor worker

5 participants