Skip to content

feat(pymllm): support Qwen3 Jetson BF16, W4A16, and W8A8 serving#670

Merged
chenghuaWang merged 35 commits intoUbiquitousLearning:mainfrom
jialilve:feature/jetson-qwen3-family-bf16-w4a16-w8a8
Apr 30, 2026
Merged

feat(pymllm): support Qwen3 Jetson BF16, W4A16, and W8A8 serving#670
chenghuaWang merged 35 commits intoUbiquitousLearning:mainfrom
jialilve:feature/jetson-qwen3-family-bf16-w4a16-w8a8

Conversation

@jialilve
Copy link
Copy Markdown
Contributor

@jialilve jialilve commented Apr 27, 2026

Summary

This PR adds Jetson-oriented pymllm support for the Qwen3 family, covering BF16 serving, W4A16/AWQ compressed-tensors serving, and W8A8 INT8 compressed-tensors serving. It also adds a bench_one_batch entrypoint for model-level prefill/decode benchmarking in pymllm.

The W8A8 path uses:

fp16/bf16 activation
-> Triton per-token INT8 activation quantization
-> CUTLASS int8_scaled_mm with fused per-row/per-column scales
-> fp16/bf16 output

What Changed

  • Add compressed-tensors support for Qwen3-VL quantized models:
    • W4A16 / AWQ path via GPTQ Marlin
    • W8A8 int-quantized path via Triton activation quantization + CUTLASS INT8 GEMM
  • Add Jetson-compatible INT8 kernel support in mllm-kernel:
    • CUTLASS int8_scaled_mm implementation for Ampere / SM8x GPUs (SM80-SM89),
      validated on Jetson Orin SM87
  • Add Qwen3 family model support in pymllm:
    • Qwen3ForCausalLM
    • model registry entry
    • weight-loading and timing tests
  • Improve Qwen3-VL serving flow on Jetson:
    • multimodal request handling updates
    • server timing fields
    • tokenizer / scheduler / detokenizer flow updates
  • Add a bench_one_batch benchmark entrypoint in pymllm:
    • direct ModelRunner execution without HTTP server, scheduler, tokenizer, or detokenizer
    • synthetic text-only token IDs for controlled prefill/decode measurements
    • JSONL result output for benchmark sweeps
    • optional torch profiler hooks, not used in the initial benchmark runs
  • Fix a KV slot leak when ChunkCache is used with radix cache disabled.
  • Add bilingual pymllm README docs for Jetson environment, launch commands, and W8A8 development notes.

Key Implementation Notes

  • W8A8 weights are stored as (K, N) column-major tensors after loading so the CUTLASS kernel can consume them without per-call copies.
  • CUTLASS is JIT-compiled on first use for the current GPU architecture via
    torch.cuda.get_device_capability() and cached per architecture under
    ~/.cache/mllm_kernel/cutlass_int8_scaled_mm/sm_XX/, for example sm_87
    on Jetson Orin.
  • The validated environment uses the official PyPI triton==3.6.0 aarch64 wheel.
  • pymllm.bench_one_batch measures model-level latency directly around prefill and per-token decode with CUDA synchronization.
  • CUDA graph replay applies to decode batches when enabled; prefill is still executed through the normal extend path.

Notes for Reviewers

  • The W8A8 CUTLASS path currently supports Ampere / SM8x GPUs only
    (SM80-SM89). The validated target is Jetson Orin SM87. Hopper / SM90 is not
    included in this PR.
  • The first W8A8 forward on a machine without the matching per-architecture
    CUTLASS JIT cache may spend extra time compiling the extension. Later runs on
    the same GPU architecture reuse the cache.
  • Qwen3-VL ViT, lm_head, embeddings, and LayerNorm are intentionally outside the current W8A8 quantized scope.
  • bench_one_batch results are model-level synthetic text-only measurements and should not be mixed with HTTP serving metrics such as TTFT, ITL, TPOT, or E2E latency.
  • Qwen3-VL benchmark runs here do not include real image prompts or ViT forward cost.

Summary by CodeRabbit

  • New Features

    • Qwen3 & Qwen3‑VL models with weight‑loading, fused projections, and forward timing capture
    • Compressed‑tensors quantization (W4A16 Marlin and W8A8 per‑token INT8 → CUTLASS) with repack/reshape utilities
    • CUTLASS INT8 scaled GEMM and Triton per‑token INT8 quant primitives; MergedLinear fused‑output layer
    • bench/benchmark scripts and a one‑batch benchmark runner
  • Bug Fixes

    • RMSNorm CPU fallback and safer multimodal tokenization/prompt handling
    • Improved timing propagation through scheduler/orchestrator
  • Tests

    • Extensive new unit and GPU tests covering kernels, repacking, quant methods, and Qwen3
  • Documentation

    • New comprehensive docs and Chinese README
  • Chores

    • JIT/build/bootstrap and config handling improvements; server debug_timing flag added

nuozhihan and others added 28 commits April 27, 2026 09:24
…nchmarks

Phase 0.1: baseline benchmark scripts for GEMM and activation quant
Phase 1.1: port sglang Triton per_token_quant_int8 to pymllm/quantization/kernels/

Triton kernel correctness: +-1 LSB rounding diff vs torch (0.01% elements)
Triton kernel performance: 25-67% faster than torch path on Jetson SM87

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Phase 1.2:
- Replace _per_token_quant_int8 torch impl with Triton kernel import
- Remove old mllm JIT kernel fallback from _int8_scaled_mm
- GEMM now uses torch._int_mm directly (intermediate state before CUTLASS)
- Update test to verify Triton quant + torch._int_mm path

All 26 tests pass (config + runtime + kernel).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Phase 2: Port sglang CUTLASS int8 GEMM to mllm-kernel.
- SM89 tile shapes (100K shared memory safe for Jetson Orin SM87)
- Per-row/col scale epilogue fused into GEMM
- JIT compiled via torch.utils.cpp_extension (~100s first run, cached after)
- Integrated into compressed_tensors.py W8A8 forward path

Performance on SM87 (93,2048,6144):
  CUTLASS: 0.295 ms (4.2x vs torch._int_mm, 67.8x vs old JIT)

cutlass_extensions ported from sglang sgl-kernel (Apache 2.0)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…eprecated

28 test cases: 7 shapes x 2 dtypes x 2 bias configs, all pass.
Old naive JIT kernel marked DEPRECATED (kept for reference).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Spike files verify CUTLASS compilation and tile configs on SM87.
Baseline doc updated with Phase 1 (Triton) and Phase 2 (CUTLASS) results.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
SM87 (Jetson Orin) has 164K shared memory per SM (same as SM80),
not 100K as initially assumed. However, benchmarking shows SM89's
3-stage tiles outperform SM80's 5-stage tiles at large M when the
per-row/col scale epilogue visitor is used, due to lower smem
pressure and better occupancy.

SM80 dispatch kept in source for future use on devices that
benefit from larger tiles.

Verified: 28 CUTLASS tests pass, performance unchanged.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Spike files kept on disk for reference but excluded from version control.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
process_weights_after_loading now stores weight as (K, N) column-major
(stride(0)==1) instead of row-major. This eliminates a full weight
matrix copy (~12MB per linear layer) on every forward call.

Root cause: CUTLASS requires column-major B (stride(0)==1), but
weights were stored row-major, triggering .t().contiguous().t()
on every call — ~2.3GB of copies per decode step for Qwen3-VL-2B.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Cache Triton quant and CUTLASS mm function refs to avoid repeated
  import lookups (140 calls/decode step)
- Remove redundant .contiguous(), .reshape(-1), .to(float32) in
  CUTLASS wrapper — scales are already in correct format from
  Triton quant and process_weights_after_loading
- Only do scales_a.squeeze(-1) to convert (M,1) -> (M,)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Compares GEMM kernel performance at representative Qwen3-VL-2B shapes.
W8A8 columns show activation quant, GEMM, and total separately.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
SM87 (Jetson Orin) has 164KB smem, same as SM80 — not 100KB like SM86/SM89.
Both sglang and vllm route SM87 to SM80 dispatch. E2E benchmark confirms
SM80 ≈ SM89 tiles on SM87 (<2% diff). Reverts the SM89 override from
5e6c634 and uses SM80 dispatch with deeper pipeline stages (5-6 stage).

Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>
When disable_radix_cache=True, ChunkCache is a no-op cache and should not be treated as cache-enabled.

Previously cache_enabled only checked cache is not None, which made the insert path report did_insert=True and skip Phase 4 free logic.

This change excludes ChunkCache from cache_enabled so KV slots are released correctly.
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 27, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds CUTLASS-backed INT8 scaled GEMM with per-row/column scaling, Triton per-token activation quant, Marlin GPTQ repack, a compressed-tensors quant pipeline (W4A16/W8A8), Qwen3/Qwen3-VL model implementations and fixes, many CUDA/JIT kernels and tests, benchmarks, runtime timing propagation, and extensive docs and tooling.

Changes

Cohort / File(s) Summary
CUTLASS epilogue & GEMM infra
mllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h, mllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h, mllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h
Adds epilogue visitor applying per-row/per-column scaling and bias, plus new GEMM device-layer compatibility/dispatch helpers (grid/workspace/launch/occupancy) for mixed-dtype GEMMs.
CUTLASS INT8 kernel & JIT wrapper
mllm-kernel/mllm_kernel/cuda/csrc/gemm/int8/int8_scaled_mm_cutlass.cu, mllm-kernel/mllm_kernel/cuda/jit/int8_scaled_mm_cutlass.py, mllm-kernel/benchmarks/bench_int8_scaled_mm.py, mllm-kernel/tests/test_int8_scaled_mm_cutlass.py
Implements int8_scaled_mm CUDA kernel with SM dispatch, exposes via JIT loader, adds Python benchmark and correctness/build tests.
Marlin GPTQ repack & kernel JIT
mllm-kernel/mllm_kernel/cuda/jit/gptq_marlin_repack.py, mllm-kernel/mllm_kernel/cuda/jit/gptq_marlin.py, mllm-kernel/mllm_kernel/cuda/jit/__init__.py, mllm-kernel/tests/test_gptq_marlin.py, mllm-kernel/tests/test_gptq_marlin_repack.py
Adds Marlin repack JIT, perm normalization/validation, JIT wiring tweaks, and GPU tests for repacking and Marlin GEMM.
Marlin header / namespace / build tweaks
mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/marlin.cuh, mllm-kernel/include/mllm_kernel/scalar_type.hpp, mllm-kernel/cmake/CPM.cmake
Includes canonical namespace alias, removes local host alias in marlin header, and adjusts CPM bootstrap to prefer vendored copy and validate CPM availability.
Marlin GEMM kernel source
mllm-kernel/mllm_kernel/cuda/csrc/gemm/int8/int8_scaled_mm_cutlass.cu
New CUTLASS-based INT8 GEMM implementation wiring per-row/col scales, bias, dtype handling, SM-style tile dispatch and PYBIND11 export.
Triton activation quantization
pymllm/quantization/kernels/int8_activation_triton.py, pymllm/tests/bench_w8a8_activation_quant.py
Adds Triton per-token int8 quant kernel and Python wrapper, plus a benchmark comparing Torch vs Triton quantization.
Compressed-tensors quantization method
pymllm/quantization/methods/compressed_tensors.py, pymllm/quantization/methods/__init__.py, pymllm/tests/test_compressed_tensors_config.py, pymllm/tests/test_compressed_tensors_runtime.py
New compressed-tensors implementation supporting W4A16 (Marlin) and W8A8 (Triton+CUTLASS), config parsing, weight repacking, lazy kernel loading, forward apply wiring, and tests.
GPT/Qwen3 model additions & VL updates
pymllm/models/qwen3.py, pymllm/models/qwen3_vl.py, pymllm/models/__init__.py, multiple tests (tests/test_qwen3_*, tests/test_qwen3_vl_*)
Adds Qwen3 text model, Qwen3-VL changes (RoPE dtype/device, vision interp/cu-seqlen, residual/deepstack handling), weight-loading helpers, registry entry, and related tests.
RMSNorm fallback & tests
pymllm/layers/rms_norm.py, pymllm/tests/test_rms_norm.py
Adds pure-Torch RMSNorm fallback and guards around FlashInfer fused calls while preserving residual-update semantics; includes tests.
MergedLinear & MLP changes
pymllm/layers/linear.py, pymllm/layers/mlp.py, pymllm/layers/__init__.py, pymllm/tests/test_linear_merged.py
Introduces MergedLinear (fused output partitions + weight_loader), switches fused gate/up to use MergedLinear, exports symbol, and adds tests for merged loading behavior.
Benchmarks: W4A16 vs W8A8 & helpers
mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py, mllm-kernel/benchmarks/bench_int8_scaled_mm.py, pymllm/tests/bench_w8a8_activation_quant.py
Benchmark scripts isolating Triton quant, CUTLASS GEMM, and Marlin GEMM latencies; standalone runners and microbench tests.
Orchestration timing, tokenizer & detokenizer
pymllm/orchestrator/scheduler_process.py, pymllm/orchestrator/model_runner_process.py, pymllm/orchestrator/detokenizer_process.py, pymllm/orchestrator/tokenizer_process.py
Adds per-request prefill/decode timing fields, wires timing through scheduler/runner/detokenizer, ensures tokenizer uses multimodal processor input_ids when available.
Server & runner updates
pymllm/server/launch.py, pymllm/executor/model_runner.py, pymllm/configs/server_config.py
Preserves multimodal message content for prompt rendering, adds enable_debug_timing flag and optional debug timing in responses, and unwraps nested quantization_config from config.json.
Bench/utility additions
pymllm/bench_one_batch.py, pymllm/tests/test_bench_one_batch.py
New single-batch benchmark runner with CLI, profiling/hooks, deterministic inputs, latency/throughput reporting, and tests.
Extensive tests & misc
mllm-kernel/..., pymllm/tests/..., pymllm/quantization/kernels/__init__.py
Many new GPU and unit tests covering repack, GEMM, quant, runtime timing, model behavior; small package init and test harness additions.
Docs / README additions
pymllm/README.md, pymllm/README-ZH.md, docs/pymllm_runtime/*
Adds comprehensive English/Chinese READMEs and runtime developer/user docs describing quant paths, env, usage, and troubleshooting.

Sequence Diagram(s)

sequenceDiagram
    participant App as Application
    participant Tokenizer as Tokenizer/Processor
    participant Triton as Triton Quant Kernel
    participant CUTLASS as CUTLASS int8_scaled_mm
    participant Marlin as Marlin gptq_marlin_gemm
    participant Detok as Detokenizer/Output

    App->>Tokenizer: prepare activations (fp16)
    Tokenizer->>Triton: per-token quantization -> (int8 activations, scales)
    Triton->>CUTLASS: send int8 activations + scales
    CUTLASS->>CUTLASS: INT8 GEMM + epilogue (per-row/col scaling, bias) -> logits
    alt W4A16 path (Marlin)
        App->>Marlin: prepacked int4 weights + activations
        Marlin->>Marlin: run gptq_marlin_gemm -> logits
    end
    CUTLASS->>Detok: logits
    Marlin->>Detok: logits
    Detok->>App: return final responses (include timing metadata)
Loading

Estimated Code Review Effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Suggested Reviewers

  • oreomaker
  • liang1232018
  • chenghuaWang
  • yirongjie
  • UbiquitousLearning

Poem

🐰 I hopped through headers, kernels and glue,
bytes got packed, quant scales found their cue,
Triton trimmed tokens, CUTLASS did the rest,
Marlin marched weights into compacted best,
I nibble carrots and deliver this test.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 21.05% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly summarizes the main change: adding Qwen3 Jetson support with three quantization paths (BF16, W4A16, W8A8).
Description check ✅ Passed The PR description is comprehensive, covering summary, detailed changes, key implementation notes, and reviewer considerations. It exceeds the template requirements significantly.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
Review rate limit: 7/8 reviews remaining, refill in 7 minutes and 30 seconds.

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 12

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (5)
pymllm/executor/model_runner.py (1)

486-505: ⚠️ Potential issue | 🟠 Major

Falls back to returning the full HF config when config.json lacks quantization_config.

When config.json is one of the unique candidate filenames (it is — compressed-tensors lists it via get_config_filenames()), the new branch only short-circuits the quantization_config unwrap when that key exists. If config.json is present but does not contain a quantization_config field (e.g. an unquantized checkpoint, or a checkpoint where the quant metadata only lives in quantize_config.json later in the list), Line 495 returns the full top-level HF config (architectures, hidden_size, …) and _resolve_quant_config will then read a bogus quant_method (likely None, but in general any string the upstream HF config contains), instead of either continuing the search for quantize_config.json or returning {}.

This also makes the post-loop fallback at lines 498–505 effectively dead whenever config.json is in unique.

🐛 Proposed fix
         for fname in unique:
             fpath = model_path / fname
             if fpath.exists():
                 with open(fpath) as fp:
                     cfg = json.load(fp)
-                # config.json stores model metadata at the top level and
-                # nests quantization details under quantization_config.
-                if fname == "config.json" and "quantization_config" in cfg:
-                    return cfg["quantization_config"]
-                return cfg
+                # config.json stores model metadata at the top level and
+                # nests quantization details under quantization_config.
+                # When the nested key is missing, fall through to the next
+                # candidate filename (e.g. quantize_config.json) instead of
+                # returning the full HF model config.
+                if fname == "config.json":
+                    if "quantization_config" in cfg:
+                        return cfg["quantization_config"]
+                    continue
+                return cfg

After this fix you can also drop the duplicated post-loop fallback (lines 497–505), since the loop already handles config.json correctly.

As per coding guidelines: "Validate inputs for public APIs and critical internal functions."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pymllm/executor/model_runner.py` around lines 486 - 505, In
_resolve_quant_config, don't return the full HF config when encountering
"config.json" that lacks "quantization_config"; instead, if fname ==
"config.json" and "quantization_config" in cfg then return
cfg["quantization_config"], otherwise continue the loop to allow later
candidates (e.g., "quantize_config.json") to be checked; remove the duplicate
post-loop fallback that re-reads config.json since the loop now handles it, and
keep the final return {} if nothing is found.
pymllm/server/launch.py (1)

707-768: ⚠️ Potential issue | 🟠 Major

r may be unbound when results is empty in /v1/completions.

r is the loop variable from for i, r in enumerate(results): at line 717. The new timing block at lines 743-767 references r after the loop. If _iter_with_disconnect_check yields nothing (e.g., client disconnect, engine error), results is empty, the loop never executes, and r is undefined — r.get(...) raises NameError, which then falls into the generic except Exception handler at line 777 and is reported to the client as a 500 with no actionable detail. Compare with /v1/chat/completions which initializes r = {} on line 965.

🐛 Suggested fix
     try:
         results = []
         async for item in _iter_with_disconnect_check(gen, request):
             results.append(item)
+        if not results:
+            raise HTTPException(status_code=500, detail="No output from engine")
         choices = []
         prompt_tokens = 0
         completion_tokens = 0
-        for i, r in enumerate(results):
+        r: Dict[str, Any] = {}
+        for i, r in enumerate(results):
             choices.append(

Alternatively, mirror the chat-completions path and bind a single r = results[-1] if results else {} before constructing the response.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pymllm/server/launch.py` around lines 707 - 768, The timing block uses the
loop variable r after the for-loop in the completion handler (the
results/choices assembly around engine.generate_async), which can be undefined
when results is empty; fix by binding r = results[-1] if results else {} (or an
empty dict) immediately after the for i, r in enumerate(results): loop and
before constructing the ORJSONResponse so all r.get(...) calls are safe (refer
to the loop producing choices and the timing dictionary construction).
pymllm/models/qwen3_vl.py (2)

1172-1226: ⚠️ Potential issue | 🟡 Minor

Wall-clock timings here don't include async GPU work.

time.perf_counter() around CUDA ops only captures kernel launch time unless preceded/followed by a torch.cuda.synchronize(). The _vit_t0/_llm_t0 deltas reported here can wildly under‑report real model latency, especially in extend mode where the work being measured is large.

For server-side coarse latency this may be acceptable (and pymllm/README-ZH.md does call this out), but the same values get bubbled into OpenAI-compatible responses via the orchestrator timing field, so users may interpret them as actual model time.

If accuracy matters, gate a sync via a config flag, e.g.:

_vit_t0 = time.perf_counter()
vision_features = self.visual(pixel_values, grid_thw=image_grid_thw)
if torch.cuda.is_available():
    torch.cuda.synchronize()
vit_prefill_ms = (time.perf_counter() - _vit_t0) * 1000.0

Same applies to the LLM forward timing block at Lines 1206–1216.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pymllm/models/qwen3_vl.py` around lines 1172 - 1226, The timing measurements
around vision and LLM forwards (_vit_t0/_llm_t0) only measure kernel launch time
and under-report real GPU latency; modify the blocks that call self.visual(...)
and self.model(...) to optionally synchronize the device before stopping the
timer when CUDA is available (e.g., check torch.cuda.is_available() and call
torch.cuda.synchronize() if a new config flag like self.sync_cuda_timings is
true), and propagate this gating to both the vision timing (vit_prefill_ms,
vit_prefill_tokens) and LLM timing (llm_prefill_ms/llm_decode_ms) so timings
reported from functions like self.visual and self.model reflect real GPU work
when configured.

1186-1192: ⚠️ Potential issue | 🟡 Minor

Include video tokens in the image token mask to match qwen3_5.py behavior and ensure accurate token counting.

The image_mask at line 1188 uses only self.image_token_id, but self.video_token_id is configured at line 1065 and should be included. The sibling model qwen3_5.py at line 517 correctly uses mask = (input_ids == self.image_token_id) | (input_ids == self.video_token_id). Without video tokens in the mask, the vit_prefill_tokens count becomes incomplete, and deepstack embeddings at lines 1194–1201 also skip video tokens.

♻️ Suggested fix
-            image_mask = input_ids == self.image_token_id
+            image_mask = (input_ids == self.image_token_id) | (
+                input_ids == self.video_token_id
+            )
             if image_mask.any():
                 vit_prefill_tokens = int(image_mask.sum().item())
                 input_embeds[image_mask] = vision_embeds.to(input_embeds.dtype)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pymllm/models/qwen3_vl.py` around lines 1186 - 1192, The image mask currently
only checks for self.image_token_id in the block that computes input_embeds via
self.model.embed_tokens; update the mask logic used in image_mask (and any
subsequent uses of vit_prefill_tokens) to include self.video_token_id as well
(i.e., mask = (input_ids == self.image_token_id) | (input_ids ==
self.video_token_id)) so that input_embeds[image_mask] = vision_embeds.to(...)
and vit_prefill_tokens = int(image_mask.sum().item()) correctly account for
video tokens; apply the same mask change wherever vit_prefill_tokens or
deepstack embedding replacements are computed so video tokens are handled
identically to image tokens.
pymllm/orchestrator/scheduler_process.py (1)

792-837: ⚠️ Potential issue | 🟡 Minor

Conflicting llm_decode_ms updates: per-batch accumulation is overwritten by cumulative recompute.

At Line 799 you accumulate the model-runner-reported decode time per batch:

req.llm_decode_ms = (req.llm_decode_ms or 0.0) + out["llm_decode_ms"]

but immediately afterwards (lines 833–837) for any decode batch you overwrite it with (now - decode_start_tic) * 1000.0, which is a wall-clock cumulative timer including scheduler overhead, queueing, etc.

The result is:

  • The accumulation at Line 799 is dead work for decode mode.
  • The reported llm_decode_ms is wall-clock cumulative since the first decode step, not the sum of per-step model forward time, which conflicts with how llm_prefill_ms/vit_prefill_ms are reported (single forward duration).

Pick one source of truth. If you want cumulative wall-clock, drop the per-batch accumulation. If you want sum of forward times, drop the cumulative recompute and keep the accumulator.

♻️ Suggested fix (keep model-reported per-step accumulation only)
-        if batch.forward_mode.is_decode():
-            _decode_now = time.perf_counter()
-            for req in batch.reqs:
-                if req.decode_start_tic is not None:
-                    req.llm_decode_ms = (_decode_now - req.decode_start_tic) * 1000.0
-
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pymllm/orchestrator/scheduler_process.py` around lines 792 - 837, The code
currently both accumulates model-reported per-step decode time into
req.llm_decode_ms (using out["llm_decode_ms"]) and later overwrites it for
decode batches with a wall-clock recompute in the batch.forward_mode.is_decode()
loop; remove the latter overwrite so the per-step model-reported accumulation is
the single source of truth: delete or skip the assignment req.llm_decode_ms =
(_decode_now - req.decode_start_tic) * 1000.0 inside the
batch.forward_mode.is_decode() block (references: req.llm_decode_ms,
out["llm_decode_ms"], batch.forward_mode.is_decode(), req.decode_start_tic) so
llm_decode_ms remains the sum of model-runner-reported decode times.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py`:
- Around line 6-9: The docstring in bench_w4a16_vs_w8a8.py contains a hardcoded
developer worktree path; replace that line with a generic, portable instruction
such as "cd to the repository root" or "run from the repository root" and/or
show a relative invocation like "python3
mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py" so users know to run the script
from the repo root instead of the specific /workspace/.worktrees/... path.

In
`@mllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h`:
- Around line 113-183: The constructor EpilogueVisitorPerRowPerCol has its
member-initializer list out of declaration order causing -Wreorder; reorder the
initializer list to match the class member declaration order (Params const&
params_, SharedStorage& shared_storage_, MatrixCoord extent_, MatrixCoord
extent_real_, ElementwiseFunctor elementwise_, bool with_bias_, bool
per_token_quant_, bool per_channel_quant_, AlphaScaleElementType*
ptr_alpha_row_, AlphaScaleElementType* ptr_alpha_col_, ScaleTileIterator
iterator_alpha_col_, OutputTileIterator iterator_C_, OutputTileIterator
iterator_D_, ...), i.e., move extent_real_ to be initialized immediately after
extent_ and ensure elementwise_ follows extent_real_ (and the rest follow their
declared order) while keeping the existing initialization expressions and
runtime logic unchanged.

In
`@mllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h`:
- Around line 184-202: The Params() constructor in gemm_with_epilogue_visitor.h
initializes params_alpha_col but omits params_alpha_row, leaving
params_alpha_row potentially uninitialized; update the Params() member
initializer list for the Params() constructor to include params_alpha_row(0)
(mirroring params_alpha_col(0)) so ScaleTileIterator::Params for the row scale
is set to the zero/empty state consistent with ptr_alpha_row being nullptr.
- Around line 204-224: The initializer list for Params mistakenly constructs
params_alpha_row from the column layout
(params_alpha_row(args.ref_alpha_col.layout())); change it to use the actual row
reference layout by initializing params_alpha_row with
args.ref_alpha_row.layout() in the Params constructor so the alpha_row scale
iterator gets correct ColumnMajor strides (update the initializer near
params_alpha_col/params_alpha_row and verify ptr_alpha_row uses
args.ref_alpha_row.data()).

In `@mllm-kernel/mllm_kernel/cuda/jit/gptq_marlin_repack.py`:
- Around line 56-75: Add explicit input validation at the start of
gptq_marlin_repack: check num_bits is a positive divisor of 32 (raise ValueError
if num_bits <= 0 or 32 % num_bits != 0), ensure tile_size alignment by
validating size_k % tile_size == 0 (use tile_size = 16 as in the function) and
ensure (size_n * tile_size) % (32 // num_bits) == 0 to avoid truncation, and
verify b_q_weight.dtype is torch.int32 (raise TypeError if not). Use the
existing symbols gptq_marlin_repack, pack_factor (computed as 32 // num_bits),
tile_size, and b_q_weight to locate where to add these checks and raise clear,
descriptive exceptions before calling _normalize_perm or the kernel.

In `@mllm-kernel/mllm_kernel/cuda/jit/int8_scaled_mm_cutlass.py`:
- Around line 74-85: The hardcoded "-arch=sm_87" in the extra_cuda_cflags
restricts the JIT cubin to SM87+, so change the build step that constructs
extra_cuda_cflags (the list passed into the JIT compile call and the
build_directory=cache_dir) to detect the local GPU compute capability at runtime
(e.g., via torch.cuda.get_device_capability() or the CUDA runtime API) and emit
a matching "-arch=sm_{major}{minor}" flag instead of the fixed "-arch=sm_87";
also include that SM string in the cache directory name (e.g., append
"/sm{major}{minor}/" to cache_dir) so the cache is per-arch and avoids
collisions. Ensure the code path that currently builds extra_cuda_cflags and
calls the JIT compile uses the detected major/minor for both the flag and the
cache directory.

In `@pymllm/layers/rms_norm.py`:
- Around line 46-54: The fallback branch in the rms norm function returns the
original unmodified residual after flashinfer.norm.fused_add_rmsnorm fails;
change the fallback to explicitly compute and return the updated residual (e.g.,
compute new_residual = x + residual and return _torch_rmsnorm(new_residual,
self.weight, self.eps), new_residual) instead of returning the old residual, and
narrow the except clause to catch only expected runtime errors (e.g.,
RuntimeError/ValueError) and add a logger call (use a module logger like logger
= logging.getLogger(__name__)) to record unexpected failures before falling
back; locate these changes around flashinfer.norm.fused_add_rmsnorm and the
fallback that currently calls _torch_rmsnorm.

In `@pymllm/models/qwen3.py`:
- Around line 296-304: The timing around self.model(...) currently measures CUDA
launch overhead not actual GPU execution; update the timing in the qwen3.py
block that sets forward_batch.llm_prefill_ms and forward_batch.llm_decode_ms to
either (a) call torch.cuda.synchronize() immediately after self.model(...)
before computing _llm_ms so the measured delta reflects real GPU work, or (b)
use torch.cuda.Event start/stop + elapsed_time for GPU timings; then set
forward_batch.llm_prefill_ms/llm_decode_ms accordingly and add a one-line
comment beside this logic (referencing forward_batch.forward_mode.is_extend(),
forward_batch.llm_prefill_ms, forward_batch.llm_decode_ms, and self.model)
documenting whether the field represents engine-perceived latency (asynchronous
launch time) or actual GPU execution time (synchronized or event-based).

In `@pymllm/orchestrator/tokenizer_process.py`:
- Around line 374-386: When mm_inputs provides multimodal input_ids
(image_inputs["input_ids"]) we currently override input_ids without enforcing
the model context cap; after extracting proc_input_ids in the block handling
mm_inputs/image_inputs, truncate proc_input_ids to self._context_length before
converting to a Python list or assigning to input_ids (e.g., slice ndarray or
list to [:self._context_length] or use equivalent method for objects exposing
tolist), so input_ids never exceed the model's context length; reference
symbols: mm_inputs, image_inputs, proc_input_ids, input_ids,
self._context_length.

In `@pymllm/quantization/kernels/int8_activation_triton.py`:
- Around line 59-82: Replace the brittle assert and the risky negative 2-index
stride access: in the function that calls _per_token_quant_int8 (the block
around the _per_token_quant_int8[(M,)](...) call), replace "assert
x.is_contiguous()" with an explicit check that raises ValueError if not
contiguous, and compute stride_x and stride_xq defensively (use x.stride(-2) /
x_q.stride(-2) when x.ndim >= 2 else fall back to x.stride(0) / x_q.stride(0));
pass those computed stride_x and stride_xq into the _per_token_quant_int8 call.
This keeps the documented "any shape" behavior and avoids IndexError on 1-D
inputs and also fixes the assert being stripped under python -O.

In `@pymllm/README-ZH.md`:
- Around line 199-201: The example in the README currently hard-codes a
developer-specific absolute path ("/workspace/xcd_mllm/test.png") inside the
JSON snippet for the "image_url" entry; replace that literal with a neutral
placeholder (e.g., "<path/to/image.png>" or a sample relative/remote URL) so
users won’t copy a non-existent path—update the "image_url" -> "url" value in
the JSON array shown in the README-ZH.md accordingly.

In `@pymllm/README.md`:
- Around line 50-60: The README's environment table lists a non-existent PyPI
package/version "flashinfer: `0.6.7`" which makes the setup unreproducible;
update the README.md entry referencing "flashinfer 0.6.7" by either removing the
flashinfer line, replacing it with the correct source (e.g., a Git URL, private
repo, or build instructions), or adding explicit install instructions and
provenance (e.g., how to build/install from source or a link to the correct
release). Edit the list block in README.md where "flashinfer: `0.6.7`" appears
so the table accurately reflects how to obtain or install flashinfer, and
include the concrete install command or URL if you choose to keep the
dependency.

---

Outside diff comments:
In `@pymllm/executor/model_runner.py`:
- Around line 486-505: In _resolve_quant_config, don't return the full HF config
when encountering "config.json" that lacks "quantization_config"; instead, if
fname == "config.json" and "quantization_config" in cfg then return
cfg["quantization_config"], otherwise continue the loop to allow later
candidates (e.g., "quantize_config.json") to be checked; remove the duplicate
post-loop fallback that re-reads config.json since the loop now handles it, and
keep the final return {} if nothing is found.

In `@pymllm/models/qwen3_vl.py`:
- Around line 1172-1226: The timing measurements around vision and LLM forwards
(_vit_t0/_llm_t0) only measure kernel launch time and under-report real GPU
latency; modify the blocks that call self.visual(...) and self.model(...) to
optionally synchronize the device before stopping the timer when CUDA is
available (e.g., check torch.cuda.is_available() and call
torch.cuda.synchronize() if a new config flag like self.sync_cuda_timings is
true), and propagate this gating to both the vision timing (vit_prefill_ms,
vit_prefill_tokens) and LLM timing (llm_prefill_ms/llm_decode_ms) so timings
reported from functions like self.visual and self.model reflect real GPU work
when configured.
- Around line 1186-1192: The image mask currently only checks for
self.image_token_id in the block that computes input_embeds via
self.model.embed_tokens; update the mask logic used in image_mask (and any
subsequent uses of vit_prefill_tokens) to include self.video_token_id as well
(i.e., mask = (input_ids == self.image_token_id) | (input_ids ==
self.video_token_id)) so that input_embeds[image_mask] = vision_embeds.to(...)
and vit_prefill_tokens = int(image_mask.sum().item()) correctly account for
video tokens; apply the same mask change wherever vit_prefill_tokens or
deepstack embedding replacements are computed so video tokens are handled
identically to image tokens.

In `@pymllm/orchestrator/scheduler_process.py`:
- Around line 792-837: The code currently both accumulates model-reported
per-step decode time into req.llm_decode_ms (using out["llm_decode_ms"]) and
later overwrites it for decode batches with a wall-clock recompute in the
batch.forward_mode.is_decode() loop; remove the latter overwrite so the per-step
model-reported accumulation is the single source of truth: delete or skip the
assignment req.llm_decode_ms = (_decode_now - req.decode_start_tic) * 1000.0
inside the batch.forward_mode.is_decode() block (references: req.llm_decode_ms,
out["llm_decode_ms"], batch.forward_mode.is_decode(), req.decode_start_tic) so
llm_decode_ms remains the sum of model-runner-reported decode times.

In `@pymllm/server/launch.py`:
- Around line 707-768: The timing block uses the loop variable r after the
for-loop in the completion handler (the results/choices assembly around
engine.generate_async), which can be undefined when results is empty; fix by
binding r = results[-1] if results else {} (or an empty dict) immediately after
the for i, r in enumerate(results): loop and before constructing the
ORJSONResponse so all r.get(...) calls are safe (refer to the loop producing
choices and the timing dictionary construction).

---

Nitpick comments:
In `@mllm-kernel/benchmarks/bench_int8_scaled_mm.py`:
- Around line 145-151: The script crashes on CUDA-less machines because the
__main__ block calls torch.cuda.get_device_name(0) and
torch.cuda.get_device_capability(0) (and internal functions like
_torch_int_mm_scaled and _try_load_cutlass_kernel assume CUDA) unconditionally;
fix by guarding CUDA-specific behavior with a torch.cuda.is_available() check in
the __main__ path: if CUDA is available, print device name/capability and call
run_benchmarks() as now; if not, print a clear message and either skip CUDA-only
benchmarks or exit early so _torch_int_mm_scaled and _try_load_cutlass_kernel
are never invoked. Ensure all calls to torch.cuda.get_device_name,
torch.cuda.get_device_capability, _torch_int_mm_scaled, and
_try_load_cutlass_kernel are only reached when torch.cuda.is_available() is
True.
- Around line 81-142: run_benchmarks currently times each backend but never
verifies correctness; add a one-shot correctness check per shape before calling
bench_fn: compute a reference output via _torch_int_mm_scaled(mat_a, mat_b,
scales_a, scales_b, out_dtype=out_dtype), then for each backend in backends
(skip "torch._int_mm") call its fn with b_arg = mat_b_colmaj if name ==
"cutlass" else mat_b and out_dtype=out_dtype, compute a relative
mean-absolute-error (e.g.,
(out.float()-ref.float()).abs().mean()/ref.float().abs().mean().clamp(min=1e-6)),
and if rel > 1e-2 print a clear mismatch message including the backend name and
shape and mark that backend as invalid (skip timing or record error) so bench_fn
is only run for backends that pass the sanity check.

In `@mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py`:
- Around line 68-71: The unpacking from load_marlin() in prepare_marlin_weights
currently binds gptq_marlin_gemm but never uses it; update the tuple unpack in
prepare_marlin_weights to either drop that value or rename it to
_gptq_marlin_gemm (or simply _ ) so the unused symbol is intentionally ignored
and the linter warning (RUF059) is silenced—keep the other symbols
(gptq_marlin_repack, marlin_make_workspace, marlin_make_empty_g_idx,
marlin_permute_scales, SCALAR_TYPE_UINT4B8) unchanged.
- Around line 158-162: The variable tag computed in the benchmark loop (tag =
"decode" if M <= 8 else "prefill") is never used; either remove that line or
include tag in the formatted output; update the print call (the one formatting
f"  ({M:>3},{K:>4},{N:>4}) {desc:<8s} ...") to incorporate tag (e.g., add
{tag:<8s} after {desc:<8s}) if you want rows labeled, otherwise delete the
unused tag assignment to avoid the dead variable.

In `@mllm-kernel/cmake/CPM.cmake`:
- Around line 12-23: The download block using file(DOWNLOAD) may leave a partial
file at ${CPM_DOWNLOAD_LOCATION} on failure; update the failure path in the
CPM.cmake snippet that handles download_status (the list(GET download_status 0
download_status_code) branch) to: if download_status_code is non-zero,
remove/unlink ${CPM_DOWNLOAD_LOCATION} if it exists to avoid poisoning the
cache, include the download_status/details in the message(FATAL_ERROR) for
better diagnostics, and then fail; reference the CPM_DOWNLOAD_LOCATION,
CPM_VERSION, download_status and download_status_code symbols to locate and
modify the file(DOWNLOAD) failure handling.
- Around line 14-18: Add integrity verification to the fallback file(DOWNLOAD
...) call by providing an EXPECTED_HASH for CPM v0.42.0 and failing the build if
the downloaded file's hash doesn't match; update the file(DOWNLOAD ...)
invocation that uses CPM_VERSION, CPM_DOWNLOAD_LOCATION and download_status to
include EXPECTED_HASH "<sha256-or-appropriate-hash-for-v0.42.0>" and check
download_status (and/or use file(DOWNLOAD ... EXPECTED_HASH) behavior) so the
script stops with an error when the hash verification fails rather than silently
using a tampered CPM.cmake.

In `@mllm-kernel/include/mllm_kernel/scalar_type.hpp`:
- Around line 260-262: Remove the global namespace alias `namespace host =
::mllm_kernel::host;` from the public header `scalar_type.hpp` (it pollutes
every TU); instead, delete that line and add a localized alias in each
translation unit that needs the short name (e.g., `marlin.cuh`,
`gptq_marlin.cuh`, `kernel.h`, `marlin_template.h`) by declaring a local
`namespace host = ::mllm_kernel::host;` inside the appropriate enclosing
namespace and, where useful, add selective using declarations like `using
host::div_ceil;` so call sites keep the short form without exporting `host` from
the header.

In `@mllm-kernel/mllm_kernel/cuda/csrc/gemm/int8/int8_scaled_mm_cutlass.cu`:
- Around line 156-217: sm89_dispatch_shape is dead/unreachable because SM80–SM89
flows unconditionally through sm80_dispatch_shape; either remove the unused
template or guard it so it can't silently rot. Fix by either deleting the
sm89_dispatch_shape template entirely, or wrap its entire definition with a
clear opt-in preprocessor guard (e.g. `#ifdef` MLLM_KEEP_SM89_REFERENCE ...
`#endif`) and add a short comment explaining it's intentionally preserved for
reference; reference the sm89_dispatch_shape and sm80_dispatch_shape symbols so
reviewers can find the related code paths.
- Around line 27-34: getSMVersion currently calls cudaGetDevice and
cudaDeviceGetAttribute without checking their cudaError_t results, which
swallows CUDA errors and returns 0; update getSMVersion to check the return
values of cudaGetDevice and both cudaDeviceGetAttribute calls, and on failure
either propagate or surface the cudaError_t (e.g., log/throw a descriptive error
containing the cudaGetErrorString result) instead of silently returning 0 so
callers (and the user) see the real CUDA error; reference the getSMVersion
function and the cudaGetDevice / cudaDeviceGetAttribute calls when making the
change.
- Around line 121-135: Add a one-line comment next to the ldc assignment that
explains why ldc is set to 0 for bias broadcasting: note that setting ldc = 0
with bias_ptr passed into Gemm::Arguments causes CUTLASS to read each row from
the same base pointer (broadcasting the [N] bias across rows), so keep ldc = 0
and the bias_ptr usage as-is in the Gemm::Arguments construction. Reference
variables: ldc, bias_ptr, and Gemm::Arguments (and the
EpilogueOutputOp/EpilogueVisitor parameters) so reviewers can find the exact
spot to add the comment.

In `@mllm-kernel/mllm_kernel/cuda/jit/gptq_marlin_repack.py`:
- Around line 12-25: The function _normalize_perm does two separate torch.any
checks which cause two GPU→CPU synchronizations; replace the dual checks
"torch.any(perm < 0) or torch.any(perm >= size_k)" with a single combined mask
and a single any call (e.g., torch.any((perm < 0) | (perm >= size_k))) so
there's only one GPU→CPU sync when validating perm; update the check in
_normalize_perm accordingly while keeping the surrounding device/dtype/length
validations unchanged.

In `@mllm-kernel/mllm_kernel/cuda/jit/gptq_marlin.py`:
- Around line 32-40: The code redundantly computes cpp_args via
make_cpp_args(dtype) and passes an explicit cuda_wrappers tuple to the `@jit`
decorator even though args=[dtype] already causes the framework to auto-generate
the identical cuda wrapper; remove the local cpp_args assignment and the
cuda_wrappers parameter from the `@jit` call, leaving args=[dtype] and
cpp_wrappers=[] intact so the auto-generated wrapper for gptq_marlin_gemm is
used; ensure you do not change func_name="gptq_marlin_gemm" or the cuda_files
entry ("gemm/marlin/gptq_marlin.cuh").

In `@mllm-kernel/mllm_kernel/cuda/jit/int8_scaled_mm_cutlass.py`:
- Around line 53-86: The _load_module() function has a TOCTOU race on the
globals _module and _CUTLASS_INC when called concurrently; protect the
lazy-initialization by adding a module-level threading.Lock (e.g.,
_load_module_lock) and wrap the critical section in _load_module() with that
lock, using a double-checked pattern (check _module before acquiring and
re-check after acquiring) so only one thread runs torch.utils.cpp_extension.load
and sets _module/_CUTLASS_INC while others return the initialized module.
- Around line 116-118: Validate the out_dtype argument instead of coercing: in
the wrapper that builds dtype_str and calls mod.int8_scaled_mm, check that
out_dtype is either torch.float16 or torch.bfloat16 and raise a clear ValueError
if not; only then map torch.float16 -> "float16" and torch.bfloat16 ->
"bfloat16" and pass dtype_str to mod.int8_scaled_mm (referencing the existing
variables out_dtype, dtype_str and the call to mod.int8_scaled_mm).

In `@mllm-kernel/tests/test_gptq_marlin_repack.py`:
- Around line 47-50: Add a short inline comment above the constants (tc_offsets
and pack_idx) in test_gptq_marlin_repack.py explaining that tc_offsets = [0, 1,
8, 9] and pack_idx = [0, 2, 4, 6, 1, 3, 5, 7] reproduce the ldmatrix / m16n8k16
tensor-core fragment row/element ordering used by Marlin's repack CUDA kernel;
include a note that these are the m16n8k16 fragment row/element offsets and add
a reference to the Marlin kernel source or file name and a brief reminder to
update these when changing bit-width or tile size (variables: tc_offsets,
pack_idx, tile_size, n_tiles).

In `@mllm-kernel/tests/test_gptq_marlin.py`:
- Around line 52-72: The test reimplements permutation logic in _get_scale_perms
and _marlin_permute_scales which duplicates get_scale_perms and
marlin_permute_scales in pymllm.quantization.methods.compressed_tensors and can
drift; replace the local copies by importing get_scale_perms and
marlin_permute_scales from that module (use them directly in the test) or, if
you intentionally want an independent reference implementation, add a concise
comment above _get_scale_perms/_marlin_permute_scales stating that they are a
deliberate independent copy to detect regressions in the production helpers and
must remain separate.

In `@mllm-kernel/tests/test_int8_scaled_mm_cutlass.py`:
- Around line 27-29: The top-level pytest.importorskip("torch") is unreachable
because torch is already imported at module scope; either remove the
importorskip call or make the module import conditional by moving the torch
import into the test setup/fixture so the module can be collected without torch.
Specifically, either delete the pytest.importorskip("torch") line near the start
of test_int8_scaled_mm_cutlass.py, or refactor the code so the torch import is
performed inside the fixture/function that checks torch.cuda.is_available()
(e.g., move the import into the fixture that contains the existing CUDA check),
ensuring collection works on environments without torch.

In `@pymllm/layers/rms_norm.py`:
- Around line 13-21: The current _torch_rmsnorm does a no-op cast and forces an
unnecessary fp32→input-dtype copy by returning x_norm.to(dtype=x.dtype) *
weight; change the sequence to keep computations in fp32, multiply x_norm by
weight converted to fp32 (e.g. weight_fp32 = weight.to(x_fp32.dtype)), then cast
the final product back to x.dtype once before return. Update _torch_rmsnorm to
compute x_fp32, var, x_norm in fp32, do x_norm * weight_fp32, and only at the
end .to(x.dtype) on that product.

In `@pymllm/models/qwen3.py`:
- Around line 393-399: The substring check that skips lm_head weights is
intentionally done before _remap_weight_name to handle both VL and non‑VL
checkpoints (e.g. "lm_head.weight", "language_model.lm_head.weight", or
"model.language_model.lm_head.weight"); add a short clarifying comment next to
the existing skip (the code block that tests `"lm_head.weight" in name`)
explaining that the substring check is purposeful and order-sensitive relative
to the _remap_weight_name(name: str) function so future readers understand
parity handling between VL/non‑VL serialized names.

In `@pymllm/orchestrator/model_runner_process.py`:
- Around line 592-595: Replace the fragile hasattr(cache, "page_size") probe
with an explicit type check for ChunkCache and short‑circuit before doing work:
update the caller (or the start of _insert_into_radix_cache) to return early
when isinstance(cache, ChunkCache) so you skip building RadixKey and calling
cache.insert / cache.match_prefix for ChunkCache instances; then remove the
now-dead hasattr-based block that sets self._rid_to_cache_protected_len[rid] =
0. Ensure you reference the ChunkCache type, _insert_into_radix_cache function,
and the code paths that call cache.insert and cache.match_prefix when making the
change.

In `@pymllm/orchestrator/scheduler_process.py`:
- Around line 849-852: The code unconditionally resets req.llm_decode_ms to 0.0
when appending to self._running_batch which can clobber a value set earlier in
process_batch_result (see the extend branch behavior around process_batch_result
and the extend output handling). Change the initialization logic in the block
where you set req.decode_start_tic and append to self._running_batch so that you
only set req.llm_decode_ms = 0.0 if req.llm_decode_ms is currently None (leave
existing non-None values intact); reference req.decode_start_tic,
req.llm_decode_ms, process_batch_result, and self._running_batch to find and
update the exact spot.

In `@pymllm/quantization/methods/compressed_tensors.py`:
- Around line 591-597: get_quant_method currently calls
_validate_supported_signature(self) on every invocation which repeats expensive
checks (including verify_marlin_supported and
torch.cuda.get_device_capability()) per layer; instead compute and cache the
signature once on the config object (e.g. in __init__ as self._cached_signature
or via a `@functools.cached_property`) and have get_quant_method use that cached
value, keeping the early ignore check (any(ignored and
prefix.startswith(ignored) for ignored in self.ignore)) and returning
CompressedTensorsLinearMethod(self, cached_signature).
- Around line 78-91: The error messages in verify_marlin_supports_shape use
hardcoded "64"/"128" which can go out of sync with the constants; update the
ValueError messages to reference GPTQ_MARLIN_MIN_THREAD_N and
GPTQ_MARLIN_MIN_THREAD_K (e.g., format or f-string) so the message reports the
actual constant values, and likewise ensure the third ValueError mentions
group_size and input_size_per_partition consistently; modify the raise calls in
verify_marlin_supports_shape to include the constant names/values instead of
literal numbers.
- Around line 1-13: Remove the eager top-level imports of gptq_marlin_gemm and
gptq_marlin_repack and instead import them lazily inside the runtime methods
that need Marlin (e.g., inside process_weights_after_loading and apply of the
W4A16 quantization path); specifically, delete the module-level import
references and add local imports of mllm_kernel.cuda.jit.gptq_marlin_gemm and
gptq_marlin_repack at the start of process_weights_after_loading (and any
W4A16.apply callsites that perform repacking/GEMM), so the Marlin extension is
only loaded when those functions execute and importing this module no longer
fails for W8A8-only environments.
- Around line 282-310: weight_shape is intentionally registered as a placeholder
to receive the per-tensor shape buffer from compressed-tensors checkpoints but
is never read; add a one-line comment next to the weight_shape
creation/registration explaining this intended purpose (so future readers don't
remove it), and in process_weights_after_loading add a defensive check that
reads the registered weight_shape and asserts or logs if its values differ from
input_size_per_partition/output_size_per_partition (reference weight_shape,
layer.register_parameter("weight_shape", ...), process_weights_after_loading,
input_size_per_partition, output_size_per_partition).

In `@pymllm/server/launch.py`:
- Around line 743-767: Extract the repeated timing logic into a small helper
pair: implement _safe_tps(tokens, ms) to return None when tokens or ms are None
or ms <= 0 and otherwise compute tokens/(ms/1000.0), and implement
_build_timing(r, prompt_tokens, completion_tokens) which reads vit_prefill_ms,
llm_prefill_ms, llm_decode_ms from r and returns the timing dict
(vit_prefill_ms, llm_prefill_ms, llm_decode_ms, prefill_tokens, vit_prefill_tps
via _safe_tps(r.get("vit_prefill_tokens"), vit_ms), llm_prefill_tps via
_safe_tps(prompt_tokens, llm_pre_ms), and llm_decode_tps via
_safe_tps(completion_tokens, llm_dec_ms)); then replace the inline "timing": {
... } dict in the /v1/completions and /v1/chat/completions response construction
with "timing": _build_timing(r, prompt_tokens, completion_tokens) so both
endpoints share the same guarded logic.

In `@pymllm/tests/bench_w8a8_activation_quant.py`:
- Around line 91-98: The loop over backends currently swallows all exceptions
and prints a vague "ERR" (for name, fn in backends.items() with bench_fn), make
failures visible: remove the bare except so exceptions propagate during
CI/development, or if you must handle them, catch Exception as e and print its
message and stack trace (e.g., via traceback.print_exc or processLogger.error)
instead of discarding e; also if you don't use the loop variable name, iterate
over backends.values() (or rename it to _ ) to avoid the unused-variable
warning.
- Around line 101-107: Guard the CUDA-only entry point by checking
torch.cuda.is_available() at the top of the if __name__ == "__main__" block
before calling torch.cuda.get_device_name(0) or
torch.cuda.get_device_capability(0); if CUDA is not available, print a clear
message and exit (or return) instead of calling those functions. Modify the main
block that calls run_benchmarks() so the device name/capability prints only
after the availability check, referencing the existing __main__ block,
torch.cuda.is_available(), torch.cuda.get_device_name,
torch.cuda.get_device_capability, and run_benchmarks to locate where to add the
guard.

In `@pymllm/tests/test_compressed_tensors_config.py`:
- Around line 121-123: Add a missing blank line between the two top-level test
functions: insert one additional empty line after the end of
test_get_quant_method_respects_ignore so there are two blank lines before the
definition of test_get_quant_method_rejects_unsupported_signature to satisfy PEP
8 E302; locate the functions by their names
test_get_quant_method_respects_ignore and
test_get_quant_method_rejects_unsupported_signature and ensure exactly two blank
lines separate these top-level function definitions.

In `@pymllm/tests/test_compressed_tensors_runtime.py`:
- Around line 306-315: The test uses very loose tolerances (atol=2e-1,
rtol=2e-1) which can hide real regressions; update the assertion comparing out
(from qm.apply) and ref (computed using ct._per_token_quant_int8 and
torch._int_mm) to use tighter tolerances such as atol=5e-2 and rtol=5e-2 (or
stronger when running with controlled K/scales), so that differences from
CUTLASS INT8 GEMM or scale-broadcast bugs are caught early; locate the
comparison around the variables x, bias, out, qm.apply,
ct._per_token_quant_int8, torch._int_mm, and layer.weight_scale and replace the
tolerance values accordingly.

In `@pymllm/tests/test_qwen3_weight_loading.py`:
- Around line 67-115: The two tests
test_load_weights_stacks_qkv_and_gate_up_from_model_prefix and
test_load_weights_accepts_model_language_model_prefix should be merged and
parametrized over layer_prefix to also include the "language_model" (no leading
"model.") case that _remap_weight_name handles; change them into a single
pytest.mark.parametrize test that calls _build_language_weights(cfg,
layer_prefix=layer_prefix), constructs Qwen3ForCausalLM(cfg), calls
model.load_weights(...), and then runs the same q/k/v and gate/up and
embedding/lm_head assertions using the parameterized layer_prefix so all three
remap branches ( "model", "model.language_model", "language_model") are covered
and the duplicated test bodies (referencing Qwen3ForCausalLM,
_build_language_weights, model.load_weights, layerX.self_attn.qkv_proj, and
layerX.mlp.gate_up_proj) are consolidated.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ff521d21-fe30-425b-8334-d94821c10a06

📥 Commits

Reviewing files that changed from the base of the PR and between 7282b0f and efb65d0.

📒 Files selected for processing (38)
  • mllm-kernel/benchmarks/bench_int8_scaled_mm.py
  • mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py
  • mllm-kernel/cmake/CPM.cmake
  • mllm-kernel/include/mllm_kernel/scalar_type.hpp
  • mllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h
  • mllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h
  • mllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h
  • mllm-kernel/mllm_kernel/cuda/csrc/gemm/int8/int8_scaled_mm_cutlass.cu
  • mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/marlin.cuh
  • mllm-kernel/mllm_kernel/cuda/jit/__init__.py
  • mllm-kernel/mllm_kernel/cuda/jit/gptq_marlin.py
  • mllm-kernel/mllm_kernel/cuda/jit/gptq_marlin_repack.py
  • mllm-kernel/mllm_kernel/cuda/jit/int8_scaled_mm_cutlass.py
  • mllm-kernel/tests/test_gptq_marlin.py
  • mllm-kernel/tests/test_gptq_marlin_repack.py
  • mllm-kernel/tests/test_int8_scaled_mm_cutlass.py
  • pymllm/README-ZH.md
  • pymllm/README.md
  • pymllm/executor/model_runner.py
  • pymllm/layers/rms_norm.py
  • pymllm/models/__init__.py
  • pymllm/models/qwen3.py
  • pymllm/models/qwen3_vl.py
  • pymllm/orchestrator/detokenizer_process.py
  • pymllm/orchestrator/model_runner_process.py
  • pymllm/orchestrator/scheduler_process.py
  • pymllm/orchestrator/tokenizer_process.py
  • pymllm/quantization/kernels/__init__.py
  • pymllm/quantization/kernels/int8_activation_triton.py
  • pymllm/quantization/methods/__init__.py
  • pymllm/quantization/methods/compressed_tensors.py
  • pymllm/server/launch.py
  • pymllm/tests/bench_w8a8_activation_quant.py
  • pymllm/tests/test_compressed_tensors_config.py
  • pymllm/tests/test_compressed_tensors_runtime.py
  • pymllm/tests/test_qwen3_forward_timing.py
  • pymllm/tests/test_qwen3_model_registry.py
  • pymllm/tests/test_qwen3_weight_loading.py

Comment on lines +6 to +9
Usage:
cd /workspace/.worktrees/pymllm-qwen3-vl-w8a8
python3 mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py
"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Replace the hardcoded developer worktree path in the docstring.

cd /workspace/.worktrees/pymllm-qwen3-vl-w8a8 is specific to a single developer's environment and won't help users running the benchmark from a fresh clone.

📝 Suggested doc fix
 Usage:
-    cd /workspace/.worktrees/pymllm-qwen3-vl-w8a8
-    python3 mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py
+    # From the repository root:
+    python3 mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
Usage:
cd /workspace/.worktrees/pymllm-qwen3-vl-w8a8
python3 mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py
"""
Usage:
# From the repository root:
python3 mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py
"""
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py` around lines 6 - 9, The
docstring in bench_w4a16_vs_w8a8.py contains a hardcoded developer worktree
path; replace that line with a generic, portable instruction such as "cd to the
repository root" or "run from the repository root" and/or show a relative
invocation like "python3 mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py" so users
know to run the script from the repo root instead of the specific
/workspace/.worktrees/... path.

Comment on lines +113 to +183
Params const& params_;
SharedStorage& shared_storage_;
MatrixCoord extent_;
MatrixCoord extent_real_;
ElementwiseFunctor elementwise_;

bool const with_bias_;
bool const per_token_quant_;
bool const per_channel_quant_;

AlphaScaleElementType* ptr_alpha_row_;
AlphaScaleElementType* ptr_alpha_col_;
ScaleTileIterator iterator_alpha_col_;
OutputTileIterator iterator_C_;
OutputTileIterator iterator_D_;

AlphaScaleElementType element_alpha_row_ = 1.0f;
AlphaScaleElementType element_alpha_col_ = 1.0f;
typename ScaleTileIterator::Fragment fragment_alpha_col_;
typename OutputTileIterator::Fragment fragment_C_;
typename OutputTileIterator::Fragment fragment_D_;

ElementAccumulator beta_;

int column_offset_;

MatrixCoord thread_offset_;

public:
CUTLASS_DEVICE
EpilogueVisitorPerRowPerCol(
Params const& params,
SharedStorage& shared_storage,
cutlass::MatrixCoord const& problem_size,
int thread_idx,
int warp_idx,
int lane_idx,
typename ScaleTileIterator::Params params_alpha_col,
typename OutputTileIterator::Params params_C,
typename OutputTileIterator::Params params_D,
bool with_bias,
bool per_token_quant,
bool per_channel_quant,
AlphaScaleElementType* ptr_alpha_row,
AlphaScaleElementType* ptr_alpha_col,
typename OutputTileIterator::Element* ptr_C,
typename OutputTileIterator::Element* ptr_D,
cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0),
int column_offset = 0,
cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0))
: params_(params),
shared_storage_(shared_storage),
extent_(problem_size),
elementwise_(params.elementwise),
with_bias_(with_bias),
per_token_quant_(per_token_quant),
per_channel_quant_(per_channel_quant),
ptr_alpha_row_(ptr_alpha_row),
ptr_alpha_col_(ptr_alpha_col),
iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset),
iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset),
iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset),
extent_real_(problem_size_real) {
if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr)) {
element_alpha_col_ = *ptr_alpha_col_;
}

if (!per_token_quant_ && (ptr_alpha_row_ != nullptr)) {
element_alpha_row_ = *ptr_alpha_row_;
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

cat -n mllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h | sed -n '110,185p'

Repository: UbiquitousLearning/mllm

Length of output: 3202


Member-initializer list order doesn't match declaration order — emits -Wreorder.

extent_real_ is declared at line 116 (4th member) but initialized last at line 175, and elementwise_ is declared at line 117 but initialized at line 166 before with_bias_/per_token_quant_/per_channel_quant_. Since C++ always initializes members in declaration order, the apparent order in the initializer list is misleading and any compiler with -Wreorder (gcc/clang default in many builds) will warn here. There's no functional bug because none of the earlier initializers read extent_real_, but reordering keeps the intent and the actual behavior in sync.

♻️ Proposed reorder of the initializer list
       : params_(params),
         shared_storage_(shared_storage),
         extent_(problem_size),
+        extent_real_(problem_size_real),
         elementwise_(params.elementwise),
         with_bias_(with_bias),
         per_token_quant_(per_token_quant),
         per_channel_quant_(per_channel_quant),
         ptr_alpha_row_(ptr_alpha_row),
         ptr_alpha_col_(ptr_alpha_col),
         iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset),
         iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset),
-        iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset),
-        extent_real_(problem_size_real) {
+        iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset) {
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@mllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h`
around lines 113 - 183, The constructor EpilogueVisitorPerRowPerCol has its
member-initializer list out of declaration order causing -Wreorder; reorder the
initializer list to match the class member declaration order (Params const&
params_, SharedStorage& shared_storage_, MatrixCoord extent_, MatrixCoord
extent_real_, ElementwiseFunctor elementwise_, bool with_bias_, bool
per_token_quant_, bool per_channel_quant_, AlphaScaleElementType*
ptr_alpha_row_, AlphaScaleElementType* ptr_alpha_col_, ScaleTileIterator
iterator_alpha_col_, OutputTileIterator iterator_C_, OutputTileIterator
iterator_D_, ...), i.e., move extent_real_ to be initialized immediately after
extent_ and ensure elementwise_ follows extent_real_ (and the rest follow their
declared order) while keeping the existing initialization expressions and
runtime logic unchanged.

Comment on lines +184 to +202
CUTLASS_HOST_DEVICE
Params()
: swizzle_log_tile(0),
params_A(0),
params_B(0),
params_alpha_col(0),
params_C(0),
params_D(0),
batch_count(0),
gemm_k_size(0),
mode(cutlass::gemm::GemmUniversalMode::kGemm),
ptr_A(nullptr),
ptr_B(nullptr),
ptr_alpha_col(nullptr),
ptr_alpha_row(nullptr),
ptr_C(nullptr),
ptr_D(nullptr),
batch_stride_A(0),
batch_stride_B(0) {}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Default Params() constructor omits params_alpha_row.

ptr_alpha_row is explicitly nulled but the corresponding params_alpha_row member is never listed in the initializer list (compare with params_alpha_col(0)). If ScaleTileIterator::Params is not trivially default-constructible with the right zero-stride state, this leaves it in an inconsistent state vs. its _col counterpart. Add the matching initializer for symmetry and to avoid surprises if ScaleTileIterator::Params changes.

🛡️ Proposed fix
           params_alpha_col(0),
+          params_alpha_row(0),
           params_C(0),
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
CUTLASS_HOST_DEVICE
Params()
: swizzle_log_tile(0),
params_A(0),
params_B(0),
params_alpha_col(0),
params_C(0),
params_D(0),
batch_count(0),
gemm_k_size(0),
mode(cutlass::gemm::GemmUniversalMode::kGemm),
ptr_A(nullptr),
ptr_B(nullptr),
ptr_alpha_col(nullptr),
ptr_alpha_row(nullptr),
ptr_C(nullptr),
ptr_D(nullptr),
batch_stride_A(0),
batch_stride_B(0) {}
CUTLASS_HOST_DEVICE
Params()
: swizzle_log_tile(0),
params_A(0),
params_B(0),
params_alpha_col(0),
params_alpha_row(0),
params_C(0),
params_D(0),
batch_count(0),
gemm_k_size(0),
mode(cutlass::gemm::GemmUniversalMode::kGemm),
ptr_A(nullptr),
ptr_B(nullptr),
ptr_alpha_col(nullptr),
ptr_alpha_row(nullptr),
ptr_C(nullptr),
ptr_D(nullptr),
batch_stride_A(0),
batch_stride_B(0) {}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@mllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h`
around lines 184 - 202, The Params() constructor in gemm_with_epilogue_visitor.h
initializes params_alpha_col but omits params_alpha_row, leaving
params_alpha_row potentially uninitialized; update the Params() member
initializer list for the Params() constructor to include params_alpha_row(0)
(mirroring params_alpha_col(0)) so ScaleTileIterator::Params for the row scale
is set to the zero/empty state consistent with ptr_alpha_row being nullptr.

Comment on lines +204 to +224
Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape_, int gemm_k_size_, int* workspace_)
: problem_size(args.problem_size),
swizzle_log_tile(0),
params_A(args.ref_A.layout()),
params_B(args.ref_B.layout()),
params_alpha_col(args.ref_alpha_col.layout()),
params_alpha_row(args.ref_alpha_col.layout()),
params_C(args.ref_C.layout()),
params_D(args.ref_D.layout()),
mode(args.mode),
batch_count(args.batch_count),
gemm_k_size(args.problem_size.k()),
ptr_A(args.ref_A.data()),
ptr_B(args.ref_B.data()),
ptr_alpha_col(args.ref_alpha_col.data()),
ptr_alpha_row(args.ref_alpha_row.data()),
ptr_C(args.ref_C.data()),
ptr_D(args.ref_D.data()),
batch_stride_A(args.batch_stride_A),
batch_stride_B(args.batch_stride_B),
epilogue_visitor(args.epilogue_visitor) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Confirm both Params have separate row/col layout sources in upstream TRT-LLM.
curl -fsSL \
  "https://raw.githubusercontent.com/NVIDIA/TensorRT-LLM/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h" \
  | grep -nE 'params_alpha_(row|col)\s*\('

Repository: UbiquitousLearning/mllm

Length of output: 233


Fix copy-paste bug: params_alpha_row initialized from wrong layout reference.

The alpha_row scale iterator's Params are being constructed from the column reference's layout. Since LayoutAlphaCol is RowMajor and LayoutAlphaRow is ColumnMajor (lines 56-57), this produces incorrect strides for the row scale iterator. This bug also exists in the upstream TRT-LLM source (commit be178810...), so it should be fixed here regardless.

Suggested fix
           params_alpha_col(args.ref_alpha_col.layout()),
-          params_alpha_row(args.ref_alpha_col.layout()),
+          params_alpha_row(args.ref_alpha_row.layout()),
           params_C(args.ref_C.layout()),
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@mllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h`
around lines 204 - 224, The initializer list for Params mistakenly constructs
params_alpha_row from the column layout
(params_alpha_row(args.ref_alpha_col.layout())); change it to use the actual row
reference layout by initializing params_alpha_row with
args.ref_alpha_row.layout() in the Params constructor so the alpha_row scale
iterator gets correct ColumnMajor strides (update the initializer near
params_alpha_col/params_alpha_row and verify ptr_alpha_row uses
args.ref_alpha_row.data()).

Comment on lines +56 to +75
def gptq_marlin_repack(
b_q_weight: torch.Tensor,
perm: Optional[torch.Tensor],
size_k: int,
size_n: int,
num_bits: int,
) -> torch.Tensor:
"""Repack GPTQ/Compressed-Tensors weights into Marlin layout."""

pack_factor = 32 // num_bits
tile_size = 16
out = torch.empty(
(size_k // tile_size, size_n * tile_size // pack_factor),
dtype=b_q_weight.dtype,
device=b_q_weight.device,
)
kernel = _make_gptq_marlin_repack_kernel()
perm_t = _normalize_perm(perm, size_k, b_q_weight.device)
kernel(b_q_weight, perm_t, out, size_k, size_n, num_bits)
return out
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add input validation for num_bits, size_k, size_n, and b_q_weight.dtype.

pack_factor = 32 // num_bits and the output shape (size_k // tile_size, size_n * tile_size // pack_factor) silently truncate / divide-by-zero on bad inputs:

  • num_bits == 0 raises a confusing ZeroDivisionError (instead of a clear validation message).
  • num_bits values that don't evenly divide 32 (e.g. 5, 6, 7) silently produce a wrong layout.
  • size_k % tile_size != 0 or (size_n * tile_size) % pack_factor != 0 silently truncate the output dims.
  • The Marlin repack kernel expects b_q_weight to be int32 (matching how layer.weight_packed is stored in pymllm/quantization/methods/compressed_tensors.py lines 329–344); a wrong dtype would otherwise be caught only deep inside the CUDA kernel.

The function is invoked once per quantized linear at load time, so the cost of these checks is negligible.

🛡️ Proposed validation
 def gptq_marlin_repack(
     b_q_weight: torch.Tensor,
     perm: Optional[torch.Tensor],
     size_k: int,
     size_n: int,
     num_bits: int,
 ) -> torch.Tensor:
     """Repack GPTQ/Compressed-Tensors weights into Marlin layout."""
 
-    pack_factor = 32 // num_bits
-    tile_size = 16
+    if num_bits not in (4, 8):
+        raise ValueError(f"num_bits must be 4 or 8, got {num_bits}")
+    if b_q_weight.dtype != torch.int32:
+        raise ValueError(
+            f"b_q_weight must be int32, got {b_q_weight.dtype}"
+        )
+    tile_size = 16
+    pack_factor = 32 // num_bits
+    if size_k % tile_size != 0:
+        raise ValueError(
+            f"size_k ({size_k}) must be a multiple of tile_size ({tile_size})"
+        )
+    if (size_n * tile_size) % pack_factor != 0:
+        raise ValueError(
+            f"size_n*tile_size ({size_n * tile_size}) must be a multiple of "
+            f"pack_factor ({pack_factor})"
+        )
     out = torch.empty(
         (size_k // tile_size, size_n * tile_size // pack_factor),
         dtype=b_q_weight.dtype,
         device=b_q_weight.device,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@mllm-kernel/mllm_kernel/cuda/jit/gptq_marlin_repack.py` around lines 56 - 75,
Add explicit input validation at the start of gptq_marlin_repack: check num_bits
is a positive divisor of 32 (raise ValueError if num_bits <= 0 or 32 % num_bits
!= 0), ensure tile_size alignment by validating size_k % tile_size == 0 (use
tile_size = 16 as in the function) and ensure (size_n * tile_size) % (32 //
num_bits) == 0 to avoid truncation, and verify b_q_weight.dtype is torch.int32
(raise TypeError if not). Use the existing symbols gptq_marlin_repack,
pack_factor (computed as 32 // num_bits), tile_size, and b_q_weight to locate
where to add these checks and raise clear, descriptive exceptions before calling
_normalize_perm or the kernel.

Comment thread pymllm/models/qwen3.py
Comment on lines +296 to +304
_llm_t0 = time.perf_counter()
hidden_states = self.model(input_ids, positions, forward_batch)
_llm_ms = (time.perf_counter() - _llm_t0) * 1000.0

if forward_batch.forward_mode.is_extend():
forward_batch.llm_prefill_ms = _llm_ms
forward_batch.llm_decode_ms = None
else:
forward_batch.llm_decode_ms = _llm_ms
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Timing without torch.cuda.synchronize() measures launch overhead, not GPU execution.

time.perf_counter() deltas around self.model(...) capture only the time spent issuing CUDA work to the stream — the kernels are still running asynchronously when _llm_ms is computed. For decode (single token), this will systematically under-report by orders of magnitude; for prefill it's noisier. Since these values feed forward_batch.llm_prefill_ms / llm_decode_ms (consumed downstream via the orchestrator/server timing path), the "prefill_ms" reported to operators won't reflect actual prefill latency.

If the goal is engine-level perceived latency (good enough for SLO accounting since something downstream eventually syncs), please document that. If the goal is GPU-only model latency, sync explicitly:

🐛 Proposed fix (sync-based timing)
-        _llm_t0 = time.perf_counter()
-        hidden_states = self.model(input_ids, positions, forward_batch)
-        _llm_ms = (time.perf_counter() - _llm_t0) * 1000.0
+        if torch.cuda.is_available():
+            torch.cuda.synchronize()
+        _llm_t0 = time.perf_counter()
+        hidden_states = self.model(input_ids, positions, forward_batch)
+        if torch.cuda.is_available():
+            torch.cuda.synchronize()
+        _llm_ms = (time.perf_counter() - _llm_t0) * 1000.0

(or use torch.cuda.Event + elapsed_time to avoid host syncs in steady state)

As per coding guidelines: "Add comments for complex algorithms or non-obvious logic." Either way, please add a one-line comment clarifying which timing semantics this field carries.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pymllm/models/qwen3.py` around lines 296 - 304, The timing around
self.model(...) currently measures CUDA launch overhead not actual GPU
execution; update the timing in the qwen3.py block that sets
forward_batch.llm_prefill_ms and forward_batch.llm_decode_ms to either (a) call
torch.cuda.synchronize() immediately after self.model(...) before computing
_llm_ms so the measured delta reflects real GPU work, or (b) use
torch.cuda.Event start/stop + elapsed_time for GPU timings; then set
forward_batch.llm_prefill_ms/llm_decode_ms accordingly and add a one-line
comment beside this logic (referencing forward_batch.forward_mode.is_extend(),
forward_batch.llm_prefill_ms, forward_batch.llm_decode_ms, and self.model)
documenting whether the field represents engine-perceived latency (asynchronous
launch time) or actual GPU execution time (synchronized or event-based).

Comment on lines +374 to +386
# If AutoProcessor produced multimodal input_ids, they must override
# the plain tokenizer result. Otherwise the prompt contains only a
# single image placeholder token and won't match the visual features.
if mm_inputs is not None:
image_inputs = mm_inputs.get("image_inputs")
if image_inputs is not None and "input_ids" in image_inputs:
proc_input_ids = image_inputs["input_ids"]
if hasattr(proc_input_ids, "ndim") and proc_input_ids.ndim > 1:
proc_input_ids = proc_input_ids[0]
if hasattr(proc_input_ids, "tolist"):
input_ids = proc_input_ids.tolist()
else:
input_ids = list(proc_input_ids)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Multimodal input_ids bypass context_length truncation.

The text-only path applies truncation=True, max_length=self._context_length on line 361-364, but when mm_inputs["image_inputs"]["input_ids"] is present it unconditionally overrides input_ids without re-checking the context cap. For models that expand image placeholders into many vision tokens (Qwen3-VL), this can silently produce sequences exceeding the model's maximum position embeddings and trigger downstream errors or OOM.

🛡️ Suggested guard
                 if hasattr(proc_input_ids, "tolist"):
                     input_ids = proc_input_ids.tolist()
                 else:
                     input_ids = list(proc_input_ids)
+                if (
+                    self._context_length is not None
+                    and len(input_ids) > self._context_length
+                ):
+                    logger.warning(
+                        "Multimodal input_ids (%d) exceed context_length=%d; truncating.",
+                        len(input_ids),
+                        self._context_length,
+                    )
+                    input_ids = input_ids[: self._context_length]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pymllm/orchestrator/tokenizer_process.py` around lines 374 - 386, When
mm_inputs provides multimodal input_ids (image_inputs["input_ids"]) we currently
override input_ids without enforcing the model context cap; after extracting
proc_input_ids in the block handling mm_inputs/image_inputs, truncate
proc_input_ids to self._context_length before converting to a Python list or
assigning to input_ids (e.g., slice ndarray or list to [:self._context_length]
or use equivalent method for objects exposing tolist), so input_ids never exceed
the model's context length; reference symbols: mm_inputs, image_inputs,
proc_input_ids, input_ids, self._context_length.

Comment on lines +59 to +82
assert x.is_contiguous(), "Input must be contiguous"

M = x.numel() // x.shape[-1]
N = x.shape[-1]
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty(
x.shape[:-1] + (1,), device=x.device, dtype=scale_dtype
)

BLOCK = triton.next_power_of_2(N)
num_warps = min(max(BLOCK // 256, 1), 8)

_per_token_quant_int8[(M,)](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
N=N,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return x_q, scales
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🌐 Web query:

Does Triton 3.6.0 support tl.extra.cuda.libdevice.round on Jetson Orin (SM87) with CUDA 12.6?

💡 Result:

Yes, Triton 3.6.0 supports tl.extra.cuda.libdevice.round on Jetson Orin (SM_87) with CUDA 12.6.

Citations:


🏁 Script executed:

# First, locate and examine the file
find . -name "int8_activation_triton.py" -type f

Repository: UbiquitousLearning/mllm

Length of output: 123


🏁 Script executed:

# Get the file structure to understand context
wc -l pymllm/quantization/kernels/int8_activation_triton.py

Repository: UbiquitousLearning/mllm

Length of output: 124


🏁 Script executed:

# Read the specific lines mentioned in the review (59-82) with context
sed -n '50,90p' pymllm/quantization/kernels/int8_activation_triton.py

Repository: UbiquitousLearning/mllm

Length of output: 971


🏁 Script executed:

# Also check the docstring mentioned (Line 52)
sed -n '45,58p' pymllm/quantization/kernels/int8_activation_triton.py

Repository: UbiquitousLearning/mllm

Length of output: 539


🏁 Script executed:

# Search for the production caller mentioned: CompressedTensorsW8A8Int8Scheme.apply
rg "CompressedTensorsW8A8Int8Scheme|apply_int8" --type py -A 10 -B 2

Repository: UbiquitousLearning/mllm

Length of output: 2123


🏁 Script executed:

# Check for any direct imports or calls to _per_token_quant_int8
rg "_per_token_quant_int8" --type py -B 5 -A 5

Repository: UbiquitousLearning/mllm

Length of output: 6167


🏁 Script executed:

# Search for all calls to per_token_quant_int8 function
rg "per_token_quant_int8" --type py

Repository: UbiquitousLearning/mllm

Length of output: 1685


🏁 Script executed:

# Check bench_w4a16_vs_w8a8.py for how per_token_quant_int8 is called
grep -A 20 -B 5 "per_token_quant_int8" mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py

Repository: UbiquitousLearning/mllm

Length of output: 1859


🏁 Script executed:

# Check bench_w8a8_activation_quant.py for how it's called
grep -A 30 "triton_fn = _try_load_triton_kernel" pymllm/tests/bench_w8a8_activation_quant.py

Repository: UbiquitousLearning/mllm

Length of output: 986


🏁 Script executed:

# Check what test shapes are used in bench_w8a8_activation_quant.py
sed -n '40,120p' pymllm/tests/bench_w8a8_activation_quant.py

Repository: UbiquitousLearning/mllm

Length of output: 1779


x.stride(-2) will raise IndexError on 1-D inputs despite the docstring claiming "any shape".

The docstring (Line 52) states "Input tensor, any shape with last dim = hidden_dim", but for x.ndim == 1, x.stride(-2) raises IndexError: Dimension out of range. Currently dormant—the sole production caller (compressed_tensors.py line ~156) reshapes to 2-D before calling—but this is an undocumented constraint that becomes a footgun for direct callers or benchmarks.

Also, assert x.is_contiguous() is stripped under python -O; use an explicit ValueError check instead.

Suggested fix
-    assert x.is_contiguous(), "Input must be contiguous"
-
-    M = x.numel() // x.shape[-1]
+    if not x.is_contiguous():
+        raise ValueError("Input must be contiguous")
+    if x.ndim < 2:
+        x = x.view(1, -1)
+
     N = x.shape[-1]
+    M = x.numel() // N
     x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
     scales = torch.empty(
-        x.shape[:-1] + (1,), device=x.device, dtype=scale_dtype
+        (*x.shape[:-1], 1), device=x.device, dtype=scale_dtype
     )

The shape rewrite also picks up Ruff RUF005.

🧰 Tools
🪛 Ruff (0.15.12)

[warning] 65-65: Consider (*x.shape[:-1], 1) instead of concatenation

Replace with (*x.shape[:-1], 1)

(RUF005)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pymllm/quantization/kernels/int8_activation_triton.py` around lines 59 - 82,
Replace the brittle assert and the risky negative 2-index stride access: in the
function that calls _per_token_quant_int8 (the block around the
_per_token_quant_int8[(M,)](...) call), replace "assert x.is_contiguous()" with
an explicit check that raises ValueError if not contiguous, and compute stride_x
and stride_xq defensively (use x.stride(-2) / x_q.stride(-2) when x.ndim >= 2
else fall back to x.stride(0) / x_q.stride(0)); pass those computed stride_x and
stride_xq into the _per_token_quant_int8 call. This keeps the documented "any
shape" behavior and avoids IndexError on 1-D inputs and also fixes the assert
being stripped under python -O.

Comment thread pymllm/README-ZH.md
Comment on lines +199 to +201
{"type": "text", "text": "请详细描述这张图片。"},
{"type": "image_url", "image_url": {"url": "/workspace/xcd_mllm/test.png"}},
],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Hard-coded developer path leaked into example.

/workspace/xcd_mllm/test.png is a developer-specific absolute path that won't exist for anyone else. Replace with a placeholder so users don't blindly copy and hit a confusing FileNotFound:

📝 Suggested fix
                 {"type": "text", "text": "请详细描述这张图片。"},
-                {"type": "image_url", "image_url": {"url": "/workspace/xcd_mllm/test.png"}},
+                {"type": "image_url", "image_url": {"url": "/path/to/your/image.png"}},
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
{"type": "text", "text": "请详细描述这张图片。"},
{"type": "image_url", "image_url": {"url": "/workspace/xcd_mllm/test.png"}},
],
{"type": "text", "text": "请详细描述这张图片。"},
{"type": "image_url", "image_url": {"url": "/path/to/your/image.png"}},
],
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pymllm/README-ZH.md` around lines 199 - 201, The example in the README
currently hard-codes a developer-specific absolute path
("/workspace/xcd_mllm/test.png") inside the JSON snippet for the "image_url"
entry; replace that literal with a neutral placeholder (e.g.,
"<path/to/image.png>" or a sample relative/remote URL) so users won’t copy a
non-existent path—update the "image_url" -> "url" value in the JSON array shown
in the README-ZH.md accordingly.

Comment thread pymllm/README.md
Comment on lines +50 to +60
- JetPack / L4T: `R36.4.4` (`/etc/nv_tegra_release`)
- Python: `3.10.12`
- PyTorch: `2.4.0`
- torchvision: `0.19.0a0+48b1edf`
- transformers: `5.3.0`
- safetensors: `0.7.0`
- flashinfer: `0.6.7`
- Triton Language: official PyPI `triton==3.6.0` manylinux aarch64 wheel
- CUDA: `12.6`
- GPU: Jetson Orin NX, SM87

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify that the pinned versions are real.
for pkg in transformers safetensors flashinfer; do
  echo "=== $pkg ==="
  curl -s "https://pypi.org/pypi/${pkg}/json" | jq -r '.releases | keys | .[-5:]'
done

Repository: UbiquitousLearning/mllm

Length of output: 330


🏁 Script executed:

#!/bin/bash
# More targeted verification of specific versions

echo "=== Checking transformers 5.3.0 ==="
curl -s "https://pypi.org/pypi/transformers/5.3.0/json" | jq -r '.info.version // "NOT FOUND"'

echo "=== Checking safetensors 0.7.0 ==="
curl -s "https://pypi.org/pypi/safetensors/0.7.0/json" | jq -r '.info.version // "NOT FOUND"'

echo "=== Checking flashinfer 0.6.7 ==="
curl -s "https://pypi.org/pypi/flashinfer/0.6.7/json" | jq -r '.info.version // "NOT FOUND"'

echo ""
echo "=== Alternative: Check flashinfer existence (any version) ==="
curl -s "https://pypi.org/pypi/flashinfer/json" | jq -r '.info.name // "PACKAGE NOT FOUND"'

Repository: UbiquitousLearning/mllm

Length of output: 275


Fix flashinfer dependency reference — version and package do not exist on PyPI.

Verification found that transformers 5.3.0 and safetensors 0.7.0 exist on PyPI and are correct. However, flashinfer 0.6.7 cannot be obtained: neither this version nor any flashinfer package is available on PyPI. The validated environment table lists this non-existent dependency, making the documented setup irreproducible. Either remove flashinfer from the table, clarify if it should come from a different source (e.g., custom build, private repository), or provide correct installation instructions.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pymllm/README.md` around lines 50 - 60, The README's environment table lists
a non-existent PyPI package/version "flashinfer: `0.6.7`" which makes the setup
unreproducible; update the README.md entry referencing "flashinfer 0.6.7" by
either removing the flashinfer line, replacing it with the correct source (e.g.,
a Git URL, private repo, or build instructions), or adding explicit install
instructions and provenance (e.g., how to build/install from source or a link to
the correct release). Edit the list block in README.md where "flashinfer:
`0.6.7`" appears so the table accurately reflects how to obtain or install
flashinfer, and include the concrete install command or URL if you choose to
keep the dependency.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
pymllm/layers/rms_norm.py (1)

52-54: ⚠️ Potential issue | 🟠 Major

Avoid bare Exception in FlashInfer fallbacks; narrow and log instead.

Line 52 and Line 65 currently swallow all exceptions and silently fallback. In this core serving path, that can hide real defects and makes failures hard to diagnose. Catch expected runtime failures and emit a warning with context.

Proposed patch
 from typing import Optional, Tuple, Union
+import logging
 
 import torch
 import flashinfer
 from torch.nn import Parameter
@@
 from pymllm.layers.base import MllmBaseLayer
 from pymllm.layers.utils import set_weight_attrs
 
+logger = logging.getLogger(__name__)
+
@@
-            except Exception:
+            except (RuntimeError, ValueError) as e:
+                logger.warning(
+                    "flashinfer fused_add_rmsnorm failed; using torch fallback: %s", e
+                )
                 residual = x + residual
                 return _torch_rmsnorm(residual, self.weight, self.eps), residual
@@
-        except Exception:
+        except (RuntimeError, ValueError) as e:
+            logger.warning("flashinfer rmsnorm failed; using torch fallback: %s", e)
             return _torch_rmsnorm(x, self.weight, self.eps)

As per coding guidelines, "Ensure functions that can fail return appropriate error codes or raise exceptions" and "Add appropriate logging (e.g., debug, info, warning, error) for significant events and errors, avoiding sensitive data exposure".

For the FlashInfer version used by this repository, what exception classes are documented or commonly raised by `flashinfer.norm.fused_add_rmsnorm` and `flashinfer.norm.rmsnorm` on unsupported shapes/dtypes/runtime failures?

Also applies to: 65-66

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pymllm/layers/rms_norm.py` around lines 52 - 54, Replace the bare "except
Exception:" handlers around the FlashInfer fallbacks (the blocks calling
flashinfer.norm.fused_add_rmsnorm and flashinfer.norm.rmsnorm and returning
_torch_rmsnorm(...)) with a narrow catch for expected runtime errors (e.g.,
except (RuntimeError, ValueError, TypeError) as e:), log a warning that includes
function/context, tensor shapes/dtypes and the exception message, then perform
the fallback; re-raise any other unexpected exceptions (use "except Exception:
raise" or avoid catching them). Update both locations that currently swallow all
exceptions so only anticipated FlashInfer failures are logged and handled while
other errors bubble up.
🧹 Nitpick comments (2)
pymllm/bench_one_batch.py (2)

398-544: Document the runner lifecycle and failure contract.

PymllmBenchRunner allocates request/KV slots directly and depends on a specific create()clear()/extend()/decode()shutdown() lifecycle. A class docstring plus method docs for preconditions, ownership, and errors would make this much safer to reuse.

As per coding guidelines, ensure public APIs, classes, and functions have clear docstrings or comments explaining purpose, parameters, returns, and errors.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pymllm/bench_one_batch.py` around lines 398 - 544, Add a clear class-level
docstring to PymllmBenchRunner describing the required lifecycle (call create()
to get an instance, then use clear()/extend()/decode() in that order as needed,
and finally call shutdown()), the ownership semantics of resources (ModelRunner
and its req_to_token_pool and token_to_kv_pool_allocator), and the failure
contract (which methods raise RuntimeError or ValueError and in what
situations). Also add brief docstrings to create(), clear(), extend(), decode(),
and shutdown() noting their parameters, return types (e.g., extend()/decode()
return (next_token_ids, DecodeState)), preconditions (e.g., _require_initialized
or expected tensor shapes), and side effects (alloc/free of request/KV slots and
that extend() will free allocated req slots on allocation failure), referencing
the methods by name so callers can find the behavior easily.

43-43: Avoid hardcoded /tmp defaults for benchmark artifacts.

Both defaults pin the script to a POSIX temp path and keep outputs in a predictable shared location. Prefer a platform-derived temp/cache directory and per-run names by default so this stays portable and avoids accidental collisions.

As per coding guidelines, ensure code is portable across supported platforms (e.g., Linux, Windows) unless explicitly platform-specific, and identify potential security issues (e.g., insecure temporary files) and recommend using secure alternatives.

Also applies to: 378-378

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pymllm/bench_one_batch.py` at line 43, The hardcoded POSIX temp path assigned
to result_filename should be replaced with a platform-safe, unique temp file
path: import tempfile (and uuid.uuid4 or use
tempfile.NamedTemporaryFile(delete=False, suffix=".jsonl")) and set
result_filename to Path(tempfile.gettempdir()) /
f"pymllm_bench_one_batch_{uuid4()}.jsonl" or create a NamedTemporaryFile and use
its name, ensuring you add the necessary imports and use uuid4() or tempfile to
avoid collisions and insecure predictable names; apply the same change to the
other occurrence referenced (result filename at the later location).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@pymllm/layers/rms_norm.py`:
- Around line 52-54: Replace the bare "except Exception:" handlers around the
FlashInfer fallbacks (the blocks calling flashinfer.norm.fused_add_rmsnorm and
flashinfer.norm.rmsnorm and returning _torch_rmsnorm(...)) with a narrow catch
for expected runtime errors (e.g., except (RuntimeError, ValueError, TypeError)
as e:), log a warning that includes function/context, tensor shapes/dtypes and
the exception message, then perform the fallback; re-raise any other unexpected
exceptions (use "except Exception: raise" or avoid catching them). Update both
locations that currently swallow all exceptions so only anticipated FlashInfer
failures are logged and handled while other errors bubble up.

---

Nitpick comments:
In `@pymllm/bench_one_batch.py`:
- Around line 398-544: Add a clear class-level docstring to PymllmBenchRunner
describing the required lifecycle (call create() to get an instance, then use
clear()/extend()/decode() in that order as needed, and finally call shutdown()),
the ownership semantics of resources (ModelRunner and its req_to_token_pool and
token_to_kv_pool_allocator), and the failure contract (which methods raise
RuntimeError or ValueError and in what situations). Also add brief docstrings to
create(), clear(), extend(), decode(), and shutdown() noting their parameters,
return types (e.g., extend()/decode() return (next_token_ids, DecodeState)),
preconditions (e.g., _require_initialized or expected tensor shapes), and side
effects (alloc/free of request/KV slots and that extend() will free allocated
req slots on allocation failure), referencing the methods by name so callers can
find the behavior easily.
- Line 43: The hardcoded POSIX temp path assigned to result_filename should be
replaced with a platform-safe, unique temp file path: import tempfile (and
uuid.uuid4 or use tempfile.NamedTemporaryFile(delete=False, suffix=".jsonl"))
and set result_filename to Path(tempfile.gettempdir()) /
f"pymllm_bench_one_batch_{uuid4()}.jsonl" or create a NamedTemporaryFile and use
its name, ensuring you add the necessary imports and use uuid4() or tempfile to
avoid collisions and insecure predictable names; apply the same change to the
other occurrence referenced (result filename at the later location).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 245f8fba-3829-44a3-8397-37411ca4a31c

📥 Commits

Reviewing files that changed from the base of the PR and between efb65d0 and c2870b5.

📒 Files selected for processing (4)
  • pymllm/bench_one_batch.py
  • pymllm/layers/rms_norm.py
  • pymllm/tests/test_bench_one_batch.py
  • pymllm/tests/test_rms_norm.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
pymllm/models/qwen3_vl.py (1)

1176-1190: ⚠️ Potential issue | 🟠 Major

These timing fields are batch totals, but the scheduler stores them per request.

The values written onto forward_batch here are single scalars for the whole batch. Downstream, pymllm/orchestrator/scheduler_process.py lines 792-806 copy them onto each Request, so multi-request batches will stamp the same aggregate vit_prefill_ms, vit_prefill_tokens, llm_prefill_ms, and llm_decode_ms onto every request.

Also applies to: 1224-1232

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pymllm/models/qwen3_vl.py` around lines 1176 - 1190, The batch-level
timing/count fields (vit_prefill_ms, vit_prefill_tokens, llm_prefill_ms,
llm_decode_ms) written into forward_batch must be per-request, not single
scalars; change the logic around forward_batch assignment so it stores
per-request arrays. Specifically, compute per-sample image token counts from
image_mask (use image_mask.sum(dim=1) or equivalent) and set vit_prefill_tokens
as a list/array per request; split vit_prefill_ms (and llm_prefill_ms /
llm_decode_ms at the other region noted around 1224-1232) across requests
proportionally to each request's token count (or compute timings per-request
earlier before batching) and write per-request values into forward_batch so
scheduler_process.py will copy correct per-request metrics. Ensure references:
forward_batch, vit_prefill_ms, vit_prefill_tokens, llm_prefill_ms,
llm_decode_ms, image_mask, vision_embeds, input_embeds.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@mllm-kernel/mllm_kernel/cuda/csrc/gemm/int8/int8_scaled_mm_cutlass.cu`:
- Around line 328-343: Add strict same-device CUDA checks for all inputs before
extracting raw pointers and launching the kernel: verify scales_a, scales_b, and
bias (if provided) are CUDA tensors, have correct dtypes/numel as already
checked, and are on the same device as mat_a (use mat_a.device() /
mat_b.device() for comparison); also assert mat_b.device() == mat_a.device().
Update the validation block that currently checks mat_a and mat_b to include
these additional TORCH_CHECKs so the kernel never receives CPU or cross-GPU
pointers when launching on mat_a's stream.

In `@mllm-kernel/tests/test_int8_scaled_mm_cutlass.py`:
- Around line 39-45: The fixture cutlass_module should skip tests on GPUs below
SM80; before importing mllm_kernel.cuda.jit.int8_scaled_mm_cutlass, query
torch.cuda.get_device_properties(0).major and .minor, compute compute_capability
= major*10 + minor and if compute_capability < 80 call pytest.skip("SM80+ GPU
required") so the runtime check in _current_cuda_arch() doesn't raise
RuntimeError during import; add this check right after the existing CUDA
availability check in cutlass_module.

---

Outside diff comments:
In `@pymllm/models/qwen3_vl.py`:
- Around line 1176-1190: The batch-level timing/count fields (vit_prefill_ms,
vit_prefill_tokens, llm_prefill_ms, llm_decode_ms) written into forward_batch
must be per-request, not single scalars; change the logic around forward_batch
assignment so it stores per-request arrays. Specifically, compute per-sample
image token counts from image_mask (use image_mask.sum(dim=1) or equivalent) and
set vit_prefill_tokens as a list/array per request; split vit_prefill_ms (and
llm_prefill_ms / llm_decode_ms at the other region noted around 1224-1232)
across requests proportionally to each request's token count (or compute timings
per-request earlier before batching) and write per-request values into
forward_batch so scheduler_process.py will copy correct per-request metrics.
Ensure references: forward_batch, vit_prefill_ms, vit_prefill_tokens,
llm_prefill_ms, llm_decode_ms, image_mask, vision_embeds, input_embeds.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 056f15e3-9555-41ba-8a82-9d1ef20af346

📥 Commits

Reviewing files that changed from the base of the PR and between c2870b5 and 4f13ff1.

📒 Files selected for processing (5)
  • mllm-kernel/mllm_kernel/cuda/csrc/gemm/int8/int8_scaled_mm_cutlass.cu
  • mllm-kernel/mllm_kernel/cuda/jit/int8_scaled_mm_cutlass.py
  • mllm-kernel/tests/test_int8_scaled_mm_cutlass.py
  • pymllm/models/qwen3_vl.py
  • pymllm/tests/test_qwen3_vl_deepstack.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • mllm-kernel/mllm_kernel/cuda/jit/int8_scaled_mm_cutlass.py

Comment on lines +328 to +343
TORCH_CHECK(mat_a.is_cuda(), "mat_a must be CUDA tensor");
TORCH_CHECK(mat_b.is_cuda(), "mat_b must be CUDA tensor");
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be 2D");
TORCH_CHECK(mat_b.dim() == 2, "mat_b must be 2D");
TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be row-major");
TORCH_CHECK(mat_b.stride(0) == 1, "mat_b must be column-major");
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "shape mismatch");
TORCH_CHECK(mat_a.size(1) % 16 == 0, "K must be multiple of 16");
TORCH_CHECK(mat_b.size(1) % 8 == 0, "N must be multiple of 8");
TORCH_CHECK(mat_a.scalar_type() == torch::kInt8, "mat_a must be Int8");
TORCH_CHECK(mat_b.scalar_type() == torch::kInt8, "mat_b must be Int8");
TORCH_CHECK(scales_a.numel() == mat_a.size(0), "scales_a size mismatch");
TORCH_CHECK(scales_b.numel() == mat_b.size(1), "scales_b size mismatch");
TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be fp32");
TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be fp32");

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

fd -t f "int8_scaled_mm_cutlass.cu" --exec cat -n {} \; | head -400

Repository: UbiquitousLearning/mllm

Length of output: 19121


Add strict same-device CUDA checks for all kernel inputs.

The current validation at lines 328-342 checks only mat_a/mat_b CUDA residency. scales_a, scales_b, bias, and cross-GPU device equality are not enforced. Since raw pointers are extracted at lines 111-115 and the kernel launches on mat_a's stream (line 141), mixed-device or CPU tensors will cause the kernel to receive invalid pointers, resulting in segfaults or silent corruption.

🔧 Proposed fix
 torch::Tensor int8_scaled_mm(
     const torch::Tensor& mat_a,
     const torch::Tensor& mat_b,
@@
   TORCH_CHECK(mat_a.is_cuda(), "mat_a must be CUDA tensor");
   TORCH_CHECK(mat_b.is_cuda(), "mat_b must be CUDA tensor");
+  TORCH_CHECK(scales_a.is_cuda(), "scales_a must be CUDA tensor");
+  TORCH_CHECK(scales_b.is_cuda(), "scales_b must be CUDA tensor");
+  TORCH_CHECK(mat_a.get_device() == mat_b.get_device(),
+              "mat_a and mat_b must be on the same CUDA device");
+  TORCH_CHECK(mat_a.get_device() == scales_a.get_device(),
+              "scales_a must be on the same CUDA device as mat_a");
+  TORCH_CHECK(mat_a.get_device() == scales_b.get_device(),
+              "scales_b must be on the same CUDA device as mat_a");
@@
   if (bias) {
+    TORCH_CHECK(bias->is_cuda(), "bias must be CUDA tensor");
+    TORCH_CHECK(bias->get_device() == mat_a.get_device(),
+                "bias must be on the same CUDA device as mat_a");
     TORCH_CHECK(bias->numel() == mat_b.size(1), "bias size mismatch");
     TORCH_CHECK(bias->dtype() == out_dtype, "bias dtype must match out_dtype");
   }

Also applies to: 353-356

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@mllm-kernel/mllm_kernel/cuda/csrc/gemm/int8/int8_scaled_mm_cutlass.cu` around
lines 328 - 343, Add strict same-device CUDA checks for all inputs before
extracting raw pointers and launching the kernel: verify scales_a, scales_b, and
bias (if provided) are CUDA tensors, have correct dtypes/numel as already
checked, and are on the same device as mat_a (use mat_a.device() /
mat_b.device() for comparison); also assert mat_b.device() == mat_a.device().
Update the validation block that currently checks mat_a and mat_b to include
these additional TORCH_CHECKs so the kernel never receives CPU or cross-GPU
pointers when launching on mat_a's stream.

Comment on lines +39 to +45
def cutlass_module():
"""Load CUTLASS module once for all tests."""
pytest.importorskip("torch")
if not torch.cuda.is_available():
pytest.skip("CUDA required")
from mllm_kernel.cuda.jit.int8_scaled_mm_cutlass import int8_scaled_mm
return int8_scaled_mm
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, find the CUTLASS kernel implementation
find . -name "*int8_scaled_mm_cutlass*" -type f

Repository: UbiquitousLearning/mllm

Length of output: 251


🏁 Script executed:

# Search for SM version constraints in the codebase
rg "SM80|SM89|compute_capability|major.*8" --type py --type cpp --type cu -A 2 -B 2

Repository: UbiquitousLearning/mllm

Length of output: 94


🏁 Script executed:

# Examine the test file to see the current fixture implementation
cat -n mllm-kernel/tests/test_int8_scaled_mm_cutlass.py

Repository: UbiquitousLearning/mllm

Length of output: 5833


🏁 Script executed:

# Check the kernel CUDA source for SM requirements
head -100 ./mllm-kernel/mllm_kernel/cuda/csrc/gemm/int8/int8_scaled_mm_cutlass.cu

Repository: UbiquitousLearning/mllm

Length of output: 4167


🏁 Script executed:

# Check JIT wrapper for SM version enforcement
cat -n ./mllm-kernel/mllm_kernel/cuda/jit/int8_scaled_mm_cutlass.py

Repository: UbiquitousLearning/mllm

Length of output: 5092


Add SM80+ GPU check to fixture to skip on unsupported hardware.

The kernel enforces SM80-SM89 at runtime via _current_cuda_arch() and raises RuntimeError on incompatible GPUs. The fixture only checks CUDA availability, causing tests to fail instead of skip on pre-SM80 GPUs. Add a device capability check in the fixture before importing the kernel module.

Proposed fix
 `@pytest.fixture`(scope="module")
 def cutlass_module():
     """Load CUTLASS module once for all tests."""
     pytest.importorskip("torch")
     if not torch.cuda.is_available():
         pytest.skip("CUDA required")
+    major, _ = torch.cuda.get_device_capability()
+    if major < 8:
+        pytest.skip("CUTLASS int8_scaled_mm requires SM80+ GPUs")
     from mllm_kernel.cuda.jit.int8_scaled_mm_cutlass import int8_scaled_mm
     return int8_scaled_mm
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@mllm-kernel/tests/test_int8_scaled_mm_cutlass.py` around lines 39 - 45, The
fixture cutlass_module should skip tests on GPUs below SM80; before importing
mllm_kernel.cuda.jit.int8_scaled_mm_cutlass, query
torch.cuda.get_device_properties(0).major and .minor, compute compute_capability
= major*10 + minor and if compute_capability < 80 call pytest.skip("SM80+ GPU
required") so the runtime check in _current_cuda_arch() doesn't raise
RuntimeError during import; add this check right after the existing CUDA
availability check in cutlass_module.

Comment thread pymllm/models/qwen3_vl.py
Comment on lines +1172 to +1176
_vit_t0 = time.perf_counter()
vision_features = (
self.visual(pixel_values, grid_thw=image_grid_thw)
)
vit_prefill_ms = (time.perf_counter() - _vit_t0) * 1000.0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, check if the file exists and read the target lines
wc -l pymllm/models/qwen3_vl.py

Repository: UbiquitousLearning/mllm

Length of output: 98


🏁 Script executed:

# Read lines around 1172-1176
sed -n '1165,1180p' pymllm/models/qwen3_vl.py

Repository: UbiquitousLearning/mllm

Length of output: 715


🏁 Script executed:

# Read lines around 1212-1222
sed -n '1205,1230p' pymllm/models/qwen3_vl.py

Repository: UbiquitousLearning/mllm

Length of output: 982


🏁 Script executed:

# Search for how timing fields are used downstream
rg "vit_prefill_ms|llm_.*_ms" --type py -B 2 -A 2

Repository: UbiquitousLearning/mllm

Length of output: 17766


🏁 Script executed:

# Search for cuda.synchronize usage
rg "cuda\.synchronize|torch\.cuda\.synchronize" --type py -B 2 -A 2

Repository: UbiquitousLearning/mllm

Length of output: 6006


🏁 Script executed:

# Check if device detection exists
rg "device\.type|\.cuda\(\)|to\(.*cuda" pymllm/models/qwen3_vl.py | head -20

Repository: UbiquitousLearning/mllm

Length of output: 49


🏁 Script executed:

# Check ForwardBatch definition
rg "class ForwardBatch" --type py -A 20

Repository: UbiquitousLearning/mllm

Length of output: 1524


🏁 Script executed:

# Look at model runner forward to understand context
sed -n '1,50p' pymllm/orchestrator/model_runner_process.py

Repository: UbiquitousLearning/mllm

Length of output: 1820


🏁 Script executed:

# Check the comment context in model_runner_process around CUDA timing
sed -n '90,140p' pymllm/orchestrator/model_runner_process.py

Repository: UbiquitousLearning/mllm

Length of output: 2210


🏁 Script executed:

# Look at how ForwardBatch timing fields are extracted and used
sed -n '95,180p' pymllm/orchestrator/model_runner_process.py

Repository: UbiquitousLearning/mllm

Length of output: 3619


🏁 Script executed:

# Verify if timing should be per-batch or per-request
rg "forward_batch\.llm_prefill_ms|forward_batch\.vit_prefill_ms" --type py -B 3 -A 3

Repository: UbiquitousLearning/mllm

Length of output: 1106


🏁 Script executed:

# Check the exact code where batch timing is copied to per-request
sed -n '280,295p' pymllm/orchestrator/scheduler_process.py

Repository: UbiquitousLearning/mllm

Length of output: 670


🏁 Script executed:

# Verify if requests are batched together
rg "batch\.reqs|len.*batch.*reqs" pymllm/orchestrator/scheduler_process.py -B 2 -A 2 | head -40

Repository: UbiquitousLearning/mllm

Length of output: 1394


🏁 Script executed:

# Check if timing should really be per-request or batch-level by examining how it's used
sed -n '320,340p' pymllm/orchestrator/scheduler_process.py

Repository: UbiquitousLearning/mllm

Length of output: 919


🏁 Script executed:

# Check if device type is available in the forward function context
sed -n '1140,1180p' pymllm/models/qwen3_vl.py

Repository: UbiquitousLearning/mllm

Length of output: 1772


🏁 Script executed:

# Check if self.visual has device attribute
rg "self\.visual\(" pymllm/models/qwen3_vl.py -B 2 -A 5 | head -30

Repository: UbiquitousLearning/mllm

Length of output: 407


🏁 Script executed:

# Check the full forward function signature
sed -n '1100,1140p' pymllm/models/qwen3_vl.py

Repository: UbiquitousLearning/mllm

Length of output: 1925


🏁 Script executed:

# Check the entire flow more carefully - see if per-request vs batch timing is documented anywhere
rg "forward_batch\.|ForwardBatch" pymllm/engine/forward_batch.py | head -20

Repository: UbiquitousLearning/mllm

Length of output: 252


🏁 Script executed:

# Examine timing field definitions in ForwardBatch
sed -n '1,150p' pymllm/engine/forward_batch.py | tail -60

Repository: UbiquitousLearning/mllm

Length of output: 2508


🏁 Script executed:

# Find timing field definitions in ForwardBatch
rg "vit_prefill_ms|llm_prefill_ms|llm_decode_ms" pymllm/engine/forward_batch.py -B 2 -A 2

Repository: UbiquitousLearning/mllm

Length of output: 49


🏁 Script executed:

# Check if ForwardBatch is truly batch-level (one timing value per batch)
sed -n '1,300p' pymllm/engine/forward_batch.py | grep -A 5 "vit_prefill\|llm_prefill\|llm_decode"

Repository: UbiquitousLearning/mllm

Length of output: 49


🏁 Script executed:

# Look at complete ForwardBatch class definition
wc -l pymllm/engine/forward_batch.py
sed -n '80,200p' pymllm/engine/forward_batch.py

Repository: UbiquitousLearning/mllm

Length of output: 4514


🏁 Script executed:

# Check the rest of ForwardBatch for timing fields
sed -n '160,191p' pymllm/engine/forward_batch.py

Repository: UbiquitousLearning/mllm

Length of output: 1442


🏁 Script executed:

# Search for where timing fields are actually added to ForwardBatch
rg "ForwardBatch\(\)" pymllm/models/qwen3_vl.py -B 3 -A 3

Repository: UbiquitousLearning/mllm

Length of output: 49


🏁 Script executed:

# Check if timing fields are added dynamically as attributes
rg "setattr.*forward_batch|forward_batch\.(vit|llm).*=" pymllm/models/qwen3_vl.py -B 1 -A 1

Repository: UbiquitousLearning/mllm

Length of output: 440


🏁 Script executed:

# Check the complete implementation in qwen3_vl.py forward method
sed -n '1220,1235p' pymllm/models/qwen3_vl.py

Repository: UbiquitousLearning/mllm

Length of output: 695


🏁 Script executed:

# Check how many requests typically end up in a batch
rg "batch_size|len.*batch.*reqs" pymllm/orchestrator/scheduler_process.py -B 2 -A 2 | head -50

Repository: UbiquitousLearning/mllm

Length of output: 1056


🏁 Script executed:

# Check if there's any per-request timing accumulation elsewhere
sed -n '440,480p' pymllm/orchestrator/scheduler_process.py

Repository: UbiquitousLearning/mllm

Length of output: 1646


🏁 Script executed:

# Verify: does the model runner set timing per-batch or per-request?
sed -n '150,200p' pymllm/orchestrator/model_runner_process.py

Repository: UbiquitousLearning/mllm

Length of output: 1991


🏁 Script executed:

# Final verification: Check if there are tests that validate timing correctness
rg "vit_prefill_ms|llm.*_ms" pymllm/tests/ --type py -l

Repository: UbiquitousLearning/mllm

Length of output: 109


🏁 Script executed:

# Check test_qwen3_forward_timing.py to see what's actually being tested
cat pymllm/tests/test_qwen3_forward_timing.py

Repository: UbiquitousLearning/mllm

Length of output: 3068


Synchronize CUDA before measuring wall-clock timing around model calls.

time.perf_counter() around self.visual(...) and self.model(...) only measures kernel submission time on CUDA. Without device synchronization, these timers capture only the async launch overhead, not actual compute time. On GPU paths, vit_prefill_ms and llm_*_ms will be systematically underestimated.

Additionally, when a batch contains multiple requests, the batch-level timing value is copied to each request object downstream (in scheduler_process.py), misattributing a single batch duration to all individual requests within that batch.

⏱️ Suggested fix
+def _sync_if_cuda(device: torch.device) -> None:
+    if device.type == "cuda":
+        torch.cuda.synchronize(device)
+
 ...
-            _vit_t0 = time.perf_counter()
+            _sync_if_cuda(self.visual.device)
+            _vit_t0 = time.perf_counter()
             vision_features = (
                 self.visual(pixel_values, grid_thw=image_grid_thw)
             )
-            vit_prefill_ms = (time.perf_counter() - _vit_t0) * 1000.0
+            _sync_if_cuda(vision_features.device)
+            vit_prefill_ms = (time.perf_counter() - _vit_t0) * 1000.0
 ...
-        _llm_t0 = time.perf_counter()
+        _sync_if_cuda(self.model.embed_tokens.weight.device)
+        _llm_t0 = time.perf_counter()
         hidden_states = (
             self.model(
                 input_ids,
@@
-        _llm_ms = (time.perf_counter() - _llm_t0) * 1000.0
+        _sync_if_cuda(hidden_states.device)
+        _llm_ms = (time.perf_counter() - _llm_t0) * 1000.0

Also applies to: 1212-1222

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@pymllm/README-ZH.md`:
- Around line 133-143: The BF16 example in the "BF16 原生模型" section is
inconsistent: change the launch command flag value from "--server.dtype float16"
to "--server.dtype bf16" (or "bfloat16" if that is the CLI's accepted token) so
the --server.dtype setting matches the BF16 section title; update the example
invocation in the README snippet (the line containing "--server.dtype float16")
to use the bf16 token used by the server CLI.

In `@pymllm/server/launch.py`:
- Around line 793-798: The code calls _maybe_add_debug_timing(payload,
result=results[-1], prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens) which mixes timing from only the last
result with token counts aggregated across all results; instead compute an
aggregated timing dict by iterating results and summing numeric timing fields
from each result.get("timing", {}) (or merge/accumulate per-field totals) and
pass that aggregated timing dict as result, or alternatively call
_maybe_add_debug_timing once per individual result with that result's own
prompt/completion token counts; update the call site in launch.py to replace
result=results[-1] with the aggregated timing dict (or loop per-result) and keep
prompt_tokens/completion_tokens consistent with the chosen aggregation approach.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ec5a8fe4-ef55-440c-9d17-88355b4ab1fe

📥 Commits

Reviewing files that changed from the base of the PR and between 4f13ff1 and 6e6b900.

⛔ Files ignored due to path filters (1)
  • docs/_static/img/pymllm-arch.png is excluded by !**/*.png
📒 Files selected for processing (12)
  • docs/index.rst
  • docs/pymllm_runtime/developer_guide.rst
  • docs/pymllm_runtime/index.rst
  • docs/pymllm_runtime/kernels_and_acceleration.rst
  • docs/pymllm_runtime/models_and_quantization.rst
  • docs/pymllm_runtime/runtime_design.rst
  • docs/pymllm_runtime/setup_and_usage.rst
  • pymllm/README-ZH.md
  • pymllm/README.md
  • pymllm/configs/server_config.py
  • pymllm/server/launch.py
  • pymllm/tests/test_server_debug_timing.py
✅ Files skipped from review due to trivial changes (9)
  • docs/index.rst
  • docs/pymllm_runtime/index.rst
  • docs/pymllm_runtime/developer_guide.rst
  • docs/pymllm_runtime/models_and_quantization.rst
  • docs/pymllm_runtime/setup_and_usage.rst
  • docs/pymllm_runtime/kernels_and_acceleration.rst
  • pymllm/configs/server_config.py
  • pymllm/README.md
  • docs/pymllm_runtime/runtime_design.rst

Comment thread pymllm/README-ZH.md
Comment on lines +133 to +143
### BF16 原生模型

```bash
cd <repo-root>

python3 -m pymllm.server.launch \
--server.model_path <model-path> \
--server.tokenizer_path <model-path> \
--server.load_format safetensors \
--server.dtype float16 \
--server.host 0.0.0.0 \
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

BF16 section uses float16 in the launch example.

The section title says BF16 native serving, but the command at Line 142 sets --server.dtype float16. Please align the flag with the section intent to avoid user confusion.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pymllm/README-ZH.md` around lines 133 - 143, The BF16 example in the "BF16
原生模型" section is inconsistent: change the launch command flag value from
"--server.dtype float16" to "--server.dtype bf16" (or "bfloat16" if that is the
CLI's accepted token) so the --server.dtype setting matches the BF16 section
title; update the example invocation in the README snippet (the line containing
"--server.dtype float16") to use the bf16 token used by the server CLI.

Comment thread pymllm/server/launch.py
Comment on lines +793 to 798
_maybe_add_debug_timing(
payload,
result=results[-1] if results else {},
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Avoid mixing aggregated token counts with single-result timing.

At Line 795, timing is taken from results[-1], but prompt_tokens/completion_tokens are aggregated across all results. For list prompts, this produces misleading TPS.

Suggested fix
-        _maybe_add_debug_timing(
-            payload,
-            result=results[-1] if results else {},
-            prompt_tokens=prompt_tokens,
-            completion_tokens=completion_tokens,
-        )
+        # debug_timing currently represents a single request/result.
+        # Avoid publishing mixed-scope TPS for batched prompt lists.
+        if len(results) == 1:
+            r0 = results[0]
+            _maybe_add_debug_timing(
+                payload,
+                result=r0,
+                prompt_tokens=r0.get("prompt_tokens", 0),
+                completion_tokens=r0.get("completion_tokens", 0),
+            )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
_maybe_add_debug_timing(
payload,
result=results[-1] if results else {},
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
# debug_timing currently represents a single request/result.
# Avoid publishing mixed-scope TPS for batched prompt lists.
if len(results) == 1:
r0 = results[0]
_maybe_add_debug_timing(
payload,
result=r0,
prompt_tokens=r0.get("prompt_tokens", 0),
completion_tokens=r0.get("completion_tokens", 0),
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pymllm/server/launch.py` around lines 793 - 798, The code calls
_maybe_add_debug_timing(payload, result=results[-1],
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) which mixes
timing from only the last result with token counts aggregated across all
results; instead compute an aggregated timing dict by iterating results and
summing numeric timing fields from each result.get("timing", {}) (or
merge/accumulate per-field totals) and pass that aggregated timing dict as
result, or alternatively call _maybe_add_debug_timing once per individual result
with that result's own prompt/completion token counts; update the call site in
launch.py to replace result=results[-1] with the aggregated timing dict (or loop
per-result) and keep prompt_tokens/completion_tokens consistent with the chosen
aggregation approach.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
pymllm/models/qwen3.py (2)

65-96: 💤 Low value

Dead code: non-fused path is unreachable.

use_fused_qkv = True is hardcoded, making lines 77-96 unreachable. Either parameterize use_fused_qkv in the constructor or remove the dead code path.

♻️ Option A: Remove dead code
-        self.use_fused_qkv = True
-
-        if self.use_fused_qkv:
-            self.qkv_proj = MergedLinear(
-                hidden_size,
-                [self.q_size, self.kv_size, self.kv_size],
-                bias=attention_bias,
-                quant_method=_get_qm("qkv_proj"),
-            )
-            self.q_proj = None
-            self.k_proj = None
-            self.v_proj = None
-        else:
-            self.qkv_proj = None
-            self.q_proj = Linear(
-                hidden_size,
-                self.q_size,
-                bias=attention_bias,
-                quant_method=_get_qm("q_proj"),
-            )
-            self.k_proj = Linear(
-                hidden_size,
-                self.kv_size,
-                bias=attention_bias,
-                quant_method=_get_qm("k_proj"),
-            )
-            self.v_proj = Linear(
-                hidden_size,
-                self.kv_size,
-                bias=attention_bias,
-                quant_method=_get_qm("v_proj"),
-            )
+        self.qkv_proj = MergedLinear(
+            hidden_size,
+            [self.q_size, self.kv_size, self.kv_size],
+            bias=attention_bias,
+            quant_method=_get_qm("qkv_proj"),
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pymllm/models/qwen3.py` around lines 65 - 96, The non-fused attention branch
is dead because use_fused_qkv is hardcoded True; update the class constructor to
accept a parameter (e.g., use_fused_qkv) and set self.use_fused_qkv from that
param (or remove the else block if fused-only is intended), then adjust
initialization of qkv_proj, q_proj, k_proj, and v_proj accordingly (referencing
self.use_fused_qkv, MergedLinear, and Linear to locate the code paths); ensure
default behavior preserves current fused behavior by defaulting the new
parameter to True and update any callers to pass False if they need the
non-fused path.

349-351: 💤 Low value

Move import to module level.

Placing the import inside forward() incurs lookup overhead on every call. If there's no circular import issue, move it to the top with other imports.

♻️ Proposed fix

At the top of the file (around line 24):

from pymllm.executor.model_runner import LogitsProcessorOutput

Then simplify lines 349-351:

-        from pymllm.executor.model_runner import LogitsProcessorOutput
-
         return LogitsProcessorOutput(next_token_logits=logits)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pymllm/models/qwen3.py` around lines 349 - 351, The local import of
LogitsProcessorOutput inside the forward() method causes repeated lookup
overhead; move "from pymllm.executor.model_runner import LogitsProcessorOutput"
to the module-level imports at the top of pymllm/models/qwen3.py and remove the
in-method import in forward(), keeping the return as return
LogitsProcessorOutput(next_token_logits=logits); if a circular import prevents
this, add a short comment explaining the circular dependency and leave the local
import as-is.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@pymllm/models/qwen3.py`:
- Around line 65-96: The non-fused attention branch is dead because
use_fused_qkv is hardcoded True; update the class constructor to accept a
parameter (e.g., use_fused_qkv) and set self.use_fused_qkv from that param (or
remove the else block if fused-only is intended), then adjust initialization of
qkv_proj, q_proj, k_proj, and v_proj accordingly (referencing
self.use_fused_qkv, MergedLinear, and Linear to locate the code paths); ensure
default behavior preserves current fused behavior by defaulting the new
parameter to True and update any callers to pass False if they need the
non-fused path.
- Around line 349-351: The local import of LogitsProcessorOutput inside the
forward() method causes repeated lookup overhead; move "from
pymllm.executor.model_runner import LogitsProcessorOutput" to the module-level
imports at the top of pymllm/models/qwen3.py and remove the in-method import in
forward(), keeping the return as return
LogitsProcessorOutput(next_token_logits=logits); if a circular import prevents
this, add a short comment explaining the circular dependency and leave the local
import as-is.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2458cdf9-12ee-43fb-905a-aa8ce2e7126c

📥 Commits

Reviewing files that changed from the base of the PR and between 6e6b900 and e06c3a5.

📒 Files selected for processing (10)
  • pymllm/layers/__init__.py
  • pymllm/layers/linear.py
  • pymllm/layers/mlp.py
  • pymllm/models/qwen3.py
  • pymllm/models/qwen3_vl.py
  • pymllm/tests/test_linear_merged.py
  • pymllm/tests/test_qwen3_residual_carry.py
  • pymllm/tests/test_qwen3_vl_deepstack.py
  • pymllm/tests/test_qwen3_vl_weight_loading.py
  • pymllm/tests/test_qwen3_weight_loading.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • pymllm/tests/test_qwen3_vl_deepstack.py

Copy link
Copy Markdown
Collaborator

@chenghuaWang chenghuaWang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@chenghuaWang chenghuaWang merged commit 729ca4c into UbiquitousLearning:main Apr 30, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants