Add Gemma3n text-only native server path#672
Add Gemma3n text-only native server path#672Grape203 wants to merge 1 commit intoUbiquitousLearning:mainfrom
Conversation
📝 WalkthroughWalkthroughAdds a Gemma 3n text-model implementation and registry entries, modifies model instantiation to optionally perform CPU-first weight loading and prefer model-specific weight loaders, and restricts radix-cache-specific logic to radix-capable cache implementations. (50 words) Changes
Sequence Diagram(s)sequenceDiagram
participant Runner as ModelRunner
participant Model as ModelClass
participant Storage as CheckpointStorage
participant Device as DeviceMgr
Runner->>Model: determine instantiate_device_str (runtime device or "cpu")
alt CPU-first required
Runner->>Device: set target="cpu"
Runner-->>Runner: log "CPU-first enabled"
else Instantiate on runtime device
Runner->>Device: set target=runtime_device
end
Runner->>Model: instantiate model on target
Runner->>Model: evaluate use_model_path_weight_loader (bool/callable)
alt model provides load_weights_from_model_path and resolver true
Runner->>Storage: request model_path chunks
Storage-->>Model: stream checkpoint
Model->>Model: load_weights_from_model_path(model_path)
else
Runner->>Storage: iterate weight tensors
Storage-->>Runner: weight iterator
Runner->>Model: load_weights(weight_iterator)
end
Note right of Model: model initialized and weights loaded
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
⚔️ Resolve merge conflicts
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Review rate limit: 6/8 reviews remaining, refill in 13 minutes and 34 seconds.Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@pymllm/models/gemma3n.py`:
- Around line 694-727: The code infers prefill vs decode from sequence length
and cache emptiness which lets a fresh 1-token request reuse stale context;
update the logic in the native path (around input_ids_hf/position_ids_hf
handling and the is_prefill decision) to detect request boundaries and reset
cached state when a new decode begins: if forward_batch.forward_mode indicates
"decode" (or if position_ids_hf[...,0] == 0) then clear
self._native_cached_input_ids, self._native_cached_positions and related cached
tensors such as self._hf_past_key_values before computing is_prefill so a
single-token new request does not append to prior context (apply same change to
the other similar branch handling caches).
- Around line 1041-1046: The weights handling in ModelRunner.load_model()
materializes iterables with list(weights), which will drain generator-like
inputs such as self._iter_weights and cause OOM for large Gemma3n checkpoints;
change the logic to treat non-dict weights as an iterable and iterate over it
directly (e.g., use a for-loop over weights or assign weight_items = weights
when it's already an iterable) rather than calling list(weights); if you keep
the TypeError branch for non-iterables, re-raise or raise a new error using
"from err" to preserve exception chaining.
- Around line 939-946: The current Gemma3n streaming loader returns
self.load_weights([]) when no .safetensors are found, which silently yields an
uninitialized model; instead, change the behavior in the st_files check inside
the Gemma3n loader: either raise a clear exception (e.g., FileNotFoundError or
ValueError) with a message that no .safetensors were found for the given
model_path so callers (including ModelRunner.load_model when
use_model_path_weight_loader is enabled) fail fast, or implement a proper .bin
fallback by detecting "*.bin" files and delegating to the existing .bin loading
code (call the appropriate bin-weight loader method) rather than returning an
empty load_weights list; update references to load_weights and the loader
selection logic accordingly.
In `@pymllm/orchestrator/model_runner_process.py`:
- Around line 501-504: The cleanup path incorrectly uses a simple "cache is not
None" check so ChunkCache (which lacks page_size) can still run radix cleanup
and set did_insert=True, preventing KV frees; update the cleanup logic to mirror
the insertion guard by checking for radix-capable cache (e.g., "if cache is not
None and hasattr(cache, 'page_size')" or equivalent) before calling
_free_rid_resources or setting did_insert, and ensure _free_rid_resources only
treats did_insert as true when a real radix insertion occurred (adjust
_free_rid_resources and caller logic in model_runner_process.py to use the same
hasattr(cache, 'page_size') predicate to decide insertion/cleanup and never mark
did_insert for ChunkCache).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 63329a33-6882-4fdc-bc8e-b9cf3b71fcd7
📒 Files selected for processing (4)
pymllm/executor/model_runner.pypymllm/models/__init__.pypymllm/models/gemma3n.pypymllm/orchestrator/model_runner_process.py
2126cae to
b400087
Compare
|
Updated the PR to address the CodeRabbit comments:
Also re-ran the local pymllm-server verification with Gemma3n E2B text weights and /v1/completions still returns: |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
pymllm/models/gemma3n.py (2)
1062-1068: ⚡ Quick winChain the exception for better debugging context.
The
TypeErrorraised whenweightsis not iterable should be chained with the original exception usingfrom errto preserve the exception context and aid debugging.♻️ Proposed fix
else: try: weight_items = iter(weights) - except TypeError: + except TypeError as err: raise TypeError( f"weights must be a dict-like state_dict, a module with state_dict(), " f"or an iterable of (name, tensor), got {type(weights)}" - ) + ) from err🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@pymllm/models/gemma3n.py` around lines 1062 - 1068, The TypeError raised when attempting to call iter(weights) loses the original exception context; update the try/except around iter(weights) to capture the original exception (e.g., except TypeError as err) and re-raise the new TypeError using "from err" so the stack trace is chained and debugging shows the original error; locate the try/except that wraps iter(weights) in the weights handling logic and modify that raise to include the exception chaining.
356-359: 💤 Low valueMinor style:
getattrwith constant attribute afterhasattrcheck.Line 359 uses
getattr(forward_batch, "kv_shared_cache")immediately after checkinghasattr(forward_batch, "kv_shared_cache")on line 358. This can be simplified to direct attribute access since the attribute's existence is already verified.♻️ Suggested simplification
- elif forward_batch is not None and hasattr(forward_batch, "kv_shared_cache"): - shared_kv_cache = getattr(forward_batch, "kv_shared_cache") + elif forward_batch is not None and hasattr(forward_batch, "kv_shared_cache"): + shared_kv_cache = forward_batch.kv_shared_cache🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@pymllm/models/gemma3n.py` around lines 356 - 359, The code checks hasattr(forward_batch, "kv_shared_cache") then uses getattr(forward_batch, "kv_shared_cache"); replace the getattr call with direct attribute access (forward_batch.kv_shared_cache) to simplify the style while preserving behavior—update the block that assigns shared_kv_cache from forward_batch in the method where forward_batch is used (the snippet using isinstance(forward_batch, dict) / hasattr(forward_batch, "kv_shared_cache")).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@pymllm/models/gemma3n.py`:
- Around line 694-742: The delete of forward_batch at the top removes the
variable so locals().get("forward_batch", None) can never find it; restore
extend-mode detection by moving or removing the del forward_batch so that
forward_batch is still present when creating forward_batch_obj, or alternatively
capture forward_batch into a local variable (e.g., forward_batch_obj =
forward_batch or via locals().get before the del) before deleting; update the
logic around forward_batch_obj / is_extend_mode (used in the is_prefill
computation) to use that captured value so extend-mode detection
(is_extend_mode) works for 1-token prompts.
---
Nitpick comments:
In `@pymllm/models/gemma3n.py`:
- Around line 1062-1068: The TypeError raised when attempting to call
iter(weights) loses the original exception context; update the try/except around
iter(weights) to capture the original exception (e.g., except TypeError as err)
and re-raise the new TypeError using "from err" so the stack trace is chained
and debugging shows the original error; locate the try/except that wraps
iter(weights) in the weights handling logic and modify that raise to include the
exception chaining.
- Around line 356-359: The code checks hasattr(forward_batch, "kv_shared_cache")
then uses getattr(forward_batch, "kv_shared_cache"); replace the getattr call
with direct attribute access (forward_batch.kv_shared_cache) to simplify the
style while preserving behavior—update the block that assigns shared_kv_cache
from forward_batch in the method where forward_batch is used (the snippet using
isinstance(forward_batch, dict) / hasattr(forward_batch, "kv_shared_cache")).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: c588324c-685b-4a73-ae39-477690450de5
📒 Files selected for processing (4)
pymllm/executor/model_runner.pypymllm/models/__init__.pypymllm/models/gemma3n.pypymllm/orchestrator/model_runner_process.py
✅ Files skipped from review due to trivial changes (1)
- pymllm/models/init.py
b400087 to
fe694cb
Compare
fe694cb to
35dce02
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
pymllm/models/gemma3n.py (1)
1061-1067:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winPreserve exception context when re-raising
TypeError.At Line 1064, re-raising without
from errdrops the original failure context and makes triage harder.💡 Proposed fix
- else: - try: - weight_items = iter(weights) - except TypeError: - raise TypeError( - f"weights must be a dict-like state_dict, a module with state_dict(), " - f"or an iterable of (name, tensor), got {type(weights)}" - ) + else: + try: + weight_items = iter(weights) + except TypeError as err: + raise TypeError( + f"weights must be a dict-like state_dict, a module with state_dict(), " + f"or an iterable of (name, tensor), got {type(weights)}" + ) from err#!/bin/bash # Verify non-chained re-raises for this code path. rg -n -C2 'except TypeError( as \w+)?:' pymllm/models/gemma3n.py rg -n -C2 'raise TypeError\(' pymllm/models/gemma3n.py🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@pymllm/models/gemma3n.py` around lines 1061 - 1067, The except TypeError block that attempts to validate weights should preserve the original exception context: capture the caught exception (e.g., "except TypeError as e:") and re-raise the new TypeError with "from e" so the original traceback is chained; update the except block around the weights iterator creation (the try that sets weight_items = iter(weights)) to use "except TypeError as e" and then "raise TypeError(... ) from e" referencing the same error message.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@pymllm/models/gemma3n.py`:
- Around line 792-797: is_prefill detection currently uses only
input_ids_hf.shape[1] and self._hf_past_key_values, which lets a one-token new
request be misclassified as decode and reuse stale _hf_past_key_values; change
the logic so prefill is also true when the request boundary indicates a new
prompt (e.g., compare a stored request id/turn counter or track whether a
prefill has been initialized for the current request), and reset/clear
self._hf_past_key_values when a new request begins; update the
is_prefill/past_key_values decision (the variables input_ids_hf, is_prefill,
past_key_values and self._hf_past_key_values) to consult that explicit
per-request boundary flag instead of relying only on token count.
---
Duplicate comments:
In `@pymllm/models/gemma3n.py`:
- Around line 1061-1067: The except TypeError block that attempts to validate
weights should preserve the original exception context: capture the caught
exception (e.g., "except TypeError as e:") and re-raise the new TypeError with
"from e" so the original traceback is chained; update the except block around
the weights iterator creation (the try that sets weight_items = iter(weights))
to use "except TypeError as e" and then "raise TypeError(... ) from e"
referencing the same error message.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 77bef586-be68-4847-a4a0-57e7bc0bdd10
📒 Files selected for processing (4)
pymllm/executor/model_runner.pypymllm/models/__init__.pypymllm/models/gemma3n.pypymllm/orchestrator/model_runner_process.py
✅ Files skipped from review due to trivial changes (1)
- pymllm/models/init.py
🚧 Files skipped from review as they are similar to previous changes (2)
- pymllm/orchestrator/model_runner_process.py
- pymllm/executor/model_runner.py
| # For batch_size=1 text-only serving: | ||
| # - prefill has sequence length > 1, so reset HF cache; | ||
| # - decode has sequence length == 1, so reuse stored HF cache. | ||
| is_prefill = input_ids_hf.shape[1] > 1 or self._hf_past_key_values is None | ||
| past_key_values = None if is_prefill else self._hf_past_key_values | ||
|
|
There was a problem hiding this comment.
HF cache reset misses one-token prefill boundaries (cross-request KV reuse risk).
At Line 795, is_prefill is inferred only from token count (> 1) plus cache emptiness. A new request with a one-token prompt can be misclassified as decode and incorrectly reuse _hf_past_key_values from a previous request.
💡 Proposed fix
- # For batch_size=1 text-only serving:
- # - prefill has sequence length > 1, so reset HF cache;
- # - decode has sequence length == 1, so reuse stored HF cache.
- is_prefill = input_ids_hf.shape[1] > 1 or self._hf_past_key_values is None
+ # Prefer scheduler mode for boundary detection so 1-token prompts are
+ # still treated as prefill for a new request.
+ is_extend_mode = False
+ if forward_batch_obj is not None:
+ forward_mode = getattr(forward_batch_obj, "forward_mode", None)
+ is_extend = getattr(forward_mode, "is_extend", None)
+ if callable(is_extend):
+ is_extend_mode = bool(is_extend())
+ elif isinstance(forward_batch_obj, dict):
+ is_extend_mode = forward_batch_obj.get("forward_mode") == "extend"
+
+ is_prefill = (
+ is_extend_mode
+ or input_ids_hf.shape[1] > 1
+ or self._hf_past_key_values is None
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@pymllm/models/gemma3n.py` around lines 792 - 797, is_prefill detection
currently uses only input_ids_hf.shape[1] and self._hf_past_key_values, which
lets a one-token new request be misclassified as decode and reuse stale
_hf_past_key_values; change the logic so prefill is also true when the request
boundary indicates a new prompt (e.g., compare a stored request id/turn counter
or track whether a prefill has been initialized for the current request), and
reset/clear self._hf_past_key_values when a new request begins; update the
is_prefill/past_key_values decision (the variables input_ids_hf, is_prefill,
past_key_values and self._hf_past_key_values) to consult that explicit
per-request boundary flag instead of relying only on token count.
Summary
This PR adds an initial Gemma3n text-only native path to pymllm.
The implementation focuses on the simplest text-only LLM path first, following the staged direction discussed with the maintainer. It supports loading Gemma3n E2B text weights, native text forward/generation, and basic
pymllm-serverdecode with RadixCache disabled.Main changes
pymllm.models.gemma3nwith a text-only native Gemma3n implementation.pymllm.models.disable_radix_cachepath so ChunkCache does not enter RadixCache-specific insertion logic.Verification
Tested locally with Gemma3n E2B text weights.
loaded=732skipped=825missing_in_ckpt=0pymllm-serverworks with RadixCache disabled./v1/completionsmulti-token decode returns: