From 874e25cd5ce0491ed8567a5841b4ce5524ef1909 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 24 Feb 2026 12:47:00 -0800 Subject: [PATCH 01/12] implement qwen3_next model architecture --- skyrl-tx/tests/models/test_qwen3_next.py | 165 +++++ skyrl-tx/tx/models/configs.py | 1 + skyrl-tx/tx/models/qwen3_next.py | 885 +++++++++++++++++++++++ skyrl-tx/tx/utils/generator.py | 34 +- skyrl-tx/tx/utils/models.py | 5 + 5 files changed, 1084 insertions(+), 6 deletions(-) create mode 100644 skyrl-tx/tests/models/test_qwen3_next.py create mode 100644 skyrl-tx/tx/models/qwen3_next.py diff --git a/skyrl-tx/tests/models/test_qwen3_next.py b/skyrl-tx/tests/models/test_qwen3_next.py new file mode 100644 index 0000000000..5fbeffe317 --- /dev/null +++ b/skyrl-tx/tests/models/test_qwen3_next.py @@ -0,0 +1,165 @@ +import tempfile + +from flax import nnx +import jax +import jax.numpy as jnp +import numpy as np +import pytest +import torch +from transformers.models.qwen3_next.configuration_qwen3_next import Qwen3NextConfig as HFQwen3NextConfig +from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextForCausalLM as HFQwen3NextForCausalLM + +from tx.models.configs import Qwen3NextConfig +from tx.models.qwen3_next import Qwen3NextForCausalLM +from tx.tinker.types import SamplingParams +from tx.utils.models import load_safetensors + + +def make_small_hf_config() -> HFQwen3NextConfig: + return HFQwen3NextConfig( + vocab_size=128, + hidden_size=32, + intermediate_size=64, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=8, + max_position_embeddings=128, + tie_word_embeddings=False, + linear_conv_kernel_dim=3, + linear_key_head_dim=4, + linear_value_head_dim=4, + linear_num_key_heads=2, + linear_num_value_heads=2, + layer_types=["linear_attention", "full_attention", "linear_attention", "full_attention"], + num_experts=0, + num_experts_per_tok=1, + decoder_sparse_step=1, + ) + + +def make_small_tx_config(base_config: HFQwen3NextConfig, *, shard_attention_heads: bool = False) -> Qwen3NextConfig: + return Qwen3NextConfig( + base_config, + max_lora_adapters=2, + max_lora_rank=8, + shard_attention_heads=shard_attention_heads, + ) + + +@pytest.mark.parametrize("tp", [1, 2]) +def test_qwen3_next_end_to_end(tp: int): + if jax.device_count() < tp: + pytest.skip(f"Need at least {tp} JAX devices for tp={tp}, found {jax.device_count()}") + + hf_config = make_small_hf_config() + hf_model = HFQwen3NextForCausalLM(hf_config) + hf_model.eval() + + input_ids = torch.tensor([[1, 2, 3, 4, 0], [5, 6, 7, 0, 0]], dtype=torch.long) + attention_mask = torch.tensor([[1, 1, 1, 1, 0], [1, 1, 1, 0, 0]], dtype=torch.long) + + with torch.no_grad(): + hf_outputs = hf_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + return_dict=True, + use_cache=False, + ) + + with tempfile.TemporaryDirectory() as tmp: + hf_model.save_pretrained(tmp, safe_serialization=True) + + config = make_small_tx_config(hf_config, shard_attention_heads=tp > 1) + mesh = jax.make_mesh((1, 1, tp), ("fsdp", "ep", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 3) + with jax.set_mesh(mesh): + model = Qwen3NextForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) + load_safetensors(tmp, config, model) + outputs = model( + input_ids.numpy(), + attention_mask=attention_mask.numpy(), + output_hidden_states=True, + ) + logits = model.compute_logits(outputs.last_hidden_state) + + assert outputs.hidden_states is not None + assert np.allclose(hf_outputs.hidden_states[0], outputs.hidden_states[0], rtol=1e-6, atol=1e-6) + assert np.allclose(hf_outputs.hidden_states[1], outputs.hidden_states[1], rtol=1e-3, atol=1e-3) + assert np.allclose(hf_outputs.hidden_states[-1], outputs.hidden_states[-1], rtol=8e-2, atol=8e-2) + assert np.allclose(hf_outputs.logits, logits, rtol=1e-1, atol=1e-1) + + +def test_qwen3_next_prefill_cache_shapes(): + config = make_small_tx_config(make_small_hf_config()) + mesh = jax.make_mesh((1, 1, 1), ("fsdp", "ep", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 3) + + with jax.set_mesh(mesh): + model = Qwen3NextForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) + input_ids = jnp.array([[1, 2, 3, 4, 0], [5, 6, 7, 0, 0]], dtype=jnp.int32) + attention_mask = jnp.array([[1, 1, 1, 1, 0], [1, 1, 1, 0, 0]], dtype=jnp.int32) + outputs = model(input_ids, attention_mask=attention_mask, output_hidden_states=True) + + assert outputs.last_hidden_state.shape == (2, 5, config.hidden_size) + assert outputs.hidden_states is not None + assert len(outputs.hidden_states) == config.num_hidden_layers + 1 + assert outputs.kv_cache is not None + assert len(outputs.kv_cache.keys) == config.num_hidden_layers + assert outputs.kv_cache.conv_states is not None + assert outputs.kv_cache.recurrent_states is not None + assert outputs.kv_cache.keys[0].shape[1] == 0 + assert outputs.kv_cache.keys[1].shape[1] == 5 + assert outputs.kv_cache.keys[3].shape[1] == 5 + assert outputs.kv_cache.conv_states[0].shape[-1] == config.linear_conv_kernel_dim + assert outputs.kv_cache.recurrent_states[0].shape[1] == config.linear_num_value_heads + + +def test_qwen3_next_decode_updates_cache_position(): + config = make_small_tx_config(make_small_hf_config()) + mesh = jax.make_mesh((1, 1, 1), ("fsdp", "ep", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 3) + + with jax.set_mesh(mesh): + model = Qwen3NextForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) + input_ids = jnp.array([[1, 2, 3], [4, 5, 0]], dtype=jnp.int32) + attention_mask = jnp.array([[1, 1, 1], [1, 1, 0]], dtype=jnp.int32) + prefill = model(input_ids, attention_mask=attention_mask) + assert prefill.kv_cache is not None + + cache = prefill.kv_cache.pad_to_length(8) + decode_attention_mask = jnp.pad(attention_mask, ((0, 0), (0, 5))) + batch_idx = jnp.arange(decode_attention_mask.shape[0]) + decode_attention_mask = decode_attention_mask.at[batch_idx, cache.cache_position].set(1) + + next_token = jnp.array([[9], [10]], dtype=jnp.int32) + positions = cache.cache_position[:, None] + decode_out = model( + next_token, + attention_mask=decode_attention_mask, + positions=positions, + kv_cache=cache, + ) + + assert decode_out.kv_cache is not None + assert jnp.all(decode_out.kv_cache.cache_position == cache.cache_position + 1) + assert decode_out.kv_cache.keys[1].shape[1] == 8 + assert decode_out.kv_cache.conv_states is not None + assert decode_out.kv_cache.recurrent_states is not None + + +def test_qwen3_next_generate(): + config = make_small_tx_config(make_small_hf_config()) + mesh = jax.make_mesh((1, 1, 1), ("fsdp", "ep", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 3) + + with jax.set_mesh(mesh): + model = Qwen3NextForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) + input_ids = jnp.array([[1, 2, 3]], dtype=jnp.int32) + attention_mask = jnp.array([[1, 1, 1]], dtype=jnp.int32) + out = model.generate( + input_ids, + attention_mask, + sampling_params=[SamplingParams(max_tokens=2, temperature=0.0, seed=0)], + ) + + assert len(out.generated_ids) == 1 + assert len(out.logprobs) == 1 + assert len(out.generated_ids[0]) == 2 diff --git a/skyrl-tx/tx/models/configs.py b/skyrl-tx/tx/models/configs.py index 9945335e13..f3fadb9c19 100644 --- a/skyrl-tx/tx/models/configs.py +++ b/skyrl-tx/tx/models/configs.py @@ -56,4 +56,5 @@ def get_num_experts(self): # Model-specific aliases for clarity and backwards compatibility Llama3Config = ModelConfig Qwen3Config = ModelConfig +Qwen3NextConfig = ModelConfig DeepseekV3Config = ModelConfig diff --git a/skyrl-tx/tx/models/qwen3_next.py b/skyrl-tx/tx/models/qwen3_next.py new file mode 100644 index 0000000000..1c08c18645 --- /dev/null +++ b/skyrl-tx/tx/models/qwen3_next.py @@ -0,0 +1,885 @@ +from __future__ import annotations + +import math + +from flax import nnx +import jax +from jax import numpy as jnp +from jax.sharding import get_abstract_mesh + +from tx.layers.attention import dot_product_attention +from tx.layers.lora import LoRAEmbed, LoRAExpert, LoRALinear +from tx.layers.rotary_embedding import apply_rope +from tx.layers.util import Param, prepare_routing, shard_map_ep +from tx.models.configs import Qwen3NextConfig +from tx.models.types import CausalLMOutput, ModelForCausalLM, ModelOutput +from tx.utils.generator import GeneratorMixin, KVCache +from tx.utils.logits_processor import LogitsProcessorMixin, LMHead + + +def apply_partial_rope( + q: jax.Array, + k: jax.Array, + positions: jax.Array, + rotary_dim: int, + rope_theta: float, +) -> tuple[jax.Array, jax.Array]: + if rotary_dim <= 0: + return q, k + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + q_rot = apply_rope(q_rot, positions, rotary_dim, rope_theta) + k_rot = apply_rope(k_rot, positions, rotary_dim, rope_theta) + return jnp.concatenate([q_rot, q_pass], axis=-1), jnp.concatenate([k_rot, k_pass], axis=-1) + + +def l2norm(x: jax.Array, axis: int = -1, eps: float = 1e-6) -> jax.Array: + inv_norm = jax.lax.rsqrt(jnp.sum(x * x, axis=axis, keepdims=True) + eps) + return x * inv_norm + + +def apply_mask_to_padding_states(hidden_states: jax.Array, attention_mask: jax.Array | None) -> jax.Array: + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + hidden_states = hidden_states * attention_mask[..., None].astype(hidden_states.dtype) + return hidden_states + + +def recurrent_gated_delta_rule( + query: jax.Array, + key: jax.Array, + value: jax.Array, + g: jax.Array, + beta: jax.Array, + initial_state: jax.Array | None = None, +) -> tuple[jax.Array, jax.Array]: + dtype = query.dtype + query = l2norm(query.astype(jnp.float32), axis=-1) + key = l2norm(key.astype(jnp.float32), axis=-1) + value = value.astype(jnp.float32) + g = g.astype(jnp.float32) + beta = beta.astype(jnp.float32) + + query = query * (1.0 / math.sqrt(query.shape[-1])) + + # [B, T, H, D] -> [T, B, H, D] + query = jnp.swapaxes(query, 0, 1) + key = jnp.swapaxes(key, 0, 1) + value = jnp.swapaxes(value, 0, 1) + g = jnp.swapaxes(g, 0, 1) + beta = jnp.swapaxes(beta, 0, 1) + + batch_size = query.shape[1] + num_heads = query.shape[2] + k_head_dim = query.shape[3] + v_head_dim = value.shape[3] + + if initial_state is None: + initial_state = jnp.zeros((batch_size, num_heads, k_head_dim, v_head_dim), dtype=jnp.float32) + else: + initial_state = initial_state.astype(jnp.float32) + + def step_fn( + state: jax.Array, + inputs: tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array], + ) -> tuple[jax.Array, jax.Array]: + q_t, k_t, v_t, g_t, beta_t = inputs + decay = jnp.exp(g_t)[..., None, None] + state = state * decay + kv_mem = jnp.sum(state * k_t[..., :, None], axis=-2) + delta = (v_t - kv_mem) * beta_t[..., None] + state = state + k_t[..., :, None] * delta[..., None, :] + out_t = jnp.sum(state * q_t[..., :, None], axis=-2) + return state, out_t + + final_state, outputs = jax.lax.scan(step_fn, initial_state, (query, key, value, g, beta)) + outputs = jnp.swapaxes(outputs, 0, 1).astype(dtype) + return outputs, final_state.astype(dtype) + + +class Qwen3NextRMSNorm(nnx.Module): + + def __init__(self, dim: int, *, eps: float, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + self.eps = eps + self.weight = Param( + dim, + dtype=dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.zeros_init(), (None,)), + rngs=rngs, + ) + + def __call__(self, x: jax.Array) -> jax.Array: + out = x.astype(jnp.float32) + out = out * jax.lax.rsqrt(jnp.mean(out * out, axis=-1, keepdims=True) + self.eps) + out = out * (1.0 + self.weight[...].astype(jnp.float32)) + return out.astype(x.dtype) + + +class Qwen3NextRMSNormGated(nnx.Module): + + def __init__(self, dim: int, *, eps: float, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + self.eps = eps + self.weight = Param( + dim, + dtype=dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.ones_init(), (None,)), + rngs=rngs, + ) + + def __call__(self, hidden_states: jax.Array, gate: jax.Array) -> jax.Array: + input_dtype = hidden_states.dtype + out = hidden_states.astype(jnp.float32) + out = out * jax.lax.rsqrt(jnp.mean(out * out, axis=-1, keepdims=True) + self.eps) + out = out * self.weight[...].astype(jnp.float32) + out = out * nnx.silu(gate.astype(jnp.float32)) + return out.astype(input_dtype) + + +class Qwen3NextAttention(nnx.Module): + + def __init__(self, config: Qwen3NextConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + self.config = config + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + tp = get_abstract_mesh().shape.get("tp", 1) + shard_attention_heads = config.shard_attention_heads + if shard_attention_heads: + assert self.num_heads % tp == 0, f"num_heads={self.num_heads} must be divisible by tp={tp}" + assert self.num_kv_heads % tp == 0, f"num_kv_heads={self.num_kv_heads} must be divisible by tp={tp}" + tp_shard = "tp" if shard_attention_heads else None + + self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // self.num_heads + rotary_dim = int(self.head_dim * getattr(config, "partial_rotary_factor", 1.0)) + rotary_dim = min(self.head_dim, rotary_dim) + self.rotary_dim = rotary_dim - (rotary_dim % 2) + + self.q_proj = LoRALinear( + in_features=config.hidden_size, + out_features=self.num_heads * self.head_dim * 2, + sharding=("fsdp", tp_shard), + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + dtype=dtype, + param_dtype=dtype, + use_bias=getattr(config, "attention_bias", False), + kernel_init=nnx.initializers.lecun_normal(), + rngs=rngs, + ) + self.k_proj = LoRALinear( + in_features=config.hidden_size, + out_features=self.num_kv_heads * self.head_dim, + sharding=("fsdp", tp_shard), + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + dtype=dtype, + param_dtype=dtype, + use_bias=getattr(config, "attention_bias", False), + kernel_init=nnx.initializers.lecun_normal(), + rngs=rngs, + ) + self.v_proj = LoRALinear( + in_features=config.hidden_size, + out_features=self.num_kv_heads * self.head_dim, + sharding=("fsdp", tp_shard), + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + dtype=dtype, + param_dtype=dtype, + use_bias=getattr(config, "attention_bias", False), + kernel_init=nnx.initializers.lecun_normal(), + rngs=rngs, + ) + self.o_proj = LoRALinear( + in_features=self.num_heads * self.head_dim, + out_features=config.hidden_size, + sharding=(tp_shard, "fsdp"), + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + dtype=dtype, + param_dtype=dtype, + use_bias=getattr(config, "attention_bias", False), + kernel_init=nnx.initializers.lecun_normal(), + rngs=rngs, + ) + self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) + self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) + + def __call__( + self, + x: jax.Array, + *, + attention_mask: jax.Array, + positions: jax.Array, + adapter_indices: jax.Array | None = None, + kv_cache: tuple[jax.Array, jax.Array] | None = None, + ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: + bsz, seq_len, _ = x.shape + + q_all = self.q_proj(x, adapter_indices=adapter_indices).reshape(bsz, seq_len, self.num_heads, self.head_dim * 2) + q, gate = jnp.split(q_all, 2, axis=-1) + gate = gate.reshape(bsz, seq_len, self.num_heads * self.head_dim) + + q = self.q_norm(q) + k = self.k_norm(self.k_proj(x, adapter_indices=adapter_indices).reshape(bsz, seq_len, self.num_kv_heads, self.head_dim)) + v = self.v_proj(x, adapter_indices=adapter_indices).reshape(bsz, seq_len, self.num_kv_heads, self.head_dim) + + q, k = apply_partial_rope(q, k, positions, self.rotary_dim, self.config.rope_theta) + + if kv_cache is not None: + k, v = KVCache.update_layer(kv_cache, k, v, positions) + + updated_cache = (k, v) + is_causal = kv_cache is None + attn_output = dot_product_attention(q, k, v, attention_mask, is_causal, self.head_dim) + attn_output = attn_output.reshape(bsz, seq_len, self.num_heads * self.head_dim) + attn_output = attn_output * nnx.sigmoid(gate) + return self.o_proj(attn_output, adapter_indices=adapter_indices), updated_cache + + +class Qwen3NextGatedDeltaNet(nnx.Module): + + def __init__(self, config: Qwen3NextConfig, layer_idx: int, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + self.conv_kernel_size = config.linear_conv_kernel_dim + self.conv_dim = self.key_dim * 2 + self.value_dim + + projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 + projection_size_ba = self.num_v_heads * 2 + + # Keep linear-attention projections replicated across TP for simplicity/stability. + self.in_proj_qkvz = LoRALinear( + self.hidden_size, + projection_size_qkvz, + sharding=("fsdp", None), + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + dtype=dtype, + param_dtype=dtype, + use_bias=False, + kernel_init=nnx.initializers.lecun_normal(), + rngs=rngs, + ) + self.in_proj_ba = LoRALinear( + self.hidden_size, + projection_size_ba, + sharding=("fsdp", None), + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + dtype=dtype, + param_dtype=dtype, + use_bias=False, + kernel_init=nnx.initializers.lecun_normal(), + rngs=rngs, + ) + + # Stored as [kernel, 1, channels] so existing safetensors transpose logic round-trips with HF Conv1d. + self.conv1d_weight = Param( + self.conv_kernel_size, + 1, + self.conv_dim, + dtype=dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, None, None)), + rngs=rngs, + ) + self.dt_bias = Param( + self.num_v_heads, + dtype=dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.ones_init(), (None,)), + rngs=rngs, + ) + self.A_log = Param( + self.num_v_heads, + dtype=dtype, + kernel_init=nnx.with_partitioning( + lambda key, shape, dtype: jnp.log( + jax.random.uniform(key, shape, dtype=dtype, minval=1e-3, maxval=16.0) + ), + (None,), + ), + rngs=rngs, + ) + + self.norm = Qwen3NextRMSNormGated(self.head_v_dim, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) + self.out_proj = LoRALinear( + self.value_dim, + self.hidden_size, + sharding=(None, "fsdp"), + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + dtype=dtype, + param_dtype=dtype, + use_bias=False, + kernel_init=nnx.initializers.lecun_normal(), + rngs=rngs, + ) + + def _get_conv_kernel(self) -> jax.Array: + # [kernel, 1, channels] -> [channels, 1, kernel] + return self.conv1d_weight[...].transpose((2, 1, 0)) + + def _causal_conv_prefill(self, x: jax.Array) -> tuple[jax.Array, jax.Array]: + # x: [B, C, T] + kernel = self._get_conv_kernel() + seq_len = x.shape[-1] + left_pad = self.conv_kernel_size - 1 + x_padded = jnp.pad(x, ((0, 0), (0, 0), (left_pad, 0))) + out = jax.lax.conv_general_dilated( + x_padded, + kernel, + window_strides=(1,), + padding="VALID", + feature_group_count=self.conv_dim, + dimension_numbers=("NCH", "OIH", "NCH"), + ) + out = nnx.silu(out[..., :seq_len]) + + state_pad = max(self.conv_kernel_size - seq_len, 0) + conv_state = jnp.pad(x, ((0, 0), (0, 0), (state_pad, 0)))[..., -self.conv_kernel_size :] + return out, conv_state + + def _causal_conv_decode(self, x: jax.Array, conv_state: jax.Array) -> tuple[jax.Array, jax.Array]: + # x: [B, C, T], conv_state: [B, C, K] + kernel = self._get_conv_kernel() + seq_len = x.shape[-1] + x_full = jnp.concatenate([conv_state, x], axis=-1) + new_state = x_full[..., -self.conv_kernel_size :] + out_full = jax.lax.conv_general_dilated( + x_full, + kernel, + window_strides=(1,), + padding="VALID", + feature_group_count=self.conv_dim, + dimension_numbers=("NCH", "OIH", "NCH"), + ) + out = nnx.silu(out_full[..., -seq_len:]) + return out, new_state + + def fix_query_key_value_ordering( + self, + mixed_qkvz: jax.Array, + mixed_ba: jax.Array, + ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + qkvz_shape = mixed_qkvz.shape[:-1] + ( + self.num_k_heads, + 2 * self.head_k_dim + 2 * self.head_v_dim * self.num_v_heads // self.num_k_heads, + ) + ba_shape = mixed_ba.shape[:-1] + (self.num_k_heads, 2 * self.num_v_heads // self.num_k_heads) + mixed_qkvz = mixed_qkvz.reshape(qkvz_shape) + mixed_ba = mixed_ba.reshape(ba_shape) + + split_qkvz = [ + self.head_k_dim, + self.head_k_dim, + self.num_v_heads // self.num_k_heads * self.head_v_dim, + self.num_v_heads // self.num_k_heads * self.head_v_dim, + ] + split_ba = [self.num_v_heads // self.num_k_heads, self.num_v_heads // self.num_k_heads] + + split_qkvz_idx = [split_qkvz[0], split_qkvz[0] + split_qkvz[1], sum(split_qkvz[:-1])] + split_ba_idx = [split_ba[0]] + query, key, value, z = jnp.split(mixed_qkvz, split_qkvz_idx, axis=3) + b, a = jnp.split(mixed_ba, split_ba_idx, axis=3) + + value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim) + z = z.reshape(z.shape[0], z.shape[1], -1, self.head_v_dim) + b = b.reshape(b.shape[0], b.shape[1], self.num_v_heads) + a = a.reshape(a.shape[0], a.shape[1], self.num_v_heads) + return query, key, value, z, b, a + + def __call__( + self, + hidden_states: jax.Array, + *, + attention_mask: jax.Array | None, + adapter_indices: jax.Array | None = None, + conv_state: jax.Array | None = None, + recurrent_state: jax.Array | None = None, + ) -> tuple[jax.Array, jax.Array, jax.Array]: + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + batch_size, seq_len, _ = hidden_states.shape + + projected_qkvz = self.in_proj_qkvz(hidden_states, adapter_indices=adapter_indices) + projected_ba = self.in_proj_ba(hidden_states, adapter_indices=adapter_indices) + query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_qkvz, projected_ba) + + query_flat = query.reshape(batch_size, seq_len, -1) + key_flat = key.reshape(batch_size, seq_len, -1) + value_flat = value.reshape(batch_size, seq_len, -1) + mixed_qkv = jnp.concatenate([query_flat, key_flat, value_flat], axis=-1).transpose((0, 2, 1)) + + use_precomputed = conv_state is not None and recurrent_state is not None and seq_len == 1 + if use_precomputed: + mixed_qkv, new_conv_state = self._causal_conv_decode(mixed_qkv, conv_state) + else: + mixed_qkv, new_conv_state = self._causal_conv_prefill(mixed_qkv) + + mixed_qkv = mixed_qkv.transpose((0, 2, 1)) + q_end = self.key_dim + k_end = self.key_dim * 2 + query_flat = mixed_qkv[..., :q_end] + key_flat = mixed_qkv[..., q_end:k_end] + value_flat = mixed_qkv[..., k_end:] + + query = query_flat.reshape(batch_size, seq_len, -1, self.head_k_dim) + key = key_flat.reshape(batch_size, seq_len, -1, self.head_k_dim) + value = value_flat.reshape(batch_size, seq_len, -1, self.head_v_dim) + + beta = nnx.sigmoid(b) + g = -jnp.exp(self.A_log[...].astype(jnp.float32)) * jax.nn.softplus( + a.astype(jnp.float32) + self.dt_bias[...].astype(jnp.float32) + ) + + if self.num_v_heads // self.num_k_heads > 1: + repeats = self.num_v_heads // self.num_k_heads + query = jnp.repeat(query, repeats, axis=2) + key = jnp.repeat(key, repeats, axis=2) + + core_out, new_recurrent_state = recurrent_gated_delta_rule(query, key, value, g, beta, recurrent_state) + + z_shape = z.shape + core_out = self.norm(core_out.reshape(-1, self.head_v_dim), z.reshape(-1, self.head_v_dim)) + core_out = core_out.reshape(z_shape).reshape(batch_size, seq_len, -1) + out = self.out_proj(core_out, adapter_indices=adapter_indices) + return out, new_conv_state, new_recurrent_state + + +class Qwen3NextMLP(nnx.Module): + + def __init__( + self, + config: Qwen3NextConfig, + *, + dtype: jnp.dtype, + rngs: nnx.Rngs, + intermediate_size: int | None = None, + ) -> None: + hidden_size = config.hidden_size + intermediate_size = intermediate_size or config.intermediate_size + + self.gate_proj = LoRALinear( + hidden_size, + intermediate_size, + sharding=("fsdp", "tp"), + use_bias=False, + dtype=dtype, + param_dtype=dtype, + kernel_init=nnx.initializers.lecun_normal(), + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + rngs=rngs, + ) + self.up_proj = LoRALinear( + hidden_size, + intermediate_size, + sharding=("fsdp", "tp"), + use_bias=False, + dtype=dtype, + param_dtype=dtype, + kernel_init=nnx.initializers.lecun_normal(), + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + rngs=rngs, + ) + self.down_proj = LoRALinear( + intermediate_size, + hidden_size, + sharding=("tp", "fsdp"), + use_bias=False, + dtype=dtype, + param_dtype=dtype, + kernel_init=nnx.initializers.lecun_normal(), + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + rngs=rngs, + ) + + def __call__(self, x: jax.Array, adapter_indices: jax.Array | None = None) -> jax.Array: + return self.down_proj(nnx.silu(self.gate_proj(x, adapter_indices)) * self.up_proj(x, adapter_indices), adapter_indices) + + +class Qwen3NextExperts(nnx.Module): + + def __init__(self, config: Qwen3NextConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + self.config = config + self.gate_proj = LoRAExpert( + config.num_experts, + config.hidden_size, + config.moe_intermediate_size, + sharding=("ep", "fsdp", "tp"), + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + dtype=dtype, + kernel_init=nnx.initializers.lecun_normal(), + rngs=rngs, + ) + self.up_proj = LoRAExpert( + config.num_experts, + config.hidden_size, + config.moe_intermediate_size, + sharding=("ep", "fsdp", "tp"), + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + dtype=dtype, + kernel_init=nnx.initializers.lecun_normal(), + rngs=rngs, + ) + self.down_proj = LoRAExpert( + config.num_experts, + config.moe_intermediate_size, + config.hidden_size, + sharding=("ep", "tp", "fsdp"), + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + dtype=dtype, + kernel_init=nnx.initializers.lecun_normal(), + rngs=rngs, + ) + + def __call__( + self, + hidden_states: jax.Array, + selected_experts: jax.Array, + routing_weights: jax.Array, + adapter_indices: jax.Array | None = None, + ) -> jax.Array: + num_experts = self.config.num_experts + top_k = self.config.num_experts_per_tok + hidden_size = self.config.hidden_size + + ep = get_abstract_mesh().shape.get("ep", 1) + assert num_experts % ep == 0, f"num_experts={num_experts} must be divisible by ep={ep}" + + hidden_expanded = jnp.repeat(hidden_states, top_k, axis=0) + adapter_expanded = jnp.repeat(adapter_indices, top_k) if adapter_indices is not None else None + hidden_sorted, group_sizes, unsort_indices, adapter_sorted = prepare_routing( + hidden_expanded, selected_experts.ravel(), num_experts, adapter_indices=adapter_expanded + ) + + def forward(experts, hidden_sorted, group_sizes, unsort_indices, adapter_sorted, routing_weights): + ep_rank = jax.lax.axis_index("ep") + experts_per_rank = num_experts // jax.lax.axis_size("ep") + group_offset = jnp.array([ep_rank * experts_per_rank], dtype=jnp.int32) + + gate = experts.gate_proj(hidden_sorted, group_sizes, adapter_sorted, group_offset=group_offset) + up = experts.up_proj(hidden_sorted, group_sizes, adapter_sorted, group_offset=group_offset) + down = experts.down_proj(nnx.silu(gate) * up, group_sizes, adapter_sorted, group_offset=group_offset) + + out = down[unsort_indices].reshape(-1, top_k, hidden_size) + local_out = jnp.sum(out * routing_weights[..., None], axis=1) + return jax.lax.psum(local_out, axis_name="ep") + + return shard_map_ep(self, forward, hidden_sorted, group_sizes, unsort_indices, adapter_sorted, routing_weights) + + +class Qwen3NextSparseMoeBlock(nnx.Module): + + def __init__(self, config: Qwen3NextConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + self.config = config + self.gate = nnx.Linear( + config.hidden_size, + config.num_experts, + use_bias=False, + dtype=dtype, + param_dtype=dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, None)), + rngs=rngs, + ) + self.experts = Qwen3NextExperts(config, dtype=dtype, rngs=rngs) + self.shared_expert = Qwen3NextMLP( + config, + dtype=dtype, + rngs=rngs, + intermediate_size=config.shared_expert_intermediate_size, + ) + self.shared_expert_gate = LoRALinear( + config.hidden_size, + 1, + sharding=("fsdp", None), + use_bias=False, + dtype=dtype, + param_dtype=dtype, + kernel_init=nnx.initializers.lecun_normal(), + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + rngs=rngs, + ) + + def __call__(self, hidden_states: jax.Array, adapter_indices: jax.Array | None = None) -> jax.Array: + batch_size, seq_len, hidden_dim = hidden_states.shape + hidden_states_flat = hidden_states.reshape(-1, hidden_dim) + adapter_flat = jnp.repeat(adapter_indices, seq_len) if adapter_indices is not None else None + + router_logits = self.gate(hidden_states_flat) + routing_weights = nnx.softmax(router_logits, axis=-1) + routing_weights, selected_experts = jax.lax.top_k(routing_weights, k=self.config.num_experts_per_tok) + if self.config.norm_topk_prob: + routing_weights = routing_weights / jnp.sum(routing_weights, axis=-1, keepdims=True) + routing_weights = routing_weights.astype(hidden_states_flat.dtype) + + expert_output = self.experts(hidden_states_flat, selected_experts, routing_weights, adapter_flat) + shared_output = self.shared_expert(hidden_states_flat, adapter_indices=adapter_flat) + shared_gate = nnx.sigmoid(self.shared_expert_gate(hidden_states_flat, adapter_indices=adapter_flat)) + + final_hidden_states = expert_output + shared_gate * shared_output + return final_hidden_states.reshape(batch_size, seq_len, hidden_dim) + + +class Qwen3NextDecoderLayer(nnx.Module): + + def __init__(self, config: Qwen3NextConfig, layer_idx: int, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + self.layer_type = config.layer_types[layer_idx] + self.input_layernorm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) + self.post_attention_layernorm = Qwen3NextRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs + ) + + if self.layer_type == "linear_attention": + self.linear_attn = Qwen3NextGatedDeltaNet(config, layer_idx, dtype=dtype, rngs=rngs) + else: + self.self_attn = Qwen3NextAttention(config, dtype=dtype, rngs=rngs) + + use_moe = ( + layer_idx not in getattr(config, "mlp_only_layers", []) + and getattr(config, "num_experts", 0) > 0 + and (layer_idx + 1) % getattr(config, "decoder_sparse_step", 1) == 0 + ) + if use_moe: + self.mlp = Qwen3NextSparseMoeBlock(config, dtype=dtype, rngs=rngs) + else: + self.mlp = Qwen3NextMLP(config, dtype=dtype, rngs=rngs) + + def __call__( + self, + hidden_states: jax.Array, + *, + attention_mask: jax.Array | None, + positions: jax.Array, + adapter_indices: jax.Array | None = None, + kv_cache: tuple[jax.Array, jax.Array] | None = None, + conv_state: jax.Array | None = None, + recurrent_state: jax.Array | None = None, + ) -> tuple[jax.Array, tuple[jax.Array, jax.Array] | None, jax.Array | None, jax.Array | None]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + if self.layer_type == "linear_attention": + hidden_states, new_conv_state, new_recurrent_state = self.linear_attn( + hidden_states, + attention_mask=attention_mask, + adapter_indices=adapter_indices, + conv_state=conv_state, + recurrent_state=recurrent_state, + ) + updated_kv = None + else: + hidden_states, updated_kv = self.self_attn( + hidden_states, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=kv_cache, + ) + new_conv_state = conv_state + new_recurrent_state = recurrent_state + + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states, adapter_indices=adapter_indices) + hidden_states = residual + hidden_states + + return hidden_states, updated_kv, new_conv_state, new_recurrent_state + + +class Qwen3NextModel(nnx.Module): + + def __init__(self, config: Qwen3NextConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + self.config = config + self.embed_tokens = LoRAEmbed( + num_embeddings=config.vocab_size, + features=config.hidden_size, + sharding=("tp", None), + dtype=dtype, + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + param_dtype=dtype, + embedding_init=nnx.initializers.normal(), + rngs=rngs, + ) + + layer_types = getattr(config, "layer_types", None) + if layer_types is None: + interval = getattr(config, "full_attention_interval", 4) + layer_types = [ + "linear_attention" if (i + 1) % interval else "full_attention" for i in range(config.num_hidden_layers) + ] + config.layer_types = layer_types + + assert len(config.layer_types) == config.num_hidden_layers + self.layer_types = tuple(config.layer_types) + self.layers = nnx.List( + [Qwen3NextDecoderLayer(config, i, dtype=dtype, rngs=rngs) for i in range(config.num_hidden_layers)] + ) + self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) + + def __call__( + self, + input_ids: jax.Array, + *, + attention_mask: jax.Array, + positions: jax.Array, + output_hidden_states: bool | None = None, + adapter_indices: jax.Array | None = None, + kv_cache: KVCache | None = None, + is_training: bool = False, + ) -> ModelOutput: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + hidden_states = self.embed_tokens(input_ids, adapter_indices=adapter_indices) + all_hidden_states: list[jax.Array] = [] + updated_keys: list[jax.Array] = [] + updated_values: list[jax.Array] = [] + updated_conv_states: list[jax.Array] = [] + updated_recurrent_states: list[jax.Array] = [] + + batch_size = input_ids.shape[0] + dtype = hidden_states.dtype + + for layer_idx, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states.append(hidden_states) + + layer_type = self.layer_types[layer_idx] + if layer_type == "full_attention": + layer_kv = (kv_cache.keys[layer_idx], kv_cache.values[layer_idx]) if kv_cache is not None else None + hidden_states, updated_kv, new_conv_state, new_recurrent_state = layer( + hidden_states, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=layer_kv, + ) + assert updated_kv is not None + updated_keys.append(updated_kv[0]) + updated_values.append(updated_kv[1]) + + if kv_cache is not None and kv_cache.conv_states is not None and kv_cache.recurrent_states is not None: + updated_conv_states.append(kv_cache.conv_states[layer_idx]) + updated_recurrent_states.append(kv_cache.recurrent_states[layer_idx]) + else: + updated_conv_states.append(jnp.zeros((batch_size, 0, 0), dtype=dtype)) + updated_recurrent_states.append(jnp.zeros((batch_size, 0, 0, 0), dtype=dtype)) + else: + layer_conv_state = None + layer_recurrent_state = None + if kv_cache is not None and kv_cache.conv_states is not None and kv_cache.recurrent_states is not None: + layer_conv_state = kv_cache.conv_states[layer_idx] + layer_recurrent_state = kv_cache.recurrent_states[layer_idx] + + linear_mask = None if kv_cache is not None else attention_mask + hidden_states, _, new_conv_state, new_recurrent_state = layer( + hidden_states, + attention_mask=linear_mask, + positions=positions, + adapter_indices=adapter_indices, + conv_state=layer_conv_state, + recurrent_state=layer_recurrent_state, + ) + assert new_conv_state is not None and new_recurrent_state is not None + updated_conv_states.append(new_conv_state) + updated_recurrent_states.append(new_recurrent_state) + + if kv_cache is not None: + updated_keys.append(kv_cache.keys[layer_idx]) + updated_values.append(kv_cache.values[layer_idx]) + else: + updated_keys.append(jnp.zeros((batch_size, 0, 0, 0), dtype=dtype)) + updated_values.append(jnp.zeros((batch_size, 0, 0, 0), dtype=dtype)) + + hidden_states = self.norm(hidden_states) + if output_hidden_states: + all_hidden_states.append(hidden_states) + + if is_training: + new_kv_cache = None + else: + new_kv_cache = KVCache.update( + kv_cache, + updated_keys, + updated_values, + positions, + attention_mask, + conv_states=updated_conv_states, + recurrent_states=updated_recurrent_states, + ) + + return ModelOutput( + last_hidden_state=hidden_states, + kv_cache=new_kv_cache, + hidden_states=all_hidden_states if output_hidden_states else None, + ) + + +class Qwen3NextForCausalLM(nnx.Module, ModelForCausalLM, GeneratorMixin, LogitsProcessorMixin): + + def __init__(self, config: Qwen3NextConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + self.config = config + self.model = Qwen3NextModel(config, dtype=dtype, rngs=rngs) + + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens.T + else: + self.lm_head = LoRALinear( + config.hidden_size, + config.vocab_size, + sharding=(None, "tp"), + use_bias=False, + dtype=dtype, + param_dtype=dtype, + kernel_init=nnx.initializers.lecun_normal(), + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + rngs=rngs, + ) + + def get_lm_head(self) -> LMHead: + return self.lm_head + + def __call__( + self, + input_ids: jax.Array, + *, + attention_mask: jax.Array, + positions: jax.Array | None = None, + output_hidden_states: bool | None = None, + adapter_indices: jax.Array | None = None, + kv_cache: KVCache | None = None, + is_training: bool = False, + ) -> CausalLMOutput: + if positions is None: + positions = jnp.arange(attention_mask.shape[1])[None, :] + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + positions=positions, + output_hidden_states=output_hidden_states, + adapter_indices=adapter_indices, + kv_cache=kv_cache, + is_training=is_training, + ) + + return CausalLMOutput( + last_hidden_state=outputs.last_hidden_state, + kv_cache=outputs.kv_cache, + hidden_states=outputs.hidden_states, + ) diff --git a/skyrl-tx/tx/utils/generator.py b/skyrl-tx/tx/utils/generator.py index b407140c6f..60d3bc581c 100644 --- a/skyrl-tx/tx/utils/generator.py +++ b/skyrl-tx/tx/utils/generator.py @@ -19,6 +19,8 @@ class KVCache: keys: list[jax.Array] # list of (batch, seq, num_kv_heads, head_dim) per layer values: list[jax.Array] # list of (batch, seq, num_kv_heads, head_dim) per layer cache_position: jax.Array # Per-sequence positions of shape (batch,) + conv_states: list[jax.Array] | None = None + recurrent_states: list[jax.Array] | None = None @staticmethod def update( @@ -27,6 +29,9 @@ def update( values: list[jax.Array], positions: jax.Array, attention_mask: jax.Array, + *, + conv_states: list[jax.Array] | None = None, + recurrent_states: list[jax.Array] | None = None, ) -> KVCache: """Create an updated KVCache with computed cache positions for left-aligned decoding. @@ -46,7 +51,15 @@ def update( else: # Prefill: next position is the sequence length (number of real tokens) cache_position = attention_mask.sum(axis=1) - return KVCache(keys=keys, values=values, cache_position=cache_position) + return KVCache( + keys=keys, + values=values, + cache_position=cache_position, + conv_states=conv_states if conv_states is not None else (kv_cache.conv_states if kv_cache is not None else None), + recurrent_states=( + recurrent_states if recurrent_states is not None else (kv_cache.recurrent_states if kv_cache is not None else None) + ), + ) @staticmethod def update_layer(kv_cache, k, v, positions): @@ -76,13 +89,22 @@ def pad_to_length(self, max_length: int) -> KVCache: Returns: New KVCache with padded keys and values. """ - # k and v have shape [B, T, num_heads, head_dim] - cache_pad_length = max_length - self.keys[0].shape[1] - pad_spec = ((0, 0), (0, cache_pad_length), (0, 0), (0, 0)) return KVCache( - keys=[jnp.pad(k, pad_spec) for k in self.keys], - values=[jnp.pad(v, pad_spec) for v in self.values], + keys=[ + jnp.pad(k, ((0, 0), (0, max_length - k.shape[1]), (0, 0), (0, 0))) + if k.shape[1] < max_length + else k + for k in self.keys + ], + values=[ + jnp.pad(v, ((0, 0), (0, max_length - v.shape[1]), (0, 0), (0, 0))) + if v.shape[1] < max_length + else v + for v in self.values + ], cache_position=self.cache_position, + conv_states=self.conv_states, + recurrent_states=self.recurrent_states, ) diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index 2b5afe153a..a584a21bee 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -66,6 +66,7 @@ def get_model_class(config: PretrainedConfig) -> Callable[..., nnx.Module]: "Get the correct model class based on the config." import tx.models.llama3 import tx.models.qwen3 + import tx.models.qwen3_next import tx.models.deepseekv3 for architecture in config.architectures or []: @@ -73,6 +74,8 @@ def get_model_class(config: PretrainedConfig) -> Callable[..., nnx.Module]: return getattr(tx.models.llama3, architecture) if hasattr(tx.models.qwen3, architecture): return getattr(tx.models.qwen3, architecture) + if hasattr(tx.models.qwen3_next, architecture): + return getattr(tx.models.qwen3_next, architecture) if hasattr(tx.models.deepseekv3, architecture): return getattr(tx.models.deepseekv3, architecture) @@ -122,6 +125,8 @@ def get_param_key(path: tuple, prefix: str = "") -> str: "Get the safetensors key for a given model path." if path[-1] in {"embedding", "kernel"}: path = (*path[:-1], "weight") + elif path[-1] == "conv1d_weight": + path = (*path[:-1], "conv1d", "weight") elif path[-1] in {"lora_A", "lora_B"}: path = (*path, "weight") return prefix + ".".join(map(str, path)) From db87dc009914cc303c42388657157cde357aa7d0 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 25 Feb 2026 00:05:55 -0800 Subject: [PATCH 02/12] add hack --- skyrl-tx/tx/models/configs.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/skyrl-tx/tx/models/configs.py b/skyrl-tx/tx/models/configs.py index f3fadb9c19..61917d3eac 100644 --- a/skyrl-tx/tx/models/configs.py +++ b/skyrl-tx/tx/models/configs.py @@ -38,8 +38,18 @@ def __init__( gradient_checkpointing: bool = False, mhc_expansion_rate: int = 1, ): - # Copy all attributes from the base config - super().__init__(**config.to_dict()) + # Copy attributes from the base config. + # Some configs (especially multimodal wrappers) keep language-model fields + # under nested dicts like "text_config". Merge these as fallbacks so + # model code can consistently access top-level attributes. + config_dict = config.to_dict() + for nested_key in ("text_config", "language_config"): + nested = config_dict.get(nested_key) + if isinstance(nested, dict): + for key, value in nested.items(): + config_dict.setdefault(key, value) + + super().__init__(**config_dict) # Add LoRA-specific parameters self.max_lora_adapters = max_lora_adapters From 3a49b9f8c97ba5a937159496d44ebb757085423c Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 25 Feb 2026 01:16:28 -0800 Subject: [PATCH 03/12] fix rope --- skyrl-tx/tests/models/test_qwen3_next.py | 55 ++++++++++++++++++++++++ skyrl-tx/tx/models/qwen3_next.py | 16 ++++++- 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/skyrl-tx/tests/models/test_qwen3_next.py b/skyrl-tx/tests/models/test_qwen3_next.py index 5fbeffe317..197635d6c3 100644 --- a/skyrl-tx/tests/models/test_qwen3_next.py +++ b/skyrl-tx/tests/models/test_qwen3_next.py @@ -6,6 +6,7 @@ import numpy as np import pytest import torch +from transformers import PretrainedConfig from transformers.models.qwen3_next.configuration_qwen3_next import Qwen3NextConfig as HFQwen3NextConfig from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextForCausalLM as HFQwen3NextForCausalLM @@ -35,6 +36,11 @@ def make_small_hf_config() -> HFQwen3NextConfig: num_experts=0, num_experts_per_tok=1, decoder_sparse_step=1, + rope_parameters={ + "rope_type": "default", + "rope_theta": 10000.0, + "partial_rotary_factor": 0.25, + }, ) @@ -163,3 +169,52 @@ def test_qwen3_next_generate(): assert len(out.generated_ids) == 1 assert len(out.logprobs) == 1 assert len(out.generated_ids[0]) == 2 + + +def test_qwen3_next_nested_rope_parameters_without_top_level_rope_theta(): + text_config = { + "vocab_size": 128, + "hidden_size": 32, + "intermediate_size": 64, + "num_hidden_layers": 4, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "head_dim": 8, + "rms_norm_eps": 1e-6, + "max_position_embeddings": 128, + "tie_word_embeddings": False, + "linear_conv_kernel_dim": 3, + "linear_key_head_dim": 4, + "linear_value_head_dim": 4, + "linear_num_key_heads": 2, + "linear_num_value_heads": 2, + "layer_types": ["linear_attention", "full_attention", "linear_attention", "full_attention"], + "num_experts": 0, + "num_experts_per_tok": 1, + "decoder_sparse_step": 1, + "rope_parameters": { + "rope_type": "default", + "rope_theta": 10_000_000, + "partial_rotary_factor": 0.25, + }, + } + base_config = PretrainedConfig( + architectures=["Qwen3NextForCausalLM"], + model_type="qwen3_5_moe", + text_config=text_config, + ) + config = Qwen3NextConfig( + base_config, + max_lora_adapters=2, + max_lora_rank=8, + shard_attention_heads=False, + ) + + mesh = jax.make_mesh((1, 1, 1), ("fsdp", "ep", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 3) + with jax.set_mesh(mesh): + model = Qwen3NextForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) + input_ids = jnp.array([[1, 2, 3]], dtype=jnp.int32) + attention_mask = jnp.array([[1, 1, 1]], dtype=jnp.int32) + outputs = model(input_ids, attention_mask=attention_mask) + + assert outputs.last_hidden_state.shape == (1, 3, config.hidden_size) diff --git a/skyrl-tx/tx/models/qwen3_next.py b/skyrl-tx/tx/models/qwen3_next.py index 1c08c18645..39cfb92f80 100644 --- a/skyrl-tx/tx/models/qwen3_next.py +++ b/skyrl-tx/tx/models/qwen3_next.py @@ -148,9 +148,21 @@ def __init__(self, config: Qwen3NextConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) tp_shard = "tp" if shard_attention_heads else None self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // self.num_heads - rotary_dim = int(self.head_dim * getattr(config, "partial_rotary_factor", 1.0)) + rope_parameters = getattr(config, "rope_parameters", None) + assert isinstance( + rope_parameters, dict + ), "Qwen3NextAttention requires config.rope_parameters to be a dict." + assert "partial_rotary_factor" in rope_parameters, ( + "Qwen3NextAttention requires rope_parameters['partial_rotary_factor']." + ) + assert "rope_theta" in rope_parameters, "Qwen3NextAttention requires rope_parameters['rope_theta']." + partial_rotary_factor = rope_parameters["partial_rotary_factor"] + rope_theta = rope_parameters["rope_theta"] + + rotary_dim = int(self.head_dim * partial_rotary_factor) rotary_dim = min(self.head_dim, rotary_dim) self.rotary_dim = rotary_dim - (rotary_dim % 2) + self.rope_theta = rope_theta self.q_proj = LoRALinear( in_features=config.hidden_size, @@ -222,7 +234,7 @@ def __call__( k = self.k_norm(self.k_proj(x, adapter_indices=adapter_indices).reshape(bsz, seq_len, self.num_kv_heads, self.head_dim)) v = self.v_proj(x, adapter_indices=adapter_indices).reshape(bsz, seq_len, self.num_kv_heads, self.head_dim) - q, k = apply_partial_rope(q, k, positions, self.rotary_dim, self.config.rope_theta) + q, k = apply_partial_rope(q, k, positions, self.rotary_dim, self.rope_theta) if kv_cache is not None: k, v = KVCache.update_layer(kv_cache, k, v, positions) From a928e22bb120c7de3d78cec6b1cca242f6715f08 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 25 Feb 2026 10:32:40 -0800 Subject: [PATCH 04/12] use bf16 --- skyrl-tx/tx/models/qwen3_next.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/skyrl-tx/tx/models/qwen3_next.py b/skyrl-tx/tx/models/qwen3_next.py index 39cfb92f80..c1d9d977fc 100644 --- a/skyrl-tx/tx/models/qwen3_next.py +++ b/skyrl-tx/tx/models/qwen3_next.py @@ -53,11 +53,12 @@ def recurrent_gated_delta_rule( initial_state: jax.Array | None = None, ) -> tuple[jax.Array, jax.Array]: dtype = query.dtype - query = l2norm(query.astype(jnp.float32), axis=-1) - key = l2norm(key.astype(jnp.float32), axis=-1) - value = value.astype(jnp.float32) - g = g.astype(jnp.float32) - beta = beta.astype(jnp.float32) + compute_dtype = dtype + query = l2norm(query.astype(compute_dtype), axis=-1) + key = l2norm(key.astype(compute_dtype), axis=-1) + value = value.astype(compute_dtype) + g = g.astype(compute_dtype) + beta = beta.astype(compute_dtype) query = query * (1.0 / math.sqrt(query.shape[-1])) @@ -74,9 +75,9 @@ def recurrent_gated_delta_rule( v_head_dim = value.shape[3] if initial_state is None: - initial_state = jnp.zeros((batch_size, num_heads, k_head_dim, v_head_dim), dtype=jnp.float32) + initial_state = jnp.zeros((batch_size, num_heads, k_head_dim, v_head_dim), dtype=compute_dtype) else: - initial_state = initial_state.astype(jnp.float32) + initial_state = initial_state.astype(compute_dtype) def step_fn( state: jax.Array, @@ -108,9 +109,10 @@ def __init__(self, dim: int, *, eps: float, dtype: jnp.dtype, rngs: nnx.Rngs) -> ) def __call__(self, x: jax.Array) -> jax.Array: - out = x.astype(jnp.float32) + compute_dtype = x.dtype + out = x.astype(compute_dtype) out = out * jax.lax.rsqrt(jnp.mean(out * out, axis=-1, keepdims=True) + self.eps) - out = out * (1.0 + self.weight[...].astype(jnp.float32)) + out = out * (1.0 + self.weight[...].astype(compute_dtype)) return out.astype(x.dtype) @@ -127,10 +129,11 @@ def __init__(self, dim: int, *, eps: float, dtype: jnp.dtype, rngs: nnx.Rngs) -> def __call__(self, hidden_states: jax.Array, gate: jax.Array) -> jax.Array: input_dtype = hidden_states.dtype - out = hidden_states.astype(jnp.float32) + compute_dtype = hidden_states.dtype + out = hidden_states.astype(compute_dtype) out = out * jax.lax.rsqrt(jnp.mean(out * out, axis=-1, keepdims=True) + self.eps) - out = out * self.weight[...].astype(jnp.float32) - out = out * nnx.silu(gate.astype(jnp.float32)) + out = out * self.weight[...].astype(compute_dtype) + out = out * nnx.silu(gate.astype(compute_dtype)) return out.astype(input_dtype) From abf461bc896e9373ed5a0941044cbfa6bf85adf9 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 25 Feb 2026 18:55:24 +0000 Subject: [PATCH 05/12] workaround --- skyrl-tx/tx/models/qwen3_next.py | 2 ++ skyrl-tx/tx/utils/models.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/skyrl-tx/tx/models/qwen3_next.py b/skyrl-tx/tx/models/qwen3_next.py index c1d9d977fc..511623ed2e 100644 --- a/skyrl-tx/tx/models/qwen3_next.py +++ b/skyrl-tx/tx/models/qwen3_next.py @@ -898,3 +898,5 @@ def __call__( kv_cache=outputs.kv_cache, hidden_states=outputs.hidden_states, ) + +Qwen3_5ForConditionalGeneration = Qwen3NextForCausalLM diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index a584a21bee..ffc6ede5d9 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -59,7 +59,8 @@ def get_dtype(dtype: str | torch.dtype) -> jnp.dtype: case "torch.float16" | "float16": return jnp.float16 case _: - raise ValueError(f"Unsupported torch dtype: {dtype}") + return jnp.bfloat16 + # raise ValueError(f"Unsupported torch dtype: {dtype}") def get_model_class(config: PretrainedConfig) -> Callable[..., nnx.Module]: From 941db37bc620911df9c3a44bb8bd485e76464ef1 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 25 Feb 2026 14:06:27 -0800 Subject: [PATCH 06/12] fix --- skyrl-tx/tx/models/qwen3_next.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/skyrl-tx/tx/models/qwen3_next.py b/skyrl-tx/tx/models/qwen3_next.py index 511623ed2e..ce2b4c02ad 100644 --- a/skyrl-tx/tx/models/qwen3_next.py +++ b/skyrl-tx/tx/models/qwen3_next.py @@ -633,8 +633,7 @@ def __call__(self, hidden_states: jax.Array, adapter_indices: jax.Array | None = router_logits = self.gate(hidden_states_flat) routing_weights = nnx.softmax(router_logits, axis=-1) routing_weights, selected_experts = jax.lax.top_k(routing_weights, k=self.config.num_experts_per_tok) - if self.config.norm_topk_prob: - routing_weights = routing_weights / jnp.sum(routing_weights, axis=-1, keepdims=True) + routing_weights = routing_weights / jnp.sum(routing_weights, axis=-1, keepdims=True) routing_weights = routing_weights.astype(hidden_states_flat.dtype) expert_output = self.experts(hidden_states_flat, selected_experts, routing_weights, adapter_flat) From d0de84fab9e10021da4e88e7bdb5749a88f1e131 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 25 Feb 2026 15:38:09 -0800 Subject: [PATCH 07/12] debug --- skyrl-tx/tx/models/qwen3_next.py | 6 +++--- skyrl-tx/tx/utils/models.py | 8 ++++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/skyrl-tx/tx/models/qwen3_next.py b/skyrl-tx/tx/models/qwen3_next.py index ce2b4c02ad..2d57170cd1 100644 --- a/skyrl-tx/tx/models/qwen3_next.py +++ b/skyrl-tx/tx/models/qwen3_next.py @@ -53,7 +53,7 @@ def recurrent_gated_delta_rule( initial_state: jax.Array | None = None, ) -> tuple[jax.Array, jax.Array]: dtype = query.dtype - compute_dtype = dtype + compute_dtype = jnp.float32 query = l2norm(query.astype(compute_dtype), axis=-1) key = l2norm(key.astype(compute_dtype), axis=-1) value = value.astype(compute_dtype) @@ -109,7 +109,7 @@ def __init__(self, dim: int, *, eps: float, dtype: jnp.dtype, rngs: nnx.Rngs) -> ) def __call__(self, x: jax.Array) -> jax.Array: - compute_dtype = x.dtype + compute_dtype = jnp.float32 out = x.astype(compute_dtype) out = out * jax.lax.rsqrt(jnp.mean(out * out, axis=-1, keepdims=True) + self.eps) out = out * (1.0 + self.weight[...].astype(compute_dtype)) @@ -129,7 +129,7 @@ def __init__(self, dim: int, *, eps: float, dtype: jnp.dtype, rngs: nnx.Rngs) -> def __call__(self, hidden_states: jax.Array, gate: jax.Array) -> jax.Array: input_dtype = hidden_states.dtype - compute_dtype = hidden_states.dtype + compute_dtype = jnp.float32 out = hidden_states.astype(compute_dtype) out = out * jax.lax.rsqrt(jnp.mean(out * out, axis=-1, keepdims=True) + self.eps) out = out * self.weight[...].astype(compute_dtype) diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index ffc6ede5d9..26f56bfec7 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -174,6 +174,14 @@ def load_safetensors( if skip_lora and is_connector_path(path): continue if key not in tensors: + if not ( + "lora_A" in path + or "lora_B" in path + or "lora_scaling" in path + or "lora_ranks" in path + or is_connector_path(path) + ): + logger.warning(f"Missing non-LoRA checkpoint key while loading from {checkpoint_dir}: {key}") continue if "experts" in path: tensor = np.stack([tensors[get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0) From 5a07de99d2c4fa25c3ad0cecdb74588f92c2becb Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 25 Feb 2026 15:43:50 -0800 Subject: [PATCH 08/12] checkpoints --- skyrl-tx/tx/utils/models.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index 26f56bfec7..2f2040edd5 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -167,13 +167,19 @@ def load_safetensors( if filter_fn is not None and not filter_fn(path): continue key = get_param_key(path) + key_to_load = key + if key_to_load not in tensors and key_to_load.startswith("model."): + # Qwen3.5 checkpoints store language weights under `model.language_model.*`. + language_key = "model.language_model." + key_to_load[len("model.") :] + if language_key in tensors: + key_to_load = language_key # Skip LoRA parameters if requested if skip_lora and ("lora_A" in path or "lora_B" in path or "lora_scaling" in path or "lora_ranks" in path): continue # Skip connector parameters if skip_lora and is_connector_path(path): continue - if key not in tensors: + if key_to_load not in tensors: if not ( "lora_A" in path or "lora_B" in path @@ -184,9 +190,15 @@ def load_safetensors( logger.warning(f"Missing non-LoRA checkpoint key while loading from {checkpoint_dir}: {key}") continue if "experts" in path: - tensor = np.stack([tensors[get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0) + def expert_key(i: int) -> str: + k = get_expert_key(path, i) + if key_to_load != key and k.startswith("model."): + return "model.language_model." + k[len("model.") :] + return k + + tensor = np.stack([tensors[expert_key(i)].T for i in range(config.get_num_experts())], axis=0) else: - tensor = tensors[key] if "embed_tokens" in key else tensors[key].T + tensor = tensors[key_to_load] if "embed_tokens" in key_to_load else tensors[key_to_load].T adapter_idx = get_adapter_slice(path, adapter_index, rank) if adapter_idx is not None: # Load into specific adapter slot via ArrayRef write-through From 0919e5452488fdacbecd56ee88e887516e176250 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 25 Feb 2026 15:44:53 -0800 Subject: [PATCH 09/12] update --- skyrl-tx/tx/models/qwen3_next.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/skyrl-tx/tx/models/qwen3_next.py b/skyrl-tx/tx/models/qwen3_next.py index 2d57170cd1..ce2b4c02ad 100644 --- a/skyrl-tx/tx/models/qwen3_next.py +++ b/skyrl-tx/tx/models/qwen3_next.py @@ -53,7 +53,7 @@ def recurrent_gated_delta_rule( initial_state: jax.Array | None = None, ) -> tuple[jax.Array, jax.Array]: dtype = query.dtype - compute_dtype = jnp.float32 + compute_dtype = dtype query = l2norm(query.astype(compute_dtype), axis=-1) key = l2norm(key.astype(compute_dtype), axis=-1) value = value.astype(compute_dtype) @@ -109,7 +109,7 @@ def __init__(self, dim: int, *, eps: float, dtype: jnp.dtype, rngs: nnx.Rngs) -> ) def __call__(self, x: jax.Array) -> jax.Array: - compute_dtype = jnp.float32 + compute_dtype = x.dtype out = x.astype(compute_dtype) out = out * jax.lax.rsqrt(jnp.mean(out * out, axis=-1, keepdims=True) + self.eps) out = out * (1.0 + self.weight[...].astype(compute_dtype)) @@ -129,7 +129,7 @@ def __init__(self, dim: int, *, eps: float, dtype: jnp.dtype, rngs: nnx.Rngs) -> def __call__(self, hidden_states: jax.Array, gate: jax.Array) -> jax.Array: input_dtype = hidden_states.dtype - compute_dtype = jnp.float32 + compute_dtype = hidden_states.dtype out = hidden_states.astype(compute_dtype) out = out * jax.lax.rsqrt(jnp.mean(out * out, axis=-1, keepdims=True) + self.eps) out = out * self.weight[...].astype(compute_dtype) From 2142df7592f84835d7dc12b43c358dae0683e585 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 25 Feb 2026 16:00:30 -0800 Subject: [PATCH 10/12] update --- skyrl-tx/tx/models/qwen3_next.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/skyrl-tx/tx/models/qwen3_next.py b/skyrl-tx/tx/models/qwen3_next.py index ce2b4c02ad..2870c98caf 100644 --- a/skyrl-tx/tx/models/qwen3_next.py +++ b/skyrl-tx/tx/models/qwen3_next.py @@ -264,13 +264,13 @@ def __init__(self, config: Qwen3NextConfig, layer_idx: int, *, dtype: jnp.dtype, self.conv_kernel_size = config.linear_conv_kernel_dim self.conv_dim = self.key_dim * 2 + self.value_dim - projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 - projection_size_ba = self.num_v_heads * 2 + projection_size_a = self.key_dim * 2 + self.value_dim * 2 + projection_size_b = self.num_v_heads * 2 # Keep linear-attention projections replicated across TP for simplicity/stability. - self.in_proj_qkvz = LoRALinear( + self.in_proj_a = LoRALinear( self.hidden_size, - projection_size_qkvz, + projection_size_a, sharding=("fsdp", None), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, @@ -280,9 +280,9 @@ def __init__(self, config: Qwen3NextConfig, layer_idx: int, *, dtype: jnp.dtype, kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) - self.in_proj_ba = LoRALinear( + self.in_proj_b = LoRALinear( self.hidden_size, - projection_size_ba, + projection_size_b, sharding=("fsdp", None), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, @@ -419,8 +419,8 @@ def __call__( hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) batch_size, seq_len, _ = hidden_states.shape - projected_qkvz = self.in_proj_qkvz(hidden_states, adapter_indices=adapter_indices) - projected_ba = self.in_proj_ba(hidden_states, adapter_indices=adapter_indices) + projected_qkvz = self.in_proj_a(hidden_states, adapter_indices=adapter_indices) + projected_ba = self.in_proj_b(hidden_states, adapter_indices=adapter_indices) query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_qkvz, projected_ba) query_flat = query.reshape(batch_size, seq_len, -1) From e31d69a21d8131283885975a731cacb15445d947 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 25 Feb 2026 16:07:34 -0800 Subject: [PATCH 11/12] only support qwen3.5 --- skyrl-tx/tx/models/qwen3_next.py | 78 +++++++++++++------------------- 1 file changed, 32 insertions(+), 46 deletions(-) diff --git a/skyrl-tx/tx/models/qwen3_next.py b/skyrl-tx/tx/models/qwen3_next.py index 2870c98caf..1b3bc713ae 100644 --- a/skyrl-tx/tx/models/qwen3_next.py +++ b/skyrl-tx/tx/models/qwen3_next.py @@ -264,13 +264,23 @@ def __init__(self, config: Qwen3NextConfig, layer_idx: int, *, dtype: jnp.dtype, self.conv_kernel_size = config.linear_conv_kernel_dim self.conv_dim = self.key_dim * 2 + self.value_dim - projection_size_a = self.key_dim * 2 + self.value_dim * 2 - projection_size_b = self.num_v_heads * 2 - # Keep linear-attention projections replicated across TP for simplicity/stability. - self.in_proj_a = LoRALinear( + projection_size_qkv = self.key_dim * 2 + self.value_dim + self.in_proj_qkv = LoRALinear( self.hidden_size, - projection_size_a, + projection_size_qkv, + sharding=("fsdp", None), + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + dtype=dtype, + param_dtype=dtype, + use_bias=False, + kernel_init=nnx.initializers.lecun_normal(), + rngs=rngs, + ) + self.in_proj_z = LoRALinear( + self.hidden_size, + self.value_dim, sharding=("fsdp", None), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, @@ -282,7 +292,19 @@ def __init__(self, config: Qwen3NextConfig, layer_idx: int, *, dtype: jnp.dtype, ) self.in_proj_b = LoRALinear( self.hidden_size, - projection_size_b, + self.num_v_heads, + sharding=("fsdp", None), + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + dtype=dtype, + param_dtype=dtype, + use_bias=False, + kernel_init=nnx.initializers.lecun_normal(), + rngs=rngs, + ) + self.in_proj_a = LoRALinear( + self.hidden_size, + self.num_v_heads, sharding=("fsdp", None), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, @@ -375,38 +397,6 @@ def _causal_conv_decode(self, x: jax.Array, conv_state: jax.Array) -> tuple[jax. out = nnx.silu(out_full[..., -seq_len:]) return out, new_state - def fix_query_key_value_ordering( - self, - mixed_qkvz: jax.Array, - mixed_ba: jax.Array, - ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: - qkvz_shape = mixed_qkvz.shape[:-1] + ( - self.num_k_heads, - 2 * self.head_k_dim + 2 * self.head_v_dim * self.num_v_heads // self.num_k_heads, - ) - ba_shape = mixed_ba.shape[:-1] + (self.num_k_heads, 2 * self.num_v_heads // self.num_k_heads) - mixed_qkvz = mixed_qkvz.reshape(qkvz_shape) - mixed_ba = mixed_ba.reshape(ba_shape) - - split_qkvz = [ - self.head_k_dim, - self.head_k_dim, - self.num_v_heads // self.num_k_heads * self.head_v_dim, - self.num_v_heads // self.num_k_heads * self.head_v_dim, - ] - split_ba = [self.num_v_heads // self.num_k_heads, self.num_v_heads // self.num_k_heads] - - split_qkvz_idx = [split_qkvz[0], split_qkvz[0] + split_qkvz[1], sum(split_qkvz[:-1])] - split_ba_idx = [split_ba[0]] - query, key, value, z = jnp.split(mixed_qkvz, split_qkvz_idx, axis=3) - b, a = jnp.split(mixed_ba, split_ba_idx, axis=3) - - value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim) - z = z.reshape(z.shape[0], z.shape[1], -1, self.head_v_dim) - b = b.reshape(b.shape[0], b.shape[1], self.num_v_heads) - a = a.reshape(a.shape[0], a.shape[1], self.num_v_heads) - return query, key, value, z, b, a - def __call__( self, hidden_states: jax.Array, @@ -419,14 +409,10 @@ def __call__( hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) batch_size, seq_len, _ = hidden_states.shape - projected_qkvz = self.in_proj_a(hidden_states, adapter_indices=adapter_indices) - projected_ba = self.in_proj_b(hidden_states, adapter_indices=adapter_indices) - query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_qkvz, projected_ba) - - query_flat = query.reshape(batch_size, seq_len, -1) - key_flat = key.reshape(batch_size, seq_len, -1) - value_flat = value.reshape(batch_size, seq_len, -1) - mixed_qkv = jnp.concatenate([query_flat, key_flat, value_flat], axis=-1).transpose((0, 2, 1)) + mixed_qkv = self.in_proj_qkv(hidden_states, adapter_indices=adapter_indices).transpose((0, 2, 1)) + z = self.in_proj_z(hidden_states, adapter_indices=adapter_indices).reshape(batch_size, seq_len, -1, self.head_v_dim) + b = self.in_proj_b(hidden_states, adapter_indices=adapter_indices) + a = self.in_proj_a(hidden_states, adapter_indices=adapter_indices) use_precomputed = conv_state is not None and recurrent_state is not None and seq_len == 1 if use_precomputed: From 334ace2b8ce43b3f1ee01e501f9b723767703a48 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Thu, 26 Feb 2026 02:32:03 +0000 Subject: [PATCH 12/12] update --- skyrl-tx/tx/models/qwen3_next.py | 1 + 1 file changed, 1 insertion(+) diff --git a/skyrl-tx/tx/models/qwen3_next.py b/skyrl-tx/tx/models/qwen3_next.py index 1b3bc713ae..310b725d63 100644 --- a/skyrl-tx/tx/models/qwen3_next.py +++ b/skyrl-tx/tx/models/qwen3_next.py @@ -885,3 +885,4 @@ def __call__( ) Qwen3_5ForConditionalGeneration = Qwen3NextForCausalLM +Qwen3_5MoeForConditionalGeneration = Qwen3NextForCausalLM