Skip to content

Add Gemma3n text-only native server path#672

Open
Grape203 wants to merge 1 commit intoUbiquitousLearning:mainfrom
Grape203:pr/gemma3n-text-only-native-20260501_135528
Open

Add Gemma3n text-only native server path#672
Grape203 wants to merge 1 commit intoUbiquitousLearning:mainfrom
Grape203:pr/gemma3n-text-only-native-20260501_135528

Conversation

@Grape203
Copy link
Copy Markdown
Contributor

@Grape203 Grape203 commented May 1, 2026

Summary

This PR adds an initial Gemma3n text-only native path to pymllm.

The implementation focuses on the simplest text-only LLM path first, following the staged direction discussed with the maintainer. It supports loading Gemma3n E2B text weights, native text forward/generation, and basic pymllm-server decode with RadixCache disabled.

Main changes

  • Add pymllm.models.gemma3n with a text-only native Gemma3n implementation.
  • Register Gemma3n model classes in pymllm.models.
  • Add model-specific CPU-first weight loading support for Gemma3n.
  • Add Gemma3n text weight streaming loader.
  • Add text-only support for per-layer embedding, AltUp flow, sliding/full attention layer types, and KV sharing.
  • Fix the disable_radix_cache path so ChunkCache does not enter RadixCache-specific insertion logic.
  • Add a minimal batch=1 full-context recompute path for native server decode correctness.

Verification

Tested locally with Gemma3n E2B text weights.

  • Native text weight loading succeeds:
    • loaded=732
    • skipped=825
    • missing_in_ckpt=0
  • Native direct generation works.
  • pymllm-server works with RadixCache disabled.
  • /v1/completions multi-token decode returns:
The capital of France is **Paris**.

<!-- This is an auto-generated comment: release notes by coderabbit.ai -->
## Summary by CodeRabbit

* **New Features**
  * Added Gemma 3n model support with causal and conditional LM runtime paths.

* **Improvements**
  * Optional CPU-first model instantiation for certain model modes to improve loading stability.
  * Model-specific streaming weight loading with safe fallback to the existing loader.
  * Radix-cache handling refined to avoid radix-specific operations on non-radix caches.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 1, 2026

📝 Walkthrough

Walkthrough

Adds a Gemma 3n text-model implementation and registry entries, modifies model instantiation to optionally perform CPU-first weight loading and prefer model-specific weight loaders, and restricts radix-cache-specific logic to radix-capable cache implementations. (50 words)

Changes

Cohort / File(s) Summary
Gemma 3n model & registry
pymllm/models/gemma3n.py, pymllm/models/__init__.py
Adds a complete Gemma 3n text-model implementation (layers, attention, RoPE, ALTUP/LAUREL, scaled embeddings, LAUREL residuals, decoder wrapper, and weight loading APIs) and registers Gemma3nForCausalLM / Gemma3nForConditionalGeneration.
Model instantiation & weight loading
pymllm/executor/model_runner.py
Model instantiation may first instantiate weights on CPU when model_cls.requires_cpu_first_weight_loading or Gemma3n native mode is detected; logs CPU-first behavior. Weight loading prefers a model-provided load_weights_from_model_path when allowed, else falls back to iterator-based load_weights.
Cache handling guards
pymllm/orchestrator/model_runner_process.py
Guards radix-cache insert/cleanup paths so radix-specific methods run only for caches exposing radix semantics (e.g., require page_size), avoiding radix-only APIs for non-radix caches like ChunkCache.

Sequence Diagram(s)

sequenceDiagram
    participant Runner as ModelRunner
    participant Model as ModelClass
    participant Storage as CheckpointStorage
    participant Device as DeviceMgr

    Runner->>Model: determine instantiate_device_str (runtime device or "cpu")
    alt CPU-first required
        Runner->>Device: set target="cpu"
        Runner-->>Runner: log "CPU-first enabled"
    else Instantiate on runtime device
        Runner->>Device: set target=runtime_device
    end

    Runner->>Model: instantiate model on target
    Runner->>Model: evaluate use_model_path_weight_loader (bool/callable)
    alt model provides load_weights_from_model_path and resolver true
        Runner->>Storage: request model_path chunks
        Storage-->>Model: stream checkpoint
        Model->>Model: load_weights_from_model_path(model_path)
    else
        Runner->>Storage: iterate weight tensors
        Storage-->>Runner: weight iterator
        Runner->>Model: load_weights(weight_iterator)
    end

    Note right of Model: model initialized and weights loaded
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Poem

🐰 I nibbled code in moonlit hops,
Gemma threads and RoPE on my chops.
CPU-first I softly sing,
Weights stream home — a joyous spring.
🍃 Hop, load, predict — the warren hops!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 13.95% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: adding Gemma3n text-only native server path support.
Description check ✅ Passed The PR description provides a comprehensive summary, main changes, and verification results, exceeding the basic template requirements.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
⚔️ Resolve merge conflicts
  • Resolve merge conflict in branch pr/gemma3n-text-only-native-20260501_135528

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
Review rate limit: 6/8 reviews remaining, refill in 13 minutes and 34 seconds.

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@pymllm/models/gemma3n.py`:
- Around line 694-727: The code infers prefill vs decode from sequence length
and cache emptiness which lets a fresh 1-token request reuse stale context;
update the logic in the native path (around input_ids_hf/position_ids_hf
handling and the is_prefill decision) to detect request boundaries and reset
cached state when a new decode begins: if forward_batch.forward_mode indicates
"decode" (or if position_ids_hf[...,0] == 0) then clear
self._native_cached_input_ids, self._native_cached_positions and related cached
tensors such as self._hf_past_key_values before computing is_prefill so a
single-token new request does not append to prior context (apply same change to
the other similar branch handling caches).
- Around line 1041-1046: The weights handling in ModelRunner.load_model()
materializes iterables with list(weights), which will drain generator-like
inputs such as self._iter_weights and cause OOM for large Gemma3n checkpoints;
change the logic to treat non-dict weights as an iterable and iterate over it
directly (e.g., use a for-loop over weights or assign weight_items = weights
when it's already an iterable) rather than calling list(weights); if you keep
the TypeError branch for non-iterables, re-raise or raise a new error using
"from err" to preserve exception chaining.
- Around line 939-946: The current Gemma3n streaming loader returns
self.load_weights([]) when no .safetensors are found, which silently yields an
uninitialized model; instead, change the behavior in the st_files check inside
the Gemma3n loader: either raise a clear exception (e.g., FileNotFoundError or
ValueError) with a message that no .safetensors were found for the given
model_path so callers (including ModelRunner.load_model when
use_model_path_weight_loader is enabled) fail fast, or implement a proper .bin
fallback by detecting "*.bin" files and delegating to the existing .bin loading
code (call the appropriate bin-weight loader method) rather than returning an
empty load_weights list; update references to load_weights and the loader
selection logic accordingly.

In `@pymllm/orchestrator/model_runner_process.py`:
- Around line 501-504: The cleanup path incorrectly uses a simple "cache is not
None" check so ChunkCache (which lacks page_size) can still run radix cleanup
and set did_insert=True, preventing KV frees; update the cleanup logic to mirror
the insertion guard by checking for radix-capable cache (e.g., "if cache is not
None and hasattr(cache, 'page_size')" or equivalent) before calling
_free_rid_resources or setting did_insert, and ensure _free_rid_resources only
treats did_insert as true when a real radix insertion occurred (adjust
_free_rid_resources and caller logic in model_runner_process.py to use the same
hasattr(cache, 'page_size') predicate to decide insertion/cleanup and never mark
did_insert for ChunkCache).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 63329a33-6882-4fdc-bc8e-b9cf3b71fcd7

📥 Commits

Reviewing files that changed from the base of the PR and between 729ca4c and 2126cae.

📒 Files selected for processing (4)
  • pymllm/executor/model_runner.py
  • pymllm/models/__init__.py
  • pymllm/models/gemma3n.py
  • pymllm/orchestrator/model_runner_process.py

Comment thread pymllm/models/gemma3n.py Outdated
Comment thread pymllm/models/gemma3n.py
Comment thread pymllm/orchestrator/model_runner_process.py
@Grape203 Grape203 force-pushed the pr/gemma3n-text-only-native-20260501_135528 branch from 2126cae to b400087 Compare May 1, 2026 09:03
@Grape203
Copy link
Copy Markdown
Contributor Author

Grape203 commented May 1, 2026

Updated the PR to address the CodeRabbit comments:

  • reset native decode cache based on extend/prefill boundary
  • raise FileNotFoundError when safetensors weights are missing
  • avoid materializing streaming weight iterators in load_weights
  • avoid treating ChunkCache as RadixCache during cleanup

Also re-ran the local pymllm-server verification with Gemma3n E2B text weights and /v1/completions still returns:
"The capital of France is Paris."

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
pymllm/models/gemma3n.py (2)

1062-1068: ⚡ Quick win

Chain the exception for better debugging context.

The TypeError raised when weights is not iterable should be chained with the original exception using from err to preserve the exception context and aid debugging.

♻️ Proposed fix
         else:
             try:
                 weight_items = iter(weights)
-            except TypeError:
+            except TypeError as err:
                 raise TypeError(
                     f"weights must be a dict-like state_dict, a module with state_dict(), "
                     f"or an iterable of (name, tensor), got {type(weights)}"
-                )
+                ) from err
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pymllm/models/gemma3n.py` around lines 1062 - 1068, The TypeError raised when
attempting to call iter(weights) loses the original exception context; update
the try/except around iter(weights) to capture the original exception (e.g.,
except TypeError as err) and re-raise the new TypeError using "from err" so the
stack trace is chained and debugging shows the original error; locate the
try/except that wraps iter(weights) in the weights handling logic and modify
that raise to include the exception chaining.

