Skip to content

Conversation

@tanmaysachan
Copy link

@tanmaysachan tanmaysachan commented Jan 16, 2026

Addresses #865

  • Model outline from pytorch -> jax
  • parity checks
  • Tests
  • Address drift
  • Train end to end

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the JAX implementation for the DeepseekV3 model. The implementation is comprehensive and covers the model's unique features like Multi-Head Latent Attention and Mixture of Experts with shared experts. The code is well-structured.

My review focuses on a critical bug that will prevent the model from running, along with some suggestions to improve maintainability by reducing code duplication and avoiding magic numbers. Addressing these points will make the implementation more robust and easier to maintain.

Comment on lines 527 to 543
# Precompute RoPE frequencies
# qk_rope_head_dim = config.qk_rope_head_dim
# original_seq_len = getattr(config, "original_seq_len", config.max_position_embeddings)
# rope_factor = getattr(config, "rope_factor", 1.0)
# beta_fast = getattr(config, "beta_fast", 32)
# beta_slow = getattr(config, "beta_slow", 1)

# TODO: Swap out like llama's rope?
# self.freqs_cis = precompute_freqs_cis(
# dim=qk_rope_head_dim,
# max_seq_len=config.max_position_embeddings,
# rope_theta=config.rope_theta,
# original_seq_len=original_seq_len,
# rope_factor=rope_factor,
# beta_fast=beta_fast,
# beta_slow=beta_slow,
# )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This block for precomputing RoPE frequencies is commented out, but self.freqs_cis is used in DeepseekV3Model.__call__ at line 571. This will raise an AttributeError at runtime.

Looking at the DeepseekV3MLA implementation, the freqs_cis parameter is not used. Instead, apply_rope is called, which computes the frequencies on the fly.

To fix this, you should remove the freqs_cis parameter from the entire call chain, as it appears to be unused. This involves:

  1. Removing freqs_cis: jax.Array from the signature of DeepseekV3MLA.__call__.
  2. Removing freqs_cis: jax.Array from the signature of DeepseekV3DecoderLayer.__call__.
  3. Removing the freqs_cis=self.freqs_cis argument from the layer() call within DeepseekV3Model.__call__.

This will resolve the crash and align the code with the current apply_rope implementation. You can then address the TODO about swapping the RoPE implementation in a separate change.

@pcmoritz pcmoritz added the tx label Jan 17, 2026
@tanmaysachan
Copy link
Author

tanmaysachan commented Jan 18, 2026

@pcmoritz The PR is open for reviews now

In the first test case I've added a todo - there seems to be some kind of drift which requires absolute tolerance to be around ~6e-3 for tests to pass. I'll investigate a little more, nothing seemed to have caught my eye so far

@tanmaysachan
Copy link
Author

Fixed the source of the drift, there was a default config mismatch

@pcmoritz
Copy link
Collaborator

This is awesome! Have you already gotten some end-to-end training working with it? It would be great to add one to https://github.com/NovaSky-AI/SkyRL/blob/main/skyrl-tx/README.md. If you haven't I'm also more than happy to help with it :)

@tanmaysachan
Copy link
Author

tanmaysachan commented Jan 21, 2026

Looks like the tests are failing, unable to replicate this on my machine somehow.
Having a look

Some qwen tests also seem to be failing - is this expected?

Have not been able to train end-to-end yet, will give it a shot over the weekend! (with any further fixes required). Added a task for it in the PR description

@tanmaysachan
Copy link
Author

tanmaysachan commented Jan 23, 2026

Failing tests root cause: Huggingface outputs are not consistent between MacOS and Ubuntu (Accelerate vs MKL)

Linux

OS: Linux 6.8.0-90-genericMachine: x86_64Python: 3.12.12PyTorch: 2.10.0+cu128PyTorch BLAS: mklCUDA available: FalseTransformers: 4.57.6

DEEPSEEK V3 TEST

HF hidden_states[-1] first 10 values (sample 0, pos 0):
[-0.05490041896700859, -0.6639361381530762, -0.4137983024120331, 0.19858041405677795, 0.4002900719642639, -1.8006019592285156, -0.7636783123016357, -0.6883448958396912, 0.39694416522979736, 2.5040738582611084]

Macos

OS: Darwin 25.2.0Machine: arm64PyTorch BLAS: accelerateTransformers: 4.57.6

DEEPSEEK V3 TEST
HF hidden_states[-1] first 10 values (sample 0, pos 0):
[-0.0496, -0.6667, -0.4240, 0.1903, 0.4095, -1.8056, -0.7479, -0.6778, 0.3872, 2.5022]

Bumping thresholds

@vercel
Copy link

vercel bot commented Jan 25, 2026

@tanmaysachan is attempting to deploy a commit to the Tyler's projects Team on Vercel.

A member of the Team first needs to authorize it.

- Add LogitsProcessorMixin to DeepseekV3ForCausalLM
- Add get_lm_head() method for logits computation
- Fix broken compute_positions import
- Fix init_lora_adapter to handle n_routed_experts attribute
- Add test_deepseekv3_lora_training.py with MoE rank normalization tests

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@tanmaysachan
Copy link
Author

tanmaysachan commented Jan 25, 2026

Accidental deployment attempt due to rebase ^

@tanmaysachan
Copy link
Author

tanmaysachan commented Jan 25, 2026

End-to-end training successfull on an A100.

/api/v1/healthz -> {"status":"ok"}

Added GPU tests (need anyscale creds to run)

GPU tests on A100:

(skyrl-tx) (main) root@C.30490604:/workspace/SkyRL/skyrl-tx$ uv run --extra gpu python -m pytest tests/models/test_deepseekv3_lora_training.py -v
======================================================================================= test session starts ========================================================================================
platform linux -- Python 3.11.14, pytest-9.0.2, pluggy-1.6.0 -- /workspace/SkyRL/skyrl-tx/.venv/bin/python3
cachedir: .pytest_cache
rootdir: /workspace/SkyRL/skyrl-tx
configfile: pyproject.toml
plugins: anyio-4.12.1
collected 2 items

tests/models/test_deepseekv3_lora_training.py::test_lora_training_moe_rank_normalized PASSED [ 50%]
tests/models/test_deepseekv3_lora_training.py::test_lora_training_high_rank PASSED [100%]

@pcmoritz
Copy link
Collaborator

pcmoritz commented Jan 27, 2026

@tanmaysachan Thanks a lot for implementing this, this is excellent work! I'd like to get something like the following working

uv run --extra gpu --extra tinker -m tx.tinker.api --base-model zai-org/GLM-4.7-Flash  --backend-config '{"max_lora_adapters": 2, "max_lora_rank": 1, "expert_parallel_size": 8, "train_micro_batch_size": 1, "shard_attention_heads": false}'

(I think we should be able to support zai-org/GLM-4.7-Flash since it has a DeepseekV3 like architecture). This can be run on 8xH100 I think and so will be great to add to the next release notes for people to try this out :)

I can look a little more into what is needed to make this work :)

@pcmoritz
Copy link
Collaborator

Let me first merge this PR, and then we can implement zai-org/GLM-4.7-Flash on top of it (since I think that will be a little more work)

@pcmoritz
Copy link
Collaborator

/gemini review

@tanmaysachan
Copy link
Author

Sure, I can add GLM in a followup (this one is quite bloated already 😅)

top_k_weights, top_k_index = self._compute_routing(router_logits)

expert_output = self.experts(hidden_states_flat, top_k_index, top_k_weights, adapter_indices_flat)
shared_output = self.shared_experts(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reminder to self: It will be great if we can make MLP layers also support flattened states going forward, so this reshaping can be removed. It would also make the layer chunking #902 nicer. Shouldn't be hard to implement, mainly needs some refactoring of the LoRAMixin

continue
if "experts" in path:
tensors[key] = np.stack([tensors[get_expert_key(path, i)].T for i in range(config.num_experts)], axis=0)
num_experts = getattr(config, "num_experts", None) or getattr(config, "n_routed_experts")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: I think this can be handled in a unified way by ModelConfig, e.g. by exposing a get_num_experts method

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed

@pcmoritz
Copy link
Collaborator

I'll need a little more time tomorrow to finish reviewing the PR, in particular the attention part and the tests, but so far it looks great! I pushed some small fixes and cleanups (if you end up working on it some more, don't forget to pull first). Hopefully we can get it merged tomorrow!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants