Skip to content
242 changes: 242 additions & 0 deletions skyrl-tx/tests/models/test_qwen3_vl_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
from flax import nnx
import jax
import jax.numpy as jnp
import numpy as np
import torch
from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeConfig
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
Qwen3VLMoeForConditionalGeneration,
)

from tx.models.configs import Qwen3VLModelConfig
from tx.models.qwen3_vl_moe import Qwen3VLForCausalLM


def _make_test_mesh() -> jax.sharding.Mesh:
return jax.make_mesh(
(1, 1, 1),
("fsdp", "ep", "tp"),
axis_types=(jax.sharding.AxisType.Auto,) * 3,
)


def _make_tiny_hf_vl_moe_config() -> Qwen3VLMoeConfig:
# Keep dimensions tiny for CI speed while exercising MoE + mRoPE codepaths.
return Qwen3VLMoeConfig(
image_token_id=151655,
video_token_id=151656,
vision_start_token_id=151652,
text_config={
"vocab_size": 128,
"hidden_size": 16,
"intermediate_size": 32,
"num_hidden_layers": 2,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"head_dim": 8,
"rms_norm_eps": 1e-6,
"attention_bias": False,
"hidden_act": "silu",
"decoder_sparse_step": 1,
"moe_intermediate_size": 8,
"num_experts_per_tok": 2,
"num_experts": 4,
"mlp_only_layers": [],
"rope_parameters": {
"rope_type": "default",
"rope_theta": 10000.0,
"mrope_section": [2, 1, 1],
},
# HF Qwen3VLMoeTextRotaryEmbedding currently reads rope_scaling.
"rope_scaling": {
"rope_type": "default",
"mrope_section": [2, 1, 1],
},
},
# Vision tower is unused in these text-only parity tests, but config must exist.
vision_config={
"depth": 0,
"hidden_size": 16,
"intermediate_size": 32,
"num_heads": 2,
"out_hidden_size": 16,
"patch_size": 2,
"spatial_merge_size": 2,
"temporal_patch_size": 1,
"num_position_embeddings": 4,
"deepstack_visual_indexes": [],
},
tie_word_embeddings=False,
)


def _load_text_weights_from_hf(
jax_model: Qwen3VLForCausalLM, hf_model: Qwen3VLMoeForConditionalGeneration
) -> None:
# Embeddings + final norm + lm_head
jax_model.model.embed_tokens.embedding[...] = hf_model.model.language_model.embed_tokens.weight.detach().cpu().numpy()
jax_model.model.norm.weight[...] = hf_model.model.language_model.norm.weight.detach().cpu().numpy()
jax_model.lm_head.kernel[...] = hf_model.lm_head.weight.detach().cpu().numpy().T

# Decoder layers (text-only parity path)
num_layers = len(jax_model.model.layers)
for i in range(num_layers):
jax_layer = jax_model.model.layers[i]
hf_layer = hf_model.model.language_model.layers[i]

# Layer norms
jax_layer.input_norm.weight[...] = hf_layer.input_layernorm.weight.detach().cpu().numpy()
jax_layer.post_norm.weight[...] = hf_layer.post_attention_layernorm.weight.detach().cpu().numpy()

# Attention
jax_layer.attn.q_proj.kernel[...] = hf_layer.self_attn.q_proj.weight.detach().cpu().numpy().T
jax_layer.attn.k_proj.kernel[...] = hf_layer.self_attn.k_proj.weight.detach().cpu().numpy().T
jax_layer.attn.v_proj.kernel[...] = hf_layer.self_attn.v_proj.weight.detach().cpu().numpy().T
jax_layer.attn.o_proj.kernel[...] = hf_layer.self_attn.o_proj.weight.detach().cpu().numpy().T
jax_layer.attn.q_norm.weight[...] = hf_layer.self_attn.q_norm.weight.detach().cpu().numpy()
jax_layer.attn.k_norm.weight[...] = hf_layer.self_attn.k_norm.weight.detach().cpu().numpy()

# MoE (router + experts)
jax_layer.mlp.router.kernel[...] = hf_layer.mlp.gate.weight.detach().cpu().numpy().T
gate_up = hf_layer.mlp.experts.gate_up_proj.detach().cpu().numpy()
inter = jax_layer.mlp.experts.gate_proj.weight.shape[2]
# HF gate_up_proj packs [gate, up] in out_features; split then transpose to [in, out].
jax_layer.mlp.experts.gate_proj.weight[...] = gate_up[:, :inter, :].transpose(0, 2, 1)
jax_layer.mlp.experts.up_proj.weight[...] = gate_up[:, inter:, :].transpose(0, 2, 1)
hf_down = hf_layer.mlp.experts.down_proj.detach().cpu().numpy()
assert hf_down.shape == jax_layer.mlp.experts.down_proj.weight.shape
jax_layer.mlp.experts.down_proj.weight[...] = hf_down
Comment on lines +106 to +108
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Missing transpose for down_proj expert weights in test weight loading

In _load_text_weights_from_hf, the gate_up_proj weights are correctly transposed from HF's (E, out_features, in_features) convention to JAX's (E, in_features, out_features) convention (line 104-105: .transpose(0, 2, 1)), but the down_proj weights are directly assigned without the same transpose (line 108). Since HF stores all fused expert parameters in (E, out, in) format (confirmed by the comment at line 103: "HF gate_up_proj packs [gate, up] in out_features"), down_proj would have shape (E, hidden_size, moe_intermediate_size) = (4, 16, 8) while the JAX weight has shape (E, moe_intermediate_size, hidden_size) = (4, 8, 16). The assertion at line 107 will fail because the shapes don't match, preventing the parity test from ever passing.

Suggested change
hf_down = hf_layer.mlp.experts.down_proj.detach().cpu().numpy()
assert hf_down.shape == jax_layer.mlp.experts.down_proj.weight.shape
jax_layer.mlp.experts.down_proj.weight[...] = hf_down
hf_down = hf_layer.mlp.experts.down_proj.detach().cpu().numpy()
assert hf_down.transpose(0, 2, 1).shape == jax_layer.mlp.experts.down_proj.weight.shape
jax_layer.mlp.experts.down_proj.weight[...] = hf_down.transpose(0, 2, 1)
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.



def _build_tiny_models() -> tuple[Qwen3VLMoeForConditionalGeneration, Qwen3VLForCausalLM, jax.sharding.Mesh]:
torch.manual_seed(0)
hf_config = _make_tiny_hf_vl_moe_config()
hf_model = Qwen3VLMoeForConditionalGeneration(hf_config).eval()

jax_config = Qwen3VLModelConfig(
hf_config,
max_lora_adapters=0,
max_lora_rank=0,
shard_attention_heads=True,
gradient_checkpointing=False,
)

mesh = _make_test_mesh()
with jax.set_mesh(mesh):
jax_model = Qwen3VLForCausalLM(jax_config, dtype=jnp.float32, rngs=nnx.Rngs(0))
_load_text_weights_from_hf(jax_model, hf_model)

return hf_model, jax_model, mesh


def test_qwen3_vl_moe_text_prefill_parity_with_hf():
hf_model, jax_model, mesh = _build_tiny_models()

input_ids = torch.tensor(
[
[11, 12, 13, 14, 0, 0],
[21, 22, 23, 24, 25, 26],
],
dtype=torch.long,
)
attention_mask = (input_ids != 0).long()

with torch.no_grad():
hf_outputs = hf_model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=True,
return_dict=True,
)

with jax.set_mesh(mesh):
jax_outputs = jax_model(
np.asarray(input_ids, dtype=np.int32),
attention_mask=np.asarray(attention_mask, dtype=np.int32),
output_hidden_states=True,
)
assert jax_outputs.hidden_states is not None
assert hf_outputs.hidden_states is not None
jax_logits = jax_model.compute_logits(jax_outputs.last_hidden_state)

np.testing.assert_allclose(
np.asarray(hf_outputs.hidden_states[0], dtype=np.float32),
np.asarray(jax_outputs.hidden_states[0], dtype=np.float32),
rtol=1e-4,
atol=1e-4,
)
np.testing.assert_allclose(
np.asarray(hf_outputs.hidden_states[1], dtype=np.float32),
np.asarray(jax_outputs.hidden_states[1], dtype=np.float32),
rtol=1.5e-2,
atol=1.5e-2,
)
# HF VL-MoE exposes pre-final-norm hidden states here, while JAX includes
# final norm in hidden_states. Align by stage instead of raw index.
hf_last = np.asarray(hf_outputs.hidden_states[-1], dtype=np.float32)
if len(jax_outputs.hidden_states) == len(hf_outputs.hidden_states):
jax_last_aligned = np.asarray(jax_outputs.hidden_states[-1], dtype=np.float32)
else:
jax_last_aligned = np.asarray(jax_outputs.hidden_states[-2], dtype=np.float32)
np.testing.assert_allclose(hf_last, jax_last_aligned, rtol=1.5e-2, atol=1.5e-2)
valid = np.asarray(attention_mask, dtype=bool)
hf_logits = np.asarray(hf_outputs.logits, dtype=np.float32)
jax_logits_np = np.asarray(jax_logits, dtype=np.float32)
np.testing.assert_allclose(
hf_logits[valid],
jax_logits_np[valid],
rtol=1.5e-2,
atol=1e-2,
)
Comment on lines +169 to +191
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The tolerances used for np.testing.assert_allclose in the prefill parity test are quite loose (e.g., rtol=1.5e-2). While some small differences between frameworks are expected, a relative tolerance of 1.5% is high for a parity test and might mask subtle implementation differences or numerical stability issues. It would be beneficial to investigate if these tolerances can be tightened to ensure closer alignment with the reference implementation.



def test_qwen3_vl_moe_text_decode_step_parity_with_hf():
hf_model, jax_model, mesh = _build_tiny_models()

# Prefill 4 tokens, then decode 1 token.
prefill_ids = torch.tensor([[11, 12, 13, 14], [21, 22, 23, 24]], dtype=torch.long)
prefill_mask = torch.ones_like(prefill_ids, dtype=torch.long)
decode_ids = torch.tensor([[15], [25]], dtype=torch.long)
decode_mask = torch.ones((2, 5), dtype=torch.long)

with torch.no_grad():
hf_prefill = hf_model(
input_ids=prefill_ids,
attention_mask=prefill_mask,
use_cache=True,
return_dict=True,
)
hf_decode = hf_model(
input_ids=decode_ids,
attention_mask=decode_mask,
past_key_values=hf_prefill.past_key_values,
use_cache=True,
return_dict=True,
)

with jax.set_mesh(mesh):
jax_prefill = jax_model(
np.asarray(prefill_ids, dtype=np.int32),
attention_mask=np.asarray(prefill_mask, dtype=np.int32),
)
assert jax_prefill.kv_cache is not None
# Match generation runtime behavior: decode updates into a pre-allocated KV cache.
jax_prefill_cache = jax_prefill.kv_cache.pad_to_length(int(decode_mask.shape[1]))

decode_positions = np.asarray(jax_prefill_cache.cache_position[:, None], dtype=np.int32)
jax_decode = jax_model(
np.asarray(decode_ids, dtype=np.int32),
attention_mask=np.asarray(decode_mask, dtype=np.int32),
kv_cache=jax_prefill_cache,
positions=decode_positions,
)
jax_decode_logits = jax_model.compute_logits(jax_decode.last_hidden_state)

np.testing.assert_allclose(
np.asarray(hf_decode.logits, dtype=np.float32),
np.asarray(jax_decode_logits, dtype=np.float32),
rtol=2e-1,
atol=9e-2,
)
Comment on lines +236 to +241
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The tolerance for the decode step parity test is extremely high (rtol=2e-1, atol=9e-2), allowing for up to a 20% relative difference. This strongly suggests a potential bug or significant numerical discrepancy in the decode path implementation compared to the Hugging Face reference. This should be investigated and fixed to ensure the model behaves as expected during generation.


34 changes: 31 additions & 3 deletions skyrl-tx/tx/layers/stacked.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,11 @@ def __call__(
positions: jax.Array,
adapter_indices: jax.Array | None,
kv_cache: KVCache | None,
rope_deltas: jax.Array | None = None,
output_hidden_states: bool,
gradient_checkpointing: bool,
is_training: bool = False,
**layer_kwargs,
) -> tuple[jax.Array, list[jax.Array], KVCache | None]:
"""Forward pass through all layers.

Expand Down Expand Up @@ -242,11 +244,19 @@ def __call__(
positions=positions,
adapter_indices=adapter_indices,
kv_cache=layer_kv,
**layer_kwargs,
)
updated_keys.append(k)
updated_values.append(v)

new_kv_cache = KVCache.update(kv_cache, updated_keys, updated_values, positions, attention_mask)
new_kv_cache = KVCache.update(
kv_cache,
updated_keys,
updated_values,
positions,
attention_mask,
rope_deltas=kv_cache.rope_deltas,
)
return hidden_states, all_hidden_states, new_kv_cache

# Prefill/training mode: use scan for efficiency
Expand All @@ -261,6 +271,7 @@ def body_fn(carry, layer_params):
positions=positions,
adapter_indices=adapter_indices,
kv_cache=None,
**layer_kwargs,
)

hs_output = new_hs if output_hidden_states else None
Expand All @@ -282,7 +293,14 @@ def body_fn(carry, layer_params):
# Convert stacked scan outputs to list format
keys_list = [all_keys[i] for i in range(self.num_layers)]
values_list = [all_values[i] for i in range(self.num_layers)]
new_kv_cache = KVCache.update(None, keys_list, values_list, positions, attention_mask)
new_kv_cache = KVCache.update(
None,
keys_list,
values_list,
positions,
attention_mask,
rope_deltas=rope_deltas,
)

all_hidden_states = [hidden_states] + list(all_hs[:-1]) if output_hidden_states else []
return final_hs, all_hidden_states, new_kv_cache
Expand Down Expand Up @@ -336,6 +354,7 @@ def _split_kv_cache(kv_cache: KVCache, split_points: list[int]) -> tuple[KVCache
keys=kv_cache.keys[start:end],
values=kv_cache.values[start:end],
cache_position=kv_cache.cache_position,
rope_deltas=kv_cache.rope_deltas,
)
for start, end in zip(boundaries[:-1], boundaries[1:])
)
Expand All @@ -345,7 +364,12 @@ def _concat_kv_caches(caches: list[KVCache]) -> KVCache:
assert caches, "Expected at least one KV cache."
keys = [key for cache in caches for key in cache.keys]
values = [value for cache in caches for value in cache.values]
return KVCache(keys=keys, values=values, cache_position=caches[-1].cache_position)
return KVCache(
keys=keys,
values=values,
cache_position=caches[-1].cache_position,
rope_deltas=caches[-1].rope_deltas,
)

def __call__(
self,
Expand All @@ -355,9 +379,11 @@ def __call__(
positions: jax.Array,
adapter_indices: jax.Array | None,
kv_cache: KVCache | None,
rope_deltas: jax.Array | None = None,
output_hidden_states: bool,
gradient_checkpointing: bool,
is_training: bool = False,
**layer_kwargs,
) -> tuple[jax.Array, list[jax.Array], KVCache | None]:
all_hidden_states: list[jax.Array] = []

Expand All @@ -379,9 +405,11 @@ def __call__(
positions=positions,
adapter_indices=adapter_indices,
kv_cache=group_kv_cache,
rope_deltas=rope_deltas,
output_hidden_states=output_hidden_states,
gradient_checkpointing=gradient_checkpointing,
is_training=is_training,
**layer_kwargs,
)
all_hidden_states.extend(layer_hidden_states)
if not is_training:
Expand Down
19 changes: 18 additions & 1 deletion skyrl-tx/tx/models/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,27 @@ def __init__(
self.mhc_expansion_rate = mhc_expansion_rate

def get_num_experts(self):
return getattr(self, "num_experts", None) or getattr(self, "n_routed_experts", None)
# Most models expose experts at top-level config.
experts = getattr(self, "num_experts", None) or getattr(
self, "n_routed_experts", None
)
if experts is not None:
return experts

# VL-MoE stores expert config under text_config (object or dict).
text_config = getattr(self, "text_config", None)
if isinstance(text_config, dict):
return text_config.get("num_experts") or text_config.get("n_routed_experts")
if text_config is not None:
return getattr(text_config, "num_experts", None) or getattr(
text_config, "n_routed_experts", None
)
return None


# Model-specific aliases for clarity and backwards compatibility
Llama3Config = ModelConfig
Qwen3Config = ModelConfig
DeepseekV3Config = ModelConfig
Qwen3VLMoeConfig = ModelConfig
Qwen3VLModelConfig = ModelConfig
Loading
Loading