356-359: 💤 Low value

Minor style: getattr with constant attribute after hasattr check.

Line 359 uses getattr(forward_batch, "kv_shared_cache") immediately after checking hasattr(forward_batch, "kv_shared_cache") on line 358. This can be simplified to direct attribute access since the attribute's existence is already verified.

♻️ Suggested simplification
-        elif forward_batch is not None and hasattr(forward_batch, "kv_shared_cache"):
-            shared_kv_cache = getattr(forward_batch, "kv_shared_cache")
+        elif forward_batch is not None and hasattr(forward_batch, "kv_shared_cache"):
+            shared_kv_cache = forward_batch.kv_shared_cache
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pymllm/models/gemma3n.py` around lines 356 - 359, The code checks
hasattr(forward_batch, "kv_shared_cache") then uses getattr(forward_batch,
"kv_shared_cache"); replace the getattr call with direct attribute access
(forward_batch.kv_shared_cache) to simplify the style while preserving
behavior—update the block that assigns shared_kv_cache from forward_batch in the
method where forward_batch is used (the snippet using isinstance(forward_batch,
dict) / hasattr(forward_batch, "kv_shared_cache")).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@pymllm/models/gemma3n.py`:
- Around line 694-742: The delete of forward_batch at the top removes the
variable so locals().get("forward_batch", None) can never find it; restore
extend-mode detection by moving or removing the del forward_batch so that
forward_batch is still present when creating forward_batch_obj, or alternatively
capture forward_batch into a local variable (e.g., forward_batch_obj =
forward_batch or via locals().get before the del) before deleting; update the
logic around forward_batch_obj / is_extend_mode (used in the is_prefill
computation) to use that captured value so extend-mode detection
(is_extend_mode) works for 1-token prompts.

---

Nitpick comments:
In `@pymllm/models/gemma3n.py`:
- Around line 1062-1068: The TypeError raised when attempting to call
iter(weights) loses the original exception context; update the try/except around
iter(weights) to capture the original exception (e.g., except TypeError as err)
and re-raise the new TypeError using "from err" so the stack trace is chained
and debugging shows the original error; locate the try/except that wraps
iter(weights) in the weights handling logic and modify that raise to include the
exception chaining.
- Around line 356-359: The code checks hasattr(forward_batch, "kv_shared_cache")
then uses getattr(forward_batch, "kv_shared_cache"); replace the getattr call
with direct attribute access (forward_batch.kv_shared_cache) to simplify the
style while preserving behavior—update the block that assigns shared_kv_cache
from forward_batch in the method where forward_batch is used (the snippet using
isinstance(forward_batch, dict) / hasattr(forward_batch, "kv_shared_cache")).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c588324c-685b-4a73-ae39-477690450de5

📥 Commits

Reviewing files that changed from the base of the PR and between 2126cae and b400087.

📒 Files selected for processing (4)
  • pymllm/executor/model_runner.py
  • pymllm/models/__init__.py
  • pymllm/models/gemma3n.py
  • pymllm/orchestrator/model_runner_process.py
✅ Files skipped from review due to trivial changes (1)
  • pymllm/models/init.py

Comment thread pymllm/models/gemma3n.py
@Grape203 Grape203 force-pushed the pr/gemma3n-text-only-native-20260501_135528 branch from b400087 to fe694cb Compare May 1, 2026 09:27
@Grape203 Grape203 force-pushed the pr/gemma3n-text-only-native-20260501_135528 branch from fe694cb to 35dce02 Compare May 1, 2026 09:28
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
pymllm/models/gemma3n.py (1)

1061-1067: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Preserve exception context when re-raising TypeError.

At Line 1064, re-raising without from err drops the original failure context and makes triage harder.

💡 Proposed fix
-        else:
-            try:
-                weight_items = iter(weights)
-            except TypeError:
-                raise TypeError(
-                    f"weights must be a dict-like state_dict, a module with state_dict(), "
-                    f"or an iterable of (name, tensor), got {type(weights)}"
-                )
+        else:
+            try:
+                weight_items = iter(weights)
+            except TypeError as err:
+                raise TypeError(
+                    f"weights must be a dict-like state_dict, a module with state_dict(), "
+                    f"or an iterable of (name, tensor), got {type(weights)}"
+                ) from err
#!/bin/bash
# Verify non-chained re-raises for this code path.
rg -n -C2 'except TypeError( as \w+)?:' pymllm/models/gemma3n.py
rg -n -C2 'raise TypeError\(' pymllm/models/gemma3n.py
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pymllm/models/gemma3n.py` around lines 1061 - 1067, The except TypeError
block that attempts to validate weights should preserve the original exception
context: capture the caught exception (e.g., "except TypeError as e:") and
re-raise the new TypeError with "from e" so the original traceback is chained;
update the except block around the weights iterator creation (the try that sets
weight_items = iter(weights)) to use "except TypeError as e" and then "raise
TypeError(... ) from e" referencing the same error message.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@pymllm/models/gemma3n.py`:
- Around line 792-797: is_prefill detection currently uses only
input_ids_hf.shape[1] and self._hf_past_key_values, which lets a one-token new
request be misclassified as decode and reuse stale _hf_past_key_values; change
the logic so prefill is also true when the request boundary indicates a new
prompt (e.g., compare a stored request id/turn counter or track whether a
prefill has been initialized for the current request), and reset/clear
self._hf_past_key_values when a new request begins; update the
is_prefill/past_key_values decision (the variables input_ids_hf, is_prefill,
past_key_values and self._hf_past_key_values) to consult that explicit
per-request boundary flag instead of relying only on token count.

---

Duplicate comments:
In `@pymllm/models/gemma3n.py`:
- Around line 1061-1067: The except TypeError block that attempts to validate
weights should preserve the original exception context: capture the caught
exception (e.g., "except TypeError as e:") and re-raise the new TypeError with
"from e" so the original traceback is chained; update the except block around
the weights iterator creation (the try that sets weight_items = iter(weights))
to use "except TypeError as e" and then "raise TypeError(... ) from e"
referencing the same error message.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 77bef586-be68-4847-a4a0-57e7bc0bdd10

📥 Commits

Reviewing files that changed from the base of the PR and between b400087 and 35dce02.

📒 Files selected for processing (4)
  • pymllm/executor/model_runner.py
  • pymllm/models/__init__.py
  • pymllm/models/gemma3n.py
  • pymllm/orchestrator/model_runner_process.py
✅ Files skipped from review due to trivial changes (1)
  • pymllm/models/init.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • pymllm/orchestrator/model_runner_process.py
  • pymllm/executor/model_runner.py

Comment thread pymllm/models/gemma3n.py
Comment on lines +792 to +797
# For batch_size=1 text-only serving:
# - prefill has sequence length > 1, so reset HF cache;
# - decode has sequence length == 1, so reuse stored HF cache.
is_prefill = input_ids_hf.shape[1] > 1 or self._hf_past_key_values is None
past_key_values = None if is_prefill else self._hf_past_key_values

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

HF cache reset misses one-token prefill boundaries (cross-request KV reuse risk).

At Line 795, is_prefill is inferred only from token count (> 1) plus cache emptiness. A new request with a one-token prompt can be misclassified as decode and incorrectly reuse _hf_past_key_values from a previous request.

💡 Proposed fix
-        # For batch_size=1 text-only serving:
-        # - prefill has sequence length > 1, so reset HF cache;
-        # - decode has sequence length == 1, so reuse stored HF cache.
-        is_prefill = input_ids_hf.shape[1] > 1 or self._hf_past_key_values is None
+        # Prefer scheduler mode for boundary detection so 1-token prompts are
+        # still treated as prefill for a new request.
+        is_extend_mode = False
+        if forward_batch_obj is not None:
+            forward_mode = getattr(forward_batch_obj, "forward_mode", None)
+            is_extend = getattr(forward_mode, "is_extend", None)
+            if callable(is_extend):
+                is_extend_mode = bool(is_extend())
+            elif isinstance(forward_batch_obj, dict):
+                is_extend_mode = forward_batch_obj.get("forward_mode") == "extend"
+
+        is_prefill = (
+            is_extend_mode
+            or input_ids_hf.shape[1] > 1
+            or self._hf_past_key_values is None
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pymllm/models/gemma3n.py` around lines 792 - 797, is_prefill detection
currently uses only input_ids_hf.shape[1] and self._hf_past_key_values, which
lets a one-token new request be misclassified as decode and reuse stale
_hf_past_key_values; change the logic so prefill is also true when the request
boundary indicates a new prompt (e.g., compare a stored request id/turn counter
or track whether a prefill has been initialized for the current request), and
reset/clear self._hf_past_key_values when a new request begins; update the
is_prefill/past_key_values decision (the variables input_ids_hf, is_prefill,
past_key_values and self._hf_past_key_values) to consult that explicit
per-request boundary flag instead of relying only on token count.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant